Browse Source

committed the proposed patch from ticket #334 to trunk

git-svn-id: svn://bind10.isc.org/svn/bind10/trunk@2931 e5f2f494-b856-4b98-b285-d166d9295462
Jelte Jansen 14 years ago
parent
commit
e0009e9dfa
2 changed files with 80 additions and 39 deletions
  1. 52 6
      src/bin/xfrout/tests/xfrout_test.py
  2. 28 33
      src/bin/xfrout/xfrout.py.in

+ 52 - 6
src/bin/xfrout/tests/xfrout_test.py

@@ -40,8 +40,12 @@ class MySocket():
         return len(data)
 
     def readsent(self):
-        result = self.sendqueue[:]
-        del self.sendqueue[:]
+        if len(self.sendqueue) >= 2:
+            size = 2 + struct.unpack("!H", self.sendqueue[:2])[0]
+        else:
+            size = 0
+        result = self.sendqueue[:size]
+        self.sendqueue = self.sendqueue[size:]
         return result
     
     def read_msg(self):
@@ -133,7 +137,7 @@ class TestXfroutSession(unittest.TestCase):
 
         msg = self.getmsg()
         msg.make_response()
-        self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa)
+        self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, 0)
         get_msg = self.sock.read_msg()
 
         self.assertEqual(get_msg.get_rr_count(Section.QUESTION()), 1)
@@ -148,10 +152,52 @@ class TestXfroutSession(unittest.TestCase):
         rdata = answer.get_rdata()
         self.assertEqual(rdata[0].to_text(), self.soa_record[7])
 
-    def test_get_message_len(self):
+    def test_trigger_send_message_with_last_soa(self):
+        rrset_a = RRset(Name("example.com"), RRClass.IN(), RRType.A(), RRTTL(3600))
+        rrset_a.add_rdata(Rdata(RRType.A(), RRClass.IN(), "192.0.2.1"))
+        rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
+
         msg = self.getmsg()
-        msg.make_response()  
-        self.assertEqual(self.xfrsess._get_message_len(msg), 29)
+        msg.make_response()
+
+        msg.add_rrset(Section.ANSWER(), rrset_a)
+        # give the function a value that is larger than MAX-len(rrset)
+        self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, 65520)
+
+        # this should have triggered the sending of two messages
+        # (1 with the rrset we added manually, and 1 that triggered
+        # the sending in _with_last_soa)
+        get_msg = self.sock.read_msg()
+        self.assertEqual(get_msg.get_rr_count(Section.QUESTION()), 1)
+        self.assertEqual(get_msg.get_rr_count(Section.ANSWER()), 1)
+        self.assertEqual(get_msg.get_rr_count(Section.AUTHORITY()), 0)
+
+        answer = get_msg.get_section(Section.ANSWER())[0]
+        self.assertEqual(answer.get_name().to_text(), "example.com.")
+        self.assertEqual(answer.get_class(), RRClass("IN"))
+        self.assertEqual(answer.get_type().to_text(), "A")
+        rdata = answer.get_rdata()
+        self.assertEqual(rdata[0].to_text(), "192.0.2.1")
+
+        get_msg = self.sock.read_msg()
+        self.assertEqual(get_msg.get_rr_count(Section.QUESTION()), 0)
+        self.assertEqual(get_msg.get_rr_count(Section.ANSWER()), 1)
+        self.assertEqual(get_msg.get_rr_count(Section.AUTHORITY()), 0)
+
+        #answer_rrset_iter = section_iter(get_msg, section.ANSWER())
+        answer = get_msg.get_section(Section.ANSWER())[0]
+        self.assertEqual(answer.get_name().to_text(), "example.com.")
+        self.assertEqual(answer.get_class(), RRClass("IN"))
+        self.assertEqual(answer.get_type().to_text(), "SOA")
+        rdata = answer.get_rdata()
+        self.assertEqual(rdata[0].to_text(), self.soa_record[7])
+
+        # and it should not have sent anything else
+        self.assertEqual(0, len(self.sock.sendqueue))
+
+    def test_get_rrset_len(self):
+        rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
+        self.assertEqual(82, get_rrset_len(rrset_soa))
 
     def test_zone_is_empty(self):
         global sqlite3_ds

+ 28 - 33
src/bin/xfrout/xfrout.py.in

@@ -57,6 +57,15 @@ AUTH_SPECFILE_LOCATION = AUTH_SPECFILE_PATH + os.sep + "auth.spec"
 MAX_TRANSFERS_OUT = 10
 VERBOSE_MODE = False
 
