// Copyright (C) 2011  Internet Systems Consortium, Inc. ("ISC")
//
// Permission to use, copy, modify, and/or distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH
// REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
// AND FITNESS.  IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT,
// INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
// LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
// OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
// PERFORMANCE OF THIS SOFTWARE.

#include <config.h>

#include <server_common/socket_request.h>

#include <gtest/gtest.h>

#include <config/tests/fake_session.h>
#include <config/ccsession.h>
#include <exceptions/exceptions.h>

#include <server_common/tests/data_path.h>

#include <cstdlib>
#include <cstddef>
#include <cerrno>
#include <sys/socket.h>
#include <sys/un.h>

#include <boost/foreach.hpp>
#include <boost/scoped_ptr.hpp>

#include <util/io/fd.h>
#include <util/io/fd_share.h>

using namespace isc::data;
using namespace isc::config;
using namespace isc::server_common;
using namespace isc;

namespace {

// Check it throws an exception when it is not initialized
TEST(SocketRequestorAccess, unitialized) {
    // Make sure it is not initialized
    initTestSocketRequestor(NULL);
    EXPECT_THROW(socketRequestor(), InvalidOperation);
}

// It returns whatever it is initialized to
TEST(SocketRequestorAccess, initialized) {
    // A concrete implementation that does nothing, just can exist
    class DummyRequestor : public SocketRequestor {
    public:
        DummyRequestor() : SocketRequestor() {}
        virtual void releaseSocket(const std::string&) {}
        virtual SocketID requestSocket(Protocol, const std::string&, uint16_t,
                                       ShareMode, const std::string&)
        {
            return (SocketID(0, "")); // Just to silence warnings
        }
    };
    DummyRequestor requestor;
    // Make sure it is initialized (the test way, of course)
    initTestSocketRequestor(&requestor);
    // It returs the same "pointer" as inserted
    // The casts are there as the template system seemed to get confused
    // without them, the types should be correct even without them, but
    // the EXPECT_EQ wanted to use long long int instead of pointers.
    EXPECT_EQ(static_cast<const SocketRequestor*>(&requestor),
              static_cast<const SocketRequestor*>(&socketRequestor()));
    // Just that we don't have an invalid pointer anyway
    initTestSocketRequestor(NULL);
}

// This class contains a fake (module)ccsession to emulate answers from Boss
class SocketRequestorTest : public ::testing::Test {
public:
    SocketRequestorTest() : session(ElementPtr(new ListElement),
                                    ElementPtr(new ListElement),
                                    ElementPtr(new ListElement)),
                            specfile(std::string(TEST_DATA_PATH) +
                                     "/spec.spec")
    {
        session.getMessages()->add(createAnswer());
        cc_session.reset(new ModuleCCSession(specfile, session, NULL, NULL,
                                             false, false));
        initSocketReqeustor(*cc_session);
    }

    ~SocketRequestorTest() {
        cleanupSocketRequestor();
    }

    // Do a standard request with some default values
    SocketRequestor::SocketID
    doRequest() {
        return (socketRequestor().requestSocket(SocketRequestor::UDP,
                                                "192.0.2.1", 12345,
                                                SocketRequestor::DONT_SHARE,
                                                "test"));
    }

    // Creates a valid socket request answer, as it would be sent by
    // Boss. 'valid' in terms of format, not values
    void
    addAnswer(const std::string& token, const std::string& path) {
        ElementPtr answer_part = Element::createMap();
        answer_part->set("token", Element::create(token));
        answer_part->set("path", Element::create(path));
        session.getMessages()->add(createAnswer(0, answer_part));
    }

    // Clears the messages the client sent so far on the fake msgq
    // (for easier access to new messages later)
    void
    clearMsgQueue() {
        while (session.getMsgQueue()->size() > 0) {
            session.getMsgQueue()->remove(0);
        }
    }

