Browse Source

[1175] modify b10-stats-httpd_test.py and b10-stats_test.py

 - add function get_availaddr to get available address and port on the
   platform

 - add function is_ipv6enabled to check ipv6 enabled on the platform

 - add miscellaneous changes to refactor unittest
Naoki Kambe 13 years ago
parent
commit
290e89c515
2 changed files with 247 additions and 225 deletions
  1. 229 211
      src/bin/stats/tests/b10-stats-httpd_test.py
  2. 18 14
      src/bin/stats/tests/b10-stats_test.py

+ 229 - 211
src/bin/stats/tests/b10-stats-httpd_test.py

@@ -36,7 +36,7 @@ import xml.etree.ElementTree
 import isc
 import stats_httpd
 import stats
-from test_utils import BaseModules, ThreadingServerManager, MyStats, MyStatsHttpd, TIMEOUT_SEC
+from test_utils import BaseModules, ThreadingServerManager, MyStats, MyStatsHttpd, send_shutdown
 
 # set test name for logger
 isc.log.init("b10-stats-httpd_test")
@@ -58,35 +58,61 @@ DUMMY_DATA = {
         }
     }
 
+def get_availaddr(address='127.0.0.1'):
+    """returns tuple of address and port available on the
+    platform. default range of port is from 65535 to 50000"""
+    for port in range(65535, 50000, -1):
+        try:
+            if is_ipv6_enabled(address):
+                sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
+            else :
+                sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+            sock.bind((address, port))
+            sock.close()
+            return (address, port)
+        except socket.error:
+            pass
+
+def is_ipv6_enabled(address='::1', port=8000):
+    """checks IPv6 enabled on the platform"""
+    try:
+        sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
+        sock.bind((address, port))
+        sock.close()
+        return True
+    except socket.error:
+        return False
+
 class TestHttpHandler(unittest.TestCase):
     """Tests for HttpHandler class"""
-
     def setUp(self):
         self.base = BaseModules()
         self.stats_server = ThreadingServerManager(MyStats)
         self.stats = self.stats_server.server
         self.stats_server.run()
+        (self.address, self.port) = get_availaddr()
+        self.stats_httpd_server = ThreadingServerManager(MyStatsHttpd, (self.address, self.port))
+        self.stats_httpd = self.stats_httpd_server.server
+        self.stats_httpd_server.run()
+        self.client = http.client.HTTPConnection(self.address, self.port)
+        self.client._http_vsn_str = 'HTTP/1.0\n'
+        self.client.connect()
 
     def tearDown(self):
+        self.client.close()
+        self.stats_httpd_server.shutdown()
         self.stats_server.shutdown()
         self.base.shutdown()
 
     def test_do_GET(self):
-        (address, port) = ('127.0.0.1', 65450)
-        statshttpd_server = ThreadingServerManager(MyStatsHttpd)
-        self.stats_httpd = statshttpd_server.server
-        self.stats_httpd.load_config({'listen_on' : [{ 'address': address, 'port' : port }]})
         self.assertTrue(type(self.stats_httpd.httpd) is list)
-        self.assertEqual(len(self.stats_httpd.httpd), 0)
-        statshttpd_server.run()
-        client = http.client.HTTPConnection(address, port)
-        client._http_vsn_str = 'HTTP/1.0\n'
-        client.connect()
+        self.assertEqual(len(self.stats_httpd.httpd), 1)
+        self.assertEqual((self.address, self.port), self.stats_httpd.http_addrs[0])
 
         # URL is '/bind10/statistics/xml'
-        client.putrequest('GET', stats_httpd.XML_URL_PATH)
-        client.endheaders()
-        response = client.getresponse()
+        self.client.putrequest('GET', stats_httpd.XML_URL_PATH)
+        self.client.endheaders()
+        response = self.client.getresponse()
         self.assertEqual(response.getheader("Content-type"), "text/xml")
         self.assertTrue(int(response.getheader("Content-Length")) > 0)
         self.assertEqual(response.status, 200)
@@ -100,9 +126,9 @@ class TestHttpHandler(unittest.TestCase):
                 self.assertIsNotNone(root.find(mod + '/' + item))
 
         # URL is '/bind10/statitics/xsd'
-        client.putrequest('GET', stats_httpd.XSD_URL_PATH)
-        client.endheaders()
-        response = client.getresponse()
+        self.client.putrequest('GET', stats_httpd.XSD_URL_PATH)
+        self.client.endheaders()
+        response = self.client.getresponse()
         self.assertEqual(response.getheader("Content-type"), "text/xml")
         self.assertTrue(int(response.getheader("Content-Length")) > 0)
         self.assertEqual(response.status, 200)
@@ -120,9 +146,9 @@ class TestHttpHandler(unittest.TestCase):
             self.assertTrue(elm.attrib['name'] in DUMMY_DATA)
 
         # URL is '/bind10/statitics/xsl'
