Browse Source

[1574b] added another version of NSEC3Hash::match, which takes NSEC3PARAM

JINMEI Tatuya 13 years ago
parent
commit
dc8a5a5b10

+ 20 - 5
src/lib/dns/nsec3hash.cc

@@ -71,6 +71,9 @@ public:
     virtual std::string calculate(const Name& name) const;
 
     virtual bool match(const generic::NSEC3& nsec3) const;
+    virtual bool match(const generic::NSEC3PARAM& nsec3param) const;
+    bool match(uint8_t algorithm, uint16_t iterations,
+               const vector<uint8_t>& salt) const;
 
 private:
     const uint8_t algorithm_;
@@ -120,12 +123,24 @@ NSEC3HashRFC5155::calculate(const Name& name) const {
 }
 
 bool
+NSEC3HashRFC5155::match(uint8_t algorithm, uint16_t iterations,
+                        const vector<uint8_t>& salt) const
+{
+    return (algorithm_ == algorithm && iterations_ == iterations &&
+            salt_.size() == salt.size() &&
+            (salt_.empty() || memcmp(&salt_[0], &salt[0], salt_.size()) == 0));
+}
+
+bool
 NSEC3HashRFC5155::match(const generic::NSEC3& nsec3) const {
-    return (algorithm_ == nsec3.getHashalg() &&
-            iterations_ == nsec3.getIterations() &&
-            salt_.size() == nsec3.getSalt().size() &&
-            (salt_.empty() ||
-             memcmp(&salt_[0], &nsec3.getSalt()[0], salt_.size()) == 0));
+    return (match(nsec3.getHashalg(), nsec3.getIterations(),
+                  nsec3.getSalt()));
+}
+
+bool
+NSEC3HashRFC5155::match(const generic::NSEC3PARAM& nsec3param) const {
+    return (match(nsec3param.getHashalg(), nsec3param.getIterations(),
+                  nsec3param.getSalt()));
 }
 } // end of unnamed namespace
 

+ 6 - 0
src/lib/dns/nsec3hash.h

@@ -146,6 +146,12 @@ public:
     /// \return true If the given parameters match the local ones; false
     /// otherwise.
     virtual bool match(const rdata::generic::NSEC3& nsec3) const = 0;
+
+    /// \brief Match given NSEC3PARAM parameters with that of the hash.
+    ///
+    /// This is similar to the other version, but extracts the parameters
+    /// to compare from an NSEC3PARAM RDATA object.
+    virtual bool match(const rdata::generic::NSEC3PARAM& nsec3param) const = 0;
 };
 
 }

+ 19 - 7
src/lib/dns/python/nsec3hash_python.cc

