Browse Source

incorporated the jelte-tcp branch, and made it possible to selectively build
the ASIO version or the builtin socket API version depending on the
availability of the boost::system library.


git-svn-id: svn://bind10.isc.org/svn/bind10/branches/jinmei-asio@1226 e5f2f494-b856-4b98-b285-d166d9295462

JINMEI Tatuya 15 years ago
parent
commit
91cfdfed1b
7 changed files with 334 additions and 79 deletions
  1. 6 2
      configure.ac
  2. 4 1
      src/bin/auth/Makefile.am
  3. 0 23
      src/bin/auth/common.cc
  4. 5 6
      src/bin/auth/common.h
  5. 304 47
      src/bin/auth/main.cc
  6. 10 0
      src/lib/cc/session.cc
  7. 5 0
      src/lib/config/ccsession.cc

+ 6 - 2
configure.ac

@@ -113,7 +113,7 @@ for BOOST_TRY_LIB in boost_system boost_system-mt; do
 		[ boost::system::error_code error_code;
 		  std::string message(error_code.message());
 		  return 0; ],
-	[ AC_MSG_RESULT(ok)
+	[ AC_MSG_RESULT(yes)
 	  BOOST_SYSTEM_LIB="-l${BOOST_TRY_LIB}"
 	  ],[])
 	if test "X${BOOST_SYSTEM_LIB}" != X; then
@@ -123,11 +123,15 @@ done
 
 if test "X${BOOST_SYSTEM_LIB}" = X; then
 	AC_MSG_RESULT(not found)
-	AC_MSG_ERROR(Unable to link with the boost::system library)
+else
+	AC_DEFINE(HAVE_BOOSTLIB, 1, Define to 1 if boost libraries are available)
 fi
 
+AM_CONDITIONAL(HAVE_BOOSTLIB, test "X${BOOST_SYSTEM_LIB}" != X)
+LDFLAGS="$LDFLAGS_SAVED"
 CPPFLAGS="$CPPFLAGS_SAVED"
 LIBS="$LIBS_SAVED"
+AC_SUBST(BOOST_LDFLAGS)
 AC_SUBST(BOOST_SYSTEM_LIB)
 
 #

+ 4 - 1
src/bin/auth/Makefile.am

@@ -6,7 +6,7 @@ CLEANFILES = *.gcno *.gcda
 
 pkglibexec_PROGRAMS = b10-auth
 b10_auth_SOURCES = auth_srv.cc auth_srv.h
-b10_auth_SOURCES += common.cc common.h
+b10_auth_SOURCES += common.h
 b10_auth_SOURCES += main.cc
 b10_auth_LDADD =  $(top_builddir)/src/lib/auth/.libs/libauth.a
 b10_auth_LDADD +=  $(top_builddir)/src/lib/dns/.libs/libdns.a
@@ -14,7 +14,10 @@ b10_auth_LDADD += $(top_builddir)/src/lib/config/libcfgclient.a
 b10_auth_LDADD += $(top_builddir)/src/lib/cc/libcc.a
 b10_auth_LDADD += $(top_builddir)/src/lib/exceptions/.libs/libexceptions.a
 b10_auth_LDADD += $(SQLITE_LIBS)
+if HAVE_BOOSTLIB
+b10_auth_LDFLAGS = $(AM_LDFLAGS) $(BOOST_LDFLAGS)
 b10_auth_LDADD += $(BOOST_SYSTEM_LIB)
+endif
 
 # TODO: config.h.in is wrong because doesn't honor pkgdatadir
 # and can't use @datadir@ because doesn't expand default ${prefix}

+ 0 - 23
src/bin/auth/common.cc

@@ -1,23 +0,0 @@
-// Copyright (C) 2009  Internet Systems Consortium, Inc. ("ISC")
-//
-// Permission to use, copy, modify, and/or distribute this software for any
-// purpose with or without fee is hereby granted, provided that the above
-// copyright notice and this permission notice appear in all copies.
-//
-// THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH
-// REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
-// AND FITNESS.  IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT,
-// INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
-// LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
-// OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
-// PERFORMANCE OF THIS SOFTWARE.
-
-// $Id$
-#include "common.h"
-#include <iostream>
-
-FatalError::FatalError(std::string m) {
-    msg = m;
-    std::cerr << msg << std::endl;
-    exit(1);
-}

+ 5 - 6
src/bin/auth/common.h

@@ -20,13 +20,12 @@
 #include <stdlib.h>
 #include <string>
 
