Browse Source

[master] Merge branch 'trac2595'

Jelte Jansen 12 years ago
parent
commit
09b1a2f927

+ 107 - 68
src/bin/bindctl/bindcmd.py

@@ -39,6 +39,7 @@ import csv
 import pwd
 import getpass
 import copy
+import errno
 
 try:
     from collections import OrderedDict
@@ -123,6 +124,11 @@ class BindCmdInterpreter(Cmd):
             self.csv_file_dir = pwd.getpwnam(getpass.getuser()).pw_dir + \
                 os.sep + '.bind10' + os.sep
 
+    def _print(self, *args):
+        '''Simple wrapper around calls to print that can be overridden in
+           unit tests.'''
+        print(*args)
+
     def _get_session_id(self):
         '''Generate one session id for the connection. '''
         rand = os.urandom(16)
@@ -150,19 +156,19 @@ WARNING: Python readline module isn't available, so the command line editor
                 return 1
 
             self.cmdloop()
-            print('\nExit from bindctl')
+            self._print('\nExit from bindctl')
             return 0
         except FailToLogin as err:
             # error already printed when this was raised, ignoring
             return 1
         except KeyboardInterrupt:
-            print('\nExit from bindctl')
+            self._print('\nExit from bindctl')
             return 0
         except socket.error as err:
-            print('Failed to send request, the connection is closed')
+            self._print('Failed to send request, the connection is closed')
             return 1
         except http.client.CannotSendRequest:
-            print('Can not send request, the connection is busy')
+            self._print('Can not send request, the connection is busy')
             return 1
 
     def _get_saved_user_info(self, dir, file_name):
@@ -181,7 +187,8 @@ WARNING: Python readline module isn't available, so the command line editor
             for row in users_info:
                 users.append([row[0], row[1]])
         except (IOError, IndexError) as err:
-            print("Error reading saved username and password from %s%s: %s" % (dir, file_name, err))
+            self._print("Error reading saved username and password "
+                        "from %s%s: %s" % (dir, file_name, err))
         finally:
             if csvfile:
                 csvfile.close()
@@ -201,12 +208,48 @@ WARNING: Python readline module isn't available, so the command line editor
             writer.writerow([username, passwd])
             csvfile.close()
         except IOError as err:
-            print("Error saving user information:", err)
-            print("user info file name: %s%s" % (dir, file_name))
+            self._print("Error saving user information:", err)
+            self._print("user info file name: %s%s" % (dir, file_name))
             return False
 
         return True
 
