Browse Source

[trac931] Most of the signing

There's a need to construct 0-length signature on error.
Michal 'vorner' Vaner 14 years ago
parent
commit
b8da54961c
2 changed files with 75 additions and 21 deletions
  1. 72 21
      src/bin/auth/auth_srv.cc
  2. 3 0
      src/bin/auth/tests/auth_srv_unittest.cc

+ 72 - 21
src/bin/auth/auth_srv.cc

@@ -20,6 +20,7 @@
 #include <cassert>
 #include <iostream>
 #include <vector>
+#include <memory>
 
 #include <boost/bind.hpp>
 
@@ -43,6 +44,7 @@
 #include <dns/rrset.h>
 #include <dns/rrttl.h>
 #include <dns/message.h>
+#include <dns/tsig.h>
 
 #include <datasrc/query.h>
 #include <datasrc/data_source.h>
@@ -58,6 +60,8 @@
 #include <auth/query.h>
 #include <auth/statistics.h>
 
+#include <server_common/keyring.h>
+
 using namespace std;
 
 using namespace isc;
@@ -85,11 +89,14 @@ public:
     isc::data::ConstElementPtr setDbFile(isc::data::ConstElementPtr config);
 
     bool processNormalQuery(const IOMessage& io_message, MessagePtr message,
-                            OutputBufferPtr buffer);
+                            OutputBufferPtr buffer,
+                            auto_ptr<TSIGContext> tsig_context);
     bool processAxfrQuery(const IOMessage& io_message, MessagePtr message,
-                          OutputBufferPtr buffer);
+                          OutputBufferPtr buffer,
+                          auto_ptr<TSIGContext> tsig_context);
     bool processNotify(const IOMessage& io_message, MessagePtr message,
-                       OutputBufferPtr buffer);
+                       OutputBufferPtr buffer,
+                       auto_ptr<TSIGContext> tsig_context);
 
     IOService io_service_;
 
