Browse Source

partially supported fromWire() rdata (A/AAA/NS only).
now allowed parsing a full DNS message.


git-svn-id: svn://bind10.isc.org/svn/bind10/branches/f2f200910@202 e5f2f494-b856-4b98-b285-d166d9295462

JINMEI Tatuya 15 years ago
parent
commit
5e96173c5a

+ 35 - 0
src/lib/dns/buffer.h

@@ -49,8 +49,11 @@ public:
     virtual size_t getSize() const = 0;
     virtual size_t getSpace() const = 0;
     virtual size_t getCurrent() const = 0;
+    virtual void setCurrent(size_t pos) = 0;
     virtual uint8_t readUint8() = 0;
     virtual uint16_t readUint16() = 0;
+    virtual uint32_t readUint32() = 0;
+    virtual void readData(void* data, size_t len) = 0;
     virtual int recvFrom(int s, struct sockaddr *from,
                          socklen_t *from_len) = 0;
 };
@@ -99,8 +102,16 @@ public:
     size_t getSize() const { return (buf_.size()); }
     size_t getSpace() const { return (buf_.size() - _readpos); }
     size_t getCurrent() const { return (_readpos); }
+    void setCurrent(size_t pos)
+    {
+        if (pos >= buf_.size())
+            throw isc::ISCBufferInvalidPosition();
+        _readpos = pos;
+    }
     uint8_t readUint8();
     uint16_t readUint16();
+    uint32_t readUint32();
+    void readData(void* data, size_t len);
     int recvFrom(int s, struct sockaddr* from, socklen_t* from_len);
 
 private:
@@ -130,6 +141,30 @@ SingleBuffer::readUint16()
 
     return (ntohs(data));
 }
+
+inline uint32_t
+SingleBuffer::readUint32()
+{
+    uint32_t data;
+
+    if (_readpos + sizeof(data) > buf_.size())
+        throw ISCBufferInvalidPosition();
+
+    memcpy((void*)&data, &buf_[_readpos], sizeof(data));
+    _readpos += sizeof(data);
+
+    return (ntohl(data));
+}
+
+inline void
+SingleBuffer::readData(void *data, size_t len)
+{
+    if (_readpos + len > buf_.size())
+        throw ISCBufferInvalidPosition();
+
+    memcpy(data, &buf_[_readpos], len);
+    _readpos += len;
+}
 }
 #endif  // __BUFFER_HH
 

+ 3 - 0
src/lib/dns/exceptions.h

@@ -47,6 +47,9 @@ class DNSInvalidMessageSection : public DNSException {};
 class DNSInvalidRendererPosition : public DNSException {};
 class DNSMessageTooShort : public DNSException {};
 class DNSCharStringTooLong : public DNSException {};
+class DNSNameDecompressionProhibited : public DNSException {};
+class DNSNameBadPointer : public DNSException {};
+class DNSInvalidRdata : public DNSException {};
 }
 }
 #endif  // __EXCEPTIONS_HH

+ 32 - 3
src/lib/dns/message.cc

@@ -24,16 +24,19 @@
 #include <boost/lexical_cast.hpp>
 
 #include <dns/buffer.h>
+#include <dns/name.h>
 #include <dns/rrset.h>
 #include <dns/message.h>
 
 using isc::dns::Name;
-using isc::dns::Message;
 using isc::dns::RRType;
 using isc::dns::RRClass;
+using isc::dns::TTL;
+using isc::dns::Message;
+using isc::dns::Rdata::Rdata;
+using isc::dns::Rdata::RdataPtr;
 using isc::dns::RRsetPtr;
 using isc::dns::RR;
-using isc::dns::TTL;
 
 Message::Message()
 {
@@ -161,7 +164,9 @@ Message::fromWire()
     counts_[SECTION_ADDITIONAL] = buffer_->readUint16();
 
     parse_question();
-    // parse other sections (TBD)
+    for (int section = SECTION_ANSWER; section < SECTION_MAX; ++section) {
+        parse_section(static_cast<section_t>(section)); // XXX cast
+    }
 }
 
 void
