Parcourir la source

[5078] Created UnixDomainSocket class in asiolink.

Marcin Siodelski il y a 8 ans
Parent
commit
0576c80854

+ 1 - 0
src/lib/asiolink/Makefile.am

@@ -32,6 +32,7 @@ libkea_asiolink_la_SOURCES += tcp_endpoint.h
 libkea_asiolink_la_SOURCES += tcp_socket.h
 libkea_asiolink_la_SOURCES += udp_endpoint.h
 libkea_asiolink_la_SOURCES += udp_socket.h
+libkea_asiolink_la_SOURCES += unix_domain_socket.cc unix_domain_socket.h
 
 # Note: the ordering matters: -Wno-... must follow -Wextra (defined in
 # KEA_CXXFLAGS)

+ 3 - 1
src/lib/asiolink/tests/Makefile.am

@@ -1,6 +1,7 @@
 AM_CPPFLAGS = -I$(top_srcdir)/src/lib -I$(top_builddir)/src/lib
 AM_CPPFLAGS += $(BOOST_INCLUDES)
 AM_CPPFLAGS += -DTEST_DATA_DIR=\"$(srcdir)/testdata\"
+AM_CPPFLAGS += -DTEST_DATA_BUILDDIR=\"$(abs_top_builddir)/src/lib/asiolink/tests\"
 
 AM_CXXFLAGS = $(KEA_CXXFLAGS)
 
@@ -8,7 +9,7 @@ if USE_STATIC_LINK
 AM_LDFLAGS = -static
 endif
 
-CLEANFILES = *.gcno *.gcda
+CLEANFILES = *.gcno *.gcda test-socket
 
 TESTS_ENVIRONMENT = \
 	$(LIBTOOL) --mode=execute $(VALGRIND_COMMAND)
@@ -28,6 +29,7 @@ run_unittests_SOURCES += udp_socket_unittest.cc
 run_unittests_SOURCES += io_service_unittest.cc
 run_unittests_SOURCES += dummy_io_callback_unittest.cc
 run_unittests_SOURCES += tcp_acceptor_unittest.cc
+run_unittests_SOURCES += unix_domain_socket_unittest.cc
 
 run_unittests_CPPFLAGS = $(AM_CPPFLAGS) $(GTEST_INCLUDES)
 

+ 193 - 0
src/lib/asiolink/tests/unix_domain_socket_unittest.cc