-        client.putrequest('GET', stats_httpd.XSL_URL_PATH)
-        client.endheaders()
-        response = client.getresponse()
+        self.client.putrequest('GET', stats_httpd.XSL_URL_PATH)
+        self.client.endheaders()
+        response = self.client.getresponse()
         self.assertEqual(response.getheader("Content-type"), "text/xml")
         self.assertTrue(int(response.getheader("Content-Length")) > 0)
         self.assertEqual(response.status, 200)
@@ -147,114 +173,83 @@ class TestHttpHandler(unittest.TestCase):
                                 [ tds[0].text+'/'+item for item in DUMMY_DATA[tds[0].text].keys() ])
 
         # 302 redirect
-        client._http_vsn_str = 'HTTP/1.1'
-        client.putrequest('GET', '/')
-        client.putheader('Host', address)
-        client.endheaders()
-        response = client.getresponse()
+        self.client._http_vsn_str = 'HTTP/1.1'
+        self.client.putrequest('GET', '/')
+        self.client.putheader('Host', self.address)
+        self.client.endheaders()
+        response = self.client.getresponse()
         self.assertEqual(response.status, 302)
         self.assertEqual(response.getheader('Location'),
-                         "http://%s:%d%s" % (address, port, stats_httpd.XML_URL_PATH))
+                         "http://%s:%d%s" % (self.address, self.port, stats_httpd.XML_URL_PATH))
 
         # # 404 NotFound
-        client._http_vsn_str = 'HTTP/1.0'
-        client.putrequest('GET', '/path/to/foo/bar')
-        client.endheaders()
-        response = client.getresponse()
+        self.client._http_vsn_str = 'HTTP/1.0'
+        self.client.putrequest('GET', '/path/to/foo/bar')
+        self.client.endheaders()
+        response = self.client.getresponse()
         self.assertEqual(response.status, 404)
 
-        client.close()
-        statshttpd_server.shutdown()
 
     def test_do_GET_failed1(self):
-        # failure case(connection with Stats is down)
-        (address, port) = ('127.0.0.1', 65451)
-        statshttpd_server = ThreadingServerManager(MyStatsHttpd)
-        statshttpd = statshttpd_server.server
-        statshttpd.load_config({'listen_on' : [{ 'address': address, 'port' : port }]})
-        statshttpd_server.run()
-        self.assertTrue(self.stats_server.server.running)
-        self.stats_server.shutdown()
-        self.assertFalse(self.stats_server.server.running)
-        statshttpd.cc_session.set_timeout(milliseconds=TIMEOUT_SEC/1000)
-        client = http.client.HTTPConnection(address, port)
-        client.connect()
+        # failure case(Stats is down)
+        self.assertTrue(self.stats.running)
+        send_shutdown("Stats") # Stats is down
+        self.assertFalse(self.stats.running)
+        self.stats_httpd.cc_session.set_timeout(milliseconds=100)
 
         # request XML
-        client.putrequest('GET', stats_httpd.XML_URL_PATH)
-        client.endheaders()
-        response = client.getresponse()
+        self.client.putrequest('GET', stats_httpd.XML_URL_PATH)
+        self.client.endheaders()
+        response = self.client.getresponse()
         self.assertEqual(response.status, 500)
 
         # request XSD
-        client.putrequest('GET', stats_httpd.XSD_URL_PATH)
-        client.endheaders()
-        response = client.getresponse()
+        self.client.putrequest('GET', stats_httpd.XSD_URL_PATH)
+        self.client.endheaders()
+        response = self.client.getresponse()
         self.assertEqual(response.status, 500)
 
         # request XSL
-        client.putrequest('GET', stats_httpd.XSL_URL_PATH)
-        client.endheaders()
-        response = client.getresponse()
+        self.client.putrequest('GET', stats_httpd.XSL_URL_PATH)
+        self.client.endheaders()
+        response = self.client.getresponse()
         self.assertEqual(response.status, 500)
 
-        client.close()
-        statshttpd_server.shutdown()
-
     def test_do_GET_failed2(self):
-        # failure case(connection with Stats is down)
-        (address, port) = ('127.0.0.1', 65452)
-        statshttpd_server = ThreadingServerManager(MyStatsHttpd)
-        self.stats_httpd = statshttpd_server.server
-        self.stats_httpd.load_config({'listen_on' : [{ 'address': address, 'port' : port }]})
-        statshttpd_server.run()
+        # failure case(Stats replies an error)
         self.stats.mccs.set_command_handler(
             lambda cmd, args: \
                 isc.config.ccsession.create_answer(1, "I have an error.")
             )
-        client = http.client.HTTPConnection(address, port)
-        client.connect()
 
         # request XML
-        client.putrequest('GET', stats_httpd.XML_URL_PATH)
-        client.endheaders()
-        response = client.getresponse()
+        self.client.putrequest('GET', stats_httpd.XML_URL_PATH)
+        self.client.endheaders()
+        response = self.client.getresponse()
         self.assertEqual(response.status, 500)
 
         # request XSD
