123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646 |
- // Copyright (C) 2009 Internet Systems Consortium, Inc. ("ISC")
- //
- // Permission to use, copy, modify, and/or distribute this software for any
- // purpose with or without fee is hereby granted, provided that the above
- // copyright notice and this permission notice appear in all copies.
- //
- // THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH
- // REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
- // AND FITNESS. IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT,
- // INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
- // LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
- // OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
- // PERFORMANCE OF THIS SOFTWARE.
- // $Id$
- #include "config.h"
- #include <sys/types.h>
- #include <sys/socket.h>
- #include <sys/select.h>
- #include <netdb.h>
- #include <stdlib.h>
- #include <errno.h>
- #include <set>
- #include <iostream>
- #include <boost/foreach.hpp>
- #ifdef HAVE_BOOSTLIB
- #include <boost/bind.hpp>
- #include <boost/asio.hpp>
- #endif
- #include <exceptions/exceptions.h>
- #include <dns/buffer.h>
- #include <dns/name.h>
- #include <dns/message.h>
- #include <dns/rrset.h>
- #include <dns/message.h>
- #include <dns/messagerenderer.h>
- #include <cc/session.h>
- #include <cc/data.h>
- #include <config/ccsession.h>
- #include "spec_config.h"
- #include "common.h"
- #include "auth_srv.h"
- #include <boost/foreach.hpp>
- using namespace std;
- #ifdef HAVE_BOOSTLIB
- using namespace boost::asio;
- using ip::udp;
- using ip::tcp;
- #endif
- using namespace isc::data;
- using namespace isc::cc;
- using namespace isc::config;
- using namespace isc::dns;
- namespace {
- const string PROGRAM = "Auth";
- const char* DNSPORT = "5300";
- }
- /* need global var for config/command handlers.
- * todo: turn this around, and put handlers in the authserver
- * class itself? */
- namespace {
- AuthSrv *auth_server;
- }
- static ElementPtr
- my_config_handler(ElementPtr new_config)
- {
- return auth_server->updateConfig(new_config);
- }
- static ElementPtr
- my_command_handler(const string& command, const ElementPtr args) {
- ElementPtr answer = createAnswer(0);
- if (command == "print_message")
- {
- cout << args << endl;
- /* let's add that message to our answer as well */
- answer->get("result")->add(args);
- }
- return answer;
- }
- #ifdef HAVE_BOOSTLIB
- //
- // Helper classes for asynchronous I/O using boost::asio
- //
- namespace {
- class Completed {
- public:
- Completed(size_t len) : len_(len) {}
- bool operator()(const boost::system::error_code& error,
- size_t bytes_transferred) const
- {
- return (error != 0 || bytes_transferred >= len_);
- }
- private:
- size_t len_;
- };
- class TCPClient {
- public:
- TCPClient(io_service& io_service) :
- socket_(io_service),
- response_buffer_(0),
- responselen_buffer_(TCP_MESSAGE_LENGTHSIZE),
- response_renderer_(response_buffer_),
- dns_message_(Message::PARSE)
- {}
- void start() {
- async_read(socket_, boost::asio::buffer(data_, TCP_MESSAGE_LENGTHSIZE),
- Completed(TCP_MESSAGE_LENGTHSIZE),
- boost::bind(&TCPClient::headerRead, this,
- placeholders::error,
- placeholders::bytes_transferred));
- }
- tcp::socket& getSocket() { return (socket_); }
- void headerRead(const boost::system::error_code& error,
- size_t bytes_transferred)
- {
- if (!error) {
- assert(bytes_transferred == TCP_MESSAGE_LENGTHSIZE);
- InputBuffer dnsbuffer(data_, TCP_MESSAGE_LENGTHSIZE);
- uint16_t msglen = dnsbuffer.readUint16();
- async_read(socket_, boost::asio::buffer(data_, msglen),
- Completed(msglen),
- boost::bind(&TCPClient::requestRead, this,
- placeholders::error,
- placeholders::bytes_transferred));
- } else {
- delete this;
- }
- }
- void requestRead(const boost::system::error_code& error,
- size_t bytes_transferred)
- {
- if (!error) {
- InputBuffer dnsbuffer(data_, bytes_transferred);
- if (auth_server->processMessage(dnsbuffer, dns_message_,
- response_renderer_, false) == 0) {
- responselen_buffer_.writeUint16(response_buffer_.getLength());
- async_write(socket_,
- boost::asio::buffer(
- responselen_buffer_.getData(),
- responselen_buffer_.getLength()),
- boost::bind(&TCPClient::responseWrite, this,
- placeholders::error));
- } else {
- delete this;
- }
- } else {
- delete this;
- }
- }
- void responseWrite(const boost::system::error_code& error)
- {
- if (!error) {
- async_write(socket_,
- boost::asio::buffer(response_buffer_.getData(),
- response_buffer_.getLength()),
- boost::bind(&TCPClient::handleWrite, this,
- placeholders::error));
- }
- }
- void handleWrite(const boost::system::error_code& error)
- {
- if (!error) {
- start(); // handle next request, if any.
- } else {
- delete this;
- }
- }
- private:
- tcp::socket socket_;
- OutputBuffer response_buffer_;
- OutputBuffer responselen_buffer_;
- MessageRenderer response_renderer_;
- Message dns_message_;
- enum { MAX_LENGTH = 65535 };
- static const size_t TCP_MESSAGE_LENGTHSIZE = 2;
- char data_[MAX_LENGTH];
- };
- class TCPServer
- {
- public:
- TCPServer(io_service& io_service, int af, short port) :
- io_service_(io_service),
- acceptor_(io_service,
- tcp::endpoint(af == AF_INET6 ? tcp::v6() : tcp::v4(), port))
- {
- TCPClient* new_client = new TCPClient(io_service_);
- // XXX: isn't the following exception free? Need to check it.
- acceptor_.async_accept(new_client->getSocket(),
- boost::bind(&TCPServer::handleAccept, this,
- new_client, placeholders::error));
- }
- void handleAccept(TCPClient* new_client,
- const boost::system::error_code& error)
- {
- if (!error) {
- new_client->start();
- new_client = new TCPClient(io_service_);
- acceptor_.async_accept(new_client->getSocket(),
- boost::bind(&TCPServer::handleAccept,
- this, new_client,
- placeholders::error));
- } else {
- delete new_client;
- }
- }
- private:
- io_service& io_service_;
- tcp::acceptor acceptor_;
- };
- class UDPServer {
- public:
- UDPServer(io_service& io_service, int af, short port) :
- io_service_(io_service),
- socket_(io_service,
- udp::endpoint(af == AF_INET6 ? udp::v6() : udp::v4(), port)),
- response_buffer_(0),
- response_renderer_(response_buffer_),
- dns_message_(Message::PARSE)
- {
- startReceive();
- }
- void handleRequest(const boost::system::error_code& error,
- size_t bytes_recvd)
- {
- if (!error && bytes_recvd > 0) {
- InputBuffer request_buffer(data_, bytes_recvd);
- dns_message_.clear(Message::PARSE);
- response_renderer_.clear();
- if (auth_server->processMessage(request_buffer, dns_message_,
- response_renderer_, true) == 0) {
- socket_.async_send_to(
- boost::asio::buffer(response_buffer_.getData(),
- response_buffer_.getLength()),
- sender_endpoint_,
- boost::bind(&UDPServer::sendCompleted,
- this,
- placeholders::error,
- placeholders::bytes_transferred));
- } else {
- startReceive();
- }
- } else {
- startReceive();
- }
- }
- void sendCompleted(const boost::system::error_code& error,
- size_t bytes_sent)
- {
- startReceive();
- }
- private:
- void startReceive() {
- socket_.async_receive_from(
- boost::asio::buffer(data_, MAX_LENGTH), sender_endpoint_,
- boost::bind(&UDPServer::handleRequest, this,
- placeholders::error,
- placeholders::bytes_transferred));
- }
- private:
- io_service& io_service_;
- udp::socket socket_;
- OutputBuffer response_buffer_;
- MessageRenderer response_renderer_;
- Message dns_message_;
- udp::endpoint sender_endpoint_;
- enum { MAX_LENGTH = 4096 };
- char data_[MAX_LENGTH];
- };
- struct ServerSet {
- ServerSet() : udp4_server(NULL), udp6_server(NULL),
- tcp4_server(NULL), tcp6_server(NULL)
- {}
- ~ServerSet()
- {
- delete udp4_server;
- delete udp6_server;
- delete tcp4_server;
- delete tcp6_server;
- }
- UDPServer* udp4_server;
- UDPServer* udp6_server;
- TCPServer* tcp4_server;
- TCPServer* tcp6_server;
- };
- static void
- run_server(const char* port, const bool use_ipv4, const bool use_ipv6,
- const string& specfile)
- {
- ServerSet servers;
- boost::asio::io_service io_service;
- short portnum = atoi(port);
- ModuleCCSession cs(specfile, io_service, my_config_handler,
- my_command_handler);
- if (use_ipv4) {
- servers.udp4_server = new UDPServer(io_service, AF_INET, portnum);
- servers.tcp4_server = new TCPServer(io_service, AF_INET, portnum);
- }
- if (use_ipv6) {
- servers.udp6_server = new UDPServer(io_service, AF_INET6, portnum);
- servers.tcp6_server = new TCPServer(io_service, AF_INET6, portnum);
- }
- cout << "Server started." << endl;
- io_service.run();
- }
- }
- #else // !HAVE_BOOSTLIB
- struct SocketSet {
- SocketSet() : ups4(-1), tps4(-1), ups6(-1), tps6(-1) {}
- ~SocketSet()
- {
- if (ups4 >= 0) {
- close(ups4);
- }
- if (tps4 >= 0) {
- close(tps4);
- }
- if (ups6 >= 0) {
- close(ups6);
- }
- if (tps4 >= 0) {
- close(tps6);
- }
- }
- int ups4, tps4, ups6, tps6;
- };
- static int
- getUDPSocket(int af, const char* port) {
- struct addrinfo hints, *res;
- memset(&hints, 0, sizeof(hints));
- hints.ai_family = af;
- hints.ai_socktype = SOCK_DGRAM;
- hints.ai_flags = AI_PASSIVE;
- hints.ai_protocol = IPPROTO_UDP;
- int error = getaddrinfo(NULL, port, &hints, &res);
- if (error != 0) {
- isc_throw(FatalError, "getaddrinfo failed: " << gai_strerror(error));
- }
- int s = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
- if (s < 0) {
- isc_throw(FatalError, "failed to open socket");
- }
- if (af == AF_INET6) {
- int on = 1;
- if (setsockopt(s, IPPROTO_IPV6, IPV6_V6ONLY, &on, sizeof(on)) < 0) {
- cerr << "couldn't set IPV6_V6ONLY socket option" << endl;
- // proceed anyway
- }
- }
- if (bind(s, res->ai_addr, res->ai_addrlen) < 0) {
- isc_throw(FatalError, "binding socket failure");
- }
- return (s);
- }
- static int
- getTCPSocket(int af, const char* port) {
- struct addrinfo hints, *res;
- memset(&hints, 0, sizeof(hints));
- hints.ai_family = af;
- hints.ai_socktype = SOCK_STREAM;
- hints.ai_flags = AI_PASSIVE;
- hints.ai_protocol = IPPROTO_TCP;
- int error = getaddrinfo(NULL, port, &hints, &res);
- if (error != 0) {
- isc_throw(FatalError, "getaddrinfo failed: " << gai_strerror(error));
- }
- int s = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
- if (s < 0) {
- isc_throw(FatalError, "failed to open socket");
- }
- int on = 1;
- if (af == AF_INET6) {
- if (setsockopt(s, IPPROTO_IPV6, IPV6_V6ONLY, &on, sizeof(on)) < 0) {
- cerr << "couldn't set IPV6_V6ONLY socket option" << endl;
- }
- }
- if (setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) < 0) {
- cerr << "couldn't set SO_REUSEADDR socket option" << endl;
- }
- if (bind(s, res->ai_addr, res->ai_addrlen) < 0) {
- isc_throw(FatalError, "binding socket failure");
- }
- listen(s, 100);
- return (s);
- }
- static void
- processMessageUDP(const int fd, Message& dns_message,
- MessageRenderer& response_renderer)
- {
- struct sockaddr_storage ss;
- socklen_t sa_len = sizeof(ss);
- struct sockaddr* sa = static_cast<struct sockaddr*>((void*)&ss);
- char recvbuf[4096];
- int cc;
- dns_message.clear(Message::PARSE);
- response_renderer.clear();
- if ((cc = recvfrom(fd, recvbuf, sizeof(recvbuf), 0, sa, &sa_len)) > 0) {
- InputBuffer buffer(recvbuf, cc);
- if (auth_server->processMessage(buffer, dns_message, response_renderer,
- true) == 0) {
- sendto(fd, response_renderer.getData(),
- response_renderer.getLength(), 0, sa, sa_len);
- }
- }
- }
- static void
- processMessageTCP(const int fd, Message& dns_message,
- MessageRenderer& response_renderer)
- {
- struct sockaddr_storage ss;
- socklen_t sa_len = sizeof(ss);
- struct sockaddr* sa = static_cast<struct sockaddr*>((void*)&ss);
- char sizebuf[2];
- int cc;
- int ts = accept(fd, sa, &sa_len);
- cout << "[XX] process TCP" << endl;
- cc = recv(ts, sizebuf, 2, 0);
- cout << "[XX] got: " << cc << endl;
- uint16_t size, size_n;
- memcpy(&size_n, sizebuf, 2);
- size = ntohs(size_n);
- cout << "[XX] got: " << size << endl;
- vector<char> message_buffer;
- message_buffer.reserve(size);
- cc = 0;
- while (cc < size) {
- cout << "[XX] cc now: " << cc << " of " << size << endl;
- cc += recv(ts, &message_buffer[0] + cc, size - cc, 0);
- }
- InputBuffer buffer(&message_buffer[0], size);
- dns_message.clear(Message::PARSE);
- response_renderer.clear();
- if (auth_server->processMessage(buffer, dns_message, response_renderer,
- false) == 0) {
- size = response_renderer.getLength();
- size_n = htons(size);
- if (send(ts, &size_n, 2, 0) == 2) {
- cc = send(ts, response_renderer.getData(),
- response_renderer.getLength(), 0);
- if (cc == -1) {
- cerr << "[AuthSrv] error in sending TCP response message" <<
- endl;
- } else {
- cout << "[XX] sent TCP response: " << cc << " bytes" << endl;
- }
- }
- }
-
- // TODO: we don't check for more queries on the stream atm
- close(ts);
- }
- static void
- run_server(const char* port, const bool use_ipv4, const bool use_ipv6,
- const string& specfile)
- {
- SocketSet socket_set;
- fd_set fds_base;
- int nfds = -1;
- FD_ZERO(&fds_base);
- if (use_ipv4) {
- socket_set.ups4 = getUDPSocket(AF_INET, port);
- FD_SET(socket_set.ups4, &fds_base);
- nfds = max(nfds, socket_set.ups4);
- socket_set.tps4 = getTCPSocket(AF_INET, port);
- FD_SET(socket_set.tps4, &fds_base);
- nfds = max(nfds, socket_set.tps4);
- }
- if (use_ipv6) {
- socket_set.ups6 = getUDPSocket(AF_INET6, port);
- FD_SET(socket_set.ups6, &fds_base);
- nfds = max(nfds, socket_set.ups6);
- socket_set.tps6 = getTCPSocket(AF_INET6, port);
- FD_SET(socket_set.tps6, &fds_base);
- nfds = max(nfds, socket_set.tps6);
- }
- ++nfds;
- ModuleCCSession cs(specfile, my_config_handler, my_command_handler);
- cout << "Server started." << endl;
-
- int ss = cs.getSocket();
- Message dns_message(Message::PARSE);
- OutputBuffer resonse_buffer(0);
- MessageRenderer response_renderer(resonse_buffer);
- while (true) {
- fd_set fds = fds_base;
- FD_SET(ss, &fds);
- int n = select(nfds, &fds, NULL, NULL, NULL);
- if (n < 0) {
- if (errno != EINTR) {
- isc_throw(FatalError, "select error");
- }
- continue;
- }
- if (socket_set.ups4 >= 0 && FD_ISSET(socket_set.ups4, &fds)) {
- processMessageUDP(socket_set.ups4, dns_message, response_renderer);
- }
- if (socket_set.ups6 >= 0 && FD_ISSET(socket_set.ups6, &fds)) {
- processMessageUDP(socket_set.ups6, dns_message, response_renderer);
- }
- if (socket_set.tps4 >= 0 && FD_ISSET(socket_set.tps4, &fds)) {
- processMessageTCP(socket_set.tps4, dns_message, response_renderer);
- }
- if (socket_set.tps6 >= 0 && FD_ISSET(socket_set.tps6, &fds)) {
- processMessageTCP(socket_set.tps6, dns_message, response_renderer);
- }
- if (FD_ISSET(ss, &fds)) {
- cs.check_command();
- }
- }
- }
- #endif // HAVE_BOOSTLIB
- static void
- usage() {
- cerr << "Usage: b10-auth [-p port] [-4|-6]" << endl;
- exit(1);
- }
- int
- main(int argc, char* argv[]) {
- int ch;
- const char* port = DNSPORT;
- bool ipv4_only = false, ipv6_only = false;
- bool use_ipv4 = false, use_ipv6 = false;
- while ((ch = getopt(argc, argv, "46p:")) != -1) {
- switch (ch) {
- case '4':
- ipv4_only = true;
- break;
- case '6':
- ipv6_only = true;
- break;
- case 'p':
- port = optarg;
- break;
- case '?':
- default:
- usage();
- }
- }
- if (argc - optind > 0) {
- usage();
- }
- if (ipv4_only && ipv6_only) {
- cerr << "-4 and -6 can't coexist" << endl;
- usage();
- }
- if (!ipv6_only) {
- use_ipv4 = true;
- }
- if (!ipv4_only) {
- use_ipv4 = true;
- }
- auth_server = new AuthSrv;
- // initialize command channel
- int ret = 0;
- try {
- string specfile;
- if (getenv("B10_FROM_SOURCE")) {
- specfile = string(getenv("B10_FROM_SOURCE")) +
- "/src/bin/auth/auth.spec";
- } else {
- specfile = string(AUTH_SPECFILE_LOCATION);
- }
- run_server(port, use_ipv4, use_ipv6, specfile);
- } catch (const std::exception& ex) {
- cerr << ex.what() << endl;
- ret = 1;
- }
- delete auth_server;
- return (ret);
- }
|