Browse Source

[2595] Address more review comments

- small updates to doc and comments
- changed __try_login to _try_login
- replaced direct calls to print() with self._print(), to override in tests
  (and more easily check output)
- added some more test cases
- made cmdctl tests fail is environment vars aren't set
Jelte Jansen 12 years ago
parent
commit
64dd4df0a8

+ 78 - 62
src/bin/bindctl/bindcmd.py

@@ -124,6 +124,11 @@ class BindCmdInterpreter(Cmd):
             self.csv_file_dir = pwd.getpwnam(getpass.getuser()).pw_dir + \
             self.csv_file_dir = pwd.getpwnam(getpass.getuser()).pw_dir + \
                 os.sep + '.bind10' + os.sep
                 os.sep + '.bind10' + os.sep
 
 
+    def _print(self, *args):
+        '''Simple wrapper around calls to print that can be overridden in
+           unit tests.'''
+        print(*args)
+
     def _get_session_id(self):
     def _get_session_id(self):
         '''Generate one session id for the connection. '''
         '''Generate one session id for the connection. '''
         rand = os.urandom(16)
         rand = os.urandom(16)
@@ -151,19 +156,19 @@ WARNING: Python readline module isn't available, so the command line editor
                 return 1
                 return 1
 
 
             self.cmdloop()
             self.cmdloop()
-            print('\nExit from bindctl')
+            self._print('\nExit from bindctl')
             return 0
             return 0
         except FailToLogin as err:
         except FailToLogin as err:
             # error already printed when this was raised, ignoring
             # error already printed when this was raised, ignoring
             return 1
             return 1
         except KeyboardInterrupt:
         except KeyboardInterrupt:
-            print('\nExit from bindctl')
+            self._print('\nExit from bindctl')
             return 0
             return 0
         except socket.error as err:
         except socket.error as err:
-            print('Failed to send request, the connection is closed')
+            self._print('Failed to send request, the connection is closed')
             return 1
             return 1
         except http.client.CannotSendRequest:
         except http.client.CannotSendRequest:
-            print('Can not send request, the connection is busy')
+            self._print('Can not send request, the connection is busy')
             return 1
             return 1
 
 
     def _get_saved_user_info(self, dir, file_name):
     def _get_saved_user_info(self, dir, file_name):
@@ -182,7 +187,8 @@ WARNING: Python readline module isn't available, so the command line editor
             for row in users_info:
             for row in users_info:
                 users.append([row[0], row[1]])
                 users.append([row[0], row[1]])
         except (IOError, IndexError) as err:
         except (IOError, IndexError) as err:
-            print("Error reading saved username and password from %s%s: %s" % (dir, file_name, err))
+            self._print("Error reading saved username and password "
+                        "from %s%s: %s" % (dir, file_name, err))
         finally:
         finally:
             if csvfile:
             if csvfile:
                 csvfile.close()
                 csvfile.close()
@@ -202,20 +208,22 @@ WARNING: Python readline module isn't available, so the command line editor
             writer.writerow([username, passwd])
             writer.writerow([username, passwd])
             csvfile.close()
             csvfile.close()
         except IOError as err:
         except IOError as err:
-            print("Error saving user information:", err)
-            print("user info file name: %s%s" % (dir, file_name))
+            self._print("Error saving user information:", err)
+            self._print("user info file name: %s%s" % (dir, file_name))
             return False
             return False
 
 
         return True
         return True
 
 
-    def __try_login(self, username, password):
+    def _try_login(self, username, password):
         '''
         '''
-        Attempts to log in to bindctl by sending a POST with
+        Attempts to log in to cmdctl by sending a POST with
         the given username and password.
         the given username and password.
         On success of the POST (mind, not the login, only the network
         On success of the POST (mind, not the login, only the network
         operation), returns a tuple (response, data).
         operation), returns a tuple (response, data).
         On failure, raises a FailToLogin exception, and prints some
         On failure, raises a FailToLogin exception, and prints some
         information on the failure.
         information on the failure.
+        This call is essentially 'private', but made 'protected' for
+        easier testing.
         '''
         '''
         param = {'username': username, 'password' : password}
         param = {'username': username, 'password' : password}
         try:
         try:
