Browse Source

[master] Merge branch 'trac2018' with trivial conflicts

Conflicts:
	src/lib/python/isc/ddns/session.py
	src/lib/python/isc/ddns/tests/session_tests.py
Jelte Jansen 13 years ago
parent
commit
9b2f2c4d4f

+ 1 - 1
src/lib/python/isc/ddns/libddns_messages.mes

@@ -205,7 +205,7 @@ should give more information on what prerequisite type failed.
 If the result code is FORMERR, the prerequisite section was not well-formed.
 If the result code is FORMERR, the prerequisite section was not well-formed.
 An error response with the given result code is sent back to the client.
 An error response with the given result code is sent back to the client.
 
 
-% LIBDDNS_UPDATE_UNCAUGHT_EXCEPTION update client %1 for zone %2: uncaught exception while processing update section: %1
+% LIBDDNS_UPDATE_UNCAUGHT_EXCEPTION update client %1 for zone %2: uncaught exception while processing update section: %3
 An uncaught exception was encountered while processing the Update
 An uncaught exception was encountered while processing the Update
 section of a DDNS message. The specific exception is shown in the log message.
 section of a DDNS message. The specific exception is shown in the log message.
 To make sure DDNS service is not interrupted, this problem is caught instead
 To make sure DDNS service is not interrupted, this problem is caught instead

+ 70 - 67
src/lib/python/isc/ddns/session.py

@@ -189,7 +189,8 @@ class UpdateSession:
 
 
         '''
         '''
         try:
         try:
-            self.__get_update_zone()
+            self._get_update_zone()
+            self._create_diff()
             prereq_result = self.__check_prerequisites()
             prereq_result = self.__check_prerequisites()
             if prereq_result != Rcode.NOERROR():
             if prereq_result != Rcode.NOERROR():
                 self.__make_response(prereq_result)
                 self.__make_response(prereq_result)
@@ -219,7 +220,7 @@ class UpdateSession:
             self.__make_response(Rcode.SERVFAIL())
             self.__make_response(Rcode.SERVFAIL())
             return UPDATE_ERROR, None, None
             return UPDATE_ERROR, None, None
 
 
-    def __get_update_zone(self):
+    def _get_update_zone(self):
         '''Parse the zone section and find the zone to be updated.
         '''Parse the zone section and find the zone to be updated.
 
 
         If the zone section is valid and the specified zone is found in
         If the zone section is valid and the specified zone is found in
@@ -228,8 +229,11 @@ class UpdateSession:
                           zone
                           zone
         __zname: The zone name as a Name object
         __zname: The zone name as a Name object
         __zclass: The zone class as an RRClass object
         __zclass: The zone class as an RRClass object
-        __finder: A ZoneFinder for this zone
-        If this method raises an exception, these members are not set
+        If this method raises an exception, these members are not set.
+
+        Note: This method is protected for ease of use in tests, where
+        methods are tested that need the setup done here without calling
+        the full handle() method.
         '''
         '''
         # Validation: the zone section must contain exactly one question,
         # Validation: the zone section must contain exactly one question,
         # and it must be of type SOA.
         # and it must be of type SOA.
@@ -247,10 +251,9 @@ class UpdateSession:
         zclass = zrecord.get_class()
         zclass = zrecord.get_class()
         zone_type, datasrc_client = self.__zone_config.find_zone(zname, zclass)
         zone_type, datasrc_client = self.__zone_config.find_zone(zname, zclass)
         if zone_type == isc.ddns.zone_config.ZONE_PRIMARY:
         if zone_type == isc.ddns.zone_config.ZONE_PRIMARY:
-            _, self.__finder = datasrc_client.find_zone(zname)
+            self.__datasrc_client = datasrc_client
             self.__zname = zname
             self.__zname = zname
             self.__zclass = zclass
             self.__zclass = zclass
-            self.__datasrc_client = datasrc_client
             return
             return
         elif zone_type == isc.ddns.zone_config.ZONE_SECONDARY:
         elif zone_type == isc.ddns.zone_config.ZONE_SECONDARY:
             # We are a secondary server; since we don't yet support update
             # We are a secondary server; since we don't yet support update
@@ -265,6 +268,26 @@ class UpdateSession:
                      ZoneFormatter(zname, zclass))
                      ZoneFormatter(zname, zclass))
         raise UpdateError('notauth', zname, zclass, Rcode.NOTAUTH(), True)
         raise UpdateError('notauth', zname, zclass, Rcode.NOTAUTH(), True)
 
 
