Browse Source

1. Add unittests and help information for cmdctl and bindctl.
2. Add login idle timeout for cmdctl, default idle time is 1200 seconds.
3. Refactor some code for cmdctl.

git-svn-id: svn://bind10.isc.org/svn/bind10/trunk@961 e5f2f494-b856-4b98-b285-d166d9295462

Likun Zhang 15 years ago
parent
commit
5281fe9e81

+ 4 - 1
configure.ac

@@ -172,8 +172,9 @@ AC_CONFIG_FILES([Makefile
                ])
 AC_OUTPUT([src/bin/cfgmgr/b10-cfgmgr.py
            src/bin/cfgmgr/run_b10-cfgmgr.sh
-           src/bin/cmdctl/b10-cmdctl.py
+           src/bin/cmdctl/cmdctl.py
            src/bin/cmdctl/run_b10-cmdctl.sh
+           src/bin/cmdctl/unittest/cmdctl_test
            src/bin/bind10/bind10.py
            src/bin/bind10/bind10_test
            src/bin/bind10/run_bind10.sh
@@ -190,6 +191,8 @@ AC_OUTPUT([src/bin/cfgmgr/b10-cfgmgr.py
            chmod +x src/bin/cfgmgr/run_b10-cfgmgr.sh
            chmod +x src/bin/cmdctl/run_b10-cmdctl.sh
            chmod +x src/bin/bind10/run_bind10.sh
+           chmod +x src/bin/cmdctl/unittest/cmdctl_test
+           chmod +x src/bin/bindctl/unittest/bindctl_test
            chmod +x src/bin/bindctl/bindctl
            chmod +x src/bin/msgq/run_msgq.sh
            chmod +x src/bin/msgq/msgq_test

+ 12 - 0
src/bin/bindctl/TODO

@@ -1,3 +1,15 @@
 1. Refactor the code for bindctl.
 2. Update man page for bindctl provided by jreed.
 3. Add more unit tests.
+4. Need Review:
+        bindcmd.py:
+            apply_config_cmd()
+            _validate_cmd()
+            complete()
+            
+        cmdparse.py:
+            _parse_params
+
+        moduleinfo.py:
+            get_param_name_by_position
+

+ 43 - 42
src/bin/bindctl/bindcmd.py

@@ -61,47 +61,52 @@ class BindCmdInterpreter(Cmd):
         self.modules = OrderedDict()
         self.add_module_info(ModuleInfo("help", desc = "Get help for bindctl"))
         self.server_port = server_port
-        self.connect_to_cmd_ctrld()
+        self._connect_to_cmd_ctrld()
         self.session_id = self._get_session_id()
 
-    def connect_to_cmd_ctrld(self):
+    def _connect_to_cmd_ctrld(self):
+        '''Connect to cmdctl in SSL context. '''
         try:
             self.conn = http.client.HTTPSConnection(self.server_port, cert_file='bindctl.pem')
         except  Exception as e:
-            print(e)
-            print("can't connect to %s, please make sure cmd-ctrld is running" % self.server_port)
+            print(e, "can't connect to %s, please make sure cmd-ctrld is running" %
+                  self.server_port)
 
     def _get_session_id(self):
+        '''Generate one session id for the connection. '''
         rand = os.urandom(16)
         now = time.time()
         ip = socket.gethostbyname(socket.gethostname())
         session_id = sha1(("%s%s%s" %(rand, now, ip)).encode())
-        session_id = session_id.hexdigest()
-        return session_id
+        digest = session_id.hexdigest()
+        return digest
     
     def run(self):
+        '''Parse commands inputted from user and send them to cmdctl. '''
         try:
-            ret = self.login()
-            if not ret:
+            if not self.login_to_cmdctl():
                 return False
 
             # Get all module information from cmd-ctrld
             self.config_data = isc.config.UIModuleCCSession(self)
-            self.update_commands()
+            self._update_commands()
             self.cmdloop()
         except KeyboardInterrupt:
             return True
 
-    def login(self):
+    def login_to_cmdctl(self):
+        '''Login to cmdctl with the username and password inputted 
+        from user. After login sucessfully, the username and password
+        will be saved in 'default_user.csv', when login next time, 
+        username and password saved in 'default_user.csv' will be used
+        first.
+        '''
         csvfile = None
         bsuccess = False
         try:
             csvfile = open('default_user.csv')
             users = csv.reader(csvfile)
             for row in users:
-                if (len(row) < 2):
-                    continue
-
                 param = {'username': row[0], 'password' : row[1]}
                 response = self.send_POST('/login', param)
                 data = response.read().decode()
@@ -120,10 +125,13 @@ class BindCmdInterpreter(Cmd):
                 return True
 
         count = 0
-        csvfile = None
         print("[TEMP MESSAGE]: username :root  password :bind10")
-        while count < 3:
+        while True:
             count = count + 1
+            if count > 3:
+                print("Too many authentication failures")
+                return False
+
             username = input("Username:")
             passwd = getpass.getpass()
             param = {'username': username, 'password' : passwd}
@@ -135,26 +143,23 @@ class BindCmdInterpreter(Cmd):
                 csvfile = open('default_user.csv', 'w')
                 writer = csv.writer(csvfile)
                 writer.writerow([username, passwd])
-                bsuccess = True
-                break
-
-            if count == 3:
-                print("Too many authentication failures")
-                break
+                csvfile.close()
+                return True
 
-        if csvfile:
-            csvfile.close()
-        return bsuccess
 
-    def update_commands(self):
+    def _update_commands(self):
+        '''Get all commands of modules. '''
         cmd_spec = self.send_GET('/command_spec')
-        if (len(cmd_spec) == 0):
-            print('can\'t get any command specification')
+        if not cmd_spec:
+            return
+
         for module_name in cmd_spec.keys():
-            if cmd_spec[module_name]:
-                self.prepare_module_commands(module_name, cmd_spec[module_name])
+            self._prepare_module_commands(module_name, cmd_spec[module_name])
 
     def send_GET(self, url, body = None):
+        '''Send GET request to cmdctl, session id is send with the name
+        'cookie' in header.
+        '''
         headers = {"cookie" : self.session_id}
         self.conn.request('GET', url, body, headers)
         res = self.conn.getresponse()
@@ -162,11 +167,12 @@ class BindCmdInterpreter(Cmd):
         if reply_msg:
            return json.loads(reply_msg.decode())
         else:
-            return None
+            return {}
        
 
     def send_POST(self, url, post_param = None): 
-        '''
+        '''Send GET request to cmdctl, session id is send with the name
+        'cookie' in header.
         Format: /module_name/command_name
         parameters of command is encoded as a map
         '''
@@ -183,13 +189,12 @@ class BindCmdInterpreter(Cmd):
         self.prompt = self.location + self.prompt_end
         return stop
 
-    def prepare_module_commands(self, module_name, module_commands):
+    def _prepare_module_commands(self, module_name, module_commands):
         module = ModuleInfo(name = module_name,
                             desc = "same here")
         for command in module_commands:
             cmd = CommandInfo(name = command["command_name"],
-                              desc = command["command_description"],
-                              need_inst_param = False)
+                              desc = command["command_description"])
             for arg in command["command_args"]:
                 param = ParamInfo(name = arg["item_name"],
                                   type = arg["item_type"],
@@ -200,7 +205,7 @@ class BindCmdInterpreter(Cmd):
             module.add_command(cmd)
         self.add_module_info(module)
 
-    def validate_cmd(self, cmd):
+    def _validate_cmd(self, cmd):
         if not cmd.module in self.modules:
             raise CmdUnknownModuleSyntaxError(cmd.module)
         
@@ -225,7 +230,6 @@ class BindCmdInterpreter(Cmd):
                                              list(params.keys())[0])
         elif params:
             param_name = None
-            index = 0
             param_count = len(params)
             for name in params:
                 # either the name of the parameter must be known, or
@@ -250,18 +254,17 @@ class BindCmdInterpreter(Cmd):
                             raise CmdUnknownParamSyntaxError(cmd.module, cmd.command, cmd.params[name])
                     else:
                         # replace the numbered items by named items
-                        param_name = command_info.get_param_name_by_position(name+1, index, param_count)
+                        param_name = command_info.get_param_name_by_position(name+1, param_count)
                         cmd.params[param_name] = cmd.params[name]
                         del cmd.params[name]
                         
                 elif not name in all_params:
                     raise CmdUnknownParamSyntaxError(cmd.module, cmd.command, name)
+
             param_nr = 0
             for name in manda_params:
                 if not name in params and not param_nr in params:
                     raise CmdMissParamSyntaxError(cmd.module, cmd.command, name)
-                
-                param_nr += 1
                 param_nr += 1
 
     def _handle_cmd(self, cmd):
@@ -385,7 +388,7 @@ class BindCmdInterpreter(Cmd):
     def _parse_cmd(self, line):
         try:
             cmd = BindCmdParse(line)
-            self.validate_cmd(cmd)
+            self._validate_cmd(cmd)
             self._handle_cmd(cmd)
         except BindCtlException as e:
             print("Error! ", e)
@@ -497,5 +500,3 @@ class BindCmdInterpreter(Cmd):
         print("received reply:", data)
 
 
-
-

+ 49 - 21
src/bin/bindctl/bindctl.py

@@ -18,69 +18,97 @@ from moduleinfo  import *
 from bindcmd import *
 import isc
 import pprint
+from optparse import OptionParser, OptionValueError
 
+__version__ = 'Bindctl'
 
 def prepare_config_commands(tool):
     module = ModuleInfo(name = "config", desc = "Configuration commands")
-    cmd = CommandInfo(name = "show", desc = "Show configuration", need_inst_param = False)
+    cmd = CommandInfo(name = "show", desc = "Show configuration")
     param = ParamInfo(name = "identifier", type = "string", optional=True)
     cmd.add_param(param)
     module.add_command(cmd)
 
-    cmd = CommandInfo(name = "add", desc = "Add entry to configuration list", need_inst_param = False)
+    cmd = CommandInfo(name = "add", desc = "Add entry to configuration list")
     param = ParamInfo(name = "identifier", type = "string", optional=True)
     cmd.add_param(param)
     param = ParamInfo(name = "value", type = "string", optional=False)
     cmd.add_param(param)
     module.add_command(cmd)
 
-    cmd = CommandInfo(name = "remove", desc = "Remove entry from configuration list", need_inst_param = False)
+    cmd = CommandInfo(name = "remove", desc = "Remove entry from configuration list")
     param = ParamInfo(name = "identifier", type = "string", optional=True)
     cmd.add_param(param)
     param = ParamInfo(name = "value", type = "string", optional=False)
     cmd.add_param(param)
     module.add_command(cmd)
 
-    cmd = CommandInfo(name = "set", desc = "Set a configuration value", need_inst_param = False)
+    cmd = CommandInfo(name = "set", desc = "Set a configuration value")
     param = ParamInfo(name = "identifier", type = "string", optional=True)
     cmd.add_param(param)
     param = ParamInfo(name = "value", type = "string", optional=False)
     cmd.add_param(param)
     module.add_command(cmd)
 
-    cmd = CommandInfo(name = "unset", desc = "Unset a configuration value", need_inst_param = False)
+    cmd = CommandInfo(name = "unset", desc = "Unset a configuration value")
     param = ParamInfo(name = "identifier", type = "string", optional=False)
     cmd.add_param(param)
     module.add_command(cmd)
 
-    cmd = CommandInfo(name = "diff", desc = "Show all local changes", need_inst_param = False)
+    cmd = CommandInfo(name = "diff", desc = "Show all local changes")
     module.add_command(cmd)
 
-    cmd = CommandInfo(name = "revert", desc = "Revert all local changes", need_inst_param = False)
+    cmd = CommandInfo(name = "revert", desc = "Revert all local changes")
     module.add_command(cmd)
 
-    cmd = CommandInfo(name = "commit", desc = "Commit all local changes", need_inst_param = False)
+    cmd = CommandInfo(name = "commit", desc = "Commit all local changes")
     module.add_command(cmd)
 
-    cmd = CommandInfo(name = "go", desc = "Go to a specific configuration part", need_inst_param = False)
+    cmd = CommandInfo(name = "go", desc = "Go to a specific configuration part")
     param = ParamInfo(name = "identifier", type="string", optional=False)
     cmd.add_param(param)
     module.add_command(cmd)
 
     tool.add_module_info(module)
 
+def check_port(option, opt_str, value, parser):
+    if (value < 0) or (value > 65535):
+        raise OptionValueError('%s requires a port number (0-65535)' % opt_str)
+    parser.values.port = value
+
+def check_addr(option, opt_str, value, parser):
+    ipstr = value
+    ip_family = socket.AF_INET
+    if (ipstr.find(':') != -1):
+        ip_family = socket.AF_INET6
+
+    try:
+        socket.inet_pton(ip_family, ipstr)
+    except:
+        raise OptionValueError("%s invalid ip address" % ipstr)
+
+    parser.values.addr = value
+
+def set_bindctl_options(parser):
+    parser.add_option('-p', '--port', dest = 'port', type = 'int',
+            action = 'callback', callback=check_port,
+            default = '8080', help = 'port for cmdctl of bind10')
+
+    parser.add_option('-a', '--address', dest = 'addr', type = 'string',
+            action = 'callback', callback=check_addr,
+            default = '127.0.0.1', help = 'IP address for cmdctl of bind10')
+
+
 if __name__ == '__main__':
-    tool = BindCmdInterpreter("localhost:8080")
-    prepare_config_commands(tool)
-    tool.run()
-# TODO: put below back, was removed to see errors
-#if __name__ == '__main__':
-    #try:
-        #tool = BindCmdInterpreter("localhost:8080")
-        #prepare_config_commands(tool)
-        #tool.run()
-    #except Exception as e:
-        #print(e)
-        #print("Failed to connect with b10-cmdctl module, is it running?")
+    try:
+        parser = OptionParser(version = __version__)
+        set_bindctl_options(parser)
+        (options, args) = parser.parse_args()
+        server_addr = options.addr + ':' + str(options.port)
+        tool = BindCmdInterpreter(server_addr)
+        prepare_config_commands(tool)
+        tool.run()
+    except Exception as e:
+        print(e, "\nFailed to connect with b10-cmdctl module, is it running?")
 
 

+ 5 - 4
src/bin/bindctl/cmdparse.py

@@ -34,10 +34,10 @@ PARAM_PATTERN = re.compile(param_name_str + param_value_str + next_params_str)
 NAME_PATTERN = re.compile("^\s*(?P<name>[\w]+)(?P<blank>\s*)(?P<others>.*)$")
 
 class BindCmdParse:
-    """ This class will parse the command line user input into three parts:
-    module name, command, parameters.
-    The first two parts are strings and parameter is one hash.
-    The parameter part is optional.
+    """ This class will parse the command line usr input into three part
+    module name, command, parameters
+    the first two parts are strings and parameter is one hash, 
+    parameters part is optional
     
     Example: zone reload, zone_name=example.com 
     module == zone
@@ -52,6 +52,7 @@ class BindCmdParse:
         self._parse_cmd(cmd)
 
     def _parse_cmd(self, text_str):    
+        '''Parse command line. '''
         # Get module name
         groups = NAME_PATTERN.match(text_str)
         if not groups:

+ 10 - 41
src/bin/bindctl/moduleinfo.py

@@ -51,10 +51,8 @@ class CommandInfo:
     more parameters
     """
 
-    def __init__(self, name, desc = "", need_inst_param = True):
+    def __init__(self, name, desc = ""):
         self.name = name
-        # Wether command needs parameter "instance_name" 
-        self.need_inst_param = need_inst_param 
         self.desc = desc
         self.params = OrderedDict()        
         # Set default parameter "help"
@@ -91,7 +89,7 @@ class CommandInfo:
         return [name for name in all_names 
                 if not self.params[name].is_optional]        
         
-    def get_param_name_by_position(self, pos, index, param_count):
+    def get_param_name_by_position(self, pos, param_count):
         # count mandatories back from the last
         # from the last mandatory; see the number of mandatories before it
         # and compare that to the number of positional arguments left to do
@@ -101,7 +99,9 @@ class CommandInfo:
         # (can this be done in all cases? this is certainly not the most efficient method;
         # one way to make the whole of this more consistent is to always set mandatories first, but
         # that would make some commands less nice to use ("config set value location" instead of "config set location value")
-        if type(pos) == int:
+        if type(pos) != int:
+            raise KeyError(str(pos) + " is not an integer")
+        else:
             if param_count == len(self.params) - 1:
                 i = 0
                 for k in self.params.keys():
@@ -131,14 +131,9 @@ class CommandInfo:
                     raise KeyError(str(pos) + " out of range")
             else:
                 raise KeyError("Too many parameters")
-        else:
-            raise KeyError(str(pos) + " is not an integer")
-    
 
-    def need_instance_param(self):
-        return self.need_inst_param
 
-    def command_help(self, inst_name, inst_type, inst_desc):
+    def command_help(self):
         print("Command ", self)
         print("\t\thelp (Get help for command)")
                 
@@ -166,65 +161,39 @@ class CommandInfo:
 
 class ModuleInfo:
     """Define the information of one module, include module name, 
-    module supporting commands, instance name and the value type of instance name
+    module supporting commands.
     """    
     
-    def __init__(self, name, inst_name = "", inst_type = STRING_TYPE, 
-                 inst_desc = "", desc = ""):
+    def __init__(self, name, desc = ""):
         self.name = name
-        self.inst_name = inst_name
-        self.inst_type = inst_type
-        self.inst_desc = inst_desc
         self.desc = desc
         self.commands = OrderedDict()         
         self.add_command(CommandInfo(name = "help", 
-                                     desc = "Get help for module",
-                                     need_inst_param = False))
+                                     desc = "Get help for module"))
         
     def __str__(self):
         return str("%s \t%s" % (self.name, self.desc))
         
     def add_command(self, command_info):        
         self.commands[command_info.name] = command_info
-        if command_info.need_instance_param():
-            command_info.add_param(ParamInfo(name = self.inst_name, 
-                                             type = self.inst_type,
-                                             desc = self.inst_desc))
-
         
     def has_command_with_name(self, command_name):
         return command_name in self.commands
         
-
     def get_command_with_name(self, command_name):
         return self.commands[command_name]
         
-        
     def get_commands(self):
         return list(self.commands.values())
         
-    
     def get_command_names(self):
         return list(self.commands.keys())
-        
-    
-    def get_instance_param_name(self):
-        return self.inst_name
-        
-        
-    def get_instance_param_type(self):
-        return self.inst_type
-        
 
     def module_help(self):
         print("Module ", self, "\nAvailable commands:")
         for k in self.commands.keys():
             print("\t", self.commands[k])
             
-            
     def command_help(self, command):
-        self.commands[command].command_help(self.inst_name, 
-                                            self.inst_type,
-                                            self.inst_desc)
-    
+        self.commands[command].command_help()    
 

+ 7 - 14
src/bin/bindctl/unittest/bindctl_test.py

@@ -85,16 +85,6 @@ class TestCmdLex(unittest.TestCase):
         self.my_assert_raise(CmdCommandNameFormatError, "zone z-d ")
         self.my_assert_raise(CmdCommandNameFormatError, "zone zdd/")
         self.my_assert_raise(CmdCommandNameFormatError, "zone zdd/ \"")
-        
-
-    def testCmdParamFormatError(self): 
-        self.my_assert_raise(CmdParamFormatError, "zone load load")
-        self.my_assert_raise(CmdParamFormatError, "zone load load=")
-        self.my_assert_raise(CmdParamFormatError, "zone load load==dd")
-        self.my_assert_raise(CmdParamFormatError, "zone load , zone_name=dd zone_file=d" )
-        self.my_assert_raise(CmdParamFormatError, "zone load zone_name=dd zone_file" )
-        self.my_assert_raise(CmdParamFormatError, "zone zdd \"")
-        
 
 class TestCmdSyntax(unittest.TestCase):
     
@@ -103,18 +93,21 @@ class TestCmdSyntax(unittest.TestCase):
         
         tool = bindcmd.BindCmdInterpreter()        
         zone_file_param = ParamInfo(name = "zone_file")
+        zone_name = ParamInfo(name = 'zone_name')
         load_cmd = CommandInfo(name = "load")
         load_cmd.add_param(zone_file_param)
+        load_cmd.add_param(zone_name)
         
         param_master = ParamInfo(name = "master", optional = True)                                 
         param_allow_update = ParamInfo(name = "allow_update", optional = True)                                           
         set_cmd = CommandInfo(name = "set")
         set_cmd.add_param(param_master)
         set_cmd.add_param(param_allow_update)
+        set_cmd.add_param(zone_name)
         
-        reload_all_cmd = CommandInfo(name = "reload_all", need_inst_param = False)        
+        reload_all_cmd = CommandInfo(name = "reload_all")        
         
-        zone_module = ModuleInfo(name = "zone", inst_name = "zone_name")                             
+        zone_module = ModuleInfo(name = "zone")                             
         zone_module.add_command(load_cmd)
         zone_module.add_command(set_cmd)
         zone_module.add_command(reload_all_cmd)
@@ -129,12 +122,12 @@ class TestCmdSyntax(unittest.TestCase):
         
     def no_assert_raise(self, cmd_line):
         cmd = cmdparse.BindCmdParse(cmd_line)
-        self.bindcmd.validate_cmd(cmd) 
+        self.bindcmd._validate_cmd(cmd) 
         
         
     def my_assert_raise(self, exception_type, cmd_line):
         cmd = cmdparse.BindCmdParse(cmd_line)
-        self.assertRaises(exception_type, self.bindcmd.validate_cmd, cmd)  
+        self.assertRaises(exception_type, self.bindcmd._validate_cmd, cmd)  
         
         
     def testValidateSuccess(self):

+ 3 - 3
src/bin/cmdctl/Makefile.am

@@ -5,10 +5,10 @@ pkglibexec_SCRIPTS = b10-cmdctl
 b10_cmdctldir = $(DESTDIR)$(pkgdatadir)
 b10_cmdctl_DATA = passwd.csv b10-cmdctl.pem
 
-CLEANFILES=	b10-cmdctl
+CLEANFILES=	cmdctl.py
 
 # TODO: does this need $$(DESTDIR) also?
 # this is done here since configure.ac AC_OUTPUT doesn't expand exec_prefix
-b10-cmdctl: b10-cmdctl.py
-	$(SED) "s|@@PYTHONPATH@@|@pyexecdir@|" b10-cmdctl.py >$@
+b10-cmdctl: cmdctl.py
+	$(SED) "s|@@PYTHONPATH@@|@pyexecdir@|" cmdctl.py >$@
 	chmod a+x $@

+ 119 - 45
src/bin/cmdctl/b10-cmdctl.py.in

@@ -38,12 +38,16 @@ import pprint
 import select
 import csv
 import random
+import time
+import signal
+from optparse import OptionParser, OptionValueError
 from hashlib import sha1
 try:
     import threading
 except ImportError:
     import dummy_threading as threading
 
+__version__ = 'BIND10'
 URL_PATTERN = re.compile('/([\w]+)(?:/([\w]+))?/?')
 
 # If B10_FROM_SOURCE is set in the environment, we use data files
@@ -92,7 +96,16 @@ class SecureHTTPRequestHandler(http.server.BaseHTTPRequestHandler):
         return self.session_id 
 
     def _is_user_logged_in(self):
-        return self.session_id in self.server.user_sessions           
+        login_time = self.server.user_sessions.get(self.session_id)
+        if not login_time:
+            return False
+        
+        idle_time = time.time() - login_time
+        if idle_time > self.server.idle_timeout:
+            return False
+        # Update idle time
+        self.server.user_sessions[self.session_id] = time.time()
+        return True
 
     def _parse_request_path(self):
         '''Parse the url, the legal url should like /ldh or /ldh/ldh '''
@@ -103,6 +116,7 @@ class SecureHTTPRequestHandler(http.server.BaseHTTPRequestHandler):
             return (groups.group(1), groups.group(2))
 
     def do_POST(self):
+        '''Process POST request. '''
         '''Process user login and send command to proper module  
         The client should send its session id in header with 
         the name 'cookie'
@@ -112,8 +126,10 @@ class SecureHTTPRequestHandler(http.server.BaseHTTPRequestHandler):
         if self._is_session_valid():
             if self.path == '/login':
                 rcode, reply = self._handle_login()
-            else:
+            elif self._is_user_logged_in():
                 rcode, reply = self._handle_post_request()
+            else:
+                rcode, reply = http.client.UNAUTHORIZED, ["please login"]
         else:
             rcode, reply = http.client.BAD_REQUEST, ["session isn't valid"]
       
@@ -127,19 +143,22 @@ class SecureHTTPRequestHandler(http.server.BaseHTTPRequestHandler):
             return http.client.OK, ["user has already login"]
         is_user_valid, error_info = self._check_user_name_and_pwd()
         if is_user_valid:
-            self.server.user_sessions.append(self.session_id)
+            self.server.save_user_session_id(self.session_id)
             return http.client.OK, ["login success "]
         else:
             return http.client.UNAUTHORIZED, error_info
 
     def _check_user_name_and_pwd(self):
+        '''Check user name and its password '''
         length = self.headers.get('Content-Length')
         if not length:
             return False, ["invalid username or password"]     
-        user_info = json.loads((self.rfile.read(int(length))).decode())
-        if not user_info:
-             return False, ["invalid username or password"]                
-        
+
+        try:
+            user_info = json.loads((self.rfile.read(int(length))).decode())
+        except:
+            return False, ["invalid username or password"]                
+
         user_name = user_info.get('username')
         if not user_name:
             return False, ["need user name"]
@@ -158,21 +177,24 @@ class SecureHTTPRequestHandler(http.server.BaseHTTPRequestHandler):
    
 
     def _handle_post_request(self):
+        '''Handle all the post request from client. '''
         mod, cmd = self._parse_request_path()
+        if (not mod) or (not cmd):
+            return http.client.BAD_REQUEST, ['malformed url']
+
         param = None
         len = self.headers.get('Content-Length')
-        rcode = http.client.OK
-        reply = None
         if len:
-            post_str = str(self.rfile.read(int(len)).decode())
-            print("command parameter:%s" % post_str)
-            param = json.loads(post_str)
-            # TODO, need return some proper return code. 
-            # currently always OK.
-        reply = self.server.send_command_to_module(mod, cmd, param)
-        print('b10-cmdctl finish send message \'%s\' to module %s' % (cmd, mod))            
+            try:
+                post_str = str(self.rfile.read(int(len)).decode())
+                param = json.loads(post_str)
+            except:
+                pass
 
-        return rcode, reply
+        reply = self.server.send_command_to_module(mod, cmd, param)
+        print('b10-cmdctl finish send message \'%s\' to module %s' % (cmd, mod))
+        # TODO, need set proper rcode
+        return http.client.OK, reply
             
    
 class CommandControl():
@@ -192,9 +214,11 @@ class CommandControl():
         return self.send_command('ConfigManager', 'get_commands_spec')
 
     def get_config_data(self):
+        '''Get config data for all modules from configmanager '''
         return self.send_command('ConfigManager', 'get_config')
 
     def update_config_data(self, module_name, command_name):
+        '''Get lastest config data for all modules from configmanager '''
         if module_name == 'ConfigManager' and command_name == 'set_config':
             self.config_data = self.get_config_data()
 
@@ -202,7 +226,7 @@ class CommandControl():
         return self.send_command('ConfigManager', 'get_module_spec')
 
     def handle_recv_msg(self):
-        # Handle received message, if 'shutdown' is received, return False
+        '''Handle received message, if 'shutdown' is received, return False'''
         (message, env) = self.cc.group_recvmsg(True)
         while message:
             if 'commands_update' in message:
@@ -217,44 +241,44 @@ class CommandControl():
         return True
     
     def send_command(self, module_name, command_name, params = None):       
+        '''Send the command from bindctl to proper module. '''
         content = [command_name]
         if params:
             content.append(params)
 
-        msg = {'command' : content}
+        reply = {}
         print('b10-cmdctl send command \'%s\' to %s' %(command_name, module_name))
         try:
+            msg = {'command' : content}
             self.cc.group_sendmsg(msg, module_name)
-            #TODO, it may be blocked, msqg need to add a new interface
-            # wait in timeout.
+            #TODO, it may be blocked, msqg need to add a new interface waiting in timeout.
             answer, env = self.cc.group_recvmsg(False)
-            if answer and 'result' in answer.keys() and type(answer['result']) == list:
-                # TODO: with the new cc implementation, replace "1" by 1
-                if answer['result'][0] == 1:
+            if answer and 'result' in answer.keys() and type(answer['result']) == list:                
+                if answer['result'][0] != 0:
                     # todo: exception
-                    print("Error: " + str(answer['result'][1]))
-                    return {}
+                    print("Error: " + str(answer['result'][1]))                    
                 else:
                     self.update_config_data(module_name, command_name)
-                    if len(answer['result']) > 1:
-                        return answer['result'][1]
-                    return {}
+                    if (len(answer['result']) > 1):
+                        reply = answer['result'][1]
             else:
                 print("Error: unexpected answer from %s" % module_name)
                 print(answer)
         except Exception as e:
-            print(e)
-            print('b10-cmdctl fail send command \'%s\' to %s' % (command_name, module_name))
-        return {}
+            print(e, ':b10-cmdctl fail send command \'%s\' to %s' % (command_name, module_name))
+        
+        return reply
 
 
 class SecureHTTPServer(http.server.HTTPServer):
     '''Make the server address can be reused.'''
     allow_reuse_address = True
 
-    def __init__(self, server_address, RequestHandlerClass):
+    def __init__(self, server_address, RequestHandlerClass, idle_timeout = 1200):
+        '''idle_timeout: the max idle time for login'''
         http.server.HTTPServer.__init__(self, server_address, RequestHandlerClass)
-        self.user_sessions = []
+        self.user_sessions = {}
+        self.idle_timeout = idle_timeout
         self.cmdctrl = CommandControl()
         self.__is_shut_down = threading.Event()
         self.__serving = False
@@ -262,23 +286,25 @@ class SecureHTTPServer(http.server.HTTPServer):
         self._read_user_info()
 
     def _read_user_info(self):
-        # Get all username and password information
+        '''Read all user's name and its' password from csv file.'''
         csvfile = None
         try:
             csvfile = open(USER_INFO_FILE)
             reader = csv.reader(csvfile)
             for row in reader:
                 self.user_infos[row[0]] = [row[1], row[2]]
-                
         except Exception as e:
             print("Fail to read user information ", e)                
-            exit(1)
         finally:
             if csvfile:
                 csvfile.close()
         
+    def save_user_session_id(self, session_id):
+        # Record user's id and login time.
+        self.user_sessions[session_id] = time.time()
         
     def get_request(self):
+        '''Get client request socket and wrap it in SSL context. '''
         newsocket, fromaddr = self.socket.accept()
         try:
             connstream = ssl.wrap_socket(newsocket,
@@ -298,18 +324,18 @@ class SecureHTTPServer(http.server.HTTPServer):
         '''Currently only support the following three url GET request '''
         rcode, reply = http.client.NO_CONTENT, []        
         if not module:
-            rcode = http.client.OK
             if id == 'command_spec':
-                reply = self.cmdctrl.command_spec
+               rcode, reply = http.client.OK, self.cmdctrl.command_spec
             elif id == 'config_data':
-                reply = self.cmdctrl.config_data
+               rcode, reply = http.client.OK, self.cmdctrl.config_data
             elif id == 'config_spec':
-                reply = self.cmdctrl.config_spec
+               rcode, reply = http.client.OK, self.cmdctrl.config_spec
         
         return rcode, reply 
 
         
     def serve_forever(self, poll_interval = 0.5):
+        '''Start cmdctl as one tcp server. '''
         self.__serving = True
         self.__is_shut_down.clear()
         while self.__serving:
@@ -330,21 +356,69 @@ class SecureHTTPServer(http.server.HTTPServer):
     def send_command_to_module(self, module_name, command_name, params):
         return self.cmdctrl.send_command(module_name, command_name, params)
 
+httpd = None
+
+def signal_handler(signal, frame):
+    if httpd:
+        httpd.shutdown()
+    sys.exit(0)
 
-def run(server_class = SecureHTTPServer, addr = 'localhost', port = 8080):
+def set_signal_handler():
+    signal.signal(signal.SIGTERM, signal_handler)
+    signal.signal(signal.SIGINT, signal_handler)
+
+def run(addr = 'localhost', port = 8080, idle_timeout = 1200):
     ''' Start cmdctl as one https server. '''
     print("b10-cmdctl module is starting on :%s port:%d" %(addr, port))
-    httpd = server_class((addr, port), SecureHTTPRequestHandler)
+    httpd = SecureHTTPServer((addr, port), SecureHTTPRequestHandler, idle_timeout)
     httpd.serve_forever()
 
+def check_port(option, opt_str, value, parser):
+    if (value < 0) or (value > 65535):
+        raise OptionValueError('%s requires a port number (0-65535)' % opt_str)
+    parser.values.port = value
+
+def check_addr(option, opt_str, value, parser):
+    ipstr = value
+    ip_family = socket.AF_INET
+    if (ipstr.find(':') != -1):
+        ip_family = socket.AF_INET6
+
+    try:
+        socket.inet_pton(ip_family, ipstr)
+    except:
+        raise OptionValueError("%s invalid ip address" % ipstr)
+
+    parser.values.addr = value
+
+def set_cmd_options(parser):
+    parser.add_option('-p', '--port', dest = 'port', type = 'int',
+            action = 'callback', callback=check_port,
+            default = '8080', help = 'port cmdctl will use')
+
+    parser.add_option('-a', '--address', dest = 'addr', type = 'string',
+            action = 'callback', callback=check_addr,
+            default = '127.0.0.1', help = 'IP address cmdctl will use')
+
+    parser.add_option('-i', '--idle-timeout', dest = 'idle_timeout', type = 'int',
+            default = '1200', help = 'login idle time out')
+
 
 if __name__ == '__main__':
     try:
-        run()
+        parser = OptionParser(version = __version__)
+        set_cmd_options(parser)
+        (options, args) = parser.parse_args()
+        set_signal_handler()
+        run(options.addr, options.port, options.idle_timeout)
     except isc.cc.SessionError as se:
         print("[b10-cmdctl] Error creating b10-cmdctl, "
                 "is the command channel daemon running?")        
     except KeyboardInterrupt:
         print("exit http server")
-        
+
+    if httpd:
+        httpd.shutdown()
+
+
 

+ 12 - 0
src/bin/cmdctl/unittest/cmdctl_test.in

@@ -0,0 +1,12 @@
+#! /bin/sh
+
+PYTHON_EXEC=${PYTHON_EXEC:-@PYTHON@}
+export PYTHON_EXEC
+
+BINDCTL_TEST_PATH=@abs_top_srcdir@/src/bin/cmdctl/unittest
+PYTHONPATH=@abs_top_srcdir@/src/bin/cmdctl
+export PYTHONPATH
+
+cd ${BINDCTL_TEST_PATH}
+exec ${PYTHON_EXEC} -O cmdctl_test.py $*
+

+ 251 - 0
src/bin/cmdctl/unittest/cmdctl_test.py

@@ -0,0 +1,251 @@
+# Copyright (C) 2009  Internet Systems Consortium.
+#
+# Permission to use, copy, modify, and distribute this software for any
+# purpose with or without fee is hereby granted, provided that the above
+# copyright notice and this permission notice appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
+# DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
+# INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
+# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
+# FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
+# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
+# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+
+import unittest
+import socket
+from cmdctl import *
+
+# Rewrite the class for unittest.
+class MySecureHTTPRequestHandler(SecureHTTPRequestHandler):
+    def __init__(self):
+        pass
+
+    def send_response(self, rcode):
+        self.rcode = rcode
+    
+    def end_headers(self):
+        pass
+
+    def do_GET(self):
+        self.wfile = open('tmp.file', 'wb')
+        super().do_GET()
+        self.wfile.close()
+        os.remove('tmp.file')
+
+    def do_POST(self):
+        self.wfile = open("tmp.file", 'wb')
+        super().do_POST()
+        self.wfile.close()
+        os.remove('tmp.file')
+    
+
+class MySecureHTTPServer(SecureHTTPServer):
+    def __init__(self):
+        self.user_sessions = {}
+        self.idle_timeout = 1200
+        self.cmdctrl = MyCommandControl()
+
+class MyCommandControl():
+    def __init__(self):
+        self.command_spec = []
+        self.config_spec = []
+        self.config_data = []
+
+    def send_command(self, mod, cmd, param):
+        pass
+
+
+class TestSecureHTTPRequestHandler(unittest.TestCase):
+    def setUp(self):
+        self.handler = MySecureHTTPRequestHandler()
+        self.handler.server = MySecureHTTPServer()
+        self.handler.server.user_sessions = {}
+        self.handler.server.user_infos = {}
+        self.handler.headers = {}
+
+    def test_parse_request_path(self):
+        self.handler.path = ''
+        mod, cmd = self.handler._parse_request_path()
+        self.assertTrue((mod == None) and (cmd == None))
+
+        self.handler.path = '/abc'
+        mod, cmd = self.handler._parse_request_path()
+        self.assertTrue((mod == 'abc') and (cmd == None))
+        
+        self.handler.path = '/abc/edf'
+        mod, cmd = self.handler._parse_request_path()
+        self.assertTrue((mod == 'abc') and (cmd == 'edf'))
+
+        self.handler.path = '/abc/edf/ghi'
+        mod, cmd = self.handler._parse_request_path()
+        self.assertTrue((mod == 'abc') and (cmd == 'edf'))
+
+    def test_parse_request_path_1(self):
+        self.handler.path = '/ab*c'
+        mod, cmd = self.handler._parse_request_path()
+        self.assertTrue((mod == 'ab') and cmd == None)
+
+        self.handler.path = '/abc/ed*fdd/ddd'
+        mod, cmd = self.handler._parse_request_path()
+        self.assertTrue((mod == 'abc') and cmd == 'ed')
+
+        self.handler.path = '/-*/edfdd/ddd'
+        mod, cmd = self.handler._parse_request_path()
+        self.assertTrue((mod == None) and (cmd == None))
+
+        self.handler.path = '/-*/edfdd/ddd'
+        mod, cmd = self.handler._parse_request_path()
+        self.assertTrue((mod == None) and (cmd == None))
+
+    def test_do_GET(self):
+        self.handler.do_GET()
+        self.assertEqual(self.handler.rcode, http.client.BAD_REQUEST)    
+        
+    def test_do_GET_1(self):
+        self.handler.headers['cookie'] = 12345
+        self.handler.do_GET()
+        self.assertEqual(self.handler.rcode, http.client.UNAUTHORIZED)    
+
+    def test_do_GET_2(self):
+        self.handler.headers['cookie'] = 12345
+        self.handler.server.user_sessions[12345] = time.time() + 1000000
+        self.handler.path = '/how/are'
+        self.handler.do_GET()
+        self.assertEqual(self.handler.rcode, http.client.NO_CONTENT)    
+    
+    def test_do_GET_3(self):
+        self.handler.headers['cookie'] = 12346
+        self.handler.server.user_sessions[12346] = time.time() + 1000000
+        path_vec = ['command_spec', 'config_data', 'config_spec']
+        for path in path_vec:
+            self.handler.path = '/' + path
+            self.handler.do_GET()
+            self.assertEqual(self.handler.rcode, http.client.OK)    
+    
+    def test_user_logged_in(self):
+        self.handler.server.user_sessions = {}
+        self.handler.session_id = 12345
+        self.assertTrue(self.handler._is_user_logged_in() == False)
+
+        self.handler.server.user_sessions[12345] = time.time()
+        self.assertTrue(self.handler._is_user_logged_in())
+
+        self.handler.server.user_sessions[12345] = time.time() - 1500
+        self.handler.idle_timeout = 1200
+        self.assertTrue(self.handler._is_user_logged_in() == False)
+
+    def test_check_user_name_and_pwd(self):
+        self.handler.headers = {}
+        ret, msg = self.handler._check_user_name_and_pwd()
+        self.assertTrue(ret == False)
+        self.assertEqual(msg, ['invalid username or password'])
+
+    def test_check_user_name_and_pwd_1(self):
+        self.handler.rfile = open("check.tmp", 'w+b')
+        user_info = {'username':'root', 'password':'abc123'}
+        len = self.handler.rfile.write(json.dumps(user_info).encode())
+        self.handler.headers['Content-Length'] = len
+        self.handler.rfile.seek(0, 0)
+
+        self.handler.server.user_infos['root'] = ['aa', 'aaa']
+        ret, msg = self.handler._check_user_name_and_pwd()
+        self.assertTrue(ret == False)
+        self.assertEqual(msg, ['password doesn\'t match'])
+        self.handler.rfile.close()
+        os.remove('check.tmp')
+
+    def test_check_user_name_and_pwd_2(self):
+        self.handler.rfile = open("check.tmp", 'w+b')
+        user_info = {'username':'root', 'password':'abc123'}
+        len = self.handler.rfile.write(json.dumps(user_info).encode())
+        self.handler.headers['Content-Length'] = len - 1
+        self.handler.rfile.seek(0, 0)
+
+        ret, msg = self.handler._check_user_name_and_pwd()
+        self.assertTrue(ret == False)
+        self.assertEqual(msg, ['invalid username or password'])
+        self.handler.rfile.close()
+        os.remove('check.tmp')
+
+    def test_check_user_name_and_pwd_3(self):
+        self.handler.rfile = open("check.tmp", 'w+b')
+        user_info = {'usernae':'root', 'password':'abc123'}
+        len = self.handler.rfile.write(json.dumps(user_info).encode())
+        self.handler.headers['Content-Length'] = len
+        self.handler.rfile.seek(0, 0)
+
+        ret, msg = self.handler._check_user_name_and_pwd()
+        self.assertTrue(ret == False)
+        self.assertEqual(msg, ['need user name'])
+        self.handler.rfile.close()
+        os.remove('check.tmp')
+
+    def test_check_user_name_and_pwd_4(self):
+        self.handler.rfile = open("check.tmp", 'w+b')
+        user_info = {'username':'root', 'pssword':'abc123'}
+        len = self.handler.rfile.write(json.dumps(user_info).encode())
+        self.handler.headers['Content-Length'] = len
+        self.handler.rfile.seek(0, 0)
+
+        self.handler.server.user_infos['root'] = ['aa', 'aaa']
+        ret, msg = self.handler._check_user_name_and_pwd()
+        self.assertTrue(ret == False)
+        self.assertEqual(msg, ['need password'])
+        self.handler.rfile.close()
+        os.remove('check.tmp')
+
+    def test_check_user_name_and_pwd_5(self):
+        self.handler.rfile = open("check.tmp", 'w+b')
+        user_info = {'username':'root', 'password':'abc123'}
+        len = self.handler.rfile.write(json.dumps(user_info).encode())
+        self.handler.headers['Content-Length'] = len
+        self.handler.rfile.seek(0, 0)
+
+        ret, msg = self.handler._check_user_name_and_pwd()
+        self.assertTrue(ret == False)
+        self.assertEqual(msg, ['user doesn\'t exist'])
+        self.handler.rfile.close()
+        os.remove('check.tmp')
+
+    def test_do_POST(self):
+        self.handler.headers = {}
+        self.handler.do_POST()
+        self.assertEqual(self.handler.rcode, http.client.BAD_REQUEST)
+
+    def test_do_POST_1(self):
+        self.handler.headers = {}
+        self.handler.headers['cookie'] = 12345
+        self.handler.path = '/'
+        self.handler.do_POST()
+        self.assertEqual(self.handler.rcode, http.client.UNAUTHORIZED)
+
+    def test_handle_post_request(self):
+        self.handler.path = '/cfgmgr/revert'
+        self.handler.headers = {}
+        rcode, reply = self.handler._handle_post_request()
+        self.assertEqual(http.client.OK, rcode)
+
+    def test_handle_post_request_1(self):
+        self.handler.path = '/*d/revert'
+        self.handler.headers = {}
+        rcode, reply = self.handler._handle_post_request()
+        self.assertEqual(http.client.BAD_REQUEST, rcode)
+
+    def test_handle_post_request_2(self):
+        self.handler.rfile = open("check.tmp", 'w+b')
+        params = {123:'param data'}
+        len = self.handler.rfile.write(json.dumps(params).encode())
+        self.handler.headers['Content-Length'] = len
+        self.handler.rfile.seek(0, 0)
+        self.handler.rfile.close()
+        os.remove('check.tmp')
+
+        self.handler.path = '/d/revert'
+        rcode, reply = self.handler._handle_post_request()
+        self.assertEqual(http.client.OK, rcode)
+
+if __name__== "__main__":
+    unittest.main()