Browse Source

Cleaned up the API quite a bit

Jeremy Stretch 7 years ago
parent
commit
3395b51086
3 changed files with 70 additions and 71 deletions
  1. 44 49
      netbox/extras/api/views.py
  2. 22 18
      netbox/extras/reports.py
  3. 4 4
      netbox/extras/views.py

+ 44 - 49
netbox/extras/api/views.py

@@ -95,74 +95,69 @@ class ReportViewSet(ViewSet):
     exclude_from_schema = True
     lookup_value_regex = '[^/]+'  # Allow dots
 
-    def list(self, request):
+    def _retrieve_report(self, pk):
+
+        # Read the PK as "<module>.<report>"
+        if '.' not in pk:
+            raise Http404
+        module_name, report_name = pk.split('.', 1)
 
-        # Compile all reports
+        # Raise a 404 on an invalid Report module/name
+        report = get_report(module_name, report_name)
+        if report is None:
+            raise Http404
+
+        return report
+
+    def list(self, request):
+        """
+        Compile all reports and their related results (if any). Result data is deferred in the list view.
+        """
         report_list = []
+
+        # Iterate through all available Reports.
         for module_name, reports in get_reports():
-            for report_name, report_cls in reports:
-                data = {
-                    'module': module_name,
-                    'name': report_name,
-                    'description': report_cls.description,
-                    'test_methods': report_cls().test_methods,
-                    'result': None,
-                }
-                try:
-                    result = ReportResult.objects.defer('data').get(report='{}.{}'.format(module_name, report_name))
-                    data['result'] = result
-                except ReportResult.DoesNotExist:
-                    pass
-                report_list.append(data)
-
-        serializer = serializers.ReportSerializer(report_list, many=True, context={'request': request})
+            for report in reports:
+
+                # Attach the relevant ReportResult (if any) to each Report.
+                report.result = ReportResult.objects.filter(report=report.full_name).defer('data').first()
+                report_list.append(report)
+
+        serializer = serializers.ReportSerializer(report_list, many=True)
 
         return Response(serializer.data)
 
     def retrieve(self, request, pk):
+        """
+        Retrieve a single Report identified as "<module>.<report>".
+        """
 
-        # Retrieve report by <module>.<report>
-        if '.' not in pk:
-            raise Http404
-        module_name, report_name = pk.split('.', 1)
-        report_cls = get_report(module_name, report_name)
-        data = {
-            'module': module_name,
-            'name': report_name,
-            'description': report_cls.description,
-            'test_methods': report_cls().test_methods,
-            'result': None,
-        }
-
-        # Attach report result
-        try:
-            result = ReportResult.objects.get(report='{}.{}'.format(module_name, report_name))
-            data['result'] = result
-        except ReportResult.DoesNotExist:
-            pass
+        # Retrieve the Report and ReportResult, if any.
+        report = self._retrieve_report(pk)
+        report.result = ReportResult.objects.filter(report=report.full_name).first()
 
-        serializer = serializers.ReportDetailSerializer(data)
+        serializer = serializers.ReportDetailSerializer(report)
 
         return Response(serializer.data)
 
     @detail_route()
     def run(self, request, pk):
+        """
+        Run a Report and create a new ReportResult, overwriting any previous result for the Report.
+        """
 
-        # Retrieve report by <module>.<report>
-        if '.' not in pk:
-            raise Http404
-        module_name, report_name = pk.split('.', 1)
-        report_cls = get_report(module_name, report_name)
-
-        # Run the report
-        report = report_cls()
+        # Retrieve and run the Report.
+        report = self._retrieve_report(pk)
         result = report.run()
 
-        # Save the ReportResult
+        # Delete the old ReportResult (if any) and save the new one.
         ReportResult.objects.filter(report=pk).delete()
-        ReportResult(report=pk, failed=report.failed, data=result).save()
+        report.result = ReportResult(report=pk, failed=report.failed, data=result)
+        report.result.save()
 
-        return Response('Report completed.')
+        serializer = serializers.ReportDetailSerializer(report)
+
+        return Response(serializer.data)
 
 
 class RecentActivityViewSet(ReadOnlyModelViewSet):

+ 22 - 18
netbox/extras/reports.py

@@ -6,7 +6,7 @@ import pkgutil
 from django.utils import timezone
 
 from .constants import LOG_DEFAULT, LOG_FAILURE, LOG_INFO, LOG_LEVEL_CODES, LOG_SUCCESS, LOG_WARNING
-import reports as user_reports
+import reports as custom_reports
 
 
 def is_report(obj):
@@ -23,7 +23,8 @@ def get_report(module_name, report_name):
     Return a specific report from within a module.
     """
     module = importlib.import_module('reports.{}'.format(module_name))
-    return getattr(module, report_name)
+    report = getattr(module, report_name, None)
+    return report()
 
 
 def get_reports():
@@ -31,27 +32,18 @@ def get_reports():
     Compile a list of all reports available across all modules in the reports path. Returns a list of tuples:
 
     [
-        (module_name, (
-            (report_name, report_class),
-            (report_name, report_class)
-        ),
-        (module_name, (
-            (report_name, report_class),
-            (report_name, report_class)
-        )
+        (module_name, (report_class, report_class, report_class, ...)),
+        (module_name, (report_class, report_class, report_class, ...)),
+        ...
     ]
     """
     module_list = []
 
-    # Iterate through all modules within the reports path
-    for importer, module_name, is_pkg in pkgutil.walk_packages(user_reports.__path__):
+    # Iterate through all modules within the reports path. These are the user-defined files in which reports are
+    # defined.
+    for importer, module_name, is_pkg in pkgutil.walk_packages(custom_reports.__path__):
         module = importlib.import_module('reports.{}'.format(module_name))
-        report_list = []
-
-        # Iterate through all Report classes within the module
-        for report_name, report_class in inspect.getmembers(module, is_report):
-            report_list.append((report_name, report_class))
-
+        report_list = [cls() for _, cls in inspect.getmembers(module, is_report)]
         module_list.append((module_name, report_list))
 
     return module_list
@@ -105,6 +97,18 @@ class Report(object):
             raise Exception("A report must contain at least one test method.")
         self.test_methods = test_methods
 
+    @property
+    def module(self):
+        return self.__module__.rsplit('.', 1)[1]
+
+    @property
+    def name(self):
+        return self.__class__.__name__
+
+    @property
+    def full_name(self):
+        return '.'.join([self.module, self.name])
+
     def _log(self, obj, message, level=LOG_DEFAULT):
         """
         Log a message from a test method. Do not call this method directly; use one of the log_* wrappers below.

+ 4 - 4
netbox/extras/views.py

@@ -58,11 +58,11 @@ class ReportListView(View):
         foo = []
         for module, report_list in reports:
             module_reports = []
-            for report_name, report_class in report_list:
+            for report in report_list:
                 module_reports.append({
-                    'name': report_name,
-                    'description': report_class.description,
-                    'results': results.get('{}.{}'.format(module, report_name), None)
+                    'name': report.name,
+                    'description': report.description,
+                    'results': results.get(report.full_name, None)
                 })
             foo.append((module, module_reports))