Parcourir la source

71. [func] each
Add "-a" (address) option to bind10 to specify an address for
the auth server to listen on.

git-svn-id: svn://bind10.isc.org/svn/bind10/trunk@2396 e5f2f494-b856-4b98-b285-d166d9295462

Evan Hunt il y a 15 ans
Parent
commit
52ac5cf227

+ 4 - 0
ChangeLog

@@ -1,3 +1,7 @@
+  71.  [func]		each
+  	Add "-a" (address) option to bind10 to specify an address for
+	the auth server to listen on.
+
   70.  [func]		each
   70.  [func]		each
   	Added a hot-spot cache to libdatasrc to speed up access to
   	Added a hot-spot cache to libdatasrc to speed up access to
 	repeatedly-queried data and reduce the number of queries to
 	repeatedly-queried data and reduce the number of queries to

+ 85 - 19
src/bin/auth/asio_link.cc

@@ -18,6 +18,7 @@
 
 
 #include <unistd.h>             // for some IPC/network system calls
 #include <unistd.h>             // for some IPC/network system calls
 #include <asio.hpp>
 #include <asio.hpp>
+#include <boost/lexical_cast.hpp>
 #include <boost/bind.hpp>
 #include <boost/bind.hpp>
 
 
 #include <dns/buffer.h>
 #include <dns/buffer.h>
@@ -30,6 +31,7 @@
 
 
 #include "spec_config.h"        // for XFROUT.  should not be here.
 #include "spec_config.h"        // for XFROUT.  should not be here.
 #include "auth_srv.h"
 #include "auth_srv.h"
+#include "common.h"
 
 
 using namespace asio;
 using namespace asio;
 using ip::udp;
 using ip::udp;
@@ -200,7 +202,7 @@ private:
 class TCPServer {
 class TCPServer {
 public:
 public:
     TCPServer(AuthSrv* auth_server, io_service& io_service,
     TCPServer(AuthSrv* auth_server, io_service& io_service,
-              int af, short port) :
+              int af, uint16_t port) :
         auth_server_(auth_server), io_service_(io_service),
         auth_server_(auth_server), io_service_(io_service),
         acceptor_(io_service_), listening_(new TCPClient(auth_server_,
         acceptor_(io_service_), listening_(new TCPClient(auth_server_,
                                                          io_service_))
                                                          io_service_))
@@ -220,6 +222,23 @@ public:
                                            listening_, placeholders::error));
                                            listening_, placeholders::error));
     }
     }
 
 
