Browse Source

added preliminary level truncation support

git-svn-id: svn://bind10.isc.org/svn/bind10/trunk@1199 e5f2f494-b856-4b98-b285-d166d9295462
JINMEI Tatuya 15 years ago
parent
commit
85dc228ffd

+ 1 - 0
src/bin/auth/auth_srv.cc

@@ -124,6 +124,7 @@ AuthSrv::processMessage(const int fd)
 
         OutputBuffer obuffer(remote_bufsize);
         MessageRenderer renderer(obuffer);
+        renderer.setLengthLimit(remote_bufsize);
         msg.toWire(renderer);
         cout << "sending a response (" <<
             boost::lexical_cast<string>(obuffer.getLength())

+ 9 - 1
src/lib/dns/buffer.h

@@ -306,7 +306,7 @@ public:
     /// exception class of \c InvalidBufferPosition will be thrown.
     ///
     /// \param pos The position in the buffer to be returned.
-    uint8_t operator[](size_t pos) const
+    const uint8_t& operator[](size_t pos) const
     {
         if (pos >= data_.size()) {
             isc_throw(InvalidBufferPosition, "read at invalid position");
@@ -326,6 +326,14 @@ public:
     /// that is to be filled in later, e.g, by \ref writeUint16At().
     /// \param len The length of the gap to be inserted in bytes.
     void skip(size_t len) { data_.insert(data_.end(), len, 0); }
+    /// \brief TBD
+    void trim(size_t len)
+    {
+        if (len > data_.size()) {
+            isc_throw(OutOfRange, "trimming too large from output buffer");
+        }
+        data_.resize(data_.size() - len);
+    }
     /// \brief Clear buffer content.
     ///
     /// This method can be used to re-initialize and reuse the buffer without

+ 57 - 19
src/lib/dns/message.cc

@@ -386,17 +386,31 @@ namespace {
 template <typename T>
 struct RenderSection
 {
-    RenderSection(MessageRenderer& renderer) :
-        counter_(0), renderer_(renderer) {}
+    RenderSection(MessageRenderer& renderer, const bool partial_ok) :
+        counter_(0), renderer_(renderer), partial_ok_(partial_ok_),
+        truncated_(false)
+    {}
     void operator()(const T& entry)
     {
-        // TBD: if truncation is necessary, do something special.
-        // throw an exception, set an internal flag, etc.
+        // If it's already truncated, ignore the rest of the section.
+        if (truncated_) {
+            return;
+        }
+        size_t pos0 = renderer_.getLength();
         counter_ += entry->toWire(renderer_);
+        if (renderer_.isTruncated()) {
+            truncated_ = true;
+            if (!partial_ok_) {
+                // roll back to the end of the previous RRset.
+                renderer_.trim(renderer_.getLength() - pos0);
+            }
+        }
     }
     unsigned int getTotalCount() { return (counter_); }
     unsigned int counter_;
     MessageRenderer& renderer_;
+    const bool partial_ok_;
+    bool truncated_;
 };
 }
 
@@ -421,6 +435,13 @@ addEDNS(MessageImpl* mimpl, MessageRenderer& renderer)
         return (false);
     }
 
+    // If adding the OPT RR would exceed the size limit, don't do it.
+    // 11 = len(".") + type(2byte) + class(2byte) + TTL(4byte) + RDLEN(2byte)
+    // (RDATA is empty in this simple implementation)
+    if (renderer.getLength() + 11 > renderer.getLengthLimit()) {
+        return (false);
+    }
+
     // Render EDNS OPT RR
     uint32_t extrcode_flags = ((mimpl->rcode_.getCode() & 0xff0) << 24);
     if (mimpl->dnssec_ok_) {
@@ -446,31 +467,41 @@ Message::toWire(MessageRenderer& renderer)
     // reserve room for the header
     renderer.skip(HEADERLEN);
 
+    uint16_t ancount = 0, nscount = 0, arcount = 0;
+
     uint16_t qdcount =
         for_each(impl_->questions_.begin(), impl_->questions_.end(),
-                 RenderSection<QuestionPtr>(renderer)).getTotalCount();
+                 RenderSection<QuestionPtr>(renderer, false)).getTotalCount();
 
     // TBD: sort RRsets in each section based on configuration policy.
-    uint16_t ancount =
-        for_each(impl_->rrsets_[sectionCodeToId(Section::ANSWER())].begin(),
-                 impl_->rrsets_[sectionCodeToId(Section::ANSWER())].end(),
-                 RenderSection<RRsetPtr>(renderer)).getTotalCount();
-    uint16_t nscount =
-        for_each(impl_->rrsets_[sectionCodeToId(Section::AUTHORITY())].begin(),
-                 impl_->rrsets_[sectionCodeToId(Section::AUTHORITY())].end(),
-                 RenderSection<RRsetPtr>(renderer)).getTotalCount();
-    uint16_t arcount =
-        for_each(impl_->rrsets_[sectionCodeToId(Section::ADDITIONAL())].begin(),
-                 impl_->rrsets_[sectionCodeToId(Section::ADDITIONAL())].end(),
-                 RenderSection<RRsetPtr>(renderer)).getTotalCount();
+    if (!renderer.isTruncated()) {
+        ancount =
+            for_each(impl_->rrsets_[sectionCodeToId(Section::ANSWER())].begin(),
+                     impl_->rrsets_[sectionCodeToId(Section::ANSWER())].end(),
+                     RenderSection<RRsetPtr>(renderer, true)).getTotalCount();
+    }
+    if (!renderer.isTruncated()) {
+        nscount =
+            for_each(impl_->rrsets_[sectionCodeToId(Section::AUTHORITY())].begin(),
+                     impl_->rrsets_[sectionCodeToId(Section::AUTHORITY())].end(),
+                     RenderSection<RRsetPtr>(renderer, true)).getTotalCount();
+    }
+    if (renderer.isTruncated()) {
+        setHeaderFlag(MessageFlag::TC());
+    } else {
+        arcount =
+            for_each(impl_->rrsets_[sectionCodeToId(Section::ADDITIONAL())].begin(),
+                     impl_->rrsets_[sectionCodeToId(Section::ADDITIONAL())].end(),
+                     RenderSection<RRsetPtr>(renderer, false)).getTotalCount();
+    }
 
     // Added EDNS OPT RR if necessary (we want to avoid hardcoding specialized
     // logic, see the parser case)
-    if (addEDNS(this->impl_, renderer)) {
+    if (!renderer.isTruncated() && addEDNS(this->impl_, renderer)) {
         ++arcount;
     }
 
-    // TBD: EDNS, TSIG, etc.
+    // TBD: TSIG, SIG(0) etc.
 
     // fill in the header
     size_t header_pos = 0;
@@ -767,6 +798,13 @@ Message::clear()
 }
 
 void
+Message::clear(Mode mode)
+{
+    impl_->init();
+    impl_->mode_ = mode;
+}
+
+void
 Message::makeResponse()
 {
     if (impl_->mode_ != Message::PARSE) {

+ 1 - 0
src/lib/dns/message.h

@@ -564,6 +564,7 @@ public:
     //void removeRR(const Section& section, const RR& rr);
 
     void clear();
+    void clear(Mode mode);
 
     // prepare for making a response from a request.  This will clear the
     // DNS header except those fields that should be kept for the response,

+ 37 - 1
src/lib/dns/messagerenderer.cc

@@ -135,7 +135,9 @@ struct MessageRendererImpl {
     /// \param buffer An \c OutputBuffer object to which wire format data is
     /// written.
     MessageRendererImpl(OutputBuffer& buffer) :
-        buffer_(buffer), nbuffer_(Name::MAX_WIRE) {}
+        buffer_(buffer), nbuffer_(Name::MAX_WIRE), msglength_limit_(512),
+        truncated_(false)
+    {}
     /// The buffer that holds the entire DNS message.
     OutputBuffer& buffer_;
     /// A local working buffer to convert each given name into wire format.
@@ -145,6 +147,10 @@ struct MessageRendererImpl {
     OutputBuffer nbuffer_;
     /// A set of compression pointers.
     std::set<NameCompressNode, NameCompare> nodeset_;
+
+    /// TBD
+    uint16_t msglength_limit_;
+    bool truncated_;
 };
 
 MessageRenderer::MessageRenderer(OutputBuffer& buffer) :
@@ -163,6 +169,12 @@ MessageRenderer::skip(size_t len)
 }
 
 void
+MessageRenderer::trim(size_t len)
+{
+    impl_->buffer_.trim(len);
+}
+
+void
 MessageRenderer::clear()
 {
     impl_->buffer_.clear();
@@ -212,6 +224,30 @@ MessageRenderer::getLength() const
     return (impl_->buffer_.getLength());
 }
 
+size_t
+MessageRenderer::getLengthLimit() const
+{
+    return (impl_->msglength_limit_);
+}
+
+void
+MessageRenderer::setLengthLimit(size_t len)
+{
+    impl_->msglength_limit_ = len;
+}
+
+bool
+MessageRenderer::isTruncated() const
+{
+    return (impl_->truncated_);
+}
+
+void
+MessageRenderer::setTruncated()
+{
+    impl_->truncated_ = true;
+}
+
 void
 MessageRenderer::writeName(const Name& name, bool compress)
 {

+ 20 - 0
src/lib/dns/messagerenderer.h

@@ -99,6 +99,23 @@ public:
     const void* getData() const;
     /// \brief Return the length of data written in the internal buffer.
     size_t getLength() const;
+
+    /// \brief TBD
+    bool isTruncated() const;
+
+    /// \brief TBD
+    size_t getLengthLimit() const;
+    //@}
+
+    ///
+    /// \name Setter Methods
+    ///
+    //@{
+    /// \brief TBD
+    void setLengthLimit(size_t len);
+
+    /// \brief TBD
+    void setTruncated();
     //@}
 
     ///
@@ -113,6 +130,9 @@ public:
     ///
     /// \param len The length of the gap to be inserted in bytes.
     void skip(size_t len);
+
+    /// \brief TBD
+    void trim(size_t len);
     /// \brief Clear the internal buffer and other internal resources.
     ///
     /// This method can be used to re-initialize and reuse the renderer

+ 18 - 4
src/lib/dns/rrset.cc

@@ -64,7 +64,7 @@ AbstractRRset::toText() const
 namespace {
 template <typename T>
 inline unsigned int
-rrsetToWire(const AbstractRRset& rrset, T& output)
+rrsetToWire(const AbstractRRset& rrset, T& output, const size_t limit)
 {
     unsigned int n = 0;
     RdataIteratorPtr it = rrset.getRdataIterator();
@@ -77,16 +77,25 @@ rrsetToWire(const AbstractRRset& rrset, T& output)
     // sort the set of Rdata based on rrset-order and sortlist, and possible
     // other options.  Details to be considered.
     do {
+        const size_t pos0 = output.getLength();
+        assert(pos0 < 65536);
+
         rrset.getName().toWire(output);
         rrset.getType().toWire(output);
         rrset.getClass().toWire(output);
         rrset.getTTL().toWire(output);
 
-        size_t pos = output.getLength();
+        const size_t pos = output.getLength();
         output.skip(sizeof(uint16_t)); // leave the space for RDLENGTH
         it->getCurrent().toWire(output);
         output.writeUint16At(output.getLength() - pos - sizeof(uint16_t), pos);
 
+        if (limit > 0 && output.getLength() > limit) {
+            // truncation is needed
+            output.trim(output.getLength() - pos0);
+            return (n);
+        }
+
         it->next();
         ++n;
     } while (!it->isLast());
@@ -98,13 +107,18 @@ rrsetToWire(const AbstractRRset& rrset, T& output)
 unsigned int
 AbstractRRset::toWire(OutputBuffer& buffer) const
 {
-    return (rrsetToWire<OutputBuffer>(*this, buffer));
+    return (rrsetToWire<OutputBuffer>(*this, buffer, 0));
 }
 
 unsigned int
 AbstractRRset::toWire(MessageRenderer& renderer) const
 {
-    return (rrsetToWire<MessageRenderer>(*this, renderer));
+    const unsigned int rrs_written = rrsetToWire<MessageRenderer>(
+        *this, renderer, renderer.getLengthLimit());
+    if (getRdataCount() > rrs_written) {
+        renderer.setTruncated();
+    }
+    return (rrs_written);
 }
 
 ostream&

+ 18 - 0
src/lib/dns/tests/buffer_unittest.cc

@@ -14,10 +14,14 @@
 
 // $Id$
 
+#include <exceptions/exceptions.h>
+
 #include <dns/buffer.h>
 
 #include <gtest/gtest.h>
 
+using namespace isc;
+
 namespace {
 
 using isc::dns::InputBuffer;
@@ -158,6 +162,20 @@ TEST_F(BufferTest, outputBufferSkip)
     EXPECT_EQ(6, obuffer.getLength());
 }
 
+TEST_F(BufferTest, outputBufferTrim)
+{
+    obuffer.writeData(testdata, sizeof(testdata));
+    EXPECT_EQ(5, obuffer.getLength());
+
+    obuffer.trim(1);
+    EXPECT_EQ(4, obuffer.getLength());
+
+    obuffer.trim(2);
+    EXPECT_EQ(2, obuffer.getLength());
+
+    EXPECT_THROW(obuffer.trim(3), OutOfRange);
+}
+
 TEST_F(BufferTest, outputBufferReadat)
 {
     obuffer.writeData(testdata, sizeof(testdata));