Browse Source

Merge branch 'master' into trac976

Stephen Morris 14 years ago
parent
commit
31a30f5485

+ 23 - 4
ChangeLog

@@ -1,16 +1,35 @@
-248.    [func]      stephen
+251.	[bug]*		jinmei
+	Make sure bindctl private files are non readable to anyone except
+	the owner or users in the same group.  Note that if BIND 10 is run
+	with changing the user, this change means that the file owner or
+	group will have to be adjusted.  Also note that this change is
+	only effective for a fresh install; if these files already exist,
+	their permissions must be adjusted by hand (if necessary).
+	(Trac870, git 461fc3cb6ebabc9f3fa5213749956467a14ebfd4)
+
+250.    [bug]           ocean
+        src/lib/util/encode, in some conditions, the DecodeNormalizer's
+        iterator may reach the end() and when later being dereferenced
+        it will cause crash on some platform.
+        (Trac838, git 83e33ec80c0c6485d8b116b13045b3488071770f)
+
+249.    [func]      	jerry
+	xfrout: add support for TSIG verification.
+	(Trac816, git 3b2040e2af2f8139c1c319a2cbc429035d93f217)
+
+248.    [func]      	stephen
 	Add file and stderr as destinations for logging.
 	(Trac555, git 38b3546867425bd64dbc5920111a843a3330646b)
 
-247.    [func]      jelte
+247.    [func]      	jelte
 	Upstream queries from the resolver now set EDNS0 buffer size.
 	(Trac834, git 48e10c2530fe52c9bde6197db07674a851aa0f5d)
 
-246.    [func]      stephen
+246.    [func]      	stephen
 	Implement logging using log4cplus (http://log4cplus.sourceforge.net)
 	(Trac899, git 31d3f525dc01638aecae460cb4bc2040c9e4df10)
 
-245.    [func]      vorner
+245.    [func]      	vorner
 	Authoritative server can now sign the answers using TSIG
 	(configured in tsig_keys/keys, list of strings like
 	"name:<base64-secret>:sha1-hmac"). It doesn't use them for

+ 1 - 1
src/bin/cfgmgr/plugins/tests/tsig_keys_test.py

@@ -86,7 +86,7 @@ class TSigKeysTest(unittest.TestCase):
         self.assertEqual("TSIG: Invalid TSIG key string: invalid.key",
                          tsig_keys.check({'keys': ['invalid.key']}))
         self.assertEqual(
-            "TSIG: attempt to decode a value not in base64 char set",
+            "TSIG: Unexpected end of input in BASE decoder",
             tsig_keys.check({'keys': ['invalid.key:123']}))
 
     def test_bad_format(self):

+ 3 - 2
src/bin/cmdctl/Makefile.am

@@ -40,12 +40,13 @@ b10-cmdctl: cmdctl.py
 
 if INSTALL_CONFIGURATIONS
 
-# TODO: permissions handled later
+# Below we intentionally use ${INSTALL} -m 640 instead of $(INSTALL_DATA)
+# because these file will contain sensitive information.
 install-data-local:
 	$(mkinstalldirs) $(DESTDIR)/@sysconfdir@/@PACKAGE@   
 	for f in $(CMDCTL_CONFIGURATIONS) ; do	\
 	  if test ! -f $(DESTDIR)$(sysconfdir)/@PACKAGE@/$$f; then	\
-	    $(INSTALL_DATA) $(srcdir)/$$f $(DESTDIR)$(sysconfdir)/@PACKAGE@/ ;	\
+	    ${INSTALL} -m 640 $(srcdir)/$$f $(DESTDIR)$(sysconfdir)/@PACKAGE@/ ;	\
 	  fi ;	\
 	done
 

+ 210 - 6
src/bin/xfrout/tests/xfrout_test.py.in

@@ -18,11 +18,14 @@
 
 import unittest
 import os
+from isc.testutils.tsigctx_mock import MockTSIGContext
 from isc.cc.session import *
 from pydnspp import *
 from xfrout import *
 import xfrout
 
+TSIG_KEY = TSIGKey("example.com:SFuWd/q99SzF8Yzd1QbB9g==")
+
 # our fake socket, where we can read and insert messages
 class MySocket():
     def __init__(self, family, type):
@@ -85,10 +88,36 @@ class TestXfroutSession(unittest.TestCase):
         msg.from_wire(self.mdata)
         return msg
 
+    def create_mock_tsig_ctx(self, error):
+        # This helper function creates a MockTSIGContext for a given key
+        # and TSIG error to be used as a result of verify (normally faked
+        # one)
+        mock_ctx = MockTSIGContext(TSIG_KEY)
+        mock_ctx.error = error
+        return mock_ctx
+
+    def message_has_tsig(self, msg):
+        return msg.get_tsig_record() is not None
+
+    def create_request_data_with_tsig(self):
+        msg = Message(Message.RENDER)
+        query_id = 0x1035
+        msg.set_qid(query_id)
+        msg.set_opcode(Opcode.QUERY())
+        msg.set_rcode(Rcode.NOERROR())
+        query_question = Question(Name("example.com."), RRClass.IN(), RRType.AXFR())
+        msg.add_question(query_question)
+
+        renderer = MessageRenderer()
+        tsig_ctx = MockTSIGContext(TSIG_KEY)
+        msg.to_wire(renderer, tsig_ctx)
+        reply_data = renderer.get_data()
+        return reply_data
+
     def setUp(self):
         self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM)
         self.log = isc.log.NSLogger('xfrout', '',  severity = 'critical', log_to_console = False )
