#!@PYTHON@ # Copyright (C) 2010 Internet Systems Consortium. # # Permission to use, copy, modify, and distribute this software for any # purpose with or without fee is hereby granted, provided that the above # copyright notice and this permission notice appear in all copies. # # THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM # DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL # INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT, # INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING # FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, # NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION # WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # $Id$ import sys; sys.path.append ('@@PYTHONPATH@@') import os import signal import isc import asyncore import struct import threading import socket import random from optparse import OptionParser, OptionValueError from isc.config.ccsession import * try: from bind10_dns import * except ImportError as e: # C++ loadable module may not be installed; even so the xfrin process # must keep running, so we warn about it and move forward. sys.stderr.write('[b10-xfrin] failed to import DNS module: %s\n' % str(e)) # If B10_FROM_SOURCE is set in the environment, we use data files # from a directory relative to that, otherwise we use the ones # installed on the system if "B10_FROM_SOURCE" in os.environ: SPECFILE_PATH = os.environ["B10_FROM_SOURCE"] + "/src/bin/xfrin" else: PREFIX = "@prefix@" DATAROOTDIR = "@datarootdir@" SPECFILE_PATH = "@datadir@/@PACKAGE@".replace("${datarootdir}", DATAROOTDIR).replace("${prefix}", PREFIX) SPECFILE_LOCATION = SPECFILE_PATH + "/xfrin.spec" __version__ = 'BIND10' # define xfrin rcode XFRIN_OK = 0 def log_error(msg): sys.stderr.write("[b10-xfrin] ") sys.stderr.write(str(msg)) sys.stderr.write('\n') class XfrinException(Exception): pass class XfrinConnection(asyncore.dispatcher): '''Do xfrin in this class. ''' def __init__(self, zone_name, db_file, shutdown_event, master_addr, port = 53, verbose = False, idle_timeout = 60): ''' idle_timeout: max idle time for read data from socket. db_file: specify the data source file. check_soa: when it's true, check soa first before sending xfr query ''' asyncore.dispatcher.__init__(self) self.create_socket(socket.AF_INET, socket.SOCK_STREAM) self._zone_name = zone_name self._db_file = db_file self._soa_rr_count = 0 self._idle_timeout = idle_timeout self.setblocking(1) self._shutdown_event = shutdown_event self._verbose = verbose self._master_addr = master_addr self._port = port def connect_to_master(self): '''Connect to master in TCP.''' try: self.connect((self._master_addr, self._port)) return True except socket.error as e: self.log_msg('Failed to connect:(%s:%d), %s' % (self._master_addr, self._port, str(e))) return False def _create_query(self, query_type): '''Create dns query message. ''' msg = message(message_mode.RENDER) query_id = random.randint(1, 0xFFFF) self._query_id = query_id msg.set_qid(query_id) msg.set_opcode(op_code.QUERY()) msg.set_rcode(rcode.NOERROR()) query_question = question(name(self._zone_name), rr_class.IN(), query_type) msg.add_question(query_question) return msg def _send_data(self, data): size = len(data) total_count = 0 while total_count < size: count = self.send(data[total_count:]) total_count += count def _send_query(self, query_type): '''Send query message over TCP. ''' msg = self._create_query(query_type) obuf = output_buffer(0) render = message_render(obuf) msg.to_wire(render) header_len = struct.pack('H', socket.htons(obuf.get_length())) self._send_data(header_len) self._send_data(obuf.get_data()) def _get_request_response(self, size): recv_size = 0 data = b'' while recv_size < size: self._recv_time_out = True self._need_recv_size = size - recv_size asyncore.loop(self._idle_timeout, count = 1) if self._recv_time_out: raise XfrinException('receive data from socket time out.') recv_size += self._recvd_size data += self._recvd_data return data def _check_soa_serial(self): ''' Compare the soa serial, if soa serial in master is less than the soa serial in local, Finish xfrin. False: soa serial in master is less or equal to the local one. True: soa serial in master is bigger ''' self._send_query(rr_type.SOA()) data_size = self._get_request_response(2) soa_reply = self._get_request_response(int(data_size)) #TODO, need select soa record from data source then compare the two #serial, current just return OK, since this function hasn't been used now return XFRIN_OK def do_xfrin(self, check_soa, ixfr_first = False): '''Do xfr by sending xfr request and parsing response. ''' try: ret = XFRIN_OK if check_soa: ret = self._check_soa_serial() logstr = 'transfer of \'%s\': AXFR ' % self._zone_name if ret == XFRIN_OK: self.log_msg(logstr + 'started') self._send_query(rr_type.AXFR()) isc.auth.sqlite3_ds.load(self._db_file, self._zone_name, self._handle_xfrin_response) self.log_msg(logstr + 'succeeded') except XfrinException as e: self.log_msg(e) self.log_msg(logstr + 'failed') #TODO, recover data source. except isc.auth.sqlite3_ds.Sqlite3DSError as e: self.log_msg(e) self.log_msg(logstr + 'failed') finally: self.close() return ret def _check_response_status(self, msg): '''Check validation of xfr response. ''' #TODO, check more? msg_rcode = msg.get_rcode() if msg_rcode != rcode.NOERROR(): raise XfrinException('error response: %s' % msg_rcode.to_text()) if not msg.get_header_flag(message_flag.QR()): raise XfrinException('response is not a response ') if msg.get_qid() != self._query_id: raise XfrinException('bad query id') if msg.get_rr_count(section.ANSWER()) == 0: raise XfrinException('answer section is empty') if msg.get_rr_count(section.QUESTION()) > 1: raise XfrinException('query section count greater than 1') def _handle_answer_section(self, rrset_iter): '''Return a generator for the reponse in one tcp package to a zone transfer.''' while not rrset_iter.is_last(): rrset = rrset_iter.get_rrset() rrset_iter.next() rrset_name = rrset.get_name().to_text() rrset_ttl = int(rrset.get_ttl().to_text()) rrset_class = rrset.get_class().to_text() rrset_type = rrset.get_type().to_text() rdata_iter = rrset.get_rdata_iterator() rdata_iter.first() while not rdata_iter.is_last(): # Count the soa record count if rrset.get_type() == rr_type.SOA(): self._soa_rr_count += 1 # XXX: the current DNS message parser can't preserve the # RR order or separete the beginning and ending SOA RRs. # As a short term workaround, we simply ignore the second # SOA, and ignore the erroneous case where the transfer # session doesn't end with an SOA. if (self._soa_rr_count == 2): # Avoid inserting soa record twice break rdata_text = rdata_iter.get_current().to_text() yield (rrset_name, rrset_ttl, rrset_class, rrset_type, rdata_text) rdata_iter.next() def _handle_xfrin_response(self): '''Return a generator for the response to a zone transfer. ''' while True: data_len = self._get_request_response(2) msg_len = socket.htons(struct.unpack('H', data_len)[0]) recvdata = self._get_request_response(msg_len) msg = message(message_mode.PARSE) msg.from_wire(input_buffer(recvdata)) self._check_response_status(msg) rrset_iter = section_iter(msg, section.ANSWER()) for rr in self._handle_answer_section(rrset_iter): yield rr if self._soa_rr_count == 2: break if self._shutdown_event.is_set(): raise XfrinException('xfrin is forced to stop') def handle_read(self): '''Read query's response from socket. ''' self._recvd_data = self.recv(self._need_recv_size) self._recvd_size = len(self._recvd_data) self._recv_time_out = False def writable(self): '''Ignore the writable socket. ''' return False def log_info(self, msg, type='info'): # Overwrite the log function, log nothing pass def log_msg(self, msg): if self._verbose: sys.stdout.write('[b10-xfrin] ') sys.stdout.write(str(msg)) sys.stdout.write('\n') def process_xfrin(xfrin_recorder, zone_name, db_file, shutdown_event, master_addr, port, check_soa, verbose): port = int(port) xfrin_recorder.increment(zone_name) conn = XfrinConnection(zone_name, db_file, shutdown_event, master_addr, port, verbose) if conn.connect_to_master(): conn.do_xfrin(check_soa) xfrin_recorder.decrement(zone_name) class XfrinRecorder(): def __init__(self): self._lock = threading.Lock() self._zones = [] def increment(self, zone_name): self._lock.acquire() self._zones.append(zone_name) self._lock.release() def decrement(self, zone_name): self._lock.acquire() if zone_name in self._zones: self._zones.remove(zone_name) self._lock.release() def xfrin_in_progress(self, zone_name): self._lock.acquire() ret = zone_name in self._zones self._lock.release() return ret def count(self): self._lock.acquire() ret = len(self._zones) self._lock.release() return ret class Xfrin(): def __init__(self, verbose = False): self._cc = isc.config.ModuleCCSession(SPECFILE_LOCATION, self.config_handler, self.command_handler) self._cc.start() self._max_transfers_in = 10 self.recorder = XfrinRecorder() self._shutdown_event = threading.Event() self._verbose = verbose def config_handler(self, new_config): # TODO, process new config data return create_answer(0) def shutdown(self): ''' shutdown the xfrin process. the thread which is doing xfrin should be terminated. ''' self._shutdown_event.set() main_thread = threading.currentThread() for th in threading.enumerate(): if th is main_thread: continue th.join() def command_handler(self, command, args): answer = create_answer(0) cmd = command try: if cmd == 'shutdown': self._shutdown_event.set() elif cmd == 'retransfer': zone_name, master, port, db_file = self._parse_cmd_params(args) ret = self.xfrin_start(zone_name, db_file, master, port, False) answer = create_answer(ret[0], ret[1]) elif cmd == 'refresh': zone_name, master, port, db_file = self._parse_cmd_params(args) ret = self.xfrin_start(zone_name, db_file, master, port) answer = create_answer(ret[0], ret[1]) except XfrinException as err: answer = create_answer(1, str(err)) return answer def _parse_cmd_params(self, args): zone_name = args.get('zone_name') if not zone_name: raise XfrinException('zone name should be provided') master = args.get('master') if not master: raise XfrinException('master address should be provided') check_addr(master) port = 53 port_str = args.get('port') if port_str: port = int(port_str) check_port(port) db_file = args.get('db_file') if not db_file: #TODO, the db file path should be got in auth server's configuration db_file = '@@LOCALSTATEDIR@@/@PACKAGE@/zone.sqlite3' return (zone_name, master, port, db_file) def startup(self): while not self._shutdown_event.is_set(): self._cc.check_command() def xfrin_start(self, zone_name, db_file, master_addr, port = 53, check_soa = True): if "bind10_dns" not in sys.modules: return (1, "xfrin failed, can't load dns message python library: 'bind10_dns'") # check max_transfer_in, else return quota error if self.recorder.count() >= self._max_transfers_in: return (1, 'xfrin quota error') if self.recorder.xfrin_in_progress(zone_name): return (1, 'zone xfrin is in progress') xfrin_thread = threading.Thread(target = process_xfrin, args = (self.recorder, zone_name, db_file, self._shutdown_event, master_addr, port, check_soa, self._verbose)) xfrin_thread.start() return (0, 'zone xfrin is started') xfrind = None def signal_handler(signal, frame): if xfrind: xfrind.shutdown() sys.exit(0) def set_signal_handler(): signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) def check_port(value): if (value < 0) or (value > 65535): raise XfrinException('requires a port number (0-65535)') def check_addr(ipstr): ip_family = socket.AF_INET if (ipstr.find(':') != -1): ip_family = socket.AF_INET6 try: socket.inet_pton(ip_family, ipstr) except: raise XfrinException("%s invalid ip address" % ipstr) def set_cmd_options(parser): parser.add_option("-v", "--verbose", dest="verbose", action="store_true", help="display more about what is going on") if __name__ == '__main__': try: parser = OptionParser(version = __version__) set_cmd_options(parser) (options, args) = parser.parse_args() set_signal_handler() xfrind = Xfrin(verbose = options.verbose) xfrind.startup() except KeyboardInterrupt: log_error("exit b10-xfrin") except isc.cc.session.SessionError as e: log_error(str(e)) log_error('Error happened! is the command channel daemon running?') except Exception as e: log_error(str(e)) if xfrind: xfrind.shutdown()