Browse Source

[trac815] update XFRIN to use TSIG

chenzhengzhang 14 years ago
parent
commit
7efc144e0e
3 changed files with 101 additions and 39 deletions
  1. 43 0
      src/bin/xfrin/tests/xfrin_test.py
  2. 53 39
      src/bin/xfrin/xfrin.py.in
  3. 5 0
      src/bin/xfrin/xfrin.spec

+ 43 - 0
src/bin/xfrin/tests/xfrin_test.py

@@ -51,6 +51,13 @@ default_answers = [soa_rrset]
 class XfrinTestException(Exception):
 class XfrinTestException(Exception):
     pass
     pass
 
 
+def strip_mutable_tsig_data(data):
+    # Unfortunately we cannot easily compare TSIG RR because we can't tweak
+    # current time.  As a work around this helper function strips off the time
+    # dependent part of TSIG RDATA, i.e., the MAC (assuming HMAC-MD5) and
+    # Time Signed.
+    return data[0:-32] + data[-26:-22] + data[-6:]
+
 class MockXfrin(Xfrin):
 class MockXfrin(Xfrin):
     # This is a class attribute of a callable object that specifies a non
     # This is a class attribute of a callable object that specifies a non
     # default behavior triggered in _cc_check_command().  Specific test methods
     # default behavior triggered in _cc_check_command().  Specific test methods
@@ -60,6 +67,7 @@ class MockXfrin(Xfrin):
     check_command_hook = None
     check_command_hook = None
 
 
     def _cc_setup(self):
     def _cc_setup(self):
+        self._tsig_key_str = None
         pass
         pass
 
 
     def _get_db_file(self):
     def _get_db_file(self):
@@ -196,6 +204,36 @@ class TestXfrinConnection(unittest.TestCase):
                          RRClass.CH())
                          RRClass.CH())
         c.close()
         c.close()
 
 
+    def test_send_query(self):
+        def create_msg(query_type):
+            msg = Message(Message.RENDER)
+            query_id = 0x1035
+            msg.set_qid(query_id)
+            msg.set_opcode(Opcode.QUERY())
+            msg.set_rcode(Rcode.NOERROR())
+            query_question = Question(Name("example.com."), RRClass.IN(), query_type)
+            msg.add_question(query_question)
+            return msg
+        self.conn._create_query = create_msg
+        # soa request
+        self.conn._send_query(RRType.SOA())
+        self.assertEqual(self.conn.query_data, b'\x00\x1d\x105\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x07example\x03com\x00\x00\x06\x00\x01')
+        # axfr request
+        self.conn._send_query(RRType.AXFR())
+        self.assertEqual(self.conn.query_data, b'\x00\x1d\x105\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x07example\x03com\x00\x00\xfc\x00\x01')
+
+        # soa request with tsig
+        tsig_key = TSIGKey("example.com:SFuWd/q99SzF8Yzd1QbB9g==")
+        self.conn._tsig_ctx = TSIGContext(tsig_key)
+        self.conn._send_query(RRType.SOA())
+        tsig_soa_data = strip_mutable_tsig_data(self.conn.query_data)
+        self.assertEqual(tsig_soa_data, b'\x00n\x105\x00\x00\x00\x01\x00\x00\x00\x00\x00\x01\x07example\x03com\x00\x00\x06\x00\x01\x07example\x03com\x00\x00\xfa\x00\xff\x00\x00\x00\x00\x00:\x08hmac-md5\x07sig-alg\x03reg\x03int\x00\x01,\x00\x10\x105\x00\x00\x00\x00')
+
+        # axfr request with tsig
+        self.conn._send_query(RRType.AXFR())
+        tsig_axfr_data = strip_mutable_tsig_data(self.conn.query_data)
+        self.assertEqual(tsig_axfr_data, b'\x00n\x105\x00\x00\x00\x01\x00\x00\x00\x00\x00\x01\x07example\x03com\x00\x00\xfc\x00\x01\x07example\x03com\x00\x00\xfa\x00\xff\x00\x00\x00\x00\x00:\x08hmac-md5\x07sig-alg\x03reg\x03int\x00\x01,\x00\x10\x105\x00\x00\x00\x00')
+
     def test_response_with_invalid_msg(self):
     def test_response_with_invalid_msg(self):
         self.conn.reply_data = b'aaaxxxx'
         self.conn.reply_data = b'aaaxxxx'
         self.assertRaises(XfrinTestException, self._handle_xfrin_response)
         self.assertRaises(XfrinTestException, self._handle_xfrin_response)
