Browse Source

[trac998] Modifications to IPCheck based on Jinmei's comments

Stephen Morris 14 years ago
parent
commit
632cd6151b
1 changed files with 53 additions and 52 deletions
  1. 53 52
      src/lib/acl/ip_check.h

+ 53 - 52
src/lib/acl/ip_check.h

@@ -134,6 +134,10 @@ private:
     static const size_t IPV6_SIZE = sizeof(struct in6_addr);
     static const size_t IPV6_SIZE = sizeof(struct in6_addr);
     static const size_t IPV4_SIZE = sizeof(struct in_addr);
     static const size_t IPV4_SIZE = sizeof(struct in_addr);
 
 
+    // Confirm our assumption of relative sizes - this allows us to assume that
+    // an array sized for an IPv6 address can hold an IPv4 address.
+    BOOST_STATIC_ASSERT(IPV6_SIZE > IPV4_SIZE);
+
 public:
 public:
     /// \brief String Constructor
     /// \brief String Constructor
     ///
     ///
@@ -149,7 +153,12 @@ public:
     ///        address).  If "n" is specified as zero, the match is for any
     ///        address).  If "n" is specified as zero, the match is for any
     ///        address in that address family.  The address can also be
     ///        address in that address family.  The address can also be
     ///        given as "any4" or "any6".
     ///        given as "any4" or "any6".
-    IPCheck(const std::string& ipprefix) : address_(), mask_(), family_(0) {
+    IPCheck(const std::string& ipprefix) : family_(0) {
+
+        // Ensure array elements are correctly initialized.
+        std::fill(address_, address_ + IPV6_SIZE, 0);
+        std::fill(mask_, mask_ + IPV6_SIZE, 0);
+
         // Check for special cases first.
         // Check for special cases first.
         if (ipprefix == "any4") {
         if (ipprefix == "any4") {
             family_ = AF_INET;
             family_ = AF_INET;
@@ -167,32 +176,28 @@ public:
             // Try to convert the address.  If successful, the result is in
             // Try to convert the address.  If successful, the result is in
             // network-byte order (most significant components at lower
             // network-byte order (most significant components at lower
             // addresses).
             // addresses).
-            BOOST_STATIC_ASSERT(IPV6_SIZE > IPV4_SIZE);
-            uint8_t address_bytes[IPV6_SIZE];
-            int status = inet_pton(AF_INET6, result.first.c_str(),
-                                   address_bytes);
+            int status = inet_pton(AF_INET6, result.first.c_str(), address_);
             if (status == 1) {
             if (status == 1) {
-                // It was an IPv6 address, copy into the address store
-                std::copy(address_bytes, address_bytes + IPV6_SIZE,
-                          std::back_inserter(address_));
+                // It was an IPv6 address.
                 family_ = AF_INET6;
                 family_ = AF_INET6;
-
             } else {
             } else {
-                // Not IPv6, try IPv4
-                int status = inet_pton(AF_INET, result.first.c_str(),
-                                       address_bytes);
+                // IPv6 interpretation failed, try IPv4.
+                status = inet_pton(AF_INET, result.first.c_str(), address_);
                 if (status == 1) {
                 if (status == 1) {
-                    std::copy(address_bytes, address_bytes + IPV4_SIZE,
-                              std::back_inserter(address_));
                     family_ = AF_INET;
                     family_ = AF_INET;
-
-                } else {
-                    isc_throw(isc::InvalidParameter, "address prefix of " <<
-                              ipprefix << " is a not valid");
                 }
                 }
             }
             }
+            
+            // Handle errors.
+            if (status == 0) {
+                isc_throw(isc::InvalidParameter, "address prefix of " <<
+                          ipprefix << " is not valid");
+            } else if (status < 0) {
+                isc_throw(isc::Unexpected, "address conversion of " <<
+                          ipprefix << " failed due to a system error");
+            }
 
 
-            // All done, so set the mask used in address comparison.
+            // All done, so set the mask used in the address comparison.
             setMask(result.second);
             setMask(result.second);
         }
         }
     }
     }
