#!@PYTHON@ # Copyright (C) 2010 Internet Systems Consortium. # # Permission to use, copy, modify, and distribute this software for any # purpose with or without fee is hereby granted, provided that the above # copyright notice and this permission notice appear in all copies. # # THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM # DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL # INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT, # INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING # FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, # NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION # WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import sys; sys.path.append ('@@PYTHONPATH@@') import isc import isc.cc import threading import struct import signal from isc.datasrc import sqlite3_ds from socketserver import * import os from isc.config.ccsession import * from isc.cc import SessionError, SessionTimeout from isc.notify import notify_out import isc.util.process import socket import select import errno from optparse import OptionParser, OptionValueError from isc.util import socketserver_mixin from xfrout_messages import * isc.log.init("b10-xfrout") logger = isc.log.Logger("xfrout") try: from libutil_io_python import * from pydnspp import * except ImportError as e: # C++ loadable module may not be installed; even so the xfrout process # must keep running, so we warn about it and move forward. log.error(XFROUT_IMPORT, str(e)) from isc.acl.acl import ACCEPT, REJECT, DROP from isc.acl.dns import REQUEST_LOADER isc.util.process.rename() def init_paths(): global SPECFILE_PATH global AUTH_SPECFILE_PATH global UNIX_SOCKET_FILE if "B10_FROM_BUILD" in os.environ: SPECFILE_PATH = os.environ["B10_FROM_BUILD"] + "/src/bin/xfrout" AUTH_SPECFILE_PATH = os.environ["B10_FROM_BUILD"] + "/src/bin/auth" if "B10_FROM_SOURCE_LOCALSTATEDIR" in os.environ: UNIX_SOCKET_FILE = os.environ["B10_FROM_SOURCE_LOCALSTATEDIR"] + \ "/auth_xfrout_conn" else: UNIX_SOCKET_FILE = os.environ["B10_FROM_BUILD"] + "/auth_xfrout_conn" else: PREFIX = "@prefix@" DATAROOTDIR = "@datarootdir@" SPECFILE_PATH = "@datadir@/@PACKAGE@".replace("${datarootdir}", DATAROOTDIR).replace("${prefix}", PREFIX) AUTH_SPECFILE_PATH = SPECFILE_PATH if "BIND10_XFROUT_SOCKET_FILE" in os.environ: UNIX_SOCKET_FILE = os.environ["BIND10_XFROUT_SOCKET_FILE"] else: UNIX_SOCKET_FILE = "@@LOCALSTATEDIR@@/auth_xfrout_conn" init_paths() SPECFILE_LOCATION = SPECFILE_PATH + "/xfrout.spec" AUTH_SPECFILE_LOCATION = AUTH_SPECFILE_PATH + os.sep + "auth.spec" MAX_TRANSFERS_OUT = 10 VERBOSE_MODE = False # tsig sign every N axfr packets. TSIG_SIGN_EVERY_NTH = 96 XFROUT_MAX_MESSAGE_SIZE = 65535 def get_rrset_len(rrset): """Returns the wire length of the given RRset""" bytes = bytearray() rrset.to_wire(bytes) return len(bytes) class XfroutSession(): def __init__(self, sock_fd, request_data, server, tsig_key_ring, remote, acl): self._sock_fd = sock_fd self._request_data = request_data self._server = server self._tsig_key_ring = tsig_key_ring self._tsig_ctx = None self._tsig_len = 0 self._remote = remote self._acl = acl self.handle() def create_tsig_ctx(self, tsig_record, tsig_key_ring): return TSIGContext(tsig_record.get_name(), tsig_record.get_rdata().get_algorithm(), tsig_key_ring) def handle(self): ''' Handle a xfrout query, send xfrout response ''' try: self.dns_xfrout_start(self._sock_fd, self._request_data) #TODO, avoid catching all exceptions except Exception as e: logger.error(XFROUT_HANDLE_QUERY_ERROR, e) pass os.close(self._sock_fd) def _check_request_tsig(self, msg, request_data): ''' If request has a tsig record, perform tsig related checks ''' tsig_record = msg.get_tsig_record() if tsig_record is not None: self._tsig_len = tsig_record.get_length() self._tsig_ctx = self.create_tsig_ctx(tsig_record, self._tsig_key_ring) tsig_error = self._tsig_ctx.verify(tsig_record, request_data) if tsig_error != TSIGError.NOERROR: return Rcode.NOTAUTH() return Rcode.NOERROR() def _parse_query_message(self, mdata): ''' parse query message to [socket,message]''' #TODO, need to add parseHeader() in case the message header is invalid try: msg = Message(Message.PARSE) Message.from_wire(msg, mdata) # TSIG related checks rcode = self._check_request_tsig(msg, mdata) if rcode == Rcode.NOERROR(): # ACL checks acl_result = self._acl.execute( isc.acl.dns.RequestContext(self._remote)) if acl_result == DROP: logger.info(XFROUT_QUERY_DROPPED, self._get_query_zone_name(msg), self._get_query_zone_class(msg), self._remote[0], self._remote[1]) return None, None elif acl_result == REJECT: logger.info(XFROUT_QUERY_REJECTED, self._get_query_zone_name(msg), self._get_query_zone_class(msg), self._remote[0], self._remote[1]) return Rcode.REFUSED(), msg except Exception as err: logger.error(XFROUT_PARSE_QUERY_ERROR, err) return Rcode.FORMERR(), None return rcode, msg def _get_query_zone_name(self, msg): question = msg.get_question()[0] return question.get_name().to_text() def _get_query_zone_class(self, msg): question = msg.get_question()[0] return question.get_class().to_text() def _send_data(self, sock_fd, data): size = len(data) total_count = 0 while total_count < size: count = os.write(sock_fd, data[total_count:]) total_count += count def _send_message(self, sock_fd, msg, tsig_ctx=None): render = MessageRenderer() # As defined in RFC5936 section3.4, perform case-preserving name # compression for AXFR message. render.set_compress_mode(MessageRenderer.CASE_SENSITIVE) render.set_length_limit(XFROUT_MAX_MESSAGE_SIZE) # XXX Currently, python wrapper doesn't accept 'None' parameter in this case, # we should remove the if statement and use a universal interface later. if tsig_ctx is not None: msg.to_wire(render, tsig_ctx) else: msg.to_wire(render) header_len = struct.pack('H', socket.htons(render.get_length())) self._send_data(sock_fd, header_len) self._send_data(sock_fd, render.get_data()) def _reply_query_with_error_rcode(self, msg, sock_fd, rcode_): if not msg: return # query message is invalid. send nothing back. msg.make_response() msg.set_rcode(rcode_) self._send_message(sock_fd, msg, self._tsig_ctx) def _zone_has_soa(self, zone): '''Judge if the zone has an SOA record.''' # In some sense, the SOA defines a zone. # If the current name server has authority for the # specific zone, we need to judge if the zone has an SOA record; # if not, we consider the zone has incomplete data, so xfrout can't # serve for it. if sqlite3_ds.get_zone_soa(zone, self._server.get_db_file()): return True return False def _zone_exist(self, zonename): '''Judge if the zone is configured by config manager.''' # Currently, if we find the zone in datasource successfully, we # consider the zone is configured, and the current name server has # authority for the specific zone. # TODO: should get zone's configuration from cfgmgr or other place # in future. return sqlite3_ds.zone_exist(zonename, self._server.get_db_file()) def _check_xfrout_available(self, zone_name): '''Check if xfr request can be responsed. TODO, Get zone's configuration from cfgmgr or some other place eg. check allow_transfer setting, ''' # If the current name server does not have authority for the # zone, xfrout can't serve for it, return rcode NOTAUTH. if not self._zone_exist(zone_name): return Rcode.NOTAUTH() # If we are an authoritative name server for the zone, but fail # to find the zone's SOA record in datasource, xfrout can't # provide zone transfer for it. if not self._zone_has_soa(zone_name): return Rcode.SERVFAIL() #TODO, check allow_transfer if not self._server.increase_transfers_counter(): return Rcode.REFUSED() return Rcode.NOERROR() def dns_xfrout_start(self, sock_fd, msg_query): rcode_, msg = self._parse_query_message(msg_query) #TODO. create query message and parse header if rcode_ is None: # Dropped by ACL return elif rcode_ == Rcode.NOTAUTH() or rcode_ == Rcode.REFUSED(): return self._reply_query_with_error_rcode(msg, sock_fd, rcode_) elif rcode_ != Rcode.NOERROR(): return self._reply_query_with_error_rcode(msg, sock_fd, Rcode.FORMERR()) zone_name = self._get_query_zone_name(msg) zone_class_str = self._get_query_zone_class(msg) # TODO: should we not also include class in the check? rcode_ = self._check_xfrout_available(zone_name) if rcode_ != Rcode.NOERROR(): logger.info(XFROUT_AXFR_TRANSFER_FAILED, zone_name, zone_class_str, rcode_.to_text()) return self._reply_query_with_error_rcode(msg, sock_fd, rcode_) try: logger.info(XFROUT_AXFR_TRANSFER_STARTED, zone_name, zone_class_str) self._reply_xfrout_query(msg, sock_fd, zone_name) except Exception as err: logger.error(XFROUT_AXFR_TRANSFER_ERROR, zone_name, zone_class_str, str(err)) pass logger.info(XFROUT_AXFR_TRANSFER_DONE, zone_name, zone_class_str) self._server.decrease_transfers_counter() return def _clear_message(self, msg): qid = msg.get_qid() opcode = msg.get_opcode() rcode = msg.get_rcode() msg.clear(Message.RENDER) msg.set_qid(qid) msg.set_opcode(opcode) msg.set_rcode(rcode) msg.set_header_flag(Message.HEADERFLAG_AA) msg.set_header_flag(Message.HEADERFLAG_QR) return msg def _create_rrset_from_db_record(self, record): '''Create one rrset from one record of datasource, if the schema of record is changed, This function should be updated first. ''' rrtype_ = RRType(record[5]) rdata_ = Rdata(rrtype_, RRClass("IN"), " ".join(record[7:])) rrset_ = RRset(Name(record[2]), RRClass("IN"), rrtype_, RRTTL( int(record[4]))) rrset_.add_rdata(rdata_) return rrset_ def _send_message_with_last_soa(self, msg, sock_fd, rrset_soa, message_upper_len, count_since_last_tsig_sign): '''Add the SOA record to the end of message. If it can't be added, a new message should be created to send out the last soa . ''' rrset_len = get_rrset_len(rrset_soa) if (count_since_last_tsig_sign == TSIG_SIGN_EVERY_NTH and message_upper_len + rrset_len >= XFROUT_MAX_MESSAGE_SIZE): # If tsig context exist, sign the packet with serial number TSIG_SIGN_EVERY_NTH self._send_message(sock_fd, msg, self._tsig_ctx) msg = self._clear_message(msg) elif (count_since_last_tsig_sign != TSIG_SIGN_EVERY_NTH and message_upper_len + rrset_len + self._tsig_len >= XFROUT_MAX_MESSAGE_SIZE): self._send_message(sock_fd, msg) msg = self._clear_message(msg) # If tsig context exist, sign the last packet msg.add_rrset(Message.SECTION_ANSWER, rrset_soa) self._send_message(sock_fd, msg, self._tsig_ctx) def _reply_xfrout_query(self, msg, sock_fd, zone_name): #TODO, there should be a better way to insert rrset. count_since_last_tsig_sign = TSIG_SIGN_EVERY_NTH msg.make_response() msg.set_header_flag(Message.HEADERFLAG_AA) soa_record = sqlite3_ds.get_zone_soa(zone_name, self._server.get_db_file()) rrset_soa = self._create_rrset_from_db_record(soa_record) msg.add_rrset(Message.SECTION_ANSWER, rrset_soa) message_upper_len = get_rrset_len(rrset_soa) + self._tsig_len for rr_data in sqlite3_ds.get_zone_datas(zone_name, self._server.get_db_file()): if self._server._shutdown_event.is_set(): # Check if xfrout is shutdown logger.info(XFROUT_STOPPING) return # TODO: RRType.SOA() ? if RRType(rr_data[5]) == RRType("SOA"): #ignore soa record continue rrset_ = self._create_rrset_from_db_record(rr_data) # We calculate the maximum size of the RRset (i.e. the # size without compression) and use that to see if we # may have reached the limit rrset_len = get_rrset_len(rrset_) if message_upper_len + rrset_len < XFROUT_MAX_MESSAGE_SIZE: msg.add_rrset(Message.SECTION_ANSWER, rrset_) message_upper_len += rrset_len continue # If tsig context exist, sign every N packets if count_since_last_tsig_sign == TSIG_SIGN_EVERY_NTH: count_since_last_tsig_sign = 0 self._send_message(sock_fd, msg, self._tsig_ctx) else: self._send_message(sock_fd, msg) count_since_last_tsig_sign += 1 msg = self._clear_message(msg) msg.add_rrset(Message.SECTION_ANSWER, rrset_) # Add the rrset to the new message # Reserve tsig space for signed packet if count_since_last_tsig_sign == TSIG_SIGN_EVERY_NTH: message_upper_len = rrset_len + self._tsig_len else: message_upper_len = rrset_len self._send_message_with_last_soa(msg, sock_fd, rrset_soa, message_upper_len, count_since_last_tsig_sign) class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer): '''The unix domain socket server which accept xfr query sent from auth server.''' def __init__(self, sock_file, handle_class, shutdown_event, config_data, cc): self._remove_unused_sock_file(sock_file) self._sock_file = sock_file socketserver_mixin.NoPollMixIn.__init__(self) ThreadingUnixStreamServer.__init__(self, sock_file, handle_class) self._shutdown_event = shutdown_event self._write_sock, self._read_sock = socket.socketpair() self._common_init() self.update_config_data(config_data) self._cc = cc def _common_init(self): self._lock = threading.Lock() self._transfers_counter = 0 self._acl = REQUEST_LOADER.load('[{"action": "ACCEPT"}]') def _receive_query_message(self, sock): ''' receive request message from sock''' # receive data length data_len = sock.recv(2) if not data_len: return None msg_len = struct.unpack('!H', data_len)[0] # receive data recv_size = 0 msgdata = b'' while recv_size < msg_len: data = sock.recv(msg_len - recv_size) if not data: return None recv_size += len(data) msgdata += data return msgdata def handle_request(self): ''' Enable server handle a request until shutdown or auth is closed.''' try: request, client_address = self.get_request() except socket.error: logger.error(XFROUT_FETCH_REQUEST_ERROR) return # Check self._shutdown_event to ensure the real shutdown comes. # Linux could trigger a spurious readable event on the _read_sock # due to a bug, so we need perform a double check. while not self._shutdown_event.is_set(): # Check if xfrout is shutdown try: (rlist, wlist, xlist) = select.select([self._read_sock, request], [], []) except select.error as e: if e.args[0] == errno.EINTR: (rlist, wlist, xlist) = ([], [], []) continue else: logger.error(XFROUT_SOCKET_SELECT_ERROR, str(e)) break # self.server._shutdown_event will be set by now, if it is not a false # alarm if self._read_sock in rlist: continue try: self.process_request(request) except Exception as pre: log.error(XFROUT_PROCESS_REQUEST_ERROR, str(pre)) break def _handle_request_noblock(self): """Override the function _handle_request_noblock(), it creates a new thread to handle requests for each auth""" td = threading.Thread(target=self.handle_request) td.setDaemon(True) td.start() def process_request(self, request): """Receive socket fd and query message from auth, then start a new thread to process the request.""" sock_fd = recv_fd(request.fileno()) if sock_fd < 0: # This may happen when one xfrout process try to connect to # xfrout unix socket server, to check whether there is another # xfrout running. if sock_fd == FD_COMM_ERROR: logger.error(XFROUT_RECEIVE_FILE_DESCRIPTOR_ERROR) return # receive request msg request_data = self._receive_query_message(request) if not request_data: return t = threading.Thread(target = self.finish_request, args = (sock_fd, request_data)) if self.daemon_threads: t.daemon = True t.start() def _guess_remote(self, sock_fd): """ Guess remote address and port of the socket. The sock_fd must be a socket """ # This uses a trick. If the socket is IPv4 in reality and we pretend # it to be IPv6, it returns IPv4 address anyway. This doesn't seem # to care about the SOCK_STREAM parameter at all (which it really is, # except for testing) if socket.has_ipv6: sock = socket.fromfd(sock_fd, socket.AF_INET6, socket.SOCK_STREAM) else: # To make it work even on hosts without IPv6 support # (Any idea how to simulate this in test?) sock = socket.fromfd(sock_fd, socket.AF_INET, socket.SOCK_STREAM) return sock.getpeername() def finish_request(self, sock_fd, request_data): '''Finish one request by instantiating RequestHandlerClass.''' self.RequestHandlerClass(sock_fd, request_data, self, self.tsig_key_ring, self._guess_remote(sock_fd), self._acl) def _remove_unused_sock_file(self, sock_file): '''Try to remove the socket file. If the file is being used by one running xfrout process, exit from python. If it's not a socket file or nobody is listening , it will be removed. If it can't be removed, exit from python. ''' if self._sock_file_in_use(sock_file): logger.error(XFROUT_UNIX_SOCKET_FILE_IN_USE, sock_file) sys.exit(0) else: if not os.path.exists(sock_file): return try: os.unlink(sock_file) except OSError as err: logger.error(XFROUT_REMOVE_OLD_UNIX_SOCKET_FILE_ERROR, sock_file, str(err)) sys.exit(0) def _sock_file_in_use(self, sock_file): '''Check whether the socket file 'sock_file' exists and is being used by one running xfrout process. If it is, return True, or else return False. ''' try: sock = socket.socket(socket.AF_UNIX) sock.connect(sock_file) except socket.error as err: return False else: return True def shutdown(self): self._write_sock.send(b"shutdown") #terminate the xfrout session thread super().shutdown() # call the shutdown() of class socketserver_mixin.NoPollMixIn try: os.unlink(self._sock_file) except Exception as e: logger.error(XFROUT_REMOVE_UNIX_SOCKET_FILE_ERROR, self._sock_file, str(e)) pass def update_config_data(self, new_config): '''Apply the new config setting of xfrout module. ''' logger.info(XFROUT_NEW_CONFIG) if 'query_acl' in new_config: self._acl = REQUEST_LOADER.load(new_config['query_acl']) self._lock.acquire() self._max_transfers_out = new_config.get('transfers_out') self.set_tsig_key_ring(new_config.get('tsig_key_ring')) self._lock.release() logger.info(XFROUT_NEW_CONFIG_DONE) def set_tsig_key_ring(self, key_list): """Set the tsig_key_ring , given a TSIG key string list representation. """ # XXX add values to configure zones/tsig options self.tsig_key_ring = TSIGKeyRing() # If key string list is empty, create a empty tsig_key_ring if not key_list: return for key_item in key_list: try: self.tsig_key_ring.add(TSIGKey(key_item)) except InvalidParameter as ipe: logger.error(XFROUT_BAD_TSIG_KEY_STRING, str(key_item)) def get_db_file(self): file, is_default = self._cc.get_remote_config_value("Auth", "database_file") # this too should be unnecessary, but currently the # 'from build' override isn't stored in the config # (and we don't have indirect python access to datasources yet) if is_default and "B10_FROM_BUILD" in os.environ: file = os.environ["B10_FROM_BUILD"] + os.sep + "bind10_zones.sqlite3" return file def increase_transfers_counter(self): '''Return False, if counter + 1 > max_transfers_out, or else return True ''' ret = False self._lock.acquire() if self._transfers_counter < self._max_transfers_out: self._transfers_counter += 1 ret = True self._lock.release() return ret def decrease_transfers_counter(self): self._lock.acquire() self._transfers_counter -= 1 self._lock.release() class XfroutServer: def __init__(self): self._unix_socket_server = None self._listen_sock_file = UNIX_SOCKET_FILE self._shutdown_event = threading.Event() self._cc = isc.config.ModuleCCSession(SPECFILE_LOCATION, self.config_handler, self.command_handler, None, True) self._config_data = self._cc.get_full_config() self._cc.start() self._cc.add_remote_config(AUTH_SPECFILE_LOCATION); self._start_xfr_query_listener() self._start_notifier() def _start_xfr_query_listener(self): '''Start a new thread to accept xfr query. ''' self._unix_socket_server = UnixSockServer(self._listen_sock_file, XfroutSession, self._shutdown_event, self._config_data, self._cc) listener = threading.Thread(target=self._unix_socket_server.serve_forever) listener.start() def _start_notifier(self): datasrc = self._unix_socket_server.get_db_file() self._notifier = notify_out.NotifyOut(datasrc) self._notifier.dispatcher() def send_notify(self, zone_name, zone_class): self._notifier.send_notify(zone_name, zone_class) def config_handler(self, new_config): '''Update config data. TODO. Do error check''' answer = create_answer(0) for key in new_config: if key not in self._config_data: answer = create_answer(1, "Unknown config data: " + str(key)) continue self._config_data[key] = new_config[key] if self._unix_socket_server: try: self._unix_socket_server.update_config_data(self._config_data) except Exception as e: answer = create_answer(1, "Failed to handle new configuration: " + str(e)) return answer def shutdown(self): ''' shutdown the xfrout process. The thread which is doing zone transfer-out should be terminated. ''' global xfrout_server xfrout_server = None #Avoid shutdown is called twice self._shutdown_event.set() self._notifier.shutdown() if self._unix_socket_server: self._unix_socket_server.shutdown() # Wait for all threads to terminate main_thread = threading.currentThread() for th in threading.enumerate(): if th is main_thread: continue th.join() def command_handler(self, cmd, args): if cmd == "shutdown": logger.info(XFROUT_RECEIVED_SHUTDOWN_COMMAND) self.shutdown() answer = create_answer(0) elif cmd == notify_out.ZONE_NEW_DATA_READY_CMD: zone_name = args.get('zone_name') zone_class = args.get('zone_class') if zone_name and zone_class: logger.info(XFROUT_NOTIFY_COMMAND, zone_name, zone_class) self.send_notify(zone_name, zone_class) answer = create_answer(0) else: answer = create_answer(1, "Bad command parameter:" + str(args)) else: answer = create_answer(1, "Unknown command:" + str(cmd)) return answer def run(self): '''Get and process all commands sent from cfgmgr or other modules. ''' while not self._shutdown_event.is_set(): self._cc.check_command(False) xfrout_server = None def signal_handler(signal, frame): if xfrout_server: xfrout_server.shutdown() sys.exit(0) def set_signal_handler(): signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) def set_cmd_options(parser): parser.add_option("-v", "--verbose", dest="verbose", action="store_true", help="display more about what is going on") if '__main__' == __name__: try: parser = OptionParser() set_cmd_options(parser) (options, args) = parser.parse_args() VERBOSE_MODE = options.verbose set_signal_handler() xfrout_server = XfroutServer() xfrout_server.run() except KeyboardInterrupt: logger.INFO(XFROUT_STOPPED_BY_KEYBOARD) except SessionError as e: logger.error(XFROUT_CC_SESSION_ERROR, str(e)) except SessionTimeout as e: logger.error(XFROUT_CC_SESSION_TIMEOUT_ERROR) if xfrout_server: xfrout_server.shutdown()