@@ -399,15 +437,20 @@ class TestXfrinRecorder(unittest.TestCase):
 
 
 class TestXfrin(unittest.TestCase):
 class TestXfrin(unittest.TestCase):
     def setUp(self):
     def setUp(self):
+        # redirect output
+        self.stderr_backup = sys.stderr
+        sys.stderr = open(os.devnull, 'w')
         self.xfr = MockXfrin()
         self.xfr = MockXfrin()
         self.args = {}
         self.args = {}
         self.args['zone_name'] = TEST_ZONE_NAME
         self.args['zone_name'] = TEST_ZONE_NAME
         self.args['port'] = TEST_MASTER_PORT
         self.args['port'] = TEST_MASTER_PORT
         self.args['master'] = TEST_MASTER_IPV4_ADDRESS
         self.args['master'] = TEST_MASTER_IPV4_ADDRESS
         self.args['db_file'] = TEST_DB_FILE
         self.args['db_file'] = TEST_DB_FILE
+        self.args['tsig_key'] = ''
 
 
     def tearDown(self):
     def tearDown(self):
         self.xfr.shutdown()
         self.xfr.shutdown()
+        sys.stderr= self.stderr_backup
 
 
     def _do_parse_zone_name_class(self):
     def _do_parse_zone_name_class(self):
         return self.xfr._parse_zone_name_and_class(self.args)
         return self.xfr._parse_zone_name_and_class(self.args)

+ 53 - 39
src/bin/xfrin/xfrin.py.in

@@ -67,15 +67,16 @@ DEFAULT_MASTER = '127.0.0.1'
 def log_error(msg):
 def log_error(msg):
     sys.stderr.write("[b10-xfrin] %s\n" % str(msg))
     sys.stderr.write("[b10-xfrin] %s\n" % str(msg))
 
 
-class XfrinException(Exception): 
+class XfrinException(Exception):
     pass
     pass
 
 
 class XfrinConnection(asyncore.dispatcher):
 class XfrinConnection(asyncore.dispatcher):
-    '''Do xfrin in this class. '''    
+    '''Do xfrin in this class. '''
 
 
     def __init__(self,
     def __init__(self,
                  sock_map, zone_name, rrclass, db_file, shutdown_event,
                  sock_map, zone_name, rrclass, db_file, shutdown_event,
-                 master_addrinfo, verbose = False, idle_timeout = 60): 
+                 master_addrinfo, tsig_key_str = None, verbose = False,
+                 idle_timeout = 60):
         ''' idle_timeout: max idle time for read data from socket.
         ''' idle_timeout: max idle time for read data from socket.
             db_file: specify the data source file.
             db_file: specify the data source file.
             check_soa: when it's true, check soa first before sending xfr query
             check_soa: when it's true, check soa first before sending xfr query
@@ -93,6 +94,9 @@ class XfrinConnection(asyncore.dispatcher):
         self._shutdown_event = shutdown_event
         self._shutdown_event = shutdown_event
         self._verbose = verbose
         self._verbose = verbose
         self._master_address = master_addrinfo[2]
         self._master_address = master_addrinfo[2]
+        self._tsig_ctx = None
+        if tsig_key_str:
+            self._tsig_ctx = TSIGContext(TSIGKey(tsig_key_str))
 
 
     def connect_to_master(self):
     def connect_to_master(self):
         '''Connect to master in TCP.'''
         '''Connect to master in TCP.'''