    isc::cc::FakeSession session;
    boost::scoped_ptr<ModuleCCSession> cc_session;
    const std::string specfile;
};

// helper function to create the request packet as we expect the
// socket requestor to send
ConstElementPtr
createExpectedRequest(const std::string& address,
                      int port,
                      const std::string& protocol,
                      const std::string& share_mode,
                      const std::string& share_name)
{
    // create command arguments
    const ElementPtr command_args = Element::createMap();
    command_args->set("address", Element::create(address));
    command_args->set("port", Element::create(port));
    command_args->set("protocol", Element::create(protocol));
    command_args->set("share_mode", Element::create(share_mode));
    command_args->set("share_name", Element::create(share_name));

    // create the envelope
    const ElementPtr packet = Element::createList();
    packet->add(Element::create("Boss"));
    packet->add(Element::create("*"));
    packet->add(createCommand("get_socket", command_args));

    return (packet);
}

TEST_F(SocketRequestorTest, testSocketRequestMessages) {
    // For each request, it will raise CCSessionError, since we don't
    // answer here.
    // We are only testing the request messages that are sent,
    // so for this test that is no problem
    clearMsgQueue();
    ConstElementPtr expected_request;

    expected_request = createExpectedRequest("192.0.2.1", 12345, "UDP",
                                             "NO", "test");
    ASSERT_THROW(socketRequestor().requestSocket(SocketRequestor::UDP,
                                   "192.0.2.1", 12345,
                                   SocketRequestor::DONT_SHARE,
                                   "test"),
                 CCSessionError);
    ASSERT_EQ(1, session.getMsgQueue()->size());
    ASSERT_EQ(*expected_request, *(session.getMsgQueue()->get(0)));

    clearMsgQueue();
    expected_request = createExpectedRequest("192.0.2.2", 1, "TCP",
                                             "ANY", "test2");
    ASSERT_THROW(socketRequestor().requestSocket(SocketRequestor::TCP,
                                   "192.0.2.2", 1,
                                   SocketRequestor::SHARE_ANY,
                                   "test2"),
                 CCSessionError);
    ASSERT_EQ(1, session.getMsgQueue()->size());
    ASSERT_EQ(*expected_request, *(session.getMsgQueue()->get(0)));

    clearMsgQueue();
    expected_request = createExpectedRequest("::1", 2, "UDP",
                                             "SAMEAPP", "test3");
    ASSERT_THROW(socketRequestor().requestSocket(SocketRequestor::UDP,
                                   "::1", 2,
                                   SocketRequestor::SHARE_SAME,
                                   "test3"),
                 CCSessionError);
    ASSERT_EQ(1, session.getMsgQueue()->size());
    ASSERT_EQ(*expected_request, *(session.getMsgQueue()->get(0)));
}

TEST_F(SocketRequestorTest, invalidParameterForSocketRequest) {
    // Bad protocol
    EXPECT_THROW(socketRequestor().
                 requestSocket(static_cast<SocketRequestor::Protocol>(2),
                               "192.0.2.1", 12345,
                               SocketRequestor::DONT_SHARE,
                               "test"),
                 InvalidParameter);

    // Bad share mode
    EXPECT_THROW(socketRequestor().
                 requestSocket(SocketRequestor::UDP,
                               "192.0.2.1", 12345,
                               static_cast<SocketRequestor::ShareMode>(3),
                               "test"),
                 InvalidParameter);
}

TEST_F(SocketRequestorTest, testBadRequestAnswers) {
    // Test various scenarios where the requestor gets back bad answers

    // Should raise CCSessionError if there is no answer
    ASSERT_THROW(doRequest(), CCSessionError);

    // Also if the answer does not match the format
    session.getMessages()->add(createAnswer());
    ASSERT_THROW(doRequest(), CCSessionError);

    // Now a 'real' answer, should fail on socket connect (no such file)
    addAnswer("foo", "/does/not/exist");
    ASSERT_THROW(doRequest(), SocketRequestor::SocketError);

    // Another failure (domain socket path too long)
    addAnswer("foo", std::string(1000, 'x'));
    ASSERT_THROW(doRequest(), SocketRequestor::SocketError);

    // Test values around path boundary
    struct sockaddr_un sock_un;
    const std::string max_len(sizeof(sock_un.sun_path) - 1, 'x');
    addAnswer("foo", max_len);
    // The failure should NOT contain 'too long'
    // (explicitly checking for existance of nonexistence of 'too long',
    // as opposed to the actual error, since 'too long' is a value we set).
    try {
        doRequest();
        FAIL() << "doRequest did not throw an exception";
    } catch (const SocketRequestor::SocketError& se) {
        ASSERT_EQ(std::string::npos, std::string(se.what()).find("too long"));
    }

    const std::string too_long(sizeof(sock_un.sun_path), 'x');
    addAnswer("foo", too_long);
    // The failure SHOULD contain 'too long'
    try {
        doRequest();
        FAIL() << "doRequest did not throw an exception";
    } catch (const SocketRequestor::SocketError& se) {
        ASSERT_NE(std::string::npos, std::string(se.what()).find("too long"));
    }

    // Send back an error response
    session.getMessages()->add(createAnswer(1, "error"));
    ASSERT_THROW(doRequest(), CCSessionError);
}

// Helper function to create the release commands as we expect
// them to be sent by the SocketRequestor class
ConstElementPtr
createExpectedRelease(const std::string& token) {
    // create command arguments
    const ElementPtr command_args = Element::createMap();
    command_args->set("token", Element::create(token));

    // create the envelope
    const ElementPtr packet = Element::createList();
    packet->add(Element::create("Boss"));
    packet->add(Element::create("*"));
    packet->add(createCommand("drop_socket", command_args));

    return (packet);
}

TEST_F(SocketRequestorTest, testSocketReleaseMessages) {
    ConstElementPtr expected_release;

    session.getMessages()->add(createAnswer());

    clearMsgQueue();
    expected_release = createExpectedRelease("foo");
    socketRequestor().releaseSocket("foo");
    ASSERT_EQ(1, session.getMsgQueue()->size());
    ASSERT_EQ(*expected_release, *(session.getMsgQueue()->get(0)));

    session.getMessages()->add(createAnswer());
    clearMsgQueue();
    expected_release = createExpectedRelease("bar");
    socketRequestor().releaseSocket("bar");
    ASSERT_EQ(1, session.getMsgQueue()->size());
    ASSERT_EQ(*expected_release, *(session.getMsgQueue()->get(0)));
}

TEST_F(SocketRequestorTest, testBadSocketReleaseAnswers) {
    // Should fail if there is no answer at all
    ASSERT_THROW(socketRequestor().releaseSocket("bar"),
                 CCSessionError);

    // Should also fail if the answer is an error
    session.getMessages()->add(createAnswer(1, "error"));
    ASSERT_THROW(socketRequestor().releaseSocket("bar"),
                 SocketRequestor::SocketError);
}

// A helper function to impose a read timeout for the server socket
// in order to avoid deadlock when the client side has a bug and doesn't
// send expected data.
// It returns true when the timeout is set successfully; otherwise false.
bool
setRecvTimo(int s) {
    const struct timeval timeo = { 10, 0 }; // 10sec, arbitrary choice
    if (setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, &timeo, sizeof(timeo)) == 0) {
        return (true);
    }
    if (errno == ENOPROTOOPT) { // deviant OS, give up using it.
        return (false);
    }
    isc_throw(isc::Unexpected, "set RCVTIMEO failed: " << strerror(errno));
}

