Browse Source

Merge pull request #10 from ekohl/master

Update code with better support for newer Django.

This merge also has lots of cleanups and improvements to the tests which now uses tox. Thanks to Ewoud Kohl van Wijngaarden (@ekohl) for giving this project some much needed love :-)
Thomas Adamcik 12 years ago
parent
commit
88769096e4
16 changed files with 173 additions and 68 deletions
  1. 5 0
      .gitignore
  2. 1 0
      MANIFEST.in
  3. 10 0
      manage.py
  4. 1 1
      netfields/__init__.py
  5. 16 18
      netfields/fields.py
  6. 11 11
      netfields/forms.py
  7. 12 15
      netfields/managers.py
  8. 12 9
      tests/models.py
  9. 8 6
      tests/tests.py
  10. 3 0
      requirements.txt
  11. 42 0
      setup.py
  12. 0 0
      tests/__init__.py
  13. 0 8
      tests/settings.py
  14. 0 0
      tests/urls.py
  15. 12 0
      testsettings.py
  16. 40 0
      tox.ini

+ 5 - 0
.gitignore

@@ -1 +1,6 @@
 *.pyc
 *.pyc
+.tox/
+MANIFEST
+build/
+dist/
+*.egg-info/

+ 1 - 0
MANIFEST.in

@@ -0,0 +1 @@
+include README.rst

+ 10 - 0
manage.py

@@ -0,0 +1,10 @@
+#!/usr/bin/env python
+import os
+import sys
+
+if __name__ == "__main__":
+    os.environ.setdefault("DJANGO_SETTINGS_MODULE", "testsettings")
+
+    from django.core.management import execute_from_command_line
+
+    execute_from_command_line(sys.argv)

+ 1 - 1
netfields/__init__.py

@@ -1,3 +1,3 @@
 from netfields.managers import NetManager
 from netfields.managers import NetManager
 from netfields.fields import (InetAddressField, CidrAddressField,
 from netfields.fields import (InetAddressField, CidrAddressField,
-        MACAddressField)
+                              MACAddressField)

+ 16 - 18
netfields/fields.py

@@ -5,6 +5,7 @@ from django.db import models
 from netfields.managers import NET_OPERATORS, NET_TEXT_OPERATORS
 from netfields.managers import NET_OPERATORS, NET_TEXT_OPERATORS
 from netfields.forms import NetAddressFormField, MACAddressFormField
 from netfields.forms import NetAddressFormField, MACAddressFormField
 
 
+
 class _NetAddressField(models.Field):
 class _NetAddressField(models.Field):
     empty_strings_allowed = False
     empty_strings_allowed = False
 
 
@@ -14,41 +15,38 @@ class _NetAddressField(models.Field):
 
 
     def to_python(self, value):
     def to_python(self, value):
         if not value:
         if not value:
-            value = None
-
-        if value is None:
             return value
             return value
 
 
         return IP(value)
         return IP(value)
 
 
     def get_prep_lookup(self, lookup_type, value):
     def get_prep_lookup(self, lookup_type, value):
-        if value is None:
+        if not value:
-            return value
+            return None
 
 
         if (lookup_type in NET_OPERATORS and
         if (lookup_type in NET_OPERATORS and
                 NET_OPERATORS[lookup_type] not in NET_TEXT_OPERATORS):
                 NET_OPERATORS[lookup_type] not in NET_TEXT_OPERATORS):
-            return self.get_db_prep_value(value)
+            return self.get_prep_value(value)
 
 
         return super(_NetAddressField, self).get_prep_lookup(
         return super(_NetAddressField, self).get_prep_lookup(
             lookup_type, value)
             lookup_type, value)
 
 
-
+    def get_prep_value(self, value):
-    def get_db_prep_value(self, value):
+        if not value:
-        if value is None:
+            return None
-            return value
 
 
         return unicode(self.to_python(value))
         return unicode(self.to_python(value))
 
 
-    def get_db_prep_lookup(self, lookup_type, value):
+    def get_db_prep_lookup(self, lookup_type, value, connection,
-        if value is None:
+            prepared=False):
-            return value
+        if not value:
+            return []
 
 
         if (lookup_type in NET_OPERATORS and
         if (lookup_type in NET_OPERATORS and
                 NET_OPERATORS[lookup_type] not in NET_TEXT_OPERATORS):
                 NET_OPERATORS[lookup_type] not in NET_TEXT_OPERATORS):
-            return [self.get_db_prep_value(value)]
+            return [value] if prepared else [self.get_prep_value(value)]
 
 
         return super(_NetAddressField, self).get_db_prep_lookup(
         return super(_NetAddressField, self).get_db_prep_lookup(
-            lookup_type, value)
+            lookup_type, value, connection=connection, prepared=prepared)
 
 
     def formfield(self, **kwargs):
     def formfield(self, **kwargs):
         defaults = {'form_class': NetAddressFormField}
         defaults = {'form_class': NetAddressFormField}
@@ -61,7 +59,7 @@ class InetAddressField(_NetAddressField):
     max_length = 39
     max_length = 39
     __metaclass__ = models.SubfieldBase
     __metaclass__ = models.SubfieldBase
 
 
-    def db_type(self):
+    def db_type(self, connection):
         return 'inet'
         return 'inet'
 
 
 
 
@@ -70,7 +68,7 @@ class CidrAddressField(_NetAddressField):
     max_length = 43
     max_length = 43
     __metaclass__ = models.SubfieldBase
     __metaclass__ = models.SubfieldBase
 
 
-    def db_type(self):
+    def db_type(self, connection):
         return 'cidr'
         return 'cidr'
 
 
 
 
@@ -81,7 +79,7 @@ class MACAddressField(models.Field):
         kwargs['max_length'] = 17
         kwargs['max_length'] = 17
         super(MACAddressField, self).__init__(*args, **kwargs)
         super(MACAddressField, self).__init__(*args, **kwargs)
 
 
-    def db_type(self):
+    def db_type(self, connection):
         return 'macaddr'
         return 'macaddr'
 
 
     def formfield(self, **kwargs):
     def formfield(self, **kwargs):

+ 11 - 11
netfields/forms.py

@@ -5,12 +5,14 @@ from django import forms
 from django.utils.encoding import force_unicode
 from django.utils.encoding import force_unicode
 from django.utils.safestring import mark_safe
 from django.utils.safestring import mark_safe
 
 
+
 class NetInput(forms.Widget):
 class NetInput(forms.Widget):
     input_type = 'text'
     input_type = 'text'
 
 
     def render(self, name, value, attrs=None):
     def render(self, name, value, attrs=None):
         # Default forms.Widget compares value != '' which breaks IP...
         # Default forms.Widget compares value != '' which breaks IP...
-        if value is None: value = ''
+        if value is None:
+            value = ''
         final_attrs = self.build_attrs(attrs, type=self.input_type, name=name)
         final_attrs = self.build_attrs(attrs, type=self.input_type, name=name)
         if value:
         if value:
             final_attrs['value'] = force_unicode(value)
             final_attrs['value'] = force_unicode(value)
@@ -26,20 +28,18 @@ class NetAddressFormField(forms.Field):
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
         super(NetAddressFormField, self).__init__(*args, **kwargs)
         super(NetAddressFormField, self).__init__(*args, **kwargs)
 
 
-    def clean(self, value):
+    def to_python(self, value):
-        super(NetAddressFormField, self).clean(value)
+        if not value:
-
-        if value in (None, ''):
             return None
             return None
+
         if isinstance(value, IP):
         if isinstance(value, IP):
             return value
             return value
-        try:
-            return IP(value)
-        except ValueError, e:
-            raise forms.ValidationError(e)
 
 
+        return self.python_type(value)
+
+
+MAC_RE = re.compile(r'^(([A-F0-9]{2}:){5}[A-F0-9]{2})$')
 
 
-mac_re = re.compile(r'^(([A-F0-9]{2}:){5}[A-F0-9]{2})$')
 
 
 class MACAddressFormField(forms.RegexField):
 class MACAddressFormField(forms.RegexField):
     default_error_messages = {
     default_error_messages = {
@@ -47,4 +47,4 @@ class MACAddressFormField(forms.RegexField):
     }
     }
 
 
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
-        super(MACAddressFormField, self).__init__(mac_re, *args, **kwargs)
+        super(MACAddressFormField, self).__init__(MAC_RE, *args, **kwargs)

+ 12 - 15
netfields/managers.py

@@ -25,13 +25,6 @@ class NetQuery(sql.Query):
     query_terms = sql.Query.query_terms.copy()
     query_terms = sql.Query.query_terms.copy()
     query_terms.update(NET_OPERATORS)
     query_terms.update(NET_OPERATORS)
 
 
-    def add_filter(self, (filter_string, value), *args, **kwargs):
-        # IP(...) == '' fails so make sure to force to string while we can
-        if isinstance(value, IP):
-            value = unicode(value)
-        return super(NetQuery, self).add_filter(
-            (filter_string, value), *args, **kwargs)
-
 
 
 class NetWhere(sql.where.WhereNode):
 class NetWhere(sql.where.WhereNode):
     def make_atom(self, child, qn, conn):
     def make_atom(self, child, qn, conn):
@@ -39,9 +32,10 @@ class NetWhere(sql.where.WhereNode):
 
 
         if hasattr(lvalue, 'process'):
         if hasattr(lvalue, 'process'):
             try:
             try:
-                lvalue, params = lvalue.process(lookup_type, params_or_value, connection)
+                lvalue, params = lvalue.process(lookup_type, params_or_value,
-            except EmptyShortCircuit:
+                                                connection)
-                raise EmptyResultSet
+            except sql.where.EmptyShortCircuit:
+                raise query.EmptyResultSet
         else:
         else:
             return super(NetWhere, self).make_atom(child, qn, conn)
             return super(NetWhere, self).make_atom(child, qn, conn)
 
 
@@ -57,9 +51,9 @@ class NetWhere(sql.where.WhereNode):
 
 
         if NET_OPERATORS.get(lookup_type, '') in NET_TEXT_OPERATORS:
         if NET_OPERATORS.get(lookup_type, '') in NET_TEXT_OPERATORS:
             if db_type == 'inet':
             if db_type == 'inet':
-                field_sql  = 'HOST(%s)' % field_sql
+                field_sql = 'HOST(%s)' % field_sql
             else:
             else:
-                field_sql  = 'TEXT(%s)' % field_sql
+                field_sql = 'TEXT(%s)' % field_sql
 
 
         if isinstance(params, QueryWrapper):
         if isinstance(params, QueryWrapper):
             extra, params = params.data
             extra, params = params.data
@@ -70,17 +64,20 @@ class NetWhere(sql.where.WhereNode):
             params = (params,)
             params = (params,)
 
 
         if lookup_type in NET_OPERATORS:
         if lookup_type in NET_OPERATORS:
-            return (' '.join([field_sql, NET_OPERATORS[lookup_type], extra]), params)
+            return (' '.join([field_sql, NET_OPERATORS[lookup_type], extra]),
+                    params)
         elif lookup_type == 'in':
         elif lookup_type == 'in':
             if not value_annot:
             if not value_annot:
                 raise sql.datastructures.EmptyResultSet
                 raise sql.datastructures.EmptyResultSet
             if extra:
             if extra:
                 return ('%s IN %s' % (field_sql, extra), params)
                 return ('%s IN %s' % (field_sql, extra), params)
-            return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] * len(params))), params)
+            return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] *
+                    len(params))), params)
         elif lookup_type == 'range':
         elif lookup_type == 'range':
             return ('%s BETWEEN %%s and %%s' % field_sql, params)
             return ('%s BETWEEN %%s and %%s' % field_sql, params)
         elif lookup_type == 'isnull':
         elif lookup_type == 'isnull':
-            return ('%s IS %sNULL' % (field_sql, (not value_annot and 'NOT ' or '')), params)
+            return ('%s IS %sNULL' % (field_sql, (not value_annot and 'NOT ' or
+                    '')), params)
 
 
         raise ValueError('Invalid lookup type "%s"' % lookup_type)
         raise ValueError('Invalid lookup type "%s"' % lookup_type)
 
 

+ 12 - 9
tests/models.py

@@ -1,41 +1,44 @@
-from IPy import IP
+from django.db.models import Model
 
 
-from django.db import models, connection, DEFAULT_DB_ALIAS
+from netfields import InetAddressField, CidrAddressField, MACAddressField, \
+        NetManager
 
 
-from netfields import *
 
 
-class InetTestModel(models.Model):
+class InetTestModel(Model):
     field = InetAddressField()
     field = InetAddressField()
     objects = NetManager()
     objects = NetManager()
 
 
     class Meta:
     class Meta:
         db_table = 'inet'
         db_table = 'inet'
 
 
-class NullInetTestModel(models.Model):
+
+class NullInetTestModel(Model):
     field = InetAddressField(null=True)
     field = InetAddressField(null=True)
     objects = NetManager()
     objects = NetManager()
 
 
     class Meta:
     class Meta:
         db_table = 'nullinet'
         db_table = 'nullinet'
 
 
-class CidrTestModel(models.Model):
+
+class CidrTestModel(Model):
     field = CidrAddressField()
     field = CidrAddressField()
     objects = NetManager()
     objects = NetManager()
 
 
     class Meta:
     class Meta:
         db_table = 'cidr'
         db_table = 'cidr'
 
 
-class NullCidrTestModel(models.Model):
+
+class NullCidrTestModel(Model):
     field = CidrAddressField(null=True)
     field = CidrAddressField(null=True)
     objects = NetManager()
     objects = NetManager()
 
 
     class Meta:
     class Meta:
         db_table = 'nullcidr'
         db_table = 'nullcidr'
 
 
-class MACTestModel(models.Model):
+
+class MACTestModel(Model):
     mac = MACAddressField(null=True)
     mac = MACAddressField(null=True)
     objects = NetManager()
     objects = NetManager()
 
 
     class Meta:
     class Meta:
         db_table = 'mac'
         db_table = 'mac'
-

+ 8 - 6
tests/tests.py

@@ -1,9 +1,11 @@
-import unittest
 from IPy import IP
 from IPy import IP
 
 
 from django.db import IntegrityError
 from django.db import IntegrityError
+from django.test import TestCase
+
+from netfields.models import (CidrTestModel, InetTestModel, NullCidrTestModel,
+                              NullInetTestModel)
 
 
-from models import *
 
 
 class BaseTestCase(object):
 class BaseTestCase(object):
     select = 'SELECT "table"."id", "table"."field" FROM "table" '
     select = 'SELECT "table"."id", "table"."field" FROM "table" '
@@ -152,7 +154,7 @@ class BaseCidrFieldTestCase(BaseTestCase):
             self.select + 'WHERE TEXT("table"."field") ~* %s ')
             self.select + 'WHERE TEXT("table"."field") ~* %s ')
 
 
 
 
-class TestInetField(BaseInetFieldTestCase, unittest.TestCase):
+class TestInetField(BaseInetFieldTestCase, TestCase):
     def setUp(self):
     def setUp(self):
         self.model = InetTestModel
         self.model = InetTestModel
         self.qs = self.model.objects.all()
         self.qs = self.model.objects.all()
@@ -168,7 +170,7 @@ class TestInetField(BaseInetFieldTestCase, unittest.TestCase):
         self.assertRaises(IntegrityError, self.model().save)
         self.assertRaises(IntegrityError, self.model().save)
 
 
 
 
-class TestInetFieldNullable(BaseInetFieldTestCase, unittest.TestCase):
+class TestInetFieldNullable(BaseInetFieldTestCase, TestCase):
     def setUp(self):
     def setUp(self):
         self.model = NullInetTestModel
         self.model = NullInetTestModel
         self.qs = self.model.objects.all()
         self.qs = self.model.objects.all()
@@ -184,7 +186,7 @@ class TestInetFieldNullable(BaseInetFieldTestCase, unittest.TestCase):
         self.model().save()
         self.model().save()
 
 
 
 
-class TestCidrField(BaseCidrFieldTestCase, unittest.TestCase):
+class TestCidrField(BaseCidrFieldTestCase, TestCase):
     def setUp(self):
     def setUp(self):
         self.model = CidrTestModel
         self.model = CidrTestModel
         self.qs = self.model.objects.all()
         self.qs = self.model.objects.all()
@@ -200,7 +202,7 @@ class TestCidrField(BaseCidrFieldTestCase, unittest.TestCase):
         self.assertRaises(IntegrityError, self.model().save)
         self.assertRaises(IntegrityError, self.model().save)
 
 
 
 
-class TestCidrFieldNullable(BaseCidrFieldTestCase, unittest.TestCase):
+class TestCidrFieldNullable(BaseCidrFieldTestCase, TestCase):
     def setUp(self):
     def setUp(self):
         self.model = NullCidrTestModel
         self.model = NullCidrTestModel
         self.qs = self.model.objects.all()
         self.qs = self.model.objects.all()

+ 3 - 0
requirements.txt

@@ -0,0 +1,3 @@
+IPy
+django>=1.3
+psycopg2

+ 42 - 0
setup.py

@@ -0,0 +1,42 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+from distutils.core import setup
+from setuptools import find_packages
+
+import os
+
+def get_long_description():
+    path = os.path.join(os.path.dirname(__file__), 'README.rst')
+    with open(path) as f:
+        return f.read()
+
+setup(
+    name='django-postgresql-netfields',
+    version='0.1',
+    license='BSD',
+    description='Django PostgreSQL netfields implementation',
+    long_description=get_long_description(),
+    url='https://github.com/adamcik/django-postgresql-netfields',
+
+    author=u'Thomas Admacik',
+    author_email='adamcik@samfundet.no',
+
+    packages=find_packages(),
+    include_package_data=True,
+    zip_safe=False,
+    install_requires=[
+        'IPy',
+        'django>=1.3',
+    ],
+
+    classifiers=[
+        'Development Status :: 4 - Beta',
+        'Environment :: Web Environment',
+        'Framework :: Django',
+        'Intended Audience :: Developers',
+        'License :: OSI Approved :: BSD License',
+        'Operating System :: OS Independent',
+        'Programming Language :: Python',
+        'Topic :: Utilities',
+    ],
+)

+ 0 - 0
tests/__init__.py


+ 0 - 8
tests/settings.py

@@ -1,8 +0,0 @@
-DATABASES = {
-    'default': {
-        'ENGINE': 'django.db.backends.sqlite3',
-        'NAME': 'netfields',
-    }
-}
-
-INSTALLED_APPS = ('tests',)

+ 0 - 0
tests/urls.py


+ 12 - 0
testsettings.py

@@ -0,0 +1,12 @@
+DATABASES = {
+    'default': {
+        'ENGINE': 'django.db.backends.postgresql_psycopg2',
+        'NAME': 'netfields',
+    }
+}
+
+INSTALLED_APPS = (
+    'netfields',
+)
+
+SECRET_KEY = "notimportant"

+ 40 - 0
tox.ini

@@ -0,0 +1,40 @@
+[tox]
+envlist=
+    py26-django13,
+    py27-django13,
+    py26-django14,
+    py27-django14,
+
+[testenv]
+commands=
+    python manage.py test
+
+# Build configurations...
+
+[testenv:py26-django13]
+basepython=python2.6
+deps=
+    IPy
+    django==1.3
+    psycopg2==2.4.1
+
+[testenv:py27-django13]
+basepython=python2.7
+deps=
+    IPy
+    django==1.3
+    psycopg2==2.4.1
+
+[testenv:py26-django14]
+basepython=python2.6
+deps=
+    IPy
+    django==1.4
+    psycopg2
+
+[testenv:py27-django14]
+basepython=python2.7
+deps=
+    IPy
+    django==1.4
+    psycopg2