Browse Source

add Message class constants and message specific exceptions to the Message class instead of the module

git-svn-id: svn://bind10.isc.org/svn/bind10/experiments/python-binding@2014 e5f2f494-b856-4b98-b285-d166d9295462
Jelte Jansen 15 years ago
parent
commit
8b71838e3f

+ 1 - 1
src/bin/bind10/run_bind10.sh.in

@@ -23,7 +23,7 @@ BIND10_PATH=@abs_top_builddir@/src/bin/bind10
 PATH=@abs_top_builddir@/src/bin/msgq:@abs_top_builddir@/src/bin/auth:@abs_top_builddir@/src/bin/cfgmgr:@abs_top_builddir@/src/bin/cmdctl:@abs_top_builddir@/src/bin/xfrin:@abs_top_builddir@/src/bin/xfrout:$PATH
 export PATH
 
-PYTHONPATH=@abs_top_builddir@/src/lib/python:@abs_top_builddir@/src/lib/dns/python/.libs
+PYTHONPATH=@abs_top_builddir@/src/lib/python:@abs_top_builddir@/src/lib/dns/python/.libs:@abs_top_builddir@/src/lib/xfr/.libs
 export PYTHONPATH
 
 B10_FROM_SOURCE=@abs_top_srcdir@

+ 2 - 2
src/bin/xfrin/xfrin.py.in

@@ -95,7 +95,7 @@ class XfrinConnection(asyncore.dispatcher):
     def _create_query(self, query_type):
         '''Create dns query message. '''
 
-        msg = Message(RENDER)
+        msg = Message(Message.RENDER)
         query_id = random.randint(1, 0xFFFF)
         self._query_id = query_id
         msg.set_qid(query_id)
@@ -238,7 +238,7 @@ class XfrinConnection(asyncore.dispatcher):
             data_len = self._get_request_response(2)
             msg_len = socket.htons(struct.unpack('H', data_len)[0])
             recvdata = self._get_request_response(msg_len)
-            msg = Message(PARSE)
+            msg = Message(Message.PARSE)
             msg.from_wire(recvdata)
             self._check_response_status(msg)
             

+ 6 - 5
src/bin/xfrout/xfrout.py.in

@@ -52,6 +52,7 @@ verbose_mode = False
 
 
 class XfroutException(Exception): pass
+class TmpException(Exception): pass
 
 class XfroutSession(BaseRequestHandler):
     def handle(self):
@@ -65,7 +66,7 @@ class XfroutSession(BaseRequestHandler):
         sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM)
         try:
             self.dns_xfrout_start(sock, msgdata)
-        except Exception as e:
+        except TmpException as e:
             if verbose_mode:
                 self.log_msg(str(e))
 
@@ -78,9 +79,9 @@ class XfroutSession(BaseRequestHandler):
         ''' parse query message to [socket,message]'''
         #TODO, need to add parseHeader() in case the message header is invalid 
         try:
-            msg = Message(PARSE)
+            msg = Message(Message.PARSE)
             msg.from_wire(mdata)
-        except Exception as err:
+        except TmpException as err:
             if verbose_mode:
                 self.log_msg(str(err))
             return Rcode.FORMERR(), None
@@ -179,7 +180,7 @@ class XfroutSession(BaseRequestHandler):
 
             if verbose_mode:
                 self.log_msg("transfer of '%s/IN': AXFR end" % zone_name)
-        except Exception as err:
+        except TmpException as err:
             if verbose_mode:
                 sys.stderr.write(str(err))
 
@@ -192,7 +193,7 @@ class XfroutSession(BaseRequestHandler):
         opcode = msg.get_opcode()
         rcode = msg.get_rcode()
         
-        msg.clear(RENDER)
+        msg.clear(Message.RENDER)
         msg.set_qid(qid)
         msg.set_opcode(opcode)
         msg.set_rcode(rcode)

+ 4 - 0
src/lib/dns/python/libdns_python_common.cc