-        client.putrequest('GET', stats_httpd.XSD_URL_PATH)
-        client.endheaders()
-        response = client.getresponse()
+        self.client.putrequest('GET', stats_httpd.XSD_URL_PATH)
+        self.client.endheaders()
+        response = self.client.getresponse()
         self.assertEqual(response.status, 500)
 
         # request XSL
-        client.putrequest('GET', stats_httpd.XSL_URL_PATH)
-        client.endheaders()
-        response = client.getresponse()
+        self.client.putrequest('GET', stats_httpd.XSL_URL_PATH)
+        self.client.endheaders()
+        response = self.client.getresponse()
         self.assertEqual(response.status, 500)
 
-        client.close()
-        statshttpd_server.shutdown()
-
     def test_do_HEAD(self):
-        (address, port) = ('127.0.0.1', 65453)
-        statshttpd_server = ThreadingServerManager(MyStatsHttpd)
-        self.stats_httpd = statshttpd_server.server
-        self.stats_httpd.load_config({'listen_on' : [{ 'address': address, 'port' : port }]})
-        statshttpd_server.run()
-        client = http.client.HTTPConnection(address, port)
-        client.connect()
-        client.putrequest('HEAD', stats_httpd.XML_URL_PATH)
-        client.endheaders()
-        response = client.getresponse()
+        self.client.putrequest('HEAD', stats_httpd.XML_URL_PATH)
+        self.client.endheaders()
+        response = self.client.getresponse()
         self.assertEqual(response.status, 200)
 
-        client.putrequest('HEAD', '/path/to/foo/bar')
-        client.endheaders()
-        response = client.getresponse()
+        self.client.putrequest('HEAD', '/path/to/foo/bar')
+        self.client.endheaders()
+        response = self.client.getresponse()
         self.assertEqual(response.status, 404)
-        client.close()
-        statshttpd_server.shutdown()
 
 class TestHttpServerError(unittest.TestCase):
     """Tests for HttpServerError exception"""
@@ -273,9 +268,12 @@ class TestHttpServer(unittest.TestCase):
         self.base.shutdown()
 
     def test_httpserver(self):
-        statshttpd = stats_httpd.StatsHttpd()
-        self.assertEqual(type(statshttpd.httpd), list)
-        self.assertEqual(len(statshttpd.httpd), 0)
+        self.stats_httpd = MyStatsHttpd(get_availaddr())
+        self.assertEqual(type(self.stats_httpd.httpd), list)
+        self.assertEqual(len(self.stats_httpd.httpd), 1)
+        for httpd in self.stats_httpd.httpd:
+            self.assertTrue(isinstance(httpd, stats_httpd.HttpServer))
+        self.stats_httpd.stop()
 
 class TestStatsHttpdError(unittest.TestCase):
     """Tests for StatsHttpdError exception"""
@@ -292,28 +290,20 @@ class TestStatsHttpd(unittest.TestCase):
     def setUp(self):
         self.base = BaseModules()
         self.stats_server = ThreadingServerManager(MyStats)
-        self.stats = self.stats_server.server
         self.stats_server.run()
-        self.stats_httpd = stats_httpd.StatsHttpd()
-
         # checking IPv6 enabled on this platform
-        self.ipv6_enabled = True
-        try:
-            sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
-            sock.bind(("::1",8000))
-            sock.close()
-        except socket.error:
-            self.ipv6_enabled = False
+        self.ipv6_enabled = is_ipv6_enabled()
 
     def tearDown(self):
-        self.stats_httpd.stop()
         self.stats_server.shutdown()
         self.base.shutdown()
 
     def test_init(self):
+        server_address = get_availaddr()
+        self.stats_httpd = MyStatsHttpd(server_address)
         self.assertEqual(self.stats_httpd.running, False)
         self.assertEqual(self.stats_httpd.poll_intval, 0.5)
-        self.assertEqual(self.stats_httpd.httpd, [])
+        self.assertNotEqual(len(self.stats_httpd.httpd), 0)
         self.assertEqual(type(self.stats_httpd.mccs), isc.config.ModuleCCSession)
         self.assertEqual(type(self.stats_httpd.cc_session), isc.cc.Session)
         self.assertEqual(len(self.stats_httpd.config), 2)
@@ -321,144 +311,164 @@ class TestStatsHttpd(unittest.TestCase):
         self.assertEqual(len(self.stats_httpd.config['listen_on']), 1)
         self.assertTrue('address' in self.stats_httpd.config['listen_on'][0])
         self.assertTrue('port' in self.stats_httpd.config['listen_on'][0])
-        self.assertTrue(('127.0.0.1', 8000) in set(self.stats_httpd.http_addrs))
+        self.assertTrue(server_address in set(self.stats_httpd.http_addrs))
+        self.stats_httpd.stop()
 
     def test_openclose_mccs(self):
