Browse Source

1. Refactor the function get_param_name_by_position of class CommandInfo, add add some unittest for it.

git-svn-id: svn://bind10.isc.org/svn/bind10/trunk@990 e5f2f494-b856-4b98-b285-d166d9295462
Likun Zhang 15 years ago
parent
commit
5379b40c81

+ 2 - 1
src/bin/bindctl/TODO

@@ -1,7 +1,8 @@
 1. Refactor the code for bindctl.
 2. Update man page for bindctl provided by jreed.
 3. Add more unit tests.
-4. Need Review:
+4. Need Review(When command line syntax is changed later, the following 
+functions should be updated first.):
         bindcmd.py:
             apply_config_cmd()
             _validate_cmd()

+ 7 - 1
src/bin/bindctl/bindcmd.py

@@ -206,6 +206,12 @@ class BindCmdInterpreter(Cmd):
         self.add_module_info(module)
 
     def _validate_cmd(self, cmd):
+        '''validate the parameters and merge some parameters together,
+        merge algorithm is based on the command line syntax, later, if
+        a better command line syntax come out, this function should be 
+        updated first.        
+        '''
+
         if not cmd.module in self.modules:
             raise CmdUnknownModuleSyntaxError(cmd.module)
         
@@ -254,7 +260,7 @@ class BindCmdInterpreter(Cmd):
                             raise CmdUnknownParamSyntaxError(cmd.module, cmd.command, cmd.params[name])
                     else:
                         # replace the numbered items by named items
-                        param_name = command_info.get_param_name_by_position(name+1, param_count)
+                        param_name = command_info.get_param_name_by_position(name, param_count)
                         cmd.params[param_name] = cmd.params[name]
                         del cmd.params[name]
                         

+ 24 - 36
src/bin/bindctl/moduleinfo.py

@@ -90,47 +90,35 @@ class CommandInfo:
                 if not self.params[name].is_optional]        
         
     def get_param_name_by_position(self, pos, param_count):
-        # count mandatories back from the last
-        # from the last mandatory; see the number of mandatories before it
-        # and compare that to the number of positional arguments left to do
-        # if the number of lefts is higher than the number of mandatories,
-        # use the first optional. Otherwise, use the first unhandled mandatory
-        # (and update the location accordingly?)
-        # (can this be done in all cases? this is certainly not the most efficient method;
-        # one way to make the whole of this more consistent is to always set mandatories first, but
-        # that would make some commands less nice to use ("config set value location" instead of "config set location value")
+        '''
+        Find a proper parameter name for the position 'pos':
+        If param_count is equal to the count of mandatory parameters of command,
+        and there is some optional parameter, find the first mandatory parameter 
+        from the position 'pos' to the end. Else, return the name on position pos.
+        (This function will be changed if bindctl command line syntax is changed
+        in the future. )
+        '''
         if type(pos) != int:
             raise KeyError(str(pos) + " is not an integer")
+
         else:
-            if param_count == len(self.params) - 1:
-                i = 0
-                for k in self.params.keys():
-                    if i == pos:
-                        return k
-                    i += 1
+            params = self.params.copy()
+            del params['help']
+            count = len(params)
+            if (pos >= count):
                 raise KeyError(str(pos) + " out of range")
-            elif param_count <= len(self.params):
-                mandatory_count = 0
-                for k in self.params.keys():
-                    if not self.params[k].is_optional:
-                        mandatory_count += 1
-                if param_count == mandatory_count:
-                    # return the first mandatory from pos
-                    i = 0
-                    for k in self.params.keys():
-                        if i >= pos and not self.params[k].is_optional:
-                            return k
-                        i += 1
-                    raise KeyError(str(pos) + " out of range")
-                else:
-                    i = 0
-                    for k in self.params.keys():
-                        if i == pos:
-                            return k
-                        i += 1
-                    raise KeyError(str(pos) + " out of range")
+
+            mandatory_count = len(self.get_mandatory_param_names())
+            param_names = list(params.keys())
+            if (param_count == mandatory_count) and (param_count < count):
+                while pos < count:
+                    if not params[param_names[pos]].is_optional:
+                        return param_names[pos]
+                    pos += 1
+                
+                raise KeyError(str(pos) + "parameters have error")
             else:
-                raise KeyError("Too many parameters")
+                return param_names[pos]
 
 
     def command_help(self):

+ 18 - 1
src/bin/bindctl/unittest/bindctl_test.py

@@ -158,7 +158,24 @@ class TestCmdSyntax(unittest.TestCase):
     def testCmdUnknownParamSyntaxError(self):
         self.my_assert_raise(CmdUnknownParamSyntaxError, "zone load zone_d='cn'")
         self.my_assert_raise(CmdUnknownParamSyntaxError, "zone reload_all zone_name = 'cn'")  
-        
+       
+class TestModuleInfo(unittest.TestCase):
+
+    def test_get_param_name_by_position(self):
+        cmd = CommandInfo('command')
+        cmd.add_param(ParamInfo('name'))
+        cmd.add_param(ParamInfo('age'))
+        cmd.add_param(ParamInfo('data', optional = True))
+        cmd.add_param(ParamInfo('sex'))
+        self.assertEqual('name', cmd.get_param_name_by_position(0, 2))
+        self.assertEqual('age', cmd.get_param_name_by_position(1, 2))
+        self.assertEqual('sex', cmd.get_param_name_by_position(2, 3))
+        self.assertEqual('data', cmd.get_param_name_by_position(2, 4))
+        self.assertEqual('data', cmd.get_param_name_by_position(2, 4))
+        
+        self.assertRaises(KeyError, cmd.get_param_name_by_position, 4, 4)
+
+
     
 class TestNameSequence(unittest.TestCase):
     """