Browse Source

[1790] Add unit tests using mock objects

Mukund Sivaraman 13 years ago
parent
commit
95b852252c
2 changed files with 157 additions and 41 deletions
  1. 116 0
      src/bin/xfrin/tests/xfrin_test.py
  2. 41 41
      src/bin/xfrin/xfrin.py.in

+ 116 - 0
src/bin/xfrin/tests/xfrin_test.py

@@ -2736,6 +2736,44 @@ class TestMain(unittest.TestCase):
         MockXfrin.check_command_hook = raise_exception
         main(MockXfrin, False)
 
+class TestXfrinProcessMockCC:
+    def __init__(self, config = []):
+        self.get_called = False
+        self.get_called_correctly = False
+        self.config = config
+
+    def get_remote_config_value(self, module, identifier):
+        self.get_called = True
+        if module == 'Auth' and identifier == 'datasources':
+            self.get_called_correctly = True
+            return (self.config, False)
+        else:
+            return (None, True)
+
+class TestXfrinProcessMockCCSession:
+    def __init__(self):
+        self.send_called = False
+        self.send_called_correctly = False
+        self.recv_called = False
+        self.recv_called_correctly = False
+
+    def group_sendmsg(self, msg, module):
+        self.send_called = True
+        if module == 'Auth' and msg['command'][0] == 'loadzone':
+            self.send_called_correctly = True
+            seq = "random-e068c2de26d760f20cf10afc4b87ef0f"
+        else:
+            seq = None
+
+        return seq
+
+    def group_recvmsg(self, message, seq):
+        self.recv_called = True
+        if message == False and seq == "random-e068c2de26d760f20cf10afc4b87ef0f":
+            self.recv_called_correctly = True
+        # return values are ignored
+        return (None, None)
+
 class TestXfrinProcess(unittest.TestCase):
     """
     Some tests for the xfrin_process function. This replaces the
@@ -2751,6 +2789,8 @@ class TestXfrinProcess(unittest.TestCase):
 
         Also sets up several internal variables to watch what happens.
         """
+        self._module_cc = TestXfrinProcessMockCC()
+        self._send_cc_session = TestXfrinProcessMockCCSession()
         # This will hold a "log" of what transfers were attempted.
         self.__transfers = []
         # This will "log" if failures or successes happened.
@@ -2795,6 +2835,9 @@ class TestXfrinProcess(unittest.TestCase):
         Part of pretending to be the server as well. This just logs the
         success/failure of the previous operation.
         """
+        if ret == XFRIN_OK:
+            xfrin._do_auth_loadzone(self, zone_name, rrclass)
+
         self.__published.append(ret)
 
     def close(self):
@@ -2825,12 +2868,19 @@ class TestXfrinProcess(unittest.TestCase):
         # Create a connection for each attempt
         self.assertEqual(len(transfers), self.__created_connections)
         self.assertEqual([published], self.__published)
+        if published == XFRIN_OK:
+            self.assertEqual(True, self._module_cc.get_called)
+            self.assertEqual(True, self._module_cc.get_called_correctly)
 
     def test_ixfr_ok(self):
         """
         Everything OK the first time, over IXFR.
         """
         self.__do_test([XFRIN_OK], [RRType.IXFR()], RRType.IXFR())
+        self.assertEqual(False, self._send_cc_session.send_called)
+        self.assertEqual(False, self._send_cc_session.send_called_correctly)
+        self.assertEqual(False, self._send_cc_session.recv_called)
+        self.assertEqual(False, self._send_cc_session.recv_called_correctly)
 
     def test_axfr_ok(self):
         """