// Helper test class that creates a randomly named domain socket
// Upon init, it will only reserve the name (and place an empty file in its
// place).
// When run() is called, it creates the socket, forks, and the child will
// listen for a connection, then send all the data passed to run to that
// connection, and then close the socket
class TestSocket {
public:
    TestSocket() : fd_(-1) {
        path_ = strdup("test_socket.XXXXXX");
        // Misuse mkstemp to generate a file name.
        const int f = mkstemp(path_);
        if (f == -1) {
            isc_throw(Unexpected, "mkstemp failed: " << strerror(errno));
        }
        // Just need the name, so immediately close
        close(f);
    }

    ~TestSocket() {
        cleanup();
    }

    void
    cleanup() {
        unlink(path_);
        if (path_ != NULL) {
            free(path_);
            path_ = NULL;
        }
        if (fd_ != -1) {
            close(fd_);
            fd_ = -1;
        }
    }

    // Returns the path used for the socket
    const char* getPath() const {
        return (path_);
    }

    // create socket, fork, and serve if child (child will exit when done).
    // If the underlying system doesn't allow to set read timeout, tell the
    // caller that via a false return value so that the caller can avoid
    // performing tests that could result in a dead lock.
    bool run(const std::vector<std::pair<std::string, int> >& data) {
        create();
        const bool timo_ok = setRecvTimo(fd_);
        const int child_pid = fork();
        if (child_pid == 0) {
            serve(data);
            exit(0);
        } else {
            // parent does not need fd anymore
            close(fd_);
            fd_ = -1;
        }
        return (timo_ok);
    }
private:
    // Actually create the socket and listen on it
    void
    create() {
        fd_ = socket(AF_UNIX, SOCK_STREAM, 0);
        if (fd_ == -1) {
            isc_throw(Unexpected, "Unable to create socket");
        }
        struct sockaddr_un socket_address;
        socket_address.sun_family = AF_UNIX;
        socklen_t len = strlen(path_);
        if (len > sizeof(socket_address.sun_path)) {
            isc_throw(Unexpected,
                      "mkstemp() created a filename too long for sun_path");
        }
        strncpy(socket_address.sun_path, path_, len);
#ifdef HAVE_SA_LEN
        socket_address.sun_len = len;
#endif

        len += offsetof(struct sockaddr_un, sun_path);
        // Remove the random file we created so we can reuse it for
        // a domain socket connection. This contains a minor race condition
        // but for the purposes of this test it should be small enough
        unlink(path_);
        if (bind(fd_, (const struct sockaddr*)&socket_address, len) == -1) {
            isc_throw(Unexpected,
                      "unable to bind to test domain socket " << path_ <<
                      ": " << strerror(errno));
        }

        if (listen(fd_, 1) == -1) {
            isc_throw(Unexpected,
                      "unable to listen on test domain socket " << path_ <<
                      ": " << strerror(errno));
        }
    }

    // Accept one connection, then for each value of the vector,
    // read the socket token from the connection and match the string
    // part of the vector element, and send the integer part of the element
    // using send_fd() (prepended by a status code 'ok').  For simplicity
    // we assume the tokens are 4 bytes long; if the test case uses a
    // different size of token the test will fail.
    //
    // There are a few specific exceptions;
    // when the value is -1, it will send back an error value (signaling
    // CREATOR_SOCKET_UNAVAILABLE)
    // when the value is -2, it will send a byte signaling CREATOR_SOCKET_OK
    // first, and then one byte from some string (i.e. bad data, not using
    // send_fd())
    //
    // NOTE: client_fd could leak on exception.  This should be cleaned up.
    // See the note about SocketSessionReceiver in socket_request.cc.
    void
    serve(const std::vector<std::pair<std::string, int> > data) {
        const int client_fd = accept(fd_, NULL, NULL);
        if (client_fd == -1) {
            isc_throw(Unexpected, "Error in accept(): " << strerror(errno));
        }
        if (!setRecvTimo(client_fd)) {
            // In the loop below we do blocking read.  To avoid deadlock
            // when the parent is buggy we'll skip it unless we can
            // set a read timeout on the socket.
            return;
        }
        typedef std::pair<std::string, int> DataPair;
        BOOST_FOREACH(DataPair cur_data, data) {
            char buf[5];
            memset(buf, 0, 5);
            if (isc::util::io::read_data(client_fd, buf, 4) != 4) {
                isc_throw(Unexpected, "unable to receive socket token");
            }
            if (cur_data.first != buf) {
                isc_throw(Unexpected, "socket token mismatch: expected="
                          << cur_data.first << ", actual=" << buf);
            }

            bool result;
            if (cur_data.second == -1) {
                // send 'CREATOR_SOCKET_UNAVAILABLE'
                result = isc::util::io::write_data(client_fd, "0\n", 2);
            } else if (cur_data.second == -2) {
                // send 'CREATOR_SOCKET_OK' first
                result = isc::util::io::write_data(client_fd, "1\n", 2);
                if (result) {
                    if (send(client_fd, "a", 1, 0) != 1) {
                        result = false;
                    }
                }
            } else {
                // send 'CREATOR_SOCKET_OK' first
                result = isc::util::io::write_data(client_fd, "1\n", 2);
                if (result) {
                    if (isc::util::io::send_fd(client_fd,
                                               cur_data.second) != 0) {
                        result = false;
                    }
                }
            }
            if (!result) {
                isc_throw(Exception, "Error in send_fd(): " <<
                          strerror(errno));
            }
        }
        close(client_fd);
    }

