Browse Source

merge trac299: Xfrout and Auth will communicate by long tcp connection,
Auth needs to make a new connection only on the first time or if an error occurred.


git-svn-id: svn://bind10.isc.org/svn/bind10/trunk@3482 e5f2f494-b856-4b98-b285-d166d9295462

Jerry 14 years ago
parent
commit
f7c0462610

+ 7 - 1
ChangeLog

@@ -1,3 +1,9 @@
+  116.	[bug]		jerry
+	src/bin/xfrout: Xfrout and Auth will communicate by long tcp
+	connection, Auth needs to make a new connection only on the first
+	time or if an error occurred.
+	(Trac #299, svn r3482)
+
   115.	[func]*		jinmei
   115.	[func]*		jinmei
 	src/lib/dns: Changed DNS message flags and section names from
 	src/lib/dns: Changed DNS message flags and section names from
 	separate classes to simpler enums, considering the balance between
 	separate classes to simpler enums, considering the balance between
@@ -10,7 +16,7 @@
 	(Trac #365, svn r3383)
 	(Trac #365, svn r3383)
 
 
   113.	[func]*		zhanglikun
   113.	[func]*		zhanglikun
-	Folder name 'utils'(the folder in /src/lib/python/isc/) has been 
+	Folder name 'utils'(the folder in /src/lib/python/isc/) has been
 	renamed	to 'util'. Programs that used 'import isc.utils.process'
 	renamed	to 'util'. Programs that used 'import isc.utils.process'
 	now need to use 'import isc.util.process'. The folder
 	now need to use 'import isc.util.process'. The folder
 	/src/lib/python/isc/Util is removed since it isn't used by any
 	/src/lib/python/isc/Util is removed since it isn't used by any

+ 12 - 13
src/bin/auth/auth_srv.cc

@@ -77,7 +77,7 @@ public:
                             MessageRenderer& response_renderer);
                             MessageRenderer& response_renderer);
     bool processAxfrQuery(const IOMessage& io_message, Message& message,
     bool processAxfrQuery(const IOMessage& io_message, Message& message,
                             MessageRenderer& response_renderer);
                             MessageRenderer& response_renderer);
-    bool processNotify(const IOMessage& io_message, Message& message, 
+    bool processNotify(const IOMessage& io_message, Message& message,
                             MessageRenderer& response_renderer);
                             MessageRenderer& response_renderer);
     std::string db_file_;
     std::string db_file_;
     ModuleCCSession* config_session_;
     ModuleCCSession* config_session_;
@@ -307,7 +307,7 @@ AuthSrvImpl::processNormalQuery(const IOMessage& io_message, Message& message,
     ConstEDNSPtr remote_edns = message.getEDNS();
     ConstEDNSPtr remote_edns = message.getEDNS();
     const bool dnssec_ok = remote_edns && remote_edns->getDNSSECAwareness();
     const bool dnssec_ok = remote_edns && remote_edns->getDNSSECAwareness();
     const uint16_t remote_bufsize = remote_edns ? remote_edns->getUDPSize() :
     const uint16_t remote_bufsize = remote_edns ? remote_edns->getUDPSize() :
-        Message::DEFAULT_MAX_UDPSIZE; 
+        Message::DEFAULT_MAX_UDPSIZE;
 
 
     message.makeResponse();
     message.makeResponse();
     message.setHeaderFlag(Message::HEADERFLAG_AA);
     message.setHeaderFlag(Message::HEADERFLAG_AA);
@@ -360,8 +360,10 @@ AuthSrvImpl::processAxfrQuery(const IOMessage& io_message, Message& message,
     }
     }
 
 
     try {
     try {
-        xfrout_client_.connect();
-        xfrout_connected_ = true;
+        if (!xfrout_connected_) {
+            xfrout_client_.connect();
+            xfrout_connected_ = true;
+        }
         xfrout_client_.sendXfroutRequestInfo(
         xfrout_client_.sendXfroutRequestInfo(
             io_message.getSocket().getNative(),
             io_message.getSocket().getNative(),
             io_message.getData(),
             io_message.getData(),
@@ -375,7 +377,7 @@ AuthSrvImpl::processAxfrQuery(const IOMessage& io_message, Message& message,
             xfrout_client_.disconnect();
             xfrout_client_.disconnect();
             xfrout_connected_ = false;
             xfrout_connected_ = false;
         }
         }
-        
+
         if (verbose_mode_) {
         if (verbose_mode_) {
             cerr << "[b10-auth] Error in handling XFR request: " << err.what()
             cerr << "[b10-auth] Error in handling XFR request: " << err.what()
                  << endl;
                  << endl;
@@ -385,15 +387,12 @@ AuthSrvImpl::processAxfrQuery(const IOMessage& io_message, Message& message,
         return (true);
         return (true);
     }
     }
 
 
-    xfrout_client_.disconnect();
-    xfrout_connected_ = false;
-
     return (false);
     return (false);
 }
 }
 
 
 bool
 bool
-AuthSrvImpl::processNotify(const IOMessage& io_message, Message& message, 
-                           MessageRenderer& response_renderer) 
+AuthSrvImpl::processNotify(const IOMessage& io_message, Message& message,
+                           MessageRenderer& response_renderer)
 {
 {
     // The incoming notify must contain exactly one question for SOA of the
     // The incoming notify must contain exactly one question for SOA of the
     // zone name.
     // zone name.
@@ -435,7 +434,7 @@ AuthSrvImpl::processNotify(const IOMessage& io_message, Message& message,
         }
         }
         return (false);
         return (false);
     }
     }
-    
+
     const string remote_ip_address =
     const string remote_ip_address =
         io_message.getRemoteEndpoint().getAddress().toText();
         io_message.getRemoteEndpoint().getAddress().toText();
     static const string command_template_start =
     static const string command_template_start =
@@ -446,7 +445,7 @@ AuthSrvImpl::processNotify(const IOMessage& io_message, Message& message,
 
 
     try {
     try {
         ConstElementPtr notify_command = Element::fromJSON(
         ConstElementPtr notify_command = Element::fromJSON(
-                command_template_start + question->getName().toText() + 
+                command_template_start + question->getName().toText() +
                 command_template_master + remote_ip_address +
                 command_template_master + remote_ip_address +
                 command_template_rrclass + question->getClass().toText() +
                 command_template_rrclass + question->getClass().toText() +
                 command_template_end);
                 command_template_end);
@@ -460,7 +459,7 @@ AuthSrvImpl::processNotify(const IOMessage& io_message, Message& message,
         if (rcode != 0) {
         if (rcode != 0) {
             if (verbose_mode_) {
             if (verbose_mode_) {
                 cerr << "[b10-auth] failed to notify Zonemgr: "
                 cerr << "[b10-auth] failed to notify Zonemgr: "
-                     << parsed_answer->str() << endl; 
+                     << parsed_answer->str() << endl;
             }
             }
             return (false);
             return (false);
         }
         }

+ 2 - 4
src/bin/auth/tests/auth_srv_unittest.cc

@@ -489,7 +489,7 @@ TEST_F(AuthSrvTest, AXFRSuccess) {
     // so we shouldn't have to respond.
     // so we shouldn't have to respond.
     EXPECT_FALSE(server.processMessage(*io_message, parse_message,
     EXPECT_FALSE(server.processMessage(*io_message, parse_message,
                                        response_renderer));
                                        response_renderer));
-    EXPECT_FALSE(xfrout.isConnected());
+    EXPECT_TRUE(xfrout.isConnected());
 }
 }
 
 
 TEST_F(AuthSrvTest, AXFRConnectFail) {
 TEST_F(AuthSrvTest, AXFRConnectFail) {
@@ -501,8 +501,6 @@ TEST_F(AuthSrvTest, AXFRConnectFail) {
                                       response_renderer));
                                       response_renderer));
     headerCheck(parse_message, default_qid, Rcode::SERVFAIL(),
     headerCheck(parse_message, default_qid, Rcode::SERVFAIL(),
                 opcode.getCode(), QR_FLAG, 1, 0, 0, 0);
                 opcode.getCode(), QR_FLAG, 1, 0, 0, 0);
-    // For a shot term workaround with xfrout we currently close the connection
-    // for each AXFR attempt
     EXPECT_FALSE(xfrout.isConnected());
     EXPECT_FALSE(xfrout.isConnected());
 }
 }
 
 
@@ -512,7 +510,7 @@ TEST_F(AuthSrvTest, AXFRSendFail) {
     createRequestPacket(opcode, Name("example.com"), RRClass::IN(),
     createRequestPacket(opcode, Name("example.com"), RRClass::IN(),
                         RRType::AXFR(), IPPROTO_TCP);
                         RRType::AXFR(), IPPROTO_TCP);
     server.processMessage(*io_message, parse_message, response_renderer);
     server.processMessage(*io_message, parse_message, response_renderer);
-    EXPECT_FALSE(xfrout.isConnected()); // see above
+    EXPECT_TRUE(xfrout.isConnected());
 
 
     xfrout.disableSend();
     xfrout.disableSend();
     parse_message.clear(Message::PARSE);
     parse_message.clear(Message::PARSE);

+ 36 - 20
src/bin/xfrout/tests/xfrout_test.py

@@ -47,22 +47,29 @@ class MySocket():
         result = self.sendqueue[:size]
         result = self.sendqueue[:size]
         self.sendqueue = self.sendqueue[size:]
         self.sendqueue = self.sendqueue[size:]
         return result
         return result
-    
+
     def read_msg(self):
     def read_msg(self):
         sent_data = self.readsent()
         sent_data = self.readsent()
         get_msg = Message(Message.PARSE)
         get_msg = Message(Message.PARSE)
         get_msg.from_wire(bytes(sent_data[2:]))
         get_msg.from_wire(bytes(sent_data[2:]))
         return get_msg
         return get_msg
-    
+
     def clear_send(self):
     def clear_send(self):
         del self.sendqueue[:]
         del self.sendqueue[:]
 
 
 # We subclass the Session class we're testing here, only
 # We subclass the Session class we're testing here, only
-# to override the __init__() method, which wants a socket,
+# to override the handle() and _send_data() method
 class MyXfroutSession(XfroutSession):
 class MyXfroutSession(XfroutSession):
     def handle(self):
     def handle(self):
         pass
         pass
-    
+
+    def _send_data(self, sock, data):
+        size = len(data)
+        total_count = 0
+        while total_count < size:
+            count = sock.send(data[total_count:])
+            total_count += count
+
 class Dbserver:
 class Dbserver:
     def __init__(self):
     def __init__(self):
         self._shutdown_event = threading.Event()
         self._shutdown_event = threading.Event()
@@ -80,12 +87,21 @@ class TestXfroutSession(unittest.TestCase):
     def setUp(self):
     def setUp(self):
         request = MySocket(socket.AF_INET,socket.SOCK_STREAM)
         request = MySocket(socket.AF_INET,socket.SOCK_STREAM)
         self.log = isc.log.NSLogger('xfrout', '',  severity = 'critical', log_to_console = False )
         self.log = isc.log.NSLogger('xfrout', '',  severity = 'critical', log_to_console = False )
-        self.xfrsess = MyXfroutSession(request, None, None, self.log)
+        (self.write_sock, self.read_sock) = socket.socketpair()
+        self.xfrsess = MyXfroutSession(request, None, None, self.log, self.read_sock)
         self.xfrsess.server = Dbserver()
         self.xfrsess.server = Dbserver()
         self.mdata = bytes(b'\xd6=\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x07example\x03com\x00\x00\xfc\x00\x01')
         self.mdata = bytes(b'\xd6=\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x07example\x03com\x00\x00\xfc\x00\x01')
         self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM)
         self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM)
         self.soa_record = (4, 3, 'example.com.', 'com.example.', 3600, 'SOA', None, 'master.example.com. admin.example.com. 1234 3600 1800 2419200 7200')
         self.soa_record = (4, 3, 'example.com.', 'com.example.', 3600, 'SOA', None, 'master.example.com. admin.example.com. 1234 3600 1800 2419200 7200')
 
 
+    def test_receive_query_message(self):
+        send_msg = b"\xd6=\x00\x00\x00\x01\x00"
+        msg_len = struct.pack('H', socket.htons(len(send_msg)))
+        self.write_sock.send(msg_len)
+        self.write_sock.send(send_msg)
+        recv_msg = self.xfrsess._receive_query_message(self.read_sock)
+        self.assertEqual(recv_msg, send_msg)
+
     def test_parse_query_message(self):
     def test_parse_query_message(self):
         [get_rcode, get_msg] = self.xfrsess._parse_query_message(self.mdata)
         [get_rcode, get_msg] = self.xfrsess._parse_query_message(self.mdata)
         self.assertEqual(get_rcode.to_text(), "NOERROR")
         self.assertEqual(get_rcode.to_text(), "NOERROR")
@@ -93,7 +109,7 @@ class TestXfroutSession(unittest.TestCase):
     def test_get_query_zone_name(self):
     def test_get_query_zone_name(self):
         msg = self.getmsg()
         msg = self.getmsg()
         self.assertEqual(self.xfrsess._get_query_zone_name(msg), "example.com.")
         self.assertEqual(self.xfrsess._get_query_zone_name(msg), "example.com.")
-  
+
     def test_send_data(self):
     def test_send_data(self):
         self.xfrsess._send_data(self.sock, self.mdata)
         self.xfrsess._send_data(self.sock, self.mdata)
         senddata = self.sock.readsent()
         senddata = self.sock.readsent()
@@ -103,8 +119,8 @@ class TestXfroutSession(unittest.TestCase):
         msg = self.getmsg()
         msg = self.getmsg()
         self.xfrsess._reply_query_with_error_rcode(msg, self.sock, Rcode(3))
         self.xfrsess._reply_query_with_error_rcode(msg, self.sock, Rcode(3))
         get_msg = self.sock.read_msg()
         get_msg = self.sock.read_msg()
-        self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN") 
-     
+        self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN")
+
     def test_clear_message(self):
     def test_clear_message(self):
         msg = self.getmsg()
         msg = self.getmsg()
         qid = msg.get_qid()
         qid = msg.get_qid()
@@ -118,7 +134,7 @@ class TestXfroutSession(unittest.TestCase):
         self.assertTrue(msg.get_header_flag(Message.HEADERFLAG_AA))
         self.assertTrue(msg.get_header_flag(Message.HEADERFLAG_AA))
 
 
     def test_reply_query_with_format_error(self):
     def test_reply_query_with_format_error(self):
-         
+
         msg = self.getmsg()
         msg = self.getmsg()
         self.xfrsess._reply_query_with_format_error(msg, self.sock)
         self.xfrsess._reply_query_with_format_error(msg, self.sock)
         get_msg = self.sock.read_msg()
         get_msg = self.sock.read_msg()
@@ -217,7 +233,7 @@ class TestXfroutSession(unittest.TestCase):
         sqlite3_ds.get_zone_soa = zone_soa
         sqlite3_ds.get_zone_soa = zone_soa
         self.assertEqual(self.xfrsess._zone_exist(True), True)
         self.assertEqual(self.xfrsess._zone_exist(True), True)
         self.assertEqual(self.xfrsess._zone_exist(False), False)
         self.assertEqual(self.xfrsess._zone_exist(False), False)
-    
+
     def test_check_xfrout_available(self):
     def test_check_xfrout_available(self):
         def zone_exist(zone):
         def zone_exist(zone):
             return zone
             return zone
@@ -243,7 +259,7 @@ class TestXfroutSession(unittest.TestCase):
         self.xfrsess.dns_xfrout_start(self.sock, b"\xd6=\x00\x00\x00\x01\x00")
         self.xfrsess.dns_xfrout_start(self.sock, b"\xd6=\x00\x00\x00\x01\x00")
         sent_data = self.sock.readsent()
         sent_data = self.sock.readsent()
         self.assertEqual(len(sent_data), 0)
         self.assertEqual(len(sent_data), 0)
-    
+
     def default(self, param):
     def default(self, param):
         return "example.com"
         return "example.com"
 
 
@@ -255,20 +271,20 @@ class TestXfroutSession(unittest.TestCase):
         self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
         self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
         get_msg = self.sock.read_msg()
         get_msg = self.sock.read_msg()
         self.assertEqual(get_msg.get_rcode().to_text(), "NOTAUTH")
         self.assertEqual(get_msg.get_rcode().to_text(), "NOTAUTH")
-    
+
     def test_dns_xfrout_start_noerror(self):
     def test_dns_xfrout_start_noerror(self):
         self.xfrsess._get_query_zone_name = self.default
         self.xfrsess._get_query_zone_name = self.default
         def noerror(form):
         def noerror(form):
-            return Rcode.NOERROR() 
+            return Rcode.NOERROR()
         self.xfrsess._check_xfrout_available = noerror
         self.xfrsess._check_xfrout_available = noerror
-        
+
         def myreply(msg, sock, zonename):
         def myreply(msg, sock, zonename):
             self.sock.send(b"success")
             self.sock.send(b"success")
-        
+
         self.xfrsess._reply_xfrout_query = myreply
         self.xfrsess._reply_xfrout_query = myreply
         self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
         self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
         self.assertEqual(self.sock.readsent(), b"success")
         self.assertEqual(self.sock.readsent(), b"success")
-    
+
     def test_reply_xfrout_query_noerror(self):
     def test_reply_xfrout_query_noerror(self):
         global sqlite3_ds
         global sqlite3_ds
         def get_zone_soa(zonename, file):
         def get_zone_soa(zonename, file):
@@ -292,7 +308,7 @@ class MyCCSession():
             return "initdb.file", False
             return "initdb.file", False
         else:
         else:
             return "unknown", False
             return "unknown", False
-    
+
 
 
 class MyUnixSockServer(UnixSockServer):
 class MyUnixSockServer(UnixSockServer):
     def __init__(self):
     def __init__(self):
@@ -306,7 +322,7 @@ class MyUnixSockServer(UnixSockServer):
 class TestUnixSockServer(unittest.TestCase):
 class TestUnixSockServer(unittest.TestCase):
     def setUp(self):
     def setUp(self):
         self.unix = MyUnixSockServer()
         self.unix = MyUnixSockServer()
-     
+
     def test_updata_config_data(self):
     def test_updata_config_data(self):
         self.unix.update_config_data({'transfers_out':10 })
         self.unix.update_config_data({'transfers_out':10 })
         self.assertEqual(self.unix._max_transfers_out, 10)
         self.assertEqual(self.unix._max_transfers_out, 10)
@@ -324,7 +340,7 @@ class TestUnixSockServer(unittest.TestCase):
         count = self.unix._transfers_counter
         count = self.unix._transfers_counter
         self.assertEqual(self.unix.increase_transfers_counter(), False)
         self.assertEqual(self.unix.increase_transfers_counter(), False)
         self.assertEqual(count, self.unix._transfers_counter)
         self.assertEqual(count, self.unix._transfers_counter)
- 
+
     def test_decrease_transfers_counter(self):
     def test_decrease_transfers_counter(self):
         count = self.unix._transfers_counter
         count = self.unix._transfers_counter
         self.unix.decrease_transfers_counter()
         self.unix.decrease_transfers_counter()
@@ -335,7 +351,7 @@ class TestUnixSockServer(unittest.TestCase):
             os.remove(sock_file)
             os.remove(sock_file)
         except OSError:
         except OSError:
             pass
             pass
- 
+
     def test_sock_file_in_use_file_exist(self):
     def test_sock_file_in_use_file_exist(self):
         sock_file = 'temp.sock.file'
         sock_file = 'temp.sock.file'
         self._remove_file(sock_file)
         self._remove_file(sock_file)

+ 108 - 73
src/bin/xfrout/xfrout.py.in

@@ -63,6 +63,7 @@ AUTH_SPECFILE_LOCATION = AUTH_SPECFILE_PATH + os.sep + "auth.spec"
 MAX_TRANSFERS_OUT = 10
 MAX_TRANSFERS_OUT = 10
 VERBOSE_MODE = False
 VERBOSE_MODE = False
 
 
+
 XFROUT_MAX_MESSAGE_SIZE = 65535
 XFROUT_MAX_MESSAGE_SIZE = 65535
 
 
 def get_rrset_len(rrset):
 def get_rrset_len(rrset):
@@ -73,46 +74,78 @@ def get_rrset_len(rrset):
 
 
 
 
 class XfroutSession(BaseRequestHandler):
 class XfroutSession(BaseRequestHandler):
-    def __init__(self, request, client_address, server, log):
+    def __init__(self, request, client_address, server, log, sock):
         # The initializer for the superclass may call functions
         # The initializer for the superclass may call functions
         # that need _log to be set, so we set it first
         # that need _log to be set, so we set it first
         self._log = log
         self._log = log
+        self._shutdown_sock = sock
         BaseRequestHandler.__init__(self, request, client_address, server)
         BaseRequestHandler.__init__(self, request, client_address, server)
 
 
     def handle(self):
     def handle(self):
-        fd = recv_fd(self.request.fileno())
-        
-        if fd < 0:
-            # This may happen when one xfrout process try to connect to
-            # xfrout unix socket server, to check whether there is another
-            # xfrout running. 
-            self._log.log_message("error", "Failed to receive the file descriptor for XFR connection")
-            return
-
-        data_len = self.request.recv(2)
-        msg_len = struct.unpack('!H', data_len)[0]
-        msgdata = self.request.recv(msg_len)
-        sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM)
-        try:
-            self.dns_xfrout_start(sock, msgdata)
-            #TODO, avoid catching all exceptions
-        except Exception as e:
-            self._log.log_message("error", str(e))
+        '''Handle a request until shutdown or xfrout client is closed.'''
+        # check self.server._shutdown_event to ensure the real shutdown comes.
+        # Linux could trigger a spurious readable event on the _shutdown_sock 
+        # due to a bug, so we need perform a double check. 
+        while not self.server._shutdown_event.is_set(): # Check if xfrout is shutdown
+            try:
+                (rlist, wlist, xlist) = select.select([self._shutdown_sock, self.request], [], [])
+            except select.error as e:
+                if e.args[0] == errno.EINTR:
+                    (rlist, wlist, xlist) = ([], [], [])
+                    continue
+                else:
+                    self._log.log_message("error", "Error with select(): %s" %e)
+                    break
+            # self.server._shutdown_evnet will be set by now, if it is not a false
+            # alarm
+            if self._shutdown_sock in rlist:
+                continue
 
 
-        try:
-            sock.shutdown(socket.SHUT_RDWR)
-        except socket.error:
-            # Avoid socket error caused by shutting down 
-            # one non-connected socket.
-            pass
+            sock_fd = recv_fd(self.request.fileno())
+
+            if sock_fd < 0:
+                # This may happen when one xfrout process try to connect to
+                # xfrout unix socket server, to check whether there is another
+                # xfrout running.
+                if sock_fd == XFR_FD_RECEIVE_FAIL:
+                    self._log.log_message("error", "Failed to receive the file descriptor for XFR connection")
+                break
 
 
-        sock.close()
-        os.close(fd)
-        pass
+            # receive query msg
+            msgdata = self._receive_query_message(self.request)
+            if not msgdata:
+                break
+
+            try:
+                self.dns_xfrout_start(sock_fd, msgdata)
+                #TODO, avoid catching all exceptions
+            except Exception as e:
+                self._log.log_message("error", str(e))
+
+            os.close(sock_fd)
+
+    def _receive_query_message(self, sock):
+        ''' receive query message from sock'''
+        # receive data length
+        data_len = sock.recv(2)
+        if not data_len:
+            return None
+        msg_len = struct.unpack('!H', data_len)[0]
+        # receive data
+        recv_size = 0
+        msgdata = b''
+        while recv_size < msg_len:
+            data = sock.recv(msg_len - recv_size)
+            if not data:
+                return None
+            recv_size += len(data)
+            msgdata += data
+
+        return msgdata
 
 
     def _parse_query_message(self, mdata):
     def _parse_query_message(self, mdata):
         ''' parse query message to [socket,message]'''
         ''' parse query message to [socket,message]'''