+    TCPServer(AuthSrv* auth_server, io_service& io_service,
+              asio::ip::address addr, uint16_t port) :
+        auth_server_(auth_server),
+        io_service_(io_service), acceptor_(io_service_),
+        listening_(new TCPClient(auth_server, io_service_))
+    {
+        tcp::endpoint endpoint(addr, port);
+        acceptor_.open(endpoint.protocol());
+
+        acceptor_.set_option(tcp::acceptor::reuse_address(true));
+        acceptor_.bind(endpoint);
+        acceptor_.listen();
+        acceptor_.async_accept(listening_->getSocket(),
+                               boost::bind(&TCPServer::handleAccept, this,
+                                           listening_, placeholders::error));
+    }
+
     ~TCPServer() { delete listening_; }
     ~TCPServer() { delete listening_; }
 
 
     void handleAccept(TCPClient* new_client,
     void handleAccept(TCPClient* new_client,
@@ -248,7 +267,7 @@ private:
 class UDPServer {
 class UDPServer {
 public:
 public:
     UDPServer(AuthSrv* auth_server, io_service& io_service,
     UDPServer(AuthSrv* auth_server, io_service& io_service,
-              int af, short port) :
+              int af, uint16_t port) :
         auth_server_(auth_server),
         auth_server_(auth_server),
         io_service_(io_service),
         io_service_(io_service),
         socket_(io_service, af == AF_INET6 ? udp::v6() : udp::v4()),
         socket_(io_service, af == AF_INET6 ? udp::v6() : udp::v4()),
@@ -267,6 +286,18 @@ public:
         startReceive();
         startReceive();
     }
     }
 
 
+    UDPServer(AuthSrv* auth_server, io_service& io_service,
+              asio::ip::address addr, uint16_t port) :
+        auth_server_(auth_server), io_service_(io_service),
+        socket_(io_service, addr.is_v6() ? udp::v6() : udp::v4()),
+        response_buffer_(0),
+        response_renderer_(response_buffer_),
+        dns_message_(Message::PARSE)
+    {
+        socket_.bind(udp::endpoint(addr, port));
+        startReceive();
+    }
+
     void handleRequest(const asio::error_code& error,
     void handleRequest(const asio::error_code& error,
                        size_t bytes_recvd)
                        size_t bytes_recvd)
     {
     {
@@ -347,7 +378,7 @@ struct ServerSet {
 
 
 class IOServiceImpl {
 class IOServiceImpl {
 public:
 public:
-    IOServiceImpl(AuthSrv* auth_server, const char* port,
+    IOServiceImpl(AuthSrv* auth_server, const char* address, const char* port,
                   const bool use_ipv4, const bool use_ipv6);
                   const bool use_ipv4, const bool use_ipv6);
     ~IOServiceImpl();
     ~IOServiceImpl();
     asio::io_service io_service_;
     asio::io_service io_service_;
@@ -358,25 +389,59 @@ public:
     TCPServer* tcp6_server_;
     TCPServer* tcp6_server_;
 };
 };
 
 
-IOServiceImpl::IOServiceImpl(AuthSrv* auth_server, const char* const port,
-                             const bool use_ipv4, const bool use_ipv6) :
+IOServiceImpl::IOServiceImpl(AuthSrv* auth_server, const char* const address,
+                             const char* const port, const bool use_ipv4,
+                             const bool use_ipv6) :
     auth_server_(auth_server), udp4_server_(NULL), udp6_server_(NULL),
     auth_server_(auth_server), udp4_server_(NULL), udp6_server_(NULL),
     tcp4_server_(NULL), tcp6_server_(NULL)
     tcp4_server_(NULL), tcp6_server_(NULL)
 {
 {
     ServerSet servers;
     ServerSet servers;
-    short portnum = atoi(port);
+    uint16_t portnum = atoi(port);
 
 
-    if (use_ipv4) {
-        servers.udp4_server = new UDPServer(auth_server, io_service_,
-                                            AF_INET, portnum);
-        servers.tcp4_server = new TCPServer(auth_server, io_service_,
-                                            AF_INET, portnum);
+    try {
+        portnum = boost::lexical_cast<uint16_t>(port);
+    } catch (const std::exception& ex) {
+        isc_throw(FatalError, "[b10-auth] Invalid port number '"
+                              << port << "'");
     }
     }
-    if (use_ipv6) {
-        servers.udp6_server = new UDPServer(auth_server, io_service_,
-                                            AF_INET6, portnum);
-        servers.tcp6_server = new TCPServer(auth_server, io_service_,
-                                            AF_INET6, portnum);
+
+    if (address != NULL) {
+        asio::ip::address addr = asio::ip::address::from_string(address);
+
+        if ((addr.is_v6() && !use_ipv6)) {
+            isc_throw(FatalError,
+                      "[b10-auth] Error: -4 conflicts with " << addr);
+        }
+
+        if ((addr.is_v4() && !use_ipv4)) {
+            isc_throw(FatalError,
+                      "[b10-auth] Error: -6 conflicts with " << addr);
+        }
+
+        if (addr.is_v4()) {
+            servers.udp4_server = new UDPServer(auth_server, io_service_,
+                                                addr, portnum);
+            servers.tcp4_server = new TCPServer(auth_server, io_service_,
+                                                addr, portnum);
+         } else {
+            servers.udp6_server = new UDPServer(auth_server, io_service_,
+                                                addr, portnum);
+            servers.tcp6_server = new TCPServer(auth_server, io_service_,
+                                                addr, portnum);
+        }
+    } else {
+        if (use_ipv4) {
+            servers.udp4_server = new UDPServer(auth_server, io_service_,
+                                                AF_INET, portnum);
+            servers.tcp4_server = new TCPServer(auth_server, io_service_,
+                                                AF_INET, portnum);
+        }
+        if (use_ipv6) {
+            servers.udp6_server = new UDPServer(auth_server, io_service_,
+                                                AF_INET6, portnum);
+            servers.tcp6_server = new TCPServer(auth_server, io_service_,
+                                                AF_INET6, portnum);
+        }
     }
     }
 
 
     // Now we don't have to worry about exception, and need to make sure that
     // Now we don't have to worry about exception, and need to make sure that
@@ -394,9 +459,10 @@ IOServiceImpl::~IOServiceImpl() {
     delete tcp6_server_;
     delete tcp6_server_;
 }
 }
 
 
-IOService::IOService(AuthSrv* auth_server, const char* const port,
-                     const bool use_ipv4, const bool use_ipv6) {
-    impl_ = new IOServiceImpl(auth_server, port, use_ipv4, use_ipv6);
+IOService::IOService(AuthSrv* auth_server, const char* const address,
+                     const char* const port, const bool use_ipv4,
+                     const bool use_ipv6) {
+    impl_ = new IOServiceImpl(auth_server, address, port, use_ipv4, use_ipv6);
 }
 }
 
 
 IOService::~IOService() {
 IOService::~IOService() {

+ 2 - 1
src/bin/auth/asio_link.h

@@ -24,7 +24,8 @@ struct IOServiceImpl;
 
 
 class IOService {
 class IOService {
 public:
 public:
-    IOService(AuthSrv* auth_server, const char* const port,
+    IOService(AuthSrv* auth_server,
+              const char* const address, const char* const port,
               const bool use_ipv4, const bool use_ipv6);
               const bool use_ipv4, const bool use_ipv6);
     ~IOService();
     ~IOService();
     void run();
     void run();

+ 6 - 2
src/bin/auth/main.cc

@@ -95,6 +95,7 @@ int
 main(int argc, char* argv[]) {
 main(int argc, char* argv[]) {
     int ch;
     int ch;
     const char* port = DNSPORT;
     const char* port = DNSPORT;
+    const char* address = NULL;
     bool use_ipv4 = true, use_ipv6 = true, cache = true;
     bool use_ipv4 = true, use_ipv6 = true, cache = true;
 
 
     while ((ch = getopt(argc, argv, "46np:v")) != -1) {
     while ((ch = getopt(argc, argv, "46np:v")) != -1) {
@@ -113,6 +114,9 @@ main(int argc, char* argv[]) {
         case 'n':
         case 'n':
             cache = false;
             cache = false;
             break;
             break;
+        case 'a':
+            address = optarg;
+            break;
         case 'p':
         case 'p':
             port = optarg;
             port = optarg;
             break;
             break;
@@ -148,8 +152,8 @@ main(int argc, char* argv[]) {
         auth_server = new AuthSrv(cache);
         auth_server = new AuthSrv(cache);
         auth_server->setVerbose(verbose_mode);
         auth_server->setVerbose(verbose_mode);
 
 
-        io_service = new asio_link::IOService(auth_server, port, use_ipv4,
-                                              use_ipv6);
+        io_service = new asio_link::IOService(auth_server, address, port,
+                                              use_ipv4, use_ipv6);
 
 
         ModuleCCSession cs(specfile, io_service->get_io_service(),
         ModuleCCSession cs(specfile, io_service->get_io_service(),
                            my_config_handler, my_command_handler);
                            my_config_handler, my_command_handler);

+ 60 - 10
src/bin/bind10/bind10.py.in

@@ -56,6 +56,7 @@ import errno
 import time
 import time
 import select
 import select
 import random
 import random
+import socket
 from optparse import OptionParser, OptionValueError
 from optparse import OptionParser, OptionValueError
 import io
 import io
 import pwd
 import pwd
@@ -173,11 +174,36 @@ class ProcessInfo:
     def respawn(self):
     def respawn(self):
         self._spawn()
         self._spawn()
 
 
+class IPAddr:
+    """Stores an IPv4 or IPv6 address."""
+    family = None
+    addr = None
+
+    def __init__(self, addr):
+        try:
+            a = socket.inet_pton(socket.AF_INET, addr)
+            self.family = socket.AF_INET
+            self.addr = a
+            return
+        except:
+            pass
+
+        try:
+            a = socket.inet_pton(socket.AF_INET6, addr)
+            self.family = socket.AF_INET6
+            self.addr = a
+            return
+        except Exception as e:
+            raise e
+    
+    def __str__(self):
+        return socket.inet_ntop(self.family, self.addr)
+
 class BoB:
 class BoB:
     """Boss of BIND class."""
     """Boss of BIND class."""
     
     
-    def __init__(self, msgq_socket_file=None, auth_port=5300, nocache=False,
-                 verbose=False, setuid=None, username=None):
+    def __init__(self, msgq_socket_file=None, auth_port=5300, address='',
+                 nocache=False, verbose=False, setuid=None, username=None):
         """Initialize the Boss of BIND. This is a singleton (only one
         """Initialize the Boss of BIND. This is a singleton (only one
         can run).
         can run).
         
         
@@ -188,6 +214,9 @@ class BoB:
         self.verbose = verbose
         self.verbose = verbose
         self.msgq_socket_file = msgq_socket_file
         self.msgq_socket_file = msgq_socket_file
         self.auth_port = auth_port
         self.auth_port = auth_port
+        self.address = None
+        if address:
+            self.address = IPAddr(address)
         self.cc_session = None
         self.cc_session = None
         self.ccs = None
         self.ccs = None
         self.processes = {}
         self.processes = {}
@@ -303,12 +332,17 @@ class BoB:
         # start b10-auth
         # start b10-auth
         # XXX: this must be read from the configuration manager in the future
         # XXX: this must be read from the configuration manager in the future
         authargs = ['b10-auth', '-p', str(self.auth_port)]
         authargs = ['b10-auth', '-p', str(self.auth_port)]
+        if self.address:
+            authargs += ['-a', str(self.address)]
         if self.nocache:
         if self.nocache:
             authargs += ['-n']
             authargs += ['-n']
         if self.verbose:
         if self.verbose:
-            sys.stdout.write("[bind10] Starting b10-auth using port %d\n" %
-                             self.auth_port)
             authargs += ['-v']
             authargs += ['-v']
+            sys.stdout.write("Starting b10-auth using port %d" %
+                             self.auth_port)
+            if self.address:
+                sys.stdout.write(" on %s" % str(self.address))
+            sys.stdout.write("\n")
         try:
         try:
             auth = ProcessInfo("b10-auth", authargs,
             auth = ProcessInfo("b10-auth", authargs,
                                c_channel_env)
                                c_channel_env)
@@ -549,6 +583,18 @@ def check_port(option, opt_str, value, parser):
     else:
     else:
         raise OptionValueError("Unknown option " + opt_str)
         raise OptionValueError("Unknown option " + opt_str)
   
   
+def check_addr(option, opt_str, value, parser):
+    """Function to insure that the address we are passed is actually 
+    a valid address. Used by OptionParser() on startup."""
+    try:
+        IPAddr(value)
+    except:
+        raise OptionValueError("%s requires a valid IPv4 or IPv6 address" % opt_str)
+    if (opt_str == '-a' or opt_str == '--address'):
+        parser.values.address = value
+    else:
+        raise OptionValueError("Unknown option " + opt_str)
+  
 def main():
 def main():
     global options
     global options
     global boss_of_bind
     global boss_of_bind
@@ -558,19 +604,22 @@ def main():
 
 
     # Parse any command-line options.
     # Parse any command-line options.
     parser = OptionParser(version=__version__)
     parser = OptionParser(version=__version__)
-    parser.add_option("-v", "--verbose", dest="verbose", action="store_true",
-                      help="display more about what is going on")
+    parser.add_option("-a", "--address", dest="address", type="string",
+                      action="callback", callback=check_addr, default='',
+                      help="address the b10-auth daemon will use (default: listen on all addresses)")
+    parser.add_option("-m", "--msgq-socket-file", dest="msgq_socket_file",
+                      type="string", default=None,
+                      help="UNIX domain socket file the b10-msgq daemon will use")
     parser.add_option("-n", "--no-cache", action="store_true", dest="nocache",
     parser.add_option("-n", "--no-cache", action="store_true", dest="nocache",
                       default=False, help="disable hot-spot cache in b10-auth")
                       default=False, help="disable hot-spot cache in b10-auth")
     parser.add_option("-p", "--port", dest="auth_port", type="string",
     parser.add_option("-p", "--port", dest="auth_port", type="string",
                       action="callback", callback=check_port, default="5300",
                       action="callback", callback=check_port, default="5300",
                       help="port the b10-auth daemon will use (default 5300)")
                       help="port the b10-auth daemon will use (default 5300)")
-    parser.add_option("-m", "--msgq-socket-file", dest="msgq_socket_file",
-                      type="string", default=None,
-                      help="UNIX domain socket file the b10-msgq daemon will use")
     parser.add_option("-u", "--user", dest="user",
     parser.add_option("-u", "--user", dest="user",
                       type="string", default=None,
                       type="string", default=None,
                       help="Change user after startup (must run as root)")
                       help="Change user after startup (must run as root)")
+    parser.add_option("-v", "--verbose", dest="verbose", action="store_true",
+                      help="display more about what is going on")
     (options, args) = parser.parse_args()
     (options, args) = parser.parse_args()
     if args:
     if args:
         parser.print_help()
         parser.print_help()
@@ -626,7 +675,8 @@ def main():
 
 
     # Go bob!
     # Go bob!
     boss_of_bind = BoB(options.msgq_socket_file, int(options.auth_port),
     boss_of_bind = BoB(options.msgq_socket_file, int(options.auth_port),
-                       options.nocache, options.verbose, setuid, username)
+                       options.address, options.nocache, options.verbose,
+                       setuid, username)
     startup_result = boss_of_bind.startup()
     startup_result = boss_of_bind.startup()
     if startup_result:
     if startup_result:
         sys.stderr.write("[bind10] Error on startup: %s\n" % startup_result)
         sys.stderr.write("[bind10] Error on startup: %s\n" % startup_result)

+ 47 - 1
src/bin/bind10/tests/bind10_test.py

@@ -1,4 +1,4 @@
-from bind10 import ProcessInfo, BoB
+from bind10 import ProcessInfo, BoB, IPAddr
 
 
 # XXX: environment tests are currently disabled, due to the preprocessor
 # XXX: environment tests are currently disabled, due to the preprocessor
 #      setup that we have now complicating the environment
 #      setup that we have now complicating the environment
@@ -7,6 +7,7 @@ import unittest
 import sys
 import sys
 import os
 import os
 import signal
 import signal
+import socket
 
 
 class TestProcessInfo(unittest.TestCase):
 class TestProcessInfo(unittest.TestCase):
     def setUp(self):
     def setUp(self):
@@ -71,12 +72,36 @@ class TestProcessInfo(unittest.TestCase):
         self.assertTrue(type(pi.pid) is int)
         self.assertTrue(type(pi.pid) is int)
         self.assertNotEqual(pi.pid, old_pid)
         self.assertNotEqual(pi.pid, old_pid)
 
 
+class TestIPAddr(unittest.TestCase):
+    def test_v6ok(self):
+        addr = IPAddr('2001:4f8::1')
+        self.assertEqual(addr.family, socket.AF_INET6)
+        self.assertEqual(addr.addr, socket.inet_pton(socket.AF_INET6, '2001:4f8::1'))
+
+    def test_v4ok(self):
+        addr = IPAddr('127.127.127.127')
+        self.assertEqual(addr.family, socket.AF_INET)
+        self.assertEqual(addr.addr, socket.inet_aton('127.127.127.127'))
+
+    def test_badaddr(self):
+        self.assertRaises(socket.error, IPAddr, 'foobar')
+        self.assertRaises(socket.error, IPAddr, 'foo::bar')
+        self.assertRaises(socket.error, IPAddr, '123')
+        self.assertRaises(socket.error, IPAddr, '123.456.789.0')
+        self.assertRaises(socket.error, IPAddr, '127/8')
+        self.assertRaises(socket.error, IPAddr, '0/0')
+        self.assertRaises(socket.error, IPAddr, '1.2.3.4/32')
+        self.assertRaises(socket.error, IPAddr, '0')
+        self.assertRaises(socket.error, IPAddr, '')
+
 class TestBoB(unittest.TestCase):
 class TestBoB(unittest.TestCase):
     def test_init(self):
     def test_init(self):
         bob = BoB()
         bob = BoB()
         self.assertEqual(bob.verbose, False)
         self.assertEqual(bob.verbose, False)
         self.assertEqual(bob.msgq_socket_file, None)
         self.assertEqual(bob.msgq_socket_file, None)
+        self.assertEqual(bob.auth_port, 5300)
         self.assertEqual(bob.cc_session, None)
         self.assertEqual(bob.cc_session, None)
+        self.assertEqual(bob.address, None)
         self.assertEqual(bob.processes, {})
         self.assertEqual(bob.processes, {})
         self.assertEqual(bob.dead_processes, {})
         self.assertEqual(bob.dead_processes, {})
         self.assertEqual(bob.runnable, False)
         self.assertEqual(bob.runnable, False)
@@ -90,6 +115,27 @@ class TestBoB(unittest.TestCase):
         self.assertEqual(bob.dead_processes, {})
         self.assertEqual(bob.dead_processes, {})
         self.assertEqual(bob.runnable, False)
         self.assertEqual(bob.runnable, False)
 
 
+    def test_init_alternate_auth_port(self):
+        bob = BoB(None, 9999)
+        self.assertEqual(bob.verbose, False)
+        self.assertEqual(bob.msgq_socket_file, None)
+        self.assertEqual(bob.auth_port, 9999)
+        self.assertEqual(bob.cc_session, None)
+        self.assertEqual(bob.address, None)
+        self.assertEqual(bob.processes, {})
+        self.assertEqual(bob.dead_processes, {})
+        self.assertEqual(bob.runnable, False)
+
+    def test_init_alternate_address(self):
+        bob = BoB(None, 5300, '127.127.127.127')
+        self.assertEqual(bob.verbose, False)
+        self.assertEqual(bob.auth_port, 5300)
+        self.assertEqual(bob.msgq_socket_file, None)
+        self.assertEqual(bob.cc_session, None)
+        self.assertEqual(bob.address.addr, socket.inet_aton('127.127.127.127'))
+        self.assertEqual(bob.processes, {})
+        self.assertEqual(bob.dead_processes, {})
+        self.assertEqual(bob.runnable, False)
     # verbose testing...
     # verbose testing...
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':