-        statshttpd = stats_httpd.StatsHttpd()
-        statshttpd.close_mccs()
-        self.assertEqual(statshttpd.mccs, None)
-        statshttpd.open_mccs()
-        self.assertIsNotNone(statshttpd.mccs)
-        statshttpd.mccs = None
-        self.assertEqual(statshttpd.mccs, None)
-        self.assertEqual(statshttpd.close_mccs(), None)
+        self.stats_httpd = MyStatsHttpd(get_availaddr())
+        self.stats_httpd.close_mccs()
+        self.assertEqual(self.stats_httpd.mccs, None)
+        self.stats_httpd.open_mccs()
+        self.assertIsNotNone(self.stats_httpd.mccs)
+        self.stats_httpd.mccs = None
+        self.assertEqual(self.stats_httpd.mccs, None)
+        self.assertEqual(self.stats_httpd.close_mccs(), None)
+        self.stats_httpd.stop()
 
     def test_mccs(self):
+        self.stats_httpd = MyStatsHttpd(get_availaddr())
         self.assertIsNotNone(self.stats_httpd.mccs.get_socket())
         self.assertTrue(
             isinstance(self.stats_httpd.mccs.get_socket(), socket.socket))
         self.assertTrue(
             isinstance(self.stats_httpd.cc_session, isc.cc.session.Session))
-        self.statistics_spec = self.stats_httpd.get_stats_spec()
+        statistics_spec = self.stats_httpd.get_stats_spec()
         for mod in DUMMY_DATA:
-            self.assertTrue(mod in self.statistics_spec)
-            for cfg in self.statistics_spec[mod]:
+            self.assertTrue(mod in statistics_spec)
+            for cfg in statistics_spec[mod]:
                 self.assertTrue('item_name' in cfg)
                 self.assertTrue(cfg['item_name'] in DUMMY_DATA[mod])
-            self.assertTrue(len(self.statistics_spec[mod]), len(DUMMY_DATA[mod]))
+            self.assertTrue(len(statistics_spec[mod]), len(DUMMY_DATA[mod]))
         self.stats_httpd.close_mccs()
         self.assertIsNone(self.stats_httpd.mccs)
+        self.stats_httpd.stop()
 
     def test_httpd(self):
         # dual stack (addresses is ipv4 and ipv6)
         if self.ipv6_enabled:
-            self.assertTrue(('127.0.0.1', 8000) in set(self.stats_httpd.http_addrs))
-            self.stats_httpd.http_addrs = [ ('::1', 8000), ('127.0.0.1', 8000) ]
-            self.assertTrue(
-                stats_httpd.HttpServer.address_family in set([socket.AF_INET, socket.AF_INET6]))
-            self.stats_httpd.open_httpd()
+            server_addresses = (get_availaddr('::1'), get_availaddr())
+            self.stats_httpd = MyStatsHttpd(*server_addresses)
             for ht in self.stats_httpd.httpd:
+                self.assertTrue(isinstance(ht, stats_httpd.HttpServer))
+                self.assertTrue(ht.address_family in set([socket.AF_INET, socket.AF_INET6]))
                 self.assertTrue(isinstance(ht.socket, socket.socket))
-            self.stats_httpd.close_httpd()
+            self.stats_httpd.stop()
 
         # dual stack (address is ipv6)
         if self.ipv6_enabled:
-            self.stats_httpd.http_addrs = [ ('::1', 8000) ]
-            self.stats_httpd.open_httpd()
+            server_addresses = get_availaddr('::1')
+            self.stats_httpd = MyStatsHttpd(server_addresses)
             for ht in self.stats_httpd.httpd:
+                self.assertTrue(isinstance(ht, stats_httpd.HttpServer))
+                self.assertEqual(ht.address_family, socket.AF_INET6)
                 self.assertTrue(isinstance(ht.socket, socket.socket))
-            self.stats_httpd.close_httpd()
+            self.stats_httpd.stop()
 
         # dual stack (address is ipv4)
         if self.ipv6_enabled:
-            self.stats_httpd.http_addrs = [ ('127.0.0.1', 8000) ]
-            self.stats_httpd.open_httpd()
+            server_addresses = get_availaddr()
+            self.stats_httpd = MyStatsHttpd(server_addresses)
             for ht in self.stats_httpd.httpd:
+                self.assertTrue(isinstance(ht, stats_httpd.HttpServer))
+                self.assertEqual(ht.address_family, socket.AF_INET)
                 self.assertTrue(isinstance(ht.socket, socket.socket))
-            self.stats_httpd.close_httpd()
+            self.stats_httpd.stop()
 
         # only-ipv4 single stack
         if not self.ipv6_enabled:
-            self.stats_httpd.http_addrs = [ ('127.0.0.1', 8000) ]
-            self.stats_httpd.open_httpd()
+            server_addresses = get_availaddr()
+            self.stats_httpd = MyStatsHttpd(server_addresses)
             for ht in self.stats_httpd.httpd:
+                self.assertTrue(isinstance(ht, stats_httpd.HttpServer))
+                self.assertEqual(ht.address_family, socket.AF_INET)
                 self.assertTrue(isinstance(ht.socket, socket.socket))
-            self.stats_httpd.close_httpd()
+            self.stats_httpd.stop()
 