-        #TODO, need to add parseHeader() in case the message header is invalid 
+        #TODO, need to add parseHeader() in case the message header is invalid
         try:
         try:
             msg = Message(Message.PARSE)
             msg = Message(Message.PARSE)
             Message.from_wire(msg, mdata)
             Message.from_wire(msg, mdata)
@@ -127,37 +160,37 @@ class XfroutSession(BaseRequestHandler):
         return question.get_name().to_text()
         return question.get_name().to_text()
 
 
 
 
-    def _send_data(self, sock, data):
+    def _send_data(self, sock_fd, data):
         size = len(data)
         size = len(data)
         total_count = 0
         total_count = 0
         while total_count < size:
         while total_count < size:
-            count = sock.send(data[total_count:])
+            count = os.write(sock_fd, data[total_count:])
             total_count += count
             total_count += count
 
 
 
 
-    def _send_message(self, sock, msg):
+    def _send_message(self, sock_fd, msg):
         render = MessageRenderer()
         render = MessageRenderer()
         render.set_length_limit(XFROUT_MAX_MESSAGE_SIZE)
         render.set_length_limit(XFROUT_MAX_MESSAGE_SIZE)
         msg.to_wire(render)
         msg.to_wire(render)
         header_len = struct.pack('H', socket.htons(render.get_length()))
         header_len = struct.pack('H', socket.htons(render.get_length()))
