Browse Source

[1783] add a test case where toWire() throws, and fix it using a cleaner class.

also make sure the internal (reusable) MessageRenderer is always used.
JINMEI Tatuya 13 years ago
parent
commit
f46b788569
2 changed files with 121 additions and 38 deletions
  1. 49 24
      src/bin/auth/auth_srv.cc
  2. 72 14
      src/bin/auth/tests/auth_srv_unittest.cc

+ 49 - 24
src/bin/auth/auth_srv.cc

@@ -78,6 +78,31 @@ using namespace isc::asiolink;
 using namespace isc::asiodns;
 using namespace isc::server_common::portconfig;
 
+namespace {
+// A helper class for cleaning up message renderer.
+//
+// A temporary object of this class is expected to be created before starting
+// response message rendering.  On construction, it (re)initialize the given
+// message renderer with the given buffer.  On destruction, it releases
+// the previously set buffer and then release any internal resource in the
+// renderer, no matter what happened during the rendering, especially even
+// when it resulted in an exception.
+class  RendererHolder {
+public:
+    RendererHolder(MessageRenderer& renderer, OutputBuffer* buffer) :
+        renderer_(renderer)
+    {
+        renderer.setBuffer(buffer);
+    }
+    ~RendererHolder() {
+        renderer_.setBuffer(NULL);
+        renderer_.clear();
+    }
+private:
+    MessageRenderer& renderer_;
+};
+}
+
 class AuthSrvImpl {
 private:
     // prohibit copy
@@ -277,8 +302,8 @@ public:
 };
 
 void
-makeErrorMessage(Message& message, OutputBuffer& buffer,
-                 const Rcode& rcode,
+makeErrorMessage(MessageRenderer& renderer, Message& message,
+                 OutputBuffer& buffer, const Rcode& rcode,
                  std::auto_ptr<TSIGContext> tsig_context =
                  std::auto_ptr<TSIGContext>())
 {
@@ -311,14 +336,12 @@ makeErrorMessage(Message& message, OutputBuffer& buffer,
 
     message.setRcode(rcode);
     
-    MessageRenderer renderer;
-    renderer.setBuffer(&buffer);
+    RendererHolder holder(renderer, &buffer);
     if (tsig_context.get() != NULL) {
         message.toWire(renderer, *tsig_context);
     } else {
         message.toWire(renderer);
     }
-    renderer.setBuffer(NULL);
     LOG_DEBUG(auth_logger, DBG_AUTH_MESSAGES, AUTH_SEND_ERROR_RESPONSE)
               .arg(renderer.getLength()).arg(message);
 }
@@ -447,13 +470,13 @@ AuthSrv::processMessage(const IOMessage& io_message, Message& message,
     } catch (const DNSProtocolError& error) {
         LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_PACKET_PROTOCOL_ERROR)
                   .arg(error.getRcode().toText()).arg(error.what());
-        makeErrorMessage(message, buffer, error.getRcode());
+        makeErrorMessage(impl_->renderer_, message, buffer, error.getRcode());
         impl_->resumeServer(server, message, true);
         return;
     } catch (const Exception& ex) {
         LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_PACKET_PARSE_ERROR)
                   .arg(ex.what());
-        makeErrorMessage(message, buffer, Rcode::SERVFAIL());
+        makeErrorMessage(impl_->renderer_, message, buffer, Rcode::SERVFAIL());
         impl_->resumeServer(server, message, true);
         return;
     } // other exceptions will be handled at a higher layer.
@@ -480,7 +503,8 @@ AuthSrv::processMessage(const IOMessage& io_message, Message& message,
     }
 
     if (tsig_error != TSIGError::NOERROR()) {
-        makeErrorMessage(message, buffer, tsig_error.toRcode(), tsig_context);
+        makeErrorMessage(impl_->renderer_, message, buffer,
+                         tsig_error.toRcode(), tsig_context);
         impl_->resumeServer(server, message, true);
         return;
     }
@@ -497,9 +521,11 @@ AuthSrv::processMessage(const IOMessage& io_message, Message& message,
         } else if (message.getOpcode() != Opcode::QUERY()) {
             LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_UNSUPPORTED_OPCODE)
                       .arg(message.getOpcode().toText());
