Browse Source

1. Add unittest and docstring to some functions in xfrin.
2. Minor fix to the xfrin.

git-svn-id: svn://bind10.isc.org/svn/bind10/trunk@1729 e5f2f494-b856-4b98-b285-d166d9295462

Likun Zhang 15 years ago
parent
commit
2b9742f844
2 changed files with 90 additions and 34 deletions
  1. 57 2
      src/bin/xfrin/tests/xfrin_test.py
  2. 33 32
      src/bin/xfrin/xfrin.py.in

File diff suppressed because it is too large
+ 57 - 2
src/bin/xfrin/tests/xfrin_test.py


+ 33 - 32
src/bin/xfrin/xfrin.py.in

@@ -50,15 +50,6 @@ SPECFILE_LOCATION = SPECFILE_PATH + "/xfrin.spec"
 __version__ = 'BIND10'
 __version__ = 'BIND10'
 # define xfrin rcode
 # define xfrin rcode
 XFRIN_OK = 0
 XFRIN_OK = 0
-XFRIN_RECV_TIMEOUT = 1
-XFRIN_NO_NEWDATA = 2
-XFRIN_QUOTA_ERROR = 3
-XFRIN_IS_DOING = 4
-
-# define xfrin state
-XFRIN_QUERY_SOA = 1
-XFRIN_FIRST_AXFR = 2
-XFRIN_FIRST_IXFR = 3
 
 
 def log_error(msg):
 def log_error(msg):
     sys.stderr.write("[b10-xfrin] ")
     sys.stderr.write("[b10-xfrin] ")
@@ -68,21 +59,17 @@ def log_error(msg):
 class XfrinException(Exception): 
 class XfrinException(Exception): 
     pass
     pass
 
 
-
 class XfrinConnection(asyncore.dispatcher):
 class XfrinConnection(asyncore.dispatcher):
     '''Do xfrin in this class. '''    
     '''Do xfrin in this class. '''    
 
 
-    def __init__(self, zone_name, db_file, 
+    def __init__(self, 
-                 shutdown_event,
+                 zone_name, db_file, shutdown_event, master_addr, 
-                 master_addr, 
+                 port = 53, verbose = False, idle_timeout = 60): 
-                 port = 53, 
-                 check_soa = True, 
-                 verbose = False,
-                 idle_timeout = 60): 
         ''' idle_timeout: max idle time for read data from socket.
         ''' idle_timeout: max idle time for read data from socket.
             db_file: specify the data source file.
             db_file: specify the data source file.
             check_soa: when it's true, check soa first before sending xfr query
             check_soa: when it's true, check soa first before sending xfr query
         '''
         '''
+
         asyncore.dispatcher.__init__(self)
         asyncore.dispatcher.__init__(self)
         self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
         self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
         self._zone_name = zone_name
         self._zone_name = zone_name
@@ -92,18 +79,22 @@ class XfrinConnection(asyncore.dispatcher):
         self.setblocking(1)
         self.setblocking(1)
         self._shutdown_event = shutdown_event
         self._shutdown_event = shutdown_event
         self._verbose = verbose
         self._verbose = verbose
+        self._master_addr = master_addr
+        self._port = port
 
 
-    def connect_to_master(self, master_addr, port):
+    def connect_to_master(self):
         '''Connect to master in TCP.'''
         '''Connect to master in TCP.'''
+
         try:
         try:
-            self.connect((master_addr, port))
+            self.connect((self._master_addr, self._port))
             return True
             return True
         except socket.error as e:
         except socket.error as e:
-            self.log_msg('Failed to connect:(%s:%d), %s' % (master_addr, port, str(e)))
+            self.log_msg('Failed to connect:(%s:%d), %s' % (self._master_addr, self._port, str(e)))
             return False
             return False
 
 
     def _create_query(self, query_type):
     def _create_query(self, query_type):
         '''Create dns query message. '''
         '''Create dns query message. '''
+
         msg = message(message_mode.RENDER)
         msg = message(message_mode.RENDER)
         query_id = random.randint(1, 0xFFFF)
         query_id = random.randint(1, 0xFFFF)
         self._query_id = query_id
         self._query_id = query_id
@@ -123,6 +114,7 @@ class XfrinConnection(asyncore.dispatcher):
 
 
     def _send_query(self, query_type):
     def _send_query(self, query_type):
         '''Send query message over TCP. '''
         '''Send query message over TCP. '''
+
         msg = self._create_query(query_type)
         msg = self._create_query(query_type)
         obuf = output_buffer(0)
         obuf = output_buffer(0)
         render = message_render(obuf)
         render = message_render(obuf)
@@ -147,26 +139,23 @@ class XfrinConnection(asyncore.dispatcher):
 
 
         return data
         return data
 
 