+    def __print_check_ssl_msg(self):
+        self._print("Please check the logs of b10-cmdctl, there may "
+                    "be a problem accepting SSL connections, such "
+                    "as a permission problem on the server "
+                    "certificate file.")
+
+    def _try_login(self, username, password):
+        '''
+        Attempts to log in to cmdctl by sending a POST with
+        the given username and password.
+        On success of the POST (mind, not the login, only the network
+        operation), returns a tuple (response, data).
+        On failure, raises a FailToLogin exception, and prints some
+        information on the failure.
+        This call is essentially 'private', but made 'protected' for
+        easier testing.
+        '''
+        param = {'username': username, 'password' : password}
+        try:
+            response = self.send_POST('/login', param)
+            data = response.read().decode()
+            # return here (will raise error after try block)
+            return (response, data)
+        except ssl.SSLError as err:
+            self._print("SSL error while sending login information: ", err)
+            if err.errno == ssl.SSL_ERROR_EOF:
+                self.__print_check_ssl_msg()
+        except socket.error as err:
+            self._print("Socket error while sending login information: ", err)
+            # An SSL setup error can also bubble up as a plain CONNRESET...
+            # (on some systems it usually does)
+            if err.errno == errno.ECONNRESET:
+                self.__print_check_ssl_msg()
+            pass
+        raise FailToLogin()
+
     def login_to_cmdctl(self):
         '''Login to cmdctl with the username and password given by
         the user. After the login is sucessful, the username and
@@ -217,41 +260,30 @@ WARNING: Python readline module isn't available, so the command line editor
         # Look at existing username/password combinations and try to log in
         users = self._get_saved_user_info(self.csv_file_dir, CSV_FILE_NAME)
         for row in users:
-            param = {'username': row[0], 'password' : row[1]}
-            try:
-                response = self.send_POST('/login', param)
-                data = response.read().decode()
-            except socket.error as err:
-                print("Socket error while sending login information:", err)
-                raise FailToLogin()
+            response, data = self._try_login(row[0], row[1])
 
             if response.status == http.client.OK:
                 # Is interactive?
                 if sys.stdin.isatty():
-                    print(data + ' login as ' + row[0])
+                    self._print(data + ' login as ' + row[0])
                 return True
 
         # No valid logins were found, prompt the user for a username/password
         count = 0
-        print('No stored password file found, please see sections '
+        self._print('No stored password file found, please see sections '
               '"Configuration specification for b10-cmdctl" and "bindctl '
               'command-line options" of the BIND 10 guide.')
         while True:
             count = count + 1
             if count > 3:
-                print("Too many authentication failures")
+                self._print("Too many authentication failures")
                 return False
 
             username = input("Username: ")
             passwd = getpass.getpass()
-            param = {'username': username, 'password' : passwd}
-            try:
-                response = self.send_POST('/login', param)
-                data = response.read().decode()
-                print(data)
-            except socket.error as err:
-                print("Socket error while sending login information:", err)
-                raise FailToLogin()
+
+            response, data = self._try_login(username, passwd)
+            self._print(data)
 
             if response.status == http.client.OK:
                 self._save_user_info(username, passwd, self.csv_file_dir,
@@ -449,25 +481,26 @@ WARNING: Python readline module isn't available, so the command line editor
         pass
 
     def do_help(self, name):
-        print(CONST_BINDCTL_HELP)
+        self._print(CONST_BINDCTL_HELP)
         for k in self.modules.values():
             n = k.get_name()
             if len(n) >= CONST_BINDCTL_HELP_INDENT_WIDTH:
-                print("    %s" % n)
-                print(textwrap.fill(k.get_desc(),
-                      initial_indent="            ",
-                      subsequent_indent="    " +
-                      " " * CONST_BINDCTL_HELP_INDENT_WIDTH,
-                      width=70))
+                self._print("    %s" % n)
+                self._print(textwrap.fill(k.get_desc(),
+                            initial_indent="            ",
+                            subsequent_indent="    " +
+                            " " * CONST_BINDCTL_HELP_INDENT_WIDTH,
+                            width=70))
             else:
-                print(textwrap.fill("%s%s%s" %
-                    (k.get_name(),
-                     " "*(CONST_BINDCTL_HELP_INDENT_WIDTH - len(k.get_name())),
-                     k.get_desc()),
-                    initial_indent="    ",
-                    subsequent_indent="    " +
-                    " " * CONST_BINDCTL_HELP_INDENT_WIDTH,
-                    width=70))
+                self._print(textwrap.fill("%s%s%s" %
+                            (k.get_name(),
+                            " "*(CONST_BINDCTL_HELP_INDENT_WIDTH -
+                                 len(k.get_name())),
+                            k.get_desc()),
+                            initial_indent="    ",
+                            subsequent_indent="    " +
+                            " " * CONST_BINDCTL_HELP_INDENT_WIDTH,
+                            width=70))
 
     def onecmd(self, line):
         if line == 'EOF' or line.lower() == "quit":
@@ -642,20 +675,20 @@ WARNING: Python readline module isn't available, so the command line editor
             self._validate_cmd(cmd)
             self._handle_cmd(cmd)
         except (IOError, http.client.HTTPException) as err:
-            print('Error: ', err)
+            self._print('Error: ', err)
         except BindCtlException as err:
-            print("Error! ", err)
+            self._print("Error! ", err)
             self._print_correct_usage(err)
         except isc.cc.data.DataTypeError as err:
-            print("Error! ", err)
+            self._print("Error! ", err)
         except isc.cc.data.DataTypeError as dte:
-            print("Error: " + str(dte))
+            self._print("Error: " + str(dte))
         except isc.cc.data.DataNotFoundError as dnfe:
-            print("Error: " + str(dnfe))
+            self._print("Error: " + str(dnfe))
         except isc.cc.data.DataAlreadyPresentError as dape:
-            print("Error: " + str(dape))
+            self._print("Error: " + str(dape))
         except KeyError as ke:
-            print("Error: missing " + str(ke))
+            self._print("Error: missing " + str(ke))
 
     def _print_correct_usage(self, ept):
         if isinstance(ept, CmdUnknownModuleSyntaxError):
@@ -704,7 +737,8 @@ WARNING: Python readline module isn't available, so the command line editor
             module_name = identifier.split('/')[1]
             if module_name != "" and (self.config_data is None or \
                not self.config_data.have_specification(module_name)):
-                print("Error: Module '" + module_name + "' unknown or not running")
+                self._print("Error: Module '" + module_name +
+                            "' unknown or not running")
                 return
 
         if cmd.command == "show":
@@ -718,7 +752,9 @@ WARNING: Python readline module isn't available, so the command line editor
                     #identifier
                     identifier += cmd.params['argument']
                 else:
-                    print("Error: unknown argument " + cmd.params['argument'] + ", or multiple identifiers given")
+                    self._print("Error: unknown argument " +
+                                cmd.params['argument'] +
+                                ", or multiple identifiers given")
                     return
             values = self.config_data.get_value_maps(identifier, show_all)
             for value_map in values:
@@ -746,13 +782,14 @@ WARNING: Python readline module isn't available, so the command line editor
                     line += "(default)"
                 if value_map['modified']:
                     line += "(modified)"
-                print(line)
+                self._print(line)
         elif cmd.command == "show_json":
             if identifier == "":
-                print("Need at least the module to show the configuration in JSON format")
+                self._print("Need at least the module to show the "
+                            "configuration in JSON format")
             else:
                 data, default = self.config_data.get_value(identifier)
-                print(json.dumps(data))
+                self._print(json.dumps(data))
         elif cmd.command == "add":
             self.config_data.add_value(identifier,
                                        cmd.params.get('value_or_name'),
@@ -764,7 +801,7 @@ WARNING: Python readline module isn't available, so the command line editor
                 self.config_data.remove_value(identifier, None)
         elif cmd.command == "set":
             if 'identifier' not in cmd.params:
-                print("Error: missing identifier or value")
+                self._print("Error: missing identifier or value")
             else:
                 parsed_value = None
                 try:
@@ -781,9 +818,9 @@ WARNING: Python readline module isn't available, so the command line editor
             try:
                 self.config_data.commit()
             except isc.config.ModuleCCSessionError as mcse:
-                print(str(mcse))
+                self._print(str(mcse))
         elif cmd.command == "diff":
-            print(self.config_data.get_local_changes())
+            self._print(self.config_data.get_local_changes())
         elif cmd.command == "go":
             self.go(identifier)
 
@@ -803,7 +840,7 @@ WARNING: Python readline module isn't available, so the command line editor
         # check if exists, if not, revert and error
         v,d = self.config_data.get_value(new_location)
         if v is None:
-            print("Error: " + identifier + " not found")
+            self._print("Error: " + identifier + " not found")
             return
 
         self.location = new_location
@@ -818,7 +855,7 @@ WARNING: Python readline module isn't available, so the command line editor
                 with open(command.params['filename']) as command_file:
                     commands = command_file.readlines()
             except IOError as ioe:
-                print("Error: " + str(ioe))
+                self._print("Error: " + str(ioe))
                 return
         elif command_sets.has_command_set(command.command):
             commands = command_sets.get_commands(command.command)
@@ -836,7 +873,7 @@ WARNING: Python readline module isn't available, so the command line editor
     def __show_execute_commands(self, commands):
         '''Prints the command list without executing them'''
         for line in commands:
-            print(line.strip())
+            self._print(line.strip())
 
     def __apply_execute_commands(self, commands):
         '''Applies the configuration commands from the given iterator.
