Browse Source

[trac812] some cleanups

JINMEI Tatuya 14 years ago
parent
commit
ed5025a132
2 changed files with 35 additions and 31 deletions
  1. 30 26
      src/lib/dns/tests/tsig_unittest.cc
  2. 5 5
      src/lib/dns/tsig.cc

+ 30 - 26
src/lib/dns/tests/tsig_unittest.cc

@@ -16,6 +16,8 @@
 #include <string>
 #include <string>
 #include <vector>
 #include <vector>
 
 
+#include <boost/scoped_ptr.hpp>
+
 #include <gtest/gtest.h>
 #include <gtest/gtest.h>
 
 
 #include <exceptions/exceptions.h>
 #include <exceptions/exceptions.h>
@@ -72,18 +74,16 @@ protected:
         // confused due to other tests that tweak the time.
         // confused due to other tests that tweak the time.
         tsig::detail::gettimeFunction = NULL;
         tsig::detail::gettimeFunction = NULL;
 
 
-        // Note: the following code is not exception safe, but we ignore it for
-        // simplicity
         decodeBase64("SFuWd/q99SzF8Yzd1QbB9g==", secret);
         decodeBase64("SFuWd/q99SzF8Yzd1QbB9g==", secret);
-        tsig_ctx = new TSIGContext(TSIGKey(test_name, TSIGKey::HMACMD5_NAME(),
-                                           &secret[0], secret.size()));
-        tsig_verify_ctx = new TSIGContext(TSIGKey(test_name,
-                                                  TSIGKey::HMACMD5_NAME(),
-                                                  &secret[0], secret.size()));
+        tsig_ctx.reset(new TSIGContext(TSIGKey(test_name,
+                                               TSIGKey::HMACMD5_NAME(),
+                                               &secret[0], secret.size())));
+        tsig_verify_ctx.reset(new TSIGContext(TSIGKey(test_name,
+                                                      TSIGKey::HMACMD5_NAME(),
+                                                      &secret[0],
+                                                      secret.size())));
     }
     }
     ~TSIGTest() {
     ~TSIGTest() {
-        delete tsig_ctx;
-        delete tsig_verify_ctx;
         tsig::detail::gettimeFunction = NULL;
         tsig::detail::gettimeFunction = NULL;
     }
     }
 
 
@@ -106,8 +106,8 @@ protected:
     static const unsigned int AA_FLAG = 0x2;
     static const unsigned int AA_FLAG = 0x2;
     static const unsigned int RD_FLAG = 0x4;
     static const unsigned int RD_FLAG = 0x4;
 
 
-    TSIGContext* tsig_ctx;
-    TSIGContext* tsig_verify_ctx;
+    boost::scoped_ptr<TSIGContext> tsig_ctx;
+    boost::scoped_ptr<TSIGContext> tsig_verify_ctx;
     const uint16_t qid;
     const uint16_t qid;
     const Name test_name;
     const Name test_name;
     const RRClass test_class;
     const RRClass test_class;
@@ -210,8 +210,8 @@ TEST_F(TSIGTest, sign) {
 
 
     {
     {
         SCOPED_TRACE("Sign test for query");
         SCOPED_TRACE("Sign test for query");
-        commonTSIGChecks(createMessageAndSign(qid, test_name, tsig_ctx), qid,
-                         0x4da8877a, common_expected_mac,
+        commonTSIGChecks(createMessageAndSign(qid, test_name, tsig_ctx.get()),
+                         qid, 0x4da8877a, common_expected_mac,
                          sizeof(common_expected_mac));
                          sizeof(common_expected_mac));
     }
     }
 }
 }