@@ -190,6 +195,30 @@ Message::parse_question()
     }
 }
 
+void
+Message::parse_section(section_t section)
+{
+    if (buffer_ == NULL)
+        throw DNSNoMessageBuffer();
+
+    for (int count = 0; count < this->counts_[section]; count++) {
+        Name name(*buffer_, getDecompressor());
+
+        // Get type, class, TTL
+        if (buffer_->getSpace() < 2 * sizeof(uint16_t) + sizeof(uint32_t))
+            throw DNSMessageTooShort();
+
+        RRType rrtype(buffer_->readUint16());
+        RRClass rrclass(buffer_->readUint16());
+        TTL ttl(buffer_->readUint32());
+        addRR(section, RR(name, rrclass, rrtype, ttl,
+                          RdataPtr(isc::dns::Rdata::Rdata::fromWire(rrclass,
+                                                                    rrtype,
+                                                                    *buffer_,
+                                                                    getDecompressor()))));
+    }
+}
+
 static const char *opcodetext[] = {
     "QUERY",
     "IQUERY",

+ 1 - 0
src/lib/dns/message.h

@@ -162,6 +162,7 @@ public:
 private:
     void initialize();
     void parse_question();
+    void parse_section(section_t section);
 
 private:
     // Open issues: should we rather have a header in wire-format

+ 29 - 2
src/lib/dns/name.cc

@@ -269,7 +269,8 @@ Name::Name(const std::string& namestr)
 Name::Name(Buffer& buffer, NameDecompressor& decompressor)
 {
     unsigned int nused, labels, n, nmax;
-    unsigned int current;
+    unsigned int cused; /* Bytes of compressed name data used */
+    unsigned int current, new_current, biggest_pointer, pos_begin;
     bool done;
     fw_state state = fw_start;
     unsigned int c;
@@ -287,6 +288,7 @@ Name::Name(Buffer& buffer, NameDecompressor& decompressor)
     labels = 0;
     done = false;
     nused = 0;
+    seen_pointer = false;
 
     /*
      * Find the maximum number of uncompressed target name
@@ -296,7 +298,10 @@ Name::Name(Buffer& buffer, NameDecompressor& decompressor)
      */
     nmax = MAXWIRE;
 
+    cused = 0;
     current = buffer.getCurrent();
+    pos_begin = current;
+    biggest_pointer = current;
 
     /*
      * Note:  The following code is not optimized for speed, but
@@ -305,6 +310,8 @@ Name::Name(Buffer& buffer, NameDecompressor& decompressor)
     while (current < buffer.getSize() && !done) {
         c = buffer.readUint8();
         current++;
+        if (!seen_pointer)
+            cused++;
 
         switch (state) {
         case fw_start:
@@ -333,7 +340,11 @@ Name::Name(Buffer& buffer, NameDecompressor& decompressor)
                 /*
                  * Ordinary 14-bit pointer.
                  */
-                throw DNSBadLabelType(); // XXX not implemented
+                if (!decompressor.isAllowed())
+                    throw DNSNameDecompressionProhibited();
+                new_current = c & 0x3F;
+                n = 1;
+                state = fw_newcurrent;
             } else
                 throw DNSBadLabelType();
             break;
@@ -348,6 +359,21 @@ Name::Name(Buffer& buffer, NameDecompressor& decompressor)
                 state = fw_start;
             break;
         case fw_newcurrent:
+            new_current *= 256;
+            new_current += c;
+            n--;
+            if (n != 0)
+                break;
+            if (new_current >= biggest_pointer)
+                throw DNSNameBadPointer();
+            biggest_pointer = new_current;
+            current = new_current;
+            buffer.setCurrent(current);
+            seen_pointer = true;
+            state = fw_start;
+            break;
+
+
             // XXX not implemented, fall through
         default:
             throw ISCUnexpected();
@@ -359,6 +385,7 @@ Name::Name(Buffer& buffer, NameDecompressor& decompressor)
 
     labels_ = labels;
     length_ = nused;
+    buffer.setCurrent(pos_begin + cused);
 }
 
 string

