Browse Source

[trac1104] extended RequestContext to support TSIG keys. implemented
NameCheck<RequestContext>::matches using that support.

JINMEI Tatuya 13 years ago
parent
commit
fafb108c23

+ 19 - 3
src/lib/acl/dns.cc

@@ -20,6 +20,9 @@
 
 
 #include <exceptions/exceptions.h>
 #include <exceptions/exceptions.h>
 
 
+#include <dns/name.h>
+#include <dns/tsigrecord.h>
+
 #include <cc/data.h>
 #include <cc/data.h>
 
 
 #include <acl/dns.h>
 #include <acl/dns.h>
@@ -29,6 +32,7 @@
 
 
 using namespace std;
 using namespace std;
 using boost::shared_ptr;
 using boost::shared_ptr;
+using namespace isc::dns;
 using namespace isc::data;
 using namespace isc::data;
 
 
 namespace isc {
 namespace isc {
@@ -39,9 +43,6 @@ namespace acl {
 /// It returns \c true if the remote (source) IP address of the request
 /// It returns \c true if the remote (source) IP address of the request
 /// matches the expression encapsulated in the \c IPCheck, and returns
 /// matches the expression encapsulated in the \c IPCheck, and returns
 /// \c false if not.
 /// \c false if not.
-///
-/// \note The match logic is expected to be extended as we add
-/// more match parameters (at least there's a plan for TSIG key).
 template <>
 template <>
 bool
 bool
 IPCheck<dns::RequestContext>::matches(
 IPCheck<dns::RequestContext>::matches(
@@ -53,6 +54,16 @@ IPCheck<dns::RequestContext>::matches(
 
 
 namespace dns {
 namespace dns {
 
 
+/// The specialization of \c NameCheck for access control with
+/// \c RequestContext.
+///
+/// TBD
+template<>
+bool
+NameCheck<RequestContext>::matches(const RequestContext& request) const {
+    return (request.tsig != NULL && request.tsig->getName() == name_);
+}
+
 vector<string>
 vector<string>
 internal::RequestCheckCreator::names() const {
 internal::RequestCheckCreator::names() const {
     // Probably we should eventually build this vector in a more
     // Probably we should eventually build this vector in a more
@@ -60,6 +71,7 @@ internal::RequestCheckCreator::names() const {
     // everything.
     // everything.
     vector<string> supported_names;
     vector<string> supported_names;
     supported_names.push_back("from");
     supported_names.push_back("from");
+    supported_names.push_back("key");
     return (supported_names);
     return (supported_names);
 }
 }
 
 
@@ -77,6 +89,10 @@ internal::RequestCheckCreator::create(const string& name,
     if (name == "from") {
     if (name == "from") {
         return (shared_ptr<internal::RequestIPCheck>(
         return (shared_ptr<internal::RequestIPCheck>(
                     new internal::RequestIPCheck(definition->stringValue())));
                     new internal::RequestIPCheck(definition->stringValue())));
+    } else if (name == "key") {
+        return (shared_ptr<internal::RequestKeyCheck>(
+                    new internal::RequestKeyCheck(
+                        Name(definition->stringValue()))));
     } else {
     } else {
         // This case shouldn't happen (normally) as it should have been
         // This case shouldn't happen (normally) as it should have been
         // rejected at the loader level.  But we explicitly catch the case
         // rejected at the loader level.  But we explicitly catch the case

+ 12 - 2
src/lib/acl/dns.h

@@ -23,9 +23,13 @@
 #include <cc/data.h>
 #include <cc/data.h>
 
 
 #include <acl/ip_check.h>
 #include <acl/ip_check.h>
+#include <acl/dnsname_check.h>
 #include <acl/loader.h>
 #include <acl/loader.h>
 
 
 namespace isc {
 namespace isc {
+namespace dns {
+class TSIGRecord;
+}
 namespace acl {
 namespace acl {
 namespace dns {
 namespace dns {
 
 
@@ -68,8 +72,10 @@ struct RequestContext {
     /// \exception None
     /// \exception None
     ///
     ///
     /// \parameter remote_address_param The remote IP address
     /// \parameter remote_address_param The remote IP address
-    explicit RequestContext(const IPAddress& remote_address_param) :
-        remote_address(remote_address_param)
+    explicit RequestContext(const IPAddress& remote_address_param,
+                            const isc::dns::TSIGRecord* tsig_param) :
+        remote_address(remote_address_param),
+        tsig(tsig_param)
     {}
     {}
 
 
     ///
     ///
@@ -83,6 +89,9 @@ struct RequestContext {
     //@{
     //@{
     /// \brief The remote IP address (eg. the client's IP address).
     /// \brief The remote IP address (eg. the client's IP address).
     const IPAddress& remote_address;
     const IPAddress& remote_address;
+
+    /// TBD
+    const isc::dns::TSIGRecord* const tsig;
     //@}
     //@}
 };
 };
 
 
@@ -114,6 +123,7 @@ namespace internal {
 
 
 // Shortcut typedef
 // Shortcut typedef
 typedef isc::acl::IPCheck<RequestContext> RequestIPCheck;
 typedef isc::acl::IPCheck<RequestContext> RequestIPCheck;
+typedef isc::acl::dns::NameCheck<RequestContext> RequestKeyCheck;
 
 
 class RequestCheckCreator : public acl::Loader<RequestContext>::CheckCreator {
 class RequestCheckCreator : public acl::Loader<RequestContext>::CheckCreator {
 public:
 public:

+ 2 - 0
src/lib/acl/dnsname_check.h

@@ -41,6 +41,8 @@ public:
     /// \param context Information to be matched
     /// \param context Information to be matched
     virtual bool matches(const Context& context) const;
     virtual bool matches(const Context& context) const;
 
 
+    const isc::dns::Name& getName() const { return (name_); }
+
 private:
 private:
     const isc::dns::Name name_;
     const isc::dns::Name name_;
 };
 };

+ 1 - 0
src/lib/acl/tests/Makefile.am

@@ -16,6 +16,7 @@ run_unittests_SOURCES += acl_test.cc
 run_unittests_SOURCES += check_test.cc
 run_unittests_SOURCES += check_test.cc
 run_unittests_SOURCES += dns_test.cc
 run_unittests_SOURCES += dns_test.cc
 run_unittests_SOURCES += ip_check_unittest.cc
 run_unittests_SOURCES += ip_check_unittest.cc
+run_unittests_SOURCES += dnsname_check_unittest.cc
 run_unittests_SOURCES += loader_test.cc
 run_unittests_SOURCES += loader_test.cc
 run_unittests_SOURCES += logcheck.h
 run_unittests_SOURCES += logcheck.h
 run_unittests_SOURCES += creators.h
 run_unittests_SOURCES += creators.h

+ 76 - 10
src/lib/acl/tests/dns_test.cc

@@ -23,6 +23,11 @@
 
 
 #include <exceptions/exceptions.h>
 #include <exceptions/exceptions.h>
 
 
+#include <dns/name.h>
+#include <dns/tsigkey.h>
+#include <dns/tsigrecord.h>
+#include <dns/rdataclass.h>
+
 #include <cc/data.h>
 #include <cc/data.h>
 #include <acl/dns.h>
 #include <acl/dns.h>
 #include <acl/loader.h>
 #include <acl/loader.h>
@@ -35,6 +40,8 @@
 
 
 using namespace std;
 using namespace std;
 using boost::scoped_ptr;
 using boost::scoped_ptr;
+using namespace isc::dns;
+using namespace isc::dns::rdata;
 using namespace isc::data;
 using namespace isc::data;
 using namespace isc::acl;
 using namespace isc::acl;
 using namespace isc::acl::dns;
 using namespace isc::acl::dns;
@@ -64,8 +71,10 @@ protected:
 };
 };
 
 
 TEST_F(RequestCheckCreatorTest, names) {
 TEST_F(RequestCheckCreatorTest, names) {
-    ASSERT_EQ(1, creator_.names().size());
-    EXPECT_EQ("from", creator_.names()[0]);
+    const vector<string> names = creator_.names();
+    EXPECT_EQ(2, names.size());
+    EXPECT_TRUE(find(names.begin(), names.end(), "from") != names.end());
+    EXPECT_TRUE(find(names.begin(), names.end(), "key") != names.end());
 }
 }
 
 
 TEST_F(RequestCheckCreatorTest, allowListAbbreviation) {
 TEST_F(RequestCheckCreatorTest, allowListAbbreviation) {
@@ -93,11 +102,11 @@ TEST_F(RequestCheckCreatorTest, createIPv6Check) {
     check_ = creator_.create("from",
     check_ = creator_.create("from",
                              Element::fromJSON("\"2001:db8::5300/120\""),
                              Element::fromJSON("\"2001:db8::5300/120\""),
                              getRequestLoader());
                              getRequestLoader());
-    const dns::internal::RequestIPCheck& ipcheck_ =
+    const dns::internal::RequestIPCheck& ipcheck =
         dynamic_cast<const dns::internal::RequestIPCheck&>(*check_);
         dynamic_cast<const dns::internal::RequestIPCheck&>(*check_);
-    EXPECT_EQ(AF_INET6, ipcheck_.getFamily());
-    EXPECT_EQ(120, ipcheck_.getPrefixlen());
-    const vector<uint8_t> check_address(ipcheck_.getAddress());
+    EXPECT_EQ(AF_INET6, ipcheck.getFamily());
+    EXPECT_EQ(120, ipcheck.getPrefixlen());
+    const vector<uint8_t> check_address(ipcheck.getAddress());
     ASSERT_EQ(16, check_address.size());
     ASSERT_EQ(16, check_address.size());
     const uint8_t expected_address[] = { 0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00,
     const uint8_t expected_address[] = { 0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00,
                                          0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
                                          0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
@@ -106,6 +115,14 @@ TEST_F(RequestCheckCreatorTest, createIPv6Check) {
                       expected_address));
                       expected_address));
 }
 }
 
 
+TEST_F(RequestCheckCreatorTest, createTSIGKeyCheck) {
+    check_ = creator_.create("key", Element::fromJSON("\"key.example.com\""),
+                             getRequestLoader());
+    const dns::internal::RequestKeyCheck& keycheck =
+        dynamic_cast<const dns::internal::RequestKeyCheck&>(*check_);
+    EXPECT_EQ(Name("key.example.com"), keycheck.getName());
+}
+
 TEST_F(RequestCheckCreatorTest, badCreate) {
 TEST_F(RequestCheckCreatorTest, badCreate) {
     // Invalid name
     // Invalid name
     EXPECT_THROW(creator_.create("bad", Element::fromJSON("\"192.0.2.1\""),
     EXPECT_THROW(creator_.create("bad", Element::fromJSON("\"192.0.2.1\""),
@@ -118,12 +135,23 @@ TEST_F(RequestCheckCreatorTest, badCreate) {
     EXPECT_THROW(creator_.create("from", Element::fromJSON("[]"),
     EXPECT_THROW(creator_.create("from", Element::fromJSON("[]"),
                                  getRequestLoader()),
                                  getRequestLoader()),
                  isc::data::TypeError);
                  isc::data::TypeError);
+    EXPECT_THROW(creator_.create("key", Element::fromJSON("1"),
+                                 getRequestLoader()),
+                 isc::data::TypeError);
+    EXPECT_THROW(creator_.create("key", Element::fromJSON("{}"),
+                                 getRequestLoader()),
+                 isc::data::TypeError);
 
 
     // Syntax error for IPCheck
     // Syntax error for IPCheck
     EXPECT_THROW(creator_.create("from", Element::fromJSON("\"bad\""),
     EXPECT_THROW(creator_.create("from", Element::fromJSON("\"bad\""),
                                  getRequestLoader()),
                                  getRequestLoader()),
                  isc::InvalidParameter);
                  isc::InvalidParameter);
 
 
+    // Syntax error for Name (key) Check
+    EXPECT_THROW(creator_.create("key", Element::fromJSON("\"bad..name\""),
+                                 getRequestLoader()),
+                 EmptyLabel);
+
     // NULL pointer
     // NULL pointer
     EXPECT_THROW(creator_.create("from", ConstElementPtr(), getRequestLoader()),
     EXPECT_THROW(creator_.create("from", ConstElementPtr(), getRequestLoader()),
                  LoaderError);
                  LoaderError);
@@ -140,23 +168,43 @@ protected:
                                 getRequestLoader()));
                                 getRequestLoader()));
     }
     }
 
 
+    // A helper shortcut to create a single Name (key) check for the given
+    // name.
+    ConstRequestCheckPtr createKeyCheck(const string& key_name) {
+        return (creator_.create("key", Element::fromJSON(
+                                    string("\"") + key_name + string("\"")),
+                                getRequestLoader()));
+    }
+
     // create a one time request context for a specific test.  Note that
     // create a one time request context for a specific test.  Note that
     // getSockaddr() uses a static storage, so it cannot be called more than
     // getSockaddr() uses a static storage, so it cannot be called more than
     // once in a single test.
     // once in a single test.
-    const dns::RequestContext& getRequest4() {
+    const dns::RequestContext& getRequest4(const TSIGRecord* tsig = NULL) {
         ipaddr.reset(new IPAddress(tests::getSockAddr("192.0.2.1")));
         ipaddr.reset(new IPAddress(tests::getSockAddr("192.0.2.1")));
-        request.reset(new dns::RequestContext(*ipaddr));
+        request.reset(new dns::RequestContext(*ipaddr, tsig));
         return (*request);
         return (*request);
     }
     }
-    const dns::RequestContext& getRequest6() {
+    const dns::RequestContext& getRequest6(const TSIGRecord* tsig = NULL) {
         ipaddr.reset(new IPAddress(tests::getSockAddr("2001:db8::1")));
         ipaddr.reset(new IPAddress(tests::getSockAddr("2001:db8::1")));
-        request.reset(new dns::RequestContext(*ipaddr));
+        request.reset(new dns::RequestContext(*ipaddr, tsig));
         return (*request);
         return (*request);
     }
     }
 
 
+    // create a one time TSIG Record for a specific test.  The only parameter
+    // of the record that matters is the key name; others are hardcoded with
+    // arbitrarily chosen values.
+    const TSIGRecord* getTSIGRecord(const string& key_name) {
+        tsig_rdata.reset(new any::TSIG(TSIGKey::HMACMD5_NAME(), 0, 0, 0, NULL,
+                                       0, 0, 0, NULL));
+        tsig.reset(new TSIGRecord(Name(key_name), *tsig_rdata));
+        return (tsig.get());
+    }
+
 private:
 private:
     scoped_ptr<IPAddress> ipaddr;
     scoped_ptr<IPAddress> ipaddr;
     scoped_ptr<dns::RequestContext> request;
     scoped_ptr<dns::RequestContext> request;
+    scoped_ptr<any::TSIG> tsig_rdata;
+    scoped_ptr<TSIGRecord> tsig;
     dns::internal::RequestCheckCreator creator_;
     dns::internal::RequestCheckCreator creator_;
 };
 };
 
 
@@ -184,6 +232,24 @@ TEST_F(RequestCheckTest, checkIPv6) {
     EXPECT_FALSE(createIPCheck("32.1.13.184")->matches(getRequest6()));
     EXPECT_FALSE(createIPCheck("32.1.13.184")->matches(getRequest6()));
 }
 }
 
 
+TEST_F(RequestCheckTest, checkTSIGKey) {
+    EXPECT_TRUE(createKeyCheck("key.example.com")->matches(
+                    getRequest4(getTSIGRecord("key.example.com"))));
+    EXPECT_FALSE(createKeyCheck("key.example.com")->matches(
+                     getRequest4(getTSIGRecord("badkey.example.com"))));
+
+    // Same for IPv6 (which shouldn't matter)
+    EXPECT_TRUE(createKeyCheck("key.example.com")->matches(
+                    getRequest6(getTSIGRecord("key.example.com"))));
+    EXPECT_FALSE(createKeyCheck("key.example.com")->matches(
+                     getRequest6(getTSIGRecord("badkey.example.com"))));
+
+    // by default the test request doesn't have a TSIG key, which shouldn't
+    // match any key checks.
+    EXPECT_FALSE(createKeyCheck("key.example.com")->matches(getRequest4()));
+    EXPECT_FALSE(createKeyCheck("key.example.com")->matches(getRequest6()));
+}
+
 // The following tests test only the creators are registered, they are tested
 // The following tests test only the creators are registered, they are tested
 // elsewhere
 // elsewhere
 
 

+ 10 - 0
src/lib/acl/tests/dnsname_check_unittest.cc

@@ -34,6 +34,16 @@ bool NameCheck<Name>::matches(const Name& name) const {
 } // namespace isc
 } // namespace isc
 
 
 namespace {
 namespace {
+TEST(DNSNameCheck, construct) {
+    EXPECT_EQ(Name("example.com"),
+              NameCheck<Name>(Name("example.com")).getName());
+
+    // Construct the same check with an explicit trailing dot.  Should result
+    // in the same result.
+    EXPECT_EQ(Name("example.com"),
+              NameCheck<Name>(Name("example.com.")).getName());
+}
+
 TEST(DNSNameCheck, match) {
 TEST(DNSNameCheck, match) {
     NameCheck<Name> check(Name("example.com"));
     NameCheck<Name> check(Name("example.com"));
     EXPECT_TRUE(check.matches(Name("example.com")));
     EXPECT_TRUE(check.matches(Name("example.com")));