tests.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. from django.core.exceptions import ValidationError
  2. from netaddr import IPAddress, IPNetwork, EUI, AddrFormatError
  3. from django import VERSION
  4. from django.db import IntegrityError
  5. from django.forms import ModelForm
  6. from django.test import TestCase
  7. from netfields.models import (CidrTestModel, InetTestModel, NullCidrTestModel,
  8. NullInetTestModel, UniqueInetTestModel,
  9. UniqueCidrTestModel, MACTestModel)
  10. from netfields.mac import mac_unix_common
  11. class BaseSqlTestCase(object):
  12. select = 'SELECT "table"."id", "table"."field" FROM "table" '
  13. def assertSqlEquals(self, qs, sql):
  14. sql = sql.replace('"table"', '"%s"' % self.table)
  15. self.assertEqual(qs.query.get_compiler(qs.db).as_sql()[0], sql)
  16. def assertSqlRaises(self, qs, error):
  17. self.assertRaises(error, qs.query.get_compiler(qs.db).as_sql)
  18. def test_init_with_blank(self):
  19. self.model()
  20. def test_isnull_true_lookup(self):
  21. self.assertSqlEquals(self.qs.filter(field__isnull=True),
  22. self.select + 'WHERE "table"."field" IS NULL')
  23. def test_isnull_false_lookup(self):
  24. self.assertSqlEquals(self.qs.filter(field__isnull=False),
  25. self.select + 'WHERE "table"."field" IS NOT NULL')
  26. def test_save(self):
  27. self.model(field=self.value1).save()
  28. def test_equals_lookup(self):
  29. self.assertSqlEquals(self.qs.filter(field=self.value1),
  30. self.select + 'WHERE "table"."field" = %s ')
  31. def test_exact_lookup(self):
  32. self.assertSqlEquals(self.qs.filter(field__exact=self.value1),
  33. self.select + 'WHERE "table"."field" = %s ')
  34. def test_in_lookup(self):
  35. self.assertSqlEquals(self.qs.filter(field__in=[self.value1, self.value2]),
  36. self.select + 'WHERE "table"."field" IN (%s, %s)')
  37. def test_gt_lookup(self):
  38. self.assertSqlEquals(self.qs.filter(field__gt=self.value1),
  39. self.select + 'WHERE "table"."field" > %s ')
  40. def test_gte_lookup(self):
  41. self.assertSqlEquals(self.qs.filter(field__gte=self.value1),
  42. self.select + 'WHERE "table"."field" >= %s ')
  43. def test_lt_lookup(self):
  44. self.assertSqlEquals(self.qs.filter(field__lt=self.value1),
  45. self.select + 'WHERE "table"."field" < %s ')
  46. def test_lte_lookup(self):
  47. self.assertSqlEquals(self.qs.filter(field__lte=self.value1),
  48. self.select + 'WHERE "table"."field" <= %s ')
  49. def test_range_lookup(self):
  50. self.assertSqlEquals(self.qs.filter(field__range=(self.value1, self.value3)),
  51. self.select + 'WHERE "table"."field" BETWEEN %s and %s')
  52. class BaseInetTestCase(BaseSqlTestCase):
  53. def test_save_object(self):
  54. self.model(field=self.value1).save()
  55. def test_init_with_text_fails(self):
  56. self.assertRaises(ValidationError, self.model, field='abc')
  57. def test_iexact_lookup(self):
  58. self.assertSqlEquals(self.qs.filter(field__iexact=self.value1),
  59. self.select + 'WHERE "table"."field" = %s ')
  60. def test_search_lookup_fails(self):
  61. self.assertSqlRaises(self.qs.filter(field__search='10'), ValueError)
  62. def test_year_lookup_fails(self):
  63. self.assertSqlRaises(self.qs.filter(field__year=1), ValueError)
  64. def test_month_lookup_fails(self):
  65. self.assertSqlRaises(self.qs.filter(field__month=1), ValueError)
  66. def test_day_lookup_fails(self):
  67. self.assertSqlRaises(self.qs.filter(field__day=1), ValueError)
  68. def test_net_contained(self):
  69. self.assertSqlEquals(self.qs.filter(field__net_contained='10.0.0.1/24'),
  70. self.select + 'WHERE "table"."field" << %s ')
  71. def test_net_contained_or_equals(self):
  72. self.assertSqlEquals(self.qs.filter(field__net_contained_or_equal='10.0.0.1/24'),
  73. self.select + 'WHERE "table"."field" <<= %s ')
  74. class BaseInetFieldTestCase(BaseInetTestCase):
  75. value1 = '10.0.0.1'
  76. value2 = '10.0.0.2'
  77. value3 = '10.0.0.10'
  78. def test_startswith_lookup(self):
  79. self.assertSqlEquals(self.qs.filter(field__startswith='10.'),
  80. self.select + 'WHERE HOST("table"."field") ILIKE %s ')
  81. def test_istartswith_lookup(self):
  82. self.assertSqlEquals(self.qs.filter(field__istartswith='10.'),
  83. self.select + 'WHERE HOST("table"."field") ILIKE %s ')
  84. def test_endswith_lookup(self):
  85. self.assertSqlEquals(self.qs.filter(field__endswith='.1'),
  86. self.select + 'WHERE HOST("table"."field") ILIKE %s ')
  87. def test_iendswith_lookup(self):
  88. self.assertSqlEquals(self.qs.filter(field__iendswith='.1'),
  89. self.select + 'WHERE HOST("table"."field") ILIKE %s ')
  90. def test_regex_lookup(self):
  91. self.assertSqlEquals(self.qs.filter(field__regex='10'),
  92. self.select + 'WHERE HOST("table"."field") ~* %s ')
  93. def test_iregex_lookup(self):
  94. self.assertSqlEquals(self.qs.filter(field__iregex='10'),
  95. self.select + 'WHERE HOST("table"."field") ~* %s ')
  96. class BaseCidrFieldTestCase(BaseInetTestCase):
  97. value1 = '10.0.0.1/32'
  98. value2 = '10.0.0.2/24'
  99. value3 = '10.0.0.10/16'
  100. def test_startswith_lookup(self):
  101. self.assertSqlEquals(self.qs.filter(field__startswith='10.'),
  102. self.select + 'WHERE TEXT("table"."field") ILIKE %s ')
  103. def test_istartswith_lookup(self):
  104. self.assertSqlEquals(self.qs.filter(field__istartswith='10.'),
  105. self.select + 'WHERE TEXT("table"."field") ILIKE %s ')
  106. def test_endswith_lookup(self):
  107. self.assertSqlEquals(self.qs.filter(field__endswith='.1'),
  108. self.select + 'WHERE TEXT("table"."field") ILIKE %s ')
  109. def test_iendswith_lookup(self):
  110. self.assertSqlEquals(self.qs.filter(field__iendswith='.1'),
  111. self.select + 'WHERE TEXT("table"."field") ILIKE %s ')
  112. def test_regex_lookup(self):
  113. self.assertSqlEquals(self.qs.filter(field__regex='10'),
  114. self.select + 'WHERE TEXT("table"."field") ~* %s ')
  115. def test_iregex_lookup(self):
  116. self.assertSqlEquals(self.qs.filter(field__iregex='10'),
  117. self.select + 'WHERE TEXT("table"."field") ~* %s ')
  118. def test_net_contains_lookup(self):
  119. self.assertSqlEquals(self.qs.filter(field__net_contains='10.0.0.1'),
  120. self.select + 'WHERE "table"."field" >> %s ')
  121. def test_net_contains_or_equals(self):
  122. self.assertSqlEquals(self.qs.filter(field__net_contains_or_equals='10.0.0.1'),
  123. self.select + 'WHERE "table"."field" >>= %s ')
  124. class TestInetField(BaseInetFieldTestCase, TestCase):
  125. def setUp(self):
  126. self.model = InetTestModel
  127. self.qs = self.model.objects.all()
  128. self.table = 'inet'
  129. def test_save_blank_fails(self):
  130. self.assertRaises(IntegrityError, self.model(field='').save)
  131. def test_save_none_fails(self):
  132. self.assertRaises(IntegrityError, self.model(field=None).save)
  133. def test_save_nothing_fails(self):
  134. self.assertRaises(IntegrityError, self.model().save)
  135. class TestInetFieldNullable(BaseInetFieldTestCase, TestCase):
  136. def setUp(self):
  137. self.model = NullInetTestModel
  138. self.qs = self.model.objects.all()
  139. self.table = 'nullinet'
  140. def test_save_blank(self):
  141. self.model().save()
  142. def test_save_none(self):
  143. self.model(field=None).save()
  144. def test_save_nothing_fails(self):
  145. self.model().save()
  146. class TestInetFieldUnique(BaseInetFieldTestCase, TestCase):
  147. def setUp(self):
  148. self.model = UniqueInetTestModel
  149. self.qs = self.model.objects.all()
  150. self.table = 'uniqueinet'
  151. def test_save_nonunique(self):
  152. self.model(field='1.2.3.4').save()
  153. self.assertRaises(IntegrityError, self.model(field='1.2.3.4').save)
  154. class TestCidrField(BaseCidrFieldTestCase, TestCase):
  155. def setUp(self):
  156. self.model = CidrTestModel
  157. self.qs = self.model.objects.all()
  158. self.table = 'cidr'
  159. def test_save_blank_fails(self):
  160. self.assertRaises(IntegrityError, self.model(field='').save)
  161. def test_save_none_fails(self):
  162. self.assertRaises(IntegrityError, self.model(field=None).save)
  163. def test_save_nothing_fails(self):
  164. self.assertRaises(IntegrityError, self.model().save)
  165. class TestCidrFieldNullable(BaseCidrFieldTestCase, TestCase):
  166. def setUp(self):
  167. self.model = NullCidrTestModel
  168. self.qs = self.model.objects.all()
  169. self.table = 'nullcidr'
  170. def test_save_blank(self):
  171. self.model().save()
  172. def test_save_none(self):
  173. self.model(field=None).save()
  174. def test_save_nothing_fails(self):
  175. self.model().save()
  176. class TestCidrFieldUnique(BaseCidrFieldTestCase, TestCase):
  177. def setUp(self):
  178. self.model = UniqueCidrTestModel
  179. self.qs = self.model.objects.all()
  180. self.table = 'uniquecidr'
  181. def test_save_nonunique(self):
  182. self.model(field='1.2.3.0/24').save()
  183. self.assertRaises(IntegrityError, self.model(field='1.2.3.0/24').save)
  184. class InetAddressTestModelForm(ModelForm):
  185. class Meta:
  186. model = InetTestModel
  187. class TestInetAddressFormField(TestCase):
  188. form_class = InetAddressTestModelForm
  189. def test_form_ipv4_valid(self):
  190. form = self.form_class({'field': '10.0.0.1'})
  191. self.assertTrue(form.is_valid())
  192. self.assertEqual(form.cleaned_data['field'], IPAddress('10.0.0.1'))
  193. def test_form_ipv4_invalid(self):
  194. form = self.form_class({'field': '10.0.0.1.2'})
  195. self.assertFalse(form.is_valid())
  196. def test_form_ipv6(self):
  197. form = self.form_class({'field': '2001:0:1::2'})
  198. self.assertTrue(form.is_valid())
  199. self.assertEqual(form.cleaned_data['field'], IPAddress('2001:0:1::2'))
  200. def test_form_ipv6_invalid(self):
  201. form = self.form_class({'field': '2001:0::1::2'})
  202. self.assertFalse(form.is_valid())
  203. class UniqueInetAddressTestModelForm(ModelForm):
  204. class Meta:
  205. model = UniqueInetTestModel
  206. class TestUniqueInetAddressFormField(TestInetAddressFormField):
  207. form_class = UniqueInetAddressTestModelForm
  208. class CidrAddressTestModelForm(ModelForm):
  209. class Meta:
  210. model = CidrTestModel
  211. class TestCidrAddressFormField(TestCase):
  212. form_class = CidrAddressTestModelForm
  213. def test_form_ipv4_valid(self):
  214. form = self.form_class({'field': '10.0.0.1/24'})
  215. self.assertTrue(form.is_valid())
  216. self.assertEqual(form.cleaned_data['field'], IPNetwork('10.0.0.1/24'))
  217. def test_form_ipv4_invalid(self):
  218. form = self.form_class({'field': '10.0.0.1.2/32'})
  219. self.assertFalse(form.is_valid())
  220. def test_form_ipv6(self):
  221. form = self.form_class({'field': '2001:0:1::2/64'})
  222. self.assertTrue(form.is_valid())
  223. self.assertEqual(form.cleaned_data['field'], IPNetwork('2001:0:1::2/64'))
  224. def test_form_ipv6_invalid(self):
  225. form = self.form_class({'field': '2001:0::1::2/128'})
  226. self.assertFalse(form.is_valid())
  227. class UniqueCidrAddressTestModelForm(ModelForm):
  228. class Meta:
  229. model = UniqueCidrTestModel
  230. class TestUniqueCidrAddressFormField(TestCidrAddressFormField):
  231. form_class = UniqueCidrAddressTestModelForm
  232. class BaseMacTestCase(BaseSqlTestCase):
  233. value1 = '00:aa:2b:c3:dd:44'
  234. value2 = '00:aa:2b:c3:dd:45'
  235. value3 = '00:aa:2b:c3:dd:ff'
  236. def test_save_object(self):
  237. self.model(field=EUI(self.value1)).save()
  238. def test_iexact_lookup(self):
  239. self.assertSqlEquals(self.qs.filter(field__iexact=self.value1),
  240. self.select + 'WHERE UPPER("table"."field"::text) = UPPER(%s) ')
  241. def test_startswith_lookup(self):
  242. self.assertSqlEquals(self.qs.filter(field__startswith='00:'),
  243. self.select + 'WHERE "table"."field"::text LIKE %s ')
  244. def test_istartswith_lookup(self):
  245. self.assertSqlEquals(self.qs.filter(field__istartswith='00:'),
  246. self.select + 'WHERE UPPER("table"."field"::text) LIKE UPPER(%s) ')
  247. def test_endswith_lookup(self):
  248. self.assertSqlEquals(self.qs.filter(field__endswith=':ff'),
  249. self.select + 'WHERE "table"."field"::text LIKE %s ')
  250. def test_iendswith_lookup(self):
  251. self.assertSqlEquals(self.qs.filter(field__iendswith=':ff'),
  252. self.select + 'WHERE UPPER("table"."field"::text) LIKE UPPER(%s) ')
  253. def test_regex_lookup(self):
  254. if VERSION[:2] < (1, 6):
  255. self.assertSqlEquals(self.qs.filter(field__regex='00'),
  256. self.select + 'WHERE "table"."field" ~ %s ')
  257. else:
  258. self.assertSqlEquals(self.qs.filter(field__regex='00'),
  259. self.select + 'WHERE "table"."field"::text ~ %s ')
  260. def test_iregex_lookup(self):
  261. if VERSION[:2] < (1, 6):
  262. self.assertSqlEquals(self.qs.filter(field__iregex='00'),
  263. self.select + 'WHERE "table"."field" ~* %s ')
  264. else:
  265. self.assertSqlEquals(self.qs.filter(field__iregex='00'),
  266. self.select + 'WHERE "table"."field"::text ~* %s ')
  267. class TestMacAddressField(BaseMacTestCase, TestCase):
  268. def setUp(self):
  269. self.model = MACTestModel
  270. self.qs = self.model.objects.all()
  271. self.table = 'mac'
  272. def test_save_blank(self):
  273. self.model().save()
  274. def test_save_none(self):
  275. self.model(field=None).save()
  276. def test_save_nothing_fails(self):
  277. self.model().save()
  278. def test_invalid_fails(self):
  279. self.assertRaises(ValidationError, self.model(field='foobar').save)
  280. class MacAddressTestModelForm(ModelForm):
  281. class Meta:
  282. model = MACTestModel
  283. class TestMacAddressFormField(TestCase):
  284. def setUp(self):
  285. self.mac = EUI('00:aa:2b:c3:dd:44', dialect=mac_unix_common)
  286. def test_unix(self):
  287. form = MacAddressTestModelForm({'field': '0:AA:2b:c3:dd:44'})
  288. self.assertTrue(form.is_valid())
  289. self.assertEqual(form.cleaned_data['field'], self.mac)
  290. def test_unix_common(self):
  291. form = MacAddressTestModelForm({'field': '00:aa:2b:c3:dd:44'})
  292. self.assertTrue(form.is_valid())
  293. self.assertEqual(form.cleaned_data['field'], self.mac)
  294. def test_eui48(self):
  295. form = MacAddressTestModelForm({'field': '00-AA-2B-C3-DD-44'})
  296. self.assertTrue(form.is_valid())
  297. self.assertEqual(form.cleaned_data['field'], self.mac)
  298. def test_cisco(self):
  299. form = MacAddressTestModelForm({'field': '00aa.2bc3.dd44'})
  300. self.assertTrue(form.is_valid())
  301. self.assertEqual(form.cleaned_data['field'], self.mac)
  302. def test_24bit_colon(self):
  303. form = MacAddressTestModelForm({'field': '00aa2b:c3dd44'})
  304. self.assertTrue(form.is_valid())
  305. self.assertEqual(form.cleaned_data['field'], self.mac)
  306. def test_24bit_hyphen(self):
  307. form = MacAddressTestModelForm({'field': '00aa2b-c3dd44'})
  308. self.assertTrue(form.is_valid())
  309. self.assertEqual(form.cleaned_data['field'], self.mac)
  310. def test_bare(self):
  311. form = MacAddressTestModelForm({'field': '00aa2b:c3dd44'})
  312. self.assertTrue(form.is_valid())
  313. self.assertEqual(form.cleaned_data['field'], self.mac)
  314. def test_invalid(self):
  315. form = MacAddressTestModelForm({'field': 'notvalid'})
  316. self.assertFalse(form.is_valid())