Browse Source

[1258] added a python wrapper for Message.from_wire(PRESERVE_ORDER)

JINMEI Tatuya 13 years ago
parent
commit
b0b0da67c9

+ 26 - 19
src/lib/dns/python/message_python.cc

@@ -649,27 +649,34 @@ PyObject*
 Message_fromWire(s_Message* self, PyObject* args) {
     const char* b;
     Py_ssize_t len;
-    if (!PyArg_ParseTuple(args, "y#", &b, &len)) {
-        return (NULL);
-    }
+    Message::ParseOptions options = Message::PARSE_DEFAULT;
+    if (PyArg_ParseTuple(args, "y#", &b, &len) ||
+        PyArg_ParseTuple(args, "y#I", &b, &len, &options)) {
+        // We need to clear the error in case the first call to ParseTuple
+        // fails.
+        PyErr_Clear();
 
-    InputBuffer inbuf(b, len);
-    try {
-        self->cppobj->fromWire(inbuf);
-        Py_RETURN_NONE;
-    } catch (const InvalidMessageOperation& imo) {
-        PyErr_SetString(po_InvalidMessageOperation, imo.what());
-        return (NULL);
-    } catch (const DNSMessageFORMERR& dmfe) {
-        PyErr_SetString(po_DNSMessageFORMERR, dmfe.what());
-        return (NULL);
-    } catch (const DNSMessageBADVERS& dmfe) {
-        PyErr_SetString(po_DNSMessageBADVERS, dmfe.what());
-        return (NULL);
-    } catch (const MessageTooShort& mts) {
-        PyErr_SetString(po_MessageTooShort, mts.what());
-        return (NULL);
+        InputBuffer inbuf(b, len);
+        try {
+            self->cppobj->fromWire(inbuf, options);
+            Py_RETURN_NONE;
+        } catch (const InvalidMessageOperation& imo) {
+            PyErr_SetString(po_InvalidMessageOperation, imo.what());
+            return (NULL);
+        } catch (const DNSMessageFORMERR& dmfe) {
+            PyErr_SetString(po_DNSMessageFORMERR, dmfe.what());
+            return (NULL);
+        } catch (const DNSMessageBADVERS& dmfe) {
+            PyErr_SetString(po_DNSMessageBADVERS, dmfe.what());
+            return (NULL);
+        } catch (const MessageTooShort& mts) {
+            PyErr_SetString(po_MessageTooShort, mts.what());
+            return (NULL);
+        }
     }
+
+    PyErr_SetString(PyExc_TypeError, "Invalid arguments to Message.from_wire");
+    return (NULL);
 }
 
 } // end of unnamed namespace

+ 4 - 0
src/lib/dns/python/pydnspp.cc

@@ -106,6 +106,10 @@ initModulePart_Message(PyObject* mod) {
         installClassVariable(message_type, "RENDER",
                              Py_BuildValue("I", Message::RENDER));
 
+        // Parse options
+        installClassVariable(message_type, "PRESERVE_ORDER",
+                             Py_BuildValue("I", Message::PRESERVE_ORDER));
+
         // Header flags
         installClassVariable(message_type, "HEADERFLAG_QR",
                              Py_BuildValue("I", Message::HEADERFLAG_QR));

+ 49 - 2
src/lib/dns/python/tests/message_python_test.py

@@ -29,9 +29,12 @@ if "TESTDATA_PATH" in os.environ:
 else:
     testdata_path = "../tests/testdata"
 
-def factoryFromFile(message, file):
+def factoryFromFile(message, file, parse_options=None):
     data = read_wire_data(file)
-    message.from_wire(data)
+    if parse_options is None:
+        message.from_wire(data)
+    else:
+        message.from_wire(data, parse_options)
     return data
 
 # we don't have direct comparison for rrsets right now (should we?
@@ -466,6 +469,50 @@ test.example.com. 3600 IN A 192.0.2.2
         self.assertEqual("192.0.2.2", rdata[1].to_text())
         self.assertEqual(2, len(rdata))
 
+    def test_from_wire_combind_rrs(self):
+        factoryFromFile(self.p, "message_fromWire19.wire")
+        rrset = self.p.get_section(Message.SECTION_ANSWER)[0]
+        self.assertEqual(RRType("A"), rrset.get_type())
+        self.assertEqual(2, len(rrset.get_rdata()))
+
+        rrset = self.p.get_section(Message.SECTION_ANSWER)[1]
+        self.assertEqual(RRType("AAAA"), rrset.get_type())
+        self.assertEqual(1, len(rrset.get_rdata()))
+
+    def check_preserve_rrs(self, message, section):
+        rrset = message.get_section(section)[0]
+        self.assertEqual(RRType("A"), rrset.get_type())
+        rdata = rrset.get_rdata()
+        self.assertEqual(1, len(rdata))
+        self.assertEqual('192.0.2.1', rdata[0].to_text())
+
+        rrset = message.get_section(section)[1]
+        self.assertEqual(RRType("AAAA"), rrset.get_type())
+        rdata = rrset.get_rdata()
+        self.assertEqual(1, len(rdata))
+        self.assertEqual('2001:db8::1', rdata[0].to_text())
+
+        rrset = message.get_section(section)[2]
+        self.assertEqual(RRType("A"), rrset.get_type())
+        rdata = rrset.get_rdata()
+        self.assertEqual(1, len(rdata))
+        self.assertEqual('192.0.2.2', rdata[0].to_text())
+
+    def test_from_wire_preserve_answer(self):
+        factoryFromFile(self.p, "message_fromWire19.wire",
+                        Message.PRESERVE_ORDER)
+        self.check_preserve_rrs(self.p, Message.SECTION_ANSWER)
+
+    def test_from_wire_preserve_authority(self):
+        factoryFromFile(self.p, "message_fromWire20.wire",
+                        Message.PRESERVE_ORDER)
+        self.check_preserve_rrs(self.p, Message.SECTION_AUTHORITY)
+
+    def test_from_wire_preserve_additional(self):
+        factoryFromFile(self.p, "message_fromWire21.wire",
+                        Message.PRESERVE_ORDER)
+        self.check_preserve_rrs(self.p, Message.SECTION_ADDITIONAL)
+
     def test_EDNS0ExtCode(self):
         # Extended Rcode = BADVERS
         message_parse = Message(Message.PARSE)