Browse Source

[1574b] added another version of NSEC3Hash::match, which takes NSEC3PARAM

JINMEI Tatuya 13 years ago
parent
commit
dc8a5a5b10

+ 20 - 5
src/lib/dns/nsec3hash.cc

@@ -71,6 +71,9 @@ public:
     virtual std::string calculate(const Name& name) const;
     virtual std::string calculate(const Name& name) const;
 
 
     virtual bool match(const generic::NSEC3& nsec3) const;
     virtual bool match(const generic::NSEC3& nsec3) const;
+    virtual bool match(const generic::NSEC3PARAM& nsec3param) const;
+    bool match(uint8_t algorithm, uint16_t iterations,
+               const vector<uint8_t>& salt) const;
 
 
 private:
 private:
     const uint8_t algorithm_;
     const uint8_t algorithm_;
@@ -120,12 +123,24 @@ NSEC3HashRFC5155::calculate(const Name& name) const {
 }
 }
 
 
 bool
 bool
+NSEC3HashRFC5155::match(uint8_t algorithm, uint16_t iterations,
+                        const vector<uint8_t>& salt) const
+{
+    return (algorithm_ == algorithm && iterations_ == iterations &&
+            salt_.size() == salt.size() &&
+            (salt_.empty() || memcmp(&salt_[0], &salt[0], salt_.size()) == 0));
+}
+
+bool
 NSEC3HashRFC5155::match(const generic::NSEC3& nsec3) const {
 NSEC3HashRFC5155::match(const generic::NSEC3& nsec3) const {
-    return (algorithm_ == nsec3.getHashalg() &&
-            iterations_ == nsec3.getIterations() &&
-            salt_.size() == nsec3.getSalt().size() &&
-            (salt_.empty() ||
-             memcmp(&salt_[0], &nsec3.getSalt()[0], salt_.size()) == 0));
+    return (match(nsec3.getHashalg(), nsec3.getIterations(),
+                  nsec3.getSalt()));
+}
+
+bool
+NSEC3HashRFC5155::match(const generic::NSEC3PARAM& nsec3param) const {
+    return (match(nsec3param.getHashalg(), nsec3param.getIterations(),
+                  nsec3param.getSalt()));
 }
 }
 } // end of unnamed namespace
 } // end of unnamed namespace
 
 

+ 6 - 0
src/lib/dns/nsec3hash.h

@@ -146,6 +146,12 @@ public:
     /// \return true If the given parameters match the local ones; false
     /// \return true If the given parameters match the local ones; false
     /// otherwise.
     /// otherwise.
     virtual bool match(const rdata::generic::NSEC3& nsec3) const = 0;
     virtual bool match(const rdata::generic::NSEC3& nsec3) const = 0;
+
+    /// \brief Match given NSEC3PARAM parameters with that of the hash.
+    ///
+    /// This is similar to the other version, but extracts the parameters
+    /// to compare from an NSEC3PARAM RDATA object.
+    virtual bool match(const rdata::generic::NSEC3PARAM& nsec3param) const = 0;
 };
 };
 
 
 }
 }

+ 19 - 7
src/lib/dns/python/nsec3hash_python.cc

