Browse Source

Refine the code for b10-cmdctl, make the code robust.

git-svn-id: svn://bind10.isc.org/svn/bind10/branches/parkinglot@632 e5f2f494-b856-4b98-b285-d166d9295462
Likun Zhang 15 years ago
parent
commit
3342ceb634
2 changed files with 150 additions and 89 deletions
  1. 149 88
      src/bin/cmdctl/b10-cmdctl.py.in
  2. 1 1
      src/bin/cmdctl/passwd.csv

+ 149 - 88
src/bin/cmdctl/b10-cmdctl.py.in

@@ -1,7 +1,3 @@
-#!@PYTHON@
-
-import sys; sys.path.append ('@@PYTHONPATH@@')
-
 # Copyright (C) 2010  Internet Systems Consortium.
 #
 # Permission to use, copy, modify, and distribute this software for any
@@ -17,7 +13,19 @@ import sys; sys.path.append ('@@PYTHONPATH@@')
 # NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
 # WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 
+''' cmdctl module is the configuration entry point for all commands from bindctl
+or some other web tools client of bind10. cmdctl is pure https server which provi-
+des RESTful API. When command client connecting with cmdctl, it should first login 
+with legal username and password. 
+    When cmdctl starting up, it will collect command specification and 
+configuration specification/data of other available modules from configmanager, then
+wait for receiving request from client, parse the request and resend the request to
+the proper module. When getting the request result from the module, send back the 
+resut to client.
+'''
+#!@PYTHON@
 
+import sys; sys.path.append ('@@PYTHONPATH@@')
 import http.server
 import urllib.parse
 import json
@@ -29,102 +37,127 @@ import select
 import csv
 import random
 from hashlib import sha1
-
 try:
     import threading
 except ImportError:
     import dummy_threading as threading
 
 URL_PATTERN = re.compile('/([\w]+)(?:/([\w]+))?/?')
+USER_INFO_FILE = "passwd.csv"
+CERTIFICATE_FILE = 'b10-cmdctl.pem'
         
 class SecureHTTPRequestHandler(http.server.BaseHTTPRequestHandler):
+    '''https connection request handler.
+    Currently only GET and POST are supported.
     '''