-        self._send_data(sock, header_len)
-        self._send_data(sock, render.get_data())
+        self._send_data(sock_fd, header_len)
+        self._send_data(sock_fd, render.get_data())
 
 
 
 
-    def _reply_query_with_error_rcode(self, msg, sock, rcode_):
+    def _reply_query_with_error_rcode(self, msg, sock_fd, rcode_):
         msg.make_response()
         msg.make_response()
         msg.set_rcode(rcode_)
         msg.set_rcode(rcode_)
-        self._send_message(sock, msg)
+        self._send_message(sock_fd, msg)
 
 
 
 
-    def _reply_query_with_format_error(self, msg, sock):
+    def _reply_query_with_format_error(self, msg, sock_fd):
         '''query message format isn't legal.'''
         '''query message format isn't legal.'''
         if not msg:
         if not msg:
-            return # query message is invalid. send nothing back. 
+            return # query message is invalid. send nothing back.
 
 
         msg.make_response()
         msg.make_response()
         msg.set_rcode(Rcode.FORMERR())
         msg.set_rcode(Rcode.FORMERR())
-        self._send_message(sock, msg)
+        self._send_message(sock_fd, msg)
 
 
 
 
     def _zone_is_empty(self, zone):
     def _zone_is_empty(self, zone):
@@ -167,24 +200,24 @@ class XfroutSession(BaseRequestHandler):
         return True
         return True
 
 
     def _zone_exist(self, zonename):
     def _zone_exist(self, zonename):
