Browse Source

cleanup: eliminated the need for isc.auth.sqlite3_ds.AXFRInDB by directly calling sqlite3_ds.load() from the xfrin module.

Other cleanups:
- cosmetic: removed redundant blank lines and white spaces after EOL
- grammar fix in comments
- catch Sqlite3DSError explicitly (but I suspect the exception handling
  in the xfrin module is naive overall, which should be fixed)


git-svn-id: svn://bind10.isc.org/svn/bind10/trunk@1684 e5f2f494-b856-4b98-b285-d166d9295462
JINMEI Tatuya 15 years ago
parent
commit
613b2880fb
2 changed files with 24 additions and 105 deletions
  1. 17 26
      src/bin/xfrin/xfrin.py.in
  2. 7 79
      src/lib/python/isc/auth/sqlite3_ds.py

+ 17 - 26
src/bin/xfrin/xfrin.py.in

@@ -85,7 +85,6 @@ class XfrinConnection(asyncore.dispatcher):
         self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
         self._zone_name = zone_name
         self._db_file = db_file
-        self._axfrin_db = isc.auth.sqlite3_ds.AXFRInDB(self._db_file, self._zone_name)
         self._soa_rr_count = 0
         self._idle_timeout = idle_timeout
         self.setblocking(1)
@@ -112,7 +111,6 @@ class XfrinConnection(asyncore.dispatcher):
             count = self.send(data[total_count:])
             total_count += count
 
-
     def _send_query(self, query_type):
         '''Send query message over TCP. '''
         msg = self._create_query(query_type)
@@ -123,7 +121,6 @@ class XfrinConnection(asyncore.dispatcher):
 
         self._send_data(header_len)
         self._send_data(obuf.get_data())
-
     
     def _get_request_response(self, size):
         recv_size = 0
@@ -140,14 +137,12 @@ class XfrinConnection(asyncore.dispatcher):
 
         return data
 
-
     def handle_read(self):
         '''Read query's response from socket. '''
         self._recvd_data = self.recv(self._need_recv_size)
         self._recvd_size = len(self._recvd_data)
         self._recv_time_out = False
 
