Browse Source

[1183] add initial use of iterator context for getRecords()

(to replace searchForRecords->getNextRecord combo)
Jelte Jansen 13 years ago
parent
commit
171088e69f

+ 10 - 2
src/lib/datasrc/database.cc

@@ -178,12 +178,20 @@ DatabaseClient::Finder::getRRset(const isc::dns::Name& name,
                                  bool want_ns)
 {
     RRsigStore sig_store;
-    database_->searchForRecords(zone_id_, name.toText());
     bool records_found = false;
     isc::dns::RRsetPtr result_rrset;
 
+    // Request the context
+    DatabaseAccessor::IteratorContextPtr
+        context(database_->getRecords(name, zone_id_));
+    // It must not return NULL, that's a bug of the implementation
+    if (context == DatabaseAccessor::IteratorContextPtr()) {
+        isc_throw(isc::Unexpected, "Iterator context null at " +
+                  name.toText());
+    }
+
     std::string columns[DatabaseAccessor::COLUMN_COUNT];
-    while (database_->getNextRecord(columns, DatabaseAccessor::COLUMN_COUNT)) {
+    while (context->getNext(columns, DatabaseAccessor::COLUMN_COUNT)) {
         if (!records_found) {
             records_found = true;
         }

+ 32 - 0
src/lib/datasrc/database.h

@@ -125,6 +125,38 @@ public:
     typedef boost::shared_ptr<IteratorContext> IteratorContextPtr;
 
     /**
+     * \brief Creates an iterator context for a specific name.
+     *
+     * This should create a new iterator context to be used by
+     * DatabaseConnection's ZoneIterator. It can be created based on the name
+     * or the ID (returned from getZone()), what is more comfortable for the
+     * database implementation. Both are provided (and are guaranteed to match,
+     * the DatabaseClient first looks up the zone ID and then calls this).
+     *
+     * The default implementation throws isc::NotImplemented, to allow
+     * "minimal" implementations of the connection not supporting optional
+     * functionality.
+     *
+     * \param name The name to search for.
+     * \param id The ID of the zone, returned from getZone().
+     * \return Newly created iterator context. Must not be NULL.
+     */
+    virtual IteratorContextPtr getRecords(const isc::dns::Name& name,
+                                          int id) const
+    {
+        /*
+         * This is a compromise. We need to document the parameters in doxygen,
+         * so they need a name, but then it complains about unused parameter.
+         * This is a NOP that "uses" the parameters.
+         */
+        static_cast<void>(name);
+        static_cast<void>(id);
+
+        isc_throw(isc::NotImplemented,
+                  "This database datasource can't be iterated");
+    }
+
+    /**
      * \brief Creates an iterator context for the whole zone.
      *
      * This should create a new iterator context to be used by

+ 56 - 3
src/lib/datasrc/sqlite3_accessor.cc

@@ -368,7 +368,10 @@ convertToPlainChar(const unsigned char* ucp,
 // it is, just provide data from it.
 class SQLite3Database::Context : public DatabaseAccessor::IteratorContext {
 public:
+    // Construct an iterator for all records. When constructed this
+    // way, the getNext() call will copy all fields
     Context(const boost::shared_ptr<const SQLite3Database>& database, int id) :
+        iterator_type_(ITT_ALL),
         database_(database),
         statement(NULL)
     {
@@ -379,6 +382,30 @@ public:
                       " to SQL statement (iterate)");
         }
     }
+
+    // Construct an iterator for records with a specific name. When constructed
+    // this way, the getNext() call will copy all fields except name
+    Context(const boost::shared_ptr<const SQLite3Database>& database, int id,
+            const isc::dns::Name& name) :
+        iterator_type_(ITT_NAME),
+        database_(database),
+        statement(NULL)
+    {
+        // We create the statement now and then just keep getting data from it
+        // TODO move to private and clean up error
+        statement = prepare(database->dbparameters_->db_, q_any_str);
+        if (sqlite3_bind_int(statement, 1, id) != SQLITE_OK) {
+            isc_throw(SQLite3Error, "Could not bind " << id <<
+                      " to SQL statement");
+        }
+        if (sqlite3_bind_text(statement, 2, name.toText().c_str(), -1,
+                              SQLITE_TRANSIENT) != SQLITE_OK) {
+            sqlite3_finalize(statement);
+            isc_throw(SQLite3Error, "Could not bind " << id <<
+                      " to SQL statement");
+        }
+    }
+
     bool getNext(std::string data[], size_t size) {
         if (size != COLUMN_COUNT) {
             isc_throw(DataSourceError, "getNext received size of " << size <<
@@ -387,9 +414,14 @@ public:
         // If there's another row, get it
         int rc(sqlite3_step(statement));
         if (rc == SQLITE_ROW) {
-            for (size_t i(0); i < size; ++ i) {
-                data[i] = convertToPlainChar(sqlite3_column_text(statement, i),
-                                             database_->dbparameters_);
+            // For both types, we copy the first four columns
+            copyColumn(data, TYPE_COLUMN);
+            copyColumn(data, TTL_COLUMN);
+            copyColumn(data, SIGTYPE_COLUMN);
+            copyColumn(data, RDATA_COLUMN);
+            // Only copy Name if we are iterating over every record
+            if (iterator_type_ == ITT_ALL) {
+                copyColumn(data, NAME_COLUMN);
             }
             return (true);
         } else if (rc != SQLITE_DONE) {
@@ -399,6 +431,7 @@ public:
         }
         return (false);
     }
+
     virtual ~Context() {
         if (statement) {
             sqlite3_finalize(statement);
@@ -406,11 +439,31 @@ public:
     }
 
 private:
+    // Depending on which constructor is called, behaviour is slightly
+    // different. We keep track of what to do with the iterator type
+    // See description of getNext() and the constructors
+    enum IteratorType {
+        ITT_ALL,
+        ITT_NAME
+    };
+
+    void copyColumn(std::string data[], int column) {
+        data[column] = convertToPlainChar(sqlite3_column_text(statement,
+                                                              column),
+                                          database_->dbparameters_);
+    }
+
+    IteratorType iterator_type_;
     boost::shared_ptr<const SQLite3Database> database_;
     sqlite3_stmt *statement;
 };
 
 DatabaseAccessor::IteratorContextPtr
+SQLite3Database::getRecords(const isc::dns::Name& name, int id) const {
+    return (IteratorContextPtr(new Context(shared_from_this(), id, name)));
+}
+
+DatabaseAccessor::IteratorContextPtr
 SQLite3Database::getAllRecords(const isc::dns::Name&, int id) const {
     return (IteratorContextPtr(new Context(shared_from_this(), id)));
 }

+ 6 - 1
src/lib/datasrc/sqlite3_accessor.h

@@ -91,9 +91,14 @@ public:
      */
     virtual std::pair<bool, int> getZone(const isc::dns::Name& name) const;
 
+    /// \brief Implementation of DatabaseAbstraction::getRecords
+    virtual IteratorContextPtr getRecords(const isc::dns::Name& name,
+                                          int id) const;
+
     /// \brief Implementation of DatabaseAbstraction::getAllRecords
     virtual IteratorContextPtr getAllRecords(const isc::dns::Name&,
-                                                  int id) const;
+                                             int id) const;
+
     /**
      * \brief Start a new search for the given name in the given zone.
      *

+ 75 - 0
src/lib/datasrc/tests/database_unittest.cc

@@ -87,6 +87,65 @@ public:
         fillData();
     }
 private:
+    class MockNameIteratorContext : public IteratorContext {
+    public:
+        MockNameIteratorContext(const MockAccessor& mock_accessor, int zone_id,
+                                const isc::dns::Name& name) :
+            searched_name_(name.toText()), cur_record_(0)
+        {
+            // 'hardcoded' name to trigger exceptions (for testing
+            // the error handling of find() (the other on is below in
+            // if the name is "exceptiononsearch" it'll raise an exception here
+            if (searched_name_ == "dsexception.in.search.") {
+                isc_throw(DataSourceError, "datasource exception on search");
+            } else if (searched_name_ == "iscexception.in.search.") {
+                isc_throw(isc::Exception, "isc exception on search");
+            } else if (searched_name_ == "basicexception.in.search.") {
+                throw std::exception();
+            }
+
+            // we're not aiming for efficiency in this test, simply
+            // copy the relevant vector from records
+            if (zone_id == 42) {
+                if (mock_accessor.records.count(searched_name_) > 0) {
+                    cur_name = mock_accessor.records.find(searched_name_)->second;
+                } else {
+                    cur_name.clear();
+                }
+            } else {
+                cur_name.clear();
+            }
+        }
+
+        virtual bool getNext(std::string columns[], size_t column_count) {
+            if (searched_name_ == "dsexception.in.getnext.") {
+                isc_throw(DataSourceError, "datasource exception on getnextrecord");
+            } else if (searched_name_ == "iscexception.in.getnext.") {
+                isc_throw(isc::Exception, "isc exception on getnextrecord");
+            } else if (searched_name_ == "basicexception.in.getnext.") {
+                throw std::exception();
+            }
+
+            if (column_count != DatabaseAccessor::COLUMN_COUNT) {
+                isc_throw(DataSourceError, "Wrong column count in getNextRecord");
+            }
+            if (cur_record_ < cur_name.size()) {
+                for (size_t i = 0; i < column_count; ++i) {
+                    columns[i] = cur_name[cur_record_][i];
+                }
+                cur_record_++;
+                return (true);
+            } else {
+                return (false);
+            }
+        }
+
+    private:
+        const std::string searched_name_;
+        int cur_record_;
+        std::vector< std::vector<std::string> > cur_name;
+    };
+
     class MockIteratorContext : public IteratorContext {
     private:
         int step;
@@ -194,6 +253,20 @@ public:
         }
     }
 
+    virtual IteratorContextPtr getRecords(const Name& name, int id) const {
+        if (id == 42) {
+            return (IteratorContextPtr(new MockNameIteratorContext(*this, id, name)));
+        } else if (id == 13) {
+            return (IteratorContextPtr());
+        } else if (id == 0) {
+            return (IteratorContextPtr(new EmptyIteratorContext()));
+        } else if (id == -1) {
+            return (IteratorContextPtr(new BadIteratorContext()));
+        } else {
+            isc_throw(isc::Unexpected, "Unknown zone ID");
+        }
+    }
+
     virtual void searchForRecords(int zone_id, const std::string& name) {
         search_running_ = true;
 
@@ -448,6 +521,8 @@ private:
 // This tests the default getAllRecords behaviour, throwing NotImplemented
 TEST(DatabaseConnectionTest, getAllRecords) {
     // The parameters don't matter
+    EXPECT_THROW(NopAccessor().getRecords(Name("."), 1),
+                 isc::NotImplemented);
     EXPECT_THROW(NopAccessor().getAllRecords(Name("."), 1),
                  isc::NotImplemented);
 }

+ 58 - 59
src/lib/datasrc/tests/sqlite3_accessor_unittest.cc

@@ -181,92 +181,91 @@ TEST_F(SQLite3Access, getRecords) {
     const size_t column_count = DatabaseAccessor::COLUMN_COUNT;
     std::string columns[column_count];
 
+    // TODO: can't do this anymore
     // without search, getNext() should return false
-    EXPECT_FALSE(db->getNextRecord(columns, column_count));
-    checkRecordRow(columns, "", "", "", "", "");
+    //EXPECT_FALSE(context->getNext(columns, column_count));
+    //checkRecordRow(columns, "", "", "", "", "");
 
-    db->searchForRecords(zone_id, "foo.bar.");
-    EXPECT_FALSE(db->getNextRecord(columns, column_count));
+    DatabaseAccessor::IteratorContextPtr
+        context(db->getRecords(Name("foo.bar"), 1));
+    ASSERT_NE(DatabaseAccessor::IteratorContextPtr(),
+              context);
+    EXPECT_FALSE(context->getNext(columns, column_count));
     checkRecordRow(columns, "", "", "", "", "");
 
-    db->searchForRecords(zone_id, "");
-    EXPECT_FALSE(db->getNextRecord(columns, column_count));
-    checkRecordRow(columns, "", "", "", "", "");
+    // TODO can't pass incomplete name anymore
+    //context = db->getRecords(Name(""), zone_id);
+    //EXPECT_FALSE(context->getNext(columns, column_count));
+    //checkRecordRow(columns, "", "", "", "", "");
 
     // Should error on a bad number of columns
-    EXPECT_THROW(db->getNextRecord(columns, 4), DataSourceError);
-    EXPECT_THROW(db->getNextRecord(columns, 6), DataSourceError);
+    EXPECT_THROW(context->getNext(columns, 4), DataSourceError);
+    EXPECT_THROW(context->getNext(columns, 6), DataSourceError);
 
     // now try some real searches
-    db->searchForRecords(zone_id, "foo.example.com.");
-    ASSERT_TRUE(db->getNextRecord(columns, column_count));
+    context = db->getRecords(Name("foo.example.com."), zone_id);
+    ASSERT_TRUE(context->getNext(columns, column_count));
     checkRecordRow(columns, "CNAME", "3600", "",
-                   "cnametest.example.org.", "foo.example.com.");
-    ASSERT_TRUE(db->getNextRecord(columns, column_count));
+                   "cnametest.example.org.", "");
+    ASSERT_TRUE(context->getNext(columns, column_count));
     checkRecordRow(columns, "RRSIG", "3600", "CNAME",
                    "CNAME 5 3 3600 20100322084538 20100220084538 33495 "
-                   "example.com. FAKEFAKEFAKEFAKE", "foo.example.com.");
-    ASSERT_TRUE(db->getNextRecord(columns, column_count));
+                   "example.com. FAKEFAKEFAKEFAKE", "");
+    ASSERT_TRUE(context->getNext(columns, column_count));
     checkRecordRow(columns, "NSEC", "7200", "",
-                   "mail.example.com. CNAME RRSIG NSEC", "foo.example.com.");
-    ASSERT_TRUE(db->getNextRecord(columns, column_count));
+                   "mail.example.com. CNAME RRSIG NSEC", "");
+    ASSERT_TRUE(context->getNext(columns, column_count));
     checkRecordRow(columns, "RRSIG", "7200", "NSEC",
                    "NSEC 5 3 7200 20100322084538 20100220084538 33495 "
-                   "example.com. FAKEFAKEFAKEFAKE", "foo.example.com.");
-    EXPECT_FALSE(db->getNextRecord(columns, column_count));
+                   "example.com. FAKEFAKEFAKEFAKE", "");
+    EXPECT_FALSE(context->getNext(columns, column_count));
     // with no more records, the array should not have been modified
     checkRecordRow(columns, "RRSIG", "7200", "NSEC",
                    "NSEC 5 3 7200 20100322084538 20100220084538 33495 "
-                   "example.com. FAKEFAKEFAKEFAKE", "foo.example.com.");
+                   "example.com. FAKEFAKEFAKEFAKE", "");
 
-    db->searchForRecords(zone_id, "example.com.");
-    ASSERT_TRUE(db->getNextRecord(columns, column_count));
+    context = db->getRecords(Name("example.com."), zone_id);
+    ASSERT_TRUE(context->getNext(columns, column_count));
     checkRecordRow(columns, "SOA", "3600", "",
                    "master.example.com. admin.example.com. "
-                   "1234 3600 1800 2419200 7200", "example.com.");
-    ASSERT_TRUE(db->getNextRecord(columns, column_count));
+                   "1234 3600 1800 2419200 7200", "");
+    ASSERT_TRUE(context->getNext(columns, column_count));
     checkRecordRow(columns, "RRSIG", "3600", "SOA",
                    "SOA 5 2 3600 20100322084538 20100220084538 "
-                   "33495 example.com. FAKEFAKEFAKEFAKE", "example.com.");
-    ASSERT_TRUE(db->getNextRecord(columns, column_count));
-    checkRecordRow(columns, "NS", "1200", "", "dns01.example.com.",
-                   "example.com.");
-    ASSERT_TRUE(db->getNextRecord(columns, column_count));
-    checkRecordRow(columns, "NS", "3600", "", "dns02.example.com.",
-                   "example.com.");
-    ASSERT_TRUE(db->getNextRecord(columns, column_count));
-    checkRecordRow(columns, "NS", "1800", "", "dns03.example.com.",
-                   "example.com.");
-    ASSERT_TRUE(db->getNextRecord(columns, column_count));
+                   "33495 example.com. FAKEFAKEFAKEFAKE", "");
+    ASSERT_TRUE(context->getNext(columns, column_count));
+    checkRecordRow(columns, "NS", "1200", "", "dns01.example.com.", "");
+    ASSERT_TRUE(context->getNext(columns, column_count));
+    checkRecordRow(columns, "NS", "3600", "", "dns02.example.com.", "");
+    ASSERT_TRUE(context->getNext(columns, column_count));
+    checkRecordRow(columns, "NS", "1800", "", "dns03.example.com.", "");
+    ASSERT_TRUE(context->getNext(columns, column_count));
     checkRecordRow(columns, "RRSIG", "3600", "NS",
                    "NS 5 2 3600 20100322084538 20100220084538 "
-                   "33495 example.com. FAKEFAKEFAKEFAKE",
-                   "example.com.");
-    ASSERT_TRUE(db->getNextRecord(columns, column_count));
-    checkRecordRow(columns, "MX", "3600", "", "10 mail.example.com.",
-                   "example.com.");
-    ASSERT_TRUE(db->getNextRecord(columns, column_count));
+                   "33495 example.com. FAKEFAKEFAKEFAKE", "");
+    ASSERT_TRUE(context->getNext(columns, column_count));
+    checkRecordRow(columns, "MX", "3600", "", "10 mail.example.com.", "");
+    ASSERT_TRUE(context->getNext(columns, column_count));
     checkRecordRow(columns, "MX", "3600", "",
-                   "20 mail.subzone.example.com.", "example.com.");
-    ASSERT_TRUE(db->getNextRecord(columns, column_count));
+                   "20 mail.subzone.example.com.", "");
+    ASSERT_TRUE(context->getNext(columns, column_count));
     checkRecordRow(columns, "RRSIG", "3600", "MX",
                    "MX 5 2 3600 20100322084538 20100220084538 "
-                   "33495 example.com. FAKEFAKEFAKEFAKE", "example.com.");
-    ASSERT_TRUE(db->getNextRecord(columns, column_count));
+                   "33495 example.com. FAKEFAKEFAKEFAKE", "");
+    ASSERT_TRUE(context->getNext(columns, column_count));
     checkRecordRow(columns, "NSEC", "7200", "",
-                   "cname-ext.example.com. NS SOA MX RRSIG NSEC DNSKEY",
-                   "example.com.");
-    ASSERT_TRUE(db->getNextRecord(columns, column_count));
+                   "cname-ext.example.com. NS SOA MX RRSIG NSEC DNSKEY", "");
+    ASSERT_TRUE(context->getNext(columns, column_count));
     checkRecordRow(columns, "RRSIG", "7200", "NSEC",
                    "NSEC 5 2 7200 20100322084538 20100220084538 "
-                   "33495 example.com. FAKEFAKEFAKEFAKE", "example.com.");
-    ASSERT_TRUE(db->getNextRecord(columns, column_count));
+                   "33495 example.com. FAKEFAKEFAKEFAKE", "");
+    ASSERT_TRUE(context->getNext(columns, column_count));
     checkRecordRow(columns, "DNSKEY", "3600", "",
                    "256 3 5 AwEAAcOUBllYc1hf7ND9uDy+Yz1BF3sI0m4q NGV7W"
                    "cTD0WEiuV7IjXgHE36fCmS9QsUxSSOV o1I/FMxI2PJVqTYHkX"
                    "FBS7AzLGsQYMU7UjBZ SotBJ6Imt5pXMu+lEDNy8TOUzG3xm7g"
-                   "0qcbW YF6qCEfvZoBtAqi5Rk7Mlrqs8agxYyMx", "example.com.");
-    ASSERT_TRUE(db->getNextRecord(columns, column_count));
+                   "0qcbW YF6qCEfvZoBtAqi5Rk7Mlrqs8agxYyMx", "");
+    ASSERT_TRUE(context->getNext(columns, column_count));
     checkRecordRow(columns, "DNSKEY", "3600", "",
                    "257 3 5 AwEAAe5WFbxdCPq2jZrZhlMj7oJdff3W7syJ tbvzg"
                    "62tRx0gkoCDoBI9DPjlOQG0UAbj+xUV 4HQZJStJaZ+fHU5AwV"
@@ -275,20 +274,20 @@ TEST_F(SQLite3Access, getRecords) {
                    "qiODyNZYQ+ZrLmF0KIJ2yPN3iO6Zq 23TaOrVTjB7d1a/h31OD"
                    "fiHAxFHrkY3t3D5J R9Nsl/7fdRmSznwtcSDgLXBoFEYmw6p86"
                    "Acv RyoYNcL1SXjaKVLG5jyU3UR+LcGZT5t/0xGf oIK/aKwEN"
-                   "rsjcKZZj660b1M=", "example.com.");
-    ASSERT_TRUE(db->getNextRecord(columns, column_count));
+                   "rsjcKZZj660b1M=", "");
+    ASSERT_TRUE(context->getNext(columns, column_count));
     checkRecordRow(columns, "RRSIG", "3600", "DNSKEY",
                    "DNSKEY 5 2 3600 20100322084538 20100220084538 "
-                   "4456 example.com. FAKEFAKEFAKEFAKE", "example.com.");
-    ASSERT_TRUE(db->getNextRecord(columns, column_count));
+                   "4456 example.com. FAKEFAKEFAKEFAKE", "");
+    ASSERT_TRUE(context->getNext(columns, column_count));
     checkRecordRow(columns, "RRSIG", "3600", "DNSKEY",
                    "DNSKEY 5 2 3600 20100322084538 20100220084538 "
-                   "33495 example.com. FAKEFAKEFAKEFAKE", "example.com.");
-    EXPECT_FALSE(db->getNextRecord(columns, column_count));
+                   "33495 example.com. FAKEFAKEFAKEFAKE", "");
+    EXPECT_FALSE(context->getNext(columns, column_count));
     // getnextrecord returning false should mean array is not altered
     checkRecordRow(columns, "RRSIG", "3600", "DNSKEY",
                    "DNSKEY 5 2 3600 20100322084538 20100220084538 "
-                   "33495 example.com. FAKEFAKEFAKEFAKE", "example.com.");
+                   "33495 example.com. FAKEFAKEFAKEFAKE", "");
 }
 
 } // end anonymous namespace