pgsql_connection.cc 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. // Copyright (C) 2016-2017 Internet Systems Consortium, Inc. ("ISC")
  2. //
  3. // This Source Code Form is subject to the terms of the Mozilla Public
  4. // License, v. 2.0. If a copy of the MPL was not distributed with this
  5. // file, You can obtain one at http://mozilla.org/MPL/2.0/.
  6. #include <config.h>
  7. #include <dhcpsrv/dhcpsrv_log.h>
  8. #include <dhcpsrv/pgsql_connection.h>
  9. // PostgreSQL errors should be tested based on the SQL state code. Each state
  10. // code is 5 decimal, ASCII, digits, the first two define the category of
  11. // error, the last three are the specific error. PostgreSQL makes the state
  12. // code as a char[5]. Macros for each code are defined in PostgreSQL's
  13. // server/utils/errcodes.h, although they require a second macro,
  14. // MAKE_SQLSTATE for completion. For example, duplicate key error as:
  15. //
  16. // #define ERRCODE_UNIQUE_VIOLATION MAKE_SQLSTATE('2','3','5','0','5')
  17. //
  18. // PostgreSQL deliberately omits the MAKE_SQLSTATE macro so callers can/must
  19. // supply their own. We'll define it as an initialization list:
  20. #define MAKE_SQLSTATE(ch1,ch2,ch3,ch4,ch5) {ch1,ch2,ch3,ch4,ch5}
  21. // So we can use it like this: const char some_error[] = ERRCODE_xxxx;
  22. #define PGSQL_STATECODE_LEN 5
  23. #include <utils/errcodes.h>
  24. using namespace std;
  25. namespace isc {
  26. namespace dhcp {
  27. // Default connection timeout
  28. /// @todo: migrate this default timeout to src/bin/dhcpX/simple_parserX.cc
  29. const int PGSQL_DEFAULT_CONNECTION_TIMEOUT = 5; // seconds
  30. const char PgSqlConnection::DUPLICATE_KEY[] = ERRCODE_UNIQUE_VIOLATION;
  31. PgSqlResult::PgSqlResult(PGresult *result)
  32. : result_(result), rows_(0), cols_(0) {
  33. if (!result) {
  34. isc_throw (BadValue, "PgSqlResult result pointer cannot be null");
  35. }
  36. rows_ = PQntuples(result);
  37. cols_ = PQnfields(result);
  38. }
  39. void
  40. PgSqlResult::rowCheck(int row) const {
  41. if (row < 0 || row >= rows_) {
  42. isc_throw (DbOperationError, "row: " << row
  43. << ", out of range: 0.." << rows_);
  44. }
  45. }
  46. PgSqlResult::~PgSqlResult() {
  47. if (result_) {
  48. PQclear(result_);
  49. }
  50. }
  51. void
  52. PgSqlResult::colCheck(int col) const {
  53. if (col < 0 || col >= cols_) {
  54. isc_throw (DbOperationError, "col: " << col
  55. << ", out of range: 0.." << cols_);
  56. }
  57. }
  58. void
  59. PgSqlResult::rowColCheck(int row, int col) const {
  60. rowCheck(row);
  61. colCheck(col);
  62. }
  63. std::string
  64. PgSqlResult::getColumnLabel(const int col) const {
  65. const char* label = NULL;
  66. try {
  67. colCheck(col);
  68. label = PQfname(result_, col);
  69. } catch (...) {
  70. std::ostringstream os;
  71. os << "Unknown column:" << col;
  72. return (os.str());
  73. }
  74. return (label);
  75. }
  76. PgSqlTransaction::PgSqlTransaction(PgSqlConnection& conn)
  77. : conn_(conn), committed_(false) {
  78. conn_.startTransaction();
  79. }
  80. PgSqlTransaction::~PgSqlTransaction() {
  81. // If commit() wasn't explicitly called, rollback.
  82. if (!committed_) {
  83. conn_.rollback();
  84. }
  85. }
  86. void
  87. PgSqlTransaction::commit() {
  88. conn_.commit();
  89. committed_ = true;
  90. }
  91. PgSqlConnection::~PgSqlConnection() {
  92. if (conn_) {
  93. // Deallocate the prepared queries.
  94. PgSqlResult r(PQexec(conn_, "DEALLOCATE all"));
  95. if(PQresultStatus(r) != PGRES_COMMAND_OK) {
  96. // Highly unlikely but we'll log it and go on.
  97. LOG_ERROR(dhcpsrv_logger, DHCPSRV_PGSQL_DEALLOC_ERROR)
  98. .arg(PQerrorMessage(conn_));
  99. }
  100. }
  101. }
  102. void
  103. PgSqlConnection::prepareStatement(const PgSqlTaggedStatement& statement) {
  104. // Prepare all statements queries with all known fields datatype
  105. PgSqlResult r(PQprepare(conn_, statement.name, statement.text,
  106. statement.nbparams, statement.types));
  107. if(PQresultStatus(r) != PGRES_COMMAND_OK) {
  108. isc_throw(DbOperationError, "unable to prepare PostgreSQL statement: "
  109. << statement.text << ", reason: " << PQerrorMessage(conn_));
  110. }
  111. }
  112. void
  113. PgSqlConnection::prepareStatements(const PgSqlTaggedStatement* start_statement,
  114. const PgSqlTaggedStatement* end_statement) {
  115. // Created the PostgreSQL prepared statements.
  116. for (const PgSqlTaggedStatement* tagged_statement = start_statement;
  117. tagged_statement != end_statement; ++tagged_statement) {
  118. prepareStatement(*tagged_statement);
  119. }
  120. }
  121. void
  122. PgSqlConnection::openDatabase() {
  123. string dbconnparameters;
  124. string shost = "localhost";
  125. try {
  126. shost = getParameter("host");
  127. } catch(...) {
  128. // No host. Fine, we'll use "localhost"
  129. }
  130. dbconnparameters += "host = '" + shost + "'" ;
  131. string sport;
  132. try {
  133. sport = getParameter("port");
  134. } catch (...) {
  135. // No port parameter, we are going to use the default port.
  136. sport = "";
  137. }
  138. if (sport.size() > 0) {
  139. unsigned int port = 0;
  140. // Port was given, so try to convert it to an integer.
  141. try {
  142. port = boost::lexical_cast<unsigned int>(sport);
  143. } catch (...) {
  144. // Port given but could not be converted to an unsigned int.
  145. // Just fall back to the default value.
  146. port = 0;
  147. }
  148. // The port is only valid when it is in the 0..65535 range.
  149. // Again fall back to the default when the given value is invalid.
  150. if (port > numeric_limits<uint16_t>::max()) {
  151. port = 0;
  152. }
  153. // Add it to connection parameters when not default.
  154. if (port > 0) {
  155. std::ostringstream oss;
  156. oss << port;
  157. dbconnparameters += " port = " + oss.str();
  158. }
  159. }
  160. string suser;
  161. try {
  162. suser = getParameter("user");
  163. dbconnparameters += " user = '" + suser + "'";
  164. } catch(...) {
  165. // No user. Fine, we'll use NULL
  166. }
  167. string spassword;
  168. try {
  169. spassword = getParameter("password");
  170. dbconnparameters += " password = '" + spassword + "'";
  171. } catch(...) {
  172. // No password. Fine, we'll use NULL
  173. }
  174. string sname;
  175. try {
  176. sname = getParameter("name");
  177. dbconnparameters += " dbname = '" + sname + "'";
  178. } catch(...) {
  179. // No database name. Throw a "NoDatabaseName" exception
  180. isc_throw(NoDatabaseName, "must specify a name for the database");
  181. }
  182. unsigned int connect_timeout = PGSQL_DEFAULT_CONNECTION_TIMEOUT;
  183. string stimeout;
  184. try {
  185. stimeout = getParameter("connect-timeout");
  186. } catch (...) {
  187. // No timeout parameter, we are going to use the default timeout.
  188. stimeout = "";
  189. }
  190. if (stimeout.size() > 0) {
  191. // Timeout was given, so try to convert it to an integer.
  192. try {
  193. connect_timeout = boost::lexical_cast<unsigned int>(stimeout);
  194. } catch (...) {
  195. // Timeout given but could not be converted to an unsigned int. Set
  196. // the connection timeout to an invalid value to trigger throwing
  197. // of an exception.
  198. connect_timeout = 0;
  199. }
  200. // The timeout is only valid if greater than zero, as depending on the
  201. // database, a zero timeout might signify something like "wait
  202. // indefinitely".
  203. //
  204. // The check below also rejects a value greater than the maximum
  205. // integer value. The lexical_cast operation used to obtain a numeric
  206. // value from a string can get confused if trying to convert a negative
  207. // integer to an unsigned int: instead of throwing an exception, it may
  208. // produce a large positive value.
  209. if ((connect_timeout == 0) ||
  210. (connect_timeout > numeric_limits<int>::max())) {
  211. isc_throw(DbInvalidTimeout, "database connection timeout (" <<
  212. stimeout << ") must be an integer greater than 0");
  213. }
  214. }
  215. std::ostringstream oss;
  216. oss << connect_timeout;
  217. dbconnparameters += " connect_timeout = " + oss.str();
  218. // Connect to Postgres, saving the low level connection pointer
  219. // in the holder object
  220. PGconn* new_conn = PQconnectdb(dbconnparameters.c_str());
  221. if (!new_conn) {
  222. isc_throw(DbOpenError, "could not allocate connection object");
  223. }
  224. if (PQstatus(new_conn) != CONNECTION_OK) {
  225. // If we have a connection object, we have to call finish
  226. // to release it, but grab the error message first.
  227. std::string error_message = PQerrorMessage(new_conn);
  228. PQfinish(new_conn);
  229. isc_throw(DbOpenError, error_message);
  230. }
  231. // We have a valid connection, so let's save it to our holder
  232. conn_.setConnection(new_conn);
  233. }
  234. bool
  235. PgSqlConnection::compareError(const PgSqlResult& r, const char* error_state) {
  236. const char* sqlstate = PQresultErrorField(r, PG_DIAG_SQLSTATE);
  237. // PostgreSQL guarantees it will always be 5 characters long
  238. return ((sqlstate != NULL) &&
  239. (memcmp(sqlstate, error_state, PGSQL_STATECODE_LEN) == 0));
  240. }
  241. void
  242. PgSqlConnection::checkStatementError(const PgSqlResult& r,
  243. PgSqlTaggedStatement& statement) const {
  244. int s = PQresultStatus(r);
  245. if (s != PGRES_COMMAND_OK && s != PGRES_TUPLES_OK) {
  246. // We're testing the first two chars of SQLSTATE, as this is the
  247. // error class. Note, there is a severity field, but it can be
  248. // misleadingly returned as fatal.
  249. const char* sqlstate = PQresultErrorField(r, PG_DIAG_SQLSTATE);
  250. if ((sqlstate != NULL) &&
  251. ((memcmp(sqlstate, "08", 2) == 0) || // Connection Exception
  252. (memcmp(sqlstate, "53", 2) == 0) || // Insufficient resources
  253. (memcmp(sqlstate, "54", 2) == 0) || // Program Limit exceeded
  254. (memcmp(sqlstate, "57", 2) == 0) || // Operator intervention
  255. (memcmp(sqlstate, "58", 2) == 0))) { // System error
  256. LOG_ERROR(dhcpsrv_logger, DHCPSRV_PGSQL_FATAL_ERROR)
  257. .arg(statement.name)
  258. .arg(PQerrorMessage(conn_))
  259. .arg(sqlstate);
  260. exit (-1);
  261. }
  262. const char* error_message = PQerrorMessage(conn_);
  263. isc_throw(DbOperationError, "Statement exec failed:" << " for: "
  264. << statement.name << ", reason: "
  265. << error_message);
  266. }
  267. }
  268. void
  269. PgSqlConnection::startTransaction() {
  270. LOG_DEBUG(dhcpsrv_logger, DHCPSRV_DBG_TRACE_DETAIL,
  271. DHCPSRV_PGSQL_START_TRANSACTION);
  272. PgSqlResult r(PQexec(conn_, "START TRANSACTION"));
  273. if (PQresultStatus(r) != PGRES_COMMAND_OK) {
  274. const char* error_message = PQerrorMessage(conn_);
  275. isc_throw(DbOperationError, "unable to start transaction"
  276. << error_message);
  277. }
  278. }
  279. void
  280. PgSqlConnection::commit() {
  281. LOG_DEBUG(dhcpsrv_logger, DHCPSRV_DBG_TRACE_DETAIL, DHCPSRV_PGSQL_COMMIT);
  282. PgSqlResult r(PQexec(conn_, "COMMIT"));
  283. if (PQresultStatus(r) != PGRES_COMMAND_OK) {
  284. const char* error_message = PQerrorMessage(conn_);
  285. isc_throw(DbOperationError, "commit failed: " << error_message);
  286. }
  287. }
  288. void
  289. PgSqlConnection::rollback() {
  290. LOG_DEBUG(dhcpsrv_logger, DHCPSRV_DBG_TRACE_DETAIL, DHCPSRV_PGSQL_ROLLBACK);
  291. PgSqlResult r(PQexec(conn_, "ROLLBACK"));
  292. if (PQresultStatus(r) != PGRES_COMMAND_OK) {
  293. const char* error_message = PQerrorMessage(conn_);
  294. isc_throw(DbOperationError, "rollback failed: " << error_message);
  295. }
  296. }
  297. }; // end of isc::dhcp namespace
  298. }; // end of isc namespace