-        self.xfrsess = MyXfroutSession(self.sock, None, Dbserver(), self.log)
+        self.xfrsess = MyXfroutSession(self.sock, None, Dbserver(), self.log, TSIGKeyRing())
         self.mdata = bytes(b'\xd6=\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x07example\x03com\x00\x00\xfc\x00\x01')
         self.soa_record = (4, 3, 'example.com.', 'com.example.', 3600, 'SOA', None, 'master.example.com. admin.example.com. 1234 3600 1800 2419200 7200')
 
@@ -96,6 +125,18 @@ class TestXfroutSession(unittest.TestCase):
         [get_rcode, get_msg] = self.xfrsess._parse_query_message(self.mdata)
         self.assertEqual(get_rcode.to_text(), "NOERROR")
 
+        # tsig signed query message
+        request_data = self.create_request_data_with_tsig()
+        # BADKEY
+        [rcode, msg] = self.xfrsess._parse_query_message(request_data)
+        self.assertEqual(rcode.to_text(), "NOTAUTH")
+        self.assertTrue(self.xfrsess._tsig_ctx is not None)
+        # NOERROR
+        self.xfrsess._tsig_key_ring.add(TSIG_KEY)
+        [rcode, msg] = self.xfrsess._parse_query_message(request_data)
+        self.assertEqual(rcode.to_text(), "NOERROR")
+        self.assertTrue(self.xfrsess._tsig_ctx is not None)
+
     def test_get_query_zone_name(self):
         msg = self.getmsg()
         self.assertEqual(self.xfrsess._get_query_zone_name(msg), "example.com.")
@@ -111,6 +152,14 @@ class TestXfroutSession(unittest.TestCase):
         get_msg = self.sock.read_msg()
         self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN")
 
+        # tsig signed message
+        msg = self.getmsg()
+        self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
+        self.xfrsess._reply_query_with_error_rcode(msg, self.sock, Rcode(3))
+        get_msg = self.sock.read_msg()
+        self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN")
+        self.assertTrue(self.message_has_tsig(get_msg))
+
     def test_send_message(self):
         msg = self.getmsg()
         msg.make_response()
@@ -152,6 +201,14 @@ class TestXfroutSession(unittest.TestCase):
         get_msg = self.sock.read_msg()
         self.assertEqual(get_msg.get_rcode().to_text(), "FORMERR")
 
+        # tsig signed message
+        msg = self.getmsg()
+        self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
+        self.xfrsess._reply_query_with_format_error(msg, self.sock)
+        get_msg = self.sock.read_msg()
+        self.assertEqual(get_msg.get_rcode().to_text(), "FORMERR")
+        self.assertTrue(self.message_has_tsig(get_msg))
+
     def test_create_rrset_from_db_record(self):
         rrset = self.xfrsess._create_rrset_from_db_record(self.soa_record)
         self.assertEqual(rrset.get_name().to_text(), "example.com.")
@@ -162,11 +219,16 @@ class TestXfroutSession(unittest.TestCase):
 
     def test_send_message_with_last_soa(self):
         rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
-
         msg = self.getmsg()
         msg.make_response()
-        self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, 0)
+
+        # packet number less than TSIG_SIGN_EVERY_NTH
+        packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1
+        self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa,
+                                                 0, packet_neet_not_sign)
         get_msg = self.sock.read_msg()
+        # tsig context is not exist
+        self.assertFalse(self.message_has_tsig(get_msg))
 
         self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 1)
         self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)
@@ -180,6 +242,42 @@ class TestXfroutSession(unittest.TestCase):
         rdata = answer.get_rdata()
         self.assertEqual(rdata[0].to_text(), self.soa_record[7])
 
+        # msg is the TSIG_SIGN_EVERY_NTH one
+        # sending the message with last soa together
+        self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa,
+                                                 0, TSIG_SIGN_EVERY_NTH)
+        get_msg = self.sock.read_msg()
+        # tsig context is not exist
+        self.assertFalse(self.message_has_tsig(get_msg))
+
+    def test_send_message_with_last_soa_with_tsig(self):
+        # create tsig context
+        self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
+
+        rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
+        msg = self.getmsg()
+        msg.make_response()
+
+        # packet number less than TSIG_SIGN_EVERY_NTH
+        packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1
+        # msg is not the TSIG_SIGN_EVERY_NTH one
+        # sending the message with last soa together
+        self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa,
+                                                 0, packet_neet_not_sign)
+        get_msg = self.sock.read_msg()
+        self.assertTrue(self.message_has_tsig(get_msg))
+
+        self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 1)
+        self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)
+        self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0)
+
+        # msg is the TSIG_SIGN_EVERY_NTH one
+        # sending the message with last soa together
+        self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa,
+                                                 0, TSIG_SIGN_EVERY_NTH)
+        get_msg = self.sock.read_msg()
+        self.assertTrue(self.message_has_tsig(get_msg))
+
     def test_trigger_send_message_with_last_soa(self):
         rrset_a = RRset(Name("example.com"), RRClass.IN(), RRType.A(), RRTTL(3600))
         rrset_a.add_rdata(Rdata(RRType.A(), RRClass.IN(), "192.0.2.1"))
@@ -187,15 +285,21 @@ class TestXfroutSession(unittest.TestCase):
 
         msg = self.getmsg()
         msg.make_response()
-
         msg.add_rrset(Message.SECTION_ANSWER, rrset_a)
-        # give the function a value that is larger than MAX-len(rrset)
-        self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, 65520)
 
+        # length larger than MAX-len(rrset)
+        length_need_split = xfrout.XFROUT_MAX_MESSAGE_SIZE - get_rrset_len(rrset_soa) + 1
+        # packet number less than TSIG_SIGN_EVERY_NTH
+        packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1
+
+        # give the function a value that is larger than MAX-len(rrset)
         # this should have triggered the sending of two messages
         # (1 with the rrset we added manually, and 1 that triggered
         # the sending in _with_last_soa)
