asio_link.cc 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585
  1. // Copyright (C) 2010 Internet Systems Consortium, Inc. ("ISC")
  2. //
  3. // Permission to use, copy, modify, and/or distribute this software for any
  4. // purpose with or without fee is hereby granted, provided that the above
  5. // copyright notice and this permission notice appear in all copies.
  6. //
  7. // THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH
  8. // REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
  9. // AND FITNESS. IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT,
  10. // INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
  11. // LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
  12. // OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
  13. // PERFORMANCE OF THIS SOFTWARE.
  14. // $Id$
  15. #include <config.h>
  16. #include <unistd.h> // for some IPC/network system calls
  17. #include <sys/socket.h>
  18. #include <netinet/in.h>
  19. #include <asio.hpp>
  20. #include <boost/lexical_cast.hpp>
  21. #include <boost/bind.hpp>
  22. #include <boost/shared_ptr.hpp>
  23. #include <dns/buffer.h>
  24. #include <dns/message.h>
  25. #include <dns/messagerenderer.h>
  26. #include <asio_link.h>
  27. #include <auth/auth_srv.h>
  28. #include <auth/common.h>
  29. using namespace asio;
  30. using asio::ip::udp;
  31. using asio::ip::tcp;
  32. using namespace std;
  33. using namespace isc::dns;
  34. namespace asio_link {
  35. IOAddress::IOAddress(const string& address_str)
  36. // XXX: we cannot simply construct the address in the initialization list
  37. // because we'd like to throw our own exception on failure.
  38. {
  39. error_code err;
  40. asio_address_ = ip::address::from_string(address_str, err);
  41. if (err) {
  42. isc_throw(IOError, "Failed to convert string to address '"
  43. << address_str << "': " << err.message());
  44. }
  45. }
  46. IOAddress::IOAddress(const ip::address& asio_address) :
  47. asio_address_(asio_address)
  48. {}
  49. string
  50. IOAddress::toText() const {
  51. return (asio_address_.to_string());
  52. }
  53. // Note: this implementation is optimized for the case where this object
  54. // is created from an ASIO endpoint object in a receiving code path
  55. // by avoiding to make a copy of the base endpoint. For TCP it may not be
  56. // a bug deal, but when we receive UDP packets at a high rate, the copy
  57. // overhead might be significant.
  58. class TCPEndpoint : public IOEndpoint {
  59. public:
  60. TCPEndpoint(const IOAddress& address, const unsigned short port) :
  61. asio_endpoint_placeholder_(
  62. new tcp::endpoint(ip::address::from_string(address.toText()),
  63. port)),
  64. asio_endpoint_(*asio_endpoint_placeholder_)
  65. {}
  66. TCPEndpoint(const tcp::endpoint& asio_endpoint) :
  67. asio_endpoint_placeholder_(NULL), asio_endpoint_(asio_endpoint)
  68. {}
  69. ~TCPEndpoint() { delete asio_endpoint_placeholder_; }
  70. virtual IOAddress getAddress() const {
  71. return (asio_endpoint_.address());
  72. }
  73. private:
  74. const tcp::endpoint* asio_endpoint_placeholder_;
  75. const tcp::endpoint& asio_endpoint_;
  76. };
  77. class UDPEndpoint : public IOEndpoint {
  78. public:
  79. UDPEndpoint(const IOAddress& address, const unsigned short port) :
  80. asio_endpoint_placeholder_(
  81. new udp::endpoint(ip::address::from_string(address.toText()),
  82. port)),
  83. asio_endpoint_(*asio_endpoint_placeholder_)
  84. {}
  85. UDPEndpoint(const udp::endpoint& asio_endpoint) :
  86. asio_endpoint_placeholder_(NULL), asio_endpoint_(asio_endpoint)
  87. {}
  88. ~UDPEndpoint() { delete asio_endpoint_placeholder_; }
  89. virtual IOAddress getAddress() const {
  90. return (asio_endpoint_.address());
  91. }
  92. private:
  93. const udp::endpoint* asio_endpoint_placeholder_;
  94. const udp::endpoint& asio_endpoint_;
  95. };
  96. const IOEndpoint*
  97. IOEndpoint::create(const int protocol, const IOAddress& address,
  98. const unsigned short port)
  99. {
  100. if (protocol == IPPROTO_UDP) {
  101. return (new UDPEndpoint(address, port));
  102. } else if (protocol == IPPROTO_TCP) {
  103. return (new TCPEndpoint(address, port));
  104. }
  105. isc_throw(IOError,
  106. "IOEndpoint creation attempt for unsupported protocol: " <<
  107. protocol);
  108. }
  109. class TCPSocket : public IOSocket {
  110. private:
  111. TCPSocket(const TCPSocket& source);
  112. TCPSocket& operator=(const TCPSocket& source);
  113. public:
  114. TCPSocket(tcp::socket& socket) : socket_(socket) {}
  115. virtual int getNative() const { return (socket_.native()); }
  116. virtual int getProtocol() const { return (IPPROTO_TCP); }
  117. private:
  118. tcp::socket& socket_;
  119. };
  120. class UDPSocket : public IOSocket {
  121. private:
  122. UDPSocket(const UDPSocket& source);
  123. UDPSocket& operator=(const UDPSocket& source);
  124. public:
  125. UDPSocket(udp::socket& socket) : socket_(socket) {}
  126. virtual int getNative() const { return (socket_.native()); }
  127. virtual int getProtocol() const { return (IPPROTO_UDP); }
  128. private:
  129. udp::socket& socket_;
  130. };
  131. class DummySocket : public IOSocket {
  132. private:
  133. DummySocket(const DummySocket& source);
  134. DummySocket& operator=(const DummySocket& source);
  135. public:
  136. DummySocket(const int protocol) : protocol_(protocol) {}
  137. virtual int getNative() const { return (-1); }
  138. virtual int getProtocol() const { return (protocol_); }
  139. private:
  140. const int protocol_;
  141. };
  142. IOSocket&
  143. IOSocket::getDummyUDPSocket() {
  144. static DummySocket socket(IPPROTO_UDP);
  145. return (socket);
  146. }
  147. IOSocket&
  148. IOSocket::getDummyTCPSocket() {
  149. static DummySocket socket(IPPROTO_TCP);
  150. return (socket);
  151. }
  152. IOMessage::IOMessage(const void* data, const size_t data_size,
  153. IOSocket& io_socket, const IOEndpoint& remote_endpoint) :
  154. data_(data), data_size_(data_size), io_socket_(io_socket),
  155. remote_endpoint_(remote_endpoint)
  156. {}
  157. //
  158. // Helper classes for asynchronous I/O using asio
  159. //
  160. class TCPClient {
  161. public:
  162. TCPClient(AuthSrv* auth_server, io_service& io_service) :
  163. auth_server_(auth_server),
  164. socket_(io_service),
  165. io_socket_(socket_),
  166. response_buffer_(0),
  167. responselen_buffer_(TCP_MESSAGE_LENGTHSIZE),
  168. response_renderer_(response_buffer_),
  169. dns_message_(Message::PARSE),
  170. custom_callback_(NULL)
  171. {}
  172. void start() {
  173. // Check for queued configuration commands
  174. if (auth_server_ != NULL &&
  175. auth_server_->configSession()->hasQueuedMsgs()) {
  176. auth_server_->configSession()->checkCommand();
  177. }
  178. async_read(socket_, asio::buffer(data_, TCP_MESSAGE_LENGTHSIZE),
  179. boost::bind(&TCPClient::headerRead, this,
  180. placeholders::error,
  181. placeholders::bytes_transferred));
  182. }
  183. tcp::socket& getSocket() { return (socket_); }
  184. void headerRead(const asio::error_code& error,
  185. size_t bytes_transferred)
  186. {
  187. if (!error) {
  188. InputBuffer dnsbuffer(data_, bytes_transferred);
  189. uint16_t msglen = dnsbuffer.readUint16();
  190. async_read(socket_, asio::buffer(data_, msglen),
  191. boost::bind(&TCPClient::requestRead, this,
  192. placeholders::error,
  193. placeholders::bytes_transferred));
  194. } else {
  195. delete this;
  196. }
  197. }
  198. void requestRead(const asio::error_code& error,
  199. size_t bytes_transferred)
  200. {
  201. if (!error) {
  202. const TCPEndpoint remote_endpoint(socket_.remote_endpoint());
  203. const IOMessage io_message(data_, bytes_transferred, io_socket_,
  204. remote_endpoint);
  205. // currently, for testing purpose only
  206. if (custom_callback_ != NULL) {
  207. (*custom_callback_)(io_message);
  208. start();
  209. return;
  210. }
  211. if (auth_server_->processMessage(io_message, dns_message_,
  212. response_renderer_)) {
  213. responselen_buffer_.writeUint16(
  214. response_buffer_.getLength());
  215. async_write(socket_,
  216. asio::buffer(
  217. responselen_buffer_.getData(),
  218. responselen_buffer_.getLength()),
  219. boost::bind(&TCPClient::responseWrite, this,
  220. placeholders::error));
  221. } else {
  222. delete this;
  223. }
  224. } else {
  225. delete this;
  226. }
  227. }
  228. void responseWrite(const asio::error_code& error) {
  229. if (!error) {
  230. async_write(socket_,
  231. asio::buffer(response_buffer_.getData(),
  232. response_buffer_.getLength()),
  233. boost::bind(&TCPClient::handleWrite, this,
  234. placeholders::error));
  235. } else {
  236. delete this;
  237. }
  238. }
  239. void handleWrite(const asio::error_code& error) {
  240. if (!error) {
  241. start(); // handle next request, if any.
  242. } else {
  243. delete this;
  244. }
  245. }
  246. // Currently this is for tests only
  247. void setCallBack(const IOService::IOCallBack* callback) {
  248. custom_callback_ = callback;
  249. }
  250. private:
  251. AuthSrv* auth_server_;
  252. tcp::socket socket_;
  253. TCPSocket io_socket_;
  254. OutputBuffer response_buffer_;
  255. OutputBuffer responselen_buffer_;
  256. MessageRenderer response_renderer_;
  257. Message dns_message_;
  258. enum { MAX_LENGTH = 65535 };
  259. static const size_t TCP_MESSAGE_LENGTHSIZE = 2;
  260. char data_[MAX_LENGTH];
  261. // currently, for testing purpose only.
  262. const IOService::IOCallBack* custom_callback_;
  263. };
  264. class TCPServer {
  265. public:
  266. TCPServer(AuthSrv* auth_server, io_service& io_service,
  267. const ip::address& addr, const uint16_t port) :
  268. auth_server_(auth_server), io_service_(io_service),
  269. acceptor_(io_service_), listening_(new TCPClient(auth_server_,
  270. io_service_)),
  271. custom_callback_(NULL)
  272. {
  273. tcp::endpoint endpoint(addr, port);
  274. acceptor_.open(endpoint.protocol());
  275. // Set v6-only (we use a different instantiation for v4,
  276. // otherwise asio will bind to both v4 and v6
  277. if (addr.is_v6()) {
  278. acceptor_.set_option(ip::v6_only(true));
  279. }
  280. acceptor_.set_option(tcp::acceptor::reuse_address(true));
  281. acceptor_.bind(endpoint);
  282. acceptor_.listen();
  283. acceptor_.async_accept(listening_->getSocket(),
  284. boost::bind(&TCPServer::handleAccept, this,
  285. listening_, placeholders::error));
  286. }
  287. ~TCPServer() { delete listening_; }
  288. void handleAccept(TCPClient* new_client,
  289. const asio::error_code& error)
  290. {
  291. if (!error) {
  292. assert(new_client == listening_);
  293. new_client->setCallBack(custom_callback_);
  294. new_client->start();
  295. listening_ = new TCPClient(auth_server_, io_service_);
  296. acceptor_.async_accept(listening_->getSocket(),
  297. boost::bind(&TCPServer::handleAccept,
  298. this, listening_,
  299. placeholders::error));
  300. } else {
  301. delete new_client;
  302. }
  303. }
  304. // Currently this is for tests only
  305. void setCallBack(const IOService::IOCallBack* callback) {
  306. custom_callback_ = callback;
  307. }
  308. private:
  309. AuthSrv* auth_server_;
  310. io_service& io_service_;
  311. tcp::acceptor acceptor_;
  312. TCPClient* listening_;
  313. // currently, for testing purpose only.
  314. const IOService::IOCallBack* custom_callback_;
  315. };
  316. class UDPServer {
  317. public:
  318. UDPServer(AuthSrv* auth_server, io_service& io_service,
  319. const ip::address& addr, const uint16_t port) :
  320. auth_server_(auth_server),
  321. io_service_(io_service),
  322. socket_(io_service, addr.is_v6() ? udp::v6() : udp::v4()),
  323. io_socket_(socket_),
  324. response_buffer_(0),
  325. response_renderer_(response_buffer_),
  326. dns_message_(Message::PARSE),
  327. custom_callback_(NULL)
  328. {
  329. socket_.set_option(socket_base::reuse_address(true));
  330. // Set v6-only (we use a different instantiation for v4,
  331. // otherwise asio will bind to both v4 and v6
  332. if (addr.is_v6()) {
  333. socket_.set_option(asio::ip::v6_only(true));
  334. socket_.bind(udp::endpoint(addr, port));
  335. } else {
  336. socket_.bind(udp::endpoint(addr, port));
  337. }
  338. startReceive();
  339. }
  340. void handleRequest(const asio::error_code& error,
  341. size_t bytes_recvd)
  342. {
  343. // Check for queued configuration commands
  344. if (auth_server_ != NULL &&
  345. auth_server_->configSession()->hasQueuedMsgs()) {
  346. auth_server_->configSession()->checkCommand();
  347. }
  348. if (!error && bytes_recvd > 0) {
  349. const UDPEndpoint remote_endpoint(sender_endpoint_);
  350. const IOMessage io_message(data_, bytes_recvd, io_socket_,
  351. remote_endpoint);
  352. // currently, for testing purpose only
  353. if (custom_callback_ != NULL) {
  354. (*custom_callback_)(io_message);
  355. startReceive();
  356. return;
  357. }
  358. dns_message_.clear(Message::PARSE);
  359. response_renderer_.clear();
  360. if (auth_server_->processMessage(io_message, dns_message_,
  361. response_renderer_)) {
  362. socket_.async_send_to(
  363. asio::buffer(response_buffer_.getData(),
  364. response_buffer_.getLength()),
  365. sender_endpoint_,
  366. boost::bind(&UDPServer::sendCompleted,
  367. this,
  368. placeholders::error,
  369. placeholders::bytes_transferred));
  370. } else {
  371. startReceive();
  372. }
  373. } else {
  374. startReceive();
  375. }
  376. }
  377. void sendCompleted(const asio::error_code& error UNUSED_PARAM,
  378. size_t bytes_sent UNUSED_PARAM)
  379. {
  380. // Even if error occurred there's nothing to do. Simply handle
  381. // the next request.
  382. startReceive();
  383. }
  384. // Currently this is for tests only
  385. void setCallBack(const IOService::IOCallBack* callback) {
  386. custom_callback_ = callback;
  387. }
  388. private:
  389. void startReceive() {
  390. socket_.async_receive_from(
  391. asio::buffer(data_, MAX_LENGTH), sender_endpoint_,
  392. boost::bind(&UDPServer::handleRequest, this,
  393. placeholders::error,
  394. placeholders::bytes_transferred));
  395. }
  396. private:
  397. AuthSrv* auth_server_;
  398. io_service& io_service_;
  399. udp::socket socket_;
  400. UDPSocket io_socket_;
  401. OutputBuffer response_buffer_;
  402. MessageRenderer response_renderer_;
  403. Message dns_message_;
  404. udp::endpoint sender_endpoint_;
  405. enum { MAX_LENGTH = 4096 };
  406. char data_[MAX_LENGTH];
  407. // currently, for testing purpose only.
  408. const IOService::IOCallBack* custom_callback_;
  409. };
  410. class IOServiceImpl {
  411. public:
  412. IOServiceImpl(AuthSrv* auth_server, const char& port,
  413. const ip::address* v4addr, const ip::address* v6addr);
  414. asio::io_service io_service_;
  415. AuthSrv* auth_server_;
  416. typedef boost::shared_ptr<UDPServer> UDPServerPtr;
  417. typedef boost::shared_ptr<TCPServer> TCPServerPtr;
  418. UDPServerPtr udp4_server_;
  419. UDPServerPtr udp6_server_;
  420. TCPServerPtr tcp4_server_;
  421. TCPServerPtr tcp6_server_;
  422. // This member is used only for testing at the moment.
  423. IOService::IOCallBack callback_;
  424. };
  425. IOServiceImpl::IOServiceImpl(AuthSrv* auth_server, const char& port,
  426. const ip::address* const v4addr,
  427. const ip::address* const v6addr) :
  428. auth_server_(auth_server),
  429. udp4_server_(UDPServerPtr()), udp6_server_(UDPServerPtr()),
  430. tcp4_server_(TCPServerPtr()), tcp6_server_(TCPServerPtr())
  431. {
  432. uint16_t portnum;
  433. try {
  434. // XXX: SunStudio with stlport4 doesn't reject some invalid
  435. // representation such as "-1" by lexical_cast<uint16_t>, so
  436. // we convert it into a signed integer of a larger size and perform
  437. // range check ourselves.
  438. const int32_t portnum32 = boost::lexical_cast<int32_t>(&port);
  439. if (portnum32 < 0 || portnum32 > 65535) {
  440. isc_throw(IOError, "Invalid port number '" << &port);
  441. }
  442. portnum = portnum32;
  443. } catch (const boost::bad_lexical_cast& ex) {
  444. isc_throw(IOError, "Invalid port number '" << &port << "': " <<
  445. ex.what());
  446. }
  447. try {
  448. if (v4addr != NULL) {
  449. udp4_server_ = UDPServerPtr(new UDPServer(auth_server, io_service_,
  450. *v4addr, portnum));
  451. tcp4_server_ = TCPServerPtr(new TCPServer(auth_server, io_service_,
  452. *v4addr, portnum));
  453. }
  454. if (v6addr != NULL) {
  455. udp6_server_ = UDPServerPtr(new UDPServer(auth_server, io_service_,
  456. *v6addr, portnum));
  457. tcp6_server_ = TCPServerPtr(new TCPServer(auth_server, io_service_,
  458. *v6addr, portnum));
  459. }
  460. } catch (const asio::system_error& err) {
  461. // We need to catch and convert any ASIO level exceptions.
  462. // This can happen for unavailable address, binding a privilege port
  463. // without the privilege, etc.
  464. isc_throw(IOError, "Failed to initialize network servers: " <<
  465. err.what());
  466. }
  467. }
  468. IOService::IOService(AuthSrv* auth_server, const char& port,
  469. const char& address) :
  470. impl_(NULL)
  471. {
  472. error_code err;
  473. const ip::address addr = ip::address::from_string(&address, err);
  474. if (err) {
  475. isc_throw(IOError, "Invalid IP address '" << &address << "': "
  476. << err.message());
  477. }
  478. impl_ = new IOServiceImpl(auth_server, port,
  479. addr.is_v4() ? &addr : NULL,
  480. addr.is_v6() ? &addr : NULL);
  481. }
  482. IOService::IOService(AuthSrv* auth_server, const char& port,
  483. const bool use_ipv4, const bool use_ipv6) :
  484. impl_(NULL)
  485. {
  486. const ip::address v4addr_any = ip::address(ip::address_v4::any());
  487. const ip::address* const v4addrp = use_ipv4 ? &v4addr_any : NULL;
  488. const ip::address v6addr_any = ip::address(ip::address_v6::any());
  489. const ip::address* const v6addrp = use_ipv6 ? &v6addr_any : NULL;
  490. impl_ = new IOServiceImpl(auth_server, port, v4addrp, v6addrp);
  491. }
  492. IOService::~IOService() {
  493. delete impl_;
  494. }
  495. void
  496. IOService::run() {
  497. impl_->io_service_.run();
  498. }
  499. void
  500. IOService::stop() {
  501. impl_->io_service_.stop();
  502. }
  503. asio::io_service&
  504. IOService::get_io_service() {
  505. return impl_->io_service_;
  506. }
  507. void
  508. IOService::setCallBack(const IOCallBack callback) {
  509. impl_->callback_ = callback;
  510. if (impl_->udp4_server_ != NULL) {
  511. impl_->udp4_server_->setCallBack(&impl_->callback_);
  512. }
  513. if (impl_->udp6_server_ != NULL) {
  514. impl_->udp6_server_->setCallBack(&impl_->callback_);
  515. }
  516. if (impl_->tcp4_server_ != NULL) {
  517. impl_->tcp4_server_->setCallBack(&impl_->callback_);
  518. }
  519. if (impl_->tcp6_server_ != NULL) {
  520. impl_->tcp6_server_->setCallBack(&impl_->callback_);
  521. }
  522. }
  523. }