Browse Source

[2595] Better handling of bad files in b10-cmdctl

Still a fatal error if it gets to that point, but
- exact error message
- basic checks (exists, isfile, isreadable) are done when updating configuration as well
Jelte Jansen 12 years ago
parent
commit
2db3158111
3 changed files with 125 additions and 46 deletions
  1. 31 10
      src/bin/cmdctl/cmdctl.py.in
  2. 11 0
      src/bin/cmdctl/cmdctl_messages.mes
  3. 83 36
      src/bin/cmdctl/tests/cmdctl_test.py

+ 31 - 10
src/bin/cmdctl/cmdctl.py.in

@@ -82,6 +82,25 @@ SPECFILE_LOCATION = SPECFILE_PATH + os.sep + "cmdctl.spec"
 class CmdctlException(Exception):
 class CmdctlException(Exception):
     pass
     pass
 
 
+class CmdctlFatalException(Exception):
+    """
+    Exception for fatal errors, which should not be caught anywhere but on
+    the highest level.
+    """
+    pass
+
+def check_file(file_name):
+    # TODO: Check contents of certificate file
+    if not os.path.exists(file_name):
+        raise CmdctlException("'%s' does not exist" % file_name)
+
+    if not os.path.isfile(file_name):
+        raise CmdctlException("'%s' is not a file" % file_name)
+
+    if not os.access(file_name, os.R_OK):
+        raise CmdctlException("'%s' is not readable" % file_name)
+
+
 class SecureHTTPRequestHandler(http.server.BaseHTTPRequestHandler):
 class SecureHTTPRequestHandler(http.server.BaseHTTPRequestHandler):
     '''https connection request handler.
     '''https connection request handler.
     Currently only GET and POST are supported.  '''
     Currently only GET and POST are supported.  '''
@@ -153,7 +172,6 @@ class SecureHTTPRequestHandler(http.server.BaseHTTPRequestHandler):
         self.end_headers()
         self.end_headers()
         self.wfile.write(json.dumps(reply).encode())
         self.wfile.write(json.dumps(reply).encode())
 
 
-
     def _handle_login(self):
     def _handle_login(self):
         if self._is_user_logged_in():
         if self._is_user_logged_in():
             return http.client.OK, ["user has already login"]
             return http.client.OK, ["user has already login"]
@@ -282,8 +300,10 @@ class CommandControl():
                 # further check need to be done: eg. whether
                 # further check need to be done: eg. whether
                 # the private/certificate is valid.
                 # the private/certificate is valid.
                 path = new_config[key]
                 path = new_config[key]
-                if not os.path.exists(path):
-                    errstr = "the file doesn't exist: " + path
+                try:
+                    check_file(path)
+                except CmdctlException as cce:
+                    errstr = str(cce)
             elif key == 'accounts_file':
             elif key == 'accounts_file':
                 errstr = self._accounts_file_check(new_config[key])
                 errstr = self._accounts_file_check(new_config[key])
             else:
             else:
@@ -524,12 +544,8 @@ class SecureHTTPServer(socketserver_mixin.NoPollMixIn,
         self.user_sessions[session_id] = time.time()
         self.user_sessions[session_id] = time.time()
 
 
     def _check_key_and_cert(self, key, cert):
     def _check_key_and_cert(self, key, cert):
-        # TODO, check the content of key/certificate file
-        if not os.path.exists(key):
-            raise CmdctlException("key file '%s' doesn't exist " % key)
-
-        if not os.path.exists(cert):
-            raise CmdctlException("certificate file '%s' doesn't exist " % cert)
+        check_file(key)
+        check_file(cert);
 
 
     def _wrap_socket_in_ssl_context(self, sock, key, cert):
     def _wrap_socket_in_ssl_context(self, sock, key, cert):
         try:
         try:
@@ -540,11 +556,14 @@ class SecureHTTPServer(socketserver_mixin.NoPollMixIn,
                                       keyfile = key,
                                       keyfile = key,
                                       ssl_version = ssl.PROTOCOL_SSLv23)
                                       ssl_version = ssl.PROTOCOL_SSLv23)
             return ssl_sock
             return ssl_sock
