|
@@ -73,57 +73,25 @@ def get_rrset_len(rrset):
|
|
|
return len(bytes)
|
|
|
|
|
|
|
|
|
-class XfroutSession(BaseRequestHandler):
|
|
|
- def __init__(self, request, client_address, server, log):
|
|
|
+class XfroutSession():
|
|
|
+ def __init__(self, sock_fd, request_data, server, log):
|
|
|
# The initializer for the superclass may call functions
|
|
|
# that need _log to be set, so we set it first
|
|
|
+ self._sock_fd = sock_fd
|
|
|
+ self._request_data = request_data
|
|
|
+ self._server = server
|
|
|
self._log = log
|
|
|
- BaseRequestHandler.__init__(self, request, client_address, server)
|
|
|
+ self.handle()
|
|
|
|
|
|
def handle(self):
|
|
|
- ''' Handle a xfrout query. First, xfrout server receive
|
|
|
- socket fd and query message from auth. Then, send xfrout
|
|
|
- response via the socket fd.'''
|
|
|
- sock_fd = recv_fd(self.request.fileno())
|
|
|
- if sock_fd < 0:
|
|
|
- # This may happen when one xfrout process try to connect to
|
|
|
- # xfrout unix socket server, to check whether there is another
|
|
|
- # xfrout running.
|
|
|
- if sock_fd == XFR_FD_RECEIVE_FAIL:
|
|
|
- self._log.log_message("error", "Failed to receive the file descriptor for XFR connection")
|
|
|
- return
|
|
|
-
|
|
|
- # receive query msg
|
|
|
- msgdata = self._receive_query_message(self.request)
|
|
|
- if not msgdata:
|
|
|
- return
|
|
|
-
|
|
|
+ ''' Handle a xfrout query, send xfrout response '''
|
|
|
try:
|
|
|
- self.dns_xfrout_start(sock_fd, msgdata)
|
|
|
+ self.dns_xfrout_start(self._sock_fd, self._request_data)
|
|
|
#TODO, avoid catching all exceptions
|
|
|
except Exception as e:
|
|
|
self._log.log_message("error", str(e))
|
|
|
|
|
|
- os.close(sock_fd)
|
|
|
-
|
|
|
- def _receive_query_message(self, sock):
|
|
|
- ''' receive query message from sock'''
|
|
|
- # receive data length
|
|
|
- data_len = sock.recv(2)
|
|
|
- if not data_len:
|
|
|
- return None
|
|
|
- msg_len = struct.unpack('!H', data_len)[0]
|
|
|
- # receive data
|
|
|
- recv_size = 0
|
|
|
- msgdata = b''
|
|
|
- while recv_size < msg_len:
|
|
|
- data = sock.recv(msg_len - recv_size)
|
|
|
- if not data:
|
|
|
- return None
|
|
|
- recv_size += len(data)
|
|
|
- msgdata += data
|
|
|
-
|
|
|
- return msgdata
|
|
|
+ os.close(self._sock_fd)
|
|
|
|
|
|
def _parse_query_message(self, mdata):
|
|
|
''' parse query message to [socket,message]'''
|
|
@@ -176,7 +144,7 @@ class XfroutSession(BaseRequestHandler):
|
|
|
|
|
|
|
|
|
def _zone_is_empty(self, zone):
|
|
|
- if sqlite3_ds.get_zone_soa(zone, self.server.get_db_file()):
|
|
|
+ if sqlite3_ds.get_zone_soa(zone, self._server.get_db_file()):
|
|
|
return False
|
|
|
|
|
|
return True
|
|
@@ -184,7 +152,7 @@ class XfroutSession(BaseRequestHandler):
|
|
|
def _zone_exist(self, zonename):
|
|
|
# Find zone in datasource, should this works? maybe should ask
|
|
|
# config manager.
|
|
|
- soa = sqlite3_ds.get_zone_soa(zonename, self.server.get_db_file())
|
|
|
+ soa = sqlite3_ds.get_zone_soa(zonename, self._server.get_db_file())
|
|
|
if soa:
|
|
|
return True
|
|
|
return False
|
|
@@ -202,7 +170,7 @@ class XfroutSession(BaseRequestHandler):
|
|
|
return Rcode.SERVFAIL()
|
|
|
|
|
|
#TODO, check allow_transfer
|
|
|
- if not self.server.increase_transfers_counter():
|
|
|
+ if not self._server.increase_transfers_counter():
|
|
|
return Rcode.REFUSED()
|
|
|
|
|
|
return Rcode.NOERROR()
|
|
@@ -228,7 +196,7 @@ class XfroutSession(BaseRequestHandler):
|
|
|
except Exception as err:
|
|
|
self._log.log_message("error", str(err))
|
|
|
|
|
|
- self.server.decrease_transfers_counter()
|
|
|
+ self._server.decrease_transfers_counter()
|
|
|
return
|
|
|
|
|
|
|
|
@@ -275,14 +243,14 @@ class XfroutSession(BaseRequestHandler):
|
|
|
#TODO, there should be a better way to insert rrset.
|
|
|
msg.make_response()
|
|
|
msg.set_header_flag(Message.HEADERFLAG_AA)
|
|
|
- soa_record = sqlite3_ds.get_zone_soa(zone_name, self.server.get_db_file())
|
|
|
+ soa_record = sqlite3_ds.get_zone_soa(zone_name, self._server.get_db_file())
|
|
|
rrset_soa = self._create_rrset_from_db_record(soa_record)
|
|
|
msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
|
|
|
|
|
|
message_upper_len = get_rrset_len(rrset_soa)
|
|
|
|
|
|
- for rr_data in sqlite3_ds.get_zone_datas(zone_name, self.server.get_db_file()):
|
|
|
- if self.server._shutdown_event.is_set(): # Check if xfrout is shutdown
|
|
|
+ for rr_data in sqlite3_ds.get_zone_datas(zone_name, self._server.get_db_file()):
|
|
|
+ if self._server._shutdown_event.is_set(): # Check if xfrout is shutdown
|
|
|
self._log.log_message("info", "xfrout process is being shutdown")
|
|
|
return
|
|
|
|
|
@@ -324,7 +292,7 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
|
|
|
self.update_config_data(config_data)
|
|
|
self._cc = cc
|
|
|
|
|
|
- def _handle_request_noblock(self):
|
|
|
+ def handle_request(self):
|
|
|
'''Rewrite _handle_request_noblock() from parent class ThreadingUnixStreamServer,
|
|
|
enable server handle a request until shutdown or xfrout client is closed.'''
|
|
|
try:
|
|
@@ -359,19 +327,52 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
|
|
|
self.close_request(request)
|
|
|
break
|
|
|
|
|
|
- def process_request_thread(self, request, client_address):
|
|
|
- ''' Rewrite process_request_thread() from parent class ThreadingUnixStreamServer,
|
|
|
- server won't close the connection after handling a xfrout query, the connection
|
|
|
- should be kept for handling upcoming xfrout queries.'''
|
|
|
- try:
|
|
|
- self.finish_request(request, client_address)
|
|
|
- except Exception as e:
|
|
|
- self.handle_error(request, client_address)
|
|
|
- self.close_request(request)
|
|
|
+ def _receive_query_message(self, sock):
|
|
|
+ ''' receive request message from sock'''
|
|
|
+ # receive data length
|
|
|
+ data_len = sock.recv(2)
|
|
|
+ if not data_len:
|
|
|
+ return None
|
|
|
+ msg_len = struct.unpack('!H', data_len)[0]
|
|
|
+ # receive data
|
|
|
+ recv_size = 0
|
|
|
+ msgdata = b''
|
|
|
+ while recv_size < msg_len:
|
|
|
+ data = sock.recv(msg_len - recv_size)
|
|
|
+ if not data:
|
|
|
+ return None
|
|
|
+ recv_size += len(data)
|
|
|
+ msgdata += data
|
|
|
+
|
|
|
+ return msgdata
|
|
|
+
|
|
|
+ def process_request(self, request, client_address):
|
|
|
+ """Receive socket fd and query message from auth, then
|
|
|
+ start a new thread to process the request."""
|
|
|
+ sock_fd = recv_fd(request.fileno())
|
|
|
+ if sock_fd < 0:
|
|
|
+ # This may happen when one xfrout process try to connect to
|
|
|
+ # xfrout unix socket server, to check whether there is another
|
|
|
+ # xfrout running.
|
|
|
+ if sock_fd == XFR_FD_RECEIVE_FAIL:
|
|
|
+ self._log.log_message("error", "Failed to receive the file descriptor for XFR connection")
|
|
|
+ return
|
|
|
+
|
|
|
+ # receive request msg
|
|
|
+ request_data = self._receive_query_message(request)
|
|
|
+ if not request_data:
|
|
|
+ return
|
|
|
+
|
|
|
+ t = threading.Thread(target = self.finish_request,
|
|
|
+ args = (sock_fd, request_data, client_address))
|
|
|
+ if self.daemon_threads:
|
|
|
+ t.daemon = True
|
|
|
+ t.start()
|
|
|
+
|
|
|
|
|
|
- def finish_request(self, request, client_address):
|
|
|
+ def finish_request(self, sock_fd, request_data, client_address):
|
|
|
'''Finish one request by instantiating RequestHandlerClass.'''
|
|
|
- self.RequestHandlerClass(request, client_address, self, self._log)
|
|
|
+ self.RequestHandlerClass(sock_fd, request_data, self, self._log)
|
|
|
|
|
|
def _remove_unused_sock_file(self, sock_file):
|
|
|
'''Try to remove the socket file. If the file is being used
|