Parcourir la source

tweaked the code to handle the EDNS BADVERS case.
style guideline conformance about the position of opening curly brace.


git-svn-id: svn://bind10.isc.org/svn/bind10/trunk@1357 e5f2f494-b856-4b98-b285-d166d9295462

JINMEI Tatuya il y a 15 ans
Parent
commit
32bf6c47a7
2 fichiers modifiés avec 79 ajouts et 109 suppressions
  1. 78 109
      src/lib/dns/message.cc
  2. 1 0
      src/lib/dns/message.h

+ 78 - 109
src/lib/dns/message.cc

@@ -162,21 +162,18 @@ static const char *sectiontext[] = {
 }
 
 string
-Opcode::toText() const
-{
+Opcode::toText() const {
     return (opcodetext[code_]);
 }
 
-Rcode::Rcode(uint16_t code) : code_(code)
-{
+Rcode::Rcode(uint16_t code) : code_(code) {
     if (code_ > MAX_RCODE) {
         isc_throw(OutOfRange, "Rcode is too large to construct");
     }
 }
 
 string
-Rcode::toText() const
-{
+Rcode::toText() const {
     if (code_ < sizeof(rcodetext) / sizeof (const char *)) {
         return (rcodetext[code_]);
     }
@@ -188,8 +185,7 @@ Rcode::toText() const
 
 namespace {
 inline unsigned int
-sectionCodeToId(const Section& section)
-{
+sectionCodeToId(const Section& section) {
     unsigned int code = section.getCode();
     assert(code > 0);
     return (section.getCode() - 1);
@@ -235,8 +231,7 @@ MessageImpl::MessageImpl(Message::Mode mode) :
 }
 
 void
-MessageImpl::init()
-{
+MessageImpl::init() {
     flags_ = 0;
     qid_ = 0;
     rcode_ = Rcode::NOERROR();  // XXX
@@ -260,41 +255,34 @@ MessageImpl::init()
 
 Message::Message(Mode mode) :
     impl_(new MessageImpl(mode))
-{
-}
+{}
 
-Message::~Message()
-{
+Message::~Message() {
     delete impl_;
 }
 
 bool
-Message::getHeaderFlag(const MessageFlag& flag) const
-{
+Message::getHeaderFlag(const MessageFlag& flag) const {
     return ((impl_->flags_ & flag.getBit()) != 0);
 }
 
 void
-Message::setHeaderFlag(const MessageFlag& flag)
-{
+Message::setHeaderFlag(const MessageFlag& flag) {
     impl_->flags_ |= flag.getBit();
 }
 
 void
-Message::clearHeaderFlag(const MessageFlag& flag)
-{
+Message::clearHeaderFlag(const MessageFlag& flag) {
     impl_->flags_ &= ~flag.getBit();
 }
 
 bool
-Message::isDNSSECSupported() const
-{
+Message::isDNSSECSupported() const {
     return (impl_->dnssec_ok_);
 }
 
 void
-Message::setDNSSECSupported(bool on)
-{
+Message::setDNSSECSupported(bool on) {
     if (impl_->mode_ != Message::RENDER) {
         isc_throw(InvalidMessageOperation,
                   "setDNSSECSupported performed in non-render mode");
@@ -303,14 +291,12 @@ Message::setDNSSECSupported(bool on)
 }
 
 uint16_t
-Message::getUDPSize() const
-{
+Message::getUDPSize() const {
     return (impl_->udpsize_);
 }
 
 void
-Message::setUDPSize(uint16_t size)
-{
+Message::setUDPSize(uint16_t size) {
     if (impl_->mode_ != Message::RENDER) {
         isc_throw(InvalidMessageOperation,
                   "setUDPSize performed in non-render mode");
@@ -323,50 +309,42 @@ Message::setUDPSize(uint16_t size)
 }
 
 qid_t
-Message::getQid() const
-{
+Message::getQid() const {
     return (impl_->qid_);
 }
 
 void
-Message::setQid(qid_t qid)
-{
+Message::setQid(qid_t qid) {
     impl_->qid_ = qid;
 }
 
 const Rcode&
-Message::getRcode() const
-{
+Message::getRcode() const {
     return (impl_->rcode_);
 }
 
 void
-Message::setRcode(const Rcode& rcode)
-{
+Message::setRcode(const Rcode& rcode) {
     impl_->rcode_ = rcode;
 }
 
 const Opcode&
-Message::getOpcode() const
-{
+Message::getOpcode() const {
     return (*impl_->opcode_);
 }
 
 void
-Message::setOpcode(const Opcode& opcode)
-{
+Message::setOpcode(const Opcode& opcode) {
     impl_->opcode_ = &opcode;
 }
 
 unsigned int
-Message::getRRCount(const Section& section) const
-{
+Message::getRRCount(const Section& section) const {
     return (impl_->counts_[section.getCode()]);
 }
 
 void
-Message::addRRset(const Section& section, RRsetPtr rrset, bool sign)
-{
+Message::addRRset(const Section& section, RRsetPtr rrset, bool sign) {
     // Note: should check duplicate (TBD)
     impl_->rrsets_[sectionCodeToId(section)].push_back(rrset);
     impl_->counts_[section.getCode()] += rrset->getRdataCount();
@@ -379,22 +357,19 @@ Message::addRRset(const Section& section, RRsetPtr rrset, bool sign)
 }
 
 void
-Message::addQuestion(const QuestionPtr question)
-{
+Message::addQuestion(const QuestionPtr question) {
     impl_->questions_.push_back(question);
     impl_->counts_[Section::QUESTION().getCode()]++;
 }
 
 void
-Message::addQuestion(const Question& question)
-{
+Message::addQuestion(const Question& question) {
     addQuestion(QuestionPtr(new Question(question)));
 }
 
 namespace {
 template <typename T>
-struct RenderSection
-{
+struct RenderSection {
     RenderSection(MessageRenderer& renderer, const bool partial_ok) :
         counter_(0), renderer_(renderer), partial_ok_(partial_ok),
         truncated_(false)
@@ -425,23 +400,26 @@ struct RenderSection
 
 namespace {
 bool
-addEDNS(MessageImpl* mimpl, MessageRenderer& renderer)
-{
-    bool is_query = ((mimpl->flags_ & MessageFlag::QR().getBit()) == 0); 
-
-    // If this is a reply and the request didn't have EDNS, we shouldn't add it.
-    if (mimpl->remote_edns_ == NULL && !is_query) {
-        return (false);
-    }
-
-    // For queries, we add EDNS only when necessary:
-    // Local UDP size is not the default value, or
-    // DNSSEC DO bit is to be set, or
-    // Extended Rcode is to be specified.
-    if (is_query && mimpl->udpsize_ == Message::DEFAULT_MAX_UDPSIZE &&
-        !mimpl->dnssec_ok_ &&
-        mimpl->rcode_.getCode() < 0x10) {
-        return (false);
+addEDNS(MessageImpl* mimpl, MessageRenderer& renderer) {
+    const bool is_query = ((mimpl->flags_ & MessageFlag::QR().getBit()) == 0); 
+
+    // If this is a reply, add EDNS either when the request had it, or
+    // if the Rcode is BADVERS, which is EDNS specific.
+    // XXX: this logic is tricky.  We should revisit this later.
+    if (!is_query) {
+        if (mimpl->remote_edns_ == NULL && mimpl->rcode_ != Rcode::BADVERS()) {
+            return (false);
+        }
+    } else {
+        // For queries, we add EDNS only when necessary:
+        // Local UDP size is not the default value, or
+        // DNSSEC DO bit is to be set, or
+        // Extended Rcode is to be specified.
+        if (mimpl->udpsize_ == Message::DEFAULT_MAX_UDPSIZE &&
+            !mimpl->dnssec_ok_ &&
+            mimpl->rcode_.getCode() < 0x10) {
+            return (false);
+        }
     }
 
     // If adding the OPT RR would exceed the size limit, don't do it.
@@ -469,8 +447,7 @@ addEDNS(MessageImpl* mimpl, MessageRenderer& renderer)
 }
 
 void
-Message::toWire(MessageRenderer& renderer)
-{
+Message::toWire(MessageRenderer& renderer) {
     uint16_t codes_and_flags;
 
     // reserve room for the header
@@ -510,6 +487,16 @@ Message::toWire(MessageRenderer& renderer)
         ++arcount;
     }
 
+    // Adjust the counter buffer.
+    // XXX: these may not be equal to the number of corresponding entries
+    // in rrsets_[] or questions_ if truncation occurred or an EDNS OPT RR
+    // was inserted.  This is not good, and we should revisit the entire
+    // design.
+    impl_->counts_[Section::QUESTION().getCode()] = qdcount;
+    impl_->counts_[Section::ANSWER().getCode()] = ancount;
+    impl_->counts_[Section::AUTHORITY().getCode()] = nscount;
+    impl_->counts_[Section::ADDITIONAL().getCode()] = arcount;
+
     // TBD: TSIG, SIG(0) etc.
 
     // fill in the header
@@ -662,7 +649,10 @@ MessageImpl::parseSection(const Section& section, InputBuffer& buffer) {
                 // This is probably because why BIND 9 does the version check
                 // in the client code.
                 // This is a TODO item.  Right now we simply reject it.
-                isc_throw(DNSMessageBADVERS, "unsupported EDNS version");
+                const unsigned int ver =
+                    (ttl.getValue() & EDNSVERSION_MASK) >> 16;
+                isc_throw(DNSMessageBADVERS, "unsupported EDNS version: " <<
+                          ver);
             }
             if (name != Name::ROOT_NAME()) {
                 isc_throw(DNSMessageFORMERR,
@@ -701,8 +691,7 @@ MessageImpl::parseSection(const Section& section, InputBuffer& buffer) {
 
 namespace {
 template <typename T>
-struct SectionFormatter
-{
+struct SectionFormatter {
     SectionFormatter(const Section& section, string& output) :
         section_(section), output_(output) {}
     void operator()(const T& entry)
@@ -718,8 +707,7 @@ struct SectionFormatter
 }
 
 string
-Message::toText() const
-{
+Message::toText() const {
     string s;
 
     s += ";; ->>HEADER<<- opcode: " + impl_->opcode_->toText();
@@ -755,9 +743,6 @@ Message::toText() const
     if (!getHeaderFlag(MessageFlag::QR()) && impl_->remote_edns_ != NULL) {
         edns_rrset = impl_->remote_edns_;
         ++arcount;
-    } else if (getHeaderFlag(MessageFlag::QR()) && impl_->local_edns_ != NULL) {
-        edns_rrset = impl_->local_edns_;
-        ++arcount;
     }
     s += ", ADDITIONAL: " + lexical_cast<string>(arcount) + "\n";
 
@@ -813,8 +798,7 @@ Message::toText() const
 }
 
 void
-Message::clear(Mode mode)
-{
+Message::clear(Mode mode) {
     impl_->init();
     impl_->mode_ = mode;
 }
@@ -856,8 +840,7 @@ struct SectionIteratorImpl {
 };
 
 template <typename T>
-SectionIterator<T>::SectionIterator(const SectionIteratorImpl<T>& impl)
-{
+SectionIterator<T>::SectionIterator(const SectionIteratorImpl<T>& impl) {
     impl_ = new SectionIteratorImpl<T>(impl.it_);
 }
 
@@ -874,8 +857,7 @@ SectionIterator<T>::SectionIterator(const SectionIterator<T>& source) :
 
 template <typename T>
 void
-SectionIterator<T>::operator=(const SectionIterator<T>& source)
-{
+SectionIterator<T>::operator=(const SectionIterator<T>& source) {
     if (impl_ == source.impl_) {
         return;
     }
@@ -887,16 +869,14 @@ SectionIterator<T>::operator=(const SectionIterator<T>& source)
 
 template <typename T>
 SectionIterator<T>&
-SectionIterator<T>::operator++()
-{
+SectionIterator<T>::operator++() {
     ++(impl_->it_);
     return (*this);
 }
 
 template <typename T>
 SectionIterator<T>
-SectionIterator<T>::operator++(int)
-{
+SectionIterator<T>::operator++(int) {
     SectionIterator<T> tmp(*this);
     ++(*this);
     return (tmp);
@@ -904,29 +884,25 @@ SectionIterator<T>::operator++(int)
 
 template <typename T>
 const T&
-SectionIterator<T>::operator*() const
-{
+SectionIterator<T>::operator*() const {
     return (*(impl_->it_));
 }
 
 template <typename T>
 const T*
-SectionIterator<T>::operator->() const
-{
+SectionIterator<T>::operator->() const {
     return (impl_->it_.operator->());
 }
 
 template <typename T>
 bool
-SectionIterator<T>::operator==(const SectionIterator<T>& other) const
-{
+SectionIterator<T>::operator==(const SectionIterator<T>& other) const {
     return (impl_->it_ == other.impl_->it_);
 }
 
 template <typename T>
 bool
-SectionIterator<T>::operator!=(const SectionIterator<T>& other) const
-{
+SectionIterator<T>::operator!=(const SectionIterator<T>& other) const {
     return (impl_->it_ != other.impl_->it_);
 }
 
@@ -946,14 +922,12 @@ typedef SectionIteratorImpl<RRsetPtr> RRsetIteratorImpl;
 /// Question iterator
 ///
 const QuestionIterator
-Message::beginQuestion() const
-{
+Message::beginQuestion() const {
     return (QuestionIterator(QuestionIteratorImpl(impl_->questions_.begin())));
 }
 
 const QuestionIterator
-Message::endQuestion() const
-{
+Message::endQuestion() const {
     return (QuestionIterator(QuestionIteratorImpl(impl_->questions_.end())));
 }
 
@@ -961,8 +935,7 @@ Message::endQuestion() const
 /// RRsets iterators
 ///
 const SectionIterator<RRsetPtr>
-Message::beginSection(const Section& section) const
-{
+Message::beginSection(const Section& section) const {
     if (section == Section::QUESTION()) {
         isc_throw(InvalidMessageSection,
                   "RRset iterator is requested for question");
@@ -974,8 +947,7 @@ Message::beginSection(const Section& section) const
 }
 
 const SectionIterator<RRsetPtr>
-Message::endSection(const Section& section) const
-{
+Message::endSection(const Section& section) const {
     if (section == Section::QUESTION()) {
         isc_throw(InvalidMessageSection,
                   "RRset iterator is requested for question");
@@ -987,20 +959,17 @@ Message::endSection(const Section& section) const
 }
 
 ostream&
-operator<<(ostream& os, const Opcode& opcode)
-{
+operator<<(ostream& os, const Opcode& opcode) {
     return (os << opcode.toText());
 }
 
 ostream&
-operator<<(ostream& os, const Rcode& rcode)
-{
+operator<<(ostream& os, const Rcode& rcode) {
     return (os << rcode.toText());
 }
 
 ostream&
-operator<<(ostream& os, const Message& message)
-{
+operator<<(ostream& os, const Message& message) {
     return (os << message.toText());
 }
 

+ 1 - 0
src/lib/dns/message.h

@@ -276,6 +276,7 @@ public:
     Rcode(uint16_t code);
     uint16_t getCode() const { return (code_); }
     bool operator==(const Rcode& other) const { return (code_ == other.code_); }
+    bool operator!=(const Rcode& other) const { return (code_ != other.code_); }
     std::string toText() const;
     static const Rcode& NOERROR();
     static const Rcode& FORMERR();