Browse Source

1. Identify one zone with its name and class. 2. Avoid check the rcode of notify reply, since if we get the reply from the slave, it means the slave has get the notify message.

git-svn-id: svn://bind10.isc.org/svn/bind10/branches/trac289@2641 e5f2f494-b856-4b98-b285-d166d9295462
Likun Zhang 14 years ago
parent
commit
a9be5b02f2

+ 7 - 5
src/bin/xfrout/xfrout.py.in

@@ -433,8 +433,8 @@ class XfroutServer:
         td.daemon = True
         td.daemon = True
         td.start()
         td.start()
 
 
-    def send_notify(self, zone_name):
-        self._notifier.send_notify(zone_name)
+    def send_notify(self, zone_name, zone_class):
+        self._notifier.send_notify(zone_name, zone_class)
 
 
     def config_handler(self, new_config):
     def config_handler(self, new_config):
         '''Update config data. TODO. Do error check'''
         '''Update config data. TODO. Do error check'''
@@ -479,9 +479,11 @@ class XfroutServer:
         
         
         elif cmd == notify_out.ZONE_NEW_DATA_READY_CMD:
         elif cmd == notify_out.ZONE_NEW_DATA_READY_CMD:
             zone_name = args.get('zone_name')
             zone_name = args.get('zone_name')
-            if zone_name:
-                self._log.log_message("info", "Receive notify command for zone: '" + zone_name + "'")
-                self.send_notify(zone_name)
+            zone_class = args.get('zone_class')
+            if zone_name and zone_class:
+                self._log.log_message("info", "Receive notify command for zone:'%s/%s'" \
+                                     % (zone_name, zone_class))
+                self.send_notify(zone_name, zone_class)
                 answer = create_answer(0)
                 answer = create_answer(0)
             else:
             else:
                 answer = create_answer(1, "Bad command parameter:" + str(args))
                 answer = create_answer(1, "Bad command parameter:" + str(args))

+ 20 - 18
src/lib/python/isc/notify/notify_out.py

@@ -82,7 +82,7 @@ class ZoneNotifyInfo:
 
 
 class NotifyOut:
 class NotifyOut:
     def __init__(self, datasrc_file, log=None, verbose=True):
     def __init__(self, datasrc_file, log=None, verbose=True):
-        self._notify_infos = {}
+        self._notify_infos = {} # key is (zone_name, zone_class)
         self._waiting_zones = []
         self._waiting_zones = []
         self._notifying_zones = []
         self._notifying_zones = []
         self._log = log
         self._log = log
@@ -100,10 +100,11 @@ class NotifyOut:
         mechanism to cover the changed datasrc.'''
         mechanism to cover the changed datasrc.'''
         self._db_file = datasrc_file
         self._db_file = datasrc_file
         for zone_name, zone_class in sqlite3_ds.get_zones_info(datasrc_file):
         for zone_name, zone_class in sqlite3_ds.get_zones_info(datasrc_file):
-            self._notify_infos[zone_name] = ZoneNotifyInfo(zone_name, zone_class)
+            zone_id = (zone_name, zone_class)
+            self._notify_infos[zone_id] = ZoneNotifyInfo(zone_name, zone_class)
             slaves = self._get_notify_slaves_from_ns(zone_name)
             slaves = self._get_notify_slaves_from_ns(zone_name)
             for item in slaves:
             for item in slaves:
-                self._notify_infos[zone_name]._notify_slaves.append((item, 53))
+                self._notify_infos[zone_id]._notify_slaves.append((item, 53))
 
 
     def _get_rdata_data(self, rr):
     def _get_rdata_data(self, rr):
         return rr[7].strip()
         return rr[7].strip()
@@ -132,20 +133,22 @@ class NotifyOut:
 
 
         return addr_list
         return addr_list
 
 
-    def send_notify(self, zone_name):
+    def send_notify(self, zone_name, zone_class='IN'):
         if zone_name[len(zone_name) - 1] != '.':
         if zone_name[len(zone_name) - 1] != '.':
             zone_name += '.'
             zone_name += '.'
-        if zone_name not in self._notify_infos:
+
+        zone_id = (zone_name, zone_class)
+        if zone_id not in self._notify_infos:
             return
             return
 
 
         with self._lock:
         with self._lock:
-            if (self.notify_num >= _MAX_NOTIFY_NUM) or (zone_name in self._notifying_zones):
-                if zone_name not in self._waiting_zones:
-                    self._waiting_zones.append(zone_name)
+            if (self.notify_num >= _MAX_NOTIFY_NUM) or (zone_id in self._notifying_zones):
+                if zone_id not in self._waiting_zones:
+                    self._waiting_zones.append(zone_id)
             else:
             else:
-                self._notify_infos[zone_name].prepare_notify_out()
+                self._notify_infos[zone_id].prepare_notify_out()
                 self.notify_num += 1 
                 self.notify_num += 1 
-                self._notifying_zones.append(zone_name)
+                self._notifying_zones.append(zone_id)
 
 
     def _wait_for_notify_reply(self):
     def _wait_for_notify_reply(self):
         '''receive notify replies in specified time. returned value 
         '''receive notify replies in specified time. returned value 
@@ -217,11 +220,12 @@ class NotifyOut:
             zone_notify_info.finish_notify_out()
             zone_notify_info.finish_notify_out()
             with self._lock:
             with self._lock:
                 self.notify_num -= 1 
                 self.notify_num -= 1 
-                self._notifying_zones.remove(zone_notify_info.zone_name) 
+                self._notifying_zones.remove((zone_notify_info.zone_name, 
+                                              zone_notify_info.zone_class)) 
                 # trigger notify out for waiting zones
                 # trigger notify out for waiting zones
                 if len(self._waiting_zones) > 0:
                 if len(self._waiting_zones) > 0:
-                    zone_name = self._waiting_zones.pop(0) 
-                    self._notify_infos[zone_name].prepare_notify_out()
+                    zone_id = self._waiting_zones.pop(0) 
+                    self._notify_infos[zone_id].prepare_notify_out()
                     self.notify_num += 1 
                     self.notify_num += 1 
 
 
     def _send_notify_message_udp(self, zone_notify_info, addrinfo):
     def _send_notify_message_udp(self, zone_notify_info, addrinfo):
@@ -268,15 +272,13 @@ class NotifyOut:
 
 
     def _handle_notify_reply(self, zone_notify_info, msg_data):
     def _handle_notify_reply(self, zone_notify_info, msg_data):
         '''Parse the notify reply message.
         '''Parse the notify reply message.
-        TODO, the error message should be refined properly.'''
+        TODO, the error message should be refined properly.
+        rcode will not checked here, If we get the response
+        from the slave, it means the slaves has got the notify.'''
         msg = Message(Message.PARSE)
         msg = Message(Message.PARSE)
         try:
         try:
             errstr = 'notify reply error: '
             errstr = 'notify reply error: '
             msg.from_wire(msg_data)
             msg.from_wire(msg_data)
-            if (msg.get_rcode() != Rcode.NOERROR()):
-                self._log_msg('error', errstr + 'bad rcode')
-                return False
-
             if not msg.get_header_flag(MessageFlag.QR()):
             if not msg.get_header_flag(MessageFlag.QR()):
                 self._log_msg('error', errstr + 'bad flags')
                 self._log_msg('error', errstr + 'bad flags')
                 return False
                 return False

+ 33 - 22
src/lib/python/isc/notify/tests/notify_out_test.py

@@ -44,11 +44,13 @@ class TestNotifyOut(unittest.TestCase):
         sqlite3_ds.load(self._db_file.name, 'cn.', self._cn_data_reader)
         sqlite3_ds.load(self._db_file.name, 'cn.', self._cn_data_reader)
         sqlite3_ds.load(self._db_file.name, 'com.', self._com_data_reader)
         sqlite3_ds.load(self._db_file.name, 'com.', self._com_data_reader)
         self._notify = notify_out.NotifyOut(self._db_file.name)
         self._notify = notify_out.NotifyOut(self._db_file.name)
-        self._notify._notify_infos['com.'] = notify_out.ZoneNotifyInfo('com.', 'IN')
-        self._notify._notify_infos['cn.'] = notify_out.ZoneNotifyInfo('cn.', 'IN')
-        self._notify._notify_infos['org.'] = notify_out.ZoneNotifyInfo('org.', 'IN')
+        self._notify._notify_infos[('com.', 'IN')] = notify_out.ZoneNotifyInfo('com.', 'IN')
+        self._notify._notify_infos[('com.', 'CH')] = notify_out.ZoneNotifyInfo('com.', 'CH')
+        self._notify._notify_infos[('cn.', 'IN')] = notify_out.ZoneNotifyInfo('cn.', 'IN')
+        self._notify._notify_infos[('org.', 'IN')] = notify_out.ZoneNotifyInfo('org.', 'IN')
+        self._notify._notify_infos[('org.', 'CH')] = notify_out.ZoneNotifyInfo('org.', 'CH')
         
         
-        info = self._notify._notify_infos['cn.']
+        info = self._notify._notify_infos[('cn.', 'IN')]
         info._notify_slaves.append(('127.0.0.1', 53))
         info._notify_slaves.append(('127.0.0.1', 53))
         info._notify_slaves.append(('1.1.1.1', 5353))
         info._notify_slaves.append(('1.1.1.1', 5353))
 
 
