Browse Source

[1574b] added another factory method, which creates NSEC3HASH from NSEC3 RDATA.

JINMEI Tatuya 13 years ago
parent
commit
7af06f97cd

+ 12 - 5
src/lib/dns/nsec3hash.cc

@@ -56,10 +56,10 @@ private:
     static const uint8_t NSEC3_HASH_SHA1 = 1;
     static const uint8_t NSEC3_HASH_SHA1 = 1;
 
 
 public:
 public:
-    NSEC3HashRFC5155(const generic::NSEC3PARAM& param) :
-        algorithm_(param.getHashalg()),
-        iterations_(param.getIterations()),
-        salt_(param.getSalt()), digest_(SHA1_HASHSIZE), obuf_(Name::MAX_WIRE)
+    NSEC3HashRFC5155(uint8_t algorithm, uint16_t iterations,
+                     const vector<uint8_t>& salt) :
+        algorithm_(algorithm), iterations_(iterations),
+        salt_(salt), digest_(SHA1_HASHSIZE), obuf_(Name::MAX_WIRE)
     {
     {
         if (algorithm_ != NSEC3_HASH_SHA1) {
         if (algorithm_ != NSEC3_HASH_SHA1) {
             isc_throw(UnknownNSEC3HashAlgorithm, "Unknown NSEC3 algorithm: " <<
             isc_throw(UnknownNSEC3HashAlgorithm, "Unknown NSEC3 algorithm: " <<
@@ -134,7 +134,14 @@ namespace dns {
 
 
 NSEC3Hash*
 NSEC3Hash*
 NSEC3Hash::create(const generic::NSEC3PARAM& param) {
 NSEC3Hash::create(const generic::NSEC3PARAM& param) {
-    return (new NSEC3HashRFC5155(param));
+    return (new NSEC3HashRFC5155(param.getHashalg(), param.getIterations(),
+                                 param.getSalt()));
+}
+
+NSEC3Hash*
+NSEC3Hash::create(const generic::NSEC3& nsec3) {
+    return (new NSEC3HashRFC5155(nsec3.getHashalg(), nsec3.getIterations(),
+                                 nsec3.getSalt()));
 }
 }
 
 
 } // namespace dns
 } // namespace dns

+ 6 - 0
src/lib/dns/nsec3hash.h

@@ -109,6 +109,12 @@ public:
     /// \return A pointer to a concrete derived object of \c NSEC3Hash.
     /// \return A pointer to a concrete derived object of \c NSEC3Hash.
     static NSEC3Hash* create(const rdata::generic::NSEC3PARAM& param);
     static NSEC3Hash* create(const rdata::generic::NSEC3PARAM& param);
 
 
+    /// \brief Factory method of NSECHash from NSEC3 RDATA.
+    ///
+    /// This is similar to the other version, but extracts the parameters
+    /// for hash calculation from an NSEC3 RDATA object.
+    static NSEC3Hash* create(const rdata::generic::NSEC3& nsec3);
+
     /// \brief The destructor.
     /// \brief The destructor.
     virtual ~NSEC3Hash() {}
     virtual ~NSEC3Hash() {}
 
 

+ 15 - 4
src/lib/dns/python/nsec3hash_python.cc

@@ -55,13 +55,24 @@ NSEC3Hash_init(PyObject* po_self, PyObject* args, PyObject*) {
         if (PyArg_ParseTuple(args, "O", &po_rdata)) {
         if (PyArg_ParseTuple(args, "O", &po_rdata)) {
             if (!PyRdata_Check(po_rdata)) {
             if (!PyRdata_Check(po_rdata)) {
                 PyErr_Format(PyExc_TypeError,
                 PyErr_Format(PyExc_TypeError,
-                             "param must be an Rdata of type NSEC3HASH, "
+                             "param must be an Rdata of type NSEC3/NSEC3HASH, "
                              "not %.200s", po_rdata->ob_type->tp_name);
                              "not %.200s", po_rdata->ob_type->tp_name);
                 return (-1);
                 return (-1);
             }
             }
-            self->cppobj = NSEC3Hash::create(
-                dynamic_cast<const generic::NSEC3PARAM&>(
-                    PyRdata_ToRdata(po_rdata)));
+            const Rdata& rdata = PyRdata_ToRdata(po_rdata);
+            const generic::NSEC3PARAM* nsec3param =
+                dynamic_cast<const generic::NSEC3PARAM*>(&rdata);
+            const generic::NSEC3* nsec3 =
+                dynamic_cast<const generic::NSEC3*>(&rdata);
+            if (nsec3param != NULL) {
+                self->cppobj = NSEC3Hash::create(*nsec3param);
+            } else if (nsec3 != NULL) {
+                self->cppobj = NSEC3Hash::create(*nsec3);
+            } else {
+                PyErr_Format(PyExc_TypeError,
+                             "param must be an Rdata of type NSEC3/NSEC3HASH");
+                return (-1);
+            }
             return (0);
             return (0);
         }
         }
     } catch (const UnknownNSEC3HashAlgorithm& ex) {
     } catch (const UnknownNSEC3HashAlgorithm& ex) {

+ 3 - 2
src/lib/dns/python/nsec3hash_python_inc.cc

@@ -10,7 +10,7 @@ NSEC3 hash values as defined in RFC5155.\n\
 \n\
 \n\
 NSEC3Hash(param)\n\
 NSEC3Hash(param)\n\
 \n\
 \n\
-    Constructor from NSEC3PARAM RDATA.\n\
+    Constructor.\n\
 \n\
 \n\
     The hash algorithm given via param must be known to the\n\
     The hash algorithm given via param must be known to the\n\
     implementation. Otherwise UnknownNSEC3HashAlgorithm exception will\n\
     implementation. Otherwise UnknownNSEC3HashAlgorithm exception will\n\
@@ -21,7 +21,8 @@ NSEC3Hash(param)\n\
                  unknown.\n\
                  unknown.\n\
 \n\
 \n\
     Parameters:\n\
     Parameters:\n\
-      param      NSEC3 parameters used for subsequent calculation.\n\
+      param      NSEC3PARAM or NSEC3 Rdata object whose parameters are\n\
+                 to be used for subsequent calculation.\n\
 \n\
 \n\
 ";
 ";
 
 

+ 48 - 37
src/lib/dns/python/tests/nsec3hash_python_test.py

@@ -23,10 +23,12 @@ class NSEC3HashTest(unittest.TestCase):
     '''
     '''
 
 
     def setUp(self):
     def setUp(self):
+        self.nsec3_common = "2T7B4G4VSA5SMI47K61MV5BV1A22BOJR A RRSIG"
         self.test_hash = NSEC3Hash(Rdata(RRType.NSEC3PARAM(), RRClass.IN(),
         self.test_hash = NSEC3Hash(Rdata(RRType.NSEC3PARAM(), RRClass.IN(),
                                          "1 0 12 aabbccdd"))
                                          "1 0 12 aabbccdd"))
-        self.nsec3_common = "2T7B4G4VSA5SMI47K61MV5BV1A22BOJR A RRSIG"
-
+        self.test_hash_nsec3 = NSEC3Hash(Rdata(RRType.NSEC3(), RRClass.IN(),
+                                               "1 0 12 aabbccdd " +
+                                               self.nsec3_common))
     def test_bad_construct(self):
     def test_bad_construct(self):
         # missing parameter
         # missing parameter
         self.assertRaises(TypeError, NSEC3Hash)
         self.assertRaises(TypeError, NSEC3Hash)
@@ -39,28 +41,33 @@ class NSEC3HashTest(unittest.TestCase):
                                                       RRClass.IN(),
                                                       RRClass.IN(),
                                                       "1 0 12 aabbccdd"), 1)
                                                       "1 0 12 aabbccdd"), 1)
 
 
+        # Invaid type of RDATA
+        self.assertRaises(TypeError, NSEC3Hash, Rdata(RRType.A(), RRClass.IN(),
+                                                      "192.0.2.1"))
+
     def test_unknown_algorithm(self):
     def test_unknown_algorithm(self):
         self.assertRaises(UnknownNSEC3HashAlgorithm, NSEC3Hash,
         self.assertRaises(UnknownNSEC3HashAlgorithm, NSEC3Hash,
                           Rdata(RRType.NSEC3PARAM(), RRClass.IN(),
                           Rdata(RRType.NSEC3PARAM(), RRClass.IN(),
                                 "2 0 12 aabbccdd"))
                                 "2 0 12 aabbccdd"))
+        self.assertRaises(UnknownNSEC3HashAlgorithm, NSEC3Hash,
+                          Rdata(RRType.NSEC3(), RRClass.IN(),
+                                "2 0 12 aabbccdd " + self.nsec3_common))
 
 
-    def test_calculate(self):
+    def calculate_check(self, hash):
         # A couple of normal cases from the RFC5155 example.
         # A couple of normal cases from the RFC5155 example.
         self.assertEqual("0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM",
         self.assertEqual("0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM",
-                         self.test_hash.calculate(Name("example")))
+                         hash.calculate(Name("example")))
         self.assertEqual("35MTHGPGCU1QG68FAB165KLNSNK3DPVL",
         self.assertEqual("35MTHGPGCU1QG68FAB165KLNSNK3DPVL",
-                         self.test_hash.calculate(Name("a.example")))
+                         hash.calculate(Name("a.example")))
 
 
         # Check case-insensitiveness
         # Check case-insensitiveness
         self.assertEqual("0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM",
         self.assertEqual("0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM",
-                         self.test_hash.calculate(Name("EXAMPLE")))
+                         hash.calculate(Name("EXAMPLE")))
 
 
-        # Some boundary cases: 0-iteration and empty salt.  Borrowed from the
-        # .com zone data.
-        self.test_hash = NSEC3Hash(Rdata(RRType.NSEC3PARAM(),
-                                         RRClass.IN(),"1 0 0 -"))
-        self.assertEqual("CK0POJMG874LJREF7EFN8430QVIT8BSM",
-                         self.test_hash.calculate(Name("com")))
+
+    def test_calculate(self):
+        self.calculate_check(self.test_hash)
+        self.calculate_check(self.test_hash_nsec3)
 
 
         # Using unusually large iterations, something larger than the 8-bit
         # Using unusually large iterations, something larger than the 8-bit
         #range.  (expected hash value generated by BIND 9's dnssec-signzone)
         #range.  (expected hash value generated by BIND 9's dnssec-signzone)
@@ -69,42 +76,46 @@ class NSEC3HashTest(unittest.TestCase):
         self.assertEqual("COG6A52MJ96MNMV3QUCAGGCO0RHCC2Q3",
         self.assertEqual("COG6A52MJ96MNMV3QUCAGGCO0RHCC2Q3",
                          self.test_hash.calculate(Name("example.org")))
                          self.test_hash.calculate(Name("example.org")))
 
 
+        # Some boundary cases: 0-iteration and empty salt.  Borrowed from the
+        # .com zone data.
+        self.test_hash = NSEC3Hash(Rdata(RRType.NSEC3PARAM(),
+                                         RRClass.IN(),"1 0 0 -"))
+        self.assertEqual("CK0POJMG874LJREF7EFN8430QVIT8BSM",
+                         self.test_hash.calculate(Name("com")))
+
     def test_calculate_badparam(self):
     def test_calculate_badparam(self):
         self.assertRaises(TypeError, self.test_hash.calculate, "example")
         self.assertRaises(TypeError, self.test_hash.calculate, "example")
         self.assertRaises(TypeError, self.test_hash.calculate)
         self.assertRaises(TypeError, self.test_hash.calculate)
         self.assertRaises(TypeError, self.test_hash.calculate, Name("."), 1)
         self.assertRaises(TypeError, self.test_hash.calculate, Name("."), 1)
 
 
-    def test_match_with_nsec3(self):
+    def check_match_with_nsec3(self, hash):
         # If all parameters match, it's considered to be matched.
         # If all parameters match, it's considered to be matched.
-        self.assertTrue(self.test_hash.match(Rdata(RRType.NSEC3(),
-                                                   RRClass.IN(),
-                                                   "1 0 12 aabbccdd " +
-                                                   self.nsec3_common)))
+        self.assertTrue(hash.match(Rdata(RRType.NSEC3(), RRClass.IN(),
+                                         "1 0 12 aabbccdd " +
+                                         self.nsec3_common)))
         # Algorithm doesn't match
         # Algorithm doesn't match
-        self.assertFalse(self.test_hash.match(Rdata(RRType.NSEC3(),
-                                                    RRClass.IN(),
-                                                    "2 0 12 aabbccdd " +
-                                                    self.nsec3_common)))
+        self.assertFalse(hash.match(Rdata(RRType.NSEC3(), RRClass.IN(),
+                                          "2 0 12 aabbccdd " +
+                                          self.nsec3_common)))
         # Iterations doesn't match
         # Iterations doesn't match
-        self.assertFalse(self.test_hash.match(Rdata(RRType.NSEC3(),
-                                                    RRClass.IN(),
-                                                    "1 0 1 aabbccdd " +
-                                                    self.nsec3_common)))
+        self.assertFalse(hash.match(Rdata(RRType.NSEC3(), RRClass.IN(),
+                                          "1 0 1 aabbccdd " +
+                                          self.nsec3_common)))
         # Salt doesn't match
         # Salt doesn't match
-        self.assertFalse(self.test_hash.match(Rdata(RRType.NSEC3(),
-                                                    RRClass.IN(),
-                                                    "1 0 12 aabbccde " +
-                                                    self.nsec3_common)))
+        self.assertFalse(hash.match(Rdata(RRType.NSEC3(), RRClass.IN(),
+                                          "1 0 12 aabbccde " +
+                                          self.nsec3_common)))
         # Salt doesn't match: the other has an empty salt
         # Salt doesn't match: the other has an empty salt
-        self.assertFalse(self.test_hash.match(Rdata(RRType.NSEC3(),
-                                                    RRClass.IN(),
-                                                    "1 0 12 - " +
-                                                    self.nsec3_common)))
+        self.assertFalse(hash.match(Rdata(RRType.NSEC3(), RRClass.IN(),
+                                          "1 0 12 - " + self.nsec3_common)))
         # Flags doesn't matter
         # Flags doesn't matter
-        self.assertTrue(self.test_hash.match(Rdata(RRType.NSEC3(),
-                                                   RRClass.IN(),
-                                                   "1 1 12 aabbccdd " +
-                                                   self.nsec3_common)))
+        self.assertTrue(hash.match(Rdata(RRType.NSEC3(), RRClass.IN(),
+                                         "1 1 12 aabbccdd " +
+                                         self.nsec3_common)))
+
+    def test_match_with_nsec3(self):
+        self.check_match_with_nsec3(self.test_hash)
+        self.check_match_with_nsec3(self.test_hash_nsec3)
 
 
         # bad parameter checks
         # bad parameter checks
         self.assertRaises(TypeError, self.test_hash.match, 1)
         self.assertRaises(TypeError, self.test_hash.match, 1)

+ 55 - 18
src/lib/dns/tests/nsec3hash_unittest.cc

@@ -35,13 +35,19 @@ const char* const nsec3_common = "2T7B4G4VSA5SMI47K61MV5BV1A22BOJR A RRSIG";
 class NSEC3HashTest : public ::testing::Test {
 class NSEC3HashTest : public ::testing::Test {
 protected:
 protected:
     NSEC3HashTest() :
     NSEC3HashTest() :
-        test_hash(NSEC3Hash::create(generic::NSEC3PARAM("1 0 12 aabbccdd")))
+        test_hash(NSEC3Hash::create(generic::NSEC3PARAM("1 0 12 aabbccdd"))),
+        test_hash_nsec3(NSEC3Hash::create(generic::NSEC3
+                                          ("1 0 12 aabbccdd " +
+                                           string(nsec3_common))))
     {}
     {}
 
 
     // An NSEC3Hash object commonly used in tests.  Parameters are borrowed
     // An NSEC3Hash object commonly used in tests.  Parameters are borrowed
     // from the RFC5155 example.  Construction of this object implicitly
     // from the RFC5155 example.  Construction of this object implicitly
     // checks a successful case of the creation.
     // checks a successful case of the creation.
     NSEC3HashPtr test_hash;
     NSEC3HashPtr test_hash;
+
+    // Similar to test_hash, but created from NSEC3 RR.
+    NSEC3HashPtr test_hash_nsec3;
 };
 };
 
 
 TEST_F(NSEC3HashTest, unknownAlgorithm) {
 TEST_F(NSEC3HashTest, unknownAlgorithm) {
@@ -49,18 +55,36 @@ TEST_F(NSEC3HashTest, unknownAlgorithm) {
                      NSEC3Hash::create(
                      NSEC3Hash::create(
                          generic::NSEC3PARAM("2 0 12 aabbccdd"))),
                          generic::NSEC3PARAM("2 0 12 aabbccdd"))),
                      UnknownNSEC3HashAlgorithm);
                      UnknownNSEC3HashAlgorithm);
+    EXPECT_THROW(NSEC3HashPtr(
+                     NSEC3Hash::create(
+                         generic::NSEC3("2 0 12 aabbccdd " +
+                                        string(nsec3_common)))),
+                     UnknownNSEC3HashAlgorithm);
 }
 }
 
 
