Browse Source

[trac812next] implemented TSIG signing main part: add TSIGRecord::toWire() and have the Message class use it with a TSIGContext.
(There are some other small cleanups in this commit)

JINMEI Tatuya 14 years ago
parent
commit
53d9f46923

+ 141 - 120
src/lib/dns/message.cc

@@ -15,6 +15,7 @@
 #include <stdint.h>
 
 #include <algorithm>
+#include <cassert>
 #include <string>
 #include <sstream>
 #include <vector>
@@ -40,6 +41,7 @@
 #include <dns/rrtype.h>
 #include <dns/rrttl.h>
 #include <dns/rrset.h>
+#include <dns/tsig.h>
 
 using namespace std;
 using namespace boost;
@@ -123,6 +125,7 @@ public:
     void setRcode(const Rcode& rcode);
     int parseQuestion(InputBuffer& buffer);
     int parseSection(const Message::Section section, InputBuffer& buffer);
+    void toWire(MessageRenderer& renderer, TSIGContext* tsig_ctx);
 };
 
 MessageImpl::MessageImpl(Message::Mode mode) :
@@ -164,6 +167,139 @@ MessageImpl::setRcode(const Rcode& rcode) {
     rcode_ = &rcode_placeholder_;
 }
 
+namespace {
+template <typename T>
+struct RenderSection {
+    RenderSection(MessageRenderer& renderer, const bool partial_ok) :
+        counter_(0), renderer_(renderer), partial_ok_(partial_ok),
+        truncated_(false)
+    {}
+    void operator()(const T& entry) {
+        // If it's already truncated, ignore the rest of the section.
+        if (truncated_) {
+            return;
+        }
+        const size_t pos0 = renderer_.getLength();
+        counter_ += entry->toWire(renderer_);
+        if (renderer_.isTruncated()) {
+            truncated_ = true;
+            if (!partial_ok_) {
+                // roll back to the end of the previous RRset.
+                renderer_.trim(renderer_.getLength() - pos0);
+            }
+        }
+    }
+    unsigned int getTotalCount() { return (counter_); }
+    unsigned int counter_;
+    MessageRenderer& renderer_;
+    const bool partial_ok_;
+    bool truncated_;
+};
+}
+
+void
+MessageImpl::toWire(MessageRenderer& renderer, TSIGContext* tsig_ctx) {
+    if (mode_ != Message::RENDER) {
+        isc_throw(InvalidMessageOperation,
+                  "Message rendering attempted in non render mode");
+    }
+    if (rcode_ == NULL) {
+        isc_throw(InvalidMessageOperation,
+                  "Message rendering attempted without Rcode set");
+    }
+    if (opcode_ == NULL) {
+        isc_throw(InvalidMessageOperation,
+                  "Message rendering attempted without Opcode set");
+    }
+
+    // reserve room for the header
+    renderer.skip(HEADERLEN);
+
+    uint16_t qdcount =
+        for_each(questions_.begin(), questions_.end(),
+                 RenderSection<QuestionPtr>(renderer, false)).getTotalCount();
+
+    // TBD: sort RRsets in each section based on configuration policy.
+    uint16_t ancount = 0;
+    if (!renderer.isTruncated()) {
+        ancount =
+            for_each(rrsets_[Message::SECTION_ANSWER].begin(),
+                     rrsets_[Message::SECTION_ANSWER].end(),
+                     RenderSection<RRsetPtr>(renderer, true)).getTotalCount();
+    }
+    uint16_t nscount = 0;
+    if (!renderer.isTruncated()) {
+        nscount =
+            for_each(rrsets_[Message::SECTION_AUTHORITY].begin(),
+                     rrsets_[Message::SECTION_AUTHORITY].end(),
+                     RenderSection<RRsetPtr>(renderer, true)).getTotalCount();
+    }
+    uint16_t arcount = 0;
+    if (renderer.isTruncated()) {
+        flags_ |= Message::HEADERFLAG_TC;
+    } else {
+        arcount =
+            for_each(rrsets_[Message::SECTION_ADDITIONAL].begin(),
+                     rrsets_[Message::SECTION_ADDITIONAL].end(),
+                     RenderSection<RRsetPtr>(renderer, false)).getTotalCount();
+    }
+
+    // Add EDNS OPT RR if necessary.  Basically, we add it only when EDNS
+    // has been explicitly set.  However, if the RCODE would require it and
+    // no EDNS has been set we generate a temporary local EDNS and use it.
+    if (!renderer.isTruncated()) {
+        ConstEDNSPtr local_edns = edns_;
+        if (!local_edns && rcode_->getExtendedCode() != 0) {
+            local_edns = ConstEDNSPtr(new EDNS());
+        }
+        if (local_edns) {
+            arcount += local_edns->toWire(renderer, rcode_->getExtendedCode());
+        }
+    }
+
+    // Adjust the counter buffer.
+    // XXX: these may not be equal to the number of corresponding entries
+    // in rrsets_[] or questions_ if truncation occurred or an EDNS OPT RR
+    // was inserted.  This is not good, and we should revisit the entire
+    // design.
+    counts_[Message::SECTION_QUESTION] = qdcount;
+    counts_[Message::SECTION_ANSWER] = ancount;
+    counts_[Message::SECTION_AUTHORITY] = nscount;
+    counts_[Message::SECTION_ADDITIONAL] = arcount;
+
+    // fill in the header
+    size_t header_pos = 0;
+    renderer.writeUint16At(qid_, header_pos);
+    header_pos += sizeof(uint16_t);
+
+    uint16_t codes_and_flags =
+        (opcode_->getCode() << OPCODE_SHIFT) & OPCODE_MASK;
+    codes_and_flags |= (rcode_->getCode() & RCODE_MASK);
+    codes_and_flags |= (flags_ & HEADERFLAG_MASK);
+    renderer.writeUint16At(codes_and_flags, header_pos);
+    header_pos += sizeof(uint16_t);
+    // XXX: should avoid repeated pattern (TODO)
+    renderer.writeUint16At(qdcount, header_pos);
+    header_pos += sizeof(uint16_t);
+    renderer.writeUint16At(ancount, header_pos);
+    header_pos += sizeof(uint16_t);
+    renderer.writeUint16At(nscount, header_pos);
+    header_pos += sizeof(uint16_t);
+    renderer.writeUint16At(arcount, header_pos);
+
+    // Add TSIG, if necessary, at the end of the message.
+    // TBD: truncate case consideration
+    if (tsig_ctx != NULL) {
+        tsig_ctx->sign(qid_, renderer.getData(),
+                       renderer.getLength())->toWire(renderer);
+
+        // update the ARCOUNT for the TSIG RR
+        ++arcount;
+        assert(arcount != 0);   // this should never happen for a sane message
+        renderer.writeUint16At(arcount, header_pos);
+    }
+}
+
 Message::Message(Mode mode) :
     impl_(new MessageImpl(mode))
 {}