-        except (ssl.SSLError, CmdctlException) as err :
+        except ssl.SSLError as err :
             logger.error(CMDCTL_SSL_SETUP_FAILURE_USER_DENIED, err)
             logger.error(CMDCTL_SSL_SETUP_FAILURE_USER_DENIED, err)
             self.close_request(sock)
             self.close_request(sock)
             # raise socket error to finish the request
             # raise socket error to finish the request
             raise socket.error
             raise socket.error
+        except CmdctlException as cce:
+            logger.fatal(CMDCTL_SSL_SETUP_FAILURE_READING_CERT, cce)
+            raise CmdctlFatalException("Unable to accept SSL connections")
 
 
     def get_request(self):
     def get_request(self):
         '''Get client request socket and wrap it in SSL context. '''
         '''Get client request socket and wrap it in SSL context. '''
@@ -633,6 +652,8 @@ if __name__ == '__main__':
         logger.info(CMDCTL_STOPPED_BY_KEYBOARD)
         logger.info(CMDCTL_STOPPED_BY_KEYBOARD)
     except CmdctlException as err:
     except CmdctlException as err:
         logger.fatal(CMDCTL_UNCAUGHT_EXCEPTION, err);
         logger.fatal(CMDCTL_UNCAUGHT_EXCEPTION, err);
+    except CmdctlFatalException as fe:
+        logger.fatal(CMDCTL_FATAL_EXCEPTION, fe)
 
 
     if httpd:
     if httpd:
         httpd.shutdown()
         httpd.shutdown()

+ 11 - 0
src/bin/cmdctl/cmdctl_messages.mes

@@ -43,6 +43,9 @@ specific error is printed in the message.
 This debug message indicates that the given command has been sent to
 This debug message indicates that the given command has been sent to
 the given module.
 the given module.
 
 
+% CMDCTL_FATAL_EXCEPTION A fatal error occured: %1
+While running, b10-cmdctl encountered a fatal exception and it will shut down
+
 % CMDCTL_NO_SUCH_USER username not found in user database: %1
 % CMDCTL_NO_SUCH_USER username not found in user database: %1
 A login attempt was made to b10-cmdctl, but the username was not known.
 A login attempt was made to b10-cmdctl, but the username was not known.
 Users can be added with the tool b10-cmdctl-usermgr.
 Users can be added with the tool b10-cmdctl-usermgr.
@@ -58,6 +61,14 @@ with the tool b10-cmdctl-usermgr.
 This debug message indicates that the given command is being sent to
 This debug message indicates that the given command is being sent to
 the given module.
 the given module.
 
 
+% CMDCTL_SSL_SETUP_FAILURE_READING_CERT failed to read certificate or key: %1
+The b10-cmdctl daemon is unable to read either the certificate file or
+the private key file, and is therefore unable to accept any SSL connections.
+This is a fatal error, as b10-cmdctl cannot be of any use in this state.
+The specific error is printed in the message.
+The administrator should solve the issue with the files, or recreate them
+with the b10-certgen tool.
+
 % CMDCTL_SSL_SETUP_FAILURE_USER_DENIED failed to create an SSL connection (user denied): %1
 % CMDCTL_SSL_SETUP_FAILURE_USER_DENIED failed to create an SSL connection (user denied): %1
 The user was denied because the SSL connection could not successfully
 The user was denied because the SSL connection could not successfully
 be set up. The specific error is given in the log message. Possible
 be set up. The specific error is given in the log message. Possible

+ 83 - 36
src/bin/cmdctl/tests/cmdctl_test.py

@@ -17,6 +17,7 @@
 import unittest
 import unittest
 import socket
 import socket
 import tempfile
 import tempfile
+import stat
 import sys
 import sys
 from cmdctl import *
 from cmdctl import *
 import isc.log
 import isc.log
@@ -36,7 +37,7 @@ class MySecureHTTPRequestHandler(SecureHTTPRequestHandler):
 
 
     def send_response(self, rcode):
     def send_response(self, rcode):
         self.rcode = rcode
         self.rcode = rcode