-        # only-ipv4 single stack (force set ipv6 )
-        if not self.ipv6_enabled:
-            self.stats_httpd.http_addrs = [ ('::1', 8000) ]
-            self.assertRaises(stats_httpd.HttpServerError,
-                              self.stats_httpd.open_httpd)
+        # any address (IPv4)
+        server_addresses = get_availaddr(address='0.0.0.0')
+        self.stats_httpd = MyStatsHttpd(server_addresses)
+        for ht in self.stats_httpd.httpd:
+            self.assertTrue(isinstance(ht, stats_httpd.HttpServer))
+            self.assertEqual(ht.address_family,socket.AF_INET)
+            self.assertTrue(isinstance(ht.socket, socket.socket))
+        self.stats_httpd.stop()
+
+        # any address (IPv6)
+        if self.ipv6_enabled:
+            server_addresses = get_availaddr(address='::')
+            self.stats_httpd = MyStatsHttpd(server_addresses)
+            for ht in self.stats_httpd.httpd:
+                self.assertTrue(isinstance(ht, stats_httpd.HttpServer))
+                self.assertEqual(ht.address_family,socket.AF_INET6)
+                self.assertTrue(isinstance(ht.socket, socket.socket))
+            self.stats_httpd.stop()
 
         # hostname
-        self.stats_httpd.http_addrs = [ ('localhost', 8000) ]
-        self.stats_httpd.open_httpd()
+        server_addresses = get_availaddr(address='localhost')
+        self.stats_httpd = MyStatsHttpd(server_addresses)
         for ht in self.stats_httpd.httpd:
+            self.assertTrue(isinstance(ht, stats_httpd.HttpServer))
+            self.assertTrue(ht.address_family in set([socket.AF_INET, socket.AF_INET6]))
             self.assertTrue(isinstance(ht.socket, socket.socket))
-        self.stats_httpd.close_httpd()
+        self.stats_httpd.stop()
 
-        self.stats_httpd.http_addrs = [ ('my.host.domain', 8000) ]
-        self.assertRaises(stats_httpd.HttpServerError, self.stats_httpd.open_httpd)
-        self.assertEqual(type(self.stats_httpd.httpd), list)
-        self.assertEqual(len(self.stats_httpd.httpd), 0)
-        self.stats_httpd.close_httpd()
+        # nonexistent hostname
+        self.assertRaises(stats_httpd.HttpServerError, MyStatsHttpd, ('my.host.domain', 8000))
 
         # over flow of port number
-        self.stats_httpd.http_addrs = [ ('', 80000) ]
-        self.assertRaises(stats_httpd.HttpServerError, self.stats_httpd.open_httpd)
+        self.assertRaises(stats_httpd.HttpServerError, MyStatsHttpd, ('127.0.0.1', 80000))
 
         # negative
-        self.stats_httpd.http_addrs = [ ('', -8000) ]
-        self.assertRaises(stats_httpd.HttpServerError, self.stats_httpd.open_httpd)
+        self.assertRaises(stats_httpd.HttpServerError, MyStatsHttpd, ('127.0.0.1', -8000))
 
         # alphabet
-        self.stats_httpd.http_addrs = [ ('', 'ABCDE') ]
-        self.assertRaises(stats_httpd.HttpServerError, self.stats_httpd.open_httpd)
+        self.assertRaises(stats_httpd.HttpServerError, MyStatsHttpd, ('127.0.0.1', 'ABCDE'))
 
         # Address already in use
-        self.statshttpd_server = ThreadingServerManager(MyStatsHttpd)
-        self.statshttpd_server.server.load_config({'listen_on' : [{ 'address': '127.0.0.1', 'port' : 65454 }]})
-        self.statshttpd_server.run()
-        self.stats_httpd.load_config({'listen_on' : [{ 'address': '127.0.0.1', 'port' : 65454 }]})
-        self.assertRaises(stats_httpd.HttpServerError, self.stats_httpd.open_httpd)
-        self.statshttpd_server.shutdown()
+        server_addresses = get_availaddr()
+        self.stats_httpd_server = ThreadingServerManager(MyStatsHttpd, server_addresses)
+        self.stats_httpd_server.run()
+        self.assertRaises(stats_httpd.HttpServerError, MyStatsHttpd, server_addresses)
+        self.stats_httpd_server.shutdown()
 
     def test_running(self):
+        self.stats_httpd_server = ThreadingServerManager(MyStatsHttpd, get_availaddr())
+        self.stats_httpd = self.stats_httpd_server.server
         self.assertFalse(self.stats_httpd.running)
-        self.statshttpd_server = ThreadingServerManager(MyStatsHttpd)
-        self.stats_httpd = self.statshttpd_server.server
-        self.stats_httpd.load_config({'listen_on' : [{ 'address': '127.0.0.1', 'port' : 65455 }]})
-        self.statshttpd_server.run()
+        self.stats_httpd_server.run()
         self.assertTrue(self.stats_httpd.running)