@@ -223,11 +231,12 @@ WARNING: Python readline module isn't available, so the command line editor
             data = response.read().decode()
             data = response.read().decode()
             return (response, data)
             return (response, data)
         except socket.error as err:
         except socket.error as err:
-            print("Socket error while sending login information: ", err)
+            self._print("Socket error while sending login information: ", err)
             if err.errno == errno.ECONNRESET:
             if err.errno == errno.ECONNRESET:
-                print("Please check the logs of b10-cmdctl, there may be a "
-                      "problem accepting SSL connections, such as a "
-                      "permission problem on the server certificate file.")
+                self._print("Please check the logs of b10-cmdctl, there may "
+                            "be a problem accepting SSL connections, such "
+                            "as a permission problem on the server "
+                            "certificate file.")
             raise FailToLogin()
             raise FailToLogin()
 
 
     def login_to_cmdctl(self):
     def login_to_cmdctl(self):
@@ -240,30 +249,30 @@ WARNING: Python readline module isn't available, so the command line editor
         # Look at existing username/password combinations and try to log in
         # Look at existing username/password combinations and try to log in
         users = self._get_saved_user_info(self.csv_file_dir, CSV_FILE_NAME)
         users = self._get_saved_user_info(self.csv_file_dir, CSV_FILE_NAME)
         for row in users:
         for row in users:
-            response, data = self.__try_login(row[0], row[1])
+            response, data = self._try_login(row[0], row[1])
 
 
             if response.status == http.client.OK:
             if response.status == http.client.OK:
                 # Is interactive?
                 # Is interactive?
                 if sys.stdin.isatty():
                 if sys.stdin.isatty():
-                    print(data + ' login as ' + row[0])
+                    self._print(data + ' login as ' + row[0])
                 return True
                 return True
 
 
         # No valid logins were found, prompt the user for a username/password
         # No valid logins were found, prompt the user for a username/password
         count = 0
         count = 0
-        print('No stored password file found, please see sections '
+        self._print('No stored password file found, please see sections '
               '"Configuration specification for b10-cmdctl" and "bindctl '
               '"Configuration specification for b10-cmdctl" and "bindctl '
               'command-line options" of the BIND 10 guide.')
               'command-line options" of the BIND 10 guide.')
         while True:
         while True:
             count = count + 1
             count = count + 1
             if count > 3:
             if count > 3:
-                print("Too many authentication failures")
+                self._print("Too many authentication failures")
                 return False
                 return False
 
 
             username = input("Username: ")
             username = input("Username: ")
             passwd = getpass.getpass()
             passwd = getpass.getpass()
 
 