-TEST_F(NSEC3HashTest, calculate) {
+// Common checks for NSEC3 hash calculation
+void
+calculateCheck(NSEC3Hash& hash) {
     // A couple of normal cases from the RFC5155 example.
     // A couple of normal cases from the RFC5155 example.
     EXPECT_EQ("0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM",
     EXPECT_EQ("0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM",
-              test_hash->calculate(Name("example")));
+              hash.calculate(Name("example")));
     EXPECT_EQ("35MTHGPGCU1QG68FAB165KLNSNK3DPVL",
     EXPECT_EQ("35MTHGPGCU1QG68FAB165KLNSNK3DPVL",
-              test_hash->calculate(Name("a.example")));
+              hash.calculate(Name("a.example")));
 
 
     // Check case-insensitiveness
     // Check case-insensitiveness
     EXPECT_EQ("0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM",
     EXPECT_EQ("0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM",
-              test_hash->calculate(Name("EXAMPLE")));
+              hash.calculate(Name("EXAMPLE")));
+}
+
+TEST_F(NSEC3HashTest, calculate) {
+    {
+        SCOPED_TRACE("calculate check with NSEC3PARAM based hash");
+        calculateCheck(*test_hash);
+    }
+    {
+        SCOPED_TRACE("calculate check with NSEC3 based hash");
+        calculateCheck(*test_hash_nsec3);
+    }
 
 
     // Some boundary cases: 0-iteration and empty salt.  Borrowed from the
     // Some boundary cases: 0-iteration and empty salt.  Borrowed from the
     // .com zone data.
     // .com zone data.
@@ -76,25 +100,38 @@ TEST_F(NSEC3HashTest, calculate) {
               ->calculate(Name("example.org")));
               ->calculate(Name("example.org")));
 }
 }
 
 