@@ -857,18 +894,19 @@ WARNING: Python readline module isn't available, so the command line editor
             for line in commands:
                 line = line.strip()
                 if verbose:
-                    print(line)
+                    self._print(line)
                 if line.startswith('#') or len(line) == 0:
                     continue
                 elif line.startswith('!'):
                     if re.match('^!echo ', line, re.I) and len(line) > 6:
-                        print(line[6:])
+                        self._print(line[6:])
                     elif re.match('^!verbose\s+on\s*$', line, re.I):
                         verbose = True
                     elif re.match('^!verbose\s+off$', line, re.I):
                         verbose = False
                     else:
-                        print("Warning: ignoring unknown directive: " + line)
+                        self._print("Warning: ignoring unknown directive: " +
+                                    line)
                 else:
                     cmd = BindCmdParser(line)
                     self._validate_cmd(cmd)
@@ -879,12 +917,12 @@ WARNING: Python readline module isn't available, so the command line editor
                 isc.cc.data.DataNotFoundError,
                 isc.cc.data.DataAlreadyPresentError,
                 KeyError) as err:
-            print('Error: ', err)
-            print()
-            print('Depending on the contents of the script, and which')
-            print('commands it has called, there can be committed and')
-            print('local changes. It is advised to check your settings,')
-            print('and revert local changes with "config revert".')
+            self._print('Error: ', err)
+            self._print()
+            self._print('Depending on the contents of the script, and which')
+            self._print('commands it has called, there can be committed and')
+            self._print('local changes. It is advised to check your settings')
+            self._print(', and revert local changes with "config revert".')
 
     def apply_cmd(self, cmd):
         '''Handles a general module command'''
@@ -898,6 +936,7 @@ WARNING: Python readline module isn't available, so the command line editor
         # The reply is a string containing JSON data,
         # parse it, then prettyprint
         if data != "" and data != "{}":
-            print(json.dumps(json.loads(data), sort_keys=True, indent=4))
+            self._print(json.dumps(json.loads(data), sort_keys=True,
+                                   indent=4))
 
 

+ 114 - 11
src/bin/bindctl/tests/bindctl_test.py

@@ -18,11 +18,14 @@ import unittest
 import isc.cc.data
 import os
 import io
+import errno
 import sys
 import socket
+import ssl
 import http.client
 import pwd
 import getpass
+import re
 from optparse import OptionParser
 from isc.config.config_data import ConfigData, MultiConfigData
 from isc.config.module_spec import ModuleSpec
@@ -335,6 +338,8 @@ class TestConfigCommands(unittest.TestCase):
         self.tool.add_module_info(mod_info)
         self.tool.config_data = FakeCCSession()
         self.stdout_backup = sys.stdout
+        self.printed_messages = []
+        self.tool._print = self.store_print
 
     def test_precmd(self):
         def update_all_modules_info():
@@ -347,6 +352,111 @@ class TestConfigCommands(unittest.TestCase):
         precmd('EOF')
         self.assertRaises(socket.error, precmd, 'continue')
 
