Browse Source

[5099] Handle HTTP timeouts.

Marcin Siodelski 8 years ago
parent
commit
4a9d027f7d

+ 52 - 57
src/lib/http/connection.cc

@@ -18,6 +18,9 @@ namespace http {
 void
 HttpConnection::
 SocketCallback::operator()(boost::system::error_code ec, size_t length) {
+    if (ec.value() == boost::asio::error::operation_aborted) {
+        return;
+    }
     callback_(ec, length);
 }
 
@@ -25,8 +28,11 @@ HttpConnection:: HttpConnection(asiolink::IOService& io_service,
                                 HttpAcceptor& acceptor,
                                 HttpConnectionPool& connection_pool,
                                 const HttpResponseCreatorPtr& response_creator,
-                                const HttpAcceptorCallback& callback)
-    : socket_(io_service),
+                                const HttpAcceptorCallback& callback,
+                                const long request_timeout)
+    : request_timer_(io_service),
+      request_timeout_(request_timeout),
+      socket_(io_service),
       socket_callback_(boost::bind(&HttpConnection::socketReadCallback, this,
                                    _1, _2)),
       socket_write_callback_(boost::bind(&HttpConnection::socketWriteCallback,
@@ -46,6 +52,19 @@ HttpConnection::~HttpConnection() {
 }
 
 void
+HttpConnection::close() {
+    socket_.close();
+}
+
+void
+HttpConnection::stopThisConnection() {
+    try {
+        connection_pool_.stop(shared_from_this());
+    } catch (...) {
+    }
+}
+
+void
 HttpConnection::asyncAccept() {
     HttpAcceptorCallback cb = boost::bind(&HttpConnection::acceptorCallback,
                                           this, _1);
@@ -59,11 +78,6 @@ HttpConnection::asyncAccept() {
 }
 
 void
-HttpConnection::close() {
-    socket_.close();
-}
-
-void
 HttpConnection::doRead() {
     try {
         TCPEndpoint endpoint;
@@ -71,52 +85,47 @@ HttpConnection::doRead() {
                              0, &endpoint, socket_callback_);
 
     } catch (const std::exception& ex) {
-        isc_throw(HttpConnectionError, "unable to start asynchronous HTTP message"
-                  " receive over TCP socket: " << ex.what());
+        stopThisConnection();
     }
 }
 
 void
 HttpConnection::doWrite() {
-    if (!output_buf_.empty()) {
-        try {
+    try {
+        if (!output_buf_.empty()) {
             socket_.asyncSend(output_buf_.data(),
                               output_buf_.length(),
                               socket_write_callback_);
-
-        } catch (const std::exception& ex) {
-            isc_throw(HttpConnectionError, "unable to start asynchronous HTTP"
-                      " message write over TCP socket: " << ex.what());
         }
+    } catch (const std::exception& ex) {
+        stopThisConnection();
     }
 }
 
 void
+HttpConnection::asyncSendResponse(const ConstHttpResponsePtr& response) {
+    output_buf_ = response->toString();
+    doWrite();
+}
+
+
+void
 HttpConnection::acceptorCallback(const boost::system::error_code& ec) {
     if (!acceptor_.isOpen()) {
         return;
     }
 
-    try {
-        if (ec) {
-            connection_pool_.stop(shared_from_this());
-        }
-    } catch (...) {
+    if (ec) {
+        stopThisConnection();
     }
 
     acceptor_callback_(ec);
 
-    try {
-        if (!ec) {
-            doRead();
-        }
-    } catch (const std::exception& ex) {
-        try {
-            connection_pool_.stop(shared_from_this());
-        } catch (...) {
-        }
+    if (!ec) {
+        request_timer_.setup(boost::bind(&HttpConnection::requestTimeoutCallback, this),
+                             request_timeout_, IntervalTimer::ONE_SHOT);
+        doRead();
     }
-
 }
 
 void
@@ -125,31 +134,16 @@ HttpConnection::socketReadCallback(boost::system::error_code ec, size_t length)
     parser_->postBuffer(static_cast<void*>(buf_.data()), length);
     parser_->poll();
     if (parser_->needData()) {
-        try {
-            doRead();
-        } catch (const std::exception& ex) {
-            try {
-                connection_pool_.stop(shared_from_this());
-            } catch (...) {
-            }
-        }
+        doRead();
 
     } else {
         try {
             request_->finalize();
         } catch (...) {
         }
-        HttpResponsePtr response = response_creator_->createHttpResponse(request_);
-        output_buf_ = response->toString();
-        try {
-            doWrite();
 
-        } catch (const std::exception& ex) {
-            try {
-                connection_pool_.stop(shared_from_this());
-            } catch (...) {
-            }
-        }
+        HttpResponsePtr response = response_creator_->createHttpResponse(request_);
+        asyncSendResponse(response);
     }
 }
 
@@ -158,21 +152,22 @@ HttpConnection::socketWriteCallback(boost::system::error_code ec,
                                     size_t length) {
     if (length <= output_buf_.size()) {
         output_buf_.erase(0, length);
-        try {
-            doWrite();
-
-        } catch (const std::exception& ex) {
-            try {
-                connection_pool_.stop(shared_from_this());
-            } catch (...) {
-            }
-        }
+        doWrite();
 
     } else {
         output_buf_.clear();
     }
 }
 
+void
+HttpConnection::requestTimeoutCallback() {
+    HttpResponsePtr response =
+        response_creator_->createStockHttpResponse(request_,
+                                                   HttpStatusCode::REQUEST_TIMEOUT);
+    asyncSendResponse(response);
+}
+
+
 } // end of namespace isc::http
 } // end of namespace isc
 

+ 13 - 1
src/lib/http/connection.h

@@ -7,6 +7,7 @@
 #ifndef HTTP_CONNECTION_H
 #define HTTP_CONNECTION_H
 
+#include <asiolink/interval_timer.h>
 #include <asiolink/io_service.h>
 #include <http/http_acceptor.h>
 #include <http/request_parser.h>
@@ -58,7 +59,8 @@ public:
                    HttpAcceptor& acceptor,
                    HttpConnectionPool& connection_pool,
                    const HttpResponseCreatorPtr& response_creator,
-                   const HttpAcceptorCallback& callback);
+                   const HttpAcceptorCallback& callback,
+                   const long request_timeout);
 
     ~HttpConnection();
 
@@ -78,8 +80,18 @@ public:
     void socketWriteCallback(boost::system::error_code ec,
                              size_t length);
 
+    void requestTimeoutCallback();
+
 private:
 
+    void asyncSendResponse(const ConstHttpResponsePtr& response);
+
+    void stopThisConnection();
+
+    asiolink::IntervalTimer request_timer_;
+
+    long request_timeout_;
+
     asiolink::TCPSocket<SocketCallback> socket_;
 
     SocketCallback socket_callback_;

+ 1 - 1
src/lib/http/connection_pool.cc

@@ -19,7 +19,7 @@ HttpConnectionPool::start(const HttpConnectionPtr& connection) {
 void
 HttpConnectionPool::stop(const HttpConnectionPtr& connection) {
     connections_.erase(connection);
-    connection->close();
+//    connection->close();
 }
 
 void

+ 11 - 3
src/lib/http/listener.cc

@@ -15,9 +15,11 @@ namespace http {
 HttpListener::HttpListener(IOService& io_service,
                            const asiolink::IOAddress& server_address,
                            const unsigned short server_port,
-                           const HttpResponseCreatorFactoryPtr& creator_factory)
+                           const HttpResponseCreatorFactoryPtr& creator_factory,
+                           const long request_timeout)
     : io_service_(io_service), acceptor_(io_service),
-      endpoint_(), creator_factory_(creator_factory) {
+      endpoint_(), creator_factory_(creator_factory),
+      request_timeout_(request_timeout) {
     try {
         endpoint_.reset(new TCPEndpoint(server_address, server_port));
 
@@ -29,6 +31,11 @@ HttpListener::HttpListener(IOService& io_service,
         isc_throw(HttpListenerError, "HttpResponseCreatorFactory must not"
                   " be null");
     }
+
+    if (request_timeout_ <= 0) {
+        isc_throw(HttpListenerError, "Invalid desired HTTP request timeout "
+                  << request_timeout_);
+    }
 }
 
 HttpListener::~HttpListener() {
@@ -65,7 +72,8 @@ HttpListener::accept() {
     HttpConnectionPtr conn(new HttpConnection(io_service_, acceptor_,
                                               connections_,
                                               response_creator,
-                                              acceptor_callback));
+                                              acceptor_callback,
+                                              request_timeout_));
     connections_.start(conn);
 }
 

+ 3 - 2
src/lib/http/listener.h

@@ -32,7 +32,8 @@ public:
     HttpListener(asiolink::IOService& io_service,
                  const asiolink::IOAddress& server_address,
                  const unsigned short server_port,
-                 const HttpResponseCreatorFactoryPtr& creator_factory);
+                 const HttpResponseCreatorFactoryPtr& creator_factory,
+                 const long request_timeout);
 
     ~HttpListener();
 
@@ -51,7 +52,7 @@ private:
     boost::scoped_ptr<asiolink::TCPEndpoint> endpoint_;
     HttpConnectionPool connections_;
     HttpResponseCreatorFactoryPtr creator_factory_;
-
+    long request_timeout_;
 };
 
 } // end of namespace isc::http

+ 1 - 0
src/lib/http/response.cc

@@ -29,6 +29,7 @@ const std::map<HttpStatusCode, std::string> status_code_to_description = {
     { HttpStatusCode::UNAUTHORIZED, "Unauthorized" },
     { HttpStatusCode::FORBIDDEN, "Forbidden" },
     { HttpStatusCode::NOT_FOUND, "Not Found" },
+    { HttpStatusCode::REQUEST_TIMEOUT, "Request Timeout" },
     { HttpStatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error" },
     { HttpStatusCode::NOT_IMPLEMENTED, "Not Implemented" },
     { HttpStatusCode::BAD_GATEWAY, "Bad Gateway" },

+ 1 - 0
src/lib/http/response.h

@@ -39,6 +39,7 @@ enum class HttpStatusCode : std::uint16_t {
     UNAUTHORIZED = 401,
     FORBIDDEN = 403,
     NOT_FOUND = 404,
+    REQUEST_TIMEOUT = 408,
     INTERNAL_SERVER_ERROR = 500,
     NOT_IMPLEMENTED = 501,
     BAD_GATEWAY = 502,

+ 1 - 1
src/lib/http/response_creator.cc

@@ -19,7 +19,7 @@ HttpResponseCreator::createHttpResponse(const ConstHttpRequestPtr& request) {
 
     // If not finalized, the request parsing failed. Generate HTTP 400.
     if (!request->isFinalized()) {
-        return (createStockBadRequest(request));
+        return (createStockHttpResponse(request, HttpStatusCode::BAD_REQUEST));
     }
 
     // Message has been successfully parsed. Create implementation specific

+ 7 - 5
src/lib/http/response_creator.h

@@ -87,14 +87,16 @@ public:
     virtual HttpRequestPtr
     createNewHttpRequest() const = 0;
 
-protected:
-
-    /// @brief Creates implementation specific HTTP 400 response.
+    /// @brief Creates implementation specific HTTP response.
     ///
     /// @param request Pointer to an object representing HTTP request.
-    /// @return Pointer to an object representing HTTP 400 response.
+    /// @param status_code Status code of the response.
+    /// @return Pointer to an object representing HTTP response.
     virtual HttpResponsePtr
-    createStockBadRequest(const ConstHttpRequestPtr& request) const = 0;
+    createStockHttpResponse(const ConstHttpRequestPtr& request,
+                            const HttpStatusCode& status_code) const = 0;
+
+protected:
 
     /// @brief Creates implementation specific HTTP response.
     ///

+ 39 - 9
src/lib/http/tests/listener_unittests.cc

@@ -28,6 +28,7 @@ namespace {
 
 const std::string SERVER_ADDRESS = "127.0.0.1";
 const unsigned short SERVER_PORT = 18123;
+const long REQUEST_TIMEOUT = 20000;
 
 /// @brief Test timeout in ms.
 const long TEST_TIMEOUT = 10000;
@@ -52,12 +53,13 @@ public:
 
 private:
 
-    /// @brief Creates HTTP 400 response.
+    /// @brief Creates HTTP response.
     ///
     /// @param request Pointer to the HTTP request.
-    /// @return Pointer to the generated HTTP 400 response.
+    /// @return Pointer to the generated HTTP response.
     virtual HttpResponsePtr
-    createStockBadRequest(const ConstHttpRequestPtr& request) const {
+    createStockHttpResponse(const ConstHttpRequestPtr& request,
+                            const HttpStatusCode& status_code) const {
         // The request hasn't been finalized so the request object
         // doesn't contain any information about the HTTP version number
         // used. But, the context should have this data (assuming the
@@ -65,8 +67,7 @@ private:
         HttpVersion http_version(request->context()->http_version_major_,
                                  request->context()->http_version_minor_);
         // This will generate the response holding JSON content.
-        ResponsePtr response(new Response(http_version,
-                                          HttpStatusCode::BAD_REQUEST));
+        ResponsePtr response(new Response(http_version, status_code));
         return (response);
     }
 
@@ -261,7 +262,7 @@ TEST_F(HttpListenerTest, listen) {
         "{ }";
 
     HttpListener listener(io_service_, IOAddress(SERVER_ADDRESS), SERVER_PORT,
-                          factory_);
+                          factory_, REQUEST_TIMEOUT);
     ASSERT_NO_THROW(listener.start());
     ASSERT_NO_THROW(startRequest(request));
     ASSERT_NO_THROW(io_service_.run());
@@ -284,7 +285,7 @@ TEST_F(HttpListenerTest, badRequest) {
         "{ }";
 
     HttpListener listener(io_service_, IOAddress(SERVER_ADDRESS), SERVER_PORT,
-                          factory_);
+                          factory_, REQUEST_TIMEOUT);
     ASSERT_NO_THROW(listener.start());
     ASSERT_NO_THROW(startRequest(request));
     ASSERT_NO_THROW(io_service_.run());
@@ -302,7 +303,14 @@ TEST_F(HttpListenerTest, badRequest) {
 
 TEST_F(HttpListenerTest, invalidFactory) {
     EXPECT_THROW(HttpListener(io_service_, IOAddress(SERVER_ADDRESS),
-                              SERVER_PORT, HttpResponseCreatorFactoryPtr()),
+                              SERVER_PORT, HttpResponseCreatorFactoryPtr(),
+                              REQUEST_TIMEOUT),
+                 HttpListenerError);
+}
+
+TEST_F(HttpListenerTest, invalidRequestTimeout) {
+    EXPECT_THROW(HttpListener(io_service_, IOAddress(SERVER_ADDRESS),
+                              SERVER_PORT, factory_, 0),
                  HttpListenerError);
 }
 
@@ -315,8 +323,30 @@ TEST_F(HttpListenerTest, addressInUse) {
     acceptor.bind(endpoint);
 
     HttpListener listener(io_service_, IOAddress(SERVER_ADDRESS),
-                          SERVER_PORT + 1, factory_);
+                          SERVER_PORT + 1, factory_, REQUEST_TIMEOUT);
     EXPECT_THROW(listener.start(), HttpListenerError);
 }
 
+TEST_F(HttpListenerTest, requestTimeout) {
+    const std::string request = "POST /foo/bar HTTP/1.1\r\n"
+        "Content-Type: foo\r\n"
+        "Content-Length:";
+
+    HttpListener listener(io_service_, IOAddress(SERVER_ADDRESS), SERVER_PORT,
+                          factory_, 1000);
+    ASSERT_NO_THROW(listener.start());
+    ASSERT_NO_THROW(startRequest(request));
+    ASSERT_NO_THROW(io_service_.run());
+    ASSERT_EQ(1, clients_.size());
+    HttpClientPtr client = *clients_.begin();
+    ASSERT_TRUE(client);
+    EXPECT_EQ("HTTP/1.1 408 Request Timeout\r\n"
+              "Content-Length: 44\r\n"
+              "Content-Type: application/json\r\n"
+              "Date: Tue, 19 Dec 2016 18:53:35 GMT\r\n"
+              "\r\n"
+              "{ \"result\": 408, \"text\": \"Request Timeout\" }",
+              client->getResponse());
+}
+
 }

+ 5 - 5
src/lib/http/tests/response_creator_unittests.cc

@@ -39,12 +39,13 @@ public:
 
 private:
 
-    /// @brief Creates HTTP 400 response.
+    /// @brief Creates HTTP response..
     ///
     /// @param request Pointer to the HTTP request.
-    /// @return Pointer to the generated HTTP 400 response.
+    /// @return Pointer to the generated HTTP response.
     virtual HttpResponsePtr
-    createStockBadRequest(const ConstHttpRequestPtr& request) const {
+    createStockHttpResponse(const ConstHttpRequestPtr& request,
+                            const HttpStatusCode& status_code) const {
         // The request hasn't been finalized so the request object
         // doesn't contain any information about the HTTP version number
         // used. But, the context should have this data (assuming the
@@ -52,8 +53,7 @@ private:
         HttpVersion http_version(request->context()->http_version_major_,
                                  request->context()->http_version_minor_);
         // This will generate the response holding JSON content.
-        ResponsePtr response(new Response(http_version,
-                                          HttpStatusCode::BAD_REQUEST));
+        ResponsePtr response(new Response(http_version, status_code));
         return (response);
     }
 

+ 1 - 0
src/lib/http/tests/response_json_unittests.cc

@@ -135,6 +135,7 @@ TEST_F(HttpResponseJsonTest, genericResponse) {
     testGenericResponse(HttpStatusCode::UNAUTHORIZED, "Unauthorized");
     testGenericResponse(HttpStatusCode::FORBIDDEN, "Forbidden");
     testGenericResponse(HttpStatusCode::NOT_FOUND, "Not Found");
+    testGenericResponse(HttpStatusCode::REQUEST_TIMEOUT, "Request Timeout");
     testGenericResponse(HttpStatusCode::INTERNAL_SERVER_ERROR,
                         "Internal Server Error");
     testGenericResponse(HttpStatusCode::NOT_IMPLEMENTED, "Not Implemented");

+ 3 - 0
src/lib/http/tests/response_unittests.cc

@@ -90,6 +90,7 @@ TEST_F(HttpResponseTest, genericResponse) {
     testResponse(HttpStatusCode::UNAUTHORIZED, "Unauthorized");
     testResponse(HttpStatusCode::FORBIDDEN, "Forbidden");
     testResponse(HttpStatusCode::NOT_FOUND, "Not Found");
+    testResponse(HttpStatusCode::REQUEST_TIMEOUT, "Request Timeout");
     testResponse(HttpStatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error");
     testResponse(HttpStatusCode::NOT_IMPLEMENTED, "Not Implemented");
     testResponse(HttpStatusCode::BAD_GATEWAY, "Bad Gateway");
@@ -110,6 +111,7 @@ TEST_F(HttpResponseTest, isClientError) {
     EXPECT_TRUE(HttpResponse::isClientError(HttpStatusCode::UNAUTHORIZED));
     EXPECT_TRUE(HttpResponse::isClientError(HttpStatusCode::FORBIDDEN));
     EXPECT_TRUE(HttpResponse::isClientError(HttpStatusCode::NOT_FOUND));
+    EXPECT_TRUE(HttpResponse::isClientError(HttpStatusCode::REQUEST_TIMEOUT));
     EXPECT_FALSE(HttpResponse::isClientError(HttpStatusCode::INTERNAL_SERVER_ERROR));
     EXPECT_FALSE(HttpResponse::isClientError(HttpStatusCode::NOT_IMPLEMENTED));
     EXPECT_FALSE(HttpResponse::isClientError(HttpStatusCode::BAD_GATEWAY));
@@ -130,6 +132,7 @@ TEST_F(HttpResponseTest, isServerError) {
     EXPECT_FALSE(HttpResponse::isServerError(HttpStatusCode::UNAUTHORIZED));
     EXPECT_FALSE(HttpResponse::isServerError(HttpStatusCode::FORBIDDEN));
     EXPECT_FALSE(HttpResponse::isServerError(HttpStatusCode::NOT_FOUND));
+    EXPECT_FALSE(HttpResponse::isServerError(HttpStatusCode::REQUEST_TIMEOUT));
     EXPECT_TRUE(HttpResponse::isServerError(HttpStatusCode::INTERNAL_SERVER_ERROR));
     EXPECT_TRUE(HttpResponse::isServerError(HttpStatusCode::NOT_IMPLEMENTED));
     EXPECT_TRUE(HttpResponse::isServerError(HttpStatusCode::BAD_GATEWAY));