Browse Source

The first implementation of address selection logic

git-svn-id: svn://bind10.isc.org/svn/bind10/branches/trac356@3671 e5f2f494-b856-4b98-b285-d166d9295462
Haidong Wang 14 years ago
parent
commit
1ae6d476f4

+ 14 - 9
src/lib/nsas/nameserver_address.h

@@ -56,22 +56,30 @@ public:
     /// the shared_ptr can avoid the NameserverEntry object being dropped while the
     /// request is processing.
     /// \param index The address's index in NameserverEntry's addresses vector
-    NameserverAddress(boost::shared_ptr<NameserverEntry>& nameserver, uint32_t index):
-        ns_(nameserver), index_(index)
+    /// \param family Address family, AF_INET or AF_INET6
+    NameserverAddress(boost::shared_ptr<NameserverEntry>& nameserver, uint32_t index, short family):
+        ns_(nameserver), index_(index), family_(family)
     {
         if(!ns_.get()) isc_throw(NullNameserverEntryPointer, "NULL NameserverEntry pointer.");
     }
 
+    /// \brief Default Constructor
+    ///
+    NameserverAddress(): index_(0), family_(AF_INET)
+    {
+    }
+
     /// \brief Destructor
     ///
     /// Empty destructor.
     ~NameserverAddress()
-    {}
+    {
+    }
 
     /// \brief Return address
     ///
     asiolink::IOAddress getAddress() const { 
-        return ns_.get()->getAddressAtIndex(index_); 
+        return ns_.get()->getAddressAtIndex(index_, family_); 
     }
 
     /// \brief Update Round-trip Time
@@ -80,16 +88,13 @@ public:
     /// update the address's RTT.
     /// \param rtt The new Round-Trip Time
     void updateRTT(uint32_t rtt) { 
-        ns_.get()->updateAddressRTTAtIndex(rtt, index_); 
+        ns_.get()->updateAddressRTTAtIndex(rtt, index_, family_); 
     }
 private:
-    /// \brief Default Constructor
-    ///
-    /// A private default constructor to avoid creating an empty object.
-    NameserverAddress();
 
     boost::shared_ptr<NameserverEntry> ns_;  ///< Shared-pointer to NameserverEntry object
     uint32_t index_;                         ///< The address index in NameserverEntry
+    short family_;                           ///< Address family AF_INET or AF_INET6
 };
 
 } // namespace nsas

+ 109 - 22
src/lib/nsas/nameserver_entry.cc

@@ -30,6 +30,7 @@
 #include "rrttl.h"
 
 #include "address_entry.h"
+#include "nameserver_address.h"
 #include "nameserver_entry.h"
 
 using namespace asiolink;
