Parcourir la source

[2018] replace diff.get_updater by find() and find_all()

Jelte Jansen il y a 13 ans
Parent
commit
38292a6182

+ 28 - 34
src/lib/python/isc/ddns/session.py

@@ -247,12 +247,6 @@ class UpdateSession:
             self.__diff = isc.xfrin.diff.Diff(datasrc_client, zname,
                                               journaling=True,
                                               single_update_mode=True)
-            # Note that while it is really the ZoneUpdater that is set
-            # here, it is still called finder, as the only methods that
-            # are and should be used on this object are find() and find_all()
-            # (ZoneUpdater provides the ZoneFinder interface itself, no
-            # separate get_zone_finder())
-            self.__finder = self.__diff.get_updater()
             self.__zname = zname
             self.__zclass = zclass
             self.__datasrc_client = datasrc_client
@@ -313,9 +307,9 @@ class UpdateSession:
            only return what the result code would be (and not read/copy
            any actual data).
         '''
-        result, _, _ = self.__finder.find(rrset.get_name(), rrset.get_type(),
-                                          ZoneFinder.NO_WILDCARD |
-                                          ZoneFinder.FIND_GLUE_OK)
+        result, _, _ = self.__diff.find(rrset.get_name(), rrset.get_type(),
+                                        ZoneFinder.NO_WILDCARD |
+                                        ZoneFinder.FIND_GLUE_OK)
         return result == ZoneFinder.SUCCESS
 
     def __prereq_rrset_exists_value(self, rrset):
@@ -324,10 +318,10 @@ class UpdateSession:
            RFC2136 Section 2.4.2
            Returns True if the prerequisite is satisfied, False otherwise.
         '''
-        result, found_rrset, _ = self.__finder.find(rrset.get_name(),
-                                                    rrset.get_type(),
-                                                    ZoneFinder.NO_WILDCARD |
-                                                    ZoneFinder.FIND_GLUE_OK)
+        result, found_rrset, _ = self.__diff.find(rrset.get_name(),
+                                                  rrset.get_type(),
+                                                  ZoneFinder.NO_WILDCARD |
+                                                  ZoneFinder.FIND_GLUE_OK)
         if result == ZoneFinder.SUCCESS and\
            rrset.get_name() == found_rrset.get_name() and\
            rrset.get_type() == found_rrset.get_type():
@@ -366,9 +360,9 @@ class UpdateSession:
            to only return what the result code would be (and not read/copy
            any actual data).
         '''
-        result, rrsets, flags = self.__finder.find_all(rrset.get_name(),
-                                                       ZoneFinder.NO_WILDCARD |
-                                                       ZoneFinder.FIND_GLUE_OK)
+        result, rrsets, flags = self.__diff.find_all(rrset.get_name(),
+                                                     ZoneFinder.NO_WILDCARD |
+                                                     ZoneFinder.FIND_GLUE_OK)
         if result == ZoneFinder.SUCCESS and\
            (flags & ZoneFinder.RESULT_WILDCARD == 0):
             return True
@@ -592,10 +586,10 @@ class UpdateSession:
         # is explicitely ignored here)
         if rrset.get_type() == RRType.SOA():
             return
-        result, orig_rrset, _ = self.__finder.find(rrset.get_name(),
-                                                   rrset.get_type(),
-                                                   ZoneFinder.NO_WILDCARD |
-                                                   ZoneFinder.FIND_GLUE_OK)
+        result, orig_rrset, _ = self.__diff.find(rrset.get_name(),
+                                                 rrset.get_type(),
+                                                 ZoneFinder.NO_WILDCARD |
+                                                 ZoneFinder.FIND_GLUE_OK)
         if result == ZoneFinder.CNAME:
             # Ignore non-cname rrs that try to update CNAME records
             # (if rrset itself is a CNAME, the finder result would be
@@ -626,10 +620,10 @@ class UpdateSession:
            Special cases: if the delete statement is for the
            zone's apex, and the type is either SOA or NS, it
            is ignored.'''
