Browse Source

[3010] Avoid signed integer overflow

Shane Kerr 11 years ago
parent
commit
2d742ae954
1 changed files with 28 additions and 16 deletions
  1. 28 16
      src/lib/dns/rrttl.cc

+ 28 - 16
src/lib/dns/rrttl.cc

@@ -42,14 +42,15 @@ myIsalpha(char c) {
 struct Unit {
     char unit;
     uint32_t multiply;
+    uint32_t max_allowed;
 };
 
 Unit units[] = {
-    { 'S', 1 },
-    { 'M', 60 },
-    { 'H', 60 * 60 },
-    { 'D', 24 * 60 * 60 },
-    { 'W', 7 * 24 * 60 * 60 }
+    { 'S', 1,                0xffffffff / 1 },
+    { 'M', 60,               0xffffffff / 60 },
+    { 'H', 60 * 60,          0xffffffff / (60 * 60) },
+    { 'D', 24 * 60 * 60,     0xffffffff / (24 * 60 * 60) },
+    { 'W', 7 * 24 * 60 * 60, 0xffffffff / (7 * 24 * 60 * 60) }
 };
 
 }
@@ -66,11 +67,9 @@ parseTTLString(const string& ttlstr, uint32_t& ttlval, string* error_txt) {
         }
         return (false);
     }
-    // We use a larger data type during the computation. This is because
-    // some compilers don't fail when out of range, so we check the range
-    // ourselves later.
-    int64_t val = 0;
 
+    // We use a larger data type to handle negative number cases.
+    uint64_t val = 0;
     const string::const_iterator end = ttlstr.end();
     string::const_iterator pos = ttlstr.begin();
 
@@ -92,7 +91,7 @@ parseTTLString(const string& ttlstr, uint32_t& ttlval, string* error_txt) {
                 } else {
                     // Case without any units at all. Just convert and store
                     // it.
-                    val = boost::lexical_cast<int64_t>(ttlstr);
+                    val = boost::lexical_cast<uint64_t>(ttlstr);
                     break;
                 }
             }
@@ -100,11 +99,13 @@ parseTTLString(const string& ttlstr, uint32_t& ttlval, string* error_txt) {
             units_mode = true;
             // Find the unit and get the size.
             uint32_t multiply = 1;  // initialize to silence compiler warnings
+            uint32_t max_allowed = 0xffffffff;
             bool found = false;
             for (size_t i = 0; i < sizeof(units) / sizeof(*units); ++i) {
                 if (toupper(*unit) == units[i].unit) {
                     found = true;
                     multiply = units[i].multiply;
+                    max_allowed = units[i].max_allowed;
                     break;
                 }
             }
@@ -122,15 +123,25 @@ parseTTLString(const string& ttlstr, uint32_t& ttlval, string* error_txt) {
                 }
                 return (false);
             }
-            const int64_t value = boost::lexical_cast<int64_t>(string(pos,
-                                                                      unit));
+            const uint64_t value =
+                boost::lexical_cast<uint64_t>(string(pos, unit));
+            if (value > max_allowed) {
+                if (error_txt != NULL) {
+                    *error_txt = "Part of TTL out of range: "  + ttlstr;
+                }
+                return (false);
+            }
+
+            // seconds cannot be out of range at this point.
+            const uint64_t seconds = value * multiply;
+            assert(seconds <= 0xffffffff);
+
             // Add what we found
-            val += multiply * value;
+            val += seconds;
             // Check the partial value is still in range (the value can only
             // grow, so if we get out of range now, it won't get better, so
             // there's no need to continue).
-            if (value < 0 || value > 0xffffffff || val < 0 ||
-                val > 0xffffffff) {
+            if (val < seconds || val > 0xffffffff) {
                 if (error_txt != NULL) {
                     *error_txt = "Part of TTL out of range: "  + ttlstr;
                 }
@@ -146,9 +157,10 @@ parseTTLString(const string& ttlstr, uint32_t& ttlval, string* error_txt) {
         return (false);
     }
 
-    if (val >= 0 && val <= 0xffffffff) {
+    if (val <= 0xffffffff) {
         ttlval = val;
     } else {
+        // This could be due to negative numbers in input, etc.
         if (error_txt != NULL) {
             *error_txt = "TTL out of range: " + ttlstr;
         }