Browse Source

[2222] added tests for counting the number of AXFR/IXFR running

 - added tests which check properly counting the number of AXFR/IXFR
   running

 - added changes related to the changes of XfroutCounter

 - added misc refactoring for testing

 - used validate_statistics() for checking statistics data format to be
   returned to the stats module

 - updated copyright
Naoki Kambe 12 years ago
parent
commit
8ab3a360bb
1 changed files with 126 additions and 45 deletions
  1. 126 45
      src/bin/xfrout/tests/xfrout_test.py.in

+ 126 - 45
src/bin/xfrout/tests/xfrout_test.py.in

@@ -1,4 +1,4 @@
-# Copyright (C) 2010  Internet Systems Consortium.
+# Copyright (C) 2010-2012  Internet Systems Consortium.
 #
 # Permission to use, copy, modify, and distribute this software for any
 # purpose with or without fee is hereby granted, provided that the above
@@ -270,6 +270,7 @@ class TestXfroutSessionBase(unittest.TestCase):
 
     def setUp(self):
         self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM)
+        self.setup_counters()
         self.xfrsess = MyXfroutSession(self.sock, None, Dbserver(),
                                        TSIGKeyRing(),
                                        (socket.AF_INET, socket.SOCK_STREAM,
@@ -278,22 +279,50 @@ class TestXfroutSessionBase(unittest.TestCase):
                                        isc.acl.dns.REQUEST_LOADER.load(
                                            [{"action": "ACCEPT"}]),
                                        {},
-                                       counter_xfrrej=self._counter_xfrrej,
-                                       counter_xfrreqdone=self._counter_xfrreqdone)
+                                       **self._counters)
         self.set_request_type(RRType.AXFR()) # test AXFR by default
         self.mdata = self.create_request_data()
         self.soa_rrset = create_soa(SOA_CURRENT_VERSION)
         # some test replaces a module-wide function.  We should ensure the
         # original is used elsewhere.
         self.orig_get_rrset_len = xfrout.get_rrset_len
-        self._zone_name_xfrrej = None
-        self._zone_name_xfrreqdone = None
 
-    def _counter_xfrrej(self, zone_name):
-        self._zone_name_xfrrej = zone_name
-
-    def _counter_xfrreqdone(self, zone_name):
-        self._zone_name_xfrreqdone = zone_name
+    def setup_counters(self):
+        self._statistics_data = {
+            'zones' : {
+                TEST_ZONE_NAME_STR : {
+                    'xfrrej': 0,
+                    'xfrreqdone': 0
+                    }
+                },
+            'axfr_started': 0,
+            'ixfr_started': 0,
+            'axfr_ended': 0,
+            'ixfr_ended': 0
+            }
+        def _counter_xfrrej(zone_name):
+            self._statistics_data['zones'][zone_name]['xfrrej'] += 1
+        def _counter_xfrreqdone(zone_name):
+            self._statistics_data['zones'][zone_name]['xfrreqdone'] += 1
+        def _inc_ixfr_running():
+            self._statistics_data['ixfr_started'] += 1
+        def _dec_ixfr_running():
+            self._statistics_data['ixfr_ended'] += 1
+        def _inc_axfr_running():
+            self._statistics_data['axfr_started'] += 1
+        def _dec_axfr_running():
+            self._statistics_data['axfr_ended'] += 1
+        self._counters = {
+            'counter_xfrrej': _counter_xfrrej,
+            'counter_xfrreqdone': _counter_xfrreqdone,
+            'inc_ixfr_running': _inc_ixfr_running,
+            'dec_ixfr_running': _dec_ixfr_running,
+            'inc_axfr_running': _inc_axfr_running,
+            'dec_axfr_running': _dec_axfr_running
+            }
+        self.get_counter = lambda n: \
+            self._statistics_data[n] if 'ixfr_' in n or 'axfr_' in n \
+            else self._statistics_data['zones'][TEST_ZONE_NAME_STR][n]
 
     def tearDown(self):
         xfrout.get_rrset_len = self.orig_get_rrset_len
@@ -468,9 +497,9 @@ class TestXfroutSession(TestXfroutSessionBase):
         # ACL checks only with the default ACL
         def acl_setter(acl):
             self.xfrsess._acl = acl
-        self.assertIsNone(self._zone_name_xfrrej)
+        self.assertEqual(self.get_counter('xfrrej'), 0)
         self.check_transfer_acl(acl_setter)
-        self.assertEqual(self._zone_name_xfrrej, TEST_ZONE_NAME_STR)
+        self.assertGreater(self.get_counter('xfrrej'), 0)
 
     def test_transfer_acl_with_nonetype_xfrrej(self):
         # ACL checks only with the default ACL and NoneType xfrrej
@@ -500,9 +529,9 @@ class TestXfroutSession(TestXfroutSessionBase):
             self.xfrsess._zone_config[zone_key]['transfer_acl'] = acl
             self.xfrsess._acl = isc.acl.dns.REQUEST_LOADER.load([
                     {"from": "127.0.0.1", "action": "DROP"}])
-        self.assertIsNone(self._zone_name_xfrrej)
+        self.assertEqual(self.get_counter('xfrrej'), 0)
         self.check_transfer_acl(acl_setter)
-        self.assertEqual(self._zone_name_xfrrej, TEST_ZONE_NAME_STR)
+        self.assertGreater(self.get_counter('xfrrej'), 0)
 
     def test_transfer_zoneacl_nomatch(self):
         # similar to the previous one, but the per zone doesn't match the
@@ -514,9 +543,9 @@ class TestXfroutSession(TestXfroutSessionBase):
                 isc.acl.dns.REQUEST_LOADER.load([
                     {"from": "127.0.0.1", "action": "DROP"}])
             self.xfrsess._acl = acl
-        self.assertIsNone(self._zone_name_xfrrej)
+        self.assertEqual(self.get_counter('xfrrej'), 0)
         self.check_transfer_acl(acl_setter)
-        self.assertEqual(self._zone_name_xfrrej, TEST_ZONE_NAME_STR)
+        self.assertGreater(self.get_counter('xfrrej'), 0)
 
     def test_get_transfer_acl(self):
         # set the default ACL.  If there's no specific zone ACL, this one
@@ -866,11 +895,11 @@ class TestXfroutSession(TestXfroutSessionBase):
         def myreply(msg, sock):
             self.sock.send(b"success")
 
-        self.assertIsNone(self._zone_name_xfrreqdone)
+        self.assertEqual(self.get_counter('xfrreqdone'), 0)
         self.xfrsess._reply_xfrout_query = myreply
         self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
         self.assertEqual(self.sock.readsent(), b"success")
-        self.assertEqual(self._zone_name_xfrreqdone, TEST_ZONE_NAME_STR)
+        self.assertGreater(self.get_counter('xfrreqdone'), 0)
 
     def test_dns_xfrout_start_with_nonetype_xfrreqdone(self):
         def noerror(msg, name, rrclass):
@@ -1154,10 +1183,20 @@ class TestXfroutSessionWithSQLite3(TestXfroutSessionBase):
             self.assertTrue(rrsets_equal(expected_rr, actual_rr))
 
     def test_axfr_normal_session(self):
+        self.assertEqual(self.get_counter('axfr_started'), 0)
+        self.assertEqual(self.get_counter('axfr_ended'), 0)
+        self.assertEqual(self.get_counter('ixfr_started'), 0)
+        self.assertEqual(self.get_counter('ixfr_ended'), 0)
         XfroutSession._handle(self.xfrsess)
         response = self.sock.read_msg(Message.PRESERVE_ORDER);
         self.assertEqual(Rcode.NOERROR(), response.get_rcode())
         self.check_axfr_stream(response)
+        self.assertEqual(self.xfrsess._request_type, RRType.AXFR())
+        self.assertNotEqual(self.xfrsess._request_type, RRType.IXFR())
+        self.assertEqual(self.get_counter('axfr_started'), 1)
+        self.assertEqual(self.get_counter('axfr_ended'), 1)
+        self.assertEqual(self.get_counter('ixfr_started'), 0)
+        self.assertEqual(self.get_counter('ixfr_ended'), 0)
 
     def test_ixfr_to_axfr(self):
         self.xfrsess._request_data = \
@@ -1176,6 +1215,10 @@ class TestXfroutSessionWithSQLite3(TestXfroutSessionBase):
         # two beginning and trailing SOAs.
         self.xfrsess._request_data = \
             self.create_request_data(ixfr=IXFR_OK_VERSION)
+        self.assertEqual(self.get_counter('axfr_started'), 0)
+        self.assertEqual(self.get_counter('axfr_ended'), 0)
+        self.assertEqual(self.get_counter('ixfr_started'), 0)
+        self.assertEqual(self.get_counter('ixfr_ended'), 0)
         XfroutSession._handle(self.xfrsess)
         response = self.sock.read_msg(Message.PRESERVE_ORDER)
         actual_records = response.get_section(Message.SECTION_ANSWER)
@@ -1191,6 +1234,12 @@ class TestXfroutSessionWithSQLite3(TestXfroutSessionBase):
         self.assertEqual(len(expected_records), len(actual_records))
         for (expected_rr, actual_rr) in zip(expected_records, actual_records):
             self.assertTrue(rrsets_equal(expected_rr, actual_rr))
+        self.assertNotEqual(self.xfrsess._request_type, RRType.AXFR())
+        self.assertEqual(self.xfrsess._request_type, RRType.IXFR())
+        self.assertEqual(self.get_counter('axfr_started'), 0)
+        self.assertEqual(self.get_counter('axfr_ended'), 0)
+        self.assertEqual(self.get_counter('ixfr_started'), 1)
+        self.assertEqual(self.get_counter('ixfr_ended'), 1)
 
     def ixfr_soa_only_common_checks(self, request_serial):
         self.xfrsess._request_data = \