@@ -55,8 +55,8 @@ NSEC3Hash_init(PyObject* po_self, PyObject* args, PyObject*) {
         if (PyArg_ParseTuple(args, "O", &po_rdata)) {
         if (PyArg_ParseTuple(args, "O", &po_rdata)) {
             if (!PyRdata_Check(po_rdata)) {
             if (!PyRdata_Check(po_rdata)) {
                 PyErr_Format(PyExc_TypeError,
                 PyErr_Format(PyExc_TypeError,
-                             "param must be an Rdata of type NSEC3/NSEC3HASH, "
-                             "not %.200s", po_rdata->ob_type->tp_name);
+                             "param must be an Rdata of type NSEC3/NSEC3PARAM,"
+                             " not %.200s", po_rdata->ob_type->tp_name);
                 return (-1);
                 return (-1);
             }
             }
             const Rdata& rdata = PyRdata_ToRdata(po_rdata);
             const Rdata& rdata = PyRdata_ToRdata(po_rdata);
@@ -138,13 +138,25 @@ NSEC3Hash_match(PyObject* po_self, PyObject* args) {
         if (PyArg_ParseTuple(args, "O", &po_rdata)) {
         if (PyArg_ParseTuple(args, "O", &po_rdata)) {
             if (!PyRdata_Check(po_rdata)) {
             if (!PyRdata_Check(po_rdata)) {
                 PyErr_Format(PyExc_TypeError,
                 PyErr_Format(PyExc_TypeError,
-                             "param must be an Rdata of type NSEC3, "
-                             "not %.200s", po_rdata->ob_type->tp_name);
+                             "param must be an Rdata of type NSEC3/NSEC3PARAM,"
+                             " not %.200s", po_rdata->ob_type->tp_name);
+                return (NULL);
+            }
+            const Rdata& rdata = PyRdata_ToRdata(po_rdata);
+            const generic::NSEC3PARAM* nsec3param =
+                dynamic_cast<const generic::NSEC3PARAM*>(&rdata);
+            const generic::NSEC3* nsec3 =
+                dynamic_cast<const generic::NSEC3*>(&rdata);
+            bool matched;
+            if (nsec3param != NULL) {
+                matched = self->cppobj->match(*nsec3param);
+            } else if (nsec3 != NULL) {
+                matched = self->cppobj->match(*nsec3);
+            } else {
+                PyErr_Format(PyExc_TypeError,
+                             "param must be an Rdata of type NSEC3/NSEC3HASH");
                 return (NULL);
                 return (NULL);
             }
             }
-            const bool matched = self->cppobj->match(
-                dynamic_cast<const generic::NSEC3&>(
-                    PyRdata_ToRdata(po_rdata)));
             PyObject* ret = matched ? Py_True : Py_False;
             PyObject* ret = matched ? Py_True : Py_False;
             Py_INCREF(ret);
             Py_INCREF(ret);
             return (ret);
             return (ret);

+ 5 - 5
src/lib/dns/python/nsec3hash_python_inc.cc

@@ -45,12 +45,12 @@ Return Value(s): Base32hex-encoded string of the hash value.\n\
 ";
 ";
 
 
 const char* const NSEC3Hash_match_doc = "\
 const char* const NSEC3Hash_match_doc = "\
-match(nsec3) -> bool\n                   \
+match(rdata) -> bool\n                   \
 \n\
 \n\
-Match given NSEC3 parameters with that of the hash.\n\
+Match given NSEC3 or NSEC3PARAM parameters with that of the hash.\n\
 \n\
 \n\
 This method compares NSEC3 parameters used for hash calculation in the\n\
 This method compares NSEC3 parameters used for hash calculation in the\n\
-object with those in the given NSEC3 RDATA, and return true iff they\n\
+object with those in the given RDATA, and return true iff they\n\
 completely match. In the current implementation only the algorithm,\n\
 completely match. In the current implementation only the algorithm,\n\
 iterations and salt are compared; the flags are ignored (as they don't\n\
 iterations and salt are compared; the flags are ignored (as they don't\n\
 affect hash calculation per RFC5155).\n\
 affect hash calculation per RFC5155).\n\
@@ -59,8 +59,8 @@ Exceptions:\n\
   None\n\
   None\n\
 \n\
 \n\
 Parameters:\n\
 Parameters:\n\
-  nsec3      An NSEC3 RDATA object whose hash parameters are to be\n\
-             matched\n\
+  rdata      An NSEC3 or NSEC3PARAM Rdata object whose hash parameters\n\
+             are to be matched\n\
 \n\
 \n\
 Return Value(s): true If the given parameters match the local ones;\n\
 Return Value(s): true If the given parameters match the local ones;\n\
 false otherwise.\n\
 false otherwise.\n\

+ 21 - 23
src/lib/dns/python/tests/nsec3hash_python_test.py

@@ -88,42 +88,40 @@ class NSEC3HashTest(unittest.TestCase):
         self.assertRaises(TypeError, self.test_hash.calculate)
         self.assertRaises(TypeError, self.test_hash.calculate)
         self.assertRaises(TypeError, self.test_hash.calculate, Name("."), 1)
         self.assertRaises(TypeError, self.test_hash.calculate, Name("."), 1)
 
 
-    def check_match_with_nsec3(self, hash):
+    def check_match(self, hash, rrtype, postfix):
         # If all parameters match, it's considered to be matched.
         # If all parameters match, it's considered to be matched.
-        self.assertTrue(hash.match(Rdata(RRType.NSEC3(), RRClass.IN(),
-                                         "1 0 12 aabbccdd " +
-                                         self.nsec3_common)))
+        self.assertTrue(hash.match(Rdata(rrtype, RRClass.IN(),
+                                         "1 0 12 aabbccdd" + postfix)))
         # Algorithm doesn't match
         # Algorithm doesn't match
-        self.assertFalse(hash.match(Rdata(RRType.NSEC3(), RRClass.IN(),
-                                          "2 0 12 aabbccdd " +
-                                          self.nsec3_common)))
+        self.assertFalse(hash.match(Rdata(rrtype, RRClass.IN(),
+                                          "2 0 12 aabbccdd" + postfix)))
         # Iterations doesn't match
         # Iterations doesn't match
-        self.assertFalse(hash.match(Rdata(RRType.NSEC3(), RRClass.IN(),
-                                          "1 0 1 aabbccdd " +
-                                          self.nsec3_common)))
+        self.assertFalse(hash.match(Rdata(rrtype, RRClass.IN(),
+                                          "1 0 1 aabbccdd" + postfix)))
         # Salt doesn't match
         # Salt doesn't match
-        self.assertFalse(hash.match(Rdata(RRType.NSEC3(), RRClass.IN(),
-                                          "1 0 12 aabbccde " +
-                                          self.nsec3_common)))
+        self.assertFalse(hash.match(Rdata(rrtype, RRClass.IN(),
+                                          "1 0 12 aabbccde" + postfix)))
         # Salt doesn't match: the other has an empty salt
         # Salt doesn't match: the other has an empty salt
-        self.assertFalse(hash.match(Rdata(RRType.NSEC3(), RRClass.IN(),
-                                          "1 0 12 - " + self.nsec3_common)))
+        self.assertFalse(hash.match(Rdata(rrtype, RRClass.IN(),
+                                          "1 0 12 -" + postfix)))
         # Flags doesn't matter
         # Flags doesn't matter
-        self.assertTrue(hash.match(Rdata(RRType.NSEC3(), RRClass.IN(),
-                                         "1 1 12 aabbccdd " +
-                                         self.nsec3_common)))
+        self.assertTrue(hash.match(Rdata(rrtype, RRClass.IN(),
+                                         "1 1 12 aabbccdd" + postfix)))
 
 
-    def test_match_with_nsec3(self):
-        self.check_match_with_nsec3(self.test_hash)
-        self.check_match_with_nsec3(self.test_hash_nsec3)
+    def test_match(self):
+        self.check_match(self.test_hash, RRType.NSEC3(),
+                         " " + self.nsec3_common)
+        self.check_match(self.test_hash_nsec3, RRType.NSEC3(),
+                         " " + self.nsec3_common)
+        self.check_match(self.test_hash, RRType.NSEC3PARAM(), "")
+        self.check_match(self.test_hash_nsec3, RRType.NSEC3PARAM(), "")
 
 
         # bad parameter checks
         # bad parameter checks
         self.assertRaises(TypeError, self.test_hash.match, 1)
         self.assertRaises(TypeError, self.test_hash.match, 1)
         self.assertRaises(TypeError, self.test_hash.match,
         self.assertRaises(TypeError, self.test_hash.match,
                           Rdata(RRType.NSEC3(), RRClass.IN(),
                           Rdata(RRType.NSEC3(), RRClass.IN(),
                                 "1 0 12 aabbccdd " + self.nsec3_common), 1)
                                 "1 0 12 aabbccdd " + self.nsec3_common), 1)
-        # this would result in bad_cast
-        self.assertRaises(IscException, self.test_hash.match,
+        self.assertRaises(TypeError, self.test_hash.match,
                           Rdata(RRType.A(), RRClass.IN(), "192.0.2.1"))
                           Rdata(RRType.A(), RRClass.IN(), "192.0.2.1"))
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':

+ 24 - 16
src/lib/dns/tests/nsec3hash_unittest.cc

@@ -100,37 +100,45 @@ TEST_F(NSEC3HashTest, calculate) {
               ->calculate(Name("example.org")));
               ->calculate(Name("example.org")));
 }
 }
 
 
-// Common checks for match against NSEC3 parameters
+// Common checks for match cases
+template <typename RDATAType>
 void
 void
-matchWithNSEC3Check(NSEC3Hash& hash) {
+matchCheck(NSEC3Hash& hash, const string& postfix) {
     // If all parameters match, it's considered to be matched.
     // If all parameters match, it's considered to be matched.
-    EXPECT_TRUE(hash.match(generic::NSEC3("1 0 12 aabbccdd " +
-                                          string(nsec3_common))));
+    EXPECT_TRUE(hash.match(RDATAType("1 0 12 aabbccdd" + postfix)));
+
     // Algorithm doesn't match
     // Algorithm doesn't match
-    EXPECT_FALSE(hash.match(generic::NSEC3("2 0 12 aabbccdd " +
-                                           string(nsec3_common))));
+    EXPECT_FALSE(hash.match(RDATAType("2 0 12 aabbccdd" + postfix)));
     // Iterations doesn't match
     // Iterations doesn't match
-    EXPECT_FALSE(hash.match(generic::NSEC3("1 0 1 aabbccdd " +
-                                           string(nsec3_common))));
+    EXPECT_FALSE(hash.match(RDATAType("1 0 1 aabbccdd" + postfix)));
     // Salt doesn't match
     // Salt doesn't match
-    EXPECT_FALSE(hash.match(generic::NSEC3("1 0 12 aabbccde " +
-                                           string(nsec3_common))));
+    EXPECT_FALSE(hash.match(RDATAType("1 0 12 aabbccde" + postfix)));
     // Salt doesn't match: the other has an empty salt
     // Salt doesn't match: the other has an empty salt
-    EXPECT_FALSE(hash.match(generic::NSEC3("1 0 12 - " +
-                                           string(nsec3_common))));
+    EXPECT_FALSE(hash.match(RDATAType("1 0 12 -" + postfix)));
     // Flags doesn't matter
     // Flags doesn't matter
-    EXPECT_TRUE(hash.match(generic::NSEC3("1 1 12 aabbccdd " +
-                                          string(nsec3_common))));
+    EXPECT_TRUE(hash.match(RDATAType("1 1 12 aabbccdd" + postfix)));
 }
 }
 
 
 TEST_F(NSEC3HashTest, matchWithNSEC3) {
 TEST_F(NSEC3HashTest, matchWithNSEC3) {
     {
     {
         SCOPED_TRACE("match NSEC3PARAM based hash against NSEC3 parameters");
         SCOPED_TRACE("match NSEC3PARAM based hash against NSEC3 parameters");
-        matchWithNSEC3Check(*test_hash);
+        matchCheck<generic::NSEC3>(*test_hash, " " + string(nsec3_common));
+    }
+    {
+        SCOPED_TRACE("match NSEC3 based hash against NSEC3 parameters");
+        matchCheck<generic::NSEC3>(*test_hash_nsec3,
+                                   " " + string(nsec3_common));
+    }
+}
+
+TEST_F(NSEC3HashTest, matchWithNSEC3PARAM) {
+    {
+        SCOPED_TRACE("match NSEC3PARAM based hash against NSEC3 parameters");
+        matchCheck<generic::NSEC3PARAM>(*test_hash, "");
     }
     }
     {
     {
         SCOPED_TRACE("match NSEC3 based hash against NSEC3 parameters");
         SCOPED_TRACE("match NSEC3 based hash against NSEC3 parameters");
-        matchWithNSEC3Check(*test_hash_nsec3);
+        matchCheck<generic::NSEC3PARAM>(*test_hash_nsec3, "");
     }
     }
 }
 }