Browse Source

added tests and python binding for get/setEDNS

git-svn-id: svn://bind10.isc.org/svn/bind10/branches/trac311@2795 e5f2f494-b856-4b98-b285-d166d9295462
JINMEI Tatuya 14 years ago
parent
commit
d1741785dd

+ 14 - 4
src/lib/dns/message.h

@@ -646,9 +646,10 @@ public:
     /// returned.
     const Rcode& getRcode() const;
 
-    /// \brief Return the Response Code of the message.
+    /// \brief Set the Response Code of the message.
     ///
     /// Only allowed in the \c RENDER mode.
+    ///
     /// If the specified code is an EDNS extended RCODE, an EDNS OPT RR will be
     /// included in the message.
     void setRcode(const Rcode& rcode);
@@ -661,12 +662,21 @@ public:
     /// Only allowed in the \c RENDER mode.
     void setOpcode(const Opcode& opcode);
 
-    /// \brief TBD
+    /// \brief Return, if any, the EDNS associated with the message.
+    ///
+    /// This method never throws an exception.
+    ///
+    /// \return A shared pointer to the EDNS.  This will be a null shared
+    /// pointer if the message is not associated with EDNS.
     ConstEDNSPtr getEDNS() const;
 
-    /// \brief TBD
+    /// \brief Set EDNS for the message.
     ///
-    /// Only allowed in the \c RENDER mode.
+    /// Only allowed in the \c RENDER mode; otherwise an exception of class
+    /// \c InvalidMessageOperation will be thrown.
+    ///
+    /// \param edns A shared pointer to an \c EDNS object to be set in
+    /// \c Message.
     void setEDNS(ConstEDNSPtr edns);
 
     /// \brief Returns the number of RRs contained in the given section.

+ 8 - 1
src/lib/dns/python/libdns_python.cc

@@ -44,6 +44,9 @@
 static PyObject* po_IscException;
 static PyObject* po_InvalidParameter;
 
+// For our own isc::dns::Exception
+static PyObject* po_DNSMessageBADVERS;
+
 // order is important here!
 #include <dns/python/messagerenderer_python.cc>
 #include <dns/python/name_python.cc>           // needs Messagerenderer
@@ -54,8 +57,8 @@ static PyObject* po_InvalidParameter;
 #include <dns/python/rrset_python.cc>          // needs Rdata, RRTTL
 #include <dns/python/question_python.cc>       // needs RRClass, RRType, RRTTL,
                                                // Name
-#include <dns/python/message_python.cc>        // needs RRset, Question
 #include <dns/python/edns_python.cc>           // needs Messagerenderer, Rcode
+#include <dns/python/message_python.cc>        // needs RRset, Question, EDNS
 
 //
 // Definition of the module
@@ -92,6 +95,10 @@ PyInit_libdns_python(void) {
                                              NULL, NULL);
     PyModule_AddObject(mod, "InvalidParameter", po_InvalidParameter);
 