-    
+
     def end_headers(self):
     def end_headers(self):
         pass
         pass
 
 
@@ -51,13 +52,13 @@ class MySecureHTTPRequestHandler(SecureHTTPRequestHandler):
         super().do_POST()
         super().do_POST()
         self.wfile.close()
         self.wfile.close()
         os.remove('tmp.file')
         os.remove('tmp.file')
-    
+
 
 
 class FakeSecureHTTPServer(SecureHTTPServer):
 class FakeSecureHTTPServer(SecureHTTPServer):
     def __init__(self):
     def __init__(self):
         self.user_sessions = {}
         self.user_sessions = {}
         self.cmdctl = FakeCommandControlForTestRequestHandler()
         self.cmdctl = FakeCommandControlForTestRequestHandler()
-        self._verbose = True 
+        self._verbose = True
         self._user_infos = {}
         self._user_infos = {}
         self.idle_timeout = 1200
         self.idle_timeout = 1200
         self._lock = threading.Lock()
         self._lock = threading.Lock()
@@ -71,6 +72,17 @@ class FakeCommandControlForTestRequestHandler(CommandControl):
     def send_command(self, mod, cmd, param):
     def send_command(self, mod, cmd, param):
         return 0, {}
         return 0, {}
 
 
+# context to temporarily make a file unreadable
+class UnreadableFile:
+    def __init__(self, file_name):
+        self.file_name = file_name
+        self.orig_mode = os.stat(file_name).st_mode
+
+    def __enter__(self):
+        os.chmod(self.file_name, self.orig_mode & ~stat.S_IRUSR)
+
+    def __exit__(self, type, value, traceback):
+        os.chmod(self.file_name, self.orig_mode)
 
 
 class TestSecureHTTPRequestHandler(unittest.TestCase):
 class TestSecureHTTPRequestHandler(unittest.TestCase):
     def setUp(self):
     def setUp(self):
@@ -97,7 +109,7 @@ class TestSecureHTTPRequestHandler(unittest.TestCase):
         self.handler.path = '/abc'
         self.handler.path = '/abc'
         mod, cmd = self.handler._parse_request_path()
         mod, cmd = self.handler._parse_request_path()
         self.assertTrue((mod == 'abc') and (cmd == None))
         self.assertTrue((mod == 'abc') and (cmd == None))
-        
+
         self.handler.path = '/abc/edf'
         self.handler.path = '/abc/edf'
         mod, cmd = self.handler._parse_request_path()
         mod, cmd = self.handler._parse_request_path()
         self.assertTrue((mod == 'abc') and (cmd == 'edf'))
         self.assertTrue((mod == 'abc') and (cmd == 'edf'))
@@ -125,20 +137,20 @@ class TestSecureHTTPRequestHandler(unittest.TestCase):
 
 
     def test_do_GET(self):
     def test_do_GET(self):
         self.handler.do_GET()
         self.handler.do_GET()
-        self.assertEqual(self.handler.rcode, http.client.BAD_REQUEST)    
-        
+        self.assertEqual(self.handler.rcode, http.client.BAD_REQUEST)
+
     def test_do_GET_1(self):
     def test_do_GET_1(self):
         self.handler.headers['cookie'] = 12345
         self.handler.headers['cookie'] = 12345
         self.handler.do_GET()
         self.handler.do_GET()
-        self.assertEqual(self.handler.rcode, http.client.UNAUTHORIZED)    
+        self.assertEqual(self.handler.rcode, http.client.UNAUTHORIZED)
 
 
     def test_do_GET_2(self):
     def test_do_GET_2(self):
         self.handler.headers['cookie'] = 12345
         self.handler.headers['cookie'] = 12345
         self.handler.server.user_sessions[12345] = time.time() + 1000000
         self.handler.server.user_sessions[12345] = time.time() + 1000000
         self.handler.path = '/how/are'
         self.handler.path = '/how/are'
         self.handler.do_GET()
         self.handler.do_GET()
