Browse Source

Add weighted random integer generator

git-svn-id: svn://bind10.isc.org/svn/bind10/branches/trac356@3656 e5f2f494-b856-4b98-b285-d166d9295462
Haidong Wang 14 years ago
parent
commit
fb7bf9b78c

+ 81 - 4
src/lib/nsas/random_number_generator.h

@@ -17,30 +17,35 @@
 #ifndef __NSAS_RANDOM_NUMBER_GENERATOR_H
 #define __NSAS_RANDOM_NUMBER_GENERATOR_H
 
+#include <numeric>
 #include <boost/random/mersenne_twister.hpp>
 #include <boost/random/uniform_int.hpp>
+#include <boost/random/uniform_real.hpp>
 #include <boost/random/variate_generator.hpp>
 
-
 namespace isc {
 namespace nsas {
 
 /// \brief Uniform random integer generator
 ///
 /// Generate uniformly distributed integers in range of [min, max]
-/// \param min The minimum number in the range
-/// \param max The maximum number in the range
 class UniformRandomIntegerGenerator{
 public:
+    /// \brief Constructor
+    ///
+    /// \param min The minimum number in the range
+    /// \param max The maximum number in the range
     UniformRandomIntegerGenerator(int min, int max):
         min_(min), max_(max), dist_(min, max), generator_(rng_, dist_)
     {
+        // Init with the current time
         rng_.seed(time(NULL));
     }
 
+    /// \brief Generate uniformly distributed integer
     int operator()() { return generator_(); }
 private:
-    ///< Hide default and copy constructor
+    /// Hide default and copy constructor
     UniformRandomIntegerGenerator();///< Default constructor
     UniformRandomIntegerGenerator(const UniformRandomIntegerGenerator&); ///< Copy constructor
 
@@ -51,6 +56,78 @@ private:
     boost::variate_generator<boost::mt19937&, boost::uniform_int<> > generator_; ///< Uniform generator
 };
 
+/// \brief Weighted random integer generator
+///
+/// Generate random integers according different probabilities
+class WeightedRandomIntegerGenerator{
+public:
+    /// \brief Constructor
+    ///
+    /// \param probabilities The probabies for all the integers, the probability must be 
+    /// between 0 and 1.0, the sum of probabilities must be equal to 1.
+    /// For example, if the probabilities contains the following values:
+    /// 0.5 0.3 0.2, the 1st integer will be generated more frequently than the
+    /// other integers and the probability is proportional to its value.
+    /// \param min The minimum integer that generated, other integers will be 
+    /// min, min + 1, ..., min + probabilities.size() - 1
+    WeightedRandomIntegerGenerator(const std::vector<double>& probabilities, int min = 0):
+        dist_(0, 1.0), uniform_real_gen_(rng_, dist_), min_(min)
+    {
+        // The probabilities must be valid
+        assert(isProbabilitiesValid(probabilities));
+        // Calculate the partial sum of probabilities
+        std::partial_sum(probabilities.begin(), probabilities.end(),
+                                     std::back_inserter(cumulative_));
+        // Init with the current time
+        rng_.seed(time(NULL));
+    }
+
+    /// \brief Generate weighted random integer
+    int operator()()
+    {
+        return std::lower_bound(cumulative_.begin(), cumulative_.end(), uniform_real_gen_()) 
+            - cumulative_.begin() + min_;
+    }
+
+    /// \brief Destroctor
+    ~WeightedRandomIntegerGenerator()
+    {
+    }
+
+private:
+    /// \brief Check the validation of probabilities vector
+    ///
+    /// The probability must be in range of [0, 1.0] and the sum must be equal to 1.0
+    /// Empty probabilities is also valid.
+    bool isProbabilitiesValid(const std::vector<double>& probabilities) const
+    {
+        typedef std::vector<double>::const_iterator Iterator;
+        double sum = probabilities.empty() ? 1 : 0;
+        for(Iterator it = probabilities.begin(); it != probabilities.end(); ++it){
+            //The probability must be in [0, 1.0]
+            if(*it < 0) return false;
+
+            if(*it > 1) return false;
+
+            sum += *it;
+        }
+
+        std::cout << sum << " " << (sum == 1.0) << std::endl;
+        double epsilon = 0.0001;
+        // The sum must be equal to 1
+        return fabs(sum - 1) < epsilon;
+    }
+
+    // Shortcut typedefs
+    typedef boost::variate_generator<boost::mt19937&, boost::uniform_real<> > UniformRealGenerator;
+
+    std::vector<double> cumulative_;            ///< The partial sum of the probabilities
+    boost::mt19937 rng_;                        ///< Mersenne Twister: A 623-dimensionally equidistributed uniform pseudo-random number generator 
+    boost::uniform_real<> dist_;                ///< Uniformly distributed real numbers
+    UniformRealGenerator uniform_real_gen_;     ///< Uniformly distributed random real numbers generator
+    int min_;                                   ///< The minimum integer that will be generated
+};
+
 }   // namespace dns
 }   // namespace isc
 

+ 145 - 4
src/lib/nsas/tests/random_number_generator_unittest.cc

@@ -54,7 +54,7 @@ private:
 
 // Test of the constructor
 TEST_F(UniformRandomIntegerGeneratorTest, Constructor) {
-    //The range must be min<=max
+    // The range must be min<=max
     ASSERT_DEATH(UniformRandomIntegerGenerator(3, 2), "");
 }
 
