Browse Source

[master] Merge branch 'trac1028'

JINMEI Tatuya 13 years ago
parent
commit
a72886e643

+ 115 - 6
src/bin/xfrin/tests/xfrin_test.py

@@ -16,6 +16,7 @@
 import unittest
 import shutil
 import socket
+import sys
 import io
 from isc.testutils.tsigctx_mock import MockTSIGContext
 from xfrin import *
@@ -216,8 +217,8 @@ class MockXfrin(Xfrin):
                                  request_type, check_soa)
 
 class MockXfrinConnection(XfrinConnection):
-    def __init__(self, sock_map, zone_name, rrclass, shutdown_event,
-                 master_addr):
+    def __init__(self, sock_map, zone_name, rrclass, datasrc_client,
+                 shutdown_event, master_addr, tsig_key=None):
         super().__init__(sock_map, zone_name, rrclass, MockDataSourceClient(),
                          shutdown_event, master_addr)
         self.query_data = b''
@@ -300,8 +301,9 @@ class TestXfrinState(unittest.TestCase):
     def setUp(self):
         self.sock_map = {}
         self.conn = MockXfrinConnection(self.sock_map, TEST_ZONE_NAME,
-                                        TEST_RRCLASS, threading.Event(),
+                                        TEST_RRCLASS, None, threading.Event(),
                                         TEST_MASTER_IPV4_ADDRINFO)
+        self.conn.init_socket()
         self.begin_soa = RRset(TEST_ZONE_NAME, TEST_RRCLASS, RRType.SOA(),
                                RRTTL(3600))
         self.begin_soa.add_rdata(Rdata(RRType.SOA(), TEST_RRCLASS,
@@ -585,8 +587,9 @@ class TestXfrinConnection(unittest.TestCase):
             os.remove(TEST_DB_FILE)
         self.sock_map = {}
         self.conn = MockXfrinConnection(self.sock_map, TEST_ZONE_NAME,
-                                        TEST_RRCLASS, threading.Event(),
+                                        TEST_RRCLASS, None, threading.Event(),
                                         TEST_MASTER_IPV4_ADDRINFO)
+        self.conn.init_socket()
         self.soa_response_params = {
             'questions': [example_soa_question],
             'bad_qid': False,
@@ -720,14 +723,16 @@ class TestAXFR(TestXfrinConnection):
         # to confirm an AF_INET6 socket has been created.  A naive application
         # tends to assume it's IPv4 only and hardcode AF_INET.  This test
         # uncovers such a bug.
-        c = MockXfrinConnection({}, TEST_ZONE_NAME, TEST_RRCLASS,
+        c = MockXfrinConnection({}, TEST_ZONE_NAME, TEST_RRCLASS, None,
                                 threading.Event(), TEST_MASTER_IPV6_ADDRINFO)
+        c.init_socket()
         c.bind(('::', 0))
         c.close()
 
     def test_init_chclass(self):
-        c = MockXfrinConnection({}, TEST_ZONE_NAME, RRClass.CH(),
+        c = MockXfrinConnection({}, TEST_ZONE_NAME, RRClass.CH(), None,
                                 threading.Event(), TEST_MASTER_IPV4_ADDRINFO)
+        c.init_socket()
         axfrmsg = c._create_query(RRType.AXFR())
         self.assertEqual(axfrmsg.get_question()[0].get_class(),
                          RRClass.CH())
@@ -1679,6 +1684,110 @@ class TestXfrinRecorder(unittest.TestCase):
         self.recorder.decrement(TEST_ZONE_NAME)
         self.assertEqual(self.recorder.xfrin_in_progress(TEST_ZONE_NAME), False)
 
+class TestXfrinProcess(unittest.TestCase):
+    def setUp(self):
+        self.unlocked = False
+        self.conn_closed = False
+        self.do_raise_on_close = False
+        self.do_raise_on_connect = False
+        self.do_raise_on_publish = False
+        self.master = (socket.AF_INET, socket.SOCK_STREAM,
+                       (TEST_MASTER_IPV4_ADDRESS, TEST_MASTER_PORT))
+
+    def tearDown(self):
+        # whatever happens the lock acquired in xfrin_recorder.increment
+        # must always be released.  We checked the condition for all test
+        # cases.
+        self.assertTrue(self.unlocked)
+
+        # Same for the connection
+        self.assertTrue(self.conn_closed)
+
+    def increment(self, zone_name):
+        '''Fake method of xfrin_recorder.increment.
+
+        '''
+        self.unlocked = False
+
+    def decrement(self, zone_name):
+        '''Fake method of xfrin_recorder.decrement.
+
+        '''
+        self.unlocked = True
+
+    def publish_xfrin_news(self, zone_name, rrclass, ret):
+        '''Fake method of serve.publish_xfrin_news
+
+        '''
+        if self.do_raise_on_publish:
+            raise XfrinTestException('Emulated exception in publish')
+
+    def connect_to_master(self, conn):
+        self.sock_fd = conn.fileno()
+        if self.do_raise_on_connect:
+            raise XfrinTestException('Emulated exception in connect')
+        return True
+
+    def conn_close(self, conn):
+        self.conn_closed = True
+        XfrinConnection.close(conn)
+        if self.do_raise_on_close:
+            raise XfrinTestException('Emulated exception in connect')
+
+    def create_xfrinconn(self, sock_map, zone_name, rrclass, datasrc_client,
+                         shutdown_event, master_addrinfo, tsig_key):
+        conn = MockXfrinConnection(sock_map, zone_name, rrclass,
+                                   datasrc_client, shutdown_event,
+                                   master_addrinfo, tsig_key)
+
+        # An awkward check that would specifically identify an old bug
+        # where initialziation of XfrinConnection._tsig_ctx_creator caused
+        # self reference and subsequently led to reference leak.
+        orig_ref = sys.getrefcount(conn)
+        conn._tsig_ctx_creator = None
+        self.assertEqual(orig_ref, sys.getrefcount(conn))
+
+        # Replace some methods for connect with our internal ones for the
+        # convenience of tests
+        conn.connect_to_master = lambda : self.connect_to_master(conn)
+        conn.do_xfrin = lambda x, y : XFRIN_OK
+        conn.close = lambda : self.conn_close(conn)
+
+        return conn
+
+    def test_process_xfrin_normal(self):
+        # Normal, successful case.  We only check that things are cleaned up
+        # at the tearDown time.
+        process_xfrin(self, self, TEST_ZONE_NAME, TEST_RRCLASS, None, None,
+                      self.master,  False, None, RRType.AXFR(),
+                      self.create_xfrinconn)
+
+    def test_process_xfrin_exception_on_connect(self):
+        # connect_to_master() will raise an exception.  Things must still be
+        # cleaned up.
+        self.do_raise_on_connect = True
+        process_xfrin(self, self, TEST_ZONE_NAME, TEST_RRCLASS, None, None,
+                      self.master,  False, None, RRType.AXFR(),
+                      self.create_xfrinconn)
+
+    def test_process_xfrin_exception_on_close(self):
+        # connect() will result in exception, and even the cleanup close()
+        # will fail with an exception.  This should be quite likely a bug,
+        # but we deal with that case.
+        self.do_raise_on_connect = True
+        self.do_raise_on_close = True
+        process_xfrin(self, self, TEST_ZONE_NAME, TEST_RRCLASS, None, None,
+                      self.master,  False, None, RRType.AXFR(),
+                      self.create_xfrinconn)
+
+    def test_process_xfrin_exception_on_publish(self):
+        # xfr succeeds but notifying the zonemgr fails with exception.
+        # everything must still be cleaned up.
+        self.do_raise_on_publish = True
+        process_xfrin(self, self, TEST_ZONE_NAME, TEST_RRCLASS, None, None,
+                      self.master,  False, None, RRType.AXFR(),
+                      self.create_xfrinconn)
+
 class TestXfrin(unittest.TestCase):
     def setUp(self):
         # redirect output

+ 84 - 69
src/bin/xfrin/xfrin.py.in

@@ -323,6 +323,7 @@ class XfrinFirstData(XfrinState):
                  conn.zone_str())
             # We are now going to add RRs to the new zone.  We need create
             # a Diff object.  It will be used throughtout the XFR session.
+            # DISABLE FOR DEBUG
             conn._diff = Diff(conn._datasrc_client, conn._zone_name, True)
             self.set_xfrstate(conn, XfrinAXFR())
         return False
@@ -468,21 +469,27 @@ class XfrinConnection(asyncore.dispatcher):
         # Data source handler
         self._datasrc_client = datasrc_client
 
-        self.create_socket(master_addrinfo[0], master_addrinfo[1])
         self._sock_map = sock_map
         self._soa_rr_count = 0
         self._idle_timeout = idle_timeout
-        self.setblocking(1)
         self._shutdown_event = shutdown_event
-        self._master_address = master_addrinfo[2]
+        self._master_addrinfo = master_addrinfo
         self._tsig_key = tsig_key
         self._tsig_ctx = None
         # tsig_ctx_creator is introduced to allow tests to use a mock class for
         # easier tests (in normal case we always use the default)
-        self._tsig_ctx_creator = self.__create_tsig_ctx
+        self._tsig_ctx_creator = lambda key : TSIGContext(key)
 
-    def __create_tsig_ctx(self, key):
-        return TSIGContext(key)
+    def init_socket(self):
+        '''Initialize the underlyig socket.
+
+        This is essentially a part of __init__() and is expected to be
+        called immediately after the constructor.  It's separated from
+        the constructor because otherwise we might not be able to close
+        it if the constructor raises an exception after opening the socket.
+        '''
+        self.create_socket(self._master_addrinfo[0], self._master_addrinfo[1])
+        self.setblocking(1)
 
     def __set_xfrstate(self, new_state):
         self.__state = new_state
@@ -498,10 +505,11 @@ class XfrinConnection(asyncore.dispatcher):
         '''Connect to master in TCP.'''
 
         try:
-            self.connect(self._master_address)
+            self.connect(self._master_addrinfo[2])
             return True
         except socket.error as e:
-            logger.error(XFRIN_CONNECT_MASTER, self._master_address, str(e))
+            logger.error(XFRIN_CONNECT_MASTER, self._master_addrinfo[2],
+                         str(e))
             return False
 
     def _get_zone_soa(self):
@@ -697,7 +705,6 @@ class XfrinConnection(asyncore.dispatcher):
             # (if not yet - possible in case of xfr-level exception) as soon
             # as possible
             self._diff = None
-            self.close()
 
         return ret
 
@@ -730,33 +737,6 @@ class XfrinConnection(asyncore.dispatcher):
         if msg.get_rr_count(Message.SECTION_QUESTION) > 1:
             raise XfrinException('query section count greater than 1')
 
-    def _handle_answer_section(self, answer_section):
-        '''Return a generator for the reponse in one tcp package to a zone transfer.'''
-
-        for rrset in answer_section:
-            rrset_name = rrset.get_name().to_text()
-            rrset_ttl = int(rrset.get_ttl().to_text())
-            rrset_class = rrset.get_class().to_text()
-            rrset_type = rrset.get_type().to_text()
-
-            for rdata in rrset.get_rdata():
-                # Count the soa record count
-                if rrset.get_type() == RRType.SOA():
-                    self._soa_rr_count += 1
-
-                    # XXX: the current DNS message parser can't preserve the
-                    # RR order or separete the beginning and ending SOA RRs.
-                    # As a short term workaround, we simply ignore the second
-                    # SOA, and ignore the erroneous case where the transfer
-                    # session doesn't end with an SOA.
-                    if (self._soa_rr_count == 2):
-                        # Avoid inserting soa record twice
-                        break
-
-                rdata_text = rdata.to_text()
-                yield (rrset_name, rrset_ttl, rrset_class, rrset_type,
-                       rdata_text)
-
     def _handle_xfrin_responses(self):
         read_next_msg = True
         while read_next_msg:
@@ -794,47 +774,82 @@ class XfrinConnection(asyncore.dispatcher):
 
         return False
 
-    def log_info(self, msg, type='info'):
-        # Overwrite the log function, log nothing
-        pass
-
-def process_xfrin(server, xfrin_recorder, zone_name, rrclass, db_file,
-                  shutdown_event, master_addrinfo, check_soa, tsig_key,
-                  request_type):
-    xfrin_recorder.increment(zone_name)
-
-    # Create a data source client used in this XFR session.  Right now we
-    # still assume an sqlite3-based data source, and use both the old and new
-    # data source APIs.  We also need to use a mock client for tests.
-    # For a temporary workaround to deal with these situations, we skip the
-    # creation when the given file is none (the test case).  Eventually
-    # this code will be much cleaner.
-    datasrc_client = None
-    if db_file is not None:
-        # temporary hardcoded sqlite initialization. Once we decide on
-        # the config specification, we need to update this (TODO)
-        # this may depend on #1207, or any followup ticket created for #1207
-        datasrc_type = "sqlite3"
-        datasrc_config = "{ \"database_file\": \"" + db_file + "\"}"
-        datasrc_client = DataSourceClient(datasrc_type, datasrc_config)
-
-    # Create a TCP connection for the XFR session and perform the operation.
-    sock_map = {}
-    conn = XfrinConnection(sock_map, zone_name, rrclass, datasrc_client,
-                           shutdown_event, master_addrinfo, tsig_key)
-    # XXX: We still need _db_file for temporary workaround in _create_query().
-    # This should be removed when we eliminate the need for the workaround.
-    conn._db_file = db_file
+def __process_xfrin(server, zone_name, rrclass, db_file,
+                    shutdown_event, master_addrinfo, check_soa, tsig_key,
+                    request_type, conn_class=XfrinConnection):
+    conn = None
+    exception = None
     ret = XFRIN_FAIL
-    if conn.connect_to_master():
-        ret = conn.do_xfrin(check_soa, request_type)
+    try:
+        # Create a data source client used in this XFR session.  Right now we
+        # still assume an sqlite3-based data source, and use both the old and
+        # new data source APIs.  We also need to use a mock client for tests.
+        # For a temporary workaround to deal with these situations, we skip the
+        # creation when the given file is none (the test case).  Eventually
+        # this code will be much cleaner.
+        datasrc_client = None
+        if db_file is not None:
+            # temporary hardcoded sqlite initialization. Once we decide on
+            # the config specification, we need to update this (TODO)
+            # this may depend on #1207, or any followup ticket created for #1207
+            datasrc_type = "sqlite3"
+            datasrc_config = "{ \"database_file\": \"" + db_file + "\"}"
+            datasrc_client = DataSourceClient(datasrc_type, datasrc_config)
+
+        # Create a TCP connection for the XFR session and perform the operation
+        sock_map = {}
+        conn = conn_class(sock_map, zone_name, rrclass, datasrc_client,
+                          shutdown_event, master_addrinfo, tsig_key)
+        conn.init_socket()
+        # XXX: We still need _db_file for temporary workaround in _create_query().
+        # This should be removed when we eliminate the need for the workaround.
+        conn._db_file = db_file
+        if conn.connect_to_master():
+            ret = conn.do_xfrin(check_soa, request_type)
+    except Exception as ex:
+        # If exception happens, just remember it here so that we can re-raise
+        # after cleaning up things.  We don't log it here because we want
+        # eliminate smallest possibility of having an exception in logging
+        # itself.
+        exception = ex
+
+    # asyncore.dispatcher requires explicit close() unless its lifetime
+    # from born to destruction is closed within asyncore.loop, which is not
+    # the case for us.  We always close() here, whether or not do_xfrin
+    # succeeds, and even when we see an unexpected exception.
+    if conn is not None:
+        conn.close()
 
     # Publish the zone transfer result news, so zonemgr can reset the
     # zone timer, and xfrout can notify the zone's slaves if the result
     # is success.
     server.publish_xfrin_news(zone_name, rrclass, ret)
+
+    if exception is not None:
+        raise exception
+
+def process_xfrin(server, xfrin_recorder, zone_name, rrclass, db_file,
+                  shutdown_event, master_addrinfo, check_soa, tsig_key,
+                  request_type, conn_class=XfrinConnection):
+    # Even if it should be rare, the main process of xfrin session can
+    # raise an exception.  In order to make sure the lock in xfrin_recorder
+    # is released in any cases, we delegate the main part to the helper
+    # function in the try block, catch any exceptions, then release the lock.
+    xfrin_recorder.increment(zone_name)
+    exception = None
+    try:
+        __process_xfrin(server, zone_name, rrclass, db_file,
+                        shutdown_event, master_addrinfo, check_soa, tsig_key,
+                        request_type, conn_class)
+    except Exception as ex:
+        # don't log it until we complete decrement().
+        exception = ex
     xfrin_recorder.decrement(zone_name)
 
+    if exception is not None:
+        typestr = "AXFR" if request_type == RRType.AXFR() else "IXFR"
+        logger.error(XFRIN_XFR_PROCESS_FAILURE, typestr, zone_name.to_text(),
+                     str(rrclass), str(exception))
 
 class XfrinRecorder:
     def __init__(self):

+ 15 - 0
src/bin/xfrin/xfrin_messages.mes

@@ -29,6 +29,21 @@ this can only happen for AXFR.
 The XFR transfer for the given zone has failed due to a protocol error.
 The error is shown in the log message.
 
+% XFRIN_XFR_PROCESS_FAILURE %1 transfer of zone %2/%3 failed: %4
+An XFR session failed outside the main protocol handling.  This
+includes an error at the data source level at the initialization
+phase, unexpected failure in the network connection setup to the
+master server, or even more unexpected failure due to unlikely events
+such as memory allocation failure.  Details of the error are shown in
+the log message.  In general, these errors are not really expected
+ones, and indicate an installation error or a program bug.  The
+session handler thread tries to clean up all intermediate resources
+even on these errors, but it may be incomplete.  So, if this log
+message continuously appears, system resource consumption should be
+checked, and you may even want to disable the corresponding transfers.
+You may also want to file a bug report if this message appears so
+often.
+
 % XFRIN_XFR_TRANSFER_STARTED %1 transfer of zone %2 started
 A connection to the master server has been made, the serial value in
 the SOA record has been checked, and a zone transfer has been started.

+ 53 - 60
src/lib/dns/python/message_python.cc

@@ -16,6 +16,7 @@
 #include <Python.h>
 
 #include <exceptions/exceptions.h>
+#include <util/python/pycppwrapper_util.h>
 #include <dns/message.h>
 #include <dns/rcode.h>
 #include <dns/tsig.h>
@@ -38,6 +39,7 @@ using namespace std;
 using namespace isc::dns;
 using namespace isc::dns::python;
 using namespace isc::util;
+using namespace isc::util::python;
 
 // Import pydoc text
 #include "message_python_inc.cc"
@@ -64,8 +66,8 @@ PyObject* Message_setEDNS(s_Message* self, PyObject* args);
 PyObject* Message_getTSIGRecord(s_Message* self);
 PyObject* Message_getRRCount(s_Message* self, PyObject* args);
 // use direct iterators for these? (or simply lists for now?)
-PyObject* Message_getQuestion(s_Message* self);
-PyObject* Message_getSection(s_Message* self, PyObject* args);
+PyObject* Message_getQuestion(PyObject* self, PyObject*);
+PyObject* Message_getSection(PyObject* self, PyObject* args);
 //static PyObject* Message_beginQuestion(s_Message* self, PyObject* args);
 //static PyObject* Message_endQuestion(s_Message* self, PyObject* args);
 //static PyObject* Message_beginSection(s_Message* self, PyObject* args);
@@ -127,10 +129,10 @@ PyMethodDef Message_methods[] = {
     },
     { "get_rr_count", reinterpret_cast<PyCFunction>(Message_getRRCount), METH_VARARGS,
       "Returns the number of RRs contained in the given section." },
-    { "get_question", reinterpret_cast<PyCFunction>(Message_getQuestion), METH_NOARGS,
+    { "get_question", Message_getQuestion, METH_NOARGS,
       "Returns a list of all Question objects in the message "
       "(should be either 0 or 1)" },
-    { "get_section", reinterpret_cast<PyCFunction>(Message_getSection), METH_VARARGS,
+    { "get_section", Message_getSection, METH_VARARGS,
       "Returns a list of all RRset objects in the given section of the message\n"
       "The argument must be of type Section" },
     { "add_question", reinterpret_cast<PyCFunction>(Message_addQuestion), METH_VARARGS,
@@ -409,50 +411,59 @@ Message_getRRCount(s_Message* self, PyObject* args) {
     }
 }
 
+// This is a helper templated class commonly used for getQuestion and
+// getSection in order to build a list of Message section items.
+template <typename ItemType, typename CreatorParamType>
+class SectionInserter {
+    typedef PyObject* (*creator_t)(const CreatorParamType&);
+public:
+    SectionInserter(PyObject* pylist, creator_t creator) :
+        pylist_(pylist), creator_(creator)
+    {}
+    void operator()(ItemType item) {
+        if (PyList_Append(pylist_, PyObjectContainer(creator_(*item)).get())
+            == -1) {
+            isc_throw(PyCPPWrapperException, "PyList_Append failed, "
+                      "probably due to short memory");
+        }
+    }
+private:
+    PyObject* pylist_;
+    creator_t creator_;
+};
+
+typedef SectionInserter<ConstQuestionPtr, Question> QuestionInserter;
+typedef SectionInserter<ConstRRsetPtr, RRset> RRsetInserter;
+
 // TODO use direct iterators for these? (or simply lists for now?)
 PyObject*
-Message_getQuestion(s_Message* self) {
-    QuestionIterator qi, qi_end;
+Message_getQuestion(PyObject* po_self, PyObject*) {
+    const s_Message* const self = static_cast<s_Message*>(po_self);
+
     try {
-        qi = self->cppobj->beginQuestion();
-        qi_end = self->cppobj->endQuestion();
+        PyObjectContainer list_container(PyList_New(0));
+        for_each(self->cppobj->beginQuestion(),
+                 self->cppobj->endQuestion(),
+                 QuestionInserter(list_container.get(), createQuestionObject));
+        return (list_container.release());
     } catch (const InvalidMessageSection& ex) {
         PyErr_SetString(po_InvalidMessageSection, ex.what());
-        return (NULL);
-    } catch (...) {
-        PyErr_SetString(po_IscException,
-                        "Unexpected exception in getting section iterators");
-        return (NULL);
-    }
-
-    PyObject* list = PyList_New(0);
-    if (list == NULL) {
-        return (NULL);
-    }
-
-    try {
-        for (; qi != qi_end; ++qi) {
-            if (PyList_Append(list, createQuestionObject(**qi)) == -1) {
-                Py_DECREF(list);
-                return (NULL);
-            }
-        }
-        return (list);
     } catch (const exception& ex) {
         const string ex_what =
-            "Unexpected failure getting Question section: " +
+            "Unexpected failure in Message.get_question: " +
             string(ex.what());
         PyErr_SetString(po_IscException, ex_what.c_str());
     } catch (...) {
         PyErr_SetString(PyExc_SystemError,
-                        "Unexpected failure getting Question section");
+                        "Unexpected failure in Message.get_question");
     }
-    Py_DECREF(list);
     return (NULL);
 }
 
 PyObject*
-Message_getSection(s_Message* self, PyObject* args) {
+Message_getSection(PyObject* po_self, PyObject* args) {
+    const s_Message* const self = static_cast<s_Message*>(po_self);
+
     unsigned int section;
     if (!PyArg_ParseTuple(args, "I", &section)) {
         PyErr_Clear();
@@ -460,46 +471,28 @@ Message_getSection(s_Message* self, PyObject* args) {
                         "no valid type in get_section argument");
         return (NULL);
     }
-    RRsetIterator rrsi, rrsi_end;
+
     try {
-        rrsi = self->cppobj->beginSection(
-            static_cast<Message::Section>(section));
-        rrsi_end = self->cppobj->endSection(
-            static_cast<Message::Section>(section));
+        PyObjectContainer list_container(PyList_New(0));
+        const Message::Section msgsection =
+            static_cast<Message::Section>(section);
+        for_each(self->cppobj->beginSection(msgsection),
+                 self->cppobj->endSection(msgsection),
+                 RRsetInserter(list_container.get(), createRRsetObject));
+        return (list_container.release());
     } catch (const isc::OutOfRange& ex) {
         PyErr_SetString(PyExc_OverflowError, ex.what());
-        return (NULL);
     } catch (const InvalidMessageSection& ex) {
         PyErr_SetString(po_InvalidMessageSection, ex.what());
-        return (NULL);
-    } catch (...) {
-        PyErr_SetString(po_IscException,
-                        "Unexpected exception in getting section iterators");
-        return (NULL);
-    }
-
-    PyObject* list = PyList_New(0);
-    if (list == NULL) {
-        return (NULL);
-    }
-    try {
-        for (; rrsi != rrsi_end; ++rrsi) {
-            if (PyList_Append(list, createRRsetObject(**rrsi)) == -1) {
-                    Py_DECREF(list);
-                    return (NULL);
-            }
-        }
-        return (list);
     } catch (const exception& ex) {
         const string ex_what =
-            "Unexpected failure creating Question object: " +
+            "Unexpected failure in Message.get_section: " +
             string(ex.what());
         PyErr_SetString(po_IscException, ex_what.c_str());
     } catch (...) {
         PyErr_SetString(PyExc_SystemError,
-                        "Unexpected failure creating Question object");
+                        "Unexpected failure in Message.get_section");
     }
-    Py_DECREF(list);
     return (NULL);
 }
 

+ 18 - 15
src/lib/dns/python/rrset_python.cc

@@ -63,7 +63,7 @@ PyObject* RRset_toText(s_RRset* self);
 PyObject* RRset_str(PyObject* self);
 PyObject* RRset_toWire(s_RRset* self, PyObject* args);
 PyObject* RRset_addRdata(s_RRset* self, PyObject* args);
-PyObject* RRset_getRdata(s_RRset* self);
+PyObject* RRset_getRdata(PyObject* po_self, PyObject*);
 PyObject* RRset_removeRRsig(s_RRset* self);
 
 // TODO: iterator?
@@ -94,7 +94,7 @@ PyMethodDef RRset_methods[] = {
       "returned" },
     { "add_rdata", reinterpret_cast<PyCFunction>(RRset_addRdata), METH_VARARGS,
       "Adds the rdata for one RR to the RRset.\nTakes an Rdata object as an argument" },
-    { "get_rdata", reinterpret_cast<PyCFunction>(RRset_getRdata), METH_NOARGS,
+    { "get_rdata", RRset_getRdata, METH_NOARGS,
       "Returns a List containing all Rdata elements" },
     { "remove_rrsig", reinterpret_cast<PyCFunction>(RRset_removeRRsig), METH_NOARGS,
       "Clears the list of RRsigs for this RRset" },
@@ -291,22 +291,26 @@ RRset_addRdata(s_RRset* self, PyObject* args) {
 }
 
 PyObject*
-RRset_getRdata(s_RRset* self) {
-    PyObject* list = PyList_New(0);
-
-    RdataIteratorPtr it = self->cppobj->getRdataIterator();
+RRset_getRdata(PyObject* po_self, PyObject*) {
+    const s_RRset* const self = static_cast<s_RRset*>(po_self);
 
     try {
-        for (; !it->isLast(); it->next()) {
-            const rdata::Rdata *rd = &it->getCurrent();
-            if (PyList_Append(list,
-                    createRdataObject(createRdata(self->cppobj->getType(),
-                                      self->cppobj->getClass(), *rd))) == -1) {
-                Py_DECREF(list);
-                return (NULL);
+        PyObjectContainer list_container(PyList_New(0));
+
+        for (RdataIteratorPtr it = self->cppobj->getRdataIterator();
+             !it->isLast(); it->next()) {
+            if (PyList_Append(list_container.get(),
+                              PyObjectContainer(
+                                  createRdataObject(
+                                      createRdata(self->cppobj->getType(),
+                                                  self->cppobj->getClass(),
+                                                  it->getCurrent()))).get())
+                == -1) {
+                isc_throw(PyCPPWrapperException, "PyList_Append failed, "
+                          "probably due to short memory");
             }
         }
-        return (list);
+        return (list_container.release());
     } catch (const exception& ex) {
         const string ex_what =
             "Unexpected failure getting rrset Rdata: " +
@@ -316,7 +320,6 @@ RRset_getRdata(s_RRset* self) {
         PyErr_SetString(PyExc_SystemError,
                         "Unexpected failure getting rrset Rdata");
     }
-    Py_DECREF(list);
     return (NULL);
 }
 

+ 16 - 1
src/lib/dns/python/tests/message_python_test.py

@@ -17,6 +17,7 @@
 # Tests for the message part of the pydnspp module
 #
 
+import sys
 import unittest
 import os
 from pydnspp import *
@@ -230,6 +231,14 @@ class MessageTest(unittest.TestCase):
         self.assertTrue(compare_rrset_list(section_rrset, self.r.get_section(Message.SECTION_ANSWER)))
         self.assertEqual(2, self.r.get_rr_count(Message.SECTION_ANSWER))
 
+        # We always make a new deep copy in get_section(), so the reference
+        # count of the returned list and its each item should be 1; otherwise
+        # they would leak.
+        self.assertEqual(1, sys.getrefcount(self.r.get_section(
+                    Message.SECTION_ANSWER)))
+        self.assertEqual(1, sys.getrefcount(self.r.get_section(
+                    Message.SECTION_ANSWER)[0]))
+
         self.assertFalse(compare_rrset_list(section_rrset, self.r.get_section(Message.SECTION_AUTHORITY)))
         self.assertEqual(0, self.r.get_rr_count(Message.SECTION_AUTHORITY))
         self.r.add_rrset(Message.SECTION_AUTHORITY, self.rrset_a)
@@ -242,7 +251,7 @@ class MessageTest(unittest.TestCase):
         self.assertTrue(compare_rrset_list(section_rrset, self.r.get_section(Message.SECTION_ADDITIONAL)))
         self.assertEqual(2, self.r.get_rr_count(Message.SECTION_ADDITIONAL))
 
-    def test_add_question(self):
+    def test_add_and_get_question(self):
         self.assertRaises(TypeError, self.r.add_question, "wrong", "wrong")
         q = Question(Name("example.com"), RRClass("IN"), RRType("A"))
         qs = [q]
@@ -252,6 +261,12 @@ class MessageTest(unittest.TestCase):
         self.assertTrue(compare_rrset_list(qs, self.r.get_question()))
         self.assertEqual(1, self.r.get_rr_count(Message.SECTION_QUESTION))
 
+        # We always make a new deep copy in get_section(), so the reference
+        # count of the returned list and its each item should be 1; otherwise
+        # they would leak.
+        self.assertEqual(1, sys.getrefcount(self.r.get_question()))
+        self.assertEqual(1, sys.getrefcount(self.r.get_question()[0]))
+
     def test_add_rrset(self):
         self.assertRaises(TypeError, self.r.add_rrset, "wrong")
         self.assertRaises(TypeError, self.r.add_rrset)

+ 7 - 0
src/lib/dns/python/tests/rrset_python_test.py

@@ -17,6 +17,7 @@
 # Tests for the rrtype part of the pydnspp module
 #
 
+import sys
 import unittest
 import os
 from pydnspp import *
@@ -110,6 +111,12 @@ class TestModuleSpec(unittest.TestCase):
                 ]
         self.assertEqual(rdata, self.rrset_a.get_rdata())
         self.assertEqual([], self.rrset_a_empty.get_rdata())
+
+        # We always make a new deep copy in get_rdata(), so the reference
+        # count of the returned list and its each item should be 1; otherwise
+        # they would leak.
+        self.assertEqual(1, sys.getrefcount(self.rrset_a.get_rdata()))
+        self.assertEqual(1, sys.getrefcount(self.rrset_a.get_rdata()[0]))
         
 if __name__ == '__main__':
     unittest.main()

+ 8 - 8
src/lib/python/isc/datasrc/finder_python.cc

@@ -268,16 +268,16 @@ PyTypeObject zonefinder_type = {
 
 PyObject*
 createZoneFinderObject(isc::datasrc::ZoneFinderPtr source, PyObject* base_obj) {
-    s_ZoneFinder* py_zi = static_cast<s_ZoneFinder*>(
+    s_ZoneFinder* py_zf = static_cast<s_ZoneFinder*>(
         zonefinder_type.tp_alloc(&zonefinder_type, 0));
-    if (py_zi != NULL) {
-        py_zi->cppobj = source;
-        py_zi->base_obj = base_obj;
-    }
-    if (base_obj != NULL) {
-        Py_INCREF(base_obj);
+    if (py_zf != NULL) {
+        py_zf->cppobj = source;
+        py_zf->base_obj = base_obj;
+        if (base_obj != NULL) {
+            Py_INCREF(base_obj);
+        }
     }
-    return (py_zi);
+    return (py_zf);
 }
 
 } // namespace python

+ 3 - 3
src/lib/python/isc/datasrc/iterator_python.cc

@@ -204,9 +204,9 @@ createZoneIteratorObject(isc::datasrc::ZoneIteratorPtr source,
     if (py_zi != NULL) {
         py_zi->cppobj = source;
         py_zi->base_obj = base_obj;
-    }
-    if (base_obj != NULL) {
-        Py_INCREF(base_obj);
+        if (base_obj != NULL) {
+            Py_INCREF(base_obj);
+        }
     }
     return (py_zi);
 }

+ 18 - 0
src/lib/python/isc/datasrc/tests/datasrc_test.py

@@ -20,6 +20,7 @@ import isc.dns
 import unittest
 import os
 import shutil
+import sys
 import json
 
 TESTDATA_PATH = os.environ['TESTDATA_PATH'] + os.sep
@@ -494,6 +495,23 @@ class DataSrcUpdater(unittest.TestCase):
                          dsc.get_updater(isc.dns.Name("notexistent.example"),
                                          True))
 
+    def test_client_reference(self):
+        # Temporarily create various objects using factory methods of the
+        # client.  The created objects won't be stored anywhere and
+        # immediately released.  The creation shouldn't affect the reference
+        # to the base client.
+        dsc = isc.datasrc.DataSourceClient("sqlite3", WRITE_ZONE_DB_CONFIG)
+        orig_ref = sys.getrefcount(dsc)
+
+        dsc.find_zone(isc.dns.Name("example.com"))
+        self.assertEqual(orig_ref, sys.getrefcount(dsc))
+
+        dsc.get_iterator(isc.dns.Name("example.com."))
+        self.assertEqual(orig_ref, sys.getrefcount(dsc))
+
+        dsc.get_updater(isc.dns.Name("example.com"), True)
+        self.assertEqual(orig_ref, sys.getrefcount(dsc))
+
 if __name__ == "__main__":
     isc.log.init("bind10")
     unittest.main()

+ 8 - 7
src/lib/python/isc/datasrc/updater_python.cc

@@ -270,15 +270,16 @@ PyObject*
 createZoneUpdaterObject(isc::datasrc::ZoneUpdaterPtr source,
                         PyObject* base_obj)
 {
-    s_ZoneUpdater* py_zi = static_cast<s_ZoneUpdater*>(
+    s_ZoneUpdater* py_zu = static_cast<s_ZoneUpdater*>(
         zoneupdater_type.tp_alloc(&zoneupdater_type, 0));
-    if (py_zi != NULL) {
-        py_zi->cppobj = source;
-    }
-    if (base_obj != NULL) {
-        Py_INCREF(base_obj);
+    if (py_zu != NULL) {
+        py_zu->cppobj = source;
+        py_zu->base_obj = base_obj;
+        if (base_obj != NULL) {
+            Py_INCREF(base_obj);
+        }
     }
-    return (py_zi);
+    return (py_zu);
 }
 
 } // namespace python