Parcourir la source

[2879] fixed the main defect with revised tests.

JINMEI Tatuya il y a 12 ans
Parent
commit
f217f1e624

+ 6 - 2
src/lib/python/isc/notify/notify_out.py

@@ -435,9 +435,13 @@ class NotifyOut:
         """
         tgt = zone_notify_info.get_current_notify_target()
         if event_type == _EVENT_READ:
+            # Note: _get_notify_reply() should also check the response's
+            # source address (see #2924).  When it's done the following code
+            # should also be adjusted a bit.
             reply = self._get_notify_reply(zone_notify_info.get_socket(), tgt)
             if reply is not None:
-                if self._handle_notify_reply(zone_notify_info, reply, tgt):
+                if (self._handle_notify_reply(zone_notify_info, reply, tgt) ==
+                    _REPLY_OK):
                     self._notify_next_target(zone_notify_info)
 
         else:
@@ -453,7 +457,7 @@ class NotifyOut:
                             _MAX_NOTIFY_TRY_NUM)
                 self._notify_next_target(zone_notify_info)
             else:
-                # set exponential backoff according rfc1996 section 3.6
+                # set exponential backoff according to rfc1996 section 3.6
                 retry_timeout = (_NOTIFY_TIMEOUT *
                                  pow(2, zone_notify_info.notify_try_num))
                 zone_notify_info.notify_timeout = time.time() + retry_timeout

+ 66 - 6
src/lib/python/isc/notify/tests/notify_out_test.py

@@ -25,10 +25,23 @@ from isc.dns import *
 
 TESTDATA_SRCDIR = os.getenv("TESTDATASRCDIR")
 
+def get_notify_msgdata(zone_name, qid=0):
+    m = Message(Message.RENDER)
+    m.set_opcode(Opcode.NOTIFY)
+    m.set_rcode(Rcode.NOERROR)
+    m.set_qid(qid)
+    m.set_header_flag(Message.HEADERFLAG_QR)
+    m.add_question(Question(zone_name, RRClass.IN, RRType.SOA))
+
+    renderer = MessageRenderer()
+    m.to_wire(renderer)
+    return renderer.get_data()
+
 # our fake socket, where we can read and insert messages
 class MockSocket():
     def __init__(self):
         self._local_sock, self._remote_sock = socket.socketpair()
+        self.__raise_on_recv = False # see set_raise_on_recv()
 
     def connect(self, to):
         pass
@@ -44,6 +57,8 @@ class MockSocket():
         return self._local_sock.send(data)
 
     def recvfrom(self, length):
+        if self.__raise_on_recv:
+            raise socket.error('fake error')
         data = self._local_sock.recv(length)
         return (data, None)
 
@@ -51,6 +66,14 @@ class MockSocket():
     def remote_end(self):
         return self._remote_sock
 
+    def set_raise_on_recv(self, on):
+        """A helper to force recvfrom() to raise an exception or cancel it.
+
+        The next call to recvfrom() will result in an exception iff parameter
+        'on' (bool) is set to True.
+        """
+        self.__raise_on_recv = on
+
 # We subclass the ZoneNotifyInfo class we're testing here, only
 # to override the create_socket() method.
 class MockZoneNotifyInfo(notify_out.ZoneNotifyInfo):
@@ -341,7 +364,6 @@ class TestNotifyOut(unittest.TestCase):
         self._notify.send_notify('example.net.')
 
         example_net_info = self._notify._notify_infos[('example.net.', 'IN')]
-        example_net_info.prepare_notify_out()
 
         # On timeout, the request will be resent until try_num reaches the max
         self.assertEqual([], sent_addrs)
@@ -367,14 +389,52 @@ class TestNotifyOut(unittest.TestCase):
         self.assertRaises(AssertionError, self._notify._zone_notify_handler,
                           example_net_info, notify_out._EVENT_TIMEOUT + 1)
 
-        cur_tgt = example_net_info._notify_current
+    def test_zone_notify_read_handler(self):
+        """Similar to the previous test, but focus on the READ events.
+
+        """
+        sent_addrs = []
+        def _fake_send_notify_message_udp(notify_info, addrinfo):
+            sent_addrs.append(addrinfo)
+            pass
+        self._notify._send_notify_message_udp = _fake_send_notify_message_udp
+        self._notify.send_notify('example.net.')
+
+        example_net_info = self._notify._notify_infos[('example.net.', 'IN')]
         example_net_info.create_socket('127.0.0.1')
-        # dns message, will result in bad_qid, but what we are testing
-        # here is whether handle_notify_reply is called correctly
-        example_net_info._sock.remote_end().send(b'\x2f\x18\xa0\x00\x00\x01\x00\x00\x00\x00\x00\x00\x07example\03com\x00\x00\x06\x00\x01')
+
+        # A successful case: an expected notify response is received, and
+        # another notify will be sent to the next slave immediately.
+        example_net_info._sock.remote_end().send(
+            get_notify_msgdata(Name('example.net')))
+        self._notify._zone_notify_handler(example_net_info,
+                                          notify_out._EVENT_READ)
+        self.assertEqual(1, example_net_info.notify_try_num)
+        expected_sent_addrs = [('192.0.2.1', 5353)]
+        self.assertEqual(expected_sent_addrs, sent_addrs)
+        self.assertEqual(('192.0.2.1', 5353), example_net_info._notify_current)
+
+        # response's QID doesn't match.  the request will be resent.
+        example_net_info._sock.remote_end().send(
+            get_notify_msgdata(Name('example.net'), qid=1))
         self._notify._zone_notify_handler(example_net_info,
                                           notify_out._EVENT_READ)
-        self.assertNotEqual(cur_tgt, example_net_info._notify_current)
+        self.assertEqual(2, example_net_info.notify_try_num)
+        expected_sent_addrs.append(('192.0.2.1', 5353))
+        self.assertEqual(expected_sent_addrs, sent_addrs)
+        self.assertEqual(('192.0.2.1', 5353), example_net_info._notify_current)
+
+        # emulate exception from socket.recvfrom().  It will have the same
+        # effect as a bad response.
+        example_net_info._sock.set_raise_on_recv(True)
+        example_net_info._sock.remote_end().send(
+            get_notify_msgdata(Name('example.net')))
+        self._notify._zone_notify_handler(example_net_info,
+                                          notify_out._EVENT_READ)
+        self.assertEqual(3, example_net_info.notify_try_num)
+        expected_sent_addrs.append(('192.0.2.1', 5353))
+        self.assertEqual(expected_sent_addrs, sent_addrs)
+        self.assertEqual(('192.0.2.1', 5353), example_net_info._notify_current)
 
     def test_get_notify_slaves_from_ns(self):
         records = self._notify._get_notify_slaves_from_ns(Name('example.net.'),