Browse Source

[1288] switch to new API 4: update notify_out initialization using new API.

JINMEI Tatuya 13 years ago
parent
commit
5dc6be6feb

+ 45 - 26
src/lib/python/isc/notify/notify_out.py

@@ -21,6 +21,7 @@ import threading
 import time
 import time
 import errno
 import errno
 from isc.datasrc import sqlite3_ds
 from isc.datasrc import sqlite3_ds
+from isc.datasrc import DataSourceClient
 from isc.net import addr
 from isc.net import addr
 import isc
 import isc
 from isc.log_messages.notify_out_messages import *
 from isc.log_messages.notify_out_messages import *
@@ -31,7 +32,7 @@ logger = isc.log.Logger("notify_out")
 # we can't import we should not start anyway, and logging an error
 # we can't import we should not start anyway, and logging an error
 # is a bad idea since the logging system is most likely not
 # is a bad idea since the logging system is most likely not
 # initialized yet. see trac ticket #1103
 # initialized yet. see trac ticket #1103
-from pydnspp import *
+from isc.dns import *
 
 
 ZONE_NEW_DATA_READY_CMD = 'zone_new_data_ready'
 ZONE_NEW_DATA_READY_CMD = 'zone_new_data_ready'
 _MAX_NOTIFY_NUM = 30
 _MAX_NOTIFY_NUM = 30
@@ -123,16 +124,20 @@ class NotifyOut:
         self._nonblock_event = threading.Event()
         self._nonblock_event = threading.Event()
 
 
     def _init_notify_out(self, datasrc_file):
     def _init_notify_out(self, datasrc_file):
-        '''Get all the zones name and its notify target's address
+        '''Get all the zones name and its notify target's address.
+
         TODO, currently the zones are got by going through the zone
         TODO, currently the zones are got by going through the zone
         table in database. There should be a better way to get them
         table in database. There should be a better way to get them
         and also the setting 'also_notify', and there should be one
         and also the setting 'also_notify', and there should be one
-        mechanism to cover the changed datasrc.'''
+        mechanism to cover the changed datasrc.
+
+        '''
         self._db_file = datasrc_file
         self._db_file = datasrc_file
         for zone_name, zone_class in sqlite3_ds.get_zones_info(datasrc_file):
         for zone_name, zone_class in sqlite3_ds.get_zones_info(datasrc_file):
             zone_id = (zone_name, zone_class)
             zone_id = (zone_name, zone_class)
             self._notify_infos[zone_id] = ZoneNotifyInfo(zone_name, zone_class)
             self._notify_infos[zone_id] = ZoneNotifyInfo(zone_name, zone_class)
-            slaves = self._get_notify_slaves_from_ns(zone_name)
+            slaves = self._get_notify_slaves_from_ns(Name(zone_name),
+                                                     RRClass(zone_class))
             for item in slaves:
             for item in slaves:
                 self._notify_infos[zone_id].notify_slaves.append((item, 53))
                 self._notify_infos[zone_id].notify_slaves.append((item, 53))
 
 
@@ -234,7 +239,7 @@ class NotifyOut:
     def _get_rdata_data(self, rr):
     def _get_rdata_data(self, rr):
         return rr[7].strip()
         return rr[7].strip()
 
 
-    def _get_notify_slaves_from_ns(self, zone_name):
+    def _get_notify_slaves_from_ns(self, zone_name, zone_class):
         '''Get all NS records, then remove the primary master from ns rrset,
         '''Get all NS records, then remove the primary master from ns rrset,
         then use the name in NS record rdata part to get the a/aaaa records
         then use the name in NS record rdata part to get the a/aaaa records
         in the same zone. the targets listed in a/aaaa record rdata are treated
         in the same zone. the targets listed in a/aaaa record rdata are treated
@@ -243,27 +248,41 @@ class NotifyOut:
         but not correct, it can't handle the delegation slaves, or the CNAME
         but not correct, it can't handle the delegation slaves, or the CNAME
         and DNAME logic.
         and DNAME logic.
         TODO. the function should be provided by one library.'''
         TODO. the function should be provided by one library.'''