+        self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, length_need_split,
+                                                 packet_neet_not_sign)
         get_msg = self.sock.read_msg()
+        self.assertFalse(self.message_has_tsig(get_msg))
         self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 1)
         self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)
         self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0)
@@ -208,6 +312,7 @@ class TestXfroutSession(unittest.TestCase):
         self.assertEqual(rdata[0].to_text(), "192.0.2.1")
 
         get_msg = self.sock.read_msg()
+        self.assertFalse(self.message_has_tsig(get_msg))
         self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 0)
         self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)
         self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0)
@@ -223,6 +328,45 @@ class TestXfroutSession(unittest.TestCase):
         # and it should not have sent anything else
         self.assertEqual(0, len(self.sock.sendqueue))
 
+    def test_trigger_send_message_with_last_soa_with_tsig(self):
+        self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
+        rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
+        msg = self.getmsg()
+        msg.make_response()
+        msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
+
+        # length larger than MAX-len(rrset)
+        length_need_split = xfrout.XFROUT_MAX_MESSAGE_SIZE - get_rrset_len(rrset_soa) + 1
+        # packet number less than TSIG_SIGN_EVERY_NTH
+        packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1
+
+        # give the function a value that is larger than MAX-len(rrset)
+        # this should have triggered the sending of two messages
+        # (1 with the rrset we added manually, and 1 that triggered
+        # the sending in _with_last_soa)
+        self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, length_need_split,
+                                                 packet_neet_not_sign)
+        get_msg = self.sock.read_msg()
+        # msg is not the TSIG_SIGN_EVERY_NTH one, it shouldn't be tsig signed
+        self.assertFalse(self.message_has_tsig(get_msg))
+        # the last packet should be tsig signed
+        get_msg = self.sock.read_msg()
+        self.assertTrue(self.message_has_tsig(get_msg))
+        # and it should not have sent anything else
+        self.assertEqual(0, len(self.sock.sendqueue))
+
+
+        # msg is the TSIG_SIGN_EVERY_NTH one, it should be tsig signed
+        self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, length_need_split,
+                                                 xfrout.TSIG_SIGN_EVERY_NTH)
+        get_msg = self.sock.read_msg()
+        self.assertTrue(self.message_has_tsig(get_msg))
+        # the last packet should be tsig signed
+        get_msg = self.sock.read_msg()
+        self.assertTrue(self.message_has_tsig(get_msg))
+        # and it should not have sent anything else
+        self.assertEqual(0, len(self.sock.sendqueue))
+
     def test_get_rrset_len(self):
         rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
         self.assertEqual(82, get_rrset_len(rrset_soa))
@@ -313,6 +457,51 @@ class TestXfroutSession(unittest.TestCase):
         reply_msg = self.sock.read_msg()
         self.assertEqual(reply_msg.get_rr_count(Message.SECTION_ANSWER), 2)
 
+    def test_reply_xfrout_query_noerror_with_tsig(self):
+        rrset_data = (4, 3, 'a.example.com.', 'com.example.', 3600, 'A', None, '192.168.1.1')
+        global sqlite3_ds
+        global xfrout
+        def get_zone_soa(zonename, file):
+            return self.soa_record
+
+        def get_zone_datas(zone, file):
+            zone_rrsets = []
+            for i in range(0, 100):
+                zone_rrsets.insert(i, rrset_data)
+            return zone_rrsets
+
+        def get_rrset_len(rrset):
+            return 65520
+
+        sqlite3_ds.get_zone_soa = get_zone_soa
+        sqlite3_ds.get_zone_datas = get_zone_datas
+        xfrout.get_rrset_len = get_rrset_len
+
+        self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
+        self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock, "example.com.")
+
+        # tsig signed first package
+        reply_msg = self.sock.read_msg()
+        self.assertEqual(reply_msg.get_rr_count(Message.SECTION_ANSWER), 1)
+        self.assertTrue(self.message_has_tsig(reply_msg))
+        # (TSIG_SIGN_EVERY_NTH - 1) packets have no tsig
+        for i in range(0, xfrout.TSIG_SIGN_EVERY_NTH - 1):
+            reply_msg = self.sock.read_msg()
+            self.assertFalse(self.message_has_tsig(reply_msg))
+        # TSIG_SIGN_EVERY_NTH packet has tsig
+        reply_msg = self.sock.read_msg()
+        self.assertTrue(self.message_has_tsig(reply_msg))
+
+        for i in range(0, 100 - TSIG_SIGN_EVERY_NTH):
+            reply_msg = self.sock.read_msg()
+            self.assertFalse(self.message_has_tsig(reply_msg))
+        # tsig signed last package
+        reply_msg = self.sock.read_msg()
+        self.assertTrue(self.message_has_tsig(reply_msg))
+
+        # and it should not have sent anything else
+        self.assertEqual(0, len(self.sock.sendqueue))
+
 class MyCCSession():
     def __init__(self):
         pass
@@ -347,8 +536,23 @@ class TestUnixSockServer(unittest.TestCase):
         self.assertEqual(recv_msg, send_msg)
 
     def test_updata_config_data(self):
+        tsig_key_str = 'example.com:SFuWd/q99SzF8Yzd1QbB9g=='
+        tsig_key_list = [tsig_key_str]
+        bad_key_list = ['bad..example.com:SFuWd/q99SzF8Yzd1QbB9g==']
         self.unix.update_config_data({'transfers_out':10 })
         self.assertEqual(self.unix._max_transfers_out, 10)