-
     def _check_soa_serial(self):
         ''' Compare the soa serial, if soa serial in master is less than
         the soa serial in local, Finish xfrin.
@@ -169,9 +164,9 @@ class XfrinConnection(asyncore.dispatcher):
             
             self.log_msg('transfer of \'%s\': AXFR started' % self._zone_name)
             if ret == XFRIN_OK:    
-                self._axfrin_db.prepare_axfrin()
                 self._send_query(rr_type.AXFR())
-                ret = self._handle_xfrin_response()
+                isc.auth.sqlite3_ds.load(self._db_file, self._zone_name,
+                                         self._handle_xfrin_response)
 
             endmsg = 'succeeded' if ret == XFRIN_OK else 'failed'
             self.log_msg('transfer of \'%s\' AXFR %s' % (self._zone_name,
@@ -179,11 +174,11 @@ class XfrinConnection(asyncore.dispatcher):
         except XfrinException as e:
             self.log_msg(e)
             self.log_msg('Error happened during xfrin!')
-            #TODO, recover data source. 
+            #TODO, recover data source.
+        except isc.auth.sqlite3_ds.Sqlite3DSError as e:
+            self.log_msg(e)
         finally:
            self.close()
-           if ret == XFRIN_OK:
-               self._axfrin_db.finish_axfrin()
 
         return ret
     
@@ -204,7 +199,6 @@ class XfrinConnection(asyncore.dispatcher):
         if msg.get_rr_count(section.QUESTION()) > 1:
             raise XfrinException('query section count greater than 1')
 
-
     def _handle_answer_section(self, rrset_iter):
         while not rrset_iter.is_last():
             rrset = rrset_iter.get_rrset()
@@ -231,11 +225,10 @@ class XfrinConnection(asyncore.dispatcher):
                         break
 
                 rdata_text = rdata_iter.get_current().to_text()
-                rr_data = (rrset_name, rrset_ttl, rrset_class, rrset_type, rdata_text)
-                self._axfrin_db.insert_axfr_record([rr_data]) 
+                yield (rrset_name, rrset_ttl, rrset_class, rrset_type,
+                       rdata_text)
                 rdata_iter.next()
 
-
     def _handle_xfrin_response(self):
         while True:
             data_len = self._get_request_response(2)
@@ -246,23 +239,21 @@ class XfrinConnection(asyncore.dispatcher):
             self._check_response_status(msg)
             
             rrset_iter = section_iter(msg, section.ANSWER())
-            self._handle_answer_section(rrset_iter)
+            for rr in self._handle_answer_section(rrset_iter):
+                yield rr
+
             if self._soa_rr_count == 2:
-                return XFRIN_OK
+                break
             
             if self._shutdown_event.is_set():
                 #Check if xfrin process is shutdown.
                 #TODO, xfrin may be blocked in one loop. 
                 raise XfrinException('xfrin is forced to stop')
 
-        return XFRIN_OK
-
-
     def writable(self):
         '''Ignore the writable socket. '''
         return False
 
-
     def log_info(self, msg, type='info'):
         # Overwrite the log function, log nothing
         pass
@@ -276,7 +267,7 @@ class XfrinConnection(asyncore.dispatcher):
 
 def process_xfrin(xfrin_recorder, zone_name, db_file, 
                   shutdown_event, master_addr, port, check_soa, verbose):
-    xfrin_recorder.increment(zone_name)
+    xfrin_recorder.increment(name)
     try:
         conn = XfrinConnection(zone_name, db_file, shutdown_event, 
                            master_addr, int(port), check_soa, verbose)
@@ -405,12 +396,12 @@ class Xfrin():
         if self.recorder.xfrin_in_progress(zone_name):
             return (1, 'zone xfrin is in progress')
 
-        xfrin_thread = threading.Thread(target = process_xfrin, 
-                                        args = (self.recorder, 
-                                                zone_name, 
-                                                db_file, 
+        xfrin_thread = threading.Thread(target = process_xfrin,
+                                        args = (self.recorder,
+                                                zone_name,
+                                                db_file,
                                                 self._shutdown_event,
-                                                master_addr, 
+                                                master_addr,
                                                 port, check_soa, self._verbose))
                                                 
         xfrin_thread.start()

+ 7 - 79
src/lib/python/isc/auth/sqlite3_ds.py

@@ -148,17 +148,20 @@ def reverse_name(name):
         new.pop(0)
     return '.'.join(new)+'.'
 
-
 #########################################################################
 # load:
 #   load a zone into the SQL database.
 # input:
 #   dbfile: the sqlite3 database fileanme
 #   zone: the zone origin
-#   reader: an generator function producing an iterable set of
+#   reader: a generator function producing an iterable set of
 #           name/ttl/class/rrtype/rdata-text tuples
 #########################################################################
 def load(dbfile, zone, reader):
+    # if the zone name doesn't contain the trailing dot, automatically add it.
+    if zone[-1] != '.':
+        zone += '.'
+
     conn, cur = open(dbfile)
     old_zone_id = get_zoneid(zone, cur)
 
@@ -184,13 +187,13 @@ def load(dbfile, zone, reader):
                                 rdtype, sigtype, rdata)
                                VALUES (?, ?, ?, ?, ?, ?, ?)""",
                             [new_zone_id, name, reverse_name(name), ttl,
-                            rdtype, sigtype, rdata])
+                             rdtype, sigtype, rdata])
             else:
                 cur.execute("""INSERT INTO records
                                (zone_id, name, rname, ttl, rdtype, rdata)
                                VALUES (?, ?, ?, ?, ?, ?)""",
                             [new_zone_id, name, reverse_name(name), ttl,
-                            rdtype, rdata])
+                             rdtype, rdata])
     except Exception as e:
         fail = "Error while loading " + zone + ": " + e.args[0]
         raise Sqlite3DSError(fail)
@@ -208,78 +211,3 @@ def load(dbfile, zone, reader):
 
     cur.close()
     conn.close()
-
-
-#########################################################################
-# temp sqlite3 datasource backend for axfr in. The code should be refectored 
-# later.
-#########################################################################
-class AXFRInDB:
-    def __init__(self, dbfile, zone_name):
-        self._dbfile = dbfile
-        self._zone_name = zone_name
-        # if the zone name doesn't contain the trailing dot, automatically
-        # add it.
-        if self._zone_name[-1] != '.':
-            self._zone_name += '.'
-        self._old_zone_id = None
-        self._new_zone_id = None
-
-    def prepare_axfrin(self):
-        self._conn, self._cur = open(self._dbfile)
-        self._old_zone_id = get_zoneid(self._zone_name, self._cur)
-
-        temp = str(random.randrange(100000))
-        self._cur.execute("INSERT INTO zones (name, rdclass) VALUES (?, 'IN')", [temp])
-        self._new_zone_id = self._cur.lastrowid
-
-
-    def insert_axfr_record(self, rrsets):
-        '''insert zone records to sqlite3 database'''
-
-        try:
-            for name, ttl, rdclass, rdtype, rdata in rrsets:
-                sigtype = ''
-                if rdtype.lower() == 'rrsig':
-                    sigtype = rdata.split()[0]
-
-                if rdtype.lower() == 'nsec3' or sigtype.lower() == 'nsec3':
-                    hash = name.split('.')[0]
-                    self._cur.execute("""INSERT INTO nsec3
-                                   (zone_id, hash, owner, ttl, rdtype, rdata)
-                                   VALUES (?, ?, ?, ?, ?, ?)""",
-                                [self._new_zone_id, hash, name, ttl, rdtype, rdata])
-                elif rdtype.lower() == 'rrsig':
-                    self._cur.execute("""INSERT INTO records
-                                   (zone_id, name, rname, ttl,
-                                    rdtype, sigtype, rdata)
-                                   VALUES (?, ?, ?, ?, ?, ?, ?)""",
-                                [self._new_zone_id, name, reverse_name(name), ttl,
-                                rdtype, sigtype, rdata])
-                else:
-                    self._cur.execute("""INSERT INTO records
-                                   (zone_id, name, rname, ttl, rdtype, rdata)
-                                   VALUES (?, ?, ?, ?, ?, ?)""",
-                                [self._new_zone_id, name, reverse_name(name), ttl,
-                                rdtype, rdata])
-        except Exception as e:
-            fail = "Error while loading " + self._zone_name + ": " + e.args[0]
-            raise Sqlite3DSError(fail)
-
-
-    def finish_axfrin(self):
-        '''commit changes and close sqlite3 database'''
-
-        if self._old_zone_id:
-            self._cur.execute("DELETE FROM zones WHERE id=?", [self._old_zone_id])
-            self._cur.execute("UPDATE zones SET name=? WHERE id=?", [self._zone_name, self._new_zone_id])
-            self._conn.commit()
-            self._cur.execute("DELETE FROM records WHERE zone_id=?", [self._old_zone_id])
-            self._cur.execute("DELETE FROM nsec3 WHERE zone_id=?", [self._old_zone_id])
-            self._conn.commit()
-        else:
-            self._cur.execute("UPDATE zones SET name=? WHERE id=?", [self._zone_name, self._new_zone_id])
-            self._conn.commit()
-
-        self._cur.close()
-        self._conn.close()