Browse Source

[master] Merge branch 'trac1217'

Conflicts:
	src/lib/datasrc/database.cc
	src/lib/datasrc/tests/database_unittest.cc
Jelte Jansen 13 years ago
parent
commit
46c4fc8c24

+ 8 - 1
src/lib/datasrc/client.h

@@ -215,11 +215,18 @@ public:
     ///
     /// \param name The name of zone apex to be traversed. It doesn't do
     ///     nearest match as findZone.
+    /// \param adjust_ttl If true, the iterator will treat RRs with the same
+    ///                   name and type but different TTL values to be of the
+    ///                   same RRset, and will adjust the TTL to the lowest
+    ///                   value found. If false, it will consider the RR to
+    ///                   belong to a different RRset.
     /// \return Pointer to the iterator.
-    virtual ZoneIteratorPtr getIterator(const isc::dns::Name& name) const {
+    virtual ZoneIteratorPtr getIterator(const isc::dns::Name& name,
+                                        bool adjust_ttl = true) const {
         // This is here to both document the parameter in doxygen (therefore it
         // needs a name) and avoid unused parameter warning.
         static_cast<void>(name);
+        static_cast<void>(adjust_ttl);
 
         isc_throw(isc::NotImplemented,
                   "Data source doesn't support iteration");

+ 21 - 10
src/lib/datasrc/database.cc

@@ -706,10 +706,12 @@ class DatabaseIterator : public ZoneIterator {
 public:
     DatabaseIterator(shared_ptr<DatabaseAccessor> accessor,
                      const Name& zone_name,
-                     const RRClass& rrclass) :
+                     const RRClass& rrclass,
+                     bool adjust_ttl) :
         accessor_(accessor),
         class_(rrclass),
-        ready_(true)
+        ready_(true),
+        adjust_ttl_(adjust_ttl)
     {
         // Get the zone
         const pair<bool, int> zone(accessor_->getZone(zone_name.toText()));
@@ -767,13 +769,17 @@ public:
         const RRType rtype(rtype_str);
         RRsetPtr rrset(new RRset(name, class_, rtype, RRTTL(ttl)));
         while (data_ready_ && name_ == name_str && rtype_str == rtype_) {
-            if (ttl_ != ttl) {
-                if (ttl < ttl_) {
-                    ttl_ = ttl;
-                    rrset->setTTL(RRTTL(ttl));
+            if (adjust_ttl_) {
+                if (ttl_ != ttl) {
+                    if (ttl < ttl_) {
+                        ttl_ = ttl;
+                        rrset->setTTL(RRTTL(ttl));
+                    }
+                    LOG_WARN(logger, DATASRC_DATABASE_ITERATE_TTL_MISMATCH).
+                        arg(name_).arg(class_).arg(rtype_).arg(rrset->getTTL());
                 }
-                LOG_WARN(logger, DATASRC_DATABASE_ITERATE_TTL_MISMATCH).
-                    arg(name_).arg(class_).arg(rtype_).arg(rrset->getTTL());
+            } else if (ttl_ != ttl) {
+                break;
             }
             rrset->addRdata(rdata::createRdata(rtype, class_, rdata_));
             getData();
@@ -806,15 +812,20 @@ private:
     bool ready_, data_ready_;
     // Data of the next row
     string name_, rtype_, rdata_, ttl_;
+    // Whether to modify differing TTL values, or treat a different TTL as
+    // a different RRset
+    bool adjust_ttl_;
 };
 
 }
 
 ZoneIteratorPtr
-DatabaseClient::getIterator(const isc::dns::Name& name) const {
+DatabaseClient::getIterator(const isc::dns::Name& name,
+                            bool adjust_ttl) const
+{
     ZoneIteratorPtr iterator = ZoneIteratorPtr(new DatabaseIterator(
                                                    accessor_->clone(), name,
-                                                   rrclass_));
+                                                   rrclass_, adjust_ttl));
     LOG_DEBUG(logger, DBG_TRACE_DETAILED, DATASRC_DATABASE_ITERATE).
         arg(name);
 

+ 7 - 1
src/lib/datasrc/database.h

@@ -863,9 +863,15 @@ public:
      * \exception Anything else the underlying DatabaseConnection might
      *     want to throw.
      * \param name The origin of the zone to iterate.
+     * \param adjust_ttl If true, the iterator will treat RRs with the same
+     *                   name and type but different TTL values to be of the
+     *                   same RRset, and will adjust the TTL to the lowest
+     *                   value found. If false, it will consider the RR to
+     *                   belong to a different RRset.
      * \return Shared pointer to the iterator (it will never be NULL)
      */
-    virtual ZoneIteratorPtr getIterator(const isc::dns::Name& name) const;
+    virtual ZoneIteratorPtr getIterator(const isc::dns::Name& name,
+                                        bool adjust_ttl = true) const;
 
     /// This implementation internally clones the accessor from the one
     /// used in the client and starts a separate transaction using the cloned

+ 5 - 1
src/lib/datasrc/memory_datasrc.cc

@@ -789,7 +789,11 @@ public:
 } // End of anonymous namespace
 
 ZoneIteratorPtr
-InMemoryClient::getIterator(const Name& name) const {
+InMemoryClient::getIterator(const Name& name, bool) const {
+    // note: adjust_ttl argument is ignored, as the RRsets are already
+    // individually stored, and hence cannot have different TTLs anymore at
+    // this point
+
     ZoneTable::FindResult result(impl_->zone_table.findZone(name));
     if (result.code != result::SUCCESS) {
         isc_throw(DataSourceError, "No such zone: " + name.toText());

+ 2 - 1
src/lib/datasrc/memory_datasrc.h

@@ -272,7 +272,8 @@ public:
     virtual FindResult findZone(const isc::dns::Name& name) const;
 
     /// \brief Implementation of the getIterator method
-    virtual ZoneIteratorPtr getIterator(const isc::dns::Name& name) const;
+    virtual ZoneIteratorPtr getIterator(const isc::dns::Name& name,
+                                        bool adjust_ttl = true) const;
 
     /// In-memory data source is read-only, so this derived method will
     /// result in a NotImplemented exception.

+ 0 - 1
src/lib/datasrc/sqlite3_accessor.cc

@@ -472,7 +472,6 @@ public:
         accessor_(accessor),
         statement_(NULL),
         name_(name)
-
     {
         // We create the statement now and then just keep getting data from it
         statement_ = prepare(accessor->dbparameters_->db_,

+ 77 - 1
src/lib/datasrc/tests/database_unittest.cc

@@ -440,10 +440,22 @@ private:
                     data[DatabaseAccessor::TTL_COLUMN] = "300";
                     data[DatabaseAccessor::RDATA_COLUMN] = "2001:db8::2";
                     return (true);
+                case 6:
+                    data[DatabaseAccessor::NAME_COLUMN] = "ttldiff.example.org";
+                    data[DatabaseAccessor::TYPE_COLUMN] = "A";
+                    data[DatabaseAccessor::TTL_COLUMN] = "300";
+                    data[DatabaseAccessor::RDATA_COLUMN] = "192.0.2.1";
+                    return (true);
+                case 7:
+                    data[DatabaseAccessor::NAME_COLUMN] = "ttldiff.example.org";
+                    data[DatabaseAccessor::TYPE_COLUMN] = "A";
+                    data[DatabaseAccessor::TTL_COLUMN] = "600";
+                    data[DatabaseAccessor::RDATA_COLUMN] = "192.0.2.2";
+                    return (true);
                 default:
                     ADD_FAILURE() <<
                         "Request past the end of iterator context";
-                case 6:
+                case 8:
                     return (false);
             }
         }
@@ -1060,6 +1072,16 @@ TYPED_TEST(DatabaseClientTest, iterator) {
     this->expected_rdatas_.push_back("2001:db8::2");
     checkRRset(rrset, Name("x.example.org"), this->qclass_, RRType::AAAA(),
                RRTTL(300), this->expected_rdatas_);
+
+    rrset = it->getNextRRset();
+    ASSERT_NE(ConstRRsetPtr(), rrset);
+    this->expected_rdatas_.clear();
+    this->expected_rdatas_.push_back("192.0.2.1");
+    this->expected_rdatas_.push_back("192.0.2.2");
+    checkRRset(rrset, Name("ttldiff.example.org"), this->qclass_, RRType::A(),
+               RRTTL(300), this->expected_rdatas_);
+
+    EXPECT_EQ(ConstRRsetPtr(), it->getNextRRset());
 }
 
 // This has inconsistent TTL in the set (the rest, like nonsense in
@@ -1200,6 +1222,60 @@ doFindTest(ZoneFinder& finder,
     }
 }
 
+// When asking for an RRset where RRs somehow have different TTLs, it should 
+// convert to the lowest one.
+TEST_F(MockDatabaseClientTest, ttldiff) {
+    ZoneIteratorPtr it(this->client_->getIterator(Name("example.org")));
+    // Walk through the full iterator, we should see 1 rrset with name
+    // ttldiff1.example.org., and two rdatas. Same for ttldiff2
+    Name name("ttldiff.example.org.");
+    bool found = false;
+    //bool found2 = false;
+    ConstRRsetPtr rrset = it->getNextRRset();
+    while(rrset != ConstRRsetPtr()) {
+        if (rrset->getName() == name) {
+            ASSERT_FALSE(found);
+            ASSERT_EQ(2, rrset->getRdataCount());
+            ASSERT_EQ(RRTTL(300), rrset->getTTL());
+            found = true;
+        }
+        rrset = it->getNextRRset();
+    }
+    ASSERT_TRUE(found);
+}
+
+// Unless we ask for individual RRs in our iterator request. In that case
+// every RR should go into its own 'rrset'
+TEST_F(MockDatabaseClientTest, ttldiff_no_adjust_ttl) {
+    ZoneIteratorPtr it(this->client_->getIterator(Name("example.org"), false));
+
+    // Walk through the full iterator, we should see 1 rrset with name
+    // ttldiff1.example.org., and two rdatas. Same for ttldiff2
+    Name name("ttldiff.example.org.");
+    int found1 = false;
+    int found2 = false;
+    ConstRRsetPtr rrset = it->getNextRRset();
+    while(rrset != ConstRRsetPtr()) {
+        if (rrset->getName() == name) {
+            ASSERT_EQ(1, rrset->getRdataCount());
+            // We should find 1 'rrset' with TTL 300 and one with TTL 600
+            if (rrset->getTTL() == RRTTL(300)) {
+                ASSERT_FALSE(found1);
+                found1 = true;
+            } else if (rrset->getTTL() == RRTTL(600)) {
+                ASSERT_FALSE(found2);
+                found2 = true;
+            } else {
+                FAIL() << "Found unexpected TTL: " <<
+                          rrset->getTTL().toText();
+            }
+        }
+        rrset = it->getNextRRset();
+    }
+    ASSERT_TRUE(found1);
+    ASSERT_TRUE(found2);
+}
+
 TYPED_TEST(DatabaseClientTest, find) {
     shared_ptr<DatabaseClient::Finder> finder(this->getFinder());
 

+ 6 - 1
src/lib/python/isc/datasrc/client_inc.cc

@@ -89,7 +89,7 @@ None\n\
 ";
 
 const char* const DataSourceClient_getIterator_doc = "\
-get_iterator(name) -> ZoneIterator\n\
+get_iterator(name, adjust_ttl=True) -> ZoneIterator\n\
 \n\
 Returns an iterator to the given zone.\n\
 \n\
@@ -111,6 +111,11 @@ anything else.\n\
 Parameters:\n\
   isc.dns.Name The name of zone apex to be traversed. It doesn't do\n\
                nearest match as find_zone.\n\
+  adjust_ttl   If True, the iterator will treat RRs with the same\n\
+               name and type but different TTL values to be of the\n\
+               same RRset, and will adjust the TTL to the lowest\n\
+               value found. If false, it will consider the RR to\n\
+               belong to a different RRset.\n\
 \n\
 Return Value(s): Pointer to the iterator.\n\
 ";

+ 19 - 3
src/lib/python/isc/datasrc/client_python.cc

@@ -83,11 +83,27 @@ DataSourceClient_findZone(PyObject* po_self, PyObject* args) {
 PyObject*
 DataSourceClient_getIterator(PyObject* po_self, PyObject* args) {
     s_DataSourceClient* const self = static_cast<s_DataSourceClient*>(po_self);
-    PyObject *name_obj;
-    if (PyArg_ParseTuple(args, "O!", &name_type, &name_obj)) {
+    PyObject* name_obj;
+    PyObject* adjust_ttl_obj = NULL;
+    if (PyArg_ParseTuple(args, "O!|O", &name_type, &name_obj,
+                         &adjust_ttl_obj)) {
         try {
+            bool adjust_ttl = true;
+            if (adjust_ttl_obj != NULL) {
+                // store result in local var so we can explicitely check for
+                // -1 error return value
+                int adjust_ttl_no = PyObject_Not(adjust_ttl_obj);
+                if (adjust_ttl_no == 1) {
+                    adjust_ttl = false;
+                } else if (adjust_ttl_no == -1) {
+                    PyErr_SetString(getDataSourceException("Error"),
+                                    "Error getting value of adjust_ttl");
+                    return (NULL);
+                }
+            }
             return (createZoneIteratorObject(
-                self->cppobj->getInstance().getIterator(PyName_ToName(name_obj)),
+                self->cppobj->getInstance().getIterator(PyName_ToName(name_obj),
+                                                        adjust_ttl),
                 po_self));
         } catch (const isc::NotImplemented& ne) {
             PyErr_SetString(getDataSourceException("NotImplemented"),

+ 27 - 2
src/lib/python/isc/datasrc/tests/datasrc_test.py

@@ -63,7 +63,7 @@ def check_for_rrset(expected_rrsets, rrset):
 
 class DataSrcClient(unittest.TestCase):
 
-    def test_constructors(self):
+    def test_(self):
         # can't construct directly
         self.assertRaises(TypeError, isc.datasrc.ZoneIterator)
 
@@ -87,7 +87,7 @@ class DataSrcClient(unittest.TestCase):
 
         # for RRSIGS, the TTL's are currently modified. This test should
         # start failing when we fix that.
-        rrs = dsc.get_iterator(isc.dns.Name("sql1.example.com."))
+        rrs = dsc.get_iterator(isc.dns.Name("sql1.example.com."), False)
 
         # we do not know the order in which they are returned by the iterator
         # but we do want to check them, so we put all records into one list
@@ -137,6 +137,13 @@ class DataSrcClient(unittest.TestCase):
                   ])
         # For RRSIGS, we can't add the fake data through the API, so we
         # simply pass no rdata at all (which is skipped by the check later)
+        
+        # Since we passed adjust_ttl = False to get_iterator, we get several
+        # sets of RRSIGs, one for each TTL
+        add_rrset(expected_rrset_list, name, rrclass,
+                  isc.dns.RRType.RRSIG(), isc.dns.RRTTL(3600), None)
+        add_rrset(expected_rrset_list, name, rrclass,
+                  isc.dns.RRType.RRSIG(), isc.dns.RRTTL(7200), None)
         add_rrset(expected_rrset_list, name, rrclass,
                   isc.dns.RRType.RRSIG(), isc.dns.RRTTL(3600), None)
         add_rrset(expected_rrset_list, name, rrclass,
@@ -158,6 +165,8 @@ class DataSrcClient(unittest.TestCase):
                   ])
         add_rrset(expected_rrset_list, name, rrclass,
                   isc.dns.RRType.RRSIG(), isc.dns.RRTTL(3600), None)
+        add_rrset(expected_rrset_list, name, rrclass,
+                  isc.dns.RRType.RRSIG(), isc.dns.RRTTL(7200), None)
 
         # rrs is an iterator, but also has direct get_next_rrset(), use
         # the latter one here
@@ -179,10 +188,26 @@ class DataSrcClient(unittest.TestCase):
         # instead of failing?
         self.assertRaises(isc.datasrc.Error, rrs.get_next_rrset)
 
+        # Without the adjust_ttl argument, it should return 55 RRsets
+        dsc = isc.datasrc.DataSourceClient("sqlite3", READ_ZONE_DB_CONFIG)
         rrets = dsc.get_iterator(isc.dns.Name("example.com"))
         # there are more than 80 RRs in this zone... let's just count them
         # (already did a full check of the smaller zone above)
         self.assertEqual(55, len(list(rrets)))
+
+        # same test, but now with explicit True argument for adjust_ttl
+        dsc = isc.datasrc.DataSourceClient("sqlite3", READ_ZONE_DB_CONFIG)
+        rrets = dsc.get_iterator(isc.dns.Name("example.com"), True)
+        # there are more than 80 RRs in this zone... let's just count them
+        # (already did a full check of the smaller zone above)
+        self.assertEqual(55, len(list(rrets)))
+
+        # Count should be 71 if we request individual rrsets for differing ttls
+        dsc = isc.datasrc.DataSourceClient("sqlite3", READ_ZONE_DB_CONFIG)
+        rrets = dsc.get_iterator(isc.dns.Name("example.com"), False)
+        # there are more than 80 RRs in this zone... let's just count them
+        # (already did a full check of the smaller zone above)
+        self.assertEqual(71, len(list(rrets)))
         # TODO should we catch this (iterating past end) and just return None
         # instead of failing?
         self.assertRaises(isc.datasrc.Error, rrs.get_next_rrset)