-            makeErrorMessage(message, buffer, Rcode::NOTIMP(), tsig_context);
+            makeErrorMessage(impl_->renderer_, message, buffer,
+                             Rcode::NOTIMP(), tsig_context);
         } else if (message.getRRCount(Message::SECTION_QUESTION) != 1) {
-            makeErrorMessage(message, buffer, Rcode::FORMERR(), tsig_context);
+            makeErrorMessage(impl_->renderer_, message, buffer,
+                             Rcode::FORMERR(), tsig_context);
         } else {
             ConstQuestionPtr question = *message.beginQuestion();
             const RRType &qtype = question->getType();
@@ -517,10 +543,10 @@ AuthSrv::processMessage(const IOMessage& io_message, Message& message,
     } catch (const std::exception& ex) {
         LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_RESPONSE_FAILURE)
                   .arg(ex.what());
-        makeErrorMessage(message, buffer, Rcode::SERVFAIL());
+        makeErrorMessage(impl_->renderer_, message, buffer, Rcode::SERVFAIL());
     } catch (...) {
         LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_RESPONSE_FAILURE_UNKNOWN);
-        makeErrorMessage(message, buffer, Rcode::SERVFAIL());
+        makeErrorMessage(impl_->renderer_, message, buffer, Rcode::SERVFAIL());
     }
     impl_->resumeServer(server, message, send_answer);
 }
@@ -563,13 +589,11 @@ AuthSrvImpl::processNormalQuery(const IOMessage& io_message, Message& message,
         }
     } catch (const Exception& ex) {
         LOG_ERROR(auth_logger, AUTH_PROCESS_FAIL).arg(ex.what());
-        makeErrorMessage(message, buffer, Rcode::SERVFAIL());
+        makeErrorMessage(renderer_, message, buffer, Rcode::SERVFAIL());
         return (true);
     }
 
-    renderer_.clear();
-    renderer_.setBuffer(&buffer);
-    
+    RendererHolder holder(renderer_, &buffer);
     const bool udp_buffer =
         (io_message.getSocket().getProtocol() == IPPROTO_UDP);
     renderer_.setLengthLimit(udp_buffer ? remote_bufsize : 65535);
@@ -578,7 +602,6 @@ AuthSrvImpl::processNormalQuery(const IOMessage& io_message, Message& message,
     } else {
         message.toWire(renderer_);
     }
-    renderer_.setBuffer(NULL);
     LOG_DEBUG(auth_logger, DBG_AUTH_MESSAGES, AUTH_SEND_NORMAL_RESPONSE)
               .arg(renderer_.getLength()).arg(message);
     return (true);
@@ -594,7 +617,8 @@ AuthSrvImpl::processXfrQuery(const IOMessage& io_message, Message& message,
 
     if (io_message.getSocket().getProtocol() == IPPROTO_UDP) {
         LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_AXFR_UDP);
-        makeErrorMessage(message, buffer, Rcode::FORMERR(), tsig_context);
+        makeErrorMessage(renderer_, message, buffer, Rcode::FORMERR(),
+                         tsig_context);
         return (true);
     }
 
@@ -619,7 +643,8 @@ AuthSrvImpl::processXfrQuery(const IOMessage& io_message, Message& message,
 
         LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_AXFR_ERROR)
                   .arg(err.what());
-        makeErrorMessage(message, buffer, Rcode::SERVFAIL(), tsig_context);
+        makeErrorMessage(renderer_, message, buffer, Rcode::SERVFAIL(),
+                         tsig_context);
         return (true);
     }
 
@@ -636,14 +661,16 @@ AuthSrvImpl::processNotify(const IOMessage& io_message, Message& message,
     if (message.getRRCount(Message::SECTION_QUESTION) != 1) {
         LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_NOTIFY_QUESTIONS)
                   .arg(message.getRRCount(Message::SECTION_QUESTION));
-        makeErrorMessage(message, buffer, Rcode::FORMERR(), tsig_context);
+        makeErrorMessage(renderer_, message, buffer, Rcode::FORMERR(),
+                         tsig_context);
         return (true);
     }
     ConstQuestionPtr question = *message.beginQuestion();
     if (question->getType() != RRType::SOA()) {
         LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_NOTIFY_RRTYPE)
                   .arg(question->getType().toText());