@@ -46,3 +46,7 @@ readDataFromSequence(uint8_t *data, size_t len, PyObject* sequence)
 }
 
 
+void addClassVariable(PyTypeObject& c, const char* name, PyObject* obj)
+{
+    PyDict_SetItemString(c.tp_dict, name, obj);
+}

+ 2 - 0
src/lib/dns/python/libdns_python_common.h

@@ -30,3 +30,5 @@
 // head of the sequence (even if it fails it removes everything that
 // it successfully read)
 int readDataFromSequence(uint8_t *data, size_t len, PyObject* sequence);
+
+void addClassVariable(PyTypeObject& c, const char* name, PyObject* obj);

+ 30 - 41
src/lib/dns/python/message_python.cc

@@ -29,13 +29,6 @@ static PyObject* po_InvalidMessageUDPSize;
 static PyObject* po_DNSMessageBADVERS;
 
 //
-// Constants
-//
-static PyObject* po_MessagePARSE;
-static PyObject* po_MessageRENDER;
-static PyObject* po_MessageDefaultMaxUDPSize;
-
-//
 // Definition of the classes
 //
 
@@ -1621,8 +1614,13 @@ Message_setQid(s_Message* self, PyObject* args)
     if (!PyArg_ParseTuple(args, "I", &id)) {
         return NULL;
     }
-    self->message->setQid(id);
-    Py_RETURN_NONE;
+    try {
+        self->message->setQid(id);
+        Py_RETURN_NONE;
+    } catch (InvalidMessageOperation imo) {
+        PyErr_SetString(po_InvalidMessageOperation, imo.what());
+        return NULL;
+    }
 }
 
 static PyObject*
@@ -1874,41 +1872,11 @@ Message_fromWire(s_Message* self, PyObject* args)
     }
 }
 
