Browse Source

[2158] updated counters handling in the classes

 - If it omits setting the counters, which are notifyoutv4, notifyoutv6,
   xfrrej, and xfrreqdone, they are set to None as defaults when the object is
   initiating. Then when each counter is invoked in some method, it checks
   whether the counter is Nonetype or not. Unless the counter is NoneType, it
   invokes the counter.  After that unless the counter is callable, a TypeError
   exception would be raised.

 - added some tests for this changes

 - removed some dead code from xfrout.py.in
Naoki Kambe 12 years ago
parent
commit
2c05e5b0d8

+ 47 - 0
src/bin/xfrout/tests/xfrout_test.py.in

@@ -472,6 +472,25 @@ class TestXfroutSession(TestXfroutSessionBase):
         self.check_transfer_acl(acl_setter)
         self.assertEqual(self._zone_name_xfrrej, TEST_ZONE_NAME_STR)
 
+    def test_transfer_acl_with_nonetype_xfrrej(self):
+        # ACL checks only with the default ACL and NoneType xfrrej
+        # counter
+        def acl_setter(acl):
+            self.xfrsess._acl = acl
+        self.xfrsess._counter_xfrrej = None
+        self.assertIsNone(self._zone_name_xfrrej)
+        self.check_transfer_acl(acl_setter)
+        self.assertIsNone(self._zone_name_xfrrej)
+
+    def test_transfer_acl_with_notcallable_xfrrej(self):
+        # ACL checks only with the default ACL and not callable xfrrej
+        # counter
+        def acl_setter(acl):
+            self.xfrsess._acl = acl
+        self.xfrsess._counter_xfrrej = 'NOT CALLABLE'
+        self.assertRaises(TypeError,
+                          self.check_transfer_acl, acl_setter)
+
     def test_transfer_zoneacl(self):
         # ACL check with a per zone ACL + default ACL.  The per zone ACL
         # should match the queryied zone, so it should be used.
@@ -853,6 +872,34 @@ class TestXfroutSession(TestXfroutSessionBase):
         self.assertEqual(self.sock.readsent(), b"success")
         self.assertEqual(self._zone_name_xfrreqdone, TEST_ZONE_NAME_STR)
 
+    def test_dns_xfrout_start_with_nonetype_xfrreqdone(self):
+        def noerror(msg, name, rrclass):
+            return Rcode.NOERROR()
+        self.xfrsess._xfrout_setup = noerror
+
+        def myreply(msg, sock):
+            self.sock.send(b"success")
+
+        self.assertIsNone(self._zone_name_xfrreqdone)
+        self.xfrsess._reply_xfrout_query = myreply
+        self.xfrsess._counter_xfrreqdone = None
+        self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
+        self.assertIsNone(self._zone_name_xfrreqdone)
+
+    def test_dns_xfrout_start_with_notcallable_xfrreqdone(self):
+        def noerror(msg, name, rrclass):
+            return Rcode.NOERROR()
+        self.xfrsess._xfrout_setup = noerror
+
+        def myreply(msg, sock):
+            self.sock.send(b"success")
+
+        self.xfrsess._reply_xfrout_query = myreply
+        self.xfrsess._counter_xfrreqdone = 'NOT CALLABLE'
+        self.assertRaises(TypeError,
+                          self.xfrsess.dns_xfrout_start, self.sock,
+                          self.mdata)
+
     def test_reply_xfrout_query_axfr(self):
         self.xfrsess._soa = self.soa_rrset
         self.xfrsess._iterator = [self.soa_rrset]

+ 8 - 12
src/bin/xfrout/xfrout.py.in

@@ -171,12 +171,8 @@ class XfroutSession():
         self._jnl_reader = None # will be set to a reader for IXFR
         # Set counter handlers for counting Xfr requests. An argument
         # is required for zone name.
-        self._counter_xfrrej = lambda x: None
-        if hasattr(counter_xfrrej, '__call__'):
-            self._counter_xfrrej = counter_xfrrej
-        self._counter_xfrreqdone = lambda x: None
-        if hasattr(counter_xfrreqdone, '__call__'):
-            self._counter_xfrreqdone = counter_xfrreqdone
+        self._counter_xfrrej = counter_xfrrej
+        self._counter_xfrreqdone = counter_xfrreqdone
         self._handle()
 
     def create_tsig_ctx(self, tsig_record, tsig_key_ring):
@@ -279,8 +275,9 @@ class XfroutSession():
                          format_zone_str(zone_name, zone_class))
             return None, None
         elif acl_result == REJECT:
-            # count rejected Xfr request by each zone name
-            self._counter_xfrrej(zone_name.to_text())
+            if self._counter_xfrrej is not None:
+                # count rejected Xfr request by each zone name
+                self._counter_xfrrej(zone_name.to_text())
             logger.debug(DBG_XFROUT_TRACE, XFROUT_QUERY_REJECTED,
                          self._request_type, format_addrinfo(self._remote),
                          format_zone_str(zone_name, zone_class))
@@ -536,8 +533,9 @@ class XfroutSession():
         except Exception as err:
             logger.error(XFROUT_XFR_TRANSFER_ERROR, self._request_typestr,
                     format_addrinfo(self._remote), zone_str, err)
-        # count done Xfr requests by each zone name
-        self._counter_xfrreqdone(zone_name.to_text())
+        if self._counter_xfrreqdone is not None:
+            # count done Xfr requests by each zone name
+            self._counter_xfrreqdone(zone_name.to_text())
         logger.info(XFROUT_XFR_TRANSFER_DONE, self._request_typestr,
                     format_addrinfo(self._remote), zone_str)
 
@@ -1023,8 +1021,6 @@ class XfroutCounter:
                 with self._lock:
                     self._add_perzone_counter(zone_name)
                     self._statistics_data[self.perzone_prefix][zone_name][counter_name] += step
-            #def __perzone_incrementer(zone_name, counter_name=item):
-            #    self._perzone_incrementer(zone_name, counter_name)
             setattr(self, 'inc_%s' % item, __perzone_incrementer)
 
 

+ 8 - 8
src/lib/python/isc/notify/notify_out.py

@@ -145,12 +145,8 @@ class NotifyOut:
         self._nonblock_event = threading.Event()
         # Set counter handlers for counting notifies. An argument is
         # required for zone name.
-        self._counter_notifyoutv4 = lambda x: None
-        if hasattr(counter_notifyoutv4, '__call__'):
-            self._counter_notifyoutv4 = counter_notifyoutv4
-        self._counter_notifyoutv6 = lambda x: None
-        if hasattr(counter_notifyoutv6, '__call__'):
-            self._counter_notifyoutv6 = counter_notifyoutv6
+        self._counter_notifyoutv4 = counter_notifyoutv4
+        self._counter_notifyoutv6 = counter_notifyoutv6
 
     def _init_notify_out(self, datasrc_file):
         '''Get all the zones name and its notify target's address.
@@ -488,9 +484,13 @@ class NotifyOut:
             sock = zone_notify_info.create_socket(addrinfo[0])
             sock.sendto(render.get_data(), 0, addrinfo)
             # count notifying by IPv4 or IPv6 for statistics
-            if zone_notify_info.get_socket().family == socket.AF_INET:
+            if zone_notify_info.get_socket().family \
+                    == socket.AF_INET \
+                    and self._counter_notifyoutv4 is not None:
                 self._counter_notifyoutv4(zone_notify_info.zone_name)
-            elif zone_notify_info.get_socket().family == socket.AF_INET6:
+            elif zone_notify_info.get_socket().family \
+                    == socket.AF_INET6 \
+                    and self._counter_notifyoutv6 is not None:
                 self._counter_notifyoutv6(zone_notify_info.zone_name)
             logger.info(NOTIFY_OUT_SENDING_NOTIFY, addrinfo[0],
                         addrinfo[1])

+ 36 - 0
src/lib/python/isc/notify/tests/notify_out_test.py

@@ -289,6 +289,42 @@ class TestNotifyOut(unittest.TestCase):
         self.assertIsNone(self._notifiedv4_zone_name)
         self.assertEqual(self._notifiedv6_zone_name, 'example.net.')
 
+    def test_send_notify_message_udp_ipv4_with_nonetype_notifyoutv4(self):
+        example_com_info = self._notify._notify_infos[('example.net.', 'IN')]
+        example_com_info.prepare_notify_out()
+        self.assertIsNone(self._notifiedv4_zone_name)
+        self.assertIsNone(self._notifiedv6_zone_name)
+        self._notify._counter_notifyoutv4 = None
+        self._notify._send_notify_message_udp(example_com_info,
+                                              ('192.0.2.1', 53))
+        self.assertIsNone(self._notifiedv4_zone_name)
+        self.assertIsNone(self._notifiedv6_zone_name)
+
+    def test_send_notify_message_udp_ipv4_with_notcallable_notifyoutv4(self):
+        example_com_info = self._notify._notify_infos[('example.net.', 'IN')]
+        example_com_info.prepare_notify_out()
+        self._notify._counter_notifyoutv4 = 'NOT CALLABLE'
+        self.assertRaises(TypeError,
+                          self._notify._send_notify_message_udp,
+                          example_com_info, ('192.0.2.1', 53))
+
+    def test_send_notify_message_udp_ipv6_with_nonetype_notifyoutv6(self):
+        example_com_info = self._notify._notify_infos[('example.net.', 'IN')]
+        self.assertIsNone(self._notifiedv4_zone_name)
+        self.assertIsNone(self._notifiedv6_zone_name)
+        self._notify._counter_notifyoutv6 = None
+        self._notify._send_notify_message_udp(example_com_info,
+                                              ('2001:db8::53', 53))
+        self.assertIsNone(self._notifiedv4_zone_name)
+        self.assertIsNone(self._notifiedv6_zone_name)
+
+    def test_send_notify_message_udp_ipv6_with_notcallable_notifyoutv6(self):
+        example_com_info = self._notify._notify_infos[('example.net.', 'IN')]
+        self._notify._counter_notifyoutv6 = 'NOT CALLABLE'
+        self.assertRaises(TypeError,
+                          self._notify._send_notify_message_udp,
+                          example_com_info, ('2001:db8::53', 53))
+
     def test_send_notify_message_with_bogus_address(self):
         example_com_info = self._notify._notify_infos[('example.net.', 'IN')]