+XFROUT_MAX_MESSAGE_SIZE = 65535
+
+def get_rrset_len(rrset):
+    """Returns the wire length of the given RRset"""
+    bytes = bytearray()
+    rrset.to_wire(bytes)
+    return len(bytes)
+
+
 class XfroutSession(BaseRequestHandler):
     def __init__(self, request, client_address, server, log):
         # The initializer for the superclass may call functions
@@ -121,10 +130,8 @@ class XfroutSession(BaseRequestHandler):
 
 
     def _send_message(self, sock, msg):
-        #obuf = output_buffer(0)
-        #render = message_render(obuf)
         render = MessageRenderer()
-        render.set_length_limit(65535)
+        render.set_length_limit(XFROUT_MAX_MESSAGE_SIZE)
         msg.to_wire(render)
         header_len = struct.pack('H', socket.htons(render.get_length()))
         self._send_data(sock, header_len)
@@ -227,34 +234,20 @@ class XfroutSession(BaseRequestHandler):
         rrset_.add_rdata(rdata_)
         return rrset_
          
-    def _send_message_with_last_soa(self, msg, sock, rrset_soa):
+    def _send_message_with_last_soa(self, msg, sock, rrset_soa, message_upper_len):
         '''Add the SOA record to the end of message. If it can't be
         added, a new message should be created to send out the last soa .
         '''
+        rrset_len = get_rrset_len(rrset_soa)
 
-        render = MessageRenderer()
-        msg.to_wire(render)
-        old_message_len = render.get_length()
-        msg.add_rrset(Section.ANSWER(), rrset_soa)
-
-        msg.to_wire(render)
-        message_len = render.get_length()
-
-        if message_len != old_message_len:
-            self._send_message(sock, msg)
+        if message_upper_len + rrset_len < XFROUT_MAX_MESSAGE_SIZE:
+            msg.add_rrset(Section.ANSWER(), rrset_soa)
         else:
+            self._send_message(sock, msg)
             msg = self._clear_message(msg)
             msg.add_rrset(Section.ANSWER(), rrset_soa)
-            self._send_message(sock, msg)
 
-    def _get_message_len(self, msg):
-        '''Get message length, every time need do like this? Actually there should be 
-        a better way, I need check with jinmei later.
-        '''
-
-        render = MessageRenderer()
-        msg.to_wire(render)
-        return render.get_length()
+        self._send_message(sock, msg)
 
 
     def _reply_xfrout_query(self, msg, sock, zone_name):
@@ -265,9 +258,8 @@ class XfroutSession(BaseRequestHandler):
         rrset_soa = self._create_rrset_from_db_record(soa_record)
         msg.add_rrset(Section.ANSWER(), rrset_soa)
 
-        old_message_len = 0
-        # TODO, Since add_rrset() return nothing when rrset can't be added, so I have to compare
-        # the message length to know if the rrset has been added sucessfully.
+        message_upper_len = get_rrset_len(rrset_soa)
+
         for rr_data in sqlite3_ds.get_zone_datas(zone_name, self.server.get_db_file()):
             if  self.server._shutdown_event.is_set(): # Check if xfrout is shutdown
                 self._log.log_message("error", "shutdown!")
@@ -277,19 +269,22 @@ class XfroutSession(BaseRequestHandler):
                 continue
 
             rrset_ = self._create_rrset_from_db_record(rr_data)
-            msg.add_rrset(Section.ANSWER(), rrset_)
-            message_len = self._get_message_len(msg)
-            if message_len != old_message_len:
-                old_message_len = message_len
+
+            # We calculate the maximum size of the RRset (i.e. the
+            # size without compression) and use that to see if we
+            # may have reached the limit
+            rrset_len = get_rrset_len(rrset_)
+            if message_upper_len + rrset_len < XFROUT_MAX_MESSAGE_SIZE:
+                msg.add_rrset(Section.ANSWER(), rrset_)
+                message_upper_len += rrset_len
                 continue
 
             self._send_message(sock, msg)
             msg = self._clear_message(msg)
             msg.add_rrset(Section.ANSWER(), rrset_) # Add the rrset to the new message
-            old_message_len = 0
-
-        self._send_message_with_last_soa(msg, sock, rrset_soa)
+            message_upper_len = rrset_len
 
+        self._send_message_with_last_soa(msg, sock, rrset_soa, message_upper_len)
 
 class UnixSockServer(ThreadingUnixStreamServer):
     '''The unix domain socket server which accept xfr query sent from auth server.'''