Browse Source

Merge branch 'master' into trac1003

Stephen Morris 13 years ago
parent
commit
a01cd4ac5a
100 changed files with 3531 additions and 1150 deletions
  1. 36 0
      ChangeLog
  2. 4 1
      configure.ac
  3. 3 0
      ext/asio/asio/impl/error_code.ipp
  4. 9 8
      src/bin/auth/auth_config.cc
  5. 16 16
      src/bin/auth/auth_srv.cc
  6. 13 13
      src/bin/auth/auth_srv.h
  7. 13 10
      src/bin/auth/command.cc
  8. 34 34
      src/bin/auth/query.cc
  9. 23 24
      src/bin/auth/query.h
  10. 5 5
      src/bin/auth/tests/auth_srv_unittest.cc
  11. 18 19
      src/bin/auth/tests/command_unittest.cc
  12. 23 23
      src/bin/auth/tests/config_unittest.cc
  13. 57 53
      src/bin/auth/tests/query_unittest.cc
  14. 7 3
      src/bin/bind10/Makefile.am
  15. 0 0
      src/bin/bind10/__init__.py
  16. 54 11
      src/bin/bind10/bind10_messages.mes
  17. 33 1
      src/bin/bind10/bind10.py.in
  18. 2 2
      src/bin/bind10/run_bind10.sh.in
  19. 226 0
      src/bin/bind10/sockcreator.py
  20. 3 4
      src/bin/bind10/tests/Makefile.am
  21. 11 1
      src/bin/bind10/tests/bind10_test.py.in
  22. 327 0
      src/bin/bind10/tests/sockcreator_test.py.in
  23. 15 6
      src/bin/bindctl/bindcmd.py
  24. 15 4
      src/bin/bindctl/bindctl_main.py.in
  25. 2 0
      src/bin/cmdctl/tests/cmdctl_test.py
  26. 2 2
      src/bin/dhcp6/tests/Makefile.am
  27. 1 1
      src/bin/dhcp6/tests/dhcp6_test.py
  28. 2 1
      src/bin/resolver/resolver.cc
  29. 2 1
      src/bin/resolver/tests/resolver_config_unittest.cc
  30. 1 1
      src/bin/sockcreator/README
  31. 16 2
      src/bin/stats/stats_httpd.py.in
  32. 89 0
      src/bin/stats/tests/b10-stats-httpd_test.py
  33. 2 0
      src/bin/xfrin/tests/xfrin_test.py
  34. 1 1
      src/bin/xfrout/tests/Makefile.am
  35. 149 22
      src/bin/xfrout/tests/xfrout_test.py.in
  36. 70 29
      src/bin/xfrout/xfrout.py.in
  37. 19 7
      src/bin/xfrout/xfrout.spec.pre.in
  38. 11 0
      src/bin/xfrout/xfrout_messages.mes
  39. 1 1
      src/lib/acl/Makefile.am
  40. 22 3
      src/lib/acl/dns.cc
  41. 19 5
      src/lib/acl/dns.h
  42. 83 0
      src/lib/acl/dnsname_check.h
  43. 2 0
      src/lib/acl/tests/Makefile.am
  44. 76 10
      src/lib/acl/tests/dns_test.cc
  45. 59 0
      src/lib/acl/tests/dnsname_check_unittest.cc
  46. 2 2
      src/lib/asiodns/tests/run_unittests.cc
  47. 2 0
      src/lib/cc/data.cc
  48. 21 5
      src/lib/config/module_spec.cc
  49. 9 0
      src/lib/config/tests/module_spec_unittests.cc
  50. 4 0
      src/lib/config/tests/testdata/Makefile.am
  51. 3 0
      src/lib/config/tests/testdata/data32_1.data
  52. 3 0
      src/lib/config/tests/testdata/data32_2.data
  53. 3 0
      src/lib/config/tests/testdata/data32_3.data
  54. 19 0
      src/lib/config/tests/testdata/spec32.spec
  55. 1 0
      src/lib/datasrc/Makefile.am
  56. 150 0
      src/lib/datasrc/client.h
  57. 35 35
      src/lib/datasrc/memory_datasrc.cc
  58. 59 97
      src/lib/datasrc/memory_datasrc.h
  59. 3 3
      src/lib/datasrc/rbtree.h
  60. 324 298
      src/lib/datasrc/tests/memory_datasrc_unittest.cc
  61. 19 17
      src/lib/datasrc/tests/zonetable_unittest.cc
  62. 10 10
      src/lib/datasrc/zone.h
  63. 6 6
      src/lib/datasrc/zonetable.cc
  64. 3 3
      src/lib/datasrc/zonetable.h
  65. 46 3
      src/lib/dns/message.cc
  66. 11 0
      src/lib/dns/message.h
  67. 9 0
      src/lib/dns/python/message_python.cc
  68. 105 16
      src/lib/dns/python/tests/message_python_test.py
  69. 8 2
      src/lib/dns/python/tests/question_python_test.py
  70. 9 0
      src/lib/dns/question.cc
  71. 8 8
      src/lib/dns/question.h
  72. 5 0
      src/lib/dns/rrtype-placeholder.h
  73. 216 10
      src/lib/dns/tests/message_unittest.cc
  74. 16 0
      src/lib/dns/tests/question_unittest.cc
  75. 5 1
      src/lib/dns/tests/testdata/Makefile.am
  76. 5 7
      src/lib/dns/tests/testdata/gen-wiredata.py.in
  77. 22 0
      src/lib/dns/tests/testdata/message_fromWire17.spec
  78. 23 0
      src/lib/dns/tests/testdata/message_fromWire18.spec
  79. 27 0
      src/lib/dns/tests/testdata/message_toWire4.spec
  80. 36 0
      src/lib/dns/tests/testdata/message_toWire5.spec
  81. 72 0
      src/lib/dns/tests/tsig_unittest.cc
  82. 87 16
      src/lib/dns/tsig.cc
  83. 21 0
      src/lib/dns/tsig.h
  84. 1 0
      src/lib/log/Makefile.am
  85. 1 149
      src/lib/log/logger_support.cc
  86. 9 53
      src/lib/log/logger_support.h
  87. 175 0
      src/lib/log/logger_unittest_support.cc
  88. 126 0
      src/lib/log/logger_unittest_support.h
  89. 3 3
      src/lib/log/tests/init_logger_test.sh.in
  90. 5 2
      src/lib/log/tests/logger_level_impl_unittest.cc
  91. 5 3
      src/lib/log/tests/logger_level_unittest.cc
  92. 13 2
      src/lib/log/tests/logger_support_unittest.cc
  93. 12 12
      src/lib/python/isc/acl/Makefile.am
  94. 29 0
      src/lib/python/isc/acl/_dns.py
  95. 2 2
      src/lib/python/isc/acl/dns.cc
  96. 58 18
      src/lib/python/isc/acl/dns.py
  97. 2 2
      src/lib/python/isc/acl/dns_requestacl_python.cc
  98. 11 8
      src/lib/python/isc/acl/dns_requestcontext_inc.cc
  99. 96 33
      src/lib/python/isc/acl/dns_requestcontext_python.cc
  100. 0 0
      src/lib/python/isc/acl/dns_requestloader_python.cc

+ 36 - 0
ChangeLog

