Browse Source

[trac1062] tests and some additional code

Jelte Jansen 13 years ago
parent
commit
71b0ae9ddb

+ 7 - 1
src/lib/datasrc/database.cc

@@ -74,6 +74,7 @@ DatabaseClient::Finder::find(const isc::dns::Name& name,
     connection_.searchForRecords(zone_id_, name.toText());
 
     isc::dns::RRsetPtr result_rrset;
+    ZoneFinder::Result result_status = NXRRSET;
 
     std::vector<std::string> columns;
     while (connection_.getNextRecord(columns)) {
@@ -96,6 +97,7 @@ DatabaseClient::Finder::find(const isc::dns::Name& name,
                                                                       getClass(),
                                                                       cur_type,
                                                                       cur_ttl));
+                result_status = SUCCESS;
             } else {
                 // We have existing data from earlier calls, do some checks
                 // and updates if necessary
@@ -120,11 +122,15 @@ DatabaseClient::Finder::find(const isc::dns::Name& name,
             result_rrset->addRdata(isc::dns::rdata::createRdata(cur_type,
                                                                 getClass(),
                                                                 columns[3]));
+            result_status = CNAME;
+        } else if (cur_type == isc::dns::RRType::RRSIG()) {
+            // if we have data already, check covered type
+            // if not, covered type must be CNAME or type requested
         }
     }
 
     if (result_rrset) {
-        return (FindResult(SUCCESS, result_rrset));
+        return (FindResult(result_status, result_rrset));
     } else if (records_found) {
         return (FindResult(NXRRSET, isc::dns::ConstRRsetPtr()));
     } else {

+ 6 - 2
src/lib/datasrc/database.h

@@ -78,7 +78,7 @@ public:
      * \param zone_id The zone to search in, as returned by getZone()
      * \param name The name of the records to find
      */
-    virtual void searchForRecords(int zone_id, const std::string& name) const = 0;
+    virtual void searchForRecords(int zone_id, const std::string& name) = 0;
 
     /**
      * \brief Retrieves the next record from the search started with searchForRecords()
@@ -93,7 +93,7 @@ public:
      *                and rdata). If there was no data, the vector is untouched.
      * \return true if there was a next record, false if there was not
      */
-    virtual bool getNextRecord(std::vector<std::string>& columns) const = 0;
+    virtual bool getNextRecord(std::vector<std::string>& columns) = 0;
 };
 
 /**
@@ -154,6 +154,10 @@ public:
         Finder(boost::shared_ptr<DatabaseConnection> connection, int zone_id);
         virtual isc::dns::Name getOrigin() const;
         virtual isc::dns::RRClass getClass() const;
+
+        /**
+         * \brief Find an RRset in the datasource
+         */
         virtual FindResult find(const isc::dns::Name& name,
                                 const isc::dns::RRType& type,
                                 isc::dns::RRsetList* target = NULL,

+ 2 - 2
src/lib/datasrc/sqlite3_connection.cc

@@ -320,7 +320,7 @@ SQLite3Connection::getZone(const isc::dns::Name& name) const {
 }
 
 void
-SQLite3Connection::searchForRecords(int zone_id, const std::string& name) const {
+SQLite3Connection::searchForRecords(int zone_id, const std::string& name) {
     sqlite3_reset(dbparameters_->q_any_);
     sqlite3_clear_bindings(dbparameters_->q_any_);
     sqlite3_bind_int(dbparameters_->q_any_, 1, zone_id);
@@ -341,7 +341,7 @@ convertToPlainChar(const unsigned char* ucp) {
 }
 
 bool
-SQLite3Connection::getNextRecord(std::vector<std::string>& columns) const {
+SQLite3Connection::getNextRecord(std::vector<std::string>& columns) {
     sqlite3_stmt* current_stmt = dbparameters_->q_any_;
     const int rc = sqlite3_step(current_stmt);
 

+ 2 - 2
src/lib/datasrc/sqlite3_connection.h

@@ -88,8 +88,8 @@ public:
      *     element and the zone id in the second if it was.
      */
     virtual std::pair<bool, int> getZone(const isc::dns::Name& name) const;
-    virtual void searchForRecords(int zone_id, const std::string& name) const;
-    virtual bool getNextRecord(std::vector<std::string>& columns) const;
+    virtual void searchForRecords(int zone_id, const std::string& name);
+    virtual bool getNextRecord(std::vector<std::string>& columns);
 private:
     /// \brief Private database data
     SQLite3Parameters* dbparameters_;

+ 117 - 2
src/lib/datasrc/tests/database_unittest.cc

@@ -18,6 +18,10 @@
 #include <exceptions/exceptions.h>
 
 #include <datasrc/database.h>
+#include <datasrc/zone.h>
+#include <datasrc/data_source.h>
+
+#include <map>
 
 using namespace isc::datasrc;
 using namespace std;
@@ -32,6 +36,8 @@ namespace {
  */
 class MockConnection : public DatabaseConnection {
 public:
+    MockConnection() { fillData(); }
+
     virtual std::pair<bool, int> getZone(const Name& name) const {
         if (name == Name("example.org")) {
             return (std::pair<bool, int>(true, 42));
@@ -39,8 +45,74 @@ public:
             return (std::pair<bool, int>(false, 0));
         }
     }
-    virtual void searchForRecords(int, const std::string&) const {};
-    virtual bool getNextRecord(std::vector<std::string>&) const { return false; };
+
+    virtual void searchForRecords(int zone_id, const std::string& name) {
+        // we're not aiming for efficiency in this test, simply
+        // copy the relevant vector from records
+        cur_record = 0;
+
+        if (zone_id == 42) {
+            if (records.count(name) > 0) {
+                cur_name = records.find(name)->second;
+            } else {
+                cur_name.clear();
+            }
+        } else {
+            cur_name.clear();
+        }
+    };
+
+    virtual bool getNextRecord(std::vector<std::string>& columns) {
+        if (cur_record < cur_name.size()) {
+            columns = cur_name[cur_record++];
+            return true;
+        } else {
+            return false;
+        }
+    };
+
+private:
+    std::map<std::string, std::vector< std::vector<std::string> > > records;
+    // used as internal index for getNextRecord()
+    size_t cur_record;
+    // used as temporary storage after searchForRecord() and during
+    // getNextRecord() calls, as well as during the building of the
+    // fake data
+    std::vector< std::vector<std::string> > cur_name;
+
+    void addRecord(const std::string& name,
+                   const std::string& type,
+                   const std::string& sigtype,
+                   const std::string& rdata) {
+        std::vector<std::string> columns;
+        columns.push_back(name);
+        columns.push_back(type);
+        columns.push_back(sigtype);
+        columns.push_back(rdata);
+        cur_name.push_back(columns);
+    }
+
+    void addCurName(const std::string& name) {
+        records[name] = cur_name;
+        cur_name.clear();
+    }
+
+    void fillData() {
+        addRecord("A", "3600", "", "192.0.2.1");
+        addRecord("AAAA", "3600", "", "2001:db8::1");
+        addRecord("AAAA", "3600", "", "2001:db8::2");
+        addCurName("www.example.org.");
+        addRecord("CNAME", "3600", "", "www.example.org.");
+        addCurName("cname.example.org.");
+
+        // also add some intentionally bad data
+        cur_name.push_back(std::vector<std::string>());
+        addCurName("emptyvector.example.org.");
+        addRecord("A", "3600", "", "192.0.2.1");
+        addRecord("CNAME", "3600", "", "www.example.org.");
+        addCurName("badcname.example.org.");
+        
+    }
 };
 
 class DatabaseClientTest : public ::testing::Test {
@@ -98,4 +170,47 @@ TEST_F(DatabaseClientTest, noConnException) {
                  isc::InvalidParameter);
 }
 
+TEST_F(DatabaseClientTest, find) {
+    DataSourceClient::FindResult zone(client_->findZone(Name("example.org")));
+    ASSERT_EQ(result::SUCCESS, zone.code);
+    shared_ptr<DatabaseClient::Finder> finder(
+        dynamic_pointer_cast<DatabaseClient::Finder>(zone.zone_finder));
+    EXPECT_EQ(42, finder->zone_id());
+    isc::dns::Name name("www.example.org.");
+
+    ZoneFinder::FindResult result1 = finder->find(name, isc::dns::RRType::A(),
+                                                  NULL, ZoneFinder::FIND_DEFAULT);
+    ASSERT_EQ(ZoneFinder::SUCCESS, result1.code);
+    EXPECT_EQ(1, result1.rrset->getRdataCount());
+    EXPECT_EQ(isc::dns::RRType::A(), result1.rrset->getType());
+
+    ZoneFinder::FindResult result2 = finder->find(name, isc::dns::RRType::AAAA(),
+                                                  NULL, ZoneFinder::FIND_DEFAULT);
+    ASSERT_EQ(ZoneFinder::SUCCESS, result2.code);
+    EXPECT_EQ(2, result2.rrset->getRdataCount());
+    EXPECT_EQ(isc::dns::RRType::AAAA(), result2.rrset->getType());
+
+    ZoneFinder::FindResult result3 = finder->find(name, isc::dns::RRType::TXT(),
+                                                  NULL, ZoneFinder::FIND_DEFAULT);
+    ASSERT_EQ(ZoneFinder::NXRRSET, result3.code);
+    EXPECT_EQ(isc::dns::ConstRRsetPtr(), result3.rrset);
+
+    ZoneFinder::FindResult result4 = finder->find(isc::dns::Name("cname.example.org."),
+                                                  isc::dns::RRType::A(),
+                                                  NULL, ZoneFinder::FIND_DEFAULT);
+    ASSERT_EQ(ZoneFinder::CNAME, result4.code);
+    EXPECT_EQ(1, result4.rrset->getRdataCount());
+    EXPECT_EQ(isc::dns::RRType::CNAME(), result4.rrset->getType());
+
+    EXPECT_THROW(finder->find(isc::dns::Name("emptyvector.example.org."),
+                                              isc::dns::RRType::A(),
+                                              NULL, ZoneFinder::FIND_DEFAULT),
+                 DataSourceError);
+    EXPECT_THROW(finder->find(isc::dns::Name("badcname.example.org."),
+                                              isc::dns::RRType::A(),
+                                              NULL, ZoneFinder::FIND_DEFAULT),
+                 DataSourceError);
+
+}
+
 }