Browse Source

unit tests

git-svn-id: svn://bind10.isc.org/svn/bind10/experiments/python-binding@2088 e5f2f494-b856-4b98-b285-d166d9295462
Jelte Jansen 15 years ago
parent
commit
384309beb2

+ 2 - 6
src/lib/dns/python/message_python.cc

@@ -1095,7 +1095,7 @@ Rcode_BADVERS(s_Rcode* self UNUSED_PARAM)
 static PyObject* 
 Rcode_richcmp(s_Rcode* self, s_Rcode* other, int op)
 {
-    bool c;
+    bool c = false;
 
     // Check for null and if the types match. If different type,
     // simply return False
@@ -1129,8 +1129,6 @@ Rcode_richcmp(s_Rcode* self, s_Rcode* other, int op)
         PyErr_SetString(PyExc_TypeError, "Unorderable type; Rcode");
         return NULL;
         break;
-    default:
-        assert(0);              // XXX: should trigger an exception
     }
     if (c)
         Py_RETURN_TRUE;
@@ -1310,7 +1308,7 @@ Section_ADDITIONAL(s_Section* self UNUSED_PARAM)
 static PyObject* 
 Section_richcmp(s_Section* self, s_Section* other, int op)
 {
-    bool c;
+    bool c = false;
 
     // Check for null and if the types match. If different type,
     // simply return False
@@ -1344,8 +1342,6 @@ Section_richcmp(s_Section* self, s_Section* other, int op)
         PyErr_SetString(PyExc_TypeError, "Unorderable type; Section");
         return NULL;
         break;
-    default:
-        assert(0);              // XXX: should trigger an exception
     }
     if (c)
         Py_RETURN_TRUE;

+ 6 - 7
src/lib/dns/python/rrttl_python.cc