-// end of Message
-
-
 // Module Initialization, all statics are initialized here
 bool
 initModulePart_Message(PyObject* mod)
 {
-    // Add the exceptions to the module
-    po_MessageTooShort = PyErr_NewException("libdns_python.MessageTooShort", NULL, NULL);
-    Py_INCREF(po_MessageTooShort);
-    PyModule_AddObject(mod, "MessageTooShort", po_MessageTooShort);
-    po_InvalidMessageSection = PyErr_NewException("libdns_python.InvalidMessageSection", NULL, NULL);
-    Py_INCREF(po_InvalidMessageSection);
-    PyModule_AddObject(mod, "InvalidMessageSection", po_InvalidMessageSection);
-    po_InvalidMessageOperation = PyErr_NewException("libdns_python.InvalidMessageOperation", NULL, NULL);
-    Py_INCREF(po_InvalidMessageOperation);
-    PyModule_AddObject(mod, "InvalidMessageOperation", po_InvalidMessageOperation);
-    po_InvalidMessageUDPSize = PyErr_NewException("libdns_python.InvalidMessageUDPSize", NULL, NULL);
-    Py_INCREF(po_InvalidMessageUDPSize);
-    PyModule_AddObject(mod, "InvalidMessageUDPSize", po_InvalidMessageUDPSize);
-    po_DNSMessageBADVERS = PyErr_NewException("libdns_python.DNSMessageBADVERS", NULL, NULL);
-    Py_INCREF(po_DNSMessageBADVERS);
-    PyModule_AddObject(mod, "DNSMessageBADVERS", po_DNSMessageBADVERS);
-
-    // Constants. These should probably go into the Message class, but need to find out how first
-    po_MessagePARSE = Py_BuildValue("I", Message::PARSE);
-    Py_INCREF(po_MessagePARSE);
-    PyModule_AddObject(mod, "PARSE", po_MessagePARSE);
-    po_MessageRENDER = Py_BuildValue("I", Message::RENDER);
-    Py_INCREF(po_MessageRENDER);
-    PyModule_AddObject(mod, "RENDER", po_MessageRENDER);
-    po_MessageDefaultMaxUDPSize = Py_BuildValue("I", Message::DEFAULT_MAX_UDPSIZE);
-    Py_INCREF(po_MessageDefaultMaxUDPSize);
-    PyModule_AddObject(mod, "DEFAULT_MAX_UDPSIZE", po_MessageDefaultMaxUDPSize);
-
+    
     /* add methods to class */
     if (PyType_Ready(&messageflag_type) < 0) {
         return false;
@@ -1943,10 +1911,31 @@ initModulePart_Message(PyObject* mod)
     if (PyType_Ready(&message_type) < 0) {
         return false;
     }
+    
+    /* Class variables
+     * These are added to the tp_dict of the type object
+     */
+    //PyDict_SetItemString(message_type.tp_dict, "PARSE", Py_BuildValue("I", Message::PARSE));
+    addClassVariable(message_type, "PARSE", Py_BuildValue("I", Message::PARSE));
+    addClassVariable(message_type, "RENDER", Py_BuildValue("I", Message::RENDER));
+    addClassVariable(message_type, "DEFAULT_MAX_UDPSIZE", Py_BuildValue("I", Message::DEFAULT_MAX_UDPSIZE));
+
+    /* Class-specific exceptions */
+    po_MessageTooShort = PyErr_NewException("libdns_python.Message.MessageTooShort", NULL, NULL);
+    addClassVariable(message_type, "MessageTooShort", po_MessageTooShort);
+    po_InvalidMessageSection = PyErr_NewException("libdns_python.Message.InvalidMessageSection", NULL, NULL);
+    addClassVariable(message_type, "InvalidMessageSection", po_InvalidMessageSection);
+    po_InvalidMessageOperation = PyErr_NewException("libdns_python.Message.InvalidMessageOperation", NULL, NULL);
+    addClassVariable(message_type, "InvalidMessageOperation", po_InvalidMessageOperation);
+    po_InvalidMessageUDPSize = PyErr_NewException("libdns_python.Message.InvalidMessageUDPSize", NULL, NULL);
+    addClassVariable(message_type, "InvalidMessageUDPSize", po_InvalidMessageUDPSize);
+    po_DNSMessageBADVERS = PyErr_NewException("libdns_python.Message.DNSMessageBADVERS", NULL, NULL);
+    addClassVariable(message_type, "DNSMessageBADVERS", po_DNSMessageBADVERS);
+
     Py_INCREF(&message_type);
     PyModule_AddObject(mod, "Message",
                        (PyObject*) &message_type);
-    
+
 
     return true;
 }

+ 30 - 30
src/lib/dns/python/tests/message_python_test.py

@@ -205,8 +205,8 @@ class SectionTest(unittest.TestCase):
 class MessageTest(unittest.TestCase):
 
     def setUp(self):
-        self.p = Message(PARSE)
-        self.r = Message(RENDER)
+        self.p = Message(Message.PARSE)
+        self.r = Message(Message.RENDER)
         
     def test_init(self):
         self.assertRaises(TypeError, Message, 3)
@@ -254,8 +254,8 @@ class MessageTest(unittest.TestCase):
         self.assertRaises(TypeError, self.r.add_rrset, "wrong")
 
     def test_clear(self):
-        self.assertEqual(None, self.r.clear(PARSE))
-        self.assertEqual(None, self.r.clear(RENDER))
+        self.assertEqual(None, self.r.clear(Message.PARSE))
+        self.assertEqual(None, self.r.clear(Message.RENDER))
         self.assertRaises(TypeError, self.r.clear, "wrong")
         self.assertRaises(TypeError, self.r.clear, 3)
 
@@ -343,25 +343,25 @@ class ConvertedUnittests(unittest.TestCase):
         self.assertEqual(2, len(rdata))
     
     def test_GetEDNS0DOBit(self):
-        message_parse = Message(PARSE)
+        message_parse = Message(Message.PARSE)
         ## Without EDNS0, DNSSEC is considered to be unsupported.
         factoryFromFile(message_parse, "message_fromWire1")
         self.assertFalse(message_parse.is_dnssec_supported())
     
         ## If DO bit is on, DNSSEC is considered to be supported.
-        message_parse.clear(PARSE)
+        message_parse.clear(Message.PARSE)
         factoryFromFile(message_parse, "message_fromWire2")
         self.assertTrue(message_parse.is_dnssec_supported())
     
         ## If DO bit is off, DNSSEC is considered to be unsupported.
-        message_parse.clear(PARSE)
+        message_parse.clear(Message.PARSE)
         factoryFromFile(message_parse, "message_fromWire3")
         self.assertFalse(message_parse.is_dnssec_supported())
     
     def test_SetEDNS0DOBit(self):
         # By default, it's false, and we can enable/disable it.
-        message_parse = Message(PARSE)
-        message_render = Message(RENDER)
+        message_parse = Message(Message.PARSE)
+        message_render = Message(Message.RENDER)
         self.assertFalse(message_render.is_dnssec_supported())
         message_render.set_dnssec_supported(True)
         self.assertTrue(message_render.is_dnssec_supported())
@@ -369,7 +369,7 @@ class ConvertedUnittests(unittest.TestCase):
         self.assertFalse(message_render.is_dnssec_supported())
     
         ## A message in the parse mode doesn't allow this flag to be set.
-        self.assertRaises(InvalidMessageOperation,
+        self.assertRaises(Message.InvalidMessageOperation,
                           message_parse.set_dnssec_supported,
                           True)
         ## Once converted to the render mode, it works as above
@@ -382,25 +382,25 @@ class ConvertedUnittests(unittest.TestCase):
     
     def test_GetEDNS0UDPSize(self):
         # Without EDNS0, the default max UDP size is used.
-        message_parse = Message(PARSE)
+        message_parse = Message(Message.PARSE)
         factoryFromFile(message_parse, "message_fromWire1")
-        self.assertEqual(DEFAULT_MAX_UDPSIZE, message_parse.get_udp_size())
+        self.assertEqual(Message.DEFAULT_MAX_UDPSIZE, message_parse.get_udp_size())
     
         ## If the size specified in EDNS0 > default max, use it.
-        message_parse.clear(PARSE)
+        message_parse.clear(Message.PARSE)
         factoryFromFile(message_parse, "message_fromWire2")
         self.assertEqual(4096, message_parse.get_udp_size())
     
         ## If the size specified in EDNS0 < default max, keep using the default.
-        message_parse.clear(PARSE)
+        message_parse.clear(Message.PARSE)
         factoryFromFile(message_parse, "message_fromWire8")
-        self.assertEqual(DEFAULT_MAX_UDPSIZE, message_parse.get_udp_size())
+        self.assertEqual(Message.DEFAULT_MAX_UDPSIZE, message_parse.get_udp_size())
     
     def test_SetEDNS0UDPSize(self):
         # The default size if unspecified
-        message_render = Message(RENDER)
-        message_parse = Message(PARSE)
-        self.assertEqual(DEFAULT_MAX_UDPSIZE, message_render.get_udp_size())
+        message_render = Message(Message.RENDER)
+        message_parse = Message(Message.PARSE)
+        self.assertEqual(Message.DEFAULT_MAX_UDPSIZE, message_render.get_udp_size())
         # A common buffer size with EDNS, should succeed
         message_render.set_udp_size(4096)
         self.assertEqual(4096, message_render.get_udp_size())
@@ -408,31 +408,31 @@ class ConvertedUnittests(unittest.TestCase):
         message_render.set_udp_size(0xffff)
         self.assertEqual(0xffff, message_render.get_udp_size())
         # Too small is value is rejected
-        self.assertRaises(InvalidMessageUDPSize, message_render.set_udp_size, 511)
+        self.assertRaises(Message.InvalidMessageUDPSize, message_render.set_udp_size, 511)
     
         # A message in the parse mode doesn't allow the set operation.
-        self.assertRaises(InvalidMessageOperation, message_parse.set_udp_size, 4096)
+        self.assertRaises(Message.InvalidMessageOperation, message_parse.set_udp_size, 4096)
         ## Once converted to the render mode, it works as above.
         message_parse.make_response()
         message_parse.set_udp_size(4096)
         self.assertEqual(4096, message_parse.get_udp_size())
         message_parse.set_udp_size(0xffff)
         self.assertEqual(0xffff, message_parse.get_udp_size())
-        self.assertRaises(InvalidMessageUDPSize, message_parse.set_udp_size, 511)
+        self.assertRaises(Message.InvalidMessageUDPSize, message_parse.set_udp_size, 511)
     
     def test_EDNS0ExtCode(self):
         # Extended Rcode = BADVERS
-        message_parse = Message(PARSE)
+        message_parse = Message(Message.PARSE)
         factoryFromFile(message_parse, "message_fromWire10")
         self.assertEqual(Rcode.BADVERS(), message_parse.get_rcode())
     
         # Maximum extended Rcode
-        message_parse.clear(PARSE)
+        message_parse.clear(Message.PARSE)
         factoryFromFile(message_parse, "message_fromWire11")
         self.assertEqual(0xfff, message_parse.get_rcode().get_code())
     
     def test_BadEDNS0(self):
-        message_parse = Message(PARSE)
+        message_parse = Message(Message.PARSE)
         # OPT RR in the answer section
         self.assertRaises(DNSMessageFORMERR,
                           factoryFromFile,
@@ -440,14 +440,14 @@ class ConvertedUnittests(unittest.TestCase):
                           "message_fromWire4")
 
         # multiple OPT RRs (in the additional section)
-        message_parse.clear(PARSE)
+        message_parse.clear(Message.PARSE)
         self.assertRaises(DNSMessageFORMERR,
                           factoryFromFile,
                           message_parse,
                           "message_fromWire5")
 
         ## OPT RR of a non root name
-        message_parse.clear(PARSE)
+        message_parse.clear(Message.PARSE)
         self.assertRaises(DNSMessageFORMERR,
                           factoryFromFile,
                           message_parse,
@@ -456,18 +456,18 @@ class ConvertedUnittests(unittest.TestCase):
         # Compressed owner name of OPT RR points to a root name.
         # Not necessarily bogus, but very unusual and mostly pathological.
         # We accept it, but is it okay?
-        message_parse.clear(PARSE)
+        message_parse.clear(Message.PARSE)
         factoryFromFile(message_parse, "message_fromWire7")
 
         # Unsupported Version
-        message_parse.clear(PARSE)
-        self.assertRaises(DNSMessageBADVERS,
+        message_parse.clear(Message.PARSE)
+        self.assertRaises(Message.DNSMessageBADVERS,
                           factoryFromFile,
                           message_parse,
                           "message_fromWire9")
     
     def test_to_text_and_wire(self):
-        message_render = Message(RENDER)
+        message_render = Message(Message.RENDER)
         message_render.set_qid(0x1035)
         message_render.set_opcode(Opcode.QUERY())
         message_render.set_rcode(Rcode.NOERROR())

+ 3 - 3
src/lib/dns/python/tests/messagerenderer_python_test.py

@@ -29,13 +29,13 @@ class MessageRendererTest(unittest.TestCase):
         t = RRType("A")
         ttl = RRTTL("3600")
         
-        message = Message(RENDER)
+        message = Message(Message.RENDER)
         message.set_qid(123)
         message.set_opcode(Opcode.QUERY())
         message.add_question(Question(name, c, t))
 
         self.message1 = message
-        message = Message(RENDER)
+        message = Message(Message.RENDER)
         message.set_qid(123)
         message.set_header_flag(MessageFlag.AA())
         message.set_header_flag(MessageFlag.QR())
@@ -48,7 +48,7 @@ class MessageRendererTest(unittest.TestCase):
         message.add_rrset(Section.AUTHORITY(), rrset)
         self.message2 = message
 
-        #message = Message(RENDER)
+        #message = Message(Message.RENDER)
         #message.set_qid(123)
         #message.set_header_flag(MessageFlag.AA())
         #message.set_header_flag(MessageFlag.QR())