+    def _create_diff(self):
+        '''
+        Initializes the internal data structure used for searching current
+        data and for adding and deleting data. This is supposed to be called
+        after ACL checks but before prerequisite checks (since the latter
+        needs the find calls provided by the Diff class).
+        Adds the private member:
+        __diff: A buffer of changes made against the zone by this update
+                This object also contains find() calls, see documentation
+                of the Diff class.
+
+        Note: This method is protected for ease of use in tests, where
+        methods are tested that need the setup done here without calling
+        the full handle() method.
+        '''
+        self.__diff = isc.xfrin.diff.Diff(self.__datasrc_client,
+                                          self.__zname,
+                                          journaling=True,
+                                          single_update_mode=True)
+
     def __check_update_acl(self, zname, zclass):
     def __check_update_acl(self, zname, zclass):
         '''Apply update ACL for the zone to be updated.'''
         '''Apply update ACL for the zone to be updated.'''
         acl = self.__zone_config.get_update_acl(zname, zclass)
         acl = self.__zone_config.get_update_acl(zname, zclass)
@@ -308,9 +331,7 @@ class UpdateSession:
            only return what the result code would be (and not read/copy
            only return what the result code would be (and not read/copy
            any actual data).
            any actual data).
         '''
         '''
-        result, _, _ = self.__finder.find(rrset.get_name(), rrset.get_type(),
-                                          ZoneFinder.NO_WILDCARD |
-                                          ZoneFinder.FIND_GLUE_OK)
+        result, _, _ = self.__diff.find(rrset.get_name(), rrset.get_type())
         return result == ZoneFinder.SUCCESS
         return result == ZoneFinder.SUCCESS
 
 
     def __prereq_rrset_exists_value(self, rrset):
     def __prereq_rrset_exists_value(self, rrset):
@@ -319,10 +340,8 @@ class UpdateSession:
            RFC2136 Section 2.4.2
            RFC2136 Section 2.4.2
            Returns True if the prerequisite is satisfied, False otherwise.
            Returns True if the prerequisite is satisfied, False otherwise.
         '''
         '''
-        result, found_rrset, _ = self.__finder.find(rrset.get_name(),
-                                                    rrset.get_type(),
-                                                    ZoneFinder.NO_WILDCARD |
-                                                    ZoneFinder.FIND_GLUE_OK)
+        result, found_rrset, _ = self.__diff.find(rrset.get_name(),
+                                                  rrset.get_type())
         if result == ZoneFinder.SUCCESS and\
         if result == ZoneFinder.SUCCESS and\
            rrset.get_name() == found_rrset.get_name() and\
            rrset.get_name() == found_rrset.get_name() and\
            rrset.get_type() == found_rrset.get_type():
            rrset.get_type() == found_rrset.get_type():
@@ -361,9 +380,7 @@ class UpdateSession:
            to only return what the result code would be (and not read/copy
            to only return what the result code would be (and not read/copy
            any actual data).
            any actual data).
         '''
         '''
-        result, rrsets, flags = self.__finder.find_all(rrset.get_name(),
-                                                       ZoneFinder.NO_WILDCARD |
-                                                       ZoneFinder.FIND_GLUE_OK)
+        result, rrsets, flags = self.__diff.find_all(rrset.get_name())
         if result == ZoneFinder.SUCCESS and\
         if result == ZoneFinder.SUCCESS and\
            (flags & ZoneFinder.RESULT_WILDCARD == 0):
            (flags & ZoneFinder.RESULT_WILDCARD == 0):
             return True
             return True
@@ -556,20 +573,20 @@ class UpdateSession:
                 return Rcode.FORMERR()
                 return Rcode.FORMERR()
         return Rcode.NOERROR()
         return Rcode.NOERROR()
 
 
-    def __do_update_add_single_rr(self, diff, rr, existing_rrset):
+    def __do_update_add_single_rr(self, rr, existing_rrset):
         '''Helper for __do_update_add_rrs_to_rrset: only add the
         '''Helper for __do_update_add_rrs_to_rrset: only add the
            rr if it is not present yet
            rr if it is not present yet
            (note that rr here should already be a single-rr rrset)
            (note that rr here should already be a single-rr rrset)
         '''
         '''
         if existing_rrset is None:
         if existing_rrset is None:
-            diff.add_data(rr)
+            self.__diff.add_data(rr)
         else:
         else:
             rr_rdata = rr.get_rdata()[0]
             rr_rdata = rr.get_rdata()[0]
             if not rr_rdata in existing_rrset.get_rdata():
             if not rr_rdata in existing_rrset.get_rdata():
-                diff.add_data(rr)
+                self.__diff.add_data(rr)
 
 
-    def __do_update_add_rrs_to_rrset(self, diff, rrset):
-        '''Add the rrs from the given rrset to the diff.
+    def __do_update_add_rrs_to_rrset(self, rrset):
+        '''Add the rrs from the given rrset to the internal diff.
            There is handling for a number of special cases mentioned
            There is handling for a number of special cases mentioned
            in RFC2136;
            in RFC2136;
            - If the addition is a CNAME, but existing data at its
            - If the addition is a CNAME, but existing data at its
@@ -587,11 +604,9 @@ class UpdateSession:
         # is explicitely ignored here)
         # is explicitely ignored here)
         if rrset.get_type() == RRType.SOA():
         if rrset.get_type() == RRType.SOA():
             return
             return
