Browse Source

From GIST

adamcik 15 years ago
parent
commit
5284d464ef
1 changed files with 40 additions and 22 deletions
  1. 40 22
      manager.py

+ 40 - 22
manager.py

@@ -4,7 +4,7 @@ from django.core.exceptions import ValidationError
 from django.db import models, connection
 from django.db.models import sql, query
 
-INET_TERMS = {
+NET_TERMS = {
     'inet_lt': '<',
     'inet_lte': '<=',
     'inet_exact': '=',
@@ -17,35 +17,30 @@ INET_TERMS = {
     'inet_contains': '>>=',
 }
 
-class IPQuery(sql.Query):
+class NetQuery(sql.Query):
     query_terms = sql.Query.query_terms.copy()
-    query_terms.update(INET_TERMS)
+    query_terms.update(NET_TERMS)
 
     def add_filter(self, (filter_string, value), *args, **kwargs):
         if isinstance(value, IP):
             value = unicode(value)
-        return super(IPQuery, self).add_filter((filter_string, value), *args, **kwargs)
+        return super(NetQuery, self).add_filter((filter_string, value), *args, **kwargs)
 
-class IPWhere(sql.where.WhereNode):
+class NetWhere(sql.where.WhereNode):
     def make_atom(self, child, qn):
         table_alias, name, db_type, lookup_type, value_annot, params = child
 
-        if lookup_type in INET_TERMS:
-            return ('%s.%s %s inet %%s' % (table_alias, name, INET_TERMS[lookup_type]), params)
+        if db_type in ['cidr', 'inet'] and lookup_type in NET_TERMS:
+            return ('%s.%s %s inet %%s' % (table_alias, name, NET_TERMS[lookup_type]), params)
 
-        return super(IPWhere, self).make_atom(child, qn)
+        return super(NetWhere, self).make_atom(child, qn)
 
-class IPManger(models.Manager):
+class NetManger(models.Manager):
     def get_query_set(self):
-        q = IPQuery(self.model, connection, IPWhere)
+        q = NetQuery(self.model, connection, NetWhere)
         return query.QuerySet(self.model, q)
 
-class IPAddressField(models.IPAddressField):
-    __metaclass__ = models.SubfieldBase
-
-    def db_type(self):
-        return 'inet'
-
+class _NetAddressField(models.Field):
     def to_python(self, value):
         if not value:
             return None
@@ -64,12 +59,35 @@ class IPAddressField(models.IPAddressField):
         if lookup_type in INET_TERMS:
             return [value]
 
-        return super(IPAddressField, self).get_db_prep_lookup(lookup_type, value)
+        return super(_NetAddressField, self).get_db_prep_lookup(lookup_type, value)
 
-class Foo(models.Model):
-    ip = IPAddressField()
+class InetAddressField(_NetAddressField):
+    description = "Postgresql inet field"
+    __metaclass__ = models.SubfieldBase
+
+    def db_type(self):
+        return 'inet'
+
+class CidrAddressField(_NetAddressField):
+    description = "Postgresql cidr field"
+    __metaclass__ = models.SubfieldBase
+
+    def db_type(self):
+        return 'cidr'
 
-    objects = IPManger()
+class MACAddressField(models.Field):
+    description = "Postgresql macaddr field"
+
+    def __init__(self, *args, **kwargs):
+        kwargs['max_length'] = 17
+        super(MACAddressField, self).__init__(*args, **kwargs)
+
+    def db_type(self):
+        return 'macaddr'
+
+class Foo(models.Model):
+    inet = InetAddressField()
+    test = CidrAddressField()
+    mac = MACAddressField()
 
-    def __unicode__(self):
-        return unicode(self.ip)
+    objects = NetManger()