+        self.assertTrue(self.unix.tsig_key_ring is not None)
+
+        self.unix.update_config_data({'transfers_out':9, 'tsig_key_ring':tsig_key_list})
+        self.assertEqual(self.unix._max_transfers_out, 9)
+        self.assertEqual(self.unix.tsig_key_ring.size(), 1)
+        self.unix.tsig_key_ring.remove(Name("example.com."))
+        self.assertEqual(self.unix.tsig_key_ring.size(), 0)
+
+        # bad tsig key
+        config_data = {'transfers_out':9, 'tsig_key_ring': bad_key_list}
+        self.assertRaises(None, self.unix.update_config_data(config_data))
+        self.assertEqual(self.unix.tsig_key_ring.size(), 0)
 
     def test_get_db_file(self):
         self.assertEqual(self.unix.get_db_file(), "initdb.file")

+ 90 - 20
src/bin/xfrout/xfrout.py.in

@@ -74,10 +74,12 @@ SPECFILE_LOCATION = SPECFILE_PATH + "/xfrout.spec"
 AUTH_SPECFILE_LOCATION = AUTH_SPECFILE_PATH + os.sep + "auth.spec"
 MAX_TRANSFERS_OUT = 10
 VERBOSE_MODE = False
-
+# tsig sign every N axfr packets.
+TSIG_SIGN_EVERY_NTH = 96
 
 XFROUT_MAX_MESSAGE_SIZE = 65535
 
+
 def get_rrset_len(rrset):
     """Returns the wire length of the given RRset"""
     bytes = bytearray()
@@ -86,15 +88,22 @@ def get_rrset_len(rrset):
 
 
 class XfroutSession():
-    def __init__(self, sock_fd, request_data, server, log):
+    def __init__(self, sock_fd, request_data, server, log, tsig_key_ring):
         # The initializer for the superclass may call functions
         # that need _log to be set, so we set it first
         self._sock_fd = sock_fd
         self._request_data = request_data
         self._server = server
         self._log = log
+        self._tsig_key_ring = tsig_key_ring
+        self._tsig_ctx = None
+        self._tsig_len = 0
         self.handle()
 
+    def create_tsig_ctx(self, tsig_record, tsig_key_ring):
+        return TSIGContext(tsig_record.get_name(), tsig_record.get_rdata().get_algorithm(),
+                           tsig_key_ring)
+
     def handle(self):
         ''' Handle a xfrout query, send xfrout response '''
         try:
@@ -105,17 +114,33 @@ class XfroutSession():
 
         os.close(self._sock_fd)
 
+    def _check_request_tsig(self, msg, request_data):
+        ''' If request has a tsig record, perform tsig related checks '''
+        tsig_record = msg.get_tsig_record()
+        if tsig_record is not None:
+            self._tsig_len = tsig_record.get_length()
+            self._tsig_ctx = self.create_tsig_ctx(tsig_record, self._tsig_key_ring)
+            tsig_error = self._tsig_ctx.verify(tsig_record, request_data)
+            if tsig_error != TSIGError.NOERROR:
+                return Rcode.NOTAUTH()
+
+        return Rcode.NOERROR()
+
     def _parse_query_message(self, mdata):
         ''' parse query message to [socket,message]'''
         #TODO, need to add parseHeader() in case the message header is invalid
         try:
             msg = Message(Message.PARSE)
             Message.from_wire(msg, mdata)
+
+            # TSIG related checks
+            rcode = self._check_request_tsig(msg, mdata)
+
         except Exception as err:
             self._log.log_message("error", str(err))
             return Rcode.FORMERR(), None
 
-        return Rcode.NOERROR(), msg
+        return rcode, msg
 
     def _get_query_zone_name(self, msg):
         question = msg.get_question()[0]
@@ -130,13 +155,20 @@ class XfroutSession():
             total_count += count
 
 
-    def _send_message(self, sock_fd, msg):
+    def _send_message(self, sock_fd, msg, tsig_ctx=None):
         render = MessageRenderer()
         # As defined in RFC5936 section3.4, perform case-preserving name
         # compression for AXFR message.
         render.set_compress_mode(MessageRenderer.CASE_SENSITIVE)
         render.set_length_limit(XFROUT_MAX_MESSAGE_SIZE)
-        msg.to_wire(render)
+
+        # XXX Currently, python wrapper doesn't accept 'None' parameter in this case,
+        # we should remove the if statement and use a universal interface later.
+        if tsig_ctx is not None:
+            msg.to_wire(render, tsig_ctx)
+        else:
+            msg.to_wire(render)
+
         header_len = struct.pack('H', socket.htons(render.get_length()))
         self._send_data(sock_fd, header_len)
         self._send_data(sock_fd, render.get_data())
@@ -145,7 +177,7 @@ class XfroutSession():
     def _reply_query_with_error_rcode(self, msg, sock_fd, rcode_):
         msg.make_response()
         msg.set_rcode(rcode_)
-        self._send_message(sock_fd, msg)
+        self._send_message(sock_fd, msg, self._tsig_ctx)
 
 
     def _reply_query_with_format_error(self, msg, sock_fd):
@@ -155,7 +187,7 @@ class XfroutSession():
 
         msg.make_response()
         msg.set_rcode(Rcode.FORMERR())
-        self._send_message(sock_fd, msg)
+        self._send_message(sock_fd, msg, self._tsig_ctx)
 
     def _zone_has_soa(self, zone):
         '''Judge if the zone has an SOA record.'''
@@ -204,7 +236,9 @@ class XfroutSession():
     def dns_xfrout_start(self, sock_fd, msg_query):
         rcode_, msg = self._parse_query_message(msg_query)
         #TODO. create query message and parse header
