Parcourir la source

- made AuthSrv construction exception-safe
- fixed memory leak for Datasrc* stored in the MetaDataSrc vector.
there are several possible ways to do this, but I chose to using
boost::shared_ptr. expect for portability issues this seems to be the
cleanest solution, and, regarding portability, we already heavily rely on
boost anyway, so we should revisit the whole design if/when we seriously
consider binary portability.


git-svn-id: svn://bind10.isc.org/svn/bind10/trunk@1118 e5f2f494-b856-4b98-b285-d166d9295462

JINMEI Tatuya il y a 15 ans
Parent
commit
3e248e341b

+ 63 - 21
src/bin/auth/auth_srv.cc

@@ -21,6 +21,7 @@
 #include <netdb.h>
 #include <stdlib.h>
 
+#include <cassert>
 #include <iostream>
 
 #include <dns/buffer.h>
@@ -33,6 +34,9 @@
 #include <config/ccsession.h>
 
 #include <auth/query.h>
+#include <auth/data_source.h>
+#include <auth/data_source_static.h>
+#include <auth/data_source_sqlite3.h>
 
 #include <cc/data.h>
 
@@ -49,11 +53,27 @@ using namespace isc::dns::rdata;
 using namespace isc::data;
 using namespace isc::config;
 
-AuthSrv::AuthSrv(int port) :
-    data_src(NULL), sock(-1)
+namespace {
+// This is a helper class to make construction of the AuthSrv class
+// exception safe.
+class AuthSocket {
+private:
+    // prohibit copy
+    AuthSocket(const AuthSocket& source);
+    AuthSocket& operator=(const AuthSocket& source);
+public:
+    AuthSocket(int port);
+    ~AuthSocket();
+    int getFD() const { return (fd_); }
+private:
+    int fd_;
+};
+
+AuthSocket::AuthSocket(int port) :
+    fd_(-1)
 {
-    int s = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
-    if (s < 0) {
+    fd_ = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
+    if (fd_ < 0) {
         throw FatalError("failed to open socket");
     }
 
@@ -67,34 +87,56 @@ AuthSrv::AuthSrv(int port) :
     sin.sin_len = sa_len;
 #endif
 
-    if (bind(s, (struct sockaddr *)&sin, sa_len) < 0) {
-        close(s);
+    if (bind(fd_, (const struct sockaddr *)&sin, sa_len) < 0) {
+        close(fd_);
         throw FatalError("could not bind socket");
     }
+}
 
-    sock = s;
-
-    // XXX: the following code is not exception-safe.  Will address in the
-    // next phase.
-
-    data_src = new(MetaDataSrc);
+AuthSocket::~AuthSocket() {
+    assert(fd_ >= 0);
+    close(fd_);
+}
+}
 
+class AuthSrvImpl {
+private:
+    // prohibit copy
+    AuthSrvImpl(const AuthSrvImpl& source);
+    AuthSrvImpl& operator=(const AuthSrvImpl& source);
+public:
+    AuthSrvImpl(int port);
+    AuthSocket sock;
+    std::string _db_file;
+    isc::auth::MetaDataSrc data_sources;
+};
+
+AuthSrvImpl::AuthSrvImpl(int port) :
+    sock(port)
+{
     // add static data source
-    data_src->addDataSrc(new StaticDataSrc);
+    data_sources.addDataSrc(ConstDataSrcPtr(new StaticDataSrc));
 
     // add SQL data source
     Sqlite3DataSrc* sd = new Sqlite3DataSrc;
     sd->init();
-    data_src->addDataSrc(sd);
+    data_sources.addDataSrc(ConstDataSrcPtr(sd));
+}
+
+AuthSrv::AuthSrv(int port)
+{
+    impl_ = new AuthSrvImpl(port);
 }
 
 AuthSrv::~AuthSrv()
 {
-    if (sock >= 0) {
-        close(sock);
-    }
+    delete impl_;
+}
 
-    delete data_src;
+int
+AuthSrv::getSocket() const
+{
+    return (impl_->sock.getFD());
 }
 
 void
@@ -103,7 +145,7 @@ AuthSrv::processMessage()
     struct sockaddr_storage ss;
     socklen_t sa_len = sizeof(ss);
     struct sockaddr* sa = static_cast<struct sockaddr*>((void*)&ss);
-    int s = sock;
+    const int s = impl_->sock.getFD();
     char recvbuf[4096];
     int cc;
 
@@ -134,7 +176,7 @@ AuthSrv::processMessage()
         msg.setUDPSize(sizeof(recvbuf));
 
         Query query(msg, dnssec_ok);
-        data_src->doQuery(query);
+        impl_->data_sources.doQuery(query);
 
         OutputBuffer obuffer(remote_bufsize);
         MessageRenderer renderer(obuffer);
@@ -150,7 +192,7 @@ void
 AuthSrv::setDbFile(const std::string& db_file)
 {
     cout << "Change data source file, call our data source's function to now read " << db_file << endl;
-    _db_file = db_file;
+    impl_->_db_file = db_file;
 }
 
 ElementPtr

+ 14 - 7
src/bin/auth/auth_srv.h

@@ -20,23 +20,30 @@
 #include <string>
 
 #include <cc/data.h>
-#include <auth/data_source_static.h>
-#include <auth/data_source_sqlite3.h>
+
+class AuthSrvImpl;
 
 class AuthSrv {
+    ///
+    /// \name Constructors, Assignment Operator and Destructor.
+    ///
+    /// Note: The copy constructor and the assignment operator are intentionally
+    /// defined as private.
+    //@{
+private:
+    AuthSrv(const AuthSrv& source);
+    AuthSrv& operator=(const AuthSrv& source);
 public:
     explicit AuthSrv(int port);
     ~AuthSrv();
-    int getSocket() { return (sock); }
+    //@}
+    int getSocket() const;
     void processMessage();
     void serve(std::string zone_name);
     void setDbFile(const std::string& db_file);
     isc::data::ElementPtr updateConfig(isc::data::ElementPtr config);
 private:
-    std::string _db_file;
-
-    isc::auth::MetaDataSrc* data_src;
-    int sock;
+    AuthSrvImpl* impl_;
 };
 
 #endif // __AUTH_SRV_H

+ 1 - 1
src/bin/auth/main.cc

@@ -48,7 +48,7 @@ const int DNSPORT = 5300;
 /* need global var for config/command handlers.
  * todo: turn this around, and put handlers in the authserver
  * class itself? */
-AuthSrv auth = AuthSrv(DNSPORT);
+AuthSrv auth(DNSPORT);
 
 static void
 usage() {

+ 7 - 6
src/lib/auth/cpp/data_source.cc

@@ -645,24 +645,25 @@ DataSrc::findReferral(const Query& q, const Name& qname, const RRClass& qclass,
 }
 
 void
-MetaDataSrc::addDataSrc(DataSrc* ds)
+MetaDataSrc::addDataSrc(ConstDataSrcPtr data_src)
 {
-    if (getClass() != RRClass::ANY() && ds->getClass() != getClass()) {
+    if (getClass() != RRClass::ANY() && data_src->getClass() != getClass()) {
         dns_throw(Unexpected, "class mismatch");
     }
 
-    data_sources.push_back(ds);
+    data_sources.push_back(data_src);
 }
 
 void
 MetaDataSrc::findClosestEnclosure(NameMatch& match) const
 {
-    BOOST_FOREACH (DataSrc* ds, data_sources) {
-        if (getClass() != RRClass::ANY() && ds->getClass() != getClass()) {
+    BOOST_FOREACH (ConstDataSrcPtr data_src, data_sources) {
+        if (getClass() != RRClass::ANY() &&
+            data_src->getClass() != getClass()) {
             continue;
         }
 
-        ds->findClosestEnclosure(match);
+        data_src->findClosestEnclosure(match);
     }
 }
 

+ 8 - 2
src/lib/auth/cpp/data_source.h

@@ -19,6 +19,8 @@
 
 #include <vector>
 
+#include <boost/shared_ptr.hpp>
+
 #include <dns/name.h>
 #include <dns/rrclass.h>
 
@@ -36,6 +38,10 @@ namespace auth {
 class NameMatch;
 class Query;
 
+class DataSrc;
+typedef boost::shared_ptr<DataSrc> DataSrcPtr;
+typedef boost::shared_ptr<const DataSrc> ConstDataSrcPtr;
+
 class AbstractDataSrc {
     ///
     /// \name Constructors, Assignment Operator and Destructor.
@@ -220,7 +226,7 @@ public:
     virtual ~MetaDataSrc() {}
     //@}
 
-    void addDataSrc(DataSrc* ds);
+    void addDataSrc(ConstDataSrcPtr data_src);
     void findClosestEnclosure(NameMatch& match) const;
 
     // Actual queries for data should not be sent to a MetaDataSrc object,
@@ -273,7 +279,7 @@ public:
     }
 
 private:
-    std::vector<DataSrc*> data_sources;
+    std::vector<ConstDataSrcPtr> data_sources;
 };
 
 class NameMatch {