Browse Source

[trac998] IPV4 checks working

Stephen Morris 14 years ago
parent
commit
f5edd31046
2 changed files with 247 additions and 177 deletions
  1. 155 119
      src/lib/acl/ip_check.h
  2. 92 58
      src/lib/acl/tests/ip_check_unittest.cc

+ 155 - 119
src/lib/acl/ip_check.h

@@ -16,6 +16,7 @@
 #define __IP_CHECK_H
 
 #include <boost/lexical_cast.hpp>
+#include <utility>
 #include <vector>
 
 #include <stdint.h>
@@ -29,16 +30,18 @@
 namespace isc {
 namespace acl {
 
+// Free functions
+
 /// \brief Convert Mask Size to Mask
 ///
 /// Given a mask size and a data type, return a value of that data type with the
-/// most significant maasksize bits set.  For example, if the data type is an
-/// unsigned either-bit byte and the masksize is 3, the function would return
-/// an eight-bit byte with the binary value 11100000.
+/// most significant masksize bits set.  For example, if the data type is an
+/// unsigned eight-bit byte and the masksize is 3, the function would return
+/// an unsigned eight-bit byte with the binary value 11100000.
 ///
-/// This is a templated function.  The template parameter must be a signed type.
+/// The function is templated on the data type of the mask.
 ///
-/// \param masksize Size of the mask.  This must be between 1 and sizeof(T).
+/// \param masksize Size of the mask.  This must be between 1 and 8*sizeof(T).
 ///        An out of range exception is thrown if this is not the case.
 ///
 /// \return Value with the most significant "masksize" bits set.
@@ -48,67 +51,99 @@ T createNetmask(size_t masksize) {
 
     if ((masksize > 0) && (masksize <= 8 * sizeof(T))) {
 
-        // To explain the logic, consider a single 4-bit word.  The masksize in
-        // this case can be between 1 and 4.  The following table has the
-        // following columns:
-        //
-        // Mask size (m): number of contiguous bits set
-        //
-        // Low value (lo): unsigned value of 4-bit word with the least-
-        // significant m contiguous bits set.
+        // In the following discussion:
         //
-        // High value (hi): unsigned value of 4-bit word with the most-
-        // significant m contiguous bits set.
+        // w is the width of the data type T in bits
+        // m is the value of masksize, the number of most signifcant bits we
+        // want to set.
         //
-        //    m   lo   hi
-        //    1    1    8
-        //    2    3   12
-        //    3    7   14
-        //    4   15   15
+        // We note that the value of 2**m - 1 gives a result with the least
+        // significant m bits set and the most signficant (w-m) bits clear.
         //
-        // Clearly the low value is equal to (2**m - 1) (using ** to indicate
-        // exponentiation).  It takes a little thought to see that the high
-        // value is equal to 2**4 - 2**(4-m).  Unfortunately, this formula will
-        // overflow as the intermediate value 2 << sizeof(T) will overflow in an
-        // element of type T.
+        // Hence the value 2**(w-m) - 1 gives a result with the least signficant
+        // w-m bits set and the most significant m bits clear.
         //
-        // However, another way of looking it is that to set the most signifcant
-        // m bits, we set all bits and clear the least-significant 4-m bits.  If
-        // T is a signed value, we can set all bits by setting it to -1.  If m
-        // is 4, we omit clearing any bits, otherwise we clear the bits
-        // represented by the bit pattern 2**(3-m) - 1.  (The value 2**(3-m)
-        // will be greater than 0 and within the range of an unsigned data
-        // type of the same size as T.  So it should not overflow.)
+        // The 1's complement of this value gives is the result we want.
         //
-        // Therefore we proceed on the assumption that T is signed
-        T mask = -1;
-        if (masksize < 8 * sizeof(T)) {
-            mask &= ~((2 << (8 * sizeof(T) - 1 - masksize)) - 1);
+        // Final note:  masksize is non-zero, so we are assured that no term in 
+        // the expression below will overflow.
+
+        return (~((1 << (8 * sizeof(T) - masksize)) - 1));
+    }
+
+    isc_throw(isc::OutOfRange, "mask size must be between 1 and " <<
+                               8 * sizeof(T));
+}
+
+/// \brief Split IP Address
+///
+/// Splits an IP address (given in the form of "xxxxxx/n" or "xxxxx" into a
+/// string representing the IP address and a number giving the size of the
+/// network mask in bits.
+///
+/// An exception will be thrown if the string format is invalid.  N.B. This
+/// does NOT check that the address component is a valid IP address - only that
+/// some string is present.
+///
+/// \param addrmask Address and/or address/mask.  The string should be passed
+///                 without leading or trailing spaces.
+/// \param defmask  Default value of the mask size, used if no mask is given.
+/// \param maxmask  Maximum valid value of the mask size.
+///
+/// \return Pair of (string, uint32) holding the address string and the mask
+///         size value.
+
+std::pair<std::string, uint32_t>
+splitIpAddress(const std::string& addrmask, uint32_t defmask, uint32_t maxmask){
+
+    uint32_t masksize = defmask;
+
+    // See if a mask size was given
+    std::vector<std::string> components = isc::util::str::tokens(addrmask, "/");
+    if (components.size() == 2) {
+
+        // There appears to be, try converting it to a number.
+        try {
+            masksize = boost::lexical_cast<size_t>(components[1]);
+        } catch (boost::bad_lexical_cast&) {
+            isc_throw(isc::InvalidParameter,
+                      "mask size specified in address/masksize " << addrmask <<
+                      " is not valid");
         }
 
-        return (mask);
+        // Is it in the valid range?
+        if ((masksize == 0) || (masksize > maxmask)) {
+            isc_throw(isc::OutOfRange,
+                      "mask size specified in address/masksize " << addrmask <<
+                      " must be in range 1 <= masksize <= " << maxmask);
+        }
+
+    } else if (components.size() > 2) {
+        isc_throw(isc::InvalidParameter, "address/masksize of " <<
+                  addrmask << " is not valid");
     }
 
-    // Invalid mask size
-    isc_throw(isc::OutOfRange, "mask size of " << masksize << " is invalid " <<
-                               "for the data type which is " << sizeof(T) <<
-                               " bytes long");
+    return (std::make_pair(components[0], masksize));
 }
 
-/// \brief IP V4 Check
+
+/// \brief IPV4 Check
 ///
 /// This class performs a match between an IPv4 address specified in an ACL
 /// (IP address, network mask and a flag indicating whether the check should
-/// be for a match or for no-match) and a given IPv4 address.
+/// be for a match or for a non-match) and a given IP address.
 ///
 /// \param Context Structure holding address to be matched.
 
-template <typename Context> class Ipv4Check : public Check<Context> {
+template <typename Context>
+class Ipv4Check : public Check<Context> {
 public:
-
-    /// \brief Constructor
+    /// \brief IPV4 Constructor
+    ///
+    /// Constructs an IPv4 Check object from a network address given as a
+    /// 32-bit value in network byte order.
     ///
-    /// \param address IP address to check for (as an address in host-byte
+    /// \param address IP address to check for (as an address in network-byte
     ///        order).
     /// \param mask The network mask specified as an integer between 1 and
     ///        32 This determines the number of bits in the mask to check.
@@ -117,84 +152,50 @@ public:
     /// \param inverse If false (the default), matches() returns true if the
     ///        condition matches.  If true, matches() returns true if the
     ///        condition does not match.
-    Ipv4Check(uint32_t address, size_t masksize = 32, bool inverse = false) :
-        address_(address), masksize_(masksize), netmask_(0), inverse_(inverse)
+    Ipv4Check(uint32_t address = 1, size_t masksize = 32, bool inverse = false):
+        address_(address), masksize_(masksize), inverse_(inverse), netmask_(0)
     {
-        init();
+        setNetmask();
     }
 
-    /// \brief Constructor
+
+
+    /// \brief String Constructor
+    ///
+    /// Constructs an IPv4 Check object from a network address and size of mask
+    /// given as a string of the form "a.b.c.d/n", where the "/n" part is
+    /// optional.
     ///
     /// \param address IP address and netmask in the form "a.b.c.d/n" (where
-    ///        the "/n" part is optional.
+    ///        the "/n" part is optional).
     /// \param inverse If false (the default), matches() returns true if the
     ///        condition matches.  If true, matches() returns true if the
     ///        condition does not match.
     Ipv4Check(const std::string& address, bool inverse = false) :
-        address_(0), masksize_(32), netmask_(0), inverse_(inverse)
+        address_(1), masksize_(32), inverse_(inverse), netmask_(0)
     {
-        // See if there is a netmask.
-        std::vector<std::string> components =
-            isc::util::str::tokens(address, "/");
-        if (components.size() == 2) {
-
-            // Yes there is, convert to a mask
-            try {
-                masksize_ = boost::lexical_cast<size_t>(components[1]);
-            } catch (boost::bad_lexical_cast&) {
-                isc_throw(isc::InvalidParameter,
-                          "mask specified in address/mask " << address <<
-                          " is not valid");
-            }
-        } else if (components.size() > 2) {
-            isc_throw(isc::InvalidParameter, "address/mask of " <<
-                      address << " is not valid");
-        }
+        // Split the address into address part and mask.
+        std::pair<std::string, uint32_t> result =
+            splitIpAddress(address, 8 * sizeof(uint32_t), 8 * sizeof(uint32_t));
 
         // Try to convert the address.
-        int result = inet_pton(AF_INET, components[0].c_str(), &address_);
-        if (result == 0) {
-            isc_throw(isc::InvalidParameter, "address/mask of " <<
-                      address << " is not valid");
+        int status = inet_pton(AF_INET, result.first.c_str(), &address_);
+        if (status == 0) {
+            isc_throw(isc::InvalidParameter, "address/masksize of " <<
+                      address << " is not valid IPV4 address");
         }
-        address_ = ntohl(address_);
 
         // All done, so finish initialization.
-        init();
+        masksize_ = result.second;
+        setNetmask();
     }
 
+
+
     /// \brief Destructor
     virtual ~Ipv4Check() {}
 
-    /// \brief Comparison
-    ///
-    /// This is the actual comparison function that checks the IP address passed
-    /// to this class with the matching information in the class itself.
-    ///
-    /// \param address Address to match against the check condition in the
-    ///        class.
-    ///
-    /// \return true if the address matches, false if it does not.
-    virtual bool compare(uint32_t address) {
-
-        // To check that the address given matches the stored network address
-        // and netmask, we check the simple condition that:
-        //
-        //     address_given & netmask_ == maskaddr_.
-        //
-        // However, we must return the negation of the result if inverse_ is
-        // set.  This leads to the truth table:
-        //
-        // Result inverse_ Return
-        // false  false    false
-        // false  true     true
-        // true   false    true
-        // true   true     false
-        //
-        // ... which is an XOR function.
 
-        return (((address & netmask_) == maskaddr_) ^ inverse_);
-    }
 
     /// \brief The check itself
     ///
@@ -203,16 +204,20 @@ public:
     /// link will fail if used for a type for which no match is provided.
     ///
     /// \param context Information to be matched
-    virtual bool matches(const Context& context) const {return false; }
+    virtual bool matches(const Context& context) const = 0;
+
+
 
     /// \brief Estimated cost
     ///
     /// Assume that the cost of the match is linear and depends on the number
-    /// of compariosn operations.
+    /// of comparison operations.
     virtual unsigned cost() const {
         return (1);             // Single check on a 32-bit word
     }
 
+
+
     ///@{
     /// Access methods - mainly for testing
 
@@ -235,24 +240,21 @@ public:
     bool getInverse() {
         return (inverse_);
     }
-
     ///@}
 
-private:
-    /// \brief Initialization
+    /// \brief Set Network Mask
     ///
-    /// Common code shared by all constructors to set up the net mask and
-    /// addresses.
-    void init() {
+    /// Sets up the network mask from the mask size.
+    void setNetmask() {
         // Validate that the mask is valid.
-        if ((masksize_ >= 1) && (masksize_ <= 32)) {
+        if ((masksize_ >= 1) && (masksize_ <=  8 * sizeof(uint32_t))) {
 
             // Calculate the bitmask given by the number of bits.
-            netmask_ = isc::acl::createNetmask<int32_t>(masksize_);
+            netmask_ = isc::acl::createNetmask<uint32_t>(masksize_);
+
+            // ... and convert to network byte order.
+            netmask_ = htonl(netmask_);
 
-            // For speed, store the masked off address.   This saves a mask
-            // operation every time the value is checked.
-            maskaddr_ = address_ & netmask_;
         } else {
             isc_throw(isc::OutOfRange,
                       "mask size of " << masksize_ << " is invalid " <<
@@ -261,11 +263,45 @@ private:
         }
     }
 
+    /// \brief Comparison
+    ///
+    /// This is the actual comparison function that checks the IP address passed
+    /// to this class with the matching information in the class itself.  It is
+    /// expected to be called from matches().
+    ///
+    /// \param address Address (in network byte order) to match against the
+    ///                check condition in the class.
+    ///
+    /// \return true if the address matches, false if it does not.
+    virtual bool compare(uint32_t address) const {
+
+        // To check that the address given matches the stored network address
+        // and netmask, we check the simple condition that:
+        //
+        //     address_given & netmask_ == stored_address & netmask_
+        //
+        // However, we must return the negation of the result if inverse_ is
+        // set.  This leads to the truth table:
+        //
+        // Result inverse_ Return
+        // false  false    false
+        // false  true     true
+        // true   false    true
+        // true   true     false
+        //
+        // ... which is an XOR function.  Although there is no explicit logical
+        /// XOR operator, with two bool arguments, "!=" serves that function.
+
+        return (((address & netmask_) == (address_ & netmask_)) != inverse_);
+    }
+
+    // Member variables
+
     uint32_t    address_;   ///< IPv4 address
-    uint32_t    maskaddr_;  ///< Masked IPV4 address
     size_t      masksize_;  ///< Mask size passed to constructor
-    int32_t     netmask_;   ///< Network mask applied to match
     bool        inverse_;   ///< test for equality or inequality
+    uint32_t    netmask_;   ///< Network mask applied to match
+
 };
 
 } // namespace acl

+ 92 - 58
src/lib/acl/tests/ip_check_unittest.cc

@@ -18,7 +18,34 @@
 
 using namespace isc::acl;
 
-/// General tests
+// Declare a derived class to allow the abstract function to be declared
+// as a concrete one.
+
+class DerivedV4Check : public Ipv4Check<uint32_t> {
+public:
+    // Basic constructor
+    DerivedV4Check(uint32_t address = 1, size_t masksize = 32,
+                   bool inverse = false) :
+                   Ipv4Check<uint32_t>(address, masksize, inverse)
+    {}
+
+    // String constructor
+    DerivedV4Check(const std::string& address, bool inverse = false) :
+        Ipv4Check<uint32_t>(address, inverse)
+    {}
+
+    // Destructor
+    virtual ~DerivedV4Check()
+    {}
+
+    // Concrete implementation of abstract method
+    virtual bool matches(const uint32_t& context) const {
+        return (compare(context));
+    }
+};
+
+
+/// Tests of the free functions.
 
 TEST(IpCheck, CreateNetmask) {
     size_t  i;
@@ -26,14 +53,16 @@ TEST(IpCheck, CreateNetmask) {
     // 8-bit tests.
 
     // Invalid arguments should throw.
-    EXPECT_THROW(createNetmask<int8_t>(0), isc::OutOfRange);
-    EXPECT_THROW(createNetmask<int8_t>(9), isc::OutOfRange);
+    EXPECT_THROW(createNetmask<uint8_t>(0), isc::OutOfRange);
+    EXPECT_THROW(createNetmask<uint8_t>(9), isc::OutOfRange);
 
-    // Check on all possible 8-bit values
+    // Check on all possible 8-bit values.  Use a signed type to generate a
+    // variable with the most significant bits set, as right-shifting it is
+    // guaranteed to introduce additional bits.
     int8_t  expected8;
     for (i = 1, expected8 = 0x80; i <= 8; ++i, expected8 >>= 1) {
-        EXPECT_EQ(static_cast<int32_t>(expected8),
-                  static_cast<int32_t>(createNetmask<int8_t>(i)));
+        EXPECT_EQ(static_cast<uint8_t>(expected8),
+                  createNetmask<uint8_t>(i));
     }
 
     // Do the same for 32 bits.
@@ -43,99 +72,104 @@ TEST(IpCheck, CreateNetmask) {
     // Check on all possible 8-bit values
     int32_t expected32;
     for (i = 1, expected32 = 0x80000000; i <= 32; ++i, expected32 >>= 1) {
-        EXPECT_EQ(expected32, createNetmask<int32_t>(i));
+        EXPECT_EQ(static_cast<uint32_t>(expected32),
+                  createNetmask<uint32_t>(i));
     }
 }
-
-// V4 tests
+// IPV4 tests
 
 // Check that the constructor expands the network mask and stores the elements
 // correctly.  For these tests, we don't worry about the type of the context,
 // so we declare it as an int.
 
 TEST(IpCheck, V4ConstructorAddress) {
-    // Alternating bits
-    Ipv4Check<int> acl1(0x55555555);
-    EXPECT_EQ(0x55555555, acl1.getAddress());
-
-    Ipv4Check<int> acl2(0xcccccccc);
-    EXPECT_EQ(0xcccccccc, acl2.getAddress());
+    DerivedV4Check acl1(0x12345678);
+    EXPECT_EQ(0x12345678, acl1.getAddress());
 }
 
+// The mask is stored in network byte order, so the pattern expected must
+// also be converted to network byte order for the comparison to succeed.
 TEST(IpCheck, V4ConstructorMask) {
     // Valid values. Address of "1" is used as a placeholder
-    Ipv4Check<int> acl1(1, 1);
-    EXPECT_EQ(0x80000000, acl1.getNetmask());
+    DerivedV4Check acl1(1, 1);
+    uint32_t expected = htonl(0x80000000);
+    EXPECT_EQ(expected, acl1.getNetmask());
     EXPECT_EQ(1, acl1.getMasksize());
 
-    Ipv4Check<int> acl2(1, 24);
-    EXPECT_EQ(0xffffff00, acl2.getNetmask());
+    DerivedV4Check acl2(1, 24);
+    expected = htonl(0xffffff00);
+    EXPECT_EQ(expected, acl2.getNetmask());
     EXPECT_EQ(24, acl2.getMasksize());
 
     // ... and some invalid network masks
-    EXPECT_THROW(Ipv4Check<int>(1, 0), isc::OutOfRange);
+    EXPECT_THROW(DerivedV4Check(1, 0), isc::OutOfRange);
+    EXPECT_THROW(DerivedV4Check(1, 33), isc::OutOfRange);
 }
 
 TEST(IpCheck, V4ConstructorInverse) {
     // Valid values. Address/mask of "1" is used as a placeholder
-    Ipv4Check<int> acl1(1, 1);
+    DerivedV4Check acl1(1, 1);
     EXPECT_FALSE(acl1.getInverse());
 
-    Ipv4Check<int> acl2(1, 1, true);
+    DerivedV4Check acl2(1, 1, true);
     EXPECT_TRUE(acl2.getInverse());
 
-    Ipv4Check<int> acl3(1, 1, false);
+    DerivedV4Check acl3(1, 1, false);
     EXPECT_FALSE(acl3.getInverse());
 }
 
 TEST(IpCheck, V4StringConstructor) {
-    Ipv4Check<int> acl1("127.0.0.1");
-    EXPECT_EQ(0x7f000001, acl1.getAddress());
+    DerivedV4Check acl1("127.0.0.1");
+    uint32_t expected = htonl(0x7f000001);
+    EXPECT_EQ(expected, acl1.getAddress());
     EXPECT_EQ(32, acl1.getMasksize());
 
-    Ipv4Check<int> acl2("255.255.255.0/24");
-    EXPECT_EQ(0xffffff00, acl2.getAddress());
+    DerivedV4Check acl2("255.255.255.0/24");
+    expected = htonl(0xffffff00);
+    EXPECT_EQ(expected, acl2.getAddress());
     EXPECT_EQ(24, acl2.getMasksize());
 
-    EXPECT_THROW(Ipv4Check<int>("255.255.255.0/0"), isc::OutOfRange);
-    EXPECT_THROW(Ipv4Check<int>("255.255.255.0/33"), isc::OutOfRange);
-    EXPECT_THROW(Ipv4Check<int>("255.255.255.0/24/3"), isc::InvalidParameter);
-    EXPECT_THROW(Ipv4Check<int>("255.255.255.0/ww"), isc::InvalidParameter);
-    EXPECT_THROW(Ipv4Check<int>("aa.255.255.0/ww"), isc::InvalidParameter);
+    EXPECT_THROW(DerivedV4Check("255.255.255.0/0"), isc::OutOfRange);
+    EXPECT_THROW(DerivedV4Check("255.255.255.0/33"), isc::OutOfRange);
+    EXPECT_THROW(DerivedV4Check("255.255.255.0/24/3"), isc::InvalidParameter);
+    EXPECT_THROW(DerivedV4Check("255.255.255.0/ww"), isc::InvalidParameter);
+    EXPECT_THROW(DerivedV4Check("aa.255.255.0/ww"), isc::InvalidParameter);
 }
 
-// Check that the comparison works - until we have a a message structure,
-// we can't check the matches function.
+// Check that the comparison works - note that "matches" just calls the
+// internal compare() code.
+//
+// Note that addresses passed to the class are expected to be in network-
+// byte order.  Therefore for the comparisons to work as expected, we must
+// convert the values to network-byte order first.
 
 TEST(IpCheck, V4Compare) {
-    // Exact address - match if given address matches stored address
-    Ipv4Check<int> acl1(0x23457f13, 32);
-    EXPECT_TRUE(acl1.compare(0x23457f13));
-    EXPECT_FALSE(acl1.compare(0x23457f12));
-    EXPECT_FALSE(acl1.compare(0x13457f13));
+    // Exact address - match if given address matches stored address.
+    DerivedV4Check acl1(htonl(0x23457f13), 32);
+    EXPECT_TRUE(acl1.matches(htonl(0x23457f13)));
+    EXPECT_FALSE(acl1.matches(htonl(0x23457f12)));
+    EXPECT_FALSE(acl1.matches(htonl(0x13457f13)));
 
     // Exact address - match if address does not match stored address
-    Ipv4Check<int> acl2(0x23457f13, 32, true);
-    EXPECT_FALSE(acl2.compare(0x23457f13));
-    EXPECT_TRUE(acl2.compare(0x23457f12));
-    EXPECT_TRUE(acl2.compare(0x13457f13));
+    DerivedV4Check acl2(htonl(0x23457f13), 32, true);
+    EXPECT_FALSE(acl2.matches(htonl(0x23457f13)));
+    EXPECT_TRUE(acl2.matches(htonl(0x23457f12)));
+    EXPECT_TRUE(acl2.matches(htonl(0x13457f13)));
 
     // Match if the address matches a mask
-    Ipv4Check<int> acl3(0x23450000, 16);
-    EXPECT_TRUE(acl3.compare(0x23450000));
-    EXPECT_TRUE(acl3.compare(0x23450001));
-    EXPECT_TRUE(acl3.compare(0x2345ffff));
-    EXPECT_FALSE(acl3.compare(0x23460000));
-    EXPECT_FALSE(acl3.compare(0x2346ffff));
+    DerivedV4Check acl3(htonl(0x23450000), 16);
+    EXPECT_TRUE(acl3.matches(htonl(0x23450000)));
+    EXPECT_TRUE(acl3.matches(htonl(0x23450001)));
+    EXPECT_TRUE(acl3.matches(htonl(0x2345ffff)));
+    EXPECT_FALSE(acl3.matches(htonl(0x23460000)));
+    EXPECT_FALSE(acl3.matches(htonl(0x2346ffff)));
 
     // Match if the address does not match a mask
-    Ipv4Check<int> acl4(0x23450000, 16, true);
-    EXPECT_FALSE(acl4.compare(0x23450000));
-    EXPECT_FALSE(acl4.compare(0x23450001));
-    EXPECT_FALSE(acl4.compare(0x2345ffff));
-    EXPECT_TRUE(acl4.compare(0x23460000));
-    EXPECT_TRUE(acl4.compare(0x2346ffff));
-
-    // 
-}
+    DerivedV4Check acl4(htonl(0x23450000), 16, true);
+    EXPECT_FALSE(acl4.matches(htonl(0x23450000)));
+    EXPECT_FALSE(acl4.matches(htonl(0x23450001)));
+    EXPECT_FALSE(acl4.matches(htonl(0x2345ffff)));
+    EXPECT_TRUE(acl4.matches(htonl(0x23460000)));
+    EXPECT_TRUE(acl4.matches(htonl(0x2346ffff)));
 
+}