Browse Source

Corrected tests and improved validation

Jeremy Stretch 8 years ago
parent
commit
117da337c7
2 changed files with 27 additions and 21 deletions
  1. 24 18
      netbox/extras/api/customfields.py
  2. 3 3
      netbox/extras/tests/test_customfields.py

+ 24 - 18
netbox/extras/api/customfields.py

@@ -18,25 +18,29 @@ class CustomFieldsSerializer(serializers.BaseSerializer):
 
     def to_internal_value(self, data):
 
-        parent_content_type = ContentType.objects.get_for_model(self.parent.Meta.model)
+        content_type = ContentType.objects.get_for_model(self.parent.Meta.model)
+        custom_fields = {field.name: field for field in CustomField.objects.filter(obj_type=content_type)}
 
-        for custom_field, value in data.items():
+        for field_name, value in data.items():
 
             # Validate custom field name
-            try:
-                cf = CustomField.objects.get(name=custom_field)
-            except CustomField.DoesNotExist:
-                raise ValidationError(u"Unknown custom field: {}".format(custom_field))
-
-            # Validate custom field content type
-            if parent_content_type not in cf.obj_type.all():
-                raise ValidationError(u"Invalid custom field for {} objects".format(parent_content_type))
+            if field_name not in custom_fields:
+                raise ValidationError(u"Invalid custom field for {} objects: {}".format(content_type, field_name))
 
             # Validate selected choice
+            cf = custom_fields[field_name]
             if cf.type == CF_TYPE_SELECT:
                 valid_choices = [c.pk for c in cf.choices.all()]
                 if value not in valid_choices:
-                    raise ValidationError(u"Invalid choice ({}) for field {}".format(value, custom_field))
+                    raise ValidationError(u"Invalid choice ({}) for field {}".format(value, field_name))
+
+        # Check for missing required fields
+        missing_fields = []
+        for field_name, field in custom_fields.items():
+            if field.required and field_name not in data:
+                missing_fields.append(field_name)
+        if missing_fields:
+            raise ValidationError(u"Missing required fields: {}".format(u", ".join(missing_fields)))
 
         return data
 
@@ -45,7 +49,7 @@ class CustomFieldModelSerializer(serializers.ModelSerializer):
     """
     Extends ModelSerializer to render any CustomFields and their values associated with an object.
     """
-    custom_fields = CustomFieldsSerializer()
+    custom_fields = CustomFieldsSerializer(required=False)
 
     def __init__(self, *args, **kwargs):
 
@@ -86,29 +90,31 @@ class CustomFieldModelSerializer(serializers.ModelSerializer):
 
     def create(self, validated_data):
 
-        custom_fields = validated_data.pop('custom_fields')
+        custom_fields = validated_data.pop('custom_fields', None)
 
         with transaction.atomic():
 
             instance = super(CustomFieldModelSerializer, self).create(validated_data)
 
             # Save custom fields
-            self._save_custom_fields(instance, custom_fields)
-            instance.custom_fields = custom_fields
+            if custom_fields is not None:
+                self._save_custom_fields(instance, custom_fields)
+                instance.custom_fields = custom_fields
 
         return instance
 
     def update(self, instance, validated_data):
 
-        custom_fields = validated_data.pop('custom_fields')
+        custom_fields = validated_data.pop('custom_fields', None)
 
         with transaction.atomic():
 
             instance = super(CustomFieldModelSerializer, self).update(instance, validated_data)
 
             # Save custom fields
-            self._save_custom_fields(instance, custom_fields)
-            instance.custom_fields = custom_fields
+            if custom_fields is not None:
+                self._save_custom_fields(instance, custom_fields)
+                instance.custom_fields = custom_fields
 
         return instance
 

+ 3 - 3
netbox/extras/tests/test_customfields.py

@@ -243,7 +243,7 @@ class CustomFieldAPITest(HttpStatusMixin, APITestCase):
             'name': 'Test Site 1',
             'slug': 'test-site-1',
             'custom_fields': {
-                'is_magic': False,
+                'is_magic': 0,
             }
         }
 
@@ -261,7 +261,7 @@ class CustomFieldAPITest(HttpStatusMixin, APITestCase):
             'name': 'Test Site 1',
             'slug': 'test-site-1',
             'custom_fields': {
-                'magic_date': date(2017, 4, 25),
+                'magic_date': '2017-04-25',
             }
         }
 
@@ -271,7 +271,7 @@ class CustomFieldAPITest(HttpStatusMixin, APITestCase):
         self.assertHttpStatus(response, status.HTTP_200_OK)
         self.assertEqual(response.data['custom_fields'].get('magic_date'), data['custom_fields']['magic_date'])
         cfv = self.site.custom_field_values.get(field=self.cf_date)
-        self.assertEqual(cfv.value, data['custom_fields']['magic_date'])
+        self.assertEqual(cfv.value.isoformat(), data['custom_fields']['magic_date'])
 
     def test_set_custom_field_url(self):