@@ -259,7 +259,7 @@ TEST_F(TSIGTest, signAtActualTime) {
     {
     {
         SCOPED_TRACE("Sign test for query at actual time");
         SCOPED_TRACE("Sign test for query at actual time");
         ConstTSIGRecordPtr tsig = createMessageAndSign(qid, test_name,
         ConstTSIGRecordPtr tsig = createMessageAndSign(qid, test_name,
-                                                       tsig_ctx);
+                                                       tsig_ctx.get());
         const any::TSIG& tsig_rdata = tsig->getRdata();
         const any::TSIG& tsig_rdata = tsig->getRdata();
 
 
         // Check the resulted time signed is in the range of [now, now + 5]
         // Check the resulted time signed is in the range of [now, now + 5]
@@ -308,13 +308,14 @@ TEST_F(TSIGTest, signUsingHMACSHA1) {
 TEST_F(TSIGTest, signResponse) {
 TEST_F(TSIGTest, signResponse) {
     tsig::detail::gettimeFunction = testGetTime<0x4da8877a>;
     tsig::detail::gettimeFunction = testGetTime<0x4da8877a>;
 
 
-    ConstTSIGRecordPtr tsig = createMessageAndSign(qid, test_name, tsig_ctx);
+    ConstTSIGRecordPtr tsig = createMessageAndSign(qid, test_name,
+                                                   tsig_ctx.get());
     tsig_verify_ctx->verifyTentative(tsig);
     tsig_verify_ctx->verifyTentative(tsig);
     EXPECT_EQ(TSIGContext::CHECKED, tsig_verify_ctx->getState());
     EXPECT_EQ(TSIGContext::CHECKED, tsig_verify_ctx->getState());
 
 
     // Transform the original message to a response, then sign the response
     // Transform the original message to a response, then sign the response
     // with the context of "verified state".
     // with the context of "verified state".
-    tsig = createMessageAndSign(qid, test_name, tsig_verify_ctx,
+    tsig = createMessageAndSign(qid, test_name, tsig_verify_ctx.get(),
                                 QR_FLAG|AA_FLAG|RD_FLAG,
                                 QR_FLAG|AA_FLAG|RD_FLAG,
                                 RRType::A(), "192.0.2.1");
                                 RRType::A(), "192.0.2.1");
     const uint8_t expected_mac[] = {
     const uint8_t expected_mac[] = {
@@ -347,13 +348,13 @@ TEST_F(TSIGTest, signContinuation) {
 
 
     // Create and sign the AXFR request, then verify it.
     // Create and sign the AXFR request, then verify it.
     tsig_verify_ctx->verifyTentative(createMessageAndSign(axfr_qid, zone_name,
     tsig_verify_ctx->verifyTentative(createMessageAndSign(axfr_qid, zone_name,
-                                                          tsig_ctx, 0,
+                                                          tsig_ctx.get(), 0,
                                                           RRType::AXFR()));
                                                           RRType::AXFR()));
     EXPECT_EQ(TSIGContext::CHECKED, tsig_verify_ctx->getState());
     EXPECT_EQ(TSIGContext::CHECKED, tsig_verify_ctx->getState());
 
 
     // Create and sign the first response message (we don't need the result
     // Create and sign the first response message (we don't need the result
     // for the purpose of this test)
     // for the purpose of this test)
-    createMessageAndSign(axfr_qid, zone_name, tsig_verify_ctx,
+    createMessageAndSign(axfr_qid, zone_name, tsig_verify_ctx.get(),
                          AA_FLAG|QR_FLAG, RRType::AXFR(),
                          AA_FLAG|QR_FLAG, RRType::AXFR(),
                          "ns.example.com. root.example.com. "
                          "ns.example.com. root.example.com. "
                          "2011041503 7200 3600 2592000 1200",
                          "2011041503 7200 3600 2592000 1200",
@@ -367,8 +368,8 @@ TEST_F(TSIGTest, signContinuation) {
     {
     {
         SCOPED_TRACE("Sign test for continued response in TCP stream");
         SCOPED_TRACE("Sign test for continued response in TCP stream");
         commonTSIGChecks(createMessageAndSign(axfr_qid, zone_name,
         commonTSIGChecks(createMessageAndSign(axfr_qid, zone_name,
-                                              tsig_verify_ctx, AA_FLAG|QR_FLAG,
-                                              RRType::AXFR(),
+                                              tsig_verify_ctx.get(),
+                                              AA_FLAG|QR_FLAG, RRType::AXFR(),
                                               "ns.example.com.", &RRType::NS(),
                                               "ns.example.com.", &RRType::NS(),
                                               false),
                                               false),
                          axfr_qid, 0x4da8e951,
                          axfr_qid, 0x4da8e951,
@@ -394,7 +395,8 @@ TEST_F(TSIGTest, badtimeResponse) {
 
 
     const uint16_t test_qid = 0x7fc4;
     const uint16_t test_qid = 0x7fc4;
     ConstTSIGRecordPtr tsig = createMessageAndSign(test_qid, test_name,
     ConstTSIGRecordPtr tsig = createMessageAndSign(test_qid, test_name,
-                                                   tsig_ctx, 0, RRType::SOA());
+                                                   tsig_ctx.get(), 0,
+                                                   RRType::SOA());
 
 
     // "advance the clock" and try validating, which should fail due to BADTIME
     // "advance the clock" and try validating, which should fail due to BADTIME
     // (verifyTentative actually doesn't check the time, though)
     // (verifyTentative actually doesn't check the time, though)
@@ -403,7 +405,7 @@ TEST_F(TSIGTest, badtimeResponse) {
     EXPECT_EQ(TSIGError::BAD_TIME(), tsig_verify_ctx->getError());
     EXPECT_EQ(TSIGError::BAD_TIME(), tsig_verify_ctx->getError());
 
 
     // make and sign a response in the context of TSIG error.
     // make and sign a response in the context of TSIG error.
-    tsig = createMessageAndSign(test_qid, test_name, tsig_verify_ctx,
+    tsig = createMessageAndSign(test_qid, test_name, tsig_verify_ctx.get(),
                                 QR_FLAG, RRType::SOA(), NULL, NULL,
                                 QR_FLAG, RRType::SOA(), NULL, NULL,
                                 true, Rcode::NOTAUTH());
                                 true, Rcode::NOTAUTH());
     const uint8_t expected_otherdata[] = { 0, 0, 0x4d, 0xa8, 0xbe, 0x86 };
     const uint8_t expected_otherdata[] = { 0, 0, 0x4d, 0xa8, 0xbe, 0x86 };
@@ -427,14 +429,15 @@ TEST_F(TSIGTest, badsigResponse) {
     // Sign a simple message, and force the verification to fail with
     // Sign a simple message, and force the verification to fail with
     // BADSIG.
     // BADSIG.
     tsig_verify_ctx->verifyTentative(createMessageAndSign(qid, test_name,
     tsig_verify_ctx->verifyTentative(createMessageAndSign(qid, test_name,
-                                                          tsig_ctx),
+                                                          tsig_ctx.get()),
                                      TSIGError::BAD_SIG());
                                      TSIGError::BAD_SIG());
 
 
     // Sign the same message (which doesn't matter for this test) with the
     // Sign the same message (which doesn't matter for this test) with the
     // context of "checked state".
     // context of "checked state".
     {
     {
         SCOPED_TRACE("Sign test for response with BADSIG error");
         SCOPED_TRACE("Sign test for response with BADSIG error");
-        commonTSIGChecks(createMessageAndSign(qid, test_name, tsig_verify_ctx),
+        commonTSIGChecks(createMessageAndSign(qid, test_name,
+                                              tsig_verify_ctx.get()),
                          message.getQid(), 0x4da8877a, NULL, 0,
                          message.getQid(), 0x4da8877a, NULL, 0,
                          16);   // 16: BADSIG
                          16);   // 16: BADSIG
     }
     }
@@ -444,11 +447,12 @@ TEST_F(TSIGTest, badkeyResponse) {
     // A similar test as badsigResponse but for BADKEY
     // A similar test as badsigResponse but for BADKEY
     tsig::detail::gettimeFunction = testGetTime<0x4da8877a>;
     tsig::detail::gettimeFunction = testGetTime<0x4da8877a>;
     tsig_verify_ctx->verifyTentative(createMessageAndSign(qid, test_name,
     tsig_verify_ctx->verifyTentative(createMessageAndSign(qid, test_name,
-                                                          tsig_ctx),
+                                                          tsig_ctx.get()),
                                      TSIGError::BAD_KEY());
                                      TSIGError::BAD_KEY());
     {
     {
         SCOPED_TRACE("Sign test for response with BADKEY error");
         SCOPED_TRACE("Sign test for response with BADKEY error");
-        commonTSIGChecks(createMessageAndSign(qid, test_name, tsig_verify_ctx),
+        commonTSIGChecks(createMessageAndSign(qid, test_name,
+                                              tsig_verify_ctx.get()),
                          message.getQid(), 0x4da8877a, NULL, 0,
                          message.getQid(), 0x4da8877a, NULL, 0,
                          17);   // 17: BADKEYSIG
                          17);   // 17: BADKEYSIG
     }
     }

+ 5 - 5
src/lib/dns/tsig.cc

@@ -129,11 +129,11 @@ TSIGContext::sign(const uint16_t qid, const void* const data,
     }
     }
 
 
     OutputBuffer variables(0);
     OutputBuffer variables(0);
-    HMACPtr hmac = HMACPtr(CryptoLink::getCryptoLink().createHMAC(
-                               impl_->key_.getSecret(),
-                               impl_->key_.getSecretLength(),
-                               impl_->key_.getCryptoAlgorithm()),
-                           deleteHMAC);
+    HMACPtr hmac(CryptoLink::getCryptoLink().createHMAC(
+                     impl_->key_.getSecret(),
+                     impl_->key_.getSecretLength(),
+                     impl_->key_.getCryptoAlgorithm()),
+                 deleteHMAC);
 
 
     // If the context has previous MAC (either the Request MAC or its own
     // If the context has previous MAC (either the Request MAC or its own
     // previous MAC), digest it.
     // previous MAC), digest it.