sockcreator.cc 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. // Copyright (C) 2011 Internet Systems Consortium, Inc. ("ISC")
  2. //
  3. // Permission to use, copy, modify, and/or distribute this software for any
  4. // purpose with or without fee is hereby granted, provided that the above
  5. // copyright notice and this permission notice appear in all copies.
  6. //
  7. // THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH
  8. // REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
  9. // AND FITNESS. IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT,
  10. // INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
  11. // LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
  12. // OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
  13. // PERFORMANCE OF THIS SOFTWARE.
  14. #include "sockcreator.h"
  15. #include <util/io/fd.h>
  16. #include <cerrno>
  17. #include <cstring>
  18. #include <unistd.h>
  19. #include <sys/types.h>
  20. #include <sys/socket.h>
  21. #include <netinet/in.h>
  22. using namespace isc::util::io;
  23. using namespace isc::socket_creator;
  24. namespace {
  25. // Simple wrappers for read_data/write_data that throw an exception on error.
  26. void
  27. readMessage(const int fd, void* where, const size_t length) {
  28. if (read_data(fd, where, length) < length) {
  29. isc_throw(ReadError, "Error reading from socket creator client");
  30. }
  31. }
  32. void
  33. writeMessage(const int fd, const void* what, const size_t length) {
  34. if (!write_data(fd, what, length)) {
  35. isc_throw(WriteError, "Error writing to socket creator client");
  36. }
  37. }
  38. // Exit on a protocol error after informing the client of the problem.
  39. void
  40. protocolError(const int fd, const char reason = 'I') {
  41. // Tell client we have a problem
  42. char message[2];
  43. message[0] = 'F';
  44. message[1] = reason;
  45. writeMessage(fd, message, sizeof(message));
  46. // ... and exit
  47. isc_throw(ProtocolError, "Fatal error, reason: " << reason);
  48. }
  49. // Return appropriate socket type constant for the socket type requested.
  50. // The output_fd argument is required to report a protocol error.
  51. int getSocketType(const char type_code, const int output_fd) {
  52. int socket_type = 0;
  53. switch (type_code) {
  54. case 'T':
  55. socket_type = SOCK_STREAM;
  56. break;
  57. case 'U':
  58. socket_type = SOCK_DGRAM;
  59. break;
  60. default:
  61. protocolError(output_fd); // Does not return
  62. }
  63. return (socket_type);
  64. }
  65. // Convert return status from getSock() to a character to be sent back to
  66. // the caller.
  67. char getErrorCode(const int status) {
  68. char error_code = ' ';
  69. switch (status) {
  70. case -1:
  71. error_code = 'S';
  72. break;
  73. case -2:
  74. error_code = 'B';
  75. break;
  76. default:
  77. isc_throw(InternalError, "Error creating socket");
  78. }
  79. return (error_code);
  80. }
  81. // Handle the request from the client.
  82. //
  83. // Reads the type and family of socket required, creates the socket and returns
  84. // it to the client.
  85. //
  86. // The arguments passed (and the exceptions thrown) are the same as those for
  87. // run().
  88. void
  89. handleRequest(const int input_fd, const int output_fd,
  90. const get_sock_t get_sock, const send_fd_t send_fd_fun,
  91. const close_t close_fun)
  92. {
  93. // Read the message from the client
  94. char type[2];
  95. readMessage(input_fd, type, sizeof(type));
  96. // Decide what type of socket is being asked for
  97. const int sock_type = getSocketType(type[0], output_fd);
  98. // Read the address they ask for depending on what address family was
  99. // specified.
  100. sockaddr* addr = NULL;
  101. size_t addr_len = 0;
  102. sockaddr_in addr_in;
  103. sockaddr_in6 addr_in6;
  104. switch (type[1]) { // The address family
  105. // The casting to apparently incompatible types by reinterpret_cast
  106. // is required by the C low-level interface.
  107. case '4':
  108. addr = reinterpret_cast<sockaddr*>(&addr_in);
  109. addr_len = sizeof(addr_in);
  110. memset(&addr_in, 0, sizeof(addr_in));
  111. addr_in.sin_family = AF_INET;
  112. readMessage(input_fd, &addr_in.sin_port, sizeof(addr_in.sin_port));
  113. readMessage(input_fd, &addr_in.sin_addr.s_addr,
  114. sizeof(addr_in.sin_addr.s_addr));
  115. break;
  116. case '6':
  117. addr = reinterpret_cast<sockaddr*>(&addr_in6);
  118. addr_len = sizeof addr_in6;
  119. memset(&addr_in6, 0, sizeof(addr_in6));
  120. addr_in6.sin6_family = AF_INET6;
  121. readMessage(input_fd, &addr_in6.sin6_port,
  122. sizeof(addr_in6.sin6_port));
  123. readMessage(input_fd, &addr_in6.sin6_addr.s6_addr,
  124. sizeof(addr_in6.sin6_addr.s6_addr));
  125. break;
  126. default:
  127. protocolError(output_fd);
  128. }
  129. // Obtain the socket
  130. const int result = get_sock(sock_type, addr, addr_len);
  131. if (result >= 0) {
  132. // Got the socket, send it to the client.
  133. writeMessage(output_fd, "S", 1);
  134. if (send_fd_fun(output_fd, result) != 0) {
  135. // Error. Close the socket (ignore any error from that operation)
  136. // and abort.
  137. close_fun(result);
  138. isc_throw(InternalError, "Error sending descriptor");
  139. }
  140. // Successfully sent the socket, so free up resources we still hold
  141. // for it.
  142. if (close_fun(result) == -1) {
  143. isc_throw(InternalError, "Error closing socket");
  144. }
  145. } else {
  146. // Error. Tell the client.
  147. char error_message[2];
  148. error_message[0] = 'E';
  149. error_message[1] = getErrorCode(result);
  150. writeMessage(output_fd, error_message, sizeof(error_message));
  151. // ...and append the reason code to the error message
  152. const int error_number = errno;
  153. writeMessage(output_fd, &error_number, sizeof(error_number));
  154. }
  155. }
  156. } // Anonymous namespace
  157. namespace isc {
  158. namespace socket_creator {
  159. // Get the socket and bind to it.
  160. int
  161. getSock(const int type, struct sockaddr* bind_addr, const socklen_t addr_len) {
  162. const int sock = socket(bind_addr->sa_family, type, 0);
  163. if (sock == -1) {
  164. return (-1);
  165. }
  166. const int on = 1;
  167. if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) == -1) {
  168. // This is part of the binding process, so it's a bind error
  169. return (-2);
  170. }
  171. if (bind_addr->sa_family == AF_INET6 &&
  172. setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, &on, sizeof(on)) == -1) {
  173. // This is part of the binding process, so it's a bind error
  174. return (-2);
  175. }
  176. if (bind(sock, bind_addr, addr_len) == -1) {
  177. return (-2);
  178. }
  179. return (sock);
  180. }
  181. // Main run loop.
  182. void
  183. run(const int input_fd, const int output_fd, const get_sock_t get_sock,
  184. const send_fd_t send_fd_fun, const close_t close_fun)
  185. {
  186. for (;;) {
  187. char command;
  188. readMessage(input_fd, &command, sizeof(command));
  189. switch (command) {
  190. case 'S': // The "get socket" command
  191. handleRequest(input_fd, output_fd, get_sock,
  192. send_fd_fun, close_fun);
  193. break;
  194. case 'T': // The "terminate" command
  195. return;
  196. default: // Don't recognise anything else
  197. protocolError(output_fd);
  198. }
  199. }
  200. }
  201. } // End of the namespaces
  202. }