Parcourir la source

Fix handling of caps etc

Thomas Adamcik il y a 15 ans
Parent
commit
8505ff2005
1 fichiers modifiés avec 53 ajouts et 28 suppressions
  1. 53 28
      manager.py

+ 53 - 28
manager.py

@@ -6,17 +6,29 @@ from django.db import models, connection
 from django.db.models import sql, query
 
 NET_TERMS = {
-    'lt': '%s < %%s',
-    'lte': '%s <= %%s',
-    'exact': '%s = %%s',
-    'gte': '%s >= %%s',
-    'gt': '%s > %%s',
-    'net_contained': '%s << %%s',
-    'net_contained_or_equal': '%s <<= %%s',
-    'net_contains': '%s >> %%s',
-    'net_contains_or_equals': '%s >>= %%s',
+    'lt': '%s < %s',
+    'lte': '%s <= %s',
+    'exact': '%s = %s',
+    'gte': '%s >= %s',
+    'gt': '%s > %s',
+    'net_contained': '%s << %s',
+    'net_contained_or_equal': '%s <<= %s',
+    'net_contains': '%s >> %s',
+    'net_contains_or_equals': '%s >>= %s',
+
+    'contains': "%s LIKE %s",
+    'startswith': "%s LIKE %s",
+    'endswith': "%s LIKE %s",
+    'regex': '%s ~* %s',
 }
 
+NET_TERMS_TEXT_LOOKUPS = set([
+    'contains',
+    'startswith',
+    'endswith',
+    'regex',
+])
+
 NET_TERMS_MAPPING = {
     'iexact': 'exact',
     'icontains': 'contains',
@@ -25,6 +37,8 @@ NET_TERMS_MAPPING = {
     'iregex': 'regex',
 }
 
+# FIXME test with .extra() and QueryWrapper
+
 class NetQuery(sql.Query):
     query_terms = sql.Query.query_terms.copy()
     query_terms.update(NET_TERMS)
@@ -39,24 +53,38 @@ class NetQuery(sql.Query):
 class NetWhere(sql.where.WhereNode):
     def make_atom(self, child, qn):
         table_alias, name, db_type, lookup_type, value_annot, params = child
-        field_sql = '%s.%s' % (qn(table_alias), qn(name))
+        if table_alias:
+            field_sql = '%s.%s' % (qn(table_alias), qn(name))
+        else:
+            field_sql = qn(name)
 
         if db_type not in ['inet', 'cidr']:
             return super(NetWhere, self).make_atom(child, qn)
-        elif lookup_type in NET_TERMS_MAPPING:
+
+        if lookup_type == 'regex':
+            lhs = 'HOST(%s)' % field_sql
+            rhs = '%s'
+        elif lookup_type in NET_TERMS_TEXT_LOOKUPS:
+            lhs = 'UPPER(HOST(%s))' % field_sql
+            rhs = 'UPPER(%s)'
+        else:
+            lhs = field_sql
+            rhs = '%s'
+
+        if lookup_type in NET_TERMS_MAPPING:
             lookup_type = NET_TERMS_MAPPING[lookup_type]
             child = (table_alias, name, db_type, lookup_type, value_annot, params)
             return self.make_atom(child, qn)
         elif lookup_type in NET_TERMS:
-            return (NET_TERMS[lookup_type] % field_sql, params)
+            return (NET_TERMS[lookup_type] % (lhs, rhs), params)
         elif lookup_type == 'in':
-            return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] * len(params))), params)
+            return ('%s IN (%s)' % (lhs, ', '.join([rhs] * len(params))), params)
         elif lookup_type == 'range':
-            return ('%s BETWEEN %%s and %%s' % (field_sql), params)
+            return ('%s BETWEEN %s and %s' % (lhs, rhs, rhs), params)
         elif lookup_type == 'isnull':
-            return ('%s IS %sNULL' % (field_sql, (not value_annot and 'NOT ' or '')), params)
-        else:
-            return super(NetWhere, self).make_atom(child, qn)
+            return ('%s IS %sNULL' % (lhs, (not value_annot and 'NOT ' or '')), params)
+
+        raise ValueError('Invalid lookup type "%s"' % lookup_type)
 
 class NetManger(models.Manager):
     use_for_related_fields = True
@@ -93,14 +121,11 @@ class _NetAddressField(models.Field):
         if value is None:
             return value
 
-        if lookup_type in ['year', 'month', 'day']:
-            raise ValueError('Invalid lookup type "%s"' % lookup_type)
-
         if lookup_type in NET_TERMS_MAPPING:
             return self.get_db_prep_lookup(
                 NET_TERMS_MAPPING[lookup_type], value)
 
-        if lookup_type in NET_TERMS:
+        if lookup_type in NET_TERMS and lookup_type not in NET_TERMS_TEXT_LOOKUPS:
             return [self.get_db_prep_value(value)]
 
         return super(_NetAddressField, self).get_db_prep_lookup(
@@ -161,16 +186,16 @@ class InetTestModel(models.Model):
     ('SELECT "foo_inettestmodel"."id", "foo_inettestmodel"."inet" FROM "foo_inettestmodel" WHERE "foo_inettestmodel"."inet" <= %s', (u'10.0.0.1',))
 
     >>> InetTestModel.objects.filter(inet__startswith='10.').query.as_sql()
-    ('SELECT "foo_inettestmodel"."id", "foo_inettestmodel"."inet" FROM "foo_inettestmodel" WHERE HOST("foo_inettestmodel"."inet")::text LIKE %s ', (u'10.%',))
+    ('SELECT "foo_inettestmodel"."id", "foo_inettestmodel"."inet" FROM "foo_inettestmodel" WHERE UPPER(HOST("foo_inettestmodel"."inet")) LIKE UPPER(%s)', (u'10.%',))
 
     >>> InetTestModel.objects.filter(inet__istartswith='10.').query.as_sql()
-    ('SELECT "foo_inettestmodel"."id", "foo_inettestmodel"."inet" FROM "foo_inettestmodel" WHERE HOST("foo_inettestmodel"."inet")::text LIKE %s ', (u'10.%',))
+    ('SELECT "foo_inettestmodel"."id", "foo_inettestmodel"."inet" FROM "foo_inettestmodel" WHERE UPPER(HOST("foo_inettestmodel"."inet")) LIKE UPPER(%s)', (u'10.%',))
 
     >>> InetTestModel.objects.filter(inet__endswith='.1').query.as_sql()
-    ('SELECT "foo_inettestmodel"."id", "foo_inettestmodel"."inet" FROM "foo_inettestmodel" WHERE HOST("foo_inettestmodel"."inet")::text LIKE %s ', (u'%.1',))
+    ('SELECT "foo_inettestmodel"."id", "foo_inettestmodel"."inet" FROM "foo_inettestmodel" WHERE UPPER(HOST("foo_inettestmodel"."inet")) LIKE UPPER(%s)', (u'%.1',))
 
     >>> InetTestModel.objects.filter(inet__iendswith='.1').query.as_sql()
-    ('SELECT "foo_inettestmodel"."id", "foo_inettestmodel"."inet" FROM "foo_inettestmodel" WHERE HOST("foo_inettestmodel"."inet")::text LIKE %s ', (u'%.1',))
+    ('SELECT "foo_inettestmodel"."id", "foo_inettestmodel"."inet" FROM "foo_inettestmodel" WHERE UPPER(HOST("foo_inettestmodel"."inet")) LIKE UPPER(%s)', (u'%.1',))
 
     >>> InetTestModel.objects.filter(inet__range=('10.0.0.1', '10.0.0.10')).query.as_sql()
     ('SELECT "foo_inettestmodel"."id", "foo_inettestmodel"."inet" FROM "foo_inettestmodel" WHERE "foo_inettestmodel"."inet" BETWEEN %s and %s', (u'10.0.0.1', u'10.0.0.10'))
@@ -199,13 +224,13 @@ class InetTestModel(models.Model):
     >>> InetTestModel.objects.filter(inet__search='10').query.as_sql()
     Traceback (most recent call last):
         ...
-    NotImplementedError: Full-text search is not implemented for this database backend
+    ValueError: Invalid lookup type "search"
 
     >>> InetTestModel.objects.filter(inet__regex=u'10').query.as_sql()
-    ('SELECT "foo_inettestmodel"."id", "foo_inettestmodel"."inet" FROM "foo_inettestmodel" WHERE HOST("foo_inettestmodel"."inet") ~ %s ', (u'10',))
+    ('SELECT "foo_inettestmodel"."id", "foo_inettestmodel"."inet" FROM "foo_inettestmodel" WHERE HOST("foo_inettestmodel"."inet") ~* %s', (u'10',))
 
     >>> InetTestModel.objects.filter(inet__iregex=u'10').query.as_sql()
-    ('SELECT "foo_inettestmodel"."id", "foo_inettestmodel"."inet" FROM "foo_inettestmodel" WHERE HOST("foo_inettestmodel"."inet") ~ %s ', (u'10',))
+    ('SELECT "foo_inettestmodel"."id", "foo_inettestmodel"."inet" FROM "foo_inettestmodel" WHERE HOST("foo_inettestmodel"."inet") ~* %s', (u'10',))
 
     >>> InetTestModel.objects.filter(inet__net_contains_or_equals='10.0.0.1').query.as_sql()
     ('SELECT "foo_inettestmodel"."id", "foo_inettestmodel"."inet" FROM "foo_inettestmodel" WHERE "foo_inettestmodel"."inet" >>= %s', (u'10.0.0.1',))