+ 6 - 2
src/lib/dns/name.h

@@ -24,9 +24,13 @@
 
 namespace isc {
 namespace dns {
-// Define them as an empty class for rapid prototyping
+// Define it as an empty class for rapid prototyping
 class NameCompressor {};
-class NameDecompressor {};
+// Define it as an almost-empty class for rapid prototyping
+class NameDecompressor {
+public:
+    bool isAllowed() { return (true); }
+};
 
 class NameComparisonResult {
 public:

+ 73 - 47
src/lib/dns/rrset.cc

@@ -31,6 +31,8 @@
 using std::pair;
 using std::map;
 
+using isc::Buffer;
+using isc::dns::NameDecompressor;
 using isc::dns::RRClass;
 using isc::dns::RRType;
 using isc::dns::TTL;
@@ -127,9 +129,12 @@ TTL::toWire(Buffer& buffer) const
     buffer.writeUint32(ttlval_);
 }
 
-typedef Rdata* (*RdataFactory)(const std::string& text_rdata);
+typedef Rdata* (*TextRdataFactory)(const std::string& text_rdata);
+typedef Rdata* (*WireRdataFactory)(Buffer& buffer,
+                                   NameDecompressor& decompressor);
 typedef pair<RRClass, RRType> RRClassTypePair;
-static map<RRClassTypePair, RdataFactory> rdata_factory_repository;
+static map<RRClassTypePair, TextRdataFactory> text_rdata_factory_repository;
+static map<RRClassTypePair, WireRdataFactory> wire_rdata_factory_repository;
 
 struct RdataFactoryRegister {
 public:
@@ -140,73 +145,91 @@ private:
 
 static RdataFactoryRegister rdata_factory;
 
-Rdata *
-createADataFromText(const std::string& text_rdata)
-{
-    return (new A(text_rdata));
-}
-
-Rdata *
-createAAAADataFromText(const std::string& text_rdata)
-{
-    return (new AAAA(text_rdata));
-}
-
-Rdata *
-createNSDataFromText(const std::string& text_rdata)
+template <typename T>
+Rdata*
+createDataFromText(const std::string& text_rdata)
 {
-    return (new NS(text_rdata));
+    return (new T(text_rdata));
 }
 
-Rdata *
-createTXTDataFromText(const std::string& text_rdata)
+template <typename T>
+Rdata*
+createDataFromWire(Buffer& buffer, NameDecompressor& decompressor)
 {
-    return (new TXT(text_rdata));
+    return (new T(buffer, decompressor));
 }
 
 RdataFactoryRegister::RdataFactoryRegister()
 {
-    rdata_factory_repository.insert(pair<RRClassTypePair, RdataFactory>
-                                    (RRClassTypePair(RRClass::IN, RRType::A),
-                                     createADataFromText));
-    rdata_factory_repository.insert(pair<RRClassTypePair, RdataFactory>
-                                    (RRClassTypePair(RRClass::IN, RRType::AAAA),
-                                     createAAAADataFromText));
+    text_rdata_factory_repository.insert(pair<RRClassTypePair, TextRdataFactory>
+                             (RRClassTypePair(RRClass::IN, RRType::A),
+                              createDataFromText<isc::dns::Rdata::IN::A>));
+    text_rdata_factory_repository.insert(pair<RRClassTypePair, TextRdataFactory>
+                             (RRClassTypePair(RRClass::IN, RRType::AAAA),
+                              createDataFromText<isc::dns::Rdata::IN::AAAA>));
     //XXX: NS/TXT belongs to the 'generic' class.  should revisit it.
-    rdata_factory_repository.insert(pair<RRClassTypePair, RdataFactory>
-                                    (RRClassTypePair(RRClass::IN, RRType::NS),
-                                     createNSDataFromText));
-    rdata_factory_repository.insert(pair<RRClassTypePair, RdataFactory>
-                                    (RRClassTypePair(RRClass::IN, RRType::TXT),
-                                     createTXTDataFromText));
+    text_rdata_factory_repository.insert(pair<RRClassTypePair, TextRdataFactory>
+                             (RRClassTypePair(RRClass::IN, RRType::NS),
+                              createDataFromText<isc::dns::Rdata::Generic::NS>));
+    text_rdata_factory_repository.insert(pair<RRClassTypePair, TextRdataFactory>
+                             (RRClassTypePair(RRClass::IN, RRType::TXT),
+                              createDataFromText<isc::dns::Rdata::Generic::TXT>));
     // XXX: we should treat class-agnostic type accordingly.
-    rdata_factory_repository.insert(pair<RRClassTypePair, RdataFactory>
-                                    (RRClassTypePair(RRClass::CH, RRType::TXT),
-                                     createTXTDataFromText));}
+    text_rdata_factory_repository.insert(pair<RRClassTypePair, TextRdataFactory>
+                             (RRClassTypePair(RRClass::CH, RRType::TXT),
+                              createDataFromText<isc::dns::Rdata::Generic::TXT>));
+
+    wire_rdata_factory_repository.insert(pair<RRClassTypePair, WireRdataFactory>
+                             (RRClassTypePair(RRClass::IN, RRType::A),
+                              createDataFromWire<isc::dns::Rdata::IN::A>));
+    wire_rdata_factory_repository.insert(pair<RRClassTypePair, WireRdataFactory>
+                             (RRClassTypePair(RRClass::IN, RRType::AAAA),
+                              createDataFromWire<isc::dns::Rdata::IN::AAAA>));
+    wire_rdata_factory_repository.insert(pair<RRClassTypePair, WireRdataFactory>
+                             (RRClassTypePair(RRClass::IN, RRType::NS),
+                              createDataFromWire<isc::dns::Rdata::Generic::NS>));
+}
 
 Rdata *
 Rdata::fromText(const RRClass& rrclass, const RRType& rrtype,
                 const std::string& text_rdata)
 {
-    map<RRClassTypePair, RdataFactory>::const_iterator entry;
-    entry = rdata_factory_repository.find(RRClassTypePair(rrclass, rrtype));
-    if (entry != rdata_factory_repository.end()) {
+    map<RRClassTypePair, TextRdataFactory>::const_iterator entry;
+    entry = text_rdata_factory_repository.find(RRClassTypePair(rrclass,
+                                                               rrtype));
+    if (entry != text_rdata_factory_repository.end()) {
         return (entry->second(text_rdata));
     }
 
     throw DNSInvalidRRType();
 }
 
+Rdata *
+Rdata::fromWire(const RRClass& rrclass, const RRType& rrtype,
+                Buffer& buffer, NameDecompressor& decompressor)
+{
+    map<RRClassTypePair, WireRdataFactory>::const_iterator entry;
+    entry = wire_rdata_factory_repository.find(RRClassTypePair(rrclass,
+                                                               rrtype));
+    if (entry != wire_rdata_factory_repository.end()) {
+        return (entry->second(buffer, decompressor));
+    }
+
+    throw DNSInvalidRRType();
+}
+
 A::A(const std::string& addrstr)
 {
     if (inet_pton(AF_INET, addrstr.c_str(), &addr_) != 1)
         throw ISCInvalidAddressString();
 }
 
-void
-A::fromWire(Buffer& buffer, NameDecompressor& decompressor)
+A::A(Buffer& buffer, NameDecompressor& decompressor)
 {
-    //TBD
+    size_t len = buffer.readUint16();
+    if (len != sizeof(addr_))
+        throw DNSInvalidRdata();
+    buffer.readData(&addr_, sizeof(addr_));
 }
 
 void
@@ -239,10 +262,11 @@ AAAA::AAAA(const std::string& addrstr)
         throw ISCInvalidAddressString();
 }
 
-void
-AAAA::fromWire(Buffer& buffer, NameDecompressor& decompressor)
+AAAA::AAAA(Buffer& buffer, NameDecompressor& decompressor)
 {
-    //TBD
+    if (buffer.readUint16() != sizeof(addr_))
+        throw DNSInvalidRdata();
+    buffer.readData(&addr_, sizeof(addr_));
 }
 
 void
@@ -269,10 +293,12 @@ AAAA::copy() const
     return (new AAAA(toText()));
 }
 
-void
-NS::fromWire(Buffer& buffer, NameDecompressor& decompressor)
+NS::NS(Buffer& buffer, NameDecompressor& decompressor)
 {
-    //TBD
+    size_t len = buffer.readUint16();
+    nsname_ = Name(buffer, decompressor);
+    if (nsname_.getLength() < len)
+        throw DNSInvalidRdata();
 }
 
 void

+ 6 - 5
src/lib/dns/rrset.h

@@ -131,14 +131,15 @@ public:
     virtual unsigned int count() const = 0;
     virtual const RRType& getType() const = 0;
     virtual std::string toText() const = 0;
-    virtual void fromWire(Buffer& b, NameDecompressor& c) = 0;
     virtual void toWire(Buffer& b, NameCompressor& c) const = 0;
     // need generic method for getting n-th field? c.f. ldns
     // e.g. string getField(int n);
 
-    // A semi polymorphic factory.
+    // semi-polymorphic factories.
     static Rdata* fromText(const RRClass& rrclass, const RRType& rrtype,
                            const std::string& text_rdata);
+    static Rdata* fromWire(const RRClass& rrclass, const RRType& rrtype,
+                           Buffer& b, NameDecompressor& d);
 
     // polymorphic copy constructor (XXX should revisit it)
     virtual Rdata* copy() const = 0;
@@ -150,11 +151,11 @@ public:
     NS() {}
     explicit NS(const std::string& namestr) : nsname_(namestr) {}
     explicit NS(const Name& nsname) : nsname_(nsname) {}
+    explicit NS(Buffer& buffer, NameDecompressor& decompressor);
     unsigned int count() const { return (1); }
     const RRType& getType() const { return (RRType::NS); }
     static const RRType& getTypeStatic() { return (RRType::NS); }
     std::string toText() const;
-    void fromWire(Buffer& b, NameDecompressor& c);
     void toWire(Buffer& b, NameCompressor& c) const;
     const std::string getNsname() const { return (nsname_.toText(false)); }
     bool operator==(const NS &other) const
@@ -194,11 +195,11 @@ public:
     A() {}
     // constructor from a textual IPv4 address
     explicit A(const std::string& addrstr);
+    explicit A(Buffer& buffer, NameDecompressor& decompressor);
     unsigned int count() const { return (1); }
     const RRType& getType() const { return (RRType::A); }
     static const RRType& getTypeStatic() { return (RRType::A); }
     std::string toText() const;
-    void fromWire(Buffer& b, NameDecompressor& c);
     void toWire(Buffer& b, NameCompressor& c) const;
     const struct in_addr& getAddress() const { return (addr_); }
     bool operator==(const A &other) const
@@ -216,11 +217,11 @@ public:
     AAAA() {}
     // constructor from a textual IPv6 address
     explicit AAAA(const std::string& addrstr);
+    explicit AAAA(Buffer& buffer, NameDecompressor& decompressor);
     unsigned int count() const { return (1); }
     std::string toText() const;
     const RRType& getType() const { return (RRType::AAAA); }
     static const RRType& getTypeStatic() { return (RRType::AAAA); }
-    void fromWire(Buffer& b, NameDecompressor& c);
     void toWire(Buffer& b, NameCompressor& c) const;
     const struct in6_addr& getAddress() const { return (addr_); }
     bool operator==(const AAAA &other) const