@@ -1578,9 +1627,9 @@ class MyXfroutServer(XfroutServer):
 
 class TestXfroutCounter(unittest.TestCase):
     def setUp(self):
-        statistics_spec = \
-            isc.config.module_spec_from_file(\
-            xfrout.SPECFILE_LOCATION).get_statistics_spec()
+        self._module_spec = isc.config.module_spec_from_file(\
+            xfrout.SPECFILE_LOCATION)
+        statistics_spec = self._module_spec.get_statistics_spec()
         self.xfrout_counter = XfroutCounter(statistics_spec)
         self._counters = isc.config.spec_name_list(\
             isc.config.find_spec_part(\
@@ -1591,22 +1640,23 @@ class TestXfroutCounter(unittest.TestCase):
         self._cycle = 10000 # number of counting per thread
 
     def test_get_default_statistics_data(self):
-        self.assertEqual(self.xfrout_counter._get_default_statistics_data(),
-                         {XfroutCounter.perzone_prefix: {
-                            XfroutCounter.entire_server: \
-                              dict([(cnt, 0) for cnt in self._counters])
-                         }})
-
-    def setup_incrementer(self, incrementer):
+        self.assertTrue(\
+            self._module_spec.validate_statistics(\
+                True,
+                self.xfrout_counter._get_default_statistics_data(),
+                )
+            )
+
+    def setup_incrementer(self, incrementer, *args):
         self._started.wait()
-        for i in range(self._cycle): incrementer(TEST_ZONE_NAME_STR)
+        for i in range(self._cycle): incrementer(*args)
 
-    def start_incrementer(self, incrementer):
+    def start_incrementer(self, incrementer, *args):
         threads = []
         for i in range(self._number):
             threads.append(threading.Thread(\
-                    target=self.setup_incrementer,\
-                        args=(incrementer,)\
+                    target=self.setup_incrementer, \
+                        args=(incrementer,) + args \
                         ))
         for th in threads: th.start()
         self._started.set()
@@ -1618,24 +1668,55 @@ class TestXfroutCounter(unittest.TestCase):
                 '%s/%s/%s' % (XfroutCounter.perzone_prefix,\
                                   zone_name, counter_name))
 
-    def test_incrementers(self):
+    def test_xxcrementers(self):
+        # for per-zone counters
         result = { XfroutCounter.entire_server: {},
                    TEST_ZONE_NAME_STR: {} }
         for counter_name in self._counters:
-                incrementer = getattr(self.xfrout_counter, 'inc_%s' % counter_name)
-                self.start_incrementer(incrementer)
-                self.assertEqual(self.get_count(\
-                            TEST_ZONE_NAME_STR, counter_name), \
-                                     self._number * self._cycle)
-                self.assertEqual(self.get_count(\
-                        XfroutCounter.entire_server, counter_name), \
-                                     self._number * self._cycle)
-                result[XfroutCounter.entire_server][counter_name] = \
-                    result[TEST_ZONE_NAME_STR][counter_name] = \
-                    self._number * self._cycle
+            incrementer = \
+                dict(self.xfrout_counter.get_counters_for_xfroutsession(), \
+                         **self.xfrout_counter.get_counters_for_notifyout())\
+                         ['counter_%s' % counter_name]
+            self.start_incrementer(incrementer, TEST_ZONE_NAME_STR)
+            self.assertEqual(self.get_count(\
+                        TEST_ZONE_NAME_STR, counter_name), \
+                                 self._number * self._cycle)
+            self.assertEqual(self.get_count(\
+                    XfroutCounter.entire_server, counter_name), \
+                                 self._number * self._cycle)
+            result[XfroutCounter.entire_server][counter_name] = \
+                result[TEST_ZONE_NAME_STR][counter_name] = \
+                self._number * self._cycle
+        statistics_data = {XfroutCounter.perzone_prefix: result}
+
+        # for {a|i}xfrrunning counters
+        for counter_name in self.xfrout_counter._xfrrunning_names:
+            incrementer = \
+                dict(self.xfrout_counter.get_counters_for_xfroutsession(), \
+                         **self.xfrout_counter.get_counters_for_notifyout())\
+                         ['inc_%s' % counter_name]
+            self.start_incrementer(incrementer)
+            self.assertEqual(
+                self.xfrout_counter.get_statistics()[counter_name],
+                self._number * self._cycle
+                )
+            decrementer = \
+                dict(self.xfrout_counter.get_counters_for_xfroutsession(), \
+                         **self.xfrout_counter.get_counters_for_notifyout())\
+                         ['dec_%s' % counter_name]
+            self.start_incrementer(decrementer)
+            self.assertEqual(
+                self.xfrout_counter.get_statistics()[counter_name],
+                0)
+            statistics_data[counter_name] = 0
         self.assertEqual(
             self.xfrout_counter.get_statistics(),
-            {XfroutCounter.perzone_prefix: result})
+            statistics_data)
+        self.assertTrue(\
+            self._module_spec.validate_statistics(\
+                True, statistics_data
+                )
+            )
 
     def test_add_perzone_counter(self):
         for counter_name in self._counters: