Browse Source

[1028] fixed leak in Message.get_section

JINMEI Tatuya 13 years ago
parent
commit
6060fcf2a3
2 changed files with 36 additions and 28 deletions
  1. 27 28
      src/lib/dns/python/message_python.cc
  2. 9 0
      src/lib/dns/python/tests/message_python_test.py

+ 27 - 28
src/lib/dns/python/message_python.cc

@@ -16,6 +16,7 @@
 #include <Python.h>
 
 #include <exceptions/exceptions.h>
+#include <util/python/pycppwrapper_util.h>
 #include <dns/message.h>
 #include <dns/rcode.h>
 #include <dns/tsig.h>
@@ -38,6 +39,7 @@ using namespace std;
 using namespace isc::dns;
 using namespace isc::dns::python;
 using namespace isc::util;
+using namespace isc::util::python;
 
 // Import pydoc text
 #include "message_python_inc.cc"
@@ -451,6 +453,21 @@ Message_getQuestion(s_Message* self) {
     return (NULL);
 }
 
+class RRsetInserter {
+public:
+    RRsetInserter(PyObject* pylist) : pylist_(pylist) {}
+    void operator()(ConstRRsetPtr rrset) {
+        if (PyList_Append(pylist_,
+                          PyObjectContainer(createRRsetObject(*rrset)).get())
+            == -1) {
+            isc_throw(PyCPPWrapperException, "PyList_Append failed, "
+                      "probably due to short memory");
+        }
+    }
+private:
+    PyObject* pylist_;
+};
+
 PyObject*
 Message_getSection(s_Message* self, PyObject* args) {
     unsigned int section;
@@ -460,46 +477,28 @@ Message_getSection(s_Message* self, PyObject* args) {
                         "no valid type in get_section argument");
         return (NULL);
     }
-    RRsetIterator rrsi, rrsi_end;
+
     try {
-        rrsi = self->cppobj->beginSection(
-            static_cast<Message::Section>(section));
-        rrsi_end = self->cppobj->endSection(
-            static_cast<Message::Section>(section));
+        PyObjectContainer list_container(PyList_New(0));
+        const Message::Section msgsection =
+            static_cast<Message::Section>(section);
+        for_each(self->cppobj->beginSection(msgsection),
+                 self->cppobj->endSection(msgsection),
+                 RRsetInserter(list_container.get()));
+        return (list_container.release());
     } catch (const isc::OutOfRange& ex) {
         PyErr_SetString(PyExc_OverflowError, ex.what());
-        return (NULL);
     } catch (const InvalidMessageSection& ex) {
         PyErr_SetString(po_InvalidMessageSection, ex.what());
-        return (NULL);
-    } catch (...) {
-        PyErr_SetString(po_IscException,
-                        "Unexpected exception in getting section iterators");
-        return (NULL);
-    }
-
-    PyObject* list = PyList_New(0);
-    if (list == NULL) {
-        return (NULL);
-    }
-    try {
-        for (; rrsi != rrsi_end; ++rrsi) {
-            if (PyList_Append(list, createRRsetObject(**rrsi)) == -1) {
-                    Py_DECREF(list);
-                    return (NULL);
-            }
-        }
-        return (list);
     } catch (const exception& ex) {
         const string ex_what =
-            "Unexpected failure creating Question object: " +
+            "Unexpected failure in Message.get_section: " +
             string(ex.what());
         PyErr_SetString(po_IscException, ex_what.c_str());
     } catch (...) {
         PyErr_SetString(PyExc_SystemError,
-                        "Unexpected failure creating Question object");
+                        "Unexpected failure in Message.get_section");
     }
-    Py_DECREF(list);
     return (NULL);
 }
 

+ 9 - 0
src/lib/dns/python/tests/message_python_test.py

@@ -17,6 +17,7 @@
 # Tests for the message part of the pydnspp module
 #
 
+import sys
 import unittest
 import os
 from pydnspp import *
@@ -230,6 +231,14 @@ class MessageTest(unittest.TestCase):
         self.assertTrue(compare_rrset_list(section_rrset, self.r.get_section(Message.SECTION_ANSWER)))
         self.assertEqual(2, self.r.get_rr_count(Message.SECTION_ANSWER))
 
+        # We always make a new deep copy in get_section(), so the reference
+        # count of the returned list and its each item should be 1; otherwise
+        # they would leak.
+        self.assertEqual(1, sys.getrefcount(self.r.get_section(
+                    Message.SECTION_ANSWER)))
+        self.assertEqual(1, sys.getrefcount(self.r.get_section(
+                    Message.SECTION_ANSWER)[0]))
+
         self.assertFalse(compare_rrset_list(section_rrset, self.r.get_section(Message.SECTION_AUTHORITY)))
         self.assertEqual(0, self.r.get_rr_count(Message.SECTION_AUTHORITY))
         self.r.add_rrset(Message.SECTION_AUTHORITY, self.rrset_a)