Browse Source

[1574b] added match() method to NSEC3Hash to check parameter consistency

JINMEI Tatuya 13 years ago
parent
commit
a2943fb1b3

+ 12 - 0
src/lib/dns/nsec3hash.cc

@@ -15,6 +15,7 @@
 #include <stdint.h>
 
 #include <cassert>
+#include <cstring>
 #include <string>
 #include <vector>
 
@@ -69,6 +70,8 @@ public:
 
     virtual std::string calculate(const Name& name) const;
 
+    virtual bool match(const generic::NSEC3& nsec3) const;
+
 private:
     const uint8_t algorithm_;
     const uint16_t iterations_;
@@ -115,6 +118,15 @@ NSEC3HashRFC5155::calculate(const Name& name) const {
 
     return (encodeBase32Hex(digest_));
 }
+
+bool
+NSEC3HashRFC5155::match(const generic::NSEC3& nsec3) const {
+    return (algorithm_ == nsec3.getHashalg() &&
+            iterations_ == nsec3.getIterations() &&
+            salt_.size() == nsec3.getSalt().size() &&
+            (salt_.empty() ||
+             memcmp(&salt_[0], &nsec3.getSalt()[0], salt_.size()) == 0));
+}
 } // end of unnamed namespace
 
 namespace isc {

+ 17 - 0
src/lib/dns/nsec3hash.h

@@ -25,6 +25,7 @@ class Name;
 
 namespace rdata {
 namespace generic {
+class NSEC3;
 class NSEC3PARAM;
 }
 }
@@ -123,6 +124,22 @@ public:
     /// calculated.
     /// \return Base32hex-encoded string of the hash value.
     virtual std::string calculate(const Name& name) const = 0;
+
+    /// \brief Match given NSEC3 parameters with that of the hash.
+    ///
+    /// This method compares NSEC3 parameters used for hash calculation
+    /// in the object with those in the given NSEC3 RDATA, and return
+    /// true iff they completely match.  In the current implementation
+    /// only the algorithm, iterations and salt are compared; the flags
+    /// are ignored (as they don't affect hash calculation per RFC5155).
+    ///
+    /// \throw None
+    ///
+    /// \param nsec3 An NSEC3 RDATA object whose hash parameters are to be
+    /// matched
+    /// \return true If the given parameters match the local ones; false
+    /// otherwise.
+    virtual bool match(const rdata::generic::NSEC3& nsec3) const = 0;
 };
 
 }

+ 32 - 0
src/lib/dns/python/nsec3hash_python.cc

@@ -118,6 +118,37 @@ NSEC3Hash_calculate(PyObject* po_self, PyObject* args) {
     return (NULL);
 }
 
+PyObject*
+NSEC3Hash_match(PyObject* po_self, PyObject* args) {
+    s_NSEC3Hash* const self = static_cast<s_NSEC3Hash*>(po_self);
+
+    try {
+        PyObject* po_rdata;
+        if (PyArg_ParseTuple(args, "O", &po_rdata)) {
+            if (!PyRdata_Check(po_rdata)) {
+                PyErr_Format(PyExc_TypeError,
+                             "param must be an Rdata of type NSEC3, "
+                             "not %.200s", po_rdata->ob_type->tp_name);
+                return (NULL);
+            }
+            const bool matched = self->cppobj->match(
+                dynamic_cast<const generic::NSEC3&>(
+                    PyRdata_ToRdata(po_rdata)));
+            return (matched ? Py_True : Py_False);
+        }
+    } catch (const exception& ex) {
+        const string ex_what = "Unexpected failure in NSEC3Hash.match: " +
+            string(ex.what());
+        PyErr_SetString(po_IscException, ex_what.c_str());
+        return (NULL);
+    } catch (...) {
+        PyErr_SetString(PyExc_SystemError, "Unexpected C++ exception");
+        return (NULL);
+    }
+
+    return (NULL);
+}
+
 // This list contains the actual set of functions we have in
 // python. Each entry has
 // 1. Python method name
@@ -126,6 +157,7 @@ NSEC3Hash_calculate(PyObject* po_self, PyObject* args) {
 // 4. Documentation
 PyMethodDef NSEC3Hash_methods[] = {
     { "calculate", NSEC3Hash_calculate, METH_VARARGS, NSEC3Hash_calculate_doc },
+    { "match", NSEC3Hash_match, METH_VARARGS, NSEC3Hash_match_doc },
     { NULL, NULL, 0, NULL }
 };
 } // end of unnamed namespace

+ 23 - 1
src/lib/dns/python/nsec3hash_python_inc.cc

@@ -26,7 +26,7 @@ NSEC3Hash(param)\n\
 ";
 
 const char* const NSEC3Hash_calculate_doc = "\
-calculate(Name) -> string\n\
+calculate(name) -> string\n\
 \n\
 Calculate the NSEC3 hash.\n\
 \n\
@@ -42,4 +42,26 @@ Parameters:\n\
 \n\
 Return Value(s): Base32hex-encoded string of the hash value.\n\
 ";
+
+const char* const NSEC3Hash_match_doc = "\
+match(nsec3) -> bool\n                   \
+\n\
+Match given NSEC3 parameters with that of the hash.\n\
+\n\
+This method compares NSEC3 parameters used for hash calculation in the\n\
+object with those in the given NSEC3 RDATA, and return true iff they\n\
+completely match. In the current implementation only the algorithm,\n\
+iterations and salt are compared; the flags are ignored (as they don't\n\
+affect hash calculation per RFC5155).\n\
+\n\
+Exceptions:\n\
+  None\n\
+\n\
+Parameters:\n\
+  nsec3      An NSEC3 RDATA object whose hash parameters are to be\n\
+             matched\n\
+\n\
+Return Value(s): true If the given parameters match the local ones;\n\
+false otherwise.\n\
+";
 } // unnamed namespace