@@ -2861,6 +2911,72 @@ class TestXfrinProcess(unittest.TestCase):
         """
         self.__do_test([XFRIN_FAIL, XFRIN_FAIL],
                        [RRType.IXFR(), RRType.AXFR()], RRType.IXFR())
+
+    def test_inmem_ok(self):
+        """
+        Inmem configuration 1.
+        """
+        self._module_cc.config = [{'zones': [{'origin': 'example.org', 'filetype': 'sqlite3',
+                                              'file': 'data/inmem-xfrin.sqlite3'}],
+                                   'type': 'memory', 'class': 'IN'}]
+        self.__do_test([XFRIN_OK], [RRType.IXFR()], RRType.IXFR())
+        self.assertEqual(True, self._send_cc_session.send_called)
+        self.assertEqual(True, self._send_cc_session.send_called_correctly)
+        self.assertEqual(True, self._send_cc_session.recv_called)
+        self.assertEqual(True, self._send_cc_session.recv_called_correctly)
+
+    def test_inmem_not_memory(self):
+        """
+        Inmem configuration 2.
+        """
+        self._module_cc.config = [{'zones': [{'origin': 'example.org', 'filetype': 'sqlite3',
+                                              'file': 'data/inmem-xfrin.sqlite3'}],
+                                   'type': 'punched-card', 'class': 'IN'}]
+        self.__do_test([XFRIN_OK], [RRType.IXFR()], RRType.IXFR())
+        self.assertEqual(False, self._send_cc_session.send_called)
+        self.assertEqual(False, self._send_cc_session.send_called_correctly)
+        self.assertEqual(False, self._send_cc_session.recv_called)
+        self.assertEqual(False, self._send_cc_session.recv_called_correctly)
+
+    def test_inmem_not_sqlite3(self):
+        """
+        Inmem configuration 3.
+        """
+        self._module_cc.config = [{'zones': [{'origin': 'example.org', 'filetype': 'postgresql',
+                                              'file': 'data/inmem-xfrin.sqlite3'}],
+                                   'type': 'memory', 'class': 'IN'}]
+        self.__do_test([XFRIN_OK], [RRType.IXFR()], RRType.IXFR())
+        self.assertEqual(False, self._send_cc_session.send_called)
+        self.assertEqual(False, self._send_cc_session.send_called_correctly)
+        self.assertEqual(False, self._send_cc_session.recv_called)
+        self.assertEqual(False, self._send_cc_session.recv_called_correctly)
+
+    def test_inmem_not_of_same_class(self):
+        """
+        Inmem configuration 4.
+        """
+        self._module_cc.config = [{'zones': [{'origin': 'example.org', 'filetype': 'sqlite3',
+                                              'file': 'data/inmem-xfrin.sqlite3'}],
+                                   'type': 'memory', 'class': 'XX'}]
+        self.__do_test([XFRIN_OK], [RRType.IXFR()], RRType.IXFR())
+        self.assertEqual(False, self._send_cc_session.send_called)
+        self.assertEqual(False, self._send_cc_session.send_called_correctly)
+        self.assertEqual(False, self._send_cc_session.recv_called)
+        self.assertEqual(False, self._send_cc_session.recv_called_correctly)
+
+    def test_inmem_not_present(self):
+        """
+        Inmem configuration 5.
+        """
+        self._module_cc.config = [{'zones': [{'origin': 'isc.org', 'filetype': 'sqlite3',
+                                              'file': 'data/inmem-xfrin.sqlite3'}],
+                                   'type': 'memory', 'class': 'IN'}]
+        self.__do_test([XFRIN_OK], [RRType.IXFR()], RRType.IXFR())
+        self.assertEqual(False, self._send_cc_session.send_called)
+        self.assertEqual(False, self._send_cc_session.send_called_correctly)
+        self.assertEqual(False, self._send_cc_session.recv_called)
+        self.assertEqual(False, self._send_cc_session.recv_called_correctly)
+
 class TestFormatting(unittest.TestCase):
     # If the formatting functions are moved to a more general library
     # (ticket #1379), these tests should be moved with them.

+ 41 - 41
src/bin/xfrin/xfrin.py.in

@@ -1247,6 +1247,46 @@ class ZoneInfo:
         return (self.master_addr.family, socket.SOCK_STREAM,
                 (str(self.master_addr), self.master_port))
 
+def _do_auth_loadzone(server, zone_name, zone_class):
+    # On a successful zone transfer, if the zone is served by
+    # b10-auth in the in-memory data source using sqlite3 as a
+    # backend, send the "loadzone" command for the zone to auth.
+    datasources, is_default =\
+        server._module_cc.get_remote_config_value(AUTH_MODULE_NAME, "datasources")
+    if is_default:
+        return
+    for d in datasources:
+        try:
+            if "class" in d:
+                dclass = RRClass(d["class"])
+            else:
+                dclass = RRClass("IN")
+        except InvalidRRClass as err:
+            logger.info(XFRIN_AUTH_CONFIG_RRCLASS_ERROR, str(err))
+            continue
+
+        if d["type"].lower() == "memory" and dclass == zone_class:
+            for zone in d["zones"]:
+                if "filetype" not in zone:
+                    continue
+                try:
+                    name = Name(zone["origin"])
+                except (EmptyLabel, TooLongLabel, BadLabelType, BadEscape, TooLongName, IncompleteName):
+                    logger.info(XFRIN_AUTH_CONFIG_NAME_PARSER_ERROR, str(err))
+                    continue
+
+                if zone["filetype"].lower() == "sqlite3" and name == zone_name:
+                    param = {"origin": zone_name.to_text(),
+                             "class": zone_class.to_text(),
+                             "datasrc": d["type"]}
+
+                    logger.debug(DBG_XFRIN_TRACE, XFRIN_AUTH_LOADZONE,
+                                 param["origin"], param["class"], param["datasrc"])
+
+                    msg = create_command("loadzone", param)
+                    seq = server._send_cc_session.group_sendmsg(msg, AUTH_MODULE_NAME)
+                    answer, env = server._send_cc_session.group_recvmsg(False, seq)
+
 class Xfrin:
     def __init__(self):
         self._max_transfers_in = 10
@@ -1540,46 +1580,6 @@ class Xfrin:
                       "bind10_zones.sqlite3"
         self._db_file = db_file
 
-    def _do_auth_loadzone(self, zone_name, zone_class):
-        # On a successful zone transfer, if the zone is served by
-        # b10-auth in the in-memory data source using sqlite3 as a
-        # backend, send the "loadzone" command for the zone to auth.
-        datasources, is_default =\
-            self._module_cc.get_remote_config_value(AUTH_MODULE_NAME, "datasources")
-        if is_default:
-            return
-        for d in datasources:
-            try:
-                if "class" in d:
-                    dclass = RRClass(d["class"])
-                else:
-                    dclass = RRClass("IN")
-            except InvalidRRClass as err:
-                logger.info(XFRIN_AUTH_CONFIG_RRCLASS_ERROR, str(err))
-                continue
-
-            if d["type"].lower() == "memory" and dclass == zone_class:
-                for zone in d["zones"]:
-                    if "filetype" not in zone:
-                        continue
-                    try:
-                        name = Name(zone["origin"])
-                    except (EmptyLabel, TooLongLabel, BadLabelType, BadEscape, TooLongName, IncompleteName):
-                        logger.info(XFRIN_AUTH_CONFIG_NAME_PARSER_ERROR, str(err))
-                        continue
-
-                    if zone["filetype"].lower() == "sqlite3" and name == zone_name:
-                        param = {"origin": zone_name.to_text(),
-                                 "class": zone_class.to_text(),
-                                 "datasrc": d["type"]}
-
-                        logger.debug(DBG_XFRIN_TRACE, XFRIN_AUTH_LOADZONE,
-                                     param["origin"], param["class"], param["datasrc"])
-
-                        msg = create_command("loadzone", param)
-                        seq = self._send_cc_session.group_sendmsg(msg, AUTH_MODULE_NAME)
-                        answer, env = self._send_cc_session.group_recvmsg(False, seq)
-
     def publish_xfrin_news(self, zone_name, zone_class, xfr_result):
         '''Send command to xfrout/zone manager module.
         If xfrin has finished successfully for one zone, tell the good
@@ -1589,7 +1589,7 @@ class Xfrin:
         param = {'zone_name': zone_name.to_text(),
                  'zone_class': zone_class.to_text()}
         if xfr_result == XFRIN_OK:
-            self._do_auth_loadzone(zone_name, zone_class)
+            _do_auth_loadzone(self, zone_name, zone_class)
             msg = create_command(notify_out.ZONE_NEW_DATA_READY_CMD, param)
             # catch the exception, in case msgq has been killed.
             try: