main.cc 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. // Copyright (C) 2009 Internet Systems Consortium, Inc. ("ISC")
  2. //
  3. // Permission to use, copy, modify, and/or distribute this software for any
  4. // purpose with or without fee is hereby granted, provided that the above
  5. // copyright notice and this permission notice appear in all copies.
  6. //
  7. // THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH
  8. // REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
  9. // AND FITNESS. IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT,
  10. // INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
  11. // LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
  12. // OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
  13. // PERFORMANCE OF THIS SOFTWARE.
  14. // $Id$
  15. #include <sys/types.h>
  16. #include <sys/socket.h>
  17. #include <sys/select.h>
  18. #include <netdb.h>
  19. #include <stdlib.h>
  20. #include <set>
  21. #include <iostream>
  22. #include <boost/foreach.hpp>
  23. #include <boost/bind.hpp>
  24. #include <boost/asio.hpp>
  25. #include <dns/buffer.h>
  26. #include <dns/name.h>
  27. #include <dns/message.h>
  28. #include <dns/rrset.h>
  29. #include <dns/message.h>
  30. #include <dns/messagerenderer.h>
  31. #include <cc/session.h>
  32. #include <cc/data.h>
  33. #include <config/ccsession.h>
  34. #include "common.h"
  35. #include "config.h"
  36. #include "auth_srv.h"
  37. #include <boost/foreach.hpp>
  38. using namespace std;
  39. using namespace boost::asio;
  40. using ip::udp;
  41. using ip::tcp;
  42. using namespace isc::data;
  43. using namespace isc::cc;
  44. using namespace isc::config;
  45. using namespace isc::dns;
  46. namespace {
  47. const string PROGRAM = "Auth";
  48. const short DNSPORT = 5300;
  49. }
  50. /* need global var for config/command handlers.
  51. * todo: turn this around, and put handlers in the authserver
  52. * class itself? */
  53. namespace {
  54. AuthSrv *auth_server;
  55. }
  56. //
  57. // Helper classes for asynchronous I/O using boost::asio
  58. //
  59. namespace {
  60. class Completed {
  61. public:
  62. Completed(size_t len) : len_(len) {}
  63. bool operator()(const boost::system::error_code& error,
  64. size_t bytes_transferred) const
  65. {
  66. return (error != 0 || bytes_transferred >= len_);
  67. }
  68. private:
  69. size_t len_;
  70. };
  71. class TCPClient {
  72. public:
  73. TCPClient(io_service& io_service) :
  74. socket_(io_service),
  75. response_buffer_(0),
  76. responselen_buffer_(TCP_MESSAGE_LENGTHSIZE),
  77. response_renderer_(response_buffer_),
  78. dns_message_(Message::PARSE)
  79. {}
  80. void start() {
  81. async_read(socket_, boost::asio::buffer(data_, TCP_MESSAGE_LENGTHSIZE),
  82. Completed(TCP_MESSAGE_LENGTHSIZE),
  83. boost::bind(&TCPClient::headerRead, this,
  84. placeholders::error,
  85. placeholders::bytes_transferred));
  86. }
  87. tcp::socket& getSocket() { return (socket_); }
  88. void headerRead(const boost::system::error_code& error,
  89. size_t bytes_transferred)
  90. {
  91. if (!error) {
  92. assert(bytes_transferred == TCP_MESSAGE_LENGTHSIZE);
  93. InputBuffer dnsbuffer(data_, TCP_MESSAGE_LENGTHSIZE);
  94. uint16_t msglen = dnsbuffer.readUint16();
  95. async_read(socket_, boost::asio::buffer(data_, msglen),
  96. Completed(msglen),
  97. boost::bind(&TCPClient::requestRead, this,
  98. placeholders::error,
  99. placeholders::bytes_transferred));
  100. } else {
  101. delete this;
  102. }
  103. }
  104. void requestRead(const boost::system::error_code& error,
  105. size_t bytes_transferred)
  106. {
  107. if (!error) {
  108. InputBuffer dnsbuffer(data_, bytes_transferred);
  109. if (auth_server->processMessage(dnsbuffer, dns_message_,
  110. response_renderer_) == 0) {
  111. responselen_buffer_.writeUint16(response_buffer_.getLength());
  112. async_write(socket_,
  113. boost::asio::buffer(
  114. responselen_buffer_.getData(),
  115. responselen_buffer_.getLength()),
  116. boost::bind(&TCPClient::responseWrite, this,
  117. placeholders::error));
  118. } else {
  119. delete this;
  120. }
  121. } else {
  122. delete this;
  123. }
  124. }
  125. void responseWrite(const boost::system::error_code& error)
  126. {
  127. if (!error) {
  128. async_write(socket_,
  129. boost::asio::buffer(response_buffer_.getData(),
  130. response_buffer_.getLength()),
  131. boost::bind(&TCPClient::handleWrite, this,
  132. placeholders::error));
  133. }
  134. }
  135. void handleWrite(const boost::system::error_code& error)
  136. {
  137. if (!error) {
  138. start(); // handle next request, if any.
  139. } else {
  140. delete this;
  141. }
  142. }
  143. private:
  144. tcp::socket socket_;
  145. OutputBuffer response_buffer_;
  146. OutputBuffer responselen_buffer_;
  147. MessageRenderer response_renderer_;
  148. Message dns_message_;
  149. enum { MAX_LENGTH = 65535 };
  150. static const size_t TCP_MESSAGE_LENGTHSIZE = 2;
  151. char data_[MAX_LENGTH];
  152. };
  153. class TCPServer
  154. {
  155. public:
  156. TCPServer(io_service& io_service, int af, short port) :
  157. io_service_(io_service),
  158. acceptor_(io_service,
  159. tcp::endpoint(af == AF_INET6 ? tcp::v6() : tcp::v4(), port))
  160. {
  161. TCPClient* new_client = new TCPClient(io_service_);
  162. // XXX: isn't the following exception free? Need to check it.
  163. acceptor_.async_accept(new_client->getSocket(),
  164. boost::bind(&TCPServer::handleAccept, this,
  165. new_client, placeholders::error));
  166. }
  167. void handleAccept(TCPClient* new_client,
  168. const boost::system::error_code& error)
  169. {
  170. if (!error) {
  171. new_client->start();
  172. new_client = new TCPClient(io_service_);
  173. acceptor_.async_accept(new_client->getSocket(),
  174. boost::bind(&TCPServer::handleAccept,
  175. this, new_client,
  176. placeholders::error));
  177. } else {
  178. delete new_client;
  179. }
  180. }
  181. private:
  182. io_service& io_service_;
  183. tcp::acceptor acceptor_;
  184. };
  185. class UDPServer {
  186. public:
  187. UDPServer(io_service& io_service, int af, short port) :
  188. io_service_(io_service),
  189. socket_(io_service,
  190. udp::endpoint(af == AF_INET6 ? udp::v6() : udp::v4(), port)),
  191. response_buffer_(0),
  192. response_renderer_(response_buffer_),
  193. dns_message_(Message::PARSE)
  194. {
  195. startReceive();
  196. }
  197. void handleRequest(const boost::system::error_code& error,
  198. size_t bytes_recvd)
  199. {
  200. if (!error && bytes_recvd > 0) {
  201. InputBuffer request_buffer(data_, bytes_recvd);
  202. dns_message_.clear(Message::PARSE);
  203. response_renderer_.clear();
  204. if (auth_server->processMessage(request_buffer, dns_message_,
  205. response_renderer_) == 0) {
  206. socket_.async_send_to(
  207. boost::asio::buffer(response_buffer_.getData(),
  208. response_buffer_.getLength()),
  209. sender_endpoint_,
  210. boost::bind(&UDPServer::sendCompleted,
  211. this,
  212. placeholders::error,
  213. placeholders::bytes_transferred));
  214. } else {
  215. startReceive();
  216. }
  217. } else {
  218. startReceive();
  219. }
  220. }
  221. void sendCompleted(const boost::system::error_code& error,
  222. size_t bytes_sent)
  223. {
  224. startReceive();
  225. }
  226. private:
  227. void startReceive() {
  228. socket_.async_receive_from(
  229. boost::asio::buffer(data_, MAX_LENGTH), sender_endpoint_,
  230. boost::bind(&UDPServer::handleRequest, this,
  231. placeholders::error,
  232. placeholders::bytes_transferred));
  233. }
  234. private:
  235. io_service& io_service_;
  236. udp::socket socket_;
  237. OutputBuffer response_buffer_;
  238. MessageRenderer response_renderer_;
  239. Message dns_message_;
  240. udp::endpoint sender_endpoint_;
  241. enum { MAX_LENGTH = 4096 };
  242. char data_[MAX_LENGTH];
  243. };
  244. }
  245. static void
  246. usage() {
  247. cerr << "Usage: b10-auth [-p port] [-4|-6]" << endl;
  248. exit(1);
  249. }
  250. ElementPtr
  251. my_config_handler(ElementPtr new_config)
  252. {
  253. auth_server->updateConfig(new_config);
  254. return createAnswer(0);
  255. }
  256. ElementPtr
  257. my_command_handler(const string& command, const ElementPtr args) {
  258. ElementPtr answer = createAnswer(0);
  259. cout << "[XX] Handle command: " << endl << command << endl;
  260. if (command == "print_message")
  261. {
  262. cout << args << endl;
  263. /* let's add that message to our answer as well */
  264. answer->get("result")->add(args);
  265. }
  266. return answer;
  267. }
  268. int
  269. main(int argc, char* argv[]) {
  270. int ch;
  271. short port = DNSPORT;
  272. bool ipv4_only = false, ipv6_only = false;
  273. bool use_ipv4 = false, use_ipv6 = false;
  274. UDPServer* udp4_server = NULL;
  275. UDPServer* udp6_server = NULL;
  276. TCPServer* tcp4_server = NULL;
  277. TCPServer* tcp6_server = NULL;
  278. while ((ch = getopt(argc, argv, "46p:")) != -1) {
  279. switch (ch) {
  280. case '4':
  281. ipv4_only = true;
  282. break;
  283. case '6':
  284. ipv6_only = true;
  285. break;
  286. case 'p':
  287. port = atoi(optarg);
  288. break;
  289. case '?':
  290. default:
  291. usage();
  292. }
  293. }
  294. if (argc - optind > 0) {
  295. usage();
  296. }
  297. if (ipv4_only && ipv6_only) {
  298. cerr << "-4 and -6 can't coexist" << endl;
  299. usage();
  300. }
  301. if (!ipv6_only) {
  302. use_ipv4 = true;
  303. }
  304. if (!ipv4_only) {
  305. use_ipv4 = true;
  306. }
  307. auth_server = new AuthSrv;
  308. // initialize command channel
  309. int ret = 0;
  310. try {
  311. string specfile;
  312. if (getenv("B10_FROM_SOURCE")) {
  313. specfile = string(getenv("B10_FROM_SOURCE")) +
  314. "/src/bin/auth/auth.spec";
  315. } else {
  316. specfile = string(AUTH_SPECFILE_LOCATION);
  317. }
  318. ModuleCCSession cs = ModuleCCSession(specfile, my_config_handler,
  319. my_command_handler);
  320. // XXX: in this prototype code we'll ignore any message on the command
  321. // channel.
  322. boost::asio::io_service io_service;
  323. if (use_ipv4) {
  324. udp4_server = new UDPServer(io_service, AF_INET, port);
  325. tcp4_server = new TCPServer(io_service, AF_INET, port);
  326. }
  327. if (use_ipv6) {
  328. udp6_server = new UDPServer(io_service, AF_INET6, port);
  329. tcp6_server = new TCPServer(io_service, AF_INET6, port);
  330. }
  331. cout << "Server started." << endl;
  332. io_service.run();
  333. } catch (const std::exception& ex) {
  334. cerr << ex.what() << endl;
  335. ret = 1;
  336. }
  337. delete udp4_server;
  338. delete tcp4_server;
  339. delete udp6_server;
  340. delete tcp6_server;
  341. delete auth_server;
  342. return (ret);
  343. }