-TEST_F(NSEC3HashTest, matchWithNSEC3) {
+// Common checks for match against NSEC3 parameters
+void
+matchWithNSEC3Check(NSEC3Hash& hash) {
     // If all parameters match, it's considered to be matched.
     // If all parameters match, it's considered to be matched.
-    EXPECT_TRUE(test_hash->match(generic::NSEC3("1 0 12 aabbccdd " +
-                                                string(nsec3_common))));
+    EXPECT_TRUE(hash.match(generic::NSEC3("1 0 12 aabbccdd " +
+                                          string(nsec3_common))));
     // Algorithm doesn't match
     // Algorithm doesn't match
-    EXPECT_FALSE(test_hash->match(generic::NSEC3("2 0 12 aabbccdd " +
-                                                 string(nsec3_common))));
+    EXPECT_FALSE(hash.match(generic::NSEC3("2 0 12 aabbccdd " +
+                                           string(nsec3_common))));
     // Iterations doesn't match
     // Iterations doesn't match
-    EXPECT_FALSE(test_hash->match(generic::NSEC3("1 0 1 aabbccdd " +
-                                                 string(nsec3_common))));
+    EXPECT_FALSE(hash.match(generic::NSEC3("1 0 1 aabbccdd " +
+                                           string(nsec3_common))));
     // Salt doesn't match
     // Salt doesn't match
-    EXPECT_FALSE(test_hash->match(generic::NSEC3("1 0 12 aabbccde " +
-                                                 string(nsec3_common))));
+    EXPECT_FALSE(hash.match(generic::NSEC3("1 0 12 aabbccde " +
+                                           string(nsec3_common))));
     // Salt doesn't match: the other has an empty salt
     // Salt doesn't match: the other has an empty salt
-    EXPECT_FALSE(test_hash->match(generic::NSEC3("1 0 12 - " +
-                                                 string(nsec3_common))));
+    EXPECT_FALSE(hash.match(generic::NSEC3("1 0 12 - " +
+                                           string(nsec3_common))));
     // Flags doesn't matter
     // Flags doesn't matter
-    EXPECT_TRUE(test_hash->match(generic::NSEC3("1 1 12 aabbccdd " +
-                                                 string(nsec3_common))));
+    EXPECT_TRUE(hash.match(generic::NSEC3("1 1 12 aabbccdd " +
+                                          string(nsec3_common))));
+}
+
+TEST_F(NSEC3HashTest, matchWithNSEC3) {
+    {
+        SCOPED_TRACE("match NSEC3PARAM based hash against NSEC3 parameters");
+        matchWithNSEC3Check(*test_hash);
+    }
+    {
+        SCOPED_TRACE("match NSEC3 based hash against NSEC3 parameters");
+        matchWithNSEC3Check(*test_hash_nsec3);
+    }
 }
 }
 
 
 } // end namespace
 } // end namespace