@@ -60,18 +62,26 @@ class TestNotifyOut(unittest.TestCase):
     def test_send_notify(self):
     def test_send_notify(self):
         self._notify.send_notify('cn')
         self._notify.send_notify('cn')
         self.assertEqual(self._notify.notify_num, 1)
         self.assertEqual(self._notify.notify_num, 1)
-        self.assertEqual(self._notify._notifying_zones[0], 'cn.')
+        self.assertEqual(self._notify._notifying_zones[0], ('cn.','IN'))
 
 
         self._notify.send_notify('com')
         self._notify.send_notify('com')
         self.assertEqual(self._notify.notify_num, 2)
         self.assertEqual(self._notify.notify_num, 2)
-        self.assertEqual(self._notify._notifying_zones[1], 'com.')
+        self.assertEqual(self._notify._notifying_zones[1], ('com.','IN'))
+
+        notify_out._MAX_NOTIFY_NUM = 3
+        self._notify.send_notify('com', 'CH')
+        self.assertEqual(self._notify.notify_num, 3)
+        self.assertEqual(self._notify._notifying_zones[2], ('com.','CH'))
     
     
-        notify_out._MAX_NOTIFY_NUM = 2
         self._notify.send_notify('org.')
         self._notify.send_notify('org.')
-        self.assertEqual(self._notify._waiting_zones[0], 'org.')
+        self.assertEqual(self._notify._waiting_zones[0], ('org.', 'IN'))
         self._notify.send_notify('org.')
         self._notify.send_notify('org.')
         self.assertEqual(1, len(self._notify._waiting_zones))
         self.assertEqual(1, len(self._notify._waiting_zones))
 
 
+        self._notify.send_notify('org.', 'CH')
+        self.assertEqual(2, len(self._notify._waiting_zones))
+        self.assertEqual(self._notify._waiting_zones[1], ('org.', 'CH'))
+
     def test_wait_for_notify_reply(self):
     def test_wait_for_notify_reply(self):
         self._notify.send_notify('cn.')
         self._notify.send_notify('cn.')
         self._notify.send_notify('com.')
         self._notify.send_notify('com.')
@@ -84,9 +94,9 @@ class TestNotifyOut(unittest.TestCase):
 
 
         # Now make one socket be readable
         # Now make one socket be readable
         addr = ('localhost', 12340)
         addr = ('localhost', 12340)
-        self._notify._notify_infos['cn.']._sock.bind(addr)
-        self._notify._notify_infos['cn.'].notify_timeout = time.time() + 10
-        self._notify._notify_infos['com.'].notify_timeout = time.time() + 10
+        self._notify._notify_infos[('cn.', 'IN')]._sock.bind(addr)
+        self._notify._notify_infos[('cn.', 'IN')].notify_timeout = time.time() + 10
+        self._notify._notify_infos[('com.', 'IN')].notify_timeout = time.time() + 10
         
         
         send_fd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
         send_fd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
         #Send some data to socket 12340, to make the target socket be readable
         #Send some data to socket 12340, to make the target socket be readable
@@ -94,17 +104,18 @@ class TestNotifyOut(unittest.TestCase):
         replied_zones, timeout_zones = self._notify._wait_for_notify_reply()
         replied_zones, timeout_zones = self._notify._wait_for_notify_reply()
         self.assertEqual(len(replied_zones), 1)
         self.assertEqual(len(replied_zones), 1)
         self.assertEqual(len(timeout_zones), 1)
         self.assertEqual(len(timeout_zones), 1)
-        self.assertTrue('cn.' in replied_zones.keys())
-        self.assertTrue('com.' in timeout_zones.keys())
-        self.assertLess(time.time(), self._notify._notify_infos['com.'].notify_timeout)
+        self.assertTrue(('cn.', 'IN') in replied_zones.keys())
+        self.assertTrue(('com.', 'IN') in timeout_zones.keys())
+        self.assertLess(time.time(), self._notify._notify_infos[('com.', 'IN')].notify_timeout)
     
     
     def test_notify_next_target(self):
     def test_notify_next_target(self):
         self._notify.send_notify('cn.')
         self._notify.send_notify('cn.')
         self._notify.send_notify('com.')
         self._notify.send_notify('com.')
         notify_out._MAX_NOTIFY_NUM = 2
         notify_out._MAX_NOTIFY_NUM = 2
         self._notify.send_notify('org.')
         self._notify.send_notify('org.')