-        self.assertEqual(self.handler.rcode, http.client.NO_CONTENT)    
-    
+        self.assertEqual(self.handler.rcode, http.client.NO_CONTENT)
+
     def test_do_GET_3(self):
     def test_do_GET_3(self):
         self.handler.headers['cookie'] = 12346
         self.handler.headers['cookie'] = 12346
         self.handler.server.user_sessions[12346] = time.time() + 1000000
         self.handler.server.user_sessions[12346] = time.time() + 1000000
@@ -146,8 +158,8 @@ class TestSecureHTTPRequestHandler(unittest.TestCase):
         for path in path_vec:
         for path in path_vec:
             self.handler.path = '/' + path
             self.handler.path = '/' + path
             self.handler.do_GET()
             self.handler.do_GET()
-            self.assertEqual(self.handler.rcode, http.client.OK)    
-    
+            self.assertEqual(self.handler.rcode, http.client.OK)
+
     def test_user_logged_in(self):
     def test_user_logged_in(self):
         self.handler.server.user_sessions = {}
         self.handler.server.user_sessions = {}
         self.handler.session_id = 12345
         self.handler.session_id = 12345
@@ -243,8 +255,8 @@ class TestSecureHTTPRequestHandler(unittest.TestCase):
         self.assertEqual(http.client.BAD_REQUEST, rcode)
         self.assertEqual(http.client.BAD_REQUEST, rcode)
 
 
     def _gen_module_spec(self):
     def _gen_module_spec(self):
-        spec = { 'commands': [ 
-                  { 'command_name' :'command', 
+        spec = { 'commands': [
+                  { 'command_name' :'command',
                     'command_args': [ {
                     'command_args': [ {
                             'item_name' : 'param1',
                             'item_name' : 'param1',
                             'item_type' : 'integer',
                             'item_type' : 'integer',
@@ -253,9 +265,9 @@ class TestSecureHTTPRequestHandler(unittest.TestCase):
                            } ],
                            } ],
                     'command_description' : 'cmd description'
                     'command_description' : 'cmd description'
                   }
                   }
-                ] 
+                ]
                }
                }
-        
+
         return spec
         return spec
 
 
     def test_handle_post_request_2(self):
     def test_handle_post_request_2(self):
@@ -296,7 +308,7 @@ class MyCommandControl(CommandControl):
         self._module_name = 'Cmdctl'
         self._module_name = 'Cmdctl'
         self._cmdctl_config_data = config.get_full_config()
         self._cmdctl_config_data = config.get_full_config()
 
 
-    def _handle_msg_from_msgq(self): 
+    def _handle_msg_from_msgq(self):
         pass
         pass
 
 
 class TestCommandControl(unittest.TestCase):
 class TestCommandControl(unittest.TestCase):
@@ -305,7 +317,7 @@ class TestCommandControl(unittest.TestCase):
         self.old_stdout = sys.stdout
         self.old_stdout = sys.stdout
         sys.stdout = open(os.devnull, 'w')
         sys.stdout = open(os.devnull, 'w')
         self.cmdctl = MyCommandControl(None, True)
         self.cmdctl = MyCommandControl(None, True)
-   
+
     def tearDown(self):
     def tearDown(self):
         sys.stdout.close()
         sys.stdout.close()
         sys.stdout = self.old_stdout
         sys.stdout = self.old_stdout
@@ -320,7 +332,7 @@ class TestCommandControl(unittest.TestCase):
         old_env = os.environ
         old_env = os.environ
         if 'B10_FROM_SOURCE' in os.environ:
         if 'B10_FROM_SOURCE' in os.environ:
             del os.environ['B10_FROM_SOURCE']
             del os.environ['B10_FROM_SOURCE']
-        self.cmdctl.get_cmdctl_config_data() 
+        self.cmdctl.get_cmdctl_config_data()
         self._check_config(self.cmdctl)
         self._check_config(self.cmdctl)
         os.environ = old_env
         os.environ = old_env
 
 