-        self.statshttpd_server.shutdown()
+        send_shutdown("StatsHttpd")
         self.assertFalse(self.stats_httpd.running)
+        self.stats_httpd_server.shutdown()
 
         # failure case
-        self.stats_httpd = stats_httpd.StatsHttpd()
+        self.stats_httpd = MyStatsHttpd(get_availaddr())
         self.stats_httpd.cc_session.close()
-        self.assertRaises(
-            isc.cc.session.SessionError, self.stats_httpd.start)
+        self.assertRaises(ValueError, self.stats_httpd.start)
+        self.stats_httpd.stop()
 
-    def test_select_failure(self):
+    def test_select_failure1(self):
         def raise_select_except(*args):
             raise select.error('dummy error')
-        def raise_select_except_with_errno(*args):
+        orig_select = stats_httpd.select.select
+        stats_httpd.select.select = raise_select_except
+        self.stats_httpd = MyStatsHttpd(get_availaddr())
+        self.assertRaises(select.error, self.stats_httpd.start)
+        self.stats_httpd.stop()
+        stats_httpd.select.select = orig_select
+
+    def test_select_failure2(self):
+        def raise_select_except(*args):
             raise select.error(errno.EINTR)
-        (address, port) = ('127.0.0.1', 65456)
+        orig_select = stats_httpd.select.select
         stats_httpd.select.select = raise_select_except
-        statshttpd = stats_httpd.StatsHttpd()
-        statshttpd.load_config({'listen_on' : [{ 'address': address, 'port' : port }]})
-        self.assertRaises(select.error, statshttpd.start)
-        statshttpd.stop()
-        stats_httpd.select.select = raise_select_except_with_errno
-        statshttpd_server = ThreadingServerManager(MyStatsHttpd)
-        statshttpd = statshttpd_server.server
-        statshttpd.load_config({'listen_on' : [{ 'address': address, 'port' : port }]})
-        statshttpd_server.run()
-        statshttpd_server.shutdown()
+        self.stats_httpd_server = ThreadingServerManager(MyStatsHttpd, get_availaddr())
+        self.stats_httpd_server.run()
+        self.stats_httpd_server.shutdown()
+        stats_httpd.select.select = orig_select
 
     def test_open_template(self):
+        self.stats_httpd = MyStatsHttpd(get_availaddr())
         # successful conditions
         tmpl = self.stats_httpd.open_template(stats_httpd.XML_TEMPLATE_LOCATION)
         self.assertTrue(isinstance(tmpl, string.Template))
@@ -490,8 +500,10 @@ class TestStatsHttpd(unittest.TestCase):
         self.assertRaises(
             IOError,
             self.stats_httpd.open_template, '/path/to/foo/bar')
+        self.stats_httpd.stop()
 
     def test_commands(self):
+        self.stats_httpd = MyStatsHttpd(get_availaddr())
         self.assertEqual(self.stats_httpd.command_handler("status", None),
                          isc.config.ccsession.create_answer(
                 0, "Stats Httpd is up. (PID " + str(os.getpid()) + ")"))
@@ -504,75 +516,81 @@ class TestStatsHttpd(unittest.TestCase):
             self.stats_httpd.command_handler("__UNKNOWN_COMMAND__", None),
             isc.config.ccsession.create_answer(
                 1, "Unknown command: __UNKNOWN_COMMAND__"))
+        self.stats_httpd.stop()
 
     def test_config(self):
+        self.stats_httpd = MyStatsHttpd(get_availaddr())
         self.assertEqual(
             self.stats_httpd.config_handler(dict(_UNKNOWN_KEY_=None)),
             isc.config.ccsession.create_answer(
-                1, "Unknown known config: _UNKNOWN_KEY_"))
+                1, "unknown item _UNKNOWN_KEY_"))
 
+        addresses = get_availaddr()
         self.assertEqual(
             self.stats_httpd.config_handler(
-                dict(listen_on=[dict(address="127.0.0.1",port=8000)])),
+                dict(listen_on=[dict(address=addresses[0],port=addresses[1])])),
             isc.config.ccsession.create_answer(0))
         self.assertTrue("listen_on" in self.stats_httpd.config)
         for addr in self.stats_httpd.config["listen_on"]:
             self.assertTrue("address" in addr)
             self.assertTrue("port" in addr)
-            self.assertTrue(addr["address"] == "127.0.0.1")
-            self.assertTrue(addr["port"] == 8000)
+            self.assertTrue(addr["address"] == addresses[0])
+            self.assertTrue(addr["port"] == addresses[1])
 
         if self.ipv6_enabled:
+            addresses = get_availaddr("::1")
             self.assertEqual(
                 self.stats_httpd.config_handler(
-                    dict(listen_on=[dict(address="::1",port=8000)])),
+                dict(listen_on=[dict(address=addresses[0],port=addresses[1])])),
                 isc.config.ccsession.create_answer(0))
             self.assertTrue("listen_on" in self.stats_httpd.config)
             for addr in self.stats_httpd.config["listen_on"]:
                 self.assertTrue("address" in addr)
                 self.assertTrue("port" in addr)
