Browse Source

Merge branch 'master' of https://github.com/crustymonkey/django-postgresql-netfields

Thomas Adamcik 14 years ago
parent
commit
8e66bcc89d
1 changed files with 17 additions and 6 deletions
  1. 17 6
      netfields/managers.py

+ 17 - 6
netfields/managers.py

@@ -1,5 +1,6 @@
 from IPy import IP
 
+from types import StringTypes
 from django.db import models, connection
 from django.db.models import sql, query
 from django.db.models.query_utils import QueryWrapper
@@ -33,11 +34,18 @@ class NetQuery(sql.Query):
 
 
 class NetWhere(sql.where.WhereNode):
-    def make_atom(self, child, qn):
-        table_alias, name, db_type, lookup_type, value_annot, params = child
+    def make_atom(self, child, qn , conn):
+        if isinstance(child[0] , sql.where.Constraint):
+            c = child[0]
+            table_alias = c.alias
+            name = c.col
+            field = c.field
+            lookup_type , value_annot , params = child[1:]
+        else:
+            table_alias, name, db_type, lookup_type, value_annot, params = child
 
-        if db_type not in ['inet', 'cidr']:
-            return super(NetWhere, self).make_atom(child, qn)
+        if field.db_type() not in ['inet', 'cidr']:
+            return super(NetWhere, self).make_atom(child, qn , conn)
 
         if table_alias:
             field_sql = '%s.%s' % (qn(table_alias), qn(name))
@@ -45,7 +53,7 @@ class NetWhere(sql.where.WhereNode):
             field_sql = qn(name)
 
         if NET_OPERATORS.get(lookup_type, '') in NET_TEXT_OPERATORS:
-            if db_type == 'inet':
+            if field.db_type() == 'inet':
                 field_sql  = 'HOST(%s)' % field_sql
             else:
                 field_sql  = 'TEXT(%s)' % field_sql
@@ -55,6 +63,9 @@ class NetWhere(sql.where.WhereNode):
         else:
             extra = ''
 
+        if type(params) in StringTypes:
+            params = (params,)
+
         if lookup_type in NET_OPERATORS:
             return (' '.join([field_sql, NET_OPERATORS[lookup_type], extra]), params)
         elif lookup_type == 'in':
@@ -75,5 +86,5 @@ class NetManager(models.Manager):
     use_for_related_fields = True
 
     def get_query_set(self):
-        q = NetQuery(self.model, connection, NetWhere)
+        q = NetQuery(self.model, NetWhere)
         return query.QuerySet(self.model, q)