Browse Source

Update the code according stephen's review result.

git-svn-id: svn://bind10.isc.org/svn/bind10/branches/trac289@2687 e5f2f494-b856-4b98-b285-d166d9295462
Likun Zhang 14 years ago
parent
commit
b18fb328d5

+ 7 - 0
src/lib/python/isc/datasrc/sqlite3_ds.py

@@ -18,6 +18,13 @@
 import sqlite3, re, random
 import isc
 
+
+#define the index of different part of one record
+RR_TYPE_INDEX = 5
+RR_NAME_INDEX = 2
+RR_TTL_INDEX = 4
+RR_RDATA_INDEX = 7
+
 #########################################################################
 # define exceptions
 #########################################################################

+ 0 - 0
src/lib/python/isc/notify/TODO


+ 117 - 57
src/lib/python/isc/notify/notify_out.py

@@ -32,46 +32,63 @@ _MAX_NOTIFY_TRY_NUM = 5
 _EVENT_NONE = 0
 _EVENT_READ = 1
 _EVENT_TIMEOUT = 2
-_NOTIFY_TIMEOUT = 2
+_NOTIFY_TIMEOUT = 1
+_IDLE_SLEEP_TIME = 0.5
+
+# define the rcode for parsing notify reply message
+_REPLY_OK = 0
+_BAD_QUERY_ID = 1
+_BAD_QUERY_NAME = 2
+_BAD_OPCODE = 3
+_BAD_QR = 4
+_BAD_REPLY_PACKET = 5
 
 def addr_to_str(addr):
     return '%s#%s' % (addr[0], addr[1])
 
 def dispatcher(notifier):
+    '''The loop function for handling notify related events.
+    If one zone get the notify reply before timeout, call the
+    handle to process the reply. If one zone can't get the notify
+    before timeout, call the handler to resend notify or notify 
+    next slave.  
+    notifier: one object of class NotifyOut. '''
     while True:
         replied_zones, not_replied_zones = notifier._wait_for_notify_reply()
         if len(replied_zones) == 0 and len(not_replied_zones) == 0:
-            time.sleep(0.5) # A better time?
+            time.sleep(_IDLE_SLEEP_TIME) #TODO set a better time for idle sleep
             continue
 
         for name_ in replied_zones:
             notifier._zone_notify_handler(replied_zones[name_], _EVENT_READ)
             
         for name_ in not_replied_zones:
-            if not_replied_zones[name_].notify_timeout < time.time():
+            if not_replied_zones[name_].notify_timeout <= time.time():
                 notifier._zone_notify_handler(not_replied_zones[name_], _EVENT_TIMEOUT)
  
 class ZoneNotifyInfo:
-    '''This class keeps track of notify-out information for one zone.
-    timeout_: absolute time for next notify reply.
-    '''    
-    def __init__(self, zone_name_, klass):
-        self._notify_slaves = []
+    '''This class keeps track of notify-out information for one zone.'''
+
+    def __init__(self, zone_name_, class_):
+        '''notify_timeout_: absolute time for next notify reply. when the zone 
+        is preparing for sending notify message, notify_timeout_ is set to now, 
+        that means the first sending is triggered by the 'Timeout' mechanism. 
+        '''
         self._notify_current = None
         self._slave_index = 0
         self._sock = None
 
+        self.notify_slaves = []
         self.zone_name = zone_name_
-        self.zone_class = klass
+        self.zone_class = class_
         self.notify_msg_id = 0
         self.notify_timeout = 0
-        # Notify times sending to one target.
-        self.notify_try_num = 0 
+        self.notify_try_num = 0  #Notify times sending to one target.
        
     def set_next_notify_target(self):
-        if self._slave_index < (len(self._notify_slaves) - 1):
+        if self._slave_index < (len(self.notify_slaves) - 1):
             self._slave_index += 1
-            self._notify_current = self._notify_slaves[self._slave_index]
+            self._notify_current = self.notify_slaves[self._slave_index]
         else:
             self._notify_current = None
 
@@ -81,8 +98,8 @@ class ZoneNotifyInfo:
         self.notify_timeout = time.time()
         self.notify_try_num = 0
         self._slave_index = 0
-        if len(self._notify_slaves) > 0:
-            self._notify_current = self._notify_slaves[0]
+        if len(self.notify_slaves) > 0:
+            self._notify_current = self.notify_slaves[0]
 
     def finish_notify_out(self):
         if self._sock:
@@ -96,6 +113,10 @@ class ZoneNotifyInfo:
         return self._notify_current
 
 class NotifyOut:
+    '''This class is used to handle notify logic for all zones(sending
+    notify message to its slaves).The only interface provided to 
+    the user is send_notify(). the object of this class should be 
+    used together with function dispatcher(). '''
     def __init__(self, datasrc_file, log=None, verbose=True):
         self._notify_infos = {} # key is (zone_name, zone_class)
         self._waiting_zones = []
@@ -119,13 +140,41 @@ class NotifyOut:
             self._notify_infos[zone_id] = ZoneNotifyInfo(zone_name, zone_class)
             slaves = self._get_notify_slaves_from_ns(zone_name)
             for item in slaves:
-                self._notify_infos[zone_id]._notify_slaves.append((item, 53))
+                self._notify_infos[zone_id].notify_slaves.append((item, 53))
+
+    def send_notify(self, zone_name, zone_class='IN'):
+        '''Send notify to one zone's slaves, this function is 
+        the only interface for class NotifyOut which can be called
+        by other object.
+          Internally, the function only set the zone's notify-reply
+        timeout to now, then notify message will be sent out. '''
+        if zone_name[len(zone_name) - 1] != '.':
+            zone_name += '.'
+
+        zone_id = (zone_name, zone_class)
+        if zone_id not in self._notify_infos:
+            return
+
+        with self._lock:
+            if (self.notify_num >= _MAX_NOTIFY_NUM) or (zone_id in self._notifying_zones):
+                if zone_id not in self._waiting_zones:
+                    self._waiting_zones.append(zone_id)
+            else:
+                self._notify_infos[zone_id].prepare_notify_out()
+                self.notify_num += 1 
+                self._notifying_zones.append(zone_id)
 
     def _get_rdata_data(self, rr):
         return rr[7].strip()
 
     def _get_notify_slaves_from_ns(self, zone_name):