@@ -328,7 +340,7 @@ class TestCommandControl(unittest.TestCase):
         os.environ['B10_FROM_SOURCE'] = '../'
         os.environ['B10_FROM_SOURCE'] = '../'
         self._check_config(self.cmdctl)
         self._check_config(self.cmdctl)
         os.environ = old_env
         os.environ = old_env
-    
+
     def test_parse_command_result(self):
     def test_parse_command_result(self):
         self.assertEqual({}, self.cmdctl._parse_command_result(1, {'error' : 1}))
         self.assertEqual({}, self.cmdctl._parse_command_result(1, {'error' : 1}))
         self.assertEqual({'a': 1}, self.cmdctl._parse_command_result(0, {'a' : 1}))
         self.assertEqual({'a': 1}, self.cmdctl._parse_command_result(0, {'a' : 1}))
@@ -391,13 +403,13 @@ class TestCommandControl(unittest.TestCase):
         os.environ = old_env
         os.environ = old_env
 
 
         answer = self.cmdctl.config_handler({'key_file': '/user/non-exist_folder'})
         answer = self.cmdctl.config_handler({'key_file': '/user/non-exist_folder'})
-        self._check_answer(answer, 1, "the file doesn't exist: /user/non-exist_folder")
+        self._check_answer(answer, 1, "'/user/non-exist_folder' does not exist")
 
 
         answer = self.cmdctl.config_handler({'cert_file': '/user/non-exist_folder'})
         answer = self.cmdctl.config_handler({'cert_file': '/user/non-exist_folder'})
-        self._check_answer(answer, 1, "the file doesn't exist: /user/non-exist_folder")
+        self._check_answer(answer, 1, "'/user/non-exist_folder' does not exist")
 
 
         answer = self.cmdctl.config_handler({'accounts_file': '/user/non-exist_folder'})
         answer = self.cmdctl.config_handler({'accounts_file': '/user/non-exist_folder'})
-        self._check_answer(answer, 1, 
+        self._check_answer(answer, 1,
                 "Invalid accounts file: [Errno 2] No such file or directory: '/user/non-exist_folder'")
                 "Invalid accounts file: [Errno 2] No such file or directory: '/user/non-exist_folder'")
 
 
         # Test with invalid accounts file
         # Test with invalid accounts file
@@ -409,7 +421,7 @@ class TestCommandControl(unittest.TestCase):
         answer = self.cmdctl.config_handler({'accounts_file': file_name})
         answer = self.cmdctl.config_handler({'accounts_file': file_name})
         self._check_answer(answer, 1, "Invalid accounts file: list index out of range")
         self._check_answer(answer, 1, "Invalid accounts file: list index out of range")
         os.remove(file_name)
         os.remove(file_name)
-    
+
     def test_send_command(self):
     def test_send_command(self):
         rcode, value = self.cmdctl.send_command('Cmdctl', 'print_settings', None)
         rcode, value = self.cmdctl.send_command('Cmdctl', 'print_settings', None)
         self.assertEqual(rcode, 0)
         self.assertEqual(rcode, 0)
@@ -424,7 +436,7 @@ class TestSecureHTTPServer(unittest.TestCase):
         self.old_stderr = sys.stderr
         self.old_stderr = sys.stderr
         sys.stdout = open(os.devnull, 'w')
         sys.stdout = open(os.devnull, 'w')
         sys.stderr = sys.stdout
         sys.stderr = sys.stdout
-        self.server = MySecureHTTPServer(('localhost', 8080), 
+        self.server = MySecureHTTPServer(('localhost', 8080),
                                          MySecureHTTPRequestHandler,
                                          MySecureHTTPRequestHandler,
                                          MyCommandControl, verbose=True)
                                          MyCommandControl, verbose=True)
 
 
@@ -458,32 +470,67 @@ class TestSecureHTTPServer(unittest.TestCase):
         self.assertEqual(1, len(self.server._user_infos))
         self.assertEqual(1, len(self.server._user_infos))
         self.assertTrue('root' in self.server._user_infos)
         self.assertTrue('root' in self.server._user_infos)
 
 
