Browse Source

tighten validation with exceptions, and add more documentation

git-svn-id: svn://bind10.isc.org/svn/bind10/branches/trac351@3040 e5f2f494-b856-4b98-b285-d166d9295462
JINMEI Tatuya 14 years ago
parent
commit
ea71921698

+ 17 - 2
src/lib/dns/message.cc

@@ -259,6 +259,9 @@ Message::setQid(qid_t qid) {
 
 const Rcode&
 Message::getRcode() const {
+    if (impl_->rcode_ == NULL) {
+        isc_throw(InvalidMessageOperation, "getRcode attempted before set");
+    }
     return (*impl_->rcode_);
 }
 
@@ -273,6 +276,9 @@ Message::setRcode(const Rcode& rcode) {
 
 const Opcode&
 Message::getOpcode() const {
+    if (impl_->opcode_ == NULL) {
+        isc_throw(InvalidMessageOperation, "getOpcode attempted before set");
+    }
     return (*impl_->opcode_);
 }
 
@@ -424,11 +430,11 @@ Message::toWire(MessageRenderer& renderer) {
     }
     if (impl_->rcode_ == NULL) {
         isc_throw(InvalidMessageOperation,
-                  "Message rendering attempted without Rcode");
+                  "Message rendering attempted without Rcode set");
     }
     if (impl_->opcode_ == NULL) {
         isc_throw(InvalidMessageOperation,
-                  "Message rendering attempted without Opcode");
+                  "Message rendering attempted without Opcode set");
     }
 
     // reserve room for the header