-        '''The simplest way to get the address of slaves, but not correct.
+        '''Get all NS records, then remove the primary master from ns rrset,
+        then use the name in NS record rdata part to get the a/aaaa records
+        in the same zone. the targets listed in a/aaaa record rdata are treated
+        as the notify slaves.
+        Note: this is the simplest way to get the address of slaves, 
+        but not correct, it can't handle the delegation slaves, or the CNAME
+        and DNAME logic.
         TODO. the function should be provided by one library.'''
         ns_rrset = sqlite3_ds.get_zone_rrset(zone_name, zone_name, 'NS', self._db_file)
         soa_rrset = sqlite3_ds.get_zone_rrset(zone_name, zone_name, 'SOA', self._db_file)
@@ -134,7 +183,7 @@ class NotifyOut:
             ns_rr_name.append(self._get_rdata_data(ns)) 
        
         if len(soa_rrset) > 0:
-            sname = (soa_rrset[0][7].split(' '))[0].strip() #TODO, bad hardcode to get rdata part
+            sname = (soa_rrset[0][sqlite3_ds.RR_RDATA_INDEX].split(' '))[0].strip() #TODO, bad hardcode to get rdata part
             if sname in ns_rr_name:
                 ns_rr_name.remove(sname)
 
@@ -149,44 +198,44 @@ class NotifyOut:
 
         return addr_list
 
-    def send_notify(self, zone_name, zone_class='IN'):
-        if zone_name[len(zone_name) - 1] != '.':
-            zone_name += '.'
-
-        zone_id = (zone_name, zone_class)
-        if zone_id not in self._notify_infos:
-            return
-
-        with self._lock:
-            if (self.notify_num >= _MAX_NOTIFY_NUM) or (zone_id in self._notifying_zones):
-                if zone_id not in self._waiting_zones:
-                    self._waiting_zones.append(zone_id)
-            else:
-                self._notify_infos[zone_id].prepare_notify_out()
-                self.notify_num += 1 
-                self._notifying_zones.append(zone_id)
-
-    def _wait_for_notify_reply(self):
-        '''receive notify replies in specified time. returned value 
-        is one tuple:(replied_zones, not_replied_zones)
-        replied_zones: the zones which receive notify reply.
-        not_replied_zones: the zones which haven't got notify reply.
-        '''
+    def _prepare_select_info(self):
+        '''Prepare the information for select(), returned 
+        value is one tuple 
+        (block_timeout, valid_socks, notifying_zones)
+        block_timeout: the timeout for select()
+        valid_socks: sockets list for waiting ready reading.
+        notifying_zones: the zones which have been triggered 
+                        for notify. '''
         valid_socks = []
         notifying_zones = {}
-        min_timeout = time.time()
+        min_timeout = None 
         for info in self._notify_infos:
             sock = self._notify_infos[info].get_socket()
             if sock:
                 valid_socks.append(sock)
                 notifying_zones[info] = self._notify_infos[info]
                 tmp_timeout = self._notify_infos[info].notify_timeout
-                if min_timeout > tmp_timeout:
+                if min_timeout:
+                    if tmp_timeout < min_timeout:
+                        min_timeout = tmp_timeout
+                else:
                     min_timeout = tmp_timeout
+       
+        block_timeout = 0
+        if min_timeout:
+            block_timeout = min_timeout - time.time()
+            if block_timeout < 0:
+                block_timeout = 0
         
-        block_timeout = min_timeout - time.time()
-        if block_timeout < 0:
-            block_timeout = 0
+        return (block_timeout, valid_socks, notifying_zones)
+
+    def _wait_for_notify_reply(self):
+        '''receive notify replies in specified time. returned value 
+        is one tuple:(replied_zones, not_replied_zones)
+        replied_zones: the zones which receive notify reply.
+        not_replied_zones: the zones which haven't got notify reply.
+        '''
+        (block_timeout, valid_socks, notifying_zones) = self._prepare_select_info()
         try:
             r_fds, w, e = select.select(valid_socks, [], [], block_timeout)
         except select.error as err:
@@ -204,6 +253,11 @@ class NotifyOut:
         return replied_zones, not_replied_zones
 
     def _zone_notify_handler(self, zone_notify_info, event_type):
+        '''Notify handler for one zone. The first notify message is 
+        always triggered by the event "_EVENT_TIMEOUT" since when 
+        one zone prepares to notify its slaves, it's notify_timeout 
+        is set to now, which is used to trigger sending notify 
+        message when dispatcher() scanning zones. '''
         tgt = zone_notify_info.get_current_notify_target()
         if event_type == _EVENT_READ:
             reply = self._get_notify_reply(zone_notify_info.get_socket(), tgt)
@@ -261,13 +315,14 @@ class NotifyOut:
 
         return True
 
-    def _create_rrset_from_db_record(self, record):
+    def _create_rrset_from_db_record(self, record, zone_class):
         '''Create one rrset from one record of datasource, if the schema of record is changed, 
         This function should be updated first. TODO, the function is copied from xfrout, there
         should be library for creating one rrset. '''
-        rrtype_ = RRType(record[5])
-        rdata_ = Rdata(rrtype_, RRClass("IN"), " ".join(record[7:]))
-        rrset_ = RRset(Name(record[2]), RRClass("IN"), rrtype_, RRTTL( int(record[4])))
+        rrtype_ = RRType(record[sqlite3_ds.RR_TYPE_INDEX])
+        rdata_ = Rdata(rrtype_, RRClass(zone_class), " ".join(record[sqlite3_ds.RR_RDATA_INDEX:]))
+        rrset_ = RRset(Name(record[sqlite3_ds.RR_NAME_INDEX]), RRClass(zone_class), \
+                       rrtype_, RRTTL( int(record[sqlite3_ds.RR_TTL_INDEX])))
         rrset_.add_rdata(rdata_)
         return rrset_
 
@@ -282,7 +337,7 @@ class NotifyOut:
         msg.add_question(question)
         # Add soa record to answer section
         soa_record = sqlite3_ds.get_zone_rrset(zone_name, zone_name, 'SOA', self._db_file) 
-        rrset_soa = self._create_rrset_from_db_record(soa_record[0])
+        rrset_soa = self._create_rrset_from_db_record(soa_record[0], zone_class)
         msg.add_rrset(Section.ANSWER(), rrset_soa)
         return msg, qid
 
@@ -297,21 +352,26 @@ class NotifyOut:
             msg.from_wire(msg_data)
             if not msg.get_header_flag(MessageFlag.QR()):
                 self._log_msg('error', errstr + 'bad flags')
-                return False
+                return _BAD_QR
 
             if msg.get_qid() != zone_notify_info.notify_msg_id: 
                 self._log_msg('error', errstr + 'bad query ID')
-                return False
+                return _BAD_QUERY_ID
+            
+            question = msg.get_question()[0]
+            if question.get_name() != Name(zone_notify_info.zone_name):
+                self._log_msg('error', errstr + 'bad query name')
+                return _BAD_QUERY_NAME
 
-            if msg.get_opcode != Opcode.NOTIFY():
+            if msg.get_opcode() != Opcode.NOTIFY():
                 self._log_msg('error', errstr + 'bad opcode')
-                return False
+                return _BAD_OPCODE
         except Exception as err:
             # We don't care what exception, just report it? 
             self._log_msg('error', errstr + str(err))
-            return False
+            return _BAD_REPLY_PACKET
 
-        return True
+        return _REPLY_OK
 
     def _get_notify_reply(self, sock, tgt_addr):
         try:

+ 61 - 21
src/lib/python/isc/notify/tests/notify_out_test.py

@@ -35,8 +35,8 @@ class TestZoneNotifyInfo(unittest.TestCase):
         self.assertEqual(self.info._sock, None)
 
     def test_set_next_notify_target(self):
-        self.info._notify_slaves.append(('127.0.0.1', 53))
-        self.info._notify_slaves.append(('1.1.1.1', 5353))
+        self.info.notify_slaves.append(('127.0.0.1', 53))
+        self.info.notify_slaves.append(('1.1.1.1', 5353))
         self.info.prepare_notify_out()
         self.assertEqual(self.info.get_current_notify_target(), ('127.0.0.1', 53))
 
@@ -66,8 +66,8 @@ class TestNotifyOut(unittest.TestCase):
         self._notify._notify_infos[('org.', 'CH')] = notify_out.ZoneNotifyInfo('org.', 'CH')
         
         info = self._notify._notify_infos[('cn.', 'IN')]
-        info._notify_slaves.append(('127.0.0.1', 53))
-        info._notify_slaves.append(('1.1.1.1', 5353))
+        info.notify_slaves.append(('127.0.0.1', 53))
+        info.notify_slaves.append(('1.1.1.1', 5353))
 
     def tearDown(self):
         sys.stdout = self.old_stdout
@@ -148,11 +148,29 @@ class TestNotifyOut(unittest.TestCase):
         self.assertEqual(0, len(self._notify._notifying_zones))
     
     def test_handle_notify_reply(self):
-        self.assertFalse(self._notify._handle_notify_reply(None, b'badmsg'))
+        self.assertEqual(notify_out._BAD_REPLY_PACKET, self._notify._handle_notify_reply(None, b'badmsg'))
         com_info = self._notify._notify_infos[('com.', 'IN')]
         com_info.notify_msg_id = 0X2f18
-        data = b'\x2f\x18\xa0\x00\x00\x01\x00\x00\x00\x00\x00\x00\x02tw\x02cn\x00\x00\x06\x00\x01'
-        self.assertTrue(self._notify._handle_notify_reply(com_info, data))
+
+        # test with right notify reply message
+        data = b'\x2f\x18\xa0\x00\x00\x01\x00\x00\x00\x00\x00\x00\x03com\x00\x00\x06\x00\x01'
+        self.assertEqual(notify_out._REPLY_OK, self._notify._handle_notify_reply(com_info, data))
+
+        # test with unright query id
+        data = b'\x2e\x18\xa0\x00\x00\x01\x00\x00\x00\x00\x00\x00\x03com\x00\x00\x06\x00\x01'
+        self.assertEqual(notify_out._BAD_QUERY_ID, self._notify._handle_notify_reply(com_info, data))
+
+        # test with unright query name
+        data = b'\x2f\x18\xa0\x00\x00\x01\x00\x00\x00\x00\x00\x00\x02cn\x00\x00\x06\x00\x01'
+        self.assertEqual(notify_out._BAD_QUERY_NAME, self._notify._handle_notify_reply(com_info, data))
+
+        # test with unright opcode
+        data = b'\x2f\x18\x80\x00\x00\x01\x00\x00\x00\x00\x00\x00\x03com\x00\x00\x06\x00\x01'
+        self.assertEqual(notify_out._BAD_OPCODE, self._notify._handle_notify_reply(com_info, data))
+
+        # test with unright qr
+        data = b'\x2f\x18\x10\x10\x00\x01\x00\x00\x00\x00\x00\x00\x03com\x00\x00\x06\x00\x01'
+        self.assertEqual(notify_out._BAD_QR, self._notify._handle_notify_reply(com_info, data))
 
     def test_send_notify_message_udp(self):
         com_info = self._notify._notify_infos[('cn.', 'IN')]
@@ -194,13 +212,13 @@ class TestNotifyOut(unittest.TestCase):
         ('cn.',         '1000',  'IN',  'NS',  'b.dns.cn.'),
         ('cn.',         '1000',  'IN',  'NS',  'c.dns.cn.'),
         ('a.dns.cn.',   '1000',  'IN',  'A',    '1.1.1.1'),
-        ('a.dns.cn.',   '1000',  'IN',  'AAAA', '2.2.2.2'),
+        ('a.dns.cn.',   '1000',  'IN',  'AAAA', '2:2::2:2'),
         ('b.dns.cn.',   '1000',  'IN',  'A',    '3.3.3.3'),
-        ('b.dns.cn.',   '1000',  'IN',  'AAAA', '4:4.4.4'),
-        ('b.dns.cn.',   '1000',  'IN',  'AAAA', '5:5.5.5'),
+        ('b.dns.cn.',   '1000',  'IN',  'AAAA', '4:4::4:4'),
+        ('b.dns.cn.',   '1000',  'IN',  'AAAA', '5:5::5:5'),
         ('c.dns.cn.',   '1000',  'IN',  'A',    '6.6.6.6'),
         ('c.dns.cn.',   '1000',  'IN',  'A',    '7.7.7.7'),
-        ('c.dns.cn.',   '1000',  'IN',  'AAAA', '8:8.8.8')]
+        ('c.dns.cn.',   '1000',  'IN',  'AAAA', '8:8::8:8')]
         for item in zone_data:
             yield item
 
@@ -212,33 +230,55 @@ class TestNotifyOut(unittest.TestCase):
         ('com.',         '1000',  'IN',  'NS',  'c.dns.com.'),
         ('a.dns.com.',   '1000',  'IN',  'A',    '1.1.1.1'),
         ('b.dns.com.',   '1000',  'IN',  'A',    '3.3.3.3'),
-        ('b.dns.com.',   '1000',  'IN',  'AAAA', '4:4.4.4'),
-        ('b.dns.com.',   '1000',  'IN',  'AAAA', '5:5.5.5')]
+        ('b.dns.com.',   '1000',  'IN',  'AAAA', '4:4::4:4'),
+        ('b.dns.com.',   '1000',  'IN',  'AAAA', '5:5::5:5')]
         for item in zone_data:
             yield item
 
     def test_get_notify_slaves_from_ns(self):
         records = self._notify._get_notify_slaves_from_ns('cn.')
         self.assertEqual(6, len(records))
-        self.assertEqual('8:8.8.8', records[5])
+        self.assertEqual('8:8::8:8', records[5])
         self.assertEqual('7.7.7.7', records[4])
         self.assertEqual('6.6.6.6', records[3])
-        self.assertEqual('5:5.5.5', records[2])
-        self.assertEqual('4:4.4.4', records[1])
+        self.assertEqual('5:5::5:5', records[2])
+        self.assertEqual('4:4::4:4', records[1])
         self.assertEqual('3.3.3.3', records[0])
 
         records = self._notify._get_notify_slaves_from_ns('com.')
-        print('=============', records)
         self.assertEqual(3, len(records))
-        self.assertEqual('5:5.5.5', records[2])
-        self.assertEqual('4:4.4.4', records[1])
+        self.assertEqual('5:5::5:5', records[2])
+        self.assertEqual('4:4::4:4', records[1])
         self.assertEqual('3.3.3.3', records[0])
     
     def test_init_notify_out(self):
         self._notify._init_notify_out(self._db_file.name)
-        self.assertListEqual([('3.3.3.3', 53), ('4:4.4.4', 53), ('5:5.5.5', 53)], 
-                             self._notify._notify_infos[('com.', 'IN')]._notify_slaves)
+        self.assertListEqual([('3.3.3.3', 53), ('4:4::4:4', 53), ('5:5::5:5', 53)], 
+                             self._notify._notify_infos[('com.', 'IN')].notify_slaves)
         
+    def test_prepare_select_info(self):
+        timeout, valid_fds, notifying_zones = self._notify._prepare_select_info()
+        self.assertEqual(0, timeout)
+        self.assertListEqual([], valid_fds)
+
+        self._notify._notify_infos[('cn.', 'IN')]._sock = 1
+        self._notify._notify_infos[('cn.', 'IN')].notify_timeout = time.time() + 5
+        timeout, valid_fds, notifying_zones = self._notify._prepare_select_info()
+        self.assertGreater(timeout, 0)
+        self.assertListEqual([1], valid_fds)
+
+        self._notify._notify_infos[('cn.', 'IN')]._sock = 1
+        self._notify._notify_infos[('cn.', 'IN')].notify_timeout = time.time() - 5
+        timeout, valid_fds, notifying_zones = self._notify._prepare_select_info()
+        self.assertEqual(timeout, 0)
+        self.assertListEqual([1], valid_fds)
+
+        self._notify._notify_infos[('com.', 'IN')]._sock = 2
+        self._notify._notify_infos[('com.', 'IN')].notify_timeout = time.time() + 5
+        timeout, valid_fds, notifying_zones = self._notify._prepare_select_info()
+        self.assertEqual(timeout, 0)
+        self.assertListEqual([2, 1], valid_fds)
+
 if __name__== "__main__":
     unittest.main()