当前位置 : 主页 > 编程语言 > python >

python中的加权随机样本

来源:互联网 收集:自由互联 发布时间:2021-06-25
我正在寻找一个函数weighted_sample的合理定义,它不会为给定权重列表返回一个随机索引(这类似于 def weighted_choice(weights, random=random): """ Given a list of weights [w_0, w_1, ..., w_n-1], return an index i
我正在寻找一个函数weighted_sample的合理定义,它不会为给定权重列表返回一个随机索引(这类似于

def weighted_choice(weights, random=random):
    """ Given a list of weights [w_0, w_1, ..., w_n-1],
        return an index i in range(n) with probability proportional to w_i. """
    rnd = random.random() * sum(weights)
    for i, w in enumerate(weights):
        if w<0:
            raise ValueError("Negative weight encountered.")
        rnd -= w
        if rnd < 0:
            return i
    raise ValueError("Sum of weights is not positive")

给出一个具有恒定权重的分类分布)但是随机抽样的那些k,没有替换,就像random.sample行为与random.choice相比.

就像weighted_choice可以写成

lambda weights: random.choice([val for val, cnt in enumerate(weights)
    for i in range(cnt)])

weighted_sample可以写成

lambda weights, k: random.sample([val for val, cnt in enumerate(weights)
    for i in range(cnt)], k)

但我想要一个解决方案,不需要我将权重解析为(可能是巨大的)列表.

编辑:如果有任何好的算法可以返回一个直方图/频率列表(与参数权重的格式相同)而不是一系列索引,这也是非常有用的.

从你的代码:..

weight_sample_indexes = lambda weights, k: random.sample([val 
        for val, cnt in enumerate(weights) for i in range(cnt)], k)

..我认为权重是正整数,而“没有替换”你的意思是没有替换解开的序列.

这是一个基于random.sample和O(log n)__getitem__的解决方案:

import bisect
import random
from collections import Counter, Sequence

def weighted_sample(population, weights, k):
    return random.sample(WeightedPopulation(population, weights), k)

class WeightedPopulation(Sequence):
    def __init__(self, population, weights):
        assert len(population) == len(weights) > 0
        self.population = population
        self.cumweights = []
        cumsum = 0 # compute cumulative weight
        for w in weights:
            cumsum += w   
            self.cumweights.append(cumsum)  
    def __len__(self):
        return self.cumweights[-1]
    def __getitem__(self, i):
        if not 0 <= i < len(self):
            raise IndexError(i)
        return self.population[bisect.bisect(self.cumweights, i)]

total = Counter()
for _ in range(1000):
    sample = weighted_sample("abc", [1,10,2], 5)
    total.update(sample)
print(sample)
print("Frequences %s" % (dict(Counter(sample)),))

# Check that values are sane
print("Total " + ', '.join("%s: %.0f" % (val, count * 1.0 / min(total.values()))
                           for val, count in total.most_common()))

产量

['b', 'b', 'b', 'c', 'c']
Frequences {'c': 2, 'b': 3}
Total b: 10, c: 2, a: 1
网友评论