+    po_DNSMessageBADVERS = PyErr_NewException(
+        "libdns_python.DNSMessageBADVERS", NULL, NULL);
+    PyModule_AddObject(mod, "DNSMessageBADVERS", po_DNSMessageBADVERS);
+
     // for each part included above, we call its specific initializer
 
     if (!initModulePart_Name(mod)) {

+ 43 - 3
src/lib/dns/python/message_python.cc

@@ -26,7 +26,6 @@ static PyObject* po_MessageTooShort;
 static PyObject* po_InvalidMessageSection;
 static PyObject* po_InvalidMessageOperation;
 static PyObject* po_InvalidMessageUDPSize;
-static PyObject* po_DNSMessageBADVERS;
 
 //
 // Definition of the classes
@@ -987,6 +986,8 @@ static PyObject* Message_getRcode(s_Message* self);
 static PyObject* Message_setRcode(s_Message* self, PyObject* args);
 static PyObject* Message_getOpcode(s_Message* self);
 static PyObject* Message_setOpcode(s_Message* self, PyObject* args);
+static PyObject* Message_getEDNS(s_Message* self);
+static PyObject* Message_setEDNS(s_Message* self, PyObject* args);
 static PyObject* Message_getRRCount(s_Message* self, PyObject* args);
 // use direct iterators for these? (or simply lists for now?)
 static PyObject* Message_getQuestion(s_Message* self);
@@ -1041,6 +1042,12 @@ static PyMethodDef Message_methods[] = {
       "Sets the message opcode (an Opcode object).\n"
       "If the message is not in RENDER mode, an "
       "InvalidMessageOperation is raised."},
+    { "get_edns", reinterpret_cast<PyCFunction>(Message_getEDNS), METH_NOARGS,
+      "Return, if any, the EDNS associated with the message."
+    },
+    { "set_edns", reinterpret_cast<PyCFunction>(Message_setEDNS), METH_VARARGS,
+      "Set EDNS for the message."
+    },
     { "get_rr_count", reinterpret_cast<PyCFunction>(Message_getRRCount), METH_VARARGS,
       "Returns the number of RRs contained in the given section." },
     { "get_question", reinterpret_cast<PyCFunction>(Message_getQuestion), METH_NOARGS,
@@ -1310,6 +1317,41 @@ Message_setOpcode(s_Message* self, PyObject* args) {
 }
 
 static PyObject*
+Message_getEDNS(s_Message* self) {
+    s_EDNS* edns;
+    EDNS* edns_body;
+    ConstEDNSPtr src = self->message->getEDNS();
+
+    if (!src) {
+        Py_RETURN_NONE;
+    }
+    if ((edns_body = new(nothrow) EDNS(*src)) == NULL) {
+        return (PyErr_NoMemory());
+    }
+    edns = static_cast<s_EDNS*>(opcode_type.tp_alloc(&edns_type, 0));
+    if (edns != NULL) {
+        edns->edns = edns_body;
+    }
+
+    return (edns);
+}
+
+static PyObject*
+Message_setEDNS(s_Message* self, PyObject* args) {
+    s_EDNS* edns;
+    if (!PyArg_ParseTuple(args, "O!", &edns_type, &edns)) {
+        return (NULL);
+    }
+    try {
+        self->message->setEDNS(EDNSPtr(new EDNS(*edns->edns)));
+        Py_RETURN_NONE;
+    } catch (const InvalidMessageOperation& imo) {
+        PyErr_SetString(po_InvalidMessageOperation, imo.what());
+        return (NULL);
+    }
+}
+
+static PyObject*
 Message_getRRCount(s_Message* self, PyObject* args) {
     s_Section *section;
     if (!PyArg_ParseTuple(args, "O!", &section_type, &section)) {
@@ -1559,8 +1601,6 @@ initModulePart_Message(PyObject* mod) {
     PyModule_AddObject(mod, "InvalidMessageOperation", po_InvalidMessageOperation);
     po_InvalidMessageUDPSize = PyErr_NewException("libdns_python.InvalidMessageUDPSize", NULL, NULL);
     PyModule_AddObject(mod, "InvalidMessageUDPSize", po_InvalidMessageUDPSize);
-    po_DNSMessageBADVERS = PyErr_NewException("libdns_python.DNSMessageBADVERS", NULL, NULL);
-    PyModule_AddObject(mod, "DNSMessageBADVERS", po_DNSMessageBADVERS);
 
     Py_INCREF(&message_type);
     PyModule_AddObject(mod, "Message",

+ 18 - 0
src/lib/dns/python/tests/message_python_test.py

@@ -331,6 +331,24 @@ class MessageTest(unittest.TestCase):
         self.assertRaises(InvalidMessageOperation,
                           self.p.set_opcode, opcode)
 
+    def test_get_edns(self):
+        self.assertEqual(None, self.p.get_edns())
+
+        message_parse = Message(Message.PARSE)
+        factoryFromFile(message_parse, "message_fromWire10")
+        edns = message_parse.get_edns()
+        self.assertEqual(0, edns.get_version())
+        self.assertEqual(4096, edns.get_udp_size())
+        self.assertTrue(edns.is_dnssec_supported())
+
+    def test_set_edns(self):
+        self.assertRaises(InvalidMessageOperation, self.p.set_edns, EDNS())
+
+        edns = EDNS()
+        edns.set_udp_size(1024)
+        self.r.set_edns(edns)
+        self.assertEqual(1024, self.r.get_edns().get_udp_size())
+
     def test_get_section(self):
         self.assertRaises(TypeError, self.r.get_section, "wrong")
 

+ 20 - 0
src/lib/dns/tests/message_unittest.cc

@@ -17,6 +17,7 @@
 #include <exceptions/exceptions.h>
 
 #include <dns/buffer.h>
+#include <dns/edns.h>
 #include <dns/exceptions.h>
 #include <dns/message.h>
 #include <dns/messagerenderer.h>
@@ -93,6 +94,25 @@ TEST_F(MessageTest, RcodeToText) {
     EXPECT_EQ("4095", Rcode(Rcode(0xfff)).toText());
 }
 
+TEST_F(MessageTest, getEDNS) {
+    EXPECT_FALSE(message_parse.getEDNS()); // by default EDNS isn't set
+
+    factoryFromFile(message_parse, "message_fromWire10");
+    EXPECT_TRUE(message_parse.getEDNS());
+    EXPECT_EQ(0, message_parse.getEDNS()->getVersion());
+    EXPECT_EQ(4096, message_parse.getEDNS()->getUDPSize());
+    EXPECT_TRUE(message_parse.getEDNS()->isDNSSECSupported());
+}
+
+TEST_F(MessageTest, setEDNS) {
+    // setEDNS() isn't allowed in the parse mode
+    EXPECT_THROW(message_parse.setEDNS(EDNSPtr(new EDNS())),
+                 InvalidMessageOperation);
+
+    EDNSPtr edns = EDNSPtr(new EDNS());
+    message_render.setEDNS(edns);
+    EXPECT_EQ(edns, message_render.getEDNS());
+}
 
 TEST_F(MessageTest, fromWire) {
     factoryFromFile(message_parse, "message_fromWire1");