-        result, to_delete, _ = self.__finder.find(rrset.get_name(),
-                                                  rrset.get_type(),
-                                                  ZoneFinder.NO_WILDCARD |
-                                                  ZoneFinder.FIND_GLUE_OK)
+        result, to_delete, _ = self.__diff.find(rrset.get_name(),
+                                                rrset.get_type(),
+                                                ZoneFinder.NO_WILDCARD |
+                                                ZoneFinder.FIND_GLUE_OK)
         if result == ZoneFinder.SUCCESS:
             if to_delete.get_name() == self.__zname and\
                (to_delete.get_type() == RRType.SOA() or\
@@ -651,10 +645,10 @@ class UpdateSession:
         # (see ticket #2016)
         # The related test is currently disabled. When this is fixed,
         # enable that test again.
-        result, orig_rrset, _ = self.__finder.find(rrset.get_name(),
-                                                   rrset.get_type(),
-                                                   ZoneFinder.NO_WILDCARD |
-                                                   ZoneFinder.FIND_GLUE_OK)
+        result, orig_rrset, _ = self.__diff.find(rrset.get_name(),
+                                                 rrset.get_type(),
+                                                 ZoneFinder.NO_WILDCARD |
+                                                 ZoneFinder.FIND_GLUE_OK)
         # Even a real rrset comparison wouldn't help here...
         # The goal is to make sure that after deletion of the
         # given rrset, at least 1 NS record is left (at the apex).
@@ -684,9 +678,9 @@ class UpdateSession:
            Special case: if the name is the zone's apex, SOA and
            NS records are kept.
         '''
-        result, rrsets, flags = self.__finder.find_all(rrset.get_name(),
-                                                       ZoneFinder.NO_WILDCARD |
-                                                       ZoneFinder.FIND_GLUE_OK)
+        result, rrsets, flags = self.__diff.find_all(rrset.get_name(),
+                                                     ZoneFinder.NO_WILDCARD |
+                                                     ZoneFinder.FIND_GLUE_OK)
         if result == ZoneFinder.SUCCESS and\
            (flags & ZoneFinder.RESULT_WILDCARD == 0):
             for to_delete in rrsets:
@@ -736,9 +730,9 @@ class UpdateSession:
         # serial magic and add the newly created one
 
         # get it from DS and to increment and stuff
-        result, old_soa, _ = self.__finder.find(self.__zname, RRType.SOA(),
-                                                ZoneFinder.NO_WILDCARD |
-                                                ZoneFinder.FIND_GLUE_OK)
+        result, old_soa, _ = self.__diff.find(self.__zname, RRType.SOA(),
+                                              ZoneFinder.NO_WILDCARD |
+                                              ZoneFinder.FIND_GLUE_OK)
 
         if self.__added_soa is not None:
             new_soa = self.__added_soa

+ 19 - 7
src/lib/python/isc/xfrin/diff.py

@@ -377,12 +377,24 @@ class Diff:
         else:
             return (self.__deletions, self.__additions)
 
-    def get_updater(self):
+    def find(self, name, rrtype, options=isc.datasrc.ZoneFinder.FIND_DEFAULT):
         """
-        Returns the ZoneUpdater associated with this Diff instance.
-        While update statements can be used on this updater, its main
-        goal is to provide the ZoneFinder interface for searching through
-        the zone as it was on the moment the updater was created.
-        If the Diff has been committed, this will return None.
+        Calls the find() method in the ZoneFinder associated with this
+        Diff's ZoneUpdater, i.e. the find() on the zone as it was on the
+        moment this Diff object got created.
+        See the ZoneFinder documentation for a full description.
+        Note that the result does not include changes made in this Diff
+        instance so far.
         """
-        return self.__updater
+        return self.__updater.find(name, rrtype, options)
+
+    def find_all(self, name, options=isc.datasrc.ZoneFinder.FIND_DEFAULT):
+        """
+        Calls the find() method in the ZoneFinder associated with this
+        Diff's ZoneUpdater, i.e. the find_all() on the zone as it was on the
+        moment this Diff object got created.
+        See the ZoneFinder documentation for a full description.
+        Note that the result does not include changes made in this Diff
+        instance so far.
+        """
+        return self.__updater.find_all(name, options)

+ 86 - 0
src/lib/python/isc/xfrin/tests/diff_tests.py

@@ -48,6 +48,13 @@ class DiffTest(unittest.TestCase):
         self.__broken_called = False
         self.__warn_called = False
         self.__should_replace = False
+        self.__find_called = False
+        self.__find_name = None
+        self.__find_type = None
+        self.__find_options = None
+        self.__find_all_called = False
+        self.__find_all_name = None
+        self.__find_all_options = None
         # Some common values
         self.__rrclass = RRClass.IN()
         self.__type = RRType.A()
@@ -156,6 +163,23 @@ class DiffTest(unittest.TestCase):
 
         return self
 
+    def find(self, name, rrtype, options=None):
+        self.__find_called = True
+        self.__find_name = name
+        self.__find_type = rrtype
+        self.__find_options = options
+        # Doesn't really matter what is returned, as long
+        # as the test can check that it's passed along
+        return "find_return"
+
+    def find_all(self, name, options=None):
+        self.__find_all_called = True
+        self.__find_all_name = name
+        self.__find_all_options = options
+        # Doesn't really matter what is returned, as long
+        # as the test can check that it's passed along
+        return "find_all_return"
+
     def test_create(self):
         """
         This test the case when the diff is successfuly created. It just
@@ -587,6 +611,68 @@ class DiffTest(unittest.TestCase):
         self.assertRaises(ValueError, diff.add_data, a)
         self.assertRaises(ValueError, diff.delete_data, a)
 
+    def test_find(self):
+        diff = Diff(self, Name('example.org.'))
+        name = Name('www.example.org.')
+        rrtype = RRType.A()
+
+        self.assertFalse(self.__find_called)
+        self.assertEqual(None, self.__find_name)
+        self.assertEqual(None, self.__find_type)
+        self.assertEqual(None, self.__find_options)
+
+        self.assertEqual("find_return", diff.find(name, rrtype))
+
+        self.assertTrue(self.__find_called)
+        self.assertEqual(name, self.__find_name)
+        self.assertEqual(rrtype, self.__find_type)
+        self.assertEqual(isc.datasrc.ZoneFinder.FIND_DEFAULT,
+                         self.__find_options)
+
+    def test_find_options(self):
+        diff = Diff(self, Name('example.org.'))
+        name = Name('foo.example.org.')
+        rrtype = RRType.TXT()
+        options = isc.datasrc.ZoneFinder.NO_WILDCARD |\
+                  isc.datasrc.ZoneFinder.FIND_GLUE_OK
+
+        self.assertEqual("find_return", diff.find(name, rrtype, options))
+
+        self.assertTrue(self.__find_called)
+        self.assertEqual(name, self.__find_name)
+        self.assertEqual(rrtype, self.__find_type)
+        self.assertEqual(options, self.__find_options)
+
+    def test_find_all(self):
+        diff = Diff(self, Name('example.org.'))
+        name = Name('www.example.org.')
+
+        self.assertFalse(self.__find_all_called)
+        self.assertEqual(None, self.__find_all_name)
+        self.assertEqual(None, self.__find_all_options)
+
+        self.assertEqual("find_all_return", diff.find_all(name))
+
+        self.assertTrue(self.__find_all_called)
+        self.assertEqual(name, self.__find_all_name)
+        self.assertEqual(isc.datasrc.ZoneFinder.FIND_DEFAULT,
+                         self.__find_all_options)
+
+    def test_find_all_options(self):
+        diff = Diff(self, Name('example.org.'))
+        name = Name('www.example.org.')
+        options = isc.datasrc.ZoneFinder.NO_WILDCARD |\
+                  isc.datasrc.ZoneFinder.FIND_GLUE_OK
+
+        self.assertFalse(self.__find_all_called)
+        self.assertEqual(None, self.__find_all_name)
+        self.assertEqual(None, self.__find_all_options)
+
+        self.assertEqual("find_all_return", diff.find_all(name, options))
+
+        self.assertTrue(self.__find_all_called)
+        self.assertEqual(name, self.__find_all_name)
+        self.assertEqual(options, self.__find_all_options)
 
 if __name__ == "__main__":
     isc.log.init("bind10")