-class FatalError : public std::exception {
+#include <exceptions/exceptions.h>
+
+class FatalError : public isc::Exception {
 public:
-    FatalError(std::string m = "fatal error");
-    ~FatalError() throw() {}
-    const char* what() const throw() { return msg.c_str(); }
-private:
-    std::string msg;
+    FatalError(const char* file, size_t line, const char* what) :
+        isc::Exception(file, line, what) {}
 };
 
 #endif // __COMMON_H

+ 304 - 47
src/bin/auth/main.cc

@@ -14,18 +14,25 @@
 
 // $Id$
 
+#include "../../../config.h"
+
 #include <sys/types.h>
 #include <sys/socket.h>
 #include <sys/select.h>
 #include <netdb.h>
 #include <stdlib.h>
+#include <errno.h>
 
 #include <set>
 #include <iostream>
 
 #include <boost/foreach.hpp>
+#ifdef HAVE_BOOSTLIB
 #include <boost/bind.hpp>
 #include <boost/asio.hpp>
+#endif
+
+#include <exceptions/exceptions.h>
 
 #include <dns/buffer.h>
 #include <dns/name.h>
@@ -38,17 +45,19 @@
 #include <cc/data.h>
 #include <config/ccsession.h>
 
-#include "common.h"
 #include "config.h"
+#include "common.h"
 #include "auth_srv.h"
 
 #include <boost/foreach.hpp>
 
 using namespace std;
 
+#ifdef HAVE_BOOSTLIB
 using namespace boost::asio;
 using ip::udp;
 using ip::tcp;
+#endif
 
 using namespace isc::data;
 using namespace isc::cc;
@@ -57,7 +66,7 @@ using namespace isc::dns;
 
 namespace {
 const string PROGRAM = "Auth";
-const short DNSPORT = 5300;
+const char* DNSPORT = "5300";
 }
 
 /* need global var for config/command handlers.
@@ -67,6 +76,26 @@ namespace {
 AuthSrv *auth_server;
 }
 
+static ElementPtr
+my_config_handler(ElementPtr new_config)
+{
+    return auth_server->updateConfig(new_config);
+}
+
+static ElementPtr
+my_command_handler(const string& command, const ElementPtr args) {
+    ElementPtr answer = createAnswer(0);
+
+    if (command == "print_message") 
+    {
+        cout << args << endl;
+        /* let's add that message to our answer as well */
+        answer->get("result")->add(args);
+    }
+    return answer;
+}
+
+#ifdef HAVE_BOOSTLIB
 //
 // Helper classes for asynchronous I/O using boost::asio
 //
@@ -272,43 +301,294 @@ private:
     enum { MAX_LENGTH = 4096 };
     char data_[MAX_LENGTH];
 };