-                self.assertTrue(addr["address"] == "::1")
-                self.assertTrue(addr["port"] == 8000)
+                self.assertTrue(addr["address"] == addresses[0])
+                self.assertTrue(addr["port"] == addresses[1])
 
+        addresses = get_availaddr()
         self.assertEqual(
             self.stats_httpd.config_handler(
-                        dict(listen_on=[dict(address="127.0.0.1",port=54321)])),
+                dict(listen_on=[dict(address=addresses[0],port=addresses[1])])),
             isc.config.ccsession.create_answer(0))
         self.assertTrue("listen_on" in self.stats_httpd.config)
         for addr in self.stats_httpd.config["listen_on"]:
             self.assertTrue("address" in addr)
             self.assertTrue("port" in addr)
-            self.assertTrue(addr["address"] == "127.0.0.1")
-            self.assertTrue(addr["port"] == 54321)
+            self.assertTrue(addr["address"] == addresses[0])
+            self.assertTrue(addr["port"] == addresses[1])
         (ret, arg) = isc.config.ccsession.parse_answer(
             self.stats_httpd.config_handler(
                 dict(listen_on=[dict(address="1.2.3.4",port=543210)]))
             )
         self.assertEqual(ret, 1)
+        self.stats_httpd.stop()
 
     def test_xml_handler(self):
-        orig_get_stats_data = stats_httpd.StatsHttpd.get_stats_data
-        stats_httpd.StatsHttpd.get_stats_data = lambda x: \
+        self.stats_httpd = MyStatsHttpd(get_availaddr())
+        self.stats_httpd.get_stats_data = lambda: \
             { 'Dummy' : { 'foo':'bar' } }
