from netaddr import IPNetwork

from django import VERSION
from django.db import models, connection
from django.db.backends.postgresql_psycopg2.base import DatabaseWrapper
from django.db.models.fields import DateTimeField
from django.db.models import sql, query
from django.db.models.query_utils import QueryWrapper
from django.utils import tree

import datetime

NET_OPERATORS = DatabaseWrapper.operators.copy()

for operator in ['contains', 'startswith', 'endswith']:
    NET_OPERATORS[operator] = 'ILIKE %s'
    NET_OPERATORS['i%s' % operator] = 'ILIKE %s'

NET_OPERATORS['iexact'] = NET_OPERATORS['exact']
NET_OPERATORS['regex'] = NET_OPERATORS['iregex']
NET_OPERATORS['net_contained'] = '<< %s'
NET_OPERATORS['net_contained_or_equal'] = '<<= %s'
NET_OPERATORS['net_contains'] = '>> %s'
NET_OPERATORS['net_contains_or_equals'] = '>>= %s'

NET_TEXT_OPERATORS = ['ILIKE %s', '~* %s']


class NetQuery(sql.Query):
    query_terms = sql.Query.query_terms.copy()
    query_terms.update(NET_OPERATORS)


class NetWhere(sql.where.WhereNode):


    def _prepare_data(self, data):
        """
        Special form of WhereNode._prepare_data() that does not automatically consume the
        __iter__ method of IPNetwork objects.  This is used in Django >= 1.6
        """

        if not isinstance(data, (list, tuple)):
            return data
        obj, lookup_type, value = data
        if not isinstance(value, IPNetwork) and hasattr(value, '__iter__') and hasattr(value, 'next'):
            # Consume any generators immediately, so that we can determine
            # emptiness and transform any non-empty values correctly.
            value = list(value)


        # The "value_annotation" parameter is used to pass auxilliary information
        # about the value(s) to the query construction. Specifically, datetime
        # and empty values need special handling. Other types could be used
        # here in the future (using Python types is suggested for consistency).
        if (isinstance(value, datetime.datetime)
            or (isinstance(obj.field, DateTimeField) and lookup_type != 'isnull')):
            value_annotation = datetime.datetime
        elif hasattr(value, 'value_annotation'):
            value_annotation = value.value_annotation
        else:
            value_annotation = bool(value)

        if hasattr(obj, "prepare"):
            value = obj.prepare(lookup_type, value)
        return (obj, lookup_type, value_annotation, value)


    if VERSION[:2] < (1, 6):
        def add(self, data, connector):
            """
            Special form of WhereNode.add() that does not automatically consume the
            __iter__ method of IPNetwork objects.
            """
            if not isinstance(data, (list, tuple)):
                # Need to bypass WhereNode
                tree.Node.add(self, data, connector)
                return

            obj, lookup_type, value = data
            if not isinstance(value, IPNetwork) and hasattr(value, '__iter__') and hasattr(value, 'next'):
                # Consume any generators immediately, so that we can determine
                # emptiness and transform any non-empty values correctly.
                value = list(value)

            # The "value_annotation" parameter is used to pass auxilliary information
            # about the value(s) to the query construction. Specifically, datetime
            # and empty values need special handling. Other types could be used
            # here in the future (using Python types is suggested for consistency).
            if isinstance(value, datetime.datetime):
                value_annotation = datetime.datetime
            elif hasattr(value, 'value_annotation'):
                value_annotation = value.value_annotation
            else:
                value_annotation = bool(value)

            if hasattr(obj, "prepare"):
                value = obj.prepare(lookup_type, value)

            # Need to bypass WhereNode
            tree.Node.add(self,
                (obj, lookup_type, value_annotation, value), connector)

    def make_atom(self, child, qn, conn):
        lvalue, lookup_type, value_annot, params_or_value = child

        if hasattr(lvalue, 'process'):
            try:
                lvalue, params = lvalue.process(lookup_type, params_or_value,
                                                connection)
            except sql.where.EmptyShortCircuit:
                raise query.EmptyResultSet
        else:
            return super(NetWhere, self).make_atom(child, qn, conn)

        table_alias, name, db_type = lvalue

        if db_type not in ['inet', 'cidr']:
            return super(NetWhere, self).make_atom(child, qn, conn)

        if table_alias:
            field_sql = '%s.%s' % (qn(table_alias), qn(name))
        else:
            field_sql = qn(name)

        if NET_OPERATORS.get(lookup_type, '') in NET_TEXT_OPERATORS:
            if db_type == 'inet':
                field_sql = 'HOST(%s)' % field_sql
            else:
                field_sql = 'TEXT(%s)' % field_sql

        if isinstance(params, QueryWrapper):
            extra, params = params.data
        else:
            extra = ''

        if isinstance(params, basestring):
            params = (params,)

        if lookup_type in NET_OPERATORS:
            return (' '.join([field_sql, NET_OPERATORS[lookup_type], extra]),
                    params)
        elif lookup_type == 'in':
            if not value_annot:
                raise sql.datastructures.EmptyResultSet
            if extra:
                return ('%s IN %s' % (field_sql, extra), params)
            return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] *
                    len(params))), params)
        elif lookup_type == 'range':
            return ('%s BETWEEN %%s and %%s' % field_sql, params)
        elif lookup_type == 'isnull':
            return ('%s IS %sNULL' % (field_sql, (not value_annot and 'NOT ' or
                    '')), params)

        raise ValueError('Invalid lookup type "%s"' % lookup_type)


class NetManager(models.Manager):
    use_for_related_fields = True

    def get_query_set(self):
        q = NetQuery(self.model, NetWhere)
        return query.QuerySet(self.model, q)