Browse Source

We need tests. Lots of tests.
Not complete yet, but getting there


git-svn-id: svn://bind10.isc.org/svn/bind10/experiments/python-binding@1947 e5f2f494-b856-4b98-b285-d166d9295462

Jelte Jansen 15 years ago
parent
commit
2d205860f9

+ 15 - 0
src/lib/dns/python/TODO

@@ -5,6 +5,9 @@ add statics for RRClass::IN() (RRClass.IN()) etc.
 same for RRType? (xfrout.py.in line 256)
 
 __str__ for name, question, everything with to_text()
+rich compare for name (at least so we can have ==)
+
+should Name.downcase() return a ref to itself?
 
 All constructors based on buffers need an optional position
 argument (like question_python has now)
@@ -13,3 +16,15 @@ at question.to_wire(bytes) does not seem to work right (only return
 value seems correct, while i'd like in-place addition if possible)
 
 creating a render message and not setting opcode/rcode results in a segfault later (nullpointer)
+
+
+The function set wrapped is not complete; for instance, in
+MessageRenderer, we really only provide the high-level readout
+functions. Do we need access to the writers? (there is one set() right
+now).
+
+All constants are now added named in the base module, while they should
+be added as class constants. Dunno how though.
+
+
+segfault when comparing with bad type like int (at least for Name and Rcode, but probably for the rest too)

+ 1 - 1
src/lib/dns/python/libdns_python.cc

@@ -27,8 +27,8 @@
 
 //#include "buffer_python.cc"
 // order is important here! (TODO: document dependencies)
-#include "name_python.cc"
 #include "messagerenderer_python.cc"
+#include "name_python.cc"
 #include "rrclass_python.cc"
 #include "rrtype_python.cc"
 #include "rrttl_python.cc"

+ 1 - 3
src/lib/dns/python/message_python.cc

@@ -627,7 +627,7 @@ Opcode_RESERVED15(s_Opcode* self UNUSED_PARAM)
 static PyObject* 
 Opcode_richcmp(s_Opcode* self, s_Opcode* other, int op)
 {
-    bool c;
+    bool c = false;
 
     // Only equals and not equals here, unorderable type
     switch (op) {
@@ -653,8 +653,6 @@ Opcode_richcmp(s_Opcode* self, s_Opcode* other, int op)
         PyErr_SetString(PyExc_TypeError, "Unorderable type; Opcode");
         return NULL;
         break;
-    default:
-        assert(0);              // XXX: should trigger an exception
     }
     if (c)
         Py_RETURN_TRUE;

+ 39 - 33
src/lib/dns/python/name_python.cc

@@ -76,9 +76,9 @@ static PyObject* NameComparisonResult_getCommonLabels(s_NameComparisonResult* se
 static PyObject* NameComparisonResult_getRelation(s_NameComparisonResult* self);
 
 static PyMethodDef NameComparisonResult_methods[] = {
-    { "getOrder", (PyCFunction)NameComparisonResult_getOrder, METH_NOARGS, "Return the order" },
-    { "getCommonLabels", (PyCFunction)NameComparisonResult_getCommonLabels, METH_NOARGS, "Return the number of common labels" },
-    { "getRelation", (PyCFunction)NameComparisonResult_getRelation, METH_NOARGS, "Return the relation" },
+    { "get_order", (PyCFunction)NameComparisonResult_getOrder, METH_NOARGS, "Return the order" },
+    { "get_common_labels", (PyCFunction)NameComparisonResult_getCommonLabels, METH_NOARGS, "Return the number of common labels" },
+    { "get_relation", (PyCFunction)NameComparisonResult_getRelation, METH_NOARGS, "Return the relation" },
     { NULL, NULL, 0, NULL }
 };
 
@@ -176,7 +176,7 @@ typedef struct {
 static int Name_init(s_Name* self, PyObject* args);
 static void Name_destroy(s_Name* self);
 
-static PyObject* Name_toWire(s_Name* self);
+static PyObject* Name_toWire(s_Name* self, PyObject* args);
 static PyObject* Name_toText(s_Name* self);
 static PyObject* Name_getLabelCount(s_Name* self);
 static PyObject* Name_at(s_Name* self, PyObject* args);
@@ -197,7 +197,7 @@ static PyMethodDef Name_methods[] = {
     { "get_length", (PyCFunction)Name_getLength, METH_NOARGS, "Return the length" },
     { "get_labelcount", (PyCFunction)Name_getLabelCount, METH_NOARGS, "Return the number of labels" },
     { "to_text", (PyCFunction)Name_toText, METH_NOARGS, "Return the string representation" },
-    { "to_wire", (PyCFunction)Name_toWire, METH_NOARGS, "Return the wire format" },
+    { "to_wire", (PyCFunction)Name_toWire, METH_VARARGS, "Return the wire format" },
     { "compare", (PyCFunction)Name_compare, METH_VARARGS, "Compare" },
     { "equals", (PyCFunction)Name_equals, METH_VARARGS, "Equals" },
     { "split", (PyCFunction)Name_split, METH_VARARGS, "split" },
@@ -305,15 +305,17 @@ Name_init(s_Name* self, PyObject* args)
     }
     PyErr_Clear();
 
-    const char* b;
+    PyObject* bytes_obj;
+    const char* bytes;
     Py_ssize_t len;
-    unsigned int position;
+    unsigned int position = 0;
 
     /* fromWire */
-    if (PyArg_ParseTuple(args, "y#I|O!", &b, &len, &position,
-                         &PyBool_Type, &downcase)) {
+    if (PyArg_ParseTuple(args, "O|IO!", &bytes_obj, &position,
+                         &PyBool_Type, &downcase) &&
+                         PyObject_AsCharBuffer(bytes_obj, &bytes, &len) != -1) {
         try {
-            InputBuffer buffer(b, len);
+            InputBuffer buffer(bytes, len);
 
             buffer.setPosition(position);
             self->name = new Name(buffer, downcase == Py_True);
@@ -322,23 +324,9 @@ Name_init(s_Name* self, PyObject* args)
             PyErr_SetString(po_InvalidBufferPosition,
                             "InvalidBufferPosition");
             return -1;
-        } catch (TooLongName) {
-            PyErr_SetString(po_TooLongName, "TooLongName");
-            return -1;
-        } catch (BadLabelType) {
-            PyErr_SetString(po_BadLabelType, "BadLabelType");
-            return -1;
         } catch (DNSMessageFORMERR) {
             PyErr_SetString(po_DNSMessageFORMERR, "DNSMessageFORMERR");
             return -1;
-        } catch (IncompleteName) {
-            PyErr_SetString(po_IncompleteName, "IncompleteName");
-            return -1;
-#ifdef CATCHMEMERR
-        } catch (std::bad_alloc) {
-            PyErr_NoMemory();
-            return -1;
-#endif
         } catch (...) {
             PyErr_SetString(po_IscException, "Unexpected?!");
             return -1;
@@ -348,7 +336,7 @@ Name_init(s_Name* self, PyObject* args)
 
     PyErr_Clear();
     PyErr_SetString(PyExc_TypeError,
-                    "fromText and fromWire Name constructors don't match");
+                    "No valid types in Name constructor (should be string or sequence and offset");
     return -1;
 }
 
@@ -394,15 +382,33 @@ Name_toText(s_Name* self)
     return Py_BuildValue("s", self->name->toText().c_str());
 }
 
-// XX TODO: renderer and direct versions
 static PyObject*
-Name_toWire(s_Name* self)
+Name_toWire(s_Name* self, PyObject* args)
 {
-    OutputBuffer buffer(255);
-
-    self->name->toWire(buffer);
-    return Py_BuildValue("y#", buffer.getData(),
-                         (Py_ssize_t) buffer.getLength());
+    PyObject* bytes;
+    s_MessageRenderer* mr;
+    
+    if (PyArg_ParseTuple(args, "O", &bytes) && PySequence_Check(bytes)) {
+        PyObject* bytes_o = bytes;
+        
+        OutputBuffer buffer(255);
+        self->name->toWire(buffer);
+        PyObject* n = PyBytes_FromStringAndSize((const char*) buffer.getData(), buffer.getLength());
+        PyObject* result = PySequence_InPlaceConcat(bytes_o, n);
+        // We need to release the object we temporarily created here
+        // to prevent memory leak
+        Py_DECREF(n);
+        return result;
+    } else if (PyArg_ParseTuple(args, "O!", &messagerenderer_type, (PyObject**) &mr)) {
+        self->name->toWire(*mr->messagerenderer);
+        // If we return NULL it is seen as an error, so use this for
+        // None returns
+        Py_RETURN_NONE;
+    }
+    PyErr_Clear();
+    PyErr_SetString(PyExc_TypeError,
+                    "toWire argument must be a sequence object or a MessageRenderer");
+    return NULL;
 }
 
 static PyObject*
@@ -525,7 +531,7 @@ Name_concatenate(s_Name* self, PyObject* args)
         try {
             ret->name = new Name(self->name->concatenate(*other->name));
         } catch (isc::dns::TooLongName tln) {
-            PyErr_SetString(PyExc_IndexError, tln.what());
+            PyErr_SetString(po_TooLongName, tln.what());
             ret->name = NULL;
         }
         if (ret->name == NULL) {

+ 2 - 9
src/lib/dns/python/rrclass_python.cc

@@ -165,15 +165,8 @@ RRClass_init(s_RRClass* self, PyObject* args)
             PyErr_Clear();
             return 0;
         }
-    } catch (IncompleteRRClass icc) {
-        // Ok so one of our functions has thrown a C++ exception.
-        // We need to translate that to a Python Exception
-        // First clear any existing error that was set
-        PyErr_Clear();
-        // Now set our own exception
-        PyErr_SetString(po_InvalidRRClass, icc.what());
-        // And return negative
-        return -1;
+    /* Incomplete is never thrown, a type error would have already been raised
+     * when we try to read the 2 bytes above */
     } catch (InvalidRRClass ic) {
         PyErr_Clear();
         PyErr_SetString(po_InvalidRRClass, ic.what());

+ 5 - 0
src/lib/dns/python/tests/Makefile.am

@@ -1,6 +1,11 @@
 PYTESTS = message_python_test.py
+PYTESTS += messagerenderer_python_test.py
+PYTESTS += name_python_test.py
 PYTESTS += question_python_test.py
+PYTESTS += rdata_python_test.py
+PYTESTS += rrclass_python_test.py
 PYTESTS += rrset_python_test.py
+PYTESTS += rrttl_python_test.py
 PYTESTS += rrtype_python_test.py
 
 EXTRA_DIST = $(PYTESTS)

+ 269 - 13
src/lib/dns/python/tests/message_python_test.py

@@ -21,6 +21,251 @@ import unittest
 import os
 from libdns_python import *
 
+
+class MessageFlagTest(unittest.TestCase):
+    def test_init(self):
+        self.assertRaises(NotImplementedError, MessageFlag)
+
+    def test_get_bit(self):
+        self.assertEqual(0x8000, MessageFlag.QR().get_bit())
+        self.assertEqual(0x0400, MessageFlag.AA().get_bit())
+        self.assertEqual(0x0200, MessageFlag.TC().get_bit())
+        self.assertEqual(0x0100, MessageFlag.RD().get_bit())
+        self.assertEqual(0x0080, MessageFlag.RA().get_bit())
+        self.assertEqual(0x0020, MessageFlag.AD().get_bit())
+        self.assertEqual(0x0010, MessageFlag.CD().get_bit())
+
+class OpcodeTest(unittest.TestCase):
+    def test_init(self):
+        self.assertRaises(NotImplementedError, Opcode)
+
+    def test_get_code(self):
+        self.assertEqual(0, Opcode.QUERY().get_code())
+        self.assertEqual(1, Opcode.IQUERY().get_code())
+        self.assertEqual(2, Opcode.STATUS().get_code())
+        self.assertEqual(3, Opcode.RESERVED3().get_code())
+        self.assertEqual(4, Opcode.NOTIFY().get_code())
+        self.assertEqual(5, Opcode.UPDATE().get_code())
+        self.assertEqual(6, Opcode.RESERVED6().get_code())
+        self.assertEqual(7, Opcode.RESERVED7().get_code())
+        self.assertEqual(8, Opcode.RESERVED8().get_code())
+        self.assertEqual(9, Opcode.RESERVED9().get_code())
+        self.assertEqual(10, Opcode.RESERVED10().get_code())
+        self.assertEqual(11, Opcode.RESERVED11().get_code())
+        self.assertEqual(12, Opcode.RESERVED12().get_code())
+        self.assertEqual(13, Opcode.RESERVED13().get_code())
+        self.assertEqual(14, Opcode.RESERVED14().get_code())
+        self.assertEqual(15, Opcode.RESERVED15().get_code())
+
+    def test_to_text(self):
+        self.assertEqual("QUERY", Opcode.QUERY().to_text())
+        self.assertEqual("QUERY", Opcode.QUERY().__str__())
+        self.assertEqual("IQUERY", Opcode.IQUERY().to_text())
+        self.assertEqual("STATUS", Opcode.STATUS().to_text())
+        self.assertEqual("RESERVED3", Opcode.RESERVED3().to_text())
+        self.assertEqual("NOTIFY", Opcode.NOTIFY().to_text())
+        self.assertEqual("UPDATE", Opcode.UPDATE().to_text())
+        self.assertEqual("RESERVED6", Opcode.RESERVED6().to_text())
+        self.assertEqual("RESERVED7", Opcode.RESERVED7().to_text())
+        self.assertEqual("RESERVED8", Opcode.RESERVED8().to_text())
+        self.assertEqual("RESERVED9", Opcode.RESERVED9().to_text())
+        self.assertEqual("RESERVED10", Opcode.RESERVED10().to_text())
+        self.assertEqual("RESERVED11", Opcode.RESERVED11().to_text())
+        self.assertEqual("RESERVED12", Opcode.RESERVED12().to_text())
+        self.assertEqual("RESERVED13", Opcode.RESERVED13().to_text())
+        self.assertEqual("RESERVED14", Opcode.RESERVED14().to_text())
+        self.assertEqual("RESERVED15", Opcode.RESERVED15().to_text())
+
+    def test_richcmp(self):
+        o1 = Opcode.QUERY()
+        o2 = Opcode.NOTIFY()
+        o3 = Opcode.NOTIFY()
+        self.assertTrue(o2 == o3)
+        self.assertTrue(o1 != o2)
+        # can't use assertRaises here...
+        try:
+            o1 < o2
+        except Exception as err:
+            self.assertEqual(TypeError, type(err))
+        try:
+            o1 <= o2
+        except Exception as err:
+            self.assertEqual(TypeError, type(err))
+        try:
+            o1 > o2
+        except Exception as err:
+            self.assertEqual(TypeError, type(err))
+        try:
+            o1 >= o2
+        except Exception as err:
+            self.assertEqual(TypeError, type(err))
+
+class RcodeTest(unittest.TestCase):
+    def test_init(self):
+        self.assertRaises(TypeError, Rcode, "wrong")
+        self.assertRaises(OverflowError, Rcode, 65536)
+
+    def test_get_code(self):
+        self.assertEqual(0, Rcode.NOERROR().get_code())
+        self.assertEqual(1, Rcode.FORMERR().get_code())
+        self.assertEqual(2, Rcode.SERVFAIL().get_code())
+        self.assertEqual(3, Rcode.NXDOMAIN().get_code())
+        self.assertEqual(4, Rcode.NOTIMP().get_code())
+        self.assertEqual(5, Rcode.REFUSED().get_code())
+        self.assertEqual(6, Rcode.YXDOMAIN().get_code())
+        self.assertEqual(7, Rcode.YXRRSET().get_code())
+        self.assertEqual(8, Rcode.NXRRSET().get_code())
+        self.assertEqual(9, Rcode.NOTAUTH().get_code())
+        self.assertEqual(10, Rcode.NOTZONE().get_code())
+        self.assertEqual(11, Rcode.RESERVED11().get_code())
+        self.assertEqual(12, Rcode.RESERVED12().get_code())
+        self.assertEqual(13, Rcode.RESERVED13().get_code())
+        self.assertEqual(14, Rcode.RESERVED14().get_code())
+        self.assertEqual(15, Rcode.RESERVED15().get_code())
+
+    def test_to_text(self):
+        self.assertEqual("NOERROR", Rcode(0).to_text())
+        self.assertEqual("NOERROR", Rcode(0).__str__())
+        self.assertEqual("FORMERR", Rcode(1).to_text())
+        self.assertEqual("SERVFAIL", Rcode(2).to_text())
+        self.assertEqual("NXDOMAIN", Rcode(3).to_text())
+        self.assertEqual("NOTIMP", Rcode(4).to_text())
+        self.assertEqual("REFUSED", Rcode(5).to_text())
+        self.assertEqual("YXDOMAIN", Rcode(6).to_text())
+        self.assertEqual("YXRRSET", Rcode(7).to_text())
+        self.assertEqual("NXRRSET", Rcode(8).to_text())
+        self.assertEqual("NOTAUTH", Rcode(9).to_text())
+        self.assertEqual("NOTZONE", Rcode(10).to_text())
+        self.assertEqual("RESERVED11", Rcode(11).to_text())
+        self.assertEqual("RESERVED12", Rcode(12).to_text())
+        self.assertEqual("RESERVED13", Rcode(13).to_text())
+        self.assertEqual("RESERVED14", Rcode(14).to_text())
+        self.assertEqual("RESERVED15", Rcode(15).to_text())
+        
+    def test_richcmp(self):
+        r1 = Rcode.NOERROR()
+        r2 = Rcode.FORMERR()
+        r3 = Rcode.FORMERR()
+        self.assertTrue(r2 == r3)
+        self.assertTrue(r1 != r2)
+        # can't use assertRaises here...
+        try:
+            r1 < r2
+        except Exception as err:
+            self.assertEqual(TypeError, type(err))
+        try:
+            r1 <= r2
+        except Exception as err:
+            self.assertEqual(TypeError, type(err))
+        try:
+            r1 > r2
+        except Exception as err:
+            self.assertEqual(TypeError, type(err))
+        try:
+            r1 >= r2
+        except Exception as err:
+            self.assertEqual(TypeError, type(err))
+
+class SectionTest(unittest.TestCase):
+
+    def test_init(self):
+        self.assertRaises(NotImplementedError, Section)
+
+    def test_get_code(self):
+        self.assertEqual(0, Section.QUESTION().get_code())
+        self.assertEqual(1, Section.ANSWER().get_code())
+        self.assertEqual(2, Section.AUTHORITY().get_code())
+        self.assertEqual(3, Section.ADDITIONAL().get_code())
+
+    def test_richcmp(self):
+        s1 = Section.QUESTION()
+        s2 = Section.ANSWER()
+        s3 = Section.ANSWER()
+        self.assertTrue(s2 == s3)
+        self.assertTrue(s1 != s2)
+        # can't use assertRaises here...
+        try:
+            s1 < s2
+        except Exception as err:
+            self.assertEqual(TypeError, type(err))
+        try:
+            s1 <= s2
+        except Exception as err:
+            self.assertEqual(TypeError, type(err))
+        try:
+            s1 > s2
+        except Exception as err:
+            self.assertEqual(TypeError, type(err))
+        try:
+            s1 >= s2
+        except Exception as err:
+            self.assertEqual(TypeError, type(err))
+        
+
+class MessageTest(unittest.TestCase):
+
+    def setUp(self):
+        self.p = Message(PARSE)
+        self.r = Message(RENDER)
+        
+    def test_init(self):
+        self.assertRaises(TypeError, Message, 3)
+        self.assertRaises(TypeError, Message, "wrong")
+
+    def test_get_header_flag(self):
+        self.assertRaises(TypeError, self.p.get_header_flag, "wrong")
+        self.assertFalse(self.p.get_header_flag(MessageFlag.AA()))
+
+    def test_set_header_flag(self):
+        self.assertRaises(TypeError, self.r.set_header_flag, "wrong")
+        self.assertRaises(TypeError, self.r.clear_header_flag, "wrong")
+
+        self.assertFalse(self.r.get_header_flag(MessageFlag.AA()))
+        self.r.set_header_flag(MessageFlag.AA())
+        self.assertTrue(self.r.get_header_flag(MessageFlag.AA()))
+        self.r.clear_header_flag(MessageFlag.AA())
+        self.assertFalse(self.r.get_header_flag(MessageFlag.AA()))
+
+    def test_set_DNSSEC_supported(self):
+        self.assertRaises(TypeError, self.r.set_dnssec_supported, "wrong")
+
+        self.assertFalse(self.r.is_dnssec_supported())
+        self.r.set_dnssec_supported(True)
+        self.assertTrue(self.r.is_dnssec_supported())
+        self.r.set_dnssec_supported(False)
+        self.assertFalse(self.r.is_dnssec_supported())
+
+    def test_set_udp_size(self):
+        self.assertRaises(TypeError, self.r.set_udp_size, "wrong")
+
+    def test_set_qid(self):
+        self.assertRaises(TypeError, self.r.set_qid, "wrong")
+
+    def test_set_rcode(self):
+        self.assertRaises(TypeError, self.r.set_rcode, "wrong")
+
+    def test_set_opcode(self):
+        self.assertRaises(TypeError, self.r.set_opcode, "wrong")
+
+    def test_get_section(self):
+        self.assertRaises(TypeError, self.r.get_section, "wrong")
+
+    def test_add_rrset(self):
+        self.assertRaises(TypeError, self.r.add_rrset, "wrong")
+
+    def test_clear(self):
+        self.assertEqual(None, self.r.clear(PARSE))
+        self.assertEqual(None, self.r.clear(RENDER))
+        self.assertRaises(TypeError, self.r.clear, "wrong")
+        self.assertRaises(TypeError, self.r.clear, 3)
+
+    def test_to_wire(self):
+        self.assertRaises(TypeError, self.r.to_wire, 1)
+
+    def test_from_wire(self):
+        self.assertRaises(TypeError, self.r.from_wire, 1)
+
+# helper functions for tests taken from c++ unittests
 if "TESTDATA_PATH" in os.environ:
     testdata_path = os.environ["TESTDATA_PATH"]
 else:
@@ -44,8 +289,9 @@ def factoryFromFile(message, file):
     message.from_wire(data)
     pass
 
-class MessageTest(unittest.TestCase):
+class ConvertedUnittests(unittest.TestCase):
     
+    # tests below based on c++ unit tests
     def test_RcodeConstruct(self):
         # normal cases
         self.assertEqual(0, Rcode(0).get_code())
@@ -220,7 +466,7 @@ class MessageTest(unittest.TestCase):
                           message_parse,
                           "message_fromWire9")
     
-    def test_toWire(self):
+    def test_to_text_and_wire(self):
         message_render = Message(RENDER)
         message_render.set_qid(0x1035)
         message_render.set_opcode(Opcode.QUERY())
@@ -228,25 +474,35 @@ class MessageTest(unittest.TestCase):
         message_render.set_header_flag(MessageFlag.QR())
         message_render.set_header_flag(MessageFlag.RD())
         message_render.set_header_flag(MessageFlag.AA())
-        #message_render.addQuestion(Question(Name("test.example.com"), RRClass.IN(),
-                                            #RRType.A()))
+        message_render.add_question(Question(Name("test.example.com"), RRClass("IN"), RRType("A")))
         rrset = RRset(Name("test.example.com"), RRClass("IN"),
                                             RRType("A"), RRTTL(3600))
-        #rrset.add_rdata(in.A("192.0.2.1"))
-        #rrset.addRdata(in.A("192.0.2.2"))
-        #message_render.addRRset(Section.ANSWER(), rrset)
+        rrset.add_rdata(Rdata(RRType("A"), RRClass("IN"), "192.0.2.1"))
+        rrset.add_rdata(Rdata(RRType("A"), RRClass("IN"), "192.0.2.2"))
+        message_render.add_rrset(Section.ANSWER(), rrset)
     
-        #self.assertEqual(1, message_render.get_rr_count(Section.QUESTION()))
-        #self.assertEqual(2, message_render.get_rr_count(Section.ANSWER()))
+        self.assertEqual(1, message_render.get_rr_count(Section.QUESTION()))
+        self.assertEqual(2, message_render.get_rr_count(Section.ANSWER()))
         self.assertEqual(0, message_render.get_rr_count(Section.AUTHORITY()))
         self.assertEqual(0, message_render.get_rr_count(Section.ADDITIONAL()))
 
         renderer = MessageRenderer()
         message_render.to_wire(renderer)
-        #vector<unsigned char> data;
-        #UnitTestUtil.readWireData("testdata/message_toWire1", data)
-        #EXPECT_PRED_FORMAT4(UnitTestUtil.matchWireData, obuffer.getData(),
-                            #obuffer.getLength(), &data[0], data.size())
+        self.assertEqual(b'\x105\x85\x00\x00\x01\x00\x02\x00\x00\x00\x00\x04test\x07example\x03com\x00\x00\x01\x00\x01\xc0\x0c\x00\x01\x00\x01\x00\x00\x0e\x10\x00\x04\xc0\x00\x02\x01\xc0\x0c\x00\x01\x00\x01\x00\x00\x0e\x10\x00\x04\xc0\x00\x02\x02',
+                         renderer.get_data())
+        msg_str =\
+""";; ->>HEADER<<- opcode: QUERY, status: NOERROR, id: 4149
+;; flags: qr aa rd ; QUESTION: 1, ANSWER: 2, AUTHORITY: 0, ADDITIONAL: 0
+
+;; QUESTION SECTION:
+;test.example.com. IN A
+
+;; ANSWER SECTION:
+test.example.com. 3600 IN A 192.0.2.1
+test.example.com. 3600 IN A 192.0.2.2
+"""
+        self.assertEqual(msg_str, message_render.to_text())
+        self.assertEqual(msg_str, message_render.__str__())
 
 if __name__ == '__main__':
     unittest.main()

+ 107 - 0
src/lib/dns/python/tests/messagerenderer_python_test.py

@@ -0,0 +1,107 @@
+# Copyright (C) 2010  Internet Systems Consortium.
+#
+# Permission to use, copy, modify, and distribute this software for any
+# purpose with or without fee is hereby granted, provided that the above
+# copyright notice and this permission notice appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
+# DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
+# INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
+# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
+# FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
+# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
+# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+#
+# Tests for the messagerenderer part of the libdns_python module
+#
+
+import unittest
+import os
+from libdns_python import *
+
+class MessageRendererTest(unittest.TestCase):
+
+    def setUp(self):
+        name = Name("example.com")
+        c = RRClass("IN")
+        t = RRType("A")
+        ttl = RRTTL("3600")
+        
+        message = Message(RENDER)
+        message.set_qid(123)
+        message.set_opcode(Opcode.QUERY())
+        message.add_question(Question(name, c, t))
+
+        self.message1 = message
+        message = Message(RENDER)
+        message.set_qid(123)
+        message.set_header_flag(MessageFlag.AA())
+        message.set_header_flag(MessageFlag.QR())
+        message.set_opcode(Opcode.QUERY())
+        message.set_rcode(Rcode.NOERROR())
+        message.add_question(Question(name, c, t))
+        rrset = RRset(name, c, t, ttl)
+        rrset.add_rdata(Rdata(t, c, "192.0.2.98"))
+        rrset.add_rdata(Rdata(t, c, "192.0.2.99"))
+        message.add_rrset(Section.AUTHORITY(), rrset)
+        self.message2 = message
+
+        #message = Message(RENDER)
+        #message.set_qid(123)
+        #message.set_header_flag(MessageFlag.AA())
+        #message.set_header_flag(MessageFlag.QR())
+        #message.set_opcode(Opcode.QUERY())
+        #message.set_rcode(Rcode.NOERROR())
+        #message.add_question(Question(name, c, t))
+        #rrset = RRset(name, c, t, ttl)
+        #for i in range(1, 99):
+        #    rrset.add_rdata(Rdata(t, c, "192.0.2." + str(i)))
+        #message.add_rrset(Section.AUTHORITY(), rrset)
+        #self.message3 = message
+
+        self.renderer1 = MessageRenderer()
+        self.renderer2 = MessageRenderer()
+        self.renderer3 = MessageRenderer()
+        self.renderer3.set_length_limit(50)
+        self.message1.to_wire(self.renderer1)
+        self.message2.to_wire(self.renderer2)
+        self.message2.to_wire(self.renderer3)
+        
+    
+    def test_messagerenderer_get_data(self):
+        data1 = b'\x00{\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x07example\x03com\x00\x00\x01\x00\x01'
+        self.assertEqual(data1, self.renderer1.get_data())
+        data2 = b'\x00{\x84\x00\x00\x01\x00\x00\x00\x02\x00\x00\x07example\x03com\x00\x00\x01\x00\x01\xc0\x0c\x00\x01\x00\x01\x00\x00\x0e\x10\x00\x04\xc0\x00\x02b\xc0\x0c\x00\x01\x00\x01\x00\x00\x0e\x10\x00\x04\xc0\x00\x02c'
+        self.assertEqual(data2, self.renderer2.get_data())
+        
+    def test_messagerenderer_get_length(self):
+        self.assertEqual(29, self.renderer1.get_length())
+        self.assertEqual(61, self.renderer2.get_length())
+        self.assertEqual(45, self.renderer3.get_length())
+
+    def test_messagerenderer_is_truncated(self):
+        self.assertFalse(self.renderer1.is_truncated())
+        self.assertFalse(self.renderer2.is_truncated())
+        self.assertTrue(self.renderer3.is_truncated())
+
+    def test_messagerenderer_get_length_limit(self):
+        self.assertEqual(512, self.renderer1.get_length_limit())
+        self.assertEqual(512, self.renderer2.get_length_limit())
+        self.assertEqual(50, self.renderer3.get_length_limit())
+
+    def test_messagerenderer_set_truncated(self):
+        self.assertFalse(self.renderer1.is_truncated())
+        self.renderer1.set_truncated()
+        self.assertTrue(self.renderer1.is_truncated())
+
+    def test_messagerenderer_set_length_limit(self):
+        renderer = MessageRenderer()
+        self.assertEqual(512, renderer.get_length_limit())
+        renderer.set_length_limit(1024)
+        self.assertEqual(1024, renderer.get_length_limit())
+        self.assertRaises(TypeError, renderer.set_length_limit, "wrong")
+
+if __name__ == '__main__':
+    unittest.main()

+ 178 - 0
src/lib/dns/python/tests/name_python_test.py

@@ -0,0 +1,178 @@
+# Copyright (C) 2010  Internet Systems Consortium.
+#
+# Permission to use, copy, modify, and distribute this software for any
+# purpose with or without fee is hereby granted, provided that the above
+# copyright notice and this permission notice appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
+# DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
+# INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
+# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
+# FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
+# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
+# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+#
+# Tests for the messagerenderer part of the libdns_python module
+#
+
+import unittest
+import os
+from libdns_python import *
+
+class NameComparisonTest(unittest.TestCase):
+    def setUp(self):
+        self.name1 = Name("aaaa.example.com")
+        self.name2 = Name("bbbb.example.com")
+        self.name3 = Name("cccc.example.com")
+        self.name4 = Name("aaaa.example.com")
+        self.name5 = Name("something.completely.different")
+
+        self.ncr12 = self.name1.compare(self.name2)
+        self.ncr13 = self.name1.compare(self.name3)
+        self.ncr23 = self.name2.compare(self.name3)
+        self.ncr21 = self.name2.compare(self.name1)
+        self.ncr32 = self.name3.compare(self.name2)
+        self.ncr31 = self.name3.compare(self.name1)
+        self.ncr14 = self.name1.compare(self.name4)
+        self.ncr15 = self.name1.compare(self.name5)
+
+    def test_init(self):
+        self.assertRaises(NotImplementedError, NameComparisonResult)
+        
+    def test_get_order(self):
+        self.assertEqual(-1, self.ncr12.get_order())
+        self.assertEqual(-2, self.ncr13.get_order())
+        self.assertEqual(-1, self.ncr23.get_order())
+        self.assertEqual(1, self.ncr21.get_order())
+        self.assertEqual(1, self.ncr32.get_order())
+        self.assertEqual(2, self.ncr31.get_order())
+        self.assertEqual(0, self.ncr14.get_order())
+
+    def test_get_common_labels(self):
+        self.assertEqual(3, self.ncr12.get_common_labels())
+        self.assertEqual(1, self.ncr15.get_common_labels())
+
+    def test_get_relation(self):
+        self.assertEqual("COMMONANCESTOR", NameRelation[self.ncr12.get_relation()])
+        self.assertEqual("COMMONANCESTOR", NameRelation[self.ncr15.get_relation()])
+
+class NameTest(unittest.TestCase):
+    def setUp(self):
+        self.name1 = Name("example.com")
+        self.name2 = Name(".")
+        self.name3 = Name("something.completely.different")
+        self.name4 = Name("EXAMPLE.com")
+        self.name5 = Name("*.example.com")
+
+    def test_init(self):
+        self.assertRaises(EmptyLabel, Name, "example..com")
+        self.assertRaises(TooLongLabel, Name, "a"*64 + ".example.com")
+        self.assertRaises(BadLabelType, Name, "\[asdf.example.com")
+        self.assertRaises(BadEscape, Name, "\\999")
+        self.assertRaises(TooLongName, Name, "example."*32 + "com")
+        self.assertRaises(IncompleteName, Name, "\\")
+        self.assertRaises(TypeError, Name, 1)
+
+        b = bytearray()
+        self.name1.to_wire(b)
+        self.assertEqual(self.name1, Name(b))
+        self.assertEqual(self.name1, Name(b, 0))
+        self.assertRaises(InvalidBufferPosition, Name, b, 100)
+        b = bytearray()
+        b += b'\x07example'*32 + b'\x03com\x00'
+        # no TooLong for from wire?
+        self.assertRaises(DNSMessageFORMERR, Name, b, 0)
+
+    def test_at(self):
+        self.assertEqual(7, self.name1.at(0))
+        self.assertEqual(101, self.name1.at(1))
+        self.assertRaises(IndexError, self.name1.at, 100)
+        self.assertRaises(TypeError, self.name1.at, "wrong")
+
+    def test_get_length(self):
+        self.assertEqual(13, self.name1.get_length())
+        self.assertEqual(1, self.name2.get_length())
+        self.assertEqual(32, self.name3.get_length())
+
+    def test_get_labelcount(self):
+        self.assertEqual(3, self.name1.get_labelcount())
+        self.assertEqual(1, self.name2.get_labelcount())
+        self.assertEqual(4, self.name3.get_labelcount())
+
+    def test_to_text(self):
+        self.assertEqual("example.com.", self.name1.to_text())
+        self.assertEqual(".", self.name2.to_text())
+        self.assertEqual("something.completely.different.", self.name3.to_text())
+
+    def test_to_wire(self):
+        b1 = bytearray()
+        self.name1.to_wire(b1)
+        self.assertEqual(bytearray(b'\x07example\x03com\x00'), b1)
+        b2 = bytearray()
+        self.name2.to_wire(b2)
+        self.assertEqual(bytearray(b'\x00'), b2)
+
+        mr = MessageRenderer()
+        self.name1.to_wire(mr)
+        self.assertEqual(b'\x07example\x03com\x00', mr.get_data())
+
+        self.assertRaises(TypeError, self.name1.to_wire, "wrong")
+        self.assertRaises(TypeError, self.name1.to_wire, 1)
+
+    def test_compare(self):
+        # tested in comparison class above
+        pass
+
+    def test_equals(self):
+        self.assertFalse(self.name1.equals(self.name2))
+        self.assertFalse(self.name1.equals(self.name3))
+        self.assertTrue(self.name1.equals(self.name4))
+        #TODO: == not yet defined
+        #self.assertEqual(self.name1, self.name2)
+
+    def test_split(self):
+        s = self.name1.split(1,1)
+        self.assertEqual("com.", s.to_text())
+        s = self.name1.split(0,1)
+        self.assertEqual("example.", s.to_text())
+        s = self.name3.split(1,2)
+        self.assertEqual("completely.different.", s.to_text())
+        self.assertRaises(TypeError, self.name1.split, "wrong", 1)
+        self.assertRaises(TypeError, self.name1.split, 1, "wrong")
+        # TODO: this test will fail when new split(int) is added
+        self.assertRaises(TypeError, self.name1.split, 1)
+        self.assertRaises(IndexError, self.name1.split, 123, 1)
+        self.assertRaises(IndexError, self.name1.split, 1, 123)
+
+    def test_reverse(self):
+        self.assertEqual("com.example.", self.name1.reverse().to_text())
+        self.assertEqual(".", self.name2.reverse().to_text())
+
+    def test_concatenate(self):
+        self.assertEqual("example.com.", self.name1.concatenate(self.name2).to_text())
+        self.assertEqual("example.com.example.com.", self.name1.concatenate(self.name1).to_text())
+        self.assertRaises(TypeError, self.name1.concatenate, "wrong")
+        self.assertRaises(TooLongName, self.name1.concatenate, Name("example."*31))
+        
+
+    def test_downcase(self):
+        self.assertEqual("EXAMPLE.com.", self.name4.to_text())
+        self.name4.downcase()
+        self.assertEqual("example.com.", self.name4.to_text())
+
+    def test_is_wildcard(self):
+        self.assertFalse(self.name1.is_wildcard())
+        self.assertTrue(self.name5.is_wildcard())
+
+    def test_richcmp(self):
+        self.assertTrue(self.name1 > self.name2)
+        self.assertFalse(self.name1 < self.name2)
+        self.assertTrue(self.name1 == self.name4)
+        self.assertTrue(self.name1 <= self.name4)
+        self.assertTrue(self.name1 >= self.name4)
+        self.assertFalse(self.name1 <= self.name2)
+
+if __name__ == '__main__':
+    unittest.main()

+ 6 - 0
src/lib/dns/python/tests/question_python_test.py

@@ -51,6 +51,10 @@ class QuestionTest(unittest.TestCase):
         self.test_question1 = Question(self.example_name1, RRClass("IN"), RRType("NS"))
         self.test_question2 = Question(self.example_name2, RRClass("CH"), RRType("A"))
 
+    def test_init(self):
+        self.assertRaises(TypeError, Question, "wrong")
+
+    # tests below based on cpp unit tests
     def test_QuestionTest_fromWire(self):
         
         q = question_from_wire("question_fromWire")
@@ -79,6 +83,7 @@ class QuestionTest(unittest.TestCase):
     def test_QuestionTest_to_text(self):
     
         self.assertEqual("foo.example.com. IN NS\n", self.test_question1.to_text())
+        self.assertEqual("foo.example.com. IN NS\n", self.test_question1.__str__())
         self.assertEqual("bar.example.com. CH A\n", self.test_question2.to_text())
     
     
@@ -96,6 +101,7 @@ class QuestionTest(unittest.TestCase):
         self.test_question2.to_wire(renderer)
         wiredata = read_wire_data("question_toWire2")
         self.assertEqual(renderer.get_data(), wiredata)
+        self.assertRaises(TypeError, self.test_question1.to_wire, 1)
     
 
 if __name__ == '__main__':

+ 61 - 0
src/lib/dns/python/tests/rdata_python_test.py

@@ -0,0 +1,61 @@
+# Copyright (C) 2010  Internet Systems Consortium.
+#
+# Permission to use, copy, modify, and distribute this software for any
+# purpose with or without fee is hereby granted, provided that the above
+# copyright notice and this permission notice appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
+# DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
+# INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
+# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
+# FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
+# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
+# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+#
+# Tests for the rdata part of the libdns_python module
+#
+
+import unittest
+import os
+from libdns_python import *
+
+class RdataTest(unittest.TestCase):
+    def setUp(self):
+        c = RRClass("IN")
+        t = RRType("A")
+        self.rdata1 = Rdata(t, c, "192.0.2.98")
+        self.rdata2 = Rdata(t, c, "192.0.2.99")
+        t = RRType("TXT")
+        self.rdata3 = Rdata(t, c, "asdfasdfasdf")
+        self.rdata4 = Rdata(t, c, "foo")
+        
+    def test_init(self):
+        self.assertRaises(TypeError, Rdata, "wrong", RRClass("IN"), "192.0.2.99")
+        self.assertRaises(TypeError, Rdata, RRType("A"), "wrong", "192.0.2.99")
+        self.assertRaises(TypeError, Rdata, RRType("A"), RRClass("IN"), 1)
+
+    def test_rdata_to_wire(self):
+        b = bytearray()
+        self.rdata1.to_wire(b)
+        self.assertEqual(b'\xc0\x00\x02b', b)
+        b = bytearray()
+        self.rdata2.to_wire(b)
+        self.assertEqual(b'\xc0\x00\x02c', b)
+        b = bytearray()
+        self.rdata3.to_wire(b)
+        self.assertEqual(b'\x0casdfasdfasdf', b)
+        b = bytearray()
+        self.rdata4.to_wire(b)
+        self.assertEqual(b'\x03foo', b)
+        self.assertRaises(TypeError, self.rdata1.to_wire, 1)
+
+    def test_rdata_to_text(self):
+        self.assertEqual("192.0.2.98", self.rdata1.to_text())
+        self.assertEqual("192.0.2.99", self.rdata2.to_text())
+        self.assertEqual("\"asdfasdfasdf\"", self.rdata3.to_text())
+        self.assertEqual("\"foo\"", self.rdata4.to_text())
+
+if __name__ == '__main__':
+    unittest.main()

+ 68 - 0
src/lib/dns/python/tests/rrclass_python_test.py

@@ -0,0 +1,68 @@
+# Copyright (C) 2010  Internet Systems Consortium.
+#
+# Permission to use, copy, modify, and distribute this software for any
+# purpose with or without fee is hereby granted, provided that the above
+# copyright notice and this permission notice appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
+# DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
+# INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
+# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
+# FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
+# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
+# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+#
+# Tests for the rrclass part of the libdns_python module
+#
+
+import unittest
+import os
+from libdns_python import *
+
+class RRClassTest(unittest.TestCase):
+    def setUp(self):
+        self.c1 = RRClass("IN")
+        self.c2 = RRClass("CH")
+
+    def test_init(self):
+        self.assertRaises(InvalidRRClass, RRClass, "wrong")
+        self.assertRaises(TypeError, RRClass, Exception())
+        b = bytearray(1)
+        b[0] = 123
+        self.assertRaises(TypeError, RRClass, b)
+        self.assertRaises(InvalidRRClass, RRClass, 65536)
+        self.assertEqual(self.c1, RRClass(1))
+        b = bytearray()
+        self.c1.to_wire(b)
+        self.assertEqual(self.c1, RRClass(b))
+        
+    def test_rrclass_to_text(self):
+        self.assertEqual("IN", self.c1.to_text())
+        self.assertEqual("IN", self.c1.__str__())
+        self.assertEqual("CH", self.c2.to_text())
+
+    def test_rrclass_to_wire(self):
+        b = bytearray()
+        self.c1.to_wire(b)
+        self.assertEqual(b'\x00\x01', b)
+        b = bytearray()
+        self.c2.to_wire(b)
+        self.assertEqual(b'\x00\x03', b)
+
+        mr = MessageRenderer()
+        self.c1.to_wire(mr)
+        self.assertEqual(b'\x00\x01', mr.get_data())
+
+        self.assertRaises(TypeError, self.c1.to_wire, "wrong")
+
+    def test_richcmp(self):
+        self.assertTrue(self.c1 != self.c2)
+        self.assertTrue(self.c1 < self.c2)
+        self.assertTrue(self.c1 <= self.c2)
+        self.assertFalse(self.c1 > self.c2)
+        self.assertFalse(self.c1 >= self.c2)
+
+if __name__ == '__main__':
+    unittest.main()

+ 50 - 0
src/lib/dns/python/tests/rrttl_python_test.py

@@ -0,0 +1,50 @@
+# Copyright (C) 2010  Internet Systems Consortium.
+#
+# Permission to use, copy, modify, and distribute this software for any
+# purpose with or without fee is hereby granted, provided that the above
+# copyright notice and this permission notice appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
+# DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
+# INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
+# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
+# FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
+# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
+# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+#
+# Tests for the rrttl part of the libdns_python module
+#
+
+import unittest
+import os
+from libdns_python import *
+
+class RdataTest(unittest.TestCase):
+    def setUp(self):
+        self.t1 = RRTTL(1)
+        self.t2 = RRTTL(3600)
+        
+    def test_init(self):
+        self.assertRaises(InvalidRRTTL, RRTTL, "wrong")
+        self.assertRaises(TypeError, RRTTL, Exception())
+        b = bytearray(1)
+        b[0] = 123
+        self.assertRaises(TypeError, RRTTL, b)
+        self.assertRaises(InvalidRRTTL, RRTTL, "4294967296")
+        
+    def test_rdata_to_text(self):
+        self.assertEqual("1", self.t1.to_text())
+        self.assertEqual("3600", self.t2.to_text())
+
+    def test_rdata_to_wire(self):
+        b = bytearray()
+        self.t1.to_wire(b)
+        self.assertEqual(b'\x00\x00\x00\x01', b)
+        b = bytearray()
+        self.t2.to_wire(b)
+        self.assertEqual(b'\x00\x00\x0e\x10', b)
+
+if __name__ == '__main__':
+    unittest.main()