-        if rcode_ != Rcode.NOERROR():
+        if rcode_ == Rcode.NOTAUTH():
+            return self._reply_query_with_error_rcode(msg, sock_fd, rcode_)
+        elif rcode_ != Rcode.NOERROR():
             return self._reply_query_with_format_error(msg, sock_fd)
 
         zone_name = self._get_query_zone_name(msg)
@@ -248,37 +282,43 @@ class XfroutSession():
         rrset_.add_rdata(rdata_)
         return rrset_
 
-    def _send_message_with_last_soa(self, msg, sock_fd, rrset_soa, message_upper_len):
+    def _send_message_with_last_soa(self, msg, sock_fd, rrset_soa, message_upper_len,
+                                    count_since_last_tsig_sign):
         '''Add the SOA record to the end of message. If it can't be
         added, a new message should be created to send out the last soa .
         '''
         rrset_len = get_rrset_len(rrset_soa)
 
-        if message_upper_len + rrset_len < XFROUT_MAX_MESSAGE_SIZE:
-            msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
-        else:
+        if (count_since_last_tsig_sign == TSIG_SIGN_EVERY_NTH and
+            message_upper_len + rrset_len >= XFROUT_MAX_MESSAGE_SIZE):
+            # If tsig context exist, sign the packet with serial number TSIG_SIGN_EVERY_NTH
+            self._send_message(sock_fd, msg, self._tsig_ctx)
+            msg = self._clear_message(msg)
+        elif (count_since_last_tsig_sign != TSIG_SIGN_EVERY_NTH and
+              message_upper_len + rrset_len + self._tsig_len >= XFROUT_MAX_MESSAGE_SIZE):
             self._send_message(sock_fd, msg)
             msg = self._clear_message(msg)
-            msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
 
-        self._send_message(sock_fd, msg)
+        # If tsig context exist, sign the last packet
+        msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
+        self._send_message(sock_fd, msg, self._tsig_ctx)
 
 
     def _reply_xfrout_query(self, msg, sock_fd, zone_name):
         #TODO, there should be a better way to insert rrset.
+        count_since_last_tsig_sign = TSIG_SIGN_EVERY_NTH
         msg.make_response()
         msg.set_header_flag(Message.HEADERFLAG_AA)
         soa_record = sqlite3_ds.get_zone_soa(zone_name, self._server.get_db_file())
         rrset_soa = self._create_rrset_from_db_record(soa_record)
         msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
 
-        message_upper_len = get_rrset_len(rrset_soa)
+        message_upper_len = get_rrset_len(rrset_soa) + self._tsig_len
 
         for rr_data in sqlite3_ds.get_zone_datas(zone_name, self._server.get_db_file()):
             if  self._server._shutdown_event.is_set(): # Check if xfrout is shutdown
                 self._log.log_message("info", "xfrout process is being shutdown")
                 return
-
             # TODO: RRType.SOA() ?
             if RRType(rr_data[5]) == RRType("SOA"): #ignore soa record
                 continue
@@ -294,12 +334,25 @@ class XfroutSession():
                 message_upper_len += rrset_len
                 continue
 
-            self._send_message(sock_fd, msg)
+            # If tsig context exist, sign every N packets
+            if count_since_last_tsig_sign == TSIG_SIGN_EVERY_NTH:
+                count_since_last_tsig_sign = 0
+                self._send_message(sock_fd, msg, self._tsig_ctx)
+            else:
+                self._send_message(sock_fd, msg)
+
+            count_since_last_tsig_sign += 1
             msg = self._clear_message(msg)
             msg.add_rrset(Message.SECTION_ANSWER, rrset_) # Add the rrset to the new message
-            message_upper_len = rrset_len
 
-        self._send_message_with_last_soa(msg, sock_fd, rrset_soa, message_upper_len)
+            # Reserve tsig space for signed packet
+            if count_since_last_tsig_sign == TSIG_SIGN_EVERY_NTH:
+                message_upper_len = rrset_len + self._tsig_len
+            else:
+                message_upper_len = rrset_len
+
+        self._send_message_with_last_soa(msg, sock_fd, rrset_soa, message_upper_len,
+                                         count_since_last_tsig_sign)
 
 class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
     '''The unix domain socket server which accept xfr query sent from auth server.'''
@@ -403,7 +456,7 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
 
     def finish_request(self, sock_fd, request_data):
         '''Finish one request by instantiating RequestHandlerClass.'''