+    def store_print(self, *args):
+        '''Method to override _print in BindCmdInterpreter.
+           Instead of printing the values, appends the argument tuple
+           to the list in self.printed_messages'''
+        self.printed_messages.append(" ".join(map(str, args)))
+
+    def __check_printed_message(self, expected_message, printed_message):
+        self.assertIsNotNone(re.match(expected_message, printed_message),
+                             "Printed message '" + printed_message +
+                             "' does not match '" + expected_message + "'")
+
+    def __check_printed_messages(self, expected_messages):
+        '''Helper test function to check the printed messages against a list
+           of regexps'''
+        self.assertEqual(len(expected_messages), len(self.printed_messages))
+        for _ in map(self.__check_printed_message,
+                     expected_messages,
+                     self.printed_messages):
+            pass
+
+    def test_try_login(self):
+        # Make sure __try_login raises the correct exception
+        # upon failure of either send_POST or the read() on the
+        # response
+
+        orig_send_POST = self.tool.send_POST
+        expected_printed_messages = []
+        try:
+            def send_POST_raiseImmediately(self, params):
+                raise socket.error("test error")
+
+            self.tool.send_POST = send_POST_raiseImmediately
+            self.assertRaises(FailToLogin, self.tool._try_login, "foo", "bar")
+            expected_printed_messages.append(
+                'Socket error while sending login information:  test error')
+            self.__check_printed_messages(expected_printed_messages)
+
+            def create_send_POST_raiseOnRead(exception):
+                '''Create a replacement send_POST() method that raises
+                   the given exception when read() is called on the value
+                   returned from send_POST()'''
+                def send_POST_raiseOnRead(self, params):
+                    class MyResponse:
+                        def read(self):
+                            raise exception
+                    return MyResponse()
+                return send_POST_raiseOnRead
+
+            # basic socket error
+            self.tool.send_POST =\
+                create_send_POST_raiseOnRead(socket.error("read error"))
+            self.assertRaises(FailToLogin, self.tool._try_login, "foo", "bar")
+            expected_printed_messages.append(
+                'Socket error while sending login information:  read error')
+            self.__check_printed_messages(expected_printed_messages)
+
+            # connection reset
+            exc = socket.error("connection reset")
+            exc.errno = errno.ECONNRESET
+            self.tool.send_POST =\
+                create_send_POST_raiseOnRead(exc)
+            self.assertRaises(FailToLogin, self.tool._try_login, "foo", "bar")
+            expected_printed_messages.append(
+                'Socket error while sending login information:  '
+                'connection reset')
+            expected_printed_messages.append(
+                'Please check the logs of b10-cmdctl, there may be a '
+                'problem accepting SSL connections, such as a permission '
+                'problem on the server certificate file.'
+            )
+            self.__check_printed_messages(expected_printed_messages)
+
+            # 'normal' SSL error
+            exc = ssl.SSLError()
+            self.tool.send_POST =\
+                create_send_POST_raiseOnRead(exc)
+            self.assertRaises(FailToLogin, self.tool._try_login, "foo", "bar")
+            expected_printed_messages.append(
+                'SSL error while sending login information:  .*')
+            self.__check_printed_messages(expected_printed_messages)
+
+            # 'EOF' SSL error
+            exc = ssl.SSLError()
+            exc.errno = ssl.SSL_ERROR_EOF
+            self.tool.send_POST =\
+                create_send_POST_raiseOnRead(exc)
+            self.assertRaises(FailToLogin, self.tool._try_login, "foo", "bar")
+            expected_printed_messages.append(
+                'SSL error while sending login information: .*')
+            expected_printed_messages.append(
+                'Please check the logs of b10-cmdctl, there may be a '
+                'problem accepting SSL connections, such as a permission '
+                'problem on the server certificate file.'
+            )
+            self.__check_printed_messages(expected_printed_messages)
+
+            # any other exception should be passed through
+            self.tool.send_POST =\
+                create_send_POST_raiseOnRead(ImportError())
+            self.assertRaises(ImportError, self.tool._try_login, "foo", "bar")
+            self.__check_printed_messages(expected_printed_messages)
+
+        finally:
+            self.tool.send_POST = orig_send_POST
+
     def test_run(self):
         def login_to_cmdctl():
             return True
@@ -360,29 +470,22 @@ class TestConfigCommands(unittest.TestCase):
         self.tool.conn.sock = FakeSocket()
         self.tool.conn.sock.close()
 
-        # validate log message for socket.err
-        socket_err_output = io.StringIO()
-        sys.stdout = socket_err_output
         self.assertEqual(1, self.tool.run())
 
         # First few lines may be some kind of heading, or a warning that
         # Python readline is unavailable, so we do a sub-string check.
         self.assertIn("Failed to send request, the connection is closed",
-                      socket_err_output.getvalue())
-
-        socket_err_output.close()
+                      self.printed_messages)
+        self.assertEqual(1, len(self.printed_messages))
 
         # validate log message for http.client.CannotSendRequest