@@ -130,9 +134,12 @@ class XfrinConnection(asyncore.dispatcher):
 
 
         msg = self._create_query(query_type)
         msg = self._create_query(query_type)
         render = MessageRenderer()
         render = MessageRenderer()
-        msg.to_wire(render)
+        if self._tsig_ctx:
-        header_len = struct.pack('H', socket.htons(render.get_length()))
+            msg.to_wire(render, self._tsig_ctx)
+        else:
+            msg.to_wire(render)
 
 
+        header_len = struct.pack('H', socket.htons(render.get_length()))
         self._send_data(header_len)
         self._send_data(header_len)
         self._send_data(render.get_data())
         self._send_data(render.get_data())
 
 
@@ -142,7 +149,7 @@ class XfrinConnection(asyncore.dispatcher):
         _get_request_response so that we can test the rest of the code without
         _get_request_response so that we can test the rest of the code without
         involving actual communication with a remote server.'''
         involving actual communication with a remote server.'''
         asyncore.loop(self._idle_timeout, map=self._sock_map, count=1)
         asyncore.loop(self._idle_timeout, map=self._sock_map, count=1)
-    
+
     def _get_request_response(self, size):
     def _get_request_response(self, size):
         recv_size = 0
         recv_size = 0
         data = b''
         data = b''
@@ -176,7 +183,7 @@ class XfrinConnection(asyncore.dispatcher):
         # strict we should be (see the comment in _check_response_header())
         # strict we should be (see the comment in _check_response_header())
         self._check_response_header(msg)
         self._check_response_header(msg)
 
 
-        # TODO, need select soa record from data source then compare the two 
+        # TODO, need select soa record from data source then compare the two
         # serial, current just return OK, since this function hasn't been used
         # serial, current just return OK, since this function hasn't been used
         # now.
         # now.
         return XFRIN_OK
         return XFRIN_OK
@@ -290,14 +297,14 @@ class XfrinConnection(asyncore.dispatcher):
             msg = Message(Message.PARSE)
             msg = Message(Message.PARSE)
             msg.from_wire(recvdata)
             msg.from_wire(recvdata)
             self._check_response_status(msg)
             self._check_response_status(msg)
-            
+
             answer_section = msg.get_section(Message.SECTION_ANSWER)
             answer_section = msg.get_section(Message.SECTION_ANSWER)
             for rr in self._handle_answer_section(answer_section):
             for rr in self._handle_answer_section(answer_section):
                 yield rr
                 yield rr
 
 
             if self._soa_rr_count == 2:
             if self._soa_rr_count == 2:
                 break
                 break
-            
+
             if self._shutdown_event.is_set():
             if self._shutdown_event.is_set():
                 raise XfrinException('xfrin is forced to stop')
                 raise XfrinException('xfrin is forced to stop')
 
 
@@ -322,16 +329,18 @@ class XfrinConnection(asyncore.dispatcher):
             sys.stdout.write('[b10-xfrin] %s\n' % str(msg))
             sys.stdout.write('[b10-xfrin] %s\n' % str(msg))
 
 
 
 
-def process_xfrin(server, xfrin_recorder, zone_name, rrclass, db_file, 
+def process_xfrin(server, xfrin_recorder, zone_name, rrclass, db_file,
-                  shutdown_event, master_addrinfo, check_soa, verbose):
+                  shutdown_event, master_addrinfo, check_soa, verbose,
+                  tsig_key_str):
     xfrin_recorder.increment(zone_name)
     xfrin_recorder.increment(zone_name)
     sock_map = {}
     sock_map = {}
     conn = XfrinConnection(sock_map, zone_name, rrclass, db_file,
     conn = XfrinConnection(sock_map, zone_name, rrclass, db_file,
-                           shutdown_event, master_addrinfo, verbose)
+                           shutdown_event, master_addrinfo,
+                           tsig_key_str, verbose)
     ret = XFRIN_FAIL
     ret = XFRIN_FAIL
     if conn.connect_to_master():
     if conn.connect_to_master():
         ret = conn.do_xfrin(check_soa)
         ret = conn.do_xfrin(check_soa)
