Browse Source

commit the code of notify-out. TODO:merge the code of secondary manager(in branch 215) to this branch, so that it's easy do the test.

git-svn-id: svn://bind10.isc.org/svn/bind10/branches/trac289@2611 e5f2f494-b856-4b98-b285-d166d9295462
Likun Zhang 14 years ago
parent
commit
ea33f60613

+ 3 - 0
configure.ac

@@ -423,6 +423,8 @@ AC_CONFIG_FILES([Makefile
                  src/lib/python/isc/config/tests/Makefile
                  src/lib/python/isc/log/Makefile
                  src/lib/python/isc/log/tests/Makefile
+                 src/lib/python/isc/notify/Makefile
+                 src/lib/python/isc/notify/tests/Makefile
                  src/lib/config/Makefile
                  src/lib/config/tests/Makefile
                  src/lib/dns/Makefile
@@ -469,6 +471,7 @@ AC_OUTPUT([src/bin/cfgmgr/b10-cfgmgr.py
            src/lib/python/isc/config/tests/config_test
            src/lib/python/isc/cc/tests/cc_test
            src/lib/python/isc/log/tests/log_test
+           src/lib/python/isc/notify/tests/notify_out_test
            src/lib/dns/gen-rdatacode.py
            src/lib/python/bind10_config.py
            src/lib/dns/tests/testdata/gen-wiredata.py

+ 24 - 11
src/bin/xfrin/xfrin.py.in

@@ -28,6 +28,7 @@ import socket
 import random
 from optparse import OptionParser, OptionValueError
 from isc.config.ccsession import *
+from isc.notify import notify_out
 try:
     from libdns_python import *
 except ImportError as e:
@@ -49,7 +50,7 @@ else:
 SPECFILE_LOCATION = SPECFILE_PATH + "/xfrin.spec"
 AUTH_SPECFILE_LOCATION = AUTH_SPECFILE_PATH + "/auth.spec"
 
-
+XFROUT_MODULE_NAME = 'Xfrout'
 __version__ = 'BIND10'
 # define xfrin rcode
 XFRIN_OK = 0
@@ -66,7 +67,7 @@ class XfrinException(Exception):
 class XfrinConnection(asyncore.dispatcher):
     '''Do xfrin in this class. '''    
 
-    def __init__(self,
+    def __init__(self, server_,
                  sock_map, zone_name, rrclass, db_file, shutdown_event,
                  master_addrinfo, verbose = False, idle_timeout = 60): 
         ''' idle_timeout: max idle time for read data from socket.
@@ -77,6 +78,7 @@ class XfrinConnection(asyncore.dispatcher):
         asyncore.dispatcher.__init__(self, map=sock_map)
         self.create_socket(master_addrinfo[0], master_addrinfo[1])
         self._zone_name = zone_name
+        self._server = server_
         self._sock_map = sock_map
         self._rrclass = rrclass
         self._db_file = db_file
@@ -192,6 +194,7 @@ class XfrinConnection(asyncore.dispatcher):
                                             self._handle_xfrin_response)
 
                 self.log_msg(logstr + 'succeeded')
+                self._server.send_notify_command(self._zone_name)
                 ret = XFRIN_OK
 
         except XfrinException as e:
@@ -316,11 +319,11 @@ class XfrinConnection(asyncore.dispatcher):
             sys.stdout.write('[b10-xfrin] %s\n' % str(msg))
 
 
-def process_xfrin(xfrin_recorder, zone_name, rrclass, db_file, 
+def process_xfrin(server, xfrin_recorder, zone_name, rrclass, db_file, 
                   shutdown_event, master_addrinfo, check_soa, verbose):
     xfrin_recorder.increment(zone_name)
     sock_map = {}
-    conn = XfrinConnection(sock_map, zone_name, rrclass, db_file,
+    conn = XfrinConnection(server, sock_map, zone_name, rrclass, db_file,
                            shutdown_event, master_addrinfo, verbose)
     if conn.connect_to_master():
         conn.do_xfrin(check_soa)
@@ -370,17 +373,20 @@ This method is used only as part of initialization, but is implemented
 separately for convenience of unit tests; by letting the test code override
 this method we can test most of this class without requiring a command channel.
 '''
-        self._cc = isc.config.ModuleCCSession(SPECFILE_LOCATION,
+        # Create one session for sending command to other modules, because the 
+        # listening session will block the send operation.
+        self._send_cc_session = isc.cc.Session()
+        self._module_cc = isc.config.ModuleCCSession(SPECFILE_LOCATION,
                                               self.config_handler,
                                               self.command_handler)
-        self._cc.start()
+        self._module_cc.start()
 
     def _cc_check_command(self):
         '''
 This is a straightforward wrapper for cc.check_command, but provided as
 a separate method for the convenience of unit tests.
 '''
-        self._cc.check_command()
+        self._module_cc.check_command()
 
     def config_handler(self, new_config):
         # TODO, process new config data
@@ -420,6 +426,12 @@ a separate method for the convenience of unit tests.
 
         return answer
 
+    def send_notify_command(self, zone_name):
+        '''Send Notify command to xfrout module.'''
+        param = {'zone_name': zone_name}
+        msg = create_command(notify_out.ZONE_NOTIFY_CMD, param)
+        self._send_cc_session.group_sendmsg(msg, XFROUT_MODULE_NAME)
+
     def _parse_cmd_params(self, args):
         zone_name = args.get('zone_name')
         if not zone_name:
@@ -441,14 +453,14 @@ a separate method for the convenience of unit tests.
             # should add it on start, and not remove it here
             # (or, if we have writable ds, we might not need this in
             # the first place)
-            self._cc.add_remote_config(AUTH_SPECFILE_LOCATION)
-            db_file, is_default = self._cc.get_remote_config_value("Auth", "database_file")
+            self._module_cc.add_remote_config(AUTH_SPECFILE_LOCATION)
+            db_file, is_default = self._module_cc.get_remote_config_value("Auth", "database_file")
             if is_default and "B10_FROM_BUILD" in os.environ:
                 # this too should be unnecessary, but currently the
                 # 'from build' override isn't stored in the config
                 # (and we don't have writable datasources yet)
                 db_file = os.environ["B10_FROM_BUILD"] + os.sep + "bind10_zones.sqlite3"
-            self._cc.remove_remote_config(AUTH_SPECFILE_LOCATION)
+            self._module_cc.remove_remote_config(AUTH_SPECFILE_LOCATION)
 
         return (zone_name, master_addrinfo, db_file)
 
@@ -469,7 +481,8 @@ a separate method for the convenience of unit tests.
             return (1, 'zone xfrin is in progress')
 
         xfrin_thread = threading.Thread(target = process_xfrin,
-                                        args = (self.recorder,
+                                        args = (self,
+                                                self.recorder,
                                                 zone_name, rrclass,
                                                 db_file,
                                                 self._shutdown_event,

+ 22 - 3
src/bin/xfrout/xfrout.py.in

@@ -28,6 +28,7 @@ import os
 from isc.config.ccsession import *
 from isc.log.log import *
 from isc.cc import SessionError
+from isc.notify import notify_out
 import socket
 import select
 import errno
@@ -303,7 +304,7 @@ class UnixSockServer(ThreadingUnixStreamServer):
         self._log = log
         self.update_config_data(config_data)
         self._cc = cc
-
+        
     def finish_request(self, request, client_address):
         '''Finish one request by instantiating RequestHandlerClass.'''
         self.RequestHandlerClass(request, client_address, self, self._log)
@@ -415,16 +416,25 @@ class XfroutServer:
                                 self._config_data.get('log_severity'), self._config_data.get('log_versions'),
                                 self._config_data.get('log_max_bytes'), True)
         self._start_xfr_query_listener()
+        self._start_notifier()
 
     def _start_xfr_query_listener(self):
         '''Start a new thread to accept xfr query. '''
-    
         self._unix_socket_server = UnixSockServer(self._listen_sock_file, XfroutSession, 
                                                   self._shutdown_event, self._config_data,
                                                   self._cc, self._log);
         listener = threading.Thread(target = listen_on_xfr_query, args = (self._unix_socket_server,))
         listener.start()
+        
+    def _start_notifier(self):
+        datasrc = self._unix_socket_server.get_db_file()
+        self._notifier = notify_out.NotifyOut(datasrc, self._log)
+        td = threading.Thread(target = notify_out.dispatcher, args = (self._notifier,))
+        td.daemon = True
+        td.start()
 
+    def send_notify(self, zone_name):
+        self._notifier.send_notify(zone_name)
 
     def config_handler(self, new_config):
         '''Update config data. TODO. Do error check'''
@@ -466,11 +476,20 @@ class XfroutServer:
             self._log.log_message("info", "Received shutdown command.")
             self.shutdown()
             answer = create_answer(0)
+        
+        elif cmd == notify_out.ZONE_NOTIFY_CMD:
+            zone_name = args.get('zone_name')
+            if zone_name:
+                self._log.log_message("info", "Receive notify command for zone " + zone_name)
+                self.send_notify(zone_name)
+                answer = create_answer(0)
+            else:
+                answer = create_answer(1, "Bad command parameter:" + str(args))
+
         else: 
             answer = create_answer(1, "Unknown command:" + str(cmd))
 
         return answer    
- 
 
     def run(self):
         '''Get and process all commands sent from cfgmgr or other modules. '''

+ 1 - 1
src/lib/python/isc/Makefile.am

@@ -1,4 +1,4 @@
-SUBDIRS = datasrc cc config log # Util
+SUBDIRS = datasrc cc config log notify # Util
 
 python_PYTHON = __init__.py
 

+ 33 - 0
src/lib/python/isc/datasrc/sqlite3_ds.py

@@ -120,6 +120,39 @@ def get_zone_soa(zonename, dbfile):
 
     return datas
 
+
+#########################################################################
+# get_zone_rrset
+#   returns the rrset of the zone with the given zone name, rrset name 
+#   and given rd type. 
+#   If the zone doesn't exist or rd type doesn't exist, return an empty list. 
+#########################################################################
+def get_zone_rrset(zonename, rr_name, rdtype, dbfile):
+    conn, cur = open(dbfile)
+    id = get_zoneid(zonename, cur)
+    cur.execute("SELECT * FROM records WHERE name = ? and zone_id = ? and rdtype = ?", 
+                [rr_name, id, rdtype])
+    datas = cur.fetchall()
+    cur.close()
+    conn.close()
+    return datas
+
+
+#########################################################################
+# get_zones_info:
+#   returns all the zones' information.
+#########################################################################
+def get_zones_info(db_file):
+    conn, cur = open(db_file)
+    cur.execute("SELECT name, rdclass FROM zones")
+    info = cur.fetchone()
+    while info:
+        yield info
+        info = cur.fetchone()
+
+    cur.close()
+    conn.close()
+
 #########################################################################
 # get_zoneid:
 #   returns the zone_id for a given zone name, or an empty

+ 5 - 0
src/lib/python/isc/notify/Makefile.am

@@ -0,0 +1,5 @@
+SUBDIRS = tests
+
+python_PYTHON = __init__.py notify_out.py
+
+pythondir = $(pyexecdir)/isc/notify

+ 1 - 0
src/lib/python/isc/notify/__init__.py

@@ -0,0 +1 @@
+from isc.notify.notify_out import *

+ 313 - 0
src/lib/python/isc/notify/notify_out.py

@@ -0,0 +1,313 @@
+import select
+import random
+import socket
+import threading
+import time
+from isc.datasrc import sqlite3_ds
+import isc
+try: 
+    from libdns_python import * 
+except ImportError as e: 
+    # C++ loadable module may not be installed; 
+    sys.stderr.write('[b10-xfrout] failed to import DNS or XFR module: %s\n' % str(e)) 
+
+ZONE_NOTIFY_CMD = 'zone_new_data_ready'
+_MAX_NOTIFY_NUM = 30
+_MAX_NOTIFY_TRY_NUM = 5
+_EVENT_NONE = 0
+_EVENT_READ = 1
+_EVENT_TIMEOUT = 2
+_NOTIFY_TIMEOUT = 2
+
+def addr_to_str(addr):
+    return '%s#%s' % (addr[0], addr[1])
+
+def dispatcher(notifier):
+    while True:
+        replied_zones, not_replied_zones = notifier._wait_for_notify_reply()
+        if len(replied_zones) == 0 and len(not_replied_zones) == 0:
+            time.sleep(0.5) # A better time?
+            continue
+
+        for name_ in replied_zones:
+            notifier._zone_notify_handler(replied_zones[name_], _EVENT_READ)
+            
+        for name_ in not_replied_zones:
+            if not_replied_zones[name_].notify_timeout < time.time():
+                notifier._zone_notify_handler(not_replied_zones[name_], _EVENT_TIMEOUT)
+ 
+class ZoneNotifyInfo:
+    '''This class keeps track of notify-out information for one zone.
+    timeout_: absolute time for next notify reply.
+    '''    
+    def __init__(self, zone_name_, klass):
+        self._notify_slaves = []
+        self._notify_current = None
+        self._slave_index = 0
+        self._sock = None
+
+        self.zone_name = zone_name_
+        self.zone_class = klass
+        self.notify_msg_id = 0
+        self.notify_timeout = 0
+        # Notify times sending to one target.
+        self.notify_try_num = 0 
+       
+    def set_next_notify_target(self):
+        if self._slave_index < (len(self._notify_slaves) - 1):
+            self._slave_index += 1
+            self._notify_current = self._notify_slaves[self._slave_index]
+        else:
+            self._notify_current = None
+
+    def prepare_notify_out(self):
+        '''Create the socket and set notify timeout time to now'''
+        self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) #TODO support IPv6?
+        self.notify_timeout = time.time()
+        self.notify_try_num = 0
+        self._slave_index = 0
+        if len(self._notify_slaves) > 0:
+            self._notify_current = self._notify_slaves[0]
+
+    def finish_notify_out(self):
+        if self._sock:
+            self._sock.close()
+            self._sock = None
+
+    def get_socket(self):
+        return self._sock
+
+    def get_current_notify_target(self):
+        return self._notify_current
+
+class NotifyOut:
+    def __init__(self, datasrc_file, log=None, verbose=True):
+        self._notify_infos = {}
+        self._waiting_zones = []
+        self._notifying_zones = []
+        self._log = log
+        self.notify_num = 0  # the count of in progress notifies
+        self._verbose = verbose
+        self._lock = threading.Lock()
+        self._db_file = datasrc_file
+        self._init_notify_out(datasrc_file)
+
+    def _init_notify_out(self, datasrc_file):
+        '''Get all the zones name and its notify target's address
+        TODO, currently the zones are got by going through the zone 
+        table in database. There should be a better way to get them 
+        and also the setting 'also_notify', and there should be one 
+        mechanism to cover the changed datasrc.'''
+        self._db_file = datasrc_file
+        for zone_name, zone_class in sqlite3_ds.get_zones_info(datasrc_file):
+            self._notify_infos[zone_name] = ZoneNotifyInfo(zone_name, zone_class)
+            slaves = self._get_notify_slaves_from_ns(zone_name)
+            for item in slaves:
+                self._notify_infos[zone_name]._notify_slaves.append((item, 53))
+
+    def _get_rdata_data(self, rr):
+        return rr[7].strip()
+
+    def _get_notify_slaves_from_ns(self, zone_name):
+        '''The simplest way to get the address of slaves, but now correct.
+        TODO. the function should be provided by one library.'''
+        ns_rrset = sqlite3_ds.get_zone_rrset(zone_name, zone_name, 'NS', self._db_file)
+        soa_rrset = sqlite3_ds.get_zone_rrset(zone_name, zone_name, 'SOA', self._db_file)
+        ns_rr_name = []
+        for ns in ns_rrset:
+            ns_rr_name.append(self._get_rdata_data(ns)) 
+        
+        sname = (soa_rrset[0][7].split(' '))[0].strip() #TODO, bad hardcode to get rdata part
+        if sname in ns_rr_name:
+            ns_rr_name.remove(sname)
+
+        addr_list = []
+        for rr_name in ns_rr_name:
+            a_rrset = sqlite3_ds.get_zone_rrset(zone_name, rr_name, 'A', self._db_file)
+            aaaa_rrset = sqlite3_ds.get_zone_rrset(zone_name, rr_name, 'AAAA', self._db_file)
+            for rr in a_rrset:
+                addr_list.append(self._get_rdata_data(rr))
+            for rr in aaaa_rrset:
+                addr_list.append(self._get_rdata_data(rr))
+
+        return addr_list
+
+    def send_notify(self, zone_name):
+        print('=============begin to send notify', zone_name, '===', self._notify_infos)
+        print(self._notify_infos)
+        if zone_name not in self._notify_infos:
+            print('=============not eixst')
+            return
+
+        print('=============begin to send notify')
+        with self._lock:
+            if (self.notify_num >= _MAX_NOTIFY_NUM) or (zone_name in self._notifying_zones):
+                if zone_name not in self._waiting_zones:
+                    self._waiting_zones.append(zone_name)
+            else:
+                self._notify_infos[zone_name].prepare_notify_out()
+                self.notify_num += 1 
+                self._notifying_zones.append(zone_name)
+
+    def _wait_for_notify_reply(self):
+        '''receive notify replies in specified time. returned value 
+        is one tuple:(replied_zones, not_replied_zones)
+        replied_zones: the zones which receive notify reply.
+        not_replied_zones: the zones which haven't got notify reply.
+        '''
+        valid_socks = []
+        notifying_zones = {}
+        min_timeout = time.time()
+        for info in self._notify_infos:
+            sock = self._notify_infos[info].get_socket()
+            if sock:
+                valid_socks.append(sock)
+                notifying_zones[info] = self._notify_infos[info]
+                tmp_timeout = self._notify_infos[info].notify_timeout
+                if min_timeout > tmp_timeout:
+                    min_timeout = tmp_timeout
+        
+        block_timeout = min_timeout - time.time()
+        if block_timeout < 0:
+            block_timeout = 0
+        try:
+            r_fds, w, e = select.select(valid_socks, [], [], block_timeout)
+        except select.error as err:
+            if err.args[0] != EINTR:
+                return [], []
+        
+        not_replied_zones = {}
+        replied_zones = {}
+        for info in notifying_zones:
+            if notifying_zones[info].get_socket() in r_fds:
+                replied_zones[info] = notifying_zones[info]
+            else:
+                not_replied_zones[info] = notifying_zones[info]
+
+        return replied_zones, not_replied_zones
+
+    def _zone_notify_handler(self, zone_notify_info, event_type):
+        tgt = zone_notify_info.get_current_notify_target()
+        if event_type == _EVENT_READ:
+            reply = self._get_notify_reply(zone_notify_info.get_socket(), tgt)
+            if reply:
+                if self._handle_notify_reply(zone_notify_info, reply):
+                    self._notify_next_target(zone_notify_info)
+
+        elif event_type == _EVENT_TIMEOUT and zone_notify_info.notify_try_num > 0:
+            self._log_msg('info', 'notify retry to %s' % addr_to_str(tgt))
+
+        tgt = zone_notify_info.get_current_notify_target()
+        if tgt:
+            zone_notify_info.notify_try_num += 1
+            if zone_notify_info.notify_try_num > _MAX_NOTIFY_TRY_NUM:
+                self._log_msg('info', 'notify to %s: retried exceeded' % addr_to_str(tgt))
+                self._notify_next_target(zone_notify_info)
+            else:
+                retry_timeout = _NOTIFY_TIMEOUT * pow(2, zone_notify_info.notify_try_num)
+                # set exponential backoff according rfc1996 section 3.6
+                zone_notify_info.notify_timeout = time.time() + retry_timeout
+                self._send_notify_message_udp(zone_notify_info, tgt)
+
+    def _notify_next_target(self, zone_notify_info):
+        '''Notify next address for the same zone. If all the targets 
+        has been notified, notify the first zone in waiting list. '''
+        zone_notify_info.notify_try_num = 0
+        zone_notify_info.set_next_notify_target()
+        tgt = zone_notify_info.get_current_notify_target()
+        if not tgt:
+            zone_notify_info.finish_notify_out()
+            with self._lock:
+                self.notify_num -= 1 
+                self._notifying_zones.remove(zone_notify_info.zone_name) 
+                # trigger notify out for waiting zones
+                if len(self._waiting_zones) > 0:
+                    zone_name = self._waiting_zones.pop(0) 
+                    self._notify_infos[zone_name].prepare_notify_out()
+                    self.notify_num += 1 
+
+    def _send_notify_message_udp(self, zone_notify_info, addrinfo):
+        msg, qid = self._create_notify_message(zone_notify_info.zone_name, 
+                                               zone_notify_info.zone_class)
+        render = MessageRenderer()
+        render.set_length_limit(512) 
+        msg.to_wire(render)
+        zone_notify_info.notify_msg_id = qid
+        sock = zone_notify_info.get_socket()
+        try:
+            sock.sendto(render.get_data(), 0, addrinfo)
+            self._log_msg('info', 'sending notify to %s' % addr_to_str(addrinfo))
+        except socket.error as err:
+            self._log_msg('error', 'send notify to %s failed: %s' % (addr_to_str(addrinfo), str(err)))
+            return False
+
+        return True
+
+    def _create_rrset_from_db_record(self, record):
+        '''Create one rrset from one record of datasource, if the schema of record is changed, 
+        This function should be updated first. TODO, the function is copied from xfrout, there
+        should be library for creating one rrset. '''
+        rrtype_ = RRType(record[5])
+        rdata_ = Rdata(rrtype_, RRClass("IN"), " ".join(record[7:]))
+        rrset_ = RRset(Name(record[2]), RRClass("IN"), rrtype_, RRTTL( int(record[4])))
+        rrset_.add_rdata(rdata_)
+        return rrset_
+
+    def _create_notify_message(self, zone_name, zone_class):
+        msg = Message(Message.RENDER)
+        qid = random.randint(0, 0xFFFF)
+        msg.set_qid(qid)
+        msg.set_opcode(Opcode.NOTIFY())
+        msg.set_rcode(Rcode.NOERROR())
+        msg.set_header_flag(MessageFlag.AA())
+        question = Question(Name(zone_name), RRClass(zone_class), RRType('SOA'))
+        msg.add_question(question)
+        # Add soa record to answer section
+        soa_record = sqlite3_ds.get_zone_rrset(zone_name, zone_name, 'SOA', self._db_file) 
+        rrset_soa = self._create_rrset_from_db_record(soa_record[0])
+        msg.add_rrset(Section.ANSWER(), rrset_soa)
+        return msg, qid
+
+    def _handle_notify_reply(self, zone_notify_info, msg_data):
+        '''Parse the notify reply message.
+        TODO, the error message should be refined properly.'''
+        msg = Message(Message.PARSE)
+        try:
+            errstr = 'notify reply error: '
+            msg.from_wire(msg_data)
+            if (msg.get_rcode() != Rcode.NOERROR()):
+                self._log_msg('error', errstr + 'bad rcode')
+                return False
+
+            if not msg.get_header_flag(MessageFlag.QR()):
+                self._log_msg('error', errstr + 'bad flags')
+                return False
+
+            if msg.get_qid() != zone_notify_info.notify_msg_id: 
+                self._log_msg('error', errstr + 'bad query ID')
+                return False
+
+            if msg.get_opcode != Opcode.NOTIFY():
+                self._log_msg('error', errstr + 'bad opcode')
+                return False
+        except Exception as err:
+            # We don't care what exception, just report it? 
+            self._log_msg('error', errstr + str(err))
+            return False
+
+        return True
+
+    def _get_notify_reply(self, sock, tgt_addr):
+        try:
+            msg, addr = sock.recvfrom(512)
+        except socket.error:
+            self._log_msg('error', "notify to %s failed: can't read notify reply" % addr_to_str(tgt_addr))
+            return None
+
+        return msg
+
+
+    def _log_msg(self, level, msg):
+        if self._log:
+            self._log.log_message(level, msg)
+

+ 12 - 0
src/lib/python/isc/notify/tests/Makefile.am

@@ -0,0 +1,12 @@
+PYTESTS = notify_out_test.py
+EXTRA_DIST = $(PYTESTS)
+
+# later will have configure option to choose this, like: coverage run --branch
+PYCOVERAGE = $(PYTHON)
+# test using command-line arguments, so use check-local target instead of TESTS
+check-local:
+	for pytest in $(PYTESTS) ; do \
+	echo Running test: $$pytest ; \
+	env PYTHONPATH=$(abs_top_srcdir)/src/lib/python:$(abs_top_builddir)/src/lib/python \
+	$(PYCOVERAGE) $(abs_srcdir)/$$pytest ; \
+	done

+ 28 - 0
src/lib/python/isc/notify/tests/notify_out_test.in

@@ -0,0 +1,28 @@
+#! /bin/sh
+
+# Copyright (C) 2010  Internet Systems Consortium.
+#
+# Permission to use, copy, modify, and distribute this software for any
+# purpose with or without fee is hereby granted, provided that the above
+# copyright notice and this permission notice appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
+# DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
+# INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
+# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
+# FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
+# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
+# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+PYTHON_EXEC=${PYTHON_EXEC:-@PYTHON@}
+export PYTHON_EXEC
+
+NOTIFY_OUT_PATH=@abs_top_srcdir@/src/lib/python/isc/notify/tests
+
+PYTHONPATH=@abs_top_srcdir@/src/lib/python
+export PYTHONPATH
+
+cd ${BIND10_PATH}
+${PYTHON_EXEC} -O ${NOTIFY_OUT_PATH}/notify_out_test.py $*
+

+ 219 - 0
src/lib/python/isc/notify/tests/notify_out_test.py

@@ -0,0 +1,219 @@
+import unittest
+import sys
+import os
+import tempfile
+import time
+import socket
+from isc.datasrc import sqlite3_ds
+import notify_out
+
+class TestZoneNotifyInfo(unittest.TestCase):
+    def setUp(self):
+        self.info = notify_out.ZoneNotifyInfo('cn.', 'IN')
+
+    def test_prepare_finish_notify_out(self):
+        self.info.prepare_notify_out()
+        self.assertNotEqual(self.info._sock, None)
+        self.assertIsNone(self.info._notify_current)
+
+        self.info.finish_notify_out()
+        self.assertEqual(self.info._sock, None)
+
+    def test_set_next_notify_target(self):
+        self.info._notify_slaves.append(('127.0.0.1', 53))
+        self.info._notify_slaves.append(('1.1.1.1', 5353))
+        self.info.prepare_notify_out()
+        self.assertEqual(self.info.get_current_notify_target(), ('127.0.0.1', 53))
+
+        self.assertEqual('127.0.0.1#53', notify_out.addr_to_str(('127.0.0.1', 53)))
+        self.info.set_next_notify_target()
+        self.assertEqual(self.info.get_current_notify_target(), ('1.1.1.1', 5353))
+        self.info.set_next_notify_target()
+        self.assertIsNone(self.info.get_current_notify_target())
+
+        temp_info = notify_out.ZoneNotifyInfo('com.', 'IN')
+        temp_info.prepare_notify_out()
+        self.assertIsNone(temp_info.get_current_notify_target())
+
+
+class TestNotifyOut(unittest.TestCase):
+    def setUp(self):
+        self.old_stdout = sys.stdout
+        sys.stdout = open(os.devnull, 'w')
+        self._db_file = tempfile.NamedTemporaryFile(delete=False)
+        sqlite3_ds.load(self._db_file.name, 'cn.', self._cn_data_reader)
+        sqlite3_ds.load(self._db_file.name, 'com.', self._com_data_reader)
+        self._notify = notify_out.NotifyOut(self._db_file.name)
+        self._notify._notify_infos['com.'] = notify_out.ZoneNotifyInfo('com.', 'IN')
+        self._notify._notify_infos['cn.'] = notify_out.ZoneNotifyInfo('cn.', 'IN')
+        self._notify._notify_infos['org.'] = notify_out.ZoneNotifyInfo('org.', 'IN')
+        
+        info = self._notify._notify_infos['cn.']
+        info._notify_slaves.append(('127.0.0.1', 53))
+        info._notify_slaves.append(('1.1.1.1', 5353))
+
+    def tearDown(self):
+        sys.stdout = self.old_stdout
+        self._db_file.close()
+        os.unlink(self._db_file.name)
+
+    def test_send_notify(self):
+        self._notify.send_notify('cn.')
+        self.assertEqual(self._notify.notify_num, 1)
+        self.assertEqual(self._notify._notifying_zones[0], 'cn.')
+
+        self._notify.send_notify('com.')
+        self.assertEqual(self._notify.notify_num, 2)
+        self.assertEqual(self._notify._notifying_zones[1], 'com.')
+    
+        notify_out._MAX_NOTIFY_NUM = 2
+        self._notify.send_notify('org.')
+        self.assertEqual(self._notify._waiting_zones[0], 'org.')
+        self._notify.send_notify('org.')
+        self.assertEqual(1, len(self._notify._waiting_zones))
+
+    def test_wait_for_notify_reply(self):
+        self._notify.send_notify('cn.')
+        self._notify.send_notify('com.')
+    
+        notify_out._MAX_NOTIFY_NUM = 2
+        self._notify.send_notify('org.')
+        replied_zones, timeout_zones = self._notify._wait_for_notify_reply()
+        self.assertEqual(len(replied_zones), 0)
+        self.assertEqual(len(timeout_zones), 2)
+
+        # Now make one socket be readable
+        addr = ('localhost', 12340)
+        self._notify._notify_infos['cn.']._sock.bind(addr)
+        self._notify._notify_infos['cn.'].notify_timeout = time.time() + 10
+        self._notify._notify_infos['com.'].notify_timeout = time.time() + 10
+        
+        send_fd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+        #Send some data to socket 12340, to make the target socket be readable
+        send_fd.sendto(b'data', addr)
+        replied_zones, timeout_zones = self._notify._wait_for_notify_reply()
+        self.assertEqual(len(replied_zones), 1)
+        self.assertEqual(len(timeout_zones), 1)
+        self.assertTrue('cn.' in replied_zones.keys())
+        self.assertTrue('com.' in timeout_zones.keys())
+        self.assertLess(time.time(), self._notify._notify_infos['com.'].notify_timeout)
+    
+    def test_notify_next_target(self):
+        self._notify.send_notify('cn.')
+        self._notify.send_notify('com.')
+        notify_out._MAX_NOTIFY_NUM = 2
+        self._notify.send_notify('org.')
+
+        info = self._notify._notify_infos['cn.']
+        self._notify._notify_next_target(info)
+        self.assertEqual(0, info.notify_try_num)
+        self.assertEqual(info.get_current_notify_target(), ('1.1.1.1', 5353))
+        self.assertEqual(2, self._notify.notify_num)
+
+        self._notify._notify_next_target(info)
+        self.assertEqual(0, info.notify_try_num)
+        self.assertIsNone(info.get_current_notify_target())
+        self.assertEqual(2, self._notify.notify_num)
+        self.assertEqual(0, len(self._notify._waiting_zones))
+
+        com_info = self._notify._notify_infos['com.']
+        self._notify._notify_next_target(com_info)
+        self.assertEqual(1, self._notify.notify_num)
+        self.assertEqual(0, len(self._notify._notifying_zones))
+    
+    def test_handle_notify_reply(self):
+        self.assertFalse(self._notify._handle_notify_reply(None, b'badmsg'))
+        com_info = self._notify._notify_infos['com.']
+        com_info.notify_msg_id = 0X2f18
+        data = b'\x2f\x18\xa0\x00\x00\x01\x00\x00\x00\x00\x00\x00\x02tw\x02cn\x00\x00\x06\x00\x01'
+        self.assertTrue(self._notify._handle_notify_reply(com_info, data))
+
+    def test_send_notify_message_udp(self):
+        com_info = self._notify._notify_infos['cn.']
+        com_info.prepare_notify_out()
+        ret = self._notify._send_notify_message_udp(com_info, ('1.1.1.1', 53))
+        self.assertTrue(ret)
+
+    def test_zone_notify_handler(self):
+        old_send_msg = self._notify._send_notify_message_udp
+        def _fake_send_notify_message_udp(va1, va2): 
+            pass
+        self._notify._send_notify_message_udp = _fake_send_notify_message_udp
+        self._notify.send_notify('cn.')
+        self._notify.send_notify('com.')
+        notify_out._MAX_NOTIFY_NUM = 2
+        self._notify.send_notify('org.')
+
+        cn_info = self._notify._notify_infos['cn.']
+        cn_info.prepare_notify_out()
+
+        cn_info.notify_try_num = 2
+        self._notify._zone_notify_handler(cn_info, notify_out._EVENT_TIMEOUT)
+        self.assertEqual(3, cn_info.notify_try_num)
+
+        time1 = cn_info.notify_timeout
+        self._notify._zone_notify_handler(cn_info, notify_out._EVENT_TIMEOUT)
+        self.assertEqual(4, cn_info.notify_try_num)
+        self.assertGreater(cn_info.notify_timeout, time1 + 2) # bigger than 2 seconds
+
+        cur_tgt = cn_info._notify_current
+        cn_info.notify_try_num = notify_out._MAX_NOTIFY_TRY_NUM
+        self._notify._zone_notify_handler(cn_info, notify_out._EVENT_NONE)
+        self.assertNotEqual(cur_tgt, cn_info._notify_current)
+
+    def _cn_data_reader(self):
+        zone_data = [
+        ('cn.',         '1000',  'IN',  'SOA', 'a.dns.cn. mail.cn. 1 1 1 1 1'),
+        ('cn.',         '1000',  'IN',  'NS',  'a.dns.cn.'),
+        ('cn.',         '1000',  'IN',  'NS',  'b.dns.cn.'),
+        ('cn.',         '1000',  'IN',  'NS',  'c.dns.cn.'),
+        ('a.dns.cn.',   '1000',  'IN',  'A',    '1.1.1.1'),
+        ('a.dns.cn.',   '1000',  'IN',  'AAAA', '2.2.2.2'),
+        ('b.dns.cn.',   '1000',  'IN',  'A',    '3.3.3.3'),
+        ('b.dns.cn.',   '1000',  'IN',  'AAAA', '4:4.4.4'),
+        ('b.dns.cn.',   '1000',  'IN',  'AAAA', '5:5.5.5'),
+        ('c.dns.cn.',   '1000',  'IN',  'A',    '6.6.6.6'),
+        ('c.dns.cn.',   '1000',  'IN',  'A',    '7.7.7.7'),
+        ('c.dns.cn.',   '1000',  'IN',  'AAAA', '8:8.8.8')]
+        for item in zone_data:
+            yield item
+
+    def _com_data_reader(self):
+        zone_data = [
+        ('com.',         '1000',  'IN',  'SOA', 'a.dns.com. mail.com. 1 1 1 1 1'),
+        ('com.',         '1000',  'IN',  'NS',  'a.dns.com.'),
+        ('com.',         '1000',  'IN',  'NS',  'b.dns.com.'),
+        ('com.',         '1000',  'IN',  'NS',  'c.dns.com.'),
+        ('a.dns.com.',   '1000',  'IN',  'A',    '1.1.1.1'),
+        ('b.dns.com.',   '1000',  'IN',  'A',    '3.3.3.3'),
+        ('b.dns.com.',   '1000',  'IN',  'AAAA', '4:4.4.4'),
+        ('b.dns.com.',   '1000',  'IN',  'AAAA', '5:5.5.5')]
+        for item in zone_data:
+            yield item
+
+    def test_get_notify_slaves_from_ns(self):
+        records = self._notify._get_notify_slaves_from_ns('cn.')
+        self.assertEqual(6, len(records))
+        self.assertEqual('8:8.8.8', records[5])
+        self.assertEqual('7.7.7.7', records[4])
+        self.assertEqual('6.6.6.6', records[3])
+        self.assertEqual('5:5.5.5', records[2])
+        self.assertEqual('4:4.4.4', records[1])
+        self.assertEqual('3.3.3.3', records[0])
+
+        records = self._notify._get_notify_slaves_from_ns('com.')
+        print('=============', records)
+        self.assertEqual(3, len(records))
+        self.assertEqual('5:5.5.5', records[2])
+        self.assertEqual('4:4.4.4', records[1])
+        self.assertEqual('3.3.3.3', records[0])
+    
+    def test_init_notify_out(self):
+        self._notify._init_notify_out(self._db_file.name)
+        self.assertListEqual([('3.3.3.3', 53), ('4:4.4.4', 53), ('5:5.5.5', 53)], 
+                             self._notify._notify_infos['com.']._notify_slaves)
+        
+if __name__== "__main__":
+    unittest.main()
+
+