Browse Source

made the auth server dual-stack (kind of quick hack)

git-svn-id: svn://bind10.isc.org/svn/bind10/trunk@1129 e5f2f494-b856-4b98-b285-d166d9295462
JINMEI Tatuya 15 years ago
parent
commit
95a8834341
3 changed files with 128 additions and 100 deletions
  1. 7 63
      src/bin/auth/auth_srv.cc
  2. 2 3
      src/bin/auth/auth_srv.h
  3. 119 34
      src/bin/auth/main.cc

+ 7 - 63
src/bin/auth/auth_srv.cc

@@ -53,67 +53,18 @@ using namespace isc::dns::rdata;
 using namespace isc::data;
 using namespace isc::data;
 using namespace isc::config;
 using namespace isc::config;
 
 
-namespace {
-// This is a helper class to make construction of the AuthSrv class
-// exception safe.
-class AuthSocket {
-private:
-    // prohibit copy
-    AuthSocket(const AuthSocket& source);
-    AuthSocket& operator=(const AuthSocket& source);
-public:
-    AuthSocket(int port);
-    ~AuthSocket();
-    int getFD() const { return (fd_); }
-private:
-    int fd_;
-};
-
-AuthSocket::AuthSocket(int port) :
-    fd_(-1)
-{
-    fd_ = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
-    if (fd_ < 0) {
-        throw FatalError("failed to open socket");
-    }
-
-    struct sockaddr_in sin;
-    sin.sin_family = AF_INET;
-    sin.sin_addr.s_addr = INADDR_ANY;
-    sin.sin_port = htons(port);
-
-    socklen_t sa_len = sizeof(sin);
-#ifdef HAVE_SIN_LEN
-    sin.sin_len = sa_len;
-#endif
-
-    if (bind(fd_, (const struct sockaddr *)&sin, sa_len) < 0) {
-        close(fd_);
-        throw FatalError("could not bind socket");
-    }
-}
-
-AuthSocket::~AuthSocket() {
-    assert(fd_ >= 0);
-    close(fd_);
-}
-}
-
 class AuthSrvImpl {
 class AuthSrvImpl {
 private:
 private:
     // prohibit copy
     // prohibit copy
     AuthSrvImpl(const AuthSrvImpl& source);
     AuthSrvImpl(const AuthSrvImpl& source);
     AuthSrvImpl& operator=(const AuthSrvImpl& source);
     AuthSrvImpl& operator=(const AuthSrvImpl& source);
 public:
 public:
-    AuthSrvImpl(int port);
-    AuthSocket sock;
+    AuthSrvImpl();
     std::string _db_file;
     std::string _db_file;
     isc::auth::MetaDataSrc data_sources;
     isc::auth::MetaDataSrc data_sources;
 };
 };
 
 
-AuthSrvImpl::AuthSrvImpl(int port) :
-    sock(port)
-{
+AuthSrvImpl::AuthSrvImpl() {
     // add static data source
     // add static data source
     data_sources.addDataSrc(ConstDataSrcPtr(new StaticDataSrc));
     data_sources.addDataSrc(ConstDataSrcPtr(new StaticDataSrc));
 
 
@@ -123,9 +74,9 @@ AuthSrvImpl::AuthSrvImpl(int port) :
     data_sources.addDataSrc(ConstDataSrcPtr(sd));
     data_sources.addDataSrc(ConstDataSrcPtr(sd));
 }
 }
 
 
-AuthSrv::AuthSrv(int port)
+AuthSrv::AuthSrv()
 {
 {
-    impl_ = new AuthSrvImpl(port);
+    impl_ = new AuthSrvImpl;
 }
 }
 
 
 AuthSrv::~AuthSrv()
 AuthSrv::~AuthSrv()
@@ -133,23 +84,16 @@ AuthSrv::~AuthSrv()
     delete impl_;
     delete impl_;
 }
 }
 
 
-int
-AuthSrv::getSocket() const
-{
-    return (impl_->sock.getFD());
-}
-
 void
 void