-        # Find zone in datasource, should this works? maybe should ask 
+        # Find zone in datasource, should this works? maybe should ask
         # config manager.
         # config manager.
         soa = sqlite3_ds.get_zone_soa(zonename, self.server.get_db_file())
         soa = sqlite3_ds.get_zone_soa(zonename, self.server.get_db_file())
         if soa:
         if soa:
             return True
             return True
         return False
         return False
 
 
-    
+
     def _check_xfrout_available(self, zone_name):
     def _check_xfrout_available(self, zone_name):
         '''Check if xfr request can be responsed.
         '''Check if xfr request can be responsed.
            TODO, Get zone's configuration from cfgmgr or some other place
            TODO, Get zone's configuration from cfgmgr or some other place
-           eg. check allow_transfer setting, 
+           eg. check allow_transfer setting,
         '''
         '''
         if not self._zone_exist(zone_name):
         if not self._zone_exist(zone_name):
             return Rcode.NOTAUTH()
             return Rcode.NOTAUTH()
 
 
         if self._zone_is_empty(zone_name):
         if self._zone_is_empty(zone_name):
-            return Rcode.SERVFAIL() 
+            return Rcode.SERVFAIL()
 
 
         #TODO, check allow_transfer
         #TODO, check allow_transfer
         if not self.server.increase_transfers_counter():
         if not self.server.increase_transfers_counter():
@@ -193,35 +226,35 @@ class XfroutSession(BaseRequestHandler):
         return Rcode.NOERROR()
         return Rcode.NOERROR()
 
 
 
 
-    def dns_xfrout_start(self, sock, msg_query):
+    def dns_xfrout_start(self, sock_fd, msg_query):
         rcode_, msg = self._parse_query_message(msg_query)
         rcode_, msg = self._parse_query_message(msg_query)
         #TODO. create query message and parse header
         #TODO. create query message and parse header
         if rcode_ != Rcode.NOERROR():
         if rcode_ != Rcode.NOERROR():
-            return self._reply_query_with_format_error(msg, sock)
+            return self._reply_query_with_format_error(msg, sock_fd)
 
 
         zone_name = self._get_query_zone_name(msg)
         zone_name = self._get_query_zone_name(msg)
         rcode_ = self._check_xfrout_available(zone_name)
         rcode_ = self._check_xfrout_available(zone_name)
         if rcode_ != Rcode.NOERROR():
         if rcode_ != Rcode.NOERROR():
             self._log.log_message("info", "transfer of '%s/IN' failed: %s",
             self._log.log_message("info", "transfer of '%s/IN' failed: %s",
                                   zone_name, rcode_.to_text())
                                   zone_name, rcode_.to_text())
-            return self. _reply_query_with_error_rcode(msg, sock, rcode_)
+            return self. _reply_query_with_error_rcode(msg, sock_fd, rcode_)
 
 
         try:
         try:
             self._log.log_message("info", "transfer of '%s/IN': AXFR started" % zone_name)
             self._log.log_message("info", "transfer of '%s/IN': AXFR started" % zone_name)
-            self._reply_xfrout_query(msg, sock, zone_name)
+            self._reply_xfrout_query(msg, sock_fd, zone_name)
             self._log.log_message("info", "transfer of '%s/IN': AXFR end" % zone_name)
             self._log.log_message("info", "transfer of '%s/IN': AXFR end" % zone_name)
         except Exception as err:
         except Exception as err:
             self._log.log_message("error", str(err))
             self._log.log_message("error", str(err))
 
 
         self.server.decrease_transfers_counter()
         self.server.decrease_transfers_counter()
-        return    
+        return
 
 
 
 
     def _clear_message(self, msg):
     def _clear_message(self, msg):
         qid = msg.get_qid()
         qid = msg.get_qid()
         opcode = msg.get_opcode()
         opcode = msg.get_opcode()
         rcode = msg.get_rcode()
         rcode = msg.get_rcode()