@@ -363,129 +499,14 @@ Message::addQuestion(const Question& question) {
     addQuestion(QuestionPtr(new Question(question)));
 }
 
-namespace {
-template <typename T>
-struct RenderSection {
-    RenderSection(MessageRenderer& renderer, const bool partial_ok) :
-        counter_(0), renderer_(renderer), partial_ok_(partial_ok),
-        truncated_(false)
-    {}
-    void operator()(const T& entry) {
-        // If it's already truncated, ignore the rest of the section.
-        if (truncated_) {
-            return;
-        }
-        const size_t pos0 = renderer_.getLength();
-        counter_ += entry->toWire(renderer_);
-        if (renderer_.isTruncated()) {
-            truncated_ = true;
-            if (!partial_ok_) {
-                // roll back to the end of the previous RRset.
-                renderer_.trim(renderer_.getLength() - pos0);
-            }
-        }
-    }
-    unsigned int getTotalCount() { return (counter_); }
-    unsigned int counter_;
-    MessageRenderer& renderer_;
-    const bool partial_ok_;
-    bool truncated_;
-};
-}
-
 void
 Message::toWire(MessageRenderer& renderer) {
-    if (impl_->mode_ != Message::RENDER) {
-        isc_throw(InvalidMessageOperation,
-                  "Message rendering attempted in non render mode");
-    }
-    if (impl_->rcode_ == NULL) {
-        isc_throw(InvalidMessageOperation,
-                  "Message rendering attempted without Rcode set");
-    }
-    if (impl_->opcode_ == NULL) {
-        isc_throw(InvalidMessageOperation,
-                  "Message rendering attempted without Opcode set");
-    }
-
-    // reserve room for the header
-    renderer.skip(HEADERLEN);
-
-    uint16_t qdcount =
-        for_each(impl_->questions_.begin(), impl_->questions_.end(),
-                 RenderSection<QuestionPtr>(renderer, false)).getTotalCount();
-
-    // TBD: sort RRsets in each section based on configuration policy.
-    uint16_t ancount = 0;
-    if (!renderer.isTruncated()) {
-        ancount =
-            for_each(impl_->rrsets_[SECTION_ANSWER].begin(),
-                     impl_->rrsets_[SECTION_ANSWER].end(),
-                     RenderSection<RRsetPtr>(renderer, true)).getTotalCount();
-    }
-    uint16_t nscount = 0;
-    if (!renderer.isTruncated()) {
-        nscount =
-            for_each(impl_->rrsets_[SECTION_AUTHORITY].begin(),
-                     impl_->rrsets_[SECTION_AUTHORITY].end(),
-                     RenderSection<RRsetPtr>(renderer, true)).getTotalCount();
-    }
-    uint16_t arcount = 0;
-    if (renderer.isTruncated()) {
-        setHeaderFlag(HEADERFLAG_TC, true);
-    } else {
-        arcount =
-            for_each(impl_->rrsets_[SECTION_ADDITIONAL].begin(),
-                     impl_->rrsets_[SECTION_ADDITIONAL].end(),
-                     RenderSection<RRsetPtr>(renderer, false)).getTotalCount();
-    }
-
-    // Add EDNS OPT RR if necessary.  Basically, we add it only when EDNS
-    // has been explicitly set.  However, if the RCODE would require it and
-    // no EDNS has been set we generate a temporary local EDNS and use it.
-    if (!renderer.isTruncated()) {
-        ConstEDNSPtr local_edns = impl_->edns_;
-        if (!local_edns && impl_->rcode_->getExtendedCode() != 0) {
-            local_edns = ConstEDNSPtr(new EDNS());
-        }
-        if (local_edns) {
-            arcount += local_edns->toWire(renderer,
-                                          impl_->rcode_->getExtendedCode());
-        }
-    }
- 
-    // Adjust the counter buffer.
-    // XXX: these may not be equal to the number of corresponding entries
-    // in rrsets_[] or questions_ if truncation occurred or an EDNS OPT RR
-    // was inserted.  This is not good, and we should revisit the entire
-    // design.
-    impl_->counts_[SECTION_QUESTION] = qdcount;
-    impl_->counts_[SECTION_ANSWER] = ancount;
-    impl_->counts_[SECTION_AUTHORITY] = nscount;
-    impl_->counts_[SECTION_ADDITIONAL] = arcount;
-
-    // TBD: TSIG, SIG(0) etc.
-
-    // fill in the header
-    size_t header_pos = 0;
-    renderer.writeUint16At(impl_->qid_, header_pos);
-    header_pos += sizeof(uint16_t);
+    impl_->toWire(renderer, NULL);
+}
 