-        cannot_send_output = io.StringIO()
-        sys.stdout = cannot_send_output
         self.assertEqual(1, self.tool.run())
 
         # First few lines may be some kind of heading, or a warning that
         # Python readline is unavailable, so we do a sub-string check.
         self.assertIn("Can not send request, the connection is busy",
-                      cannot_send_output.getvalue())
-
-        cannot_send_output.close()
+                      self.printed_messages)
+        self.assertEqual(2, len(self.printed_messages))
 
     def test_apply_cfg_command_int(self):
         self.tool.location = '/'

+ 33 - 20
src/bin/cmdctl/cmdctl.py.in

@@ -82,6 +82,18 @@ SPECFILE_LOCATION = SPECFILE_PATH + os.sep + "cmdctl.spec"
 class CmdctlException(Exception):
     pass
 
+def check_file(file_name):
+    # TODO: Check contents of certificate file
+    if not os.path.exists(file_name):
+        raise CmdctlException("'%s' does not exist" % file_name)
+
+    if not os.path.isfile(file_name):
+        raise CmdctlException("'%s' is not a file" % file_name)
+
+    if not os.access(file_name, os.R_OK):
+        raise CmdctlException("'%s' is not readable" % file_name)
+
+
 class SecureHTTPRequestHandler(http.server.BaseHTTPRequestHandler):
     '''https connection request handler.
     Currently only GET and POST are supported.  '''
@@ -153,7 +165,6 @@ class SecureHTTPRequestHandler(http.server.BaseHTTPRequestHandler):
         self.end_headers()
         self.wfile.write(json.dumps(reply).encode())
 
-
     def _handle_login(self):
         if self._is_user_logged_in():
             return http.client.OK, ["user has already login"]
@@ -278,12 +289,14 @@ class CommandControl():
             if key == 'version':
                 continue
             elif key in ['key_file', 'cert_file']:
-                #TODO, only check whether the file exist,
-                # further check need to be done: eg. whether
-                # the private/certificate is valid.
+                # TODO: we only check whether the file exist, is a
+                # file, and is readable; but further check need to be done:
+                # eg. whether the private/certificate is valid.
                 path = new_config[key]
-                if not os.path.exists(path):
-                    errstr = "the file doesn't exist: " + path
+                try:
+                    check_file(path)
+                except CmdctlException as cce:
+                    errstr = str(cce)
             elif key == 'accounts_file':
                 errstr = self._accounts_file_check(new_config[key])
             else:
@@ -524,27 +537,27 @@ class SecureHTTPServer(socketserver_mixin.NoPollMixIn,
         self.user_sessions[session_id] = time.time()
 
     def _check_key_and_cert(self, key, cert):
-        # TODO, check the content of key/certificate file
-        if not os.path.exists(key):
-            raise CmdctlException("key file '%s' doesn't exist " % key)
-
-        if not os.path.exists(cert):
-            raise CmdctlException("certificate file '%s' doesn't exist " % cert)
+        check_file(key)
+        check_file(cert);
 
     def _wrap_socket_in_ssl_context(self, sock, key, cert):
         try:
             self._check_key_and_cert(key, cert)
             ssl_sock = ssl.wrap_socket(sock,
-                                      server_side = True,
-                                      certfile = cert,
-                                      keyfile = key,
-                                      ssl_version = ssl.PROTOCOL_SSLv23)
+                                       server_side=True,
+                                       certfile=cert,
+                                       keyfile=key,
+                                       ssl_version=ssl.PROTOCOL_SSLv23)
+            # Return here (if control leaves this blocks it will raise an
+            # error)
             return ssl_sock
-        except (ssl.SSLError, CmdctlException) as err :
+        except ssl.SSLError as err:
             logger.error(CMDCTL_SSL_SETUP_FAILURE_USER_DENIED, err)
-            self.close_request(sock)
-            # raise socket error to finish the request
-            raise socket.error
+        except (CmdctlException, IOError) as cce:
+            logger.error(CMDCTL_SSL_SETUP_FAILURE_READING_CERT, cce)
+        self.close_request(sock)
+        # raise socket error to finish the request
+        raise socket.error
 
     def get_request(self):
         '''Get client request socket and wrap it in SSL context. '''

+ 7 - 0
src/bin/cmdctl/cmdctl_messages.mes

@@ -58,6 +58,13 @@ with the tool b10-cmdctl-usermgr.
 This debug message indicates that the given command is being sent to
 the given module.
 
+% CMDCTL_SSL_SETUP_FAILURE_READING_CERT failed to read certificate or key: %1
+The b10-cmdctl daemon is unable to read either the certificate file or
+the private key file, and is therefore unable to accept any SSL connections.
+The specific error is printed in the message.
+The administrator should solve the issue with the files, or recreate them
+with the b10-certgen tool.
+
 % CMDCTL_SSL_SETUP_FAILURE_USER_DENIED failed to create an SSL connection (user denied): %1
 The user was denied because the SSL connection could not successfully
 be set up. The specific error is given in the log message. Possible

+ 1 - 1
src/bin/cmdctl/tests/Makefile.am

