Browse Source

cleanup: eliminated the need for isc.auth.sqlite3_ds.AXFRInDB by directly calling sqlite3_ds.load() from the xfrin module.

Other cleanups:
- cosmetic: removed redundant blank lines and white spaces after EOL
- grammar fix in comments
- catch Sqlite3DSError explicitly (but I suspect the exception handling
  in the xfrin module is naive overall, which should be fixed)


git-svn-id: svn://bind10.isc.org/svn/bind10/trunk@1684 e5f2f494-b856-4b98-b285-d166d9295462
JINMEI Tatuya 15 years ago
parent
commit
613b2880fb
2 changed files with 24 additions and 105 deletions
  1. 17 26
      src/bin/xfrin/xfrin.py.in
  2. 7 79
      src/lib/python/isc/auth/sqlite3_ds.py

+ 17 - 26
src/bin/xfrin/xfrin.py.in

@@ -85,7 +85,6 @@ class XfrinConnection(asyncore.dispatcher):
         self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
         self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
         self._zone_name = zone_name
         self._zone_name = zone_name
         self._db_file = db_file
         self._db_file = db_file
-        self._axfrin_db = isc.auth.sqlite3_ds.AXFRInDB(self._db_file, self._zone_name)
         self._soa_rr_count = 0
         self._soa_rr_count = 0
         self._idle_timeout = idle_timeout
         self._idle_timeout = idle_timeout
         self.setblocking(1)
         self.setblocking(1)
@@ -112,7 +111,6 @@ class XfrinConnection(asyncore.dispatcher):
             count = self.send(data[total_count:])
             count = self.send(data[total_count:])
             total_count += count
             total_count += count
 
 
-
     def _send_query(self, query_type):
     def _send_query(self, query_type):
         '''Send query message over TCP. '''
         '''Send query message over TCP. '''
         msg = self._create_query(query_type)
         msg = self._create_query(query_type)
@@ -123,7 +121,6 @@ class XfrinConnection(asyncore.dispatcher):
 
 
         self._send_data(header_len)
         self._send_data(header_len)
         self._send_data(obuf.get_data())
         self._send_data(obuf.get_data())
-
     
     
     def _get_request_response(self, size):
     def _get_request_response(self, size):
         recv_size = 0
         recv_size = 0
@@ -140,14 +137,12 @@ class XfrinConnection(asyncore.dispatcher):
 
 
         return data
         return data
 
 
-
     def handle_read(self):
     def handle_read(self):
         '''Read query's response from socket. '''
         '''Read query's response from socket. '''
         self._recvd_data = self.recv(self._need_recv_size)
         self._recvd_data = self.recv(self._need_recv_size)
         self._recvd_size = len(self._recvd_data)
         self._recvd_size = len(self._recvd_data)
         self._recv_time_out = False
         self._recv_time_out = False
 
 
-
     def _check_soa_serial(self):
     def _check_soa_serial(self):
         ''' Compare the soa serial, if soa serial in master is less than
         ''' Compare the soa serial, if soa serial in master is less than
         the soa serial in local, Finish xfrin.
         the soa serial in local, Finish xfrin.
@@ -169,9 +164,9 @@ class XfrinConnection(asyncore.dispatcher):
             
             
             self.log_msg('transfer of \'%s\': AXFR started' % self._zone_name)
             self.log_msg('transfer of \'%s\': AXFR started' % self._zone_name)
             if ret == XFRIN_OK:    
             if ret == XFRIN_OK:    
-                self._axfrin_db.prepare_axfrin()
                 self._send_query(rr_type.AXFR())
                 self._send_query(rr_type.AXFR())