+        self._notify.send_notify('com.', 'CH')
 
 
-        info = self._notify._notify_infos['cn.']
+        info = self._notify._notify_infos[('cn.', 'IN')]
         self._notify._notify_next_target(info)
         self._notify._notify_next_target(info)
         self.assertEqual(0, info.notify_try_num)
         self.assertEqual(0, info.notify_try_num)
         self.assertEqual(info.get_current_notify_target(), ('1.1.1.1', 5353))
         self.assertEqual(info.get_current_notify_target(), ('1.1.1.1', 5353))
@@ -114,22 +125,22 @@ class TestNotifyOut(unittest.TestCase):
         self.assertEqual(0, info.notify_try_num)
         self.assertEqual(0, info.notify_try_num)
         self.assertIsNone(info.get_current_notify_target())
         self.assertIsNone(info.get_current_notify_target())
         self.assertEqual(2, self._notify.notify_num)
         self.assertEqual(2, self._notify.notify_num)
-        self.assertEqual(0, len(self._notify._waiting_zones))
+        self.assertEqual(1, len(self._notify._waiting_zones))
 
 
-        com_info = self._notify._notify_infos['com.']
+        com_info = self._notify._notify_infos[('com.', 'IN')]
         self._notify._notify_next_target(com_info)
         self._notify._notify_next_target(com_info)
-        self.assertEqual(1, self._notify.notify_num)
+        self.assertEqual(2, self._notify.notify_num)
         self.assertEqual(0, len(self._notify._notifying_zones))
         self.assertEqual(0, len(self._notify._notifying_zones))
     
     
     def test_handle_notify_reply(self):
     def test_handle_notify_reply(self):
         self.assertFalse(self._notify._handle_notify_reply(None, b'badmsg'))
         self.assertFalse(self._notify._handle_notify_reply(None, b'badmsg'))
-        com_info = self._notify._notify_infos['com.']
+        com_info = self._notify._notify_infos[('com.', 'IN')]
         com_info.notify_msg_id = 0X2f18
         com_info.notify_msg_id = 0X2f18
         data = b'\x2f\x18\xa0\x00\x00\x01\x00\x00\x00\x00\x00\x00\x02tw\x02cn\x00\x00\x06\x00\x01'
         data = b'\x2f\x18\xa0\x00\x00\x01\x00\x00\x00\x00\x00\x00\x02tw\x02cn\x00\x00\x06\x00\x01'
         self.assertTrue(self._notify._handle_notify_reply(com_info, data))
         self.assertTrue(self._notify._handle_notify_reply(com_info, data))
 
 
     def test_send_notify_message_udp(self):
     def test_send_notify_message_udp(self):
-        com_info = self._notify._notify_infos['cn.']
+        com_info = self._notify._notify_infos[('cn.', 'IN')]
         com_info.prepare_notify_out()
         com_info.prepare_notify_out()
         ret = self._notify._send_notify_message_udp(com_info, ('1.1.1.1', 53))
         ret = self._notify._send_notify_message_udp(com_info, ('1.1.1.1', 53))
         self.assertTrue(ret)
         self.assertTrue(ret)
@@ -144,7 +155,7 @@ class TestNotifyOut(unittest.TestCase):
         notify_out._MAX_NOTIFY_NUM = 2
         notify_out._MAX_NOTIFY_NUM = 2
         self._notify.send_notify('org.')
         self._notify.send_notify('org.')
 
 
-        cn_info = self._notify._notify_infos['cn.']
+        cn_info = self._notify._notify_infos[('cn.', 'IN')]
         cn_info.prepare_notify_out()
         cn_info.prepare_notify_out()
 
 
         cn_info.notify_try_num = 2
         cn_info.notify_try_num = 2
@@ -211,7 +222,7 @@ class TestNotifyOut(unittest.TestCase):
     def test_init_notify_out(self):
     def test_init_notify_out(self):
         self._notify._init_notify_out(self._db_file.name)
         self._notify._init_notify_out(self._db_file.name)
         self.assertListEqual([('3.3.3.3', 53), ('4:4.4.4', 53), ('5:5.5.5', 53)], 
         self.assertListEqual([('3.3.3.3', 53), ('4:4.4.4', 53), ('5:5.5.5', 53)], 
-                             self._notify._notify_infos['com.']._notify_slaves)
+                             self._notify._notify_infos[('com.', 'IN')]._notify_slaves)
         
         
 if __name__== "__main__":
 if __name__== "__main__":
     unittest.main()
     unittest.main()