Parcourir la source

[2595] Better handling of bad files in b10-cmdctl

Still a fatal error if it gets to that point, but
- exact error message
- basic checks (exists, isfile, isreadable) are done when updating configuration as well
Jelte Jansen il y a 12 ans
Parent
commit
2db3158111
3 fichiers modifiés avec 125 ajouts et 46 suppressions
  1. 31 10
      src/bin/cmdctl/cmdctl.py.in
  2. 11 0
      src/bin/cmdctl/cmdctl_messages.mes
  3. 83 36
      src/bin/cmdctl/tests/cmdctl_test.py

+ 31 - 10
src/bin/cmdctl/cmdctl.py.in

@@ -82,6 +82,25 @@ SPECFILE_LOCATION = SPECFILE_PATH + os.sep + "cmdctl.spec"
 class CmdctlException(Exception):
     pass
 
+class CmdctlFatalException(Exception):
+    """
+    Exception for fatal errors, which should not be caught anywhere but on
+    the highest level.
+    """
+    pass
+
+def check_file(file_name):
+    # TODO: Check contents of certificate file
+    if not os.path.exists(file_name):
+        raise CmdctlException("'%s' does not exist" % file_name)
+
+    if not os.path.isfile(file_name):
+        raise CmdctlException("'%s' is not a file" % file_name)
+
+    if not os.access(file_name, os.R_OK):
+        raise CmdctlException("'%s' is not readable" % file_name)
+
+
 class SecureHTTPRequestHandler(http.server.BaseHTTPRequestHandler):
     '''https connection request handler.
     Currently only GET and POST are supported.  '''
@@ -153,7 +172,6 @@ class SecureHTTPRequestHandler(http.server.BaseHTTPRequestHandler):
         self.end_headers()
         self.wfile.write(json.dumps(reply).encode())
 
-
     def _handle_login(self):
         if self._is_user_logged_in():
             return http.client.OK, ["user has already login"]
@@ -282,8 +300,10 @@ class CommandControl():
                 # further check need to be done: eg. whether
                 # the private/certificate is valid.
                 path = new_config[key]
-                if not os.path.exists(path):
-                    errstr = "the file doesn't exist: " + path
+                try:
+                    check_file(path)
+                except CmdctlException as cce:
+                    errstr = str(cce)
             elif key == 'accounts_file':
                 errstr = self._accounts_file_check(new_config[key])
             else:
@@ -524,12 +544,8 @@ class SecureHTTPServer(socketserver_mixin.NoPollMixIn,
         self.user_sessions[session_id] = time.time()
 
     def _check_key_and_cert(self, key, cert):
-        # TODO, check the content of key/certificate file
-        if not os.path.exists(key):
-            raise CmdctlException("key file '%s' doesn't exist " % key)
-
-        if not os.path.exists(cert):
-            raise CmdctlException("certificate file '%s' doesn't exist " % cert)
+        check_file(key)
+        check_file(cert);
 
     def _wrap_socket_in_ssl_context(self, sock, key, cert):
         try:
@@ -540,11 +556,14 @@ class SecureHTTPServer(socketserver_mixin.NoPollMixIn,
                                       keyfile = key,
                                       ssl_version = ssl.PROTOCOL_SSLv23)
             return ssl_sock
-        except (ssl.SSLError, CmdctlException) as err :
+        except ssl.SSLError as err :
             logger.error(CMDCTL_SSL_SETUP_FAILURE_USER_DENIED, err)
             self.close_request(sock)
             # raise socket error to finish the request
             raise socket.error
+        except CmdctlException as cce:
+            logger.fatal(CMDCTL_SSL_SETUP_FAILURE_READING_CERT, cce)
+            raise CmdctlFatalException("Unable to accept SSL connections")
 
     def get_request(self):
         '''Get client request socket and wrap it in SSL context. '''
@@ -633,6 +652,8 @@ if __name__ == '__main__':
         logger.info(CMDCTL_STOPPED_BY_KEYBOARD)
     except CmdctlException as err:
         logger.fatal(CMDCTL_UNCAUGHT_EXCEPTION, err);
+    except CmdctlFatalException as fe:
+        logger.fatal(CMDCTL_FATAL_EXCEPTION, fe)
 
     if httpd:
         httpd.shutdown()

+ 11 - 0
src/bin/cmdctl/cmdctl_messages.mes

@@ -43,6 +43,9 @@ specific error is printed in the message.
 This debug message indicates that the given command has been sent to
 the given module.
 
+% CMDCTL_FATAL_EXCEPTION A fatal error occured: %1
+While running, b10-cmdctl encountered a fatal exception and it will shut down
+
 % CMDCTL_NO_SUCH_USER username not found in user database: %1
 A login attempt was made to b10-cmdctl, but the username was not known.
 Users can be added with the tool b10-cmdctl-usermgr.
@@ -58,6 +61,14 @@ with the tool b10-cmdctl-usermgr.
 This debug message indicates that the given command is being sent to
 the given module.
 
+% CMDCTL_SSL_SETUP_FAILURE_READING_CERT failed to read certificate or key: %1
+The b10-cmdctl daemon is unable to read either the certificate file or
+the private key file, and is therefore unable to accept any SSL connections.
+This is a fatal error, as b10-cmdctl cannot be of any use in this state.
+The specific error is printed in the message.
+The administrator should solve the issue with the files, or recreate them
+with the b10-certgen tool.
+
 % CMDCTL_SSL_SETUP_FAILURE_USER_DENIED failed to create an SSL connection (user denied): %1
 The user was denied because the SSL connection could not successfully
 be set up. The specific error is given in the log message. Possible

+ 83 - 36
src/bin/cmdctl/tests/cmdctl_test.py

@@ -17,6 +17,7 @@
 import unittest
 import socket
 import tempfile
+import stat
 import sys
 from cmdctl import *
 import isc.log
@@ -36,7 +37,7 @@ class MySecureHTTPRequestHandler(SecureHTTPRequestHandler):
 
     def send_response(self, rcode):
         self.rcode = rcode
-    
+
     def end_headers(self):
         pass
 
@@ -51,13 +52,13 @@ class MySecureHTTPRequestHandler(SecureHTTPRequestHandler):
         super().do_POST()
         self.wfile.close()
         os.remove('tmp.file')
-    
+
 
 class FakeSecureHTTPServer(SecureHTTPServer):
     def __init__(self):
         self.user_sessions = {}
         self.cmdctl = FakeCommandControlForTestRequestHandler()
-        self._verbose = True 
+        self._verbose = True
         self._user_infos = {}
         self.idle_timeout = 1200
         self._lock = threading.Lock()
@@ -71,6 +72,17 @@ class FakeCommandControlForTestRequestHandler(CommandControl):
     def send_command(self, mod, cmd, param):
         return 0, {}
 
+# context to temporarily make a file unreadable
+class UnreadableFile:
+    def __init__(self, file_name):
+        self.file_name = file_name
+        self.orig_mode = os.stat(file_name).st_mode
+
+    def __enter__(self):
+        os.chmod(self.file_name, self.orig_mode & ~stat.S_IRUSR)
+
+    def __exit__(self, type, value, traceback):
+        os.chmod(self.file_name, self.orig_mode)
 
 class TestSecureHTTPRequestHandler(unittest.TestCase):
     def setUp(self):
@@ -97,7 +109,7 @@ class TestSecureHTTPRequestHandler(unittest.TestCase):
         self.handler.path = '/abc'
         mod, cmd = self.handler._parse_request_path()
         self.assertTrue((mod == 'abc') and (cmd == None))
-        
+
         self.handler.path = '/abc/edf'
         mod, cmd = self.handler._parse_request_path()
         self.assertTrue((mod == 'abc') and (cmd == 'edf'))
@@ -125,20 +137,20 @@ class TestSecureHTTPRequestHandler(unittest.TestCase):
 
     def test_do_GET(self):
         self.handler.do_GET()
-        self.assertEqual(self.handler.rcode, http.client.BAD_REQUEST)    
-        
+        self.assertEqual(self.handler.rcode, http.client.BAD_REQUEST)
+
     def test_do_GET_1(self):
         self.handler.headers['cookie'] = 12345
         self.handler.do_GET()
-        self.assertEqual(self.handler.rcode, http.client.UNAUTHORIZED)    
+        self.assertEqual(self.handler.rcode, http.client.UNAUTHORIZED)
 
     def test_do_GET_2(self):
         self.handler.headers['cookie'] = 12345
         self.handler.server.user_sessions[12345] = time.time() + 1000000
         self.handler.path = '/how/are'
         self.handler.do_GET()
-        self.assertEqual(self.handler.rcode, http.client.NO_CONTENT)    
-    
+        self.assertEqual(self.handler.rcode, http.client.NO_CONTENT)
+
     def test_do_GET_3(self):
         self.handler.headers['cookie'] = 12346
         self.handler.server.user_sessions[12346] = time.time() + 1000000
@@ -146,8 +158,8 @@ class TestSecureHTTPRequestHandler(unittest.TestCase):
         for path in path_vec:
             self.handler.path = '/' + path
             self.handler.do_GET()
-            self.assertEqual(self.handler.rcode, http.client.OK)    
-    
+            self.assertEqual(self.handler.rcode, http.client.OK)
+
     def test_user_logged_in(self):
         self.handler.server.user_sessions = {}
         self.handler.session_id = 12345
@@ -243,8 +255,8 @@ class TestSecureHTTPRequestHandler(unittest.TestCase):
         self.assertEqual(http.client.BAD_REQUEST, rcode)
 
     def _gen_module_spec(self):
-        spec = { 'commands': [ 
-                  { 'command_name' :'command', 
+        spec = { 'commands': [
+                  { 'command_name' :'command',
                     'command_args': [ {
                             'item_name' : 'param1',
                             'item_type' : 'integer',
@@ -253,9 +265,9 @@ class TestSecureHTTPRequestHandler(unittest.TestCase):
                            } ],
                     'command_description' : 'cmd description'
                   }
-                ] 
+                ]
                }
-        
+
         return spec
 
     def test_handle_post_request_2(self):
@@ -296,7 +308,7 @@ class MyCommandControl(CommandControl):
         self._module_name = 'Cmdctl'
         self._cmdctl_config_data = config.get_full_config()
 
-    def _handle_msg_from_msgq(self): 
+    def _handle_msg_from_msgq(self):
         pass
 
 class TestCommandControl(unittest.TestCase):
@@ -305,7 +317,7 @@ class TestCommandControl(unittest.TestCase):
         self.old_stdout = sys.stdout
         sys.stdout = open(os.devnull, 'w')
         self.cmdctl = MyCommandControl(None, True)
-   
+
     def tearDown(self):
         sys.stdout.close()
         sys.stdout = self.old_stdout
@@ -320,7 +332,7 @@ class TestCommandControl(unittest.TestCase):
         old_env = os.environ
         if 'B10_FROM_SOURCE' in os.environ:
             del os.environ['B10_FROM_SOURCE']
-        self.cmdctl.get_cmdctl_config_data() 
+        self.cmdctl.get_cmdctl_config_data()
         self._check_config(self.cmdctl)
         os.environ = old_env
 
@@ -328,7 +340,7 @@ class TestCommandControl(unittest.TestCase):
         os.environ['B10_FROM_SOURCE'] = '../'
         self._check_config(self.cmdctl)
         os.environ = old_env
-    
+
     def test_parse_command_result(self):
         self.assertEqual({}, self.cmdctl._parse_command_result(1, {'error' : 1}))
         self.assertEqual({'a': 1}, self.cmdctl._parse_command_result(0, {'a' : 1}))
@@ -391,13 +403,13 @@ class TestCommandControl(unittest.TestCase):
         os.environ = old_env
 
         answer = self.cmdctl.config_handler({'key_file': '/user/non-exist_folder'})
-        self._check_answer(answer, 1, "the file doesn't exist: /user/non-exist_folder")
+        self._check_answer(answer, 1, "'/user/non-exist_folder' does not exist")
 
         answer = self.cmdctl.config_handler({'cert_file': '/user/non-exist_folder'})
-        self._check_answer(answer, 1, "the file doesn't exist: /user/non-exist_folder")
+        self._check_answer(answer, 1, "'/user/non-exist_folder' does not exist")
 
         answer = self.cmdctl.config_handler({'accounts_file': '/user/non-exist_folder'})
-        self._check_answer(answer, 1, 
+        self._check_answer(answer, 1,
                 "Invalid accounts file: [Errno 2] No such file or directory: '/user/non-exist_folder'")
 
         # Test with invalid accounts file
@@ -409,7 +421,7 @@ class TestCommandControl(unittest.TestCase):
         answer = self.cmdctl.config_handler({'accounts_file': file_name})
         self._check_answer(answer, 1, "Invalid accounts file: list index out of range")
         os.remove(file_name)
-    
+
     def test_send_command(self):
         rcode, value = self.cmdctl.send_command('Cmdctl', 'print_settings', None)
         self.assertEqual(rcode, 0)
@@ -424,7 +436,7 @@ class TestSecureHTTPServer(unittest.TestCase):
         self.old_stderr = sys.stderr
         sys.stdout = open(os.devnull, 'w')
         sys.stderr = sys.stdout
-        self.server = MySecureHTTPServer(('localhost', 8080), 
+        self.server = MySecureHTTPServer(('localhost', 8080),
                                          MySecureHTTPRequestHandler,
                                          MyCommandControl, verbose=True)
 
@@ -458,32 +470,67 @@ class TestSecureHTTPServer(unittest.TestCase):
         self.assertEqual(1, len(self.server._user_infos))
         self.assertTrue('root' in self.server._user_infos)
 
+    def test_check_file(self):
+        # Just some file that we know exists
+        file_name = SRC_FILE_PATH + 'cmdctl-keyfile.pem'
+        check_file(file_name)
+        with UnreadableFile(file_name):
+            self.assertRaises(CmdctlException, check_file, file_name)
+        self.assertRaises(CmdctlException, check_file, '/local/not-exist')
+        self.assertRaises(CmdctlException, check_file, '/')
+
+
     def test_check_key_and_cert(self):
+        keyfile = SRC_FILE_PATH + 'cmdctl-keyfile.pem'
+        certfile = SRC_FILE_PATH + 'cmdctl-certfile.pem'
+
+        # no exists
+        self.assertRaises(CmdctlException, self.server._check_key_and_cert,
+                          keyfile, '/local/not-exist')
         self.assertRaises(CmdctlException, self.server._check_key_and_cert,
-                         '/local/not-exist', 'cmdctl-keyfile.pem')
+                         '/local/not-exist', certfile)
 
-        self.server._check_key_and_cert(SRC_FILE_PATH + 'cmdctl-keyfile.pem',
-                                        SRC_FILE_PATH + 'cmdctl-certfile.pem')
+        # not a file
+        self.assertRaises(CmdctlException, self.server._check_key_and_cert,
+                          keyfile, '/')
+        self.assertRaises(CmdctlException, self.server._check_key_and_cert,
+                         '/', certfile)
+
+        # no read permission
+        with UnreadableFile(certfile):
+            self.assertRaises(CmdctlException,
+                              self.server._check_key_and_cert,
+                              keyfile, certfile)
+
+        with UnreadableFile(keyfile):
+            self.assertRaises(CmdctlException,
+                              self.server._check_key_and_cert,
+                              keyfile, certfile)
+
+        # All OK (also happens to check the context code above works)
+        self.server._check_key_and_cert(keyfile, certfile)
 
     def test_wrap_sock_in_ssl_context(self):
         sock = socket.socket()
-        self.assertRaises(socket.error, 
+
+        # Test exception is Fatal here (all specific cases are tested
+        # in test_check_key_and_cert())
+        self.assertRaises(CmdctlFatalException,
                           self.server._wrap_socket_in_ssl_context,
-                          sock, 
-                          '../cmdctl-keyfile',
-                          '../cmdctl-certfile')
+                          sock,
+                          'no_such_file', 'no_such_file')
 
         sock1 = socket.socket()
-        self.server._wrap_socket_in_ssl_context(sock1, 
+        self.server._wrap_socket_in_ssl_context(sock1,
                           SRC_FILE_PATH + 'cmdctl-keyfile.pem',
                           SRC_FILE_PATH + 'cmdctl-certfile.pem')
 
 class TestFuncNotInClass(unittest.TestCase):
     def test_check_port(self):
-        self.assertRaises(OptionValueError, check_port, None, 'port', -1, None)        
-        self.assertRaises(OptionValueError, check_port, None, 'port', 65536, None)        
-        self.assertRaises(OptionValueError, check_addr, None, 'ipstr', 'a.b.d', None)        
-        self.assertRaises(OptionValueError, check_addr, None, 'ipstr', '1::0:a.b', None)        
+        self.assertRaises(OptionValueError, check_port, None, 'port', -1, None)
+        self.assertRaises(OptionValueError, check_port, None, 'port', 65536, None)
+        self.assertRaises(OptionValueError, check_addr, None, 'ipstr', 'a.b.d', None)
+        self.assertRaises(OptionValueError, check_addr, None, 'ipstr', '1::0:a.b', None)
 
 
 if __name__== "__main__":