-        result, orig_rrset, _ = self.__finder.find(rrset.get_name(),
-                                                   rrset.get_type(),
-                                                   ZoneFinder.NO_WILDCARD |
-                                                   ZoneFinder.FIND_GLUE_OK)
-        if result == self.__finder.CNAME:
+        result, orig_rrset, _ = self.__diff.find(rrset.get_name(),
+                                                 rrset.get_type())
+        if result == ZoneFinder.CNAME:
             # Ignore non-cname rrs that try to update CNAME records
             # Ignore non-cname rrs that try to update CNAME records
             # (if rrset itself is a CNAME, the finder result would be
             # (if rrset itself is a CNAME, the finder result would be
             # SUCCESS, see next case)
             # SUCCESS, see next case)
@@ -601,7 +616,7 @@ class UpdateSession:
             if rrset.get_type() == RRType.CNAME():
             if rrset.get_type() == RRType.CNAME():
                 # Remove original CNAME record (the new one
                 # Remove original CNAME record (the new one
                 # is added below)
                 # is added below)
-                diff.delete_data(orig_rrset)
+                self.__diff.delete_data(orig_rrset)
             # We do not have WKS support at this time, but if there
             # We do not have WKS support at this time, but if there
             # are special Update equality rules such as for WKS, and
             # are special Update equality rules such as for WKS, and
             # we do have support for the type, this is where the check
             # we do have support for the type, this is where the check
@@ -612,19 +627,17 @@ class UpdateSession:
             if rrset.get_type() == RRType.CNAME():
             if rrset.get_type() == RRType.CNAME():
                 return
                 return
         for rr in foreach_rr(rrset):
         for rr in foreach_rr(rrset):
-            self.__do_update_add_single_rr(diff, rr, orig_rrset)
+            self.__do_update_add_single_rr(rr, orig_rrset)
 
 
-    def __do_update_delete_rrset(self, diff, rrset):
+    def __do_update_delete_rrset(self, rrset):
         '''Deletes the rrset with the name and type of the given
         '''Deletes the rrset with the name and type of the given
            rrset from the zone data (by putting all existing data
            rrset from the zone data (by putting all existing data
-           in the given diff as delete statements).
+           in the internal diff as delete statements).
            Special cases: if the delete statement is for the
            Special cases: if the delete statement is for the
            zone's apex, and the type is either SOA or NS, it
            zone's apex, and the type is either SOA or NS, it
            is ignored.'''
            is ignored.'''
