Parcourir la source

[2096] [2097] make sure RdataReader::getSize work w/o using a ctor parameter.

and, since it was the only purpose of the 'size' param of the constructor,
it's now removed.
test cases are adjusted accordingly.
JINMEI Tatuya il y a 12 ans
Parent
commit
d2a9f546a9

+ 37 - 2
src/lib/datasrc/memory/rdata_reader.cc

@@ -50,13 +50,12 @@ RdataReader::Result::Result(const uint8_t* data, size_t size) :
 {}
 
 RdataReader::RdataReader(const RRClass& rrclass, const RRType& rrtype,
-                         size_t size, const uint8_t* data,
+                         const uint8_t* data,
                          size_t rdata_count, size_t sig_count,
                          const NameAction& name_action,
                          const DataAction& data_action) :
     name_action_(name_action),
     data_action_(data_action),
-    size_(size),
     spec_(getRdataEncodeSpec(rrclass, rrtype)),
     var_count_total_(spec_.varlen_count * rdata_count),
     sig_count_(sig_count),
@@ -145,6 +144,42 @@ RdataReader::nextSig() {
     }
 }
 
+size_t
+RdataReader::getSize() const {
+    size_t storage_size = 0;    // this will be the end result
+    size_t data_pos = 0;
+    size_t length_pos = 0;
+
+    // Go over all data fields, adding their lengths to storage_size
+    for (size_t spec_pos = 0; spec_pos < spec_count_; ++spec_pos) {
+        const RdataFieldSpec& spec =
+            spec_.fields[spec_pos % spec_.field_count];
+        if (spec.type == RdataFieldSpec::DOMAIN_NAME) {
+            const size_t seq_len =
+                LabelSequence(data_ + data_pos).getSerializedLength();
+            data_pos += seq_len;
+            storage_size += seq_len;
+        } else {
+            const size_t data_len =
+                (spec.type == RdataFieldSpec::FIXEDLEN_DATA ?
+                 spec.fixeddata_len : lengths_[length_pos++]);
+            data_pos += data_len;
+            storage_size += data_len;
+        }
+    }
+    // Same for all RRSIG data
+    for (size_t sig_pos = 0; sig_pos < sig_count_; ++sig_pos) {
+        const size_t sig_data_len = lengths_[length_pos++];
+        storage_size += sig_data_len;
+    }
+
+    // Finally, add the size for 16-bit length fields
+    storage_size += (var_count_total_ * sizeof(uint16_t) +
+                     sig_count_ * sizeof(uint16_t));
+
+    return (storage_size);
+}
+
 }
 }
 }

+ 13 - 6
src/lib/datasrc/memory/rdata_reader.h

@@ -110,15 +110,13 @@ public:
     ///
     /// \param rrclass The class the encoded rdata belongs to.
     /// \param rrtype The type of the encode rdata.
-    /// \param size Number of bytes the data have in serialized form.
     /// \param data The actual data.
     /// \param rdata_count The number of Rdata encoded in the data.
     /// \param sig_count The number of RRSig rdata bundled with the data.
     /// \param name_action The callback to be called on each encountered name.
     /// \param data_action The callback to be called on each data chunk.
     RdataReader(const dns::RRClass& rrclass, const dns::RRType& rrtype,
-                size_t size, const uint8_t* data,
-                size_t rdata_count, size_t sig_count,
+                const uint8_t* data, size_t rdata_count, size_t sig_count,
                 const NameAction& name_action = &emptyNameAction,
                 const DataAction& data_action = &emptyDataAction);
 
@@ -250,12 +248,21 @@ public:
 
     /// \brief Returns the size of associated data.
     ///
-    /// This just returns whatever was passed to the constructor as size.
-    size_t getSize() const { return (size_); }
+    /// This should be the same as the return value of
+    /// RdataEncoder::getStorageLength() for the same set of data.
+    /// The intended use of this method is to tell the caller the size of
+    /// data that were possibly dynamically allocated so that the caller can
+    /// use it for deallocation.
+    ///
+    /// This method only uses the parameters given at the construction of the
+    /// object, and does not rely on or modify other mutable states.
+    /// In practice, when the caller wants to call this method, that would be
+    /// the only purpose of that RdataReader object (although it doesn't have
+    /// to be so).
+    size_t getSize() const;
 private:
     const NameAction name_action_;
     const DataAction data_action_;
-    const size_t size_;
     const RdataEncodeSpec& spec_;
     // Total number of var-length fields, count of signatures
     const size_t var_count_total_, sig_count_, spec_count_;

+ 21 - 19
src/lib/datasrc/memory/tests/rdata_serialization_unittest.cc

@@ -268,7 +268,7 @@ public:
                        size_t expected_varlen_fields,
                        // Warning: this test actualy might change the
                        // encoded_data !
