Browse Source

[master] Merge branch 'trac2710'

Conflicts:
	src/bin/cmdctl/tests/cmdctl_test.py
Jelte Jansen 12 years ago
parent
commit
16e8be506f
2 changed files with 115 additions and 3 deletions
  1. 17 3
      src/bin/cmdctl/cmdctl.py.in
  2. 98 0
      src/bin/cmdctl/tests/cmdctl_test.py

+ 17 - 3
src/bin/cmdctl/cmdctl.py.in

@@ -507,12 +507,25 @@ class SecureHTTPServer(socketserver_mixin.NoPollMixIn,
         self._verbose = verbose
         self._lock = threading.Lock()
         self._user_infos = {}
-        self._accounts_file = None
+        self.__accounts_file = None
+        self.__accounts_file_mtime = 0
 
     def _create_user_info(self, accounts_file):
         '''Read all user's name and its' salt, hashed password
         from accounts file.'''
-        if (self._accounts_file == accounts_file) and (len(self._user_infos) > 0):
+
+        # If the file does not exist, set accounts to empty, and return
+        if not os.path.exists(accounts_file):
+            self._user_infos = {}
+            self.__accounts_file = None
+            self.__accounts_file_mtime = 0
+            return
+
+        # If the filename hasn't changed, and the file itself
+        # has neither, do nothing
+        accounts_file_mtime = os.stat(accounts_file).st_mtime
+        if self.__accounts_file == accounts_file and\
+           accounts_file_mtime <= self.__accounts_file_mtime:
             return
 
         with self._lock:
@@ -530,7 +543,8 @@ class SecureHTTPServer(socketserver_mixin.NoPollMixIn,
                 if csvfile:
                     csvfile.close()
 
-        self._accounts_file = accounts_file
+        self.__accounts_file = accounts_file
+        self.__accounts_file_mtime = accounts_file_mtime
         if len(self._user_infos) == 0:
             logger.error(CMDCTL_NO_USER_ENTRIES_READ)
 

+ 98 - 0
src/bin/cmdctl/tests/cmdctl_test.py

@@ -17,6 +17,7 @@
 import unittest
 import socket
 import tempfile
+import time
 import stat
 import sys
 from cmdctl import *
@@ -71,6 +72,26 @@ class UnreadableFile:
     def __exit__(self, type, value, traceback):
         os.chmod(self.file_name, self.orig_mode)
 
+class TmpTextFile:
+    """
+    Context class for temporarily creating a text file with some
+    lines of content.
+
+    The file is automatically deleted if the context is left, so
+    make sure to not use the path of an existing file!
+    """
+    def __init__(self, path, contents):
+        self.__path = path
+        self.__contents = contents
+
+    def __enter__(self):
+        with open(self.__path, 'w') as f:
+            f.write("\n".join(self.__contents) + "\n")
+
+    def __exit__(self, type, value, traceback):
+        os.unlink(self.__path)
+
+
 class TestSecureHTTPRequestHandler(unittest.TestCase):
     def setUp(self):
         self.old_stdout = sys.stdout
@@ -534,6 +555,83 @@ class TestSecureHTTPServer(unittest.TestCase):
         self.assertIn('6f0c73bd33101a5ec0294b3ca39fec90ef4717fe',
                       self.server.get_user_info('root'))
 
+        # When the file is not changed calling _create_user_info() again
+        # should have no effect. In order to test this, we overwrite the
+        # user-infos that were just set and make sure it isn't touched by
+        # the call (so make sure it isn't set to some empty value)
+        fake_users_val = { 'notinfile': [] }
+        self.server._user_infos = fake_users_val
+        self.server._create_user_info(SRC_FILE_PATH + 'cmdctl-accounts.csv')
+        self.assertEqual(fake_users_val, self.server._user_infos)
+
+    def test_create_user_info_changing_file_time(self):
+        self.assertEqual(0, len(self.server._user_infos))
+
+        # Create a file
+        accounts_file = BUILD_FILE_PATH + 'new_file.csv'
+        with TmpTextFile(accounts_file, ['root,foo,bar']):
+            self.server._create_user_info(accounts_file)
+            self.assertEqual(1, len(self.server._user_infos))
+            self.assertTrue('root' in self.server._user_infos)
+
+            # Make sure re-reading is a noop if file was not modified
+            fake_users_val = { 'notinfile': [] }
+            self.server._user_infos = fake_users_val
+            self.server._create_user_info(accounts_file)
+            self.assertEqual(fake_users_val, self.server._user_infos)
+
+        # Yes sleep sucks, but in this case we need it to check for
+        # a changed mtime, not for some thread to do its work
+        # (do we run these tests on systems with 1+ secs mtimes?)
+        time.sleep(0.1)
+        # create the file again, this time read should not be a noop
+        with TmpTextFile(accounts_file, ['otherroot,foo,bar']):
+            self.server._create_user_info(accounts_file)
+            self.assertEqual(1, len(self.server._user_infos))
+            self.assertTrue('otherroot' in self.server._user_infos)
+
+    def test_create_user_info_changing_file_name(self):
+        """
+        Check that the accounts file is re-read if the file name is different
+        """
+        self.assertEqual(0, len(self.server._user_infos))
+
+        # Create two files
+        accounts_file1 = BUILD_FILE_PATH + 'new_file.csv'
+        accounts_file2 = BUILD_FILE_PATH + 'new_file2.csv'
+        with TmpTextFile(accounts_file2, ['otherroot,foo,bar']):
+            with TmpTextFile(accounts_file1, ['root,foo,bar']):
+                self.server._create_user_info(accounts_file1)
+                self.assertEqual(1, len(self.server._user_infos))
+                self.assertTrue('root' in self.server._user_infos)
+
+                # Make sure re-reading is a noop if file was not modified
+                fake_users_val = { 'notinfile': [] }
+                self.server._user_infos = fake_users_val
+                self.server._create_user_info(accounts_file1)
+                self.assertEqual(fake_users_val, self.server._user_infos)
+
+                # But a different file should be read
+                self.server._create_user_info(accounts_file2)
+                self.assertEqual(1, len(self.server._user_infos))
+                self.assertTrue('otherroot' in self.server._user_infos)
+
+    def test_create_user_info_nonexistent_file(self):
+        # Even if there was data initially, if set to a nonexistent
+        # file it should result in no users
+        accounts_file = BUILD_FILE_PATH + 'new_file.csv'
+        self.assertFalse(os.path.exists(accounts_file))
+        fake_users_val = { 'notinfile': [] }
+        self.server._user_infos = fake_users_val
+        self.server._create_user_info(accounts_file)
+        self.assertEqual({}, self.server._user_infos)
+
+        # Should it now be created it should be read
+        with TmpTextFile(accounts_file, ['root,foo,bar']):
+            self.server._create_user_info(accounts_file)
+            self.assertEqual(1, len(self.server._user_infos))
+            self.assertTrue('root' in self.server._user_infos)
+
     def test_check_file(self):
         # Just some file that we know exists
         file_name = BUILD_FILE_PATH + 'cmdctl-keyfile.pem'