@@ -62,18 +62,159 @@ TEST_F(UniformRandomIntegerGeneratorTest, Constructor) {
 TEST_F(UniformRandomIntegerGeneratorTest, IntegerRange) {
     vector<int> numbers;
 
-    //Generate a lot of random integers
+    // Generate a lot of random integers
     for(int i = 0; i < max()*10; ++i){
         numbers.push_back(gen());
     }
 
-    //Remove the duplicated values
+    // Remove the duplicated values
     sort(numbers.begin(), numbers.end());
     vector<int>::iterator it = unique(numbers.begin(), numbers.end());
 
-    //make sure the numbers are in range [min, max]
+    // make sure the numbers are in range [min, max]
     ASSERT_EQ(it - numbers.begin(), max() - min() + 1); 
 }
 
+
+/// \brief Test Fixture Class for weighted random number generator
+class WeightedRandomIntegerGeneratorTest : public ::testing::Test {
+public:
+    WeightedRandomIntegerGeneratorTest():
+        gen_(NULL), min_(1)
+    {
+        // Initialize the probabilites vector
+        probabilities_.push_back(0.5);
+        probabilities_.push_back(0.3);
+        probabilities_.push_back(0.2);
+
+        gen_ = new WeightedRandomIntegerGenerator(probabilities_, min_);
+    }
+
+    int gen() { return (*gen_)(); }
+    int min() const { return min_; }
+    int max() const { return min_ + probabilities_.size() - 1; }
+
+    virtual ~WeightedRandomIntegerGeneratorTest()
+    {
+        delete gen_;
+    }
+
+private:
+    vector<double> probabilities_;
+    WeightedRandomIntegerGenerator *gen_;
+    int min_;
+};
+
+// Test of the weighted random number generator constructor
+TEST_F(WeightedRandomIntegerGeneratorTest, Constructor) 
+{
+    vector<double> probabilities;
+
+    // If no probabilities is provided, the smallest integer will always be generated
+    WeightedRandomIntegerGenerator gen(probabilities, 123);
+    for(int i = 0; i < 100; ++i){
+        ASSERT_EQ(gen(), 123);
+    }
+
+    //The probability must be >= 0
+    probabilities.push_back(-0.1);
+    probabilities.push_back(1.1);
+    ASSERT_DEATH(WeightedRandomIntegerGenerator gen2(probabilities), "");
+
+    //The probability must be <= 1.0
+    probabilities.clear();
+    probabilities.push_back(0.1);
+    probabilities.push_back(1.1);
+    ASSERT_DEATH(WeightedRandomIntegerGenerator gen3(probabilities), "");
+
+    //The sum must be equal to 1.0
+    probabilities.clear();
+    probabilities.push_back(0.2);
+    probabilities.push_back(0.9);
+    ASSERT_DEATH(WeightedRandomIntegerGenerator gen4(probabilities), "");
+
+    //The sum must be equal to 1.0
+    probabilities.clear();
+    probabilities.push_back(0.3);
+    probabilities.push_back(0.2);
+    probabilities.push_back(0.1);
+    ASSERT_DEATH(WeightedRandomIntegerGenerator gen5(probabilities), "");
+}
+
+// Test the randomization of the generator
+TEST_F(WeightedRandomIntegerGeneratorTest, WeightedRandomization) 
+{
+    {
+        vector<double> probabilities;
+        probabilities.push_back(0.5);
+        probabilities.push_back(0.5);
+
+        // Uniformly generated integers
+        WeightedRandomIntegerGenerator gen(probabilities);
+        int c1 = 0;
+        int c2 = 0;
+        for(int i = 0; i < 100000; ++i){
+            int n = gen();
+            if(n == 0) ++c1;
+            else if(n == 1) ++c2;
+        }
+        // The probabilities should almost equal
+        ASSERT_EQ(1, (int)(c1*1.0/c2 + 0.5));
+    }
+
+    {
+        vector<double> probabilities;
+        int c1 = 0;
+        int c2 = 0;
+        probabilities.push_back(0.2);
+        probabilities.push_back(0.8);
+        WeightedRandomIntegerGenerator gen(probabilities);
+        for(int i = 0; i < 100000; ++i){
+            int n = gen();
+            if(n == 0) ++c1;
+            else if(n == 1) ++c2;
+        }
+        // The 2nd integer count should be 4 times of 1st one
+        ASSERT_EQ(4, (int)(c2*1.0/c1 + 0.5));
+    }
+
+    {
+        vector<double> probabilities;
+        int c1 = 0;
+        int c2 = 0;
+        probabilities.push_back(0.8);
+        probabilities.push_back(0.2);
+        WeightedRandomIntegerGenerator gen(probabilities);
+        for(int i = 0; i < 100000; ++i){
+            int n = gen();
+            if(n == 0) ++c1;
+            else if(n == 1) ++c2;
+        }
+        // The 1st integer count should be 4 times of 2nd one
+        ASSERT_EQ(4, (int)(c1*1.0/c2 + 0.5));
+    }
+
+    {
+        vector<double> probabilities;
+        int c1 = 0;
+        int c2 = 0;
+        int c3 = 0;
+        probabilities.push_back(0.5);
+        probabilities.push_back(0.25);
+        probabilities.push_back(0.25);
+        WeightedRandomIntegerGenerator gen(probabilities);
+        for(int i = 0; i < 100000; ++i){
+            int n = gen();
+            if(n == 0) ++c1;
+            else if(n == 1) ++c2;
+            else if(n == 2) ++c3;
+        }
+        // The 1st integer count should be double of 2nd one
+        ASSERT_EQ(2, (int)(c1*1.0/c2 + 0.5));
+        // The 1st integer count should be double of 3rd one
+        ASSERT_EQ(2, (int)(c1*1.0/c3 + 0.5));
+    }
+}
+
 } // namespace nsas
 } // namespace isc