-AuthSrv::processMessage()
+AuthSrv::processMessage(const int fd)
 {
 {
     struct sockaddr_storage ss;
     struct sockaddr_storage ss;
     socklen_t sa_len = sizeof(ss);
     socklen_t sa_len = sizeof(ss);
     struct sockaddr* sa = static_cast<struct sockaddr*>((void*)&ss);
     struct sockaddr* sa = static_cast<struct sockaddr*>((void*)&ss);
-    const int s = impl_->sock.getFD();
     char recvbuf[4096];
     char recvbuf[4096];
     int cc;
     int cc;
 
 
-    if ((cc = recvfrom(s, recvbuf, sizeof(recvbuf), 0, sa, &sa_len)) > 0) {
+    if ((cc = recvfrom(fd, recvbuf, sizeof(recvbuf), 0, sa, &sa_len)) > 0) {
         Message msg(Message::PARSE);
         Message msg(Message::PARSE);
         InputBuffer buffer(recvbuf, cc);
         InputBuffer buffer(recvbuf, cc);
 
 
@@ -184,7 +128,7 @@ AuthSrv::processMessage()
         cout << "sending a response (" <<
         cout << "sending a response (" <<
             boost::lexical_cast<string>(obuffer.getLength())
             boost::lexical_cast<string>(obuffer.getLength())
                   << " bytes):\n" << msg.toText() << endl;
                   << " bytes):\n" << msg.toText() << endl;
-        sendto(s, obuffer.getData(), obuffer.getLength(), 0, sa, sa_len);
+        sendto(fd, obuffer.getData(), obuffer.getLength(), 0, sa, sa_len);
     }
     }
 }
 }
 
 

+ 2 - 3
src/bin/auth/auth_srv.h

@@ -34,11 +34,10 @@ private:
     AuthSrv(const AuthSrv& source);
     AuthSrv(const AuthSrv& source);
     AuthSrv& operator=(const AuthSrv& source);
     AuthSrv& operator=(const AuthSrv& source);
 public:
 public:
-    explicit AuthSrv(int port);
+    explicit AuthSrv();
     ~AuthSrv();
     ~AuthSrv();
     //@}
     //@}
-    int getSocket() const;
-    void processMessage();
+    void processMessage(int fd);
     void serve(std::string zone_name);
     void serve(std::string zone_name);
     void setDbFile(const std::string& db_file);
     void setDbFile(const std::string& db_file);
     isc::data::ElementPtr updateConfig(isc::data::ElementPtr config);
     isc::data::ElementPtr updateConfig(isc::data::ElementPtr config);

+ 119 - 34
src/bin/auth/main.cc

@@ -41,31 +41,38 @@
 #include <boost/foreach.hpp>
 #include <boost/foreach.hpp>
 
 
 using namespace std;
 using namespace std;
+using namespace isc::data;
+using namespace isc::cc;
+using namespace isc::config;
 
 
+namespace {
 const string PROGRAM = "Auth";
 const string PROGRAM = "Auth";
-const int DNSPORT = 5300;
+const char* DNSPORT = "5300";
+}
 
 
 /* need global var for config/command handlers.
 /* need global var for config/command handlers.
  * todo: turn this around, and put handlers in the authserver
  * todo: turn this around, and put handlers in the authserver
  * class itself? */
  * class itself? */
-AuthSrv auth(DNSPORT);
+namespace {
+AuthSrv *auth_server;
+}
 
 
 static void
 static void
 usage() {
 usage() {
-    cerr << "Usage: b10-auth [-p port]" << endl;
+    cerr << "Usage: b10-auth [-p port] [-4|-6]" << endl;
     exit(1);
     exit(1);
 }
 }
 
 
-isc::data::ElementPtr
-my_config_handler(isc::data::ElementPtr new_config)
+ElementPtr
+my_config_handler(ElementPtr new_config)
 {
 {
-    auth.updateConfig(new_config);
-    return isc::config::createAnswer(0);
+    auth_server->updateConfig(new_config);
+    return createAnswer(0);
 }
 }
 
 
-isc::data::ElementPtr
-my_command_handler(const std::string& command, const isc::data::ElementPtr args) {
-    isc::data::ElementPtr answer = isc::config::createAnswer(0);
+ElementPtr
+my_command_handler(const string& command, const ElementPtr args) {
+    ElementPtr answer = createAnswer(0);
 
 
     cout << "[XX] Handle command: " << endl << command << endl;
     cout << "[XX] Handle command: " << endl << command << endl;
     if (command == "print_message") 
     if (command == "print_message") 
@@ -77,15 +84,54 @@ my_command_handler(const std::string& command, const isc::data::ElementPtr args)
     return answer;
     return answer;
 }
 }
 
 