+
+struct ServerSet {
+    ServerSet() : udp4_server(NULL), udp6_server(NULL),
+                  tcp4_server(NULL), tcp6_server(NULL)
+    {}
+    ~ServerSet()
+    {
+        delete udp4_server;
+        delete udp6_server;
+        delete tcp4_server;
+        delete tcp6_server;
+    }
+    UDPServer* udp4_server;
+    UDPServer* udp6_server;
+    TCPServer* tcp4_server;
+    TCPServer* tcp6_server;
+};
+
+static void
+run_server(const char* port, const bool use_ipv4, const bool use_ipv6,
+           const string& specfile)
+{
+    ServerSet servers;
+    boost::asio::io_service io_service;
+    short portnum = atoi(port);
+
+    ModuleCCSession cs(specfile, io_service, my_config_handler,
+                       my_command_handler);
+
+    if (use_ipv4) {
+        servers.udp4_server = new UDPServer(io_service, AF_INET, portnum);
+        servers.tcp4_server = new TCPServer(io_service, AF_INET, portnum);
+    }
+    if (use_ipv6) {
+        servers.udp6_server = new UDPServer(io_service, AF_INET6, portnum);
+        servers.tcp6_server = new TCPServer(io_service, AF_INET6, portnum);
+    }
+
+    cout << "Server started." << endl;
+    io_service.run();
+}
+}
+#else  // !HAVE_BOOSTLIB
+struct SocketSet {
+    SocketSet() : ups4(-1), tps4(-1), ups6(-1), tps6(-1) {}
+    ~SocketSet()
+    {
+        if (ups4 >= 0) {
+            close(ups4);
+        }
+        if (tps4 >= 0) {
+            close(tps4);
+        }
+        if (ups6 >= 0) {
+            close(ups6);
+        }
+        if (tps4 >= 0) {
+            close(tps6);
+        }
+    }
+    int ups4, tps4, ups6, tps6;
+};
+
+static int
+getUDPSocket(int af, const char* port) {
+    struct addrinfo hints, *res;
+
+    memset(&hints, 0, sizeof(hints));
+    hints.ai_family = af;
+    hints.ai_socktype = SOCK_DGRAM;
+    hints.ai_flags = AI_PASSIVE;
+    hints.ai_protocol = IPPROTO_UDP;
+
+    int error = getaddrinfo(NULL, port, &hints, &res);
+    if (error != 0) {
+        isc_throw(FatalError, "getaddrinfo failed: " << gai_strerror(error));
+    }
+
+    int s = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
+    if (s < 0) {
+        isc_throw(FatalError, "failed to open socket");
+    }
+
+    if (af == AF_INET6) {
+        int on = 1;
+        if (setsockopt(s, IPPROTO_IPV6, IPV6_V6ONLY, &on, sizeof(on)) < 0) {
+            cerr << "couldn't set IPV6_V6ONLY socket option" << endl;
+            // proceed anyway
+        }
+    }
+
+    if (bind(s, res->ai_addr, res->ai_addrlen) < 0) {
+        isc_throw(FatalError, "binding socket failure");
+    }
+
+    return (s);
+}
+
+static int
+getTCPSocket(int af, const char* port) {
+    struct addrinfo hints, *res;
+
+    memset(&hints, 0, sizeof(hints));
+    hints.ai_family = af;
+    hints.ai_socktype = SOCK_STREAM;
+    hints.ai_flags = AI_PASSIVE;
+    hints.ai_protocol = IPPROTO_TCP;
+
+    int error = getaddrinfo(NULL, port, &hints, &res);
+    if (error != 0) {
+        isc_throw(FatalError, "getaddrinfo failed: " << gai_strerror(error));
+    }
+
+    int s = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
+    if (s < 0) {
+        isc_throw(FatalError, "failed to open socket");
+    }
+
+    int on = 1;
+    if (af == AF_INET6) {
+        if (setsockopt(s, IPPROTO_IPV6, IPV6_V6ONLY, &on, sizeof(on)) < 0) {
+            cerr << "couldn't set IPV6_V6ONLY socket option" << endl;
+        }
+    }
+
+    if (setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) < 0) {
+        cerr << "couldn't set SO_REUSEADDR socket option" << endl;
+    }
+
+    if (bind(s, res->ai_addr, res->ai_addrlen) < 0) {
+        isc_throw(FatalError, "binding socket failure");
+    }
+
+    listen(s, 100);
+    return (s);
 }
 
 static void