-        makeErrorMessage(message, buffer, Rcode::FORMERR(), tsig_context);
+        makeErrorMessage(renderer_, message, buffer, Rcode::FORMERR(),
+                         tsig_context);
         return (true);
     }
 
@@ -698,14 +725,12 @@ AuthSrvImpl::processNotify(const IOMessage& io_message, Message& message,
     message.setHeaderFlag(Message::HEADERFLAG_AA);
     message.setRcode(Rcode::NOERROR());
 
-    renderer_.clear();
-    renderer_.setBuffer(&buffer);
+    RendererHolder holder(renderer_, &buffer);
     if (tsig_context.get() != NULL) {
         message.toWire(renderer_, *tsig_context);
     } else {
         message.toWire(renderer_);
     }
-    renderer_.setBuffer(NULL);
     return (true);
 }
 

+ 72 - 14
src/bin/auth/tests/auth_srv_unittest.cc

@@ -1138,11 +1138,12 @@ checkThrow(ThrowWhen method, ThrowWhen throw_at, bool isc_exception) {
 class FakeZoneFinder : public isc::datasrc::ZoneFinder {
 public:
     FakeZoneFinder(isc::datasrc::ZoneFinderPtr zone_finder,
-                   ThrowWhen throw_when,
-                   bool isc_exception) :
+                   ThrowWhen throw_when, bool isc_exception,
+                   ConstRRsetPtr fake_rrset) :
         real_zone_finder_(zone_finder),
         throw_when_(throw_when),
-        isc_exception_(isc_exception)
+        isc_exception_(isc_exception),
+        fake_rrset_(fake_rrset)
     {}
 
     virtual isc::dns::Name
@@ -1162,7 +1163,18 @@ public:
          const isc::dns::RRType& type,
          isc::datasrc::ZoneFinder::FindOptions options)
     {
+        using namespace isc::datasrc;
         checkThrow(THROW_AT_FIND, throw_when_, isc_exception_);
+        // If faked RRset was specified on construction and it matches the
+        // query, return it instead of searching the real data source.
+        if (fake_rrset_ && fake_rrset_->getName() == name &&
+            fake_rrset_->getType() == type)
+        {
+            return (ZoneFinderContextPtr(new ZoneFinder::Context(
+                                             *this, options,
+                                             ResultContext(SUCCESS,
+                                                           fake_rrset_))));
+        }
         return (real_zone_finder_->find(name, type, options));
     }
 
@@ -1190,6 +1202,7 @@ private:
     isc::datasrc::ZoneFinderPtr real_zone_finder_;
     ThrowWhen throw_when_;
     bool isc_exception_;
+    ConstRRsetPtr fake_rrset_;
 };
 
 /// \brief Proxy InMemoryClient that can throw exceptions at specified times
@@ -1206,12 +1219,15 @@ public:
     ///        class or the related FakeZoneFinder)
     /// \param isc_exception if true, throw isc::Exception, otherwise,
     ///                      throw std::exception
+    /// \param fake_rrset If non NULL, it will be used as an answer to
+    /// find() for that name and type.
     FakeInMemoryClient(AuthSrv::InMemoryClientPtr real_client,
-                       ThrowWhen throw_when,
-                       bool isc_exception) :
+                       ThrowWhen throw_when, bool isc_exception,
+                       ConstRRsetPtr fake_rrset = ConstRRsetPtr()) :
         real_client_(real_client),
         throw_when_(throw_when),
-        isc_exception_(isc_exception)
+        isc_exception_(isc_exception),
+        fake_rrset_(fake_rrset)
     {}
 
     /// \brief proxy call for findZone
@@ -1226,14 +1242,16 @@ public:
         const FindResult result = real_client_->findZone(name);
         return (FindResult(result.code, isc::datasrc::ZoneFinderPtr(
                                         new FakeZoneFinder(result.zone_finder,
-                                        throw_when_,
-                                        isc_exception_))));
+                                                           throw_when_,
+                                                           isc_exception_,
+                                                           fake_rrset_))));
     }
 
 private:
     AuthSrv::InMemoryClientPtr real_client_;
     ThrowWhen throw_when_;
     bool isc_exception_;