-            response, data = self.__try_login(username, passwd)
-            print(data)
+            response, data = self._try_login(username, passwd)
+            self._print(data)
 
 
             if response.status == http.client.OK:
             if response.status == http.client.OK:
                 self._save_user_info(username, passwd, self.csv_file_dir,
                 self._save_user_info(username, passwd, self.csv_file_dir,
@@ -461,25 +470,26 @@ WARNING: Python readline module isn't available, so the command line editor
         pass
         pass
 
 
     def do_help(self, name):
     def do_help(self, name):
-        print(CONST_BINDCTL_HELP)
+        self._print(CONST_BINDCTL_HELP)
         for k in self.modules.values():
         for k in self.modules.values():
             n = k.get_name()
             n = k.get_name()
             if len(n) >= CONST_BINDCTL_HELP_INDENT_WIDTH:
             if len(n) >= CONST_BINDCTL_HELP_INDENT_WIDTH:
-                print("    %s" % n)
-                print(textwrap.fill(k.get_desc(),
-                      initial_indent="            ",
-                      subsequent_indent="    " +
-                      " " * CONST_BINDCTL_HELP_INDENT_WIDTH,
-                      width=70))
+                self._print("    %s" % n)
+                self._print(textwrap.fill(k.get_desc(),
+                            initial_indent="            ",
+                            subsequent_indent="    " +
+                            " " * CONST_BINDCTL_HELP_INDENT_WIDTH,
+                            width=70))
             else:
             else:
-                print(textwrap.fill("%s%s%s" %
-                    (k.get_name(),
-                     " "*(CONST_BINDCTL_HELP_INDENT_WIDTH - len(k.get_name())),
-                     k.get_desc()),
-                    initial_indent="    ",
-                    subsequent_indent="    " +
-                    " " * CONST_BINDCTL_HELP_INDENT_WIDTH,
-                    width=70))
+                self._print(textwrap.fill("%s%s%s" %
+                            (k.get_name(),
+                            " "*(CONST_BINDCTL_HELP_INDENT_WIDTH -
+                                 len(k.get_name())),
+                            k.get_desc()),
+                            initial_indent="    ",
+                            subsequent_indent="    " +
+                            " " * CONST_BINDCTL_HELP_INDENT_WIDTH,
+                            width=70))
 
 
     def onecmd(self, line):
     def onecmd(self, line):
         if line == 'EOF' or line.lower() == "quit":
         if line == 'EOF' or line.lower() == "quit":
@@ -654,20 +664,20 @@ WARNING: Python readline module isn't available, so the command line editor
             self._validate_cmd(cmd)
             self._validate_cmd(cmd)
             self._handle_cmd(cmd)
             self._handle_cmd(cmd)
         except (IOError, http.client.HTTPException) as err:
         except (IOError, http.client.HTTPException) as err:
-            print('Error: ', err)
+            self._print('Error: ', err)
         except BindCtlException as err:
         except BindCtlException as err:
-            print("Error! ", err)
+            self._print("Error! ", err)
             self._print_correct_usage(err)
             self._print_correct_usage(err)
         except isc.cc.data.DataTypeError as err:
         except isc.cc.data.DataTypeError as err:
-            print("Error! ", err)
+            self._print("Error! ", err)
         except isc.cc.data.DataTypeError as dte:
         except isc.cc.data.DataTypeError as dte:
-            print("Error: " + str(dte))
+            self._print("Error: " + str(dte))
         except isc.cc.data.DataNotFoundError as dnfe:
         except isc.cc.data.DataNotFoundError as dnfe:
-            print("Error: " + str(dnfe))
+            self._print("Error: " + str(dnfe))
         except isc.cc.data.DataAlreadyPresentError as dape:
         except isc.cc.data.DataAlreadyPresentError as dape:
-            print("Error: " + str(dape))
+            self._print("Error: " + str(dape))
         except KeyError as ke:
         except KeyError as ke:
-            print("Error: missing " + str(ke))
+            self._print("Error: missing " + str(ke))
 
 
     def _print_correct_usage(self, ept):
     def _print_correct_usage(self, ept):
         if isinstance(ept, CmdUnknownModuleSyntaxError):
         if isinstance(ept, CmdUnknownModuleSyntaxError):
@@ -716,7 +726,8 @@ WARNING: Python readline module isn't available, so the command line editor
             module_name = identifier.split('/')[1]
             module_name = identifier.split('/')[1]
             if module_name != "" and (self.config_data is None or \
             if module_name != "" and (self.config_data is None or \
                not self.config_data.have_specification(module_name)):
                not self.config_data.have_specification(module_name)):
-                print("Error: Module '" + module_name + "' unknown or not running")
+                self._print("Error: Module '" + module_name +
+                            "' unknown or not running")
                 return
                 return
 
 
         if cmd.command == "show":
         if cmd.command == "show":
@@ -730,7 +741,9 @@ WARNING: Python readline module isn't available, so the command line editor
                     #identifier
                     #identifier
                     identifier += cmd.params['argument']
                     identifier += cmd.params['argument']
                 else:
                 else:
-                    print("Error: unknown argument " + cmd.params['argument'] + ", or multiple identifiers given")
+                    self._print("Error: unknown argument " +
+                                cmd.params['argument'] +
+                                ", or multiple identifiers given")
                     return
                     return
             values = self.config_data.get_value_maps(identifier, show_all)
             values = self.config_data.get_value_maps(identifier, show_all)
             for value_map in values:
             for value_map in values:
@@ -758,13 +771,14 @@ WARNING: Python readline module isn't available, so the command line editor
                     line += "(default)"
                     line += "(default)"
                 if value_map['modified']:
                 if value_map['modified']:
                     line += "(modified)"
                     line += "(modified)"
-                print(line)
+                self._print(line)
         elif cmd.command == "show_json":
         elif cmd.command == "show_json":
             if identifier == "":
             if identifier == "":
-                print("Need at least the module to show the configuration in JSON format")
+                self._print("Need at least the module to show the "
+                            "configuration in JSON format")
             else:
             else:
                 data, default = self.config_data.get_value(identifier)
                 data, default = self.config_data.get_value(identifier)
-                print(json.dumps(data))
+                self._print(json.dumps(data))
         elif cmd.command == "add":
         elif cmd.command == "add":
             self.config_data.add_value(identifier,
             self.config_data.add_value(identifier,
                                        cmd.params.get('value_or_name'),
                                        cmd.params.get('value_or_name'),
@@ -776,7 +790,7 @@ WARNING: Python readline module isn't available, so the command line editor
                 self.config_data.remove_value(identifier, None)
                 self.config_data.remove_value(identifier, None)
         elif cmd.command == "set":
         elif cmd.command == "set":
             if 'identifier' not in cmd.params:
             if 'identifier' not in cmd.params:
-                print("Error: missing identifier or value")
+                self._print("Error: missing identifier or value")
             else:
             else:
                 parsed_value = None
                 parsed_value = None
                 try:
                 try:
@@ -793,9 +807,9 @@ WARNING: Python readline module isn't available, so the command line editor
             try:
             try:
                 self.config_data.commit()
                 self.config_data.commit()
             except isc.config.ModuleCCSessionError as mcse:
             except isc.config.ModuleCCSessionError as mcse:
-                print(str(mcse))
+                self._print(str(mcse))
         elif cmd.command == "diff":
         elif cmd.command == "diff":
-            print(self.config_data.get_local_changes())
+            self._print(self.config_data.get_local_changes())
         elif cmd.command == "go":
         elif cmd.command == "go":
             self.go(identifier)
             self.go(identifier)
 
 
@@ -815,7 +829,7 @@ WARNING: Python readline module isn't available, so the command line editor
         # check if exists, if not, revert and error
         # check if exists, if not, revert and error
         v,d = self.config_data.get_value(new_location)
         v,d = self.config_data.get_value(new_location)
         if v is None:
         if v is None:
-            print("Error: " + identifier + " not found")
+            self._print("Error: " + identifier + " not found")
             return
             return
 
 
         self.location = new_location
         self.location = new_location
@@ -830,7 +844,7 @@ WARNING: Python readline module isn't available, so the command line editor
                 with open(command.params['filename']) as command_file:
                 with open(command.params['filename']) as command_file:
                     commands = command_file.readlines()
                     commands = command_file.readlines()
             except IOError as ioe:
             except IOError as ioe:
-                print("Error: " + str(ioe))
+                self._print("Error: " + str(ioe))
                 return
                 return
         elif command_sets.has_command_set(command.command):
         elif command_sets.has_command_set(command.command):
             commands = command_sets.get_commands(command.command)
             commands = command_sets.get_commands(command.command)
@@ -848,7 +862,7 @@ WARNING: Python readline module isn't available, so the command line editor
     def __show_execute_commands(self, commands):
     def __show_execute_commands(self, commands):
         '''Prints the command list without executing them'''
         '''Prints the command list without executing them'''
         for line in commands:
         for line in commands:
-            print(line.strip())
+            self._print(line.strip())
 
 
     def __apply_execute_commands(self, commands):
     def __apply_execute_commands(self, commands):
         '''Applies the configuration commands from the given iterator.
         '''Applies the configuration commands from the given iterator.
@@ -869,18 +883,19 @@ WARNING: Python readline module isn't available, so the command line editor
             for line in commands:
             for line in commands:
                 line = line.strip()
                 line = line.strip()
                 if verbose:
                 if verbose:
-                    print(line)
+                    self._print(line)
                 if line.startswith('#') or len(line) == 0:
                 if line.startswith('#') or len(line) == 0:
                     continue
                     continue
                 elif line.startswith('!'):
                 elif line.startswith('!'):
                     if re.match('^!echo ', line, re.I) and len(line) > 6:
                     if re.match('^!echo ', line, re.I) and len(line) > 6:
-                        print(line[6:])
+                        self._print(line[6:])
                     elif re.match('^!verbose\s+on\s*$', line, re.I):
                     elif re.match('^!verbose\s+on\s*$', line, re.I):
                         verbose = True
                         verbose = True
                     elif re.match('^!verbose\s+off$', line, re.I):
                     elif re.match('^!verbose\s+off$', line, re.I):
                         verbose = False
                         verbose = False
                     else:
                     else:
-                        print("Warning: ignoring unknown directive: " + line)
+                        self._print("Warning: ignoring unknown directive: " +
+                                    line)
                 else:
                 else:
                     cmd = BindCmdParser(line)
                     cmd = BindCmdParser(line)
                     self._validate_cmd(cmd)
                     self._validate_cmd(cmd)
@@ -891,12 +906,12 @@ WARNING: Python readline module isn't available, so the command line editor
                 isc.cc.data.DataNotFoundError,
                 isc.cc.data.DataNotFoundError,
                 isc.cc.data.DataAlreadyPresentError,
                 isc.cc.data.DataAlreadyPresentError,
                 KeyError) as err:
                 KeyError) as err:
-            print('Error: ', err)
-            print()
-            print('Depending on the contents of the script, and which')
-            print('commands it has called, there can be committed and')
-            print('local changes. It is advised to check your settings,')
-            print('and revert local changes with "config revert".')
+            self._print('Error: ', err)
+            self._print()
+            self._print('Depending on the contents of the script, and which')
+            self._print('commands it has called, there can be committed and')
+            self._print('local changes. It is advised to check your settings')
+            self._print(', and revert local changes with "config revert".')
 
 
     def apply_cmd(self, cmd):
     def apply_cmd(self, cmd):
         '''Handles a general module command'''
         '''Handles a general module command'''
@@ -910,6 +925,7 @@ WARNING: Python readline module isn't available, so the command line editor
         # The reply is a string containing JSON data,
         # The reply is a string containing JSON data,
         # parse it, then prettyprint
         # parse it, then prettyprint
         if data != "" and data != "{}":
         if data != "" and data != "{}":
-            print(json.dumps(json.loads(data), sort_keys=True, indent=4))
+            self._print(json.dumps(json.loads(data), sort_keys=True,
+                                   indent=4))
 
 
 
 

+ 41 - 25
src/bin/bindctl/tests/bindctl_test.py

@@ -335,6 +335,8 @@ class TestConfigCommands(unittest.TestCase):
         self.tool.add_module_info(mod_info)
         self.tool.add_module_info(mod_info)
         self.tool.config_data = FakeCCSession()
         self.tool.config_data = FakeCCSession()
         self.stdout_backup = sys.stdout
         self.stdout_backup = sys.stdout
+        self.printed_messages = []
+        self.tool._print = self.store_print
 
 
     def test_precmd(self):
     def test_precmd(self):
         def update_all_modules_info():
         def update_all_modules_info():
@@ -347,6 +349,12 @@ class TestConfigCommands(unittest.TestCase):
         precmd('EOF')
         precmd('EOF')
         self.assertRaises(socket.error, precmd, 'continue')
         self.assertRaises(socket.error, precmd, 'continue')
 
 
+    def store_print(self, *args):
+        '''Method to override _print in BindCmdInterpreter.
+           Instead of printing the values, appends the argument tuple
+           to the list in self.printed_messages'''
+        self.printed_messages.append(" ".join(map(str, args)))
+
     def test_try_login(self):
     def test_try_login(self):
         # Make sure __try_login raises the correct exception
         # Make sure __try_login raises the correct exception
         # upon failure of either send_POST or the read() on the
         # upon failure of either send_POST or the read() on the
@@ -358,20 +366,35 @@ class TestConfigCommands(unittest.TestCase):
                 raise socket.error("test error")
                 raise socket.error("test error")
 
 
             self.tool.send_POST = send_POST_raiseImmediately
             self.tool.send_POST = send_POST_raiseImmediately
-            self.assertRaises(FailToLogin,
-                              self.tool._BindCmdInterpreter__try_login,
-                              "foo", "bar")
-
-            def send_POST_raiseOnRead(self, params):
-                class MyResponse:
-                    def read(self):
-                        raise socket.error("read error")
-                return MyResponse()
-
-            self.tool.send_POST = send_POST_raiseOnRead
-            self.assertRaises(FailToLogin,
-                              self.tool._BindCmdInterpreter__try_login,
-                              "foo", "bar")
+            self.assertRaises(FailToLogin, self.tool._try_login, "foo", "bar")
+            self.assertIn('Socket error while sending login information:  test error',
+                          self.printed_messages)
+            self.assertEqual(1, len(self.printed_messages))
+
+            def create_send_POST_raiseOnRead(exception):
+                '''Create a replacement send_POST() method that raises
+                   the given exception when read() is called on the value
+                   returned from send_POST()'''
+                def send_POST_raiseOnRead(self, params):
+                    class MyResponse:
+                        def read(self):
+                            raise exception
+                    return MyResponse()
+                return send_POST_raiseOnRead
+
+            self.tool.send_POST =\
+                create_send_POST_raiseOnRead(socket.error("read error"))
+            self.assertRaises(FailToLogin, self.tool._try_login, "foo", "bar")
+            self.assertIn('Socket error while sending login information:  read error',
+                          self.printed_messages)
+            self.assertEqual(2, len(self.printed_messages))
+
+            # any other exception should be passed through
+            self.tool.send_POST =\
+                create_send_POST_raiseOnRead(ImportError())
+            self.assertRaises(ImportError, self.tool._try_login, "foo", "bar")
+            self.assertEqual(2, len(self.printed_messages))
+
         finally:
         finally:
             self.tool.send_POST = orig_send_POST
             self.tool.send_POST = orig_send_POST
 
 
@@ -388,29 +411,22 @@ class TestConfigCommands(unittest.TestCase):
         self.tool.conn.sock = FakeSocket()
         self.tool.conn.sock = FakeSocket()
         self.tool.conn.sock.close()
         self.tool.conn.sock.close()
 
 
-        # validate log message for socket.err
-        socket_err_output = io.StringIO()
-        sys.stdout = socket_err_output
         self.assertEqual(1, self.tool.run())
         self.assertEqual(1, self.tool.run())
 
 
         # First few lines may be some kind of heading, or a warning that
         # First few lines may be some kind of heading, or a warning that
         # Python readline is unavailable, so we do a sub-string check.
         # Python readline is unavailable, so we do a sub-string check.
         self.assertIn("Failed to send request, the connection is closed",
         self.assertIn("Failed to send request, the connection is closed",
-                      socket_err_output.getvalue())
-
-        socket_err_output.close()
+                      self.printed_messages)
+        self.assertEqual(1, len(self.printed_messages))
 
 
         # validate log message for http.client.CannotSendRequest
         # validate log message for http.client.CannotSendRequest
-        cannot_send_output = io.StringIO()
-        sys.stdout = cannot_send_output
         self.assertEqual(1, self.tool.run())
         self.assertEqual(1, self.tool.run())
 
 
         # First few lines may be some kind of heading, or a warning that
         # First few lines may be some kind of heading, or a warning that
         # Python readline is unavailable, so we do a sub-string check.
         # Python readline is unavailable, so we do a sub-string check.
         self.assertIn("Can not send request, the connection is busy",
         self.assertIn("Can not send request, the connection is busy",
-                      cannot_send_output.getvalue())
-
-        cannot_send_output.close()
+                      self.printed_messages)
+        self.assertEqual(2, len(self.printed_messages))
 
 
     def test_apply_cfg_command_int(self):
     def test_apply_cfg_command_int(self):
         self.tool.location = '/'
         self.tool.location = '/'

+ 8 - 9
src/bin/cmdctl/cmdctl.py.in

@@ -289,10 +289,9 @@ class CommandControl():
             if key == 'version':
             if key == 'version':
                 continue
                 continue
             elif key in ['key_file', 'cert_file']:
             elif key in ['key_file', 'cert_file']:
-                # TODO, only check whether the file exist, is a
-                # file, and is readable
-                # further check need to be done: eg. whether
-                # the private/certificate is valid.
+                # TODO: we only check whether the file exist, is a
+                # file, and is readable; but further check need to be done:
+                # eg. whether the private/certificate is valid.
                 path = new_config[key]
                 path = new_config[key]
                 try:
                 try:
                     check_file(path)
                     check_file(path)
@@ -549,16 +548,16 @@ class SecureHTTPServer(socketserver_mixin.NoPollMixIn,
                                       certfile = cert,
                                       certfile = cert,
                                       keyfile = key,
                                       keyfile = key,
                                       ssl_version = ssl.PROTOCOL_SSLv23)
                                       ssl_version = ssl.PROTOCOL_SSLv23)
+            # Return here (if control leaves this blocks it will raise an
+            # error)
             return ssl_sock
             return ssl_sock
         except ssl.SSLError as err:
         except ssl.SSLError as err:
             logger.error(CMDCTL_SSL_SETUP_FAILURE_USER_DENIED, err)
             logger.error(CMDCTL_SSL_SETUP_FAILURE_USER_DENIED, err)
-            self.close_request(sock)
-            # raise socket error to finish the request
-            raise socket.error
         except (CmdctlException, IOError) as cce:
         except (CmdctlException, IOError) as cce:
             logger.error(CMDCTL_SSL_SETUP_FAILURE_READING_CERT, cce)
             logger.error(CMDCTL_SSL_SETUP_FAILURE_READING_CERT, cce)
-            # raise socket error to finish the request
-            raise socket.error
+        self.close_request(sock)
+        # raise socket error to finish the request
+        raise socket.error
 
 
     def get_request(self):
     def get_request(self):
         '''Get client request socket and wrap it in SSL context. '''
         '''Get client request socket and wrap it in SSL context. '''

+ 18 - 6
src/bin/cmdctl/tests/cmdctl_test.py

@@ -22,13 +22,13 @@ import sys
 from cmdctl import *
 from cmdctl import *
 import isc.log
 import isc.log
 
 
-SRC_FILE_PATH = '..' + os.sep
-if 'CMDCTL_SRC_PATH' in os.environ:
-    SRC_FILE_PATH = os.environ['CMDCTL_SRC_PATH'] + os.sep
+assert 'CMDCTL_SRC_PATH' in os.environ,\
+       "Please run this test with 'make check'"
+SRC_FILE_PATH = os.environ['CMDCTL_SRC_PATH'] + os.sep
 
 
-BUILD_FILE_PATH = '..' + os.sep
-if 'CMDCTL_BUILD_PATH' in os.environ:
-    BUILD_FILE_PATH = os.environ['CMDCTL_BUILD_PATH'] + os.sep
+assert 'CMDCTL_BUILD_PATH' in os.environ,\
+       "Please run this test with 'make check'"
+BUILD_FILE_PATH = os.environ['CMDCTL_BUILD_PATH'] + os.sep
 
 
 # Rewrite the class for unittest.
 # Rewrite the class for unittest.
 class MySecureHTTPRequestHandler(SecureHTTPRequestHandler):
 class MySecureHTTPRequestHandler(SecureHTTPRequestHandler):
@@ -535,6 +535,18 @@ class TestSecureHTTPServer(unittest.TestCase):
                                    BUILD_FILE_PATH + 'cmdctl-certfile.pem')
                                    BUILD_FILE_PATH + 'cmdctl-certfile.pem')
         self.assertIsInstance(ssl_sock, ssl.SSLSocket)
         self.assertIsInstance(ssl_sock, ssl.SSLSocket)
 
 
+        # wrap_socket can also raise IOError, which should be caught and
+        # handled like the other errors.
+        # Force this by temporarily disabling our own file checks
+        orig_check_func = self.server._check_key_and_cert
+        try:
+            self.server._check_key_and_cert = lambda x,y: None
+            self.assertRaises(socket.error,
+                              self.server._wrap_socket_in_ssl_context,
+                              sock,
+                              'no_such_file', 'no_such_file')
+        finally:
+            self.server._check_key_and_cert = orig_check_func
 
 
 class TestFuncNotInClass(unittest.TestCase):
 class TestFuncNotInClass(unittest.TestCase):
     def test_check_port(self):
     def test_check_port(self):