+    def test_check_file(self):
+        # Just some file that we know exists
+        file_name = SRC_FILE_PATH + 'cmdctl-keyfile.pem'
+        check_file(file_name)
+        with UnreadableFile(file_name):
+            self.assertRaises(CmdctlException, check_file, file_name)
+        self.assertRaises(CmdctlException, check_file, '/local/not-exist')
+        self.assertRaises(CmdctlException, check_file, '/')
+
+
     def test_check_key_and_cert(self):
     def test_check_key_and_cert(self):
+        keyfile = SRC_FILE_PATH + 'cmdctl-keyfile.pem'
+        certfile = SRC_FILE_PATH + 'cmdctl-certfile.pem'
+
+        # no exists
+        self.assertRaises(CmdctlException, self.server._check_key_and_cert,
+                          keyfile, '/local/not-exist')
         self.assertRaises(CmdctlException, self.server._check_key_and_cert,
         self.assertRaises(CmdctlException, self.server._check_key_and_cert,
-                         '/local/not-exist', 'cmdctl-keyfile.pem')
+                         '/local/not-exist', certfile)
 
 
-        self.server._check_key_and_cert(SRC_FILE_PATH + 'cmdctl-keyfile.pem',
-                                        SRC_FILE_PATH + 'cmdctl-certfile.pem')
+        # not a file
+        self.assertRaises(CmdctlException, self.server._check_key_and_cert,
+                          keyfile, '/')
+        self.assertRaises(CmdctlException, self.server._check_key_and_cert,
+                         '/', certfile)
+
+        # no read permission
+        with UnreadableFile(certfile):
+            self.assertRaises(CmdctlException,
+                              self.server._check_key_and_cert,
+                              keyfile, certfile)
+
+        with UnreadableFile(keyfile):
+            self.assertRaises(CmdctlException,
+                              self.server._check_key_and_cert,
+                              keyfile, certfile)
+
+        # All OK (also happens to check the context code above works)
+        self.server._check_key_and_cert(keyfile, certfile)
 
 
     def test_wrap_sock_in_ssl_context(self):
     def test_wrap_sock_in_ssl_context(self):
         sock = socket.socket()
         sock = socket.socket()
-        self.assertRaises(socket.error, 
+
+        # Test exception is Fatal here (all specific cases are tested
+        # in test_check_key_and_cert())
+        self.assertRaises(CmdctlFatalException,
                           self.server._wrap_socket_in_ssl_context,
                           self.server._wrap_socket_in_ssl_context,
-                          sock, 
-                          '../cmdctl-keyfile',
-                          '../cmdctl-certfile')
+                          sock,
+                          'no_such_file', 'no_such_file')
 
 
         sock1 = socket.socket()
         sock1 = socket.socket()
-        self.server._wrap_socket_in_ssl_context(sock1, 
+        self.server._wrap_socket_in_ssl_context(sock1,
                           SRC_FILE_PATH + 'cmdctl-keyfile.pem',
                           SRC_FILE_PATH + 'cmdctl-keyfile.pem',
                           SRC_FILE_PATH + 'cmdctl-certfile.pem')
                           SRC_FILE_PATH + 'cmdctl-certfile.pem')
 
 
 class TestFuncNotInClass(unittest.TestCase):
 class TestFuncNotInClass(unittest.TestCase):
     def test_check_port(self):
     def test_check_port(self):
-        self.assertRaises(OptionValueError, check_port, None, 'port', -1, None)        
-        self.assertRaises(OptionValueError, check_port, None, 'port', 65536, None)        
-        self.assertRaises(OptionValueError, check_addr, None, 'ipstr', 'a.b.d', None)        
-        self.assertRaises(OptionValueError, check_addr, None, 'ipstr', '1::0:a.b', None)        
+        self.assertRaises(OptionValueError, check_port, None, 'port', -1, None)
+        self.assertRaises(OptionValueError, check_port, None, 'port', 65536, None)
+        self.assertRaises(OptionValueError, check_addr, None, 'ipstr', 'a.b.d', None)
+        self.assertRaises(OptionValueError, check_addr, None, 'ipstr', '1::0:a.b', None)
 
 
 
 
 if __name__== "__main__":
 if __name__== "__main__":