Browse Source

[2376] Disallow missing last unit

Also, simplify the code little bit and make it run faster in the case of
TTL without any units.
Michal 'vorner' Vaner 12 years ago
parent
commit
1a23885959
1 changed files with 30 additions and 20 deletions
  1. 30 20
      src/lib/dns/rrttl.cc

+ 30 - 20
src/lib/dns/rrttl.cc

@@ -69,27 +69,41 @@ RRTTL::RRTTL(const std::string& ttlstr) {
     const string::const_iterator end = ttlstr.end();
     string::const_iterator pos = ttlstr.begin();
 
+    // When we detect we have some units
+    bool units_mode = false;
+
     try {
         while (pos != end) {
             // Find the first unit, if there's any.
             const string::const_iterator unit = find_if(pos, end, myIsalpha);
-            // Default multiplication if no unit.
-            uint32_t multiply = 1;
-            if (unit != end) {
-                // Find the unit and get the size.
-                bool found = false;
-                for (size_t i = 0; i < sizeof(units) / sizeof(*units); ++i) {
-                    if (toupper(*unit) == units[i].unit) {
-                        found = true;
-                        multiply = units[i].multiply;
-                        break;
-                    }
+            // No unit
+            if (unit == end) {
+                if (units_mode) {
+                    // We had some units before. The last one is missing unit.
+                    isc_throw(InvalidRRTTL, "Missing the last unit: " <<
+                              ttlstr);
+                } else {
+                    // Case without any units at all. Just convert and store it.
+                    val = boost::lexical_cast<int64_t>(ttlstr);
+                    break;
                 }
-                if (!found) {
-                    isc_throw(InvalidRRTTL, "Unknown unit used: " << *unit <<
-                              " in: " << ttlstr);
+            }
+            // There's a unit now.
+            units_mode = true;
+            // Find the unit and get the size.
+            uint32_t multiply;
+            bool found = false;
+            for (size_t i = 0; i < sizeof(units) / sizeof(*units); ++i) {
+                if (toupper(*unit) == units[i].unit) {
+                    found = true;
+                    multiply = units[i].multiply;
+                    break;
                 }
             }
+            if (!found) {
+                isc_throw(InvalidRRTTL, "Unknown unit used: " << *unit <<
+                          " in: " << ttlstr);
+            }
             // Now extract the number.
             if (unit == pos) {
                 isc_throw(InvalidRRTTL, "Missing number in TTL: " << ttlstr);
@@ -105,12 +119,8 @@ RRTTL::RRTTL(const std::string& ttlstr) {
                 isc_throw(InvalidRRTTL, "Part of TTL out of range: " <<
                           ttlstr);
             }
-            // Move to after the unit (if any). But make sure not to increment
-            // past end, which is, strictly speaking, illegal.
-            pos = unit;
-            if (pos != end) {
-                ++pos;
-            }
+            // Move to after the unit.
+            pos = unit + 1;
         }
     } catch (const boost::bad_lexical_cast&) {
         isc_throw(InvalidRRTTL, "invalid TTL: " << ttlstr);