-        result, to_delete, _ = self.__finder.find(rrset.get_name(),
-                                                  rrset.get_type(),
-                                                  ZoneFinder.NO_WILDCARD |
-                                                  ZoneFinder.FIND_GLUE_OK)
+        result, to_delete, _ = self.__diff.find(rrset.get_name(),
+                                                rrset.get_type())
         if result == ZoneFinder.SUCCESS:
         if result == ZoneFinder.SUCCESS:
             if to_delete.get_name() == self.__zname and\
             if to_delete.get_name() == self.__zname and\
                (to_delete.get_type() == RRType.SOA() or\
                (to_delete.get_type() == RRType.SOA() or\
@@ -632,9 +645,9 @@ class UpdateSession:
                 # ignore
                 # ignore
                 return
                 return
             for rr in foreach_rr(to_delete):
             for rr in foreach_rr(to_delete):
-                diff.delete_data(rr)
+                self.__diff.delete_data(rr)
 
 
-    def __ns_deleter_helper(self, diff, rrset):
+    def __ns_deleter_helper(self, rrset):
         '''Special case helper for deleting NS resource records
         '''Special case helper for deleting NS resource records
            at the zone apex. In that scenario, the last NS record
            at the zone apex. In that scenario, the last NS record
            may never be removed (and any action that would do so
            may never be removed (and any action that would do so
@@ -646,10 +659,8 @@ class UpdateSession:
         # (see ticket #2016)
         # (see ticket #2016)
         # The related test is currently disabled. When this is fixed,
         # The related test is currently disabled. When this is fixed,
         # enable that test again.
         # enable that test again.
-        result, orig_rrset, _ = self.__finder.find(rrset.get_name(),
-                                                   rrset.get_type(),
-                                                   ZoneFinder.NO_WILDCARD |
-                                                   ZoneFinder.FIND_GLUE_OK)
+        result, orig_rrset, _ = self.__diff.find(rrset.get_name(),
+                                                 rrset.get_type())
         # Even a real rrset comparison wouldn't help here...
         # Even a real rrset comparison wouldn't help here...
         # The goal is to make sure that after deletion of the
         # The goal is to make sure that after deletion of the
         # given rrset, at least 1 NS record is left (at the apex).
         # given rrset, at least 1 NS record is left (at the apex).
@@ -670,18 +681,16 @@ class UpdateSession:
                                           rrset.get_ttl())
                                           rrset.get_ttl())
                 to_delete.add_rdata(rdata)
                 to_delete.add_rdata(rdata)
                 orig_rrset_rdata.remove(rdata)
                 orig_rrset_rdata.remove(rdata)
-                diff.delete_data(to_delete)
+                self.__diff.delete_data(to_delete)
 
 
-    def __do_update_delete_name(self, diff, rrset):
+    def __do_update_delete_name(self, rrset):
         '''Delete all data at the name of the given rrset,
         '''Delete all data at the name of the given rrset,
            by adding all data found by find_all as delete statements
            by adding all data found by find_all as delete statements
-           to the given diff.
+           to the internal diff.
            Special case: if the name is the zone's apex, SOA and
            Special case: if the name is the zone's apex, SOA and
            NS records are kept.
            NS records are kept.
         '''
         '''
-        result, rrsets, flags = self.__finder.find_all(rrset.get_name(),
-                                                       ZoneFinder.NO_WILDCARD |
-                                                       ZoneFinder.FIND_GLUE_OK)
+        result, rrsets, flags = self.__diff.find_all(rrset.get_name())
         if result == ZoneFinder.SUCCESS and\
         if result == ZoneFinder.SUCCESS and\
            (flags & ZoneFinder.RESULT_WILDCARD == 0):
            (flags & ZoneFinder.RESULT_WILDCARD == 0):
             for to_delete in rrsets:
             for to_delete in rrsets:
@@ -692,9 +701,9 @@ class UpdateSession:
                     continue
                     continue
                 else:
                 else:
                     for rr in foreach_rr(to_delete):
                     for rr in foreach_rr(to_delete):
-                        diff.delete_data(rr)
+                        self.__diff.delete_data(rr)
 
 
-    def __do_update_delete_rrs_from_rrset(self, diff, rrset):
+    def __do_update_delete_rrs_from_rrset(self, rrset):
         '''Deletes all resource records in the given rrset from the
         '''Deletes all resource records in the given rrset from the
            zone. Resource records that do not exist are ignored.
            zone. Resource records that do not exist are ignored.
            If the rrset if of type SOA, it is ignored.
            If the rrset if of type SOA, it is ignored.
@@ -715,25 +724,23 @@ class UpdateSession:
             elif rrset.get_type() == RRType.NS():
             elif rrset.get_type() == RRType.NS():
                 # hmm. okay. annoying. There must be at least one left,
                 # hmm. okay. annoying. There must be at least one left,
                 # delegate to helper method
                 # delegate to helper method
-                self.__ns_deleter_helper(diff, to_delete)
+                self.__ns_deleter_helper(to_delete)
                 return
                 return
         for rr in foreach_rr(to_delete):
         for rr in foreach_rr(to_delete):
-            diff.delete_data(rr)
+            self.__diff.delete_data(rr)
 
 
-    def __update_soa(self, diff):
+    def __update_soa(self):
         '''Checks the member value __added_soa, and depending on
         '''Checks the member value __added_soa, and depending on
            whether it has been set and what its value is, creates
            whether it has been set and what its value is, creates
            a new SOA if necessary.
            a new SOA if necessary.
            Then removes the original SOA and adds the new one,
            Then removes the original SOA and adds the new one,
-           by adding the needed operations to the given diff.'''
+           by adding the needed operations to the internal diff.'''
         # Get the existing SOA
         # Get the existing SOA
         # if a new soa was specified, add that one, otherwise, do the
         # if a new soa was specified, add that one, otherwise, do the
         # serial magic and add the newly created one
         # serial magic and add the newly created one
 
 
         # get it from DS and to increment and stuff
         # get it from DS and to increment and stuff
-        result, old_soa, _ = self.__finder.find(self.__zname, RRType.SOA(),
-                                                ZoneFinder.NO_WILDCARD |
-                                                ZoneFinder.FIND_GLUE_OK)
+        result, old_soa, _ = self.__diff.find(self.__zname, RRType.SOA())
 
 
         if self.__added_soa is not None:
         if self.__added_soa is not None:
             new_soa = self.__added_soa
             new_soa = self.__added_soa
@@ -742,8 +749,8 @@ class UpdateSession:
             new_soa = old_soa
             new_soa = old_soa
             # increment goes here
             # increment goes here
 
 
-        diff.delete_data(old_soa)
-        diff.add_data(new_soa)
+        self.__diff.delete_data(old_soa)
+        self.__diff.add_data(new_soa)
 
 
     def __do_update(self):
     def __do_update(self):
         '''Scan, check, and execute the Update section in the
         '''Scan, check, and execute the Update section in the
@@ -758,12 +765,8 @@ class UpdateSession:
 
 
         # update
         # update
         try:
         try:
-            # create an ixfr-out-friendly diff structure to work on
-            diff = isc.xfrin.diff.Diff(self.__datasrc_client, self.__zname,
-                                       journaling=True, single_update_mode=True)
-
             # Do special handling for SOA first
             # Do special handling for SOA first
-            self.__update_soa(diff)
+            self.__update_soa()
 
 
             # Algorithm from RFC2136 Section 3.4
             # Algorithm from RFC2136 Section 3.4
             # Note that this works on full rrsets, not individual RRs.
             # Note that this works on full rrsets, not individual RRs.
@@ -777,16 +780,16 @@ class UpdateSession:
             # do_update statements)
             # do_update statements)
             for rrset in self.__message.get_section(SECTION_UPDATE):
             for rrset in self.__message.get_section(SECTION_UPDATE):
                 if rrset.get_class() == self.__zclass:
                 if rrset.get_class() == self.__zclass:
-                    self.__do_update_add_rrs_to_rrset(diff, rrset)
+                    self.__do_update_add_rrs_to_rrset(rrset)
                 elif rrset.get_class() == RRClass.ANY():
                 elif rrset.get_class() == RRClass.ANY():
                     if rrset.get_type() == RRType.ANY():
                     if rrset.get_type() == RRType.ANY():
-                        self.__do_update_delete_name(diff, rrset)
+                        self.__do_update_delete_name(rrset)
                     else:
                     else:
-                        self.__do_update_delete_rrset(diff, rrset)
+                        self.__do_update_delete_rrset(rrset)
                 elif rrset.get_class() == RRClass.NONE():
                 elif rrset.get_class() == RRClass.NONE():
-                    self.__do_update_delete_rrs_from_rrset(diff, rrset)
+                    self.__do_update_delete_rrs_from_rrset(rrset)
 
 
-            diff.commit()
+            self.__diff.commit()
             return Rcode.NOERROR()
             return Rcode.NOERROR()
         except isc.datasrc.Error as dse:
         except isc.datasrc.Error as dse:
             logger.info(LIBDDNS_UPDATE_DATASRC_ERROR, dse)
             logger.info(LIBDDNS_UPDATE_DATASRC_ERROR, dse)

+ 103 - 3
src/lib/python/isc/ddns/tests/session_tests.py

@@ -94,6 +94,97 @@ def create_rrset(name, rrclass, rrtype, ttl, rdatas = []):
         add_rdata(rrset, rdata)
         add_rdata(rrset, rdata)
     return rrset
     return rrset
 
 
+class SessionModuleTests(unittest.TestCase):
+    '''Tests for module-level functions in the session.py module'''
+
+    def test_foreach_rr_in_rrset(self):
+        rrset = create_rrset("www.example.org", TEST_RRCLASS,
+                             RRType.A(), 3600, [ "192.0.2.1" ])
+
+        l = []
+        for rr in foreach_rr(rrset):
+            l.append(str(rr))
+        self.assertEqual(["www.example.org. 3600 IN A 192.0.2.1\n"], l)
+
+        add_rdata(rrset, "192.0.2.2")
+        add_rdata(rrset, "192.0.2.3")
+
+        # but through the generator, there should be several 1-line entries
+        l = []
+        for rr in foreach_rr(rrset):
+            l.append(str(rr))
+        self.assertEqual(["www.example.org. 3600 IN A 192.0.2.1\n",
+                          "www.example.org. 3600 IN A 192.0.2.2\n",
+                          "www.example.org. 3600 IN A 192.0.2.3\n",
+                         ], l)
+
+    def test_convert_rrset_class(self):
+        # Converting an RRSET to a different class should work
+        # if the rdata types can be converted
+        rrset = create_rrset("www.example.org", RRClass.NONE(), RRType.A(),
+                             3600, [ b'\xc0\x00\x02\x01', b'\xc0\x00\x02\x02'])
+
+        rrset2 = convert_rrset_class(rrset, RRClass.IN())
+        self.assertEqual("www.example.org. 3600 IN A 192.0.2.1\n" +
+                         "www.example.org. 3600 IN A 192.0.2.2\n",
+                         str(rrset2))
+
+        rrset3 = convert_rrset_class(rrset2, RRClass.NONE())
+        self.assertEqual("www.example.org. 3600 CLASS254 A \\# 4 " +
+                         "c0000201\nwww.example.org. 3600 CLASS254 " +
+                         "A \\# 4 c0000202\n",
+                         str(rrset3))
+
+        # depending on what type of bad data is given, a number
+        # of different exceptions could be raised (TODO: i recall
+        # there was a ticket about making a better hierarchy for
+        # dns/parsing related exceptions)
+        self.assertRaises(InvalidRdataLength, convert_rrset_class,
+                          rrset, RRClass.CH())
+        add_rdata(rrset, b'\xc0\x00')
+        self.assertRaises(DNSMessageFORMERR, convert_rrset_class,
+                          rrset, RRClass.IN())
+
+    def test_collect_rrsets(self):
+        '''
+        Tests the 'rrset collector' method, which collects rrsets
+        with the same name and type
+        '''
+        collected = []
+
+        collect_rrsets(collected, create_rrset("a.example.org", RRClass.IN(),
+                                               RRType.A(), 0, [ "192.0.2.1" ]))
+        # Same name and class, different type
+        collect_rrsets(collected, create_rrset("a.example.org", RRClass.IN(),
+                                               RRType.TXT(), 0, [ "one" ]))
+        collect_rrsets(collected, create_rrset("a.example.org", RRClass.IN(),
+                                               RRType.A(), 0, [ "192.0.2.2" ]))
+        collect_rrsets(collected, create_rrset("a.example.org", RRClass.IN(),
+                                               RRType.TXT(), 0, [ "two" ]))
+        # Same class and type as an existing one, different name
+        collect_rrsets(collected, create_rrset("b.example.org", RRClass.IN(),
+                                               RRType.A(), 0, [ "192.0.2.3" ]))
+        # Same name and type as an existing one, different class
+        collect_rrsets(collected, create_rrset("a.example.org", RRClass.CH(),
+                                               RRType.TXT(), 0, [ "one" ]))
+        collect_rrsets(collected, create_rrset("b.example.org", RRClass.IN(),
+                                               RRType.A(), 0, [ "192.0.2.4" ]))
+        collect_rrsets(collected, create_rrset("a.example.org", RRClass.CH(),
+                                               RRType.TXT(), 0, [ "two" ]))
+
+        strings = [ rrset.to_text() for rrset in collected ]
+        # note + vs , in this list
+        expected = ['a.example.org. 0 IN A 192.0.2.1\n' +
+                    'a.example.org. 0 IN A 192.0.2.2\n',
+                    'a.example.org. 0 IN TXT "one"\n' +
+                    'a.example.org. 0 IN TXT "two"\n',
+                    'b.example.org. 0 IN A 192.0.2.3\n' +
+                    'b.example.org. 0 IN A 192.0.2.4\n',
+                    'a.example.org. 0 CH TXT "one"\n' +
+                    'a.example.org. 0 CH TXT "two"\n']
+
+        self.assertEqual(expected, strings)
+
 class SessionTestBase(unittest.TestCase):
 class SessionTestBase(unittest.TestCase):
     '''Base class for all sesion related tests.
     '''Base class for all sesion related tests.
 
 
@@ -112,7 +203,14 @@ class SessionTestBase(unittest.TestCase):
                                       ZoneConfig([], TEST_RRCLASS,
                                       ZoneConfig([], TEST_RRCLASS,
                                                  self._datasrc_client,
                                                  self._datasrc_client,
                                                  self._acl_map))
                                                  self._acl_map))