+static int
+getSocket(int af, const char* port) {
+    struct addrinfo hints, *res;
+
+    memset(&hints, 0, sizeof(hints));
+    hints.ai_family = af;
+    hints.ai_socktype = SOCK_DGRAM;
+    hints.ai_flags = AI_PASSIVE;
+    hints.ai_protocol = IPPROTO_UDP;
+
+    int error = getaddrinfo(NULL, port, &hints, &res);
+    if (error != 0) {
+        cerr << "getaddrinfo failed: " << gai_strerror(error);
+        return (-1);
+    }
+
+    int s = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
+    if (s < 0) {
+        cerr << "failed to open socket" << endl;
+        return (-1);
+    }
+
+    if (bind(s, res->ai_addr, res->ai_addrlen) < 0) {
+        cerr << "binding socket failure" << endl;
+        close(s);
+        return (-1);
+    }
+
+    return (s);
+}
+
 int
 int
 main(int argc, char* argv[]) {
 main(int argc, char* argv[]) {
     int ch;
     int ch;
-    int port = DNSPORT;
+    const char* port = DNSPORT;
+    bool ipv4_only = false, ipv6_only = false;
+    int ps4 = -1, ps6 = -1;
 
 
-    while ((ch = getopt(argc, argv, "p:")) != -1) {
+    while ((ch = getopt(argc, argv, "46p:")) != -1) {
         switch (ch) {
         switch (ch) {
+        case '4':
+            ipv4_only = true;
+            break;
+        case '6':
+            ipv6_only = true;
+            break;
         case 'p':
         case 'p':
-            port = atoi(optarg);
+            port = optarg;
             break;
             break;
         case '?':
         case '?':
         default:
         default:
@@ -93,52 +139,91 @@ main(int argc, char* argv[]) {
         }
         }
     }
     }
 
 
-    if (argc - optind > 0)
+    if (argc - optind > 0) {
         usage();
         usage();
+    }
+
+    if (ipv4_only && ipv6_only) {
+        cerr << "-4 and -6 can't coexist" << endl;
+        usage();
+    }
+    if (!ipv4_only) {
+        ps4 = getSocket(AF_INET, port);
+        if (ps4 < 0) {
+            exit(1);
+        }
+    }
+    if (!ipv6_only) {
+        ps6 = getSocket(AF_INET6, port);
+        if (ps6 < 0) {
+            if (ps4 < 0) {
+                close(ps4);
+            }
+            exit(1);
+        }
+    }
+
+    auth_server = new AuthSrv;
 
 
     // initialize command channel
     // initialize command channel
+    int ret = 0;
     try {
     try {
-        std::string specfile;
+        string specfile;
         if (getenv("B10_FROM_SOURCE")) {
         if (getenv("B10_FROM_SOURCE")) {
-            specfile = std::string(getenv("B10_FROM_SOURCE")) + "/src/bin/auth/auth.spec";
+            specfile = string(getenv("B10_FROM_SOURCE")) +
+                "/src/bin/auth/auth.spec";
         } else {
         } else {
-            specfile = std::string(AUTH_SPECFILE_LOCATION);
+            specfile = string(AUTH_SPECFILE_LOCATION);
         }
         }
-        isc::config::ModuleCCSession cs = isc::config::ModuleCCSession(specfile,
-                                                                       my_config_handler,
-                                                                       my_command_handler);
+        ModuleCCSession cs = ModuleCCSession(specfile, my_config_handler,
+                                             my_command_handler);
 
 
         // main server loop
         // main server loop
         fd_set fds;
         fd_set fds;
-        int ps = auth.getSocket();
         int ss = cs.getSocket();
         int ss = cs.getSocket();
-        int nfds = max(ps, ss) + 1;
+        int nfds = max(max(ps4, ps6), ss) + 1;
         int counter = 0;
         int counter = 0;
-    
+
         cout << "Server started." << endl;
         cout << "Server started." << endl;
         while (true) {
         while (true) {
             FD_ZERO(&fds);
             FD_ZERO(&fds);
-            FD_SET(ps, &fds);
+            if (ps4 >= 0) {
+                FD_SET(ps4, &fds);
+            }
+            if (ps6 >= 0) {
+                FD_SET(ps6, &fds);
+            }
             FD_SET(ss, &fds);
             FD_SET(ss, &fds);
-    
+
             int n = select(nfds, &fds, NULL, NULL, NULL);
             int n = select(nfds, &fds, NULL, NULL, NULL);
-            if (n < 0)
+            if (n < 0) {
                 throw FatalError("select error");
                 throw FatalError("select error");
-    
-            if (FD_ISSET(ps, &fds)) {
+            }
+
+            if (FD_ISSET(ps4, &fds)) {
+                ++counter;
+                auth_server->processMessage(ps4);
+            }
+            if (FD_ISSET(ps6, &fds)) {
                 ++counter;
                 ++counter;
-                auth.processMessage();
+                auth_server->processMessage(ps6);
             }
             }
     
     
-            /* isset not really necessary, but keep it for now */
             if (FD_ISSET(ss, &fds)) {
             if (FD_ISSET(ss, &fds)) {
                 cs.check_command();
                 cs.check_command();
             }
             }
         }
         }
-    } catch (isc::cc::SessionError se) {
+    } catch (SessionError se) {
         cout << se.what() << endl;
         cout << se.what() << endl;
-        exit(1);
+        ret = 1;
     }
     }
-    
-    return (0);
+
+    if (ps4 >= 0) {
+        close(ps4);
+    }
+    if (ps6 >= 0) {
+        close(ps6);
+    }
+    delete auth_server;
+    return (ret);
 }
 }