@@ -0,0 +1,193 @@
+// Copyright (C) 2017 Internet Systems Consortium, Inc. ("ISC")
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#include <config.h>
+#include <asiolink/asio_wrapper.h>
+#include <asiolink/interval_timer.h>
+#include <asiolink/io_service.h>
+#include <asiolink/unix_domain_socket.h>
+#include <boost/bind.hpp>
+#include <gtest/gtest.h>
+#include <array>
+#include <cstdio>
+#include <sstream>
+#include <string>
+
+using namespace boost::asio;
+using namespace boost::asio::local;
+using namespace isc::asiolink;
+
+namespace  {
+
+/// @brief Test unix socket file name.
+const std::string TEST_SOCKET = "test-socket";
+
+/// @brief Test timeout in ms.
+const long TEST_TIMEOUT = 10000;
+
+/// @brief Test fixture class for @ref UnixDomainSocket class.
+class UnixDomainSocketTest : public ::testing::Test {
+public:
+
+    /// @brief Constructor.
+    ///
+    /// Removes unix socket descriptor before the test.
+    UnixDomainSocketTest() : io_service_(),
+                             server_endpoint_(unixSocketFilePath()),
+                             server_acceptor_(io_service_.get_io_service()),
+                             server_socket_(io_service_.get_io_service()),
+                             test_timer_(io_service_) {
+        removeUnixSocketFile();
+        test_timer_.setup(boost::bind(&UnixDomainSocketTest::timeoutHandler, this),
+                                      TEST_TIMEOUT, IntervalTimer::ONE_SHOT);
+    }
+
+    /// @brief Destructor.
+    ///
+    /// Removes unix socket descriptor after the test.
+    virtual ~UnixDomainSocketTest() {
+        removeUnixSocketFile();
+    }
+
+    /// @brief Returns socket file path.
+    static std::string unixSocketFilePath() {
+        std::ostringstream s;
+        s << TEST_DATA_BUILDDIR << "/" << TEST_SOCKET;
+        return (s.str());
+    }
+
+    /// @brief Removes unix socket descriptor.
+    void removeUnixSocketFile() {
+        static_cast<void>(remove(unixSocketFilePath().c_str()));
+    }
+
+    /// @brief Creates and binds server socket.
+    void bindServerSocket() {
+        server_acceptor_.open();
+        server_acceptor_.bind(server_endpoint_);
+        server_acceptor_.listen();
+        server_acceptor_.async_accept(server_socket_,
+                                      boost::bind(&UnixDomainSocketTest::
+                                                  acceptHandler, this, _1));
+    }
+
+    /// @brief Server acceptor handler.
+    ///
+    /// @param ec Error code.
+    void acceptHandler(const boost::system::error_code& ec) {
+        if (ec) {
+            ADD_FAILURE() << ec.message();
+        }
+        server_socket_.async_read_some(boost::asio::buffer(&raw_buf_[0],
+                                                           raw_buf_.size()),
+                                       boost::bind(&UnixDomainSocketTest::
+                                                   readHandler, this, _1, _2));
+    }
+
+    /// @brief Server read handler.
+    ///
+    /// @param ec Error code.
+    /// @param bytes_transferred Number of bytes read.
+    void readHandler(const boost::system::error_code& ec,
+                     size_t bytes_transferred) {
+        std::string received(&raw_buf_[0], bytes_transferred);
+        std::string response("received " + received);
+        boost::asio::write(server_socket_, boost::asio::buffer(response.c_str(),
+                                                               response.size()));
+        io_service_.stop();
+    }
+
+    /// @brief Callback function invoke upon test timeout.
+    ///
+    /// It stops the IO service and reports test timeout.
+    void timeoutHandler() {
+        ADD_FAILURE() << "Timeout occurred while running the test!";
+        io_service_.stop();
+    }
+
+    /// @brief IO service used by the tests.
+    IOService io_service_;
+
+    /// @brief Server endpoint.
+    local::stream_protocol::endpoint server_endpoint_;
+
+    /// @brief Server acceptor.
+    local::stream_protocol::acceptor server_acceptor_;
+
+    /// @brief Server side unix domain socket.
+    stream_protocol::socket server_socket_;
+
+    /// @brief Receive buffer.
+    std::array<char, 1024> raw_buf_;
+
+    /// @brief Asynchronous timer service to detect timeouts.
+    IntervalTimer test_timer_;
+};
+
+// This test verifies that the client can send data over the unix
+// domain socket and receive a response.
+TEST_F(UnixDomainSocketTest, sendReceive) {
+    // Start the server.
+    bindServerSocket();
+
+    // Setup client side.
+    UnixDomainSocket socket(io_service_);
+    ASSERT_NO_THROW(socket.connect(TEST_SOCKET));
+
+    // Send "foo".
+    const std::string outbound_data = "foo";
+    size_t sent_size = 0;
+    ASSERT_NO_THROW(sent_size = socket.write(outbound_data.c_str(),
+                                             outbound_data.size()));
+    // Make sure all data have been sent.
+    ASSERT_EQ(outbound_data.size(), sent_size);
+
+    // Run IO service to generate server's response.
+    io_service_.run();
+
+    // Receive response from the socket.
+    std::array<char, 1024> read_buf;
+    size_t bytes_read = 0;
+    ASSERT_NO_THROW(bytes_read = socket.receive(&read_buf[0], read_buf.size()));
+    std::string response(&read_buf[0], bytes_read);
+
+    // The server should prepend "received" to the data we had sent.
+    EXPECT_EQ("received foo", response);
+}
+
+// This test verifies that UnixDomainSocketError exception is thrown
+// on attempt to connect, write or receive when the server socket
+// is not available.
+TEST_F(UnixDomainSocketTest, clientErrors) {
+    UnixDomainSocket socket(io_service_);
+    ASSERT_THROW(socket.connect(TEST_SOCKET), UnixDomainSocketError);
+    const std::string outbound_data = "foo";
+    ASSERT_THROW(socket.write(outbound_data.c_str(), outbound_data.size()),
+                 UnixDomainSocketError);
+    std::array<char, 1024> read_buf;
+    ASSERT_THROW(socket.receive(&read_buf[0], read_buf.size()),
+                 UnixDomainSocketError);
+}
+
+// Check that native socket descriptor is returned correctly when
+// the socket is connected.
+TEST_F(UnixDomainSocketTest, getNative) {
+    // Start the server.
+    bindServerSocket();
+
+    // Setup client side.
+    UnixDomainSocket socket(io_service_);
+    ASSERT_NO_THROW(socket.connect(TEST_SOCKET));
+    ASSERT_GE(socket.getNative(), 0);
+}
+
+// Check that protocol returned is 0.
+TEST_F(UnixDomainSocketTest, getProtocol) {
+    UnixDomainSocket socket(io_service_);
+    EXPECT_EQ(0, socket.getProtocol());
+}
+
+}

+ 96 - 0
src/lib/asiolink/unix_domain_socket.cc

