Parcourir la source

[640] add test and modify XfroutServer.shutdown() a bit

Jelte Jansen il y a 13 ans
Parent
commit
5da80cb570
2 fichiers modifiés avec 31 ajouts et 6 suppressions
  1. 25 0
      src/bin/xfrout/tests/xfrout_test.py.in
  2. 6 6
      src/bin/xfrout/xfrout.py.in

+ 25 - 0
src/bin/xfrout/tests/xfrout_test.py.in

@@ -19,6 +19,7 @@
 import unittest
 import os
 from isc.testutils.tsigctx_mock import MockTSIGContext
+from isc.testutils.ccsession_mock import MockModuleCCSession
 from isc.cc.session import *
 import isc.config
 from isc.dns import *
@@ -1423,6 +1424,30 @@ class TestInitialization(unittest.TestCase):
         xfrout.init_paths()
         self.assertEqual(xfrout.UNIX_SOCKET_FILE, "The/Socket/File")
 
+class MyNotifier():
+    def __init__(self):
+        self.shutdown_called = False
+
+    def shutdown(self):
+        self.shutdown_called = True
+
+class MyXfroutServer(XfroutServer):
+    def __init__(self):
+        self._cc = MockModuleCCSession()
+        self._shutdown_event = threading.Event()
+        self._notifier = MyNotifier()
+        self._unix_socket_server = None
+
+class TestXfroutServer(unittest.TestCase):
+    def setUp(self):
+        self.xfrout_server = MyXfroutServer()
+
+    def test_shutdown(self):
+        self.xfrout_server.shutdown()
+        self.assertTrue(self.xfrout_server._notifier.shutdown_called)
+        self.assertTrue(self.xfrout_server._cc.stopped)
+
+
 if __name__== "__main__":
     isc.log.resetUnitTestRootLogger()
     unittest.main()

+ 6 - 6
src/bin/xfrout/xfrout.py.in

@@ -975,12 +975,12 @@ class XfroutServer:
         if self._unix_socket_server:
             self._unix_socket_server.shutdown()
 
-        # Wait for all threads to terminate
-        main_thread = threading.currentThread()
-        for th in threading.enumerate():
-            if th is main_thread:
-                continue
-            th.join()
+            # Wait for all threads to terminate
+            main_thread = threading.currentThread()
+            for th in threading.enumerate():
+                if th is main_thread:
+                    continue
+                th.join()
 
     def command_handler(self, cmd, args):
         if cmd == "shutdown":