Parcourir la source

[2158] updated counters handling in the classes

 - If it omits setting the counters, which are notifyoutv4, notifyoutv6,
   xfrrej, and xfrreqdone, they are set to None as defaults when the object is
   initiating. Then when each counter is invoked in some method, it checks
   whether the counter is Nonetype or not. Unless the counter is NoneType, it
   invokes the counter.  After that unless the counter is callable, a TypeError
   exception would be raised.

 - added some tests for this changes

 - removed some dead code from xfrout.py.in
Naoki Kambe il y a 12 ans
Parent
commit
2c05e5b0d8

+ 47 - 0
src/bin/xfrout/tests/xfrout_test.py.in

@@ -472,6 +472,25 @@ class TestXfroutSession(TestXfroutSessionBase):
         self.check_transfer_acl(acl_setter)
         self.check_transfer_acl(acl_setter)
         self.assertEqual(self._zone_name_xfrrej, TEST_ZONE_NAME_STR)
         self.assertEqual(self._zone_name_xfrrej, TEST_ZONE_NAME_STR)
 
 
+    def test_transfer_acl_with_nonetype_xfrrej(self):
+        # ACL checks only with the default ACL and NoneType xfrrej
+        # counter
+        def acl_setter(acl):
+            self.xfrsess._acl = acl
+        self.xfrsess._counter_xfrrej = None
+        self.assertIsNone(self._zone_name_xfrrej)
+        self.check_transfer_acl(acl_setter)
+        self.assertIsNone(self._zone_name_xfrrej)
+
+    def test_transfer_acl_with_notcallable_xfrrej(self):
+        # ACL checks only with the default ACL and not callable xfrrej
+        # counter
+        def acl_setter(acl):
+            self.xfrsess._acl = acl
+        self.xfrsess._counter_xfrrej = 'NOT CALLABLE'
+        self.assertRaises(TypeError,
+                          self.check_transfer_acl, acl_setter)
+
     def test_transfer_zoneacl(self):
     def test_transfer_zoneacl(self):
         # ACL check with a per zone ACL + default ACL.  The per zone ACL
         # ACL check with a per zone ACL + default ACL.  The per zone ACL
         # should match the queryied zone, so it should be used.
         # should match the queryied zone, so it should be used.
@@ -853,6 +872,34 @@ class TestXfroutSession(TestXfroutSessionBase):
         self.assertEqual(self.sock.readsent(), b"success")
         self.assertEqual(self.sock.readsent(), b"success")
         self.assertEqual(self._zone_name_xfrreqdone, TEST_ZONE_NAME_STR)
         self.assertEqual(self._zone_name_xfrreqdone, TEST_ZONE_NAME_STR)
 
 
+    def test_dns_xfrout_start_with_nonetype_xfrreqdone(self):
+        def noerror(msg, name, rrclass):
+            return Rcode.NOERROR()
+        self.xfrsess._xfrout_setup = noerror
+
+        def myreply(msg, sock):
+            self.sock.send(b"success")
+
+        self.assertIsNone(self._zone_name_xfrreqdone)
+        self.xfrsess._reply_xfrout_query = myreply
+        self.xfrsess._counter_xfrreqdone = None
+        self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
+        self.assertIsNone(self._zone_name_xfrreqdone)
+
+    def test_dns_xfrout_start_with_notcallable_xfrreqdone(self):
+        def noerror(msg, name, rrclass):
+            return Rcode.NOERROR()
+        self.xfrsess._xfrout_setup = noerror
+
+        def myreply(msg, sock):
+            self.sock.send(b"success")
+
+        self.xfrsess._reply_xfrout_query = myreply
+        self.xfrsess._counter_xfrreqdone = 'NOT CALLABLE'
+        self.assertRaises(TypeError,
+                          self.xfrsess.dns_xfrout_start, self.sock,
+                          self.mdata)
+
     def test_reply_xfrout_query_axfr(self):
     def test_reply_xfrout_query_axfr(self):
         self.xfrsess._soa = self.soa_rrset
         self.xfrsess._soa = self.soa_rrset
         self.xfrsess._iterator = [self.soa_rrset]
         self.xfrsess._iterator = [self.soa_rrset]

+ 8 - 12
src/bin/xfrout/xfrout.py.in

@@ -171,12 +171,8 @@ class XfroutSession():
         self._jnl_reader = None # will be set to a reader for IXFR
         self._jnl_reader = None # will be set to a reader for IXFR
         # Set counter handlers for counting Xfr requests. An argument
         # Set counter handlers for counting Xfr requests. An argument
         # is required for zone name.
         # is required for zone name.
-        self._counter_xfrrej = lambda x: None
-        if hasattr(counter_xfrrej, '__call__'):
-            self._counter_xfrrej = counter_xfrrej
-        self._counter_xfrreqdone = lambda x: None
-        if hasattr(counter_xfrreqdone, '__call__'):
-            self._counter_xfrreqdone = counter_xfrreqdone
+        self._counter_xfrrej = counter_xfrrej
+        self._counter_xfrreqdone = counter_xfrreqdone
         self._handle()
         self._handle()
 
 
     def create_tsig_ctx(self, tsig_record, tsig_key_ring):
     def create_tsig_ctx(self, tsig_record, tsig_key_ring):
@@ -279,8 +275,9 @@ class XfroutSession():
                          format_zone_str(zone_name, zone_class))
                          format_zone_str(zone_name, zone_class))
             return None, None
             return None, None
         elif acl_result == REJECT:
         elif acl_result == REJECT:
-            # count rejected Xfr request by each zone name
-            self._counter_xfrrej(zone_name.to_text())
+            if self._counter_xfrrej is not None:
+                # count rejected Xfr request by each zone name
+                self._counter_xfrrej(zone_name.to_text())
             logger.debug(DBG_XFROUT_TRACE, XFROUT_QUERY_REJECTED,
             logger.debug(DBG_XFROUT_TRACE, XFROUT_QUERY_REJECTED,
                          self._request_type, format_addrinfo(self._remote),
                          self._request_type, format_addrinfo(self._remote),
                          format_zone_str(zone_name, zone_class))
                          format_zone_str(zone_name, zone_class))
@@ -536,8 +533,9 @@ class XfroutSession():
         except Exception as err:
         except Exception as err:
             logger.error(XFROUT_XFR_TRANSFER_ERROR, self._request_typestr,
             logger.error(XFROUT_XFR_TRANSFER_ERROR, self._request_typestr,
                     format_addrinfo(self._remote), zone_str, err)
                     format_addrinfo(self._remote), zone_str, err)
-        # count done Xfr requests by each zone name
-        self._counter_xfrreqdone(zone_name.to_text())
+        if self._counter_xfrreqdone is not None:
+            # count done Xfr requests by each zone name
+            self._counter_xfrreqdone(zone_name.to_text())
         logger.info(XFROUT_XFR_TRANSFER_DONE, self._request_typestr,
         logger.info(XFROUT_XFR_TRANSFER_DONE, self._request_typestr,
                     format_addrinfo(self._remote), zone_str)
                     format_addrinfo(self._remote), zone_str)
 
 
@@ -1023,8 +1021,6 @@ class XfroutCounter:
                 with self._lock:
                 with self._lock:
                     self._add_perzone_counter(zone_name)
                     self._add_perzone_counter(zone_name)
                     self._statistics_data[self.perzone_prefix][zone_name][counter_name] += step
                     self._statistics_data[self.perzone_prefix][zone_name][counter_name] += step
-            #def __perzone_incrementer(zone_name, counter_name=item):
-            #    self._perzone_incrementer(zone_name, counter_name)
             setattr(self, 'inc_%s' % item, __perzone_incrementer)
             setattr(self, 'inc_%s' % item, __perzone_incrementer)
 
 
 
 

+ 8 - 8
src/lib/python/isc/notify/notify_out.py

@@ -145,12 +145,8 @@ class NotifyOut:
         self._nonblock_event = threading.Event()
         self._nonblock_event = threading.Event()
         # Set counter handlers for counting notifies. An argument is
         # Set counter handlers for counting notifies. An argument is
         # required for zone name.
         # required for zone name.
-        self._counter_notifyoutv4 = lambda x: None
-        if hasattr(counter_notifyoutv4, '__call__'):
-            self._counter_notifyoutv4 = counter_notifyoutv4
-        self._counter_notifyoutv6 = lambda x: None
-        if hasattr(counter_notifyoutv6, '__call__'):
-            self._counter_notifyoutv6 = counter_notifyoutv6
+        self._counter_notifyoutv4 = counter_notifyoutv4
+        self._counter_notifyoutv6 = counter_notifyoutv6
 
 
     def _init_notify_out(self, datasrc_file):
     def _init_notify_out(self, datasrc_file):
         '''Get all the zones name and its notify target's address.
         '''Get all the zones name and its notify target's address.
@@ -488,9 +484,13 @@ class NotifyOut:
             sock = zone_notify_info.create_socket(addrinfo[0])
             sock = zone_notify_info.create_socket(addrinfo[0])
             sock.sendto(render.get_data(), 0, addrinfo)
             sock.sendto(render.get_data(), 0, addrinfo)
             # count notifying by IPv4 or IPv6 for statistics
             # count notifying by IPv4 or IPv6 for statistics
-            if zone_notify_info.get_socket().family == socket.AF_INET:
+            if zone_notify_info.get_socket().family \
+                    == socket.AF_INET \
+                    and self._counter_notifyoutv4 is not None:
                 self._counter_notifyoutv4(zone_notify_info.zone_name)
                 self._counter_notifyoutv4(zone_notify_info.zone_name)
-            elif zone_notify_info.get_socket().family == socket.AF_INET6:
+            elif zone_notify_info.get_socket().family \
+                    == socket.AF_INET6 \
+                    and self._counter_notifyoutv6 is not None:
                 self._counter_notifyoutv6(zone_notify_info.zone_name)
                 self._counter_notifyoutv6(zone_notify_info.zone_name)
             logger.info(NOTIFY_OUT_SENDING_NOTIFY, addrinfo[0],
             logger.info(NOTIFY_OUT_SENDING_NOTIFY, addrinfo[0],
                         addrinfo[1])
                         addrinfo[1])

+ 36 - 0
src/lib/python/isc/notify/tests/notify_out_test.py

@@ -289,6 +289,42 @@ class TestNotifyOut(unittest.TestCase):
         self.assertIsNone(self._notifiedv4_zone_name)
         self.assertIsNone(self._notifiedv4_zone_name)
         self.assertEqual(self._notifiedv6_zone_name, 'example.net.')
         self.assertEqual(self._notifiedv6_zone_name, 'example.net.')
 
 
+    def test_send_notify_message_udp_ipv4_with_nonetype_notifyoutv4(self):
+        example_com_info = self._notify._notify_infos[('example.net.', 'IN')]
+        example_com_info.prepare_notify_out()
+        self.assertIsNone(self._notifiedv4_zone_name)
+        self.assertIsNone(self._notifiedv6_zone_name)
+        self._notify._counter_notifyoutv4 = None
+        self._notify._send_notify_message_udp(example_com_info,
+                                              ('192.0.2.1', 53))
+        self.assertIsNone(self._notifiedv4_zone_name)
+        self.assertIsNone(self._notifiedv6_zone_name)
+
+    def test_send_notify_message_udp_ipv4_with_notcallable_notifyoutv4(self):
+        example_com_info = self._notify._notify_infos[('example.net.', 'IN')]
+        example_com_info.prepare_notify_out()
+        self._notify._counter_notifyoutv4 = 'NOT CALLABLE'
+        self.assertRaises(TypeError,
+                          self._notify._send_notify_message_udp,
+                          example_com_info, ('192.0.2.1', 53))
+
+    def test_send_notify_message_udp_ipv6_with_nonetype_notifyoutv6(self):
+        example_com_info = self._notify._notify_infos[('example.net.', 'IN')]
+        self.assertIsNone(self._notifiedv4_zone_name)
+        self.assertIsNone(self._notifiedv6_zone_name)
+        self._notify._counter_notifyoutv6 = None
+        self._notify._send_notify_message_udp(example_com_info,
+                                              ('2001:db8::53', 53))
+        self.assertIsNone(self._notifiedv4_zone_name)
+        self.assertIsNone(self._notifiedv6_zone_name)
+
+    def test_send_notify_message_udp_ipv6_with_notcallable_notifyoutv6(self):
+        example_com_info = self._notify._notify_infos[('example.net.', 'IN')]
+        self._notify._counter_notifyoutv6 = 'NOT CALLABLE'
+        self.assertRaises(TypeError,
+                          self._notify._send_notify_message_udp,
+                          example_com_info, ('2001:db8::53', 53))
+
     def test_send_notify_message_with_bogus_address(self):
     def test_send_notify_message_with_bogus_address(self):
         example_com_info = self._notify._notify_infos[('example.net.', 'IN')]
         example_com_info = self._notify._notify_infos[('example.net.', 'IN')]