# Copyright (C) 2010  Internet Systems Consortium.
#
# Permission to use, copy, modify, and distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
# DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
# INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
# FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

import unittest
import os
from pydnspp import *
from testutil import *

class EDNSTest(unittest.TestCase):

    def setUp(self):
        self.rrtype = RRType("OPT")
        self.rrclass = RRClass(4096)
        self.rrttl_do_on = RRTTL(0x00008000)
        self.rrttl_do_off = RRTTL(0)
        self.rrttl_badver = RRTTL(0x00018000)
        self.opt_rdata = Rdata(self.rrtype, self.rrclass, bytes())
        self.edns_base = EDNS(Name.ROOT_NAME, self.rrclass, self.rrtype,
                              self.rrttl_do_off, self.opt_rdata)

    def test_badver_construct(self):
        self.assertRaises(InvalidParameter, EDNS, 1)
        self.assertRaises(TypeError, EDNS, 1, 2) # signature mismatch
        self.assertRaises(TypeError, EDNS, 256) # invalid arguments

    def test_dnssec_dobit(self):
        edns = EDNS(Name.ROOT_NAME, self.rrclass, self.rrtype,
                    self.rrttl_do_on, self.opt_rdata)
        self.assertTrue(edns.get_dnssec_awareness())

        edns = EDNS(Name.ROOT_NAME, self.rrclass, self.rrtype,
                    self.rrttl_do_off, self.opt_rdata)
        self.assertFalse(edns.get_dnssec_awareness())

        edns = EDNS()
        self.assertFalse(edns.get_dnssec_awareness())
        edns.set_dnssec_awareness(True)
        self.assertTrue(edns.get_dnssec_awareness())
        edns.set_dnssec_awareness(False);
        self.assertFalse(edns.get_dnssec_awareness())

        self.assertRaises(TypeError, edns.set_dnssec_awareness, "wrong")
        self.assertRaises(TypeError, edns.set_dnssec_awareness, 1)

    def test_udpsize(self):
        edns = EDNS(Name.ROOT_NAME, self.rrclass, self.rrtype,
                    self.rrttl_do_on, self.opt_rdata)
        self.assertEqual(4096, edns.get_udp_size())

        edns = EDNS()
        edns.set_udp_size(511)
        self.assertEqual(511, edns.get_udp_size())
        self.assertRaises(TypeError, edns.set_udp_size, "wrong")

        # Range check.  We need to do this at the binding level, so we need
        # explicit tests for it.
        edns.set_udp_size(0)
        self.assertEqual(0, edns.get_udp_size())
        edns.set_udp_size(65535)
        self.assertEqual(65535, edns.get_udp_size())
        self.assertRaises(ValueError, edns.set_udp_size, 0x10000)
        self.assertRaises(ValueError, edns.set_udp_size, -1)

    def test_get_version(self):
        self.assertEqual(EDNS.SUPPORTED_VERSION, EDNS().get_version())

    def test_bad_wiredata(self):
        self.assertRaises(InvalidParameter, EDNS, Name.ROOT_NAME,
                          self.rrclass, RRType("A"),
                          self.rrttl_do_on, self.opt_rdata)
        self.assertRaises(DNSMessageFORMERR, EDNS, Name("example.com"),
                          self.rrclass, self.rrtype, self.rrttl_do_on,
                          self.opt_rdata)
        self.assertRaises(DNSMessageBADVERS, EDNS, Name.ROOT_NAME,
                          self.rrclass, self.rrtype, self.rrttl_badver,
                          self.opt_rdata)

    def test_to_text(self):
        edns = EDNS()
        edns.set_udp_size(4096)
        expected_str = "; EDNS: version: 0, flags:; udp: 4096\n"
        self.assertEqual(expected_str, edns.to_text())
        self.assertEqual(expected_str, str(edns))

        edns.set_dnssec_awareness(True)
        self.assertEqual("; EDNS: version: 0, flags: do; udp: 4096\n",
                         edns.to_text())

        self.assertEqual("; EDNS: version: 0, flags: do; udp: 4096\n",
                         EDNS(Name.ROOT_NAME, self.rrclass, self.rrtype,
                              RRTTL(0x01008000), self.opt_rdata).to_text())

        self.assertEqual("; EDNS: version: 0, flags: do; udp: 4096\n",
                         EDNS(Name.ROOT_NAME, self.rrclass, self.rrtype,
                              RRTTL(0x00008001), self.opt_rdata).to_text())

    def test_towire_renderer(self):
        renderer = MessageRenderer()
        extrcode_noerror = Rcode.NOERROR.get_extended_code()
        extrcode_badvers = Rcode.BADVERS.get_extended_code()

        self.assertEqual(1, self.edns_base.to_wire(renderer, extrcode_noerror))
        wiredata = read_wire_data("edns_toWire1.wire")
        self.assertEqual(wiredata, renderer.get_data())

        renderer.clear()
        self.edns_base.set_dnssec_awareness(True)
        self.assertEqual(1, self.edns_base.to_wire(renderer, extrcode_noerror))
        wiredata = read_wire_data("edns_toWire2.wire")
        self.assertEqual(wiredata, renderer.get_data())

        renderer.clear()
        self.edns_base.set_dnssec_awareness(True)
        self.assertEqual(1, self.edns_base.to_wire(renderer, extrcode_badvers))
        wiredata = read_wire_data("edns_toWire3.wire")
        self.assertEqual(wiredata, renderer.get_data())

        renderer.clear()
        self.edns_base.set_dnssec_awareness(True)
        self.edns_base.set_udp_size(511)
        self.assertEqual(1, self.edns_base.to_wire(renderer, extrcode_noerror))
        wiredata = read_wire_data("edns_toWire4.wire")
        self.assertEqual(wiredata, renderer.get_data())

        renderer.clear()
        edns = EDNS(Name.ROOT_NAME, self.rrclass, self.rrtype,
                    RRTTL(0x00008001), self.opt_rdata)
        self.assertEqual(1, edns.to_wire(renderer, extrcode_noerror))
        wiredata = read_wire_data("edns_toWire2.wire")
        self.assertEqual(wiredata, renderer.get_data())

        renderer.clear()
        renderer.set_length_limit(10)
        self.edns_base.set_dnssec_awareness(True)
        self.assertEqual(0, self.edns_base.to_wire(renderer, extrcode_noerror))
        self.assertEqual(0, renderer.get_length())

    def test_towire_buffer(self):
        extrcode_noerror = Rcode.NOERROR.get_extended_code()

        obuffer = bytes()
        obuffer = self.edns_base.to_wire(obuffer, extrcode_noerror)
        wiredata = read_wire_data("edns_toWire1.wire")
        self.assertEqual(wiredata, obuffer)

    def test_create_from_rr(self):
        (edns, extrcode) = EDNS.create_from_rr(Name.ROOT_NAME, self.rrclass,
                                               self.rrtype, self.rrttl_do_on,
                                               self.opt_rdata)
        self.assertEqual(EDNS.SUPPORTED_VERSION, edns.get_version())
        self.assertTrue(edns.get_dnssec_awareness())
        self.assertEqual(4096, edns.get_udp_size())
        self.assertEqual(0, extrcode)

        (edns, extrcode) = EDNS.create_from_rr(Name.ROOT_NAME, self.rrclass,
                                               self.rrtype, RRTTL(0x01008000),
                                               self.opt_rdata)
        self.assertEqual(1, extrcode)

        self.assertRaises(DNSMessageBADVERS, EDNS.create_from_rr,
                          Name.ROOT_NAME, self.rrclass, self.rrtype,
                          self.rrttl_badver, self.opt_rdata)

if __name__ == '__main__':
    unittest.main()