@@ -25,7 +25,7 @@ endif
 	echo Running test: $$pytest ; \
 	$(LIBRARY_PATH_PLACEHOLDER) \
 	PYTHONPATH=$(COMMON_PYTHON_PATH):$(abs_top_builddir)/src/bin/cmdctl \
-	CMDCTL_SPEC_PATH=$(abs_top_builddir)/src/bin/cmdctl \
+	CMDCTL_BUILD_PATH=$(abs_top_builddir)/src/bin/cmdctl \
 	CMDCTL_SRC_PATH=$(abs_top_srcdir)/src/bin/cmdctl \
 	B10_LOCKFILE_DIR_FROM_BUILD=$(abs_top_builddir) \
 	$(PYCOVERAGE_RUN) $(abs_srcdir)/$$pytest || exit ; \

+ 115 - 45
src/bin/cmdctl/tests/cmdctl_test.py

@@ -17,17 +17,18 @@
 import unittest
 import socket
 import tempfile
+import stat
 import sys
 from cmdctl import *
 import isc.log
 
-SPEC_FILE_PATH = '..' + os.sep
-if 'CMDCTL_SPEC_PATH' in os.environ:
-    SPEC_FILE_PATH = os.environ['CMDCTL_SPEC_PATH'] + os.sep
+assert 'CMDCTL_SRC_PATH' in os.environ,\
+       "Please run this test with 'make check'"
+SRC_FILE_PATH = os.environ['CMDCTL_SRC_PATH'] + os.sep
 
-SRC_FILE_PATH = '..' + os.sep
-if 'CMDCTL_SRC_PATH' in os.environ:
-    SRC_FILE_PATH = os.environ['CMDCTL_SRC_PATH'] + os.sep
+assert 'CMDCTL_BUILD_PATH' in os.environ,\
+       "Please run this test with 'make check'"
+BUILD_FILE_PATH = os.environ['CMDCTL_BUILD_PATH'] + os.sep
 
 # Rewrite the class for unittest.
 class MySecureHTTPRequestHandler(SecureHTTPRequestHandler):
@@ -36,7 +37,7 @@ class MySecureHTTPRequestHandler(SecureHTTPRequestHandler):
 
     def send_response(self, rcode):
         self.rcode = rcode
-    
+
     def end_headers(self):
         pass
 
@@ -51,13 +52,13 @@ class MySecureHTTPRequestHandler(SecureHTTPRequestHandler):
         super().do_POST()
         self.wfile.close()
         os.remove('tmp.file')
-    
+
 
 class FakeSecureHTTPServer(SecureHTTPServer):
     def __init__(self):
         self.user_sessions = {}
         self.cmdctl = FakeCommandControlForTestRequestHandler()
-        self._verbose = True 
+        self._verbose = True
         self._user_infos = {}
         self.idle_timeout = 1200
         self._lock = threading.Lock()
@@ -71,6 +72,17 @@ class FakeCommandControlForTestRequestHandler(CommandControl):
     def send_command(self, mod, cmd, param):
         return 0, {}
 
+# context to temporarily make a file unreadable
+class UnreadableFile:
+    def __init__(self, file_name):
+        self.file_name = file_name
+        self.orig_mode = os.stat(file_name).st_mode
+
+    def __enter__(self):
+        os.chmod(self.file_name, self.orig_mode & ~stat.S_IRUSR)
+
+    def __exit__(self, type, value, traceback):
+        os.chmod(self.file_name, self.orig_mode)
 
 class TestSecureHTTPRequestHandler(unittest.TestCase):
     def setUp(self):
@@ -97,7 +109,7 @@ class TestSecureHTTPRequestHandler(unittest.TestCase):
         self.handler.path = '/abc'
         mod, cmd = self.handler._parse_request_path()
         self.assertTrue((mod == 'abc') and (cmd == None))
-        
+
         self.handler.path = '/abc/edf'
         mod, cmd = self.handler._parse_request_path()
         self.assertTrue((mod == 'abc') and (cmd == 'edf'))
@@ -125,20 +137,20 @@ class TestSecureHTTPRequestHandler(unittest.TestCase):
 
     def test_do_GET(self):
         self.handler.do_GET()
-        self.assertEqual(self.handler.rcode, http.client.BAD_REQUEST)    
-        
+        self.assertEqual(self.handler.rcode, http.client.BAD_REQUEST)
+
     def test_do_GET_1(self):
         self.handler.headers['cookie'] = 12345
         self.handler.do_GET()
-        self.assertEqual(self.handler.rcode, http.client.UNAUTHORIZED)    
+        self.assertEqual(self.handler.rcode, http.client.UNAUTHORIZED)
 
     def test_do_GET_2(self):
         self.handler.headers['cookie'] = 12345
         self.handler.server.user_sessions[12345] = time.time() + 1000000
         self.handler.path = '/how/are'
         self.handler.do_GET()
-        self.assertEqual(self.handler.rcode, http.client.NO_CONTENT)    
-    
+        self.assertEqual(self.handler.rcode, http.client.NO_CONTENT)
+
     def test_do_GET_3(self):
         self.handler.headers['cookie'] = 12346
         self.handler.server.user_sessions[12346] = time.time() + 1000000
