Browse Source

[1288] switch to new API 4: update notify_out initialization using new API.

JINMEI Tatuya 13 years ago
parent
commit
5dc6be6feb

+ 45 - 26
src/lib/python/isc/notify/notify_out.py

@@ -21,6 +21,7 @@ import threading
 import time
 import errno
 from isc.datasrc import sqlite3_ds
+from isc.datasrc import DataSourceClient
 from isc.net import addr
 import isc
 from isc.log_messages.notify_out_messages import *
@@ -31,7 +32,7 @@ logger = isc.log.Logger("notify_out")
 # we can't import we should not start anyway, and logging an error
 # is a bad idea since the logging system is most likely not
 # initialized yet. see trac ticket #1103
-from pydnspp import *
+from isc.dns import *
 
 ZONE_NEW_DATA_READY_CMD = 'zone_new_data_ready'
 _MAX_NOTIFY_NUM = 30
@@ -123,16 +124,20 @@ class NotifyOut:
         self._nonblock_event = threading.Event()
 
     def _init_notify_out(self, datasrc_file):
-        '''Get all the zones name and its notify target's address
+        '''Get all the zones name and its notify target's address.
+
         TODO, currently the zones are got by going through the zone
         table in database. There should be a better way to get them
         and also the setting 'also_notify', and there should be one
-        mechanism to cover the changed datasrc.'''
+        mechanism to cover the changed datasrc.
+
+        '''
         self._db_file = datasrc_file
         for zone_name, zone_class in sqlite3_ds.get_zones_info(datasrc_file):
             zone_id = (zone_name, zone_class)
             self._notify_infos[zone_id] = ZoneNotifyInfo(zone_name, zone_class)
-            slaves = self._get_notify_slaves_from_ns(zone_name)
+            slaves = self._get_notify_slaves_from_ns(Name(zone_name),
+                                                     RRClass(zone_class))
             for item in slaves:
                 self._notify_infos[zone_id].notify_slaves.append((item, 53))
 
@@ -234,7 +239,7 @@ class NotifyOut:
     def _get_rdata_data(self, rr):
         return rr[7].strip()
 
-    def _get_notify_slaves_from_ns(self, zone_name):
+    def _get_notify_slaves_from_ns(self, zone_name, zone_class):
         '''Get all NS records, then remove the primary master from ns rrset,
         then use the name in NS record rdata part to get the a/aaaa records
         in the same zone. the targets listed in a/aaaa record rdata are treated
@@ -243,27 +248,41 @@ class NotifyOut:
         but not correct, it can't handle the delegation slaves, or the CNAME
         and DNAME logic.
         TODO. the function should be provided by one library.'''
-        ns_rrset = sqlite3_ds.get_zone_rrset(zone_name, zone_name, 'NS', self._db_file)
-        soa_rrset = sqlite3_ds.get_zone_rrset(zone_name, zone_name, 'SOA', self._db_file)
-        ns_rr_name = []
-        for ns in ns_rrset:
-            ns_rr_name.append(self._get_rdata_data(ns))
-
-        if len(soa_rrset) > 0:
-            sname = (soa_rrset[0][sqlite3_ds.RR_RDATA_INDEX].split(' '))[0].strip() #TODO, bad hardcode to get rdata part
-            if sname in ns_rr_name:
-                ns_rr_name.remove(sname)
-
-        addr_list = []
-        for rr_name in ns_rr_name:
-            a_rrset = sqlite3_ds.get_zone_rrset(zone_name, rr_name, 'A', self._db_file)
-            aaaa_rrset = sqlite3_ds.get_zone_rrset(zone_name, rr_name, 'AAAA', self._db_file)
-            for rr in a_rrset:
-                addr_list.append(self._get_rdata_data(rr))
-            for rr in aaaa_rrset:
-                addr_list.append(self._get_rdata_data(rr))
-
-        return addr_list
+        datasrc_config = '{ \"database_file\": \"' + self._db_file + '\"}'
+        result, finder = DataSourceClient('sqlite3',
+                                          datasrc_config).find_zone(zone_name)
+        if result is not DataSourceClient.SUCCESS:
+            return []
+
+        result, ns_rrset = finder.find(zone_name, RRType.NS(), None,
+                                       finder.FIND_DEFAULT)
+        if result is not finder.SUCCESS or ns_rrset is None:
+            # TODO: Log it.
+            return []
+        result, soa_rrset = finder.find(zone_name, RRType.SOA(), None,
+                                       finder.FIND_DEFAULT)
+        if result is not finder.SUCCESS or soa_rrset is None or \
+                soa_rrset.get_rdata_count() != 1:
+            # TODO: Log it.
+            return []           # broken zone anyway, stop here.
+        soa_mname = Name(soa_rrset.get_rdata()[0].to_text().split(' ')[0])
+
+        addrs = []
+        for ns_rdata in ns_rrset.get_rdata():
+            ns_name = Name(ns_rdata.to_text())
+            if soa_mname == ns_name:
+                continue
+            result, rrset = finder.find(ns_name, RRType.A(), None,
+                                        finder.FIND_DEFAULT)
+            if result is finder.SUCCESS and rrset is not None:
+                addrs.extend([a.to_text() for a in rrset.get_rdata()])
+
+            result, rrset = finder.find(ns_name, RRType.AAAA(), None,
+                                        finder.FIND_DEFAULT)
+            if result is finder.SUCCESS and rrset is not None:
+                addrs.extend([aaaa.to_text() for aaaa in rrset.get_rdata()])
+
+        return addrs
 
     def _prepare_select_info(self):
         '''

+ 6 - 4
src/lib/python/isc/notify/tests/notify_out_test.py

@@ -22,6 +22,7 @@ import socket
 from isc.datasrc import sqlite3_ds
 from isc.notify import notify_out, SOCK_DATA
 import isc.log
+from isc.dns import *
 
 # our fake socket, where we can read and insert messages
 class MockSocket():
@@ -341,7 +342,8 @@ class TestNotifyOut(unittest.TestCase):
             yield item
 
     def test_get_notify_slaves_from_ns(self):
-        records = self._notify._get_notify_slaves_from_ns('example.net.')
+        records = self._notify._get_notify_slaves_from_ns(Name('example.net.'),
+                                                          RRClass.IN())
         self.assertEqual(6, len(records))
         self.assertEqual('8:8::8:8', records[5])
         self.assertEqual('7.7.7.7', records[4])
@@ -350,7 +352,8 @@ class TestNotifyOut(unittest.TestCase):
         self.assertEqual('4:4::4:4', records[1])
         self.assertEqual('3.3.3.3', records[0])
 
-        records = self._notify._get_notify_slaves_from_ns('example.com.')
+        records = self._notify._get_notify_slaves_from_ns(Name('example.com.'),
+                                                          RRClass.IN())
         self.assertEqual(3, len(records))
         self.assertEqual('5:5::5:5', records[2])
         self.assertEqual('4:4::4:4', records[1])
@@ -417,6 +420,5 @@ class TestNotifyOut(unittest.TestCase):
 
 if __name__== "__main__":
     isc.log.init("bind10")
+    isc.log.resetUnitTestRootLogger()
     unittest.main()
-
-