Michal 'vorner' Vaner 12 years ago
parent
commit
44c321c37e

+ 102 - 10
src/lib/dns/rrttl.cc

@@ -21,25 +21,117 @@
 #include <dns/messagerenderer.h>
 #include <dns/rrttl.h>
 
+#include <boost/lexical_cast.hpp>
+#include <algorithm>
+#include <cctype>
+
 using namespace std;
 using namespace isc::dns;
 using namespace isc::util;
 
+namespace {
+
+// We wrap the C isalpha, because it seems to be overloaded with something.
+// Then the find_if doesn't work.
+bool
+myIsalpha(char c) {
+    return (isalpha(c));
+}
+
+// The conversion of units to their size
+struct Unit {
+    char unit;
+    uint32_t multiply;
+};
+
+Unit units[] = {
+    { 'S', 1 },
+    { 'M', 60 },
+    { 'H', 60 * 60 },
+    { 'D', 24 * 60 * 60 },
+    { 'W', 7 * 24 * 60 * 60 }
+};
+
+}
+
 namespace isc {
 namespace dns {
 
 RRTTL::RRTTL(const std::string& ttlstr) {
-    // Some systems (at least gcc-4.4) flow negative values over into
-    // unsigned integer, where older systems failed to parse. We want
-    // that failure here, so we extract into int64 and check the value
-    int64_t val;
-
-    istringstream iss(ttlstr);
-    iss >> dec >> val;
-    if (iss.rdstate() == ios::eofbit && val >= 0 && val <= 0xffffffff) {
-        ttlval_ = static_cast<uint32_t>(val);
+    if (ttlstr.empty()) {
+        isc_throw(InvalidRRTTL, "Empty TTL string");
+    }
+    // We use a larger data type during the computation. This is because
+    // some compilers don't fail when out of range, so we check the range
+    // ourselves later.
+    int64_t val = 0;
+
+    const string::const_iterator end = ttlstr.end();
+    string::const_iterator pos = ttlstr.begin();
+
+    // When we detect we have some units
+    bool units_mode = false;
+
+    try {
+        while (pos != end) {
+            // Find the first unit, if there's any.
+            const string::const_iterator unit = find_if(pos, end, myIsalpha);
+            // No unit
+            if (unit == end) {
+                if (units_mode) {
+                    // We had some units before. The last one is missing unit.
+                    isc_throw(InvalidRRTTL, "Missing the last unit: " <<
+                              ttlstr);
+                } else {
+                    // Case without any units at all. Just convert and store
+                    // it.
+                    val = boost::lexical_cast<int64_t>(ttlstr);
+                    break;
+                }
+            }
+            // There's a unit now.
+            units_mode = true;
+            // Find the unit and get the size.
+            uint32_t multiply;
+            bool found = false;
+            for (size_t i = 0; i < sizeof(units) / sizeof(*units); ++i) {
+                if (toupper(*unit) == units[i].unit) {
+                    found = true;
+                    multiply = units[i].multiply;
+                    break;
+                }
+            }
+            if (!found) {
+                isc_throw(InvalidRRTTL, "Unknown unit used: " << *unit <<
+                          " in: " << ttlstr);
+            }
+            // Now extract the number.
+            if (unit == pos) {
+                isc_throw(InvalidRRTTL, "Missing number in TTL: " << ttlstr);
+            }
+            const int64_t value = boost::lexical_cast<int64_t>(string(pos,
+                                                                      unit));
+            // Add what we found
+            val += multiply * value;
+            // Check the partial value is still in range (the value can only
+            // grow, so if we get out of range now, it won't get better, so
+            // there's no need to continue).
+            if (value < 0 || value > 0xffffffff || val < 0 ||
+                val > 0xffffffff) {
+                isc_throw(InvalidRRTTL, "Part of TTL out of range: " <<
+                          ttlstr);
+            }
+            // Move to after the unit.
+            pos = unit + 1;
+        }
+    } catch (const boost::bad_lexical_cast&) {
+        isc_throw(InvalidRRTTL, "invalid TTL: " << ttlstr);
+    }
+
+    if (val >= 0 && val <= 0xffffffff) {
+        ttlval_ = val;
     } else {
-        isc_throw(InvalidRRTTL, "invalid TTL");
+        isc_throw(InvalidRRTTL, "TTL out of range: " << ttlstr);
     }
 }
 

+ 12 - 9
src/lib/dns/rrttl.h

@@ -74,15 +74,18 @@ public:
     explicit RRTTL(uint32_t ttlval) : ttlval_(ttlval) {}
     /// Constructor from a string.
     ///
-    /// This version of the implementation only accepts decimal TTL values in
-    /// seconds.
-    /// In a near future version, we'll extend it so that we can accept more
-    /// convenient ones such as "2H" or "1D".
-    ///
-    /// If the given string is not recognized as a valid representation of
-    /// an RR TTL, an exception of class \c InvalidRRTTL will be thrown.
-    ///
-    /// \param ttlstr A string representation of the \c RRTTL
+    /// It accepts either a decimal number, specifying number of seconds. Or,
+    /// it can be given a sequence of numbers and units, like "2H" (meaning
+    /// two hours), "1W3D" (one week and 3 days). The allowed units are W
+    /// (week), D (day), H (hour), M (minute) and S (second). They can be also
+    /// specified in lower-case. No further restrictions are checked (so they
+    /// can be specified in arbitrary order and even things like "1D1D" can
+    /// be used to specify two days).
+    ///
+    /// \param ttlstr A string representation of the \c RRTTL.
+    ///
+    /// \throw InvalidRRTTL in case the string is not recognized as valid
+    ///     TTL representation.
     explicit RRTTL(const std::string& ttlstr);
     /// Constructor from wire-format data.
     ///

+ 9 - 5
src/lib/dns/tests/masterload_unittest.cc

@@ -307,16 +307,20 @@ TEST_F(MasterLoadTest, loadNonAtopSOA) {
                  MasterLoadError);
 }
 
+// Load TTL with units
+TEST_F(MasterLoadTest, loadUnitTTL) {
+    stringstream rr_stream2("example.com. 1D IN A 192.0.2.1");
+    masterLoad(rr_stream2, origin, zclass, callback);
+    EXPECT_EQ(1, results.size());
+    EXPECT_EQ(0, results[0]->getRdataIterator()->getCurrent().compare(
+                  *rdata::createRdata(RRType::A(), zclass, "192.0.2.1")));
+}
+
 TEST_F(MasterLoadTest, loadBadRRText) {
     rr_stream << "example..com. 3600 IN A 192.0.2.1"; // bad owner name
     EXPECT_THROW(masterLoad(rr_stream, origin, zclass, callback),
                  MasterLoadError);
 
-    // currently we only support numeric TTLs
-    stringstream rr_stream2("example.com. 1D IN A 192.0.2.1");
-    EXPECT_THROW(masterLoad(rr_stream2, origin, zclass, callback),
-                 MasterLoadError);
-
     // bad RR class text
     stringstream rr_stream3("example.com. 3600 BAD A 192.0.2.1");
     EXPECT_THROW(masterLoad(rr_stream3, origin, zclass, callback),

+ 74 - 2
src/lib/dns/tests/rrttl_unittest.cc

@@ -65,20 +65,92 @@ RRTTLTest::rrttlFactoryFromWire(const char* datafile) {
     return (RRTTL(buffer));
 }
 
-TEST_F(RRTTLTest, fromText) {
+TEST_F(RRTTLTest, getValue) {
     EXPECT_EQ(0, ttl_0.getValue());
     EXPECT_EQ(3600, ttl_1h.getValue());
     EXPECT_EQ(86400, ttl_1d.getValue());
     EXPECT_EQ(0x12345678, ttl_32bit.getValue());
     EXPECT_EQ(0xffffffff, ttl_max.getValue());
+}
+
+TEST_F(RRTTLTest, fromText) {
+    // Border cases
+    EXPECT_EQ(0, RRTTL("0").getValue());
+    EXPECT_EQ(4294967295, RRTTL("4294967295").getValue());
 
-    EXPECT_THROW(RRTTL("1D"), InvalidRRTTL); // we don't support this form yet
+    // Invalid cases
     EXPECT_THROW(RRTTL("0xdeadbeef"), InvalidRRTTL); // must be decimal
     EXPECT_THROW(RRTTL("-1"), InvalidRRTTL); // must be positive
     EXPECT_THROW(RRTTL("1.1"), InvalidRRTTL); // must be integer
     EXPECT_THROW(RRTTL("4294967296"), InvalidRRTTL); // must be 32-bit
 }
 
+void
+checkUnit(unsigned multiply, char suffix) {
+    SCOPED_TRACE(string("Unit check with suffix ") + suffix);
+    const uint32_t value = 10 * multiply;
+    const string num = "10";
+    // Check both lower and upper version of the suffix
+    EXPECT_EQ(value,
+              RRTTL(num + static_cast<char>(tolower(suffix))).getValue());
+    EXPECT_EQ(value,
+              RRTTL(num + static_cast<char>(toupper(suffix))).getValue());
+}
+
+// Check parsing the unit form (1D, etc)
+TEST_F(RRTTLTest, fromTextUnit) {
+    // Check each of the units separately
+    checkUnit(1, 'S');
+    checkUnit(60, 'M');
+    checkUnit(60 * 60, 'H');
+    checkUnit(24 * 60 * 60, 'D');
+    checkUnit(7 * 24 * 60 * 60, 'W');
+
+    // Some border cases (with units)
+    EXPECT_EQ(4294967295, RRTTL("4294967295S").getValue());
+    EXPECT_EQ(0, RRTTL("0W0D0H0M0S").getValue());
+    EXPECT_EQ(4294967295, RRTTL("1193046H1695S").getValue());
+    // Leading zeroes are accepted
+    EXPECT_EQ(4294967295, RRTTL("0000000000000004294967295S").getValue());
+
+    // Now some compound ones. We allow any order (it would be much work to
+    // check the order anyway).
+    EXPECT_EQ(60 * 60 + 3, RRTTL("1H3S").getValue());
+
+    // Awkward, but allowed case - the same unit used twice.
+    EXPECT_EQ(20 * 3600, RRTTL("12H8H").getValue());
+
+    // Negative number in part of the expression, but the total is positive.
+    // Rejected.
+    EXPECT_THROW(RRTTL("-1S1H"), InvalidRRTTL);
+
+    // Some things out of range in the ttl, but it wraps to number in range
+    // in int64_t. Should still not get fooled and reject it.
+
+    // First part out of range
+    EXPECT_THROW(RRTTL("9223372036854775807S9223372036854775807S2S"),
+                 InvalidRRTTL);
+    // Second part out of range, but it immediately wraps (2S+2^64-2S)
+    EXPECT_THROW(RRTTL("2S18446744073709551614S"), InvalidRRTTL);
+    // The whole thing wraps right away (2^64S)
+    EXPECT_THROW(RRTTL("18446744073709551616S"), InvalidRRTTL);
+    // Second part out of range, and will become negative with the unit,
+    EXPECT_THROW(RRTTL("256S307445734561825856M"), InvalidRRTTL);
+
+    // Missing before unit.
+    EXPECT_THROW(RRTTL("W5H"), InvalidRRTTL);
+    EXPECT_THROW(RRTTL("5hW"), InvalidRRTTL);
+
+    // Empty string is not allowed
+    EXPECT_THROW(RRTTL(""), InvalidRRTTL);
+    // Missing the last unit is not allowed
+    EXPECT_THROW(RRTTL("3D5"), InvalidRRTTL);
+
+    // There are some wrong units
+    EXPECT_THROW(RRTTL("13X"), InvalidRRTTL);
+    EXPECT_THROW(RRTTL("3D5F"), InvalidRRTTL);
+}
+
 TEST_F(RRTTLTest, fromWire) {
     EXPECT_EQ(0x12345678,
               rrttlFactoryFromWire("rrcode32_fromWire1").getValue());