@@ -146,8 +158,8 @@ class TestSecureHTTPRequestHandler(unittest.TestCase):
         for path in path_vec:
             self.handler.path = '/' + path
             self.handler.do_GET()
-            self.assertEqual(self.handler.rcode, http.client.OK)    
-    
+            self.assertEqual(self.handler.rcode, http.client.OK)
+
     def test_user_logged_in(self):
         self.handler.server.user_sessions = {}
         self.handler.session_id = 12345
@@ -243,8 +255,8 @@ class TestSecureHTTPRequestHandler(unittest.TestCase):
         self.assertEqual(http.client.BAD_REQUEST, rcode)
 
     def _gen_module_spec(self):
-        spec = { 'commands': [ 
-                  { 'command_name' :'command', 
+        spec = { 'commands': [
+                  { 'command_name' :'command',
                     'command_args': [ {
                             'item_name' : 'param1',
                             'item_type' : 'integer',
@@ -253,9 +265,9 @@ class TestSecureHTTPRequestHandler(unittest.TestCase):
                            } ],
                     'command_description' : 'cmd description'
                   }
-                ] 
+                ]
                }
-        
+
         return spec
 
     def test_handle_post_request_2(self):
@@ -290,13 +302,13 @@ class MyCommandControl(CommandControl):
         return {}
 
     def _setup_session(self):
-        spec_file = SPEC_FILE_PATH + 'cmdctl.spec'
+        spec_file = BUILD_FILE_PATH + 'cmdctl.spec'
         module_spec = isc.config.module_spec_from_file(spec_file)
         config = isc.config.config_data.ConfigData(module_spec)
         self._module_name = 'Cmdctl'
         self._cmdctl_config_data = config.get_full_config()
 
-    def _handle_msg_from_msgq(self): 
+    def _handle_msg_from_msgq(self):
         pass
 
 class TestCommandControl(unittest.TestCase):
@@ -305,7 +317,7 @@ class TestCommandControl(unittest.TestCase):
         self.old_stdout = sys.stdout
         sys.stdout = open(os.devnull, 'w')
         self.cmdctl = MyCommandControl(None, True)
-   
+
     def tearDown(self):
         sys.stdout.close()
         sys.stdout = self.old_stdout
@@ -320,7 +332,7 @@ class TestCommandControl(unittest.TestCase):
         old_env = os.environ
         if 'B10_FROM_SOURCE' in os.environ:
             del os.environ['B10_FROM_SOURCE']
-        self.cmdctl.get_cmdctl_config_data() 
+        self.cmdctl.get_cmdctl_config_data()
         self._check_config(self.cmdctl)
         os.environ = old_env
 
@@ -328,7 +340,7 @@ class TestCommandControl(unittest.TestCase):
         os.environ['B10_FROM_SOURCE'] = '../'
         self._check_config(self.cmdctl)
         os.environ = old_env
-    
+
     def test_parse_command_result(self):
         self.assertEqual({}, self.cmdctl._parse_command_result(1, {'error' : 1}))
         self.assertEqual({'a': 1}, self.cmdctl._parse_command_result(0, {'a' : 1}))
@@ -391,13 +403,13 @@ class TestCommandControl(unittest.TestCase):
         os.environ = old_env
 
         answer = self.cmdctl.config_handler({'key_file': '/user/non-exist_folder'})
-        self._check_answer(answer, 1, "the file doesn't exist: /user/non-exist_folder")
+        self._check_answer(answer, 1, "'/user/non-exist_folder' does not exist")
 
         answer = self.cmdctl.config_handler({'cert_file': '/user/non-exist_folder'})
-        self._check_answer(answer, 1, "the file doesn't exist: /user/non-exist_folder")
+        self._check_answer(answer, 1, "'/user/non-exist_folder' does not exist")
 
         answer = self.cmdctl.config_handler({'accounts_file': '/user/non-exist_folder'})
-        self._check_answer(answer, 1, 
+        self._check_answer(answer, 1,
                 "Invalid accounts file: [Errno 2] No such file or directory: '/user/non-exist_folder'")
 
         # Test with invalid accounts file
@@ -409,7 +421,7 @@ class TestCommandControl(unittest.TestCase):
         answer = self.cmdctl.config_handler({'accounts_file': file_name})
         self._check_answer(answer, 1, "Invalid accounts file: list index out of range")
         os.remove(file_name)
-    
+
     def test_send_command(self):
         rcode, value = self.cmdctl.send_command('Cmdctl', 'print_settings', None)
         self.assertEqual(rcode, 0)
@@ -424,7 +436,7 @@ class TestSecureHTTPServer(unittest.TestCase):
         self.old_stderr = sys.stderr
         sys.stdout = open(os.devnull, 'w')
         sys.stderr = sys.stdout
-        self.server = MySecureHTTPServer(('localhost', 8080), 
+        self.server = MySecureHTTPServer(('localhost', 8080),
                                          MySecureHTTPRequestHandler,
                                          MyCommandControl, verbose=True)
 
@@ -458,32 +470,90 @@ class TestSecureHTTPServer(unittest.TestCase):
         self.assertEqual(1, len(self.server._user_infos))
         self.assertTrue('root' in self.server._user_infos)
 
+    def test_check_file(self):
+        # Just some file that we know exists
+        file_name = BUILD_FILE_PATH + 'cmdctl-keyfile.pem'
+        check_file(file_name)
+        with UnreadableFile(file_name):
+            self.assertRaises(CmdctlException, check_file, file_name)
+        self.assertRaises(CmdctlException, check_file, '/local/not-exist')
+        self.assertRaises(CmdctlException, check_file, '/')
+
+
     def test_check_key_and_cert(self):
+        keyfile = BUILD_FILE_PATH + 'cmdctl-keyfile.pem'
+        certfile = BUILD_FILE_PATH + 'cmdctl-certfile.pem'
+
+        # no exists
+        self.assertRaises(CmdctlException, self.server._check_key_and_cert,
+                          keyfile, '/local/not-exist')
+        self.assertRaises(CmdctlException, self.server._check_key_and_cert,
+                         '/local/not-exist', certfile)
+
+        # not a file
+        self.assertRaises(CmdctlException, self.server._check_key_and_cert,
+                          keyfile, '/')
         self.assertRaises(CmdctlException, self.server._check_key_and_cert,
-                         '/local/not-exist', 'cmdctl-keyfile.pem')
+                         '/', certfile)
 
-        self.server._check_key_and_cert(SRC_FILE_PATH + 'cmdctl-keyfile.pem',
-                                        SRC_FILE_PATH + 'cmdctl-certfile.pem')
+        # no read permission
+        with UnreadableFile(certfile):
+            self.assertRaises(CmdctlException,
+                              self.server._check_key_and_cert,
+                              keyfile, certfile)
+
+        with UnreadableFile(keyfile):
+            self.assertRaises(CmdctlException,
+                              self.server._check_key_and_cert,
+                              keyfile, certfile)
+
+        # All OK (also happens to check the context code above works)
+        self.server._check_key_and_cert(keyfile, certfile)
 
     def test_wrap_sock_in_ssl_context(self):
         sock = socket.socket()
-        self.assertRaises(socket.error, 
+
+        # Bad files should result in a socket.error raised by our own
+        # code in the basic file checks
+        self.assertRaises(socket.error,
                           self.server._wrap_socket_in_ssl_context,
-                          sock, 
-                          '../cmdctl-keyfile',
-                          '../cmdctl-certfile')
+                          sock,
+                          'no_such_file', 'no_such_file')
 
+        # Using a non-certificate file would cause an SSLError, which
+        # is caught by our code which then raises a basic socket.error
+        self.assertRaises(socket.error,
+                          self.server._wrap_socket_in_ssl_context,
+                          sock,
+                          BUILD_FILE_PATH + 'cmdctl.py',
+                          BUILD_FILE_PATH + 'cmdctl-certfile.pem')
+
+        # Should succeed
         sock1 = socket.socket()
-        self.server._wrap_socket_in_ssl_context(sock1, 
-                          SRC_FILE_PATH + 'cmdctl-keyfile.pem',
-                          SRC_FILE_PATH + 'cmdctl-certfile.pem')
+        ssl_sock = self.server._wrap_socket_in_ssl_context(sock1,
+                                   BUILD_FILE_PATH + 'cmdctl-keyfile.pem',
+                                   BUILD_FILE_PATH + 'cmdctl-certfile.pem')
+        self.assertIsInstance(ssl_sock, ssl.SSLSocket)
+
+        # wrap_socket can also raise IOError, which should be caught and
+        # handled like the other errors.
+        # Force this by temporarily disabling our own file checks
+        orig_check_func = self.server._check_key_and_cert
+        try:
+            self.server._check_key_and_cert = lambda x,y: None
+            self.assertRaises(socket.error,
+                              self.server._wrap_socket_in_ssl_context,
+                              sock,
+                              'no_such_file', 'no_such_file')
+        finally:
+            self.server._check_key_and_cert = orig_check_func
 
 class TestFuncNotInClass(unittest.TestCase):
     def test_check_port(self):
-        self.assertRaises(OptionValueError, check_port, None, 'port', -1, None)        
-        self.assertRaises(OptionValueError, check_port, None, 'port', 65536, None)        
-        self.assertRaises(OptionValueError, check_addr, None, 'ipstr', 'a.b.d', None)        
-        self.assertRaises(OptionValueError, check_addr, None, 'ipstr', '1::0:a.b', None)        
+        self.assertRaises(OptionValueError, check_port, None, 'port', -1, None)
+        self.assertRaises(OptionValueError, check_port, None, 'port', 65536, None)
+        self.assertRaises(OptionValueError, check_addr, None, 'ipstr', 'a.b.d', None)
+        self.assertRaises(OptionValueError, check_addr, None, 'ipstr', '1::0:a.b', None)
 
 
 if __name__== "__main__":