Browse Source

[2641] Test /users-exist API call

Mukund Sivaraman 12 years ago
parent
commit
1e653cb517
1 changed files with 38 additions and 15 deletions
  1. 38 15
      src/bin/cmdctl/tests/cmdctl_test.py

+ 38 - 15
src/bin/cmdctl/tests/cmdctl_test.py

@@ -41,19 +41,6 @@ class MySecureHTTPRequestHandler(SecureHTTPRequestHandler):
     def end_headers(self):
     def end_headers(self):
         pass
         pass
 
 
-    def do_GET(self):
-        self.wfile = open('tmp.file', 'wb')
-        super().do_GET()
-        self.wfile.close()
-        os.remove('tmp.file')
-
-    def do_POST(self):
-        self.wfile = open("tmp.file", 'wb')
-        super().do_POST()
-        self.wfile.close()
-        os.remove('tmp.file')
-
-
 class FakeSecureHTTPServer(SecureHTTPServer):
 class FakeSecureHTTPServer(SecureHTTPServer):
     def __init__(self):
     def __init__(self):
         self.user_sessions = {}
         self.user_sessions = {}
@@ -93,13 +80,16 @@ class TestSecureHTTPRequestHandler(unittest.TestCase):
         self.handler.server.user_sessions = {}
         self.handler.server.user_sessions = {}
         self.handler.server._user_infos = {}
         self.handler.server._user_infos = {}
         self.handler.headers = {}
         self.handler.headers = {}
-        self.handler.rfile = open("check.tmp", 'w+b')
+        self.handler.rfile = open('input.tmp', 'w+b')
+        self.handler.wfile = open('output.tmp', 'w+b')
 
 
     def tearDown(self):
     def tearDown(self):
         sys.stdout.close()
         sys.stdout.close()
         sys.stdout = self.old_stdout
         sys.stdout = self.old_stdout
+        self.handler.wfile.close()
+        os.remove('output.tmp')
         self.handler.rfile.close()
         self.handler.rfile.close()
-        os.remove('check.tmp')
+        os.remove('input.tmp')
 
 
     def test_is_session_valid(self):
     def test_is_session_valid(self):
         self.assertIsNone(self.handler.session_id)
         self.assertIsNone(self.handler.session_id)
@@ -300,6 +290,39 @@ class TestSecureHTTPRequestHandler(unittest.TestCase):
         rcode, reply = self.handler._handle_post_request()
         rcode, reply = self.handler._handle_post_request()
         self.assertEqual(http.client.BAD_REQUEST, rcode)
         self.assertEqual(http.client.BAD_REQUEST, rcode)
 
 
+    def test_handle_users_exist(self):
+        orig_get_num_users = self.handler.server.get_num_users
+        try:
+            def create_get_num_users(n):
+                '''Create a replacement get_num_users() method.'''
+                def my_get_num_users():
+                    return n
+                return my_get_num_users
+
+            # Check case where get_num_users() returns 0
+            self.handler.server.get_num_users = create_get_num_users(0)
+            self.handler.headers['cookie'] = 12345
+            self.handler.path = '/users-exist'
+            self.handler.do_POST()
+            self.assertEqual(self.handler.rcode, http.client.OK)
+            self.handler.wfile.seek(0, 0)
+            d = self.handler.wfile.read()
+            self.assertFalse(json.loads(d.decode()))
+
+            # Clear the output
+            self.handler.wfile.seek(0, 0)
+            self.handler.wfile.truncate()
+
+            # Check case where get_num_users() returns > 0
+            self.handler.server.get_num_users = create_get_num_users(4)
+            self.handler.do_POST()
+            self.assertEqual(self.handler.rcode, http.client.OK)
+            self.handler.wfile.seek(0, 0)
+            d = self.handler.wfile.read()
+            self.assertTrue(json.loads(d.decode()))
+        finally:
+            self.handler.server.get_num_users = orig_get_num_users
+
 class MyCommandControl(CommandControl):
 class MyCommandControl(CommandControl):
     def _get_modules_specification(self):
     def _get_modules_specification(self):
         return {}
         return {}