-        self._session._UpdateSession__get_update_zone()
+        self._session._get_update_zone()
+        self._session._create_diff()
+
+    def tearDown(self):
+        # With the Updater created in _get_update_zone, and tests
+        # doing all kinds of crazy stuff, one might get database locked
+        # errors if it doesn't clean up explicitely after each test
+        self._session = None
 
 
     def check_response(self, msg, expected_rcode):
     def check_response(self, msg, expected_rcode):
         '''Perform common checks on update resposne message.'''
         '''Perform common checks on update resposne message.'''
@@ -463,7 +561,8 @@ class SessionTest(SessionTestBase):
         zconfig = ZoneConfig([], TEST_RRCLASS, self._datasrc_client,
         zconfig = ZoneConfig([], TEST_RRCLASS, self._datasrc_client,
                              self._acl_map)
                              self._acl_map)
         session = UpdateSession(msg, TEST_CLIENT4, zconfig)
         session = UpdateSession(msg, TEST_CLIENT4, zconfig)
-        session._UpdateSession__get_update_zone()
+        session._get_update_zone()
+        session._create_diff()
         # compare the to_text output of the rcodes (nicer error messages)
         # compare the to_text output of the rcodes (nicer error messages)
         # This call itself should also be done by handle(),
         # This call itself should also be done by handle(),
         # but just for better failures, it is first called on its own
         # but just for better failures, it is first called on its own