-                ret = self._handle_xfrin_response()
+                isc.auth.sqlite3_ds.load(self._db_file, self._zone_name,
+                                         self._handle_xfrin_response)
 
 
             endmsg = 'succeeded' if ret == XFRIN_OK else 'failed'
             endmsg = 'succeeded' if ret == XFRIN_OK else 'failed'
             self.log_msg('transfer of \'%s\' AXFR %s' % (self._zone_name,
             self.log_msg('transfer of \'%s\' AXFR %s' % (self._zone_name,
@@ -179,11 +174,11 @@ class XfrinConnection(asyncore.dispatcher):
         except XfrinException as e:
         except XfrinException as e:
             self.log_msg(e)
             self.log_msg(e)
             self.log_msg('Error happened during xfrin!')
             self.log_msg('Error happened during xfrin!')
-            #TODO, recover data source. 
+            #TODO, recover data source.
+        except isc.auth.sqlite3_ds.Sqlite3DSError as e:
+            self.log_msg(e)
         finally:
         finally:
            self.close()
            self.close()
-           if ret == XFRIN_OK:
-               self._axfrin_db.finish_axfrin()
 
 
         return ret
         return ret
     
     
@@ -204,7 +199,6 @@ class XfrinConnection(asyncore.dispatcher):
         if msg.get_rr_count(section.QUESTION()) > 1:
         if msg.get_rr_count(section.QUESTION()) > 1:
             raise XfrinException('query section count greater than 1')
             raise XfrinException('query section count greater than 1')
 
 
-
     def _handle_answer_section(self, rrset_iter):
     def _handle_answer_section(self, rrset_iter):
         while not rrset_iter.is_last():
         while not rrset_iter.is_last():
             rrset = rrset_iter.get_rrset()
             rrset = rrset_iter.get_rrset()
@@ -231,11 +225,10 @@ class XfrinConnection(asyncore.dispatcher):
                         break
                         break
 
 
                 rdata_text = rdata_iter.get_current().to_text()
                 rdata_text = rdata_iter.get_current().to_text()
-                rr_data = (rrset_name, rrset_ttl, rrset_class, rrset_type, rdata_text)
-                self._axfrin_db.insert_axfr_record([rr_data]) 
+                yield (rrset_name, rrset_ttl, rrset_class, rrset_type,
+                       rdata_text)
                 rdata_iter.next()
                 rdata_iter.next()
 
 
-
     def _handle_xfrin_response(self):
     def _handle_xfrin_response(self):
         while True:
         while True:
             data_len = self._get_request_response(2)
             data_len = self._get_request_response(2)
@@ -246,23 +239,21 @@ class XfrinConnection(asyncore.dispatcher):
             self._check_response_status(msg)
             self._check_response_status(msg)
             
             
             rrset_iter = section_iter(msg, section.ANSWER())
             rrset_iter = section_iter(msg, section.ANSWER())
-            self._handle_answer_section(rrset_iter)
+            for rr in self._handle_answer_section(rrset_iter):
+                yield rr
+
             if self._soa_rr_count == 2:
             if self._soa_rr_count == 2:
-                return XFRIN_OK
+                break
             
             
             if self._shutdown_event.is_set():
             if self._shutdown_event.is_set():
                 #Check if xfrin process is shutdown.
                 #Check if xfrin process is shutdown.
                 #TODO, xfrin may be blocked in one loop. 
                 #TODO, xfrin may be blocked in one loop. 
                 raise XfrinException('xfrin is forced to stop')
                 raise XfrinException('xfrin is forced to stop')
 
 
-        return XFRIN_OK
-
-
     def writable(self):
     def writable(self):
         '''Ignore the writable socket. '''
         '''Ignore the writable socket. '''
         return False
         return False
 
 
-
     def log_info(self, msg, type='info'):
     def log_info(self, msg, type='info'):
         # Overwrite the log function, log nothing
         # Overwrite the log function, log nothing
         pass
         pass
@@ -276,7 +267,7 @@ class XfrinConnection(asyncore.dispatcher):
 
 
 def process_xfrin(xfrin_recorder, zone_name, db_file, 
 def process_xfrin(xfrin_recorder, zone_name, db_file, 
                   shutdown_event, master_addr, port, check_soa, verbose):
                   shutdown_event, master_addr, port, check_soa, verbose):
-    xfrin_recorder.increment(zone_name)
+    xfrin_recorder.increment(name)
     try:
     try:
         conn = XfrinConnection(zone_name, db_file, shutdown_event, 
         conn = XfrinConnection(zone_name, db_file, shutdown_event, 
                            master_addr, int(port), check_soa, verbose)
                            master_addr, int(port), check_soa, verbose)
@@ -405,12 +396,12 @@ class Xfrin():
         if self.recorder.xfrin_in_progress(zone_name):
         if self.recorder.xfrin_in_progress(zone_name):
             return (1, 'zone xfrin is in progress')
             return (1, 'zone xfrin is in progress')
 
 
-        xfrin_thread = threading.Thread(target = process_xfrin, 
-                                        args = (self.recorder, 
-                                                zone_name, 
-                                                db_file, 
+        xfrin_thread = threading.Thread(target = process_xfrin,
+                                        args = (self.recorder,
+                                                zone_name,
+                                                db_file,
                                                 self._shutdown_event,
                                                 self._shutdown_event,
-                                                master_addr, 
+                                                master_addr,
                                                 port, check_soa, self._verbose))
                                                 port, check_soa, self._verbose))
                                                 
                                                 
         xfrin_thread.start()
         xfrin_thread.start()

+ 7 - 79
src/lib/python/isc/auth/sqlite3_ds.py

@@ -148,17 +148,20 @@ def reverse_name(name):
         new.pop(0)
         new.pop(0)
     return '.'.join(new)+'.'
     return '.'.join(new)+'.'
 
 
-
 #########################################################################
 #########################################################################
 # load:
 # load:
 #   load a zone into the SQL database.
 #   load a zone into the SQL database.
 # input:
 # input:
 #   dbfile: the sqlite3 database fileanme
 #   dbfile: the sqlite3 database fileanme
 #   zone: the zone origin
 #   zone: the zone origin
-#   reader: an generator function producing an iterable set of
+#   reader: a generator function producing an iterable set of
 #           name/ttl/class/rrtype/rdata-text tuples
 #           name/ttl/class/rrtype/rdata-text tuples
 #########################################################################
 #########################################################################
 def load(dbfile, zone, reader):
 def load(dbfile, zone, reader):
+    # if the zone name doesn't contain the trailing dot, automatically add it.
+    if zone[-1] != '.':
+        zone += '.'
+
     conn, cur = open(dbfile)
     conn, cur = open(dbfile)
     old_zone_id = get_zoneid(zone, cur)
     old_zone_id = get_zoneid(zone, cur)
 
 
@@ -184,13 +187,13 @@ def load(dbfile, zone, reader):
                                 rdtype, sigtype, rdata)
                                 rdtype, sigtype, rdata)
                                VALUES (?, ?, ?, ?, ?, ?, ?)""",
                                VALUES (?, ?, ?, ?, ?, ?, ?)""",
                             [new_zone_id, name, reverse_name(name), ttl,
                             [new_zone_id, name, reverse_name(name), ttl,
-                            rdtype, sigtype, rdata])
+                             rdtype, sigtype, rdata])
             else:
             else:
                 cur.execute("""INSERT INTO records
                 cur.execute("""INSERT INTO records
                                (zone_id, name, rname, ttl, rdtype, rdata)
                                (zone_id, name, rname, ttl, rdtype, rdata)
                                VALUES (?, ?, ?, ?, ?, ?)""",
                                VALUES (?, ?, ?, ?, ?, ?)""",
                             [new_zone_id, name, reverse_name(name), ttl,
                             [new_zone_id, name, reverse_name(name), ttl,
-                            rdtype, rdata])
+                             rdtype, rdata])
     except Exception as e:
     except Exception as e:
         fail = "Error while loading " + zone + ": " + e.args[0]
         fail = "Error while loading " + zone + ": " + e.args[0]
         raise Sqlite3DSError(fail)
         raise Sqlite3DSError(fail)
