Browse Source

Refactored SecretViewSet

Jeremy Stretch 8 years ago
parent
commit
07a2b136b8
1 changed files with 29 additions and 36 deletions
  1. 29 36
      netbox/secrets/api/views.py

+ 29 - 36
netbox/secrets/api/views.py

@@ -8,7 +8,7 @@ from rest_framework.authentication import BasicAuthentication, SessionAuthentica
 from rest_framework.permissions import IsAuthenticated
 from rest_framework.renderers import JSONRenderer
 from rest_framework.response import Response
-from rest_framework.viewsets import ViewSet, ModelViewSet
+from rest_framework.viewsets import GenericViewSet, ModelViewSet, ViewSet
 
 from extras.api.renderers import FormlessBrowsableAPIRenderer, FreeRADIUSClientsRenderer
 from secrets.exceptions import InvalidSessionKey
@@ -50,35 +50,38 @@ class SecretViewSet(WritableSerializerMixin, ModelViewSet):
     filter_class = SecretFilter
     # DRF's BrowsableAPIRenderer can't support passing the secret key as a header, so we disable it.
     renderer_classes = [FormlessBrowsableAPIRenderer, JSONRenderer, FreeRADIUSClientsRenderer]
-    # Enabled BasicAuthentication for testing (until we have TokenAuthentication implemented)
-    authentication_classes = [BasicAuthentication, SessionAuthentication]
-    permission_classes = [IsAuthenticated]
 
-    def _read_session_key(self, request):
+    master_key = None
 
-        # Check for a session key provided as a cookie or header
-        if 'session_key' in request.COOKIES:
-            return base64.b64decode(request.COOKIES['session_key'])
-        elif 'HTTP_X_SESSION_KEY' in request.META:
-            return base64.b64decode(request.META['HTTP_X_SESSION_KEY'])
-        return None
+    def initial(self, request, *args, **kwargs):
 
-    def retrieve(self, request, *args, **kwargs):
+        super(SecretViewSet, self).initial(request, *args, **kwargs)
 
-        secret = self.get_object()
-        session_key = self._read_session_key(request)
+        # Read session key from HTTP cookie or header if it has been provided. The session key must be provided in order
+        # to encrypt/decrypt secrets.
+        if 'session_key' in request.COOKIES:
+            session_key = base64.b64decode(request.COOKIES['session_key'])
+        elif 'HTTP_X_SESSION_KEY' in request.META:
+            session_key = base64.b64decode(request.META['HTTP_X_SESSION_KEY'])
+        else:
+            session_key = None
 
-        # Retrieve session key cipher (if any) for the current user
+        # Attempt to retrieve the master key for encryption/decryption if a session key has been provided.
         if session_key is not None:
             try:
                 sk = SessionKey.objects.get(userkey__user=request.user)
-                master_key = sk.get_master_key(session_key)
-                secret.decrypt(master_key)
-            except SessionKey.DoesNotExist:
-                return HttpResponseBadRequest("No active session key for current user.")
-            except InvalidSessionKey:
+                self.master_key = sk.get_master_key(session_key)
+            except (SessionKey.DoesNotExist, InvalidSessionKey):
                 return HttpResponseBadRequest("Invalid session key.")
 
+    def retrieve(self, request, *args, **kwargs):
+
+        secret = self.get_object()
+
+        # Attempt to decrypt the secret if the master key is known
+        if self.master_key is not None:
+            secret.decrypt(self.master_key)
+
         serializer = self.get_serializer(secret)
         return Response(serializer.data)
 
@@ -86,29 +89,19 @@ class SecretViewSet(WritableSerializerMixin, ModelViewSet):
 
         queryset = self.filter_queryset(self.get_queryset())
 
-        # Attempt to retrieve the master key for decryption
-        session_key = self._read_session_key(request)
-        master_key = None
-        if session_key is not None:
-            try:
-                sk = SessionKey.objects.get(user=request.user)
-                master_key = sk.get_master_key(session_key)
-            except SessionKey.DoesNotExist:
-                return HttpResponseBadRequest("No active session key for current user.")
-            except InvalidSessionKey:
-                return HttpResponseBadRequest("Invalid session key.")
-
-        # Pagination
         page = self.paginate_queryset(queryset)
         if page is not None:
-            secrets = []
-            if master_key is not None:
+
+            # Attempt to decrypt all secrets if the master key is known
+            if self.master_key is not None:
+                secrets = []
                 for secret in page:
-                    secret.decrypt(master_key)
+                    secret.decrypt(self.master_key)
                     secrets.append(secret)
                 serializer = self.get_serializer(secrets, many=True)
             else:
                 serializer = self.get_serializer(page, many=True)
+
             return self.get_paginated_response(serializer.data)
 
         serializer = self.get_serializer(queryset, many=True)