-    
+
     # Publish the zone transfer result news, so zonemgr can reset the
     # Publish the zone transfer result news, so zonemgr can reset the
     # zone timer, and xfrout can notify the zone's slaves if the result
     # zone timer, and xfrout can notify the zone's slaves if the result
     # is success.
     # is success.
@@ -379,11 +388,11 @@ class Xfrin:
         self._verbose = verbose
         self._verbose = verbose
 
 
     def _cc_setup(self):
     def _cc_setup(self):
-        '''This method is used only as part of initialization, but is 
+        '''This method is used only as part of initialization, but is
-        implemented separately for convenience of unit tests; by letting 
+        implemented separately for convenience of unit tests; by letting
-        the test code override this method we can test most of this class 
+        the test code override this method we can test most of this class
         without requiring a command channel.'''
         without requiring a command channel.'''
-        # Create one session for sending command to other modules, because the 
+        # Create one session for sending command to other modules, because the
         # listening session will block the send operation.
         # listening session will block the send operation.
         self._send_cc_session = isc.cc.Session()
         self._send_cc_session = isc.cc.Session()
         self._module_cc = isc.config.ModuleCCSession(SPECFILE_LOCATION,
         self._module_cc = isc.config.ModuleCCSession(SPECFILE_LOCATION,
@@ -394,15 +403,17 @@ class Xfrin:
         self._max_transfers_in = config_data.get("transfers_in")
         self._max_transfers_in = config_data.get("transfers_in")
         self._master_addr = config_data.get('master_addr') or self._master_addr
         self._master_addr = config_data.get('master_addr') or self._master_addr
         self._master_port = config_data.get('master_port') or self._master_port
         self._master_port = config_data.get('master_port') or self._master_port
+        self._tsig_key_str = config_data.get('tsig_key') or None
 
 
     def _cc_check_command(self):
     def _cc_check_command(self):
-        '''This is a straightforward wrapper for cc.check_command, 
+        '''This is a straightforward wrapper for cc.check_command,
-        but provided as a separate method for the convenience 
+        but provided as a separate method for the convenience
         of unit tests.'''
         of unit tests.'''
         self._module_cc.check_command(False)
         self._module_cc.check_command(False)
 
 
     def config_handler(self, new_config):
     def config_handler(self, new_config):
         self._max_transfers_in = new_config.get("transfers_in") or self._max_transfers_in
         self._max_transfers_in = new_config.get("transfers_in") or self._max_transfers_in
+        self._tsig_key_str = new_config.get('tsig_key') or None
         if ('master_addr' in new_config) or ('master_port' in new_config):
         if ('master_addr' in new_config) or ('master_port' in new_config):
             # User should change the port and address together.
             # User should change the port and address together.
             try:
             try:
@@ -420,7 +431,7 @@ class Xfrin:
         return create_answer(0)
         return create_answer(0)
 
 
     def shutdown(self):
     def shutdown(self):
-        ''' shutdown the xfrin process. the thread which is doing xfrin should be 
+        ''' shutdown the xfrin process. the thread which is doing xfrin should be
         terminated.
         terminated.
         '''
         '''
         self._shutdown_event.set()
         self._shutdown_event.set()
@@ -436,30 +447,32 @@ class Xfrin:
             if command == 'shutdown':
             if command == 'shutdown':
                 self._shutdown_event.set()
                 self._shutdown_event.set()
             elif command == 'notify' or command == REFRESH_FROM_ZONEMGR:
             elif command == 'notify' or command == REFRESH_FROM_ZONEMGR:
-                # Xfrin receives the refresh/notify command from zone manager. 
+                # Xfrin receives the refresh/notify command from zone manager.
-                # notify command maybe has the parameters which 
+                # notify command maybe has the parameters which
                 # specify the notifyfrom address and port, according the RFC1996, zone
                 # specify the notifyfrom address and port, according the RFC1996, zone
                 # transfer should starts first from the notifyfrom, but now, let 'TODO' it.
                 # transfer should starts first from the notifyfrom, but now, let 'TODO' it.
                 (zone_name, rrclass) = self._parse_zone_name_and_class(args)
                 (zone_name, rrclass) = self._parse_zone_name_and_class(args)
                 (master_addr) = build_addr_info(self._master_addr, self._master_port)
                 (master_addr) = build_addr_info(self._master_addr, self._master_port)
-                ret = self.xfrin_start(zone_name, 
+                ret = self.xfrin_start(zone_name,
-                                       rrclass, 
+                                       rrclass,
                                        self._get_db_file(),
                                        self._get_db_file(),
                                        master_addr,
                                        master_addr,
+                                       self._tsig_key_str,
                                        True)
                                        True)
                 answer = create_answer(ret[0], ret[1])
                 answer = create_answer(ret[0], ret[1])
 
 
             elif command == 'retransfer' or command == 'refresh':
             elif command == 'retransfer' or command == 'refresh':
                 # Xfrin receives the retransfer/refresh from cmdctl(sent by bindctl).
                 # Xfrin receives the retransfer/refresh from cmdctl(sent by bindctl).
-                # If the command has specified master address, do transfer from the 
+                # If the command has specified master address, do transfer from the
-                # master address, or else do transfer from the configured masters.                
+                # master address, or else do transfer from the configured masters.
                 (zone_name, rrclass) = self._parse_zone_name_and_class(args)
                 (zone_name, rrclass) = self._parse_zone_name_and_class(args)
                 master_addr = self._parse_master_and_port(args)
                 master_addr = self._parse_master_and_port(args)
                 db_file = args.get('db_file') or self._get_db_file()
                 db_file = args.get('db_file') or self._get_db_file()
-                ret = self.xfrin_start(zone_name, 
+                ret = self.xfrin_start(zone_name,
-                                       rrclass, 
+                                       rrclass,
-                                       db_file, 
+                                       db_file,
                                        master_addr,
                                        master_addr,
+                                       self._tsig_key_str,
                                        (False if command == 'retransfer' else True))
                                        (False if command == 'retransfer' else True))
                 answer = create_answer(ret[0], ret[1])
                 answer = create_answer(ret[0], ret[1])
 
 
@@ -483,14 +496,14 @@ class Xfrin:
                 rrclass = RRClass(rrclass)
                 rrclass = RRClass(rrclass)
             except InvalidRRClass as e:
             except InvalidRRClass as e:
                 raise XfrinException('invalid RRClass: ' + rrclass)
                 raise XfrinException('invalid RRClass: ' + rrclass)
-        
+
         return zone_name, rrclass
         return zone_name, rrclass
 
 
     def _parse_master_and_port(self, args):
     def _parse_master_and_port(self, args):
         port = args.get('port') or self._master_port
         port = args.get('port') or self._master_port
         master = args.get('master') or self._master_addr
         master = args.get('master') or self._master_addr
         return build_addr_info(master, port)
         return build_addr_info(master, port)
- 
+
     def _get_db_file(self):
     def _get_db_file(self):
         #TODO, the db file path should be got in auth server's configuration
         #TODO, the db file path should be got in auth server's configuration
         # if we need access to this configuration more often, we
         # if we need access to this configuration more often, we
@@ -506,12 +519,12 @@ class Xfrin:
             db_file = os.environ["B10_FROM_BUILD"] + os.sep + "bind10_zones.sqlite3"
             db_file = os.environ["B10_FROM_BUILD"] + os.sep + "bind10_zones.sqlite3"
         self._module_cc.remove_remote_config(AUTH_SPECFILE_LOCATION)
         self._module_cc.remove_remote_config(AUTH_SPECFILE_LOCATION)
         return db_file
         return db_file
-       
+
     def publish_xfrin_news(self, zone_name, zone_class,  xfr_result):
     def publish_xfrin_news(self, zone_name, zone_class,  xfr_result):
         '''Send command to xfrout/zone manager module.
         '''Send command to xfrout/zone manager module.
-        If xfrin has finished successfully for one zone, tell the good 
+        If xfrin has finished successfully for one zone, tell the good
         news(command: zone_new_data_ready) to zone manager and xfrout.
         news(command: zone_new_data_ready) to zone manager and xfrout.
-        if xfrin failed, just tell the bad news to zone manager, so that 
+        if xfrin failed, just tell the bad news to zone manager, so that
         it can reset the refresh timer for that zone. '''
         it can reset the refresh timer for that zone. '''
         param = {'zone_name': zone_name, 'zone_class': zone_class.to_text()}
         param = {'zone_name': zone_name, 'zone_class': zone_class.to_text()}
         if xfr_result == XFRIN_OK:
         if xfr_result == XFRIN_OK:
@@ -531,8 +544,8 @@ class Xfrin:
                                                                       seq)
                                                                       seq)
                 except isc.cc.session.SessionTimeout:
                 except isc.cc.session.SessionTimeout:
                     pass        # for now we just ignore the failure
                     pass        # for now we just ignore the failure
