Parcourir la source

[master] Merge branch 'trac2595'

Jelte Jansen il y a 12 ans
Parent
commit
09b1a2f927

+ 107 - 68
src/bin/bindctl/bindcmd.py

@@ -39,6 +39,7 @@ import csv
 import pwd
 import pwd
 import getpass
 import getpass
 import copy
 import copy
+import errno
 
 
 try:
 try:
     from collections import OrderedDict
     from collections import OrderedDict
@@ -123,6 +124,11 @@ class BindCmdInterpreter(Cmd):
             self.csv_file_dir = pwd.getpwnam(getpass.getuser()).pw_dir + \
             self.csv_file_dir = pwd.getpwnam(getpass.getuser()).pw_dir + \
                 os.sep + '.bind10' + os.sep
                 os.sep + '.bind10' + os.sep
 
 
+    def _print(self, *args):
+        '''Simple wrapper around calls to print that can be overridden in
+           unit tests.'''
+        print(*args)
+
     def _get_session_id(self):
     def _get_session_id(self):
         '''Generate one session id for the connection. '''
         '''Generate one session id for the connection. '''
         rand = os.urandom(16)
         rand = os.urandom(16)
@@ -150,19 +156,19 @@ WARNING: Python readline module isn't available, so the command line editor
                 return 1
                 return 1
 
 
             self.cmdloop()
             self.cmdloop()
-            print('\nExit from bindctl')
+            self._print('\nExit from bindctl')
             return 0
             return 0
         except FailToLogin as err:
         except FailToLogin as err:
             # error already printed when this was raised, ignoring
             # error already printed when this was raised, ignoring
             return 1
             return 1
         except KeyboardInterrupt:
         except KeyboardInterrupt:
-            print('\nExit from bindctl')
+            self._print('\nExit from bindctl')
             return 0
             return 0
         except socket.error as err:
         except socket.error as err:
-            print('Failed to send request, the connection is closed')
+            self._print('Failed to send request, the connection is closed')
             return 1
             return 1
         except http.client.CannotSendRequest:
         except http.client.CannotSendRequest:
-            print('Can not send request, the connection is busy')
+            self._print('Can not send request, the connection is busy')
             return 1
             return 1
 
 
     def _get_saved_user_info(self, dir, file_name):
     def _get_saved_user_info(self, dir, file_name):
@@ -181,7 +187,8 @@ WARNING: Python readline module isn't available, so the command line editor
             for row in users_info:
             for row in users_info:
                 users.append([row[0], row[1]])
                 users.append([row[0], row[1]])
         except (IOError, IndexError) as err:
         except (IOError, IndexError) as err:
-            print("Error reading saved username and password from %s%s: %s" % (dir, file_name, err))
+            self._print("Error reading saved username and password "
+                        "from %s%s: %s" % (dir, file_name, err))
         finally:
         finally:
             if csvfile:
             if csvfile:
                 csvfile.close()
                 csvfile.close()
@@ -201,12 +208,48 @@ WARNING: Python readline module isn't available, so the command line editor
             writer.writerow([username, passwd])
             writer.writerow([username, passwd])
             csvfile.close()
             csvfile.close()
         except IOError as err:
         except IOError as err:
-            print("Error saving user information:", err)
-            print("user info file name: %s%s" % (dir, file_name))
+            self._print("Error saving user information:", err)
+            self._print("user info file name: %s%s" % (dir, file_name))
             return False
             return False
 
 
         return True
         return True
 
 
+    def __print_check_ssl_msg(self):
+        self._print("Please check the logs of b10-cmdctl, there may "
+                    "be a problem accepting SSL connections, such "
+                    "as a permission problem on the server "
+                    "certificate file.")
+
+    def _try_login(self, username, password):
+        '''
+        Attempts to log in to cmdctl by sending a POST with
+        the given username and password.
+        On success of the POST (mind, not the login, only the network
+        operation), returns a tuple (response, data).
+        On failure, raises a FailToLogin exception, and prints some
+        information on the failure.
+        This call is essentially 'private', but made 'protected' for
+        easier testing.
+        '''
+        param = {'username': username, 'password' : password}
+        try:
+            response = self.send_POST('/login', param)
+            data = response.read().decode()
+            # return here (will raise error after try block)
+            return (response, data)
+        except ssl.SSLError as err:
+            self._print("SSL error while sending login information: ", err)
+            if err.errno == ssl.SSL_ERROR_EOF:
+                self.__print_check_ssl_msg()
+        except socket.error as err:
+            self._print("Socket error while sending login information: ", err)
+            # An SSL setup error can also bubble up as a plain CONNRESET...
+            # (on some systems it usually does)
+            if err.errno == errno.ECONNRESET:
+                self.__print_check_ssl_msg()
+            pass
+        raise FailToLogin()
+
     def login_to_cmdctl(self):
     def login_to_cmdctl(self):
         '''Login to cmdctl with the username and password given by
         '''Login to cmdctl with the username and password given by
         the user. After the login is sucessful, the username and
         the user. After the login is sucessful, the username and
@@ -217,41 +260,30 @@ WARNING: Python readline module isn't available, so the command line editor
         # Look at existing username/password combinations and try to log in
         # Look at existing username/password combinations and try to log in
         users = self._get_saved_user_info(self.csv_file_dir, CSV_FILE_NAME)
         users = self._get_saved_user_info(self.csv_file_dir, CSV_FILE_NAME)
         for row in users:
         for row in users:
-            param = {'username': row[0], 'password' : row[1]}
-            try:
-                response = self.send_POST('/login', param)
-                data = response.read().decode()
-            except socket.error as err:
-                print("Socket error while sending login information:", err)
-                raise FailToLogin()
+            response, data = self._try_login(row[0], row[1])
 
 
             if response.status == http.client.OK:
             if response.status == http.client.OK:
                 # Is interactive?
                 # Is interactive?
                 if sys.stdin.isatty():
                 if sys.stdin.isatty():
-                    print(data + ' login as ' + row[0])
+                    self._print(data + ' login as ' + row[0])
                 return True
                 return True
 
 
         # No valid logins were found, prompt the user for a username/password
         # No valid logins were found, prompt the user for a username/password
         count = 0
         count = 0
-        print('No stored password file found, please see sections '
+        self._print('No stored password file found, please see sections '
               '"Configuration specification for b10-cmdctl" and "bindctl '
               '"Configuration specification for b10-cmdctl" and "bindctl '
               'command-line options" of the BIND 10 guide.')
               'command-line options" of the BIND 10 guide.')
         while True:
         while True:
             count = count + 1
             count = count + 1
             if count > 3:
             if count > 3:
-                print("Too many authentication failures")
+                self._print("Too many authentication failures")
                 return False
                 return False
 
 
             username = input("Username: ")
             username = input("Username: ")
             passwd = getpass.getpass()
             passwd = getpass.getpass()
-            param = {'username': username, 'password' : passwd}
-            try:
-                response = self.send_POST('/login', param)
-                data = response.read().decode()
-                print(data)
-            except socket.error as err:
-                print("Socket error while sending login information:", err)
-                raise FailToLogin()
+
+            response, data = self._try_login(username, passwd)
+            self._print(data)
 
 
             if response.status == http.client.OK:
             if response.status == http.client.OK:
                 self._save_user_info(username, passwd, self.csv_file_dir,
                 self._save_user_info(username, passwd, self.csv_file_dir,
@@ -449,25 +481,26 @@ WARNING: Python readline module isn't available, so the command line editor
         pass
         pass
 
 
     def do_help(self, name):
     def do_help(self, name):
-        print(CONST_BINDCTL_HELP)
+        self._print(CONST_BINDCTL_HELP)
         for k in self.modules.values():
         for k in self.modules.values():
             n = k.get_name()
             n = k.get_name()
             if len(n) >= CONST_BINDCTL_HELP_INDENT_WIDTH:
             if len(n) >= CONST_BINDCTL_HELP_INDENT_WIDTH:
-                print("    %s" % n)
-                print(textwrap.fill(k.get_desc(),
-                      initial_indent="            ",
-                      subsequent_indent="    " +
-                      " " * CONST_BINDCTL_HELP_INDENT_WIDTH,
-                      width=70))
+                self._print("    %s" % n)
+                self._print(textwrap.fill(k.get_desc(),
+                            initial_indent="            ",
+                            subsequent_indent="    " +
+                            " " * CONST_BINDCTL_HELP_INDENT_WIDTH,
+                            width=70))
             else:
             else:
-                print(textwrap.fill("%s%s%s" %
-                    (k.get_name(),
-                     " "*(CONST_BINDCTL_HELP_INDENT_WIDTH - len(k.get_name())),
-                     k.get_desc()),
-                    initial_indent="    ",
-                    subsequent_indent="    " +
-                    " " * CONST_BINDCTL_HELP_INDENT_WIDTH,
-                    width=70))
+                self._print(textwrap.fill("%s%s%s" %
+                            (k.get_name(),
+                            " "*(CONST_BINDCTL_HELP_INDENT_WIDTH -
+                                 len(k.get_name())),
+                            k.get_desc()),
+                            initial_indent="    ",
+                            subsequent_indent="    " +
+                            " " * CONST_BINDCTL_HELP_INDENT_WIDTH,
+                            width=70))
 
 
     def onecmd(self, line):
     def onecmd(self, line):
         if line == 'EOF' or line.lower() == "quit":
         if line == 'EOF' or line.lower() == "quit":
@@ -642,20 +675,20 @@ WARNING: Python readline module isn't available, so the command line editor
             self._validate_cmd(cmd)
             self._validate_cmd(cmd)
             self._handle_cmd(cmd)
             self._handle_cmd(cmd)
         except (IOError, http.client.HTTPException) as err:
         except (IOError, http.client.HTTPException) as err:
-            print('Error: ', err)
+            self._print('Error: ', err)
         except BindCtlException as err:
         except BindCtlException as err:
-            print("Error! ", err)
+            self._print("Error! ", err)
             self._print_correct_usage(err)
             self._print_correct_usage(err)
         except isc.cc.data.DataTypeError as err:
         except isc.cc.data.DataTypeError as err:
-            print("Error! ", err)
+            self._print("Error! ", err)
         except isc.cc.data.DataTypeError as dte:
         except isc.cc.data.DataTypeError as dte:
-            print("Error: " + str(dte))
+            self._print("Error: " + str(dte))
         except isc.cc.data.DataNotFoundError as dnfe:
         except isc.cc.data.DataNotFoundError as dnfe:
-            print("Error: " + str(dnfe))
+            self._print("Error: " + str(dnfe))
         except isc.cc.data.DataAlreadyPresentError as dape:
         except isc.cc.data.DataAlreadyPresentError as dape:
-            print("Error: " + str(dape))
+            self._print("Error: " + str(dape))
         except KeyError as ke:
         except KeyError as ke:
-            print("Error: missing " + str(ke))
+            self._print("Error: missing " + str(ke))
 
 
     def _print_correct_usage(self, ept):
     def _print_correct_usage(self, ept):
         if isinstance(ept, CmdUnknownModuleSyntaxError):
         if isinstance(ept, CmdUnknownModuleSyntaxError):
@@ -704,7 +737,8 @@ WARNING: Python readline module isn't available, so the command line editor
             module_name = identifier.split('/')[1]
             module_name = identifier.split('/')[1]
             if module_name != "" and (self.config_data is None or \
             if module_name != "" and (self.config_data is None or \
                not self.config_data.have_specification(module_name)):
                not self.config_data.have_specification(module_name)):
-                print("Error: Module '" + module_name + "' unknown or not running")
+                self._print("Error: Module '" + module_name +
+                            "' unknown or not running")
                 return
                 return
 
 
         if cmd.command == "show":
         if cmd.command == "show":
@@ -718,7 +752,9 @@ WARNING: Python readline module isn't available, so the command line editor
                     #identifier
                     #identifier
                     identifier += cmd.params['argument']
                     identifier += cmd.params['argument']
                 else:
                 else:
-                    print("Error: unknown argument " + cmd.params['argument'] + ", or multiple identifiers given")
+                    self._print("Error: unknown argument " +
+                                cmd.params['argument'] +
+                                ", or multiple identifiers given")
                     return
                     return
             values = self.config_data.get_value_maps(identifier, show_all)
             values = self.config_data.get_value_maps(identifier, show_all)
             for value_map in values:
             for value_map in values:
@@ -746,13 +782,14 @@ WARNING: Python readline module isn't available, so the command line editor
                     line += "(default)"
                     line += "(default)"
                 if value_map['modified']:
                 if value_map['modified']:
                     line += "(modified)"
                     line += "(modified)"
-                print(line)
+                self._print(line)
         elif cmd.command == "show_json":
         elif cmd.command == "show_json":
             if identifier == "":
             if identifier == "":
-                print("Need at least the module to show the configuration in JSON format")
+                self._print("Need at least the module to show the "
+                            "configuration in JSON format")
             else:
             else:
                 data, default = self.config_data.get_value(identifier)
                 data, default = self.config_data.get_value(identifier)
-                print(json.dumps(data))
+                self._print(json.dumps(data))
         elif cmd.command == "add":
         elif cmd.command == "add":
             self.config_data.add_value(identifier,
             self.config_data.add_value(identifier,
                                        cmd.params.get('value_or_name'),
                                        cmd.params.get('value_or_name'),
@@ -764,7 +801,7 @@ WARNING: Python readline module isn't available, so the command line editor
                 self.config_data.remove_value(identifier, None)
                 self.config_data.remove_value(identifier, None)
         elif cmd.command == "set":
         elif cmd.command == "set":
             if 'identifier' not in cmd.params:
             if 'identifier' not in cmd.params:
-                print("Error: missing identifier or value")
+                self._print("Error: missing identifier or value")
             else:
             else:
                 parsed_value = None
                 parsed_value = None
                 try:
                 try:
@@ -781,9 +818,9 @@ WARNING: Python readline module isn't available, so the command line editor
             try:
             try:
                 self.config_data.commit()
                 self.config_data.commit()
             except isc.config.ModuleCCSessionError as mcse:
             except isc.config.ModuleCCSessionError as mcse:
-                print(str(mcse))
+                self._print(str(mcse))
         elif cmd.command == "diff":
         elif cmd.command == "diff":
-            print(self.config_data.get_local_changes())
+            self._print(self.config_data.get_local_changes())
         elif cmd.command == "go":
         elif cmd.command == "go":
             self.go(identifier)
             self.go(identifier)
 
 
@@ -803,7 +840,7 @@ WARNING: Python readline module isn't available, so the command line editor
         # check if exists, if not, revert and error
         # check if exists, if not, revert and error
         v,d = self.config_data.get_value(new_location)
         v,d = self.config_data.get_value(new_location)
         if v is None:
         if v is None:
-            print("Error: " + identifier + " not found")
+            self._print("Error: " + identifier + " not found")
             return
             return
 
 
         self.location = new_location
         self.location = new_location
@@ -818,7 +855,7 @@ WARNING: Python readline module isn't available, so the command line editor
                 with open(command.params['filename']) as command_file:
                 with open(command.params['filename']) as command_file:
                     commands = command_file.readlines()
                     commands = command_file.readlines()
             except IOError as ioe:
             except IOError as ioe:
-                print("Error: " + str(ioe))
+                self._print("Error: " + str(ioe))
                 return
                 return
         elif command_sets.has_command_set(command.command):
         elif command_sets.has_command_set(command.command):
             commands = command_sets.get_commands(command.command)
             commands = command_sets.get_commands(command.command)
@@ -836,7 +873,7 @@ WARNING: Python readline module isn't available, so the command line editor
     def __show_execute_commands(self, commands):
     def __show_execute_commands(self, commands):
         '''Prints the command list without executing them'''
         '''Prints the command list without executing them'''
         for line in commands:
         for line in commands:
-            print(line.strip())
+            self._print(line.strip())
 
 
     def __apply_execute_commands(self, commands):
     def __apply_execute_commands(self, commands):
         '''Applies the configuration commands from the given iterator.
         '''Applies the configuration commands from the given iterator.
@@ -857,18 +894,19 @@ WARNING: Python readline module isn't available, so the command line editor
             for line in commands:
             for line in commands:
                 line = line.strip()
                 line = line.strip()
                 if verbose:
                 if verbose:
-                    print(line)
+                    self._print(line)
                 if line.startswith('#') or len(line) == 0:
                 if line.startswith('#') or len(line) == 0:
                     continue
                     continue
                 elif line.startswith('!'):
                 elif line.startswith('!'):
                     if re.match('^!echo ', line, re.I) and len(line) > 6:
                     if re.match('^!echo ', line, re.I) and len(line) > 6:
-                        print(line[6:])
+                        self._print(line[6:])
                     elif re.match('^!verbose\s+on\s*$', line, re.I):
                     elif re.match('^!verbose\s+on\s*$', line, re.I):
                         verbose = True
                         verbose = True
                     elif re.match('^!verbose\s+off$', line, re.I):
                     elif re.match('^!verbose\s+off$', line, re.I):
                         verbose = False
                         verbose = False
                     else:
                     else:
-                        print("Warning: ignoring unknown directive: " + line)
+                        self._print("Warning: ignoring unknown directive: " +
+                                    line)
                 else:
                 else:
                     cmd = BindCmdParser(line)
                     cmd = BindCmdParser(line)
                     self._validate_cmd(cmd)
                     self._validate_cmd(cmd)
@@ -879,12 +917,12 @@ WARNING: Python readline module isn't available, so the command line editor
                 isc.cc.data.DataNotFoundError,
                 isc.cc.data.DataNotFoundError,
                 isc.cc.data.DataAlreadyPresentError,
                 isc.cc.data.DataAlreadyPresentError,
                 KeyError) as err:
                 KeyError) as err:
-            print('Error: ', err)
-            print()
-            print('Depending on the contents of the script, and which')
-            print('commands it has called, there can be committed and')
-            print('local changes. It is advised to check your settings,')
-            print('and revert local changes with "config revert".')
+            self._print('Error: ', err)
+            self._print()
+            self._print('Depending on the contents of the script, and which')
+            self._print('commands it has called, there can be committed and')
+            self._print('local changes. It is advised to check your settings')
+            self._print(', and revert local changes with "config revert".')
 
 
     def apply_cmd(self, cmd):
     def apply_cmd(self, cmd):
         '''Handles a general module command'''
         '''Handles a general module command'''
@@ -898,6 +936,7 @@ WARNING: Python readline module isn't available, so the command line editor
         # The reply is a string containing JSON data,
         # The reply is a string containing JSON data,
         # parse it, then prettyprint
         # parse it, then prettyprint
         if data != "" and data != "{}":
         if data != "" and data != "{}":
-            print(json.dumps(json.loads(data), sort_keys=True, indent=4))
+            self._print(json.dumps(json.loads(data), sort_keys=True,
+                                   indent=4))
 
 
 
 

+ 114 - 11
src/bin/bindctl/tests/bindctl_test.py

@@ -18,11 +18,14 @@ import unittest
 import isc.cc.data
 import isc.cc.data
 import os
 import os
 import io
 import io
+import errno
 import sys
 import sys
 import socket
 import socket
+import ssl
 import http.client
 import http.client
 import pwd
 import pwd
 import getpass
 import getpass
+import re
 from optparse import OptionParser
 from optparse import OptionParser
 from isc.config.config_data import ConfigData, MultiConfigData
 from isc.config.config_data import ConfigData, MultiConfigData
 from isc.config.module_spec import ModuleSpec
 from isc.config.module_spec import ModuleSpec
@@ -335,6 +338,8 @@ class TestConfigCommands(unittest.TestCase):
         self.tool.add_module_info(mod_info)
         self.tool.add_module_info(mod_info)
         self.tool.config_data = FakeCCSession()
         self.tool.config_data = FakeCCSession()
         self.stdout_backup = sys.stdout
         self.stdout_backup = sys.stdout
+        self.printed_messages = []
+        self.tool._print = self.store_print
 
 
     def test_precmd(self):
     def test_precmd(self):
         def update_all_modules_info():
         def update_all_modules_info():
@@ -347,6 +352,111 @@ class TestConfigCommands(unittest.TestCase):
         precmd('EOF')
         precmd('EOF')
         self.assertRaises(socket.error, precmd, 'continue')
         self.assertRaises(socket.error, precmd, 'continue')
 
 
+    def store_print(self, *args):
+        '''Method to override _print in BindCmdInterpreter.
+           Instead of printing the values, appends the argument tuple
+           to the list in self.printed_messages'''
+        self.printed_messages.append(" ".join(map(str, args)))
+
+    def __check_printed_message(self, expected_message, printed_message):
+        self.assertIsNotNone(re.match(expected_message, printed_message),
+                             "Printed message '" + printed_message +
+                             "' does not match '" + expected_message + "'")
+
+    def __check_printed_messages(self, expected_messages):
+        '''Helper test function to check the printed messages against a list
+           of regexps'''
+        self.assertEqual(len(expected_messages), len(self.printed_messages))
+        for _ in map(self.__check_printed_message,
+                     expected_messages,
+                     self.printed_messages):
+            pass
+
+    def test_try_login(self):
+        # Make sure __try_login raises the correct exception
+        # upon failure of either send_POST or the read() on the
+        # response
+
+        orig_send_POST = self.tool.send_POST
+        expected_printed_messages = []
+        try:
+            def send_POST_raiseImmediately(self, params):
+                raise socket.error("test error")
+
+            self.tool.send_POST = send_POST_raiseImmediately
+            self.assertRaises(FailToLogin, self.tool._try_login, "foo", "bar")
+            expected_printed_messages.append(
+                'Socket error while sending login information:  test error')
+            self.__check_printed_messages(expected_printed_messages)
+
+            def create_send_POST_raiseOnRead(exception):
+                '''Create a replacement send_POST() method that raises
+                   the given exception when read() is called on the value
+                   returned from send_POST()'''
+                def send_POST_raiseOnRead(self, params):
+                    class MyResponse:
+                        def read(self):
+                            raise exception
+                    return MyResponse()
+                return send_POST_raiseOnRead
+
+            # basic socket error
+            self.tool.send_POST =\
+                create_send_POST_raiseOnRead(socket.error("read error"))
+            self.assertRaises(FailToLogin, self.tool._try_login, "foo", "bar")
+            expected_printed_messages.append(
+                'Socket error while sending login information:  read error')
+            self.__check_printed_messages(expected_printed_messages)
+
+            # connection reset
+            exc = socket.error("connection reset")
+            exc.errno = errno.ECONNRESET
+            self.tool.send_POST =\
+                create_send_POST_raiseOnRead(exc)
+            self.assertRaises(FailToLogin, self.tool._try_login, "foo", "bar")
+            expected_printed_messages.append(
+                'Socket error while sending login information:  '
+                'connection reset')
+            expected_printed_messages.append(
+                'Please check the logs of b10-cmdctl, there may be a '
+                'problem accepting SSL connections, such as a permission '
+                'problem on the server certificate file.'
+            )
+            self.__check_printed_messages(expected_printed_messages)
+
+            # 'normal' SSL error
+            exc = ssl.SSLError()
+            self.tool.send_POST =\
+                create_send_POST_raiseOnRead(exc)
+            self.assertRaises(FailToLogin, self.tool._try_login, "foo", "bar")
+            expected_printed_messages.append(
+                'SSL error while sending login information:  .*')
+            self.__check_printed_messages(expected_printed_messages)
+
+            # 'EOF' SSL error
+            exc = ssl.SSLError()
+            exc.errno = ssl.SSL_ERROR_EOF
+            self.tool.send_POST =\
+                create_send_POST_raiseOnRead(exc)
+            self.assertRaises(FailToLogin, self.tool._try_login, "foo", "bar")
+            expected_printed_messages.append(
+                'SSL error while sending login information: .*')
+            expected_printed_messages.append(
+                'Please check the logs of b10-cmdctl, there may be a '
+                'problem accepting SSL connections, such as a permission '
+                'problem on the server certificate file.'
+            )
+            self.__check_printed_messages(expected_printed_messages)
+
+            # any other exception should be passed through
+            self.tool.send_POST =\
+                create_send_POST_raiseOnRead(ImportError())
+            self.assertRaises(ImportError, self.tool._try_login, "foo", "bar")
+            self.__check_printed_messages(expected_printed_messages)
+
+        finally:
+            self.tool.send_POST = orig_send_POST
+
     def test_run(self):
     def test_run(self):
         def login_to_cmdctl():
         def login_to_cmdctl():
             return True
             return True
@@ -360,29 +470,22 @@ class TestConfigCommands(unittest.TestCase):
         self.tool.conn.sock = FakeSocket()
         self.tool.conn.sock = FakeSocket()
         self.tool.conn.sock.close()
         self.tool.conn.sock.close()
 
 
-        # validate log message for socket.err
-        socket_err_output = io.StringIO()
-        sys.stdout = socket_err_output
         self.assertEqual(1, self.tool.run())
         self.assertEqual(1, self.tool.run())
 
 
         # First few lines may be some kind of heading, or a warning that
         # First few lines may be some kind of heading, or a warning that
         # Python readline is unavailable, so we do a sub-string check.
         # Python readline is unavailable, so we do a sub-string check.
         self.assertIn("Failed to send request, the connection is closed",
         self.assertIn("Failed to send request, the connection is closed",
-                      socket_err_output.getvalue())
-
-        socket_err_output.close()
+                      self.printed_messages)
+        self.assertEqual(1, len(self.printed_messages))
 
 
         # validate log message for http.client.CannotSendRequest
         # validate log message for http.client.CannotSendRequest
-        cannot_send_output = io.StringIO()
-        sys.stdout = cannot_send_output
         self.assertEqual(1, self.tool.run())
         self.assertEqual(1, self.tool.run())
 
 
         # First few lines may be some kind of heading, or a warning that
         # First few lines may be some kind of heading, or a warning that
         # Python readline is unavailable, so we do a sub-string check.
         # Python readline is unavailable, so we do a sub-string check.
         self.assertIn("Can not send request, the connection is busy",
         self.assertIn("Can not send request, the connection is busy",
-                      cannot_send_output.getvalue())
-
-        cannot_send_output.close()
+                      self.printed_messages)
+        self.assertEqual(2, len(self.printed_messages))
 
 
     def test_apply_cfg_command_int(self):
     def test_apply_cfg_command_int(self):
         self.tool.location = '/'
         self.tool.location = '/'

+ 33 - 20
src/bin/cmdctl/cmdctl.py.in

@@ -82,6 +82,18 @@ SPECFILE_LOCATION = SPECFILE_PATH + os.sep + "cmdctl.spec"
 class CmdctlException(Exception):
 class CmdctlException(Exception):
     pass
     pass
 
 
+def check_file(file_name):
+    # TODO: Check contents of certificate file
+    if not os.path.exists(file_name):
+        raise CmdctlException("'%s' does not exist" % file_name)
+
+    if not os.path.isfile(file_name):
+        raise CmdctlException("'%s' is not a file" % file_name)
+
+    if not os.access(file_name, os.R_OK):
+        raise CmdctlException("'%s' is not readable" % file_name)
+
+
 class SecureHTTPRequestHandler(http.server.BaseHTTPRequestHandler):
 class SecureHTTPRequestHandler(http.server.BaseHTTPRequestHandler):
     '''https connection request handler.
     '''https connection request handler.
     Currently only GET and POST are supported.  '''
     Currently only GET and POST are supported.  '''
@@ -153,7 +165,6 @@ class SecureHTTPRequestHandler(http.server.BaseHTTPRequestHandler):
         self.end_headers()
         self.end_headers()
         self.wfile.write(json.dumps(reply).encode())
         self.wfile.write(json.dumps(reply).encode())
 
 
-
     def _handle_login(self):
     def _handle_login(self):
         if self._is_user_logged_in():
         if self._is_user_logged_in():
             return http.client.OK, ["user has already login"]
             return http.client.OK, ["user has already login"]
@@ -278,12 +289,14 @@ class CommandControl():
             if key == 'version':
             if key == 'version':
                 continue
                 continue
             elif key in ['key_file', 'cert_file']:
             elif key in ['key_file', 'cert_file']:
-                #TODO, only check whether the file exist,
-                # further check need to be done: eg. whether
-                # the private/certificate is valid.
+                # TODO: we only check whether the file exist, is a
+                # file, and is readable; but further check need to be done:
+                # eg. whether the private/certificate is valid.
                 path = new_config[key]
                 path = new_config[key]
-                if not os.path.exists(path):
-                    errstr = "the file doesn't exist: " + path
+                try:
+                    check_file(path)
+                except CmdctlException as cce:
+                    errstr = str(cce)
             elif key == 'accounts_file':
             elif key == 'accounts_file':
                 errstr = self._accounts_file_check(new_config[key])
                 errstr = self._accounts_file_check(new_config[key])
             else:
             else:
@@ -524,27 +537,27 @@ class SecureHTTPServer(socketserver_mixin.NoPollMixIn,
         self.user_sessions[session_id] = time.time()
         self.user_sessions[session_id] = time.time()
 
 
     def _check_key_and_cert(self, key, cert):
     def _check_key_and_cert(self, key, cert):
-        # TODO, check the content of key/certificate file
-        if not os.path.exists(key):
-            raise CmdctlException("key file '%s' doesn't exist " % key)
-
-        if not os.path.exists(cert):
-            raise CmdctlException("certificate file '%s' doesn't exist " % cert)
+        check_file(key)
+        check_file(cert);
 
 
     def _wrap_socket_in_ssl_context(self, sock, key, cert):
     def _wrap_socket_in_ssl_context(self, sock, key, cert):
         try:
         try:
             self._check_key_and_cert(key, cert)
             self._check_key_and_cert(key, cert)
             ssl_sock = ssl.wrap_socket(sock,
             ssl_sock = ssl.wrap_socket(sock,
-                                      server_side = True,
-                                      certfile = cert,
-                                      keyfile = key,
-                                      ssl_version = ssl.PROTOCOL_SSLv23)
+                                       server_side=True,
+                                       certfile=cert,
+                                       keyfile=key,
+                                       ssl_version=ssl.PROTOCOL_SSLv23)
+            # Return here (if control leaves this blocks it will raise an
+            # error)
             return ssl_sock
             return ssl_sock
-        except (ssl.SSLError, CmdctlException) as err :
+        except ssl.SSLError as err:
             logger.error(CMDCTL_SSL_SETUP_FAILURE_USER_DENIED, err)
             logger.error(CMDCTL_SSL_SETUP_FAILURE_USER_DENIED, err)
-            self.close_request(sock)
-            # raise socket error to finish the request
-            raise socket.error
+        except (CmdctlException, IOError) as cce:
+            logger.error(CMDCTL_SSL_SETUP_FAILURE_READING_CERT, cce)
+        self.close_request(sock)
+        # raise socket error to finish the request
+        raise socket.error
 
 
     def get_request(self):
     def get_request(self):
         '''Get client request socket and wrap it in SSL context. '''
         '''Get client request socket and wrap it in SSL context. '''

+ 7 - 0
src/bin/cmdctl/cmdctl_messages.mes

@@ -58,6 +58,13 @@ with the tool b10-cmdctl-usermgr.
 This debug message indicates that the given command is being sent to
 This debug message indicates that the given command is being sent to
 the given module.
 the given module.
 
 
+% CMDCTL_SSL_SETUP_FAILURE_READING_CERT failed to read certificate or key: %1
+The b10-cmdctl daemon is unable to read either the certificate file or
+the private key file, and is therefore unable to accept any SSL connections.
+The specific error is printed in the message.
+The administrator should solve the issue with the files, or recreate them
+with the b10-certgen tool.
+
 % CMDCTL_SSL_SETUP_FAILURE_USER_DENIED failed to create an SSL connection (user denied): %1
 % CMDCTL_SSL_SETUP_FAILURE_USER_DENIED failed to create an SSL connection (user denied): %1
 The user was denied because the SSL connection could not successfully
 The user was denied because the SSL connection could not successfully
 be set up. The specific error is given in the log message. Possible
 be set up. The specific error is given in the log message. Possible

+ 1 - 1
src/bin/cmdctl/tests/Makefile.am

@@ -25,7 +25,7 @@ endif
 	echo Running test: $$pytest ; \
 	echo Running test: $$pytest ; \
 	$(LIBRARY_PATH_PLACEHOLDER) \
 	$(LIBRARY_PATH_PLACEHOLDER) \
 	PYTHONPATH=$(COMMON_PYTHON_PATH):$(abs_top_builddir)/src/bin/cmdctl \
 	PYTHONPATH=$(COMMON_PYTHON_PATH):$(abs_top_builddir)/src/bin/cmdctl \
-	CMDCTL_SPEC_PATH=$(abs_top_builddir)/src/bin/cmdctl \
+	CMDCTL_BUILD_PATH=$(abs_top_builddir)/src/bin/cmdctl \
 	CMDCTL_SRC_PATH=$(abs_top_srcdir)/src/bin/cmdctl \
 	CMDCTL_SRC_PATH=$(abs_top_srcdir)/src/bin/cmdctl \
 	B10_LOCKFILE_DIR_FROM_BUILD=$(abs_top_builddir) \
 	B10_LOCKFILE_DIR_FROM_BUILD=$(abs_top_builddir) \
 	$(PYCOVERAGE_RUN) $(abs_srcdir)/$$pytest || exit ; \
 	$(PYCOVERAGE_RUN) $(abs_srcdir)/$$pytest || exit ; \

+ 115 - 45
src/bin/cmdctl/tests/cmdctl_test.py

@@ -17,17 +17,18 @@
 import unittest
 import unittest
 import socket
 import socket
 import tempfile
 import tempfile
+import stat
 import sys
 import sys
 from cmdctl import *
 from cmdctl import *
 import isc.log
 import isc.log
 
 
-SPEC_FILE_PATH = '..' + os.sep
-if 'CMDCTL_SPEC_PATH' in os.environ:
-    SPEC_FILE_PATH = os.environ['CMDCTL_SPEC_PATH'] + os.sep
+assert 'CMDCTL_SRC_PATH' in os.environ,\
+       "Please run this test with 'make check'"
+SRC_FILE_PATH = os.environ['CMDCTL_SRC_PATH'] + os.sep
 
 
-SRC_FILE_PATH = '..' + os.sep
-if 'CMDCTL_SRC_PATH' in os.environ:
-    SRC_FILE_PATH = os.environ['CMDCTL_SRC_PATH'] + os.sep
+assert 'CMDCTL_BUILD_PATH' in os.environ,\
+       "Please run this test with 'make check'"
+BUILD_FILE_PATH = os.environ['CMDCTL_BUILD_PATH'] + os.sep
 
 
 # Rewrite the class for unittest.
 # Rewrite the class for unittest.
 class MySecureHTTPRequestHandler(SecureHTTPRequestHandler):
 class MySecureHTTPRequestHandler(SecureHTTPRequestHandler):
@@ -36,7 +37,7 @@ class MySecureHTTPRequestHandler(SecureHTTPRequestHandler):
 
 
     def send_response(self, rcode):
     def send_response(self, rcode):
         self.rcode = rcode
         self.rcode = rcode
-    
+
     def end_headers(self):
     def end_headers(self):
         pass
         pass
 
 
@@ -51,13 +52,13 @@ class MySecureHTTPRequestHandler(SecureHTTPRequestHandler):
         super().do_POST()
         super().do_POST()
         self.wfile.close()
         self.wfile.close()
         os.remove('tmp.file')
         os.remove('tmp.file')
-    
+
 
 
 class FakeSecureHTTPServer(SecureHTTPServer):
 class FakeSecureHTTPServer(SecureHTTPServer):
     def __init__(self):
     def __init__(self):
         self.user_sessions = {}
         self.user_sessions = {}
         self.cmdctl = FakeCommandControlForTestRequestHandler()
         self.cmdctl = FakeCommandControlForTestRequestHandler()
-        self._verbose = True 
+        self._verbose = True
         self._user_infos = {}
         self._user_infos = {}
         self.idle_timeout = 1200
         self.idle_timeout = 1200
         self._lock = threading.Lock()
         self._lock = threading.Lock()
@@ -71,6 +72,17 @@ class FakeCommandControlForTestRequestHandler(CommandControl):
     def send_command(self, mod, cmd, param):
     def send_command(self, mod, cmd, param):
         return 0, {}
         return 0, {}
 
 
+# context to temporarily make a file unreadable
+class UnreadableFile:
+    def __init__(self, file_name):
+        self.file_name = file_name
+        self.orig_mode = os.stat(file_name).st_mode
+
+    def __enter__(self):
+        os.chmod(self.file_name, self.orig_mode & ~stat.S_IRUSR)
+
+    def __exit__(self, type, value, traceback):
+        os.chmod(self.file_name, self.orig_mode)
 
 
 class TestSecureHTTPRequestHandler(unittest.TestCase):
 class TestSecureHTTPRequestHandler(unittest.TestCase):
     def setUp(self):
     def setUp(self):
@@ -97,7 +109,7 @@ class TestSecureHTTPRequestHandler(unittest.TestCase):
         self.handler.path = '/abc'
         self.handler.path = '/abc'
         mod, cmd = self.handler._parse_request_path()
         mod, cmd = self.handler._parse_request_path()
         self.assertTrue((mod == 'abc') and (cmd == None))
         self.assertTrue((mod == 'abc') and (cmd == None))
-        
+
         self.handler.path = '/abc/edf'
         self.handler.path = '/abc/edf'
         mod, cmd = self.handler._parse_request_path()
         mod, cmd = self.handler._parse_request_path()
         self.assertTrue((mod == 'abc') and (cmd == 'edf'))
         self.assertTrue((mod == 'abc') and (cmd == 'edf'))
@@ -125,20 +137,20 @@ class TestSecureHTTPRequestHandler(unittest.TestCase):
 
 
     def test_do_GET(self):
     def test_do_GET(self):
         self.handler.do_GET()
         self.handler.do_GET()
-        self.assertEqual(self.handler.rcode, http.client.BAD_REQUEST)    
-        
+        self.assertEqual(self.handler.rcode, http.client.BAD_REQUEST)
+
     def test_do_GET_1(self):
     def test_do_GET_1(self):
         self.handler.headers['cookie'] = 12345
         self.handler.headers['cookie'] = 12345
         self.handler.do_GET()
         self.handler.do_GET()
-        self.assertEqual(self.handler.rcode, http.client.UNAUTHORIZED)    
+        self.assertEqual(self.handler.rcode, http.client.UNAUTHORIZED)
 
 
     def test_do_GET_2(self):
     def test_do_GET_2(self):
         self.handler.headers['cookie'] = 12345
         self.handler.headers['cookie'] = 12345
         self.handler.server.user_sessions[12345] = time.time() + 1000000
         self.handler.server.user_sessions[12345] = time.time() + 1000000
         self.handler.path = '/how/are'
         self.handler.path = '/how/are'
         self.handler.do_GET()
         self.handler.do_GET()
-        self.assertEqual(self.handler.rcode, http.client.NO_CONTENT)    
-    
+        self.assertEqual(self.handler.rcode, http.client.NO_CONTENT)
+
     def test_do_GET_3(self):
     def test_do_GET_3(self):
         self.handler.headers['cookie'] = 12346
         self.handler.headers['cookie'] = 12346
         self.handler.server.user_sessions[12346] = time.time() + 1000000
         self.handler.server.user_sessions[12346] = time.time() + 1000000
@@ -146,8 +158,8 @@ class TestSecureHTTPRequestHandler(unittest.TestCase):
         for path in path_vec:
         for path in path_vec:
             self.handler.path = '/' + path
             self.handler.path = '/' + path
             self.handler.do_GET()
             self.handler.do_GET()
-            self.assertEqual(self.handler.rcode, http.client.OK)    
-    
+            self.assertEqual(self.handler.rcode, http.client.OK)
+
     def test_user_logged_in(self):
     def test_user_logged_in(self):
         self.handler.server.user_sessions = {}
         self.handler.server.user_sessions = {}
         self.handler.session_id = 12345
         self.handler.session_id = 12345
@@ -243,8 +255,8 @@ class TestSecureHTTPRequestHandler(unittest.TestCase):
         self.assertEqual(http.client.BAD_REQUEST, rcode)
         self.assertEqual(http.client.BAD_REQUEST, rcode)
 
 
     def _gen_module_spec(self):
     def _gen_module_spec(self):
-        spec = { 'commands': [ 
-                  { 'command_name' :'command', 
+        spec = { 'commands': [
+                  { 'command_name' :'command',
                     'command_args': [ {
                     'command_args': [ {
                             'item_name' : 'param1',
                             'item_name' : 'param1',
                             'item_type' : 'integer',
                             'item_type' : 'integer',
@@ -253,9 +265,9 @@ class TestSecureHTTPRequestHandler(unittest.TestCase):
                            } ],
                            } ],
                     'command_description' : 'cmd description'
                     'command_description' : 'cmd description'
                   }
                   }
-                ] 
+                ]
                }
                }
-        
+
         return spec
         return spec
 
 
     def test_handle_post_request_2(self):
     def test_handle_post_request_2(self):
@@ -290,13 +302,13 @@ class MyCommandControl(CommandControl):
         return {}
         return {}
 
 
     def _setup_session(self):
     def _setup_session(self):
-        spec_file = SPEC_FILE_PATH + 'cmdctl.spec'
+        spec_file = BUILD_FILE_PATH + 'cmdctl.spec'
         module_spec = isc.config.module_spec_from_file(spec_file)
         module_spec = isc.config.module_spec_from_file(spec_file)
         config = isc.config.config_data.ConfigData(module_spec)
         config = isc.config.config_data.ConfigData(module_spec)
         self._module_name = 'Cmdctl'
         self._module_name = 'Cmdctl'
         self._cmdctl_config_data = config.get_full_config()
         self._cmdctl_config_data = config.get_full_config()
 
 
-    def _handle_msg_from_msgq(self): 
+    def _handle_msg_from_msgq(self):
         pass
         pass
 
 
 class TestCommandControl(unittest.TestCase):
 class TestCommandControl(unittest.TestCase):
@@ -305,7 +317,7 @@ class TestCommandControl(unittest.TestCase):
         self.old_stdout = sys.stdout
         self.old_stdout = sys.stdout
         sys.stdout = open(os.devnull, 'w')
         sys.stdout = open(os.devnull, 'w')
         self.cmdctl = MyCommandControl(None, True)
         self.cmdctl = MyCommandControl(None, True)
-   
+
     def tearDown(self):
     def tearDown(self):
         sys.stdout.close()
         sys.stdout.close()
         sys.stdout = self.old_stdout
         sys.stdout = self.old_stdout
@@ -320,7 +332,7 @@ class TestCommandControl(unittest.TestCase):
         old_env = os.environ
         old_env = os.environ
         if 'B10_FROM_SOURCE' in os.environ:
         if 'B10_FROM_SOURCE' in os.environ:
             del os.environ['B10_FROM_SOURCE']
             del os.environ['B10_FROM_SOURCE']
-        self.cmdctl.get_cmdctl_config_data() 
+        self.cmdctl.get_cmdctl_config_data()
         self._check_config(self.cmdctl)
         self._check_config(self.cmdctl)
         os.environ = old_env
         os.environ = old_env
 
 
@@ -328,7 +340,7 @@ class TestCommandControl(unittest.TestCase):
         os.environ['B10_FROM_SOURCE'] = '../'
         os.environ['B10_FROM_SOURCE'] = '../'
         self._check_config(self.cmdctl)
         self._check_config(self.cmdctl)
         os.environ = old_env
         os.environ = old_env
-    
+
     def test_parse_command_result(self):
     def test_parse_command_result(self):
         self.assertEqual({}, self.cmdctl._parse_command_result(1, {'error' : 1}))
         self.assertEqual({}, self.cmdctl._parse_command_result(1, {'error' : 1}))
         self.assertEqual({'a': 1}, self.cmdctl._parse_command_result(0, {'a' : 1}))
         self.assertEqual({'a': 1}, self.cmdctl._parse_command_result(0, {'a' : 1}))
@@ -391,13 +403,13 @@ class TestCommandControl(unittest.TestCase):
         os.environ = old_env
         os.environ = old_env
 
 
         answer = self.cmdctl.config_handler({'key_file': '/user/non-exist_folder'})
         answer = self.cmdctl.config_handler({'key_file': '/user/non-exist_folder'})
-        self._check_answer(answer, 1, "the file doesn't exist: /user/non-exist_folder")
+        self._check_answer(answer, 1, "'/user/non-exist_folder' does not exist")
 
 
         answer = self.cmdctl.config_handler({'cert_file': '/user/non-exist_folder'})
         answer = self.cmdctl.config_handler({'cert_file': '/user/non-exist_folder'})
-        self._check_answer(answer, 1, "the file doesn't exist: /user/non-exist_folder")
+        self._check_answer(answer, 1, "'/user/non-exist_folder' does not exist")
 
 
         answer = self.cmdctl.config_handler({'accounts_file': '/user/non-exist_folder'})
         answer = self.cmdctl.config_handler({'accounts_file': '/user/non-exist_folder'})
-        self._check_answer(answer, 1, 
+        self._check_answer(answer, 1,
                 "Invalid accounts file: [Errno 2] No such file or directory: '/user/non-exist_folder'")
                 "Invalid accounts file: [Errno 2] No such file or directory: '/user/non-exist_folder'")
 
 
         # Test with invalid accounts file
         # Test with invalid accounts file
@@ -409,7 +421,7 @@ class TestCommandControl(unittest.TestCase):
         answer = self.cmdctl.config_handler({'accounts_file': file_name})
         answer = self.cmdctl.config_handler({'accounts_file': file_name})
         self._check_answer(answer, 1, "Invalid accounts file: list index out of range")
         self._check_answer(answer, 1, "Invalid accounts file: list index out of range")
         os.remove(file_name)
         os.remove(file_name)
-    
+
     def test_send_command(self):
     def test_send_command(self):
         rcode, value = self.cmdctl.send_command('Cmdctl', 'print_settings', None)
         rcode, value = self.cmdctl.send_command('Cmdctl', 'print_settings', None)
         self.assertEqual(rcode, 0)
         self.assertEqual(rcode, 0)
@@ -424,7 +436,7 @@ class TestSecureHTTPServer(unittest.TestCase):
         self.old_stderr = sys.stderr
         self.old_stderr = sys.stderr
         sys.stdout = open(os.devnull, 'w')
         sys.stdout = open(os.devnull, 'w')
         sys.stderr = sys.stdout
         sys.stderr = sys.stdout
-        self.server = MySecureHTTPServer(('localhost', 8080), 
+        self.server = MySecureHTTPServer(('localhost', 8080),
                                          MySecureHTTPRequestHandler,
                                          MySecureHTTPRequestHandler,
                                          MyCommandControl, verbose=True)
                                          MyCommandControl, verbose=True)
 
 
@@ -458,32 +470,90 @@ class TestSecureHTTPServer(unittest.TestCase):
         self.assertEqual(1, len(self.server._user_infos))
         self.assertEqual(1, len(self.server._user_infos))
         self.assertTrue('root' in self.server._user_infos)
         self.assertTrue('root' in self.server._user_infos)
 
 
+    def test_check_file(self):
+        # Just some file that we know exists
+        file_name = BUILD_FILE_PATH + 'cmdctl-keyfile.pem'
+        check_file(file_name)
+        with UnreadableFile(file_name):
+            self.assertRaises(CmdctlException, check_file, file_name)
+        self.assertRaises(CmdctlException, check_file, '/local/not-exist')
+        self.assertRaises(CmdctlException, check_file, '/')
+
+
     def test_check_key_and_cert(self):
     def test_check_key_and_cert(self):
+        keyfile = BUILD_FILE_PATH + 'cmdctl-keyfile.pem'
+        certfile = BUILD_FILE_PATH + 'cmdctl-certfile.pem'
+
+        # no exists
+        self.assertRaises(CmdctlException, self.server._check_key_and_cert,
+                          keyfile, '/local/not-exist')
+        self.assertRaises(CmdctlException, self.server._check_key_and_cert,
+                         '/local/not-exist', certfile)
+
+        # not a file
+        self.assertRaises(CmdctlException, self.server._check_key_and_cert,
+                          keyfile, '/')
         self.assertRaises(CmdctlException, self.server._check_key_and_cert,
         self.assertRaises(CmdctlException, self.server._check_key_and_cert,
-                         '/local/not-exist', 'cmdctl-keyfile.pem')
+                         '/', certfile)
 
 
-        self.server._check_key_and_cert(SRC_FILE_PATH + 'cmdctl-keyfile.pem',
-                                        SRC_FILE_PATH + 'cmdctl-certfile.pem')
+        # no read permission
+        with UnreadableFile(certfile):
+            self.assertRaises(CmdctlException,
+                              self.server._check_key_and_cert,
+                              keyfile, certfile)
+
+        with UnreadableFile(keyfile):
+            self.assertRaises(CmdctlException,
+                              self.server._check_key_and_cert,
+                              keyfile, certfile)
+
+        # All OK (also happens to check the context code above works)
+        self.server._check_key_and_cert(keyfile, certfile)
 
 
     def test_wrap_sock_in_ssl_context(self):
     def test_wrap_sock_in_ssl_context(self):
         sock = socket.socket()
         sock = socket.socket()
-        self.assertRaises(socket.error, 
+
+        # Bad files should result in a socket.error raised by our own
+        # code in the basic file checks
+        self.assertRaises(socket.error,
                           self.server._wrap_socket_in_ssl_context,
                           self.server._wrap_socket_in_ssl_context,
-                          sock, 
-                          '../cmdctl-keyfile',
-                          '../cmdctl-certfile')
+                          sock,
+                          'no_such_file', 'no_such_file')
 
 
+        # Using a non-certificate file would cause an SSLError, which
+        # is caught by our code which then raises a basic socket.error
+        self.assertRaises(socket.error,
+                          self.server._wrap_socket_in_ssl_context,
+                          sock,
+                          BUILD_FILE_PATH + 'cmdctl.py',
+                          BUILD_FILE_PATH + 'cmdctl-certfile.pem')
+
+        # Should succeed
         sock1 = socket.socket()
         sock1 = socket.socket()
-        self.server._wrap_socket_in_ssl_context(sock1, 
-                          SRC_FILE_PATH + 'cmdctl-keyfile.pem',
-                          SRC_FILE_PATH + 'cmdctl-certfile.pem')
+        ssl_sock = self.server._wrap_socket_in_ssl_context(sock1,
+                                   BUILD_FILE_PATH + 'cmdctl-keyfile.pem',
+                                   BUILD_FILE_PATH + 'cmdctl-certfile.pem')
+        self.assertIsInstance(ssl_sock, ssl.SSLSocket)
+
+        # wrap_socket can also raise IOError, which should be caught and
+        # handled like the other errors.
+        # Force this by temporarily disabling our own file checks
+        orig_check_func = self.server._check_key_and_cert
+        try:
+            self.server._check_key_and_cert = lambda x,y: None
+            self.assertRaises(socket.error,
+                              self.server._wrap_socket_in_ssl_context,
+                              sock,
+                              'no_such_file', 'no_such_file')
+        finally:
+            self.server._check_key_and_cert = orig_check_func
 
 
 class TestFuncNotInClass(unittest.TestCase):
 class TestFuncNotInClass(unittest.TestCase):
     def test_check_port(self):
     def test_check_port(self):
-        self.assertRaises(OptionValueError, check_port, None, 'port', -1, None)        
-        self.assertRaises(OptionValueError, check_port, None, 'port', 65536, None)        
-        self.assertRaises(OptionValueError, check_addr, None, 'ipstr', 'a.b.d', None)        
-        self.assertRaises(OptionValueError, check_addr, None, 'ipstr', '1::0:a.b', None)        
+        self.assertRaises(OptionValueError, check_port, None, 'port', -1, None)
+        self.assertRaises(OptionValueError, check_port, None, 'port', 65536, None)
+        self.assertRaises(OptionValueError, check_addr, None, 'ipstr', 'a.b.d', None)
+        self.assertRaises(OptionValueError, check_addr, None, 'ipstr', '1::0:a.b', None)
 
 
 
 
 if __name__== "__main__":
 if __name__== "__main__":