@@ -164,12 +164,13 @@ RRTTL_init(s_RRTTL* self, PyObject* args)
             self->rrttl = new RRTTL(i);
             return 0;
         } else if (PyArg_ParseTuple(args, "O", &bytes) && PySequence_Check(bytes)) {
-            uint8_t data[2];
-            int result = readDataFromSequence(data, 2, bytes);
+            Py_ssize_t size = PySequence_Size(bytes);
+            uint8_t data[size];
+            int result = readDataFromSequence(data, size, bytes);
             if (result != 0) {
                 return result;
             }
-            InputBuffer ib(data, 2);
+            InputBuffer ib(data, size);
             self->rrttl = new RRTTL(ib);
             PyErr_Clear();
             return 0;
@@ -180,7 +181,7 @@ RRTTL_init(s_RRTTL* self, PyObject* args)
         // First clear any existing error that was set
         PyErr_Clear();
         // Now set our own exception
-        PyErr_SetString(po_InvalidRRTTL, icc.what());
+        PyErr_SetString(po_IncompleteRRTTL, icc.what());
         // And return negative
         return -1;
     } catch (InvalidRRTTL ic) {
@@ -255,7 +256,7 @@ RRTTL_getValue(s_RRTTL* self)
 static PyObject* 
 RRTTL_richcmp(s_RRTTL* self, s_RRTTL* other, int op)
 {
-    bool c;
+    bool c = false;
 
     // Check for null and if the types match. If different type,
     // simply return False
@@ -286,8 +287,6 @@ RRTTL_richcmp(s_RRTTL* self, s_RRTTL* other, int op)
         c = *other->rrttl < *self->rrttl ||
             *self->rrttl == *other->rrttl;
         break;
-    default:
-        assert(0);              // XXX: should trigger an exception
     }
     if (c)
         Py_RETURN_TRUE;

+ 5 - 4
src/lib/dns/python/rrtype_python.cc

@@ -167,12 +167,13 @@ RRType_init(s_RRType* self, PyObject* args)
             self->rrtype = new RRType(i);
             return 0;
         } else if (PyArg_ParseTuple(args, "O", &bytes) && PySequence_Check(bytes)) {
-            uint8_t data[2];
-            int result = readDataFromSequence(data, 2, bytes);
+            Py_ssize_t size = PySequence_Size(bytes);
+            uint8_t data[size];
+            int result = readDataFromSequence(data, size, bytes);
             if (result != 0) {
                 return result;
             }
-            InputBuffer ib(data, 2);
+            InputBuffer ib(data, size);
             self->rrtype = new RRType(ib);
             PyErr_Clear();
             return 0;
@@ -183,7 +184,7 @@ RRType_init(s_RRType* self, PyObject* args)
         // First clear any existing error that was set
         PyErr_Clear();
         // Now set our own exception
-        PyErr_SetString(po_InvalidRRType, icc.what());
+        PyErr_SetString(po_IncompleteRRType, icc.what());
         // And return negative
         return -1;
     } catch (InvalidRRType ic) {

+ 25 - 0
src/lib/dns/python/tests/message_python_test.py

@@ -82,6 +82,8 @@ class OpcodeTest(unittest.TestCase):
         o3 = Opcode.NOTIFY()
         self.assertTrue(o2 == o3)
         self.assertTrue(o1 != o2)
+        self.assertFalse(o1 == 1)
+        self.assertFalse(o1 == o2)
         # can't use assertRaises here...
         try:
             o1 < o2
@@ -148,6 +150,8 @@ class RcodeTest(unittest.TestCase):
         r3 = Rcode.FORMERR()
         self.assertTrue(r2 == r3)
         self.assertTrue(r1 != r2)
+        self.assertFalse(r1 == r2)
+        self.assertFalse(r1 != 1)
         # can't use assertRaises here...
         try:
             r1 < r2
@@ -183,6 +187,8 @@ class SectionTest(unittest.TestCase):
         s3 = Section.ANSWER()
         self.assertTrue(s2 == s3)
         self.assertTrue(s1 != s2)
+        self.assertFalse(s1 == s2)
+        self.assertFalse(s1 == 1)
         # can't use assertRaises here...
         try:
             s1 < s2
@@ -226,6 +232,11 @@ class MessageTest(unittest.TestCase):
         self.r.clear_header_flag(MessageFlag.AA())
         self.assertFalse(self.r.get_header_flag(MessageFlag.AA()))
 
+        self.assertRaises(InvalidMessageOperation,
+                          self.p.set_header_flag, MessageFlag.AA())
+        self.assertRaises(InvalidMessageOperation,
+                          self.p.clear_header_flag, MessageFlag.AA())
+
     def test_set_DNSSEC_supported(self):
         self.assertRaises(TypeError, self.r.set_dnssec_supported, "wrong")
 
@@ -240,6 +251,8 @@ class MessageTest(unittest.TestCase):
 
     def test_set_qid(self):
         self.assertRaises(TypeError, self.r.set_qid, "wrong")
+        self.assertRaises(InvalidMessageOperation,
+                          self.p.set_qid, 123)
 
     def test_set_rcode(self):
         self.assertRaises(TypeError, self.r.set_rcode, "wrong")
@@ -250,6 +263,12 @@ class MessageTest(unittest.TestCase):
     def test_get_section(self):
         self.assertRaises(TypeError, self.r.get_section, "wrong")
 
+    def test_get_rr_count(self):
+        self.assertRaises(TypeError, self.r.get_rr_count, "wrong")
+
+    def test_add_question(self):
+        self.assertRaises(TypeError, self.r.add_question, "wrong", "wrong")
+
     def test_add_rrset(self):
         self.assertRaises(TypeError, self.r.add_rrset, "wrong")
 
@@ -261,9 +280,15 @@ class MessageTest(unittest.TestCase):
 
     def test_to_wire(self):
         self.assertRaises(TypeError, self.r.to_wire, 1)
+        self.assertRaises(InvalidMessageOperation,
+                          self.p.to_wire, MessageRenderer())
 
     def test_from_wire(self):
         self.assertRaises(TypeError, self.r.from_wire, 1)
+        self.assertRaises(InvalidMessageOperation,
+                          Message.from_wire, self.r, bytes())
+        self.assertRaises(MessageTooShort,
+                          Message.from_wire, self.p, bytes())
 
 # helper functions for tests taken from c++ unittests
 if "TESTDATA_PATH" in os.environ:

+ 12 - 0
src/lib/dns/python/tests/rrset_python_test.py

@@ -35,6 +35,9 @@ class TestModuleSpec(unittest.TestCase):
         self.rrset_a.add_rdata(Rdata(RRType("A"), RRClass("IN"), "192.0.2.1"));
         self.rrset_a.add_rdata(Rdata(RRType("A"), RRClass("IN"), "192.0.2.2"));
 
+    def test_init(self):
+        self.assertRaises(TypeError, RRset)
+
     def test_get_rdata_count(self):
         for i in range(0, self.MAX_RDATA_COUNT):
             self.assertEqual(i, self.rrset_a_empty.get_rdata_count())
@@ -63,10 +66,12 @@ class TestModuleSpec(unittest.TestCase):
         self.assertEqual(RRTTL(86400), self.rrset_a.get_ttl());
         self.rrset_a.set_ttl(RRTTL(0));
         self.assertEqual(RRTTL(0), self.rrset_a.get_ttl());
+        self.assertRaises(TypeError, self.rrset_a.set_ttl, 1)
 
     def test_set_name(self):
         self.rrset_a.set_name(self.test_nsname);
         self.assertEqual(self.test_nsname, self.rrset_a.get_name());
+        self.assertRaises(TypeError, self.rrset_a.set_name, 1)
 
     def test_add_rdata(self):
         # no iterator to read out yet (TODO: add addition test once implemented)
@@ -78,6 +83,12 @@ class TestModuleSpec(unittest.TestCase):
         self.assertEqual("test.example.com. 3600 IN A 192.0.2.1\n"
                          "test.example.com. 3600 IN A 192.0.2.2\n",
                          self.rrset_a.to_text());
+        self.assertEqual("test.example.com. 3600 IN A 192.0.2.1\n"
+                         "test.example.com. 3600 IN A 192.0.2.2\n",
+                         self.rrset_a.__str__());
+
+        #rrset_empty = RRset(self.test_name, RRClass("IN"), RRType("A"), RRTTL(3600))
+        self.assertRaises(EmptyRRset, self.rrset_a_empty.to_text)
 
     def test_to_wire_buffer(self):
         exp_buffer = bytearray(b'\x04test\x07example\x03com\x00\x00\x01\x00\x01\x00\x00\x0e\x10\x00\x04\xc0\x00\x02\x01\x04test\x07example\x03com\x00\x00\x01\x00\x01\x00\x00\x0e\x10\x00\x04\xc0\x00\x02\x02')
@@ -86,6 +97,7 @@ class TestModuleSpec(unittest.TestCase):
         self.assertEqual(exp_buffer, buffer)
 
         self.assertRaises(EmptyRRset, self.rrset_a_empty.to_wire, buffer);
+        self.assertRaises(TypeError, self.rrset_a.to_wire, 1)
 
     def test_to_wire_renderer(self):
         exp_buffer = bytearray(b'\x04test\x07example\x03com\x00\x00\x01\x00\x01\x00\x00\x0e\x10\x00\x04\xc0\x00\x02\x01\xc0\x00\x00\x01\x00\x01\x00\x00\x0e\x10\x00\x04\xc0\x00\x02\x02')

+ 21 - 1
src/lib/dns/python/tests/rrttl_python_test.py

@@ -31,11 +31,18 @@ class RdataTest(unittest.TestCase):
         self.assertRaises(TypeError, RRTTL, Exception())
         b = bytearray(1)
         b[0] = 123
-        self.assertRaises(TypeError, RRTTL, b)
+        self.assertRaises(IncompleteRRTTL, RRTTL, b)
         self.assertRaises(InvalidRRTTL, RRTTL, "4294967296")
+        b = bytearray(4)
+        b[0] = 0
+        b[1] = 0
+        b[2] = 0
+        b[3] = 15
+        self.assertEqual(15, RRTTL(b).get_value())
         
     def test_rdata_to_text(self):
         self.assertEqual("1", self.t1.to_text())
+        self.assertEqual("1", self.t1.__str__())
         self.assertEqual("3600", self.t2.to_text())
 
     def test_rdata_to_wire(self):
@@ -45,6 +52,19 @@ class RdataTest(unittest.TestCase):
         b = bytearray()
         self.t2.to_wire(b)
         self.assertEqual(b'\x00\x00\x0e\x10', b)
+        mr = MessageRenderer()
+        self.t2.to_wire(mr)
+        self.assertEqual(b'\x00\x00\x0e\x10', mr.get_data())
+        self.assertRaises(TypeError, self.t1.to_wire, 1)
+
+    def test_rdata_richcmp(self):
+        self.assertTrue(self.t1 == RRTTL(1))
+        self.assertFalse(self.t1 != RRTTL(1))
+        self.assertFalse(self.t1 == 1)
+        self.assertTrue(self.t1 < self.t2)
+        self.assertTrue(self.t1 <= self.t2)
+        self.assertFalse(self.t1 > self.t2)
+        self.assertFalse(self.t1 >= self.t2)
 
 if __name__ == '__main__':
     unittest.main()

+ 18 - 0
src/lib/dns/python/tests/rrtype_python_test.py

@@ -30,6 +30,15 @@ class TestModuleSpec(unittest.TestCase):
     rrtype_max = RRType(0xffff);
     wiredata = bytearray(b'\x00\x01\x00\x80\x08\x00\x80\x00\xff\xff');
 
+
+    def test_init(self):
+        self.assertRaises(InvalidRRType, RRType, 65537)
+        b = bytearray(b'\x00\x01')
+        self.assertEqual(RRType("A"), RRType(b))
+        b = bytearray(b'\x01')
+        self.assertRaises(IncompleteRRType, RRType, b)
+        self.assertRaises(TypeError, RRType, Exception)
+    
     def test_from_text(self):
         self.assertEqual("A", RRType("A").to_text())
         self.assertEqual("NS", RRType("NS").to_text());
@@ -54,6 +63,7 @@ class TestModuleSpec(unittest.TestCase):
 
     def test_to_text(self):
         self.assertEqual("A", RRType(1).to_text());
+        self.assertEqual("A", RRType(1).__str__());
         self.assertEqual("TYPE65000", RRType(65000).to_text());
 
     def test_to_wire_buffer(self):
@@ -76,6 +86,9 @@ class TestModuleSpec(unittest.TestCase):
 
         self.assertEqual(self.wiredata, mr.get_data())
 
+    def test_to_wire_bad(self):
+        self.assertRaises(TypeError, self.rrtype_1.to_wire, "wrong")
+
     def test_compare(self):
         self.assertTrue(RRType(1) == RRType("A"));
         #self.assertTrue(RRType(1).equals(RRType("A")));
@@ -83,7 +96,12 @@ class TestModuleSpec(unittest.TestCase):
         #self.assertTrue(RRType(0).nequals(RRType("A")));
     
         self.assertTrue(RRType("A") < RRType("NS"));
+        self.assertTrue(RRType("A") <= RRType("NS"));
         self.assertTrue(RRType(100) < RRType(65535));
+        self.assertFalse(RRType(100) > RRType(65535));
+        self.assertFalse(RRType(100) >= RRType(65535));
+
+        self.assertFalse(self.rrtype_1 == 1)
         
 if __name__ == '__main__':
     unittest.main()