-usage() {
-    cerr << "Usage: b10-auth [-p port] [-4|-6]" << endl;
-    exit(1);
+processMessageUDP(const int fd, Message& dns_message,
+                  MessageRenderer& response_renderer)
+{
+    struct sockaddr_storage ss;
+    socklen_t sa_len = sizeof(ss);
+    struct sockaddr* sa = static_cast<struct sockaddr*>((void*)&ss);
+    char recvbuf[4096];
+    int cc;
+
+    dns_message.clear(Message::PARSE);
+    response_renderer.clear();
+    if ((cc = recvfrom(fd, recvbuf, sizeof(recvbuf), 0, sa, &sa_len)) > 0) {
+        InputBuffer buffer(recvbuf, cc);
+        if (auth_server->processMessage(buffer, dns_message, response_renderer,
+                                        true) == 0) {
+            sendto(fd, response_renderer.getData(),
+                   response_renderer.getLength(), 0, sa, sa_len);
+        }
+    }
 }
 
-ElementPtr
-my_config_handler(ElementPtr new_config)
+static void
+processMessageTCP(const int fd, Message& dns_message,
+                  MessageRenderer& response_renderer)
 {
-    return auth_server->updateConfig(new_config);
+    struct sockaddr_storage ss;
+    socklen_t sa_len = sizeof(ss);
+    struct sockaddr* sa = static_cast<struct sockaddr*>((void*)&ss);
+    char sizebuf[2];
+    int cc;
+    int ts = accept(fd, sa, &sa_len);
+
+    cout << "[XX] process TCP" << endl;
+    cc = recv(ts, sizebuf, 2, 0);
+    cout << "[XX] got: " << cc << endl;
+    uint16_t size, size_n;
+    memcpy(&size_n, sizebuf, 2);
+    size = ntohs(size_n);
+    cout << "[XX] got: " << size << endl;
+
+    vector<char> message_buffer;
+    message_buffer.reserve(size);
+    cc = 0;
+    while (cc < size) {
+        cout << "[XX] cc now: " << cc << " of " << size << endl;
+        cc += recv(ts, &message_buffer[0] + cc, size - cc, 0);
+    }
+
+    InputBuffer buffer(&message_buffer[0], size);
+    dns_message.clear(Message::PARSE);
+    response_renderer.clear();
+    if (auth_server->processMessage(buffer, dns_message, response_renderer,
+                                    false) == 0) {
+        size = response_renderer.getLength();
+        size_n = htons(size);
+        if (send(ts, &size_n, 2, 0) == 2) {
+            cc = send(ts, response_renderer.getData(),
+                      response_renderer.getLength(), 0);
+            if (cc == -1) {
+                cerr << "[AuthSrv] error in sending TCP response message" <<
+                    endl;
+            } else {
+                cout << "[XX] sent TCP response: " << cc << " bytes" << endl;
+            }
+        }
+    }
+ 
+   // TODO: we don't check for more queries on the stream atm
+    close(ts);
 }
 
-ElementPtr
-my_command_handler(const string& command, const ElementPtr args) {
-    ElementPtr answer = createAnswer(0);
+static void
+run_server(const char* port, const bool use_ipv4, const bool use_ipv6,
+           const string& specfile)
+{
+    SocketSet socket_set;
+    fd_set fds_base;
+    int nfds = -1;
+
+    FD_ZERO(&fds_base);
+    if (use_ipv4) {
+        socket_set.ups4 = getUDPSocket(AF_INET, port);
+        FD_SET(socket_set.ups4, &fds_base);
+        nfds = max(nfds, socket_set.ups4);
+        socket_set.tps4 = getTCPSocket(AF_INET, port);
+        FD_SET(socket_set.tps4, &fds_base);
+        nfds = max(nfds, socket_set.tps4);
+    }
+    if (use_ipv6) {
+        socket_set.ups6 = getUDPSocket(AF_INET6, port);
+        FD_SET(socket_set.ups6, &fds_base);
+        nfds = max(nfds, socket_set.ups6);
+        socket_set.tps6 = getTCPSocket(AF_INET6, port);
+        FD_SET(socket_set.tps6, &fds_base);
+        nfds = max(nfds, socket_set.tps6);
+    }
+    ++nfds;
 
-    if (command == "print_message") 
-    {
-        cout << args << endl;
-        /* let's add that message to our answer as well */
-        answer->get("result")->add(args);
+    ModuleCCSession cs(specfile, my_config_handler, my_command_handler);
+
+    cout << "Server started." << endl;
+    
+    int ss = cs.getSocket();
+    Message dns_message(Message::PARSE);
+    OutputBuffer resonse_buffer(0);
+    MessageRenderer response_renderer(resonse_buffer);
+
+    while (true) {
+        fd_set fds = fds_base;
+        FD_SET(ss, &fds);
+
+        int n = select(nfds, &fds, NULL, NULL, NULL);
+        if (n < 0) {
+            if (errno != EINTR) {
+                isc_throw(FatalError, "select error");
+            }
+            continue;
+        }
+
+        if (socket_set.ups4 >= 0 && FD_ISSET(socket_set.ups4, &fds)) {
+            processMessageUDP(socket_set.ups4, dns_message, response_renderer);
+        }
+        if (socket_set.ups6 >= 0 && FD_ISSET(socket_set.ups6, &fds)) {
+            processMessageUDP(socket_set.ups6, dns_message, response_renderer);
+        }
+        if (socket_set.tps4 >= 0 && FD_ISSET(socket_set.tps4, &fds)) {
+            processMessageTCP(socket_set.tps4, dns_message, response_renderer);
+        }
+        if (socket_set.tps6 >= 0 && FD_ISSET(socket_set.tps6, &fds)) {
+            processMessageTCP(socket_set.tps6, dns_message, response_renderer);
+        }
+        if (FD_ISSET(ss, &fds)) {
+            cs.check_command();
+        }
     }
-    return answer;
+}
+#endif // HAVE_BOOSTLIB
+
+static void
+usage() {
+    cerr << "Usage: b10-auth [-p port] [-4|-6]" << endl;
+    exit(1);
 }
 
 int
 main(int argc, char* argv[]) {
     int ch;
-    short port = DNSPORT;
+    const char* port = DNSPORT;
     bool ipv4_only = false, ipv6_only = false;
     bool use_ipv4 = false, use_ipv6 = false;
-    UDPServer* udp4_server = NULL;
-    UDPServer* udp6_server = NULL;
-    TCPServer* tcp4_server = NULL;
-    TCPServer* tcp6_server = NULL;
 
     while ((ch = getopt(argc, argv, "46p:")) != -1) {
         switch (ch) {
@@ -319,7 +599,7 @@ main(int argc, char* argv[]) {
             ipv6_only = true;
             break;
         case 'p':
-            port = atoi(optarg);
+            port = optarg;
             break;
         case '?':
         default:
@@ -355,35 +635,12 @@ main(int argc, char* argv[]) {
             specfile = string(AUTH_SPECFILE_LOCATION);
         }
 
-        // XXX: in this prototype code we'll ignore any message on the command
-        // channel.
-
-        boost::asio::io_service io_service;
-
-        ModuleCCSession cs(specfile, io_service, my_config_handler,
-                           my_command_handler);
-
-        if (use_ipv4) {
-            udp4_server = new UDPServer(io_service, AF_INET, port);
-            tcp4_server = new TCPServer(io_service, AF_INET, port);
-        }
-        if (use_ipv6) {
-            udp6_server = new UDPServer(io_service, AF_INET6, port);
-            tcp6_server = new TCPServer(io_service, AF_INET6, port);
-        }
-
-        cout << "Server started." << endl;
-        io_service.run();
+        run_server(port, use_ipv4, use_ipv6, specfile);
     } catch (const std::exception& ex) {
         cerr << ex.what() << endl;
         ret = 1;
     }
 
-    delete udp4_server;
-    delete tcp4_server;
-    delete udp6_server;
-    delete tcp6_server;
-
     delete auth_server;
     return (ret);
 }

+ 10 - 0
src/lib/cc/session.cc

@@ -14,6 +14,8 @@
 
 // $Id$
 
+#include "config.h"
+
 #include <stdint.h>
 
 #include <cstdio>
@@ -21,9 +23,11 @@
 #include <iostream>
 #include <sstream>
 
+#ifdef HAVE_BOOSTLIB
 #include <boost/bind.hpp>
 #include <boost/function.hpp>
 #include <boost/asio.hpp>
+#endif
 
 #include <exceptions/exceptions.h>
 
@@ -34,10 +38,12 @@ using namespace std;
 using namespace isc::cc;
 using namespace isc::data;
 
+#ifdef HAVE_BOOSTLIB
 // some of the boost::asio names conflict with socket API system calls
 // (e.g. write(2)) so we don't import the entire boost::asio namespace.
 using boost::asio::io_service;
 using boost::asio::ip::tcp;
+#endif
 
 #include <sys/types.h>
 #include <sys/socket.h>
@@ -62,6 +68,7 @@ public:
     std::string lname_;
 };
 
+#ifdef HAVE_BOOSTLIB
 class ASIOSession : public SessionImpl {
 public:
     ASIOSession(io_service& io_service) :
@@ -163,6 +170,7 @@ ASIOSession::internalRead(const boost::system::error_code& error,
         isc_throw(SessionError, "asynchronous read failed");
     }
 }
+#endif
 
 class SocketSession : public SessionImpl {
 public:
@@ -257,8 +265,10 @@ SocketSession::readData(void* data, const size_t datalen) {
 Session::Session() : impl_(new SocketSession)
 {}
 
+#ifdef HAVE_BOOSTLIB
 Session::Session(io_service& io_service) : impl_(new ASIOSession(io_service))
 {}
+#endif
 
 Session::~Session() {
     delete impl_;

+ 5 - 0
src/lib/config/ccsession.cc

@@ -20,6 +20,7 @@
 //               react on config change announcements)
 //
 
+#include "config.h"
 
 #include <stdexcept>
 #include <stdlib.h>
@@ -31,7 +32,9 @@
 #include <sstream>
 #include <cerrno>
 
+#ifdef HAVE_BOOSTLIB
 #include <boost/bind.hpp>
+#endif
 #include <boost/foreach.hpp>
 
 #include <cc/data.h>
@@ -163,6 +166,7 @@ ModuleCCSession::read_module_specification(const std::string& filename) {
     file.close();
 }
 
+#ifdef HAVE_BOOSTLIB
 void
 ModuleCCSession::startCheck() {
     // data available on the command channel.  process it in the synchronous
@@ -187,6 +191,7 @@ ModuleCCSession::ModuleCCSession(
     // register callback for asynchronous read
     session_.startRead(boost::bind(&ModuleCCSession::startCheck, this));
 }
+#endif
 
 ModuleCCSession::ModuleCCSession(
     std::string spec_file_name,