@@ -55,8 +55,8 @@ NSEC3Hash_init(PyObject* po_self, PyObject* args, PyObject*) {
         if (PyArg_ParseTuple(args, "O", &po_rdata)) {
             if (!PyRdata_Check(po_rdata)) {
                 PyErr_Format(PyExc_TypeError,
-                             "param must be an Rdata of type NSEC3/NSEC3HASH, "
-                             "not %.200s", po_rdata->ob_type->tp_name);
+                             "param must be an Rdata of type NSEC3/NSEC3PARAM,"
+                             " not %.200s", po_rdata->ob_type->tp_name);
                 return (-1);
             }
             const Rdata& rdata = PyRdata_ToRdata(po_rdata);
@@ -138,13 +138,25 @@ NSEC3Hash_match(PyObject* po_self, PyObject* args) {
         if (PyArg_ParseTuple(args, "O", &po_rdata)) {
             if (!PyRdata_Check(po_rdata)) {
                 PyErr_Format(PyExc_TypeError,
-                             "param must be an Rdata of type NSEC3, "
-                             "not %.200s", po_rdata->ob_type->tp_name);
+                             "param must be an Rdata of type NSEC3/NSEC3PARAM,"
+                             " not %.200s", po_rdata->ob_type->tp_name);
+                return (NULL);
+            }
+            const Rdata& rdata = PyRdata_ToRdata(po_rdata);
+            const generic::NSEC3PARAM* nsec3param =
+                dynamic_cast<const generic::NSEC3PARAM*>(&rdata);
+            const generic::NSEC3* nsec3 =
+                dynamic_cast<const generic::NSEC3*>(&rdata);
+            bool matched;
+            if (nsec3param != NULL) {
+                matched = self->cppobj->match(*nsec3param);
+            } else if (nsec3 != NULL) {
+                matched = self->cppobj->match(*nsec3);
+            } else {
+                PyErr_Format(PyExc_TypeError,
+                             "param must be an Rdata of type NSEC3/NSEC3HASH");
                 return (NULL);
             }
-            const bool matched = self->cppobj->match(
-                dynamic_cast<const generic::NSEC3&>(
-                    PyRdata_ToRdata(po_rdata)));
             PyObject* ret = matched ? Py_True : Py_False;
             Py_INCREF(ret);
             return (ret);

+ 5 - 5
src/lib/dns/python/nsec3hash_python_inc.cc

@@ -45,12 +45,12 @@ Return Value(s): Base32hex-encoded string of the hash value.\n\
 ";
 
 const char* const NSEC3Hash_match_doc = "\
-match(nsec3) -> bool\n                   \
+match(rdata) -> bool\n                   \
 \n\
-Match given NSEC3 parameters with that of the hash.\n\
+Match given NSEC3 or NSEC3PARAM parameters with that of the hash.\n\
 \n\
 This method compares NSEC3 parameters used for hash calculation in the\n\
-object with those in the given NSEC3 RDATA, and return true iff they\n\
+object with those in the given RDATA, and return true iff they\n\
 completely match. In the current implementation only the algorithm,\n\
 iterations and salt are compared; the flags are ignored (as they don't\n\
 affect hash calculation per RFC5155).\n\
@@ -59,8 +59,8 @@ Exceptions:\n\
   None\n\
 \n\
 Parameters:\n\
-  nsec3      An NSEC3 RDATA object whose hash parameters are to be\n\
-             matched\n\
+  rdata      An NSEC3 or NSEC3PARAM Rdata object whose hash parameters\n\
+             are to be matched\n\
 \n\
 Return Value(s): true If the given parameters match the local ones;\n\
 false otherwise.\n\

+ 21 - 23
src/lib/dns/python/tests/nsec3hash_python_test.py

@@ -88,42 +88,40 @@ class NSEC3HashTest(unittest.TestCase):
         self.assertRaises(TypeError, self.test_hash.calculate)
         self.assertRaises(TypeError, self.test_hash.calculate, Name("."), 1)
 
-    def check_match_with_nsec3(self, hash):
+    def check_match(self, hash, rrtype, postfix):
         # If all parameters match, it's considered to be matched.
-        self.assertTrue(hash.match(Rdata(RRType.NSEC3(), RRClass.IN(),
-                                         "1 0 12 aabbccdd " +
-                                         self.nsec3_common)))
+        self.assertTrue(hash.match(Rdata(rrtype, RRClass.IN(),
+                                         "1 0 12 aabbccdd" + postfix)))
         # Algorithm doesn't match
-        self.assertFalse(hash.match(Rdata(RRType.NSEC3(), RRClass.IN(),
-                                          "2 0 12 aabbccdd " +
-                                          self.nsec3_common)))
+        self.assertFalse(hash.match(Rdata(rrtype, RRClass.IN(),
+                                          "2 0 12 aabbccdd" + postfix)))
         # Iterations doesn't match
-        self.assertFalse(hash.match(Rdata(RRType.NSEC3(), RRClass.IN(),
-                                          "1 0 1 aabbccdd " +
-                                          self.nsec3_common)))
+        self.assertFalse(hash.match(Rdata(rrtype, RRClass.IN(),
+                                          "1 0 1 aabbccdd" + postfix)))
         # Salt doesn't match
-        self.assertFalse(hash.match(Rdata(RRType.NSEC3(), RRClass.IN(),
-                                          "1 0 12 aabbccde " +
-                                          self.nsec3_common)))
+        self.assertFalse(hash.match(Rdata(rrtype, RRClass.IN(),
+                                          "1 0 12 aabbccde" + postfix)))
         # Salt doesn't match: the other has an empty salt
-        self.assertFalse(hash.match(Rdata(RRType.NSEC3(), RRClass.IN(),
-                                          "1 0 12 - " + self.nsec3_common)))
+        self.assertFalse(hash.match(Rdata(rrtype, RRClass.IN(),
+                                          "1 0 12 -" + postfix)))
         # Flags doesn't matter
-        self.assertTrue(hash.match(Rdata(RRType.NSEC3(), RRClass.IN(),
-                                         "1 1 12 aabbccdd " +
-                                         self.nsec3_common)))
+        self.assertTrue(hash.match(Rdata(rrtype, RRClass.IN(),
+                                         "1 1 12 aabbccdd" + postfix)))
 
-    def test_match_with_nsec3(self):
-        self.check_match_with_nsec3(self.test_hash)
-        self.check_match_with_nsec3(self.test_hash_nsec3)
+    def test_match(self):
+        self.check_match(self.test_hash, RRType.NSEC3(),
+                         " " + self.nsec3_common)
+        self.check_match(self.test_hash_nsec3, RRType.NSEC3(),
+                         " " + self.nsec3_common)
+        self.check_match(self.test_hash, RRType.NSEC3PARAM(), "")
+        self.check_match(self.test_hash_nsec3, RRType.NSEC3PARAM(), "")
 
         # bad parameter checks
         self.assertRaises(TypeError, self.test_hash.match, 1)
         self.assertRaises(TypeError, self.test_hash.match,
                           Rdata(RRType.NSEC3(), RRClass.IN(),
                                 "1 0 12 aabbccdd " + self.nsec3_common), 1)
-        # this would result in bad_cast
-        self.assertRaises(IscException, self.test_hash.match,
+        self.assertRaises(TypeError, self.test_hash.match,
                           Rdata(RRType.A(), RRClass.IN(), "192.0.2.1"))
 
 if __name__ == '__main__':

+ 24 - 16
src/lib/dns/tests/nsec3hash_unittest.cc

@@ -100,37 +100,45 @@ TEST_F(NSEC3HashTest, calculate) {
               ->calculate(Name("example.org")));
 }
 
-// Common checks for match against NSEC3 parameters
+// Common checks for match cases
+template <typename RDATAType>
 void
-matchWithNSEC3Check(NSEC3Hash& hash) {
+matchCheck(NSEC3Hash& hash, const string& postfix) {
     // If all parameters match, it's considered to be matched.
-    EXPECT_TRUE(hash.match(generic::NSEC3("1 0 12 aabbccdd " +
-                                          string(nsec3_common))));
+    EXPECT_TRUE(hash.match(RDATAType("1 0 12 aabbccdd" + postfix)));
+
     // Algorithm doesn't match
-    EXPECT_FALSE(hash.match(generic::NSEC3("2 0 12 aabbccdd " +
-                                           string(nsec3_common))));
+    EXPECT_FALSE(hash.match(RDATAType("2 0 12 aabbccdd" + postfix)));
     // Iterations doesn't match
-    EXPECT_FALSE(hash.match(generic::NSEC3("1 0 1 aabbccdd " +
-                                           string(nsec3_common))));
+    EXPECT_FALSE(hash.match(RDATAType("1 0 1 aabbccdd" + postfix)));
     // Salt doesn't match
-    EXPECT_FALSE(hash.match(generic::NSEC3("1 0 12 aabbccde " +
-                                           string(nsec3_common))));
+    EXPECT_FALSE(hash.match(RDATAType("1 0 12 aabbccde" + postfix)));
     // Salt doesn't match: the other has an empty salt
-    EXPECT_FALSE(hash.match(generic::NSEC3("1 0 12 - " +
-                                           string(nsec3_common))));
+    EXPECT_FALSE(hash.match(RDATAType("1 0 12 -" + postfix)));
     // Flags doesn't matter
-    EXPECT_TRUE(hash.match(generic::NSEC3("1 1 12 aabbccdd " +
-                                          string(nsec3_common))));
+    EXPECT_TRUE(hash.match(RDATAType("1 1 12 aabbccdd" + postfix)));
 }
 
 TEST_F(NSEC3HashTest, matchWithNSEC3) {
     {
         SCOPED_TRACE("match NSEC3PARAM based hash against NSEC3 parameters");
-        matchWithNSEC3Check(*test_hash);
+        matchCheck<generic::NSEC3>(*test_hash, " " + string(nsec3_common));
+    }
+    {
+        SCOPED_TRACE("match NSEC3 based hash against NSEC3 parameters");
+        matchCheck<generic::NSEC3>(*test_hash_nsec3,
+                                   " " + string(nsec3_common));
+    }
+}
+
+TEST_F(NSEC3HashTest, matchWithNSEC3PARAM) {
+    {
+        SCOPED_TRACE("match NSEC3PARAM based hash against NSEC3 parameters");
+        matchCheck<generic::NSEC3PARAM>(*test_hash, "");
     }
     {
         SCOPED_TRACE("match NSEC3 based hash against NSEC3 parameters");
-        matchWithNSEC3Check(*test_hash_nsec3);
+        matchCheck<generic::NSEC3PARAM>(*test_hash_nsec3, "");
     }
 }