Browse Source

Merge pull request #12 from digitalocean/rpc-refactor

RPC Refactor
jeremystretch 8 years ago
parent
commit
74add9f519
3 changed files with 89 additions and 75 deletions
  1. 2 2
      netbox/extras/management/commands/run_inventory.py
  2. 87 72
      netbox/extras/rpc.py
  3. 0 1
      requirements.txt

+ 2 - 2
netbox/extras/management/commands/run_inventory.py

@@ -1,6 +1,6 @@
-from Exscript.protocols.Exception import LoginFailure
 from getpass import getpass
 from ncclient.transport.errors import AuthenticationError
+from paramiko import AuthenticationException
 
 from django.conf import settings
 from django.core.management.base import BaseCommand, CommandError
@@ -96,7 +96,7 @@ class Command(BaseCommand):
                     inventory = rpc_client.get_inventory()
             except KeyboardInterrupt:
                 raise
-            except (AuthenticationError, LoginFailure):
+            except (AuthenticationError, AuthenticationException):
                 self.stdout.write("Authentication error!")
                 continue
             except Exception as e:

+ 87 - 72
netbox/extras/rpc.py

@@ -1,9 +1,8 @@
-from Exscript import Account
-from Exscript.protocols import SSH2
 from ncclient import manager
 import paramiko
 import re
 import xmltodict
+import time
 
 
 CONNECT_TIMEOUT = 5  # seconds
@@ -54,6 +53,56 @@ class RPCClient(object):
         raise NotImplementedError("Feature not implemented for this platform.")
 
 
+class SSHClient(RPCClient):
+    def __enter__(self):
+
+        self.ssh = paramiko.SSHClient()
+        self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+        try:
+            self.ssh.connect(
+                self.host,
+                username=self.username,
+                password=self.password,
+                timeout=CONNECT_TIMEOUT,
+                allow_agent=False,
+                look_for_keys=False,
+            )
+        except paramiko.AuthenticationException:
+            # Try default credentials if the configured creds don't work
+            try:
+                default_creds = self.default_credentials
+                if default_creds.get('username') and default_creds.get('password'):
+                    self.ssh.connect(
+                        self.host,
+                        username=default_creds['username'],
+                        password=default_creds['password'],
+                        timeout=CONNECT_TIMEOUT,
+                        allow_agent=False,
+                        look_for_keys=False,
+                    )
+                else:
+                    raise ValueError('default_credentials are incomplete.')
+            except AttributeError:
+                raise paramiko.AuthenticationException
+
+        self.session = self.ssh.invoke_shell()
+        self.session.recv(1000)
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.ssh.close()
+
+    def _send(self, cmd, pause=1):
+        self.session.send('{}\n'.format(cmd))
+        data = ''
+        time.sleep(pause)
+        while self.session.recv_ready():
+            data += self.session.recv(4096).decode()
+            if not data:
+                break
+        return data
+
+
 class JunosNC(RPCClient):
     """
     NETCONF client for Juniper Junos devices
@@ -130,95 +179,61 @@ class JunosNC(RPCClient):
         return result
 
 
-class IOSSSH(RPCClient):
+class IOSSSH(SSHClient):
     """
     SSH client for Cisco IOS devices
     """
 
-    def __enter__(self):
-
-        # Initiate a connection to the device
-        self.ssh = SSH2(connect_timeout=CONNECT_TIMEOUT)
-        self.ssh.connect(self.host)
-        self.ssh.authenticate(Account(self.username, self.password))
-
-        # Disable terminal paging
-        self.ssh.execute("terminal length 0")
-
-        return self
-
-    def __exit__(self, exc_type, exc_val, exc_tb):
-
-        # Close the connection to the device
-        self.ssh.send("exit\r")
-        self.ssh.close()
-
     def get_inventory(self):
+        def version():
+
+            def parse(cmd_out, rex):
+                for i in cmd_out:
+                    match = re.search(rex, i)
+                    if match:
+                        return match.groups()[0]
+
+            sh_ver = self._send('show version').split('\r\n')
+            return {
+                'serial': parse(sh_ver, 'Processor board ID ([^\s]+)'),
+                'description': parse(sh_ver, 'cisco ([^\s]+)')
+            }
 
-        result = dict()
-
-        # Gather chassis data
-        try:
-            self.ssh.execute("show version")
-            show_version = self.ssh.response
-            serial = re.search("Processor board ID ([^\s]+)", show_version).groups()[0]
-            description = re.search("\r\n\r\ncisco ([^\s]+)", show_version).groups()[0]
-        except:
-            raise RuntimeError("Failed to glean chassis info from device.")
-        result['chassis'] = {
-            'serial': serial,
-            'description': description,
-        }
-
-        # Gather modules
-        result['modules'] = []
-        try:
-            self.ssh.execute("show inventory")
-            show_inventory = self.ssh.response
-            # Split modules on double line
-            modules_raw = show_inventory.strip().split('\r\n\r\n')
-            for module_raw in modules_raw:
+        def modules(chassis_serial=None):
+            cmd = self._send('show inventory').split('\r\n\r\n')
+            for i in cmd:
+                i_fmt = i.replace('\r\n', ' ')
                 try:
-                    m_name = re.search('NAME: "([^"]+)"', module_raw).group(1)
-                    m_pid = re.search('PID: ([^\s]+)', module_raw).group(1)
-                    m_serial = re.search('SN: ([^\s]+)', module_raw).group(1)
+                    m_name = re.search('NAME: "([^"]+)"', i_fmt).group(1)
+                    m_pid = re.search('PID: ([^\s]+)', i_fmt).group(1)
+                    m_serial = re.search('SN: ([^\s]+)', i_fmt).group(1)
                     # Omit built-in modules and those with no PID
-                    if m_serial != result['chassis']['serial'] and m_pid.lower() != 'unspecified':
-                        result['modules'].append({
+                    if m_serial != chassis_serial and m_pid.lower() != 'unspecified':
+                        yield {
                             'name': m_name,
                             'part_id': m_pid,
                             'serial': m_serial,
-                        })
+                        }
                 except AttributeError:
                     continue
-        except:
-            raise RuntimeError("Failed to glean module info from device.")
 
-        return result
+        self._send('term length 0')
+        sh_version = version()
+
+        return {
+            'chassis': sh_version,
+            'modules': list(modules(chassis_serial=sh_version.get('serial')))
+        }
 
 
-class OpengearSSH(RPCClient):
+class OpengearSSH(SSHClient):
     """
     SSH client for Opengear devices
     """
-
-    def __enter__(self):
-
-        # Initiate a connection to the device
-        self.ssh = paramiko.SSHClient()
-        self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
-        try:
-            self.ssh.connect(self.host, username=self.username, password=self.password, timeout=CONNECT_TIMEOUT)
-        except paramiko.AuthenticationException:
-            # Try default Opengear credentials if the configured creds don't work
-            self.ssh.connect(self.host, username='root', password='default')
-
-        return self
-
-    def __exit__(self, exc_type, exc_val, exc_tb):
-
-        # Close the connection to the device
-        self.ssh.close()
+    default_credentials = {
+        'username': 'root',
+        'password': 'default',
+    }
 
     def get_inventory(self):
 

+ 0 - 1
requirements.txt

@@ -5,7 +5,6 @@ django-filter==0.13.0
 django-rest-swagger==0.3.7
 django-tables2==1.2.1
 djangorestframework==3.3.3
-Exscript==2.1.503
 Markdown==2.6.6
 ncclient==0.4.7
 netaddr==0.7.18