@@ -1,3 +1,37 @@
+275.	[func]		jinmei
+	Added support for TSIG key matching in ACLs.  The xfrout ACL can
+	now refer to TSIG key names using the "key" attribute.  For
+	example, the following specifies an ACL that allows zone transfer
+	if and only if the request is signed with a TSIG of a key name
+	"key.example":
+	> config set Xfrout/query_acl[0] {"action": "ACCEPT", \
+	                                  "key": "key.example"}
+	(Trac #1104, git 9b2e89cabb6191db86f88ee717f7abc4171fa979)
+
+274.	[bug]		naokikambe
+	add unittests for functions xml_handler, xsd_handler and xsl_handler
+	respectively to make sure their behaviors are correct, regardless of
+	whether type which xml.etree.ElementTree.tostring() after Python3.2
+	returns is str or byte.
+	(Trac #1021, git 486bf91e0ecc5fbecfe637e1e75ebe373d42509b)
+
+273.    [func]		vorner
+	It is possible to specify ACL for the xfrout module. It is in the ACL
+	configuration key and has the usual ACL syntax. It currently supports
+	only the source address. Default ACL accepts everything.
+	(Trac #772, git 50070c824270d5da1db0b716db73b726d458e9f7)
+
+272.	[func]		jinmei
+	libdns++/pydnspp: TSIG signing now handles truncated DNS messages
+	(i.e. with TC bit on) with TSIG correctly.
+	(Trac #910, 8e00f359e81c3cb03c5075710ead0f87f87e3220)
+
+271.	[func]		stephen
+	Default logging for unit tests changed to severity DEBUG (level 99)
+	with the output routed to /dev/null.  This can be altered by setting
+	the B10_LOGGER_XXX environment variables.
+	(Trac #1024, git 72a0beb8dfe85b303f546d09986461886fe7a3d8)
+
 270.	[func]		jinmei
 	Added python bindings for ACLs using the DNS request as the
 	context.  They are accessible via the isc.acl.dns module.
@@ -13,6 +47,8 @@
 	unit tests.
 	(Trac #1071, git 05164f9d61006869233b498d248486b4307ea8b6)
 
+bind10-devel-20110705 released on July 05, 2011
+
 267.	[func]		tomek
 	Added a dummy module for DHCP6. This module does not actually
 	do anything at this point, and BIND 10 has no option for

+ 4 - 1
configure.ac

@@ -270,6 +270,8 @@ B10_CXXFLAGS="-Wall -Wextra -Wwrite-strings -Woverloaded-virtual -Wno-sign-compa
 case "$host" in
 *-solaris*)
 	MULTITHREADING_FLAG=-pthreads
+	# In Solaris, IN6ADDR_ANY_INIT and IN6ADDR_LOOPBACK_INIT need -Wno-missing-braces
+	B10_CXXFLAGS="$B10_CXXFLAGS -Wno-missing-braces"
 	;;
 *)
 	MULTITHREADING_FLAG=-pthread
@@ -902,9 +904,10 @@ AC_OUTPUT([doc/version.ent
            src/bin/zonemgr/run_b10-zonemgr.sh
            src/bin/stats/stats.py
            src/bin/stats/stats_httpd.py
-           src/bin/bind10/bind10.py
+           src/bin/bind10/bind10_src.py
            src/bin/bind10/run_bind10.sh
            src/bin/bind10/tests/bind10_test.py
+           src/bin/bind10/tests/sockcreator_test.py
            src/bin/bindctl/run_bindctl.sh
            src/bin/bindctl/bindctl_main.py
            src/bin/bindctl/tests/bindctl_test

+ 3 - 0
ext/asio/asio/impl/error_code.ipp

@@ -11,6 +11,9 @@
 #ifndef ASIO_IMPL_ERROR_CODE_IPP
 #define ASIO_IMPL_ERROR_CODE_IPP
 
+// strerror() needs <cstring>
+#include <cstring>
+
 #if defined(_MSC_VER) && (_MSC_VER >= 1200)
 # pragma once
 #endif // defined(_MSC_VER) && (_MSC_VER >= 1200)

+ 9 - 8
src/bin/auth/auth_config.cc

@@ -107,7 +107,7 @@ DatasourcesConfig::commit() {
     // server implementation details, and isn't scalable wrt the number of
     // data source types, and should eventually be improved.
     // Currently memory data source for class IN is the only possibility.
-    server_.setMemoryDataSrc(RRClass::IN(), AuthSrv::MemoryDataSrcPtr());
+    server_.setInMemoryClient(RRClass::IN(), AuthSrv::InMemoryClientPtr());
 
     BOOST_FOREACH(shared_ptr<AuthConfigParser> datasrc_config, datasources_) {
         datasrc_config->commit();
@@ -125,12 +125,12 @@ public:
     {}
     virtual void build(ConstElementPtr config_value);
     virtual void commit() {
-        server_.setMemoryDataSrc(rrclass_, memory_datasrc_);
+        server_.setInMemoryClient(rrclass_, memory_client_);
     }
 private:
     AuthSrv& server_;
     RRClass rrclass_;
-    AuthSrv::MemoryDataSrcPtr memory_datasrc_;
+    AuthSrv::InMemoryClientPtr memory_client_;
 };
 
 void
@@ -143,8 +143,8 @@ MemoryDatasourceConfig::build(ConstElementPtr config_value) {
     // We'd eventually optimize building zones (in case of reloading) by
     // selectively loading fresh zones.  Right now we simply check the
     // RR class is supported by the server implementation.
-    server_.getMemoryDataSrc(rrclass_);
-    memory_datasrc_ = AuthSrv::MemoryDataSrcPtr(new MemoryDataSrc());
+    server_.getInMemoryClient(rrclass_);
+    memory_client_ = AuthSrv::InMemoryClientPtr(new InMemoryClient());
 
     ConstElementPtr zones_config = config_value->get("zones");
     if (!zones_config) {
@@ -163,9 +163,10 @@ MemoryDatasourceConfig::build(ConstElementPtr config_value) {
             isc_throw(AuthConfigError, "Missing zone file for zone: "
                       << origin->str());
         }
-        shared_ptr<MemoryZone> new_zone(new MemoryZone(rrclass_,
+        shared_ptr<InMemoryZoneFinder> zone_finder(new
+                                                   InMemoryZoneFinder(rrclass_,
             Name(origin->stringValue())));
-        const result::Result result = memory_datasrc_->addZone(new_zone);
+        const result::Result result = memory_client_->addZone(zone_finder);
         if (result == result::EXIST) {
             isc_throw(AuthConfigError, "zone "<< origin->str()
                       << " already exists");
@@ -177,7 +178,7 @@ MemoryDatasourceConfig::build(ConstElementPtr config_value) {
          * need the load method to be split into some kind of build and
          * commit/abort parts.
          */
-        new_zone->load(file->stringValue());
+        zone_finder->load(file->stringValue());
     }
 }
 

+ 16 - 16
src/bin/auth/auth_srv.cc

@@ -108,8 +108,8 @@ public:
     AbstractSession* xfrin_session_;
 
     /// In-memory data source.  Currently class IN only for simplicity.
-    const RRClass memory_datasrc_class_;
-    AuthSrv::MemoryDataSrcPtr memory_datasrc_;
+    const RRClass memory_client_class_;
+    AuthSrv::InMemoryClientPtr memory_client_;
 
     /// Hot spot cache
     isc::datasrc::HotCache cache_;
@@ -145,7 +145,7 @@ AuthSrvImpl::AuthSrvImpl(const bool use_cache,
                          AbstractXfroutClient& xfrout_client) :
     config_session_(NULL),
     xfrin_session_(NULL),
-    memory_datasrc_class_(RRClass::IN()),
+    memory_client_class_(RRClass::IN()),
     statistics_timer_(io_service_),
     counters_(),
     keyring_(NULL),
@@ -290,7 +290,7 @@ makeErrorMessage(MessagePtr message, OutputBufferPtr buffer,
         message->toWire(renderer);
     }
     LOG_DEBUG(auth_logger, DBG_AUTH_MESSAGES, AUTH_SEND_ERROR_RESPONSE)
-              .arg(message->toText());
+              .arg(renderer.getLength()).arg(*message);
 }
 }
 
@@ -329,34 +329,34 @@ AuthSrv::getConfigSession() const {
     return (impl_->config_session_);
 }
 
-AuthSrv::MemoryDataSrcPtr
-AuthSrv::getMemoryDataSrc(const RRClass& rrclass) {
+AuthSrv::InMemoryClientPtr
+AuthSrv::getInMemoryClient(const RRClass& rrclass) {
     // XXX: for simplicity, we only support the IN class right now.
-    if (rrclass != impl_->memory_datasrc_class_) {
+    if (rrclass != impl_->memory_client_class_) {
         isc_throw(InvalidParameter,
                   "Memory data source is not supported for RR class "
                   << rrclass);
     }
-    return (impl_->memory_datasrc_);
+    return (impl_->memory_client_);
 }
 
 void
-AuthSrv::setMemoryDataSrc(const isc::dns::RRClass& rrclass,
-                          MemoryDataSrcPtr memory_datasrc)
+AuthSrv::setInMemoryClient(const isc::dns::RRClass& rrclass,
+                           InMemoryClientPtr memory_client)
 {
     // XXX: see above
-    if (rrclass != impl_->memory_datasrc_class_) {
+    if (rrclass != impl_->memory_client_class_) {
         isc_throw(InvalidParameter,
                   "Memory data source is not supported for RR class "
                   << rrclass);
-    } else if (!impl_->memory_datasrc_ && memory_datasrc) {
+    } else if (!impl_->memory_client_ && memory_client) {
         LOG_DEBUG(auth_logger, DBG_AUTH_OPS, AUTH_MEM_DATASRC_ENABLED)
                   .arg(rrclass);
-    } else if (impl_->memory_datasrc_ && !memory_datasrc) {
+    } else if (impl_->memory_client_ && !memory_client) {
         LOG_DEBUG(auth_logger, DBG_AUTH_OPS, AUTH_MEM_DATASRC_DISABLED)
                   .arg(rrclass);
     }
-    impl_->memory_datasrc_ = memory_datasrc;
+    impl_->memory_client_ = memory_client;
 }
 
 uint32_t
@@ -505,10 +505,10 @@ AuthSrvImpl::processNormalQuery(const IOMessage& io_message, MessagePtr message,
         // If a memory data source is configured call the separate
         // Query::process()
         const ConstQuestionPtr question = *message->beginQuestion();
-        if (memory_datasrc_ && memory_datasrc_class_ == question->getClass()) {
+        if (memory_client_ && memory_client_class_ == question->getClass()) {
             const RRType& qtype = question->getType();
             const Name& qname = question->getName();
-            auth::Query(*memory_datasrc_, qname, qtype, *message).process();
+            auth::Query(*memory_client_, qname, qtype, *message).process();
         } else {
             datasrc::Query query(*message, cache_, dnssec_ok);
             data_sources_.doQuery(query);

+ 13 - 13
src/bin/auth/auth_srv.h

@@ -17,7 +17,7 @@
 
 #include <string>
 
-// For MemoryDataSrcPtr below.  This should be a temporary definition until
+// For InMemoryClientPtr below.  This should be a temporary definition until
 // we reorganize the data source framework.
 #include <boost/shared_ptr.hpp>
 
@@ -39,7 +39,7 @@
 
 namespace isc {
 namespace datasrc {
-class MemoryDataSrc;
+class InMemoryClient;
 }
 namespace xfr {
 class AbstractXfroutClient;
@@ -133,7 +133,7 @@ public:
     /// If there is a data source installed, it will be replaced with the
     /// new one.
     ///
-    /// In the current implementation, the SQLite data source and MemoryDataSrc
+    /// In the current implementation, the SQLite data source and InMemoryClient
     /// are assumed.
     /// We can enable memory data source and get the path of SQLite database by
     /// the \c config parameter.  If we disabled memory data source, the SQLite
@@ -233,16 +233,16 @@ public:
     ///
     void setXfrinSession(isc::cc::AbstractSession* xfrin_session);
 
-    /// A shared pointer type for \c MemoryDataSrc.
+    /// A shared pointer type for \c InMemoryClient.
     ///
     /// This is defined inside the \c AuthSrv class as it's supposed to be
     /// a short term interface until we integrate the in-memory and other
     /// data source frameworks.
-    typedef boost::shared_ptr<isc::datasrc::MemoryDataSrc> MemoryDataSrcPtr;
+    typedef boost::shared_ptr<isc::datasrc::InMemoryClient> InMemoryClientPtr;
 
-    /// An immutable shared pointer type for \c MemoryDataSrc.
-    typedef boost::shared_ptr<const isc::datasrc::MemoryDataSrc>
-    ConstMemoryDataSrcPtr;
+    /// An immutable shared pointer type for \c InMemoryClient.
+    typedef boost::shared_ptr<const isc::datasrc::InMemoryClient>
+    ConstInMemoryClientPtr;
 
     /// Returns the in-memory data source configured for the \c AuthSrv,
     /// if any.
@@ -260,11 +260,11 @@ public:
     /// \param rrclass The RR class of the requested in-memory data source.
     /// \return A pointer to the in-memory data source, if configured;
     /// otherwise NULL.
-    MemoryDataSrcPtr getMemoryDataSrc(const isc::dns::RRClass& rrclass);
+    InMemoryClientPtr getInMemoryClient(const isc::dns::RRClass& rrclass);
 
     /// Sets or replaces the in-memory data source of the specified RR class.
     ///
-    /// As noted in \c getMemoryDataSrc(), some RR classes may not be
+    /// As noted in \c getInMemoryClient(), some RR classes may not be
     /// supported, in which case an exception of class \c InvalidParameter
     /// will be thrown.
     /// This method never throws an exception otherwise.
@@ -275,9 +275,9 @@ public:
     /// in-memory data source.
     ///
     /// \param rrclass The RR class of the in-memory data source to be set.
-    /// \param memory_datasrc A (shared) pointer to \c MemoryDataSrc to be set.
-    void setMemoryDataSrc(const isc::dns::RRClass& rrclass,
-                          MemoryDataSrcPtr memory_datasrc);
+    /// \param memory_datasrc A (shared) pointer to \c InMemoryClient to be set.
+    void setInMemoryClient(const isc::dns::RRClass& rrclass,
+                           InMemoryClientPtr memory_client);
 
     /// \brief Set the communication session with Statistics.
     ///

+ 13 - 10
src/bin/auth/command.cc

@@ -136,19 +136,21 @@ public:
         // that doesn't block other server operations.
         // TODO: we may (should?) want to check the "last load time" and
         // the timestamp of the file and skip loading if the file isn't newer.
-        shared_ptr<MemoryZone> newzone(new MemoryZone(oldzone->getClass(),
-                                                      oldzone->getOrigin()));
-        newzone->load(oldzone->getFileName());
-        oldzone->swap(*newzone);
+        shared_ptr<InMemoryZoneFinder> zone_finder(
+            new InMemoryZoneFinder(old_zone_finder->getClass(),
+                                   old_zone_finder->getOrigin()));
+        zone_finder->load(old_zone_finder->getFileName());
+        old_zone_finder->swap(*zone_finder);
         LOG_DEBUG(auth_logger, DBG_AUTH_OPS, AUTH_LOAD_ZONE)
-                  .arg(newzone->getOrigin()).arg(newzone->getClass());
+                  .arg(zone_finder->getOrigin()).arg(zone_finder->getClass());
     }
 
 private:
-    shared_ptr<MemoryZone> oldzone; // zone to be updated with the new file.
+    // zone finder to be updated with the new file.
+    shared_ptr<InMemoryZoneFinder> old_zone_finder;
 
     // A helper private method to parse and validate command parameters.
-    // On success, it sets 'oldzone' to the zone to be updated.
+    // On success, it sets 'old_zone_finder' to the zone to be updated.
     // It returns true if everything is okay; and false if the command is
     // valid but there's no need for further process.
     bool validate(AuthSrv& server, isc::data::ConstElementPtr args) {
@@ -176,7 +178,7 @@ private:
         const RRClass zone_class = class_elem ?
             RRClass(class_elem->stringValue()) : RRClass::IN();
 
-        AuthSrv::MemoryDataSrcPtr datasrc(server.getMemoryDataSrc(zone_class));
+        AuthSrv::InMemoryClientPtr datasrc(server.getInMemoryClient(zone_class));
         if (datasrc == NULL) {
             isc_throw(AuthCommandError, "Memory data source is disabled");
         }
@@ -188,13 +190,14 @@ private:
         const Name origin(origin_elem->stringValue());
 
         // Get the current zone
-        const MemoryDataSrc::FindResult result = datasrc->findZone(origin);
+        const InMemoryClient::FindResult result = datasrc->findZone(origin);
         if (result.code != result::SUCCESS) {
             isc_throw(AuthCommandError, "Zone " << origin <<
                       " is not found in data source");
         }
 
-        oldzone = boost::dynamic_pointer_cast<MemoryZone>(result.zone);
+        old_zone_finder = boost::dynamic_pointer_cast<InMemoryZoneFinder>(
+            result.zone_finder);
 
         return (true);
     }

+ 34 - 34
src/bin/auth/query.cc

@@ -19,7 +19,7 @@
 #include <dns/rcode.h>
 #include <dns/rdataclass.h>
 
-#include <datasrc/memory_datasrc.h>
+#include <datasrc/client.h>
 
 #include <auth/query.h>
 
@@ -31,14 +31,14 @@ namespace isc {
 namespace auth {
 
 void
-Query::getAdditional(const Zone& zone, const RRset& rrset) const {
+Query::getAdditional(const ZoneFinder& zone, const RRset& rrset) const {
     RdataIteratorPtr rdata_iterator(rrset.getRdataIterator());
     for (; !rdata_iterator->isLast(); rdata_iterator->next()) {
         const Rdata& rdata(rdata_iterator->getCurrent());
         if (rrset.getType() == RRType::NS()) {
             // Need to perform the search in the "GLUE OK" mode.
             const generic::NS& ns = dynamic_cast<const generic::NS&>(rdata);
-            findAddrs(zone, ns.getNSName(), Zone::FIND_GLUE_OK);
+            findAddrs(zone, ns.getNSName(), ZoneFinder::FIND_GLUE_OK);
         } else if (rrset.getType() == RRType::MX()) {
             const generic::MX& mx(dynamic_cast<const generic::MX&>(rdata));
             findAddrs(zone, mx.getMXName());
@@ -47,8 +47,8 @@ Query::getAdditional(const Zone& zone, const RRset& rrset) const {
 }
 
 void
-Query::findAddrs(const Zone& zone, const Name& qname,
-                 const Zone::FindOptions options) const
+Query::findAddrs(const ZoneFinder& zone, const Name& qname,
+                 const ZoneFinder::FindOptions options) const
 {
     // Out of zone name
     NameComparisonResult result = zone.getOrigin().compare(qname);
@@ -66,9 +66,9 @@ Query::findAddrs(const Zone& zone, const Name& qname,
 
     // Find A rrset
     if (qname_ != qname || qtype_ != RRType::A()) {
-        Zone::FindResult a_result = zone.find(qname, RRType::A(), NULL,
+        ZoneFinder::FindResult a_result = zone.find(qname, RRType::A(), NULL,
                                               options);
-        if (a_result.code == Zone::SUCCESS) {
+        if (a_result.code == ZoneFinder::SUCCESS) {
             response_.addRRset(Message::SECTION_ADDITIONAL,
                     boost::const_pointer_cast<RRset>(a_result.rrset));
         }
@@ -76,9 +76,9 @@ Query::findAddrs(const Zone& zone, const Name& qname,
 
     // Find AAAA rrset
     if (qname_ != qname || qtype_ != RRType::AAAA()) {
-        Zone::FindResult aaaa_result =
+        ZoneFinder::FindResult aaaa_result =
             zone.find(qname, RRType::AAAA(), NULL, options);
-        if (aaaa_result.code == Zone::SUCCESS) {
+        if (aaaa_result.code == ZoneFinder::SUCCESS) {
             response_.addRRset(Message::SECTION_ADDITIONAL,
                     boost::const_pointer_cast<RRset>(aaaa_result.rrset));
         }
@@ -86,10 +86,10 @@ Query::findAddrs(const Zone& zone, const Name& qname,
 }
 
 void
-Query::putSOA(const Zone& zone) const {
-    Zone::FindResult soa_result(zone.find(zone.getOrigin(),
+Query::putSOA(const ZoneFinder& zone) const {
+    ZoneFinder::FindResult soa_result(zone.find(zone.getOrigin(),
         RRType::SOA()));
-    if (soa_result.code != Zone::SUCCESS) {
+    if (soa_result.code != ZoneFinder::SUCCESS) {
         isc_throw(NoSOA, "There's no SOA record in zone " <<
             zone.getOrigin().toText());
     } else {
@@ -104,11 +104,12 @@ Query::putSOA(const Zone& zone) const {
 }
 
 void
-Query::getAuthAdditional(const Zone& zone) const {
+Query::getAuthAdditional(const ZoneFinder& zone) const {
     // Fill in authority and addtional sections.
-    Zone::FindResult ns_result = zone.find(zone.getOrigin(), RRType::NS());
+    ZoneFinder::FindResult ns_result = zone.find(zone.getOrigin(),
+                                                 RRType::NS());
     // zone origin name should have NS records
-    if (ns_result.code != Zone::SUCCESS) {
+    if (ns_result.code != ZoneFinder::SUCCESS) {
         isc_throw(NoApexNS, "There's no apex NS records in zone " <<
                 zone.getOrigin().toText());
     } else {
@@ -125,8 +126,8 @@ Query::process() const {
     const bool qtype_is_any = (qtype_ == RRType::ANY());
 
     response_.setHeaderFlag(Message::HEADERFLAG_AA, false);
-    const MemoryDataSrc::FindResult result =
-        memory_datasrc_.findZone(qname_);
+    const DataSourceClient::FindResult result =
+        datasrc_client_.findZone(qname_);
 
     // If we have no matching authoritative zone for the query name, return
     // REFUSED.  In short, this is to be compatible with BIND 9, but the
@@ -145,11 +146,10 @@ Query::process() const {
     while (keep_doing) {
         keep_doing = false;
         std::auto_ptr<RRsetList> target(qtype_is_any ? new RRsetList : NULL);
-        const Zone::FindResult db_result(result.zone->find(qname_, qtype_,
-            target.get()));
-
+        const ZoneFinder::FindResult db_result(
+            result.zone_finder->find(qname_, qtype_, target.get()));
         switch (db_result.code) {
-            case Zone::DNAME: {
+            case ZoneFinder::DNAME: {
                 // First, put the dname into the answer
                 response_.addRRset(Message::SECTION_ANSWER,
                     boost::const_pointer_cast<RRset>(db_result.rrset));
@@ -191,7 +191,7 @@ Query::process() const {
                 response_.addRRset(Message::SECTION_ANSWER, cname);
                 break;
             }
-            case Zone::CNAME:
+            case ZoneFinder::CNAME:
                 /*
                  * We don't do chaining yet. Therefore handling a CNAME is
                  * mostly the same as handling SUCCESS, but we didn't get
@@ -204,46 +204,46 @@ Query::process() const {
                 response_.addRRset(Message::SECTION_ANSWER,
                     boost::const_pointer_cast<RRset>(db_result.rrset));
                 break;
-            case Zone::SUCCESS:
+            case ZoneFinder::SUCCESS:
                 if (qtype_is_any) {
                     // If quety type is ANY, insert all RRs under the domain
                     // into answer section.
                     BOOST_FOREACH(RRsetPtr rrset, *target) {
                         response_.addRRset(Message::SECTION_ANSWER, rrset);
                         // Handle additional for answer section
-                        getAdditional(*result.zone, *rrset.get());
+                        getAdditional(*result.zone_finder, *rrset.get());
                     }
                 } else {
                     response_.addRRset(Message::SECTION_ANSWER,
                         boost::const_pointer_cast<RRset>(db_result.rrset));
                     // Handle additional for answer section
-                    getAdditional(*result.zone, *db_result.rrset);
+                    getAdditional(*result.zone_finder, *db_result.rrset);
                 }
                 // If apex NS records haven't been provided in the answer
                 // section, insert apex NS records into the authority section
                 // and AAAA/A RRS of each of the NS RDATA into the additional
                 // section.
-                if (qname_ != result.zone->getOrigin() ||
-                    db_result.code != Zone::SUCCESS ||
+                if (qname_ != result.zone_finder->getOrigin() ||
+                    db_result.code != ZoneFinder::SUCCESS ||
                     (qtype_ != RRType::NS() && !qtype_is_any))
                 {
-                    getAuthAdditional(*result.zone);
+                    getAuthAdditional(*result.zone_finder);
                 }
                 break;
-            case Zone::DELEGATION:
+            case ZoneFinder::DELEGATION:
                 response_.setHeaderFlag(Message::HEADERFLAG_AA, false);
                 response_.addRRset(Message::SECTION_AUTHORITY,
                     boost::const_pointer_cast<RRset>(db_result.rrset));
-                getAdditional(*result.zone, *db_result.rrset);
+                getAdditional(*result.zone_finder, *db_result.rrset);
                 break;
-            case Zone::NXDOMAIN:
+            case ZoneFinder::NXDOMAIN:
                 // Just empty answer with SOA in authority section
                 response_.setRcode(Rcode::NXDOMAIN());
-                putSOA(*result.zone);
+                putSOA(*result.zone_finder);
                 break;
-            case Zone::NXRRSET:
+            case ZoneFinder::NXRRSET:
                 // Just empty answer with SOA in authority section
-                putSOA(*result.zone);
+                putSOA(*result.zone_finder);
                 break;
         }
     }

+ 23 - 24
src/bin/auth/query.h

@@ -26,7 +26,7 @@ class RRset;
 }
 
 namespace datasrc {
-class MemoryDataSrc;
+class DataSourceClient;
 }
 
 namespace auth {
@@ -36,10 +36,8 @@ namespace auth {
 ///
 /// Many of the design details for this class are still in flux.
 /// We'll revisit and update them as we add more functionality, for example:
-/// - memory_datasrc parameter of the constructor.  It is a data source that
-///   uses in memory dedicated backend.
 /// - as a related point, we may have to pass the RR class of the query.
-///   in the initial implementation the RR class is an attribute of memory
+///   in the initial implementation the RR class is an attribute of
 ///   datasource and omitted.  It's not clear if this assumption holds with
 ///   generic data sources.  On the other hand, it will help keep
 ///   implementation simpler, and we might rather want to modify the design
@@ -51,7 +49,7 @@ namespace auth {
 ///   separate attribute setter.
 /// - likewise, we'll eventually need to do per zone access control, for which
 ///   we need querier's information such as its IP address.
-/// - memory_datasrc and response may better be parameters to process() instead
+/// - datasrc_client and response may better be parameters to process() instead
 ///   of the constructor.
 ///
 /// <b>Note:</b> The class name is intentionally the same as the one used in
@@ -71,7 +69,7 @@ private:
     /// Adds a SOA of the zone into the authority zone of response_.
     /// Can throw NoSOA.
     ///
-    void putSOA(const isc::datasrc::Zone& zone) const;
+    void putSOA(const isc::datasrc::ZoneFinder& zone) const;
 
     /// \brief Look up additional data (i.e., address records for the names
     /// included in NS or MX records).
@@ -83,11 +81,11 @@ private:
     /// This method may throw a exception because its underlying methods may
     /// throw exceptions.
     ///
-    /// \param zone The Zone wherein the additional data to the query is bo be
-    /// found.
+    /// \param zone The ZoneFinder through which the additional data for the
+    /// query is to be found.
     /// \param rrset The RRset (i.e., NS or MX rrset) which require additional
     /// processing.
-    void getAdditional(const isc::datasrc::Zone& zone,
+    void getAdditional(const isc::datasrc::ZoneFinder& zone,
                        const isc::dns::RRset& rrset) const;
 
     /// \brief Find address records for a specified name.
@@ -102,18 +100,19 @@ private:
     /// The glue records must exactly match the name in the NS RDATA, without
     /// CNAME or wildcard processing.
     ///
-    /// \param zone The \c Zone wherein the address records is to be found.
+    /// \param zone The \c ZoneFinder through which the address records is to
+    /// be found.
     /// \param qname The name in rrset RDATA.
     /// \param options The search options.
-    void findAddrs(const isc::datasrc::Zone& zone,
+    void findAddrs(const isc::datasrc::ZoneFinder& zone,
                    const isc::dns::Name& qname,
-                   const isc::datasrc::Zone::FindOptions options
-                   = isc::datasrc::Zone::FIND_DEFAULT) const;
+                   const isc::datasrc::ZoneFinder::FindOptions options
+                   = isc::datasrc::ZoneFinder::FIND_DEFAULT) const;
 
-    /// \brief Look up \c Zone's NS and address records for the NS RDATA
-    /// (domain name) for authoritative answer.
+    /// \brief Look up a zone's NS RRset and their address records for an
+    /// authoritative answer.
     ///
-    /// On returning an authoritative answer, insert the \c Zone's NS into the
+    /// On returning an authoritative answer, insert a zone's NS into the
     /// authority section and AAAA/A RRs of each of the NS RDATA into the
     /// additional section.
     ///
@@ -126,24 +125,24 @@ private:
     /// include AAAA/A RRs under a zone cut in additional section. (BIND 9
     /// excludes under-cut RRs; NSD include them.)
     ///
-    /// \param zone The \c Zone wherein the additional data to the query is to
-    /// be found.
-    void getAuthAdditional(const isc::datasrc::Zone& zone) const;
+    /// \param zone The \c ZoneFinder through which the NS and additional data
+    /// for the query are to be found.
+    void getAuthAdditional(const isc::datasrc::ZoneFinder& zone) const;
 
 public:
     /// Constructor from query parameters.
     ///
     /// This constructor never throws an exception.
     ///
-    /// \param memory_datasrc The memory datasource wherein the answer to the query is
+    /// \param datasrc_client The datasource wherein the answer to the query is
     /// to be found.
     /// \param qname The query name
     /// \param qtype The RR type of the query
     /// \param response The response message to store the answer to the query.
-    Query(const isc::datasrc::MemoryDataSrc& memory_datasrc,
+    Query(const isc::datasrc::DataSourceClient& datasrc_client,
           const isc::dns::Name& qname, const isc::dns::RRType& qtype,
           isc::dns::Message& response) :
-        memory_datasrc_(memory_datasrc), qname_(qname), qtype_(qtype),
+        datasrc_client_(datasrc_client), qname_(qname), qtype_(qtype),
         response_(response)
     {}
 
@@ -157,7 +156,7 @@ public:
     /// successful search would result in adding a corresponding RRset to
     /// the answer section of the response.
     ///
-    /// If no matching zone is found in the memory datasource, the RCODE of
+    /// If no matching zone is found in the datasource, the RCODE of
     /// SERVFAIL will be set in the response.
     /// <b>Note:</b> this is different from the error code that BIND 9 returns
     /// by default when it's configured as an authoritative-only server (and
@@ -208,7 +207,7 @@ public:
     };
 
 private:
-    const isc::datasrc::MemoryDataSrc& memory_datasrc_;
+    const isc::datasrc::DataSourceClient& datasrc_client_;
     const isc::dns::Name& qname_;
     const isc::dns::RRType& qtype_;
     isc::dns::Message& response_;

+ 5 - 5
src/bin/auth/tests/auth_srv_unittest.cc

@@ -651,17 +651,17 @@ TEST_F(AuthSrvTest, updateConfigFail) {
                 QR_FLAG | AA_FLAG, 1, 1, 1, 0);
 }
 
-TEST_F(AuthSrvTest, updateWithMemoryDataSrc) {
+TEST_F(AuthSrvTest, updateWithInMemoryClient) {
     // Test configuring memory data source.  Detailed test cases are covered
     // in the configuration tests.  We only check the AuthSrv interface here.
 
     // By default memory data source isn't enabled
-    EXPECT_EQ(AuthSrv::MemoryDataSrcPtr(), server.getMemoryDataSrc(rrclass));
+    EXPECT_EQ(AuthSrv::InMemoryClientPtr(), server.getInMemoryClient(rrclass));
     updateConfig(&server,
                  "{\"datasources\": [{\"type\": \"memory\"}]}", true);
     // after successful configuration, we should have one (with empty zoneset).
-    ASSERT_NE(AuthSrv::MemoryDataSrcPtr(), server.getMemoryDataSrc(rrclass));
-    EXPECT_EQ(0, server.getMemoryDataSrc(rrclass)->getZoneCount());
+    ASSERT_NE(AuthSrv::InMemoryClientPtr(), server.getInMemoryClient(rrclass));
+    EXPECT_EQ(0, server.getInMemoryClient(rrclass)->getZoneCount());
 
     // The memory data source is empty, should return REFUSED rcode.
     createDataFromFile("examplequery_fromWire.wire");
@@ -672,7 +672,7 @@ TEST_F(AuthSrvTest, updateWithMemoryDataSrc) {
                 opcode.getCode(), QR_FLAG, 1, 0, 0, 0);
 }
 
-TEST_F(AuthSrvTest, chQueryWithMemoryDataSrc) {
+TEST_F(AuthSrvTest, chQueryWithInMemoryClient) {
     // Configure memory data source for class IN
     updateConfig(&server, "{\"datasources\": "
                  "[{\"class\": \"IN\", \"type\": \"memory\"}]}", true);

+ 18 - 19
src/bin/auth/tests/command_unittest.cc

@@ -60,7 +60,6 @@ protected:
     MockSession statistics_session;
     MockXfroutClient xfrout;
     AuthSrv server;
-    AuthSrv::ConstMemoryDataSrcPtr memory_datasrc;
     ConstElementPtr result;
     int rcode;
 public:
@@ -110,18 +109,18 @@ TEST_F(AuthCommandTest, shutdown) {
 // zones, and checks the zones are correctly loaded.
 void
 zoneChecks(AuthSrv& server) {
-    EXPECT_TRUE(server.getMemoryDataSrc(RRClass::IN()));
-    EXPECT_EQ(Zone::SUCCESS, server.getMemoryDataSrc(RRClass::IN())->
-              findZone(Name("ns.test1.example")).zone->
+    EXPECT_TRUE(server.getInMemoryClient(RRClass::IN()));
+    EXPECT_EQ(ZoneFinder::SUCCESS, server.getInMemoryClient(RRClass::IN())->
+              findZone(Name("ns.test1.example")).zone_finder->
               find(Name("ns.test1.example"), RRType::A()).code);
-    EXPECT_EQ(Zone::NXRRSET, server.getMemoryDataSrc(RRClass::IN())->
-              findZone(Name("ns.test1.example")).zone->
+    EXPECT_EQ(ZoneFinder::NXRRSET, server.getInMemoryClient(RRClass::IN())->
+              findZone(Name("ns.test1.example")).zone_finder->
               find(Name("ns.test1.example"), RRType::AAAA()).code);
-    EXPECT_EQ(Zone::SUCCESS, server.getMemoryDataSrc(RRClass::IN())->
-              findZone(Name("ns.test2.example")).zone->
+    EXPECT_EQ(ZoneFinder::SUCCESS, server.getInMemoryClient(RRClass::IN())->
+              findZone(Name("ns.test2.example")).zone_finder->
               find(Name("ns.test2.example"), RRType::A()).code);
-    EXPECT_EQ(Zone::NXRRSET, server.getMemoryDataSrc(RRClass::IN())->
-              findZone(Name("ns.test2.example")).zone->
+    EXPECT_EQ(ZoneFinder::NXRRSET, server.getInMemoryClient(RRClass::IN())->
+              findZone(Name("ns.test2.example")).zone_finder->
               find(Name("ns.test2.example"), RRType::AAAA()).code);
 }
 
@@ -147,21 +146,21 @@ configureZones(AuthSrv& server) {
 
 void
 newZoneChecks(AuthSrv& server) {
-    EXPECT_TRUE(server.getMemoryDataSrc(RRClass::IN()));
-    EXPECT_EQ(Zone::SUCCESS, server.getMemoryDataSrc(RRClass::IN())->
-              findZone(Name("ns.test1.example")).zone->
+    EXPECT_TRUE(server.getInMemoryClient(RRClass::IN()));
+    EXPECT_EQ(ZoneFinder::SUCCESS, server.getInMemoryClient(RRClass::IN())->
+              findZone(Name("ns.test1.example")).zone_finder->
               find(Name("ns.test1.example"), RRType::A()).code);
     // now test1.example should have ns/AAAA
-    EXPECT_EQ(Zone::SUCCESS, server.getMemoryDataSrc(RRClass::IN())->
-              findZone(Name("ns.test1.example")).zone->
+    EXPECT_EQ(ZoneFinder::SUCCESS, server.getInMemoryClient(RRClass::IN())->
+              findZone(Name("ns.test1.example")).zone_finder->
               find(Name("ns.test1.example"), RRType::AAAA()).code);
 
     // test2.example shouldn't change
-    EXPECT_EQ(Zone::SUCCESS, server.getMemoryDataSrc(RRClass::IN())->
-              findZone(Name("ns.test2.example")).zone->
+    EXPECT_EQ(ZoneFinder::SUCCESS, server.getInMemoryClient(RRClass::IN())->
+              findZone(Name("ns.test2.example")).zone_finder->
               find(Name("ns.test2.example"), RRType::A()).code);
-    EXPECT_EQ(Zone::NXRRSET, server.getMemoryDataSrc(RRClass::IN())->
-              findZone(Name("ns.test2.example")).zone->
+    EXPECT_EQ(ZoneFinder::NXRRSET, server.getInMemoryClient(RRClass::IN())->
+              findZone(Name("ns.test2.example")).zone_finder->
               find(Name("ns.test2.example"), RRType::AAAA()).code);
 }
 

+ 23 - 23
src/bin/auth/tests/config_unittest.cc

@@ -57,12 +57,12 @@ protected:
 
 TEST_F(AuthConfigTest, datasourceConfig) {
     // By default, we don't have any in-memory data source.
-    EXPECT_EQ(AuthSrv::MemoryDataSrcPtr(), server.getMemoryDataSrc(rrclass));
+    EXPECT_EQ(AuthSrv::InMemoryClientPtr(), server.getInMemoryClient(rrclass));
     configureAuthServer(server, Element::fromJSON(
                             "{\"datasources\": [{\"type\": \"memory\"}]}"));
     // after successful configuration, we should have one (with empty zoneset).
-    ASSERT_NE(AuthSrv::MemoryDataSrcPtr(), server.getMemoryDataSrc(rrclass));
-    EXPECT_EQ(0, server.getMemoryDataSrc(rrclass)->getZoneCount());
+    ASSERT_NE(AuthSrv::InMemoryClientPtr(), server.getInMemoryClient(rrclass));
+    EXPECT_EQ(0, server.getInMemoryClient(rrclass)->getZoneCount());
 }
 
 TEST_F(AuthConfigTest, databaseConfig) {
@@ -82,7 +82,7 @@ TEST_F(AuthConfigTest, versionConfig) {
 }
 
 TEST_F(AuthConfigTest, exceptionGuarantee) {
-    EXPECT_EQ(AuthSrv::MemoryDataSrcPtr(), server.getMemoryDataSrc(rrclass));
+    EXPECT_EQ(AuthSrv::InMemoryClientPtr(), server.getInMemoryClient(rrclass));
     // This configuration contains an invalid item, which will trigger
     // an exception.
     EXPECT_THROW(configureAuthServer(
@@ -92,7 +92,7 @@ TEST_F(AuthConfigTest, exceptionGuarantee) {
                          " \"no_such_config_var\": 1}")),
                  AuthConfigError);
     // The server state shouldn't change
-    EXPECT_EQ(AuthSrv::MemoryDataSrcPtr(), server.getMemoryDataSrc(rrclass));
+    EXPECT_EQ(AuthSrv::InMemoryClientPtr(), server.getInMemoryClient(rrclass));
 }
 
 TEST_F(AuthConfigTest, exceptionConversion) {
@@ -154,22 +154,22 @@ protected:
 TEST_F(MemoryDatasrcConfigTest, addZeroDataSrc) {
     parser->build(Element::fromJSON("[]"));
     parser->commit();
-    EXPECT_EQ(AuthSrv::MemoryDataSrcPtr(), server.getMemoryDataSrc(rrclass));
+    EXPECT_EQ(AuthSrv::InMemoryClientPtr(), server.getInMemoryClient(rrclass));
 }
 
 TEST_F(MemoryDatasrcConfigTest, addEmpty) {
     // By default, we don't have any in-memory data source.
-    EXPECT_EQ(AuthSrv::MemoryDataSrcPtr(), server.getMemoryDataSrc(rrclass));
+    EXPECT_EQ(AuthSrv::InMemoryClientPtr(), server.getInMemoryClient(rrclass));
     parser->build(Element::fromJSON("[{\"type\": \"memory\"}]"));
     parser->commit();
-    EXPECT_EQ(0, server.getMemoryDataSrc(rrclass)->getZoneCount());
+    EXPECT_EQ(0, server.getInMemoryClient(rrclass)->getZoneCount());
 }
 
 TEST_F(MemoryDatasrcConfigTest, addZeroZone) {
     parser->build(Element::fromJSON("[{\"type\": \"memory\","
                                     "  \"zones\": []}]"));
     parser->commit();
-    EXPECT_EQ(0, server.getMemoryDataSrc(rrclass)->getZoneCount());
+    EXPECT_EQ(0, server.getInMemoryClient(rrclass)->getZoneCount());
 }
 
 TEST_F(MemoryDatasrcConfigTest, addOneZone) {
@@ -179,10 +179,10 @@ TEST_F(MemoryDatasrcConfigTest, addOneZone) {
                       "               \"file\": \"" TEST_DATA_DIR
                       "/example.zone\"}]}]")));
     EXPECT_NO_THROW(parser->commit());
-    EXPECT_EQ(1, server.getMemoryDataSrc(rrclass)->getZoneCount());
+    EXPECT_EQ(1, server.getInMemoryClient(rrclass)->getZoneCount());
     // Check it actually loaded something
-    EXPECT_EQ(Zone::SUCCESS, server.getMemoryDataSrc(rrclass)->findZone(
-        Name("ns.example.com.")).zone->find(Name("ns.example.com."),
+    EXPECT_EQ(ZoneFinder::SUCCESS, server.getInMemoryClient(rrclass)->findZone(
+        Name("ns.example.com.")).zone_finder->find(Name("ns.example.com."),
         RRType::A()).code);
 }
 
@@ -199,7 +199,7 @@ TEST_F(MemoryDatasrcConfigTest, addMultiZones) {
                       "               \"file\": \"" TEST_DATA_DIR
                       "/example.net.zone\"}]}]")));
     EXPECT_NO_THROW(parser->commit());
-    EXPECT_EQ(3, server.getMemoryDataSrc(rrclass)->getZoneCount());
+    EXPECT_EQ(3, server.getInMemoryClient(rrclass)->getZoneCount());
 }
 
 TEST_F(MemoryDatasrcConfigTest, replace) {
@@ -209,9 +209,9 @@ TEST_F(MemoryDatasrcConfigTest, replace) {
                       "               \"file\": \"" TEST_DATA_DIR
                       "/example.zone\"}]}]")));
     EXPECT_NO_THROW(parser->commit());
-    EXPECT_EQ(1, server.getMemoryDataSrc(rrclass)->getZoneCount());
+    EXPECT_EQ(1, server.getInMemoryClient(rrclass)->getZoneCount());
     EXPECT_EQ(isc::datasrc::result::SUCCESS,
-              server.getMemoryDataSrc(rrclass)->findZone(
+              server.getInMemoryClient(rrclass)->findZone(
                   Name("example.com")).code);
 
     // create a new parser, and install a new set of configuration.  It
@@ -227,9 +227,9 @@ TEST_F(MemoryDatasrcConfigTest, replace) {
                       "               \"file\": \"" TEST_DATA_DIR
                       "/example.net.zone\"}]}]")));
     EXPECT_NO_THROW(parser->commit());
-    EXPECT_EQ(2, server.getMemoryDataSrc(rrclass)->getZoneCount());
+    EXPECT_EQ(2, server.getInMemoryClient(rrclass)->getZoneCount());
     EXPECT_EQ(isc::datasrc::result::NOTFOUND,
-              server.getMemoryDataSrc(rrclass)->findZone(
+              server.getInMemoryClient(rrclass)->findZone(
                   Name("example.com")).code);
 }
 
@@ -241,9 +241,9 @@ TEST_F(MemoryDatasrcConfigTest, exception) {
                       "               \"file\": \"" TEST_DATA_DIR
                       "/example.zone\"}]}]")));
     EXPECT_NO_THROW(parser->commit());
-    EXPECT_EQ(1, server.getMemoryDataSrc(rrclass)->getZoneCount());
+    EXPECT_EQ(1, server.getInMemoryClient(rrclass)->getZoneCount());
     EXPECT_EQ(isc::datasrc::result::SUCCESS,
-              server.getMemoryDataSrc(rrclass)->findZone(
+              server.getInMemoryClient(rrclass)->findZone(
                   Name("example.com")).code);
 
     // create a new parser, and try to load something. It will throw,
@@ -262,9 +262,9 @@ TEST_F(MemoryDatasrcConfigTest, exception) {
     // commit it
 
     // The original should be untouched
-    EXPECT_EQ(1, server.getMemoryDataSrc(rrclass)->getZoneCount());
+    EXPECT_EQ(1, server.getInMemoryClient(rrclass)->getZoneCount());
     EXPECT_EQ(isc::datasrc::result::SUCCESS,
-              server.getMemoryDataSrc(rrclass)->findZone(
+              server.getInMemoryClient(rrclass)->findZone(
                   Name("example.com")).code);
 }
 
@@ -275,13 +275,13 @@ TEST_F(MemoryDatasrcConfigTest, remove) {
                       "               \"file\": \"" TEST_DATA_DIR
                       "/example.zone\"}]}]")));
     EXPECT_NO_THROW(parser->commit());
-    EXPECT_EQ(1, server.getMemoryDataSrc(rrclass)->getZoneCount());
+    EXPECT_EQ(1, server.getInMemoryClient(rrclass)->getZoneCount());
 
     delete parser;
     parser = createAuthConfigParser(server, "datasources"); 
     EXPECT_NO_THROW(parser->build(Element::fromJSON("[]")));
     EXPECT_NO_THROW(parser->commit());
-    EXPECT_EQ(AuthSrv::MemoryDataSrcPtr(), server.getMemoryDataSrc(rrclass));
+    EXPECT_EQ(AuthSrv::InMemoryClientPtr(), server.getInMemoryClient(rrclass));
 }
 
 TEST_F(MemoryDatasrcConfigTest, adDuplicateZones) {

+ 57 - 53
src/bin/auth/tests/query_unittest.cc

@@ -93,9 +93,9 @@ const char* const other_zone_rrs =
     "mx.delegation.example.com. 3600 IN A 192.0.2.100\n";
 
 // This is a mock Zone class for testing.
-// It is a derived class of Zone for the convenient of tests.
+// It is a derived class of ZoneFinder for the convenient of tests.
 // Its find() method emulates the common behavior of protocol compliant
-// zone classes, but simplifies some minor cases and also supports broken
+// ZoneFinder classes, but simplifies some minor cases and also supports broken
 // behavior.
 // For simplicity, most names are assumed to be "in zone"; there's only
 // one zone cut at the point of name "delegation.example.com".
@@ -103,9 +103,9 @@ const char* const other_zone_rrs =
 // will result in DNAME.
 // This mock zone doesn't handle empty non terminal nodes (if we need to test
 // such cases find() should have specialized code for it).
-class MockZone : public Zone {
+class MockZoneFinder : public ZoneFinder {
 public:
-    MockZone() :
+    MockZoneFinder() :
         origin_(Name("example.com")),
         delegation_name_("delegation.example.com"),
         dname_name_("dname.example.com"),
@@ -120,7 +120,7 @@ public:
             other_zone_rrs;
 
         masterLoad(zone_stream, origin_, rrclass_,
-                   boost::bind(&MockZone::loadRRset, this, _1));
+                   boost::bind(&MockZoneFinder::loadRRset, this, _1));
     }
     virtual const isc::dns::Name& getOrigin() const { return (origin_); }
     virtual const isc::dns::RRClass& getClass() const { return (rrclass_); }
@@ -163,9 +163,9 @@ private:
     const RRClass rrclass_;
 };
 
-Zone::FindResult
-MockZone::find(const Name& name, const RRType& type,
-               RRsetList* target, const FindOptions options) const
+ZoneFinder::FindResult
+MockZoneFinder::find(const Name& name, const RRType& type,
+                     RRsetList* target, const FindOptions options) const
 {
     // Emulating a broken zone: mandatory apex RRs are missing if specifically
     // configured so (which are rare cases).
@@ -233,11 +233,15 @@ protected:
         response.setRcode(Rcode::NOERROR());
         response.setOpcode(Opcode::QUERY());
         // create and add a matching zone.
-        mock_zone = new MockZone();
-        memory_datasrc.addZone(ZonePtr(mock_zone));
+        mock_finder = new MockZoneFinder();
+        memory_client.addZone(ZoneFinderPtr(mock_finder));
     }
-    MockZone* mock_zone;
-    MemoryDataSrc memory_datasrc;
+    MockZoneFinder* mock_finder;
+    // We use InMemoryClient here. We could have some kind of mock client
+    // here, but historically, the Query supported only InMemoryClient
+    // (originally named MemoryDataSrc) and was tested with it, so we keep
+    // it like this for now.
+    InMemoryClient memory_client;
     const Name qname;
     const RRClass qclass;
     const RRType qtype;
@@ -286,14 +290,14 @@ responseCheck(Message& response, const isc::dns::Rcode& rcode,
 TEST_F(QueryTest, noZone) {
     // There's no zone in the memory datasource.  So the response should have
     // REFUSED.
-    MemoryDataSrc empty_memory_datasrc;
-    Query nozone_query(empty_memory_datasrc, qname, qtype, response);
+    InMemoryClient empty_memory_client;
+    Query nozone_query(empty_memory_client, qname, qtype, response);
     EXPECT_NO_THROW(nozone_query.process());
     EXPECT_EQ(Rcode::REFUSED(), response.getRcode());
 }
 
 TEST_F(QueryTest, exactMatch) {
-    Query query(memory_datasrc, qname, qtype, response);
+    Query query(memory_client, qname, qtype, response);
     EXPECT_NO_THROW(query.process());
     // find match rrset
     responseCheck(response, Rcode::NOERROR(), AA_FLAG, 1, 3, 3,
@@ -303,7 +307,7 @@ TEST_F(QueryTest, exactMatch) {
 TEST_F(QueryTest, exactAddrMatch) {
     // find match rrset, omit additional data which has already been provided
     // in the answer section from the additional.
-    EXPECT_NO_THROW(Query(memory_datasrc, Name("noglue.example.com"), qtype,
+    EXPECT_NO_THROW(Query(memory_client, Name("noglue.example.com"), qtype,
                           response).process());
 
     responseCheck(response, Rcode::NOERROR(), AA_FLAG, 1, 3, 2,
@@ -315,7 +319,7 @@ TEST_F(QueryTest, exactAddrMatch) {
 TEST_F(QueryTest, apexNSMatch) {
     // find match rrset, omit authority data which has already been provided
     // in the answer section from the authority section.
-    EXPECT_NO_THROW(Query(memory_datasrc, Name("example.com"), RRType::NS(),
+    EXPECT_NO_THROW(Query(memory_client, Name("example.com"), RRType::NS(),
                           response).process());
 
     responseCheck(response, Rcode::NOERROR(), AA_FLAG, 3, 0, 3,
@@ -326,7 +330,7 @@ TEST_F(QueryTest, apexNSMatch) {
 TEST_F(QueryTest, exactAnyMatch) {
     // find match rrset, omit additional data which has already been provided
     // in the answer section from the additional.
-    EXPECT_NO_THROW(Query(memory_datasrc, Name("noglue.example.com"),
+    EXPECT_NO_THROW(Query(memory_client, Name("noglue.example.com"),
                           RRType::ANY(), response).process());
 
     responseCheck(response, Rcode::NOERROR(), AA_FLAG, 1, 3, 2,
@@ -339,18 +343,18 @@ TEST_F(QueryTest, exactAnyMatch) {
 TEST_F(QueryTest, apexAnyMatch) {
     // find match rrset, omit additional data which has already been provided
     // in the answer section from the additional.
-    EXPECT_NO_THROW(Query(memory_datasrc, Name("example.com"),
+    EXPECT_NO_THROW(Query(memory_client, Name("example.com"),
                           RRType::ANY(), response).process());
     responseCheck(response, Rcode::NOERROR(), AA_FLAG, 4, 0, 3,
                   "example.com. 3600 IN SOA . . 0 0 0 0 0\n"
                   "example.com. 3600 IN NS glue.delegation.example.com.\n"
                   "example.com. 3600 IN NS noglue.example.com.\n"
                   "example.com. 3600 IN NS example.net.\n",
-                  NULL, ns_addrs_txt, mock_zone->getOrigin());
+                  NULL, ns_addrs_txt, mock_finder->getOrigin());
 }
 
 TEST_F(QueryTest, mxANYMatch) {
-    EXPECT_NO_THROW(Query(memory_datasrc, Name("mx.example.com"),
+    EXPECT_NO_THROW(Query(memory_client, Name("mx.example.com"),
                           RRType::ANY(), response).process());
     responseCheck(response, Rcode::NOERROR(), AA_FLAG, 3, 3, 4,
                   mx_txt, zone_ns_txt,
@@ -358,17 +362,17 @@ TEST_F(QueryTest, mxANYMatch) {
 }
 
 TEST_F(QueryTest, glueANYMatch) {
-    EXPECT_NO_THROW(Query(memory_datasrc, Name("delegation.example.com"),
+    EXPECT_NO_THROW(Query(memory_client, Name("delegation.example.com"),
                           RRType::ANY(), response).process());
     responseCheck(response, Rcode::NOERROR(), 0, 0, 4, 3,
                   NULL, delegation_txt, ns_addrs_txt);
 }
 
 TEST_F(QueryTest, nodomainANY) {
-    EXPECT_NO_THROW(Query(memory_datasrc, Name("nxdomain.example.com"),
+    EXPECT_NO_THROW(Query(memory_client, Name("nxdomain.example.com"),
                           RRType::ANY(), response).process());
     responseCheck(response, Rcode::NXDOMAIN(), AA_FLAG, 0, 1, 0,
-                  NULL, soa_txt, NULL, mock_zone->getOrigin());
+                  NULL, soa_txt, NULL, mock_finder->getOrigin());
 }
 
 // This tests that when we need to look up Zone's apex NS records for
@@ -376,15 +380,15 @@ TEST_F(QueryTest, nodomainANY) {
 // throw in that case.
 TEST_F(QueryTest, noApexNS) {
     // Disable apex NS record
-    mock_zone->setApexNSFlag(false);
+    mock_finder->setApexNSFlag(false);
 
-    EXPECT_THROW(Query(memory_datasrc, Name("noglue.example.com"), qtype,
+    EXPECT_THROW(Query(memory_client, Name("noglue.example.com"), qtype,
                        response).process(), Query::NoApexNS);
     // We don't look into the response, as it threw
 }
 
 TEST_F(QueryTest, delegation) {
-    EXPECT_NO_THROW(Query(memory_datasrc, Name("delegation.example.com"),
+    EXPECT_NO_THROW(Query(memory_client, Name("delegation.example.com"),
                           qtype, response).process());
 
     responseCheck(response, Rcode::NOERROR(), 0, 0, 4, 3,
@@ -392,18 +396,18 @@ TEST_F(QueryTest, delegation) {
 }
 
 TEST_F(QueryTest, nxdomain) {
-    EXPECT_NO_THROW(Query(memory_datasrc, Name("nxdomain.example.com"), qtype,
+    EXPECT_NO_THROW(Query(memory_client, Name("nxdomain.example.com"), qtype,
                           response).process());
     responseCheck(response, Rcode::NXDOMAIN(), AA_FLAG, 0, 1, 0,
-                  NULL, soa_txt, NULL, mock_zone->getOrigin());
+                  NULL, soa_txt, NULL, mock_finder->getOrigin());
 }
 
 TEST_F(QueryTest, nxrrset) {
-    EXPECT_NO_THROW(Query(memory_datasrc, Name("www.example.com"),
+    EXPECT_NO_THROW(Query(memory_client, Name("www.example.com"),
                           RRType::TXT(), response).process());
 
     responseCheck(response, Rcode::NOERROR(), AA_FLAG, 0, 1, 0,
-                  NULL, soa_txt, NULL, mock_zone->getOrigin());
+                  NULL, soa_txt, NULL, mock_finder->getOrigin());
 }
 
 /*
@@ -412,22 +416,22 @@ TEST_F(QueryTest, nxrrset) {
  */
 TEST_F(QueryTest, noSOA) {
     // disable zone's SOA RR.
-    mock_zone->setSOAFlag(false);
+    mock_finder->setSOAFlag(false);
 
     // The NX Domain
-    EXPECT_THROW(Query(memory_datasrc, Name("nxdomain.example.com"),
+    EXPECT_THROW(Query(memory_client, Name("nxdomain.example.com"),
                        qtype, response).process(), Query::NoSOA);
     // Of course, we don't look into the response, as it throwed
 
     // NXRRSET
-    EXPECT_THROW(Query(memory_datasrc, Name("nxrrset.example.com"),
+    EXPECT_THROW(Query(memory_client, Name("nxrrset.example.com"),
                        qtype, response).process(), Query::NoSOA);
 }
 
 TEST_F(QueryTest, noMatchZone) {
     // there's a zone in the memory datasource but it doesn't match the qname.
     // should result in REFUSED.
-    Query(memory_datasrc, Name("example.org"), qtype, response).process();
+    Query(memory_client, Name("example.org"), qtype, response).process();
     EXPECT_EQ(Rcode::REFUSED(), response.getRcode());
 }
 
@@ -438,7 +442,7 @@ TEST_F(QueryTest, noMatchZone) {
  * A record, other to unknown out of zone one.
  */
 TEST_F(QueryTest, MX) {
-    Query(memory_datasrc, Name("mx.example.com"), RRType::MX(),
+    Query(memory_client, Name("mx.example.com"), RRType::MX(),
           response).process();
 
     responseCheck(response, Rcode::NOERROR(), AA_FLAG, 3, 3, 4,
@@ -452,7 +456,7 @@ TEST_F(QueryTest, MX) {
  * This should not trigger the additional processing for the exchange.
  */
 TEST_F(QueryTest, MXAlias) {
-    Query(memory_datasrc, Name("cnamemx.example.com"), RRType::MX(),
+    Query(memory_client, Name("cnamemx.example.com"), RRType::MX(),
           response).process();
 
     // there shouldn't be no additional RRs for the exchanges (we have 3
@@ -472,7 +476,7 @@ TEST_F(QueryTest, MXAlias) {
  * returned.
  */
 TEST_F(QueryTest, CNAME) {
-    Query(memory_datasrc, Name("cname.example.com"), RRType::A(),
+    Query(memory_client, Name("cname.example.com"), RRType::A(),
         response).process();
 
     responseCheck(response, Rcode::NOERROR(), AA_FLAG, 1, 0, 0,
@@ -482,7 +486,7 @@ TEST_F(QueryTest, CNAME) {
 TEST_F(QueryTest, explicitCNAME) {
     // same owner name as the CNAME test but explicitly query for CNAME RR.
     // expect the same response as we don't provide a full chain yet.
-    Query(memory_datasrc, Name("cname.example.com"), RRType::CNAME(),
+    Query(memory_client, Name("cname.example.com"), RRType::CNAME(),
         response).process();
 
     responseCheck(response, Rcode::NOERROR(), AA_FLAG, 1, 3, 3,
@@ -494,7 +498,7 @@ TEST_F(QueryTest, CNAME_NX_RRSET) {
     // note: with chaining, what should be expected is not trivial:
     // BIND 9 returns the CNAME in answer and SOA in authority, no additional.
     // NSD returns the CNAME, NS in authority, A/AAAA for NS in additional.
-    Query(memory_datasrc, Name("cname.example.com"), RRType::TXT(),
+    Query(memory_client, Name("cname.example.com"), RRType::TXT(),
         response).process();
 
     responseCheck(response, Rcode::NOERROR(), AA_FLAG, 1, 0, 0,
@@ -503,7 +507,7 @@ TEST_F(QueryTest, CNAME_NX_RRSET) {
 
 TEST_F(QueryTest, explicitCNAME_NX_RRSET) {
     // same owner name as the NXRRSET test but explicitly query for CNAME RR.
-    Query(memory_datasrc, Name("cname.example.com"), RRType::CNAME(),
+    Query(memory_client, Name("cname.example.com"), RRType::CNAME(),
         response).process();
 
     responseCheck(response, Rcode::NOERROR(), AA_FLAG, 1, 3, 3,
@@ -517,7 +521,7 @@ TEST_F(QueryTest, CNAME_NX_DOMAIN) {
     // RCODE being NXDOMAIN.
     // NSD returns the CNAME, NS in authority, A/AAAA for NS in additional,
     // RCODE being NOERROR.
-    Query(memory_datasrc, Name("cnamenxdom.example.com"), RRType::A(),
+    Query(memory_client, Name("cnamenxdom.example.com"), RRType::A(),
         response).process();
 
     responseCheck(response, Rcode::NOERROR(), AA_FLAG, 1, 0, 0,
@@ -526,7 +530,7 @@ TEST_F(QueryTest, CNAME_NX_DOMAIN) {
 
 TEST_F(QueryTest, explicitCNAME_NX_DOMAIN) {
     // same owner name as the NXDOMAIN test but explicitly query for CNAME RR.
-    Query(memory_datasrc, Name("cnamenxdom.example.com"), RRType::CNAME(),
+    Query(memory_client, Name("cnamenxdom.example.com"), RRType::CNAME(),
         response).process();
 
     responseCheck(response, Rcode::NOERROR(), AA_FLAG, 1, 3, 3,
@@ -542,7 +546,7 @@ TEST_F(QueryTest, CNAME_OUT) {
      * Then the same test should be done with .org included there and
      * see what it does (depends on what we want to do)
      */
-    Query(memory_datasrc, Name("cnameout.example.com"), RRType::A(),
+    Query(memory_client, Name("cnameout.example.com"), RRType::A(),
         response).process();
 
     responseCheck(response, Rcode::NOERROR(), AA_FLAG, 1, 0, 0,
@@ -551,7 +555,7 @@ TEST_F(QueryTest, CNAME_OUT) {
 
 TEST_F(QueryTest, explicitCNAME_OUT) {
     // same owner name as the OUT test but explicitly query for CNAME RR.
-    Query(memory_datasrc, Name("cnameout.example.com"), RRType::CNAME(),
+    Query(memory_client, Name("cnameout.example.com"), RRType::CNAME(),
         response).process();
 
     responseCheck(response, Rcode::NOERROR(), AA_FLAG, 1, 3, 3,
@@ -567,7 +571,7 @@ TEST_F(QueryTest, explicitCNAME_OUT) {
  * pointing to NXRRSET and NXDOMAIN cases (similarly as with CNAME).
  */
 TEST_F(QueryTest, DNAME) {
-    Query(memory_datasrc, Name("www.dname.example.com"), RRType::A(),
+    Query(memory_client, Name("www.dname.example.com"), RRType::A(),
         response).process();
 
     responseCheck(response, Rcode::NOERROR(), AA_FLAG, 2, 0, 0,
@@ -583,7 +587,7 @@ TEST_F(QueryTest, DNAME) {
  * DNAME.
  */
 TEST_F(QueryTest, DNAME_ANY) {
-    Query(memory_datasrc, Name("www.dname.example.com"), RRType::ANY(),
+    Query(memory_client, Name("www.dname.example.com"), RRType::ANY(),
         response).process();
 
     responseCheck(response, Rcode::NOERROR(), AA_FLAG, 2, 0, 0,
@@ -592,7 +596,7 @@ TEST_F(QueryTest, DNAME_ANY) {
 
 // Test when we ask for DNAME explicitly, it does no synthetizing.
 TEST_F(QueryTest, explicitDNAME) {
-    Query(memory_datasrc, Name("dname.example.com"), RRType::DNAME(),
+    Query(memory_client, Name("dname.example.com"), RRType::DNAME(),
         response).process();
 
     responseCheck(response, Rcode::NOERROR(), AA_FLAG, 1, 3, 3,
@@ -604,7 +608,7 @@ TEST_F(QueryTest, explicitDNAME) {
  * the CNAME, it should return the RRset.
  */
 TEST_F(QueryTest, DNAME_A) {
-    Query(memory_datasrc, Name("dname.example.com"), RRType::A(),
+    Query(memory_client, Name("dname.example.com"), RRType::A(),
         response).process();
 
     responseCheck(response, Rcode::NOERROR(), AA_FLAG, 1, 3, 3,
@@ -616,11 +620,11 @@ TEST_F(QueryTest, DNAME_A) {
  * It should not synthetize the CNAME.
  */
 TEST_F(QueryTest, DNAME_NX_RRSET) {
-    EXPECT_NO_THROW(Query(memory_datasrc, Name("dname.example.com"),
+    EXPECT_NO_THROW(Query(memory_client, Name("dname.example.com"),
         RRType::TXT(), response).process());
 
     responseCheck(response, Rcode::NOERROR(), AA_FLAG, 0, 1, 0,
-        NULL, soa_txt, NULL, mock_zone->getOrigin());
+        NULL, soa_txt, NULL, mock_finder->getOrigin());
 }
 
 /*
@@ -636,7 +640,7 @@ TEST_F(QueryTest, LongDNAME) {
         "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa."
         "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa."
         "dname.example.com.");
-    EXPECT_NO_THROW(Query(memory_datasrc, longname, RRType::A(),
+    EXPECT_NO_THROW(Query(memory_client, longname, RRType::A(),
         response).process());
 
     responseCheck(response, Rcode::YXDOMAIN(), AA_FLAG, 1, 0, 0,
@@ -655,7 +659,7 @@ TEST_F(QueryTest, MaxLenDNAME) {
         "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa."
         "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa."
         "dname.example.com.");
-    EXPECT_NO_THROW(Query(memory_datasrc, longname, RRType::A(),
+    EXPECT_NO_THROW(Query(memory_client, longname, RRType::A(),
         response).process());
 
     // Check the answer is OK

+ 7 - 3
src/bin/bind10/Makefile.am

@@ -1,7 +1,11 @@
 SUBDIRS = . tests
 
 sbin_SCRIPTS = bind10
-CLEANFILES = bind10 bind10.pyc bind10_messages.py bind10_messages.pyc
+CLEANFILES = bind10 bind10_src.pyc bind10_messages.py bind10_messages.pyc \
+	sockcreator.pyc
+
+python_PYTHON = __init__.py sockcreator.py
+pythondir = $(pyexecdir)/bind10
 
 pkglibexecdir = $(libexecdir)/@PACKAGE@
 pyexec_DATA = bind10_messages.py
@@ -24,9 +28,9 @@ bind10_messages.py: bind10_messages.mes
 	$(top_builddir)/src/lib/log/compiler/message -p $(top_srcdir)/src/bin/bind10/bind10_messages.mes
 
 # this is done here since configure.ac AC_OUTPUT doesn't expand exec_prefix
-bind10: bind10.py
+bind10: bind10_src.py
 	$(SED) -e "s|@@PYTHONPATH@@|@pyexecdir@|" \
-	       -e "s|@@LIBEXECDIR@@|$(pkglibexecdir)|" bind10.py >$@
+	       -e "s|@@LIBEXECDIR@@|$(pkglibexecdir)|" bind10_src.py >$@
 	chmod a+x $@
 
 pytest:

+ 0 - 0
src/bin/bind10/__init__.py


+ 54 - 11
src/bin/bind10/bind10_messages.mes

@@ -32,15 +32,15 @@ started according to the configuration.
 The boss process was started with the -u option, to drop root privileges
 and continue running as the specified user, but the user is unknown.
 
+% BIND10_KILLING_ALL_PROCESSES killing all started processes
+The boss module was not able to start every process it needed to start
+during startup, and will now kill the processes that did get started.
+
 % BIND10_KILL_PROCESS killing process %1
 The boss module is sending a kill signal to process with the given name,
 as part of the process of killing all started processes during a failed
 startup, as described for BIND10_KILLING_ALL_PROCESSES
 
-% BIND10_KILLING_ALL_PROCESSES killing all started processes
-The boss module was not able to start every process it needed to start
-during startup, and will now kill the processes that did get started.
-
 % BIND10_MSGQ_ALREADY_RUNNING msgq daemon already running, cannot start
 There already appears to be a message bus daemon running. Either an
 old process was not shut down correctly, and needs to be killed, or
@@ -113,12 +113,49 @@ it shall send SIGKILL signals to the processes still alive.
 All child processes have been stopped, and the boss process will now
 stop itself.
 
-% BIND10_START_AS_NON_ROOT starting %1 as a user, not root. This might fail.
-The given module is being started or restarted without root privileges.
-If the module needs these privileges, it may have problems starting.
-Note that this issue should be resolved by the pending 'socket-creator'
-process; once that has been implemented, modules should not need root
-privileges anymore. See tickets #800 and #801 for more information.
+% BIND10_SOCKCREATOR_BAD_CAUSE unknown error cause from socket creator: %1
+The socket creator reported an error when creating a socket. But the function
+which failed is unknown (not one of 'S' for socket or 'B' for bind).
+
+% BIND10_SOCKCREATOR_BAD_RESPONSE unknown response for socket request: %1
+The boss requested a socket from the creator, but the answer is unknown. This
+looks like a programmer error.
+
+% BIND10_SOCKCREATOR_CRASHED the socket creator crashed
+The socket creator terminated unexpectadly. It is not possible to restart it
+(because the boss already gave up root privileges), so the system is going
+to terminate.
+
+% BIND10_SOCKCREATOR_EOF eof while expecting data from socket creator
+There should be more data from the socket creator, but it closed the socket.
+It probably crashed.
+
+% BIND10_SOCKCREATOR_INIT initializing socket creator parser
+The boss module initializes routines for parsing the socket creator
+protocol.
+
+% BIND10_SOCKCREATOR_KILL killing the socket creator
+The socket creator is being terminated the aggressive way, by sending it
+sigkill. This should not happen usually.
+
+% BIND10_SOCKCREATOR_TERMINATE terminating socket creator
+The boss module sends a request to terminate to the socket creator.
+
+% BIND10_SOCKCREATOR_TRANSPORT_ERROR transport error when talking to the socket creator: %1
+Either sending or receiving data from the socket creator failed with the given
+error. The creator probably crashed or some serious OS-level problem happened,
+as the communication happens only on local host.
+
+% BIND10_SOCKET_CREATED successfully created socket %1
+The socket creator successfully created and sent a requested socket, it has
+the given file number.
+
+% BIND10_SOCKET_ERROR error on %1 call in the creator: %2/%3
+The socket creator failed to create the requested socket. It failed on the
+indicated OS API function with given error.
+
+% BIND10_SOCKET_GET requesting socket [%1]:%2 of type %3 from the creator
+The boss forwards a request for a socket to the socket creator.
 
 % BIND10_STARTED_PROCESS started %1
 The given process has successfully been started.
@@ -147,6 +184,13 @@ All modules have been successfully started, and BIND 10 is now running.
 There was a fatal error when BIND10 was trying to start. The error is
 shown, and BIND10 will now shut down.
 
+% BIND10_START_AS_NON_ROOT starting %1 as a user, not root. This might fail.
+The given module is being started or restarted without root privileges.
+If the module needs these privileges, it may have problems starting.
+Note that this issue should be resolved by the pending 'socket-creator'
+process; once that has been implemented, modules should not need root
+privileges anymore. See tickets #800 and #801 for more information.
+
 % BIND10_STOP_PROCESS asking %1 to shut down
 The boss module is sending a shutdown command to the given module over
 the message channel.
@@ -154,4 +198,3 @@ the message channel.
 % BIND10_UNKNOWN_CHILD_PROCESS_ENDED unknown child pid %1 exited
 An unknown child process has exited. The PID is printed, but no further
 action will be taken by the boss process.
-

+ 33 - 1
src/bin/bind10/bind10.py.in

@@ -67,6 +67,7 @@ import isc.util.process
 import isc.net.parse
 import isc.log
 from bind10_messages import *
+import bind10.sockcreator
 
 isc.log.init("b10-boss")
 logger = isc.log.Logger("boss")
@@ -248,6 +249,7 @@ class BoB:
         self.config_filename = config_filename
         self.cmdctl_port = cmdctl_port
         self.brittle = brittle
+        self.sockcreator = None
 
     def config_handler(self, new_config):
         # If this is initial update, don't do anything now, leave it to startup
@@ -333,6 +335,20 @@ class BoB:
                                                             "Unknown command")
         return answer
 
+    def start_creator(self):
+        self.curproc = 'b10-sockcreator'
+        self.sockcreator = bind10.sockcreator.Creator("@@LIBEXECDIR@@:" +
+                                                      os.environ['PATH'])
+
+    def stop_creator(self, kill=False):
+        if self.sockcreator is None:
+            return
+        if kill:
+            self.sockcreator.kill()
+        else:
+            self.sockcreator.terminate()
+        self.sockcreator = None
+
     def kill_started_processes(self):
         """
             Called as part of the exception handling when a process fails to
@@ -341,6 +357,8 @@ class BoB:
         """
         logger.info(BIND10_KILLING_ALL_PROCESSES)
 
+        self.stop_creator(True)
+
         for pid in self.processes:
             logger.info(BIND10_KILL_PROCESS, self.processes[pid].name)
             self.processes[pid].process.kill()
@@ -571,6 +589,11 @@ class BoB:
             Starts up all the processes.  Any exception generated during the
             starting of the processes is handled by the caller.
         """
+        # The socket creator first, as it is the only thing that needs root
+        self.start_creator()
+        # TODO: Once everything uses the socket creator, we can drop root
+        # privileges right now
+
         c_channel_env = self.c_channel_env
         self.start_msgq(c_channel_env)
         self.start_cfgmgr(c_channel_env)
@@ -660,6 +683,8 @@ class BoB:
         self.cc_session.group_sendmsg(cmd, "Zonemgr", "Zonemgr")
         self.cc_session.group_sendmsg(cmd, "Stats", "Stats")
         self.cc_session.group_sendmsg(cmd, "StatsHttpd", "StatsHttpd")
+        # Terminate the creator last
+        self.stop_creator()
 
     def stop_process(self, process, recipient):
         """
@@ -746,7 +771,14 @@ class BoB:
                 # XXX: should be impossible to get any other error here
                 raise
             if pid == 0: break
-            if pid in self.processes:
+            if self.sockcreator is not None and self.sockcreator.pid() == pid:
+                # This is the socket creator, started and terminated
+                # differently. This can't be restarted.
+                if self.runnable:
+                    logger.fatal(BIND10_SOCKCREATOR_CRASHED)
+                    self.sockcreator = None
+                    self.runnable = False
+            elif pid in self.processes:
                 # One of the processes we know about.  Get information on it.
                 proc_info = self.processes.pop(pid)
                 proc_info.restart_schedule.set_run_stop_time()

+ 2 - 2
src/bin/bind10/run_bind10.sh.in

@@ -23,14 +23,14 @@ BIND10_PATH=@abs_top_builddir@/src/bin/bind10
 PATH=@abs_top_builddir@/src/bin/msgq:@abs_top_builddir@/src/bin/auth:@abs_top_builddir@/src/bin/resolver:@abs_top_builddir@/src/bin/cfgmgr:@abs_top_builddir@/src/bin/cmdctl:@abs_top_builddir@/src/bin/stats:@abs_top_builddir@/src/bin/xfrin:@abs_top_builddir@/src/bin/xfrout:@abs_top_builddir@/src/bin/zonemgr:@abs_top_builddir@/src/bin/dhcp6:$PATH
 export PATH
 
-PYTHONPATH=@abs_top_builddir@/src/lib/python:@abs_top_builddir@/src/lib/dns/python/.libs:@abs_top_builddir@/src/lib/xfr/.libs:@abs_top_builddir@/src/lib/log/.libs:@abs_top_builddir@/src/lib/util/io/.libs:@abs_top_builddir@/src/lib/python/isc/config
+PYTHONPATH=@abs_top_builddir@/src/lib/python:@abs_top_builddir@/src/lib/dns/python/.libs:@abs_top_builddir@/src/lib/xfr/.libs:@abs_top_builddir@/src/lib/log/.libs:@abs_top_builddir@/src/lib/util/io/.libs:@abs_top_builddir@/src/lib/python/isc/config:@abs_top_builddir@/src/lib/python/isc/acl/.libs:
 export PYTHONPATH
 
 # If necessary (rare cases), explicitly specify paths to dynamic libraries
 # required by loadable python modules.
 SET_ENV_LIBRARY_PATH=@SET_ENV_LIBRARY_PATH@
 if test $SET_ENV_LIBRARY_PATH = yes; then
-	@ENV_LIBRARY_PATH@=@abs_top_builddir@/src/lib/dns/.libs:@abs_top_builddir@/src/lib/cryptolink/.libs:@abs_top_builddir@/src/lib/cc/.libs:@abs_top_builddir@/src/lib/config/.libs:@abs_top_builddir@/src/lib/log/.libs:@abs_top_builddir@/src/lib/util/.libs:@abs_top_builddir@/src/lib/util/io/.libs:@abs_top_builddir@/src/lib/exceptions/.libs:$@ENV_LIBRARY_PATH@
+	@ENV_LIBRARY_PATH@=@abs_top_builddir@/src/lib/dns/.libs:@abs_top_builddir@/src/lib/cryptolink/.libs:@abs_top_builddir@/src/lib/cc/.libs:@abs_top_builddir@/src/lib/config/.libs:@abs_top_builddir@/src/lib/log/.libs:@abs_top_builddir@/src/lib/acl/.libs:@abs_top_builddir@/src/lib/util/.libs:@abs_top_builddir@/src/lib/util/io/.libs:@abs_top_builddir@/src/lib/exceptions/.libs:$@ENV_LIBRARY_PATH@
 	export @ENV_LIBRARY_PATH@
 fi
 

+ 226 - 0
src/bin/bind10/sockcreator.py

@@ -0,0 +1,226 @@
+# Copyright (C) 2011  Internet Systems Consortium, Inc. ("ISC")
+#
+# Permission to use, copy, modify, and distribute this software for any
+# purpose with or without fee is hereby granted, provided that the above
+# copyright notice and this permission notice appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
+# DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
+# INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
+# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
+# FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
+# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
+# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+import socket
+import struct
+import os
+import subprocess
+from bind10_messages import *
+from libutil_io_python import recv_fd
+
+logger = isc.log.Logger("boss")
+
+"""
+Module that comunicates with the privileged socket creator (b10-sockcreator).
+"""
+
+class CreatorError(Exception):
+    """
+    Exception for socket creator related errors.
+
+    It has two members: fatal and errno and they are just holding the values
+    passed to the __init__ function.
+    """
+
+    def __init__(self, message, fatal, errno=None):
+        """
+        Creates the exception. The message argument is the usual string.
+        The fatal one tells if the error is fatal (eg. the creator crashed)
+        and errno is the errno value returned from socket creator, if
+        applicable.
+        """
+        Exception.__init__(self, message)
+        self.fatal = fatal
+        self.errno = errno
+
+class Parser:
+    """
+    This class knows the sockcreator language. It creates commands, sends them
+    and receives the answers and parses them.
+
+    It does not start it, the communication channel must be provided.
+
+    In theory, anything here can throw a fatal CreatorError exception, but it
+    happens only in case something like the creator process crashes. Any other
+    occasions are mentioned explicitly.
+    """
+
+    def __init__(self, creator_socket):
+        """
+        Creates the parser. The creator_socket is socket to the socket creator
+        process that will be used for communication. However, the object must
+        have a read_fd() method to read the file descriptor. This slightly
+        unusual trick with modifying an object is used to easy up testing.
+
+        You can use WrappedSocket in production code to add the method to any
+        ordinary socket.
+        """
+        self.__socket = creator_socket
+        logger.info(BIND10_SOCKCREATOR_INIT)
+
+    def terminate(self):
+        """
+        Asks the creator process to terminate and waits for it to close the
+        socket. Does not return anything. Raises a CreatorError if there is
+        still data on the socket, if there is an error closing the socket,
+        or if the socket had already been closed.
+        """
+        if self.__socket is None:
+            raise CreatorError('Terminated already', True)
+        logger.info(BIND10_SOCKCREATOR_TERMINATE)
+        try:
+            self.__socket.sendall(b'T')
+            # Wait for an EOF - it will return empty data
+            eof = self.__socket.recv(1)
+            if len(eof) != 0:
+                raise CreatorError('Protocol error - data after terminated',
+                                   True)
+            self.__socket = None
+        except socket.error as se:
+            self.__socket = None
+            raise CreatorError(str(se), True)
+
+    def get_socket(self, address, port, socktype):
+        """
+        Asks the socket creator process to create a socket. Pass an address
+        (the isc.net.IPaddr object), port number and socket type (either
+        string "UDP", "TCP" or constant socket.SOCK_DGRAM or
+        socket.SOCK_STREAM.
+
+        Blocks until it is provided by the socket creator process (which
+        should be fast, as it is on localhost) and returns the file descriptor
+        number. It raises a CreatorError exception if the creation fails.
+        """
+        if self.__socket is None:
+            raise CreatorError('Socket requested on terminated creator', True)
+        # First, assemble the request from parts
+        logger.info(BIND10_SOCKET_GET, address, port, socktype)
+        data = b'S'
+        if socktype == 'UDP' or socktype == socket.SOCK_DGRAM:
+            data += b'U'
+        elif socktype == 'TCP' or socktype == socket.SOCK_STREAM:
+            data += b'T'
+        else:
+            raise ValueError('Unknown socket type: ' + str(socktype))
+        if address.family == socket.AF_INET:
+            data += b'4'
+        elif address.family == socket.AF_INET6:
+            data += b'6'
+        else:
+            raise ValueError('Unknown address family in address')
+        data += struct.pack('!H', port)
+        data += address.addr
+        try:
+            # Send the request
+            self.__socket.sendall(data)
+            answer = self.__socket.recv(1)
+            if answer == b'S':
+                # Success!
+                result = self.__socket.read_fd()
+                logger.info(BIND10_SOCKET_CREATED, result)
+                return result
+            elif answer == b'E':
+                # There was an error, read the error as well
+                error = self.__socket.recv(1)
+                errno = struct.unpack('i',
+                                      self.__read_all(len(struct.pack('i',
+                                                                      0))))
+                if error == b'S':
+                    cause = 'socket'
+                elif error == b'B':
+                    cause = 'bind'
+                else:
+                    self.__socket = None
+                    logger.fatal(BIND10_SOCKCREATOR_BAD_CAUSE, error)
+                    raise CreatorError('Unknown error cause' + str(answer), True)
+                logger.error(BIND10_SOCKET_ERROR, cause, errno[0],
+                             os.strerror(errno[0]))
+                raise CreatorError('Error creating socket on ' + cause, False,
+                                   errno[0])
+            else:
+                self.__socket = None
+                logger.fatal(BIND10_SOCKCREATOR_BAD_RESPONSE, answer)
+                raise CreatorError('Unknown response ' + str(answer), True)
+        except socket.error as se:
+            self.__socket = None
+            logger.fatal(BIND10_SOCKCREATOR_TRANSPORT_ERROR, str(se))
+            raise CreatorError(str(se), True)
+
+    def __read_all(self, length):
+        """
+        Keeps reading until length data is read or EOF or error happens.
+
+        EOF is considered error as well and throws a CreatorError.
+        """
+        result = b''
+        while len(result) < length:
+            data = self.__socket.recv(length - len(result))
+            if len(data) == 0:
+                self.__socket = None
+                logger.fatal(BIND10_SOCKCREATOR_EOF)
+                raise CreatorError('Unexpected EOF', True)
+            result += data
+        return result
+
+class WrappedSocket:
+    """
+    This class wraps a socket and adds a read_fd method, so it can be used
+    for the Parser class conveniently. It simply copies all its guts into
+    itself and implements the method.
+    """
+    def __init__(self, socket):
+        # Copy whatever can be copied from the socket
+        for name in dir(socket):
+            if name not in ['__class__', '__weakref__']:
+                setattr(self, name, getattr(socket, name))
+        # Keep the socket, so we can prevent it from being garbage-collected
+        # and closed before we are removed ourself
+        self.__orig_socket = socket
+
+    def read_fd(self):
+        """
+        Read the file descriptor from the socket.
+        """
+        return recv_fd(self.fileno())
+
+# FIXME: Any idea how to test this? Starting an external process doesn't sound
+# OK
+class Creator(Parser):
+    """
+    This starts the socket creator and allows asking for the sockets.
+    """
+    def __init__(self, path):
+        (local, remote) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
+        # Popen does not like, for some reason, having the same socket for
+        # stdin as well as stdout, so we dup it before passing it there.
+        remote2 = socket.fromfd(remote.fileno(), socket.AF_UNIX,
+                                socket.SOCK_STREAM)
+        env = os.environ
+        env['PATH'] = path
+        self.__process = subprocess.Popen(['b10-sockcreator'], env=env,
+                                          stdin=remote.fileno(),
+                                          stdout=remote2.fileno())
+        remote.close()
+        remote2.close()
+        Parser.__init__(self, WrappedSocket(local))
+
+    def pid(self):
+        return self.__process.pid
+
+    def kill(self):
+        logger.warn(BIND10_SOCKCREATOR_KILL)
+        if self.__process is not None:
+            self.__process.kill()
+            self.__process = None

+ 3 - 4
src/bin/bind10/tests/Makefile.am

@@ -1,14 +1,13 @@
 PYCOVERAGE_RUN = @PYCOVERAGE_RUN@
 #PYTESTS = args_test.py bind10_test.py
 # NOTE: this has a generated test found in the builddir
-PYTESTS = bind10_test.py
-EXTRA_DIST = $(PYTESTS)
+PYTESTS = bind10_test.py sockcreator_test.py
 
 # If necessary (rare cases), explicitly specify paths to dynamic libraries
 # required by loadable python modules.
 LIBRARY_PATH_PLACEHOLDER =
 if SET_ENV_LIBRARY_PATH
-LIBRARY_PATH_PLACEHOLDER += $(ENV_LIBRARY_PATH)=$(abs_top_builddir)/src/lib/cc/.libs:$(abs_top_builddir)/src/lib/config/.libs:$(abs_top_builddir)/src/lib/log/.libs:$(abs_top_builddir)/src/lib/util/.libs:$(abs_top_builddir)/src/lib/exceptions/.libs:$$$(ENV_LIBRARY_PATH)
+LIBRARY_PATH_PLACEHOLDER += $(ENV_LIBRARY_PATH)=$(abs_top_builddir)/src/lib/cc/.libs:$(abs_top_builddir)/src/lib/config/.libs:$(abs_top_builddir)/src/lib/log/.libs:$(abs_top_builddir)/src/lib/util/.libs:$(abs_top_builddir)/src/lib/exceptions/.libs:$(abs_top_builddir)/src/lib/util/io/.libs:$$$(ENV_LIBRARY_PATH)
 endif
 
 # test using command-line arguments, so use check-local target instead of TESTS
@@ -21,7 +20,7 @@ endif
 	for pytest in $(PYTESTS) ; do \
 	echo Running test: $$pytest ; \
 	$(LIBRARY_PATH_PLACEHOLDER) \
-	env PYTHONPATH=$(abs_top_srcdir)/src/lib/python:$(abs_top_builddir)/src/lib/python:$(abs_top_builddir)/src/bin/bind10 \
+	env PYTHONPATH=$(abs_top_srcdir)/src/lib/python:$(abs_top_builddir)/src/lib/python:$(abs_top_srcdir)/src/bin:$(abs_top_builddir)/src/bin/bind10:$(abs_top_builddir)/src/lib/util/io/.libs \
 	BIND10_MSGQ_SOCKET_FILE=$(abs_top_builddir)/msgq_socket \
 		$(PYCOVERAGE_RUN) $(abs_builddir)/$$pytest || exit ; \
 	done

+ 11 - 1
src/bin/bind10/tests/bind10_test.py.in

@@ -13,7 +13,7 @@
 # NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
 # WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 
-from bind10 import ProcessInfo, BoB, parse_args, dump_pid, unlink_pid_file, _BASETIME
+from bind10_src import ProcessInfo, BoB, parse_args, dump_pid, unlink_pid_file, _BASETIME
 
 # XXX: environment tests are currently disabled, due to the preprocessor
 #      setup that we have now complicating the environment
@@ -26,6 +26,7 @@ import socket
 from isc.net.addr import IPAddr
 import time
 import isc
+import isc.log
 
 from isc.testutils.parse_args import TestOptParser, OptsError
 
@@ -192,6 +193,13 @@ class MockBob(BoB):
         self.cmdctl = False
         self.c_channel_env = {}
         self.processes = { }
+        self.creator = False
+
+    def start_creator(self):
+        self.creator = True
+
+    def stop_creator(self, kill=False):
+        self.creator = False
 
     def read_bind10_config(self):
         # Configuration options are set directly
@@ -336,6 +344,7 @@ class TestStartStopProcessesBob(unittest.TestCase):
         self.assertEqual(bob.msgq, core)
         self.assertEqual(bob.cfgmgr, core)
         self.assertEqual(bob.ccsession, core)
+        self.assertEqual(bob.creator, core)
         self.assertEqual(bob.auth, auth)
         self.assertEqual(bob.resolver, resolver)
         self.assertEqual(bob.xfrout, auth)
@@ -764,4 +773,5 @@ class TestBrittle(unittest.TestCase):
         self.assertFalse(bob.runnable)
 
 if __name__ == '__main__':
+    isc.log.resetUnitTestRootLogger()
     unittest.main()

+ 327 - 0
src/bin/bind10/tests/sockcreator_test.py.in

@@ -0,0 +1,327 @@
+# Copyright (C) 2011  Internet Systems Consortium, Inc. ("ISC")
+#
+# Permission to use, copy, modify, and distribute this software for any
+# purpose with or without fee is hereby granted, provided that the above
+# copyright notice and this permission notice appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
+# DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
+# INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
+# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
+# FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
+# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
+# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+# This test file is generated .py.in -> .py just to be in the build dir,
+# same as the rest of the tests. Saves a lot of stuff in makefile.
+
+"""
+Tests for the bind10.sockcreator module.
+"""
+
+import unittest
+import struct
+import socket
+from isc.net.addr import IPAddr
+import isc.log
+from libutil_io_python import send_fd
+from bind10.sockcreator import Parser, CreatorError, WrappedSocket
+
+class FakeCreator:
+    """
+    Class emulating the socket to the socket creator. It can be given expected
+    data to receive (and check) and responses to give to the Parser class
+    during testing.
+    """
+
+    class InvalidPlan(Exception):
+        """
+        Raised when someone wants to recv when sending is planned or vice
+        versa.
+        """
+        pass
+
+    class InvalidData(Exception):
+        """
+        Raises when the data passed to sendall are not the same as expected.
+        """
+        pass
+
+    def __init__(self, plan):
+        """
+        Create the object. The plan variable contains list of expected actions,
+        in form:
+
+        [('r', 'Data to return from recv'), ('s', 'Data expected on sendall'),
+             , ('d', 'File descriptor number to return from read_sock'), ('e',
+             None), ...]
+
+        It modifies the array as it goes.
+        """
+        self.__plan = plan
+
+    def __get_plan(self, expected):
+        if len(self.__plan) == 0:
+            raise InvalidPlan('Nothing more planned')
+        (kind, data) = self.__plan[0]
+        if kind == 'e':
+            self.__plan.pop(0)
+            raise socket.error('False socket error')
+        if kind != expected:
+            raise InvalidPlan('Planned ' + kind + ', but ' + expected +
+                'requested')
+        return data
+
+    def recv(self, maxsize):
+        """
+        Emulate recv. Returs maxsize bytes from the current recv plan. If
+        there are data left from previous recv call, it is used first.
+
+        If no recv is planned, raises InvalidPlan.
+        """
+        data = self.__get_plan('r')
+        result, rest = data[:maxsize], data[maxsize:]
+        if len(rest) > 0:
+            self.__plan[0] = ('r', rest)
+        else:
+            self.__plan.pop(0)
+        return result
+
+    def read_fd(self):
+        """
+        Emulate the reading of file descriptor. Returns one from a plan.
+
+        It raises InvalidPlan if no socket is planned now.
+        """
+        fd = self.__get_plan('f')
+        self.__plan.pop(0)
+        return fd
+
+    def sendall(self, data):
+        """
+        Checks that the data passed are correct according to plan. It raises
+        InvalidData if the data differs or InvalidPlan when sendall is not
+        expected.
+        """
+        planned = self.__get_plan('s')
+        dlen = len(data)
+        prefix, rest = planned[:dlen], planned[dlen:]
+        if prefix != data:
+            raise InvalidData('Expected "' + str(prefix)+ '", got "' +
+                str(data) + '"')
+        if len(rest) > 0:
+            self.__plan[0] = ('s', rest)
+        else:
+            self.__plan.pop(0)
+
+    def all_used(self):
+        """
+        Returns if the whole plan was consumed.
+        """
+        return len(self.__plan) == 0
+
+class ParserTests(unittest.TestCase):
+    """
+    Testcases for the Parser class.
+
+    A lot of these test could be done by
+    `with self.assertRaises(CreatorError) as cm`. But some versions of python
+    take the scope wrong and don't work, so we use the primitive way of
+    try-except.
+    """
+    def __terminate(self):
+        creator = FakeCreator([('s', b'T'), ('r', b'')])
+        parser = Parser(creator)
+        self.assertEqual(None, parser.terminate())
+        self.assertTrue(creator.all_used())
+        return parser
+
+    def test_terminate(self):
+        """
+        Test if the command to terminate is correct and it waits for reading the
+        EOF.
+        """
+        self.__terminate()
+
+    def __terminate_raises(self, parser):
+        """
+        Check that terminate() raises a fatal exception.
+        """
+        try:
+            parser.terminate()
+            self.fail("Not raised")
+        except CreatorError as ce:
+            self.assertTrue(ce.fatal)
+            self.assertEqual(None, ce.errno)
+
+    def test_terminate_error1(self):
+        """
+        Test it reports an exception when there's error terminating the creator.
+        This one raises an error when receiving the EOF.
+        """
+        creator = FakeCreator([('s', b'T'), ('e', None)])
+        parser = Parser(creator)
+        self.__terminate_raises(parser)
+
+    def test_terminate_error2(self):
+        """
+        Test it reports an exception when there's error terminating the creator.
+        This one raises an error when sending data.
+        """
+        creator = FakeCreator([('e', None)])
+        parser = Parser(creator)
+        self.__terminate_raises(parser)
+
+    def test_terminate_error3(self):
+        """
+        Test it reports an exception when there's error terminating the creator.
+        This one sends data when it should have terminated.
+        """
+        creator = FakeCreator([('s', b'T'), ('r', b'Extra data')])
+        parser = Parser(creator)
+        self.__terminate_raises(parser)
+
+    def test_terminate_twice(self):
+        """
+        Test we can't terminate twice.
+        """
+        parser = self.__terminate()
+        self.__terminate_raises(parser)
+
+    def test_crash(self):
+        """
+        Tests that the parser correctly raises exception when it crashes
+        unexpectedly.
+        """
+        creator = FakeCreator([('s', b'SU4\0\0\0\0\0\0'), ('r', b'')])
+        parser = Parser(creator)
+        try:
+            parser.get_socket(IPAddr('0.0.0.0'), 0, 'UDP')
+            self.fail("Not raised")
+        except CreatorError as ce:
+            self.assertTrue(creator.all_used())
+            # Is the exception correct?
+            self.assertTrue(ce.fatal)
+            self.assertEqual(None, ce.errno)
+
+    def test_error(self):
+        """
+        Tests that the parser correctly raises non-fatal exception when
+        the socket can not be created.
+        """
+        # We split the int to see if it can cope with data coming in
+        # different packets
+        intpart = struct.pack('@i', 42)
+        creator = FakeCreator([('s', b'SU4\0\0\0\0\0\0'), ('r', b'ES' +
+            intpart[:1]), ('r', intpart[1:])])
+        parser = Parser(creator)
+        try:
+            parser.get_socket(IPAddr('0.0.0.0'), 0, 'UDP')
+            self.fail("Not raised")
+        except CreatorError as ce:
+            self.assertTrue(creator.all_used())
+            # Is the exception correct?
+            self.assertFalse(ce.fatal)
+            self.assertEqual(42, ce.errno)
+
+    def __error(self, plan):
+        creator = FakeCreator(plan)
+        parser = Parser(creator)
+        try:
+            parser.get_socket(IPAddr('0.0.0.0'), 0, socket.SOCK_DGRAM)
+            self.fail("Not raised")
+        except CreatorError as ce:
+            self.assertTrue(creator.all_used())
+            self.assertTrue(ce.fatal)
+
+    def test_error_send(self):
+        self.__error([('e', None)])
+
+    def test_error_recv(self):
+        self.__error([('s', b'SU4\0\0\0\0\0\0'), ('e', None)])
+
+    def test_error_read_fd(self):
+        self.__error([('s', b'SU4\0\0\0\0\0\0'), ('r', b'S'), ('e', None)])
+
+    def __create(self, addr, socktype, encoded):
+        creator = FakeCreator([('s', b'S' + encoded), ('r', b'S'), ('f', 42)])
+        parser = Parser(creator)
+        self.assertEqual(42, parser.get_socket(IPAddr(addr), 42, socktype))
+
+    def test_create1(self):
+        self.__create('192.0.2.0', 'UDP', b'U4\0\x2A\xC0\0\x02\0')
+
+    def test_create2(self):
+        self.__create('2001:db8::', socket.SOCK_STREAM,
+            b'T6\0\x2A\x20\x01\x0d\xb8\0\0\0\0\0\0\0\0\0\0\0\0')
+
+    def test_create_terminated(self):
+        """
+        Test we can't request sockets after it was terminated.
+        """
+        parser = self.__terminate()
+        try:
+            parser.get_socket(IPAddr('0.0.0.0'), 0, 'UDP')
+            self.fail("Not raised")
+        except CreatorError as ce:
+            self.assertTrue(ce.fatal)
+            self.assertEqual(None, ce.errno)
+
+    def test_invalid_socktype(self):
+        """
+        Test invalid socket type is rejected
+        """
+        self.assertRaises(ValueError, Parser(FakeCreator([])).get_socket,
+                          IPAddr('0.0.0.0'), 42, 'RAW')
+
+    def test_invalid_family(self):
+        """
+        Test it rejects invalid address family.
+        """
+        # Note: this produces a bad logger output, since this address
+        # can not be converted to string, so the original message with
+        # placeholders is output. This should not happen in practice, so
+        # it is harmless.
+        addr = IPAddr('0.0.0.0')
+        addr.family = 42
+        self.assertRaises(ValueError, Parser(FakeCreator([])).get_socket,
+                          addr, 42, socket.SOCK_DGRAM)
+
+class WrapTests(unittest.TestCase):
+    """
+    Tests for the wrap_socket function.
+    """
+    def test_wrap(self):
+        # We construct two pairs of socket. The receiving side of one pair will
+        # be wrapped. Then we send one of the other pair through this pair and
+        # check the received one can be used as a socket
+
+        # The transport socket
+        (t1, t2) = socket.socketpair()
+        # The payload socket
+        (p1, p2) = socket.socketpair()
+
+        t2 = WrappedSocket(t2)
+
+        # Transfer the descriptor
+        send_fd(t1.fileno(), p1.fileno())
+        p1 = socket.fromfd(t2.read_fd(), socket.AF_UNIX, socket.SOCK_STREAM)
+
+        # Now, pass some data trough the socket
+        p1.send(b'A')
+        data = p2.recv(1)
+        self.assertEqual(b'A', data)
+
+        # Test the wrapping didn't hurt the socket's usual methods
+        t1.send(b'B')
+        data = t2.recv(1)
+        self.assertEqual(b'B', data)
+        t2.send(b'C')
+        data = t1.recv(1)
+        self.assertEqual(b'C', data)
+
+if __name__ == '__main__':
+    isc.log.init("bind10") # FIXME Should this be needed?
+    isc.log.resetUnitTestRootLogger()
+    unittest.main()

+ 15 - 6
src/bin/bindctl/bindcmd.py

@@ -398,6 +398,8 @@ class BindCmdInterpreter(Cmd):
                 print("Error: " + str(dte))
             except isc.cc.data.DataNotFoundError as dnfe:
                 print("Error: " + str(dnfe))
+            except isc.cc.data.DataAlreadyPresentError as dape:
+                print("Error: " + str(dape))
             except KeyError as ke:
                 print("Error: missing " + str(ke))
         else:
@@ -634,7 +636,15 @@ class BindCmdInterpreter(Cmd):
                     # we have more data to show
                     line += "/"
                 else:
-                    line += "\t" + json.dumps(value_map['value'])
+                    # if type is named_set, don't print value if None
+                    # (it is either {} meaning empty, or None, meaning
+                    # there actually is data, but not to be shown with
+                    # the current command
+                    if value_map['type'] == 'named_set' and\
+                       value_map['value'] is None:
+                        line += "/\t"
+                    else:
+                        line += "\t" + json.dumps(value_map['value'])
                 line += "\t" + value_map['type']
                 line += "\t"
                 if value_map['default']:
@@ -649,10 +659,9 @@ class BindCmdInterpreter(Cmd):
                 data, default = self.config_data.get_value(identifier)
                 print(json.dumps(data))
         elif cmd.command == "add":
-            if 'value' in cmd.params:
-                self.config_data.add_value(identifier, cmd.params['value'])
-            else:
-                self.config_data.add_value(identifier)
+            self.config_data.add_value(identifier,
+                                       cmd.params.get('value_or_name'),
+                                       cmd.params.get('value_for_set'))
         elif cmd.command == "remove":
             if 'value' in cmd.params:
                 self.config_data.remove_value(identifier, cmd.params['value'])
@@ -679,7 +688,7 @@ class BindCmdInterpreter(Cmd):
             except isc.config.ModuleCCSessionError as mcse:
                 print(str(mcse))
         elif cmd.command == "diff":
-            print(self.config_data.get_local_changes());
+            print(self.config_data.get_local_changes())
         elif cmd.command == "go":
             self.go(identifier)
 

+ 15 - 4
src/bin/bindctl/bindctl_main.py.in

@@ -50,17 +50,28 @@ def prepare_config_commands(tool):
     cmd.add_param(param)
     module.add_command(cmd)
 
-    cmd = CommandInfo(name = "add", desc = "Add an entry to configuration list. If no value is given, a default value is added.")
+    cmd = CommandInfo(name = "add", desc =
+        "Add an entry to configuration list or a named set. "
+        "When adding to a list, the command has one optional argument, "
+        "a value to add to the list. The value must be in correct JSON "
+        "and complete. When adding to a named set, it has one "
+        "mandatory parameter (the name to add), and an optional "
+        "parameter value, similar to when adding to a list. "
+        "In either case, when no value is given, an entry will be "
+        "constructed with default values.")
     param = ParamInfo(name = "identifier", type = "string", optional=True, desc = DEFAULT_IDENTIFIER_DESC)
     cmd.add_param(param)
-    param = ParamInfo(name = "value", type = "string", optional=True, desc = "Specifies a value to add to the list. It must be in correct JSON format and complete.")
+    param = ParamInfo(name = "value_or_name", type = "string", optional=True, desc = "Specifies a value to add to the list, or the name when adding to a named set. It must be in correct JSON format and complete.")
+    cmd.add_param(param)
+    module.add_command(cmd)
+    param = ParamInfo(name = "value_for_set", type = "string", optional=True, desc = "Specifies an optional value to add to the named map. It must be in correct JSON format and complete.")
     cmd.add_param(param)
     module.add_command(cmd)
 
-    cmd = CommandInfo(name = "remove", desc = "Remove entry from configuration list.")
+    cmd = CommandInfo(name = "remove", desc = "Remove entry from configuration list or named set.")
     param = ParamInfo(name = "identifier", type = "string", optional=True, desc = DEFAULT_IDENTIFIER_DESC)
     cmd.add_param(param)
-    param = ParamInfo(name = "value", type = "string", optional=True, desc = "Specifies a value to remove from the list. It must be in correct JSON format and complete.")
+    param = ParamInfo(name = "value", type = "string", optional=True, desc = "When identifier is a list, specifies a value to remove from the list. It must be in correct JSON format and complete. When it is a named set, specifies the name to remove.")
     cmd.add_param(param)
     module.add_command(cmd)
 

+ 2 - 0
src/bin/cmdctl/tests/cmdctl_test.py

@@ -19,6 +19,7 @@ import socket
 import tempfile
 import sys
 from cmdctl import *
+import isc.log
 
 SPEC_FILE_PATH = '..' + os.sep
 if 'CMDCTL_SPEC_PATH' in os.environ:
@@ -447,6 +448,7 @@ class TestFuncNotInClass(unittest.TestCase):
 
 
 if __name__== "__main__":
+    isc.log.resetUnitTestRootLogger()
     unittest.main()
 
 

+ 2 - 2
src/bin/dhcp6/tests/Makefile.am

@@ -8,14 +8,14 @@ EXTRA_DIST = $(PYTESTS)
 # required by loadable python modules.
 LIBRARY_PATH_PLACEHOLDER =
 if SET_ENV_LIBRARY_PATH
-LIBRARY_PATH_PLACEHOLDER += $(ENV_LIBRARY_PATH)=$(abs_top_builddir)/src/lib/cc/.libs:$(abs_top_builddir)/src/lib/config/.libs:$(abs_top_builddir)/src/lib/log/.libs:$(abs_top_builddir)/src/lib/util/.libs:$(abs_top_builddir)/src/lib/exceptions/.libs:$$$(ENV_LIBRARY_PATH)
+LIBRARY_PATH_PLACEHOLDER += $(ENV_LIBRARY_PATH)=$(abs_top_builddir)/src/lib/cc/.libs:$(abs_top_builddir)/src/lib/config/.libs:$(abs_top_builddir)/src/lib/log/.libs:$(abs_top_builddir)/src/lib/util/.libs:$(abs_top_builddir)/src/lib/exceptions/.libs:$(abs_top_builddir)/src/lib/util/io/.libs:$$$(ENV_LIBRARY_PATH)
 endif
 
 # test using command-line arguments, so use check-local target instead of TESTS
 check-local:
 	for pytest in $(PYTESTS) ; do \
 	echo Running test: $$pytest ; \
-	env PYTHONPATH=$(abs_top_srcdir)/src/lib/python:$(abs_top_builddir)/src/lib/python:$(abs_top_builddir)/src/bin/bind10 \
+	env PYTHONPATH=$(abs_top_srcdir)/src/lib/python:$(abs_top_builddir)/src/lib/python:$(abs_top_srcdir)/src/bin:$(abs_top_builddir)/src/bin/bind10:$(abs_top_builddir)/src/lib/util/io/.libs \
 	$(LIBRARY_PATH_PLACEHOLDER) \
 	BIND10_MSGQ_SOCKET_FILE=$(abs_top_builddir)/msgq_socket \
 		$(PYCOVERAGE_RUN) $(abs_srcdir)/$$pytest || exit ; \

+ 1 - 1
src/bin/dhcp6/tests/dhcp6_test.py

@@ -13,7 +13,7 @@
 # NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
 # WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 
-from bind10 import ProcessInfo, parse_args, dump_pid, unlink_pid_file, _BASETIME
+from bind10_src import ProcessInfo, parse_args, dump_pid, unlink_pid_file, _BASETIME
 
 import unittest
 import sys

+ 2 - 1
src/bin/resolver/resolver.cc

@@ -520,7 +520,8 @@ ResolverImpl::processNormalQuery(const IOMessage& io_message,
     const Client client(io_message);
     const BasicAction query_action(
         getQueryACL().execute(acl::dns::RequestContext(
-                                  client.getRequestSourceIPAddress())));
+                                  client.getRequestSourceIPAddress(),
+                                  query_message->getTSIGRecord())));
     if (query_action == isc::acl::REJECT) {
         LOG_INFO(resolver_logger, RESOLVER_QUERY_REJECTED)
             .arg(question->getName()).arg(qtype).arg(qclass).arg(client);

+ 2 - 1
src/bin/resolver/tests/resolver_config_unittest.cc

@@ -72,7 +72,8 @@ protected:
                                           IOSocket::getDummyUDPSocket(),
                                           *endpoint));
         client.reset(new Client(*query_message));
-        request.reset(new RequestContext(client->getRequestSourceIPAddress()));
+        request.reset(new RequestContext(client->getRequestSourceIPAddress(),
+                                         NULL));
         return (*request);
     }
     void invalidTest(const string &JSON, const string& name);

+ 1 - 1
src/bin/sockcreator/README

@@ -3,7 +3,7 @@ The socket creator
 
 The only thing we need higher rights than standard user is binding sockets to
 ports lower than 1024. So we will have a separate process that keeps the
-rights, while the rests drop them for security reasons.
+rights, while the rest drops them for security reasons.
 
 This process is the socket creator. Its goal is to be as simple as possible
 and to contain as little code as possible to minimise the amount of code

+ 16 - 2
src/bin/stats/stats_httpd.py.in

@@ -385,7 +385,14 @@ class StatsHttpd:
             annotation.append(documentation)
             element.append(annotation)
             xsd_root.append(element)
-        xsd_string = xml.etree.ElementTree.tostring(xsd_root)
+        # The coding conversion is tricky. xml..tostring() of Python 3.2
+        # returns bytes (not string) regardless of the coding, while
+        # tostring() of Python 3.1 returns a string.  To support both
+        # cases transparently, we first make sure tostring() returns
+        # bytes by specifying utf-8 and then convert the result to a
+        # plain string (code below assume it).
+        xsd_string = str(xml.etree.ElementTree.tostring(xsd_root, encoding='utf-8'),
+                         encoding='us-ascii')
         self.xsd_body = self.open_template(XSD_TEMPLATE_LOCATION).substitute(
             xsd_string=xsd_string,
             xsd_namespace=XSD_NAMESPACE
@@ -410,7 +417,14 @@ class StatsHttpd:
             tr.append(td1)
             tr.append(td2)
             xsd_root.append(tr)
-        xsl_string = xml.etree.ElementTree.tostring(xsd_root)
+        # The coding conversion is tricky. xml..tostring() of Python 3.2
+        # returns bytes (not string) regardless of the coding, while
+        # tostring() of Python 3.1 returns a string.  To support both
+        # cases transparently, we first make sure tostring() returns
+        # bytes by specifying utf-8 and then convert the result to a
+        # plain string (code below assume it).
+        xsl_string = str(xml.etree.ElementTree.tostring(xsd_root, encoding='utf-8'),
+                         encoding='us-ascii')
         self.xsl_body = self.open_template(XSL_TEMPLATE_LOCATION).substitute(
             xsl_string=xsl_string,
             xsd_namespace=XSD_NAMESPACE)

+ 89 - 0
src/bin/stats/tests/b10-stats-httpd_test.py

@@ -402,6 +402,95 @@ class TestStatsHttpd(unittest.TestCase):
             )
         self.assertEqual(ret, 1)
 
+    def test_xml_handler(self):
+        orig_get_stats_data = stats_httpd.StatsHttpd.get_stats_data
+        stats_httpd.StatsHttpd.get_stats_data = lambda x: {'foo':'bar'}
+        xml_body1 = stats_httpd.StatsHttpd().open_template(
+            stats_httpd.XML_TEMPLATE_LOCATION).substitute(
+            xml_string='<foo>bar</foo>',
+            xsd_namespace=stats_httpd.XSD_NAMESPACE,
+            xsd_url_path=stats_httpd.XSD_URL_PATH,
+            xsl_url_path=stats_httpd.XSL_URL_PATH)
+        xml_body2 = stats_httpd.StatsHttpd().xml_handler()
+        self.assertEqual(type(xml_body1), str)
+        self.assertEqual(type(xml_body2), str)
+        self.assertEqual(xml_body1, xml_body2)
+        stats_httpd.StatsHttpd.get_stats_data = lambda x: {'bar':'foo'}
+        xml_body2 = stats_httpd.StatsHttpd().xml_handler()
+        self.assertNotEqual(xml_body1, xml_body2)
+        stats_httpd.StatsHttpd.get_stats_data = orig_get_stats_data
+
+    def test_xsd_handler(self):
+        orig_get_stats_spec = stats_httpd.StatsHttpd.get_stats_spec
+        stats_httpd.StatsHttpd.get_stats_spec = lambda x: \
+            [{
+                "item_name": "foo",
+                "item_type": "string",
+                "item_optional": False,
+                "item_default": "bar",
+                "item_description": "foo is bar",
+                "item_title": "Foo"
+               }]
+        xsd_body1 = stats_httpd.StatsHttpd().open_template(
+            stats_httpd.XSD_TEMPLATE_LOCATION).substitute(
+            xsd_string='<all>' \
+                + '<element maxOccurs="1" minOccurs="1" name="foo" type="string">' \
+                + '<annotation><appinfo>Foo</appinfo>' \
+                + '<documentation>foo is bar</documentation>' \
+                + '</annotation></element></all>',
+            xsd_namespace=stats_httpd.XSD_NAMESPACE)
+        xsd_body2 = stats_httpd.StatsHttpd().xsd_handler()
+        self.assertEqual(type(xsd_body1), str)
+        self.assertEqual(type(xsd_body2), str)
+        self.assertEqual(xsd_body1, xsd_body2)
+        stats_httpd.StatsHttpd.get_stats_spec = lambda x: \
+            [{
+                "item_name": "bar",
+                "item_type": "string",
+                "item_optional": False,
+                "item_default": "foo",
+                "item_description": "bar is foo",
+                "item_title": "bar"
+               }]
+        xsd_body2 = stats_httpd.StatsHttpd().xsd_handler()
+        self.assertNotEqual(xsd_body1, xsd_body2)
+        stats_httpd.StatsHttpd.get_stats_spec = orig_get_stats_spec
+
+    def test_xsl_handler(self):
+        orig_get_stats_spec = stats_httpd.StatsHttpd.get_stats_spec
+        stats_httpd.StatsHttpd.get_stats_spec = lambda x: \
+            [{
+                "item_name": "foo",
+                "item_type": "string",
+                "item_optional": False,
+                "item_default": "bar",
+                "item_description": "foo is bar",
+                "item_title": "Foo"
+               }]
+        xsl_body1 = stats_httpd.StatsHttpd().open_template(
+            stats_httpd.XSL_TEMPLATE_LOCATION).substitute(
+            xsl_string='<xsl:template match="*"><tr>' \
+                + '<td class="title" title="foo is bar">Foo</td>' \
+                + '<td><xsl:value-of select="foo" /></td>' \
+                + '</tr></xsl:template>',
+            xsd_namespace=stats_httpd.XSD_NAMESPACE)
+        xsl_body2 = stats_httpd.StatsHttpd().xsl_handler()
+        self.assertEqual(type(xsl_body1), str)
+        self.assertEqual(type(xsl_body2), str)
+        self.assertEqual(xsl_body1, xsl_body2)
+        stats_httpd.StatsHttpd.get_stats_spec = lambda x: \
+            [{
+                "item_name": "bar",
+                "item_type": "string",
+                "item_optional": False,
+                "item_default": "foo",
+                "item_description": "bar is foo",
+                "item_title": "bar"
+               }]
+        xsl_body2 = stats_httpd.StatsHttpd().xsl_handler()
+        self.assertNotEqual(xsl_body1, xsl_body2)
+        stats_httpd.StatsHttpd.get_stats_spec = orig_get_stats_spec
+
     def test_for_without_B10_FROM_SOURCE(self):
         # just lets it go through the code without B10_FROM_SOURCE env
         # variable

+ 2 - 0
src/bin/xfrin/tests/xfrin_test.py

@@ -18,6 +18,7 @@ import socket
 import io
 from isc.testutils.tsigctx_mock import MockTSIGContext
 from xfrin import *
+import isc.log
 
 #
 # Commonly used (mostly constant) test parameters
@@ -1115,6 +1116,7 @@ class TestMain(unittest.TestCase):
 
 if __name__== "__main__":
     try:
+        isc.log.resetUnitTestRootLogger()
         unittest.main()
     except KeyboardInterrupt as e:
         print(e)

+ 1 - 1
src/bin/xfrout/tests/Makefile.am

@@ -6,7 +6,7 @@ EXTRA_DIST = $(PYTESTS)
 # required by loadable python modules.
 LIBRARY_PATH_PLACEHOLDER =
 if SET_ENV_LIBRARY_PATH
-LIBRARY_PATH_PLACEHOLDER += $(ENV_LIBRARY_PATH)=$(abs_top_builddir)/src/lib/cc/.libs:$(abs_top_builddir)/src/lib/config/.libs:$(abs_top_builddir)/src/lib/log/.libs:$(abs_top_builddir)/src/lib/dns/.libs:$(abs_top_builddir)/src/lib/cryptolink/.libs:$(abs_top_builddir)/src/lib/util/.libs:$(abs_top_builddir)/src/lib/exceptions/.libs:$(abs_top_builddir)/src/lib/util/io/.libs:$$$(ENV_LIBRARY_PATH)
+LIBRARY_PATH_PLACEHOLDER += $(ENV_LIBRARY_PATH)=$(abs_top_builddir)/src/lib/cc/.libs:$(abs_top_builddir)/src/lib/config/.libs:$(abs_top_builddir)/src/lib/log/.libs:$(abs_top_builddir)/src/lib/dns/.libs:$(abs_top_builddir)/src/lib/cryptolink/.libs:$(abs_top_builddir)/src/lib/acl/.libs:$(abs_top_builddir)/src/lib/util/.libs:$(abs_top_builddir)/src/lib/exceptions/.libs:$(abs_top_builddir)/src/lib/util/io/.libs:$$$(ENV_LIBRARY_PATH)
 endif
 
 # test using command-line arguments, so use check-local target instead of TESTS

+ 149 - 22
src/bin/xfrout/tests/xfrout_test.py.in

@@ -23,6 +23,8 @@ from isc.cc.session import *
 from pydnspp import *
 from xfrout import *
 import xfrout
+import isc.log
+import isc.acl.dns
 
 TSIG_KEY = TSIGKey("example.com:SFuWd/q99SzF8Yzd1QbB9g==")
 
@@ -116,8 +118,11 @@ class TestXfroutSession(unittest.TestCase):
 
     def setUp(self):
         self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM)
-        #self.log = isc.log.NSLogger('xfrout', '',  severity = 'critical', log_to_console = False )
-        self.xfrsess = MyXfroutSession(self.sock, None, Dbserver(), TSIGKeyRing())
+        self.xfrsess = MyXfroutSession(self.sock, None, Dbserver(),
+                                       TSIGKeyRing(), ('127.0.0.1', 12345),
+                                       # When not testing ACLs, simply accept
+                                       isc.acl.dns.REQUEST_LOADER.load(
+                                           [{"action": "ACCEPT"}]))
         self.mdata = bytes(b'\xd6=\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x07example\x03com\x00\x00\xfc\x00\x01')
         self.soa_record = (4, 3, 'example.com.', 'com.example.', 3600, 'SOA', None, 'master.example.com. admin.example.com. 1234 3600 1800 2419200 7200')
 
@@ -132,11 +137,90 @@ class TestXfroutSession(unittest.TestCase):
         self.assertEqual(rcode.to_text(), "NOTAUTH")
         self.assertTrue(self.xfrsess._tsig_ctx is not None)
         # NOERROR
-        self.xfrsess._tsig_key_ring.add(TSIG_KEY)
+        self.assertEqual(TSIGKeyRing.SUCCESS,
+                         self.xfrsess._tsig_key_ring.add(TSIG_KEY))
         [rcode, msg] = self.xfrsess._parse_query_message(request_data)
         self.assertEqual(rcode.to_text(), "NOERROR")
         self.assertTrue(self.xfrsess._tsig_ctx is not None)
 
+        # ACL checks, put some ACL inside
+        self.xfrsess._acl = isc.acl.dns.REQUEST_LOADER.load([
+            {
+                "from": "127.0.0.1",
+                "action": "ACCEPT"
+            },
+            {
+                "from": "192.0.2.1",
+                "action": "DROP"
+            }
+        ])
+        # Localhost (the default in this test) is accepted
+        rcode, msg = self.xfrsess._parse_query_message(self.mdata)
+        self.assertEqual(rcode.to_text(), "NOERROR")
+        # This should be dropped completely, therefore returning None
+        self.xfrsess._remote = ('192.0.2.1', 12345)
+        rcode, msg = self.xfrsess._parse_query_message(self.mdata)
+        self.assertEqual(None, rcode)
+        # This should be refused, therefore REFUSED
+        self.xfrsess._remote = ('192.0.2.2', 12345)
+        rcode, msg = self.xfrsess._parse_query_message(self.mdata)
+        self.assertEqual(rcode.to_text(), "REFUSED")
+        # If the TSIG check fails, it should not check ACL
+        # (If it checked ACL as well, it would just drop the request)
+        self.xfrsess._remote = ('192.0.2.1', 12345)
+        self.xfrsess._tsig_key_ring = TSIGKeyRing()
+        rcode, msg = self.xfrsess._parse_query_message(request_data)
+        self.assertEqual(rcode.to_text(), "NOTAUTH")
+        self.assertTrue(self.xfrsess._tsig_ctx is not None)
+
+        # ACL using TSIG: successful case
+        self.xfrsess._acl = isc.acl.dns.REQUEST_LOADER.load([
+            {"key": "example.com", "action": "ACCEPT"}, {"action": "REJECT"}
+        ])
+        self.assertEqual(TSIGKeyRing.SUCCESS,
+                         self.xfrsess._tsig_key_ring.add(TSIG_KEY))
+        [rcode, msg] = self.xfrsess._parse_query_message(request_data)
+        self.assertEqual(rcode.to_text(), "NOERROR")
+
+        # ACL using TSIG: key name doesn't match; should be rejected
+        self.xfrsess._acl = isc.acl.dns.REQUEST_LOADER.load([
+            {"key": "example.org", "action": "ACCEPT"}, {"action": "REJECT"}
+        ])
+        [rcode, msg] = self.xfrsess._parse_query_message(request_data)
+        self.assertEqual(rcode.to_text(), "REFUSED")
+
+        # ACL using TSIG: no TSIG; should be rejected
+        self.xfrsess._acl = isc.acl.dns.REQUEST_LOADER.load([
+            {"key": "example.org", "action": "ACCEPT"}, {"action": "REJECT"}
+        ])
+        [rcode, msg] = self.xfrsess._parse_query_message(self.mdata)
+        self.assertEqual(rcode.to_text(), "REFUSED")
+
+        #
+        # ACL using IP + TSIG: both should match
+        #
+        self.xfrsess._acl = isc.acl.dns.REQUEST_LOADER.load([
+                {"ALL": [{"key": "example.com"}, {"from": "192.0.2.1"}],
+                 "action": "ACCEPT"},
+                {"action": "REJECT"}
+        ])
+        # both matches
+        self.xfrsess._remote = ('192.0.2.1', 12345)
+        [rcode, msg] = self.xfrsess._parse_query_message(request_data)
+        self.assertEqual(rcode.to_text(), "NOERROR")
+        # TSIG matches, but address doesn't
+        self.xfrsess._remote = ('192.0.2.2', 12345)
+        [rcode, msg] = self.xfrsess._parse_query_message(request_data)
+        self.assertEqual(rcode.to_text(), "REFUSED")
+        # Address matches, but TSIG doesn't (not included)
+        self.xfrsess._remote = ('192.0.2.1', 12345)
+        [rcode, msg] = self.xfrsess._parse_query_message(self.mdata)
+        self.assertEqual(rcode.to_text(), "REFUSED")
+        # Neither address nor TSIG matches
+        self.xfrsess._remote = ('192.0.2.2', 12345)
+        [rcode, msg] = self.xfrsess._parse_query_message(self.mdata)
+        self.assertEqual(rcode.to_text(), "REFUSED")
+
     def test_get_query_zone_name(self):
         msg = self.getmsg()
         self.assertEqual(self.xfrsess._get_query_zone_name(msg), "example.com.")
@@ -195,20 +279,6 @@ class TestXfroutSession(unittest.TestCase):
         self.assertEqual(msg.get_rcode(), rcode)
         self.assertTrue(msg.get_header_flag(Message.HEADERFLAG_AA))
 
-    def test_reply_query_with_format_error(self):
-        msg = self.getmsg()
-        self.xfrsess._reply_query_with_format_error(msg, self.sock)
-        get_msg = self.sock.read_msg()
-        self.assertEqual(get_msg.get_rcode().to_text(), "FORMERR")
-
-        # tsig signed message
-        msg = self.getmsg()
-        self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
-        self.xfrsess._reply_query_with_format_error(msg, self.sock)
-        get_msg = self.sock.read_msg()
-        self.assertEqual(get_msg.get_rcode().to_text(), "FORMERR")
-        self.assertTrue(self.message_has_tsig(get_msg))
-
     def test_create_rrset_from_db_record(self):
         rrset = self.xfrsess._create_rrset_from_db_record(self.soa_record)
         self.assertEqual(rrset.get_name().to_text(), "example.com.")
@@ -515,18 +585,42 @@ class MyCCSession():
 
 class MyUnixSockServer(UnixSockServer):
     def __init__(self):
-        self._lock = threading.Lock()
-        self._transfers_counter = 0
         self._shutdown_event = threading.Event()
         self._max_transfers_out = 10
         self._cc = MyCCSession()
-        #self._log = isc.log.NSLogger('xfrout', '', severity = 'critical', log_to_console = False )
+        self._common_init()
 
 class TestUnixSockServer(unittest.TestCase):
     def setUp(self):
         self.write_sock, self.read_sock = socket.socketpair()
         self.unix = MyUnixSockServer()
 
+    def test_guess_remote(self):
+        """Test we can guess the remote endpoint when we have only the
+           file descriptor. This is needed, because we get only that one
+           from auth."""
+        # We test with UDP, as it can be "connected" without other
+        # endpoint
+        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+        sock.connect(('127.0.0.1', 12345))
+        self.assertEqual(('127.0.0.1', 12345),
+                         self.unix._guess_remote(sock.fileno()))
+        if socket.has_ipv6:
+            # Don't check IPv6 address on hosts not supporting them
+            sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
+            sock.connect(('::1', 12345))
+            self.assertEqual(('::1', 12345, 0, 0),
+                             self.unix._guess_remote(sock.fileno()))
+            # Try when pretending there's no IPv6 support
+            # (No need to pretend when there's really no IPv6)
+            xfrout.socket.has_ipv6 = False
+            sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+            sock.connect(('127.0.0.1', 12345))
+            self.assertEqual(('127.0.0.1', 12345),
+                             self.unix._guess_remote(sock.fileno()))
+            # Return it back
+            xfrout.socket.has_ipv6 = True
+
     def test_receive_query_message(self):
         send_msg = b"\xd6=\x00\x00\x00\x01\x00"
         msg_len = struct.pack('H', socket.htons(len(send_msg)))
@@ -535,15 +629,37 @@ class TestUnixSockServer(unittest.TestCase):
         recv_msg = self.unix._receive_query_message(self.read_sock)
         self.assertEqual(recv_msg, send_msg)
 
-    def test_updata_config_data(self):
+    def check_default_ACL(self):
+        context = isc.acl.dns.RequestContext(socket.getaddrinfo("127.0.0.1",
+                                             1234, 0, socket.SOCK_DGRAM,
+                                             socket.IPPROTO_UDP,
+                                             socket.AI_NUMERICHOST)[0][4])
+        self.assertEqual(isc.acl.acl.ACCEPT, self.unix._acl.execute(context))
+
+    def check_loaded_ACL(self):
+        context = isc.acl.dns.RequestContext(socket.getaddrinfo("127.0.0.1",
+                                             1234, 0, socket.SOCK_DGRAM,
+                                             socket.IPPROTO_UDP,
+                                             socket.AI_NUMERICHOST)[0][4])
+        self.assertEqual(isc.acl.acl.ACCEPT, self.unix._acl.execute(context))
+        context = isc.acl.dns.RequestContext(socket.getaddrinfo("192.0.2.1",
+                                             1234, 0, socket.SOCK_DGRAM,
+                                             socket.IPPROTO_UDP,
+                                             socket.AI_NUMERICHOST)[0][4])
+        self.assertEqual(isc.acl.acl.REJECT, self.unix._acl.execute(context))
+
+    def test_update_config_data(self):
+        self.check_default_ACL()
         tsig_key_str = 'example.com:SFuWd/q99SzF8Yzd1QbB9g=='
         tsig_key_list = [tsig_key_str]
         bad_key_list = ['bad..example.com:SFuWd/q99SzF8Yzd1QbB9g==']
         self.unix.update_config_data({'transfers_out':10 })
         self.assertEqual(self.unix._max_transfers_out, 10)
         self.assertTrue(self.unix.tsig_key_ring is not None)
+        self.check_default_ACL()
 
-        self.unix.update_config_data({'transfers_out':9, 'tsig_key_ring':tsig_key_list})
+        self.unix.update_config_data({'transfers_out':9,
+                                      'tsig_key_ring':tsig_key_list})
         self.assertEqual(self.unix._max_transfers_out, 9)
         self.assertEqual(self.unix.tsig_key_ring.size(), 1)
         self.unix.tsig_key_ring.remove(Name("example.com."))
@@ -554,6 +670,16 @@ class TestUnixSockServer(unittest.TestCase):
         self.assertRaises(None, self.unix.update_config_data(config_data))
         self.assertEqual(self.unix.tsig_key_ring.size(), 0)
 
+        # Load the ACL
+        self.unix.update_config_data({'query_acl': [{'from': '127.0.0.1',
+                                               'action': 'ACCEPT'}]})
+        self.check_loaded_ACL()
+        # Pass a wrong data there and check it does not replace the old one
+        self.assertRaises(isc.acl.acl.LoaderError,
+                          self.unix.update_config_data,
+                          {'query_acl': ['Something bad']})
+        self.check_loaded_ACL()
+
     def test_get_db_file(self):
         self.assertEqual(self.unix.get_db_file(), "initdb.file")
 
@@ -670,4 +796,5 @@ class TestInitialization(unittest.TestCase):
         self.assertEqual(xfrout.UNIX_SOCKET_FILE, "The/Socket/File")
 
 if __name__== "__main__":
+    isc.log.resetUnitTestRootLogger()
     unittest.main()

+ 70 - 29
src/bin/xfrout/xfrout.py.in

@@ -48,6 +48,9 @@ except ImportError as e:
     # must keep running, so we warn about it and move forward.
     log.error(XFROUT_IMPORT, str(e))
 
+from isc.acl.acl import ACCEPT, REJECT, DROP
+from isc.acl.dns import REQUEST_LOADER
+
 isc.util.process.rename()
 
 def init_paths():
@@ -92,16 +95,16 @@ def get_rrset_len(rrset):
 
 
 class XfroutSession():
-    def __init__(self, sock_fd, request_data, server, tsig_key_ring):
-        # The initializer for the superclass may call functions
-        # that need _log to be set, so we set it first
+    def __init__(self, sock_fd, request_data, server, tsig_key_ring, remote,
+                 acl):
         self._sock_fd = sock_fd
         self._request_data = request_data
         self._server = server
-        #self._log = log
         self._tsig_key_ring = tsig_key_ring
         self._tsig_ctx = None
         self._tsig_len = 0
+        self._remote = remote
+        self._acl = acl
         self.handle()
 
     def create_tsig_ctx(self, tsig_record, tsig_key_ring):
@@ -114,7 +117,7 @@ class XfroutSession():
             self.dns_xfrout_start(self._sock_fd, self._request_data)
             #TODO, avoid catching all exceptions
         except Exception as e:
-            logger.error(XFROUT_HANDLE_QUERY_ERROR, str(e))
+            logger.error(XFROUT_HANDLE_QUERY_ERROR, e)
             pass
 
         os.close(self._sock_fd)
@@ -141,8 +144,26 @@ class XfroutSession():
             # TSIG related checks
             rcode = self._check_request_tsig(msg, mdata)
 
+            if rcode == Rcode.NOERROR():
+                # ACL checks
+                acl_result = self._acl.execute(
+                    isc.acl.dns.RequestContext(self._remote,
+                                               msg.get_tsig_record()))
+                if acl_result == DROP:
+                    logger.info(XFROUT_QUERY_DROPPED,
+                                self._get_query_zone_name(msg),
+                                self._get_query_zone_class(msg),
+                                self._remote[0], self._remote[1])
+                    return None, None
+                elif acl_result == REJECT:
+                    logger.info(XFROUT_QUERY_REJECTED,
+                                self._get_query_zone_name(msg),
+                                self._get_query_zone_class(msg),
+                                self._remote[0], self._remote[1])
+                    return Rcode.REFUSED(), msg
+
         except Exception as err:
-            logger.error(XFROUT_PARSE_QUERY_ERROR, str(err))
+            logger.error(XFROUT_PARSE_QUERY_ERROR, err)
             return Rcode.FORMERR(), None
 
         return rcode, msg
@@ -183,18 +204,11 @@ class XfroutSession():
 
 
     def _reply_query_with_error_rcode(self, msg, sock_fd, rcode_):
-        msg.make_response()
-        msg.set_rcode(rcode_)
-        self._send_message(sock_fd, msg, self._tsig_ctx)
-
-
-    def _reply_query_with_format_error(self, msg, sock_fd):
-        '''query message format isn't legal.'''
         if not msg:
             return # query message is invalid. send nothing back.
 
         msg.make_response()
-        msg.set_rcode(Rcode.FORMERR())
+        msg.set_rcode(rcode_)
         self._send_message(sock_fd, msg, self._tsig_ctx)
 
     def _zone_has_soa(self, zone):
@@ -244,10 +258,13 @@ class XfroutSession():
     def dns_xfrout_start(self, sock_fd, msg_query):
         rcode_, msg = self._parse_query_message(msg_query)
         #TODO. create query message and parse header
-        if rcode_ == Rcode.NOTAUTH():
+        if rcode_ is None: # Dropped by ACL
+            return
+        elif rcode_ == Rcode.NOTAUTH() or rcode_ == Rcode.REFUSED():
             return self._reply_query_with_error_rcode(msg, sock_fd, rcode_)
         elif rcode_ != Rcode.NOERROR():
-            return self._reply_query_with_format_error(msg, sock_fd)
+            return self._reply_query_with_error_rcode(msg, sock_fd,
+                                                      Rcode.FORMERR())
 
         zone_name = self._get_query_zone_name(msg)
         zone_class_str = self._get_query_zone_class(msg)
@@ -257,7 +274,7 @@ class XfroutSession():
         if rcode_ != Rcode.NOERROR():
             logger.info(XFROUT_AXFR_TRANSFER_FAILED, zone_name,
                         zone_class_str, rcode_.to_text())
-            return self. _reply_query_with_error_rcode(msg, sock_fd, rcode_)
+            return self._reply_query_with_error_rcode(msg, sock_fd, rcode_)
 
         try:
             logger.info(XFROUT_AXFR_TRANSFER_STARTED, zone_name, zone_class_str)
@@ -375,14 +392,20 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
         self._sock_file = sock_file
         socketserver_mixin.NoPollMixIn.__init__(self)
         ThreadingUnixStreamServer.__init__(self, sock_file, handle_class)
-        self._lock = threading.Lock()
-        self._transfers_counter = 0
         self._shutdown_event = shutdown_event
         self._write_sock, self._read_sock = socket.socketpair()
-        #self._log = log
+        self._common_init()
         self.update_config_data(config_data)
         self._cc = cc
 
+    def _common_init(self):
+        self._lock = threading.Lock()
+        self._transfers_counter = 0
+        # This default value will probably get overwritten by the (same)
+        # default value from the spec file. This is here just to make
+        # sure and to make the default value in tests consistent.
+        self._acl = REQUEST_LOADER.load('[{"action": "ACCEPT"}]')
+
     def _receive_query_message(self, sock):
         ''' receive request message from sock'''
         # receive data length
@@ -465,10 +488,28 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
             t.daemon = True
         t.start()
 
+    def _guess_remote(self, sock_fd):
+        """
+           Guess remote address and port of the socket. The sock_fd must be a
+           socket
+        """
+        # This uses a trick. If the socket is IPv4 in reality and we pretend
+        # it to be IPv6, it returns IPv4 address anyway. This doesn't seem
+        # to care about the SOCK_STREAM parameter at all (which it really is,
+        # except for testing)
+        if socket.has_ipv6:
+            sock = socket.fromfd(sock_fd, socket.AF_INET6, socket.SOCK_STREAM)
+        else:
+            # To make it work even on hosts without IPv6 support
+            # (Any idea how to simulate this in test?)
+            sock = socket.fromfd(sock_fd, socket.AF_INET, socket.SOCK_STREAM)
+        return sock.getpeername()
 
     def finish_request(self, sock_fd, request_data):
         '''Finish one request by instantiating RequestHandlerClass.'''
-        self.RequestHandlerClass(sock_fd, request_data, self, self.tsig_key_ring)
+        self.RequestHandlerClass(sock_fd, request_data, self,
+                                 self.tsig_key_ring,
+                                 self._guess_remote(sock_fd), self._acl)
 
     def _remove_unused_sock_file(self, sock_file):
         '''Try to remove the socket file. If the file is being used
@@ -512,6 +553,8 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
     def update_config_data(self, new_config):
         '''Apply the new config setting of xfrout module. '''
         logger.info(XFROUT_NEW_CONFIG)
+        if 'query_acl' in new_config:
+            self._acl = REQUEST_LOADER.load(new_config['query_acl'])
         self._lock.acquire()
         self._max_transfers_out = new_config.get('transfers_out')
         self.set_tsig_key_ring(new_config.get('tsig_key_ring'))
@@ -563,16 +606,12 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
 class XfroutServer:
     def __init__(self):
         self._unix_socket_server = None
-        #self._log = None
         self._listen_sock_file = UNIX_SOCKET_FILE
         self._shutdown_event = threading.Event()
         self._cc = isc.config.ModuleCCSession(SPECFILE_LOCATION, self.config_handler, self.command_handler)
         self._config_data = self._cc.get_full_config()
         self._cc.start()
         self._cc.add_remote_config(AUTH_SPECFILE_LOCATION);
-        #self._log = isc.log.NSLogger(self._config_data.get('log_name'), self._config_data.get('log_file'),
-        #                        self._config_data.get('log_severity'), self._config_data.get('log_versions'),
-        #                        self._config_data.get('log_max_bytes'), True)
         self._start_xfr_query_listener()
         self._start_notifier()
 
@@ -601,11 +640,13 @@ class XfroutServer:
                 continue
             self._config_data[key] = new_config[key]
 
-        #if self._log:
-        #    self._log.update_config(new_config)
-
         if self._unix_socket_server:
-            self._unix_socket_server.update_config_data(self._config_data)
+            try:
+                self._unix_socket_server.update_config_data(self._config_data)
+            except Exception as e:
+                answer = create_answer(1,
+                                       "Failed to handle new configuration: " +
+                                       str(e))
 
         return answer
 

+ 19 - 7
src/bin/xfrout/xfrout.spec.pre.in

@@ -16,27 +16,27 @@
        },
        {
          "item_name": "log_file",
-    	 "item_type": "string",
+         "item_type": "string",
          "item_optional": false,
          "item_default": "@@LOCALSTATEDIR@@/@PACKAGE@/log/Xfrout.log"
        },
        {
          "item_name": "log_severity",
-    	 "item_type": "string",
+         "item_type": "string",
          "item_optional": false,
-    	 "item_default": "debug"
+         "item_default": "debug"
        },
        {
          "item_name": "log_versions",
-    	 "item_type": "integer",
+         "item_type": "integer",
          "item_optional": false,
-    	 "item_default": 5
+         "item_default": 5
        },
        {
          "item_name": "log_max_bytes",
-    	 "item_type": "integer",
+         "item_type": "integer",
          "item_optional": false,
-    	 "item_default": 1048576
+         "item_default": 1048576
        },
        {
          "item_name": "tsig_key_ring",
@@ -49,6 +49,18 @@
              "item_type": "string",
              "item_optional": true
          }
+       },
+       {
+         "item_name": "query_acl",
+         "item_type": "list",
+         "item_optional": false,
+         "item_default": [{"action": "ACCEPT"}],
+         "list_item_spec":
+         {
+             "item_name": "acl_element",
+             "item_type": "any",
+             "item_optional": true
+         }
        }
       ],
       "commands": [

+ 11 - 0
src/bin/xfrout/xfrout_messages.mes

@@ -95,6 +95,17 @@ in the log message, but at this point no specific information other
 than that could be given. This points to incomplete exception handling
 in the code.
 
+% XFROUT_QUERY_DROPPED request to transfer %1/%2 to [%3]:%4 dropped
+The xfrout process silently dropped a request to transfer zone to given host.
+This is required by the ACLs. The %1 and %2 represent the zone name and class,
+the %3 and %4 the IP address and port of the peer requesting the transfer.
+
+% XFROUT_QUERY_REJECTED request to transfer %1/%2 to [%3]:%4 rejected
+The xfrout process rejected (by REFUSED rcode) a request to transfer zone to
+given host. This is because of ACLs. The %1 and %2 represent the zone name and
+class, the %3 and %4 the IP address and port of the peer requesting the
+transfer.
+
 % XFROUT_RECEIVE_FILE_DESCRIPTOR_ERROR error receiving the file descriptor for an XFR connection
 There was an error receiving the file descriptor for the transfer
 request. Normally, the request is received by b10-auth, and passed on

+ 1 - 1
src/lib/acl/Makefile.am

@@ -19,7 +19,7 @@ libacl_la_LIBADD += $(top_builddir)/src/lib/util/libutil.la
 # DNS specialized one
 lib_LTLIBRARIES += libdnsacl.la
 
-libdnsacl_la_SOURCES = dns.h dns.cc
+libdnsacl_la_SOURCES = dns.h dns.cc dnsname_check.h
 
 libdnsacl_la_LIBADD = libacl.la
 libdnsacl_la_LIBADD += $(top_builddir)/src/lib/dns/libdns++.la

+ 22 - 3
src/lib/acl/dns.cc

@@ -20,15 +20,20 @@
 
 #include <exceptions/exceptions.h>
 
+#include <dns/name.h>
+#include <dns/tsigrecord.h>
+
 #include <cc/data.h>
 
 #include <acl/dns.h>
 #include <acl/ip_check.h>
+#include <acl/dnsname_check.h>
 #include <acl/loader.h>
 #include <acl/logic_check.h>
 
 using namespace std;
 using boost::shared_ptr;
+using namespace isc::dns;
 using namespace isc::data;
 
 namespace isc {
@@ -39,9 +44,6 @@ namespace acl {
 /// It returns \c true if the remote (source) IP address of the request
 /// matches the expression encapsulated in the \c IPCheck, and returns
 /// \c false if not.
-///
-/// \note The match logic is expected to be extended as we add
-/// more match parameters (at least there's a plan for TSIG key).
 template <>
 bool
 IPCheck<dns::RequestContext>::matches(
@@ -53,6 +55,18 @@ IPCheck<dns::RequestContext>::matches(
 
 namespace dns {
 
+/// The specialization of \c NameCheck for access control with
+/// \c RequestContext.
+///
+/// It returns \c true if the request contains a TSIG record and its key
+/// (owner) name is equal to the name stored in the check; otherwise
+/// it returns \c false.
+template<>
+bool
+NameCheck<RequestContext>::matches(const RequestContext& request) const {
+    return (request.tsig != NULL && request.tsig->getName() == name_);
+}
+
 vector<string>
 internal::RequestCheckCreator::names() const {
     // Probably we should eventually build this vector in a more
@@ -60,6 +74,7 @@ internal::RequestCheckCreator::names() const {
     // everything.
     vector<string> supported_names;
     supported_names.push_back("from");
+    supported_names.push_back("key");
     return (supported_names);
 }
 
@@ -77,6 +92,10 @@ internal::RequestCheckCreator::create(const string& name,
     if (name == "from") {
         return (shared_ptr<internal::RequestIPCheck>(
                     new internal::RequestIPCheck(definition->stringValue())));
+    } else if (name == "key") {
+        return (shared_ptr<internal::RequestKeyCheck>(
+                    new internal::RequestKeyCheck(
+                        Name(definition->stringValue()))));
     } else {
         // This case shouldn't happen (normally) as it should have been
         // rejected at the loader level.  But we explicitly catch the case

+ 19 - 5
src/lib/acl/dns.h

@@ -23,9 +23,13 @@
 #include <cc/data.h>
 
 #include <acl/ip_check.h>
+#include <acl/dnsname_check.h>
 #include <acl/loader.h>
 
 namespace isc {
+namespace dns {
+class TSIGRecord;
+}
 namespace acl {
 namespace dns {
 
@@ -53,9 +57,9 @@ namespace dns {
  * used only for a very short period as stated above.
  *
  * Based on the minimalist philosophy, the initial implementation only
- * maintains the remote (source) IP address of the request.  The plan is
- * to add more parameters of the request.  A scheduled next step is to
- * support the TSIG key (if it's included in the request).  Other possibilities
+ * maintains the remote (source) IP address of the request and (optionally)
+ * the TSIG record included in the request.  We may add more parameters of
+ * the request as we see the need for them.  Possible additional parameters
  * are the local (destination) IP address, the remote and local port numbers,
  * various fields of the DNS request (e.g. a particular header flag value).
  */
@@ -68,8 +72,12 @@ struct RequestContext {
     /// \exception None
     ///
     /// \parameter remote_address_param The remote IP address
-    explicit RequestContext(const IPAddress& remote_address_param) :
-        remote_address(remote_address_param)
+    /// \parameter tsig_param A valid pointer to the TSIG record included in
+    /// the request or NULL if the request doesn't contain a TSIG.
+    RequestContext(const IPAddress& remote_address_param,
+                   const isc::dns::TSIGRecord* tsig_param) :
+        remote_address(remote_address_param),
+        tsig(tsig_param)
     {}
 
     ///
@@ -83,6 +91,11 @@ struct RequestContext {
     //@{
     /// \brief The remote IP address (eg. the client's IP address).
     const IPAddress& remote_address;
+
+    /// \brief The TSIG record included in the request message, if any.
+    ///
+    /// If the request doesn't include a TSIG, this member will be NULL.
+    const isc::dns::TSIGRecord* const tsig;
     //@}
 };
 
@@ -114,6 +127,7 @@ namespace internal {
 
 // Shortcut typedef
 typedef isc::acl::IPCheck<RequestContext> RequestIPCheck;
+typedef isc::acl::dns::NameCheck<RequestContext> RequestKeyCheck;
 
 class RequestCheckCreator : public acl::Loader<RequestContext>::CheckCreator {
 public:

+ 83 - 0
src/lib/acl/dnsname_check.h

@@ -0,0 +1,83 @@
+// Copyright (C) 2011  Internet Systems Consortium, Inc. ("ISC")
+//
+// Permission to use, copy, modify, and/or distribute this software for any
+// purpose with or without fee is hereby granted, provided that the above
+// copyright notice and this permission notice appear in all copies.
+//
+// THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH
+// REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
+// AND FITNESS.  IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT,
+// INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
+// LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
+// OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
+// PERFORMANCE OF THIS SOFTWARE.
+
+#ifndef __DNSNAME_CHECK_H
+#define __DNSNAME_CHECK_H 1
+
+#include <dns/name.h>
+
+#include <acl/check.h>
+
+namespace isc {
+namespace acl {
+namespace dns {
+
+/// ACL check for DNS names
+///
+/// This class is intended to perform a match between a domain name
+/// specified in an ACL and a given name.  The primary usage of this class
+/// is an ACL match for TSIG keys, where an ACL would contain a list of
+/// acceptable key names and the \c match() method would compare the owner
+/// name of a TSIG record against the specified names.
+///
+/// This class could be used for other kinds of names such as the query name
+/// of normal DNS queries.
+///
+/// The class is templated on the type of a context structure passed to the
+/// matches() method, and a template specialisation for that method must be
+/// supplied for the class to be used.
+template <typename Context>
+class NameCheck : public Check<Context> {
+public:
+    /// The constructor
+    ///
+    /// \exception std::bad_alloc Resource allocation fails in copying the
+    /// name
+    ///
+    /// \param name The domain name to be matched in \c matches().
+    NameCheck(const isc::dns::Name& name) : name_(name) {}
+
+    /// Destructor
+    virtual ~NameCheck() {}
+
+    /// The check method
+    ///
+    /// Matches the passed argument to the condition stored here.  Different
+    /// specializations must be provided for different argument types, and the
+    /// program will fail to compile if a required specialisation is not
+    /// provided.
+    ///
+    /// \param context Information to be matched
+    virtual bool matches(const Context& context) const;
+
+    /// Returns the name specified on construction.
+    ///
+    /// This is mainly for testing purposes.
+    ///
+    /// \exception None
+    const isc::dns::Name& getName() const { return (name_); }
+
+private:
+    const isc::dns::Name name_;
+};
+
+} // namespace dns
+} // namespace acl
+} // namespace isc
+
+#endif // __DNSNAME_CHECK_H
+
+// Local Variables:
+// mode: c++
+// End:

+ 2 - 0
src/lib/acl/tests/Makefile.am

@@ -16,6 +16,7 @@ run_unittests_SOURCES += acl_test.cc
 run_unittests_SOURCES += check_test.cc
 run_unittests_SOURCES += dns_test.cc
 run_unittests_SOURCES += ip_check_unittest.cc
+run_unittests_SOURCES += dnsname_check_unittest.cc
 run_unittests_SOURCES += loader_test.cc
 run_unittests_SOURCES += logcheck.h
 run_unittests_SOURCES += creators.h
@@ -30,6 +31,7 @@ run_unittests_LDADD += $(top_builddir)/src/lib/util/unittests/libutil_unittests.
 run_unittests_LDADD += $(top_builddir)/src/lib/acl/libacl.la
 run_unittests_LDADD += $(top_builddir)/src/lib/util/libutil.la
 run_unittests_LDADD += $(top_builddir)/src/lib/cc/libcc.la
+run_unittests_LDADD += $(top_builddir)/src/lib/dns/libdns++.la
 run_unittests_LDADD += $(top_builddir)/src/lib/log/liblog.la
 run_unittests_LDADD += $(top_builddir)/src/lib/exceptions/libexceptions.la
 run_unittests_LDADD += $(top_builddir)/src/lib/acl/libdnsacl.la

+ 76 - 10
src/lib/acl/tests/dns_test.cc

@@ -23,6 +23,11 @@
 
 #include <exceptions/exceptions.h>
 
+#include <dns/name.h>
+#include <dns/tsigkey.h>
+#include <dns/tsigrecord.h>
+#include <dns/rdataclass.h>
+
 #include <cc/data.h>
 #include <acl/dns.h>
 #include <acl/loader.h>
@@ -35,6 +40,8 @@
 
 using namespace std;
 using boost::scoped_ptr;
+using namespace isc::dns;
+using namespace isc::dns::rdata;
 using namespace isc::data;
 using namespace isc::acl;
 using namespace isc::acl::dns;
@@ -64,8 +71,10 @@ protected:
 };
 
 TEST_F(RequestCheckCreatorTest, names) {
-    ASSERT_EQ(1, creator_.names().size());
-    EXPECT_EQ("from", creator_.names()[0]);
+    const vector<string> names = creator_.names();
+    EXPECT_EQ(2, names.size());
+    EXPECT_TRUE(find(names.begin(), names.end(), "from") != names.end());
+    EXPECT_TRUE(find(names.begin(), names.end(), "key") != names.end());
 }
 
 TEST_F(RequestCheckCreatorTest, allowListAbbreviation) {
@@ -93,11 +102,11 @@ TEST_F(RequestCheckCreatorTest, createIPv6Check) {
     check_ = creator_.create("from",
                              Element::fromJSON("\"2001:db8::5300/120\""),
                              getRequestLoader());
-    const dns::internal::RequestIPCheck& ipcheck_ =
+    const dns::internal::RequestIPCheck& ipcheck =
         dynamic_cast<const dns::internal::RequestIPCheck&>(*check_);
-    EXPECT_EQ(AF_INET6, ipcheck_.getFamily());
-    EXPECT_EQ(120, ipcheck_.getPrefixlen());
-    const vector<uint8_t> check_address(ipcheck_.getAddress());
+    EXPECT_EQ(AF_INET6, ipcheck.getFamily());
+    EXPECT_EQ(120, ipcheck.getPrefixlen());
+    const vector<uint8_t> check_address(ipcheck.getAddress());
     ASSERT_EQ(16, check_address.size());
     const uint8_t expected_address[] = { 0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00,
                                          0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
@@ -106,6 +115,14 @@ TEST_F(RequestCheckCreatorTest, createIPv6Check) {
                       expected_address));
 }
 
+TEST_F(RequestCheckCreatorTest, createTSIGKeyCheck) {
+    check_ = creator_.create("key", Element::fromJSON("\"key.example.com\""),
+                             getRequestLoader());
+    const dns::internal::RequestKeyCheck& keycheck =
+        dynamic_cast<const dns::internal::RequestKeyCheck&>(*check_);
+    EXPECT_EQ(Name("key.example.com"), keycheck.getName());
+}
+
 TEST_F(RequestCheckCreatorTest, badCreate) {
     // Invalid name
     EXPECT_THROW(creator_.create("bad", Element::fromJSON("\"192.0.2.1\""),
@@ -118,12 +135,23 @@ TEST_F(RequestCheckCreatorTest, badCreate) {
     EXPECT_THROW(creator_.create("from", Element::fromJSON("[]"),
                                  getRequestLoader()),
                  isc::data::TypeError);
+    EXPECT_THROW(creator_.create("key", Element::fromJSON("1"),
+                                 getRequestLoader()),
+                 isc::data::TypeError);
+    EXPECT_THROW(creator_.create("key", Element::fromJSON("{}"),
+                                 getRequestLoader()),
+                 isc::data::TypeError);
 
     // Syntax error for IPCheck
     EXPECT_THROW(creator_.create("from", Element::fromJSON("\"bad\""),
                                  getRequestLoader()),
                  isc::InvalidParameter);
 
+    // Syntax error for Name (key) Check
+    EXPECT_THROW(creator_.create("key", Element::fromJSON("\"bad..name\""),
+                                 getRequestLoader()),
+                 EmptyLabel);
+
     // NULL pointer
     EXPECT_THROW(creator_.create("from", ConstElementPtr(), getRequestLoader()),
                  LoaderError);
@@ -140,23 +168,43 @@ protected:
                                 getRequestLoader()));
     }
 
+    // A helper shortcut to create a single Name (key) check for the given
+    // name.
+    ConstRequestCheckPtr createKeyCheck(const string& key_name) {
+        return (creator_.create("key", Element::fromJSON(
+                                    string("\"") + key_name + string("\"")),
+                                getRequestLoader()));
+    }
+
     // create a one time request context for a specific test.  Note that
     // getSockaddr() uses a static storage, so it cannot be called more than
     // once in a single test.
-    const dns::RequestContext& getRequest4() {
+    const dns::RequestContext& getRequest4(const TSIGRecord* tsig = NULL) {
         ipaddr.reset(new IPAddress(tests::getSockAddr("192.0.2.1")));
-        request.reset(new dns::RequestContext(*ipaddr));
+        request.reset(new dns::RequestContext(*ipaddr, tsig));
         return (*request);
     }
-    const dns::RequestContext& getRequest6() {
+    const dns::RequestContext& getRequest6(const TSIGRecord* tsig = NULL) {
         ipaddr.reset(new IPAddress(tests::getSockAddr("2001:db8::1")));
-        request.reset(new dns::RequestContext(*ipaddr));
+        request.reset(new dns::RequestContext(*ipaddr, tsig));
         return (*request);
     }
 
+    // create a one time TSIG Record for a specific test.  The only parameter
+    // of the record that matters is the key name; others are hardcoded with
+    // arbitrarily chosen values.
+    const TSIGRecord* getTSIGRecord(const string& key_name) {
+        tsig_rdata.reset(new any::TSIG(TSIGKey::HMACMD5_NAME(), 0, 0, 0, NULL,
+                                       0, 0, 0, NULL));
+        tsig.reset(new TSIGRecord(Name(key_name), *tsig_rdata));
+        return (tsig.get());
+    }
+
 private:
     scoped_ptr<IPAddress> ipaddr;
     scoped_ptr<dns::RequestContext> request;
+    scoped_ptr<any::TSIG> tsig_rdata;
+    scoped_ptr<TSIGRecord> tsig;
     dns::internal::RequestCheckCreator creator_;
 };
 
@@ -184,6 +232,24 @@ TEST_F(RequestCheckTest, checkIPv6) {
     EXPECT_FALSE(createIPCheck("32.1.13.184")->matches(getRequest6()));
 }
 
+TEST_F(RequestCheckTest, checkTSIGKey) {
+    EXPECT_TRUE(createKeyCheck("key.example.com")->matches(
+                    getRequest4(getTSIGRecord("key.example.com"))));
+    EXPECT_FALSE(createKeyCheck("key.example.com")->matches(
+                     getRequest4(getTSIGRecord("badkey.example.com"))));
+
+    // Same for IPv6 (which shouldn't matter)
+    EXPECT_TRUE(createKeyCheck("key.example.com")->matches(
+                    getRequest6(getTSIGRecord("key.example.com"))));
+    EXPECT_FALSE(createKeyCheck("key.example.com")->matches(
+                     getRequest6(getTSIGRecord("badkey.example.com"))));
+
+    // by default the test request doesn't have a TSIG key, which shouldn't
+    // match any key checks.
+    EXPECT_FALSE(createKeyCheck("key.example.com")->matches(getRequest4()));
+    EXPECT_FALSE(createKeyCheck("key.example.com")->matches(getRequest6()));
+}
+
 // The following tests test only the creators are registered, they are tested
 // elsewhere
 

+ 59 - 0
src/lib/acl/tests/dnsname_check_unittest.cc

@@ -0,0 +1,59 @@
+// Copyright (C) 2011  Internet Systems Consortium, Inc. ("ISC")
+//
+// Permission to use, copy, modify, and/or distribute this software for any
+// purpose with or without fee is hereby granted, provided that the above
+// copyright notice and this permission notice appear in all copies.
+//
+// THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH
+// REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
+// AND FITNESS.  IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT,
+// INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
+// LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
+// OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
+// PERFORMANCE OF THIS SOFTWARE.
+
+#include <gtest/gtest.h>
+
+#include <dns/name.h>
+
+#include <acl/dnsname_check.h>
+
+using namespace isc::dns;
+using namespace isc::acl::dns;
+
+// Provide a specialization of the DNSNameCheck::matches() method.
+namespace isc  {
+namespace acl {
+namespace dns {
+template <>
+bool NameCheck<Name>::matches(const Name& name) const {
+    return (name_ == name);
+}
+} // namespace dns
+} // namespace acl
+} // namespace isc
+
+namespace {
+TEST(DNSNameCheck, construct) {
+    EXPECT_EQ(Name("example.com"),
+              NameCheck<Name>(Name("example.com")).getName());
+
+    // Construct the same check with an explicit trailing dot.  Should result
+    // in the same result.
+    EXPECT_EQ(Name("example.com"),
+              NameCheck<Name>(Name("example.com.")).getName());
+}
+
+TEST(DNSNameCheck, match) {
+    NameCheck<Name> check(Name("example.com"));
+    EXPECT_TRUE(check.matches(Name("example.com")));
+    EXPECT_FALSE(check.matches(Name("example.org")));
+
+    // comparison is case insensitive
+    EXPECT_TRUE(check.matches(Name("EXAMPLE.COM")));
+
+    // this is exact match.  so super/sub domains don't match
+    EXPECT_FALSE(check.matches(Name("com")));
+    EXPECT_FALSE(check.matches(Name("www.example.com")));
+}
+} // Unnamed namespace

+ 2 - 2
src/lib/asiodns/tests/run_unittests.cc

@@ -15,14 +15,14 @@
 #include <gtest/gtest.h>
 #include <util/unittests/run_all.h>
 
-#include <log/logger_manager.h>
+#include <log/logger_support.h>
 #include <dns/tests/unittest_util.h>
 
 int
 main(int argc, char* argv[])
 {
     ::testing::InitGoogleTest(&argc, argv);         // Initialize Google test
-    isc::log::LoggerManager::init("unittest");      // Set a root logger name
+    isc::log::initLogger();                         // Initialize logging
     isc::UnitTestUtil::addDataPath(TEST_DATA_DIR);  // Add location of test data
 
     return (isc::util::unittests::run_all());

+ 2 - 0
src/lib/cc/data.cc

@@ -511,6 +511,8 @@ Element::nameToType(const std::string& type_name) {
         return (Element::list);
     } else if (type_name == "map") {
         return (Element::map);
+    } else if (type_name == "named_set") {
+        return (Element::map);
     } else if (type_name == "null") {
         return (Element::null);
     } else if (type_name == "any") {

+ 21 - 5
src/lib/config/module_spec.cc

@@ -67,10 +67,13 @@ check_config_item(ConstElementPtr spec) {
         check_leaf_item(spec, "list_item_spec", Element::map, true);
         check_config_item(spec->get("list_item_spec"));
     }
-    // todo: add stuff for type map
-    if (Element::nameToType(spec->get("item_type")->stringValue()) == Element::map) {
+
+    if (spec->get("item_type")->stringValue() == "map") {
         check_leaf_item(spec, "map_item_spec", Element::list, true);
         check_config_item_list(spec->get("map_item_spec"));
+    } else if (spec->get("item_type")->stringValue() == "named_set") {
+        check_leaf_item(spec, "named_set_item_spec", Element::map, true);
+        check_config_item(spec->get("named_set_item_spec"));
     }
 }
 
@@ -286,7 +289,8 @@ check_type(ConstElementPtr spec, ConstElementPtr element) {
             return (cur_item_type == "list");
             break;
         case Element::map:
-            return (cur_item_type == "map");
+            return (cur_item_type == "map" ||
+                    cur_item_type == "named_set");
             break;
     }
     return (false);
@@ -323,8 +327,20 @@ ModuleSpec::validateItem(ConstElementPtr spec, ConstElementPtr data,
         }
     }
     if (data->getType() == Element::map) {
-        if (!validateSpecList(spec->get("map_item_spec"), data, full, errors)) {
-            return (false);
+        // either a normal 'map' or a 'named set' (determined by which
+        // subspecification it has)
+        if (spec->contains("map_item_spec")) {
+            if (!validateSpecList(spec->get("map_item_spec"), data, full, errors)) {
+                return (false);
+            }
+        } else {
+            typedef std::pair<std::string, ConstElementPtr> maptype;
+
+            BOOST_FOREACH(maptype m, data->mapValue()) {
+                if (!validateItem(spec->get("named_set_item_spec"), m.second, full, errors)) {
+                    return (false);
+                }
+            }
         }
     }
     return (true);

+ 9 - 0
src/lib/config/tests/module_spec_unittests.cc

@@ -211,3 +211,12 @@ TEST(ModuleSpec, CommandValidation) {
     EXPECT_EQ(errors->get(0)->stringValue(), "Type mismatch");
 
 }
+
+TEST(ModuleSpec, NamedSetValidation) {
+    ModuleSpec dd = moduleSpecFromFile(specfile("spec32.spec"));
+
+    ElementPtr errors = Element::createList();
+    EXPECT_TRUE(dataTestWithErrors(dd, "data32_1.data", errors));
+    EXPECT_FALSE(dataTest(dd, "data32_2.data"));
+    EXPECT_FALSE(dataTest(dd, "data32_3.data"));
+}

+ 4 - 0
src/lib/config/tests/testdata/Makefile.am

@@ -22,6 +22,9 @@ EXTRA_DIST += data22_7.data
 EXTRA_DIST += data22_8.data
 EXTRA_DIST += data22_9.data
 EXTRA_DIST += data22_10.data
+EXTRA_DIST += data32_1.data
+EXTRA_DIST += data32_2.data
+EXTRA_DIST += data32_3.data
 EXTRA_DIST += spec1.spec
 EXTRA_DIST += spec2.spec
 EXTRA_DIST += spec3.spec
@@ -53,3 +56,4 @@ EXTRA_DIST += spec28.spec
 EXTRA_DIST += spec29.spec
 EXTRA_DIST += spec30.spec
 EXTRA_DIST += spec31.spec
+EXTRA_DIST += spec32.spec

+ 3 - 0
src/lib/config/tests/testdata/data32_1.data

@@ -0,0 +1,3 @@
+{
+    "named_set_item": { "foo": 1, "bar": 2 }
+}

+ 3 - 0
src/lib/config/tests/testdata/data32_2.data

@@ -0,0 +1,3 @@
+{
+    "named_set_item": { "foo": "wrongtype", "bar": 2 }
+}

+ 3 - 0
src/lib/config/tests/testdata/data32_3.data

@@ -0,0 +1,3 @@
+{
+    "named_set_item": []
+}

+ 19 - 0
src/lib/config/tests/testdata/spec32.spec

@@ -0,0 +1,19 @@
+{
+  "module_spec": {
+    "module_name": "Spec32",
+    "config_data": [
+      { "item_name": "named_set_item",
+        "item_type": "named_set",
+        "item_optional": false,
+        "item_default": { "a": 1, "b": 2 },
+        "named_set_item_spec": {
+          "item_name": "named_set_element",
+          "item_type": "integer",
+          "item_optional": false,
+          "item_default": 3
+        }
+      }
+    ]
+  }
+}
+

+ 1 - 0
src/lib/datasrc/Makefile.am

@@ -21,6 +21,7 @@ libdatasrc_la_SOURCES += memory_datasrc.h memory_datasrc.cc
 libdatasrc_la_SOURCES += zone.h
 libdatasrc_la_SOURCES += result.h
 libdatasrc_la_SOURCES += logger.h logger.cc
+libdatasrc_la_SOURCES += client.h
 nodist_libdatasrc_la_SOURCES = datasrc_messages.h datasrc_messages.cc
 
 libdatasrc_la_LIBADD = $(top_builddir)/src/lib/exceptions/libexceptions.la

+ 150 - 0
src/lib/datasrc/client.h

@@ -0,0 +1,150 @@
+// Copyright (C) 2011  Internet Systems Consortium, Inc. ("ISC")
+//
+// Permission to use, copy, modify, and/or distribute this software for any
+// purpose with or without fee is hereby granted, provided that the above
+// copyright notice and this permission notice appear in all copies.
+//
+// THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH
+// REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
+// AND FITNESS.  IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT,
+// INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
+// LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
+// OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
+// PERFORMANCE OF THIS SOFTWARE.
+
+#ifndef __DATA_SOURCE_CLIENT_H
+#define __DATA_SOURCE_CLIENT_H 1
+
+#include <datasrc/zone.h>
+
+namespace isc {
+namespace datasrc {
+
+/// \brief The base class of data source clients.
+///
+/// This is an abstract base class that defines the common interface for
+/// various types of data source clients.  A data source client is a top level
+/// access point to a data source, allowing various operations on the data
+/// source such as lookups, traversing or updates.  The client class itself
+/// has limited focus and delegates the responsibility for these specific
+/// operations to other classes; in general methods of this class act as
+/// factories of these other classes.
+///
+/// The following derived classes are currently (expected to be) provided:
+/// - \c InMemoryClient: A client of a conceptual data source that stores
+/// all necessary data in memory for faster lookups
+/// - \c DatabaseClient: A client that uses a real database backend (such as
+/// an SQL database).  It would internally hold a connection to the underlying
+/// database system.
+///
+/// \note It is intentional that while the term these derived classes don't
+/// contain "DataSource" unlike their base class.  It's also noteworthy
+/// that the naming of the base class is somewhat redundant because the
+/// namespace \c datasrc would indicate that it's related to a data source.
+/// The redundant naming comes from the observation that namespaces are
+/// often omitted with \c using directives, in which case "Client"
+/// would be too generic.  On the other hand, concrete derived classes are
+/// generally not expected to be referenced directly from other modules and
+/// applications, so we'll give them more concise names such as InMemoryClient.
+///
+/// A single \c DataSourceClient object is expected to handle only a single
+/// RR class even if the underlying data source contains records for multiple
+/// RR classes.  Likewise, (when we support views) a \c DataSourceClient
+/// object is expected to handle only a single view.
+///
+/// If the application uses multiple threads, each thread will need to
+/// create and use a separate DataSourceClient.  This is because some
+/// database backend doesn't allow multiple threads to share the same
+/// connection to the database.
+///
+/// \note For a client using an in memory backend, this may result in
+/// having a multiple copies of the same data in memory, increasing the
+/// memory footprint substantially.  Depending on how to support multiple
+/// CPU cores for concurrent lookups on the same single data source (which
+/// is not fully fixed yet, and for which multiple threads may be used),
+/// this design may have to be revisited.
+///
+/// This class (and therefore its derived classes) are not copyable.
+/// This is because the derived classes would generally contain attributes
+/// that are not easy to copy (such as a large size of in memory data or a
+/// network connection to a database server).  In order to avoid a surprising
+/// disruption with a naive copy it's prohibited explicitly.  For the expected
+/// usage of the client classes the restriction should be acceptable.
+///
+/// \todo This class is not complete. It needs more factory methods, for
+///     accessing the whole zone, updating it, loading it, etc.
+class DataSourceClient : boost::noncopyable {
+public:
+    /// \brief A helper structure to represent the search result of
+    /// \c find().
+    ///
+    /// This is a straightforward pair of the result code and a share pointer
+    /// to the found zone to represent the result of \c find().
+    /// We use this in order to avoid overloading the return value for both
+    /// the result code ("success" or "not found") and the found object,
+    /// i.e., avoid using \c NULL to mean "not found", etc.
+    ///
+    /// This is a simple value class with no internal state, so for
+    /// convenience we allow the applications to refer to the members
+    /// directly.
+    ///
+    /// See the description of \c find() for the semantics of the member
+    /// variables.
+    struct FindResult {
+        FindResult(result::Result param_code,
+                   const ZoneFinderPtr param_zone_finder) :
+            code(param_code), zone_finder(param_zone_finder)
+        {}
+        const result::Result code;
+        const ZoneFinderPtr zone_finder;
+    };
+
+    ///
+    /// \name Constructors and Destructor.
+    ///
+protected:
+    /// Default constructor.
+    ///
+    /// This is intentionally defined as protected as this base class
+    /// should never be instantiated directly.
+    ///
+    /// The constructor of a concrete derived class may throw an exception.
+    /// This interface does not specify which exceptions can happen (at least
+    /// at this moment), and the caller should expect any type of exception
+    /// and react accordingly.
+    DataSourceClient() {}
+
+public:
+    /// The destructor.
+    virtual ~DataSourceClient() {}
+    //@}
+
+    /// Returns a \c ZoneFinder for a zone that best matches the given name.
+    ///
+    /// A concrete derived version of this method gets access to its backend
+    /// data source to search for a zone whose origin gives the longest match
+    /// against \c name.  It returns the search result in the form of a
+    /// \c FindResult object as follows:
+    /// - \c code: The result code of the operation.
+    ///   - \c result::SUCCESS: A zone that gives an exact match is found
+    ///   - \c result::PARTIALMATCH: A zone whose origin is a
+    ///   super domain of \c name is found (but there is no exact match)
+    ///   - \c result::NOTFOUND: For all other cases.
+    /// - \c zone_finder: Pointer to a \c ZoneFinder object for the found zone
+    /// if one is found; otherwise \c NULL.
+    ///
+    /// A specific derived version of this method may throw an exception.
+    /// This interface does not specify which exceptions can happen (at least
+    /// at this moment), and the caller should expect any type of exception
+    /// and react accordingly.
+    ///
+    /// \param name A domain name for which the search is performed.
+    /// \return A \c FindResult object enclosing the search result (see above).
+    virtual FindResult findZone(const isc::dns::Name& name) const = 0;
+};
+}
+}
+#endif  // DATA_SOURCE_CLIENT_H
+// Local Variables:
+// mode: c++
+// End:

+ 35 - 35
src/lib/datasrc/memory_datasrc.cc

@@ -32,10 +32,10 @@ using namespace isc::dns;
 namespace isc {
 namespace datasrc {
 
-// Private data and hidden methods of MemoryZone
-struct MemoryZone::MemoryZoneImpl {
+// Private data and hidden methods of InMemoryZoneFinder
+struct InMemoryZoneFinder::InMemoryZoneFinderImpl {
     // Constructor
-    MemoryZoneImpl(const RRClass& zone_class, const Name& origin) :
+    InMemoryZoneFinderImpl(const RRClass& zone_class, const Name& origin) :
         zone_class_(zone_class), origin_(origin), origin_data_(NULL),
         domains_(true)
     {
@@ -223,7 +223,7 @@ struct MemoryZone::MemoryZoneImpl {
      * Implementation of longer methods. We put them here, because the
      * access is without the impl_-> and it will get inlined anyway.
      */
-    // Implementation of MemoryZone::add
+    // Implementation of InMemoryZoneFinder::add
     result::Result add(const ConstRRsetPtr& rrset, DomainTree* domains) {
         // Sanitize input.  This will cause an exception to be thrown
         // if the input RRset is empty.
@@ -409,7 +409,7 @@ struct MemoryZone::MemoryZoneImpl {
         }
     }
 
-    // Implementation of MemoryZone::find
+    // Implementation of InMemoryZoneFinder::find
     FindResult find(const Name& name, RRType type,
                     RRsetList* target, const FindOptions options) const
     {
@@ -593,50 +593,50 @@ struct MemoryZone::MemoryZoneImpl {
     }
 };
 
-MemoryZone::MemoryZone(const RRClass& zone_class, const Name& origin) :
-    impl_(new MemoryZoneImpl(zone_class, origin))
+InMemoryZoneFinder::InMemoryZoneFinder(const RRClass& zone_class, const Name& origin) :
+    impl_(new InMemoryZoneFinderImpl(zone_class, origin))
 {
     LOG_DEBUG(logger, DBG_TRACE_BASIC, DATASRC_MEM_CREATE).arg(origin).
         arg(zone_class);
 }
 
-MemoryZone::~MemoryZone() {
+InMemoryZoneFinder::~InMemoryZoneFinder() {
     LOG_DEBUG(logger, DBG_TRACE_BASIC, DATASRC_MEM_DESTROY).arg(getOrigin()).
         arg(getClass());
     delete impl_;
 }
 
 const Name&
-MemoryZone::getOrigin() const {
+InMemoryZoneFinder::getOrigin() const {
     return (impl_->origin_);
 }
 
 const RRClass&
-MemoryZone::getClass() const {
+InMemoryZoneFinder::getClass() const {
     return (impl_->zone_class_);
 }
 
-Zone::FindResult
-MemoryZone::find(const Name& name, const RRType& type,
+ZoneFinder::FindResult
+InMemoryZoneFinder::find(const Name& name, const RRType& type,
                  RRsetList* target, const FindOptions options) const
 {
     return (impl_->find(name, type, target, options));
 }
 
 result::Result
-MemoryZone::add(const ConstRRsetPtr& rrset) {
+InMemoryZoneFinder::add(const ConstRRsetPtr& rrset) {
     return (impl_->add(rrset, &impl_->domains_));
 }
 
 
 void
-MemoryZone::load(const string& filename) {
+InMemoryZoneFinder::load(const string& filename) {
     LOG_DEBUG(logger, DBG_TRACE_BASIC, DATASRC_MEM_LOAD).arg(getOrigin()).
         arg(filename);
     // Load it into a temporary tree
-    MemoryZoneImpl::DomainTree tmp;
+    InMemoryZoneFinderImpl::DomainTree tmp;
     masterLoad(filename.c_str(), getOrigin(), getClass(),
-        boost::bind(&MemoryZoneImpl::addFromLoad, impl_, _1, &tmp));
+        boost::bind(&InMemoryZoneFinderImpl::addFromLoad, impl_, _1, &tmp));
     // If it went well, put it inside
     impl_->file_name_ = filename;
     tmp.swap(impl_->domains_);
@@ -644,61 +644,61 @@ MemoryZone::load(const string& filename) {
 }
 
 void
-MemoryZone::swap(MemoryZone& zone) {
+InMemoryZoneFinder::swap(InMemoryZoneFinder& zone_finder) {
     LOG_DEBUG(logger, DBG_TRACE_BASIC, DATASRC_MEM_SWAP).arg(getOrigin()).
-        arg(zone.getOrigin());
-    std::swap(impl_, zone.impl_);
+        arg(zone_finder.getOrigin());
+    std::swap(impl_, zone_finder.impl_);
 }
 
 const string
-MemoryZone::getFileName() const {
+InMemoryZoneFinder::getFileName() const {
     return (impl_->file_name_);
 }
 
-/// Implementation details for \c MemoryDataSrc hidden from the public
+/// Implementation details for \c InMemoryClient hidden from the public
 /// interface.
 ///
-/// For now, \c MemoryDataSrc only contains a \c ZoneTable object, which
-/// consists of (pointers to) \c MemoryZone objects, we may add more
+/// For now, \c InMemoryClient only contains a \c ZoneTable object, which
+/// consists of (pointers to) \c InMemoryZoneFinder objects, we may add more
 /// member variables later for new features.
-class MemoryDataSrc::MemoryDataSrcImpl {
+class InMemoryClient::InMemoryClientImpl {
 public:
-    MemoryDataSrcImpl() : zone_count(0) {}
+    InMemoryClientImpl() : zone_count(0) {}
     unsigned int zone_count;
     ZoneTable zone_table;
 };
 
-MemoryDataSrc::MemoryDataSrc() : impl_(new MemoryDataSrcImpl)
+InMemoryClient::InMemoryClient() : impl_(new InMemoryClientImpl)
 {}
 
-MemoryDataSrc::~MemoryDataSrc() {
+InMemoryClient::~InMemoryClient() {
     delete impl_;
 }
 
 unsigned int
-MemoryDataSrc::getZoneCount() const {
+InMemoryClient::getZoneCount() const {
     return (impl_->zone_count);
 }
 
 result::Result
-MemoryDataSrc::addZone(ZonePtr zone) {
-    if (!zone) {
+InMemoryClient::addZone(ZoneFinderPtr zone_finder) {
+    if (!zone_finder) {
         isc_throw(InvalidParameter,
-                  "Null pointer is passed to MemoryDataSrc::addZone()");
+                  "Null pointer is passed to InMemoryClient::addZone()");
     }
 
     LOG_DEBUG(logger, DBG_TRACE_BASIC, DATASRC_MEM_ADD_ZONE).
-        arg(zone->getOrigin()).arg(zone->getClass().toText());
+        arg(zone_finder->getOrigin()).arg(zone_finder->getClass().toText());
 
-    const result::Result result = impl_->zone_table.addZone(zone);
+    const result::Result result = impl_->zone_table.addZone(zone_finder);
     if (result == result::SUCCESS) {
         ++impl_->zone_count;
     }
     return (result);
 }
 
-MemoryDataSrc::FindResult
-MemoryDataSrc::findZone(const isc::dns::Name& name) const {
+InMemoryClient::FindResult
+InMemoryClient::findZone(const isc::dns::Name& name) const {
     LOG_DEBUG(logger, DBG_TRACE_DATA, DATASRC_MEM_FIND_ZONE).arg(name);
     return (FindResult(impl_->zone_table.findZone(name).code,
                        impl_->zone_table.findZone(name).zone));

+ 59 - 97
src/lib/datasrc/memory_datasrc.h

@@ -17,7 +17,10 @@
 
 #include <string>
 
+#include <boost/noncopyable.hpp>
+
 #include <datasrc/zonetable.h>
+#include <datasrc/client.h>
 
 namespace isc {
 namespace dns {
@@ -27,18 +30,17 @@ class RRsetList;
 
 namespace datasrc {
 
-/// A derived zone class intended to be used with the memory data source.
-class MemoryZone : public Zone {
+/// A derived zone finder class intended to be used with the memory data source.
+///
+/// Conceptually this "finder" maintains a local in-memory copy of all RRs
+/// of a single zone from some kind of source (right now it's a textual
+/// master file, but it could also be another data source with a database
+/// backend).  This is why the class has methods like \c load() or \c add().
+///
+/// This class is non copyable.
+class InMemoryZoneFinder : boost::noncopyable, public ZoneFinder {
     ///
     /// \name Constructors and Destructor.
-    ///
-    /// \b Note:
-    /// The copy constructor and the assignment operator are intentionally
-    /// defined as private, making this class non copyable.
-    //@{
-private:
-    MemoryZone(const MemoryZone& source);
-    MemoryZone& operator=(const MemoryZone& source);
 public:
     /// \brief Constructor from zone parameters.
     ///
@@ -48,10 +50,11 @@ public:
     ///
     /// \param rrclass The RR class of the zone.
     /// \param origin The origin name of the zone.
-    MemoryZone(const isc::dns::RRClass& rrclass, const isc::dns::Name& origin);
+    InMemoryZoneFinder(const isc::dns::RRClass& rrclass,
+                       const isc::dns::Name& origin);
 
     /// The destructor.
-    virtual ~MemoryZone();
+    virtual ~InMemoryZoneFinder();
     //@}
 
     /// \brief Returns the origin of the zone.
@@ -128,14 +131,14 @@ public:
     /// Return the master file name of the zone
     ///
     /// This method returns the name of the zone's master file to be loaded.
-    /// The returned string will be an empty unless the zone has successfully
-    /// loaded a zone.
+    /// The returned string will be an empty unless the zone finder has
+    /// successfully loaded a zone.
     ///
     /// This method should normally not throw an exception.  But the creation
     /// of the return string may involve a resource allocation, and if it
     /// fails, the corresponding standard exception will be thrown.
     ///
-    /// \return The name of the zone file loaded in the zone, or an empty
+    /// \return The name of the zone file loaded in the zone finder, or an empty
     /// string if the zone hasn't loaded any file.
     const std::string getFileName() const;
 
@@ -164,143 +167,102 @@ public:
     ///     configuration reloading is written.
     void load(const std::string& filename);
 
-    /// Exchanges the content of \c this zone with that of the given \c zone.
+    /// Exchanges the content of \c this zone finder with that of the given
+    /// \c zone_finder.
     ///
     /// This method never throws an exception.
     ///
-    /// \param zone Another \c MemoryZone object which is to be swapped with
-    /// \c this zone.
-    void swap(MemoryZone& zone);
+    /// \param zone_finder Another \c InMemoryZone object which is to
+    /// be swapped with \c this zone finder.
+    void swap(InMemoryZoneFinder& zone_finder);
 
 private:
     /// \name Hidden private data
     //@{
-    struct MemoryZoneImpl;
-    MemoryZoneImpl* impl_;
+    struct InMemoryZoneFinderImpl;
+    InMemoryZoneFinderImpl* impl_;
     //@}
 };
 
-/// \brief A data source that uses in memory dedicated backend.
+/// \brief A data source client that holds all necessary data in memory.
 ///
-/// The \c MemoryDataSrc class represents a data source and provides a
-/// basic interface to help DNS lookup processing. For a given domain
-/// name, its \c findZone() method searches the in memory dedicated backend
-/// for the zone that gives a longest match against that name.
+/// The \c InMemoryClient class provides an access to a conceptual data
+/// source that maintains all necessary data in a memory image, thereby
+/// allowing much faster lookups.  The in memory data is a copy of some
+/// real physical source - in the current implementation a list of zones
+/// are populated as a result of \c addZone() calls; zone data is given
+/// in a standard master file (but there's a plan to use database backends
+/// as a source of the in memory data).
 ///
-/// The in memory dedicated backend are assumed to be of the same RR class,
-/// but the \c MemoryDataSrc class does not enforce the assumption through
+/// Although every data source client is assumed to be of the same RR class,
+/// the \c InMemoryClient class does not enforce the assumption through
 /// its interface.
 /// For example, the \c addZone() method does not check if the new zone is of
-/// the same RR class as that of the others already in the dedicated backend.
+/// the same RR class as that of the others already in memory.
 /// It is caller's responsibility to ensure this assumption.
 ///
 /// <b>Notes to developer:</b>
 ///
-/// For now, we don't make it a derived class of AbstractDataSrc because the
-/// interface is so different (we'll eventually consider this as part of the
-/// generalization work).
-///
 /// The addZone() method takes a (Boost) shared pointer because it would be
 /// inconvenient to require the caller to maintain the ownership of zones,
 /// while it wouldn't be safe to delete unnecessary zones inside the dedicated
 /// backend.
 ///
-/// The findZone() method takes a domain name and returns the best matching \c
-/// MemoryZone in the form of (Boost) shared pointer, so that it can provide
-/// the general interface for all data sources.
-class MemoryDataSrc {
+/// The findZone() method takes a domain name and returns the best matching 
+/// \c InMemoryZoneFinder in the form of (Boost) shared pointer, so that it can
+/// provide the general interface for all data sources.
+class InMemoryClient : public DataSourceClient {
 public:
-    /// \brief A helper structure to represent the search result of
-    /// <code>MemoryDataSrc::find()</code>.
-    ///
-    /// This is a straightforward pair of the result code and a share pointer
-    /// to the found zone to represent the result of \c find().
-    /// We use this in order to avoid overloading the return value for both
-    /// the result code ("success" or "not found") and the found object,
-    /// i.e., avoid using \c NULL to mean "not found", etc.
-    ///
-    /// This is a simple value class with no internal state, so for
-    /// convenience we allow the applications to refer to the members
-    /// directly.
-    ///
-    /// See the description of \c find() for the semantics of the member
-    /// variables.
-    struct FindResult {
-        FindResult(result::Result param_code, const ZonePtr param_zone) :
-            code(param_code), zone(param_zone)
-        {}
-        const result::Result code;
-        const ZonePtr zone;
-    };
-
     ///
     /// \name Constructors and Destructor.
     ///
-    /// \b Note:
-    /// The copy constructor and the assignment operator are intentionally
-    /// defined as private, making this class non copyable.
     //@{
-private:
-    MemoryDataSrc(const MemoryDataSrc& source);
-    MemoryDataSrc& operator=(const MemoryDataSrc& source);
 
-public:
     /// Default constructor.
     ///
     /// This constructor internally involves resource allocation, and if
     /// it fails, a corresponding standard exception will be thrown.
     /// It never throws an exception otherwise.
-    MemoryDataSrc();
+    InMemoryClient();
 
     /// The destructor.
-    ~MemoryDataSrc();
+    ~InMemoryClient();
     //@}
 
-    /// Return the number of zones stored in the data source.
+    /// Return the number of zones stored in the client.
     ///
     /// This method never throws an exception.
     ///
-    /// \return The number of zones stored in the data source.
+    /// \return The number of zones stored in the client.
     unsigned int getZoneCount() const;
 
-    /// Add a \c Zone to the \c MemoryDataSrc.
+    /// Add a zone (in the form of \c ZoneFinder) to the \c InMemoryClient.
     ///
-    /// \c Zone must not be associated with a NULL pointer; otherwise
+    /// \c zone_finder must not be associated with a NULL pointer; otherwise
     /// an exception of class \c InvalidParameter will be thrown.
     /// If internal resource allocation fails, a corresponding standard
     /// exception will be thrown.
     /// This method never throws an exception otherwise.
     ///
-    /// \param zone A \c Zone object to be added.
-    /// \return \c result::SUCCESS If the zone is successfully
-    /// added to the memory data source.
+    /// \param zone_finder A \c ZoneFinder object to be added.
+    /// \return \c result::SUCCESS If the zone_finder is successfully
+    /// added to the client.
     /// \return \c result::EXIST The memory data source already
     /// stores a zone that has the same origin.
-    result::Result addZone(ZonePtr zone);
+    result::Result addZone(ZoneFinderPtr zone_finder);
 
-    /// Find a \c Zone that best matches the given name in the \c MemoryDataSrc.
-    ///
-    /// It searches the internal storage for a \c Zone that gives the
-    /// longest match against \c name, and returns the result in the
-    /// form of a \c FindResult object as follows:
-    /// - \c code: The result code of the operation.
-    ///   - \c result::SUCCESS: A zone that gives an exact match
-    //    is found
-    ///   - \c result::PARTIALMATCH: A zone whose origin is a
-    //    super domain of \c name is found (but there is no exact match)
-    ///   - \c result::NOTFOUND: For all other cases.
-    /// - \c zone: A "Boost" shared pointer to the found \c Zone object if one
-    //  is found; otherwise \c NULL.
-    ///
-    /// This method never throws an exception.
+    /// Returns a \c ZoneFinder for a zone_finder that best matches the given
+    /// name.
     ///
-    /// \param name A domain name for which the search is performed.
-    /// \return A \c FindResult object enclosing the search result (see above).
-    FindResult findZone(const isc::dns::Name& name) const;
+    /// This derived version of the method never throws an exception.
+    /// For other details see \c DataSourceClient::findZone().
+    virtual FindResult findZone(const isc::dns::Name& name) const;
 
 private:
-    class MemoryDataSrcImpl;
-    MemoryDataSrcImpl* impl_;
+    // TODO: Do we still need the PImpl if nobody should manipulate this class
+    // directly any more (it should be handled through DataSourceClient)?
+    class InMemoryClientImpl;
+    InMemoryClientImpl* impl_;
 };
 }
 }

+ 3 - 3
src/lib/datasrc/rbtree.h

@@ -704,9 +704,9 @@ public:
     /// \brief Find with callback and node chain.
     ///
     /// This version of \c find() is specifically designed for the backend
-    /// of the \c MemoryZone class, and implements all necessary features
-    /// for that purpose.  Other applications shouldn't need these additional
-    /// features, and should normally use the simpler versions.
+    /// of the \c InMemoryZoneFinder class, and implements all necessary
+    /// features for that purpose.  Other applications shouldn't need these
+    /// additional features, and should normally use the simpler versions.
     ///
     /// This version of \c find() calls the callback whenever traversing (on
     /// the way from root down the tree) a marked node on the way down through

File diff suppressed because it is too large
+ 324 - 298
src/lib/datasrc/tests/memory_datasrc_unittest.cc


+ 19 - 17
src/lib/datasrc/tests/zonetable_unittest.cc

@@ -18,7 +18,7 @@
 #include <dns/rrclass.h>
 
 #include <datasrc/zonetable.h>
-// We use MemoryZone to put something into the table
+// We use InMemoryZone to put something into the table
 #include <datasrc/memory_datasrc.h>
 
 #include <gtest/gtest.h>
@@ -28,31 +28,32 @@ using namespace isc::datasrc;
 
 namespace {
 TEST(ZoneTest, init) {
-    MemoryZone zone(RRClass::IN(), Name("example.com"));
+    InMemoryZoneFinder zone(RRClass::IN(), Name("example.com"));
     EXPECT_EQ(Name("example.com"), zone.getOrigin());
     EXPECT_EQ(RRClass::IN(), zone.getClass());
 
-    MemoryZone ch_zone(RRClass::CH(), Name("example"));
+    InMemoryZoneFinder ch_zone(RRClass::CH(), Name("example"));
     EXPECT_EQ(Name("example"), ch_zone.getOrigin());
     EXPECT_EQ(RRClass::CH(), ch_zone.getClass());
 }
 
 TEST(ZoneTest, find) {
-    MemoryZone zone(RRClass::IN(), Name("example.com"));
-    EXPECT_EQ(Zone::NXDOMAIN,
+    InMemoryZoneFinder zone(RRClass::IN(), Name("example.com"));
+    EXPECT_EQ(ZoneFinder::NXDOMAIN,
               zone.find(Name("www.example.com"), RRType::A()).code);
 }
 
 class ZoneTableTest : public ::testing::Test {
 protected:
-    ZoneTableTest() : zone1(new MemoryZone(RRClass::IN(),
-                                           Name("example.com"))),
-                      zone2(new MemoryZone(RRClass::IN(),
-                                           Name("example.net"))),
-                      zone3(new MemoryZone(RRClass::IN(), Name("example")))
+    ZoneTableTest() : zone1(new InMemoryZoneFinder(RRClass::IN(),
+                                                   Name("example.com"))),
+                      zone2(new InMemoryZoneFinder(RRClass::IN(),
+                                                   Name("example.net"))),
+                      zone3(new InMemoryZoneFinder(RRClass::IN(),
+                                                   Name("example")))
     {}
     ZoneTable zone_table;
-    ZonePtr zone1, zone2, zone3;
+    ZoneFinderPtr zone1, zone2, zone3;
 };
 
 TEST_F(ZoneTableTest, addZone) {
@@ -60,7 +61,8 @@ TEST_F(ZoneTableTest, addZone) {
     EXPECT_EQ(result::EXIST, zone_table.addZone(zone1));
     // names are compared in a case insensitive manner.
     EXPECT_EQ(result::EXIST, zone_table.addZone(
-                  ZonePtr(new MemoryZone(RRClass::IN(), Name("EXAMPLE.COM")))));
+                  ZoneFinderPtr(new InMemoryZoneFinder(RRClass::IN(),
+                                                       Name("EXAMPLE.COM")))));
 
     EXPECT_EQ(result::SUCCESS, zone_table.addZone(zone2));
     EXPECT_EQ(result::SUCCESS, zone_table.addZone(zone3));
@@ -68,11 +70,11 @@ TEST_F(ZoneTableTest, addZone) {
     // Zone table is indexed only by name.  Duplicate origin name with
     // different zone class isn't allowed.
     EXPECT_EQ(result::EXIST, zone_table.addZone(
-                  ZonePtr(new MemoryZone(RRClass::CH(),
-                                         Name("example.com")))));
+                  ZoneFinderPtr(new InMemoryZoneFinder(RRClass::CH(),
+                                                       Name("example.com")))));
 
     /// Bogus zone (NULL)
-    EXPECT_THROW(zone_table.addZone(ZonePtr()), isc::InvalidParameter);
+    EXPECT_THROW(zone_table.addZone(ZoneFinderPtr()), isc::InvalidParameter);
 }
 
 TEST_F(ZoneTableTest, DISABLED_removeZone) {
@@ -95,7 +97,7 @@ TEST_F(ZoneTableTest, findZone) {
 
     EXPECT_EQ(result::NOTFOUND,
               zone_table.findZone(Name("example.org")).code);
-    EXPECT_EQ(ConstZonePtr(),
+    EXPECT_EQ(ConstZoneFinderPtr(),
               zone_table.findZone(Name("example.org")).zone);
 
     // there's no exact match.  the result should be the longest match,
@@ -107,7 +109,7 @@ TEST_F(ZoneTableTest, findZone) {
 
     // make sure the partial match is indeed the longest match by adding
     // a zone with a shorter origin and query again.
-    ZonePtr zone_com(new MemoryZone(RRClass::IN(), Name("com")));
+    ZoneFinderPtr zone_com(new InMemoryZoneFinder(RRClass::IN(), Name("com")));
     EXPECT_EQ(result::SUCCESS, zone_table.addZone(zone_com));
     EXPECT_EQ(Name("example.com"),
               zone_table.findZone(Name("www.example.com")).zone->getOrigin());

+ 10 - 10
src/lib/datasrc/zone.h

@@ -27,7 +27,7 @@ namespace datasrc {
 /// a DNS zone as part of data source.
 ///
 /// At the moment this is provided mainly for making the \c ZoneTable class
-/// and the authoritative query logic  testable, and only provides a minimal
+/// and the authoritative query logic testable, and only provides a minimal
 /// set of features.
 /// This is why this class is defined in the same header file, but it may
 /// have to move to a separate header file when we understand what is
@@ -53,9 +53,9 @@ namespace datasrc {
 ///
 /// <b>Note:</b> Unlike some other abstract base classes we don't name the
 /// class beginning with "Abstract".  This is because we want to have
-/// commonly used definitions such as \c Result and \c ZonePtr, and we want
-/// to make them look more intuitive.
-class Zone {
+/// commonly used definitions such as \c Result and \c ZoneFinderPtr, and we
+/// want to make them look more intuitive.
+class ZoneFinder {
 public:
     /// Result codes of the \c find() method.
     ///
@@ -119,10 +119,10 @@ protected:
     ///
     /// This is intentionally defined as \c protected as this base class should
     /// never be instantiated (except as part of a derived class).
-    Zone() {}
+    ZoneFinder() {}
 public:
     /// The destructor.
-    virtual ~Zone() {}
+    virtual ~ZoneFinder() {}
     //@}
 
     ///
@@ -201,11 +201,11 @@ public:
     //@}
 };
 
-/// \brief A pointer-like type pointing to a \c Zone object.
-typedef boost::shared_ptr<Zone> ZonePtr;
+/// \brief A pointer-like type pointing to a \c ZoneFinder object.
+typedef boost::shared_ptr<ZoneFinder> ZoneFinderPtr;
 
-/// \brief A pointer-like type pointing to a \c Zone object.
-typedef boost::shared_ptr<const Zone> ConstZonePtr;
+/// \brief A pointer-like type pointing to a \c ZoneFinder object.
+typedef boost::shared_ptr<const ZoneFinder> ConstZoneFinderPtr;
 
 }
 }

+ 6 - 6
src/lib/datasrc/zonetable.cc

@@ -28,8 +28,8 @@ namespace datasrc {
 /// \short Private data and implementation of ZoneTable
 struct ZoneTable::ZoneTableImpl {
     // Type aliases to make it shorter
-    typedef RBTree<Zone> ZoneTree;
-    typedef RBNode<Zone> ZoneNode;
+    typedef RBTree<ZoneFinder> ZoneTree;
+    typedef RBNode<ZoneFinder> ZoneNode;
     // The actual storage
     ZoneTree zones_;
 
@@ -40,7 +40,7 @@ struct ZoneTable::ZoneTableImpl {
      */
 
     // Implementation of ZoneTable::addZone
-    result::Result addZone(ZonePtr zone) {
+    result::Result addZone(ZoneFinderPtr zone) {
         // Sanity check
         if (!zone) {
             isc_throw(InvalidParameter,
@@ -85,12 +85,12 @@ struct ZoneTable::ZoneTableImpl {
                 break;
             // We have no data there, so translate the pointer to NULL as well
             case ZoneTree::NOTFOUND:
-                return (FindResult(result::NOTFOUND, ZonePtr()));
+                return (FindResult(result::NOTFOUND, ZoneFinderPtr()));
             // Can Not Happen
             default:
                 assert(0);
                 // Because of warning
-                return (FindResult(result::NOTFOUND, ZonePtr()));
+                return (FindResult(result::NOTFOUND, ZoneFinderPtr()));
         }
 
         // Can Not Happen (remember, NOTFOUND is handled)
@@ -108,7 +108,7 @@ ZoneTable::~ZoneTable() {
 }
 
 result::Result
-ZoneTable::addZone(ZonePtr zone) {
+ZoneTable::addZone(ZoneFinderPtr zone) {
     return (impl_->addZone(zone));
 }
 

+ 3 - 3
src/lib/datasrc/zonetable.h

@@ -41,11 +41,11 @@ namespace datasrc {
 class ZoneTable {
 public:
     struct FindResult {
-        FindResult(result::Result param_code, const ZonePtr param_zone) :
+        FindResult(result::Result param_code, const ZoneFinderPtr param_zone) :
             code(param_code), zone(param_zone)
         {}
         const result::Result code;
-        const ZonePtr zone;
+        const ZoneFinderPtr zone;
     };
     ///
     /// \name Constructors and Destructor.
@@ -83,7 +83,7 @@ public:
     /// added to the zone table.
     /// \return \c result::EXIST The zone table already contains
     /// zone of the same origin.
-    result::Result addZone(ZonePtr zone);
+    result::Result addZone(ZoneFinderPtr zone);
 
     /// Remove a \c Zone of the given origin name from the \c ZoneTable.
     ///

+ 46 - 3
src/lib/dns/message.cc

@@ -239,7 +239,28 @@ MessageImpl::toWire(AbstractMessageRenderer& renderer, TSIGContext* tsig_ctx) {
                   "Message rendering attempted without Opcode set");
     }
 
+    // Reserve the space for TSIG (if needed) so that we can handle truncation
+    // case correctly later when that happens.  orig_xxx variables remember
+    // some configured parameters of renderer in case they are needed in
+    // truncation processing below.
+    const size_t tsig_len = (tsig_ctx != NULL) ? tsig_ctx->getTSIGLength() : 0;
+    const size_t orig_msg_len_limit = renderer.getLengthLimit();
+    const AbstractMessageRenderer::CompressMode orig_compress_mode =
+        renderer.getCompressMode();
+    if (tsig_len > 0) {
+        if (tsig_len > orig_msg_len_limit) {
+            isc_throw(InvalidParameter, "Failed to render DNS message: "
+                      "too small limit for a TSIG (" <<
+                      orig_msg_len_limit << ")");
+        }
+        renderer.setLengthLimit(orig_msg_len_limit - tsig_len);
+    }
+
     // reserve room for the header
+    if (renderer.getLengthLimit() < HEADERLEN) {
+        isc_throw(InvalidParameter, "Failed to render DNS message: "
+                  "too small limit for a Header");
+    }
     renderer.skip(HEADERLEN);
 
     uint16_t qdcount =
@@ -284,6 +305,22 @@ MessageImpl::toWire(AbstractMessageRenderer& renderer, TSIGContext* tsig_ctx) {
         }
     }
 
+    // If we're adding a TSIG to a truncated message, clear all RRsets
+    // from the message except for the question before adding the TSIG.
+    // If even (some of) the question doesn't fit, don't include it.
+    if (tsig_ctx != NULL && renderer.isTruncated()) {
+        renderer.clear();
+        renderer.setLengthLimit(orig_msg_len_limit - tsig_len);
+        renderer.setCompressMode(orig_compress_mode);
+        renderer.skip(HEADERLEN);
+        qdcount = for_each(questions_.begin(), questions_.end(),
+                           RenderSection<QuestionPtr>(renderer,
+                                                      false)).getTotalCount();
+        ancount = 0;
+        nscount = 0;
+        arcount = 0;
+    }
+
     // Adjust the counter buffer.
     // XXX: these may not be equal to the number of corresponding entries
     // in rrsets_[] or questions_ if truncation occurred or an EDNS OPT RR
@@ -315,10 +352,16 @@ MessageImpl::toWire(AbstractMessageRenderer& renderer, TSIGContext* tsig_ctx) {
     renderer.writeUint16At(arcount, header_pos);
 
     // Add TSIG, if necessary, at the end of the message.
-    // TODO: truncate case consideration
     if (tsig_ctx != NULL) {
-        tsig_ctx->sign(qid_, renderer.getData(),
-                       renderer.getLength())->toWire(renderer);
+        // Release the reserved space in the renderer.
+        renderer.setLengthLimit(orig_msg_len_limit);
+
+        const int tsig_count =
+            tsig_ctx->sign(qid_, renderer.getData(),
+                           renderer.getLength())->toWire(renderer);
+        if (tsig_count != 1) {
+            isc_throw(Unexpected, "Failed to render a TSIG RR");
+        }
 
         // update the ARCOUNT for the TSIG RR.  Note that for a sane DNS
         // message arcount should never overflow to 0.

+ 11 - 0
src/lib/dns/message.h

@@ -565,6 +565,17 @@ public:
     /// \c tsig_ctx will be updated based on the fact it was used for signing
     /// and with the latest MAC.
     ///
+    /// \exception InvalidMessageOperation The message is not in the Render
+    /// mode, or either Rcode or Opcode is not set.
+    /// \exception InvalidParameter The allowable limit of \c renderer is too
+    /// small for a TSIG or the Header section.  Note that this shouldn't
+    /// happen with parameters as defined in the standard protocols,
+    /// so it's more likely a program bug.
+    /// \exception Unexpected Rendering the TSIG RR fails.  The implementation
+    /// internally makes sure this doesn't happen, so if that ever occurs
+    /// it should mean a bug either in the TSIG context or in the renderer
+    /// implementation.
+    ///
     /// \param renderer See the other version
     /// \param tsig_ctx A TSIG context that is to be used for signing the
     /// message

+ 9 - 0
src/lib/dns/python/message_python.cc

@@ -703,6 +703,15 @@ Message_toWire(s_Message* self, PyObject* args) {
             // python program has a bug.
             PyErr_SetString(po_TSIGContextError, ex.what());
             return (NULL);
+        } catch (const std::exception& ex) {
+            // Other exceptions should be rare (most likely an implementation
+            // bug)
+            PyErr_SetString(po_TSIGContextError, ex.what());
+            return (NULL);
+        } catch (...) {
+            PyErr_SetString(PyExc_RuntimeError,
+                            "Unexpected C++ exception in Message.to_wire");
+            return (NULL);
         }
     }
     PyErr_Clear();

+ 105 - 16
src/lib/dns/python/tests/message_python_test.py

@@ -21,6 +21,7 @@ import unittest
 import os
 from pydnspp import *
 from testutil import *
+from pyunittests_util import fix_current_time
 
 # helper functions for tests taken from c++ unittests
 if "TESTDATA_PATH" in os.environ:
@@ -31,7 +32,7 @@ else:
 def factoryFromFile(message, file):
     data = read_wire_data(file)
     message.from_wire(data)
-    pass
+    return data
 
 # we don't have direct comparison for rrsets right now (should we?
 # should go in the cpp version first then), so also no direct list
@@ -44,6 +45,15 @@ def compare_rrset_list(list1, list2):
             return False
     return True
 
+# These are used for TSIG + TC tests
+LONG_TXT1 = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde";
+
+LONG_TXT2 = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456";
+
+LONG_TXT3 = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef01";
+
+LONG_TXT4 = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0";
+
 # a complete message taken from cpp tests, for testing towire and totext
 def create_message():
     message_render = Message(Message.RENDER)
@@ -62,16 +72,12 @@ def create_message():
     message_render.add_rrset(Message.SECTION_ANSWER, rrset)
     return message_render
 
-def strip_mutable_tsig_data(data):
-    # Unfortunately we cannot easily compare TSIG RR because we can't tweak
-    # current time.  As a work around this helper function strips off the time
-    # dependent part of TSIG RDATA, i.e., the MAC (assuming HMAC-MD5) and
-    # Time Signed.
-    return data[0:-32] + data[-26:-22] + data[-6:]
-
 class MessageTest(unittest.TestCase):
 
     def setUp(self):
+        # make sure we don't use faked time unless explicitly do so in tests
+        fix_current_time(None)
+
         self.p = Message(Message.PARSE)
         self.r = Message(Message.RENDER)
 
@@ -90,6 +96,10 @@ class MessageTest(unittest.TestCase):
         self.tsig_key = TSIGKey("www.example.com:SFuWd/q99SzF8Yzd1QbB9g==")
         self.tsig_ctx = TSIGContext(self.tsig_key)
 
+    def tearDown(self):
+        # reset any faked current time setting (it would affect other tests)
+        fix_current_time(None)
+
     def test_init(self):
         self.assertRaises(TypeError, Message, -1)
         self.assertRaises(TypeError, Message, 3)
@@ -285,33 +295,112 @@ class MessageTest(unittest.TestCase):
         self.assertRaises(InvalidMessageOperation, self.r.to_wire,
                           MessageRenderer())
 
-    def __common_tsigquery_setup(self):
+    def __common_tsigmessage_setup(self, flags=[Message.HEADERFLAG_RD],
+                                   rrtype=RRType("A"), answer_data=None):
         self.r.set_opcode(Opcode.QUERY())
         self.r.set_rcode(Rcode.NOERROR())
-        self.r.set_header_flag(Message.HEADERFLAG_RD)
+        for flag in flags:
+            self.r.set_header_flag(flag)
+        if answer_data is not None:
+            rrset = RRset(Name("www.example.com"), RRClass("IN"),
+                          rrtype, RRTTL(86400))
+            for rdata in answer_data:
+                rrset.add_rdata(Rdata(rrtype, RRClass("IN"), rdata))
+            self.r.add_rrset(Message.SECTION_ANSWER, rrset)
         self.r.add_question(Question(Name("www.example.com"),
-                                     RRClass("IN"), RRType("A")))
+                                     RRClass("IN"), rrtype))
 
     def __common_tsig_checks(self, expected_file):
         renderer = MessageRenderer()
         self.r.to_wire(renderer, self.tsig_ctx)
-        actual_wire = strip_mutable_tsig_data(renderer.get_data())
-        expected_wire = strip_mutable_tsig_data(read_wire_data(expected_file))
-        self.assertEqual(expected_wire, actual_wire)
+        self.assertEqual(read_wire_data(expected_file), renderer.get_data())
 
     def test_to_wire_with_tsig(self):
+        fix_current_time(0x4da8877a)
         self.r.set_qid(0x2d65)
-        self.__common_tsigquery_setup()
+        self.__common_tsigmessage_setup()
         self.__common_tsig_checks("message_toWire2.wire")
 
     def test_to_wire_with_edns_tsig(self):
+        fix_current_time(0x4db60d1f)
         self.r.set_qid(0x6cd)
-        self.__common_tsigquery_setup()
+        self.__common_tsigmessage_setup()
         edns = EDNS()
         edns.set_udp_size(4096)
         self.r.set_edns(edns)
         self.__common_tsig_checks("message_toWire3.wire")
 
+    def test_to_wire_tsig_truncation(self):
+        fix_current_time(0x4e179212)
+        data = factoryFromFile(self.p, "message_fromWire17.wire")
+        self.assertEqual(TSIGError.NOERROR,
+                         self.tsig_ctx.verify(self.p.get_tsig_record(), data))
+        self.r.set_qid(0x22c2)
+        self.__common_tsigmessage_setup([Message.HEADERFLAG_QR,
+                                         Message.HEADERFLAG_AA,
+                                         Message.HEADERFLAG_RD],
+                                        RRType("TXT"),
+                                        [LONG_TXT1, LONG_TXT2])
+        self.__common_tsig_checks("message_toWire4.wire")
+
+    def test_to_wire_tsig_truncation2(self):
+        fix_current_time(0x4e179212)
+        data = factoryFromFile(self.p, "message_fromWire17.wire")
+        self.assertEqual(TSIGError.NOERROR,
+                         self.tsig_ctx.verify(self.p.get_tsig_record(), data))
+        self.r.set_qid(0x22c2)
+        self.__common_tsigmessage_setup([Message.HEADERFLAG_QR,
+                                         Message.HEADERFLAG_AA,
+                                         Message.HEADERFLAG_RD],
+                                        RRType("TXT"),
+                                        [LONG_TXT1, LONG_TXT3])
+        self.__common_tsig_checks("message_toWire4.wire")
+
+    def test_to_wire_tsig_truncation3(self):
+        self.r.set_opcode(Opcode.QUERY())
+        self.r.set_rcode(Rcode.NOERROR())
+        for i in range(1, 68):
+            self.r.add_question(Question(Name("www.example.com"),
+                                         RRClass("IN"), RRType(i)))
+        renderer = MessageRenderer()
+        self.r.to_wire(renderer, self.tsig_ctx)
+
+        self.p.from_wire(renderer.get_data())
+        self.assertTrue(self.p.get_header_flag(Message.HEADERFLAG_TC))
+        self.assertEqual(66, self.p.get_rr_count(Message.SECTION_QUESTION))
+        self.assertNotEqual(None, self.p.get_tsig_record())
+
+    def test_to_wire_tsig_no_truncation(self):
+        fix_current_time(0x4e17b38d)
+        data = factoryFromFile(self.p, "message_fromWire18.wire")
+        self.assertEqual(TSIGError.NOERROR,
+                         self.tsig_ctx.verify(self.p.get_tsig_record(), data))
+        self.r.set_qid(0xd6e2)
+        self.__common_tsigmessage_setup([Message.HEADERFLAG_QR,
+                                         Message.HEADERFLAG_AA,
+                                         Message.HEADERFLAG_RD],
+                                        RRType("TXT"),
+                                        [LONG_TXT1, LONG_TXT4])
+        self.__common_tsig_checks("message_toWire5.wire")
+
+    def test_to_wire_tsig_length_errors(self):
+        renderer = MessageRenderer()
+        renderer.set_length_limit(84) # 84 = expected TSIG length - 1
+        self.__common_tsigmessage_setup()
+        self.assertRaises(TSIGContextError,
+                          self.r.to_wire, renderer, self.tsig_ctx)
+
+        renderer.clear()
+        self.r.clear(Message.RENDER)
+        renderer.set_length_limit(86) # 86 = expected TSIG length + 1
+        self.__common_tsigmessage_setup()
+        self.assertRaises(TSIGContextError,
+                          self.r.to_wire, renderer, self.tsig_ctx)
+
+        # skip the last test of the corresponding C++ test: it requires
+        # subclassing MessageRenderer, which is (currently) not possible
+        # for python.  In any case, it's very unlikely to happen in practice.
+
     def test_to_text(self):
         message_render = create_message()
         

+ 8 - 2
src/lib/dns/python/tests/question_python_test.py

@@ -74,7 +74,6 @@ class QuestionTest(unittest.TestCase):
         self.assertEqual("foo.example.com. IN NS\n", str(self.test_question1))
         self.assertEqual("bar.example.com. CH A\n", self.test_question2.to_text())
     
-    
     def test_to_wire_buffer(self):
         obuffer = bytes()
         obuffer = self.test_question1.to_wire(obuffer)
@@ -82,7 +81,6 @@ class QuestionTest(unittest.TestCase):
         wiredata = read_wire_data("question_toWire1")
         self.assertEqual(obuffer, wiredata)
     
-    
     def test_to_wire_renderer(self):
         renderer = MessageRenderer()
         self.test_question1.to_wire(renderer)
@@ -91,5 +89,13 @@ class QuestionTest(unittest.TestCase):
         self.assertEqual(renderer.get_data(), wiredata)
         self.assertRaises(TypeError, self.test_question1.to_wire, 1)
 
+    def test_to_wire_truncated(self):
+        renderer = MessageRenderer()
+        renderer.set_length_limit(self.example_name1.get_length())
+        self.assertFalse(renderer.is_truncated())
+        self.test_question1.to_wire(renderer)
+        self.assertTrue(renderer.is_truncated())
+        self.assertEqual(0, renderer.get_length())
+
 if __name__ == '__main__':
     unittest.main()

+ 9 - 0
src/lib/dns/question.cc

@@ -57,10 +57,19 @@ Question::toWire(OutputBuffer& buffer) const {
 
 unsigned int
 Question::toWire(AbstractMessageRenderer& renderer) const {
+    const size_t pos0 = renderer.getLength();
+
     renderer.writeName(name_);
     rrtype_.toWire(renderer);
     rrclass_.toWire(renderer);
 
+    // Make sure the renderer has a room for the question
+    if (renderer.getLength() > renderer.getLengthLimit()) {
+        renderer.trim(renderer.getLength() - pos0);
+        renderer.setTruncated();
+        return (0);
+    }
+
     return (1);                 // number of "entries"
 }
 

+ 8 - 8
src/lib/dns/question.h

@@ -201,23 +201,23 @@ public:
     /// class description).
     ///
     /// The owner name will be compressed if possible, although it's an
-    /// unlikely event in practice because the %Question section a DNS
+    /// unlikely event in practice because the Question section a DNS
     /// message normally doesn't contain multiple question entries and
     /// it's located right after the Header section.
     /// Nevertheless, \c renderer records the information of the owner name
     /// so that it can be pointed by other RRs in other sections (which is
     /// more likely to happen).
     ///
-    /// In theory, an attempt to render a Question may cause truncation
-    /// (when the Question section contains a large number of entries),
-    /// but this implementation doesn't catch that situation.
-    /// It would make the code unnecessarily complicated (though perhaps
-    /// slightly) for almost impossible case in practice.
-    /// An upper layer will handle the pathological case as a general error.
+    /// It could be possible, though very rare in practice, that
+    /// an attempt to render a Question may cause truncation
+    /// (when the Question section contains a large number of entries).
+    /// In such a case this method avoid the rendering and indicate the
+    /// truncation in the \c renderer.  This method returns 0 in this case.
     ///
     /// \param renderer DNS message rendering context that encapsulates the
     /// output buffer and name compression information.
-    /// \return 1
+    ///
+    /// \return 1 on success; 0 if it causes truncation
     unsigned int toWire(AbstractMessageRenderer& renderer) const;
 
     /// \brief Render the Question in the wire format without name compression.

+ 5 - 0
src/lib/dns/rrtype-placeholder.h

@@ -22,6 +22,11 @@
 
 #include <exceptions/exceptions.h>
 
+// Solaris x86 defines DS in <sys/regset.h>, which gets pulled in by Boost
+#if defined(__sun) && defined(DS)
+# undef DS
+#endif
+
 namespace isc {
 namespace util {
 class InputBuffer;

+ 216 - 10
src/lib/dns/tests/message_unittest.cc

@@ -62,7 +62,6 @@ using namespace isc::dns::rdata;
 //
 
 const uint16_t Message::DEFAULT_MAX_UDPSIZE;
-const Name test_name("test.example.com");
 
 namespace isc {
 namespace util {
@@ -79,7 +78,8 @@ const uint16_t TSIGContext::DEFAULT_FUDGE;
 namespace {
 class MessageTest : public ::testing::Test {
 protected:
-    MessageTest() : obuffer(0), renderer(obuffer),
+    MessageTest() : test_name("test.example.com"), obuffer(0),
+                    renderer(obuffer),
                     message_parse(Message::PARSE),
                     message_render(Message::RENDER),
                     bogus_section(static_cast<Message::Section>(
@@ -103,8 +103,9 @@ protected:
                                              "FAKEFAKEFAKEFAKE"));
         rrset_aaaa->addRRsig(rrset_rrsig);
     }
-    
+
     static Question factoryFromFile(const char* datafile);
+    const Name test_name;
     OutputBuffer obuffer;
     MessageRenderer renderer;
     Message message_parse;
@@ -114,17 +115,18 @@ protected:
     RRsetPtr rrset_aaaa;        // AAAA RRset with one RDATA with RRSIG
     RRsetPtr rrset_rrsig;       // RRSIG for the AAAA RRset
     TSIGContext tsig_ctx;
+    vector<unsigned char> received_data;
     vector<unsigned char> expected_data;
 
-    static void factoryFromFile(Message& message, const char* datafile);
+    void factoryFromFile(Message& message, const char* datafile);
 };
 
 void
 MessageTest::factoryFromFile(Message& message, const char* datafile) {
-    std::vector<unsigned char> data;
-    UnitTestUtil::readWireData(datafile, data);
+    received_data.clear();
+    UnitTestUtil::readWireData(datafile, received_data);
 
-    InputBuffer buffer(&data[0], data.size());
+    InputBuffer buffer(&received_data[0], received_data.size());
     message.fromWire(buffer);
 }
 
@@ -618,15 +620,43 @@ testGetTime() {
     return (NOW);
 }
 
+// bit-wise constant flags to configure DNS header flags for test
+// messages.
+const unsigned int QR_FLAG = 0x1;
+const unsigned int AA_FLAG = 0x2;
+const unsigned int RD_FLAG = 0x4;
+
 void
 commonTSIGToWireCheck(Message& message, MessageRenderer& renderer,
-                      TSIGContext& tsig_ctx, const char* const expected_file)
+                      TSIGContext& tsig_ctx, const char* const expected_file,
+                      unsigned int message_flags = RD_FLAG,
+                      RRType qtype = RRType::A(),
+                      const vector<const char*>* answer_data = NULL)
 {
     message.setOpcode(Opcode::QUERY());
     message.setRcode(Rcode::NOERROR());
-    message.setHeaderFlag(Message::HEADERFLAG_RD, true);
+    if ((message_flags & QR_FLAG) != 0) {
+        message.setHeaderFlag(Message::HEADERFLAG_QR);
+    }
+    if ((message_flags & AA_FLAG) != 0) {
+        message.setHeaderFlag(Message::HEADERFLAG_AA);
+    }
+    if ((message_flags & RD_FLAG) != 0) {
+        message.setHeaderFlag(Message::HEADERFLAG_RD);
+    }
     message.addQuestion(Question(Name("www.example.com"), RRClass::IN(),
-                                 RRType::A()));
+                                 qtype));
+
+    if (answer_data != NULL) {
+        RRsetPtr ans_rrset(new RRset(Name("www.example.com"), RRClass::IN(),
+                                     qtype, RRTTL(86400)));
+        for (vector<const char*>::const_iterator it = answer_data->begin();
+             it != answer_data->end();
+             ++it) {
+            ans_rrset->addRdata(createRdata(qtype, RRClass::IN(), *it));
+        }
+        message.addRRset(Message::SECTION_ANSWER, ans_rrset);
+    }
 
     message.toWire(renderer, tsig_ctx);
     vector<unsigned char> expected_data;
@@ -670,6 +700,182 @@ TEST_F(MessageTest, toWireWithEDNSAndTSIG) {
     }
 }
 
+// Some of the following tests involve truncation.  We use the query name
+// "www.example.com" and some TXT question/answers.  The length of the
+// header and question will be 33 bytes.  If we also try to include a
+// TSIG of the same key name (not compressed) with HMAC-MD5, the TSIG RR
+// will be 85 bytes.
+
+// A long TXT RDATA.  With a fully compressed owner name, the corresponding
+// RR will be 268 bytes.
+const char* const long_txt1 = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde";
+
+// With a fully compressed owner name, the corresponding RR will be 212 bytes.
+// It should result in truncation even without TSIG (33 + 268 + 212 = 513)
+const char* const long_txt2 = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456";
+
+// With a fully compressed owner name, the corresponding RR will be 127 bytes.
+// So, it can fit in the standard 512 bytes with txt1 and without TSIG, but
+// adding a TSIG would result in truncation (33 + 268 + 127 + 85 = 513)
+const char* const long_txt3 = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef01";
+
+// This is 1 byte shorter than txt3, which will result in a possible longest
+// message containing answer RRs and TSIG.
+const char* const long_txt4 = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0";
+
+// Example output generated by
+// "dig -y www.example.com:SFuWd/q99SzF8Yzd1QbB9g== www.example.com txt
+// QID: 0x22c2
+// Time Signed: 0x00004e179212
+TEST_F(MessageTest, toWireTSIGTruncation) {
+    isc::util::detail::gettimeFunction = testGetTime<0x4e179212>;
+
+    // Verify a validly signed query so that we can use the TSIG context
+
+    factoryFromFile(message_parse, "message_fromWire17.wire");
+    EXPECT_EQ(TSIGError::NOERROR(),
+              tsig_ctx.verify(message_parse.getTSIGRecord(),
+                              &received_data[0], received_data.size()));
+
+    message_render.setQid(0x22c2);
+    vector<const char*> answer_data;
+    answer_data.push_back(long_txt1);
+    answer_data.push_back(long_txt2);
+    {
+        SCOPED_TRACE("Message sign with TSIG and TC bit on");
+        commonTSIGToWireCheck(message_render, renderer, tsig_ctx,
+                              "message_toWire4.wire",
+                              QR_FLAG|AA_FLAG|RD_FLAG,
+                              RRType::TXT(), &answer_data);
+    }
+}
+
+TEST_F(MessageTest, toWireTSIGTruncation2) {
+    // Similar to the previous test, but without TSIG it wouldn't cause
+    // truncation.
+    isc::util::detail::gettimeFunction = testGetTime<0x4e179212>;
+    factoryFromFile(message_parse, "message_fromWire17.wire");
+    EXPECT_EQ(TSIGError::NOERROR(),
+              tsig_ctx.verify(message_parse.getTSIGRecord(),
+                              &received_data[0], received_data.size()));
+
+    message_render.setQid(0x22c2);
+    vector<const char*> answer_data;
+    answer_data.push_back(long_txt1);
+    answer_data.push_back(long_txt3);
+    {
+        SCOPED_TRACE("Message sign with TSIG and TC bit on (2)");
+        commonTSIGToWireCheck(message_render, renderer, tsig_ctx,
+                              "message_toWire4.wire",
+                              QR_FLAG|AA_FLAG|RD_FLAG,
+                              RRType::TXT(), &answer_data);
+    }
+}
+
+TEST_F(MessageTest, toWireTSIGTruncation3) {
+    // Similar to previous ones, but truncation occurs due to too many
+    // Questions (very unusual, but not necessarily illegal).
+
+    // We are going to create a message starting with a standard
+    // header (12 bytes) and multiple questions in the Question
+    // section of the same owner name (changing the RRType, just so
+    // that it would be the form that would be accepted by the BIND 9
+    // parser).  The first Question is 21 bytes in length, and the subsequent
+    // ones are 6 bytes.  We'll also use a TSIG whose size is 85 bytes.
+    // Up to 66 questions can fit in the standard 512-byte buffer
+    // (12 + 21 + 6 * 65 + 85 = 508).  If we try to add one more it would
+    // result in truncation.
+    message_render.setOpcode(Opcode::QUERY());
+    message_render.setRcode(Rcode::NOERROR());
+    for (int i = 1; i <= 67; ++i) {
+        message_render.addQuestion(Question(Name("www.example.com"),
+                                            RRClass::IN(), RRType(i)));
+    }
+    message_render.toWire(renderer, tsig_ctx);
+
+    // Check the rendered data by parsing it.  We only check it has the
+    // TC bit on, has the correct number of questions, and has a TSIG RR.
+    // Checking the signature wouldn't be necessary for this rare case
+    // scenario.
+    InputBuffer buffer(renderer.getData(), renderer.getLength());
+    message_parse.fromWire(buffer);
+    EXPECT_TRUE(message_parse.getHeaderFlag(Message::HEADERFLAG_TC));
+    // Note that the number of questions are 66, not 67 as we tried to add.
+    EXPECT_EQ(66, message_parse.getRRCount(Message::SECTION_QUESTION));
+    EXPECT_TRUE(message_parse.getTSIGRecord() != NULL);
+}
+
+TEST_F(MessageTest, toWireTSIGNoTruncation) {
+    // A boundary case that shouldn't cause truncation: the resulting
+    // response message with a TSIG will be 512 bytes long.
+    isc::util::detail::gettimeFunction = testGetTime<0x4e17b38d>;
+    factoryFromFile(message_parse, "message_fromWire18.wire");
+    EXPECT_EQ(TSIGError::NOERROR(),
+              tsig_ctx.verify(message_parse.getTSIGRecord(),
+                              &received_data[0], received_data.size()));
+
+    message_render.setQid(0xd6e2);
+    vector<const char*> answer_data;
+    answer_data.push_back(long_txt1);
+    answer_data.push_back(long_txt4);
+    {
+        SCOPED_TRACE("Message sign with TSIG, no truncation");
+        commonTSIGToWireCheck(message_render, renderer, tsig_ctx,
+                              "message_toWire5.wire",
+                              QR_FLAG|AA_FLAG|RD_FLAG,
+                              RRType::TXT(), &answer_data);
+    }
+}
+
+// This is a buggy renderer for testing.  It behaves like the straightforward
+// MessageRenderer, but once it has some data, its setLengthLimit() ignores
+// the given parameter and resets the limit to the current length, making
+// subsequent insertion result in truncation, which would make TSIG RR
+// rendering fail unexpectedly in the test that follows.
+class BadRenderer : public MessageRenderer {
+public:
+    BadRenderer(isc::util::OutputBuffer& buffer) :
+        MessageRenderer(buffer)
+    {}
+    virtual void setLengthLimit(size_t len) {
+        if (getLength() > 0) {
+            MessageRenderer::setLengthLimit(getLength());
+        } else {
+            MessageRenderer::setLengthLimit(len);
+        }
+    }
+};
+
+TEST_F(MessageTest, toWireTSIGLengthErrors) {
+    // specify an unusual short limit that wouldn't be able to hold
+    // the TSIG.
+    renderer.setLengthLimit(tsig_ctx.getTSIGLength() - 1);
+    // Use commonTSIGToWireCheck() only to call toWire() with otherwise valid
+    // conditions.  The checks inside it don't matter because we expect an
+    // exception before any of the checks.
+    EXPECT_THROW(commonTSIGToWireCheck(message_render, renderer, tsig_ctx,
+                                       "message_toWire2.wire"),
+                 InvalidParameter);
+
+    // This one is large enough for TSIG, but the remaining limit isn't
+    // even enough for the Header section.
+    renderer.clear();
+    message_render.clear(Message::RENDER);
+    renderer.setLengthLimit(tsig_ctx.getTSIGLength() + 1);
+    EXPECT_THROW(commonTSIGToWireCheck(message_render, renderer, tsig_ctx,
+                                       "message_toWire2.wire"),
+                 InvalidParameter);
+
+    // Trying to render a message with TSIG using a buggy renderer.
+    obuffer.clear();
+    BadRenderer bad_renderer(obuffer);
+    bad_renderer.setLengthLimit(512);
+    message_render.clear(Message::RENDER);
+    EXPECT_THROW(commonTSIGToWireCheck(message_render, bad_renderer, tsig_ctx,
+                                       "message_toWire2.wire"),
+                 Unexpected);
+}
+
 TEST_F(MessageTest, toWireWithoutOpcode) {
     message_render.setRcode(Rcode::NOERROR());
     EXPECT_THROW(message_render.toWire(renderer), InvalidMessageOperation);

+ 16 - 0
src/lib/dns/tests/question_unittest.cc

@@ -106,6 +106,22 @@ TEST_F(QuestionTest, toWireRenderer) {
                         obuffer.getLength(), &wiredata[0], wiredata.size());
 }
 
+TEST_F(QuestionTest, toWireTruncated) {
+    // If the available length in the renderer is too small, it would require
+    // truncation.  This won't happen in normal cases, but protocol wise it
+    // could still happen if and when we support some (possibly future) opcode
+    // that allows multiple questions.
+
+    // Set the length limit to the qname length so that the whole question
+    // would request truncated
+    renderer.setLengthLimit(example_name1.getLength());
+
+    EXPECT_FALSE(renderer.isTruncated()); // check pre-render condition
+    EXPECT_EQ(0, test_question1.toWire(renderer));
+    EXPECT_TRUE(renderer.isTruncated());
+    EXPECT_EQ(0, renderer.getLength()); // renderer shouldn't have any data
+}
+
 // test operator<<.  We simply confirm it appends the result of toText().
 TEST_F(QuestionTest, LeftShiftOperator) {
     ostringstream oss;

+ 5 - 1
src/lib/dns/tests/testdata/Makefile.am

@@ -5,8 +5,10 @@ BUILT_SOURCES += edns_toWire4.wire
 BUILT_SOURCES += message_fromWire10.wire message_fromWire11.wire
 BUILT_SOURCES += message_fromWire12.wire message_fromWire13.wire
 BUILT_SOURCES += message_fromWire14.wire message_fromWire15.wire
-BUILT_SOURCES += message_fromWire16.wire
+BUILT_SOURCES += message_fromWire16.wire message_fromWire17.wire
+BUILT_SOURCES += message_fromWire18.wire
 BUILT_SOURCES += message_toWire2.wire message_toWire3.wire
+BUILT_SOURCES += message_toWire4.wire message_toWire5.wire
 BUILT_SOURCES += message_toText1.wire message_toText2.wire
 BUILT_SOURCES += message_toText3.wire
 BUILT_SOURCES += name_toWire5.wire name_toWire6.wire
@@ -59,7 +61,9 @@ EXTRA_DIST += message_fromWire9 message_fromWire10.spec
 EXTRA_DIST += message_fromWire11.spec message_fromWire12.spec
 EXTRA_DIST += message_fromWire13.spec message_fromWire14.spec
 EXTRA_DIST += message_fromWire15.spec message_fromWire16.spec
+EXTRA_DIST += message_fromWire17.spec message_fromWire18.spec
 EXTRA_DIST += message_toWire1 message_toWire2.spec message_toWire3.spec
+EXTRA_DIST += message_toWire4.spec message_toWire5.spec
 EXTRA_DIST += message_toText1.txt message_toText1.spec
 EXTRA_DIST += message_toText2.txt message_toText2.spec
 EXTRA_DIST += message_toText3.txt message_toText3.spec

+ 5 - 7
src/lib/dns/tests/testdata/gen-wiredata.py.in

@@ -307,8 +307,8 @@ class SOA(RR):
                                                 self.retry, self.expire,
                                                 self.minimum))
 
-class TXT:
-    rdlen = -1                  # auto-calculate
+class TXT(RR):
+    rdlen = None                # auto-calculate
     nstring = 1                 # number of character-strings
     stringlen = -1              # default string length, auto-calculate
     string = 'Test String'      # default string
@@ -330,11 +330,9 @@ class TXT:
                 stringlen_list.append(self.stringlen)
             if stringlen_list[-1] < 0:
                 stringlen_list[-1] = int(len(wirestring_list[-1]) / 2)
-        rdlen = self.rdlen
-        if rdlen < 0:
-            rdlen = int(len(''.join(wirestring_list)) / 2) + self.nstring
-        f.write('\n# TXT RDATA (RDLEN=%d)\n' % rdlen)
-        f.write('%04x\n' % rdlen);
+        if self.rdlen is None:
+            self.rdlen = int(len(''.join(wirestring_list)) / 2) + self.nstring
+        self.dump_header(f, self.rdlen)
         for i in range(0, self.nstring):
             f.write('# String Len=%d, String=\"%s\"\n' %
                     (stringlen_list[i], string_list[i]))

+ 22 - 0
src/lib/dns/tests/testdata/message_fromWire17.spec

@@ -0,0 +1,22 @@
+#
+# A simple DNS query message with TSIG signed
+#
+
+[custom]
+sections: header:question:tsig
+[header]
+id: 0x22c2
+rd: 1
+arcount: 1
+[question]
+name: www.example.com
+rrtype: TXT
+[tsig]
+as_rr: True
+# TSIG QNAME won't be compressed
+rr_name: www.example.com
+algorithm: hmac-md5
+time_signed: 0x4e179212
+mac_size: 16
+mac: 0x8214b04634e32323d651ac60b08e6388
+original_id: 0x22c2

+ 23 - 0
src/lib/dns/tests/testdata/message_fromWire18.spec

@@ -0,0 +1,23 @@
+#
+# Another simple DNS query message with TSIG signed.  Only ID and time signed
+# (and MAC as a result) are different.
+#
+
+[custom]
+sections: header:question:tsig
+[header]
+id: 0xd6e2
+rd: 1
+arcount: 1
+[question]
+name: www.example.com
+rrtype: TXT
+[tsig]
+as_rr: True
+# TSIG QNAME won't be compressed
+rr_name: www.example.com
+algorithm: hmac-md5
+time_signed: 0x4e17b38d
+mac_size: 16
+mac: 0x903b5b194a799b03a37718820c2404f2
+original_id: 0xd6e2

+ 27 - 0
src/lib/dns/tests/testdata/message_toWire4.spec

@@ -0,0 +1,27 @@
+#
+# Truncated DNS response with TSIG signed
+# This is expected to be a response to "fromWire17"
+#
+
+[custom]
+sections: header:question:tsig
+[header]
+id: 0x22c2
+rd: 1
+qr: 1
+aa: 1
+# It's "truncated":
+tc: 1
+arcount: 1
+[question]
+name: www.example.com
+rrtype: TXT
+[tsig]
+as_rr: True
+# TSIG QNAME won't be compressed
+rr_name: www.example.com
+algorithm: hmac-md5
+time_signed: 0x4e179212
+mac_size: 16
+mac: 0x88adc3811d1d6bec7c684438906fc694
+original_id: 0x22c2

+ 36 - 0
src/lib/dns/tests/testdata/message_toWire5.spec

@@ -0,0 +1,36 @@
+#
+# A longest possible (without EDNS) DNS response with TSIG, i.e. totatl
+# length should be 512 bytes.
+#
+
+[custom]
+sections: header:question:txt/1:txt/2:tsig
+[header]
+id: 0xd6e2
+rd: 1
+qr: 1
+aa: 1
+ancount: 2
+arcount: 1
+[question]
+name: www.example.com
+rrtype: TXT
+[txt/1]
+as_rr: True
+# QNAME is fully compressed
+rr_name: ptr=12
+string: 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde
+[txt/2]
+as_rr: True
+# QNAME is fully compressed
+rr_name: ptr=12
+string: 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0
+[tsig]
+as_rr: True
+# TSIG QNAME won't be compressed
+rr_name: www.example.com
+algorithm: hmac-md5
+time_signed: 0x4e17b38d
+mac_size: 16
+mac: 0xbe2ba477373d2496891e2fda240ee4ec
+original_id: 0xd6e2

+ 72 - 0
src/lib/dns/tests/tsig_unittest.cc

@@ -927,4 +927,76 @@ TEST_F(TSIGTest, tooShortMAC) {
     }
 }
 
+TEST_F(TSIGTest, getTSIGLength) {
+    // Check for the most common case with various algorithms
+    // See the comment in TSIGContext::getTSIGLength() for calculation and
+    // parameter notation.
+    // The key name (www.example.com) is the same for most cases, where n1=17
+
+    // hmac-md5.sig-alg.reg.int.: n2=26, x=16
+    EXPECT_EQ(85, tsig_ctx->getTSIGLength());
+
+    // hmac-sha1: n2=11, x=20
+    tsig_ctx.reset(new TSIGContext(TSIGKey(test_name, TSIGKey::HMACSHA1_NAME(),
+                                           &dummy_data[0], 20)));
+    EXPECT_EQ(74, tsig_ctx->getTSIGLength());
+
+    // hmac-sha256: n2=13, x=32
+    tsig_ctx.reset(new TSIGContext(TSIGKey(test_name,
+                                           TSIGKey::HMACSHA256_NAME(),
+                                           &dummy_data[0], 32)));
+    EXPECT_EQ(88, tsig_ctx->getTSIGLength());
+
+    // hmac-sha224: n2=13, x=28
+    tsig_ctx.reset(new TSIGContext(TSIGKey(test_name,
+                                           TSIGKey::HMACSHA224_NAME(),
+                                           &dummy_data[0], 28)));
+    EXPECT_EQ(84, tsig_ctx->getTSIGLength());
+
+    // hmac-sha384: n2=13, x=48
+    tsig_ctx.reset(new TSIGContext(TSIGKey(test_name,
+                                           TSIGKey::HMACSHA384_NAME(),
+                                           &dummy_data[0], 48)));
+    EXPECT_EQ(104, tsig_ctx->getTSIGLength());
+
+    // hmac-sha512: n2=13, x=64
+    tsig_ctx.reset(new TSIGContext(TSIGKey(test_name,
+                                           TSIGKey::HMACSHA512_NAME(),
+                                           &dummy_data[0], 64)));
+    EXPECT_EQ(120, tsig_ctx->getTSIGLength());
+
+    // bad key case: n1=len(badkey.example.com)=20, n2=26, x=0
+    tsig_ctx.reset(new TSIGContext(badkey_name, TSIGKey::HMACMD5_NAME(),
+                                   keyring));
+    EXPECT_EQ(72, tsig_ctx->getTSIGLength());
+
+    // bad sig case: n1=17, n2=26, x=0
+    isc::util::detail::gettimeFunction = testGetTime<0x4da8877a>;
+    createMessageFromFile("message_toWire2.wire");
+    tsig_ctx.reset(new TSIGContext(TSIGKey(test_name, TSIGKey::HMACMD5_NAME(),
+                                           &dummy_data[0],
+                                           dummy_data.size())));
+    {
+        SCOPED_TRACE("Verify resulting in BADSIG");
+        commonVerifyChecks(*tsig_ctx, message.getTSIGRecord(),
+                           &received_data[0], received_data.size(),
+                           TSIGError::BAD_SIG(), TSIGContext::RECEIVED_REQUEST);
+    }
+    EXPECT_EQ(69, tsig_ctx->getTSIGLength());
+
+    // bad time case: n1=17, n2=26, x=16, y=6
+    isc::util::detail::gettimeFunction = testGetTime<0x4da8877a - 1000>;
+    tsig_ctx.reset(new TSIGContext(TSIGKey(test_name, TSIGKey::HMACMD5_NAME(),
+                                           &dummy_data[0],
+                                           dummy_data.size())));
+    {
+        SCOPED_TRACE("Verify resulting in BADTIME");
+        commonVerifyChecks(*tsig_ctx, message.getTSIGRecord(),
+                           &received_data[0], received_data.size(),
+                           TSIGError::BAD_TIME(),
+                           TSIGContext::RECEIVED_REQUEST);
+    }
+    EXPECT_EQ(91, tsig_ctx->getTSIGLength());
+}
+
 } // end namespace

+ 87 - 16
src/lib/dns/tsig.cc

@@ -58,10 +58,32 @@ getTSIGTime() {
 }
 
 struct TSIGContext::TSIGContextImpl {
-    TSIGContextImpl(const TSIGKey& key) :
-        state_(INIT), key_(key), error_(Rcode::NOERROR()),
-        previous_timesigned_(0)
-    {}
+    TSIGContextImpl(const TSIGKey& key,
+                    TSIGError error = TSIGError::NOERROR()) :
+        state_(INIT), key_(key), error_(error),
+        previous_timesigned_(0), digest_len_(0)
+    {
+        if (error == TSIGError::NOERROR()) {
+            // In normal (NOERROR) case, the key should be valid, and we
+            // should be able to pre-create a corresponding HMAC object,
+            // which will be likely to be used for sign or verify later.
+            // We do this in the constructor so that we can know the expected
+            // digest length in advance.  The creation should normally succeed,
+            // but the key information could be still broken, which could
+            // trigger an exception inside the cryptolink module.  We ignore
+            // it at this moment; a subsequent sign/verify operation will try
+            // to create the HMAC, which would also fail.
+            try {
+                hmac_.reset(CryptoLink::getCryptoLink().createHMAC(
+                                key_.getSecret(), key_.getSecretLength(),
+                                key_.getAlgorithm()),
+                            deleteHMAC);
+            } catch (const Exception&) {
+                return;
+            }
+            digest_len_ = hmac_->getOutputLength();
+        }
+    }
 
     // This helper method is used from verify().  It's expected to be called
     // just before verify() returns.  It updates internal state based on
@@ -85,6 +107,23 @@ struct TSIGContext::TSIGContextImpl {
         return (error);
     }
 
+    // A shortcut method to create an HMAC object for sign/verify.  If one
+    // has been successfully created in the constructor, return it; otherwise
+    // create a new one and return it.  In the former case, the ownership is
+    // transferred to the caller; the stored HMAC will be reset after the
+    // call.
+    HMACPtr createHMAC() {
+        if (hmac_) {
+            HMACPtr ret = HMACPtr();
+            ret.swap(hmac_);
+            return (ret);
+        }
+        return (HMACPtr(CryptoLink::getCryptoLink().createHMAC(
+                            key_.getSecret(), key_.getSecretLength(),
+                            key_.getAlgorithm()),
+                        deleteHMAC));
+    }
+
     // The following three are helper methods to compute the digest for
     // TSIG sign/verify in order to unify the common code logic for sign()
     // and verify() and to keep these callers concise.
@@ -111,6 +150,8 @@ struct TSIGContext::TSIGContextImpl {
     vector<uint8_t> previous_digest_;
     TSIGError error_;
     uint64_t previous_timesigned_; // only meaningful for response with BADTIME
+    size_t digest_len_;
+    HMACPtr hmac_;
 };
 
 void
@@ -221,8 +262,7 @@ TSIGContext::TSIGContext(const Name& key_name, const Name& algorithm_name,
         // be used in subsequent response with a TSIG indicating a BADKEY
         // error.
         impl_ = new TSIGContextImpl(TSIGKey(key_name, algorithm_name,
-                                            NULL, 0));
-        impl_->error_ = TSIGError::BAD_KEY();
+                                            NULL, 0), TSIGError::BAD_KEY());
     } else {
         impl_ = new TSIGContextImpl(*result.key);
     }
@@ -232,6 +272,45 @@ TSIGContext::~TSIGContext() {
     delete impl_;
 }
 
+size_t
+TSIGContext::getTSIGLength() const {
+    //
+    // The space required for an TSIG record is:
+    //
+    //	n1 bytes for the (key) name
+    //	2 bytes for the type
+    //	2 bytes for the class
+    //	4 bytes for the ttl
+    //	2 bytes for the rdlength
+    //	n2 bytes for the algorithm name
+    //	6 bytes for the time signed
+    //	2 bytes for the fudge
+    //	2 bytes for the MAC size
+    //	x bytes for the MAC
+    //	2 bytes for the original id
+    //	2 bytes for the error
+    //	2 bytes for the other data length
+    //	y bytes for the other data (at most)
+    // ---------------------------------
+    //     26 + n1 + n2 + x + y bytes
+    //
+
+    // Normally the digest length ("x") is the length of the underlying
+    // hash output.  If a key related error occurred, however, the
+    // corresponding TSIG will be "unsigned", and the digest length will be 0.
+    const size_t digest_len =
+        (impl_->error_ == TSIGError::BAD_KEY() ||
+         impl_->error_ == TSIGError::BAD_SIG()) ? 0 : impl_->digest_len_;
+
+    // Other Len ("y") is normally 0; if BAD_TIME error occurred, the
+    // subsequent TSIG will contain 48 bits of the server current time.
+    const size_t other_len = (impl_->error_ == TSIGError::BAD_TIME()) ? 6 : 0;
+
+    return (26 + impl_->key_.getKeyName().getLength() +
+            impl_->key_.getAlgorithmName().getLength() +
+            digest_len + other_len);
+}
+
 TSIGContext::State
 TSIGContext::getState() const {
     return (impl_->state_);
@@ -276,11 +355,7 @@ TSIGContext::sign(const uint16_t qid, const void* const data,
         return (tsig);
     }
 
-    HMACPtr hmac(CryptoLink::getCryptoLink().createHMAC(
-                     impl_->key_.getSecret(),
-                     impl_->key_.getSecretLength(),
-                     impl_->key_.getAlgorithm()),
-                 deleteHMAC);
+    HMACPtr hmac(impl_->createHMAC());
 
     // If the context has previous MAC (either the Request MAC or its own
     // previous MAC), digest it.
@@ -406,11 +481,7 @@ TSIGContext::verify(const TSIGRecord* const record, const void* const data,
         return (impl_->postVerifyUpdate(error, NULL, 0));
     }
 
-    HMACPtr hmac(CryptoLink::getCryptoLink().createHMAC(
-                     impl_->key_.getSecret(),
-                     impl_->key_.getSecretLength(),
-                     impl_->key_.getAlgorithm()),
-                 deleteHMAC);
+    HMACPtr hmac(impl_->createHMAC());
 
     // If the context has previous MAC (either the Request MAC or its own
     // previous MAC), digest it.

+ 21 - 0
src/lib/dns/tsig.h

@@ -353,6 +353,27 @@ public:
     TSIGError verify(const TSIGRecord* const record, const void* const data,
                      const size_t data_len);
 
+    /// Return the expected length of TSIG RR after \c sign()
+    ///
+    /// This method returns the length of the TSIG RR that would be
+    /// produced as a result of \c sign() with the state of the context
+    /// at the time of the call.  The expected length can be decided
+    /// from the key and the algorithm (which determines the MAC size if
+    /// included) and the recorded TSIG error.  Specifically, if a key
+    /// related error has been identified, the MAC will be excluded; if
+    /// a time error has occurred, the TSIG will include "other data".
+    ///
+    /// This method is provided mainly for the convenience of the Message
+    /// class, which needs to know the expected TSIG length in rendering a
+    /// signed DNS message so that it can handle truncated messages with TSIG
+    /// correctly.  Normal applications wouldn't need this method.  The Python
+    /// binding for this method won't be provided for the same reason.
+    ///
+    /// \exception None
+    ///
+    /// \return The expected TISG RR length in bytes
+    size_t getTSIGLength() const;
+
     /// Return the current state of the context
     ///
     /// \note

+ 1 - 0
src/lib/log/Makefile.am

@@ -20,6 +20,7 @@ liblog_la_SOURCES += logger_manager_impl.cc logger_manager_impl.h
 liblog_la_SOURCES += logger_name.cc logger_name.h
 liblog_la_SOURCES += logger_specification.h
 liblog_la_SOURCES += logger_support.cc logger_support.h
+liblog_la_SOURCES += logger_unittest_support.cc logger_unittest_support.h
 liblog_la_SOURCES += macros.h
 liblog_la_SOURCES += log_messages.cc log_messages.h
 liblog_la_SOURCES += message_dictionary.cc message_dictionary.h

+ 1 - 149
src/lib/log/logger_support.cc

@@ -12,28 +12,9 @@
 // OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
 // PERFORMANCE OF THIS SOFTWARE
 
-/// \brief Temporary Logger Support
-///
-/// Performs run-time initialization of the logger system.  In particular, it
-/// is passed information from the command line and:
-///
-/// a) Sets the severity of the messages being logged (and debug level if
-/// appropriate).
-/// b) Reads in the local message file is one has been supplied.
-///
-/// These functions will be replaced once the code has been written to obtain
-/// the logging parameters from the configuration database.
-
-#include <iostream>
-#include <algorithm>
-#include <iostream>
 #include <string>
-
-#include <log/logger_level.h>
-#include <log/logger_manager.h>
-#include <log/logger_specification.h>
 #include <log/logger_support.h>
-#include <log/output_option.h>
+#include <log/logger_manager.h>
 
 using namespace std;
 
@@ -42,77 +23,6 @@ namespace {
 // Flag to hold logging initialization state.
 bool logging_init_state = false;
 
-
-// Set logging destination according to the setting of B10_LOGGER_DESTINATION.
-// (See header for initLogger() for more details.)  This is a no-op if the
-// environment variable is not defined.
-//
-// \param root Name of the root logger
-// \param severity Severity level to be assigned to the root logger
-// \param dbglevel Debug level
-
-void
-setDestination(const char* root, const isc::log::Severity severity,
-               const int dbglevel) {
-
-    using namespace isc::log;
-
-    const char* destination = getenv("B10_LOGGER_DESTINATION");
-    if (destination != NULL) {
-
-        // Constants: not declared static as this is function is expected to be
-        // called once only
-        const string STDOUT = "stdout";
-        const string STDERR = "stderr";
-        const string SYSLOG = "syslog";
-        const string SYSLOG_COLON = "syslog:";
-
-        // Prepare the objects to define the logging specification
-        LoggerSpecification spec(root, severity, dbglevel);
-        OutputOption option;
-
-        // Set up output option according to destination specification
-        const string dest = destination;
-        if (dest == STDOUT) {
-            option.destination = OutputOption::DEST_CONSOLE;
-            option.stream = OutputOption::STR_STDOUT;
-
-        } else if (dest == STDERR) {
-            option.destination = OutputOption::DEST_CONSOLE;
-            option.stream = OutputOption::STR_STDERR;
-
-        } else if (dest == SYSLOG) {
-            option.destination = OutputOption::DEST_SYSLOG;
-            // Use default specified in OutputOption constructor for the
-            // syslog destination
-
-        } else if (dest.find(SYSLOG_COLON) == 0) {
-            option.destination = OutputOption::DEST_SYSLOG;
-            // Must take account of the string actually being "syslog:"
-            if (dest == SYSLOG_COLON) {
-                cerr << "**ERROR** value for B10_LOGGER_DESTINATION of " <<
-                        SYSLOG_COLON << " is invalid, " << SYSLOG <<
-                        " will be used instead\n";
-                // Use default for logging facility
-
-            } else {
-                // Everything else in the string is the facility name
-                option.facility = dest.substr(SYSLOG_COLON.size());
-            }
-
-        } else {
-            // Not a recognised destination, assume a file
-            option.destination = OutputOption::DEST_FILE;
-            option.filename = dest;
-        }
-
-        // ... and set the destination
-        spec.addOutputOption(option);
-        LoggerManager manager;
-        manager.process(spec);
-    }
-}
-
 } // Anonymous namespace
 
 namespace isc {
@@ -140,63 +50,5 @@ initLogger(const string& root, isc::log::Severity severity, int dbglevel,
     LoggerManager::init(root, severity, dbglevel, file);
 }
 
-// Logger Run-Time Initialization via Environment Variables
-void initLogger(isc::log::Severity severity, int dbglevel) {
-
-    // Root logger name is defined by the environment variable B10_LOGGER_ROOT.
-    // If not present, the name is "bind10".
-    const char* DEFAULT_ROOT = "bind10";
-    const char* root = getenv("B10_LOGGER_ROOT");
-    if (! root) {
-        root = DEFAULT_ROOT;
-    }
-
-    // Set the logging severity.  The environment variable is
-    // B10_LOGGER_SEVERITY, and can be one of "DEBUG", "INFO", "WARN", "ERROR"
-    // of "FATAL".  Note that the string must be in upper case with no leading
-    // of trailing blanks.
-    const char* sev_char = getenv("B10_LOGGER_SEVERITY");
-    if (sev_char) {
-        severity = isc::log::getSeverity(sev_char);
-    }
-
-    // If the severity is debug, get the debug level (environment variable
-    // B10_LOGGER_DBGLEVEL), which should be in the range 0 to 99.
-    if (severity == isc::log::DEBUG) {
-        const char* dbg_char = getenv("B10_LOGGER_DBGLEVEL");
-        if (dbg_char) {
-            int level = 0;
-            try {
-                level = boost::lexical_cast<int>(dbg_char);
-                if (level < MIN_DEBUG_LEVEL) {
-                    cerr << "**ERROR** debug level of " << level
-                         << " is invalid - a value of " << MIN_DEBUG_LEVEL
-                         << " will be used\n";
-                    level = MIN_DEBUG_LEVEL;
-                } else if (level > MAX_DEBUG_LEVEL) {
-                    cerr << "**ERROR** debug level of " << level
-                         << " is invalid - a value of " << MAX_DEBUG_LEVEL
-                         << " will be used\n";
-                    level = MAX_DEBUG_LEVEL;
-                }
-            } catch (...) {
-                // Error, but not fatal to the test
-                cerr << "**ERROR** Unable to translate "
-                        "B10_LOGGER_DBGLEVEL - a value of 0 will be used\n";
-            }
-            dbglevel = level;
-        }
-    }
-
-    // Set the local message file
-    const char* localfile = getenv("B10_LOGGER_LOCALMSG");
-
-    // Initialize logging
-    initLogger(root, severity, dbglevel, localfile);
-
-    // Now set the destination for logging output
-    setDestination(root, severity, dbglevel);
-}
-
 } // namespace log
 } // namespace isc

+ 9 - 53
src/lib/log/logger_support.h

@@ -19,6 +19,13 @@
 
 #include <string>
 #include <log/logger.h>
+#include <log/logger_unittest_support.h>
+
+/// \file
+/// \brief Logging initialization functions
+///
+/// Contains a set of functions relating to logging initialization that are
+/// used by the production code.
 
 namespace isc {
 namespace log {
@@ -33,17 +40,13 @@ namespace log {
 /// \return true if logging has been initialized, false if not
 bool isLoggingInitialized();
 
-/// \brief Set "logging initialized" flag
-///
-/// Sets the state of the "logging initialized" flag.
+/// \brief Set state of "logging initialized" flag
 ///
 /// \param state State to set the flag to. (This is expected to be "true" - the
 ///        default - for all code apart from specific unit tests.)
 void setLoggingInitialized(bool state = true);
 
-
-
-/// \brief Run-Time Initialization
+/// \brief Run-time initialization
 ///
 /// Performs run-time initialization of the logger in particular supplying:
 ///
@@ -62,54 +65,7 @@ void initLogger(const std::string& root,
                 isc::log::Severity severity = isc::log::INFO,
                 int dbglevel = 0, const char* file = NULL);
 
-
-/// \brief Run-Time Initialization from Environment
-///
-/// Performs run-time initialization of the logger via the setting of
-/// environment variables.  These are:
-///
-/// - B10_LOGGER_ROOT\n
-/// Name of the root logger.  If not given, the string "bind10" will be used.
-///
-/// - B10_LOGGER_SEVERITY\n
-/// Severity of messages that will be logged.  This must be one of the strings
-/// "DEBUG", "INFO", "WARN", "ERROR", "FATAL" or "NONE". (Must be upper case
-/// and must not contain leading or trailing spaces.)  If not specified (or if
-/// specified but incorrect), the default passed as argument to this function
-/// (currently INFO) will be used.
-///
-/// - B10_LOGGER_DBGLEVEL\n
-/// Ignored if the level is not DEBUG, this should be a number between 0 and
-/// 99 indicating the logging severity.  The default is 0.  If outside these
-/// limits or if not a number, The value passed to this function (default
-/// of 0) is used.
-///
-/// - B10_LOGGER_LOCALMSG\n
-/// If defined, the path specification of a file that contains message
-/// definitions replacing ones in the default dictionary.
-///
-/// - B10_LOGGER_DESTINATION\n
-/// If defined, the destination of the logging output.  This can be one of:
-///   - \c stdout Send output to stdout.
-///   - \c stderr Send output to stderr
-///   - \c syslog Send output to syslog using the facility local0.
-///   - \c syslog:xxx  Send output to syslog, using the facility xxx. ("xxx"
-///     should be one of the syslog facilities such as "local0".)  There must
-///     be a colon between "syslog" and "xxx
-///   - \c other Anything else is interpreted as the name of a file to which
-///     output is appended.  If the file does not exist, it is created.
-///
-/// Any errors in the settings cause messages to be output to stderr.
-///
-/// This function is aimed at test programs, allowing the default settings to
-/// be overridden by the tester.  It is not intended for use in production
-/// code.
-
-void initLogger(isc::log::Severity severity = isc::log::INFO,
-                int dbglevel = 0);
-
 } // namespace log
 } // namespace isc
 
-
 #endif // __LOGGER_SUPPORT_H

+ 175 - 0
src/lib/log/logger_unittest_support.cc

@@ -0,0 +1,175 @@
+// Copyright (C) 2011  Internet Systems Consortium, Inc. ("ISC")
+//
+// Permission to use, copy, modify, and/or distribute this software for any
+// purpose with or without fee is hereby granted, provided that the above
+// copyright notice and this permission notice appear in all copies.
+//
+// THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH
+// REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
+// AND FITNESS.  IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT,
+// INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
+// LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
+// OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
+// PERFORMANCE OF THIS SOFTWARE.
+
+#include <iostream>
+#include <algorithm>
+#include <string>
+
+#include <log/logger_level.h>
+#include <log/logger_name.h>
+#include <log/logger_manager.h>
+#include <log/logger_specification.h>
+#include <log/logger_unittest_support.h>
+#include <log/logger_support.h>
+#include <log/output_option.h>
+
+using namespace std;
+
+namespace isc {
+namespace log {
+
+// Get the logging severity.  This is defined by the environment variable
+// B10_LOGGER_SEVERITY, and can be one of "DEBUG", "INFO", "WARN", "ERROR"
+// of "FATAL".  (Note that the string must be in upper case with no leading
+// of trailing blanks.)  If not present, the default severity passed to the
+// function is returned.
+isc::log::Severity
+b10LoggerSeverity(isc::log::Severity defseverity) {
+    const char* sev_char = getenv("B10_LOGGER_SEVERITY");
+    if (sev_char) {
+        return (isc::log::getSeverity(sev_char));
+    }
+    return (defseverity);
+}
+
+// Get the debug level.  This is defined by the envornment variable
+// B10_LOGGER_DBGLEVEL.  If not defined, a default value passed to the function
+// is returned.
+int
+b10LoggerDbglevel(int defdbglevel) {
+    const char* dbg_char = getenv("B10_LOGGER_DBGLEVEL");
+    if (dbg_char) {
+        int level = 0;
+        try {
+            level = boost::lexical_cast<int>(dbg_char);
+            if (level < MIN_DEBUG_LEVEL) {
+                std::cerr << "**ERROR** debug level of " << level
+                          << " is invalid - a value of " << MIN_DEBUG_LEVEL
+                          << " will be used\n";
+                level = MIN_DEBUG_LEVEL;
+            } else if (level > MAX_DEBUG_LEVEL) {
+                std::cerr << "**ERROR** debug level of " << level
+                          << " is invalid - a value of " << MAX_DEBUG_LEVEL
+                          << " will be used\n";
+                level = MAX_DEBUG_LEVEL;
+            }
+        } catch (...) {
+            // Error, but not fatal to the test
+            std::cerr << "**ERROR** Unable to translate "
+                         "B10_LOGGER_DBGLEVEL - a value of 0 will be used\n";
+        }
+        return (level);
+    }
+
+    return (defdbglevel);
+}
+
+
+// Reset characteristics of the root logger to that set by the environment
+// variables B10_LOGGER_SEVERITY, B10_LOGGER_DBGLEVEL and B10_LOGGER_DESTINATION.
+
+void
+resetUnitTestRootLogger() {
+
+    using namespace isc::log;
+
+    // Constants: not declared static as this is function is expected to be
+    // called once only
+    const string DEVNULL = "/dev/null";
+    const string STDOUT = "stdout";
+    const string STDERR = "stderr";
+    const string SYSLOG = "syslog";
+    const string SYSLOG_COLON = "syslog:";
+
+    // Get the destination.  If not specified, assume /dev/null. (The default
+    // severity for unit tests is DEBUG, which generates a lot of output.
+    // Routing the logging to /dev/null will suppress that, whilst still
+    // ensuring that the code paths are tested.)
+    const char* destination = getenv("B10_LOGGER_DESTINATION");
+    const string dest((destination == NULL) ? DEVNULL : destination);
+
+    // Prepare the objects to define the logging specification
+    LoggerSpecification spec(getRootLoggerName(), 
+                             b10LoggerSeverity(isc::log::DEBUG),
+                             b10LoggerDbglevel(isc::log::MAX_DEBUG_LEVEL));
+    OutputOption option;
+
+    // Set up output option according to destination specification
+    if (dest == STDOUT) {
+        option.destination = OutputOption::DEST_CONSOLE;
+        option.stream = OutputOption::STR_STDOUT;
+
+    } else if (dest == STDERR) {
+        option.destination = OutputOption::DEST_CONSOLE;
+        option.stream = OutputOption::STR_STDERR;
+
+    } else if (dest == SYSLOG) {
+        option.destination = OutputOption::DEST_SYSLOG;
+        // Use default specified in OutputOption constructor for the
+        // syslog destination
+
+    } else if (dest.find(SYSLOG_COLON) == 0) {
+        option.destination = OutputOption::DEST_SYSLOG;
+        // Must take account of the string actually being "syslog:"
+        if (dest == SYSLOG_COLON) {
+            cerr << "**ERROR** value for B10_LOGGER_DESTINATION of " <<
+                    SYSLOG_COLON << " is invalid, " << SYSLOG <<
+                    " will be used instead\n";
+            // Use default for logging facility
+
+        } else {
+            // Everything else in the string is the facility name
+            option.facility = dest.substr(SYSLOG_COLON.size());
+        }
+
+    } else {
+        // Not a recognised destination, assume a file.
+        option.destination = OutputOption::DEST_FILE;
+        option.filename = dest;
+    }
+
+    // ... and set the destination
+    spec.addOutputOption(option);
+    LoggerManager manager;
+    manager.process(spec);
+}
+
+
+// Logger Run-Time Initialization via Environment Variables
+void initLogger(isc::log::Severity severity, int dbglevel) {
+
+    // Root logger name is defined by the environment variable B10_LOGGER_ROOT.
+    // If not present, the name is "bind10".
+    const char* DEFAULT_ROOT = "bind10";
+    const char* root = getenv("B10_LOGGER_ROOT");
+    if (! root) {
+        root = DEFAULT_ROOT;
+    }
+
+    // Set the local message file
+    const char* localfile = getenv("B10_LOGGER_LOCALMSG");
+
+    // Initialize logging
+    initLogger(root, isc::log::DEBUG, isc::log::MAX_DEBUG_LEVEL, localfile);
+
+    // Now set reset the output destination of the root logger, overriding
+    // the default severity, debug level and destination with those specified
+    // in the environment variables.  (The two-step approach is used as the
+    // setUnitTestRootLoggerCharacteristics() function is used in several
+    // places in the BIND 10 tests, and it avoid duplicating code.)
+    resetUnitTestRootLogger();
+} 
+
+} // namespace log
+} // namespace isc

+ 126 - 0
src/lib/log/logger_unittest_support.h

@@ -0,0 +1,126 @@
+// Copyright (C) 2011  Internet Systems Consortium, Inc. ("ISC")
+//
+// Permission to use, copy, modify, and/or distribute this software for any
+// purpose with or without fee is hereby granted, provided that the above
+// copyright notice and this permission notice appear in all copies.
+//
+// THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH
+// REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
+// AND FITNESS.  IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT,
+// INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
+// LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
+// OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
+// PERFORMANCE OF THIS SOFTWARE.
+
+#ifndef __LOGGER_UNITTEST_SUPPORT_H
+#define __LOGGER_UNITTEST_SUPPORT_H
+
+#include <string>
+#include <log/logger.h>
+
+/// \file
+/// \brief Miscellaneous logging functions used by the unit tests.
+///
+/// As the configuration database is unsually unavailable during unit tests,
+/// the functions defined here allow a limited amount of logging configuration
+/// through the use of environment variables
+
+namespace isc {
+namespace log {
+
+/// \brief Run-Time Initialization for Unit Tests from Environment
+///
+/// Performs run-time initialization of the logger via the setting of
+/// environment variables.  These are:
+///
+/// - B10_LOGGER_ROOT\n
+/// Name of the root logger.  If not given, the string "bind10" will be used.
+///
+/// - B10_LOGGER_SEVERITY\n
+/// Severity of messages that will be logged.  This must be one of the strings
+/// "DEBUG", "INFO", "WARN", "ERROR", "FATAL" or "NONE". (Must be upper case
+/// and must not contain leading or trailing spaces.)  If not specified (or if
+/// specified but incorrect), the default passed as argument to this function
+/// (currently DEBUG) will be used.
+///
+/// - B10_LOGGER_DBGLEVEL\n
+/// Ignored if the level is not DEBUG, this should be a number between 0 and
+/// 99 indicating the logging severity.  The default is 0.  If outside these
+/// limits or if not a number, The value passed to this function (default
+/// of MAX_DEBUG_LEVEL) is used.
+///
+/// - B10_LOGGER_LOCALMSG\n
+/// If defined, the path specification of a file that contains message
+/// definitions replacing ones in the default dictionary.
+///
+/// - B10_LOGGER_DESTINATION\n
+/// If defined, the destination of the logging output.  This can be one of:
+///   - \c stdout Send output to stdout.
+///   - \c stderr Send output to stderr
+///   - \c syslog Send output to syslog using the facility local0.
+///   - \c syslog:xxx  Send output to syslog, using the facility xxx. ("xxx"
+///     should be one of the syslog facilities such as "local0".)  There must
+///     be a colon between "syslog" and "xxx
+///   - \c other Anything else is interpreted as the name of a file to which
+///     output is appended.  If the file does not exist, it is created.
+///
+/// Any errors in the settings cause messages to be output to stderr.
+///
+/// This function is aimed at test programs, allowing the default settings to
+/// be overridden by the tester.  It is not intended for use in production
+/// code.
+///
+/// TODO: Rename. This function overloads the initLogger() function that can
+///       be used to initialize production programs.  This may lead to confusion.
+void initLogger(isc::log::Severity severity = isc::log::DEBUG,
+                int dbglevel = isc::log::MAX_DEBUG_LEVEL);
+
+
+/// \brief Obtains logging severity from B10_LOGGER_SEVERITY
+///
+/// Support function called by the unit test logging initialization code.
+/// It returns the logging severity defined by B10_LOGGER_SEVERITY.  If
+/// not defined it returns the default passed to it.
+///
+/// \param defseverity Default severity used if B10_LOGGER_SEVERITY is not
+//         defined.
+///
+/// \return Severity to use for the logging.
+isc::log::Severity b10LoggerSeverity(isc::log::Severity defseverity);
+
+
+/// \brief Obtains logging debug level from B10_LOGGER_DBGLEVEL
+///
+/// Support function called by the unit test logging initialization code.
+/// It returns the logging debug level defined by B10_LOGGER_DBGLEVEL.  If
+/// not defined, it returns the default passed to it.
+///
+/// N.B. If there is an error, a message is written to stderr and a value
+/// related to the error is used. (This is because (a) logging is not yet
+/// initialized, hence only the error stream is known to exist, and (b) this
+/// function is only used in unit test logging initialization, so incorrect
+/// selection of a level is not really an issue.)
+///
+/// \param defdbglevel Default debug level to be used if B10_LOGGER_DBGLEVEL
+///        is not defined.
+///
+/// \return Debug level to use.
+int b10LoggerDbglevel(int defdbglevel);
+
+
+/// \brief Reset root logger characteristics
+///
+/// This is a simplified interface into the resetting of the characteristics
+/// of the root logger.  It is aimed for use in unit tests and resets the
+/// characteristics of the root logger to use a severity, debug level and
+/// destination set by the environment variables B10_LOGGER_SEVERITY,
+/// B10_LOGGER_DBGLEVEL and B10_LOGGER_DESTINATION.
+void
+resetUnitTestRootLogger();
+
+} // namespace log
+} // namespace isc
+
+
+
+#endif // __LOGGER_UNITTEST_SUPPORT_H

+ 3 - 3
src/lib/log/tests/init_logger_test.sh.in

@@ -44,7 +44,7 @@ WARN  [bind10.log] LOG_BAD_STREAM bad log console output stream: warn
 ERROR [bind10.log] LOG_DUPLICATE_MESSAGE_ID duplicate message ID (error) in compiled code
 FATAL [bind10.log] LOG_NO_MESSAGE_ID line fatal: message definition line found without a message ID
 .
-B10_LOGGER_SEVERITY=DEBUG B10_LOGGER_DBGLEVEL=99 ./init_logger_test 2>&1 | \
+B10_LOGGER_DESTINATION=stdout B10_LOGGER_SEVERITY=DEBUG B10_LOGGER_DBGLEVEL=99 ./init_logger_test | \
     cut -d' ' -f3- | diff $tempfile -
 passfail $?
 
@@ -57,7 +57,7 @@ WARN  [bind10.log] LOG_BAD_STREAM bad log console output stream: warn
 ERROR [bind10.log] LOG_DUPLICATE_MESSAGE_ID duplicate message ID (error) in compiled code
 FATAL [bind10.log] LOG_NO_MESSAGE_ID line fatal: message definition line found without a message ID
 .
-B10_LOGGER_SEVERITY=DEBUG B10_LOGGER_DBGLEVEL=50 ./init_logger_test 2>&1 | \
+B10_LOGGER_DESTINATION=stdout B10_LOGGER_SEVERITY=DEBUG B10_LOGGER_DBGLEVEL=50 ./init_logger_test | \
     cut -d' ' -f3- | diff $tempfile -
 passfail $?
 
@@ -67,7 +67,7 @@ WARN  [bind10.log] LOG_BAD_STREAM bad log console output stream: warn
 ERROR [bind10.log] LOG_DUPLICATE_MESSAGE_ID duplicate message ID (error) in compiled code
 FATAL [bind10.log] LOG_NO_MESSAGE_ID line fatal: message definition line found without a message ID
 .
-B10_LOGGER_SEVERITY=WARN ./init_logger_test 2>&1 | \
+B10_LOGGER_DESTINATION=stdout B10_LOGGER_SEVERITY=WARN ./init_logger_test | \
     cut -d' ' -f3- | diff $tempfile -
 passfail $?
 

+ 5 - 2
src/lib/log/tests/logger_level_impl_unittest.cc

@@ -20,6 +20,7 @@
 #include <boost/lexical_cast.hpp>
 
 #include <log/logger_level_impl.h>
+#include <log/logger_support.h>
 #include <log4cplus/logger.h>
 
 using namespace isc::log;
@@ -27,8 +28,10 @@ using namespace std;
 
 class LoggerLevelImplTest : public ::testing::Test {
 protected:
-    LoggerLevelImplTest()
-    {}
+    LoggerLevelImplTest() {
+        // Ensure logging set to default for unit tests
+        resetUnitTestRootLogger();
+    }
 
     ~LoggerLevelImplTest()
     {}

+ 5 - 3
src/lib/log/tests/logger_level_unittest.cc

@@ -20,7 +20,7 @@
 #include <log/logger.h>
 #include <log/logger_manager.h>
 #include <log/log_messages.h>
-#include <log/logger_name.h>
+#include <log/logger_support.h>
 
 using namespace isc;
 using namespace isc::log;
@@ -29,7 +29,9 @@ using namespace std;
 class LoggerLevelTest : public ::testing::Test {
 protected:
     LoggerLevelTest() {
-        // Logger initialization is done in main()
+        // Logger initialization is done in main().  As logging tests may
+        // alter the default logging output, it is reset here.
+        resetUnitTestRootLogger();
     }
     ~LoggerLevelTest() {
         LoggerManager::reset();
@@ -57,7 +59,7 @@ TEST_F(LoggerLevelTest, Creation) {
     EXPECT_EQ(42, level3.dbglevel);
 }
 
-TEST(LoggerLevel, getSeverity) {
+TEST_F(LoggerLevelTest, getSeverity) {
     EXPECT_EQ(DEBUG, getSeverity("DEBUG"));
     EXPECT_EQ(DEBUG, getSeverity("debug"));
     EXPECT_EQ(DEBUG, getSeverity("DeBuG"));

+ 13 - 2
src/lib/log/tests/logger_support_unittest.cc

@@ -18,12 +18,23 @@
 
 using namespace isc::log;
 
+class LoggerSupportTest : public ::testing::Test {
+protected:
+    LoggerSupportTest() {
+        // Logger initialization is done in main().  As logging tests may
+        // alter the default logging output, it is reset here.
+        resetUnitTestRootLogger();
+    }
+    ~LoggerSupportTest() {
+    }
+};
+
 // Check that the initialized flag can be manipulated.  This is a bit chicken-
 // -and-egg: we want to reset to the flag to the original value at the end
 // of the test, so use the functions to do that.  But we are trying to check
 // that these functions in fact work.
 
-TEST(LoggerSupportTest, InitializedFlag) {
+TEST_F(LoggerSupportTest, InitializedFlag) {
     bool current_flag = isLoggingInitialized();
 
     // check we can flip the flag.
@@ -51,7 +62,7 @@ TEST(LoggerSupportTest, InitializedFlag) {
 // Check that a logger will throw an exception if logging has not been
 // initialized.
 
-TEST(LoggerSupportTest, LoggingInitializationCheck) {
+TEST_F(LoggerSupportTest, LoggingInitializationCheck) {
 
     // Assert that logging has been initialized (it should be in main()).
     bool current_flag = isLoggingInitialized();

+ 12 - 12
src/lib/python/isc/acl/Makefile.am

@@ -4,10 +4,10 @@ AM_CPPFLAGS = -I$(top_srcdir)/src/lib -I$(top_builddir)/src/lib
 AM_CPPFLAGS += $(BOOST_INCLUDES)
 AM_CXXFLAGS = $(B10_CXXFLAGS)
 
-python_PYTHON = __init__.py
+python_PYTHON = __init__.py dns.py
 pythondir = $(PYTHON_SITEPKG_DIR)/isc/acl
 
-pyexec_LTLIBRARIES = acl.la dns.la
+pyexec_LTLIBRARIES = acl.la _dns.la
 pyexecdir = $(PYTHON_SITEPKG_DIR)/isc/acl
 
 acl_la_SOURCES = acl.cc
@@ -15,14 +15,14 @@ acl_la_CPPFLAGS = $(AM_CPPFLAGS) $(PYTHON_INCLUDES)
 acl_la_LDFLAGS = $(PYTHON_LDFLAGS)
 acl_la_CXXFLAGS = $(AM_CXXFLAGS) $(PYTHON_CXXFLAGS)
 
-dns_la_SOURCES = dns.h dns.cc dns_requestacl_python.h dns_requestacl_python.cc
-dns_la_SOURCES += dns_requestcontext_python.h dns_requestcontext_python.cc
-dns_la_SOURCES += dns_requestloader_python.h dns_requestloader_python.cc
-dns_la_CPPFLAGS = $(AM_CPPFLAGS) $(PYTHON_INCLUDES)
-dns_la_LDFLAGS = $(PYTHON_LDFLAGS)
+_dns_la_SOURCES = dns.h dns.cc dns_requestacl_python.h dns_requestacl_python.cc
+_dns_la_SOURCES += dns_requestcontext_python.h dns_requestcontext_python.cc
+_dns_la_SOURCES += dns_requestloader_python.h dns_requestloader_python.cc
+_dns_la_CPPFLAGS = $(AM_CPPFLAGS) $(PYTHON_INCLUDES)
+_dns_la_LDFLAGS = $(PYTHON_LDFLAGS)
 # Note: PYTHON_CXXFLAGS may have some -Wno... workaround, which must be
 # placed after -Wextra defined in AM_CXXFLAGS
-dns_la_CXXFLAGS = $(AM_CXXFLAGS) $(PYTHON_CXXFLAGS)
+_dns_la_CXXFLAGS = $(AM_CXXFLAGS) $(PYTHON_CXXFLAGS)
 
 # Python prefers .so, while some OSes (specifically MacOS) use a different
 # suffix for dynamic objects.  -module is necessary to work this around.
@@ -30,11 +30,11 @@ acl_la_LDFLAGS += -module
 acl_la_LIBADD = $(top_builddir)/src/lib/acl/libacl.la
 acl_la_LIBADD += $(PYTHON_LIB)
 
-dns_la_LDFLAGS += -module
-dns_la_LIBADD = $(top_builddir)/src/lib/acl/libdnsacl.la
-dns_la_LIBADD += $(PYTHON_LIB)
+_dns_la_LDFLAGS += -module
+_dns_la_LIBADD = $(top_builddir)/src/lib/acl/libdnsacl.la
+_dns_la_LIBADD += $(PYTHON_LIB)
 
-EXTRA_DIST = acl.py dns.py
+EXTRA_DIST = acl.py _dns.py
 EXTRA_DIST += acl_inc.cc
 EXTRA_DIST += dnsacl_inc.cc dns_requestacl_inc.cc dns_requestcontext_inc.cc
 EXTRA_DIST += dns_requestloader_inc.cc

+ 29 - 0
src/lib/python/isc/acl/_dns.py

@@ -0,0 +1,29 @@
+# Copyright (C) 2011  Internet Systems Consortium.
+#
+# Permission to use, copy, modify, and distribute this software for any
+# purpose with or without fee is hereby granted, provided that the above
+# copyright notice and this permission notice appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
+# DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
+# INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
+# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
+# FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
+# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
+# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+# This file is not installed; The .so version will be installed into the right
+# place at installation time.
+# This helper script is only to find it in the .libs directory when we run
+# as a test or from the build directory.
+
+import os
+import sys
+
+for base in sys.path[:]:
+    bindingdir = os.path.join(base, 'isc/acl/.libs')
+    if os.path.exists(bindingdir):
+        sys.path.insert(0, bindingdir)
+
+from _dns import *

+ 2 - 2
src/lib/python/isc/acl/dns.cc

@@ -52,7 +52,7 @@ PyMethodDef methods[] = {
 
 PyModuleDef dnsacl = {
     { PyObject_HEAD_INIT(NULL) NULL, 0, NULL},
-    "isc.acl.dns",
+    "isc.acl._dns",
     dnsacl_doc,
     -1,
     methods,
@@ -90,7 +90,7 @@ getACLException(const char* ex_name) {
 }
 
 PyMODINIT_FUNC
-PyInit_dns(void) {
+PyInit__dns(void) {
     PyObject* mod = PyModule_Create(&dnsacl);
     if (mod == NULL) {
         return (NULL);

+ 58 - 18
src/lib/python/isc/acl/dns.py

@@ -13,21 +13,61 @@
 # NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
 # WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 
-# This file is not installed. The log.so is installed into the right place.
-# It is only to find it in the .libs directory when we run as a test or
-# from the build directory.
-# But as nobody gives us the builddir explicitly (and we can't use generation
-# from .in file, as it would put us into the builddir and we wouldn't be found)
-# we guess from current directory. Any idea for something better? This should
-# be enough for the tests, but would it work for B10_FROM_SOURCE as well?
-# Should we look there? Or define something in bind10_config?
-
-import os
-import sys
-
-for base in sys.path[:]:
-    bindingdir = os.path.join(base, 'isc/acl/.libs')
-    if os.path.exists(bindingdir):
-        sys.path.insert(0, bindingdir)
-
-from dns import *
+"""\
+This module provides Python bindings for the C++ classes in the
+isc::acl::dns namespace.  Specifically, it defines Python interfaces of
+handling access control lists (ACLs) with DNS related contexts.
+The actual binding is implemented in an effectively hidden module,
+isc.acl._dns; this frontend module is in terms of implementation so that
+the C++ binding code doesn't have to deal with complicated operations
+that could be done in a more straightforward way in native Python.
+
+For further details of the actual module, see the documentation of the
+_dns module.
+"""
+
+import pydnspp
+
+import isc.acl._dns
+from isc.acl._dns import *
+
+class RequestACL(isc.acl._dns.RequestACL):
+    """A straightforward wrapper subclass of isc.acl._dns.RequestACL.
+
+    See the base class documentation for more implementation.
+    """
+    pass
+
+class RequestLoader(isc.acl._dns.RequestLoader):
+    """A straightforward wrapper subclass of isc.acl._dns.RequestLoader.
+
+    See the base class documentation for more implementation.
+    """
+    pass
+
+class RequestContext(isc.acl._dns.RequestContext):
+    """A straightforward wrapper subclass of isc.acl._dns.RequestContext.
+
+    See the base class documentation for more implementation.
+    """
+
+    def __init__(self, remote_address, tsig=None):
+        """Wrapper for the RequestContext constructor.
+
+        Internal implementation details that the users don't have to
+        worry about: To avoid dealing with pydnspp bindings in the C++ code,
+        this wrapper converts the TSIG record in its wire format in the form
+        of byte data, and has the binding re-construct the record from it.
+        """
+        tsig_wire = b''
+        if tsig is not None:
+            if not isinstance(tsig, pydnspp.TSIGRecord):
+                raise TypeError("tsig must be a TSIGRecord, not %s" %
+                                tsig.__class__.__name__)
+            tsig_wire = tsig.to_wire(tsig_wire)
+        isc.acl._dns.RequestContext.__init__(self, remote_address, tsig_wire)
+
+    def __str__(self):
+        """Wrap __str__() to convert the module name."""
+        s = isc.acl._dns.RequestContext.__str__(self)
+        return s.replace('<isc.acl._dns', '<isc.acl.dns')

+ 2 - 2
src/lib/python/isc/acl/dns_requestacl_python.cc

@@ -114,7 +114,7 @@ namespace python {
 // Most of the functions are not actually implemented and NULL here.
 PyTypeObject requestacl_type = {
     PyVarObject_HEAD_INIT(NULL, 0)
-    "isc.acl.dns.RequestACL",
+    "isc.acl._dns.RequestACL",
     sizeof(s_RequestACL),                 // tp_basicsize
     0,                                  // tp_itemsize
     RequestACL_destroy,                // tp_dealloc
@@ -132,7 +132,7 @@ PyTypeObject requestacl_type = {
     NULL,                               // tp_getattro
     NULL,                               // tp_setattro
     NULL,                               // tp_as_buffer
-    Py_TPFLAGS_DEFAULT,                 // tp_flags
+    Py_TPFLAGS_DEFAULT|Py_TPFLAGS_BASETYPE, // tp_flags
     RequestACL_doc,
     NULL,                               // tp_traverse
     NULL,                               // tp_clear

+ 11 - 8
src/lib/python/isc/acl/dns_requestcontext_inc.cc

@@ -5,18 +5,18 @@ DNS request to be checked.\n\
 This plays the role of ACL context for the RequestACL object.\n\
 \n\
 Based on the minimalist philosophy, the initial implementation only\n\
-maintains the remote (source) IP address of the request. The plan is\n\
-to add more parameters of the request. A scheduled next step is to\n\
-support the TSIG key (if it's included in the request). Other\n\
-possibilities are the local (destination) IP address, the remote and\n\
-local port numbers, various fields of the DNS request (e.g. a\n\
-particular header flag value).\n\
+maintains the remote (source) IP address of the request and\n\
+(optionally) the TSIG record included in the request. We may add more\n\
+parameters of the request as we see the need for them. Possible\n\
+additional parameters are the local (destination) IP address, the\n\
+remote and local port numbers, various fields of the DNS request (e.g.\n\
+a particular header flag value).\n\
 \n\
-RequestContext(remote_address)\n\
+RequestContext(remote_address, tsig)\n\
 \n\
     In this initial implementation, the constructor only takes a\n\
     remote IP address in the form of a socket address as used in the\n\
-    Python socket module.\n\
+    Python socket module, and optionally a pydnspp.TSIGRecord object.\n\
 \n\
     Exceptions:\n\
       isc.acl.ACLError Normally shouldn't happen, but still possible\n\
@@ -25,6 +25,9 @@ RequestContext(remote_address)\n\
 \n\
     Parameters:\n\
       remote_address The remote IP address\n\
+      tsig   The TSIG record included in the request message, if any.\n\
+             If the request doesn't include a TSIG, this will be None.\n\
+             If this parameter is omitted None will be assumed.\n\
 \n\
 ";
 } // unnamed namespace

+ 96 - 33
src/lib/python/isc/acl/dns_requestcontext_python.cc

@@ -14,7 +14,7 @@
 
 // Enable this if you use s# variants with PyArg_ParseTuple(), see
 // http://docs.python.org/py3k/c-api/arg.html#strings-and-buffers
-//#define PY_SSIZE_T_CLEAN
+#define PY_SSIZE_T_CLEAN
 
 // Python.h needs to be placed at the head of the program file, see:
 // http://docs.python.org/py3k/extending/extending.html#a-simple-example
@@ -37,8 +37,16 @@
 
 #include <exceptions/exceptions.h>
 
+#include <util/buffer.h>
 #include <util/python/pycppwrapper_util.h>
 
+#include <dns/name.h>
+#include <dns/rrclass.h>
+#include <dns/rrtype.h>
+#include <dns/rrttl.h>
+#include <dns/rdata.h>
+#include <dns/tsigrecord.h>
+
 #include <acl/dns.h>
 #include <acl/ip_check.h>
 
@@ -49,6 +57,8 @@ using namespace std;
 using boost::scoped_ptr;
 using boost::lexical_cast;
 using namespace isc;
+using namespace isc::dns;
+using namespace isc::dns::rdata;
 using namespace isc::util::python;
 using namespace isc::acl::dns;
 using namespace isc::acl::dns::python;
@@ -59,11 +69,39 @@ namespace dns {
 namespace python {
 
 struct s_RequestContext::Data {
-    // The constructor.  Currently it only accepts the information of the
-    // request source address, and contains all necessary logic in the body
-    // of the constructor.  As it's extended we may have refactor it by
-    // introducing helper methods.
-    Data(const char* const remote_addr, const unsigned short remote_port) {
+    // The constructor.
+    Data(const char* const remote_addr, const unsigned short remote_port,
+         const char* tsig_data, const Py_ssize_t tsig_len)
+    {
+        createRemoteAddr(remote_addr, remote_port);
+        createTSIGRecord(tsig_data, tsig_len);
+    }
+
+    // A convenient type converter from sockaddr_storage to sockaddr
+    const struct sockaddr& getRemoteSockaddr() const {
+        const void* p = &remote_ss;
+        return (*static_cast<const struct sockaddr*>(p));
+    }
+
+    // The remote (source) IP address of the request.  Note that it needs
+    // a reference to remote_ss.  That's why the latter is stored within
+    // this structure.
+    scoped_ptr<IPAddress> remote_ipaddr;
+
+    // The effective length of remote_ss.  It's necessary for getnameinfo()
+    // called from sockaddrToText (__str__ backend).
+    socklen_t remote_salen;
+
+    // The TSIG record included in the request, if any.  If the request
+    // doesn't contain a TSIG, this will be NULL.
+    scoped_ptr<TSIGRecord> tsig_record;
+
+private:
+    // A helper method for the constructor that is responsible for constructing
+    // the remote address.
+    void createRemoteAddr(const char* const remote_addr,
+                          const unsigned short remote_port)
+    {
         struct addrinfo hints, *res;
         memset(&hints, 0, sizeof(hints));
         hints.ai_family = AF_UNSPEC;
@@ -85,20 +123,31 @@ struct s_RequestContext::Data {
         remote_ipaddr.reset(new IPAddress(getRemoteSockaddr()));
     }
 
-    // A convenient type converter from sockaddr_storage to sockaddr
-    const struct sockaddr& getRemoteSockaddr() const {
-        const void* p = &remote_ss;
-        return (*static_cast<const struct sockaddr*>(p));
-    }
-
-    // The remote (source) IP address the request.  Note that it needs
-    // a reference to remote_ss.  That's why the latter is stored within
-    // this structure.
-    scoped_ptr<IPAddress> remote_ipaddr;
+    // A helper method for the constructor that is responsible for constructing
+    // the request TSIG.
+    void createTSIGRecord(const char* tsig_data, const Py_ssize_t tsig_len) {
+        if (tsig_len == 0) {
+            return;
+        }
 
-    // The effective length of remote_ss.  It's necessary for getnameinf()
-    // called from sockaddrToText (__str__ backend).
-    socklen_t remote_salen;
+        // Re-construct the TSIG record from the passed binary.  This should
+        // normally succeed because we are generally expected to be called
+        // from the frontend .py, which converts a valid TSIGRecord in its
+        // wire format.  If some evil or buggy python program directly calls
+        // us with bogus data, validation in libdns++ will trigger an
+        // exception, which will be caught and converted to a Python exception
+        // in RequestContext_init().
+        isc::util::InputBuffer b(tsig_data, tsig_len);
+        const Name key_name(b);
+        const RRType tsig_type(b.readUint16());
+        const RRClass tsig_class(b.readUint16());
+        const RRTTL ttl(b.readUint32());
+        const size_t rdlen(b.readUint16());
+        const ConstRdataPtr rdata = createRdata(tsig_type, tsig_class, b,
+                                                rdlen);
+        tsig_record.reset(new TSIGRecord(key_name, tsig_class, ttl,
+                                         *rdata, 0));
+    }
 
 private:
     struct sockaddr_storage remote_ss;
@@ -145,31 +194,41 @@ RequestContext_init(PyObject* po_self, PyObject* args, PyObject*) {
     s_RequestContext* const self = static_cast<s_RequestContext*>(po_self);
 
     try {
-        // In this initial implementation, the constructor is simply: It
-        // takes a single parameter, which should be a Python socket address
-        // object.  For IPv4, it's ('address test', numeric_port); for IPv6,
+        // In this initial implementation, the constructor is simple: It
+        // takes two parameters.  The first parameter should be a Python
+        // socket address object.
+        // For IPv4, it's ('address test', numeric_port); for IPv6,
         // it's ('address text', num_port, num_flowid, num_zoneid).
+        // The second parameter is wire-format TSIG record in the form of
+        // Python byte data.  If the TSIG isn't included in the request,
+        // its length will be 0.
         // Below, we parse the argument in the most straightforward way.
         // As the constructor becomes more complicated, we should probably
         // make it more structural (for example, we should first retrieve
-        // the socket address as a PyObject, and parse it recursively)
+        // the python objects, and parse them recursively)
 
         const char* remote_addr;
         unsigned short remote_port;
         unsigned int remote_flowinfo; // IPv6 only, unused here
         unsigned int remote_zoneid; // IPv6 only, unused here
-
-        if (PyArg_ParseTuple(args, "(sH)", &remote_addr, &remote_port) ||
-            PyArg_ParseTuple(args, "(sHII)", &remote_addr, &remote_port,
-                             &remote_flowinfo, &remote_zoneid))
+        const char* tsig_data;
+        Py_ssize_t tsig_len;
+
+        if (PyArg_ParseTuple(args, "(sH)y#", &remote_addr, &remote_port,
+                             &tsig_data, &tsig_len) ||
+            PyArg_ParseTuple(args, "(sHII)y#", &remote_addr, &remote_port,
+                             &remote_flowinfo, &remote_zoneid,
+                             &tsig_data, &tsig_len))
         {
-            // We need to clear the error in case the first call to PareTuple
+            // We need to clear the error in case the first call to ParseTuple
             // fails.
             PyErr_Clear();
 
             auto_ptr<s_RequestContext::Data> dataptr(
-                new s_RequestContext::Data(remote_addr, remote_port));
-            self->cppobj = new RequestContext(*dataptr->remote_ipaddr);
+                new s_RequestContext::Data(remote_addr, remote_port,
+                                           tsig_data, tsig_len));
+            self->cppobj = new RequestContext(*dataptr->remote_ipaddr,
+                                              dataptr->tsig_record.get());
             self->data_ = dataptr.release();
             return (0);
         }
@@ -224,7 +283,11 @@ RequestContext_str(PyObject* po_self) {
         objss << "<" << requestcontext_type.tp_name << " object, "
               << "remote_addr="
               << sockaddrToText(self->data_->getRemoteSockaddr(),
-                                self->data_->remote_salen) << ">";
+                                self->data_->remote_salen);
+        if (self->data_->tsig_record) {
+            objss << ", key=" << self->data_->tsig_record->getName();
+        }
+        objss << ">";
         return (Py_BuildValue("s", objss.str().c_str()));
     } catch (const exception& ex) {
         const string ex_what =
@@ -248,7 +311,7 @@ namespace python {
 // Most of the functions are not actually implemented and NULL here.
 PyTypeObject requestcontext_type = {
     PyVarObject_HEAD_INIT(NULL, 0)
-    "isc.acl.dns.RequestContext",
+    "isc.acl._dns.RequestContext",
     sizeof(s_RequestContext),                 // tp_basicsize
     0,                                  // tp_itemsize
     RequestContext_destroy,             // tp_dealloc
@@ -266,7 +329,7 @@ PyTypeObject requestcontext_type = {
     NULL,                               // tp_getattro
     NULL,                               // tp_setattro
     NULL,                               // tp_as_buffer
-    Py_TPFLAGS_DEFAULT,                 // tp_flags
+    Py_TPFLAGS_DEFAULT|Py_TPFLAGS_BASETYPE, // tp_flags
     RequestContext_doc,
     NULL,                               // tp_traverse
     NULL,                               // tp_clear

+ 0 - 0
src/lib/python/isc/acl/dns_requestloader_python.cc


Some files were not shown because too many files changed in this diff