@@ -488,7 +587,8 @@ class SessionTest(SessionTestBase):
         zconfig = ZoneConfig([], TEST_RRCLASS, self._datasrc_client,
         zconfig = ZoneConfig([], TEST_RRCLASS, self._datasrc_client,
                              self._acl_map)
                              self._acl_map)
         session = UpdateSession(msg, TEST_CLIENT4, zconfig)
         session = UpdateSession(msg, TEST_CLIENT4, zconfig)
-        session._UpdateSession__get_update_zone()
+        session._get_update_zone()
+        session._create_diff()
         # compare the to_text output of the rcodes (nicer error messages)
         # compare the to_text output of the rcodes (nicer error messages)
         # This call itself should also be done by handle(),
         # This call itself should also be done by handle(),
         # but just for better failures, it is first called on its own
         # but just for better failures, it is first called on its own

+ 35 - 4
src/lib/python/isc/xfrin/diff.py

@@ -25,6 +25,7 @@ But for now, it lives here.
 
 
 import isc.dns
 import isc.dns
 import isc.log
 import isc.log
+from isc.datasrc import ZoneFinder
 from isc.log_messages.libxfrin_messages import *
 from isc.log_messages.libxfrin_messages import *
 
 
 class NoSuchZone(Exception):
 class NoSuchZone(Exception):
