#!@PYTHON@ # Copyright (C) 2010 Internet Systems Consortium. # Copyright (C) 2010 CZ NIC # # Permission to use, copy, modify, and distribute this software for any # purpose with or without fee is hereby granted, provided that the above # copyright notice and this permission notice appear in all copies. # # THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM # DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL # INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT, # INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING # FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, # NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION # WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import sys; sys.path.append ('@@PYTHONPATH@@') import os import signal import isc import asyncore import struct import threading import socket import random from optparse import OptionParser, OptionValueError from isc.config.ccsession import * from isc.notify import notify_out import isc.util.process import isc.net.parse try: from pydnspp import * except ImportError as e: # C++ loadable module may not be installed; even so the xfrin process # must keep running, so we warn about it and move forward. sys.stderr.write('[b10-xfrin] failed to import DNS module: %s\n' % str(e)) isc.util.process.rename() # If B10_FROM_BUILD is set in the environment, we use data files # from a directory relative to that, otherwise we use the ones # installed on the system if "B10_FROM_BUILD" in os.environ: SPECFILE_PATH = os.environ["B10_FROM_BUILD"] + "/src/bin/xfrin" AUTH_SPECFILE_PATH = os.environ["B10_FROM_BUILD"] + "/src/bin/auth" else: PREFIX = "@prefix@" DATAROOTDIR = "@datarootdir@" SPECFILE_PATH = "@datadir@/@PACKAGE@".replace("${datarootdir}", DATAROOTDIR).replace("${prefix}", PREFIX) AUTH_SPECFILE_PATH = SPECFILE_PATH SPECFILE_LOCATION = SPECFILE_PATH + "/xfrin.spec" AUTH_SPECFILE_LOCATION = AUTH_SPECFILE_PATH + "/auth.spec" XFROUT_MODULE_NAME = 'Xfrout' ZONE_MANAGER_MODULE_NAME = 'Zonemgr' REFRESH_FROM_ZONEMGR = 'refresh_from_zonemgr' ZONE_XFRIN_FAILED = 'zone_xfrin_failed' __version__ = 'BIND10' # define xfrin rcode XFRIN_OK = 0 XFRIN_FAIL = 1 DEFAULT_MASTER_PORT = '53' DEFAULT_MASTER = '127.0.0.1' def log_error(msg): sys.stderr.write("[b10-xfrin] %s\n" % str(msg)) class XfrinException(Exception): pass class XfrinConnection(asyncore.dispatcher): '''Do xfrin in this class. ''' def __init__(self, sock_map, zone_name, rrclass, db_file, shutdown_event, master_addrinfo, verbose = False, idle_timeout = 60): ''' idle_timeout: max idle time for read data from socket. db_file: specify the data source file. check_soa: when it's true, check soa first before sending xfr query ''' asyncore.dispatcher.__init__(self, map=sock_map) self.create_socket(master_addrinfo[0], master_addrinfo[1]) self._zone_name = zone_name self._sock_map = sock_map self._rrclass = rrclass self._db_file = db_file self._soa_rr_count = 0 self._idle_timeout = idle_timeout self.setblocking(1) self._shutdown_event = shutdown_event self._verbose = verbose self._master_address = master_addrinfo[2] def connect_to_master(self): '''Connect to master in TCP.''' try: self.connect(self._master_address) return True except socket.error as e: self.log_msg('Failed to connect:(%s), %s' % (self._master_address, str(e))) return False def _create_query(self, query_type): '''Create dns query message. ''' msg = Message(Message.RENDER) query_id = random.randint(0, 0xFFFF) self._query_id = query_id msg.set_qid(query_id) msg.set_opcode(Opcode.QUERY()) msg.set_rcode(Rcode.NOERROR()) query_question = Question(Name(self._zone_name), self._rrclass, query_type) msg.add_question(query_question) return msg def _send_data(self, data): size = len(data) total_count = 0 while total_count < size: count = self.send(data[total_count:]) total_count += count def _send_query(self, query_type): '''Send query message over TCP. ''' msg = self._create_query(query_type) render = MessageRenderer() msg.to_wire(render) header_len = struct.pack('H', socket.htons(render.get_length())) self._send_data(header_len) self._send_data(render.get_data()) def _asyncore_loop(self): ''' This method is a trivial wrapper for asyncore.loop(). It's extracted from _get_request_response so that we can test the rest of the code without involving actual communication with a remote server.''' asyncore.loop(self._idle_timeout, map=self._sock_map, count=1) def _get_request_response(self, size): recv_size = 0 data = b'' while recv_size < size: self._recv_time_out = True self._need_recv_size = size - recv_size self._asyncore_loop() if self._recv_time_out: raise XfrinException('receive data from socket time out.') recv_size += self._recvd_size data += self._recvd_data return data def _check_soa_serial(self): ''' Compare the soa serial, if soa serial in master is less than the soa serial in local, Finish xfrin. False: soa serial in master is less or equal to the local one. True: soa serial in master is bigger ''' self._send_query(RRType("SOA")) data_len = self._get_request_response(2) msg_len = socket.htons(struct.unpack('H', data_len)[0]) soa_response = self._get_request_response(msg_len) msg = Message(Message.PARSE) msg.from_wire(soa_response) # perform some minimal level validation. It's an open issue how # strict we should be (see the comment in _check_response_header()) self._check_response_header(msg) # TODO, need select soa record from data source then compare the two # serial, current just return OK, since this function hasn't been used # now. return XFRIN_OK def do_xfrin(self, check_soa, ixfr_first = False): '''Do xfr by sending xfr request and parsing response. ''' try: ret = XFRIN_OK if check_soa: logstr = 'SOA check for \'%s\' ' % self._zone_name ret = self._check_soa_serial() logstr = 'transfer of \'%s\': AXFR ' % self._zone_name if ret == XFRIN_OK: self.log_msg(logstr + 'started') # TODO: .AXFR() RRType.AXFR() self._send_query(RRType(252)) isc.datasrc.sqlite3_ds.load(self._db_file, self._zone_name, self._handle_xfrin_response) self.log_msg(logstr + 'succeeded') except XfrinException as e: self.log_msg(e) self.log_msg(logstr + 'failed') ret = XFRIN_FAIL #TODO, recover data source. except isc.datasrc.sqlite3_ds.Sqlite3DSError as e: self.log_msg(e) self.log_msg(logstr + 'failed') ret = XFRIN_FAIL except UserWarning as e: # XXX: this is an exception from our C++ library via the # Boost.Python binding. It would be better to have more more # specific exceptions, but at this moment this is the finest # granularity. self.log_msg(e) self.log_msg(logstr + 'failed') ret = XFRIN_FAIL finally: self.close() return ret def _check_response_header(self, msg): '''Perform minimal validation on responses''' # It's not clear how strict we should be about response validation. # BIND 9 ignores some cases where it would normally be considered a # bogus response. For example, it accepts a response even if its # opcode doesn't match that of the corresponding request. # According to an original developer of BIND 9 some of the missing # checks are deliberate to be kind to old implementations that would # cause interoperability trouble with stricter checks. msg_rcode = msg.get_rcode() if msg_rcode != Rcode.NOERROR(): raise XfrinException('error response: %s' % msg_rcode.to_text()) if not msg.get_header_flag(Message.HEADERFLAG_QR): raise XfrinException('response is not a response ') if msg.get_qid() != self._query_id: raise XfrinException('bad query id') def _check_response_status(self, msg): '''Check validation of xfr response. ''' self._check_response_header(msg) if msg.get_rr_count(Message.SECTION_ANSWER) == 0: raise XfrinException('answer section is empty') if msg.get_rr_count(Message.SECTION_QUESTION) > 1: raise XfrinException('query section count greater than 1') def _handle_answer_section(self, answer_section): '''Return a generator for the reponse in one tcp package to a zone transfer.''' for rrset in answer_section: rrset_name = rrset.get_name().to_text() rrset_ttl = int(rrset.get_ttl().to_text()) rrset_class = rrset.get_class().to_text() rrset_type = rrset.get_type().to_text() for rdata in rrset.get_rdata(): # Count the soa record count if rrset.get_type() == RRType("SOA"): self._soa_rr_count += 1 # XXX: the current DNS message parser can't preserve the # RR order or separete the beginning and ending SOA RRs. # As a short term workaround, we simply ignore the second # SOA, and ignore the erroneous case where the transfer # session doesn't end with an SOA. if (self._soa_rr_count == 2): # Avoid inserting soa record twice break rdata_text = rdata.to_text() yield (rrset_name, rrset_ttl, rrset_class, rrset_type, rdata_text) def _handle_xfrin_response(self): '''Return a generator for the response to a zone transfer. ''' while True: data_len = self._get_request_response(2) msg_len = socket.htons(struct.unpack('H', data_len)[0]) recvdata = self._get_request_response(msg_len) msg = Message(Message.PARSE) msg.from_wire(recvdata) self._check_response_status(msg) answer_section = msg.get_section(Message.SECTION_ANSWER) for rr in self._handle_answer_section(answer_section): yield rr if self._soa_rr_count == 2: break if self._shutdown_event.is_set(): raise XfrinException('xfrin is forced to stop') def handle_read(self): '''Read query's response from socket. ''' self._recvd_data = self.recv(self._need_recv_size) self._recvd_size = len(self._recvd_data) self._recv_time_out = False def writable(self): '''Ignore the writable socket. ''' return False def log_info(self, msg, type='info'): # Overwrite the log function, log nothing pass def log_msg(self, msg): if self._verbose: sys.stdout.write('[b10-xfrin] %s\n' % str(msg)) def process_xfrin(server, xfrin_recorder, zone_name, rrclass, db_file, shutdown_event, master_addrinfo, check_soa, verbose): xfrin_recorder.increment(zone_name) sock_map = {} conn = XfrinConnection(sock_map, zone_name, rrclass, db_file, shutdown_event, master_addrinfo, verbose) ret = XFRIN_FAIL if conn.connect_to_master(): ret = conn.do_xfrin(check_soa) # Publish the zone transfer result news, so zonemgr can reset the # zone timer, and xfrout can notify the zone's slaves if the result # is success. server.publish_xfrin_news(zone_name, rrclass, ret) xfrin_recorder.decrement(zone_name) class XfrinRecorder: def __init__(self): self._lock = threading.Lock() self._zones = [] def increment(self, zone_name): self._lock.acquire() self._zones.append(zone_name) self._lock.release() def decrement(self, zone_name): self._lock.acquire() if zone_name in self._zones: self._zones.remove(zone_name) self._lock.release() def xfrin_in_progress(self, zone_name): self._lock.acquire() ret = zone_name in self._zones self._lock.release() return ret def count(self): self._lock.acquire() ret = len(self._zones) self._lock.release() return ret class Xfrin: def __init__(self, verbose = False): self._max_transfers_in = 10 #TODO, this is the temp way to set the zone's master. self._master_addr = DEFAULT_MASTER self._master_port = DEFAULT_MASTER_PORT self._cc_setup() self.recorder = XfrinRecorder() self._shutdown_event = threading.Event() self._verbose = verbose def _cc_setup(self): '''This method is used only as part of initialization, but is implemented separately for convenience of unit tests; by letting the test code override this method we can test most of this class without requiring a command channel.''' # Create one session for sending command to other modules, because the # listening session will block the send operation. self._send_cc_session = isc.cc.Session() self._module_cc = isc.config.ModuleCCSession(SPECFILE_LOCATION, self.config_handler, self.command_handler) self._module_cc.start() config_data = self._module_cc.get_full_config() self._max_transfers_in = config_data.get("transfers_in") self._master_addr = config_data.get('master_addr') or self._master_addr self._master_port = config_data.get('master_port') or self._master_port def _cc_check_command(self): '''This is a straightforward wrapper for cc.check_command, but provided as a separate method for the convenience of unit tests.''' self._module_cc.check_command(False) def config_handler(self, new_config): self._max_transfers_in = new_config.get("transfers_in") or self._max_transfers_in if ('master_addr' in new_config) or ('master_port' in new_config): # User should change the port and address together. try: addr = new_config.get('master_addr') or self._master_addr port = new_config.get('master_port') or self._master_port isc.net.parse.addr_parse(addr) isc.net.parse.port_parse(port) self._master_addr = addr self._master_port = port except ValueError: errmsg = "bad format for zone's master: " + str(new_config) log_error(errmsg) return create_answer(1, errmsg) return create_answer(0) def shutdown(self): ''' shutdown the xfrin process. the thread which is doing xfrin should be terminated. ''' self._shutdown_event.set() main_thread = threading.currentThread() for th in threading.enumerate(): if th is main_thread: continue th.join() def command_handler(self, command, args): answer = create_answer(0) try: if command == 'shutdown': self._shutdown_event.set() elif command == 'notify' or command == REFRESH_FROM_ZONEMGR: # Xfrin receives the refresh/notify command from zone manager. # notify command maybe has the parameters which # specify the notifyfrom address and port, according the RFC1996, zone # transfer should starts first from the notifyfrom, but now, let 'TODO' it. (zone_name, rrclass) = self._parse_zone_name_and_class(args) (master_addr) = build_addr_info(self._master_addr, self._master_port) ret = self.xfrin_start(zone_name, rrclass, self._get_db_file(), master_addr, True) answer = create_answer(ret[0], ret[1]) elif command == 'retransfer' or command == 'refresh': # Xfrin receives the retransfer/refresh from cmdctl(sent by bindctl). # If the command has specified master address, do transfer from the # master address, or else do transfer from the configured masters. (zone_name, rrclass) = self._parse_zone_name_and_class(args) master_addr = self._parse_master_and_port(args) db_file = args.get('db_file') or self._get_db_file() ret = self.xfrin_start(zone_name, rrclass, db_file, master_addr, (False if command == 'retransfer' else True)) answer = create_answer(ret[0], ret[1]) else: answer = create_answer(1, 'unknown command: ' + command) except XfrinException as err: log_error('error happened for command: %s, %s' % (command, str(err)) ) answer = create_answer(1, str(err)) return answer def _parse_zone_name_and_class(self, args): zone_name = args.get('zone_name') if not zone_name: raise XfrinException('zone name should be provided') rrclass = args.get('zone_class') if not rrclass: rrclass = RRClass.IN() else: try: rrclass = RRClass(rrclass) except InvalidRRClass as e: raise XfrinException('invalid RRClass: ' + rrclass) return zone_name, rrclass def _parse_master_and_port(self, args): port = args.get('port') or self._master_port master = args.get('master') or self._master_addr return build_addr_info(master, port) def _get_db_file(self): #TODO, the db file path should be got in auth server's configuration # if we need access to this configuration more often, we # should add it on start, and not remove it here # (or, if we have writable ds, we might not need this in # the first place) self._module_cc.add_remote_config(AUTH_SPECFILE_LOCATION) db_file, is_default = self._module_cc.get_remote_config_value("Auth", "database_file") if is_default and "B10_FROM_BUILD" in os.environ: # this too should be unnecessary, but currently the # 'from build' override isn't stored in the config # (and we don't have writable datasources yet) db_file = os.environ["B10_FROM_BUILD"] + os.sep + "bind10_zones.sqlite3" self._module_cc.remove_remote_config(AUTH_SPECFILE_LOCATION) return db_file def publish_xfrin_news(self, zone_name, zone_class, xfr_result): '''Send command to xfrout/zone manager module. If xfrin has finished successfully for one zone, tell the good news(command: zone_new_data_ready) to zone manager and xfrout. if xfrin failed, just tell the bad news to zone manager, so that it can reset the refresh timer for that zone. ''' param = {'zone_name': zone_name, 'zone_class': zone_class.to_text()} if xfr_result == XFRIN_OK: msg = create_command(notify_out.ZONE_NEW_DATA_READY_CMD, param) # catch the exception, in case msgq has been killed. try: seq = self._send_cc_session.group_sendmsg(msg, XFROUT_MODULE_NAME) try: answer, env = self._send_cc_session.group_recvmsg(False, seq) except isc.cc.session.SessionTimeout: pass # for now we just ignore the failure seq = self._send_cc_session.group_sendmsg(msg, ZONE_MANAGER_MODULE_NAME) try: answer, env = self._send_cc_session.group_recvmsg(False, seq) except isc.cc.session.SessionTimeout: pass # for now we just ignore the failure except socket.error as err: log_error("Fail to send message to %s and %s, msgq may has been killed" % (XFROUT_MODULE_NAME, ZONE_MANAGER_MODULE_NAME)) else: msg = create_command(ZONE_XFRIN_FAILED, param) # catch the exception, in case msgq has been killed. try: seq = self._send_cc_session.group_sendmsg(msg, ZONE_MANAGER_MODULE_NAME) try: answer, env = self._send_cc_session.group_recvmsg(False, seq) except isc.cc.session.SessionTimeout: pass # for now we just ignore the failure except socket.error as err: log_error("Fail to send message to %s, msgq may has been killed" % ZONE_MANAGER_MODULE_NAME) def startup(self): while not self._shutdown_event.is_set(): self._cc_check_command() def xfrin_start(self, zone_name, rrclass, db_file, master_addrinfo, check_soa = True): if "pydnspp" not in sys.modules: return (1, "xfrin failed, can't load dns message python library: 'pydnspp'") # check max_transfer_in, else return quota error if self.recorder.count() >= self._max_transfers_in: return (1, 'xfrin quota error') if self.recorder.xfrin_in_progress(zone_name): return (1, 'zone xfrin is in progress') xfrin_thread = threading.Thread(target = process_xfrin, args = (self, self.recorder, zone_name, rrclass, db_file, self._shutdown_event, master_addrinfo, check_soa, self._verbose)) xfrin_thread.start() return (0, 'zone xfrin is started') xfrind = None def signal_handler(signal, frame): if xfrind: xfrind.shutdown() sys.exit(0) def set_signal_handler(): signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) def build_addr_info(addrstr, portstr): """ Return tuple (family, socktype, sockaddr) for given address and port. IPv4 and IPv6 are the only supported addresses now, so sockaddr will be (address, port). The socktype is socket.SOCK_STREAM for now. """ try: port = isc.net.parse.port_parse(portstr) addr = isc.net.parse.addr_parse(addrstr) return (addr.family, socket.SOCK_STREAM, (addrstr, port)) except ValueError as err: raise XfrinException("failed to resolve master address/port=%s/%s: %s" % (addrstr, portstr, str(err))) def set_cmd_options(parser): parser.add_option("-v", "--verbose", dest="verbose", action="store_true", help="display more about what is going on") def main(xfrin_class, use_signal = True): """The main loop of the Xfrin daemon. @param xfrin_class: A class of the Xfrin object. This is normally Xfrin, but can be a subclass of it for customization. @param use_signal: True if this process should catch signals. This is normally True, but may be disabled when this function is called in a testing context.""" global xfrind try: parser = OptionParser(version = __version__) set_cmd_options(parser) (options, args) = parser.parse_args() if use_signal: set_signal_handler() xfrind = xfrin_class(verbose = options.verbose) xfrind.startup() except KeyboardInterrupt: log_error("exit b10-xfrin") except isc.cc.session.SessionError as e: log_error(str(e)) log_error('Error happened! is the command channel daemon running?') except Exception as e: log_error(str(e)) if xfrind: xfrind.shutdown() if __name__ == '__main__': main(Xfrin)