+ 42 - 0
src/lib/dns/python/tests/nsec3hash_python_test.py

@@ -25,6 +25,7 @@ class NSEC3HashTest(unittest.TestCase):
     def setUp(self):
         self.test_hash = NSEC3Hash(Rdata(RRType.NSEC3PARAM(), RRClass.IN(),
                                          "1 0 12 aabbccdd"))
+        self.nsec3_common = "2T7B4G4VSA5SMI47K61MV5BV1A22BOJR A RRSIG"
 
     def test_bad_construct(self):
         # missing parameter
@@ -73,5 +74,46 @@ class NSEC3HashTest(unittest.TestCase):
         self.assertRaises(TypeError, self.test_hash.calculate)
         self.assertRaises(TypeError, self.test_hash.calculate, Name("."), 1)
 
+    def test_match_with_nsec3(self):
+        # If all parameters match, it's considered to be matched.
+        self.assertTrue(self.test_hash.match(Rdata(RRType.NSEC3(),
+                                                   RRClass.IN(),
+                                                   "1 0 12 aabbccdd " +
+                                                   self.nsec3_common)))
+        # Algorithm doesn't match
+        self.assertFalse(self.test_hash.match(Rdata(RRType.NSEC3(),
+                                                    RRClass.IN(),
+                                                    "2 0 12 aabbccdd " +
+                                                    self.nsec3_common)))
+        # Iterations doesn't match
+        self.assertFalse(self.test_hash.match(Rdata(RRType.NSEC3(),
+                                                    RRClass.IN(),
+                                                    "1 0 1 aabbccdd " +
+                                                    self.nsec3_common)))
+        # Salt doesn't match
+        self.assertFalse(self.test_hash.match(Rdata(RRType.NSEC3(),
+                                                    RRClass.IN(),
+                                                    "1 0 12 aabbccde " +
+                                                    self.nsec3_common)))
+        # Salt doesn't match: the other has an empty salt
+        self.assertFalse(self.test_hash.match(Rdata(RRType.NSEC3(),
+                                                    RRClass.IN(),
+                                                    "1 0 12 - " +
+                                                    self.nsec3_common)))
+        # Flags doesn't matter
+        self.assertTrue(self.test_hash.match(Rdata(RRType.NSEC3(),
+                                                   RRClass.IN(),
+                                                   "1 1 12 aabbccdd " +
+                                                   self.nsec3_common)))
+
+        # bad parameter checks
+        self.assertRaises(TypeError, self.test_hash.match, 1)
+        self.assertRaises(TypeError, self.test_hash.match,
+                          Rdata(RRType.NSEC3(), RRClass.IN(),
+                                "1 0 12 aabbccdd " + self.nsec3_common), 1)
+        # this would result in bad_cast
+        self.assertRaises(IscException, self.test_hash.match,
+                          Rdata(RRType.A(), RRClass.IN(), "192.0.2.1"))
+
 if __name__ == '__main__':
     unittest.main()

+ 27 - 0
src/lib/dns/tests/nsec3hash_unittest.cc

@@ -12,6 +12,8 @@
 // OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
 // PERFORMANCE OF THIS SOFTWARE.
 
+#include <string>
+
 #include <gtest/gtest.h>
 
 #include <boost/scoped_ptr.hpp>
@@ -20,12 +22,16 @@
 #include <dns/rdataclass.h>
 
 using boost::scoped_ptr;
+using namespace std;
 using namespace isc::dns;
 using namespace isc::dns::rdata;
 
 namespace {
 typedef scoped_ptr<NSEC3Hash> NSEC3HashPtr;
 
+// Commonly used NSEC3 suffix, defined to reduce amount of type
+const char* const nsec3_common = "2T7B4G4VSA5SMI47K61MV5BV1A22BOJR A RRSIG";
+
 class NSEC3HashTest : public ::testing::Test {
 protected:
     NSEC3HashTest() :
@@ -70,4 +76,25 @@ TEST_F(NSEC3HashTest, calculate) {
               ->calculate(Name("example.org")));
 }
 
+TEST_F(NSEC3HashTest, matchWithNSEC3) {
+    // If all parameters match, it's considered to be matched.
+    EXPECT_TRUE(test_hash->match(generic::NSEC3("1 0 12 aabbccdd " +
+                                                string(nsec3_common))));
+    // Algorithm doesn't match
+    EXPECT_FALSE(test_hash->match(generic::NSEC3("2 0 12 aabbccdd " +
+                                                 string(nsec3_common))));
+    // Iterations doesn't match
+    EXPECT_FALSE(test_hash->match(generic::NSEC3("1 0 1 aabbccdd " +
+                                                 string(nsec3_common))));
+    // Salt doesn't match
+    EXPECT_FALSE(test_hash->match(generic::NSEC3("1 0 12 aabbccde " +
+                                                 string(nsec3_common))));
+    // Salt doesn't match: the other has an empty salt
+    EXPECT_FALSE(test_hash->match(generic::NSEC3("1 0 12 - " +
+                                                 string(nsec3_common))));
+    // Flags doesn't matter
+    EXPECT_TRUE(test_hash->match(generic::NSEC3("1 1 12 aabbccdd " +
+                                                 string(nsec3_common))));
+}
+
 } // end namespace