@@ -119,7 +120,7 @@ class Diff:
         else:
         else:
             self.__buffer = []
             self.__buffer = []
 
 
-    def __check_commited(self):
+    def __check_committed(self):
         """
         """
         This checks if the diff is already commited or broken. If it is, it
         This checks if the diff is already commited or broken. If it is, it
         raises ValueError. This check is for methods that need to work only on
         raises ValueError. This check is for methods that need to work only on
@@ -169,7 +170,7 @@ class Diff:
         - in single_update_mode if any later rr is of type SOA (both for
         - in single_update_mode if any later rr is of type SOA (both for
           addition and deletion)
           addition and deletion)
         """
         """
-        self.__check_commited()
+        self.__check_committed()
         if rr.get_rdata_count() != 1:
         if rr.get_rdata_count() != 1:
             raise ValueError('The rrset must contain exactly 1 Rdata, but ' +
             raise ValueError('The rrset must contain exactly 1 Rdata, but ' +
                              'it holds ' + str(rr.get_rdata_count()))
                              'it holds ' + str(rr.get_rdata_count()))
@@ -298,7 +299,7 @@ class Diff:
                 else:
                 else:
                     raise ValueError('Unknown operation ' + operation)
                     raise ValueError('Unknown operation ' + operation)
 
 
-        self.__check_commited()
+        self.__check_committed()
         # First, compact the data
         # First, compact the data
         self.compact()
         self.compact()
         try:
         try:
@@ -330,7 +331,7 @@ class Diff:
 
 
         This might raise isc.datasrc.Error.
         This might raise isc.datasrc.Error.
         """
         """
-        self.__check_commited()
+        self.__check_committed()
         # Push the data inside the data source
         # Push the data inside the data source
         self.apply()
         self.apply()
         # Make sure they are visible.
         # Make sure they are visible.
@@ -376,3 +377,33 @@ class Diff:
             raise ValueError("Separate buffers requested in single-update mode")
             raise ValueError("Separate buffers requested in single-update mode")
         else:
         else:
             return (self.__deletions, self.__additions)
             return (self.__deletions, self.__additions)
+
+    def find(self, name, rrtype,
+             options=(ZoneFinder.NO_WILDCARD | ZoneFinder.FIND_GLUE_OK)):
+        """
+        Calls the find() method in the ZoneFinder associated with this
+        Diff's ZoneUpdater, i.e. the find() on the zone as it was on the
+        moment this Diff object got created.
+        See the ZoneFinder documentation for a full description.
+        Note that the result does not include changes made in this Diff
+        instance so far.
+        Options default to NO_WILDCARD and FIND_GLUE_OK.
+        Raises a ValueError if the Diff has been committed already
+        """
+        self.__check_committed()
+        return self.__updater.find(name, rrtype, options)
+
+    def find_all(self, name,
+                 options=(ZoneFinder.NO_WILDCARD | ZoneFinder.FIND_GLUE_OK)):
+        """
+        Calls the find() method in the ZoneFinder associated with this
+        Diff's ZoneUpdater, i.e. the find_all() on the zone as it was on the
+        moment this Diff object got created.
+        See the ZoneFinder documentation for a full description.
+        Note that the result does not include changes made in this Diff
+        instance so far.
+        Options default to NO_WILDCARD and FIND_GLUE_OK.
+        Raises a ValueError if the Diff has been committed already
+        """
+        self.__check_committed()
+        return self.__updater.find_all(name, options)

+ 88 - 1
src/lib/python/isc/xfrin/tests/diff_tests.py

@@ -15,7 +15,7 @@
 
 
 import isc.log
 import isc.log
 import unittest
 import unittest
-import isc.datasrc
+from isc.datasrc import ZoneFinder
 from isc.dns import Name, RRset, RRClass, RRType, RRTTL, Rdata
 from isc.dns import Name, RRset, RRClass, RRType, RRTTL, Rdata
 from isc.xfrin.diff import Diff, NoSuchZone
 from isc.xfrin.diff import Diff, NoSuchZone
 
 
@@ -48,6 +48,13 @@ class DiffTest(unittest.TestCase):
         self.__broken_called = False
         self.__broken_called = False
         self.__warn_called = False
         self.__warn_called = False
         self.__should_replace = False
         self.__should_replace = False