-        
+
         msg.clear(Message.RENDER)
         msg.clear(Message.RENDER)
         msg.set_qid(qid)
         msg.set_qid(qid)
         msg.set_opcode(opcode)
         msg.set_opcode(opcode)
@@ -231,7 +264,7 @@ class XfroutSession(BaseRequestHandler):
         return msg
         return msg
 
 
     def _create_rrset_from_db_record(self, record):
     def _create_rrset_from_db_record(self, record):
-        '''Create one rrset from one record of datasource, if the schema of record is changed, 
+        '''Create one rrset from one record of datasource, if the schema of record is changed,
         This function should be updated first.
         This function should be updated first.
         '''
         '''
         rrtype_ = RRType(record[5])
         rrtype_ = RRType(record[5])
@@ -239,8 +272,8 @@ class XfroutSession(BaseRequestHandler):
         rrset_ = RRset(Name(record[2]), RRClass("IN"), rrtype_, RRTTL( int(record[4])))
         rrset_ = RRset(Name(record[2]), RRClass("IN"), rrtype_, RRTTL( int(record[4])))
         rrset_.add_rdata(rdata_)
         rrset_.add_rdata(rdata_)
         return rrset_
         return rrset_
-         
-    def _send_message_with_last_soa(self, msg, sock, rrset_soa, message_upper_len):
+
+    def _send_message_with_last_soa(self, msg, sock_fd, rrset_soa, message_upper_len):
         '''Add the SOA record to the end of message. If it can't be
         '''Add the SOA record to the end of message. If it can't be
         added, a new message should be created to send out the last soa .
         added, a new message should be created to send out the last soa .
         '''
         '''
@@ -249,14 +282,14 @@ class XfroutSession(BaseRequestHandler):
         if message_upper_len + rrset_len < XFROUT_MAX_MESSAGE_SIZE:
         if message_upper_len + rrset_len < XFROUT_MAX_MESSAGE_SIZE:
             msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
             msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
         else:
         else:
-            self._send_message(sock, msg)
+            self._send_message(sock_fd, msg)
             msg = self._clear_message(msg)
             msg = self._clear_message(msg)
             msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
             msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
 
 
-        self._send_message(sock, msg)
+        self._send_message(sock_fd, msg)
 
 
 
 
-    def _reply_xfrout_query(self, msg, sock, zone_name):
+    def _reply_xfrout_query(self, msg, sock_fd, zone_name):
         #TODO, there should be a better way to insert rrset.
         #TODO, there should be a better way to insert rrset.
         msg.make_response()
         msg.make_response()
         msg.set_header_flag(Message.HEADERFLAG_AA)
         msg.set_header_flag(Message.HEADERFLAG_AA)
@@ -286,12 +319,12 @@ class XfroutSession(BaseRequestHandler):
                 message_upper_len += rrset_len
                 message_upper_len += rrset_len
                 continue
                 continue
 
 
-            self._send_message(sock, msg)
+            self._send_message(sock_fd, msg)
             msg = self._clear_message(msg)
             msg = self._clear_message(msg)
             msg.add_rrset(Message.SECTION_ANSWER, rrset_) # Add the rrset to the new message
             msg.add_rrset(Message.SECTION_ANSWER, rrset_) # Add the rrset to the new message
             message_upper_len = rrset_len
             message_upper_len = rrset_len
 
 
-        self._send_message_with_last_soa(msg, sock, rrset_soa, message_upper_len)
+        self._send_message_with_last_soa(msg, sock_fd, rrset_soa, message_upper_len)
 
 
 class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
 class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
     '''The unix domain socket server which accept xfr query sent from auth server.'''
     '''The unix domain socket server which accept xfr query sent from auth server.'''
@@ -304,22 +337,23 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
         self._lock = threading.Lock()
         self._lock = threading.Lock()
         self._transfers_counter = 0
         self._transfers_counter = 0
         self._shutdown_event = shutdown_event
         self._shutdown_event = shutdown_event
+        self._write_sock, self._read_sock = socket.socketpair()
         self._log = log
         self._log = log
         self.update_config_data(config_data)
         self.update_config_data(config_data)
         self._cc = cc
         self._cc = cc
-        
+
     def finish_request(self, request, client_address):
     def finish_request(self, request, client_address):
         '''Finish one request by instantiating RequestHandlerClass.'''
         '''Finish one request by instantiating RequestHandlerClass.'''
-        self.RequestHandlerClass(request, client_address, self, self._log)
+        self.RequestHandlerClass(request, client_address, self, self._log, self._read_sock)
 
 
     def _remove_unused_sock_file(self, sock_file):
     def _remove_unused_sock_file(self, sock_file):
-        '''Try to remove the socket file. If the file is being used 
-        by one running xfrout process, exit from python. 
+        '''Try to remove the socket file. If the file is being used
+        by one running xfrout process, exit from python.
         If it's not a socket file or nobody is listening
         If it's not a socket file or nobody is listening
         , it will be removed. If it can't be removed, exit from python. '''
         , it will be removed. If it can't be removed, exit from python. '''
         if self._sock_file_in_use(sock_file):
         if self._sock_file_in_use(sock_file):
-            sys.stderr.write("[b10-xfrout] Fail to start xfrout process, unix socket" 
-                  " file '%s' is being used by another xfrout process\n" % sock_file)
+            self._log.log_message("error", "Fail to start xfrout process, unix socket file '%s'"
+                                 " is being used by another xfrout process\n" % sock_file)
             sys.exit(0)
             sys.exit(0)
         else:
         else:
             if not os.path.exists(sock_file):
             if not os.path.exists(sock_file):
@@ -328,12 +362,12 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
             try:
             try:
                 os.unlink(sock_file)
                 os.unlink(sock_file)
             except OSError as err:
             except OSError as err:
-                sys.stderr.write('[b10-xfrout] Fail to remove file %s: %s\n' % (sock_file, err))
+                self._log.log_message("error", '[b10-xfrout] Fail to remove file %s: %s\n' % (sock_file, err))
                 sys.exit(0)
                 sys.exit(0)
-   
+
     def _sock_file_in_use(self, sock_file):
     def _sock_file_in_use(self, sock_file):
-        '''Check whether the socket file 'sock_file' exists and 
-        is being used by one running xfrout process. If it is, 
+        '''Check whether the socket file 'sock_file' exists and
+        is being used by one running xfrout process. If it is,
         return True, or else return False. '''
         return True, or else return False. '''
         try:
         try:
             sock = socket.socket(socket.AF_UNIX)
             sock = socket.socket(socket.AF_UNIX)
@@ -341,9 +375,10 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
         except socket.error as err:
         except socket.error as err:
             return False
             return False
         else:
         else:
-            return True 
+            return True
 
 
     def shutdown(self):
     def shutdown(self):
+        self._write_sock.send(b"shutdown") #terminate the xfrout session thread
         super().shutdown() # call the shutdown() of class socketserver_mixin.NoPollMixIn
         super().shutdown() # call the shutdown() of class socketserver_mixin.NoPollMixIn
         try:
         try:
             os.unlink(self._sock_file)
             os.unlink(self._sock_file)
@@ -390,7 +425,7 @@ class XfroutServer:
     def __init__(self):
     def __init__(self):
         self._unix_socket_server = None
         self._unix_socket_server = None
         self._log = None
         self._log = None
-        self._listen_sock_file = UNIX_SOCKET_FILE 
+        self._listen_sock_file = UNIX_SOCKET_FILE
         self._shutdown_event = threading.Event()
         self._shutdown_event = threading.Event()
         self._cc = isc.config.ModuleCCSession(SPECFILE_LOCATION, self.config_handler, self.command_handler)
         self._cc = isc.config.ModuleCCSession(SPECFILE_LOCATION, self.config_handler, self.command_handler)
         self._config_data = self._cc.get_full_config()
         self._config_data = self._cc.get_full_config()