-                       vector<uint8_t>& encoded_data,
+                       vector<uint8_t>& encoded_data, size_t,
                        MessageRenderer& renderer)
     {
         // If this type of RDATA is expected to contain variable-length fields,
@@ -319,11 +319,11 @@ public:
     static void decode(const isc::dns::RRClass& rrclass,
                        const isc::dns::RRType& rrtype,
                        size_t rdata_count, size_t sig_count, size_t,
-                       const vector<uint8_t>& encoded_data,
+                       const vector<uint8_t>& encoded_data, size_t,
                        MessageRenderer& renderer)
     {
-        RdataReader reader(rrclass, rrtype, encoded_data.size(),
-                           &encoded_data[0], rdata_count, sig_count);
+        RdataReader reader(rrclass, rrtype, &encoded_data[0], rdata_count,
+                           sig_count);
         RdataReader::Result field;
         while ((field = reader.next())) {
             switch (field.type()) {
@@ -359,11 +359,11 @@ public:
     static void decode(const isc::dns::RRClass& rrclass,
                        const isc::dns::RRType& rrtype,
                        size_t rdata_count, size_t sig_count, size_t,
-                       const vector<uint8_t>& encoded_data,
+                       const vector<uint8_t>& encoded_data, size_t,
                        MessageRenderer& renderer)
     {
-        RdataReader reader(rrclass, rrtype, encoded_data.size(),
-                           &encoded_data[0], rdata_count, sig_count);
+        RdataReader reader(rrclass, rrtype, &encoded_data[0], rdata_count,
+                           sig_count);
         // Use the reader first and rewind it
         reader.iterateSig();
         reader.iterate();
@@ -402,11 +402,11 @@ public:
     static void decode(const isc::dns::RRClass& rrclass,
                        const isc::dns::RRType& rrtype,
                        size_t rdata_count, size_t sig_count, size_t,
-                       const vector<uint8_t>& encoded_data,
+                       const vector<uint8_t>& encoded_data, size_t,
                        MessageRenderer& renderer)
     {
-        RdataReader reader(rrclass, rrtype, encoded_data.size(),
-                           &encoded_data[0], rdata_count, sig_count,
+        RdataReader reader(rrclass, rrtype, &encoded_data[0], rdata_count,
+                           sig_count,
                            boost::bind(renderNameField, &renderer,
                                        additionalRequired(rrtype), _1, _2),
                            boost::bind(renderDataField, &renderer, _1, _2));
@@ -422,11 +422,11 @@ public:
     static void decode(const isc::dns::RRClass& rrclass,
                        const isc::dns::RRType& rrtype,
                        size_t rdata_count, size_t sig_count, size_t,
-                       const vector<uint8_t>& encoded_data,
+                       const vector<uint8_t>& encoded_data, size_t,
                        MessageRenderer& renderer)
     {
-        RdataReader reader(rrclass, rrtype, encoded_data.size(),
-                           &encoded_data[0], rdata_count, sig_count,
+        RdataReader reader(rrclass, rrtype, &encoded_data[0],
+                           rdata_count, sig_count,
                            boost::bind(renderNameField, &renderer,
                                        additionalRequired(rrtype), _1, _2),
                            boost::bind(renderDataField, &renderer, _1, _2));
@@ -460,17 +460,18 @@ public:
                        const isc::dns::RRType& rrtype,
                        size_t rdata_count, size_t sig_count, size_t,
                        const vector<uint8_t>& encoded_data,
+                       size_t encoded_data_len,
                        MessageRenderer& renderer)
     {
         vector<uint8_t> data;
         MessageRenderer* current;
-        RdataReader reader(rrclass, rrtype, encoded_data.size(),
-                           &encoded_data[0], rdata_count, sig_count,
+        RdataReader reader(rrclass, rrtype, &encoded_data[0],
+                           rdata_count, sig_count,
                            boost::bind(renderNameField, &renderer,
                                        additionalRequired(rrtype), _1, _2),
                            boost::bind(appendData, &data, &current, _1, _2));
         // The size matches
-        EXPECT_EQ(encoded_data.size(), reader.getSize());
+        EXPECT_EQ(encoded_data_len, reader.getSize());
         if (start_sig) {
             current = NULL;
             reader.nextSig();
@@ -496,7 +497,7 @@ public:
         renderer.writeName(dummy_name2);
         renderer.writeData(&data[0], data.size());
         // The size matches even after use
-        EXPECT_EQ(encoded_data.size(), reader.getSize());
+        EXPECT_EQ(encoded_data_len, reader.getSize());
     }
 };
 
@@ -566,10 +567,11 @@ checkEncode(RRClass rrclass, RRType rrtype,
     BOOST_FOREACH(const ConstRdataPtr& rdata, rrsig_list) {
         encoder_.addSIGRdata(*rdata);
     }
-    encodeWrapper(encoder_.getStorageLength());
+    const size_t storage_len = encoder_.getStorageLength();
+    encodeWrapper(storage_len);
 
     DecoderStyle::decode(rrclass, rrtype, rdata_list.size(), rrsig_list.size(),
-                         expected_varlen_fields, encoded_data_,
+                         expected_varlen_fields, encoded_data_, storage_len,
                          actual_renderer_);
 
     // Two sets of wire-format data should be identical.