+        self.__find_called = False
+        self.__find_name = None
+        self.__find_type = None
+        self.__find_options = None
+        self.__find_all_called = False
+        self.__find_all_name = None
+        self.__find_all_options = None
         # Some common values
         # Some common values
         self.__rrclass = RRClass.IN()
         self.__rrclass = RRClass.IN()
         self.__type = RRType.A()
         self.__type = RRType.A()
@@ -156,6 +163,23 @@ class DiffTest(unittest.TestCase):
 
 
         return self
         return self
 
 
+    def find(self, name, rrtype, options=None):
+        self.__find_called = True
+        self.__find_name = name
+        self.__find_type = rrtype
+        self.__find_options = options
+        # Doesn't really matter what is returned, as long
+        # as the test can check that it's passed along
+        return "find_return"
+
+    def find_all(self, name, options=None):
+        self.__find_all_called = True
+        self.__find_all_name = name
+        self.__find_all_options = options
+        # Doesn't really matter what is returned, as long
+        # as the test can check that it's passed along
+        return "find_all_return"
+
     def test_create(self):
     def test_create(self):
         """
         """
         This test the case when the diff is successfuly created. It just
         This test the case when the diff is successfuly created. It just
@@ -265,6 +289,9 @@ class DiffTest(unittest.TestCase):
         self.assertRaises(ValueError, diff.commit)
         self.assertRaises(ValueError, diff.commit)
         self.assertRaises(ValueError, diff.add_data, self.__rrset2)
         self.assertRaises(ValueError, diff.add_data, self.__rrset2)
         self.assertRaises(ValueError, diff.delete_data, self.__rrset1)
         self.assertRaises(ValueError, diff.delete_data, self.__rrset1)
+        self.assertRaises(ValueError, diff.find, Name('foo.example.org.'),
+                          RRType.A())
+        self.assertRaises(ValueError, diff.find_all, Name('foo.example.org.'))
         diff.apply = orig_apply
         diff.apply = orig_apply
         self.assertRaises(ValueError, diff.apply)
         self.assertRaises(ValueError, diff.apply)
         # This one does not state it should raise, so check it doesn't
         # This one does not state it should raise, so check it doesn't
@@ -587,6 +614,66 @@ class DiffTest(unittest.TestCase):
         self.assertRaises(ValueError, diff.add_data, a)
         self.assertRaises(ValueError, diff.add_data, a)
         self.assertRaises(ValueError, diff.delete_data, a)
         self.assertRaises(ValueError, diff.delete_data, a)
 
 
+    def test_find(self):
+        diff = Diff(self, Name('example.org.'))
+        name = Name('www.example.org.')
+        rrtype = RRType.A()
+
+        self.assertFalse(self.__find_called)
+        self.assertEqual(None, self.__find_name)
+        self.assertEqual(None, self.__find_type)
+        self.assertEqual(None, self.__find_options)
+
+        self.assertEqual("find_return", diff.find(name, rrtype))
+
+        self.assertTrue(self.__find_called)
+        self.assertEqual(name, self.__find_name)
+        self.assertEqual(rrtype, self.__find_type)
+        self.assertEqual(ZoneFinder.NO_WILDCARD | ZoneFinder.FIND_GLUE_OK,
+                         self.__find_options)
+
+    def test_find_options(self):
+        diff = Diff(self, Name('example.org.'))
+        name = Name('foo.example.org.')
+        rrtype = RRType.TXT()
+        options = ZoneFinder.NO_WILDCARD
+
+        self.assertEqual("find_return", diff.find(name, rrtype, options))
+
+        self.assertTrue(self.__find_called)
+        self.assertEqual(name, self.__find_name)
+        self.assertEqual(rrtype, self.__find_type)
+        self.assertEqual(options, self.__find_options)
+
+    def test_find_all(self):
+        diff = Diff(self, Name('example.org.'))
+        name = Name('www.example.org.')
+
+        self.assertFalse(self.__find_all_called)
+        self.assertEqual(None, self.__find_all_name)
+        self.assertEqual(None, self.__find_all_options)
+
+        self.assertEqual("find_all_return", diff.find_all(name))
+
+        self.assertTrue(self.__find_all_called)
+        self.assertEqual(name, self.__find_all_name)
+        self.assertEqual(ZoneFinder.NO_WILDCARD | ZoneFinder.FIND_GLUE_OK,
+                         self.__find_all_options)
+
+    def test_find_all_options(self):
+        diff = Diff(self, Name('example.org.'))
+        name = Name('www.example.org.')
+        options = isc.datasrc.ZoneFinder.NO_WILDCARD
+
+        self.assertFalse(self.__find_all_called)
+        self.assertEqual(None, self.__find_all_name)
+        self.assertEqual(None, self.__find_all_options)
+
+        self.assertEqual("find_all_return", diff.find_all(name, options))
+
+        self.assertTrue(self.__find_all_called)
+        self.assertEqual(name, self.__find_all_name)
+        self.assertEqual(options, self.__find_all_options)
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     isc.log.init("bind10")
     isc.log.init("bind10")