Parcourir la source

[trac893] updated TSIGKeyRing so that its find() method takes key name and
algorithm and matches both

JINMEI Tatuya il y a 14 ans
Parent
commit
a91fc737dd

+ 23 - 8
src/lib/dns/python/tests/tsigkey_python_test.py

@@ -68,6 +68,9 @@ class TSIGKeyTest(unittest.TestCase):
 
 class TSIGKeyRingTest(unittest.TestCase):
     key_name = Name('example.com')
+    md5_name = Name('hmac-md5.sig-alg.reg.int')
+    sha1_name = Name('hmac-sha1')
+    sha256_name = Name('hmac-sha256')
     secret = b'someRandomData'
 
     def setUp(self):
@@ -152,18 +155,26 @@ class TSIGKeyRingTest(unittest.TestCase):
 
     def test_find(self):
         self.assertEqual((TSIGKeyRing.NOTFOUND, None),
-                         self.keyring.find(self.key_name))
+                         self.keyring.find(self.key_name, self.md5_name))
 
         self.assertEqual(TSIGKeyRing.SUCCESS,
                          self.keyring.add(TSIGKey(self.key_name,
-                                                  TSIGKey.HMACSHA256_NAME,
+                                                  self.sha256_name,
                                                   self.secret)))
-        (code, key) = self.keyring.find(self.key_name)
+        (code, key) = self.keyring.find(self.key_name, self.sha256_name)
         self.assertEqual(TSIGKeyRing.SUCCESS, code)
         self.assertEqual(self.key_name, key.get_key_name())
         self.assertEqual(TSIGKey.HMACSHA256_NAME, key.get_algorithm_name())
         self.assertEqual(self.secret, key.get_secret())
 
+        (code, key) = self.keyring.find(Name('different-key.example'),
+                                        self.sha256_name)
+        self.assertEqual(TSIGKeyRing.NOTFOUND, code)
+        self.assertEqual(None, key)
+        (code, key) = self.keyring.find(self.key_name, self.md5_name)
+        self.assertEqual(TSIGKeyRing.NOTFOUND, code)
+        self.assertEqual(None, key)
+
         self.assertRaises(TypeError, self.keyring.find, 1)
         self.assertRaises(TypeError, self.keyring.find, 'should be a name')
         self.assertRaises(TypeError, self.keyring.find, self.key_name, 0)
@@ -171,24 +182,28 @@ class TSIGKeyRingTest(unittest.TestCase):
     def test_find_from_some(self):
         self.assertEqual(TSIGKeyRing.SUCCESS,
                          self.keyring.add(TSIGKey(self.key_name,
-                                                  TSIGKey.HMACSHA256_NAME,
+                                                  self.sha256_name,
                                                   self.secret)))
         self.assertEqual(TSIGKeyRing.SUCCESS,
                          self.keyring.add(TSIGKey(Name('another.example'),
-                                                  TSIGKey.HMACMD5_NAME,
+                                                  self.md5_name,
                                                   self.secret)))
         self.assertEqual(TSIGKeyRing.SUCCESS,
                          self.keyring.add(TSIGKey(Name('more.example'),
-                                                  TSIGKey.HMACSHA1_NAME,
+                                                  self.sha1_name,
                                                   self.secret)))
 
-        (code, key) = self.keyring.find(Name('another.example'))
+        (code, key) = self.keyring.find(Name('another.example'), self.md5_name)
         self.assertEqual(TSIGKeyRing.SUCCESS, code)
         self.assertEqual(Name('another.example'), key.get_key_name())
         self.assertEqual(TSIGKey.HMACMD5_NAME, key.get_algorithm_name())
 
         self.assertEqual((TSIGKeyRing.NOTFOUND, None),
-                         self.keyring.find(Name('noexist.example')))
+                         self.keyring.find(Name('noexist.example'),
+                                           self.sha1_name))
+        self.assertEqual((TSIGKeyRing.NOTFOUND, None),
+                         self.keyring.find(Name('another.example'),
+                                           self.sha1_name))
 
 if __name__ == '__main__':
     unittest.main()

+ 4 - 2
src/lib/dns/python/tsigkey_python.cc

