Browse Source

[trac419] update request handling logic

chenzhengzhang 14 years ago
parent
commit
40f74edaaf

+ 9 - 9
src/bin/xfrout/tests/xfrout_test.py

@@ -88,20 +88,11 @@ class TestXfroutSession(unittest.TestCase):
         request = MySocket(socket.AF_INET,socket.SOCK_STREAM)
         self.log = isc.log.NSLogger('xfrout', '',  severity = 'critical', log_to_console = False )
         self.xfrsess = MyXfroutSession(request, None, None, self.log)
-        self.write_sock, self.read_sock = socket.socketpair()
         self.xfrsess.server = Dbserver()
         self.mdata = bytes(b'\xd6=\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x07example\x03com\x00\x00\xfc\x00\x01')
         self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM)
         self.soa_record = (4, 3, 'example.com.', 'com.example.', 3600, 'SOA', None, 'master.example.com. admin.example.com. 1234 3600 1800 2419200 7200')
 
-    def test_receive_query_message(self):
-        send_msg = b"\xd6=\x00\x00\x00\x01\x00"
-        msg_len = struct.pack('H', socket.htons(len(send_msg)))
-        self.write_sock.send(msg_len)
-        self.write_sock.send(send_msg)
-        recv_msg = self.xfrsess._receive_query_message(self.read_sock)
-        self.assertEqual(recv_msg, send_msg)
-
     def test_parse_query_message(self):
         [get_rcode, get_msg] = self.xfrsess._parse_query_message(self.mdata)
         self.assertEqual(get_rcode.to_text(), "NOERROR")
@@ -321,8 +312,17 @@ class MyUnixSockServer(UnixSockServer):
 
 class TestUnixSockServer(unittest.TestCase):
     def setUp(self):
+        self.write_sock, self.read_sock = socket.socketpair()
         self.unix = MyUnixSockServer()
 
+    def test_receive_query_message(self):
+        send_msg = b"\xd6=\x00\x00\x00\x01\x00"
+        msg_len = struct.pack('H', socket.htons(len(send_msg)))
+        self.write_sock.send(msg_len)
+        self.write_sock.send(send_msg)
+        recv_msg = self.unix._receive_query_message(self.read_sock)
+        self.assertEqual(recv_msg, send_msg)
+
     def test_updata_config_data(self):
         self.unix.update_config_data({'transfers_out':10 })
         self.assertEqual(self.unix._max_transfers_out, 10)

+ 61 - 60
src/bin/xfrout/xfrout.py.in

@@ -73,57 +73,25 @@ def get_rrset_len(rrset):
     return len(bytes)
 
 
-class XfroutSession(BaseRequestHandler):
-    def __init__(self, request, client_address, server, log):
+class XfroutSession():
+    def __init__(self, sock_fd, request_data, server, log):
         # The initializer for the superclass may call functions
         # that need _log to be set, so we set it first
+        self._sock_fd = sock_fd
+        self._request_data = request_data
+        self._server = server
         self._log = log
-        BaseRequestHandler.__init__(self, request, client_address, server)
+        self.handle()
 
     def handle(self):
-        ''' Handle a xfrout query. First, xfrout server receive
-        socket fd and query message from auth. Then, send xfrout
-        response via the socket fd.'''
-        sock_fd = recv_fd(self.request.fileno())
-        if sock_fd < 0:
-            # This may happen when one xfrout process try to connect to
-            # xfrout unix socket server, to check whether there is another
-            # xfrout running.
-            if sock_fd == XFR_FD_RECEIVE_FAIL:
-                self._log.log_message("error", "Failed to receive the file descriptor for XFR connection")
-            return
-
-        # receive query msg
-        msgdata = self._receive_query_message(self.request)
-        if not msgdata:
-            return
-
+        ''' Handle a xfrout query, send xfrout response '''
         try:
-            self.dns_xfrout_start(sock_fd, msgdata)
+            self.dns_xfrout_start(self._sock_fd, self._request_data)
             #TODO, avoid catching all exceptions
         except Exception as e:
             self._log.log_message("error", str(e))
 
-        os.close(sock_fd)
-
-    def _receive_query_message(self, sock):
-        ''' receive query message from sock'''
-        # receive data length
-        data_len = sock.recv(2)
-        if not data_len:
-            return None
-        msg_len = struct.unpack('!H', data_len)[0]
-        # receive data
-        recv_size = 0
-        msgdata = b''
-        while recv_size < msg_len:
-            data = sock.recv(msg_len - recv_size)
-            if not data:
-                return None
-            recv_size += len(data)
-            msgdata += data
-
-        return msgdata
+        os.close(self._sock_fd)
 
     def _parse_query_message(self, mdata):
         ''' parse query message to [socket,message]'''
@@ -176,7 +144,7 @@ class XfroutSession(BaseRequestHandler):
 
 
     def _zone_is_empty(self, zone):
-        if sqlite3_ds.get_zone_soa(zone, self.server.get_db_file()):
+        if sqlite3_ds.get_zone_soa(zone, self._server.get_db_file()):
             return False
 
         return True