+    ConstRRsetPtr fake_rrset_;
 };
 
 } // end anonymous namespace for throwing proxy classes
@@ -1248,9 +1266,7 @@ TEST_F(AuthSrvTest, queryWithInMemoryClientProxy) {
 
     AuthSrv::InMemoryClientPtr fake_client(
         new FakeInMemoryClient(server.getInMemoryClient(rrclass),
-                               THROW_NEVER,
-                               false));
-
+                               THROW_NEVER, false));
     ASSERT_NE(AuthSrv::InMemoryClientPtr(), server.getInMemoryClient(rrclass));
     server.setInMemoryClient(rrclass, fake_client);
 
@@ -1267,9 +1283,11 @@ TEST_F(AuthSrvTest, queryWithInMemoryClientProxy) {
 // to throw in the given method
 // If isc_exception is true, it will throw isc::Exception, otherwise
 // it will throw std::exception
+// If non null rrset is given, it will be passed to the proxy so it can
+// return some faked response.
 void
 setupThrow(AuthSrv* server, const char *config, ThrowWhen throw_when,
-                   bool isc_exception)
+           bool isc_exception, ConstRRsetPtr rrset = ConstRRsetPtr())
 {
     // Set real inmem client to proxy
     updateConfig(server, config, true);
@@ -1279,8 +1297,7 @@ setupThrow(AuthSrv* server, const char *config, ThrowWhen throw_when,
     AuthSrv::InMemoryClientPtr fake_client(
         new FakeInMemoryClient(
             server->getInMemoryClient(isc::dns::RRClass::IN()),
-            throw_when,
-            isc_exception));
+            throw_when, isc_exception, rrset));
 
     ASSERT_NE(AuthSrv::InMemoryClientPtr(),
               server->getInMemoryClient(isc::dns::RRClass::IN()));
@@ -1324,4 +1341,45 @@ TEST_F(AuthSrvTest, queryWithInMemoryClientProxyGetClass) {
                 opcode.getCode(), QR_FLAG | AA_FLAG, 1, 1, 2, 1);
 }
 
+TEST_F(AuthSrvTest, queryWithThrowingInToWire) {
+    // Set up a faked data source.  It will return an empty RRset for the
+    // query.
+    ConstRRsetPtr empty_rrset(new RRset(Name("foo.example"),
+                                        RRClass::IN(), RRType::TXT(),
+                                        RRTTL(0)));
+    setupThrow(&server, CONFIG_INMEMORY_EXAMPLE, THROW_NEVER, true,
+               empty_rrset);
+
+    // Repeat the query processing two times.  Due to the faked RRset,
+    // toWire() should throw, and it should result in SERVFAIL.
+    OutputBufferPtr orig_buffer;
+    for (int i = 0; i < 2; ++i) {
+        UnitTestUtil::createDNSSECRequestMessage(request_message, opcode,
+                                                 default_qid,
+                                                 Name("foo.example."),
+                                                 RRClass::IN(), RRType::TXT());
+        createRequestPacket(request_message, IPPROTO_UDP);
+        server.processMessage(*io_message, *parse_message, *response_obuffer,
+                              &dnsserv);
+        headerCheck(*parse_message, default_qid, Rcode::SERVFAIL(),
+                    opcode.getCode(), QR_FLAG, 1, 0, 0, 0);
+
+        // Make a backup of the original buffer for latest tests and replace
+        // it with a new one
+        if (!orig_buffer) {
+            orig_buffer = response_obuffer;
+            response_obuffer.reset(new OutputBuffer(0));
+        }
+        request_message.clear(Message::RENDER);
+        parse_message->clear(Message::PARSE);
+    }
+
+    // Now check if the original buffer is intact
+    parse_message->clear(Message::PARSE);
+    InputBuffer ibuffer(orig_buffer->getData(), orig_buffer->getLength());
+    parse_message->fromWire(ibuffer);
+    headerCheck(*parse_message, default_qid, Rcode::SERVFAIL(),
+                opcode.getCode(), QR_FLAG, 1, 0, 0, 0);
+}
+
 }