@@ -690,6 +696,15 @@ struct SectionFormatter {
 
 string
 Message::toText() const {
+    if (impl_->rcode_ == NULL) {
+        isc_throw(InvalidMessageOperation,
+                  "Message::toText() attempted without Rcode set");
+    }
+    if (impl_->opcode_ == NULL) {
+        isc_throw(InvalidMessageOperation,
+                  "Message::toText() attempted without Opcode set");
+    }
+
     string s;
 
     s += ";; ->>HEADER<<- opcode: " + impl_->opcode_->toText();

+ 19 - 0
src/lib/dns/message.h

@@ -378,6 +378,11 @@ public:
     /// included).  In the \c PARSE mode, if the received message contains
     /// an EDNS OPT RR, the corresponding extended code is identified and
     /// returned.
+    ///
+    /// The message must have been properly parsed (in the case of the
+    /// \c PARSE mode) or an \c Rcode has been set (in the case of the
+    /// \c RENDER mode) beforehand.  Otherwise, an exception of class
+    /// \c InvalidMessageOperation will be thrown.
     const Rcode& getRcode() const;
 
     /// \brief Return the Response Code of the message.
@@ -388,6 +393,11 @@ public:
     void setRcode(const Rcode& rcode);
 
     /// \brief Return the OPCODE given in the header section of the message.
+    ///
+    /// The message must have been properly parsed (in the case of the
+    /// \c PARSE mode) or an \c Opcode has been set (in the case of the
+    /// \c RENDER mode) beforehand.  Otherwise, an exception of class
+    /// \c InvalidMessageOperation will be thrown.
     const Opcode& getOpcode() const;
 
     /// \brief Set the OPCODE of the header section of the message.
@@ -466,10 +476,19 @@ public:
     void makeResponse();
 
     /// \brief Convert the Message to a string.
+    ///
+    /// At least \c Opcode and \c Rcode must be validly set in the \c Message
+    /// (as a result of parse in the \c PARSE mode or by explicitly setting
+    /// in the \c RENDER mode);  otherwise, an exception of
+    /// class \c InvalidMessageOperation will be thrown.
     std::string toText() const;
 
     /// \brief Render the message in wire formant into a \c MessageRenderer
     /// object.
+    ///
+    /// This \c Message must be in the \c RENDER mode and both \c Opcode and
+    /// \c Rcode must have been set beforehand; otherwise, an exception of
+    /// class \c InvalidMessageOperation will be thrown.
     void toWire(MessageRenderer& renderer);
 
     /// \brief Parse the header section of the \c Message.

+ 28 - 6
src/lib/dns/python/message_python.cc

@@ -745,12 +745,18 @@ Message_getRcode(s_Message* self) {
 
     rcode = static_cast<s_Rcode*>(rcode_type.tp_alloc(&rcode_type, 0));
     if (rcode != NULL) {
-        rcode->rcode = new Rcode(self->message->getRcode());
-        if (rcode->rcode == NULL)
-          {
+        rcode->rcode = NULL;
+        try {
+            rcode->rcode = new Rcode(self->message->getRcode());
+        } catch (const InvalidMessageOperation& imo) {
+            PyErr_SetString(po_InvalidMessageOperation, imo.what());
+        } catch (...) {
+            PyErr_SetString(po_IscException, "Unexpected exception");
+        }
+        if (rcode->rcode == NULL) {
             Py_DECREF(rcode);
             return (NULL);
-          }
+        }
     }
 
     return (rcode);
@@ -777,7 +783,14 @@ Message_getOpcode(s_Message* self) {
 
     opcode = static_cast<s_Opcode*>(opcode_type.tp_alloc(&opcode_type, 0));
     if (opcode != NULL) {
-        opcode->opcode = new Opcode(self->message->getOpcode());
+        opcode->opcode = NULL;
+        try {
+            opcode->opcode = new Opcode(self->message->getOpcode());
+        } catch (const InvalidMessageOperation& imo) {
+            PyErr_SetString(po_InvalidMessageOperation, imo.what());
+        } catch (...) {
+            PyErr_SetString(po_IscException, "Unexpected exception");
+        }
         if (opcode->opcode == NULL) {
             Py_DECREF(opcode);
             return (NULL);
@@ -937,7 +950,16 @@ Message_makeResponse(s_Message* self) {
 static PyObject*
 Message_toText(s_Message* self) {
     // Py_BuildValue makes python objects from native data
-    return (Py_BuildValue("s", self->message->toText().c_str()));
+    try {
+        return (Py_BuildValue("s", self->message->toText().c_str()));
+    } catch (const InvalidMessageOperation& imo) {
+        PyErr_Clear();
+        PyErr_SetString(po_InvalidMessageOperation, imo.what());
+        return (NULL);
+    } catch (...) {
+        PyErr_SetString(po_IscException, "Unexpected exception");
+        return (NULL);
+    }
 }
 
 static PyObject*

+ 21 - 0
src/lib/dns/python/tests/message_python_test.py

@@ -198,6 +198,7 @@ class MessageTest(unittest.TestCase):
         self.assertRaises(InvalidMessageOperation,
                           self.p.set_rcode, rcode)
         
+        self.assertRaises(InvalidMessageOperation, self.p.get_rcode)
 
     def test_set_opcode(self):
         self.assertRaises(TypeError, self.r.set_opcode, "wrong")
@@ -209,6 +210,8 @@ class MessageTest(unittest.TestCase):
         self.assertRaises(InvalidMessageOperation,
                           self.p.set_opcode, opcode)
 
+        self.assertRaises(InvalidMessageOperation, self.p.get_opcode)
+
     def test_get_section(self):
         self.assertRaises(TypeError, self.r.get_section, "wrong")
 
@@ -273,6 +276,16 @@ class MessageTest(unittest.TestCase):
         self.assertEqual(b'\x105\x85\x00\x00\x01\x00\x02\x00\x00\x00\x00\x04test\x07example\x03com\x00\x00\x01\x00\x01\xc0\x0c\x00\x01\x00\x01\x00\x00\x0e\x10\x00\x04\xc0\x00\x02\x01\xc0\x0c\x00\x01\x00\x01\x00\x00\x0e\x10\x00\x04\xc0\x00\x02\x02',
                          renderer.get_data())
 
+    def test_to_wire_without_opcode(self):
+        self.r.set_rcode(Rcode.NOERROR())
+        self.assertRaises(InvalidMessageOperation, self.r.to_wire,
+                          MessageRenderer())
+
+    def test_to_wire_without_rcode(self):
+        self.r.set_opcode(Opcode.QUERY())
+        self.assertRaises(InvalidMessageOperation, self.r.to_wire,
+                          MessageRenderer())
+
     def test_to_text(self):
         message_render = create_message()
         
@@ -290,6 +303,14 @@ test.example.com. 3600 IN A 192.0.2.2
         self.assertEqual(msg_str, message_render.to_text())
         self.assertEqual(msg_str, str(message_render))
 
+    def test_to_text_without_opcode(self):
+        self.r.set_rcode(Rcode.NOERROR())
+        self.assertRaises(InvalidMessageOperation, self.r.to_text)
+
+    def test_to_text_without_rcode(self):
+        self.r.set_opcode(Opcode.QUERY())
+        self.assertRaises(InvalidMessageOperation, self.r.to_text)
+
     def test_from_wire(self):
         self.assertRaises(TypeError, self.r.from_wire, 1)
         self.assertRaises(InvalidMessageOperation,

+ 41 - 0
src/lib/dns/tests/message_unittest.cc

@@ -111,6 +111,22 @@ TEST_F(MessageTest, fromWire) {
     EXPECT_TRUE(it->isLast());
 }
 
+TEST_F(MessageTest, opcode) {    // for get/setOpcode
+    EXPECT_THROW(message_parse.setOpcode(Opcode::NOTIFY()),
+                 InvalidMessageOperation);
+    message_render.setOpcode(Opcode::UPDATE());
+    EXPECT_EQ(Opcode::UPDATE(), message_render.getOpcode());
+    EXPECT_THROW(message_parse.getOpcode(), InvalidMessageOperation);
+}
+
+TEST_F(MessageTest, rcode) {    // for get/setRcode
+    EXPECT_THROW(message_parse.setRcode(Rcode::BADVERS()),
+                 InvalidMessageOperation);
+    message_render.setRcode(Rcode::BADVERS());
+    EXPECT_EQ(Rcode::BADVERS(), message_render.getRcode());
+    EXPECT_THROW(message_parse.getRcode(), InvalidMessageOperation);
+}
+
 TEST_F(MessageTest, GetEDNS0DOBit) {
     // Without EDNS0, DNSSEC is considered to be unsupported.
     factoryFromFile(message_parse, "message_fromWire1");
@@ -246,4 +262,29 @@ TEST_F(MessageTest, toWire) {
     EXPECT_PRED_FORMAT4(UnitTestUtil::matchWireData, obuffer.getData(),
                         obuffer.getLength(), &data[0], data.size());
 }
+
+TEST_F(MessageTest, toWireInParseMode) {
+    // toWire() isn't allowed in the parse mode.
+    EXPECT_THROW(message_parse.toWire(renderer), InvalidMessageOperation);
+}
+
+TEST_F(MessageTest, toWireWithoutOpcode) {
+    message_render.setRcode(Rcode::NOERROR());
+    EXPECT_THROW(message_render.toWire(renderer), InvalidMessageOperation);
+}
+
+TEST_F(MessageTest, toWireWithoutRcode) {
+    message_render.setOpcode(Opcode::QUERY());
+    EXPECT_THROW(message_render.toWire(renderer), InvalidMessageOperation);
+}
+
+TEST_F(MessageTest, toTextWithoutOpcode) {
+    message_render.setRcode(Rcode::NOERROR());
+    EXPECT_THROW(message_render.toText(), InvalidMessageOperation);
+}
+
+TEST_F(MessageTest, toTextWithoutRcode) {
+    message_render.setOpcode(Opcode::QUERY());
+    EXPECT_THROW(message_render.toText(), InvalidMessageOperation);
+}
 }