@@ -184,7 +152,7 @@ class XfroutSession(BaseRequestHandler):
     def _zone_exist(self, zonename):
         # Find zone in datasource, should this works? maybe should ask
         # config manager.
-        soa = sqlite3_ds.get_zone_soa(zonename, self.server.get_db_file())
+        soa = sqlite3_ds.get_zone_soa(zonename, self._server.get_db_file())
         if soa:
             return True
         return False
@@ -202,7 +170,7 @@ class XfroutSession(BaseRequestHandler):
             return Rcode.SERVFAIL()
 
         #TODO, check allow_transfer
-        if not self.server.increase_transfers_counter():
+        if not self._server.increase_transfers_counter():
             return Rcode.REFUSED()
 
         return Rcode.NOERROR()
@@ -228,7 +196,7 @@ class XfroutSession(BaseRequestHandler):
         except Exception as err:
             self._log.log_message("error", str(err))
 
-        self.server.decrease_transfers_counter()
+        self._server.decrease_transfers_counter()
         return
 
 
@@ -275,14 +243,14 @@ class XfroutSession(BaseRequestHandler):
         #TODO, there should be a better way to insert rrset.
         msg.make_response()
         msg.set_header_flag(Message.HEADERFLAG_AA)
-        soa_record = sqlite3_ds.get_zone_soa(zone_name, self.server.get_db_file())
+        soa_record = sqlite3_ds.get_zone_soa(zone_name, self._server.get_db_file())
         rrset_soa = self._create_rrset_from_db_record(soa_record)
         msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
 
         message_upper_len = get_rrset_len(rrset_soa)
 
-        for rr_data in sqlite3_ds.get_zone_datas(zone_name, self.server.get_db_file()):
-            if  self.server._shutdown_event.is_set(): # Check if xfrout is shutdown
+        for rr_data in sqlite3_ds.get_zone_datas(zone_name, self._server.get_db_file()):
+            if  self._server._shutdown_event.is_set(): # Check if xfrout is shutdown
                 self._log.log_message("info", "xfrout process is being shutdown")
                 return
 
@@ -324,7 +292,7 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
         self.update_config_data(config_data)
         self._cc = cc
 
-    def _handle_request_noblock(self):
+    def handle_request(self):
         '''Rewrite _handle_request_noblock() from parent class ThreadingUnixStreamServer,
         enable server handle a request until shutdown or xfrout client is closed.'''
         try:
@@ -359,19 +327,52 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
                     self.close_request(request)
                     break
 
-    def process_request_thread(self, request, client_address):
-        ''' Rewrite process_request_thread() from parent class ThreadingUnixStreamServer,
-        server won't close the connection after handling a xfrout query, the connection
-        should be kept for handling upcoming xfrout queries.'''
-        try:
-            self.finish_request(request, client_address)
-        except Exception as e:
-            self.handle_error(request, client_address)
-            self.close_request(request)
+    def _receive_query_message(self, sock):
+        ''' receive request message from sock'''
+        # receive data length
+        data_len = sock.recv(2)
+        if not data_len:
+            return None
+        msg_len = struct.unpack('!H', data_len)[0]
+        # receive data
+        recv_size = 0
+        msgdata = b''
+        while recv_size < msg_len:
+            data = sock.recv(msg_len - recv_size)
+            if not data:
+                return None
+            recv_size += len(data)
+            msgdata += data
+
+        return msgdata
+
+    def process_request(self, request, client_address):
+        """Receive socket fd and query message from auth, then
+        start a new thread to process the request."""
+        sock_fd = recv_fd(request.fileno())
+        if sock_fd < 0:
+            # This may happen when one xfrout process try to connect to
+            # xfrout unix socket server, to check whether there is another
+            # xfrout running.
+            if sock_fd == XFR_FD_RECEIVE_FAIL:
+                self._log.log_message("error", "Failed to receive the file descriptor for XFR connection")
+            return
+
+        # receive request msg
+        request_data = self._receive_query_message(request)
+        if not request_data:
+            return
+
+        t = threading.Thread(target = self.finish_request,
+                             args = (sock_fd, request_data, client_address))
+        if self.daemon_threads:
+            t.daemon = True
+        t.start()
+
 
-    def finish_request(self, request, client_address):
+    def finish_request(self, sock_fd, request_data, client_address):
         '''Finish one request by instantiating RequestHandlerClass.'''
-        self.RequestHandlerClass(request, client_address, self, self._log)
+        self.RequestHandlerClass(sock_fd, request_data, self, self._log)
 
     def _remove_unused_sock_file(self, sock_file):
         '''Try to remove the socket file. If the file is being used

+ 1 - 1
src/lib/python/isc/util/socketserver_mixin.py

@@ -79,7 +79,7 @@ class NoPollMixIn:
                 break
             else:
                 # Create a new thread to handle requests for each auth
-                threading.Thread(target=self._handle_request_noblock).start()
+                threading.Thread(target=self.handle_request).start()
 
         self._is_shut_down.set()