@@ -208,78 +211,3 @@ def load(dbfile, zone, reader):
 
 
     cur.close()
     cur.close()
     conn.close()
     conn.close()
-
-
-#########################################################################
-# temp sqlite3 datasource backend for axfr in. The code should be refectored 
-# later.
-#########################################################################
-class AXFRInDB:
-    def __init__(self, dbfile, zone_name):
-        self._dbfile = dbfile
-        self._zone_name = zone_name
-        # if the zone name doesn't contain the trailing dot, automatically
-        # add it.
-        if self._zone_name[-1] != '.':
-            self._zone_name += '.'
-        self._old_zone_id = None
-        self._new_zone_id = None
-
-    def prepare_axfrin(self):
-        self._conn, self._cur = open(self._dbfile)
-        self._old_zone_id = get_zoneid(self._zone_name, self._cur)
-
-        temp = str(random.randrange(100000))
-        self._cur.execute("INSERT INTO zones (name, rdclass) VALUES (?, 'IN')", [temp])
-        self._new_zone_id = self._cur.lastrowid
-
-
-    def insert_axfr_record(self, rrsets):
-        '''insert zone records to sqlite3 database'''
-
-        try:
-            for name, ttl, rdclass, rdtype, rdata in rrsets:
-                sigtype = ''
-                if rdtype.lower() == 'rrsig':
-                    sigtype = rdata.split()[0]
-
-                if rdtype.lower() == 'nsec3' or sigtype.lower() == 'nsec3':
-                    hash = name.split('.')[0]
-                    self._cur.execute("""INSERT INTO nsec3
-                                   (zone_id, hash, owner, ttl, rdtype, rdata)
-                                   VALUES (?, ?, ?, ?, ?, ?)""",
-                                [self._new_zone_id, hash, name, ttl, rdtype, rdata])
-                elif rdtype.lower() == 'rrsig':
-                    self._cur.execute("""INSERT INTO records
-                                   (zone_id, name, rname, ttl,
-                                    rdtype, sigtype, rdata)
-                                   VALUES (?, ?, ?, ?, ?, ?, ?)""",
-                                [self._new_zone_id, name, reverse_name(name), ttl,
-                                rdtype, sigtype, rdata])
-                else:
-                    self._cur.execute("""INSERT INTO records
-                                   (zone_id, name, rname, ttl, rdtype, rdata)
-                                   VALUES (?, ?, ?, ?, ?, ?)""",
-                                [self._new_zone_id, name, reverse_name(name), ttl,
-                                rdtype, rdata])
-        except Exception as e:
-            fail = "Error while loading " + self._zone_name + ": " + e.args[0]
-            raise Sqlite3DSError(fail)
-
-
-    def finish_axfrin(self):
-        '''commit changes and close sqlite3 database'''
-
-        if self._old_zone_id:
-            self._cur.execute("DELETE FROM zones WHERE id=?", [self._old_zone_id])
-            self._cur.execute("UPDATE zones SET name=? WHERE id=?", [self._zone_name, self._new_zone_id])
-            self._conn.commit()
-            self._cur.execute("DELETE FROM records WHERE zone_id=?", [self._old_zone_id])
-            self._cur.execute("DELETE FROM nsec3 WHERE zone_id=?", [self._old_zone_id])
-            self._conn.commit()
-        else:
-            self._cur.execute("UPDATE zones SET name=? WHERE id=?", [self._zone_name, self._new_zone_id])
-            self._conn.commit()
-
-        self._cur.close()
-        self._conn.close()