`<random>`: Skip reset in `normal_distribution::operator()` with custom params (#4618)

Co-authored-by: Stephan T. Lavavej <stl@nuwen.net>
This commit is contained in:
Martin Hořeňovský 2024-04-27 01:35:06 +02:00 коммит произвёл GitHub
Родитель bf1e9bdb30
Коммит 6b544648ce
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
6 изменённых файлов: 118 добавлений и 11 удалений

Просмотреть файл

@ -3427,8 +3427,7 @@ public:
template <class _Engine>
_NODISCARD result_type operator()(_Engine& _Eng, const param_type& _Par0) {
reset();
return _Eval(_Eng, _Par0, false);
return _Eval(_Eng, _Par0);
}
_NODISCARD_FRIEND bool operator==(const normal_distribution& _Left, const normal_distribution& _Right) noexcept
@ -3470,11 +3469,11 @@ public:
private:
template <class _Engine>
result_type _Eval(_Engine& _Eng, const param_type& _Par0,
bool _Keep = true) { // compute next value
// Knuth, vol. 2, p. 122, alg. P
result_type _Eval(_Engine& _Eng, const param_type& _Par0) {
// compute next value
// Knuth, vol. 2, p. 122, alg. P
_Ty _Res;
if (_Keep && _Valid) {
if (_Valid) {
_Res = _Xx2;
_Valid = false;
} else { // generate two values, store one, return one
@ -3507,11 +3506,9 @@ private:
}
const _Ty _Fx{_STD sqrt(_Ty{-2} * _LogSx / _Sx)};
if (_Keep) { // save second value for next call
_Xx2 = _Fx * _Vx2;
_Valid = true;
}
_Res = _Fx * _Vx1;
_Xx2 = _Fx * _Vx2; // save second value for next call
_Valid = true;
_Res = _Fx * _Vx1;
}
return _Res * _Par0._Sigma + _Par0._Mean;
}

Просмотреть файл

@ -243,6 +243,8 @@ tests\GH_004201_chrono_formatter
tests\GH_004275_seeking_fancy_iterators
tests\GH_004388_unordered_meow_operator_equal
tests\GH_004477_mdspan_warning_5246
tests\GH_004618_mixed_operator_usage_keeps_statistical_properties
tests\GH_004618_normal_distribution_avoids_resets
tests\LWG2381_num_get_floating_point
tests\LWG2597_complex_branch_cut
tests\LWG3018_shared_ptr_function

Просмотреть файл

@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
RUNALL_INCLUDE ..\usual_matrix.lst

Просмотреть файл

@ -0,0 +1,56 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include <cassert>
#include <cmath>
#include <cstddef>
#include <numeric>
#include <random>
#include <vector>
constexpr std::size_t N = 1000000;
void check_results(const std::vector<double>& generated, const double expected_mean, const double expected_variance) {
const double actual_mean = std::accumulate(generated.begin(), generated.end(), 0.0) / generated.size();
assert(std::abs((actual_mean - expected_mean) / expected_mean) < 0.01);
double actual_variance = 0;
double actual_skew = 0;
for (const auto& value : generated) {
const auto deviation = value - actual_mean;
actual_variance += deviation * deviation;
actual_skew += deviation * deviation * deviation;
}
actual_variance /= generated.size();
assert(std::abs((actual_variance - expected_variance) / expected_variance) < 0.01);
actual_skew /= generated.size() * actual_variance * std::sqrt(actual_variance);
assert(std::abs(actual_skew) < 0.01);
}
int main() {
// The basic idea is to generate values from two different distributions,
// one defined through the type's constructor, one through the params
// overload of operator(). The generated values are then checked for
// their expected statistical properties.
std::mt19937 rng;
std::normal_distribution<> dist(5.0, 4.0);
using dist_params = std::normal_distribution<>::param_type;
const dist_params params(50.0, 0.5);
std::vector<double> dist_results;
dist_results.reserve(N);
std::vector<double> param_results;
param_results.reserve(N);
// Make sure that we get some first and some second values for both
// generated distributions
for (std::size_t i = 0; i < N; i += 2) {
dist_results.push_back(dist(rng));
param_results.push_back(dist(rng, params));
param_results.push_back(dist(rng, params));
dist_results.push_back(dist(rng));
}
check_results(dist_results, dist.mean(), dist.stddev() * dist.stddev());
check_results(param_results, params.mean(), params.stddev() * params.stddev());
}

Просмотреть файл

@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
RUNALL_INCLUDE ..\usual_matrix.lst

Просмотреть файл

@ -0,0 +1,44 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include <cassert>
#include <cstddef>
#include <random>
class FakeGenerator {
private:
std::mt19937 m_underlying_gen;
std::size_t m_operator_calls = 0;
public:
using GenType = std::mt19937;
using result_type = GenType::result_type;
static constexpr result_type min() {
return GenType::min();
}
static constexpr result_type max() {
return GenType::max();
}
result_type operator()() {
++m_operator_calls;
return m_underlying_gen();
}
std::size_t calls() const {
return m_operator_calls;
}
};
int main() {
FakeGenerator rng;
std::normal_distribution<> dist(0.0, 1.0);
using dist_params = std::normal_distribution<>::param_type;
dist_params params(50.0, 0.5);
(void) dist(rng);
const auto calls_before = rng.calls();
(void) dist(rng);
const auto calls_after = rng.calls();
assert(calls_before == calls_after);
}