-            except socket.error as err: 
+            except socket.error as err:
-                log_error("Fail to send message to %s and %s, msgq may has been killed" 
+                log_error("Fail to send message to %s and %s, msgq may has been killed"
                           % (XFROUT_MODULE_NAME, ZONE_MANAGER_MODULE_NAME))
                           % (XFROUT_MODULE_NAME, ZONE_MANAGER_MODULE_NAME))
         else:
         else:
             msg = create_command(ZONE_XFRIN_FAILED, param)
             msg = create_command(ZONE_XFRIN_FAILED, param)
@@ -545,14 +558,14 @@ class Xfrin:
                 except isc.cc.session.SessionTimeout:
                 except isc.cc.session.SessionTimeout:
                     pass        # for now we just ignore the failure
                     pass        # for now we just ignore the failure
             except socket.error as err:
             except socket.error as err:
-                log_error("Fail to send message to %s, msgq may has been killed" 
+                log_error("Fail to send message to %s, msgq may has been killed"
                           % ZONE_MANAGER_MODULE_NAME)
                           % ZONE_MANAGER_MODULE_NAME)
 
 
     def startup(self):
     def startup(self):
         while not self._shutdown_event.is_set():
         while not self._shutdown_event.is_set():
             self._cc_check_command()
             self._cc_check_command()
 
 