@@ -404,12 +439,12 @@ class XfroutServer:
 
 
     def _start_xfr_query_listener(self):
     def _start_xfr_query_listener(self):
         '''Start a new thread to accept xfr query. '''
         '''Start a new thread to accept xfr query. '''
-        self._unix_socket_server = UnixSockServer(self._listen_sock_file, XfroutSession, 
+        self._unix_socket_server = UnixSockServer(self._listen_sock_file, XfroutSession,
                                                   self._shutdown_event, self._config_data,
                                                   self._shutdown_event, self._config_data,
                                                   self._cc, self._log);
                                                   self._cc, self._log);
         listener = threading.Thread(target=self._unix_socket_server.serve_forever)
         listener = threading.Thread(target=self._unix_socket_server.serve_forever)
         listener.start()
         listener.start()
-        
+
     def _start_notifier(self):
     def _start_notifier(self):
         datasrc = self._unix_socket_server.get_db_file()
         datasrc = self._unix_socket_server.get_db_file()
         self._notifier = notify_out.NotifyOut(datasrc, self._log)
         self._notifier = notify_out.NotifyOut(datasrc, self._log)
@@ -472,7 +507,7 @@ class XfroutServer:
             else:
             else:
                 answer = create_answer(1, "Bad command parameter:" + str(args))
                 answer = create_answer(1, "Bad command parameter:" + str(args))
 
 
-        else: 
+        else:
             answer = create_answer(1, "Unknown command:" + str(cmd))
             answer = create_answer(1, "Unknown command:" + str(cmd))
 
 
         return answer
         return answer
@@ -514,7 +549,7 @@ if '__main__' == __name__:
         sys.stderr.write("[b10-xfrout] Error creating xfrout, "
         sys.stderr.write("[b10-xfrout] Error creating xfrout, "
                            "is the command channel daemon running?\n")
                            "is the command channel daemon running?\n")
     except SessionTimeout as e:
     except SessionTimeout as e:
-        sys.stderr.write("[b10-xfrout] Error creating xfrout, " 
+        sys.stderr.write("[b10-xfrout] Error creating xfrout, "
                            "is the configuration manager running?\n")
                            "is the configuration manager running?\n")
     except ModuleCCSessionError as e:
     except ModuleCCSessionError as e:
         sys.stderr.write("[b10-xfrout] exit xfrout process:%s\n" % str(e))
         sys.stderr.write("[b10-xfrout] exit xfrout process:%s\n" % str(e))

+ 32 - 32
src/bin/zonemgr/zonemgr.py.in

@@ -90,26 +90,26 @@ class ZonemgrException(Exception):
 
 
 class ZonemgrRefresh:
 class ZonemgrRefresh:
     """This class will maintain and manage zone refresh info.
     """This class will maintain and manage zone refresh info.
-    It also provides methods to keep track of zone timers and 
-    do zone refresh. 
-    Zone timers can be started by calling run_timer(), and it 
+    It also provides methods to keep track of zone timers and
+    do zone refresh.
+    Zone timers can be started by calling run_timer(), and it
     can be stopped by calling shutdown() in another thread.
     can be stopped by calling shutdown() in another thread.
 
 
     """
     """
 
 
     def __init__(self, cc, db_file, slave_socket, config_data):
     def __init__(self, cc, db_file, slave_socket, config_data):
         self._cc = cc
         self._cc = cc
-        self._check_sock = slave_socket 
+        self._check_sock = slave_socket
         self._db_file = db_file
         self._db_file = db_file
         self.update_config_data(config_data)
         self.update_config_data(config_data)
-        self._zonemgr_refresh_info = {} 
+        self._zonemgr_refresh_info = {}
         self._build_zonemgr_refresh_info()
         self._build_zonemgr_refresh_info()
         self._running = False
         self._running = False
-    
+
     def _random_jitter(self, max, jitter):
     def _random_jitter(self, max, jitter):
         """Imposes some random jitters for refresh and
         """Imposes some random jitters for refresh and
         retry timers to avoid many zones need to do refresh
         retry timers to avoid many zones need to do refresh
-        at the same time. 
+        at the same time.
         The value should be between (max - jitter) and max.
         The value should be between (max - jitter) and max.
         """
         """
         if 0 == jitter:
         if 0 == jitter:
@@ -120,7 +120,7 @@ class ZonemgrRefresh:
         return time.time()
         return time.time()
 
 
     def _set_zone_timer(self, zone_name_class, max, jitter):
     def _set_zone_timer(self, zone_name_class, max, jitter):
-        """Set zone next refresh time. 
+        """Set zone next refresh time.
         jitter should not be bigger than half the original value."""
         jitter should not be bigger than half the original value."""
         self._set_zone_next_refresh_time(zone_name_class, self._get_current_time() + \
         self._set_zone_next_refresh_time(zone_name_class, self._get_current_time() + \
                                             self._random_jitter(max, jitter))
                                             self._random_jitter(max, jitter))
@@ -143,7 +143,7 @@ class ZonemgrRefresh:
 
 
     def _set_zone_notify_timer(self, zone_name_class):
     def _set_zone_notify_timer(self, zone_name_class):
         """Set zone next refresh time after receiving notify
         """Set zone next refresh time after receiving notify
-           next_refresh_time = now 
+           next_refresh_time = now
         """
         """
         self._set_zone_timer(zone_name_class, 0, 0)
         self._set_zone_timer(zone_name_class, 0, 0)
 
 
@@ -199,7 +199,7 @@ class ZonemgrRefresh:
             raise ZonemgrException("[b10-zonemgr] zone (%s, %s) doesn't have soa." % zone_name_class)
             raise ZonemgrException("[b10-zonemgr] zone (%s, %s) doesn't have soa." % zone_name_class)
         zone_info["zone_soa_rdata"] = zone_soa[7]
         zone_info["zone_soa_rdata"] = zone_soa[7]
         zone_info["zone_state"] = ZONE_OK
         zone_info["zone_state"] = ZONE_OK
-        zone_info["last_refresh_time"] = self._get_current_time() 
+        zone_info["last_refresh_time"] = self._get_current_time()
         zone_info["next_refresh_time"] = self._get_current_time() + \
         zone_info["next_refresh_time"] = self._get_current_time() + \
                                          float(zone_soa[7].split(" ")[REFRESH_OFFSET])
                                          float(zone_soa[7].split(" ")[REFRESH_OFFSET])
         self._zonemgr_refresh_info[zone_name_class] = zone_info
         self._zonemgr_refresh_info[zone_name_class] = zone_info
@@ -233,7 +233,7 @@ class ZonemgrRefresh:
 
 
     def _get_zone_notifier_master(self, zone_name_class):
     def _get_zone_notifier_master(self, zone_name_class):
         if ("notify_master" in self._zonemgr_refresh_info[zone_name_class].keys()):
         if ("notify_master" in self._zonemgr_refresh_info[zone_name_class].keys()):
-            return self._zonemgr_refresh_info[zone_name_class]["notify_master"] 
+            return self._zonemgr_refresh_info[zone_name_class]["notify_master"]
 
 
         return None
         return None
 
 
@@ -248,7 +248,7 @@ class ZonemgrRefresh:
         return self._zonemgr_refresh_info[zone_name_class]["zone_state"]
         return self._zonemgr_refresh_info[zone_name_class]["zone_state"]
 
 
     def _set_zone_state(self, zone_name_class, zone_state):
     def _set_zone_state(self, zone_name_class, zone_state):
-        self._zonemgr_refresh_info[zone_name_class]["zone_state"] = zone_state 
+        self._zonemgr_refresh_info[zone_name_class]["zone_state"] = zone_state
 
 
     def _get_zone_refresh_timeout(self, zone_name_class):
     def _get_zone_refresh_timeout(self, zone_name_class):
         return self._zonemgr_refresh_info[zone_name_class]["refresh_timeout"]
         return self._zonemgr_refresh_info[zone_name_class]["refresh_timeout"]
