from netaddr import IPNetwork

from django import VERSION
from django.db import models, connection
from django.db.backends.postgresql_psycopg2.base import DatabaseWrapper
from django.db.models.fields import DateTimeField, Field
from django.db.models import sql, query, Lookup
from django.db.models.query_utils import QueryWrapper
from django.utils import tree

import datetime

NET_OPERATORS = DatabaseWrapper.operators.copy()

for operator in ['contains', 'startswith', 'endswith']:
    NET_OPERATORS[operator] = 'ILIKE %s'
    NET_OPERATORS['i%s' % operator] = 'ILIKE %s'

NET_OPERATORS['iexact'] = NET_OPERATORS['exact']
NET_OPERATORS['regex'] = NET_OPERATORS['iregex']

NET_TEXT_OPERATORS = ['ILIKE %s', '~* %s']


class NetContainsOrEqual(Lookup):
    lookup_name = 'net_contains_or_equals'
    def as_sql(self, qn, connection):
        lhs, lhs_params = self.process_lhs(qn, connection)
        rhs, rhs_params = self.process_rhs(qn, connection)
        params = lhs_params + rhs_params
        return "%s >>= %s" % (lhs, rhs), params

Field.register_lookup(NetContainsOrEqual)

class NetContains(Lookup):
    lookup_name = 'net_contains'
    def as_sql(self, qn, connection):
        lhs, lhs_params = self.process_lhs(qn, connection)
        rhs, rhs_params = self.process_rhs(qn, connection)
        params = lhs_params + rhs_params
        return "%s >> %s" % (lhs, rhs), params

Field.register_lookup(NetContains)

class NetContained(Lookup):
    lookup_name = 'net_contained'
    def as_sql(self, qn, connection):
        lhs, lhs_params = self.process_lhs(qn, connection)
        rhs, rhs_params = self.process_rhs(qn, connection)
        params = lhs_params + rhs_params
        return "%s << %s" % (lhs, rhs), params

Field.register_lookup(NetContained)


class NetContainedOrEqual(Lookup):
    lookup_name = 'net_contained_or_equal'
    def as_sql(self, qn, connection):
        lhs, lhs_params = self.process_lhs(qn, connection)
        rhs, rhs_params = self.process_rhs(qn, connection)
        params = lhs_params + rhs_params
        return "%s <<= %s" % (lhs, rhs), params

Field.register_lookup(NetContainedOrEqual)


class NetManager(models.Manager):
    use_for_related_fields = True

    def get_query_set(self):
        return super(NetManager, self).get_queryset()

    def get_queryset(self):
        return super(NetManager, self).get_queryset()