-    Process GET and POST
-    '''
+
+    def _check_username(self, name):
+        if self.server.user_infos.get(name):
+            return True
+        return False
         
-    def parse_path(self, path):
+    
+    def _check_password(self, name, password):
+        if not password:
+            return False
+        else:
+            info = self.server.user_infos.get(name)            
+            datastr = (password + info[1]).encode()
+            return sha1(datastr).hexdigest() == info[0]        
+
+    def _process_user_login(self):
+        # check username and password, if pass, record client's session id
+        rcode, reply = http.client.UNAUTHORIZED, []        
+        length = self.headers.get('content-length')
+        if not length:
+            reply = ["invalid username or password"]     
+        else:
+            user_info = json.loads((self.rfile.read(int(length))).decode())
+            name = user_info.get('username')
+            passwd = user_info.get('password')
+            if not user_info:
+                reply = ["invalid username or password"]                
+            elif not self._check_username(name):
+                reply = ["username doesn't exist"]    
+            elif not self._check_password(name, passwd):
+                reply = ["invalid password"]
+            else:
+                senid = self.headers.get('cookie')
+                if not senid:
+                    reply = ["need session id from client"]                
+                else:
+                    self.server.user_sessions.append(senid)
+                    rcode, reply = http.client.OK, ["login success "]
+    
+        return rcode, reply
+        
+    
+    def _parse_request_path(self, path):
+        '''Parse the url, the legal url should like /ldh or /ldh/ldh '''
         groups = URL_PATTERN.match(path) 
         if not groups:
-            return (NULL, NULL)
+            return (None, None)
 
         return (groups.group(1), groups.group(2))
 
-    def check_username(self,name):
-        reader = csv.reader(open("passwd.csv"), delimiter="\t", quoting=csv.QUOTE_MINIMAL)
-        for row in reader:
-            if name==row[0]:
-                self.user=[row[0],row[1],row[2]]
-                return 1
-        return 0
-
-
-    def check(self):
-        length = self.headers.get('content-length')
-        nbytes = int(length)
-        user_info = json.loads((self.rfile.read(nbytes)).decode())
-
-        if not user_info:
-            return ["error: invalid username or password"], http.client.UNAUTHORIZED
-        if not self.check_username(user_info['username']):
-            return ["error:the username doesn't exists"], http.client.UNAUTHORIZED
-
-        if sha1((user_info['password'] + self.user[2]).encode()).hexdigest() != self.user[1] :
-            return ["error:the username and passwd did not match!"], http.client.UNAUTHORIZED
-        else :
-            id = self.headers.get('cookie')
-            self.server.session[id]['username'] = user_info['username']
-            return ["login success !"], http.client.OK
-
-
     def do_GET(self):
-        id = self.headers.get('cookie')
-        if id not in self.server.session:
-            self.server.session[id]={}
-        
-        reply_value = []
-        if "username" not in self.server.session[id]:
-            reply_value = ["please post username and passwd"]
+        ''' The client should send its session id in header with 
+        the name 'cookie'
+        '''
+        rcode, reply = 200, []        
+        senid = self.headers.get('cookie')
+        if not senid:
+            rcode = http.client.BAD_REQUEST
         else:
-            identifier, module = self.parse_path(self.path)
-            if identifier != None:
-                data = self.server.get_reply_data_for_GET(identifier, module) 
-                if data:
-                    reply_value = data
-
-        self.send_response(200)
+            if senid not in self.server.user_sessions:
+                rcode, reply = http.client.UNAUTHORIZED, ["please login"]
+            else:
+                identifier, module = self._parse_request_path(self.path)   
+                rcode, reply = self.server.get_reply_data_for_GET(identifier, module) 
+                    
+        self.send_response(rcode)
         self.end_headers()
-        self.wfile.write(json.dumps(reply_value).encode())
+        self.wfile.write(json.dumps(reply).encode())
 
         
     def do_POST(self):
+        '''Process user login and send command to proper module  
+        The client should send its session id in header with 
+        the name 'cookie'
+        '''
+        rcode, reply = http.client.OK, []
         id = self.headers.get('cookie')
-        if id not in self.server.session:
-            self.server.session[id] = {}
-
-        reply_msg = []
-        rcode = 200
-        if self.path == '/login':
-            reply_msg, rcode = self.check()
-        elif "username" not in self.server.session[id]:
-            reply_msg, rcode = ["please login!"], http.client.UNAUTHORIZED
-        else:
-            mod, cmd = self.parse_path(self.path)
-            param = None
-            len = self.headers.get('Content-Length')
-            if len:
-                post_str = str(self.rfile.read(int(len)).decode())
-                print("command parameter:%s" % post_str)
-                param = json.loads(post_str)
-
-            reply_msg = self.server.send_command_to_module(mod, cmd, param)
-            print('b10-cmdctl finish send message \'%s\' to module %s' % (cmd, mod))
-        
-        #TODO, set proper rcode
+        if not id:
+            rcode = http.client.BAD_REQUEST
+        else:        
+            if self.path == '/login':
+                rcode, reply = self._process_user_login()
+            elif id not in self.server.user_sessions:
+                rcode, reply = http.client.UNAUTHORIZED, ["please login"]           
+            else:
+                mod, cmd = self._parse_request_path(self.path)
+                param = None
+                len = self.headers.get('Content-Length')
+                if len:
+                    post_str = str(self.rfile.read(int(len)).decode())
+                    print("command parameter:%s" % post_str)
+                    param = json.loads(post_str)
+    
+                # TODO, need return some proper return code. 
+                # currently always OK.
+                reply = self.server.send_command_to_module(mod, cmd, param)
+                print('b10-cmdctl finish send message \'%s\' to module %s' % (cmd, mod))            
+       
         self.send_response(rcode)
         self.end_headers()
-        self.wfile.write(json.dumps(reply_msg).encode())
+        self.wfile.write(json.dumps(reply).encode())
+
 
 class CommandControl():
+    '''Get all modules' config data/specification from configmanager.
+    receive command from client and resend it to proper module.
+    '''
+
     def __init__(self):
         self.cc = ISC.CC.Session()
         self.cc.group_subscribe('Cmd-Ctrld')
@@ -161,7 +194,7 @@ class CommandControl():
         
         return True
     
-    def send_command(self, module_name, command_name, params = None):
+    def send_command(self, module_name, command_name, params = None):       
         content = [command_name]
         if params:
             content.append(params)
@@ -170,10 +203,12 @@ class CommandControl():
         print('b10-cmdctl send command \'%s\' to %s' %(command_name, module_name))
         try:
             self.cc.group_sendmsg(msg, module_name)
+            #TODO, it may be blocked, msqg need to add a new interface
+            # wait in timeout.
             answer, env = self.cc.group_recvmsg(False)
             if answer and 'result' in answer.keys() and type(answer['result']) == list:
                 # TODO: with the new cc implementation, replace "1" by 1
-                if answer['result'][0] == "1":
+                if answer['result'][0] == 1:
                     # todo: exception
                     print("Error: " + str(answer['result'][1]))
                     return {}
@@ -196,35 +231,59 @@ class SecureHTTPServer(http.server.HTTPServer):
 
     def __init__(self, server_address, RequestHandlerClass):
         http.server.HTTPServer.__init__(self, server_address, RequestHandlerClass)
-        self.session = {}
+        self.user_sessions = []
         self.cmdctrl = CommandControl()
         self.__is_shut_down = threading.Event()
         self.__serving = False
+        self.user_infos = {}
+        self._read_user_info()
 
+    def _read_user_info(self):
+        # Get all username and password information
+        csvfile = None
+        try:
+            csvfile = open(USER_INFO_FILE)
+            reader = csv.reader(csvfile)
+            for row in reader:
+                self.user_infos[row[0]] = [row[1], row[2]]
+                
+        except Exception as e:
+            print("Fail to read user information ", e)                
+            exit(1)
+        finally:
+            if csvfile:
+                csvfile.close()
+        
+        
     def get_request(self):
         newsocket, fromaddr = self.socket.accept()
         try:
             connstream = ssl.wrap_socket(newsocket,
                                      server_side = True,
-                                     certfile = 'b10-cmdctl.pem',
-                                     keyfile = 'b10-cmdctl.pem',
+                                     certfile = CERTIFICATE_FILE,
+                                     keyfile = CERTIFICATE_FILE,
                                      ssl_version = ssl.PROTOCOL_SSLv23)
             return (connstream, fromaddr)
         except ssl.SSLError as e :
-            print("error happen***********")
-            print(e)
+            print("cmdctl: deny client's invalid connection", e)
+            self.close_request(newsocket)
+            # raise socket error to finish the request
+            raise socket.error
             
     
-    def get_reply_data_for_GET(self, id, module_name):
-        if module_name is None:
+    def get_reply_data_for_GET(self, id, module):
+        '''Currently only support the following three url GET request '''
+        rcode, reply = http.client.NO_CONTENT, []        
+        if not module:
             if id == 'command_spec':
-                return self.cmdctrl.command_spec
+                rcode, reply = http.client.OK, self.cmdctrl.command_spec
             elif id == 'config_data':
-                return self.cmdctrl.config_data
+                rcode, reply = http.client.OK, self.cmdctrl.config_data
             elif id == 'config_spec':
-                return self.cmdctrl.config_spec
-            else:
-                return None
+                rcode, reply = http.client.OK, self.cmdctrl.config_spec
+        
+        return rcode, reply     
+            
 
     def serve_forever(self, poll_interval = 0.5):
         self.__serving = True
@@ -247,7 +306,9 @@ class SecureHTTPServer(http.server.HTTPServer):
     def send_command_to_module(self, module_name, command_name, params):
         return self.cmdctrl.send_command(module_name, command_name, params)
 
+
 def run(server_class = SecureHTTPServer, addr = 'localhost', port = 8080):
+    ''' Start cmdctl as one https server. '''
     print("b10-cmdctl module is starting on :%s port:%d" %(addr, port))
     httpd = server_class((addr, port), SecureHTTPRequestHandler)
     httpd.serve_forever()

+ 1 - 1
src/bin/cmdctl/passwd.csv

@@ -1 +1 @@
-root	e0da4e422d3f42173edfee8a0fab11f9c5f4f2bb	-R%zdgw/L@E}n1WTQTw*p=3i3=C~cbsvy=s'mWEO=m)IQN]|e4-/u?rC)5cDIBLn
+root,6f0c73bd33101a5ec0294b3ca39fec90ef4717fe,"{5hV&$^(]!uV,3H>E~=f`I;,HgMl""`Eyao4^0l|Nlz|%R9Y0v)#/t'u@CzJ$U)"