-        ns_rrset = sqlite3_ds.get_zone_rrset(zone_name, zone_name, 'NS', self._db_file)
-        soa_rrset = sqlite3_ds.get_zone_rrset(zone_name, zone_name, 'SOA', self._db_file)
-        ns_rr_name = []
-        for ns in ns_rrset:
-            ns_rr_name.append(self._get_rdata_data(ns))
-
-        if len(soa_rrset) > 0:
-            sname = (soa_rrset[0][sqlite3_ds.RR_RDATA_INDEX].split(' '))[0].strip() #TODO, bad hardcode to get rdata part
-            if sname in ns_rr_name:
-                ns_rr_name.remove(sname)
-
-        addr_list = []
-        for rr_name in ns_rr_name:
-            a_rrset = sqlite3_ds.get_zone_rrset(zone_name, rr_name, 'A', self._db_file)
-            aaaa_rrset = sqlite3_ds.get_zone_rrset(zone_name, rr_name, 'AAAA', self._db_file)
-            for rr in a_rrset:
-                addr_list.append(self._get_rdata_data(rr))
-            for rr in aaaa_rrset:
-                addr_list.append(self._get_rdata_data(rr))
-
-        return addr_list
+        datasrc_config = '{ \"database_file\": \"' + self._db_file + '\"}'
+        result, finder = DataSourceClient('sqlite3',
+                                          datasrc_config).find_zone(zone_name)
+        if result is not DataSourceClient.SUCCESS:
+            return []
+
+        result, ns_rrset = finder.find(zone_name, RRType.NS(), None,
+                                       finder.FIND_DEFAULT)
+        if result is not finder.SUCCESS or ns_rrset is None:
+            # TODO: Log it.
+            return []
+        result, soa_rrset = finder.find(zone_name, RRType.SOA(), None,
+                                       finder.FIND_DEFAULT)
+        if result is not finder.SUCCESS or soa_rrset is None or \
+                soa_rrset.get_rdata_count() != 1:
+            # TODO: Log it.
+            return []           # broken zone anyway, stop here.
+        soa_mname = Name(soa_rrset.get_rdata()[0].to_text().split(' ')[0])
+
+        addrs = []
+        for ns_rdata in ns_rrset.get_rdata():
+            ns_name = Name(ns_rdata.to_text())
+            if soa_mname == ns_name:
+                continue
+            result, rrset = finder.find(ns_name, RRType.A(), None,
+                                        finder.FIND_DEFAULT)
+            if result is finder.SUCCESS and rrset is not None:
+                addrs.extend([a.to_text() for a in rrset.get_rdata()])
+
+            result, rrset = finder.find(ns_name, RRType.AAAA(), None,
+                                        finder.FIND_DEFAULT)
+            if result is finder.SUCCESS and rrset is not None:
+                addrs.extend([aaaa.to_text() for aaaa in rrset.get_rdata()])
+
+        return addrs
 
 
     def _prepare_select_info(self):
     def _prepare_select_info(self):
         '''
         '''

+ 6 - 4
src/lib/python/isc/notify/tests/notify_out_test.py

@@ -22,6 +22,7 @@ import socket
 from isc.datasrc import sqlite3_ds
 from isc.datasrc import sqlite3_ds
 from isc.notify import notify_out, SOCK_DATA
 from isc.notify import notify_out, SOCK_DATA
 import isc.log
 import isc.log
+from isc.dns import *
 
 
 # our fake socket, where we can read and insert messages
 # our fake socket, where we can read and insert messages
 class MockSocket():
 class MockSocket():
@@ -341,7 +342,8 @@ class TestNotifyOut(unittest.TestCase):
             yield item
             yield item
 
 
     def test_get_notify_slaves_from_ns(self):
     def test_get_notify_slaves_from_ns(self):
-        records = self._notify._get_notify_slaves_from_ns('example.net.')
+        records = self._notify._get_notify_slaves_from_ns(Name('example.net.'),
+                                                          RRClass.IN())
         self.assertEqual(6, len(records))
         self.assertEqual(6, len(records))
         self.assertEqual('8:8::8:8', records[5])
         self.assertEqual('8:8::8:8', records[5])
         self.assertEqual('7.7.7.7', records[4])
         self.assertEqual('7.7.7.7', records[4])
@@ -350,7 +352,8 @@ class TestNotifyOut(unittest.TestCase):
         self.assertEqual('4:4::4:4', records[1])
         self.assertEqual('4:4::4:4', records[1])
         self.assertEqual('3.3.3.3', records[0])
         self.assertEqual('3.3.3.3', records[0])
 
 
-        records = self._notify._get_notify_slaves_from_ns('example.com.')
+        records = self._notify._get_notify_slaves_from_ns(Name('example.com.'),
+                                                          RRClass.IN())
         self.assertEqual(3, len(records))
         self.assertEqual(3, len(records))
         self.assertEqual('5:5::5:5', records[2])
         self.assertEqual('5:5::5:5', records[2])
         self.assertEqual('4:4::4:4', records[1])
         self.assertEqual('4:4::4:4', records[1])
@@ -417,6 +420,5 @@ class TestNotifyOut(unittest.TestCase):
 
 
 if __name__== "__main__":
 if __name__== "__main__":
     isc.log.init("bind10")
     isc.log.init("bind10")
+    isc.log.resetUnitTestRootLogger()
     unittest.main()
     unittest.main()
-
-