Parcourir la source

[1261] introduced the XfrinState class and implemented the InitialSOA subclass.

JINMEI Tatuya il y a 13 ans
Parent
commit
4fda2b6eef
2 fichiers modifiés avec 84 ajouts et 0 suppressions
  1. 29 0
      src/bin/xfrin/tests/xfrin_test.py
  2. 55 0
      src/bin/xfrin/xfrin.py.in

+ 29 - 0
src/bin/xfrin/tests/xfrin_test.py

@@ -174,6 +174,35 @@ class MockXfrinConnection(XfrinConnection):
 
         return reply_data
 
+class TestXfrinState(unittest.TestCase):
+    def setUp(self):
+        self.sock_map = {}
+        self.conn = MockXfrinConnection(self.sock_map, TEST_ZONE_NAME_STR,
+                                        TEST_RRCLASS, TEST_DB_FILE,
+                                        threading.Event(),
+                                        TEST_MASTER_IPV4_ADDRINFO)
+        self.ns_rrset = RRset(TEST_ZONE_NAME, TEST_RRCLASS, RRType.NS(),
+                              RRTTL(3600))
+        self.ns_rrset.add_rdata(Rdata(RRType.NS(), TEST_RRCLASS,
+                                      'ns.example.com'))
+
+class TestXfrinInitialSOA(TestXfrinState):
+    def setUp(self):
+        super().setUp()
+        self.state = XfrinInitialSOA()
+
+    def test_handle_rr(self):
+        # normal case
+        self.state.handle_rr(self.conn, soa_rrset)
+        self.assertEqual(type(XfrinFirstData()),
+                         type(self.conn.get_xfrstate()))
+        self.assertEqual(1234, self.conn._end_serial)
+
+    def test_handle_not_soa(self):
+        # The given RR is not of SOA
+        self.assertRaises(XfrinProtocolError, self.state.handle_rr, self.conn,
+                          self.ns_rrset)
+
 class TestXfrinConnection(unittest.TestCase):
     def setUp(self):
         if os.path.exists(TEST_DB_FILE):

+ 55 - 0
src/bin/xfrin/xfrin.py.in

@@ -77,6 +77,11 @@ XFRIN_FAIL = 1
 class XfrinException(Exception):
     pass
 
+class XfrinProtocolError(Exception):
+    '''An exception raised for errors encountered in xfrin protocol handling.
+    '''
+    pass
+
 class XfrinZoneInfoException(Exception):
     """This exception is raised if there is an error in the given
        configuration (part), or when a command does not have a required
@@ -112,6 +117,48 @@ def _check_zone_class(zone_class_str):
     except InvalidRRClass as irce:
         raise XfrinZoneInfoException("bad zone class: " + zone_class_str + " (" + str(irce) + ")")
 
+def get_soa_serial(soa_rdata):
+    '''Extract the serial field of an SOA RDATA and returns it as an intger.
+
+    We don't have to be very efficient here, so we first dump the entire RDATA
+    as a string and convert the first corresponding field.  This should be
+    sufficient in practice, but may not always work when the MNAME or RNAME
+    contains an (escaped) space character in their labels.  Ideally there
+    should be a more direct and convenient way to get access to the SOA
+    fields.
+    '''
+    return int(soa_rdata.to_text().split()[2])
+
+class XfrinState:
+    '''
+    The states of the incomding *XFR state machine.
+    '''
+    def set_xfrstate(self, conn, new_state):
+        '''Set the XfrConnection to a given new state
+
+        As a "friend" class, this method intentionally gets access to the
+        connection's "private" method.
+        '''
+        conn._XfrinConnection__set_xfrstate(new_state)
+
+class XfrinInitialSOA(XfrinState):
+    def handle_rr(self, conn, rr):
+        if rr.get_type() != RRType.SOA():
+            raise XfrinProtocolError('First RR in zone transfer must be SOA ('
+                                     + rr.get_type().to_text() + ' given)')
+        conn._end_serial = get_soa_serial(rr.get_rdata()[0])
+
+        # FIXME: we need to check the serial is actually greater than ours.
+        # To do so, however, we need a way to find records from datasource.
+        # Complete that part later as a separate task.  (Always performing
+        # xfr could be inefficient, but shouldn't do any harm otherwise)
+
+        self.set_xfrstate(conn, XfrinFirstData())
+
+class XfrinFirstData(XfrinState):
+    def handle_rr(self, conn, rr):
+        pass
+
 class XfrinConnection(asyncore.dispatcher):
     '''Do xfrin in this class. '''
 
@@ -125,6 +172,8 @@ class XfrinConnection(asyncore.dispatcher):
         '''
 
         asyncore.dispatcher.__init__(self, map=sock_map)
+        self.__state = None
+        self._end_serial = None # essentially private
         self.create_socket(master_addrinfo[0], master_addrinfo[1])
         self._zone_name = zone_name
         self._sock_map = sock_map
@@ -145,6 +194,12 @@ class XfrinConnection(asyncore.dispatcher):
     def __create_tsig_ctx(self, key):
         return TSIGContext(key)
 
+    def __set_xfrstate(self, new_state):
+        self.__state = new_state
+
+    def get_xfrstate(self):
+        return self.__state
+
     def connect_to_master(self):
         '''Connect to master in TCP.'''