-    def handle_read(self):
-        '''Read query's response from socket. '''
-        self._recvd_data = self.recv(self._need_recv_size)
-        self._recvd_size = len(self._recvd_data)
-        self._recv_time_out = False
-
     def _check_soa_serial(self):
     def _check_soa_serial(self):
         ''' Compare the soa serial, if soa serial in master is less than
         ''' Compare the soa serial, if soa serial in master is less than
         the soa serial in local, Finish xfrin.
         the soa serial in local, Finish xfrin.
         False: soa serial in master is less or equal to the local one.
         False: soa serial in master is less or equal to the local one.
         True: soa serial in master is bigger
         True: soa serial in master is bigger
         '''
         '''
+
         self._send_query(rr_type.SOA())
         self._send_query(rr_type.SOA())
         data_size = self._get_request_response(2)
         data_size = self._get_request_response(2)
         soa_reply = self._get_request_response(int(data_size))
         soa_reply = self._get_request_response(int(data_size))
         #TODO, need select soa record from data source then compare the two 
         #TODO, need select soa record from data source then compare the two 
-        #serial 
+        #serial, current just return OK, since this function hasn't been used now 
         return XFRIN_OK
         return XFRIN_OK
 
 
     def do_xfrin(self, check_soa, ixfr_first = False):
     def do_xfrin(self, check_soa, ixfr_first = False):
+        '''Do xfr by sending xfr request and parsing response. '''
+
         try:
         try:
             ret = XFRIN_OK
             ret = XFRIN_OK
             if check_soa:
             if check_soa:
@@ -194,6 +183,8 @@ class XfrinConnection(asyncore.dispatcher):
         return ret
         return ret
     
     
     def _check_response_status(self, msg):
     def _check_response_status(self, msg):
+        '''Check validation of xfr response. '''
+
         #TODO, check more?
         #TODO, check more?
         msg_rcode = msg.get_rcode()
         msg_rcode = msg.get_rcode()
         if msg_rcode != rcode.NOERROR():
         if msg_rcode != rcode.NOERROR():
@@ -212,6 +203,8 @@ class XfrinConnection(asyncore.dispatcher):
             raise XfrinException('query section count greater than 1')
             raise XfrinException('query section count greater than 1')
 
 
     def _handle_answer_section(self, rrset_iter):
     def _handle_answer_section(self, rrset_iter):
+        '''Return a generator for the reponse in one tcp package to a zone transfer.'''
+
         while not rrset_iter.is_last():
         while not rrset_iter.is_last():
             rrset = rrset_iter.get_rrset()
             rrset = rrset_iter.get_rrset()
             rrset_iter.next()
             rrset_iter.next()
@@ -242,6 +235,8 @@ class XfrinConnection(asyncore.dispatcher):
                 rdata_iter.next()
                 rdata_iter.next()
 
 
     def _handle_xfrin_response(self):
     def _handle_xfrin_response(self):
+        '''Return a generator for the response to a zone transfer. '''
+
         while True:
         while True:
             data_len = self._get_request_response(2)
             data_len = self._get_request_response(2)
             msg_len = socket.htons(struct.unpack('H', data_len)[0])
             msg_len = socket.htons(struct.unpack('H', data_len)[0])
@@ -258,12 +253,18 @@ class XfrinConnection(asyncore.dispatcher):
                 break
                 break
             
             
             if self._shutdown_event.is_set():
             if self._shutdown_event.is_set():
-                #Check if xfrin process is shutdown.
-                #TODO, xfrin may be blocked in one loop. 
                 raise XfrinException('xfrin is forced to stop')
                 raise XfrinException('xfrin is forced to stop')
 
 
+    def handle_read(self):
+        '''Read query's response from socket. '''
+
+        self._recvd_data = self.recv(self._need_recv_size)
+        self._recvd_size = len(self._recvd_data)
+        self._recv_time_out = False
+
     def writable(self):
     def writable(self):
         '''Ignore the writable socket. '''
         '''Ignore the writable socket. '''
+
         return False
         return False
 
 
     def log_info(self, msg, type='info'):
     def log_info(self, msg, type='info'):
@@ -282,9 +283,9 @@ def process_xfrin(xfrin_recorder, zone_name, db_file,
     port = int(port)
     port = int(port)
     xfrin_recorder.increment(zone_name)
     xfrin_recorder.increment(zone_name)
     conn = XfrinConnection(zone_name, db_file, shutdown_event, 
     conn = XfrinConnection(zone_name, db_file, shutdown_event, 
-                           master_addr, port, check_soa, verbose)
+                           master_addr, port, verbose)
-    if conn.connect_to_master(master_addr, port):
+    if conn.connect_to_master():
-        conn.do_xfrin(False)
+        conn.do_xfrin(check_soa)
 
 
     xfrin_recorder.decrement(zone_name)
     xfrin_recorder.decrement(zone_name)