Browse Source

[2710] Check file mtime and name, and reread if changed

Jelte Jansen 12 years ago
parent
commit
0a55e2f325
2 changed files with 44 additions and 2 deletions
  1. 12 1
      src/bin/cmdctl/cmdctl.py.in
  2. 32 1
      src/bin/cmdctl/tests/cmdctl_test.py

+ 12 - 1
src/bin/cmdctl/cmdctl.py.in

@@ -503,11 +503,21 @@ class SecureHTTPServer(socketserver_mixin.NoPollMixIn,
         self._lock = threading.Lock()
         self._user_infos = {}
         self._accounts_file = None
+        self.__accounts_file_mtime = 0
 
     def _create_user_info(self, accounts_file):
         '''Read all user's name and its' salt, hashed password
         from accounts file.'''
-        if (self._accounts_file == accounts_file) and (len(self._user_infos) > 0):
+
+        # If the file does not exist, do nothing
+        if not os.path.exists(accounts_file):
+            return
+
+        # If the filename hasn't changed, and the file itself
+        # has neither, do nothing
+        accounts_file_mtime = os.stat(accounts_file).st_mtime
+        if self._accounts_file == accounts_file and\
+           accounts_file_mtime <= self.__accounts_file_mtime:
             return
 
         with self._lock:
@@ -526,6 +536,7 @@ class SecureHTTPServer(socketserver_mixin.NoPollMixIn,
                     csvfile.close()
 
         self._accounts_file = accounts_file
+        self.__accounts_file_mtime = accounts_file_mtime
         if len(self._user_infos) == 0:
             logger.error(CMDCTL_NO_USER_ENTRIES_READ)
 

+ 32 - 1
src/bin/cmdctl/tests/cmdctl_test.py

@@ -17,6 +17,7 @@
 import unittest
 import socket
 import tempfile
+import time
 import stat
 import sys
 from cmdctl import *
@@ -500,7 +501,7 @@ class TestSecureHTTPServer(unittest.TestCase):
         self.server._create_user_info(SRC_FILE_PATH + 'cmdctl-accounts.csv')
         self.assertEqual(fake_users_val, self.server._user_infos)
 
-    def test_create_user_info_changing_file(self):
+    def test_create_user_info_changing_file_time(self):
         self.assertEqual(0, len(self.server._user_infos))
         self.assertFalse('root' in self.server._user_infos)
 
@@ -517,12 +518,42 @@ class TestSecureHTTPServer(unittest.TestCase):
             self.server._create_user_info(accounts_file)
             self.assertEqual(fake_users_val, self.server._user_infos)
 
+        # Yes sleep sucks, but in this case we need it to check for
+        # a changed mtime, not for some thread to do its work
+        time.sleep(1.1)
         # create the file again, this time read should not be a noop
         with TmpTextFile(accounts_file, ['otherroot,foo,bar']):
             self.server._create_user_info(accounts_file)
             self.assertEqual(1, len(self.server._user_infos))
             self.assertTrue('otherroot' in self.server._user_infos)
 
+    def test_create_user_info_changing_file_name(self):
+        """
+        Check that the accounts file is re-read if the file name is different
+        """
+        self.assertEqual(0, len(self.server._user_infos))
+        self.assertFalse('root' in self.server._user_infos)
+
+        # Create two files
+        accounts_file1 = BUILD_FILE_PATH + 'new_file.csv'
+        accounts_file2 = BUILD_FILE_PATH + 'new_file2.csv'
+        with TmpTextFile(accounts_file2, ['otherroot,foo,bar']):
+            with TmpTextFile(accounts_file1, ['root,foo,bar']):
+                self.server._create_user_info(accounts_file1)
+                self.assertEqual(1, len(self.server._user_infos))
+                self.assertTrue('root' in self.server._user_infos)
+
+                # Make sure re-reading is a noop if file was not modified
+                fake_users_val = { 'notinfile': [] }
+                self.server._user_infos = fake_users_val
+                self.server._create_user_info(accounts_file1)
+                self.assertEqual(fake_users_val, self.server._user_infos)
+
+                # But a different file should be read
+                self.server._create_user_info(accounts_file2)
+                self.assertEqual(1, len(self.server._user_infos))
+                self.assertTrue('otherroot' in self.server._user_infos)
+
     def test_check_file(self):
         # Just some file that we know exists
         file_name = BUILD_FILE_PATH + 'cmdctl-keyfile.pem'