Parcourir la source

[2378] Detect class mismatch with copy mode

Michal 'vorner' Vaner il y a 12 ans
Parent
commit
c1ba429f5e
2 fichiers modifiés avec 54 ajouts et 3 suppressions
  1. 43 3
      src/lib/datasrc/tests/zone_loader_unittest.cc
  2. 11 0
      src/lib/datasrc/zone_loader.cc

+ 43 - 3
src/lib/datasrc/tests/zone_loader_unittest.cc

@@ -42,7 +42,8 @@ class MockClient : public DataSourceClient {
 public:
 public:
     MockClient() :
     MockClient() :
         commit_called_(false),
         commit_called_(false),
-        missing_zone_(false)
+        missing_zone_(false),
+        rrclass_(RRClass::IN())
     {}
     {}
     virtual FindResult findZone(const Name&) const {
     virtual FindResult findZone(const Name&) const {
         isc_throw(isc::NotImplemented, "Method not used in tests");
         isc_throw(isc::NotImplemented, "Method not used in tests");
@@ -67,6 +68,8 @@ public:
     bool commit_called_;
     bool commit_called_;
     // If set to true, getUpdater returns NULL
     // If set to true, getUpdater returns NULL
     bool missing_zone_;
     bool missing_zone_;
+    // The pretended class of the client. Usualy IN, but can be overriden.
+    RRClass rrclass_;
 };
 };
 
 
 // The updater isn't really correct according to the API. For example,
 // The updater isn't really correct according to the API. For example,
@@ -77,10 +80,11 @@ public:
 class Updater : public ZoneUpdater {
 class Updater : public ZoneUpdater {
 public:
 public:
     Updater(MockClient* client) :
     Updater(MockClient* client) :
-        client_(client)
+        client_(client),
+        finder_(client_->rrclass_)
     {}
     {}
     virtual ZoneFinder& getFinder() {
     virtual ZoneFinder& getFinder() {
-        isc_throw(isc::NotImplemented, "Method not used in tests");
+        return (finder_);
     }
     }
     virtual void addRRset(const isc::dns::AbstractRRset& rrset) {
     virtual void addRRset(const isc::dns::AbstractRRset& rrset) {
         if (client_->commit_called_) {
         if (client_->commit_called_) {
@@ -96,6 +100,34 @@ public:
     }
     }
 private:
 private:
     MockClient* client_;
     MockClient* client_;
+    class Finder : public ZoneFinder {
+    public:
+        Finder(const RRClass& rrclass) :
+            class_(rrclass)
+        {}
+        virtual RRClass getClass() const {
+            return (class_);
+        }
+        virtual Name getOrigin() const {
+            isc_throw(isc::NotImplemented, "Method not used in tests");
+        }
+        virtual shared_ptr<Context> find(const Name&, const RRType&,
+                                         const FindOptions)
+        {
+            isc_throw(isc::NotImplemented, "Method not used in tests");
+        }
+        virtual shared_ptr<Context> findAll(const Name&,
+                                            vector<ConstRRsetPtr>&,
+                                            const FindOptions)
+        {
+            isc_throw(isc::NotImplemented, "Method not used in tests");
+        }
+        virtual FindNSEC3Result findNSEC3(const Name&, bool) {
+            isc_throw(isc::NotImplemented, "Method not used in tests");
+        }
+    private:
+        const RRClass class_;
+    } finder_;
 };
 };
 
 
 ZoneUpdaterPtr
 ZoneUpdaterPtr
@@ -242,6 +274,14 @@ TEST_F(ZoneLoaderTest, copyMissingSource) {
                             source_client_), DataSourceError);
                             source_client_), DataSourceError);
 }
 }
 
 
+// The class of the source and destination are different
+TEST_F(ZoneLoaderTest, classMismatch) {
+    destination_client_.rrclass_ = RRClass::CH();
+    prepareSource(Name::ROOT_NAME(), "root.zone");
+    EXPECT_THROW(ZoneLoader(destination_client_, Name::ROOT_NAME(),
+                            source_client_), isc::InvalidParameter);
+}
+
 // Load an unsigned zone, all at once
 // Load an unsigned zone, all at once
 TEST_F(ZoneLoaderTest, loadUnsigned) {
 TEST_F(ZoneLoaderTest, loadUnsigned) {
     ZoneLoader loader(destination_client_, Name::ROOT_NAME(),
     ZoneLoader loader(destination_client_, Name::ROOT_NAME(),

+ 11 - 0
src/lib/datasrc/zone_loader.cc

@@ -43,6 +43,17 @@ ZoneLoader::ZoneLoader(DataSourceClient& destination, const Name& zone_name,
         isc_throw(DataSourceError, "Zone " << zone_name << " not found in "
         isc_throw(DataSourceError, "Zone " << zone_name << " not found in "
                   "destination data source, can't fill it with data");
                   "destination data source, can't fill it with data");
     }
     }
+    // The dereference of zone_finder is safe, if we can get iterator, we can
+    // get a finder.
+    //
+    // TODO: We probably need a getClass on the data source itself.
+    if (source.findZone(zone_name).zone_finder->getClass() !=
+        updater_->getFinder().getClass()) {
+        isc_throw(isc::InvalidParameter,
+                  "Source and destination class mismatch");
+    }
+}
+
 }
 }
 
 
 namespace {
 namespace {