@@ -0,0 +1,96 @@
+// Copyright (C) 2017 Internet Systems Consortium, Inc. ("ISC")
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#include <asiolink/asio_wrapper.h>
+#include <asiolink/unix_domain_socket.h>
+#include <iostream>
+using namespace boost::asio::local;
+
+namespace isc {
+namespace asiolink {
+
+/// @brief Implementation of the unix domain socket.
+class UnixDomainSocketImpl {
+public:
+
+    /// @brief Constructor.
+    ///
+    /// @param io_service IO service to be used by the socket class.
+    UnixDomainSocketImpl(IOService& io_service)
+        : socket_(io_service.get_io_service()) {
+    }
+
+    /// @brief Destructor.
+    ///
+    /// Closes the socket.
+    ~UnixDomainSocketImpl() {
+        close();
+    }
+
+    /// @brief Closes the socket.
+    void close();
+
+    /// @brief Instance of the boost asio unix domain socket.
+    stream_protocol::socket socket_;
+};
+
+void
+UnixDomainSocketImpl::close() {
+    static_cast<void>(socket_.close());
+}
+
+UnixDomainSocket::UnixDomainSocket(IOService& io_service)
+    : impl_(new UnixDomainSocketImpl(io_service)) {
+}
+
+int
+UnixDomainSocket::getNative() const {
+    return (impl_->socket_.native());
+}
+
+int
+UnixDomainSocket::getProtocol() const {
+    return (0);
+}
+
+void
+UnixDomainSocket::connect(const std::string& path) {
+    boost::system::error_code ec;
+    impl_->socket_.connect(stream_protocol::endpoint(path.c_str()), ec);
+    if (ec) {
+        isc_throw(UnixDomainSocketError, ec.message());
+    }
+}
+
+size_t
+UnixDomainSocket::write(const void* data, size_t length) {
+    boost::system::error_code ec;
+    size_t res = boost::asio::write(impl_->socket_,
+                                    boost::asio::buffer(data, length),
+                                    ec);
+    if (ec) {
+        isc_throw(UnixDomainSocketError, ec.message());
+    }
+    return (res);
+}
+
+size_t
+UnixDomainSocket::receive(void* data, size_t length) {
+    boost::system::error_code ec;
+    size_t res = impl_->socket_.receive(boost::asio::buffer(data, length), 0, ec);
+    if (ec) {
+        isc_throw(UnixDomainSocketError, ec.message());
+    }
+    return (res);
+}
+
+void
+UnixDomainSocket::close() {
+    impl_->close();
+}
+
+}
+}

+ 89 - 0
src/lib/asiolink/unix_domain_socket.h

@@ -0,0 +1,89 @@
+// Copyright (C) 2017 Internet Systems Consortium, Inc. ("ISC")
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef UNIX_DOMAIN_SOCKET_H
+#define UNIX_DOMAIN_SOCKET_H
+
+#include <asiolink/io_service.h>
+#include <asiolink/io_socket.h>
+#include <boost/shared_ptr.hpp>
+#include <string>
+
+namespace isc {
+namespace asiolink {
+
+/// @brief Exception thrown upon socket error.
+class UnixDomainSocketError : public Exception {
+public:
+    UnixDomainSocketError(const char* file, size_t line, const char* what) :
+        isc::Exception(file, line, what) { };
+};
+
+class UnixDomainSocketImpl;
+
+/// @brief Represents unix domain socket implemented in terms
+/// of boost asio.
+class UnixDomainSocket : public IOSocket {
+public:
+
+    /// @brief Constructor.
+    ///
+    /// @param io_service Reference to IOService to be used by this
+    /// class.
+    UnixDomainSocket(IOService& io_service);
+
+    /// @brief Returns native socket representation.
+    virtual int getNative() const;
+
+    /// @brief Always returns 0.
+    virtual int getProtocol() const;
+
+    /// @brief Connects the socket to the specified endpoint.
+    ///
+    /// @param endpoint Endpoint to connect to.
+    /// @param [out] ec Error code returned as a result of an attempt to
+    /// connect.
+    ///
+    /// @throw UnixDomainSocketError if error occurs.
+    void connect(const std::string& path);
+
+    /// @brief Writes specified amount of data to a socket.
+    ///
+    /// @param data Pointer to data to be written.
+    /// @param length Number of bytes to be written.
+    /// @param [out] ec Error code returned as a result of an attempt to
+    /// write to socket.
+    ///
+    /// @return Number of bytes written.
+    /// @throw UnixDomainSocketError if error occurs.
+    size_t write(const void* data, size_t length);
+
+    /// @brief Receives data from a socket.
+    ///
+    /// @param [out] data Pointer to a location into which the read data should
+    /// be stored.
+    /// @param length Length of the buffer.
+    /// @param [out] ec Error code returned as a result of an attempt to
+    /// read from socket.
+    ///
+    /// @return Number of bytes read.
+    /// @throw UnixDomainSocketError if error occurs.
+    size_t receive(void* data, size_t length);
+
+    /// @brief Closes the socket.
+    void close();
+
+private:
+
+    /// @brief Pointer to the implementation of this class.
+    boost::shared_ptr<UnixDomainSocketImpl> impl_;
+
+};
+
+} // end of namespace isc::asiolink
+} // end of namespace isc
+
+#endif // UNIX_DOMAIN_SOCKET_H