123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283 |
- # Copyright (C) 2010 Internet Systems Consortium.
- #
- # Permission to use, copy, modify, and distribute this software for any
- # purpose with or without fee is hereby granted, provided that the above
- # copyright notice and this permission notice appear in all copies.
- #
- # THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
- # DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
- # INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
- # INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
- # FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
- # NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
- # WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
- '''Tests for the XfroutSession and UnixSockServer classes '''
- import unittest
- import os
- from isc.cc.session import *
- from bind10_dns import *
- from xfrout import *
- # our fake socket, where we can read and insert messages
- class MySocket():
- def __init__(self, family, type):
- self.family = family
- self.type = type
- self.sendqueue = bytearray()
- def connect(self, to):
- pass
- def close(self):
- pass
- def send(self, data):
- self.sendqueue.extend(data);
- return len(data)
- def readsent(self):
- result = self.sendqueue[:]
- del self.sendqueue[:]
- return result
-
- def read_msg(self):
- sent_data = self.readsent()
- get_msg = message(message_mode.PARSE)
- get_msg.from_wire(input_buffer(bytes(sent_data[2:])))
- return get_msg
-
- def clear_send(self):
- del self.sendqueue[:]
- # We subclass the Session class we're testing here, only
- # to override the __init__() method, which wants a socket,
- class MyXfroutSession(XfroutSession):
- def handle(self):
- pass
-
- class Dbserver:
- def __init__(self):
- self._shutdown_event = threading.Event()
- def get_db_file(self):
- return None
- def decrease_transfers_counter(self):
- pass
- class TestXfroutSession(unittest.TestCase):
- def getmsg(self):
- msg = message(message_mode.PARSE)
- msg.from_wire(input_buffer(self.mdata))
- return msg
- def setUp(self):
- request = MySocket(socket.AF_INET,socket.SOCK_STREAM)
- self.xfrsess = MyXfroutSession(request, None, None)
- self.xfrsess.server = Dbserver()
- self.mdata = b'\xd6=\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x07example\x03com\x00\x00\xfc\x00\x01'
- self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM)
- self.soa_record = (4, 3, 'example.com.', 'com.example.', 3600, 'SOA', None, 'master.example.com. admin.example.com. 1234 3600 1800 2419200 7200')
- def test_parse_query_message(self):
- [get_rcode, get_msg] = self.xfrsess._parse_query_message(self.mdata)
- self.assertEqual(get_rcode.to_text(), "NOERROR")
- def test_get_query_zone_name(self):
- msg = self.getmsg()
- self.assertEqual(self.xfrsess._get_query_zone_name(msg), "example.com.")
-
- def test_send_data(self):
- self.xfrsess._send_data(self.sock, self.mdata)
- senddata = self.sock.readsent()
- self.assertEqual(senddata, self.mdata)
- def test_reply_xfrout_query_with_error_rcode(self):
- msg = self.getmsg()
- self.xfrsess._reply_query_with_error_rcode(msg, self.sock, rcode(3))
- get_msg = self.sock.read_msg()
- self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN")
-
- def test_clear_message(self):
- msg = self.getmsg()
- qid = msg.get_qid()
- opcode = msg.get_opcode()
- rcode = msg.get_rcode()
- self.xfrsess._clear_message(msg)
- self.assertEqual(msg.get_qid(), qid)
- self.assertEqual(msg.get_opcode(), opcode)
- self.assertEqual(msg.get_rcode(), rcode)
- self.assertTrue(msg.get_header_flag(message_flag.AA()))
- def test_reply_query_with_format_error(self):
-
- msg = self.getmsg()
- self.xfrsess._reply_query_with_format_error(msg, self.sock)
- get_msg = self.sock.read_msg()
- self.assertEqual(get_msg.get_rcode().to_text(), "FORMERR")
- def test_create_rrset_from_db_record(self):
- rrset = self.xfrsess._create_rrset_from_db_record(self.soa_record)
- self.assertEqual(rrset.get_name().to_text(), "example.com.")
- self.assertEqual(rrset.get_class(), rr_class.IN())
- self.assertEqual(rrset.get_type().to_text(), "SOA")
- rdata_iter = rrset.get_rdata_iterator()
- rdata_iter.first()
- self.assertEqual(rdata_iter.get_current().to_text(), self.soa_record[7])
- def test_send_message_with_last_soa(self):
- rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
- msg = self.getmsg()
- msg.make_response()
- self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa)
- get_msg = self.sock.read_msg()
- self.assertEqual(get_msg.get_rr_count(section.QUESTION()), 1)
- self.assertEqual(get_msg.get_rr_count(section.ANSWER()), 1)
- self.assertEqual(get_msg.get_rr_count(section.AUTHORITY()), 0)
- answer_rrset_iter = section_iter(get_msg, section.ANSWER())
- answer = answer_rrset_iter.get_rrset()
- self.assertEqual(answer.get_name().to_text(), "example.com.")
- self.assertEqual(answer.get_class(), rr_class.IN())
- self.assertEqual(answer.get_type().to_text(), "SOA")
- rdata_iter = answer.get_rdata_iterator()
- rdata_iter.first()
- self.assertEqual(rdata_iter.get_current().to_text(), self.soa_record[7])
- def test_get_message_len(self):
- msg = self.getmsg()
- msg.make_response()
- self.assertEqual(self.xfrsess._get_message_len(msg), 29)
- def test_zone_is_empty(self):
- global sqlite3_ds
- def mydb1(zone, file):
- return True
- sqlite3_ds.get_zone_soa = mydb1
- self.assertEqual(self.xfrsess._zone_is_empty(""), False)
- def mydb2(zone, file):
- return False
- sqlite3_ds.get_zone_soa = mydb2
- self.assertEqual(self.xfrsess._zone_is_empty(""), True)
- def test_zone_exist(self):
- global sqlite3_ds
- def zone_soa(zone, file):
- return zone
- sqlite3_ds.get_zone_soa = zone_soa
- self.assertEqual(self.xfrsess._zone_exist(True), True)
- self.assertEqual(self.xfrsess._zone_exist(False), False)
-
- def test_check_xfrout_available(self):
- def zone_exist(zone):
- return zone
- self.xfrsess._zone_exist = zone_exist
- self.xfrsess._zone_is_empty = zone_exist
- self.assertEqual(self.xfrsess._check_xfrout_available(False).to_text(), "NOTAUTH")
- self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "SERVFAIL")
- def zone_empty(zone):
- return not zone
- self.xfrsess._zone_is_empty = zone_empty
- def false_func():
- return False
- self.xfrsess.server.increase_transfers_counter = false_func
- self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "REFUSED")
- def true_func():
- return True
- self.xfrsess.server.increase_transfers_counter = true_func
- self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "NOERROR")
- def test_dns_xfrout_start_formerror(self):
- # formerror
- self.xfrsess.dns_xfrout_start(self.sock, b"\xd6=\x00\x00\x00\x01\x00")
- sent_data = self.sock.readsent()
- self.assertEqual(len(sent_data), 0)
-
- def default(self, param):
- return "example.com"
- def test_dns_xfrout_start_notauth(self):
- self.xfrsess._get_query_zone_name = self.default
- def notauth(formpara):
- return rcode.NOTAUTH()
- self.xfrsess._check_xfrout_available = notauth
- self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
- get_msg = self.sock.read_msg()
- self.assertEqual(get_msg.get_rcode().to_text(), "NOTAUTH")
-
- def test_dns_xfrout_start_noerror(self):
- self.xfrsess._get_query_zone_name = self.default
- def noerror(form):
- return rcode.NOERROR()
- self.xfrsess._check_xfrout_available = noerror
-
- def myreply(msg, sock, zonename):
- self.sock.send(b"success")
-
- self.xfrsess._reply_xfrout_query = myreply
- self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
- self.assertEqual(self.sock.readsent(), b"success")
-
- def test_reply_xfrout_query_noerror(self):
- global sqlite3_ds
- def get_zone_soa(zonename, file):
- return self.soa_record
- def get_zone_datas(zone, file):
- return [self.soa_record]
- sqlite3_ds.get_zone_soa = get_zone_soa
- sqlite3_ds.get_zone_datas = get_zone_datas
- self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock, "example.com.")
- reply_msg = self.sock.read_msg()
- self.assertEqual(reply_msg.get_rr_count(section.ANSWER()), 2)
- # set event
- self.xfrsess.server._shutdown_event.set()
- self.assertRaises(XfroutException, self.xfrsess._reply_xfrout_query, self.getmsg(), self.sock, "example.com.")
- class MyUnixSockServer(UnixSockServer):
- def __init__(self):
- self._lock = threading.Lock()
- self._transfers_counter = 0
- self._shutdown_event = threading.Event()
- self._db_file = "initdb.file"
- self._max_transfers_out = 10
- class TestUnixSockServer(unittest.TestCase):
- def setUp(self):
- self.unix = MyUnixSockServer()
-
- def test_updata_config_data(self):
- self.unix.update_config_data({'transfers_out':10, 'db_file':"db.file"})
- self.assertEqual(self.unix._max_transfers_out, 10)
- self.assertEqual(self.unix._db_file, "db.file")
- def test_get_db_file(self):
- self.assertEqual(self.unix.get_db_file(), "initdb.file")
- def test_increase_transfers_counter(self):
- self.unix._max_transfers_out = 10
- count = self.unix._transfers_counter
- self.assertEqual(self.unix.increase_transfers_counter(), True)
- self.assertEqual(count + 1, self.unix._transfers_counter)
- self.unix._max_transfers_out = 0
- count = self.unix._transfers_counter
- self.assertEqual(self.unix.increase_transfers_counter(), False)
- self.assertEqual(count, self.unix._transfers_counter)
-
- def test_decrease_transfers_counter(self):
- count = self.unix._transfers_counter
- self.unix.decrease_transfers_counter()
- self.assertEqual(count - 1, self.unix._transfers_counter)
- if __name__== "__main__":
- unittest.main()
|