-    uint16_t codes_and_flags =
-        (impl_->opcode_->getCode() << OPCODE_SHIFT) & OPCODE_MASK;
-    codes_and_flags |= (impl_->rcode_->getCode() & RCODE_MASK);
-    codes_and_flags |= (impl_->flags_ & HEADERFLAG_MASK);
-    renderer.writeUint16At(codes_and_flags, header_pos);
-    header_pos += sizeof(uint16_t);
-    // XXX: should avoid repeated pattern (TODO)
-    renderer.writeUint16At(qdcount, header_pos);
-    header_pos += sizeof(uint16_t);
-    renderer.writeUint16At(ancount, header_pos);
-    header_pos += sizeof(uint16_t);
-    renderer.writeUint16At(nscount, header_pos);
-    header_pos += sizeof(uint16_t);
-    renderer.writeUint16At(arcount, header_pos);
-    header_pos += sizeof(uint16_t);
+void
+Message::toWire(MessageRenderer& renderer, TSIGContext& tsig_ctx) {
+    impl_->toWire(renderer, &tsig_ctx);
 }
 
 void

+ 4 - 0
src/lib/dns/message.h

@@ -33,6 +33,7 @@ class InputBuffer;
 }
 
 namespace dns {
+class TSIGContext;
 
 ///
 /// \brief A standard DNS module exception that is thrown if a wire format
@@ -531,6 +532,9 @@ public:
     /// class \c InvalidMessageOperation will be thrown.
     void toWire(MessageRenderer& renderer);
 
+    // TBD
+    void toWire(MessageRenderer& renderer, TSIGContext& tsig_ctx);
+
     /// \brief Parse the header section of the \c Message.
     void parseHeader(isc::util::InputBuffer& buffer);
 

+ 2 - 10
src/lib/dns/rdata/any_255/tsig_250.cc

@@ -24,7 +24,7 @@
 #include <dns/messagerenderer.h>
 #include <dns/rdata.h>
 #include <dns/rdataclass.h>
-
+#include <dns/tsigerror.h>
 
 using namespace std;
 using namespace boost;
@@ -313,15 +313,7 @@ TSIG::toText() const {
         result += encodeBase64(impl_->mac_) + " ";
     }
     result += lexical_cast<string>(impl_->original_id_) + " ";
-    if (impl_->error_ == 16) {  // XXX: we'll soon introduce generic converter.
-        result += "BADSIG ";
-    } else if (impl_->error_ == 17) {
-        result += "BADKEY ";
-    } else if (impl_->error_ == 18) {
-        result += "BADTIME ";
-    } else {
-        result += lexical_cast<string>(impl_->error_) + " ";
-    }
+    result += TSIGError(impl_->error_).toText() + " ";
     result += lexical_cast<string>(impl_->other_data_.size());
     if (impl_->other_data_.size() > 0) {
         result += " " + encodeBase64(impl_->other_data_);

+ 1 - 0
src/lib/dns/tests/Makefile.am

@@ -50,6 +50,7 @@ run_unittests_SOURCES += message_unittest.cc
 run_unittests_SOURCES += tsig_unittest.cc
 run_unittests_SOURCES += tsigerror_unittest.cc
 run_unittests_SOURCES += tsigkey_unittest.cc
+run_unittests_SOURCES += tsigrecord_unittest.cc
 run_unittests_SOURCES += run_unittests.cc
 run_unittests_CPPFLAGS = $(AM_CPPFLAGS) $(GTEST_INCLUDES)
 run_unittests_LDFLAGS = $(AM_LDFLAGS) $(GTEST_LDFLAGS)

+ 80 - 1
src/lib/dns/tests/message_unittest.cc

@@ -12,6 +12,8 @@
 // OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
 // PERFORMANCE OF THIS SOFTWARE.
 
+#include <boost/scoped_ptr.hpp>
+
 #include <exceptions/exceptions.h>
 
 #include <util/buffer.h>
@@ -26,6 +28,8 @@
 #include <dns/rrclass.h>
 #include <dns/rrttl.h>
 #include <dns/rrtype.h>
+#include <dns/tsig.h>
+#include <dns/tsigkey.h>
 
 #include <gtest/gtest.h>
 
@@ -53,6 +57,17 @@ using namespace isc::dns::rdata;
 const uint16_t Message::DEFAULT_MAX_UDPSIZE;
 const Name test_name("test.example.com");
 
+// See dnssectime.cc
+namespace isc {
+namespace dns {
+namespace tsig {
+namespace detail {
+extern int64_t (*gettimeFunction)();
+}
+}
+}
+}
+
 namespace {
 class MessageTest : public ::testing::Test {
 protected:
@@ -60,7 +75,9 @@ protected:
                     message_parse(Message::PARSE),
                     message_render(Message::RENDER),
                     bogus_section(static_cast<Message::Section>(
-                                      Message::SECTION_ADDITIONAL + 1))
+                                      Message::SECTION_ADDITIONAL + 1)),
+                    tsig_ctx(TSIGKey("www.example.com:"
+                                     "SFuWd/q99SzF8Yzd1QbB9g=="))
     {
         rrset_a = RRsetPtr(new RRset(test_name, RRClass::IN(),
                                      RRType::A(), RRTTL(3600)));
@@ -88,6 +105,9 @@ protected:
     RRsetPtr rrset_a;           // A RRset with two RDATAs
     RRsetPtr rrset_aaaa;        // AAAA RRset with one RDATA with RRSIG
     RRsetPtr rrset_rrsig;       // RRSIG for the AAAA RRset
+    TSIGContext tsig_ctx;
+    vector<unsigned char> expected_data;
+
     static void factoryFromFile(Message& message, const char* datafile);
 };
 
@@ -519,6 +539,65 @@ TEST_F(MessageTest, toWireInParseMode) {
     EXPECT_THROW(message_parse.toWire(renderer), InvalidMessageOperation);
 }
 
+// See dnssectime_unittest.cc
+template <int64_t NOW>
+int64_t
+testGetTime() {
+    return (NOW);
+}
+
+void
+commonTSIGToWireCheck(Message& message, MessageRenderer& renderer,
+                      TSIGContext& tsig_ctx, const char* const expected_file)
+{
+    message.setOpcode(Opcode::QUERY());
+    message.setRcode(Rcode::NOERROR());
+    message.setHeaderFlag(Message::HEADERFLAG_RD, true);
+    message.addQuestion(Question(Name("www.example.com"), RRClass::IN(),
+                                 RRType::A()));
+
+    message.toWire(renderer, tsig_ctx);
+    vector<unsigned char> expected_data;
+    UnitTestUtil::readWireData(expected_file, expected_data);
+    EXPECT_PRED_FORMAT4(UnitTestUtil::matchWireData, renderer.getData(),
+                        renderer.getLength(),
+                        &expected_data[0], expected_data.size());
+}
+
+TEST_F(MessageTest, toWireWithTSIG) {
+    // Rendering a message with TSIG.  Various special cases specific to
+    // TSIG are tested in the tsig tests.  We only check the message contains
+    // a TSIG at the end and the ARCOUNT of the header is updated.
+
+    tsig::detail::gettimeFunction = testGetTime<0x4da8877a>;
+
+    message_render.setQid(0x2d65);
+
+    {
+        SCOPED_TRACE("Message sign with TSIG");
+        commonTSIGToWireCheck(message_render, renderer, tsig_ctx,
+                              "message_toWire2.wire");
+    }
+}
+
+TEST_F(MessageTest, toWireWithEDNSAndTSIG) {
+    // Similar to the previous test, but with an EDNS before TSIG.
+    // The wire data check will confirm the ordering.
+    tsig::detail::gettimeFunction = testGetTime<0x4db60d1f>;
+
+    message_render.setQid(0x6cd);
+
+    EDNSPtr edns(new EDNS());
+    edns->setUDPSize(4096);
+    message_render.setEDNS(edns);
+
+    {
+        SCOPED_TRACE("Message sign with TSIG and EDNS");
+        commonTSIGToWireCheck(message_render, renderer, tsig_ctx,
+                              "message_toWire3.wire");
+    }
+}
+
 TEST_F(MessageTest, toWireWithoutOpcode) {
     message_render.setRcode(Rcode::NOERROR());
     EXPECT_THROW(message_render.toWire(renderer), InvalidMessageOperation);

+ 6 - 3
src/lib/dns/tests/testdata/Makefile.am

@@ -3,6 +3,7 @@ CLEANFILES = *.wire
 BUILT_SOURCES = edns_toWire1.wire edns_toWire2.wire edns_toWire3.wire
 BUILT_SOURCES += edns_toWire4.wire
 BUILT_SOURCES += message_fromWire10.wire message_fromWire11.wire
+BUILT_SOURCES += message_toWire2.wire message_toWire3.wire
 BUILT_SOURCES += name_toWire5.wire name_toWire6.wire
 BUILT_SOURCES += rdatafields1.wire rdatafields2.wire rdatafields3.wire
 BUILT_SOURCES += rdatafields4.wire rdatafields5.wire rdatafields6.wire
@@ -33,6 +34,7 @@ BUILT_SOURCES += rdata_tsig_fromWire9.wire
 BUILT_SOURCES += rdata_tsig_toWire1.wire rdata_tsig_toWire2.wire
 BUILT_SOURCES += rdata_tsig_toWire3.wire rdata_tsig_toWire4.wire
 BUILT_SOURCES += rdata_tsig_toWire5.wire
+BUILT_SOURCES += tsigrecord_toWire1.wire tsigrecord_toWire2.wire
 
 # NOTE: keep this in sync with real file listing
 # so is included in tarball
@@ -46,7 +48,7 @@ EXTRA_DIST += message_fromWire5 message_fromWire6
 EXTRA_DIST += message_fromWire7 message_fromWire8
 EXTRA_DIST += message_fromWire9 message_fromWire10.spec
 EXTRA_DIST += message_fromWire11.spec
-EXTRA_DIST += message_toWire1
+EXTRA_DIST += message_toWire1 message_toWire2.spec message_toWire3.spec
 EXTRA_DIST += name_fromWire1 name_fromWire2 name_fromWire3_1 name_fromWire3_2
 EXTRA_DIST += name_fromWire4 name_fromWire6 name_fromWire7 name_fromWire8
 EXTRA_DIST += name_fromWire9 name_fromWire10 name_fromWire11 name_fromWire12
@@ -66,7 +68,8 @@ EXTRA_DIST += rdata_nsec_fromWire6.spec rdata_nsec_fromWire7.spec
 EXTRA_DIST += rdata_nsec_fromWire8.spec rdata_nsec_fromWire9.spec
 EXTRA_DIST += rdata_nsec_fromWire10.spec
 EXTRA_DIST += rdata_nsec3param_fromWire1
-EXTRA_DIST += rdata_nsec3_fromWire1 rdata_nsec3_fromWire3
+EXTRA_DIST += rdata_nsec3_fromWire1
+EXTRA_DIST += rdata_nsec3_fromWire2.spec rdata_nsec3_fromWire3
 EXTRA_DIST += rdata_nsec3_fromWire4.spec rdata_nsec3_fromWire5.spec
 EXTRA_DIST += rdata_nsec3_fromWire6.spec rdata_nsec3_fromWire7.spec
 EXTRA_DIST += rdata_nsec3_fromWire8.spec rdata_nsec3_fromWire9.spec
@@ -94,7 +97,7 @@ EXTRA_DIST += rdata_tsig_fromWire9.spec
 EXTRA_DIST += rdata_tsig_toWire1.spec rdata_tsig_toWire2.spec
 EXTRA_DIST += rdata_tsig_toWire3.spec rdata_tsig_toWire4.spec
 EXTRA_DIST += rdata_tsig_toWire5.spec
-EXTRA_DIST += rdata_nsec3_fromWire2.spec
+EXTRA_DIST += tsigrecord_toWire1.spec tsigrecord_toWire2.spec
 
 .spec.wire:
 	./gen-wiredata.py -o $@ $<

+ 15 - 2
src/lib/dns/tests/testdata/gen-wiredata.py.in

@@ -433,6 +433,11 @@ class RRSIG:
         f.write('%04x %s %s\n' % (self.tag, name_wire, sig_wire))
 
 class TSIG:
+    as_rr = False
+    rr_name = 'example.com' # only when as_rr is True, same for class/TTL
+    rr_class = parse_value('ANY', dict_rrclass)
+    rr_ttl = 0
+
     rdlen = None                # auto-calculate
     algorithm = 'hmac-sha256'
     time_signed = 1286978795    # arbitrarily chosen default
@@ -471,8 +476,16 @@ class TSIG:
         if rdlen is None:
             rdlen = int(len(name_wire) / 2 + 16 + len(mac) / 2 + \
                             len(other_data) / 2)
-        f.write('\n# TSIG RDATA (RDLEN=%d)\n' % rdlen)
-        f.write('%04x\n' % rdlen);
+        if self.as_rr:
+            f.write('\n# TSIG RR (QNAME=%s Class=%s TTL=%d RDLEN=%d)\n' %
+                    (self.rr_name, rdict_rrclass[self.rr_class],
+                     self.rr_ttl, rdlen))
+            f.write('%s %04x %04x %08x %04x\n' %
+                    (encode_name(self.rr_name), dict_rrtype['tsig'],
+                     self.rr_class, self.rr_ttl, rdlen))
+        else:
+            f.write('\n# TSIG RDATA (RDLEN=%d)\n' % rdlen)
+            f.write('%04x\n' % rdlen);
         f.write('# Algorithm=%s Time-Signed=%d Fudge=%d\n' %
                 (self.algorithm, self.time_signed, self.fudge))
         f.write('%s %012x %04x\n' % (name_wire, self.time_signed, self.fudge))

+ 2 - 1
src/lib/dns/tests/testdata/tsigrecord_toWire1.spec

@@ -1,5 +1,6 @@
 #
-# A simple DNS response message with TSIG signed
+# A simple TSIG RR (some of the parameters are taken from a live example
+# and don't have a specific meaning)
 #
 
 [custom]

+ 19 - 0
src/lib/dns/tests/testdata/tsigrecord_toWire2.spec

@@ -0,0 +1,19 @@
+#
+# TSIG RR after some names that could (unexpectedly) cause name compression
+#
+
+[custom]
+sections: name/1:name/2:tsig
+[name/1]
+name: hmac-md5.sig-alg.reg.int
+[name/2]
+name: foo.example.com
+[tsig]
+as_rr: True
+# TSIG QNAME won't be compressed
+rr_name: www.example.com
+algorithm: hmac-md5
+time_signed: 0x4da8877a
+mac_size: 16
+mac: 0xdadadadadadadadadadadadadadadada
+original_id: 0x2d65

+ 3 - 1
src/lib/dns/tsig.cc

@@ -118,8 +118,9 @@ TSIGContext::sign(const uint16_t qid, const void* const data,
     // specified in Section 4.3 of RFC2845.
     if (error == TSIGError::BAD_SIG() || error == TSIGError::BAD_KEY()) {
         ConstTSIGRecordPtr tsig(new TSIGRecord(
+                                    impl_->key_.getKeyName(),
                                     any::TSIG(impl_->key_.getAlgorithmName(),
-                                              now, DEFAULT_FUDGE, NULL, 0,
+                                              now, DEFAULT_FUDGE, 0, NULL,
                                               qid, error.getCode(), 0, NULL)));
         impl_->previous_digest_.clear();
         impl_->state_ = SIGNED;
@@ -187,6 +188,7 @@ TSIGContext::sign(const uint16_t qid, const void* const data,
     // Get the final digest, update internal state, then finish.
     vector<uint8_t> digest = hmac->sign();
     ConstTSIGRecordPtr tsig(new TSIGRecord(
+                                impl_->key_.getKeyName(),
                                 any::TSIG(impl_->key_.getAlgorithmName(),
                                           time_signed, DEFAULT_FUDGE,
                                           digest.size(), &digest[0],

+ 75 - 0
src/lib/dns/tsigrecord.cc

@@ -12,16 +12,91 @@
 // OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
 // PERFORMANCE OF THIS SOFTWARE.
 
+#include <ostream>
+#include <string>
+
 #include <util/buffer.h>
 
+#include <dns/messagerenderer.h>
 #include <dns/rrclass.h>
+#include <dns/rrttl.h>
 #include <dns/tsigrecord.h>
 
+using namespace isc::util;
+
+namespace {
+// Internally used constants:
+
+// Size in octets for the RR type, class TTL fields.
+const size_t RR_COMMON_LEN = 8;
+
+// Size in octets for the fixed part of TSIG RDATAs.
+// - Time Signed (6)
+// - Fudge (2)
+// - MAC Size (2)
+// - Original ID (2)
+// - Error (2)
+// - Other Len (2)
+const size_t RDATA_COMMON_LEN = 16;
+}
+
 namespace isc {
 namespace dns {
+TSIGRecord::TSIGRecord(const Name& key_name,
+                       const rdata::any::TSIG& tsig_rdata) :
+    key_name_(key_name), rdata_(tsig_rdata),
+    length_(RR_COMMON_LEN + RDATA_COMMON_LEN + key_name_.getLength() +
+            rdata_.getAlgorithm().getLength() +
+            rdata_.getMACSize() + rdata_.getOtherLen())
+{}
+
 const RRClass&
 TSIGRecord::getClass() {
     return (RRClass::ANY());
 }
+
+namespace {
+template <typename OUTPUT>
+void
+toWireCommon(OUTPUT& output, const rdata::any::TSIG& rdata) {
+    // RR type, class, TTL are fixed constants.
+    RRType::TSIG().toWire(output);
+    TSIGRecord::getClass().toWire(output);
+    output.writeUint32(TSIGRecord::TSIG_TTL);
+
+    // RDLEN
+    output.writeUint16(RDATA_COMMON_LEN + rdata.getAlgorithm().getLength() +
+                       rdata.getMACSize() + rdata.getOtherLen());
+
+    // TSIG RDATA
+    rdata.toWire(output);
+}
+}
+
+void
+TSIGRecord::toWire(AbstractMessageRenderer& renderer) const {
+    // key name = owner.  note that we disable compression.
+    renderer.writeName(key_name_, false);
+
+    toWireCommon(renderer, rdata_);
+}
+
+void
+TSIGRecord::toWire(OutputBuffer& buffer) const {
+    key_name_.toWire(buffer);
+    toWireCommon(buffer, rdata_);
+}
+
+std::string
+TSIGRecord::toText() const {
+    return (key_name_.toText() + " " + RRTTL(TSIG_TTL).toText() + " " +
+            getClass().toText() + " " + RRType::TSIG().toText() + " " +
+            rdata_.toText() + "\n");
+}
+
+std::ostream&
+operator<<(std::ostream& os, const TSIGRecord& record) {
+    return (os << record.toText());
+}
 } // namespace dns
 } // namespace isc

+ 39 - 11
src/lib/dns/tsigrecord.h

@@ -15,15 +15,21 @@
 #ifndef __TSIGRECORD_H
 #define __TSIGRECORD_H 1
 
-#include <boost/shared_ptr.hpp>
+#include <ostream>
+#include <string>
 
-#include <util/buffer.h>
+#include <boost/shared_ptr.hpp>
 
 #include <dns/name.h>
 #include <dns/rdataclass.h>
 
 namespace isc {
+namespace util {
+class OutputBuffer;
+}
 namespace dns {
+class AbstractMessageRenderer;
+
 /// TSIG resource record.
 ///
 /// A \c TSIGRecord class object represents a TSIG resource record and is
@@ -33,9 +39,8 @@ namespace dns {
 /// TSIG without knowing protocol details of TSIG, such as that it uses a
 /// fixed constant of TTL.
 ///
-/// \note So the plan is to eventually provide a \c toWire() method and
-/// the "from wire" constructor.  They are not yet provided in this initial
-/// step.
+/// \todo So the plan is to eventually provide  the "from wire" constructor.
+/// It's not yet provided in the current phase of development.
 ///
 /// \note
 /// This class could be a derived class of \c AbstractRRset.  That way
@@ -54,13 +59,13 @@ namespace dns {
 /// similar to why \c EDNS is a separate class.
 class TSIGRecord {
 public:
-    /// Constructor from TSIG RDATA
+    /// Constructor from TSIG key name and RDATA
     ///
-    /// \exception std::bad_alloc Resource allocation for copying the RDATA
-    /// fails
-    explicit TSIGRecord(const rdata::any::TSIG& tsig_rdata) :
-        rdata_(tsig_rdata)
-    {}
+    /// \exception std::bad_alloc Resource allocation for copying the name or
+    /// RDATA fails
+    TSIGRecord(const Name& key_name, const rdata::any::TSIG& tsig_rdata);
+
+    const Name& getName() const { return (key_name_); }
 
     /// Return the RDATA of the TSIG RR
     ///
@@ -79,12 +84,23 @@ public:
     /// \exception None
     static const RRClass& getClass();
 
+    // Note: More important for the "from wire" case.
+    size_t getLength() const { return (length_); }
+
+    void toWire(AbstractMessageRenderer& renderer) const;
+
+    void toWire(isc::util::OutputBuffer& buffer) const;
+
+    std::string toText() const;
+
     /// The TTL value to be used in TSIG RRs.
     static const uint32_t TSIG_TTL = 0;
     //@}
 
 private:
+    const Name key_name_;
     const rdata::any::TSIG rdata_;
+    const size_t length_;
 };
 
 /// A pointer-like type pointing to a \c TSIGRecord object.
@@ -92,6 +108,18 @@ typedef boost::shared_ptr<TSIGRecord> TSIGRecordPtr;
 
 /// A pointer-like type pointing to an immutable \c TSIGRecord object.
 typedef boost::shared_ptr<const TSIGRecord> ConstTSIGRecordPtr;
+
+/// Insert the \c TSIGRecord as a string into stream.
+///
+/// This method convert \c record into a string and inserts it into the
+/// output stream \c os.
+///
+/// \param os A \c std::ostream object on which the insertion operation is
+/// performed.
+/// \param record An \c TSIGRecord object output by the operation.
+/// \return A reference to the same \c std::ostream object referenced by
+/// parameter \c os after the insertion operation.
+std::ostream& operator<<(std::ostream& os, const TSIGRecord& record);
 }
 }