Parcourir la source

1. Refactor the function get_param_name_by_position of class CommandInfo, add add some unittest for it.

git-svn-id: svn://bind10.isc.org/svn/bind10/trunk@990 e5f2f494-b856-4b98-b285-d166d9295462
Likun Zhang il y a 15 ans
Parent
commit
5379b40c81

+ 2 - 1
src/bin/bindctl/TODO

@@ -1,7 +1,8 @@
 1. Refactor the code for bindctl.
 1. Refactor the code for bindctl.
 2. Update man page for bindctl provided by jreed.
 2. Update man page for bindctl provided by jreed.
 3. Add more unit tests.
 3. Add more unit tests.
-4. Need Review:
+4. Need Review(When command line syntax is changed later, the following 
+functions should be updated first.):
         bindcmd.py:
         bindcmd.py:
             apply_config_cmd()
             apply_config_cmd()
             _validate_cmd()
             _validate_cmd()

+ 7 - 1
src/bin/bindctl/bindcmd.py

@@ -206,6 +206,12 @@ class BindCmdInterpreter(Cmd):
         self.add_module_info(module)
         self.add_module_info(module)
 
 
     def _validate_cmd(self, cmd):
     def _validate_cmd(self, cmd):
+        '''validate the parameters and merge some parameters together,
+        merge algorithm is based on the command line syntax, later, if
+        a better command line syntax come out, this function should be 
+        updated first.        
+        '''
+
         if not cmd.module in self.modules:
         if not cmd.module in self.modules:
             raise CmdUnknownModuleSyntaxError(cmd.module)
             raise CmdUnknownModuleSyntaxError(cmd.module)
         
         
@@ -254,7 +260,7 @@ class BindCmdInterpreter(Cmd):
                             raise CmdUnknownParamSyntaxError(cmd.module, cmd.command, cmd.params[name])
                             raise CmdUnknownParamSyntaxError(cmd.module, cmd.command, cmd.params[name])
                     else:
                     else:
                         # replace the numbered items by named items
                         # replace the numbered items by named items
-                        param_name = command_info.get_param_name_by_position(name+1, param_count)
+                        param_name = command_info.get_param_name_by_position(name, param_count)
                         cmd.params[param_name] = cmd.params[name]
                         cmd.params[param_name] = cmd.params[name]
                         del cmd.params[name]
                         del cmd.params[name]
                         
                         

+ 24 - 36
src/bin/bindctl/moduleinfo.py

@@ -90,47 +90,35 @@ class CommandInfo:
                 if not self.params[name].is_optional]        
                 if not self.params[name].is_optional]        
         
         
     def get_param_name_by_position(self, pos, param_count):
     def get_param_name_by_position(self, pos, param_count):
-        # count mandatories back from the last
-        # from the last mandatory; see the number of mandatories before it
-        # and compare that to the number of positional arguments left to do
-        # if the number of lefts is higher than the number of mandatories,
-        # use the first optional. Otherwise, use the first unhandled mandatory
-        # (and update the location accordingly?)
-        # (can this be done in all cases? this is certainly not the most efficient method;
-        # one way to make the whole of this more consistent is to always set mandatories first, but
-        # that would make some commands less nice to use ("config set value location" instead of "config set location value")
+        '''
+        Find a proper parameter name for the position 'pos':
+        If param_count is equal to the count of mandatory parameters of command,
+        and there is some optional parameter, find the first mandatory parameter 
+        from the position 'pos' to the end. Else, return the name on position pos.
+        (This function will be changed if bindctl command line syntax is changed
+        in the future. )
+        '''
         if type(pos) != int:
         if type(pos) != int:
             raise KeyError(str(pos) + " is not an integer")
             raise KeyError(str(pos) + " is not an integer")
+
         else:
         else:
-            if param_count == len(self.params) - 1:
-                i = 0
-                for k in self.params.keys():
-                    if i == pos:
-                        return k
-                    i += 1
+            params = self.params.copy()
+            del params['help']
+            count = len(params)
+            if (pos >= count):
                 raise KeyError(str(pos) + " out of range")
                 raise KeyError(str(pos) + " out of range")
-            elif param_count <= len(self.params):
-                mandatory_count = 0
-                for k in self.params.keys():
-                    if not self.params[k].is_optional:
-                        mandatory_count += 1
-                if param_count == mandatory_count:
-                    # return the first mandatory from pos
-                    i = 0
-                    for k in self.params.keys():
-                        if i >= pos and not self.params[k].is_optional:
-                            return k
-                        i += 1
-                    raise KeyError(str(pos) + " out of range")
-                else:
-                    i = 0
-                    for k in self.params.keys():
-                        if i == pos:
-                            return k
-                        i += 1
-                    raise KeyError(str(pos) + " out of range")
+
+            mandatory_count = len(self.get_mandatory_param_names())
+            param_names = list(params.keys())
+            if (param_count == mandatory_count) and (param_count < count):
+                while pos < count:
+                    if not params[param_names[pos]].is_optional:
+                        return param_names[pos]
+                    pos += 1
+                
+                raise KeyError(str(pos) + "parameters have error")
             else:
             else:
-                raise KeyError("Too many parameters")
+                return param_names[pos]
 
 
 
 
     def command_help(self):
     def command_help(self):

+ 18 - 1
src/bin/bindctl/unittest/bindctl_test.py

@@ -158,7 +158,24 @@ class TestCmdSyntax(unittest.TestCase):
     def testCmdUnknownParamSyntaxError(self):
     def testCmdUnknownParamSyntaxError(self):
         self.my_assert_raise(CmdUnknownParamSyntaxError, "zone load zone_d='cn'")
         self.my_assert_raise(CmdUnknownParamSyntaxError, "zone load zone_d='cn'")
         self.my_assert_raise(CmdUnknownParamSyntaxError, "zone reload_all zone_name = 'cn'")  
         self.my_assert_raise(CmdUnknownParamSyntaxError, "zone reload_all zone_name = 'cn'")  
-        
+       
+class TestModuleInfo(unittest.TestCase):
+
+    def test_get_param_name_by_position(self):
+        cmd = CommandInfo('command')
+        cmd.add_param(ParamInfo('name'))
+        cmd.add_param(ParamInfo('age'))
+        cmd.add_param(ParamInfo('data', optional = True))
+        cmd.add_param(ParamInfo('sex'))
+        self.assertEqual('name', cmd.get_param_name_by_position(0, 2))
+        self.assertEqual('age', cmd.get_param_name_by_position(1, 2))
+        self.assertEqual('sex', cmd.get_param_name_by_position(2, 3))
+        self.assertEqual('data', cmd.get_param_name_by_position(2, 4))
+        self.assertEqual('data', cmd.get_param_name_by_position(2, 4))
+        
+        self.assertRaises(KeyError, cmd.get_param_name_by_position, 4, 4)
+
+
     
     
 class TestNameSequence(unittest.TestCase):
 class TestNameSequence(unittest.TestCase):
     """
     """