Browse Source

[2096] Implementation of the reader

The tests don't pass, for some reason, but it seems there are less
failures than before. There would probably be some off-by-something
error somewhere.
Michal 'vorner' Vaner 12 years ago
parent
commit
fdd121bda3

+ 104 - 3
src/lib/datasrc/memory/rdata_reader.cc

@@ -14,12 +14,14 @@
 
 #include "rdata_reader.h"
 
+using namespace isc::dns;
+
 namespace isc {
 namespace datasrc {
 namespace memory {
 
 void
-RdataReader::emptyNameAction(const dns::LabelSequence&, unsigned) {
+RdataReader::emptyNameAction(const LabelSequence&, unsigned) {
     // Do nothing here. On purpose, it's not unfinished.
 }
 
@@ -28,7 +30,7 @@ RdataReader::emptyDataAction(const uint8_t*, size_t) {
     // Do nothing here. On purpose, it's not unfinished.
 }
 
-RdataReader::Result::Result(const dns::LabelSequence& label,
+RdataReader::Result::Result(const LabelSequence& label,
                             unsigned attributes) :
     label_(label),
     data_(NULL),
@@ -39,7 +41,7 @@ RdataReader::Result::Result(const dns::LabelSequence& label,
 {}
 
 RdataReader::Result::Result(const uint8_t* data, size_t size) :
-    label_(dns::Name::ROOT_NAME()),
+    label_(Name::ROOT_NAME()),
     data_(data),
     size_(size),
     type_(DATA),
@@ -47,6 +49,105 @@ RdataReader::Result::Result(const uint8_t* data, size_t size) :
     additional_(false)
 {}
 
+const uint8_t*
+RdataReader::findSigs() const {
+    // Validate the lengths - we want to make sure all the lengths
+    // fit. The base would be the beginning of all the data.
+    const uint8_t* const
+        base(static_cast<const uint8_t*>(static_cast<const void*>(lengths_)));
+    if (base + size_ < data_) {
+        isc_throw(isc::BadValue, "Size won't even hold all the length fields");
+    }
+    // It is easier to look from the end than from the beginning -
+    // this way we just need to sum all the sig lenghs together.
+    size_t sum(0);
+    for (size_t i(0); i < sig_count_; ++ i) {
+        sum += lengths_[var_count_total_ + i];
+    }
+    const uint8_t* const result(data_ + size_ - sum);
+    // Validate the signatures fit.
+    if (result < data_) {
+        isc_throw(isc::BadValue, "Size won't even hold all the RRSigs");
+    }
+    return (result);
+}
+
+RdataReader::RdataReader(const RRClass& rrclass, const RRType& rrtype,
+                         size_t size, const uint8_t* data,
+                         size_t rdata_count, size_t sig_count,
+                         const NameAction& name_action,
+                         const DataAction& data_action) :
+    name_action_(name_action),
+    data_action_(data_action),
+    size_(size),
+    spec_(getRdataEncodeSpec(rrclass, rrtype)),
+    var_count_total_(spec_.varlen_count * rdata_count),
+    sig_count_(sig_count),
+    spec_count_(spec_.field_count * rdata_count),
+    // The casts, well, C++ decided it doesn't like completely valid
+    // and explicitly allowed cast in C, so we need to fool it through
+    // void.
+    lengths_(static_cast<const uint16_t*>(
+             static_cast<const void*>(data))), // The lenghts are stored first
+    // And the data just after all the lengths
+    data_(data + (var_count_total_ + sig_count_) * sizeof(uint16_t)),
+    sigs_(findSigs())
+{
+    rewind();
+}
+
+void
+RdataReader::rewind() {
+    data_pos_ = 0;
+    spec_pos_ = 0;
+    length_pos_ = 0;
+    sig_data_pos_ = 0;
+    sig_pos_ = 0;
+}
+
+RdataReader::Result
+RdataReader::next() {
+    // TODO: Add checks we are in the valid part of memory
+    if (spec_pos_ < spec_count_) {
+        const RdataFieldSpec& spec(spec_.fields[(spec_pos_ ++) %
+                                                spec_.field_count]);
+        if (spec.type == RdataFieldSpec::DOMAIN_NAME) {
+            const LabelSequence sequence(&data_[data_pos_]);
+            data_pos_ += sequence.getSerializedLength();
+            name_action_(sequence, spec.name_attributes);
+            return (Result(sequence, spec.name_attributes));
+        } else {
+            const size_t length(spec.type == RdataFieldSpec::FIXEDLEN_DATA ?
+                                spec.fixeddata_len : lengths_[length_pos_ ++]);
+            Result result(&data_[data_pos_], length);
+            data_pos_ += length;
+            data_action_(result.data(), result.size());
+            return (result);
+        }
+    } else {
+        return (Result());
+    }
+}
+
+RdataReader::Result
+RdataReader::nextSig() {
+    // We ensured the whole block of signatures is in the valid block,
+    // so we don't need any checking here.
+    if (sig_pos_ < sig_count_) {
+        // Extract the result
+        Result result(sigs_ + sig_data_pos_, lengths_[var_count_total_ +
+                      sig_pos_]);
+        // Move the position of iterator.
+        sig_data_pos_ += lengths_[var_count_total_ + sig_pos_];
+        sig_pos_ ++;
+        // Call the callback
+        data_action_(result.data(), result.size());
+        return (result);
+    } else {
+        return (Result());
+    }
+}
+
 }
 }
 }

+ 27 - 0
src/lib/datasrc/memory/rdata_reader.h

@@ -77,6 +77,12 @@ namespace memory {
 ///     }
 /// }
 /// \endcode
+///
+/// \note It is caller's responsibility to pass valid data here. This means
+///     the data returned by RdataEncoder and the corresponding class and type.
+///     The reader tries to detect some of the inconsistencies, but it's
+///     not perfect. Also, if anything throws, it is improper use of the class
+///     and the class might be in inconsistent or unknown state.
 class RdataReader {
 public:
     /// \brief Function called on each name encountered in the data.
@@ -108,10 +114,14 @@ public:
     /// \param rrtype The type of the encode rdata.
     /// \param size Number of bytes the data have in serialized form.
     /// \param data The actual data.
+    /// \param rdata_count The number of Rdata encoded in the data.
+    /// \param sig_count The number of RRSig rdata bundled with the data.
     /// \param name_action The callback to be called on each encountered name.
     /// \param data_action The callback to be called on each data chunk.
+    /// \throw isc::BadValue if mismatch of size and counts is detected.
     RdataReader(const dns::RRClass& rrclass, const dns::RRType& rrtype,
                 size_t size, const uint8_t* data,
+                size_t rdata_count, size_t sig_count,
                 const NameAction& name_action = &emptyNameAction,
                 const DataAction& data_action = &emptyDataAction);
 
@@ -245,6 +255,23 @@ public:
     ///
     /// This just returns whatever was passed to the constructor as size.
     size_t getSize() const;
+private:
+    const NameAction name_action_;
+    const DataAction data_action_;
+    const size_t size_;
+    const RdataEncodeSpec& spec_;
+    // Total number of var-length fields, count of signatures
+    const size_t var_count_total_, sig_count_, spec_count_;
+    // Pointer to the beginning of length fields
+    const uint16_t* const lengths_;
+    // Pointer to the beginning of the data (after the lengths)
+    const uint8_t* const data_;
+    // Pointer to the first data signature
+    const uint8_t* const sigs_;
+    const uint8_t* findSigs() const;
+    // The positions in data.
+    size_t data_pos_, spec_pos_, length_pos_;
+    size_t sig_pos_, sig_data_pos_;
 };
 
 }

+ 6 - 6
src/lib/datasrc/memory/tests/rdata_serialization_unittest.cc

@@ -317,12 +317,12 @@ class NextDecoder {
 public:
     static void decode(const isc::dns::RRClass& rrclass,
                        const isc::dns::RRType& rrtype,
-                       size_t, size_t, size_t,
+                       size_t rdata_count, size_t sig_count, size_t,
                        const vector<uint8_t>& encoded_data,
                        MessageRenderer& renderer)
     {
         RdataReader reader(rrclass, rrtype, encoded_data.size(),
-                           &encoded_data[0]);
+                           &encoded_data[0], rdata_count, sig_count);
         RdataReader::Result field;
         while (field = reader.next()) {
             switch (field.type()) {
@@ -356,12 +356,12 @@ class CallbackDecoder {
 public:
     static void decode(const isc::dns::RRClass& rrclass,
                        const isc::dns::RRType& rrtype,
-                       size_t, size_t, size_t,
+                       size_t rdata_count, size_t sig_count, size_t,
                        const vector<uint8_t>& encoded_data,
                        MessageRenderer& renderer)
     {
         RdataReader reader(rrclass, rrtype, encoded_data.size(),
-                           &encoded_data[0],
+                           &encoded_data[0], rdata_count, sig_count,
                            boost::bind(renderNameField, &renderer,
                                        additionalRequired(rrtype), _1, _2),
                            boost::bind(renderDataField, &renderer, _1, _2));
@@ -376,12 +376,12 @@ class IterateDecoder {
 public:
     static void decode(const isc::dns::RRClass& rrclass,
                        const isc::dns::RRType& rrtype,
-                       size_t, size_t, size_t,
+                       size_t rdata_count, size_t sig_count, size_t,
                        const vector<uint8_t>& encoded_data,
                        MessageRenderer& renderer)
     {
         RdataReader reader(rrclass, rrtype, encoded_data.size(),
-                           &encoded_data[0],
+                           &encoded_data[0], rdata_count, sig_count,
                            boost::bind(renderNameField, &renderer,
                                        additionalRequired(rrtype), _1, _2),
                            boost::bind(renderDataField, &renderer, _1, _2));