-        xml_body1 = stats_httpd.StatsHttpd().open_template(
+        xml_body1 = self.stats_httpd.open_template(
             stats_httpd.XML_TEMPLATE_LOCATION).substitute(
             xml_string='<Dummy><foo>bar</foo></Dummy>',
             xsd_namespace=stats_httpd.XSD_NAMESPACE,
             xsd_url_path=stats_httpd.XSD_URL_PATH,
             xsl_url_path=stats_httpd.XSL_URL_PATH)
-        xml_body2 = stats_httpd.StatsHttpd().xml_handler()
+        xml_body2 = self.stats_httpd.xml_handler()
         self.assertEqual(type(xml_body1), str)
         self.assertEqual(type(xml_body2), str)
         self.assertEqual(xml_body1, xml_body2)
-        stats_httpd.StatsHttpd.get_stats_data = lambda x: \
+        self.stats_httpd.get_stats_data = lambda: \
             { 'Dummy' : {'bar':'foo'} }
-        xml_body2 = stats_httpd.StatsHttpd().xml_handler()
+        xml_body2 = self.stats_httpd.xml_handler()
         self.assertNotEqual(xml_body1, xml_body2)
-        stats_httpd.StatsHttpd.get_stats_data = orig_get_stats_data
+        self.stats_httpd.stop()
 
     def test_xsd_handler(self):
-        orig_get_stats_spec = stats_httpd.StatsHttpd.get_stats_spec
-        stats_httpd.StatsHttpd.get_stats_spec = lambda x: \
+        self.stats_httpd = MyStatsHttpd(get_availaddr())
+        self.stats_httpd.get_stats_spec = lambda: \
             { "Dummy" :
                   [{
                         "item_name": "foo",
@@ -583,7 +601,7 @@ class TestStatsHttpd(unittest.TestCase):
                         "item_title": "Foo"
                         }]
               }
-        xsd_body1 = stats_httpd.StatsHttpd().open_template(
+        xsd_body1 = self.stats_httpd.open_template(
             stats_httpd.XSD_TEMPLATE_LOCATION).substitute(
             xsd_string=\
                 '<all><element name="Dummy"><complexType><all>' \
@@ -593,11 +611,11 @@ class TestStatsHttpd(unittest.TestCase):
                 + '</annotation></element></all>' \
                 + '</complexType></element></all>',
             xsd_namespace=stats_httpd.XSD_NAMESPACE)
-        xsd_body2 = stats_httpd.StatsHttpd().xsd_handler()
+        xsd_body2 = self.stats_httpd.xsd_handler()
         self.assertEqual(type(xsd_body1), str)
         self.assertEqual(type(xsd_body2), str)
         self.assertEqual(xsd_body1, xsd_body2)
-        stats_httpd.StatsHttpd.get_stats_spec = lambda x: \
+        self.stats_httpd.get_stats_spec = lambda: \
             { "Dummy" :
                   [{
                         "item_name": "bar",
@@ -608,13 +626,13 @@ class TestStatsHttpd(unittest.TestCase):
                         "item_title": "bar"
                         }]
               }
-        xsd_body2 = stats_httpd.StatsHttpd().xsd_handler()
+        xsd_body2 = self.stats_httpd.xsd_handler()
         self.assertNotEqual(xsd_body1, xsd_body2)
-        stats_httpd.StatsHttpd.get_stats_spec = orig_get_stats_spec
+        self.stats_httpd.stop()
 
     def test_xsl_handler(self):
-        orig_get_stats_spec = stats_httpd.StatsHttpd.get_stats_spec
-        stats_httpd.StatsHttpd.get_stats_spec = lambda x: \
+        self.stats_httpd = MyStatsHttpd(get_availaddr())
+        self.stats_httpd.get_stats_spec = lambda: \
             { "Dummy" :
                   [{
                         "item_name": "foo",
@@ -625,7 +643,7 @@ class TestStatsHttpd(unittest.TestCase):
                         "item_title": "Foo"
                         }]
               }
-        xsl_body1 = stats_httpd.StatsHttpd().open_template(
+        xsl_body1 = self.stats_httpd.open_template(
             stats_httpd.XSL_TEMPLATE_LOCATION).substitute(
             xsl_string='<xsl:template match="*"><tr>' \
                 + '<td>Dummy</td>' \
@@ -633,11 +651,11 @@ class TestStatsHttpd(unittest.TestCase):
                 + '<td><xsl:value-of select="Dummy/foo" /></td>' \
                 + '</tr></xsl:template>',
             xsd_namespace=stats_httpd.XSD_NAMESPACE)
-        xsl_body2 = stats_httpd.StatsHttpd().xsl_handler()
+        xsl_body2 = self.stats_httpd.xsl_handler()
         self.assertEqual(type(xsl_body1), str)
         self.assertEqual(type(xsl_body2), str)
         self.assertEqual(xsl_body1, xsl_body2)
-        stats_httpd.StatsHttpd.get_stats_spec = lambda x: \
+        self.stats_httpd.get_stats_spec = lambda: \
             { "Dummy" :
                   [{
                         "item_name": "bar",
@@ -648,9 +666,9 @@ class TestStatsHttpd(unittest.TestCase):
                         "item_title": "bar"
                         }]
               }
-        xsl_body2 = stats_httpd.StatsHttpd().xsl_handler()
+        xsl_body2 = self.stats_httpd.xsl_handler()
         self.assertNotEqual(xsl_body1, xsl_body2)
-        stats_httpd.StatsHttpd.get_stats_spec = orig_get_stats_spec
+        self.stats_httpd.stop()
 
     def test_for_without_B10_FROM_SOURCE(self):
         # just lets it go through the code without B10_FROM_SOURCE env

+ 18 - 14
src/bin/stats/tests/b10-stats_test.py

@@ -30,7 +30,7 @@ import imp
 
 import stats
 import isc.cc.session
-from test_utils import BaseModules, ThreadingServerManager, MyStats, send_command, TIMEOUT_SEC
+from test_utils import BaseModules, ThreadingServerManager, MyStats, send_command, send_shutdown
 
 # set test name for logger
 isc.log.init("b10-stats_test")
@@ -189,27 +189,31 @@ class TestStats(unittest.TestCase):
 
     def test_start(self):
         # start without err
-        statsserver = ThreadingServerManager(MyStats)
-        statsd = statsserver.server
-        self.assertFalse(statsd.running)
-        statsserver.run()
-        self.assertTrue(statsd.running)
-        statsserver.shutdown()
-        self.assertFalse(statsd.running)
+        self.stats_server = ThreadingServerManager(MyStats)
+        self.stats = self.stats_server.server
+        self.assertFalse(self.stats.running)
+        self.stats_server.run()
+        self.assertTrue(self.stats.running)
+        send_shutdown("Stats")
+        self.assertFalse(self.stats.running)
+        self.stats_server.shutdown()
 
         # start with err
-        statsd = stats.Stats()
-        statsd.update_statistics_data = lambda x,**y: ['an error']
-        self.assertRaises(stats.StatsError, statsd.start)
+        self.stats = stats.Stats()
+        self.stats.update_statistics_data = lambda x,**y: ['an error']
+        self.assertRaises(stats.StatsError, self.stats.start)
 
     def test_handlers(self):
+        self.stats_server = ThreadingServerManager(MyStats)
+        self.stats = self.stats_server.server
+        self.stats_server.run()
         # config_handler
         self.assertEqual(self.stats.config_handler({'foo':'bar'}),
                          isc.config.create_answer(0))
 
         # command_handler
-        statsserver = ThreadingServerManager(MyStats)
-        statsserver.run()
+        self.base.boss.server._started.wait()
+        self.base.boss.server._started.clear()
         self.assertEqual(
             send_command(
                 'show', 'Stats',
@@ -279,7 +283,7 @@ class TestStats(unittest.TestCase):
             send_command('__UNKNOWN__', 'Stats'),
             (1, "Unknown command: '__UNKNOWN__'"))
 
-        statsserver.shutdown()
+        self.stats_server.shutdown()
 
     def test_update_modules(self):
         self.assertEqual(len(self.stats.modules), 0)