-        self.RequestHandlerClass(sock_fd, request_data, self, self._log)
+        self.RequestHandlerClass(sock_fd, request_data, self, self._log, self.tsig_key_ring)
 
     def _remove_unused_sock_file(self, sock_file):
         '''Try to remove the socket file. If the file is being used
@@ -449,10 +502,27 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
         self._log.log_message('info', 'update config data start.')
         self._lock.acquire()
         self._max_transfers_out = new_config.get('transfers_out')
+        self.set_tsig_key_ring(new_config.get('tsig_key_ring'))
         self._log.log_message('info', 'max transfer out : %d', self._max_transfers_out)
         self._lock.release()
         self._log.log_message('info', 'update config data complete.')
 
+    def set_tsig_key_ring(self, key_list):
+        """Set the tsig_key_ring , given a TSIG key string list representation. """
+
+        # XXX add values to configure zones/tsig options
+        self.tsig_key_ring = TSIGKeyRing()
+        # If key string list is empty, create a empty tsig_key_ring
+        if not key_list:
+            return
+
+        for key_item in key_list:
+            try:
+                self.tsig_key_ring.add(TSIGKey(key_item))
+            except InvalidParameter as ipe:
+                errmsg = "bad TSIG key string: " + str(key_item)
+                self._log.log_message('error', '%s' % errmsg)
+
     def get_db_file(self):
         file, is_default = self._cc.get_remote_config_value("Auth", "database_file")
         # this too should be unnecessary, but currently the

+ 12 - 0
src/bin/xfrout/xfrout.spec.pre.in

@@ -37,6 +37,18 @@
     	 "item_type": "integer",
          "item_optional": false,
     	 "item_default": 1048576
+       },
+       {
+         "item_name": "tsig_key_ring",
+         "item_type": "list",
+         "item_optional": true,
+         "item_default": [],
+         "list_item_spec" :
+         {
+             "item_name": "tsig_key",
+             "item_type": "string",
+             "item_optional": true
+         }
        }
       ],
       "commands": [

+ 157 - 109
src/lib/log/compiler/message.cc

@@ -35,6 +35,8 @@
 
 #include <log/logger.h>
 
+#include <boost/foreach.hpp>
+
 using namespace std;
 using namespace isc::log;
 using namespace isc::util;
@@ -78,10 +80,11 @@ version() {
 void
 usage() {
     cout <<
-        "Usage: message [-h] [-v] <message-file>\n" <<
+        "Usage: message [-h] [-v] [-p] <message-file>\n" <<
         "\n" <<
         "-h       Print this message and exit\n" <<
         "-v       Print the program version and exit\n" <<
+        "-p       Output python source instead of C++ ones\n" <<
         "\n" <<
         "<message-file> is the name of the input message file.\n";
 }
@@ -237,6 +240,44 @@ writeClosingNamespace(ostream& output, const vector<string>& ns) {
     }
 }
 
+/// \breif Write python file
+///
+/// Writes the python file containing the symbol definitions as module level
+/// constants. These are objects which register themself at creation time,
+/// so they can be replaced by dictionary later.
+///
+/// \param file Name of the message file. The source code is written to a file
+///     file of the same name but with a .py suffix.
+/// \param dictionary The dictionary holding the message definitions.
+///
+/// \note We don't use the namespace as in C++. We don't need it, because
+///     python file/module works as implicit namespace as well.
+
+void
+writePythonFile(const string& file, MessageDictionary& dictionary) {
+    Filename message_file(file);
+    Filename python_file(Filename(message_file.name()).useAsDefault(".py"));
+
+    // Open the file for writing
+    ofstream pyfile(python_file.fullName().c_str());
+
+    // Write the comment and imports
+    pyfile <<
+        "# File created from " << message_file.fullName() << " on " <<
+            currentTime() << "\n" <<
+        "\n" <<
+        "import isc.log.message\n" <<
+        "\n";
+
+    vector<string> idents(sortedIdentifiers(dictionary));
+    BOOST_FOREACH(const string& ident, idents) {
+        pyfile << ident << " = isc.log.message.create(\"" <<
+            ident << "\", \"" << quoteString(dictionary.getText(ident)) <<
+            "\")\n";
+    }
+
+    pyfile.close();
+}
 
 /// \brief Write Header File
 ///
@@ -264,52 +305,46 @@ writeHeaderFile(const string& file, const vector<string>& ns_components,
     // Open the output file for writing
     ofstream hfile(header_file.fullName().c_str());
 
-    try {
-        if (hfile.fail()) {
-            throw MessageException(MSG_OPENOUT, header_file.fullName(),
-                strerror(errno));
-        }
-
-        // Write the header preamble.  If there is an error, we'll pick it up
-        // after the last write.
-
-        hfile <<
-            "// File created from " << message_file.fullName() << " on " <<
-                currentTime() << "\n" <<
-             "\n" <<
-             "#ifndef " << sentinel_text << "\n" <<
-             "#define "  << sentinel_text << "\n" <<
-             "\n" <<
-             "#include <log/message_types.h>\n" <<
-             "\n";
-
-        // Write the message identifiers, bounded by a namespace declaration
-        writeOpeningNamespace(hfile, ns_components);
-
-        vector<string> idents = sortedIdentifiers(dictionary);
-        for (vector<string>::const_iterator j = idents.begin();
-            j != idents.end(); ++j) {
-            hfile << "extern const isc::log::MessageID " << *j << ";\n";
-        }
-        hfile << "\n";
+    if (hfile.fail()) {
+        throw MessageException(MSG_OPENOUT, header_file.fullName(),
+            strerror(errno));
+    }
 
-        writeClosingNamespace(hfile, ns_components);
+    // Write the header preamble.  If there is an error, we'll pick it up
+    // after the last write.
+
+    hfile <<
+        "// File created from " << message_file.fullName() << " on " <<
+            currentTime() << "\n" <<
+         "\n" <<
+         "#ifndef " << sentinel_text << "\n" <<
+         "#define "  << sentinel_text << "\n" <<
+         "\n" <<
+         "#include <log/message_types.h>\n" <<
+         "\n";
+
+    // Write the message identifiers, bounded by a namespace declaration
+    writeOpeningNamespace(hfile, ns_components);
+
+    vector<string> idents = sortedIdentifiers(dictionary);
+    for (vector<string>::const_iterator j = idents.begin();
+        j != idents.end(); ++j) {
+        hfile << "extern const isc::log::MessageID " << *j << ";\n";
+    }
+    hfile << "\n";
 
-        // ... and finally the postamble
-        hfile << "#endif // " << sentinel_text << "\n";
+    writeClosingNamespace(hfile, ns_components);
 
-        // Report errors (if any) and exit
-        if (hfile.fail()) {
-            throw MessageException(MSG_WRITERR, header_file.fullName(),
-                strerror(errno));
-        }
+    // ... and finally the postamble
+    hfile << "#endif // " << sentinel_text << "\n";
 
-        hfile.close();
-    }
-    catch (MessageException&) {
-        hfile.close();
-        throw;
+    // Report errors (if any) and exit
+    if (hfile.fail()) {
+        throw MessageException(MSG_WRITERR, header_file.fullName(),
+            strerror(errno));
     }
+
+    hfile.close();
 }
 
 
@@ -357,76 +392,71 @@ writeProgramFile(const string& file, const vector<string>& ns_components,
 
     // Open the output file for writing
     ofstream ccfile(program_file.fullName().c_str());
-    try {
-        if (ccfile.fail()) {
-            throw MessageException(MSG_OPENOUT, program_file.fullName(),
-                strerror(errno));
-        }
 
-        // Write the preamble.  If there is an error, we'll pick it up after
-        // the last write.
+    if (ccfile.fail()) {
+        throw MessageException(MSG_OPENOUT, program_file.fullName(),
+            strerror(errno));
+    }
 
-        ccfile <<
-            "// File created from " << message_file.fullName() << " on " <<
-                currentTime() << "\n" <<
-             "\n" <<
-             "#include <cstddef>\n" <<
-             "#include <log/message_types.h>\n" <<
-             "#include <log/message_initializer.h>\n" <<
-             "\n";
+    // Write the preamble.  If there is an error, we'll pick it up after
+    // the last write.
 
-        // Declare the message symbols themselves.
+    ccfile <<
+        "// File created from " << message_file.fullName() << " on " <<
+            currentTime() << "\n" <<
+         "\n" <<
+         "#include <cstddef>\n" <<
+         "#include <log/message_types.h>\n" <<
+         "#include <log/message_initializer.h>\n" <<
+         "\n";
 
-        writeOpeningNamespace(ccfile, ns_components);
+    // Declare the message symbols themselves.
 
-        vector<string> idents = sortedIdentifiers(dictionary);
-        for (vector<string>::const_iterator j = idents.begin();
-            j != idents.end(); ++j) {
-            ccfile << "extern const isc::log::MessageID " << *j <<
-                " = \"" << *j << "\";\n";
-        }
-        ccfile << "\n";
+    writeOpeningNamespace(ccfile, ns_components);
 
-        writeClosingNamespace(ccfile, ns_components);
+    vector<string> idents = sortedIdentifiers(dictionary);
+    for (vector<string>::const_iterator j = idents.begin();
+        j != idents.end(); ++j) {
+        ccfile << "extern const isc::log::MessageID " << *j <<
+            " = \"" << *j << "\";\n";
+    }
+    ccfile << "\n";
 
-        // Now the code for the message initialization.
+    writeClosingNamespace(ccfile, ns_components);
 
-        ccfile <<
-             "namespace {\n" <<
-             "\n" <<
-             "const char* values[] = {\n";
+    // Now the code for the message initialization.
 
-        // Output the identifiers and the associated text.
-        idents = sortedIdentifiers(dictionary);
-        for (vector<string>::const_iterator i = idents.begin();
-            i != idents.end(); ++i) {
-                ccfile << "    \"" << *i << "\", \"" <<
-                    quoteString(dictionary.getText(*i)) << "\",\n";
-        }
+    ccfile <<
+         "namespace {\n" <<
+         "\n" <<
+         "const char* values[] = {\n";
 
+    // Output the identifiers and the associated text.
+    idents = sortedIdentifiers(dictionary);
+    for (vector<string>::const_iterator i = idents.begin();
+        i != idents.end(); ++i) {
+            ccfile << "    \"" << *i << "\", \"" <<
+                quoteString(dictionary.getText(*i)) << "\",\n";
+    }
 
-        // ... and the postamble
-        ccfile <<
-            "    NULL\n" <<
-            "};\n" <<
-            "\n" <<
-            "const isc::log::MessageInitializer initializer(values);\n" <<
-            "\n" <<
-            "} // Anonymous namespace\n" <<
-            "\n";
-
-        // Report errors (if any) and exit
-        if (ccfile.fail()) {
-            throw MessageException(MSG_WRITERR, program_file.fullName(),
-                strerror(errno));
-        }
 
-        ccfile.close();
-    }
-    catch (MessageException&) {
-        ccfile.close();
-        throw;
+    // ... and the postamble
+    ccfile <<
+        "    NULL\n" <<
+        "};\n" <<
+        "\n" <<
+        "const isc::log::MessageInitializer initializer(values);\n" <<
+        "\n" <<
+        "} // Anonymous namespace\n" <<
+        "\n";
+
+    // Report errors (if any) and exit
+    if (ccfile.fail()) {
+        throw MessageException(MSG_WRITERR, program_file.fullName(),
+            strerror(errno));
     }
+
+    ccfile.close();
 }
 
 
@@ -466,13 +496,19 @@ warnDuplicates(MessageReader& reader) {
 int
 main(int argc, char* argv[]) {
 
-    const char* soptions = "hv";               // Short options
+    const char* soptions = "hvp";               // Short options
 
     optind = 1;             // Ensure we start a new scan
     int  opt;               // Value of the option
 
+    bool doPython = false;
+
     while ((opt = getopt(argc, argv, soptions)) != -1) {
         switch (opt) {
+            case 'p':
+                doPython = true;
+                break;
+
             case 'h':
                 usage();
                 return 0;
@@ -508,15 +544,27 @@ main(int argc, char* argv[]) {
         MessageReader reader(&dictionary);
         reader.readFile(message_file);
 
-        // Get the namespace into which the message definitions will be put and
-        // split it into components.
-        vector<string> ns_components = splitNamespace(reader.getNamespace());
-
-        // Write the header file.
-        writeHeaderFile(message_file, ns_components, dictionary);
-
-        // Write the file that defines the message symbols and text
-        writeProgramFile(message_file, ns_components, dictionary);
+        if (doPython) {
+            // Warn in case of ignored directives
+            if (!reader.getNamespace().empty()) {
+                cerr << "Python mode, ignoring the $NAMESPACE directive" <<
+                    endl;
+            }
+
+            // Write the whole python file
+            writePythonFile(message_file, dictionary);
+        } else {
+            // Get the namespace into which the message definitions will be put and
+            // split it into components.
+            vector<string> ns_components =
+                splitNamespace(reader.getNamespace());
+
+            // Write the header file.
+            writeHeaderFile(message_file, ns_components, dictionary);
+
+            // Write the file that defines the message symbols and text
+            writeProgramFile(message_file, ns_components, dictionary);
+        }
 
         // Finally, warn of any duplicates encountered.
         warnDuplicates(reader);

+ 29 - 6
src/lib/util/encode/base_n.cc

@@ -160,19 +160,42 @@ public:
         base_zero_code_(base_zero_code),
         base_(base), base_beginpad_(base_beginpad), base_end_(base_end),
         in_pad_(false)
-    {}
+    {
+        // Skip beginning spaces, if any.  We need do it here because
+        // otherwise the first call to operator*() would be confused.
+        skipSpaces();
+    }
     DecodeNormalizer& operator++() {
         ++base_;
-        while (base_ != base_end_ && isspace(*base_)) {
-            ++base_;
-        }
+        skipSpaces();
         if (base_ == base_beginpad_) {
             in_pad_ = true;
         }
         return (*this);
     }
+    void skipSpaces() {
+        // If (char is signed and) *base_ < 0, on Windows platform with Visual
+        // Studio compiler it may trigger _ASSERTE((unsigned)(c + 1) <= 256);
+        // so make sure that the parameter of isspace() is larger than 0.
+        // We don't simply cast it to unsigned char to avoid confusing the
+        // isspace() implementation with a possible extension for values
+        // larger than 127.  Also note the check is not ">= 0"; for systems
+        // where char is unsigned that would always be true and would possibly
+        // trigger a compiler warning that could stop the build.
+        while (base_ != base_end_ && *base_ > 0 && isspace(*base_)) {
+            ++base_;
+        }
+    }
     const char& operator*() const {
-        if (in_pad_ && *base_ == BASE_PADDING_CHAR) {
+        if (base_ == base_end_) {
+            // binary_from_baseX calls this operator when it needs more bits
+            // even if the internal iterator (base_) has reached its end
+            // (if that happens it means the input is an incomplete baseX
+            // string and should be rejected).  So this is the only point
+            // we can catch and reject this type of invalid input.
+            isc_throw(BadValue, "Unexpected end of input in BASE decoder");
+        }
+        if (in_pad_) {
             return (base_zero_code_);
         } else {
             return (*base_);
@@ -268,7 +291,7 @@ BaseNTransformer<BitsPerChunk, BaseZeroCode, Encoder, Decoder>::decode(
                 isc_throw(BadValue, "Too many " << algorithm
                           << " padding characters: " << input);
             }
-        } else if (!isspace(ch)) {
+        } else if (ch < 0 || !isspace(ch)) {
             break;
         }
         ++srit;

+ 6 - 1
src/lib/util/tests/base32hex_unittest.cc

@@ -66,7 +66,7 @@ decodeCheck(const string& input_string, vector<uint8_t>& output,
             const string& expected)
 {
     decodeBase32Hex(input_string, output);
-    EXPECT_EQ(expected, string(&output[0], &output[0] + output.size()));
+    EXPECT_EQ(expected, string(output.begin(), output.end()));
 }
 
 TEST_F(Base32HexTest, decode) {
@@ -79,6 +79,11 @@ TEST_F(Base32HexTest, decode) {
     // whitespace should be allowed
     decodeCheck("CP NM\tUOG=", decoded_data, "foob");
     decodeCheck("CPNMU===\n", decoded_data, "foo");
+    decodeCheck("  CP NM\tUOG=", decoded_data, "foob");
+    decodeCheck(" ", decoded_data, "");
+
+    // Incomplete input
+    EXPECT_THROW(decodeBase32Hex("CPNMUOJ", decoded_data), BadValue);
 
     // invalid number of padding characters
     EXPECT_THROW(decodeBase32Hex("CPNMU0==", decoded_data), BadValue);

+ 7 - 1
src/lib/util/tests/base64_unittest.cc

@@ -52,7 +52,7 @@ decodeCheck(const string& input_string, vector<uint8_t>& output,
             const string& expected)
 {
     decodeBase64(input_string, output);
-    EXPECT_EQ(expected, string(&output[0], &output[0] + output.size()));
+    EXPECT_EQ(expected, string(output.begin(), output.end()));
 }
 
 TEST_F(Base64Test, decode) {
@@ -66,6 +66,12 @@ TEST_F(Base64Test, decode) {
     decodeCheck("Zm 9v\tYmF\ny", decoded_data, "foobar");
     decodeCheck("Zm9vYg==", decoded_data, "foob");
     decodeCheck("Zm9vYmE=\n", decoded_data, "fooba");
+    decodeCheck(" Zm9vYmE=\n", decoded_data, "fooba");
+    decodeCheck(" ", decoded_data, "");
+    decodeCheck("\n\t", decoded_data, "");
+
+    // incomplete input
+    EXPECT_THROW(decodeBase64("Zm9vYmF", decoded_data), BadValue);
 
     // only up to 2 padding characters are allowed
     EXPECT_THROW(decodeBase64("A===", decoded_data), BadValue);