@@ -268,7 +268,7 @@ class ZonemgrRefresh:
         try:
         try:
             self._cc.group_sendmsg(msg, module_name)
             self._cc.group_sendmsg(msg, module_name)
         except socket.error:
         except socket.error:
-            sys.stderr.write("[b10-zonemgr] Failed to send to module %s, the session has been closed." % module_name) 
+            sys.stderr.write("[b10-zonemgr] Failed to send to module %s, the session has been closed." % module_name)
 
 
     def _find_need_do_refresh_zone(self):
     def _find_need_do_refresh_zone(self):
         """Find the first zone need do refresh, if no zone need
         """Find the first zone need do refresh, if no zone need
@@ -281,10 +281,10 @@ class ZonemgrRefresh:
             if (ZONE_REFRESHING == zone_state and
             if (ZONE_REFRESHING == zone_state and
                 (self._get_zone_refresh_timeout(zone_name_class) > self._get_current_time())):
                 (self._get_zone_refresh_timeout(zone_name_class) > self._get_current_time())):
                 continue
                 continue
-                    
-            # Get the zone with minimum next_refresh_time 
-            if ((zone_need_refresh is None) or 
-                (self._get_zone_next_refresh_time(zone_name_class) < 
+
+            # Get the zone with minimum next_refresh_time
+            if ((zone_need_refresh is None) or
+                (self._get_zone_next_refresh_time(zone_name_class) <
                  self._get_zone_next_refresh_time(zone_need_refresh))):
                  self._get_zone_next_refresh_time(zone_need_refresh))):
                 zone_need_refresh = zone_name_class
                 zone_need_refresh = zone_name_class
 
 
@@ -292,14 +292,14 @@ class ZonemgrRefresh:
             if (self._get_zone_next_refresh_time(zone_need_refresh) < self._get_current_time()):
             if (self._get_zone_next_refresh_time(zone_need_refresh) < self._get_current_time()):
                 break
                 break
 
 
-        return zone_need_refresh 
+        return zone_need_refresh
+
 
 
-    
     def _do_refresh(self, zone_name_class):
     def _do_refresh(self, zone_name_class):
         """Do zone refresh."""
         """Do zone refresh."""
         log_msg("Do refresh for zone (%s, %s)." % zone_name_class)
         log_msg("Do refresh for zone (%s, %s)." % zone_name_class)
         self._set_zone_state(zone_name_class, ZONE_REFRESHING)
         self._set_zone_state(zone_name_class, ZONE_REFRESHING)
-        self._set_zone_refresh_timeout(zone_name_class, self._get_current_time() + self._max_transfer_timeout) 
+        self._set_zone_refresh_timeout(zone_name_class, self._get_current_time() + self._max_transfer_timeout)
         notify_master = self._get_zone_notifier_master(zone_name_class)
         notify_master = self._get_zone_notifier_master(zone_name_class)
         # If the zone has notify master, send notify command to xfrin module
         # If the zone has notify master, send notify command to xfrin module
         if notify_master:
         if notify_master:
@@ -307,7 +307,7 @@ class ZonemgrRefresh:
                      "zone_class" : zone_name_class[1],
                      "zone_class" : zone_name_class[1],
                      "master" : notify_master
                      "master" : notify_master
                      }
                      }
-            self._send_command(XFRIN_MODULE_NAME, ZONE_NOTIFY_COMMAND, param) 
+            self._send_command(XFRIN_MODULE_NAME, ZONE_NOTIFY_COMMAND, param)
             self._clear_zone_notifier_master(zone_name_class)
             self._clear_zone_notifier_master(zone_name_class)
         # Send refresh command to xfrin module
         # Send refresh command to xfrin module
         else:
         else:
@@ -328,19 +328,19 @@ class ZonemgrRefresh:
         while self._running:
         while self._running:
             # If zonemgr has no zone, set timer timeout to self._lowerbound_retry.
             # If zonemgr has no zone, set timer timeout to self._lowerbound_retry.
             if self._zone_mgr_is_empty():
             if self._zone_mgr_is_empty():
-                timeout = self._lowerbound_retry 
+                timeout = self._lowerbound_retry
             else:
             else:
                 zone_need_refresh = self._find_need_do_refresh_zone()
                 zone_need_refresh = self._find_need_do_refresh_zone()
-                # If don't get zone with minimum next refresh time, set timer timeout to self._lowerbound_retry 
+                # If don't get zone with minimum next refresh time, set timer timeout to self._lowerbound_retry.
                 if not zone_need_refresh:
                 if not zone_need_refresh:
-                    timeout = self._lowerbound_retry 
+                    timeout = self._lowerbound_retry
                 else:
                 else:
                     timeout = self._get_zone_next_refresh_time(zone_need_refresh) - self._get_current_time()
                     timeout = self._get_zone_next_refresh_time(zone_need_refresh) - self._get_current_time()
                     if (timeout < 0):
                     if (timeout < 0):
                         self._do_refresh(zone_need_refresh)
                         self._do_refresh(zone_need_refresh)
                         continue
                         continue
 
 
-            """ Wait for the socket notification for a maximum time of timeout 
+            """ Wait for the socket notification for a maximum time of timeout
             in seconds (as float)."""
             in seconds (as float)."""
             try:
             try:
                 rlist, wlist, xlist = select.select([self._check_sock, self._read_sock], [], [], timeout)
                 rlist, wlist, xlist = select.select([self._check_sock, self._read_sock], [], [], timeout)
@@ -352,7 +352,7 @@ class ZonemgrRefresh:
                     break
                     break
 
 
             for fd in rlist:
             for fd in rlist:
-                if fd == self._read_sock: # awaken by shutdown socket 
+                if fd == self._read_sock: # awaken by shutdown socket
                     # self._running will be False by now, if it is not a false
                     # self._running will be False by now, if it is not a false
                     # alarm
                     # alarm
                     continue
                     continue
@@ -416,7 +416,7 @@ class Zonemgr:
         self._zone_refresh = None
         self._zone_refresh = None
         self._setup_session()
         self._setup_session()
         self._db_file = self.get_db_file()
         self._db_file = self.get_db_file()
-        # Create socket pair for communicating between main thread and zonemgr timer thread 
+        # Create socket pair for communicating between main thread and zonemgr timer thread
         self._master_socket, self._slave_socket = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
         self._master_socket, self._slave_socket = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
         self._zone_refresh = ZonemgrRefresh(self._cc, self._db_file, self._slave_socket, self._config_data)
         self._zone_refresh = ZonemgrRefresh(self._cc, self._db_file, self._slave_socket, self._config_data)
         self._zone_refresh.run_timer()
         self._zone_refresh.run_timer()
@@ -426,7 +426,7 @@ class Zonemgr:
         self.running = False
         self.running = False
 
 
     def _setup_session(self):
     def _setup_session(self):
-        """Setup two sessions for zonemgr, one(self._module_cc) is used for receiving 
+        """Setup two sessions for zonemgr, one(self._module_cc) is used for receiving
         commands and config data sent from other modules, another one (self._cc)
         commands and config data sent from other modules, another one (self._cc)
         is used to send commands to proper modules."""
         is used to send commands to proper modules."""
         self._cc = isc.cc.Session()
         self._cc = isc.cc.Session()
@@ -450,7 +450,7 @@ class Zonemgr:
     def shutdown(self):
     def shutdown(self):
         """Shutdown the zonemgr process. the thread which is keeping track of zone
         """Shutdown the zonemgr process. the thread which is keeping track of zone
         timers should be terminated.
         timers should be terminated.
-        """ 
+        """
         self._zone_refresh.shutdown()
         self._zone_refresh.shutdown()
 
 
         self._slave_socket.close()
         self._slave_socket.close()
