xfrin.py.in 16 KB


  1. #!@PYTHON@
  2. # Copyright (C) 2010 Internet Systems Consortium.
  3. #
  4. # Permission to use, copy, modify, and distribute this software for any
  5. # purpose with or without fee is hereby granted, provided that the above
  6. # copyright notice and this permission notice appear in all copies.
  7. #
  8. # THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
  9. # DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
  10. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
  11. # INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
  12. # INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
  13. # FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
  14. # NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
  15. # WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  16. # $Id$
  17. import sys; sys.path.append ('@@PYTHONPATH@@')
  18. import os
  19. import signal
  20. import isc
  21. import asyncore
  22. import struct
  23. import threading
  24. import socket
  25. import random
  26. from optparse import OptionParser, OptionValueError
  27. from isc.config.ccsession import *
  28. try:
  29. from bind10_dns import *
  30. except ImportError as e:
  31. # C++ loadable module may not be installed; even so the xfrin process
  32. # must keep running, so we warn about it and move forward.
  33. sys.stderr.write('[b10-xfrin] failed to import DNS module: %s\n' % str(e))
  34. # If B10_FROM_SOURCE is set in the environment, we use data files
  35. # from a directory relative to that, otherwise we use the ones
  36. # installed on the system
  37. if "B10_FROM_SOURCE" in os.environ:
  38. SPECFILE_PATH = os.environ["B10_FROM_SOURCE"] + "/src/bin/xfrin"
  39. else:
  40. PREFIX = "@prefix@"
  41. DATAROOTDIR = "@datarootdir@"
  42. SPECFILE_PATH = "@datadir@/@PACKAGE@".replace("${datarootdir}", DATAROOTDIR).replace("${prefix}", PREFIX)
  43. SPECFILE_LOCATION = SPECFILE_PATH + "/xfrin.spec"
  44. __version__ = 'BIND10'
  45. # define xfrin rcode
  46. XFRIN_OK = 0
  47. def log_error(msg):
  48. sys.stderr.write("[b10-xfrin] ")
  49. sys.stderr.write(str(msg))
  50. sys.stderr.write('\n')
  51. class XfrinException(Exception):
  52. pass
  53. class XfrinConnection(asyncore.dispatcher):
  54. '''Do xfrin in this class. '''
  55. def __init__(self,
  56. zone_name, db_file, shutdown_event, master_addr,
  57. port = 53, verbose = False, idle_timeout = 60):
  58. ''' idle_timeout: max idle time for read data from socket.
  59. db_file: specify the data source file.
  60. check_soa: when it's true, check soa first before sending xfr query
  61. '''
  62. asyncore.dispatcher.__init__(self)
  63. self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
  64. self._zone_name = zone_name
  65. self._db_file = db_file
  66. self._soa_rr_count = 0
  67. self._idle_timeout = idle_timeout
  68. self.setblocking(1)
  69. self._shutdown_event = shutdown_event
  70. self._verbose = verbose
  71. self._master_addr = master_addr
  72. self._port = port
  73. def connect_to_master(self):
  74. '''Connect to master in TCP.'''
  75. try:
  76. self.connect((self._master_addr, self._port))
  77. return True
  78. except socket.error as e:
  79. self.log_msg('Failed to connect:(%s:%d), %s' % (self._master_addr, self._port, str(e)))
  80. return False
  81. def _create_query(self, query_type):
  82. '''Create dns query message. '''
  83. msg = message(message_mode.RENDER)
  84. query_id = random.randint(1, 0xFFFF)
  85. self._query_id = query_id
  86. msg.set_qid(query_id)
  87. msg.set_opcode(op_code.QUERY())
  88. msg.set_rcode(rcode.NOERROR())
  89. query_question = question(name(self._zone_name), rr_class.IN(), query_type)
  90. msg.add_question(query_question)
  91. return msg
  92. def _send_data(self, data):
  93. size = len(data)
  94. total_count = 0
  95. while total_count < size:
  96. count = self.send(data[total_count:])
  97. total_count += count
  98. def _send_query(self, query_type):
  99. '''Send query message over TCP. '''
  100. msg = self._create_query(query_type)
  101. obuf = output_buffer(0)
  102. render = message_render(obuf)
  103. msg.to_wire(render)
  104. header_len = struct.pack('H', socket.htons(obuf.get_length()))
  105. self._send_data(header_len)
  106. self._send_data(obuf.get_data())
  107. def _get_request_response(self, size):
  108. recv_size = 0
  109. data = b''
  110. while recv_size < size:
  111. self._recv_time_out = True
  112. self._need_recv_size = size - recv_size
  113. asyncore.loop(self._idle_timeout, count = 1)
  114. if self._recv_time_out:
  115. raise XfrinException('receive data from socket time out.')
  116. recv_size += self._recvd_size
  117. data += self._recvd_data
  118. return data
  119. def _check_soa_serial(self):
  120. ''' Compare the soa serial, if soa serial in master is less than
  121. the soa serial in local, Finish xfrin.
  122. False: soa serial in master is less or equal to the local one.
  123. True: soa serial in master is bigger
  124. '''
  125. self._send_query(rr_type.SOA())
  126. data_size = self._get_request_response(2)
  127. soa_reply = self._get_request_response(int(data_size))
  128. #TODO, need select soa record from data source then compare the two
  129. #serial, current just return OK, since this function hasn't been used now
  130. return XFRIN_OK
  131. def do_xfrin(self, check_soa, ixfr_first = False):
  132. '''Do xfr by sending xfr request and parsing response. '''
  133. try:
  134. ret = XFRIN_OK
  135. if check_soa:
  136. ret = self._check_soa_serial()
  137. logstr = 'transfer of \'%s\': AXFR ' % self._zone_name
  138. if ret == XFRIN_OK:
  139. self.log_msg(logstr + 'started')
  140. self._send_query(rr_type.AXFR())
  141. isc.datasrc.sqlite3_ds.load(self._db_file, self._zone_name,
  142. self._handle_xfrin_response)
  143. self.log_msg(logstr + 'succeeded')
  144. except XfrinException as e:
  145. self.log_msg(e)
  146. self.log_msg(logstr + 'failed')
  147. #TODO, recover data source.
  148. except isc.datasrc.sqlite3_ds.Sqlite3DSError as e:
  149. self.log_msg(e)
  150. self.log_msg(logstr + 'failed')
  151. finally:
  152. self.close()
  153. return ret
  154. def _check_response_status(self, msg):
  155. '''Check validation of xfr response. '''
  156. #TODO, check more?
  157. msg_rcode = msg.get_rcode()
  158. if msg_rcode != rcode.NOERROR():
  159. raise XfrinException('error response: %s' % msg_rcode.to_text())
  160. if not msg.get_header_flag(message_flag.QR()):
  161. raise XfrinException('response is not a response ')
  162. if msg.get_qid() != self._query_id:
  163. raise XfrinException('bad query id')
  164. if msg.get_rr_count(section.ANSWER()) == 0:
  165. raise XfrinException('answer section is empty')
  166. if msg.get_rr_count(section.QUESTION()) > 1:
  167. raise XfrinException('query section count greater than 1')
  168. def _handle_answer_section(self, rrset_iter):
  169. '''Return a generator for the reponse in one tcp package to a zone transfer.'''
  170. while not rrset_iter.is_last():
  171. rrset = rrset_iter.get_rrset()
  172. rrset_iter.next()
  173. rrset_name = rrset.get_name().to_text()
  174. rrset_ttl = int(rrset.get_ttl().to_text())
  175. rrset_class = rrset.get_class().to_text()
  176. rrset_type = rrset.get_type().to_text()
  177. rdata_iter = rrset.get_rdata_iterator()
  178. rdata_iter.first()
  179. while not rdata_iter.is_last():
  180. # Count the soa record count
  181. if rrset.get_type() == rr_type.SOA():
  182. self._soa_rr_count += 1
  183. # XXX: the current DNS message parser can't preserve the
  184. # RR order or separete the beginning and ending SOA RRs.
  185. # As a short term workaround, we simply ignore the second
  186. # SOA, and ignore the erroneous case where the transfer
  187. # session doesn't end with an SOA.
  188. if (self._soa_rr_count == 2):
  189. # Avoid inserting soa record twice
  190. break
  191. rdata_text = rdata_iter.get_current().to_text()
  192. yield (rrset_name, rrset_ttl, rrset_class, rrset_type,
  193. rdata_text)
  194. rdata_iter.next()
  195. def _handle_xfrin_response(self):
  196. '''Return a generator for the response to a zone transfer. '''
  197. while True:
  198. data_len = self._get_request_response(2)
  199. msg_len = socket.htons(struct.unpack('H', data_len)[0])
  200. recvdata = self._get_request_response(msg_len)
  201. msg = message(message_mode.PARSE)
  202. msg.from_wire(input_buffer(recvdata))
  203. self._check_response_status(msg)
  204. rrset_iter = section_iter(msg, section.ANSWER())
  205. for rr in self._handle_answer_section(rrset_iter):
  206. yield rr
  207. if self._soa_rr_count == 2:
  208. break
  209. if self._shutdown_event.is_set():
  210. raise XfrinException('xfrin is forced to stop')
  211. def handle_read(self):
  212. '''Read query's response from socket. '''
  213. self._recvd_data = self.recv(self._need_recv_size)
  214. self._recvd_size = len(self._recvd_data)
  215. self._recv_time_out = False
  216. def writable(self):
  217. '''Ignore the writable socket. '''
  218. return False
  219. def log_info(self, msg, type='info'):
  220. # Overwrite the log function, log nothing
  221. pass
  222. def log_msg(self, msg):
  223. if self._verbose:
  224. sys.stdout.write('[b10-xfrin] ')
  225. sys.stdout.write(str(msg))
  226. sys.stdout.write('\n')
  227. def process_xfrin(xfrin_recorder, zone_name, db_file,
  228. shutdown_event, master_addr, port, check_soa, verbose):
  229. port = int(port)
  230. xfrin_recorder.increment(zone_name)
  231. conn = XfrinConnection(zone_name, db_file, shutdown_event,
  232. master_addr, port, verbose)
  233. if conn.connect_to_master():
  234. conn.do_xfrin(check_soa)
  235. xfrin_recorder.decrement(zone_name)
  236. class XfrinRecorder():
  237. def __init__(self):
  238. self._lock = threading.Lock()
  239. self._zones = []
  240. def increment(self, zone_name):
  241. self._lock.acquire()
  242. self._zones.append(zone_name)
  243. self._lock.release()
  244. def decrement(self, zone_name):
  245. self._lock.acquire()
  246. if zone_name in self._zones:
  247. self._zones.remove(zone_name)
  248. self._lock.release()
  249. def xfrin_in_progress(self, zone_name):
  250. self._lock.acquire()
  251. ret = zone_name in self._zones
  252. self._lock.release()
  253. return ret
  254. def count(self):
  255. self._lock.acquire()
  256. ret = len(self._zones)
  257. self._lock.release()
  258. return ret
  259. class Xfrin():
  260. def __init__(self, verbose = False):
  261. self._cc = isc.config.ModuleCCSession(SPECFILE_LOCATION, self.config_handler, self.command_handler)
  262. self._cc.start()
  263. self._max_transfers_in = 10
  264. self.recorder = XfrinRecorder()
  265. self._shutdown_event = threading.Event()
  266. self._verbose = verbose
  267. def config_handler(self, new_config):
  268. # TODO, process new config data
  269. return create_answer(0)
  270. def shutdown(self):
  271. ''' shutdown the xfrin process. the thread which is doing xfrin should be
  272. terminated.
  273. '''
  274. self._shutdown_event.set()
  275. main_thread = threading.currentThread()
  276. for th in threading.enumerate():
  277. if th is main_thread:
  278. continue
  279. th.join()
  280. def command_handler(self, command, args):
  281. answer = create_answer(0)
  282. cmd = command
  283. try:
  284. if cmd == 'shutdown':
  285. self._shutdown_event.set()
  286. elif cmd == 'retransfer':
  287. zone_name, master, port, db_file = self._parse_cmd_params(args)
  288. ret = self.xfrin_start(zone_name, db_file, master, port, False)
  289. answer = create_answer(ret[0], ret[1])
  290. elif cmd == 'refresh':
  291. zone_name, master, port, db_file = self._parse_cmd_params(args)
  292. ret = self.xfrin_start(zone_name, db_file, master, port)
  293. answer = create_answer(ret[0], ret[1])
  294. except XfrinException as err:
  295. answer = create_answer(1, str(err))
  296. return answer
  297. def _parse_cmd_params(self, args):
  298. zone_name = args.get('zone_name')
  299. if not zone_name:
  300. raise XfrinException('zone name should be provided')
  301. master = args.get('master')
  302. if not master:
  303. raise XfrinException('master address should be provided')
  304. check_addr(master)
  305. port = 53
  306. port_str = args.get('port')
  307. if port_str:
  308. port = int(port_str)
  309. check_port(port)
  310. db_file = args.get('db_file')
  311. if not db_file:
  312. #TODO, the db file path should be got in auth server's configuration
  313. db_file = '@@LOCALSTATEDIR@@/@PACKAGE@/zone.sqlite3'
  314. return (zone_name, master, port, db_file)
  315. def startup(self):
  316. while not self._shutdown_event.is_set():
  317. self._cc.check_command()
  318. def xfrin_start(self, zone_name, db_file, master_addr,
  319. port = 53,
  320. check_soa = True):
  321. if "bind10_dns" not in sys.modules:
  322. return (1, "xfrin failed, can't load dns message python library: 'bind10_dns'")
  323. # check max_transfer_in, else return quota error
  324. if self.recorder.count() >= self._max_transfers_in:
  325. return (1, 'xfrin quota error')
  326. if self.recorder.xfrin_in_progress(zone_name):
  327. return (1, 'zone xfrin is in progress')
  328. xfrin_thread = threading.Thread(target = process_xfrin,
  329. args = (self.recorder,
  330. zone_name,
  331. db_file,
  332. self._shutdown_event,
  333. master_addr,
  334. port, check_soa, self._verbose))
  335. xfrin_thread.start()
  336. return (0, 'zone xfrin is started')
  337. xfrind = None
  338. def signal_handler(signal, frame):
  339. if xfrind:
  340. xfrind.shutdown()
  341. sys.exit(0)
  342. def set_signal_handler():
  343. signal.signal(signal.SIGTERM, signal_handler)
  344. signal.signal(signal.SIGINT, signal_handler)
  345. def check_port(value):
  346. if (value < 0) or (value > 65535):
  347. raise XfrinException('requires a port number (0-65535)')
  348. def check_addr(ipstr):
  349. ip_family = socket.AF_INET
  350. if (ipstr.find(':') != -1):
  351. ip_family = socket.AF_INET6
  352. try:
  353. socket.inet_pton(ip_family, ipstr)
  354. except:
  355. raise XfrinException("%s invalid ip address" % ipstr)
  356. def set_cmd_options(parser):
  357. parser.add_option("-v", "--verbose", dest="verbose", action="store_true",
  358. help="display more about what is going on")
  359. if __name__ == '__main__':
  360. try:
  361. parser = OptionParser(version = __version__)
  362. set_cmd_options(parser)
  363. (options, args) = parser.parse_args()
  364. set_signal_handler()
  365. xfrind = Xfrin(verbose = options.verbose)
  366. xfrind.startup()
  367. except KeyboardInterrupt:
  368. log_error("exit b10-xfrin")
  369. except isc.cc.session.SessionError as e:
  370. log_error(str(e))
  371. log_error('Error happened! is the command channel daemon running?')
  372. except Exception as e:
  373. log_error(str(e))
  374. if xfrind:
  375. xfrind.shutdown()