@@ -225,19 +230,21 @@ public:
 
 
     /// \return Stored IP address
     /// \return Stored IP address
     std::vector<uint8_t> getAddress() const {
     std::vector<uint8_t> getAddress() const {
-        return (address_);
+        const size_t vector_len = (family_ == AF_INET ? IPV4_SIZE : IPV6_SIZE);
+        return (std::vector<uint8_t>(address_, address_ + vector_len));
     }
     }
 
 
     /// \return Network mask applied to match
     /// \return Network mask applied to match
     std::vector<uint8_t> getMask() const {
     std::vector<uint8_t> getMask() const {
-        return (mask_);
+        const size_t vector_len = (family_ == AF_INET ? IPV4_SIZE : IPV6_SIZE);
+        return (std::vector<uint8_t>(mask_, mask_ + vector_len));
     }
     }
 
 
     /// \return Prefix length of the match
     /// \return Prefix length of the match
     size_t getPrefixlen() const {
     size_t getPrefixlen() const {
-        // Work this out by shifting bits out of the mask
+        // Work this out by shifting bits out of the mask.
         size_t count = 0;
         size_t count = 0;
-        for (size_t i = 0; i < mask_.size(); ++i) {
+        for (size_t i = 0; i < IPV6_SIZE; ++i) {
             if (mask_[i] == 0xff) {
             if (mask_[i] == 0xff) {
                 // Full byte, 8 bit set
                 // Full byte, 8 bit set
                 count += 8;
                 count += 8;
@@ -245,12 +252,13 @@ public:
             } else if (mask_[i] != 0) {
             } else if (mask_[i] != 0) {
                 // Partial set, count the bits
                 // Partial set, count the bits
                 uint8_t byte = mask_[i];
                 uint8_t byte = mask_[i];
-                for (int i = 0; i < 8 * sizeof(uint8_t); ++i) {
+                for (int j = 0; j < 8; ++j) {
                     count += byte & 0x01;   // Add one if the bit is set
                     count += byte & 0x01;   // Add one if the bit is set
                     byte >>= 1;             // Go for next bit
                     byte >>= 1;             // Go for next bit
                 }
                 }
-
-                // There won't be any more bits set after this, so exit
+            } else {
+                // Encountered a zero byte, so exit - there are no more bits
+                // set.
                 break;
                 break;
             }
             }
         }
         }
@@ -259,17 +267,10 @@ public:
 
 
     /// \return Address family
     /// \return Address family
     int getFamily() const {
     int getFamily() const {
-        // Check that a family_  value of 0 does not imply IPv4 or IPv6.
-        // This avoids confusion if getFamily() is called on an object that
-        // has been initialized by default.
-        BOOST_STATIC_ASSERT(AF_INET != 0);
-        BOOST_STATIC_ASSERT(AF_INET6 != 0);
-
         return (family_);
         return (family_);
     }
     }
     ///@}
     ///@}
 
 
-private:
     /// \brief Comparison
     /// \brief Comparison
     ///
     ///
     /// This is the actual comparison function that checks the IP address passed
     /// This is the actual comparison function that checks the IP address passed
@@ -297,23 +298,26 @@ private:
         //
         //
         // The result is checked for all bytes for which there are bits set in
         // The result is checked for all bytes for which there are bits set in
         // the mask.  We stop at the first non-match (or when we run out of bits
         // the mask.  We stop at the first non-match (or when we run out of bits
-        // in the mask). (Note that the mask represents a contiguous set of
-        // bits.  As such, as soon as we find a mask byte of zeroes, we have run
-        // past the part of the address where we need to match.
+        // in the mask). 
         //
         //
-        // Note that if the passed address was any4 or any6, we rely on the
-        // fact that the size of address_ is zero - the loop will terminate
-        // before the first iteration.
+        // Note that the mask represents a contiguous set of bits.  As such, as
+        // soon as we find a mask byte of zeroes, we have run past the part of
+        // the address where we need to match.
+        //
+        // Note also that when checking an IPv4 address, the constructor has
+        // set all bytes in the mask beyond the first four bytes that may be
+        // taken up by a mask for that address to zero, which will cause the
+        // loop to terminate.  This means that if the ACL is for an IPv4
+        // address, the loop will never check more than four bytes of testaddr.
 
 
         bool match = true;
         bool match = true;
-        for (int i = 0; match && (i < address_.size()) &&
-                       (mask_[i] != 0); ++i) {
+        for (int i = 0; match && (i < IPV6_SIZE) && (mask_[i] != 0); ++i) {
              match = ((testaddr[i] & mask_[i]) == (address_[i] & mask_[i]));
              match = ((testaddr[i] & mask_[i]) == (address_[i] & mask_[i]));
         }
         }
         return (match);
         return (match);
     }
     }
 
 
-
+private:
     /// \brief Set Mask
     /// \brief Set Mask
     ///
     ///
     /// Sets up the mask from the prefix length.  This involves setting
     /// Sets up the mask from the prefix length.  This involves setting
@@ -328,11 +332,9 @@ private:
     ///        was given.)
     ///        was given.)
     void setMask(int requested) {
     void setMask(int requested) {
 
 
-        mask_.clear();
-        mask_.resize((family_ == AF_INET) ? IPV4_SIZE : IPV6_SIZE);
-
-        // Set the maximum number of bits allowed in the mask.
-        int maxmask = 8 * (mask_.size());
+        // Set the maximum number of bits allowed in the mask, and request
+        // that number of bits if no prefix length was given in the constructor.
+        int maxmask = 8 * ((family_ == AF_INET) ? IPV4_SIZE : IPV6_SIZE);
         if (requested < 0) {
         if (requested < 0) {
             requested = maxmask;
             requested = maxmask;
         }
         }
@@ -364,11 +366,10 @@ private:
         }
         }
     }
     }
 
 
-    // Member variables
-
-    std::vector<uint8_t> address_;  ///< Address in binary form
-    std::vector<uint8_t> mask_;     ///< Address mask
-    int         family_;            ///< Address family
+    // Member variables.
+    uint8_t address_[IPV6_SIZE];  ///< Address in binary form
+    uint8_t mask_[IPV6_SIZE];     ///< Address mask
+    int     family_;              ///< Address family
 };
 };
 
 
 } // namespace acl
 } // namespace acl