@@ -503,7 +503,7 @@ class Zonemgr:
 
 
     def command_handler(self, command, args):
     def command_handler(self, command, args):
         """Handle command receivd from command channel.
         """Handle command receivd from command channel.
-        ZONE_NOTIFY_COMMAND is issued by Auth process; ZONE_XFRIN_SUCCESS_COMMAND 
+        ZONE_NOTIFY_COMMAND is issued by Auth process; ZONE_XFRIN_SUCCESS_COMMAND
         and ZONE_XFRIN_FAILED_COMMAND are issued by Xfrin process; shutdown is issued
         and ZONE_XFRIN_FAILED_COMMAND are issued by Xfrin process; shutdown is issued
         by a user or Boss process. """
         by a user or Boss process. """
         answer = create_answer(0)
         answer = create_answer(0)
@@ -572,10 +572,10 @@ if '__main__' == __name__:
     except KeyboardInterrupt:
     except KeyboardInterrupt:
         sys.stderr.write("[b10-zonemgr] exit zonemgr process\n")
         sys.stderr.write("[b10-zonemgr] exit zonemgr process\n")
     except isc.cc.session.SessionError as e:
     except isc.cc.session.SessionError as e:
-        sys.stderr.write("[b10-zonemgr] Error creating zonemgr, " 
+        sys.stderr.write("[b10-zonemgr] Error creating zonemgr, "
                            "is the command channel daemon running?\n")
                            "is the command channel daemon running?\n")
     except isc.cc.session.SessionTimeout as e:
     except isc.cc.session.SessionTimeout as e:
-        sys.stderr.write("[b10-zonemgr] Error creating zonemgr, " 
+        sys.stderr.write("[b10-zonemgr] Error creating zonemgr, "
                            "is the configuration manager running?\n")
                            "is the configuration manager running?\n")
     except isc.config.ModuleCCSessionError as e:
     except isc.config.ModuleCCSessionError as e:
         sys.stderr.write("[b10-zonemgr] exit zonemgr process: %s\n" % str(e))
         sys.stderr.write("[b10-zonemgr] exit zonemgr process: %s\n" % str(e))

+ 1 - 1
src/lib/xfr/fd_share.cc

@@ -93,7 +93,7 @@ recv_fd(const int sock) {
 
 
     if (recvmsg(sock, &msghdr, 0) < 0) {
     if (recvmsg(sock, &msghdr, 0) < 0) {
         free(msghdr.msg_control);
         free(msghdr.msg_control);
-        return (-1);
+        return (XFR_FD_RECEIVE_FAIL);
     }
     }
     const struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msghdr);
     const struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msghdr);
     int fd = -1;
     int fd = -1;

+ 7 - 3
src/lib/xfr/fd_share.h

@@ -20,12 +20,16 @@
 namespace isc {
 namespace isc {
 namespace xfr {
 namespace xfr {
 
 
+/// Failed to receive xfr socket descriptor "fd" on unix domain socket 'sock'
+const int XFR_FD_RECEIVE_FAIL = -2;
+
 // Receive socket descriptor on unix domain socket 'sock'.
 // Receive socket descriptor on unix domain socket 'sock'.
 // Returned value is the socket descriptor received.
 // Returned value is the socket descriptor received.
+// Returned XFR_FD_RECEIVE_FAIL if failed to receive xfr socket descriptor
 // Errors are indicated by a return value of -1.
 // Errors are indicated by a return value of -1.
 int recv_fd(const int sock);
 int recv_fd(const int sock);
 
 
-// Send socket descriptor "fd" to server over unix domain socket 'sock', 
+// Send socket descriptor "fd" to server over unix domain socket 'sock',
 // the connection from socket 'sock' to unix domain server should be established first.
 // the connection from socket 'sock' to unix domain server should be established first.
 // Errors are indicated by a return value of -1.
 // Errors are indicated by a return value of -1.
 int send_fd(const int sock, const int fd);
 int send_fd(const int sock, const int fd);
@@ -35,6 +39,6 @@ int send_fd(const int sock, const int fd);
 
 
 #endif
 #endif
 
 
-// Local Variables: 
+// Local Variables:
 // mode: c++
 // mode: c++
-// End: 
+// End:

+ 16 - 3
src/lib/xfr/fdshare_python.cc

@@ -22,8 +22,9 @@
 
 
 #include <xfr/fd_share.h>
 #include <xfr/fd_share.h>
 
 
+
 static PyObject*
 static PyObject*
-fdshare_recv_fd(PyObject *self UNUSED_PARAM, PyObject *args) {
+fdshare_recv_fd(PyObject* self UNUSED_PARAM, PyObject* args) {
     int sock, fd;
     int sock, fd;
     if (!PyArg_ParseTuple(args, "i", &sock)) {
     if (!PyArg_ParseTuple(args, "i", &sock)) {
         return (NULL);
         return (NULL);
@@ -33,7 +34,7 @@ fdshare_recv_fd(PyObject *self UNUSED_PARAM, PyObject *args) {
 }
 }
 
 
 static PyObject*
 static PyObject*
-fdshare_send_fd(PyObject *self UNUSED_PARAM, PyObject *args) {
+fdshare_send_fd(PyObject* self UNUSED_PARAM, PyObject* args) {
     int sock, fd, result;
     int sock, fd, result;
     if (!PyArg_ParseTuple(args, "ii", &sock, &fd)) {
     if (!PyArg_ParseTuple(args, "ii", &sock, &fd)) {
         return (NULL);
         return (NULL);
@@ -63,11 +64,23 @@ static PyModuleDef bind10_fdshare_python = {
 
 
 PyMODINIT_FUNC
 PyMODINIT_FUNC
 PyInit_libxfr_python(void) {
 PyInit_libxfr_python(void) {
-    PyObject *mod = PyModule_Create(&bind10_fdshare_python);
+    PyObject* mod = PyModule_Create(&bind10_fdshare_python);
     if (mod == NULL) {
     if (mod == NULL) {
         return (NULL);
         return (NULL);
     }
     }
 
 
+    PyObject* XFR_FD_RECEIVE_FAIL = Py_BuildValue("i", isc::xfr::XFR_FD_RECEIVE_FAIL);
+    if (XFR_FD_RECEIVE_FAIL == NULL) {
+        Py_XDECREF(mod);
+        return (NULL);
+    }
+    int ret = PyModule_AddObject(mod, "XFR_FD_RECEIVE_FAIL", XFR_FD_RECEIVE_FAIL);
+    if (-1 == ret) {
+        Py_XDECREF(XFR_FD_RECEIVE_FAIL);
+        Py_XDECREF(mod);
+        return (NULL);
+    }
+
     return (mod);
     return (mod);
 }
 }
 
 

+ 1 - 7
src/lib/xfr/xfrout_client.cc

@@ -69,7 +69,7 @@ XfroutClient::disconnect() {
     }
     }
 }
 }
 
 
-int 
+int
 XfroutClient::sendXfroutRequestInfo(const int tcp_sock,
 XfroutClient::sendXfroutRequestInfo(const int tcp_sock,
                                     const void* const msg_data,
                                     const void* const msg_data,
                                     const uint16_t msg_len)
                                     const uint16_t msg_len)
@@ -93,12 +93,6 @@ XfroutClient::sendXfroutRequestInfo(const int tcp_sock,
         isc_throw(XfroutError,
         isc_throw(XfroutError,
                   "failed to send XFR request data to xfrout module");
                   "failed to send XFR request data to xfrout module");
     }
     }
-    
-    int databuf = 0;
-    if (recv(impl_->socket_.native(), &databuf, sizeof(int), 0) != 0) {
-        isc_throw(XfroutError,
-                  "xfr query hasn't been processed properly by xfrout module");
-    }
 
 
     return (0);
     return (0);
 }
 }