Browse Source

[1028] fixed leak in Message.get_question()

JINMEI Tatuya 13 years ago
parent
commit
5cb4d41cf6
2 changed files with 37 additions and 43 deletions
  1. 30 42
      src/lib/dns/python/message_python.cc
  2. 7 1
      src/lib/dns/python/tests/message_python_test.py

+ 30 - 42
src/lib/dns/python/message_python.cc

@@ -412,62 +412,50 @@ Message_getRRCount(s_Message* self, PyObject* args) {
 }
 
 // TODO use direct iterators for these? (or simply lists for now?)
+template <typename ItemType, typename CreatorParamType>
+class SectionInserter {
+    typedef PyObject* (*creator_t)(const CreatorParamType&);
+public:
+    SectionInserter(PyObject* pylist, creator_t creator) :
+        pylist_(pylist), creator_(creator)
+    {}
+    void operator()(ItemType item) {
+        if (PyList_Append(pylist_, PyObjectContainer(creator_(*item)).get())
+            == -1) {
+            isc_throw(PyCPPWrapperException, "PyList_Append failed, "
+                      "probably due to short memory");
+        }
+    }
+private:
+    PyObject* pylist_;
+    creator_t creator_;
+};
+
+typedef SectionInserter<ConstQuestionPtr, Question> QuestionInserter;
+typedef SectionInserter<ConstRRsetPtr, RRset> RRsetInserter;
+
 PyObject*
 Message_getQuestion(s_Message* self) {
-    QuestionIterator qi, qi_end;
     try {
-        qi = self->cppobj->beginQuestion();
-        qi_end = self->cppobj->endQuestion();
+        PyObjectContainer list_container(PyList_New(0));
+        for_each(self->cppobj->beginQuestion(),
+                 self->cppobj->endQuestion(),
+                 QuestionInserter(list_container.get(), createQuestionObject));
+        return (list_container.release());
     } catch (const InvalidMessageSection& ex) {
         PyErr_SetString(po_InvalidMessageSection, ex.what());
-        return (NULL);
-    } catch (...) {
-        PyErr_SetString(po_IscException,
-                        "Unexpected exception in getting section iterators");
-        return (NULL);
-    }
-
-    PyObject* list = PyList_New(0);
-    if (list == NULL) {
-        return (NULL);
-    }
-
-    try {
-        for (; qi != qi_end; ++qi) {
-            if (PyList_Append(list, createQuestionObject(**qi)) == -1) {
-                Py_DECREF(list);
-                return (NULL);
-            }
-        }
-        return (list);
     } catch (const exception& ex) {
         const string ex_what =
-            "Unexpected failure getting Question section: " +
+            "Unexpected failure in Message.get_question: " +
             string(ex.what());
         PyErr_SetString(po_IscException, ex_what.c_str());
     } catch (...) {
         PyErr_SetString(PyExc_SystemError,
-                        "Unexpected failure getting Question section");
+                        "Unexpected failure in Message.get_question");
     }
-    Py_DECREF(list);
     return (NULL);
 }
 
-class RRsetInserter {
-public:
-    RRsetInserter(PyObject* pylist) : pylist_(pylist) {}
-    void operator()(ConstRRsetPtr rrset) {
-        if (PyList_Append(pylist_,
-                          PyObjectContainer(createRRsetObject(*rrset)).get())
-            == -1) {
-            isc_throw(PyCPPWrapperException, "PyList_Append failed, "
-                      "probably due to short memory");
-        }
-    }
-private:
-    PyObject* pylist_;
-};
-
 PyObject*
 Message_getSection(s_Message* self, PyObject* args) {
     unsigned int section;
@@ -484,7 +472,7 @@ Message_getSection(s_Message* self, PyObject* args) {
             static_cast<Message::Section>(section);
         for_each(self->cppobj->beginSection(msgsection),
                  self->cppobj->endSection(msgsection),
-                 RRsetInserter(list_container.get()));
+                 RRsetInserter(list_container.get(), createRRsetObject));
         return (list_container.release());
     } catch (const isc::OutOfRange& ex) {
         PyErr_SetString(PyExc_OverflowError, ex.what());

+ 7 - 1
src/lib/dns/python/tests/message_python_test.py

@@ -251,7 +251,7 @@ class MessageTest(unittest.TestCase):
         self.assertTrue(compare_rrset_list(section_rrset, self.r.get_section(Message.SECTION_ADDITIONAL)))
         self.assertEqual(2, self.r.get_rr_count(Message.SECTION_ADDITIONAL))
 
-    def test_add_question(self):
+    def test_add_and_get_question(self):
         self.assertRaises(TypeError, self.r.add_question, "wrong", "wrong")
         q = Question(Name("example.com"), RRClass("IN"), RRType("A"))
         qs = [q]
@@ -261,6 +261,12 @@ class MessageTest(unittest.TestCase):
         self.assertTrue(compare_rrset_list(qs, self.r.get_question()))
         self.assertEqual(1, self.r.get_rr_count(Message.SECTION_QUESTION))
 
+        # We always make a new deep copy in get_section(), so the reference
+        # count of the returned list and its each item should be 1; otherwise
+        # they would leak.
+        self.assertEqual(1, sys.getrefcount(self.r.get_question()))
+        self.assertEqual(1, sys.getrefcount(self.r.get_question()[0]))
+
     def test_add_rrset(self):
         self.assertRaises(TypeError, self.r.add_rrset, "wrong")
         self.assertRaises(TypeError, self.r.add_rrset)