    int fd_;
    char* path_;
};

TEST_F(SocketRequestorTest, testSocketPassing) {
    TestSocket ts;
    std::vector<std::pair<std::string, int> > data;
    data.push_back(std::pair<std::string, int>("foo\n", 1));
    data.push_back(std::pair<std::string, int>("bar\n", 2));
    data.push_back(std::pair<std::string, int>("foo\n", 3));
    data.push_back(std::pair<std::string, int>("foo\n", 1));
    data.push_back(std::pair<std::string, int>("foo\n", -1));
    data.push_back(std::pair<std::string, int>("foo\n", -2));

    // run() returns true iff we can specify read timeout so we avoid a
    // deadlock.  Unless there's a bug the test should succeed even without the
    // timeout, but we don't want to make the test hang up in case with an
    // unexpected bug, so we'd rather skip most of the tests in that case.
    const bool timo_ok = ts.run(data);
    SocketRequestor::SocketID socket_id;
    if (timo_ok) {
        // 1 should be ok
        addAnswer("foo", ts.getPath());
        socket_id = doRequest();
        ASSERT_EQ("foo", socket_id.second);
        ASSERT_EQ(0, close(socket_id.first));

        // 2 should be ok too
        addAnswer("bar", ts.getPath());
        socket_id = doRequest();
        ASSERT_EQ("bar", socket_id.second);
        ASSERT_EQ(0, close(socket_id.first));

        // 3 should be ok too (reuse earlier token)
        addAnswer("foo", ts.getPath());
        socket_id = doRequest();
        ASSERT_EQ("foo", socket_id.second);
        ASSERT_EQ(0, close(socket_id.first));
    }

    // Create a second socket server, to test that multiple different
    // domains sockets would work as well (even though we don't actually
    // use that feature)
    TestSocket ts2;
    std::vector<std::pair<std::string, int> > data2;
    data2.push_back(std::pair<std::string, int>("foo\n", 1));
    const bool timo_ok2 = ts2.run(data2);

    if (timo_ok2) {
        // 1 should be ok
        addAnswer("foo", ts2.getPath());
        socket_id = doRequest();
        ASSERT_EQ("foo", socket_id.second);
        ASSERT_EQ(0, close(socket_id.first));
    }

    if (timo_ok) {
        // Now use first socket again
        addAnswer("foo", ts.getPath());
        socket_id = doRequest();
        ASSERT_EQ("foo", socket_id.second);
        ASSERT_EQ(0, close(socket_id.first));

        // -1 is a "normal" error
        addAnswer("foo", ts.getPath());
        ASSERT_THROW(doRequest(), SocketRequestor::SocketError);

        // -2 is an unexpected error.  After this point it's not guaranteed the
        // connection works as intended.
        addAnswer("foo", ts.getPath());
        ASSERT_THROW(doRequest(), SocketRequestor::SocketError);
    }

    // Vector is of first socket is now empty, so the socket should be gone
    addAnswer("foo", ts.getPath());
    ASSERT_THROW(doRequest(), SocketRequestor::SocketError);

    // Vector is of second socket is now empty too, so the socket should be
    // gone
    addAnswer("foo", ts2.getPath());
    ASSERT_THROW(doRequest(), SocketRequestor::SocketError);
}

}