@@ -416,10 +416,12 @@ TSIGKeyRing_remove(const s_TSIGKeyRing* self, PyObject* args) {
 PyObject*
 TSIGKeyRing_find(const s_TSIGKeyRing* self, PyObject* args) {
     s_Name* key_name;
+    s_Name* algorithm_name;
 
-    if (PyArg_ParseTuple(args, "O!", &name_type, &key_name)) {
+    if (PyArg_ParseTuple(args, "O!O!", &name_type, &key_name,
+                         &name_type, &algorithm_name)) {
         const TSIGKeyRing::FindResult result =
-            self->keyring->find(*key_name->name);
+            self->keyring->find(*key_name->name, *algorithm_name->name);
         if (result.key != NULL) {
             s_TSIGKey* key = PyObject_New(s_TSIGKey, &tsigkey_type);
             if (key == NULL) {

+ 48 - 26
src/lib/dns/tests/tsigkey_unittest.cc

@@ -33,7 +33,7 @@ class TSIGKeyTest : public ::testing::Test {
 protected:
     TSIGKeyTest() : secret("someRandomData"), key_name("example.com") {}
     string secret;
-    Name key_name;
+    const Name key_name;
 };
 
 TEST_F(TSIGKeyTest, algorithmNames) {
@@ -125,12 +125,18 @@ class TSIGKeyRingTest : public ::testing::Test {
 protected:
     TSIGKeyRingTest() :
         key_name("example.com"),
+        md5_name("hmac-md5.sig-alg.reg.int"),
+        sha1_name("hmac-sha1"),
+        sha256_name("hmac-sha256"),
         secretstring("anotherRandomData"),
         secret(secretstring.c_str()),
         secret_len(secretstring.size())
     {}
     TSIGKeyRing keyring;
-    Name key_name;
+    const Name key_name;
+    const Name md5_name;
+    const Name sha1_name;
+    const Name sha256_name;
 private:
     const string secretstring;
 protected:
@@ -203,44 +209,60 @@ TEST_F(TSIGKeyRingTest, removeFromSome) {
 }
 
 TEST_F(TSIGKeyRingTest, find) {
-    EXPECT_EQ(TSIGKeyRing::NOTFOUND, keyring.find(key_name).code);
-    EXPECT_EQ(static_cast<const TSIGKey*>(NULL), keyring.find(key_name).key);
-
-    EXPECT_EQ(TSIGKeyRing::SUCCESS, keyring.add(
-                  TSIGKey(key_name, TSIGKey::HMACSHA256_NAME(),
-                          secret, secret_len)));
-    const TSIGKeyRing::FindResult result(keyring.find(key_name));
-    EXPECT_EQ(TSIGKeyRing::SUCCESS, result.code);
-    EXPECT_EQ(key_name, result.key->getKeyName());
-    EXPECT_EQ(TSIGKey::HMACSHA256_NAME(), result.key->getAlgorithmName());
+    // If the keyring is empty the search should fail.
+    EXPECT_EQ(TSIGKeyRing::NOTFOUND, keyring.find(key_name, md5_name).code);
+    EXPECT_EQ(static_cast<const TSIGKey*>(NULL),
+              keyring.find(key_name, md5_name).key);
+
+    // Add a key and try to find it.  Should succeed.
+    EXPECT_EQ(TSIGKeyRing::SUCCESS, keyring.add(TSIGKey(key_name, sha256_name,
+                                                        secret, secret_len)));
+    const TSIGKeyRing::FindResult result1(keyring.find(key_name, sha256_name));
+    EXPECT_EQ(TSIGKeyRing::SUCCESS, result1.code);
+    EXPECT_EQ(key_name, result1.key->getKeyName());
+    EXPECT_EQ(TSIGKey::HMACSHA256_NAME(), result1.key->getAlgorithmName());
     EXPECT_PRED_FORMAT4(UnitTestUtil::matchWireData, secret, secret_len,
-                        result.key->getSecret(),
-                        result.key->getSecretLength());
+                        result1.key->getSecret(),
+                        result1.key->getSecretLength());
+
+    // If either key name or algorithm doesn't match, search should fail.
+    const TSIGKeyRing::FindResult result2 =
+        keyring.find(Name("different-key.example"), sha256_name);
+    EXPECT_EQ(TSIGKeyRing::NOTFOUND, result2.code);
+    EXPECT_EQ(static_cast<const TSIGKey*>(NULL), result2.key);
+
+    const TSIGKeyRing::FindResult result3 = keyring.find(key_name, md5_name);
+    EXPECT_EQ(TSIGKeyRing::NOTFOUND, result3.code);
+    EXPECT_EQ(static_cast<const TSIGKey*>(NULL), result3.key);
 }
 
 TEST_F(TSIGKeyRingTest, findFromSome) {
     // essentially the same test, but search a larger set
 
-    EXPECT_EQ(TSIGKeyRing::SUCCESS, keyring.add(
-                  TSIGKey(key_name, TSIGKey::HMACSHA256_NAME(),
-                          secret, secret_len)));
-    EXPECT_EQ(TSIGKeyRing::SUCCESS, keyring.add(
-                  TSIGKey(Name("another.example"), TSIGKey::HMACMD5_NAME(),
-                          secret, secret_len)));
-    EXPECT_EQ(TSIGKeyRing::SUCCESS, keyring.add(
-                  TSIGKey(Name("more.example"), TSIGKey::HMACSHA1_NAME(),
-                          secret, secret_len)));
+    EXPECT_EQ(TSIGKeyRing::SUCCESS, keyring.add(TSIGKey(key_name, sha256_name,
+                                                        secret, secret_len)));
+    EXPECT_EQ(TSIGKeyRing::SUCCESS, keyring.add(TSIGKey(Name("another.example"),
+                                                        md5_name,
+                                                        secret, secret_len)));
+    EXPECT_EQ(TSIGKeyRing::SUCCESS, keyring.add(TSIGKey(Name("more.example"),
+                                                        sha1_name,
+                                                        secret, secret_len)));
 
     const TSIGKeyRing::FindResult result(
-        keyring.find(Name("another.example")));
+        keyring.find(Name("another.example"), md5_name));
     EXPECT_EQ(TSIGKeyRing::SUCCESS, result.code);
     EXPECT_EQ(Name("another.example"), result.key->getKeyName());
     EXPECT_EQ(TSIGKey::HMACMD5_NAME(), result.key->getAlgorithmName());
 
     EXPECT_EQ(TSIGKeyRing::NOTFOUND,
-              keyring.find(Name("noexist.example")).code);
+              keyring.find(Name("noexist.example"), sha1_name).code);
+    EXPECT_EQ(static_cast<const TSIGKey*>(NULL),
+              keyring.find(Name("noexist.example"), sha256_name).key);
+
+    EXPECT_EQ(TSIGKeyRing::NOTFOUND,
+              keyring.find(Name("another.example"), sha1_name).code);
     EXPECT_EQ(static_cast<const TSIGKey*>(NULL),
-              keyring.find(Name("noexist.example")).key);
+              keyring.find(Name("another.example"), sha256_name).key);
 }
 
 TEST(TSIGStringTest, TSIGKeyFromToString) {

+ 3 - 2
src/lib/dns/tsigkey.cc

@@ -230,10 +230,11 @@ TSIGKeyRing::remove(const Name& key_name) {
 }
 
 TSIGKeyRing::FindResult
-TSIGKeyRing::find(const Name& key_name) {
+TSIGKeyRing::find(const Name& key_name, const Name& algorithm_name) {
     TSIGKeyRingImpl::TSIGKeyMap::const_iterator found =
         impl_->keys.find(key_name);
-    if (found == impl_->keys.end()) {
+    if (found == impl_->keys.end() ||
+        (*found).second.getAlgorithmName() != algorithm_name) {
         return (FindResult(NOTFOUND, NULL));
     }
     return (FindResult(SUCCESS, &((*found).second)));

+ 5 - 2
src/lib/dns/tsigkey.h

@@ -304,7 +304,9 @@ public:
     /// Find a \c TSIGKey for the given name in the \c TSIGKeyRing.
     ///
     /// It searches the internal storage for a \c TSIGKey whose name is
-    /// \c key_name, and returns the result in the form of a \c FindResult
+    /// \c key_name and that uses the hash algorithm identified by
+    /// \c algorithm_name.
+    /// It returns the result in the form of a \c FindResult
     /// object as follows:
     /// - \c code: \c SUCCESS if a key is found; otherwise \c NOTFOUND.
     /// - \c key: A pointer to the found \c TSIGKey object if one is found;
@@ -318,8 +320,9 @@ public:
     /// This method never throws an exception.
     ///
     /// \param key_name The name of the key to be found.
+    /// \param algorithm_name The name of the algorithm of the found key.
     /// \return A \c FindResult object enclosing the search result (see above).
-    FindResult find(const Name& key_name);
+    FindResult find(const Name& key_name, const Name& algorithm_name);
 private:
     struct TSIGKeyRingImpl;
     TSIGKeyRingImpl* impl_;