@@ -241,7 +248,9 @@ public:
 
 void
 makeErrorMessage(MessagePtr message, OutputBufferPtr buffer,
-                 const Rcode& rcode, const bool verbose_mode)
+                 const Rcode& rcode, const bool verbose_mode,
+                 std::auto_ptr<TSIGContext> tsig_context =
+                 std::auto_ptr<TSIGContext>())
 {
     // extract the parameters that should be kept.
     // XXX: with the current implementation, it's not easy to set EDNS0
@@ -272,7 +281,11 @@ makeErrorMessage(MessagePtr message, OutputBufferPtr buffer,
     message->setRcode(rcode);
 
     MessageRenderer renderer(*buffer);
-    message->toWire(renderer);
+    if (tsig_context.get() != NULL) {
+        message->toWire(renderer, *tsig_context);
+    } else {
+        message->toWire(renderer);
+    }
 
     if (verbose_mode) {
         cerr << "[b10-auth] sending an error response (" <<
@@ -446,29 +459,52 @@ AuthSrv::processMessage(const IOMessage& io_message, MessagePtr message,
     }
 
     // Perform further protocol-level validation.
+    // TSIG first
+    // If this is set to something, we know we need to answer with TSIG as well
+    std::auto_ptr<TSIGContext> tsig_context;
+    const TSIGRecord* tsig_record(message->getTSIGRecord());
+    TSIGError tsig_error(TSIGError::NOERROR());
+
+    // Do we do TSIG?
+    // The keyring can be null if we're in test
+    if (server_common::keyring && tsig_record) {
+        tsig_context.reset(new TSIGContext(tsig_record->getName(),
+                                           tsig_record->getRdata().
+                                                getAlgorithm(),
+                                           *server_common::keyring));
+        tsig_error = tsig_context->verify(tsig_record, io_message.getData(),
+                                          io_message.getDataSize());
+    }
 
     bool sendAnswer = true;
-    if (message->getOpcode() == Opcode::NOTIFY()) {
-        sendAnswer = impl_->processNotify(io_message, message, buffer);
+    if (tsig_error != TSIGError::NOERROR()) {
+        // TODO We need to add a TSIG but with 0-length signature
+        makeErrorMessage(message, buffer, tsig_error.toRcode(),
+                         impl_->verbose_mode_);
+    } else if (message->getOpcode() == Opcode::NOTIFY()) {
+        sendAnswer = impl_->processNotify(io_message, message, buffer,
+                                          tsig_context);
     } else if (message->getOpcode() != Opcode::QUERY()) {
         if (impl_->verbose_mode_) {
             cerr << "[b10-auth] unsupported opcode" << endl;
         }
         makeErrorMessage(message, buffer, Rcode::NOTIMP(),
-                         impl_->verbose_mode_);
+                         impl_->verbose_mode_, tsig_context);
     } else if (message->getRRCount(Message::SECTION_QUESTION) != 1) {
         makeErrorMessage(message, buffer, Rcode::FORMERR(),
-                         impl_->verbose_mode_);
+                         impl_->verbose_mode_, tsig_context);
     } else {
         ConstQuestionPtr question = *message->beginQuestion();
         const RRType &qtype = question->getType();
         if (qtype == RRType::AXFR()) {
-            sendAnswer = impl_->processAxfrQuery(io_message, message, buffer);
+            sendAnswer = impl_->processAxfrQuery(io_message, message, buffer,
+                                                 tsig_context);
         } else if (qtype == RRType::IXFR()) {
             makeErrorMessage(message, buffer, Rcode::NOTIMP(),
-                             impl_->verbose_mode_);
+                             impl_->verbose_mode_, tsig_context);
         } else {
-            sendAnswer = impl_->processNormalQuery(io_message, message, buffer);
+            sendAnswer = impl_->processNormalQuery(io_message, message, buffer,
+                                                   tsig_context);
         }
     }
 
@@ -477,7 +513,8 @@ AuthSrv::processMessage(const IOMessage& io_message, MessagePtr message,
 
 bool
 AuthSrvImpl::processNormalQuery(const IOMessage& io_message, MessagePtr message,
-                                OutputBufferPtr buffer)
+                                OutputBufferPtr buffer,
+                                auto_ptr<TSIGContext> tsig_context)
 {
     ConstEDNSPtr remote_edns = message->getEDNS();
     const bool dnssec_ok = remote_edns && remote_edns->getDNSSECAwareness();
@@ -523,7 +560,11 @@ AuthSrvImpl::processNormalQuery(const IOMessage& io_message, MessagePtr message,
     const bool udp_buffer =
         (io_message.getSocket().getProtocol() == IPPROTO_UDP);
     renderer.setLengthLimit(udp_buffer ? remote_bufsize : 65535);
-    message->toWire(renderer);
+    if (tsig_context.get() != NULL) {
+        message->toWire(renderer, *tsig_context);
+    } else {
+        message->toWire(renderer);
+    }
 
     if (verbose_mode_) {
         cerr << "[b10-auth] sending a response ("
@@ -536,7 +577,8 @@ AuthSrvImpl::processNormalQuery(const IOMessage& io_message, MessagePtr message,
 
 bool
 AuthSrvImpl::processAxfrQuery(const IOMessage& io_message, MessagePtr message,
-                              OutputBufferPtr buffer)
+                              OutputBufferPtr buffer,
+                              auto_ptr<TSIGContext> tsig_context)
 {
     // Increment query counter.
     incCounter(io_message.getSocket().getProtocol());
@@ -545,7 +587,8 @@ AuthSrvImpl::processAxfrQuery(const IOMessage& io_message, MessagePtr message,
         if (verbose_mode_) {
             cerr << "[b10-auth] AXFR query over UDP isn't allowed" << endl;
         }
-        makeErrorMessage(message, buffer, Rcode::FORMERR(), verbose_mode_);
+        makeErrorMessage(message, buffer, Rcode::FORMERR(), verbose_mode_,
+                         tsig_context);
         return (true);
     }
 
@@ -572,7 +615,8 @@ AuthSrvImpl::processAxfrQuery(const IOMessage& io_message, MessagePtr message,
             cerr << "[b10-auth] Error in handling XFR request: " << err.what()
                  << endl;
         }
-        makeErrorMessage(message, buffer, Rcode::SERVFAIL(), verbose_mode_);
+        makeErrorMessage(message, buffer, Rcode::SERVFAIL(), verbose_mode_,
+                         tsig_context);
         return (true);
     }
 
@@ -581,7 +625,8 @@ AuthSrvImpl::processAxfrQuery(const IOMessage& io_message, MessagePtr message,
 
 bool
 AuthSrvImpl::processNotify(const IOMessage& io_message, MessagePtr message, 
-                           OutputBufferPtr buffer)
+                           OutputBufferPtr buffer,
+                           std::auto_ptr<TSIGContext> tsig_context)
 {
     // The incoming notify must contain exactly one question for SOA of the
     // zone name.
@@ -590,7 +635,8 @@ AuthSrvImpl::processNotify(const IOMessage& io_message, MessagePtr message,
                 cerr << "[b10-auth] invalid number of questions in notify: "
                      << message->getRRCount(Message::SECTION_QUESTION) << endl;
         }
-        makeErrorMessage(message, buffer, Rcode::FORMERR(), verbose_mode_);
+        makeErrorMessage(message, buffer, Rcode::FORMERR(), verbose_mode_,
+                         tsig_context);
         return (true);
     }
     ConstQuestionPtr question = *message->beginQuestion();
@@ -599,7 +645,8 @@ AuthSrvImpl::processNotify(const IOMessage& io_message, MessagePtr message,
                 cerr << "[b10-auth] invalid question RR type in notify: "
                      << question->getType() << endl;
         }
-        makeErrorMessage(message, buffer, Rcode::FORMERR(), verbose_mode_);
+        makeErrorMessage(message, buffer, Rcode::FORMERR(), verbose_mode_,
+                         tsig_context);
         return (true);
     }
 
@@ -662,7 +709,11 @@ AuthSrvImpl::processNotify(const IOMessage& io_message, MessagePtr message,
     message->setRcode(Rcode::NOERROR());
 
     MessageRenderer renderer(*buffer);
-    message->toWire(renderer);
+    if (tsig_context.get() != NULL) {
+        message->toWire(renderer, *tsig_context);
+    } else {
+        message->toWire(renderer);
+    }
     return (true);
 }
 

+ 3 - 0
src/bin/auth/tests/auth_srv_unittest.cc

@@ -269,6 +269,7 @@ TEST_F(AuthSrvTest, TSIGSigned) {
     // We need to parse the message ourself, or getTSIGRecord won't work
     InputBuffer ib(response_obuffer->getData(), response_obuffer->getLength());
     Message m(Message::PARSE);
+    m.fromWire(ib);
 
     const TSIGRecord* tsig = m.getTSIGRecord();
     ASSERT_TRUE(tsig) << "Missing TSIG signature";
@@ -303,6 +304,7 @@ TEST_F(AuthSrvTest, TSIGSignedNoKey) {
     // We need to parse the message ourself, or getTSIGRecord won't work
     InputBuffer ib(response_obuffer->getData(), response_obuffer->getLength());
     Message m(Message::PARSE);
+    m.fromWire(ib);
 
     const TSIGRecord* tsig = m.getTSIGRecord();
     ASSERT_TRUE(tsig) <<
@@ -334,6 +336,7 @@ TEST_F(AuthSrvTest, TSIGBadSig) {
     // We need to parse the message ourself, or getTSIGRecord won't work
     InputBuffer ib(response_obuffer->getData(), response_obuffer->getLength());
     Message m(Message::PARSE);
+    m.fromWire(ib);
 
     const TSIGRecord* tsig = m.getTSIGRecord();
     ASSERT_TRUE(tsig) <<