@@ -40,19 +41,12 @@ using namespace std;
 namespace isc {
 namespace nsas {
 
-// Generate a small random RTT when initialize the list of addresses
-// to select all the addresses in unpredicable order
-// The initia RTT is between 0ms and 7ms which is as the same as bind9
-#define MIN_INIT_RTT 0
-#define MAX_INIT_RTT 7
-UniformRandomIntegerGenerator NameserverEntry::rndRttGen_(MIN_INIT_RTT, MAX_INIT_RTT);
-
 // Constructor, initialized with the list of addresses associated with this
 // nameserver.
 NameserverEntry::NameserverEntry(const AbstractRRset* v4Set,
     const AbstractRRset* v6Set, time_t curtime) : expiration_(0)
 {
-    uint32_t rtt = 0;       // Round-trip time for an address
+    uint32_t rtt = 1;       // Round-trip time for an address
     string v4name = "";     // Name from the V4 RRset
     string v6name = "";     // Name from the v6 RRset
     uint16_t v4class = 0;   // Class of V4 RRset
@@ -74,8 +68,8 @@ NameserverEntry::NameserverEntry(const AbstractRRset* v4Set,
         RdataIteratorPtr i = v4Set->getRdataIterator();
         i->first();
         while (! i->isLast()) {
-            address_.push_back(AddressEntry(IOAddress(i->getCurrent().toText()),
-            ++rtt));
+            v4_addresses_.push_back(AddressEntry(IOAddress(i->getCurrent().toText()),
+            rtt));
             i->next();
         }
 
@@ -83,6 +77,9 @@ NameserverEntry::NameserverEntry(const AbstractRRset* v4Set,
         expiration_ = curtime + v4Set->getTTL().getValue();
         v4name = v4Set->getName().toText(false);    // Ensure trailing dot
         v4class = v4Set->getClass().getCode();
+
+        // Update the address selector
+        updateAddressSelector(v4_addresses_, v4_address_selector_);
     }
 
     // Now the v6 addresses
@@ -91,8 +88,8 @@ NameserverEntry::NameserverEntry(const AbstractRRset* v4Set,
         RdataIteratorPtr i = v6Set->getRdataIterator();
         i->first();
         while (! i->isLast()) {
-            address_.push_back(AddressEntry(IOAddress(i->getCurrent().toText()),
-            ++rtt));
+            v6_addresses_.push_back(AddressEntry(IOAddress(i->getCurrent().toText()),
+            rtt));
             i->next();
         }
 
@@ -108,6 +105,9 @@ NameserverEntry::NameserverEntry(const AbstractRRset* v4Set,
         // Extract the name of the v6 set and its class
         v6name = v6Set->getName().toText(false);    // Ensure trailing dot
         v6class = v6Set->getClass().getCode();
+
+        // Update the address selector
+        updateAddressSelector(v6_addresses_, v6_address_selector_);
     }
 
     // TODO: Log a problem if both V4 and V6 address were null.
@@ -144,35 +144,94 @@ void NameserverEntry::getAddresses(AddressVector& addresses, short family) const
     // Now copy all entries that meet the criteria.  Since remove_copy_if
     // does the inverse (copies all entries that do not meet the criteria),
     // the predicate for address selection is negated.
-    remove_copy_if(address_.begin(), address_.end(), back_inserter(addresses),
+    remove_copy_if(v4_addresses_.begin(), v4_addresses_.end(), back_inserter(addresses),
+        bind1st(AddressSelection(), family));
+    remove_copy_if(v6_addresses_.begin(), v6_addresses_.end(), back_inserter(addresses),
         bind1st(AddressSelection(), family));
 }
 
-asiolink::IOAddress NameserverEntry::getAddressAtIndex(uint32_t index) const
+// Return one address matching the given family
+bool NameserverEntry::getAddress(boost::shared_ptr<NameserverEntry>& nameserver, 
+        NameserverAddress& address, short family)
+{
+
+    // The shared_ptr must contain this pointer
+    assert(nameserver.get() == this);
+
+    if(family == AF_INET){
+        if(v4_addresses_.size() == 0) return false;
+
+        address = NameserverAddress(nameserver, v4_address_selector_(), AF_INET);
+        return true;
+    } else if(family == AF_INET6){
+        if(v6_addresses_.size() == 0) return false;
+
+        address = NameserverAddress(nameserver, v6_address_selector_(), AF_INET6);
+        return true;
+    }
+    return false;
+}
+
+// Return the address corresponding to the family
+asiolink::IOAddress NameserverEntry::getAddressAtIndex(uint32_t index, short family) const
 {
-    assert(index < address_.size());
+    const vector<AddressEntry> *addresses = &v4_addresses_;
+    if(family == AF_INET6){
+        addresses = &v6_addresses_;
+    }
+    assert(index < addresses->size());
 
-    return address_[index].getAddress();
+    return (*addresses)[index].getAddress();
 }
 
 // Set the address RTT to a specific value
 void NameserverEntry::setAddressRTT(const IOAddress& address, uint32_t rtt) {
 
     // Search through the list of addresses for a match
-    for (AddressVectorIterator i = address_.begin(); i != address_.end(); ++i) {
+    for (AddressVectorIterator i = v4_addresses_.begin(); i != v4_addresses_.end(); ++i) {
+        if (i->getAddress().equal(address)) {
+            i->setRTT(rtt);
+
+            // Update the selector
+            updateAddressSelector(v4_addresses_, v4_address_selector_);
+            return;
+        }
+    }
+
+    // Search the v6 list
+    for (AddressVectorIterator i = v6_addresses_.begin(); i != v6_addresses_.end(); ++i) {
         if (i->getAddress().equal(address)) {
             i->setRTT(rtt);
+
+            // Update the selector
+            updateAddressSelector(v6_addresses_, v6_address_selector_);
+            return;
         }
     }
 }
 
 // Update the address's rtt 
-void NameserverEntry::updateAddressRTTAtIndex(uint32_t rtt, uint32_t index) {
-    //make sure it is a valid index
-    if(index >= address_.size()) return;
+#define UPDATE_RTT_ALPHA 0.7
+void NameserverEntry::updateAddressRTTAtIndex(uint32_t rtt, uint32_t index, short family) {
+    vector<AddressEntry>* addresses = &v4_addresses_;
+    if(family == AF_INET6){
+        addresses = &v6_addresses_;
+    }
 
-    //update the rtt
-    address_[index].setRTT(rtt);
+    //make sure it is a valid index
+    if(index >= addresses->size()) return;
+
+    // Smoothly update the rtt
+    // The algorithm is as the same as bind8/bind9:
+    //    new_rtt = old_rtt * alpha + new_rtt * (1 - alpha), where alpha is a float number in [0, 1.0]
+    // The default value for alpha is 0.7
+    uint32_t old_rtt = (*addresses)[index].getRTT();
+    uint32_t new_rtt = (int)(old_rtt * UPDATE_RTT_ALPHA + rtt * (1 - UPDATE_RTT_ALPHA));
+    (*addresses)[index].setRTT(new_rtt);
+
+    // Update the selector
+    if(family == AF_INET) updateAddressSelector(v4_addresses_, v4_address_selector_);
+    else if(family == AF_INET6) updateAddressSelector(v6_addresses_, v6_address_selector_);
 }
 
 // Sets the address to be unreachable
@@ -180,5 +239,33 @@ void NameserverEntry::setAddressUnreachable(const IOAddress& address) {
     setAddressRTT(address, AddressEntry::UNREACHABLE);
 }
 
+// Update the address selector according to the RTTs
+//
+// Each address has a probability to be selected if multiple addresses are available
+// The weight factor is equal to 1/(rtt*rtt), then all the weight factors are normalized
+// to make the sum equal to 1.0
+void NameserverEntry::updateAddressSelector(const std::vector<AddressEntry>& addresses, 
+        WeightedRandomIntegerGenerator& selector)
+{
+    vector<double> probabilities;
+    for(vector<AddressEntry>::const_iterator it = addresses.begin(); 
+            it != addresses.end(); ++it){
+        uint32_t rtt = (*it).getRTT();
+        if(rtt == 0) isc_throw(RTTIsZero, "The RTT is 0");
+
+        probabilities.push_back(1.0/(rtt*rtt));
+    }
+    // Calculate the sum
+    double sum = accumulate(probabilities.begin(), probabilities.end(), 0.0);
+
+    // Normalize the probabilities to make the sum equal to 1.0
+    for(vector<double>::iterator it = probabilities.begin(); 
+            it != probabilities.end(); ++it){
+        (*it) /= sum;
+    }
+
+    selector.reset(probabilities);
+}
+
 } // namespace dns
 } // namespace isc

+ 43 - 9
src/lib/nsas/nameserver_entry.h

@@ -33,6 +33,8 @@
 namespace isc {
 namespace nsas {
 
+class NameserverAddress;
+
 /// \brief Inconsistent Owner Names
 ///
 /// Thrown if a NameserverEntry is constructed from both an A and AAAA RRset
@@ -44,6 +46,16 @@ public:
     {}
 };
 
+/// \brief RTT is zero
+///
+/// Thrown if a RTT related with an address is 0.
+class RTTIsZero : public Exception {
+public:
+    RTTIsZero(const char* file, size_t line, const char* what) :
+        isc::Exception(file, line, what)
+    {}
+};
+
 /// \brief Inconsistent Class
 ///
 /// Thrown if a NameserverEntry is constructed from both an A and AAAA RRset
@@ -127,10 +139,22 @@ public:
     virtual void getAddresses(NameserverEntry::AddressVector& addresses,
         short family = 0) const;
 
+    /// \brief Return one address
+    ///
+    /// Return one address corresponding to this nameserver
+    /// \param nameserver The NamerserverEntry shared_ptr object. The NameserverAddress
+    ///        need to hold it to avoid NameserverEntry being released
+    /// \param address NameserverAddress object used to receive the address
+    /// \param family The family of user request, AF_INET or AF_INET6
+    /// \return true if one address is found, false otherwise
+    virtual bool getAddress(boost::shared_ptr<NameserverEntry>& nameserver, 
+            NameserverAddress& address, short family);
+
     /// \brief Return Address that corresponding to the index
     ///
     /// \param index The address index in the address vector
-    virtual asiolink::IOAddress getAddressAtIndex(uint32_t index) const;
+    /// \param family The address family, AF_INET or AF_INET6
+    virtual asiolink::IOAddress getAddressAtIndex(uint32_t index, short family) const;
 
     /// \brief Update RTT
     ///
@@ -144,7 +168,8 @@ public:
     ///
     /// \param rtt Round-Trip Time
     /// \param index The address's index in address vector
-    virtual void updateAddressRTTAtIndex(uint32_t rtt, uint32_t index);
+    /// \param family The address family, AF_INET or AF_INET6
+    virtual void updateAddressRTTAtIndex(uint32_t rtt, uint32_t index, short family);
 
     /// \brief Set Address Unreachable
     ///
@@ -196,13 +221,22 @@ public:
     };
 
 private:
-    boost::mutex    mutex_;                          ///< Mutex protecting this object
-    std::string     name_;                           ///< Canonical name of the nameserver
-    uint16_t        classCode_;                      ///< Class of the nameserver
-    std::vector<AddressEntry> address_;              ///< Set of V4/V6 addresses
-    time_t          expiration_;                     ///< Summary expiration time
-    time_t          last_access_;                    ///< Last access time to the structure
-    static UniformRandomIntegerGenerator rndRttGen_; ///< Small random RTT generator
+    /// \brief Update the address selector according to the RTTs of addresses
+    ///
+    /// \param addresses The address list
+    /// \param selector Weighted random generator
+    void updateAddressSelector(const std::vector<AddressEntry>& addresses, 
+            WeightedRandomIntegerGenerator& selector);
+
+    boost::mutex    mutex_;                              ///< Mutex protecting this object
+    std::string     name_;                               ///< Canonical name of the nameserver
+    uint16_t        classCode_;                          ///< Class of the nameserver
+    std::vector<AddressEntry> v4_addresses_;             ///< Set of V4 addresses
+    std::vector<AddressEntry> v6_addresses_;             ///< Set of V6 addresses
+    time_t          expiration_;                         ///< Summary expiration time
+    time_t          last_access_;                        ///< Last access time to the structure
+    WeightedRandomIntegerGenerator v4_address_selector_; ///< Generate one integer according to different probability
+    WeightedRandomIntegerGenerator v6_address_selector_; ///< Generate one integer according to different probability
 };
 
 }   // namespace dns

+ 31 - 1
src/lib/nsas/random_number_generator.h

@@ -81,6 +81,37 @@ public:
         // Init with the current time
         rng_.seed(time(NULL));
     }
+    
+    /// \brief Default constructor
+    ///
+    WeightedRandomIntegerGenerator():
+        dist_(0, 1.0), uniform_real_gen_(rng_, dist_), min_(0)
+    {
+    }
+
+    /// \brief Reset the probabilities
+    ///
+    /// Change the weights of each integers
+    /// \param probabilities The probabies for all the integers
+    /// \param min The minimum integer that generated
+    void reset(const std::vector<double>& probabilities, int min = 0)
+    {
+        // The probabilities must be valid
+        assert(isProbabilitiesValid(probabilities));
+
+        // Reset the cumulative sum
+        cumulative_.clear();
+
+        // Calculate the partial sum of probabilities
+        std::partial_sum(probabilities.begin(), probabilities.end(),
+                                     std::back_inserter(cumulative_));
+
+        // Reset the minimum integer
+        min_ = min;
+
+        // Reset the random number generator
+        rng_.seed(time(NULL));
+    }
 
     /// \brief Generate weighted random integer
     int operator()()
@@ -112,7 +143,6 @@ private:
             sum += *it;
         }
 
-        std::cout << sum << " " << (sum == 1.0) << std::endl;
         double epsilon = 0.0001;
         // The sum must be equal to 1
         return fabs(sum - 1) < epsilon;

+ 8 - 6
src/lib/nsas/tests/nameserver_address_unittest.cc

@@ -51,7 +51,7 @@ public:
     boost::shared_ptr<NameserverEntry>& getNameserverEntry() { return ns_; }
 
     // Return the IOAddress corresponding to the index in rrv4_
-    asiolink::IOAddress getAddressAtIndex(uint32_t index) { return ns_.get()->getAddressAtIndex(index); }
+    asiolink::IOAddress getAddressAtIndex(uint32_t index) { return ns_.get()->getAddressAtIndex(index, AF_INET); }
 
     // Return the addresses count stored in RRset
     unsigned int getAddressesCount() const { return rrv4_.getRdataCount(); }
@@ -73,8 +73,8 @@ class NameserverAddressTest : public ::testing::Test {
 protected:
     // Constructor
     NameserverAddressTest(): 
-        ns_address_(ns_sample_.getNameserverEntry(), TEST_ADDRESS_INDEX),
-        invalid_ns_address_(ns_sample_.getNameserverEntry(), ns_sample_.getAddressesCount())
+        ns_address_(ns_sample_.getNameserverEntry(), TEST_ADDRESS_INDEX, AF_INET),
+        invalid_ns_address_(ns_sample_.getNameserverEntry(), ns_sample_.getAddressesCount(), AF_INET)
     {
     }
 
@@ -95,7 +95,7 @@ TEST_F(NameserverAddressTest, Address) {
 
     boost::shared_ptr<NameserverEntry> empty_ne((NameserverEntry*)NULL);
     // It will throw an NullNameserverEntryPointer exception with the empty NameserverEntry shared pointer
-    ASSERT_THROW({NameserverAddress empty_ns_address(empty_ne, 0);}, NullNameserverEntryPointer);
+    ASSERT_THROW({NameserverAddress empty_ns_address(empty_ne, 0, AF_INET);}, NullNameserverEntryPointer);
 }
 
 // Test that the RTT is updated
@@ -106,10 +106,12 @@ TEST_F(NameserverAddressTest, UpdateRTT) {
     uint32_t old_rtt0 = ns_sample_.getAddressRTTAtIndex(0);
     uint32_t old_rtt2 = ns_sample_.getAddressRTTAtIndex(2);
 
-    ns_address_.updateRTT(new_rtt);
+    for(int i = 0; i < 10000; ++i){
+        ns_address_.updateRTT(new_rtt);
+    }
 
     //The RTT should have been updated
-    EXPECT_EQ(new_rtt, ns_sample_.getAddressRTTAtIndex(TEST_ADDRESS_INDEX));
+    EXPECT_NE(new_rtt, ns_sample_.getAddressRTTAtIndex(TEST_ADDRESS_INDEX));
 
     //The RTTs not been updated should remain unchanged
     EXPECT_EQ(old_rtt0, ns_sample_.getAddressRTTAtIndex(0));

+ 77 - 28
src/lib/nsas/tests/nameserver_entry_unittest.cc

@@ -29,6 +29,7 @@
 
 #include "asiolink.h"
 #include "address_entry.h"
+#include "nameserver_address.h"
 #include "nameserver_entry.h"
 
 #include "nsas_test.h"
@@ -279,34 +280,6 @@ TEST_F(NameserverEntryTest, AddressListConstructor) {
     CompareAddressVectors(dv, dvcomponent);
 }
 
-// Test the the RTT on tthe created addresses is not 0 and is different
-TEST_F(NameserverEntryTest, InitialRTT) {
-
-    // Get the RTT for the different addresses
-    NameserverEntry alpha(&rrv4_, &rrv6_);
-    NameserverEntry::AddressVector vec;
-    alpha.getAddresses(vec);
-
-    // Copy into a vector of time_t.
-    vector<uint32_t> rtt;
-    for (NameserverEntry::AddressVectorIterator i = vec.begin();
-        i != vec.end(); ++i) {
-        rtt.push_back(i->getRTT());
-    }
-
-    // Ensure that the addresses are sorted and note how many RTTs we have.
-    sort(rtt.begin(), rtt.end());
-    int oldcount = rtt.size();
-
-    // Remove duplicates and notw the new size.
-    vector<uint32_t>::iterator newend = unique(rtt.begin(), rtt.end());
-    rtt.erase(newend, rtt.end());
-    int newcount = rtt.size();
-
-    // .. and we don't expect to have lost anything.
-    EXPECT_EQ(oldcount, newcount);
-}
-
 // Set an address RTT to a given value
 TEST_F(NameserverEntryTest, SetRTT) {
 
@@ -450,6 +423,82 @@ TEST_F(NameserverEntryTest, CheckClass) {
 
 }
 
+// Select one address from the address list
+TEST_F(NameserverEntryTest, AddressSelection) {
+    boost::shared_ptr<NameserverEntry> ns(new NameserverEntry(&rrv4_, &rrv6_));
+
+    NameserverEntry::AddressVector v4Addresses;
+    NameserverEntry::AddressVector v6Addresses;
+    ns->getAddresses(v4Addresses, AF_INET);
+    ns->getAddresses(v6Addresses, AF_INET6);
+
+    int c1 = 0;
+    int c2 = 0;
+    int c3 = 0;
+    NameserverAddress ns_address;
+    for(int i = 0; i < 10000; ++i){
+        ns.get()->getAddress(ns, ns_address, AF_INET);
+        asiolink::IOAddress io_address = ns_address.getAddress();
+        if(io_address.toText() == v4Addresses[0].getAddress().toText()) ++c1;
+        else if(io_address.toText() == v4Addresses[1].getAddress().toText()) ++c2;
+        else if(io_address.toText() == v4Addresses[2].getAddress().toText()) ++c3;
+    }
+    // c1, c2 and c3 should almost be equal
+    ASSERT_EQ(1, (int)(c1*1.0/c2 + 0.5));
+    ASSERT_EQ(1, (int)(c2*1.0/c3 + 0.5));
+
+    // update the rtt to 1, 2, 3
+    ns->setAddressRTT(v4Addresses[0].getAddress(), 1);
+    ns->setAddressRTT(v4Addresses[1].getAddress(), 2);
+    ns->setAddressRTT(v4Addresses[2].getAddress(), 3);
+    c1 = c2 = c3 = 0; 
+    for(int i = 0; i < 100000; ++i){
+        ns.get()->getAddress(ns, ns_address, AF_INET);
+        asiolink::IOAddress io_address = ns_address.getAddress();
+        if(io_address.toText() == v4Addresses[0].getAddress().toText()) ++c1;
+        else if(io_address.toText() == v4Addresses[1].getAddress().toText()) ++c2;
+        else if(io_address.toText() == v4Addresses[2].getAddress().toText()) ++c3;
+    }
+
+    // c1 should be (2*2) times of c2
+    ASSERT_EQ(4, (int)(c1*1.0/c2 + 0.5));
+    // c1 should be (3*3) times of c3
+    ASSERT_EQ(9, (int)(c1*1.0/c3 + 0.5));
+}
+
+// Test the RTT is updated smoothly
+TEST_F(NameserverEntryTest, UpdateRTT) {
+    NameserverEntry ns(&rrv4_, &rrv6_);
+    NameserverEntry::AddressVector vec;
+    ns.getAddresses(vec);
+
+    // Initialize the rtt with a small value
+    uint32_t init_rtt = 1;
+    ns.setAddressRTT(vec[0].getAddress(), init_rtt);
+    // The rtt will be stablized to a large value
+    uint32_t stable_rtt = 100;
+
+    // Update the rtt
+    ns.updateAddressRTTAtIndex(stable_rtt, 0, AF_INET);
+
+    vec.clear();
+    ns.getAddresses(vec);
+    uint32_t new_rtt = vec[0].getRTT();
+
+    // The rtt should not close to new rtt immediately
+    ASSERT_TRUE((stable_rtt - new_rtt) > (new_rtt - init_rtt));
+
+    // Update the rtt for enough times
+    for(int i = 0; i < 10000; ++i){
+        ns.updateAddressRTTAtIndex(stable_rtt, 0, AF_INET);
+    }
+    vec.clear();
+    ns.getAddresses(vec);
+    new_rtt = vec[0].getRTT();
+
+    // The rtt should be close to stable rtt value
+    ASSERT_TRUE((stable_rtt - new_rtt) < (new_rtt - init_rtt));
+}
 
 }   // namespace nsas
 }   // namespace isc

+ 35 - 22
src/lib/nsas/tests/random_number_generator_unittest.cc

@@ -79,30 +79,11 @@ TEST_F(UniformRandomIntegerGeneratorTest, IntegerRange) {
 /// \brief Test Fixture Class for weighted random number generator
 class WeightedRandomIntegerGeneratorTest : public ::testing::Test {
 public:
-    WeightedRandomIntegerGeneratorTest():
-        gen_(NULL), min_(1)
-    {
-        // Initialize the probabilites vector
-        probabilities_.push_back(0.5);
-        probabilities_.push_back(0.3);
-        probabilities_.push_back(0.2);
-
-        gen_ = new WeightedRandomIntegerGenerator(probabilities_, min_);
-    }
-
-    int gen() { return (*gen_)(); }
-    int min() const { return min_; }
-    int max() const { return min_ + probabilities_.size() - 1; }
+    WeightedRandomIntegerGeneratorTest()
+    { }
 
     virtual ~WeightedRandomIntegerGeneratorTest()
-    {
-        delete gen_;
-    }
-
-private:
-    vector<double> probabilities_;
-    WeightedRandomIntegerGenerator *gen_;
-    int min_;
+    { }
 };
 
 // Test of the weighted random number generator constructor
@@ -216,5 +197,37 @@ TEST_F(WeightedRandomIntegerGeneratorTest, WeightedRandomization)
     }
 }
 
+// Test the reset function of generator
+TEST_F(WeightedRandomIntegerGeneratorTest, ResetProbabilities) 
+{
+        vector<double> probabilities;
+        int c1 = 0;
+        int c2 = 0;
+        probabilities.push_back(0.8);
+        probabilities.push_back(0.2);
+        WeightedRandomIntegerGenerator gen(probabilities);
+        for(int i = 0; i < 100000; ++i){
+            int n = gen();
+            if(n == 0) ++c1;
+            else if(n == 1) ++c2;
+        }
+        // The 1st integer count should be 4 times of 2nd one
+        ASSERT_EQ(4, (int)(c1*1.0/c2 + 0.5));
+
+        // Reset the probabilities
+        probabilities.clear();
+        c1 = c2 = 0;
+        probabilities.push_back(0.2);
+        probabilities.push_back(0.8);
+        gen.reset(probabilities);
+        for(int i = 0; i < 100000; ++i){
+            int n = gen();
+            if(n == 0) ++c1;
+            else if(n == 1) ++c2;
+        }
+        // The 2nd integer count should be 4 times of 1st one
+        ASSERT_EQ(4, (int)(c2*1.0/c1 + 0.5));
+}
+
 } // namespace nsas
 } // namespace isc