Browse Source

merge trac299: Xfrout and Auth will communicate by long tcp connection,
Auth needs to make a new connection only on the first time or if an error occurred.


git-svn-id: svn://bind10.isc.org/svn/bind10/trunk@3482 e5f2f494-b856-4b98-b285-d166d9295462

Jerry 14 years ago
parent
commit
f7c0462610

+ 7 - 1
ChangeLog

@@ -1,3 +1,9 @@
+  116.	[bug]		jerry
+	src/bin/xfrout: Xfrout and Auth will communicate by long tcp
+	connection, Auth needs to make a new connection only on the first
+	time or if an error occurred.
+	(Trac #299, svn r3482)
+
   115.	[func]*		jinmei
 	src/lib/dns: Changed DNS message flags and section names from
 	separate classes to simpler enums, considering the balance between
@@ -10,7 +16,7 @@
 	(Trac #365, svn r3383)
 
   113.	[func]*		zhanglikun
-	Folder name 'utils'(the folder in /src/lib/python/isc/) has been 
+	Folder name 'utils'(the folder in /src/lib/python/isc/) has been
 	renamed	to 'util'. Programs that used 'import isc.utils.process'
 	now need to use 'import isc.util.process'. The folder
 	/src/lib/python/isc/Util is removed since it isn't used by any

+ 12 - 13
src/bin/auth/auth_srv.cc

@@ -77,7 +77,7 @@ public:
                             MessageRenderer& response_renderer);
     bool processAxfrQuery(const IOMessage& io_message, Message& message,
                             MessageRenderer& response_renderer);
-    bool processNotify(const IOMessage& io_message, Message& message, 
+    bool processNotify(const IOMessage& io_message, Message& message,
                             MessageRenderer& response_renderer);
     std::string db_file_;
     ModuleCCSession* config_session_;
@@ -307,7 +307,7 @@ AuthSrvImpl::processNormalQuery(const IOMessage& io_message, Message& message,
     ConstEDNSPtr remote_edns = message.getEDNS();
     const bool dnssec_ok = remote_edns && remote_edns->getDNSSECAwareness();
     const uint16_t remote_bufsize = remote_edns ? remote_edns->getUDPSize() :
-        Message::DEFAULT_MAX_UDPSIZE; 
+        Message::DEFAULT_MAX_UDPSIZE;
 
     message.makeResponse();
     message.setHeaderFlag(Message::HEADERFLAG_AA);
@@ -360,8 +360,10 @@ AuthSrvImpl::processAxfrQuery(const IOMessage& io_message, Message& message,
     }
 
     try {
-        xfrout_client_.connect();
-        xfrout_connected_ = true;
+        if (!xfrout_connected_) {
+            xfrout_client_.connect();
+            xfrout_connected_ = true;
+        }
         xfrout_client_.sendXfroutRequestInfo(
             io_message.getSocket().getNative(),
             io_message.getData(),
@@ -375,7 +377,7 @@ AuthSrvImpl::processAxfrQuery(const IOMessage& io_message, Message& message,
             xfrout_client_.disconnect();
             xfrout_connected_ = false;
         }
-        
+
         if (verbose_mode_) {
             cerr << "[b10-auth] Error in handling XFR request: " << err.what()
                  << endl;
@@ -385,15 +387,12 @@ AuthSrvImpl::processAxfrQuery(const IOMessage& io_message, Message& message,
         return (true);
     }
 
-    xfrout_client_.disconnect();
-    xfrout_connected_ = false;
-
     return (false);
 }
 
 bool
-AuthSrvImpl::processNotify(const IOMessage& io_message, Message& message, 
-                           MessageRenderer& response_renderer) 
+AuthSrvImpl::processNotify(const IOMessage& io_message, Message& message,
+                           MessageRenderer& response_renderer)
 {
     // The incoming notify must contain exactly one question for SOA of the
     // zone name.
@@ -435,7 +434,7 @@ AuthSrvImpl::processNotify(const IOMessage& io_message, Message& message,
         }
         return (false);
     }
-    
+
     const string remote_ip_address =
         io_message.getRemoteEndpoint().getAddress().toText();
     static const string command_template_start =
@@ -446,7 +445,7 @@ AuthSrvImpl::processNotify(const IOMessage& io_message, Message& message,
 
     try {
         ConstElementPtr notify_command = Element::fromJSON(
-                command_template_start + question->getName().toText() + 
+                command_template_start + question->getName().toText() +
                 command_template_master + remote_ip_address +
                 command_template_rrclass + question->getClass().toText() +
                 command_template_end);
@@ -460,7 +459,7 @@ AuthSrvImpl::processNotify(const IOMessage& io_message, Message& message,
         if (rcode != 0) {
             if (verbose_mode_) {
                 cerr << "[b10-auth] failed to notify Zonemgr: "
-                     << parsed_answer->str() << endl; 
+                     << parsed_answer->str() << endl;
             }
             return (false);
         }

+ 2 - 4
src/bin/auth/tests/auth_srv_unittest.cc

@@ -489,7 +489,7 @@ TEST_F(AuthSrvTest, AXFRSuccess) {
     // so we shouldn't have to respond.
     EXPECT_FALSE(server.processMessage(*io_message, parse_message,
                                        response_renderer));
-    EXPECT_FALSE(xfrout.isConnected());
+    EXPECT_TRUE(xfrout.isConnected());
 }
 
 TEST_F(AuthSrvTest, AXFRConnectFail) {
@@ -501,8 +501,6 @@ TEST_F(AuthSrvTest, AXFRConnectFail) {
                                       response_renderer));
     headerCheck(parse_message, default_qid, Rcode::SERVFAIL(),
                 opcode.getCode(), QR_FLAG, 1, 0, 0, 0);
-    // For a shot term workaround with xfrout we currently close the connection
-    // for each AXFR attempt
     EXPECT_FALSE(xfrout.isConnected());
 }
 
@@ -512,7 +510,7 @@ TEST_F(AuthSrvTest, AXFRSendFail) {
     createRequestPacket(opcode, Name("example.com"), RRClass::IN(),
                         RRType::AXFR(), IPPROTO_TCP);
     server.processMessage(*io_message, parse_message, response_renderer);
-    EXPECT_FALSE(xfrout.isConnected()); // see above
+    EXPECT_TRUE(xfrout.isConnected());
 
     xfrout.disableSend();
     parse_message.clear(Message::PARSE);

+ 36 - 20
src/bin/xfrout/tests/xfrout_test.py

@@ -47,22 +47,29 @@ class MySocket():
         result = self.sendqueue[:size]
         self.sendqueue = self.sendqueue[size:]
         return result
-    
+
     def read_msg(self):
         sent_data = self.readsent()
         get_msg = Message(Message.PARSE)
         get_msg.from_wire(bytes(sent_data[2:]))
         return get_msg
-    
+
     def clear_send(self):
         del self.sendqueue[:]
 
 # We subclass the Session class we're testing here, only
-# to override the __init__() method, which wants a socket,
+# to override the handle() and _send_data() method
 class MyXfroutSession(XfroutSession):
     def handle(self):
         pass
-    
+
+    def _send_data(self, sock, data):
+        size = len(data)
+        total_count = 0
+        while total_count < size:
+            count = sock.send(data[total_count:])
+            total_count += count
+
 class Dbserver:
     def __init__(self):
         self._shutdown_event = threading.Event()
@@ -80,12 +87,21 @@ class TestXfroutSession(unittest.TestCase):
     def setUp(self):
         request = MySocket(socket.AF_INET,socket.SOCK_STREAM)
         self.log = isc.log.NSLogger('xfrout', '',  severity = 'critical', log_to_console = False )
-        self.xfrsess = MyXfroutSession(request, None, None, self.log)
+        (self.write_sock, self.read_sock) = socket.socketpair()
+        self.xfrsess = MyXfroutSession(request, None, None, self.log, self.read_sock)
         self.xfrsess.server = Dbserver()
         self.mdata = bytes(b'\xd6=\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x07example\x03com\x00\x00\xfc\x00\x01')
         self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM)
         self.soa_record = (4, 3, 'example.com.', 'com.example.', 3600, 'SOA', None, 'master.example.com. admin.example.com. 1234 3600 1800 2419200 7200')
 
+    def test_receive_query_message(self):
+        send_msg = b"\xd6=\x00\x00\x00\x01\x00"
+        msg_len = struct.pack('H', socket.htons(len(send_msg)))
+        self.write_sock.send(msg_len)
+        self.write_sock.send(send_msg)
+        recv_msg = self.xfrsess._receive_query_message(self.read_sock)
+        self.assertEqual(recv_msg, send_msg)
+
     def test_parse_query_message(self):
         [get_rcode, get_msg] = self.xfrsess._parse_query_message(self.mdata)
         self.assertEqual(get_rcode.to_text(), "NOERROR")
@@ -93,7 +109,7 @@ class TestXfroutSession(unittest.TestCase):
     def test_get_query_zone_name(self):
         msg = self.getmsg()
         self.assertEqual(self.xfrsess._get_query_zone_name(msg), "example.com.")
-  
+
     def test_send_data(self):
         self.xfrsess._send_data(self.sock, self.mdata)
         senddata = self.sock.readsent()
@@ -103,8 +119,8 @@ class TestXfroutSession(unittest.TestCase):
         msg = self.getmsg()
         self.xfrsess._reply_query_with_error_rcode(msg, self.sock, Rcode(3))
         get_msg = self.sock.read_msg()
-        self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN") 
-     
+        self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN")
+
     def test_clear_message(self):
         msg = self.getmsg()
         qid = msg.get_qid()
@@ -118,7 +134,7 @@ class TestXfroutSession(unittest.TestCase):
         self.assertTrue(msg.get_header_flag(Message.HEADERFLAG_AA))
 
     def test_reply_query_with_format_error(self):
-         
+
         msg = self.getmsg()
         self.xfrsess._reply_query_with_format_error(msg, self.sock)
         get_msg = self.sock.read_msg()
@@ -217,7 +233,7 @@ class TestXfroutSession(unittest.TestCase):
         sqlite3_ds.get_zone_soa = zone_soa
         self.assertEqual(self.xfrsess._zone_exist(True), True)
         self.assertEqual(self.xfrsess._zone_exist(False), False)
-    
+
     def test_check_xfrout_available(self):
         def zone_exist(zone):
             return zone
@@ -243,7 +259,7 @@ class TestXfroutSession(unittest.TestCase):
         self.xfrsess.dns_xfrout_start(self.sock, b"\xd6=\x00\x00\x00\x01\x00")
         sent_data = self.sock.readsent()
         self.assertEqual(len(sent_data), 0)
-    
+
     def default(self, param):
         return "example.com"
 
@@ -255,20 +271,20 @@ class TestXfroutSession(unittest.TestCase):
         self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
         get_msg = self.sock.read_msg()
         self.assertEqual(get_msg.get_rcode().to_text(), "NOTAUTH")
-    
+
     def test_dns_xfrout_start_noerror(self):
         self.xfrsess._get_query_zone_name = self.default
         def noerror(form):
-            return Rcode.NOERROR() 
+            return Rcode.NOERROR()
         self.xfrsess._check_xfrout_available = noerror
-        
+
         def myreply(msg, sock, zonename):
             self.sock.send(b"success")
-        
+
         self.xfrsess._reply_xfrout_query = myreply
         self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
         self.assertEqual(self.sock.readsent(), b"success")
-    
+
     def test_reply_xfrout_query_noerror(self):
         global sqlite3_ds
         def get_zone_soa(zonename, file):
@@ -292,7 +308,7 @@ class MyCCSession():
             return "initdb.file", False
         else:
             return "unknown", False
-    
+
 
 class MyUnixSockServer(UnixSockServer):
     def __init__(self):
@@ -306,7 +322,7 @@ class MyUnixSockServer(UnixSockServer):
 class TestUnixSockServer(unittest.TestCase):
     def setUp(self):
         self.unix = MyUnixSockServer()
-     
+
     def test_updata_config_data(self):
         self.unix.update_config_data({'transfers_out':10 })
         self.assertEqual(self.unix._max_transfers_out, 10)
@@ -324,7 +340,7 @@ class TestUnixSockServer(unittest.TestCase):
         count = self.unix._transfers_counter
         self.assertEqual(self.unix.increase_transfers_counter(), False)
         self.assertEqual(count, self.unix._transfers_counter)
- 
+
     def test_decrease_transfers_counter(self):
         count = self.unix._transfers_counter
         self.unix.decrease_transfers_counter()
@@ -335,7 +351,7 @@ class TestUnixSockServer(unittest.TestCase):
             os.remove(sock_file)
         except OSError:
             pass
- 
+
     def test_sock_file_in_use_file_exist(self):
         sock_file = 'temp.sock.file'
         self._remove_file(sock_file)

+ 108 - 73
src/bin/xfrout/xfrout.py.in

@@ -63,6 +63,7 @@ AUTH_SPECFILE_LOCATION = AUTH_SPECFILE_PATH + os.sep + "auth.spec"
 MAX_TRANSFERS_OUT = 10
 VERBOSE_MODE = False
 
+
 XFROUT_MAX_MESSAGE_SIZE = 65535
 
 def get_rrset_len(rrset):
@@ -73,46 +74,78 @@ def get_rrset_len(rrset):
 
 
 class XfroutSession(BaseRequestHandler):
-    def __init__(self, request, client_address, server, log):
+    def __init__(self, request, client_address, server, log, sock):
         # The initializer for the superclass may call functions
         # that need _log to be set, so we set it first
         self._log = log
+        self._shutdown_sock = sock
         BaseRequestHandler.__init__(self, request, client_address, server)
 
     def handle(self):
-        fd = recv_fd(self.request.fileno())
-        
-        if fd < 0:
-            # This may happen when one xfrout process try to connect to
-            # xfrout unix socket server, to check whether there is another
-            # xfrout running. 
-            self._log.log_message("error", "Failed to receive the file descriptor for XFR connection")
-            return
-
-        data_len = self.request.recv(2)
-        msg_len = struct.unpack('!H', data_len)[0]
-        msgdata = self.request.recv(msg_len)
-        sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM)
-        try:
-            self.dns_xfrout_start(sock, msgdata)
-            #TODO, avoid catching all exceptions
-        except Exception as e:
-            self._log.log_message("error", str(e))
+        '''Handle a request until shutdown or xfrout client is closed.'''
+        # check self.server._shutdown_event to ensure the real shutdown comes.
+        # Linux could trigger a spurious readable event on the _shutdown_sock 
+        # due to a bug, so we need perform a double check. 
+        while not self.server._shutdown_event.is_set(): # Check if xfrout is shutdown
+            try:
+                (rlist, wlist, xlist) = select.select([self._shutdown_sock, self.request], [], [])
+            except select.error as e:
+                if e.args[0] == errno.EINTR:
+                    (rlist, wlist, xlist) = ([], [], [])
+                    continue
+                else:
+                    self._log.log_message("error", "Error with select(): %s" %e)
+                    break
+            # self.server._shutdown_evnet will be set by now, if it is not a false
+            # alarm
+            if self._shutdown_sock in rlist:
+                continue
 
-        try:
-            sock.shutdown(socket.SHUT_RDWR)
-        except socket.error:
-            # Avoid socket error caused by shutting down 
-            # one non-connected socket.
-            pass
+            sock_fd = recv_fd(self.request.fileno())
+
+            if sock_fd < 0:
+                # This may happen when one xfrout process try to connect to
+                # xfrout unix socket server, to check whether there is another
+                # xfrout running.
+                if sock_fd == XFR_FD_RECEIVE_FAIL:
+                    self._log.log_message("error", "Failed to receive the file descriptor for XFR connection")
+                break
 
-        sock.close()
-        os.close(fd)
-        pass
+            # receive query msg
+            msgdata = self._receive_query_message(self.request)
+            if not msgdata:
+                break
+
+            try:
+                self.dns_xfrout_start(sock_fd, msgdata)
+                #TODO, avoid catching all exceptions
+            except Exception as e:
+                self._log.log_message("error", str(e))
+
+            os.close(sock_fd)
+
+    def _receive_query_message(self, sock):
+        ''' receive query message from sock'''
+        # receive data length
+        data_len = sock.recv(2)
+        if not data_len:
+            return None
+        msg_len = struct.unpack('!H', data_len)[0]
+        # receive data
+        recv_size = 0
+        msgdata = b''
+        while recv_size < msg_len:
+            data = sock.recv(msg_len - recv_size)
+            if not data:
+                return None
+            recv_size += len(data)
+            msgdata += data
+
+        return msgdata
 
     def _parse_query_message(self, mdata):
         ''' parse query message to [socket,message]'''
-        #TODO, need to add parseHeader() in case the message header is invalid 
+        #TODO, need to add parseHeader() in case the message header is invalid
         try:
             msg = Message(Message.PARSE)
             Message.from_wire(msg, mdata)
@@ -127,37 +160,37 @@ class XfroutSession(BaseRequestHandler):
         return question.get_name().to_text()
 
 
-    def _send_data(self, sock, data):
+    def _send_data(self, sock_fd, data):
         size = len(data)
         total_count = 0
         while total_count < size:
-            count = sock.send(data[total_count:])
+            count = os.write(sock_fd, data[total_count:])
             total_count += count
 
 
-    def _send_message(self, sock, msg):
+    def _send_message(self, sock_fd, msg):
         render = MessageRenderer()
         render.set_length_limit(XFROUT_MAX_MESSAGE_SIZE)
         msg.to_wire(render)
         header_len = struct.pack('H', socket.htons(render.get_length()))
-        self._send_data(sock, header_len)
-        self._send_data(sock, render.get_data())
+        self._send_data(sock_fd, header_len)
+        self._send_data(sock_fd, render.get_data())
 
 
-    def _reply_query_with_error_rcode(self, msg, sock, rcode_):
+    def _reply_query_with_error_rcode(self, msg, sock_fd, rcode_):
         msg.make_response()
         msg.set_rcode(rcode_)
-        self._send_message(sock, msg)
+        self._send_message(sock_fd, msg)
 
 
-    def _reply_query_with_format_error(self, msg, sock):
+    def _reply_query_with_format_error(self, msg, sock_fd):
         '''query message format isn't legal.'''
         if not msg:
-            return # query message is invalid. send nothing back. 
+            return # query message is invalid. send nothing back.
 
         msg.make_response()
         msg.set_rcode(Rcode.FORMERR())
-        self._send_message(sock, msg)
+        self._send_message(sock_fd, msg)
 
 
     def _zone_is_empty(self, zone):
@@ -167,24 +200,24 @@ class XfroutSession(BaseRequestHandler):
         return True
 
     def _zone_exist(self, zonename):
-        # Find zone in datasource, should this works? maybe should ask 
+        # Find zone in datasource, should this works? maybe should ask
         # config manager.
         soa = sqlite3_ds.get_zone_soa(zonename, self.server.get_db_file())
         if soa:
             return True
         return False
 
-    
+
     def _check_xfrout_available(self, zone_name):
         '''Check if xfr request can be responsed.
            TODO, Get zone's configuration from cfgmgr or some other place
-           eg. check allow_transfer setting, 
+           eg. check allow_transfer setting,
         '''
         if not self._zone_exist(zone_name):
             return Rcode.NOTAUTH()
 
         if self._zone_is_empty(zone_name):
-            return Rcode.SERVFAIL() 
+            return Rcode.SERVFAIL()
 
         #TODO, check allow_transfer
         if not self.server.increase_transfers_counter():
@@ -193,35 +226,35 @@ class XfroutSession(BaseRequestHandler):
         return Rcode.NOERROR()
 
 
-    def dns_xfrout_start(self, sock, msg_query):
+    def dns_xfrout_start(self, sock_fd, msg_query):
         rcode_, msg = self._parse_query_message(msg_query)
         #TODO. create query message and parse header
         if rcode_ != Rcode.NOERROR():
-            return self._reply_query_with_format_error(msg, sock)
+            return self._reply_query_with_format_error(msg, sock_fd)
 
         zone_name = self._get_query_zone_name(msg)
         rcode_ = self._check_xfrout_available(zone_name)
         if rcode_ != Rcode.NOERROR():
             self._log.log_message("info", "transfer of '%s/IN' failed: %s",
                                   zone_name, rcode_.to_text())
-            return self. _reply_query_with_error_rcode(msg, sock, rcode_)
+            return self. _reply_query_with_error_rcode(msg, sock_fd, rcode_)
 
         try:
             self._log.log_message("info", "transfer of '%s/IN': AXFR started" % zone_name)
-            self._reply_xfrout_query(msg, sock, zone_name)
+            self._reply_xfrout_query(msg, sock_fd, zone_name)
             self._log.log_message("info", "transfer of '%s/IN': AXFR end" % zone_name)
         except Exception as err:
             self._log.log_message("error", str(err))
 
         self.server.decrease_transfers_counter()
-        return    
+        return
 
 
     def _clear_message(self, msg):
         qid = msg.get_qid()
         opcode = msg.get_opcode()
         rcode = msg.get_rcode()
-        
+
         msg.clear(Message.RENDER)
         msg.set_qid(qid)
         msg.set_opcode(opcode)
@@ -231,7 +264,7 @@ class XfroutSession(BaseRequestHandler):
         return msg
 
     def _create_rrset_from_db_record(self, record):
-        '''Create one rrset from one record of datasource, if the schema of record is changed, 
+        '''Create one rrset from one record of datasource, if the schema of record is changed,
         This function should be updated first.
         '''
         rrtype_ = RRType(record[5])
@@ -239,8 +272,8 @@ class XfroutSession(BaseRequestHandler):
         rrset_ = RRset(Name(record[2]), RRClass("IN"), rrtype_, RRTTL( int(record[4])))
         rrset_.add_rdata(rdata_)
         return rrset_
-         
-    def _send_message_with_last_soa(self, msg, sock, rrset_soa, message_upper_len):
+
+    def _send_message_with_last_soa(self, msg, sock_fd, rrset_soa, message_upper_len):
         '''Add the SOA record to the end of message. If it can't be
         added, a new message should be created to send out the last soa .
         '''
@@ -249,14 +282,14 @@ class XfroutSession(BaseRequestHandler):
         if message_upper_len + rrset_len < XFROUT_MAX_MESSAGE_SIZE:
             msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
         else:
-            self._send_message(sock, msg)
+            self._send_message(sock_fd, msg)
             msg = self._clear_message(msg)
             msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
 
-        self._send_message(sock, msg)
+        self._send_message(sock_fd, msg)
 
 
-    def _reply_xfrout_query(self, msg, sock, zone_name):
+    def _reply_xfrout_query(self, msg, sock_fd, zone_name):
         #TODO, there should be a better way to insert rrset.
         msg.make_response()
         msg.set_header_flag(Message.HEADERFLAG_AA)
@@ -286,12 +319,12 @@ class XfroutSession(BaseRequestHandler):
                 message_upper_len += rrset_len
                 continue
 
-            self._send_message(sock, msg)
+            self._send_message(sock_fd, msg)
             msg = self._clear_message(msg)
             msg.add_rrset(Message.SECTION_ANSWER, rrset_) # Add the rrset to the new message
             message_upper_len = rrset_len
 
-        self._send_message_with_last_soa(msg, sock, rrset_soa, message_upper_len)
+        self._send_message_with_last_soa(msg, sock_fd, rrset_soa, message_upper_len)
 
 class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
     '''The unix domain socket server which accept xfr query sent from auth server.'''
@@ -304,22 +337,23 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
         self._lock = threading.Lock()
         self._transfers_counter = 0
         self._shutdown_event = shutdown_event
+        self._write_sock, self._read_sock = socket.socketpair()
         self._log = log
         self.update_config_data(config_data)
         self._cc = cc
-        
+
     def finish_request(self, request, client_address):
         '''Finish one request by instantiating RequestHandlerClass.'''
-        self.RequestHandlerClass(request, client_address, self, self._log)
+        self.RequestHandlerClass(request, client_address, self, self._log, self._read_sock)
 
     def _remove_unused_sock_file(self, sock_file):
-        '''Try to remove the socket file. If the file is being used 
-        by one running xfrout process, exit from python. 
+        '''Try to remove the socket file. If the file is being used
+        by one running xfrout process, exit from python.
         If it's not a socket file or nobody is listening
         , it will be removed. If it can't be removed, exit from python. '''
         if self._sock_file_in_use(sock_file):
-            sys.stderr.write("[b10-xfrout] Fail to start xfrout process, unix socket" 
-                  " file '%s' is being used by another xfrout process\n" % sock_file)
+            self._log.log_message("error", "Fail to start xfrout process, unix socket file '%s'"
+                                 " is being used by another xfrout process\n" % sock_file)
             sys.exit(0)
         else:
             if not os.path.exists(sock_file):
@@ -328,12 +362,12 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
             try:
                 os.unlink(sock_file)
             except OSError as err:
-                sys.stderr.write('[b10-xfrout] Fail to remove file %s: %s\n' % (sock_file, err))
+                self._log.log_message("error", '[b10-xfrout] Fail to remove file %s: %s\n' % (sock_file, err))
                 sys.exit(0)
-   
+
     def _sock_file_in_use(self, sock_file):
-        '''Check whether the socket file 'sock_file' exists and 
-        is being used by one running xfrout process. If it is, 
+        '''Check whether the socket file 'sock_file' exists and
+        is being used by one running xfrout process. If it is,
         return True, or else return False. '''
         try:
             sock = socket.socket(socket.AF_UNIX)
@@ -341,9 +375,10 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
         except socket.error as err:
             return False
         else:
-            return True 
+            return True
 
     def shutdown(self):
+        self._write_sock.send(b"shutdown") #terminate the xfrout session thread
         super().shutdown() # call the shutdown() of class socketserver_mixin.NoPollMixIn
         try:
             os.unlink(self._sock_file)
@@ -390,7 +425,7 @@ class XfroutServer:
     def __init__(self):
         self._unix_socket_server = None
         self._log = None
-        self._listen_sock_file = UNIX_SOCKET_FILE 
+        self._listen_sock_file = UNIX_SOCKET_FILE
         self._shutdown_event = threading.Event()
         self._cc = isc.config.ModuleCCSession(SPECFILE_LOCATION, self.config_handler, self.command_handler)
         self._config_data = self._cc.get_full_config()
@@ -404,12 +439,12 @@ class XfroutServer:
 
     def _start_xfr_query_listener(self):
         '''Start a new thread to accept xfr query. '''
-        self._unix_socket_server = UnixSockServer(self._listen_sock_file, XfroutSession, 
+        self._unix_socket_server = UnixSockServer(self._listen_sock_file, XfroutSession,
                                                   self._shutdown_event, self._config_data,
                                                   self._cc, self._log);
         listener = threading.Thread(target=self._unix_socket_server.serve_forever)
         listener.start()
-        
+
     def _start_notifier(self):
         datasrc = self._unix_socket_server.get_db_file()
         self._notifier = notify_out.NotifyOut(datasrc, self._log)
@@ -472,7 +507,7 @@ class XfroutServer:
             else:
                 answer = create_answer(1, "Bad command parameter:" + str(args))
 
-        else: 
+        else:
             answer = create_answer(1, "Unknown command:" + str(cmd))
 
         return answer
@@ -514,7 +549,7 @@ if '__main__' == __name__:
         sys.stderr.write("[b10-xfrout] Error creating xfrout, "
                            "is the command channel daemon running?\n")
     except SessionTimeout as e:
-        sys.stderr.write("[b10-xfrout] Error creating xfrout, " 
+        sys.stderr.write("[b10-xfrout] Error creating xfrout, "
                            "is the configuration manager running?\n")
     except ModuleCCSessionError as e:
         sys.stderr.write("[b10-xfrout] exit xfrout process:%s\n" % str(e))

+ 32 - 32
src/bin/zonemgr/zonemgr.py.in

@@ -90,26 +90,26 @@ class ZonemgrException(Exception):
 
 class ZonemgrRefresh:
     """This class will maintain and manage zone refresh info.
-    It also provides methods to keep track of zone timers and 
-    do zone refresh. 
-    Zone timers can be started by calling run_timer(), and it 
+    It also provides methods to keep track of zone timers and
+    do zone refresh.
+    Zone timers can be started by calling run_timer(), and it
     can be stopped by calling shutdown() in another thread.
 
     """
 
     def __init__(self, cc, db_file, slave_socket, config_data):
         self._cc = cc
-        self._check_sock = slave_socket 
+        self._check_sock = slave_socket
         self._db_file = db_file
         self.update_config_data(config_data)
-        self._zonemgr_refresh_info = {} 
+        self._zonemgr_refresh_info = {}
         self._build_zonemgr_refresh_info()
         self._running = False
-    
+
     def _random_jitter(self, max, jitter):
         """Imposes some random jitters for refresh and
         retry timers to avoid many zones need to do refresh
-        at the same time. 
+        at the same time.
         The value should be between (max - jitter) and max.
         """
         if 0 == jitter:
@@ -120,7 +120,7 @@ class ZonemgrRefresh:
         return time.time()
 
     def _set_zone_timer(self, zone_name_class, max, jitter):
-        """Set zone next refresh time. 
+        """Set zone next refresh time.
         jitter should not be bigger than half the original value."""
         self._set_zone_next_refresh_time(zone_name_class, self._get_current_time() + \
                                             self._random_jitter(max, jitter))
@@ -143,7 +143,7 @@ class ZonemgrRefresh:
 
     def _set_zone_notify_timer(self, zone_name_class):
         """Set zone next refresh time after receiving notify
-           next_refresh_time = now 
+           next_refresh_time = now
         """
         self._set_zone_timer(zone_name_class, 0, 0)
 
@@ -199,7 +199,7 @@ class ZonemgrRefresh:
             raise ZonemgrException("[b10-zonemgr] zone (%s, %s) doesn't have soa." % zone_name_class)
         zone_info["zone_soa_rdata"] = zone_soa[7]
         zone_info["zone_state"] = ZONE_OK
-        zone_info["last_refresh_time"] = self._get_current_time() 
+        zone_info["last_refresh_time"] = self._get_current_time()
         zone_info["next_refresh_time"] = self._get_current_time() + \
                                          float(zone_soa[7].split(" ")[REFRESH_OFFSET])
         self._zonemgr_refresh_info[zone_name_class] = zone_info
@@ -233,7 +233,7 @@ class ZonemgrRefresh:
 
     def _get_zone_notifier_master(self, zone_name_class):
         if ("notify_master" in self._zonemgr_refresh_info[zone_name_class].keys()):
-            return self._zonemgr_refresh_info[zone_name_class]["notify_master"] 
+            return self._zonemgr_refresh_info[zone_name_class]["notify_master"]
 
         return None
 
@@ -248,7 +248,7 @@ class ZonemgrRefresh:
         return self._zonemgr_refresh_info[zone_name_class]["zone_state"]
 
     def _set_zone_state(self, zone_name_class, zone_state):
-        self._zonemgr_refresh_info[zone_name_class]["zone_state"] = zone_state 
+        self._zonemgr_refresh_info[zone_name_class]["zone_state"] = zone_state
 
     def _get_zone_refresh_timeout(self, zone_name_class):
         return self._zonemgr_refresh_info[zone_name_class]["refresh_timeout"]
@@ -268,7 +268,7 @@ class ZonemgrRefresh:
         try:
             self._cc.group_sendmsg(msg, module_name)
         except socket.error:
-            sys.stderr.write("[b10-zonemgr] Failed to send to module %s, the session has been closed." % module_name) 
+            sys.stderr.write("[b10-zonemgr] Failed to send to module %s, the session has been closed." % module_name)
 
     def _find_need_do_refresh_zone(self):
         """Find the first zone need do refresh, if no zone need
@@ -281,10 +281,10 @@ class ZonemgrRefresh:
             if (ZONE_REFRESHING == zone_state and
                 (self._get_zone_refresh_timeout(zone_name_class) > self._get_current_time())):
                 continue
-                    
-            # Get the zone with minimum next_refresh_time 
-            if ((zone_need_refresh is None) or 
-                (self._get_zone_next_refresh_time(zone_name_class) < 
+
+            # Get the zone with minimum next_refresh_time
+            if ((zone_need_refresh is None) or
+                (self._get_zone_next_refresh_time(zone_name_class) <
                  self._get_zone_next_refresh_time(zone_need_refresh))):
                 zone_need_refresh = zone_name_class
 
@@ -292,14 +292,14 @@ class ZonemgrRefresh:
             if (self._get_zone_next_refresh_time(zone_need_refresh) < self._get_current_time()):
                 break
 
-        return zone_need_refresh 
+        return zone_need_refresh
+
 
-    
     def _do_refresh(self, zone_name_class):
         """Do zone refresh."""
         log_msg("Do refresh for zone (%s, %s)." % zone_name_class)
         self._set_zone_state(zone_name_class, ZONE_REFRESHING)
-        self._set_zone_refresh_timeout(zone_name_class, self._get_current_time() + self._max_transfer_timeout) 
+        self._set_zone_refresh_timeout(zone_name_class, self._get_current_time() + self._max_transfer_timeout)
         notify_master = self._get_zone_notifier_master(zone_name_class)
         # If the zone has notify master, send notify command to xfrin module
         if notify_master:
@@ -307,7 +307,7 @@ class ZonemgrRefresh:
                      "zone_class" : zone_name_class[1],
                      "master" : notify_master
                      }
-            self._send_command(XFRIN_MODULE_NAME, ZONE_NOTIFY_COMMAND, param) 
+            self._send_command(XFRIN_MODULE_NAME, ZONE_NOTIFY_COMMAND, param)
             self._clear_zone_notifier_master(zone_name_class)
         # Send refresh command to xfrin module
         else:
@@ -328,19 +328,19 @@ class ZonemgrRefresh:
         while self._running:
             # If zonemgr has no zone, set timer timeout to self._lowerbound_retry.
             if self._zone_mgr_is_empty():
-                timeout = self._lowerbound_retry 
+                timeout = self._lowerbound_retry
             else:
                 zone_need_refresh = self._find_need_do_refresh_zone()
-                # If don't get zone with minimum next refresh time, set timer timeout to self._lowerbound_retry 
+                # If don't get zone with minimum next refresh time, set timer timeout to self._lowerbound_retry.
                 if not zone_need_refresh:
-                    timeout = self._lowerbound_retry 
+                    timeout = self._lowerbound_retry
                 else:
                     timeout = self._get_zone_next_refresh_time(zone_need_refresh) - self._get_current_time()
                     if (timeout < 0):
                         self._do_refresh(zone_need_refresh)
                         continue
 
-            """ Wait for the socket notification for a maximum time of timeout 
+            """ Wait for the socket notification for a maximum time of timeout
             in seconds (as float)."""
             try:
                 rlist, wlist, xlist = select.select([self._check_sock, self._read_sock], [], [], timeout)
@@ -352,7 +352,7 @@ class ZonemgrRefresh:
                     break
 
             for fd in rlist:
-                if fd == self._read_sock: # awaken by shutdown socket 
+                if fd == self._read_sock: # awaken by shutdown socket
                     # self._running will be False by now, if it is not a false
                     # alarm
                     continue
@@ -416,7 +416,7 @@ class Zonemgr:
         self._zone_refresh = None
         self._setup_session()
         self._db_file = self.get_db_file()
-        # Create socket pair for communicating between main thread and zonemgr timer thread 
+        # Create socket pair for communicating between main thread and zonemgr timer thread
         self._master_socket, self._slave_socket = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
         self._zone_refresh = ZonemgrRefresh(self._cc, self._db_file, self._slave_socket, self._config_data)
         self._zone_refresh.run_timer()
@@ -426,7 +426,7 @@ class Zonemgr:
         self.running = False
 
     def _setup_session(self):
-        """Setup two sessions for zonemgr, one(self._module_cc) is used for receiving 
+        """Setup two sessions for zonemgr, one(self._module_cc) is used for receiving
         commands and config data sent from other modules, another one (self._cc)
         is used to send commands to proper modules."""
         self._cc = isc.cc.Session()
@@ -450,7 +450,7 @@ class Zonemgr:
     def shutdown(self):
         """Shutdown the zonemgr process. the thread which is keeping track of zone
         timers should be terminated.
-        """ 
+        """
         self._zone_refresh.shutdown()
 
         self._slave_socket.close()
@@ -503,7 +503,7 @@ class Zonemgr:
 
     def command_handler(self, command, args):
         """Handle command receivd from command channel.
-        ZONE_NOTIFY_COMMAND is issued by Auth process; ZONE_XFRIN_SUCCESS_COMMAND 
+        ZONE_NOTIFY_COMMAND is issued by Auth process; ZONE_XFRIN_SUCCESS_COMMAND
         and ZONE_XFRIN_FAILED_COMMAND are issued by Xfrin process; shutdown is issued
         by a user or Boss process. """
         answer = create_answer(0)
@@ -572,10 +572,10 @@ if '__main__' == __name__:
     except KeyboardInterrupt:
         sys.stderr.write("[b10-zonemgr] exit zonemgr process\n")
     except isc.cc.session.SessionError as e:
-        sys.stderr.write("[b10-zonemgr] Error creating zonemgr, " 
+        sys.stderr.write("[b10-zonemgr] Error creating zonemgr, "
                            "is the command channel daemon running?\n")
     except isc.cc.session.SessionTimeout as e:
-        sys.stderr.write("[b10-zonemgr] Error creating zonemgr, " 
+        sys.stderr.write("[b10-zonemgr] Error creating zonemgr, "
                            "is the configuration manager running?\n")
     except isc.config.ModuleCCSessionError as e:
         sys.stderr.write("[b10-zonemgr] exit zonemgr process: %s\n" % str(e))

+ 1 - 1
src/lib/xfr/fd_share.cc

@@ -93,7 +93,7 @@ recv_fd(const int sock) {
 
     if (recvmsg(sock, &msghdr, 0) < 0) {
         free(msghdr.msg_control);
-        return (-1);
+        return (XFR_FD_RECEIVE_FAIL);
     }
     const struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msghdr);
     int fd = -1;

+ 7 - 3
src/lib/xfr/fd_share.h

@@ -20,12 +20,16 @@
 namespace isc {
 namespace xfr {
 
+/// Failed to receive xfr socket descriptor "fd" on unix domain socket 'sock'
+const int XFR_FD_RECEIVE_FAIL = -2;
+
 // Receive socket descriptor on unix domain socket 'sock'.
 // Returned value is the socket descriptor received.
+// Returned XFR_FD_RECEIVE_FAIL if failed to receive xfr socket descriptor
 // Errors are indicated by a return value of -1.
 int recv_fd(const int sock);
 
-// Send socket descriptor "fd" to server over unix domain socket 'sock', 
+// Send socket descriptor "fd" to server over unix domain socket 'sock',
 // the connection from socket 'sock' to unix domain server should be established first.
 // Errors are indicated by a return value of -1.
 int send_fd(const int sock, const int fd);
@@ -35,6 +39,6 @@ int send_fd(const int sock, const int fd);
 
 #endif
 
-// Local Variables: 
+// Local Variables:
 // mode: c++
-// End: 
+// End:

+ 16 - 3
src/lib/xfr/fdshare_python.cc

@@ -22,8 +22,9 @@
 
 #include <xfr/fd_share.h>
 
+
 static PyObject*
-fdshare_recv_fd(PyObject *self UNUSED_PARAM, PyObject *args) {
+fdshare_recv_fd(PyObject* self UNUSED_PARAM, PyObject* args) {
     int sock, fd;
     if (!PyArg_ParseTuple(args, "i", &sock)) {
         return (NULL);
@@ -33,7 +34,7 @@ fdshare_recv_fd(PyObject *self UNUSED_PARAM, PyObject *args) {
 }
 
 static PyObject*
-fdshare_send_fd(PyObject *self UNUSED_PARAM, PyObject *args) {
+fdshare_send_fd(PyObject* self UNUSED_PARAM, PyObject* args) {
     int sock, fd, result;
     if (!PyArg_ParseTuple(args, "ii", &sock, &fd)) {
         return (NULL);
@@ -63,11 +64,23 @@ static PyModuleDef bind10_fdshare_python = {
 
 PyMODINIT_FUNC
 PyInit_libxfr_python(void) {
-    PyObject *mod = PyModule_Create(&bind10_fdshare_python);
+    PyObject* mod = PyModule_Create(&bind10_fdshare_python);
     if (mod == NULL) {
         return (NULL);
     }
 
+    PyObject* XFR_FD_RECEIVE_FAIL = Py_BuildValue("i", isc::xfr::XFR_FD_RECEIVE_FAIL);
+    if (XFR_FD_RECEIVE_FAIL == NULL) {
+        Py_XDECREF(mod);
+        return (NULL);
+    }
+    int ret = PyModule_AddObject(mod, "XFR_FD_RECEIVE_FAIL", XFR_FD_RECEIVE_FAIL);
+    if (-1 == ret) {
+        Py_XDECREF(XFR_FD_RECEIVE_FAIL);
+        Py_XDECREF(mod);
+        return (NULL);
+    }
+
     return (mod);
 }
 

+ 1 - 7
src/lib/xfr/xfrout_client.cc

@@ -69,7 +69,7 @@ XfroutClient::disconnect() {
     }
 }
 
-int 
+int
 XfroutClient::sendXfroutRequestInfo(const int tcp_sock,
                                     const void* const msg_data,
                                     const uint16_t msg_len)
@@ -93,12 +93,6 @@ XfroutClient::sendXfroutRequestInfo(const int tcp_sock,
         isc_throw(XfroutError,
                   "failed to send XFR request data to xfrout module");
     }
-    
-    int databuf = 0;
-    if (recv(impl_->socket_.native(), &databuf, sizeof(int), 0) != 0) {
-        isc_throw(XfroutError,
-                  "xfr query hasn't been processed properly by xfrout module");
-    }
 
     return (0);
 }