52 строки
1.4 KiB
Python
52 строки
1.4 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT license.
|
|
|
|
|
|
import numpy as np
|
|
|
|
def online_sampling(x_gen, k):
|
|
"""
|
|
an online sampler. With a generator x_gen, return k random samples.
|
|
"""
|
|
sampled = []
|
|
cnt = 0
|
|
while True:
|
|
try:
|
|
x = next(x_gen)
|
|
cnt += 1
|
|
if np.random.uniform() * cnt <= k:
|
|
if len(sampled) < k:
|
|
sampled.append(x)
|
|
else:
|
|
evict_idx = np.random.choice(range(k))
|
|
sampled[evict_idx] = x
|
|
except StopIteration:
|
|
break
|
|
return sampled
|
|
|
|
|
|
def stochastic_log_dense(n, k):
|
|
"""
|
|
Sample k items from n options in log-dense fashion.
|
|
"""
|
|
n_chunks = int(np.log2(n - 1)) + 2
|
|
prob_chunks = np.power(2.0, range(n_chunks))
|
|
prob_chunks /= float(np.sum(prob_chunks))
|
|
if k > n_chunks:
|
|
k = n_chunks
|
|
# TODO fix this later
|
|
selected_chunks = np.random.choice(
|
|
range(n_chunks), k, replace=False, p=prob_chunks)
|
|
|
|
chunki_to_start = [
|
|
max(0, n - np.power(2, i)) for i in reversed(range(n_chunks))
|
|
]
|
|
|
|
sampled_idx = []
|
|
for chunki in selected_chunks:
|
|
start = chunki_to_start[chunki]
|
|
next_start = (
|
|
n if chunki + 1 == n_chunks else chunki_to_start[chunki + 1])
|
|
idx = np.random.choice(range(start, next_start))
|
|
sampled_idx.append(idx)
|
|
return sampled_idx |