Parcourir la source

1. Identify one zone with its name and class. 2. Avoid check the rcode of notify reply, since if we get the reply from the slave, it means the slave has get the notify message.

git-svn-id: svn://bind10.isc.org/svn/bind10/branches/trac289@2641 e5f2f494-b856-4b98-b285-d166d9295462
Likun Zhang il y a 14 ans
Parent
commit
a9be5b02f2

+ 7 - 5
src/bin/xfrout/xfrout.py.in

@@ -433,8 +433,8 @@ class XfroutServer:
         td.daemon = True
         td.start()
 
-    def send_notify(self, zone_name):
-        self._notifier.send_notify(zone_name)
+    def send_notify(self, zone_name, zone_class):
+        self._notifier.send_notify(zone_name, zone_class)
 
     def config_handler(self, new_config):
         '''Update config data. TODO. Do error check'''
@@ -479,9 +479,11 @@ class XfroutServer:
         
         elif cmd == notify_out.ZONE_NEW_DATA_READY_CMD:
             zone_name = args.get('zone_name')
-            if zone_name:
-                self._log.log_message("info", "Receive notify command for zone: '" + zone_name + "'")
-                self.send_notify(zone_name)
+            zone_class = args.get('zone_class')
+            if zone_name and zone_class:
+                self._log.log_message("info", "Receive notify command for zone:'%s/%s'" \
+                                     % (zone_name, zone_class))
+                self.send_notify(zone_name, zone_class)
                 answer = create_answer(0)
             else:
                 answer = create_answer(1, "Bad command parameter:" + str(args))

+ 20 - 18
src/lib/python/isc/notify/notify_out.py

@@ -82,7 +82,7 @@ class ZoneNotifyInfo:
 
 class NotifyOut:
     def __init__(self, datasrc_file, log=None, verbose=True):
-        self._notify_infos = {}
+        self._notify_infos = {} # key is (zone_name, zone_class)
         self._waiting_zones = []
         self._notifying_zones = []
         self._log = log
@@ -100,10 +100,11 @@ class NotifyOut:
         mechanism to cover the changed datasrc.'''
         self._db_file = datasrc_file
         for zone_name, zone_class in sqlite3_ds.get_zones_info(datasrc_file):
-            self._notify_infos[zone_name] = ZoneNotifyInfo(zone_name, zone_class)
+            zone_id = (zone_name, zone_class)
+            self._notify_infos[zone_id] = ZoneNotifyInfo(zone_name, zone_class)
             slaves = self._get_notify_slaves_from_ns(zone_name)
             for item in slaves:
-                self._notify_infos[zone_name]._notify_slaves.append((item, 53))
+                self._notify_infos[zone_id]._notify_slaves.append((item, 53))
 
     def _get_rdata_data(self, rr):
         return rr[7].strip()
@@ -132,20 +133,22 @@ class NotifyOut:
 
         return addr_list
 
-    def send_notify(self, zone_name):
+    def send_notify(self, zone_name, zone_class='IN'):
         if zone_name[len(zone_name) - 1] != '.':
             zone_name += '.'
-        if zone_name not in self._notify_infos:
+
+        zone_id = (zone_name, zone_class)
+        if zone_id not in self._notify_infos:
             return
 
         with self._lock:
-            if (self.notify_num >= _MAX_NOTIFY_NUM) or (zone_name in self._notifying_zones):
-                if zone_name not in self._waiting_zones:
-                    self._waiting_zones.append(zone_name)
+            if (self.notify_num >= _MAX_NOTIFY_NUM) or (zone_id in self._notifying_zones):
+                if zone_id not in self._waiting_zones:
+                    self._waiting_zones.append(zone_id)
             else:
-                self._notify_infos[zone_name].prepare_notify_out()
+                self._notify_infos[zone_id].prepare_notify_out()
                 self.notify_num += 1 
-                self._notifying_zones.append(zone_name)
+                self._notifying_zones.append(zone_id)
 
     def _wait_for_notify_reply(self):
         '''receive notify replies in specified time. returned value 
@@ -217,11 +220,12 @@ class NotifyOut:
             zone_notify_info.finish_notify_out()
             with self._lock:
                 self.notify_num -= 1 
-                self._notifying_zones.remove(zone_notify_info.zone_name) 
+                self._notifying_zones.remove((zone_notify_info.zone_name, 
+                                              zone_notify_info.zone_class)) 
                 # trigger notify out for waiting zones
                 if len(self._waiting_zones) > 0:
-                    zone_name = self._waiting_zones.pop(0) 
-                    self._notify_infos[zone_name].prepare_notify_out()
+                    zone_id = self._waiting_zones.pop(0) 
+                    self._notify_infos[zone_id].prepare_notify_out()
                     self.notify_num += 1 
 
     def _send_notify_message_udp(self, zone_notify_info, addrinfo):
@@ -268,15 +272,13 @@ class NotifyOut:
 
     def _handle_notify_reply(self, zone_notify_info, msg_data):
         '''Parse the notify reply message.
-        TODO, the error message should be refined properly.'''
+        TODO, the error message should be refined properly.
+        rcode will not checked here, If we get the response
+        from the slave, it means the slaves has got the notify.'''
         msg = Message(Message.PARSE)
         try:
             errstr = 'notify reply error: '
             msg.from_wire(msg_data)
-            if (msg.get_rcode() != Rcode.NOERROR()):
-                self._log_msg('error', errstr + 'bad rcode')
-                return False
-
             if not msg.get_header_flag(MessageFlag.QR()):
                 self._log_msg('error', errstr + 'bad flags')
                 return False

+ 33 - 22
src/lib/python/isc/notify/tests/notify_out_test.py

@@ -44,11 +44,13 @@ class TestNotifyOut(unittest.TestCase):
         sqlite3_ds.load(self._db_file.name, 'cn.', self._cn_data_reader)
         sqlite3_ds.load(self._db_file.name, 'com.', self._com_data_reader)
         self._notify = notify_out.NotifyOut(self._db_file.name)
-        self._notify._notify_infos['com.'] = notify_out.ZoneNotifyInfo('com.', 'IN')
-        self._notify._notify_infos['cn.'] = notify_out.ZoneNotifyInfo('cn.', 'IN')
-        self._notify._notify_infos['org.'] = notify_out.ZoneNotifyInfo('org.', 'IN')
+        self._notify._notify_infos[('com.', 'IN')] = notify_out.ZoneNotifyInfo('com.', 'IN')
+        self._notify._notify_infos[('com.', 'CH')] = notify_out.ZoneNotifyInfo('com.', 'CH')
+        self._notify._notify_infos[('cn.', 'IN')] = notify_out.ZoneNotifyInfo('cn.', 'IN')
+        self._notify._notify_infos[('org.', 'IN')] = notify_out.ZoneNotifyInfo('org.', 'IN')
+        self._notify._notify_infos[('org.', 'CH')] = notify_out.ZoneNotifyInfo('org.', 'CH')
         
-        info = self._notify._notify_infos['cn.']
+        info = self._notify._notify_infos[('cn.', 'IN')]
         info._notify_slaves.append(('127.0.0.1', 53))
         info._notify_slaves.append(('1.1.1.1', 5353))
 
@@ -60,18 +62,26 @@ class TestNotifyOut(unittest.TestCase):
     def test_send_notify(self):
         self._notify.send_notify('cn')
         self.assertEqual(self._notify.notify_num, 1)
-        self.assertEqual(self._notify._notifying_zones[0], 'cn.')
+        self.assertEqual(self._notify._notifying_zones[0], ('cn.','IN'))
 
         self._notify.send_notify('com')
         self.assertEqual(self._notify.notify_num, 2)
-        self.assertEqual(self._notify._notifying_zones[1], 'com.')
+        self.assertEqual(self._notify._notifying_zones[1], ('com.','IN'))
+
+        notify_out._MAX_NOTIFY_NUM = 3
+        self._notify.send_notify('com', 'CH')
+        self.assertEqual(self._notify.notify_num, 3)
+        self.assertEqual(self._notify._notifying_zones[2], ('com.','CH'))
     
-        notify_out._MAX_NOTIFY_NUM = 2
         self._notify.send_notify('org.')
-        self.assertEqual(self._notify._waiting_zones[0], 'org.')
+        self.assertEqual(self._notify._waiting_zones[0], ('org.', 'IN'))
         self._notify.send_notify('org.')
         self.assertEqual(1, len(self._notify._waiting_zones))
 
+        self._notify.send_notify('org.', 'CH')
+        self.assertEqual(2, len(self._notify._waiting_zones))
+        self.assertEqual(self._notify._waiting_zones[1], ('org.', 'CH'))
+
     def test_wait_for_notify_reply(self):
         self._notify.send_notify('cn.')
         self._notify.send_notify('com.')
@@ -84,9 +94,9 @@ class TestNotifyOut(unittest.TestCase):
 
         # Now make one socket be readable
         addr = ('localhost', 12340)
-        self._notify._notify_infos['cn.']._sock.bind(addr)
-        self._notify._notify_infos['cn.'].notify_timeout = time.time() + 10
-        self._notify._notify_infos['com.'].notify_timeout = time.time() + 10
+        self._notify._notify_infos[('cn.', 'IN')]._sock.bind(addr)
+        self._notify._notify_infos[('cn.', 'IN')].notify_timeout = time.time() + 10
+        self._notify._notify_infos[('com.', 'IN')].notify_timeout = time.time() + 10
         
         send_fd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
         #Send some data to socket 12340, to make the target socket be readable
@@ -94,17 +104,18 @@ class TestNotifyOut(unittest.TestCase):
         replied_zones, timeout_zones = self._notify._wait_for_notify_reply()
         self.assertEqual(len(replied_zones), 1)
         self.assertEqual(len(timeout_zones), 1)
-        self.assertTrue('cn.' in replied_zones.keys())
-        self.assertTrue('com.' in timeout_zones.keys())
-        self.assertLess(time.time(), self._notify._notify_infos['com.'].notify_timeout)
+        self.assertTrue(('cn.', 'IN') in replied_zones.keys())
+        self.assertTrue(('com.', 'IN') in timeout_zones.keys())
+        self.assertLess(time.time(), self._notify._notify_infos[('com.', 'IN')].notify_timeout)
     
     def test_notify_next_target(self):
         self._notify.send_notify('cn.')
         self._notify.send_notify('com.')
         notify_out._MAX_NOTIFY_NUM = 2
         self._notify.send_notify('org.')
+        self._notify.send_notify('com.', 'CH')
 
-        info = self._notify._notify_infos['cn.']
+        info = self._notify._notify_infos[('cn.', 'IN')]
         self._notify._notify_next_target(info)
         self.assertEqual(0, info.notify_try_num)
         self.assertEqual(info.get_current_notify_target(), ('1.1.1.1', 5353))
@@ -114,22 +125,22 @@ class TestNotifyOut(unittest.TestCase):
         self.assertEqual(0, info.notify_try_num)
         self.assertIsNone(info.get_current_notify_target())
         self.assertEqual(2, self._notify.notify_num)
-        self.assertEqual(0, len(self._notify._waiting_zones))
+        self.assertEqual(1, len(self._notify._waiting_zones))
 
-        com_info = self._notify._notify_infos['com.']
+        com_info = self._notify._notify_infos[('com.', 'IN')]
         self._notify._notify_next_target(com_info)
-        self.assertEqual(1, self._notify.notify_num)
+        self.assertEqual(2, self._notify.notify_num)
         self.assertEqual(0, len(self._notify._notifying_zones))
     
     def test_handle_notify_reply(self):
         self.assertFalse(self._notify._handle_notify_reply(None, b'badmsg'))
-        com_info = self._notify._notify_infos['com.']
+        com_info = self._notify._notify_infos[('com.', 'IN')]
         com_info.notify_msg_id = 0X2f18
         data = b'\x2f\x18\xa0\x00\x00\x01\x00\x00\x00\x00\x00\x00\x02tw\x02cn\x00\x00\x06\x00\x01'
         self.assertTrue(self._notify._handle_notify_reply(com_info, data))
 
     def test_send_notify_message_udp(self):
-        com_info = self._notify._notify_infos['cn.']
+        com_info = self._notify._notify_infos[('cn.', 'IN')]
         com_info.prepare_notify_out()
         ret = self._notify._send_notify_message_udp(com_info, ('1.1.1.1', 53))
         self.assertTrue(ret)
@@ -144,7 +155,7 @@ class TestNotifyOut(unittest.TestCase):
         notify_out._MAX_NOTIFY_NUM = 2
         self._notify.send_notify('org.')
 
-        cn_info = self._notify._notify_infos['cn.']
+        cn_info = self._notify._notify_infos[('cn.', 'IN')]
         cn_info.prepare_notify_out()
 
         cn_info.notify_try_num = 2
@@ -211,7 +222,7 @@ class TestNotifyOut(unittest.TestCase):
     def test_init_notify_out(self):
         self._notify._init_notify_out(self._db_file.name)
         self.assertListEqual([('3.3.3.3', 53), ('4:4.4.4', 53), ('5:5.5.5', 53)], 
-                             self._notify._notify_infos['com.']._notify_slaves)
+                             self._notify._notify_infos[('com.', 'IN')]._notify_slaves)
         
 if __name__== "__main__":
     unittest.main()