-    def xfrin_start(self, zone_name, rrclass, db_file, master_addrinfo,
+    def xfrin_start(self, zone_name, rrclass, db_file, master_addrinfo, tsig_key_str,
                     check_soa = True):
                     check_soa = True):
         if "pydnspp" not in sys.modules:
         if "pydnspp" not in sys.modules:
             return (1, "xfrin failed, can't load dns message python library: 'pydnspp'")
             return (1, "xfrin failed, can't load dns message python library: 'pydnspp'")
@@ -571,7 +584,8 @@ class Xfrin:
                                                 db_file,
                                                 db_file,
                                                 self._shutdown_event,
                                                 self._shutdown_event,
                                                 master_addrinfo, check_soa,
                                                 master_addrinfo, check_soa,
-                                                self._verbose))
+                                                self._verbose,
+                                                tsig_key_str))
 
 
         xfrin_thread.start()
         xfrin_thread.start()
         return (0, 'zone xfrin is started')
         return (0, 'zone xfrin is started')

+ 5 - 0
src/bin/xfrin/xfrin.spec

@@ -19,6 +19,11 @@
         "item_type": "integer",
         "item_type": "integer",
         "item_optional": false,
         "item_optional": false,
         "item_default": 53
         "item_default": 53
+      },
+      { "item_name": "tsig_key",
+        "item_type": "string",
+        "item_optional": true,
+        "item_default": ""
       }
       }
     ],
     ],
     "commands": [
     "commands": [