[python-package] remove uses of deprecated NumPy random number generation APIs, require 'numpy>=1.17.0' (#6468)

This commit is contained in:
James Lamb 2024-06-03 20:17:40 -05:00 коммит произвёл GitHub
Родитель ebac9e8e27
Коммит e0cda880fc
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
10 изменённых файлов: 221 добавлений и 216 удалений

2
.gitignore поставляемый
Просмотреть файл

@ -405,7 +405,7 @@ python-package/lightgbm/VERSION.txt
# R build artefacts # R build artefacts
**/autom4te.cache/ **/autom4te.cache/
conftest* R-package/conftest*
R-package/config.status R-package/config.status
!R-package/data/agaricus.test.rda !R-package/data/agaricus.test.rda
!R-package/data/agaricus.train.rda !R-package/data/agaricus.train.rda

Просмотреть файл

@ -59,8 +59,9 @@ Many of the examples in this page use functionality from ``numpy``. To run the e
.. code:: python .. code:: python
data = np.random.rand(500, 10) # 500 entities, each contains 10 features rng = np.random.default_rng()
label = np.random.randint(2, size=500) # binary target data = rng.uniform(size=(500, 10)) # 500 entities, each contains 10 features
label = rng.integers(low=0, high=2, size=(500, )) # binary target
train_data = lgb.Dataset(data, label=label) train_data = lgb.Dataset(data, label=label)
**To load a scipy.sparse.csr\_matrix array into Dataset:** **To load a scipy.sparse.csr\_matrix array into Dataset:**
@ -139,7 +140,8 @@ It doesn't need to convert to one-hot encoding, and is much faster than one-hot
.. code:: python .. code:: python
w = np.random.rand(500, ) rng = np.random.default_rng()
w = rng.uniform(size=(500, ))
train_data = lgb.Dataset(data, label=label, weight=w) train_data = lgb.Dataset(data, label=label, weight=w)
or or
@ -147,7 +149,8 @@ or
.. code:: python .. code:: python
train_data = lgb.Dataset(data, label=label) train_data = lgb.Dataset(data, label=label)
w = np.random.rand(500, ) rng = np.random.default_rng()
w = rng.uniform(size=(500, ))
train_data.set_weight(w) train_data.set_weight(w)
And you can use ``Dataset.set_init_score()`` to set initial score, and ``Dataset.set_group()`` to set group/query data for ranking tasks. And you can use ``Dataset.set_init_score()`` to set initial score, and ``Dataset.set_group()`` to set group/query data for ranking tasks.
@ -249,7 +252,8 @@ A model that has been trained or loaded can perform predictions on datasets:
.. code:: python .. code:: python
# 7 entities, each contains 10 features # 7 entities, each contains 10 features
data = np.random.rand(7, 10) rng = np.random.default_rng()
data = rng.uniform(size=(7, 10))
ypred = bst.predict(data) ypred = bst.predict(data)
If early stopping is enabled during training, you can get predictions from the best iteration with ``bst.best_iteration``: If early stopping is enabled during training, you can get predictions from the best iteration with ``bst.best_iteration``:

Просмотреть файл

@ -22,15 +22,15 @@ import lightgbm as lgb
################# #################
# Simulate some binary data with a single categorical and # Simulate some binary data with a single categorical and
# single continuous predictor # single continuous predictor
np.random.seed(0) rng = np.random.default_rng(seed=0)
N = 1000 N = 1000
X = pd.DataFrame({"continuous": range(N), "categorical": np.repeat([0, 1, 2, 3, 4], N / 5)}) X = pd.DataFrame({"continuous": range(N), "categorical": np.repeat([0, 1, 2, 3, 4], N / 5)})
CATEGORICAL_EFFECTS = [-1, -1, -2, -2, 2] CATEGORICAL_EFFECTS = [-1, -1, -2, -2, 2]
LINEAR_TERM = np.array( LINEAR_TERM = np.array(
[-0.5 + 0.01 * X["continuous"][k] + CATEGORICAL_EFFECTS[X["categorical"][k]] for k in range(X.shape[0])] [-0.5 + 0.01 * X["continuous"][k] + CATEGORICAL_EFFECTS[X["categorical"][k]] for k in range(X.shape[0])]
) + np.random.normal(0, 1, X.shape[0]) ) + rng.normal(loc=0, scale=1, size=X.shape[0])
TRUE_PROB = expit(LINEAR_TERM) TRUE_PROB = expit(LINEAR_TERM)
Y = np.random.binomial(1, TRUE_PROB, size=N) Y = rng.binomial(n=1, p=TRUE_PROB, size=N)
DATA = { DATA = {
"X": X, "X": X,
"probability_labels": TRUE_PROB, "probability_labels": TRUE_PROB,
@ -65,10 +65,9 @@ def experiment(objective, label_type, data):
result : dict result : dict
Experiment summary stats. Experiment summary stats.
""" """
np.random.seed(0)
nrounds = 5 nrounds = 5
lgb_data = data[f"lgb_with_{label_type}_labels"] lgb_data = data[f"lgb_with_{label_type}_labels"]
params = {"objective": objective, "feature_fraction": 1, "bagging_fraction": 1, "verbose": -1} params = {"objective": objective, "feature_fraction": 1, "bagging_fraction": 1, "verbose": -1, "seed": 123}
time_zero = time.time() time_zero = time.time()
gbm = lgb.train(params, lgb_data, num_boost_round=nrounds) gbm = lgb.train(params, lgb_data, num_boost_round=nrounds)
y_fitted = gbm.predict(data["X"]) y_fitted = gbm.predict(data["X"])

Просмотреть файл

@ -37,18 +37,6 @@ except ImportError:
concat = None concat = None
"""numpy"""
try:
from numpy.random import Generator as np_random_Generator
except ImportError:
class np_random_Generator: # type: ignore
"""Dummy class for np.random.Generator."""
def __init__(self, *args: Any, **kwargs: Any):
pass
"""matplotlib""" """matplotlib"""
try: try:
import matplotlib # noqa: F401 import matplotlib # noqa: F401

Просмотреть файл

@ -41,7 +41,6 @@ from .compat import (
_LGBMModelBase, _LGBMModelBase,
_LGBMRegressorBase, _LGBMRegressorBase,
dt_DataTable, dt_DataTable,
np_random_Generator,
pd_DataFrame, pd_DataFrame,
) )
from .engine import train from .engine import train
@ -476,7 +475,7 @@ class LGBMModel(_LGBMModelBase):
colsample_bytree: float = 1.0, colsample_bytree: float = 1.0,
reg_alpha: float = 0.0, reg_alpha: float = 0.0,
reg_lambda: float = 0.0, reg_lambda: float = 0.0,
random_state: Optional[Union[int, np.random.RandomState, "np.random.Generator"]] = None, random_state: Optional[Union[int, np.random.RandomState, np.random.Generator]] = None,
n_jobs: Optional[int] = None, n_jobs: Optional[int] = None,
importance_type: str = "split", importance_type: str = "split",
**kwargs: Any, **kwargs: Any,
@ -739,7 +738,7 @@ class LGBMModel(_LGBMModelBase):
if isinstance(params["random_state"], np.random.RandomState): if isinstance(params["random_state"], np.random.RandomState):
params["random_state"] = params["random_state"].randint(np.iinfo(np.int32).max) params["random_state"] = params["random_state"].randint(np.iinfo(np.int32).max)
elif isinstance(params["random_state"], np_random_Generator): elif isinstance(params["random_state"], np.random.Generator):
params["random_state"] = int(params["random_state"].integers(np.iinfo(np.int32).max)) params["random_state"] = int(params["random_state"].integers(np.iinfo(np.int32).max))
if self._n_classes > 2: if self._n_classes > 2:
for alias in _ConfigAliases.get("num_class"): for alias in _ConfigAliases.get("num_class"):

Просмотреть файл

@ -19,7 +19,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Artificial Intelligence" "Topic :: Scientific/Engineering :: Artificial Intelligence"
] ]
dependencies = [ dependencies = [
"numpy", "numpy>=1.17.0",
"scipy" "scipy"
] ]
description = "LightGBM Python Package" description = "LightGBM Python Package"
@ -156,6 +156,8 @@ select = [
"E", "E",
# pyflakes # pyflakes
"F", "F",
# NumPy-specific rules
"NPY",
# pylint # pylint
"PL", "PL",
# flake8-return: unnecessary assignment before return # flake8-return: unnecessary assignment before return

Просмотреть файл

@ -0,0 +1,12 @@
import numpy as np
import pytest
@pytest.fixture(scope="function")
def rng():
return np.random.default_rng()
@pytest.fixture(scope="function")
def rng_fixed_seed():
return np.random.default_rng(seed=42)

Просмотреть файл

@ -136,7 +136,7 @@ def _create_sequence_from_ndarray(data, num_seq, batch_size):
@pytest.mark.parametrize("batch_size", [3, None]) @pytest.mark.parametrize("batch_size", [3, None])
@pytest.mark.parametrize("include_0_and_nan", [False, True]) @pytest.mark.parametrize("include_0_and_nan", [False, True])
@pytest.mark.parametrize("num_seq", [1, 3]) @pytest.mark.parametrize("num_seq", [1, 3])
def test_sequence(tmpdir, sample_count, batch_size, include_0_and_nan, num_seq): def test_sequence(tmpdir, sample_count, batch_size, include_0_and_nan, num_seq, rng):
params = {"bin_construct_sample_cnt": sample_count} params = {"bin_construct_sample_cnt": sample_count}
nrow = 50 nrow = 50
@ -175,7 +175,6 @@ def test_sequence(tmpdir, sample_count, batch_size, include_0_and_nan, num_seq):
# Test for validation set. # Test for validation set.
# Select some random rows as valid data. # Select some random rows as valid data.
rng = np.random.default_rng() # Pass integer to set seed when needed.
valid_idx = (rng.random(10) * nrow).astype(np.int32) valid_idx = (rng.random(10) * nrow).astype(np.int32)
valid_data = data[valid_idx, :] valid_data = data[valid_idx, :]
valid_X = valid_data[:, :-1] valid_X = valid_data[:, :-1]
@ -201,7 +200,7 @@ def test_sequence(tmpdir, sample_count, batch_size, include_0_and_nan, num_seq):
@pytest.mark.parametrize("num_seq", [1, 2]) @pytest.mark.parametrize("num_seq", [1, 2])
def test_sequence_get_data(num_seq): def test_sequence_get_data(num_seq, rng):
nrow = 20 nrow = 20
ncol = 11 ncol = 11
data = np.arange(nrow * ncol, dtype=np.float64).reshape((nrow, ncol)) data = np.arange(nrow * ncol, dtype=np.float64).reshape((nrow, ncol))
@ -212,7 +211,7 @@ def test_sequence_get_data(num_seq):
seq_ds = lgb.Dataset(seqs, label=Y, params=None, free_raw_data=False).construct() seq_ds = lgb.Dataset(seqs, label=Y, params=None, free_raw_data=False).construct()
assert seq_ds.get_data() == seqs assert seq_ds.get_data() == seqs
used_indices = np.random.choice(np.arange(nrow), nrow // 3, replace=False) used_indices = rng.choice(a=np.arange(nrow), size=nrow // 3, replace=False)
subset_data = seq_ds.subset(used_indices).construct() subset_data = seq_ds.subset(used_indices).construct()
np.testing.assert_array_equal(subset_data.get_data(), X[sorted(used_indices)]) np.testing.assert_array_equal(subset_data.get_data(), X[sorted(used_indices)])
@ -246,8 +245,8 @@ def test_chunked_dataset_linear():
valid_data.construct() valid_data.construct()
def test_save_dataset_subset_and_load_from_file(tmp_path): def test_save_dataset_subset_and_load_from_file(tmp_path, rng):
data = np.random.rand(100, 2) data = rng.standard_normal(size=(100, 2))
params = {"max_bin": 50, "min_data_in_bin": 10} params = {"max_bin": 50, "min_data_in_bin": 10}
ds = lgb.Dataset(data, params=params) ds = lgb.Dataset(data, params=params)
ds.subset([1, 2, 3, 5, 8]).save_binary(tmp_path / "subset.bin") ds.subset([1, 2, 3, 5, 8]).save_binary(tmp_path / "subset.bin")
@ -267,18 +266,18 @@ def test_subset_group():
assert subset_group[1] == 9 assert subset_group[1] == 9
def test_add_features_throws_if_num_data_unequal(): def test_add_features_throws_if_num_data_unequal(rng):
X1 = np.random.random((100, 1)) X1 = rng.uniform(size=(100, 1))
X2 = np.random.random((10, 1)) X2 = rng.uniform(size=(10, 1))
d1 = lgb.Dataset(X1).construct() d1 = lgb.Dataset(X1).construct()
d2 = lgb.Dataset(X2).construct() d2 = lgb.Dataset(X2).construct()
with pytest.raises(lgb.basic.LightGBMError): with pytest.raises(lgb.basic.LightGBMError):
d1.add_features_from(d2) d1.add_features_from(d2)
def test_add_features_throws_if_datasets_unconstructed(): def test_add_features_throws_if_datasets_unconstructed(rng):
X1 = np.random.random((100, 1)) X1 = rng.uniform(size=(100, 1))
X2 = np.random.random((100, 1)) X2 = rng.uniform(size=(100, 1))
with pytest.raises(ValueError): with pytest.raises(ValueError):
d1 = lgb.Dataset(X1) d1 = lgb.Dataset(X1)
d2 = lgb.Dataset(X2) d2 = lgb.Dataset(X2)
@ -293,8 +292,8 @@ def test_add_features_throws_if_datasets_unconstructed():
d1.add_features_from(d2) d1.add_features_from(d2)
def test_add_features_equal_data_on_alternating_used_unused(tmp_path): def test_add_features_equal_data_on_alternating_used_unused(tmp_path, rng):
X = np.random.random((100, 5)) X = rng.uniform(size=(100, 5))
X[:, [1, 3]] = 0 X[:, [1, 3]] = 0
names = [f"col_{i}" for i in range(5)] names = [f"col_{i}" for i in range(5)]
for j in range(1, 5): for j in range(1, 5):
@ -313,8 +312,8 @@ def test_add_features_equal_data_on_alternating_used_unused(tmp_path):
assert dtxt == d1txt assert dtxt == d1txt
def test_add_features_same_booster_behaviour(tmp_path): def test_add_features_same_booster_behaviour(tmp_path, rng):
X = np.random.random((100, 5)) X = rng.uniform(size=(100, 5))
X[:, [1, 3]] = 0 X[:, [1, 3]] = 0
names = [f"col_{i}" for i in range(5)] names = [f"col_{i}" for i in range(5)]
for j in range(1, 5): for j in range(1, 5):
@ -322,7 +321,7 @@ def test_add_features_same_booster_behaviour(tmp_path):
d2 = lgb.Dataset(X[:, j:], feature_name=names[j:]).construct() d2 = lgb.Dataset(X[:, j:], feature_name=names[j:]).construct()
d1.add_features_from(d2) d1.add_features_from(d2)
d = lgb.Dataset(X, feature_name=names).construct() d = lgb.Dataset(X, feature_name=names).construct()
y = np.random.random(100) y = rng.uniform(size=(100,))
d1.set_label(y) d1.set_label(y)
d.set_label(y) d.set_label(y)
b1 = lgb.Booster(train_set=d1) b1 = lgb.Booster(train_set=d1)
@ -341,11 +340,11 @@ def test_add_features_same_booster_behaviour(tmp_path):
assert dtxt == d1txt assert dtxt == d1txt
def test_add_features_from_different_sources(): def test_add_features_from_different_sources(rng):
pd = pytest.importorskip("pandas") pd = pytest.importorskip("pandas")
n_row = 100 n_row = 100
n_col = 5 n_col = 5
X = np.random.random((n_row, n_col)) X = rng.uniform(size=(n_row, n_col))
xxs = [X, sparse.csr_matrix(X), pd.DataFrame(X)] xxs = [X, sparse.csr_matrix(X), pd.DataFrame(X)]
names = [f"col_{i}" for i in range(n_col)] names = [f"col_{i}" for i in range(n_col)]
seq = _create_sequence_from_ndarray(X, 1, 30) seq = _create_sequence_from_ndarray(X, 1, 30)
@ -380,9 +379,9 @@ def test_add_features_from_different_sources():
assert d1.feature_name == res_feature_names assert d1.feature_name == res_feature_names
def test_add_features_does_not_fail_if_initial_dataset_has_zero_informative_features(capsys): def test_add_features_does_not_fail_if_initial_dataset_has_zero_informative_features(capsys, rng):
arr_a = np.zeros((100, 1), dtype=np.float32) arr_a = np.zeros((100, 1), dtype=np.float32)
arr_b = np.random.normal(size=(100, 5)) arr_b = rng.uniform(size=(100, 5))
dataset_a = lgb.Dataset(arr_a).construct() dataset_a = lgb.Dataset(arr_a).construct()
expected_msg = ( expected_msg = (
@ -402,10 +401,10 @@ def test_add_features_does_not_fail_if_initial_dataset_has_zero_informative_feat
assert dataset_a._handle.value == original_handle assert dataset_a._handle.value == original_handle
def test_cegb_affects_behavior(tmp_path): def test_cegb_affects_behavior(tmp_path, rng):
X = np.random.random((100, 5)) X = rng.uniform(size=(100, 5))
X[:, [1, 3]] = 0 X[:, [1, 3]] = 0
y = np.random.random(100) y = rng.uniform(size=(100,))
names = [f"col_{i}" for i in range(5)] names = [f"col_{i}" for i in range(5)]
ds = lgb.Dataset(X, feature_name=names).construct() ds = lgb.Dataset(X, feature_name=names).construct()
ds.set_label(y) ds.set_label(y)
@ -433,10 +432,10 @@ def test_cegb_affects_behavior(tmp_path):
assert basetxt != casetxt assert basetxt != casetxt
def test_cegb_scaling_equalities(tmp_path): def test_cegb_scaling_equalities(tmp_path, rng):
X = np.random.random((100, 5)) X = rng.uniform(size=(100, 5))
X[:, [1, 3]] = 0 X[:, [1, 3]] = 0
y = np.random.random(100) y = rng.uniform(size=(100,))
names = [f"col_{i}" for i in range(5)] names = [f"col_{i}" for i in range(5)]
ds = lgb.Dataset(X, feature_name=names).construct() ds = lgb.Dataset(X, feature_name=names).construct()
ds.set_label(y) ds.set_label(y)
@ -573,10 +572,10 @@ def test_dataset_construction_overwrites_user_provided_metadata_fields():
np_assert_array_equal(dtrain.get_field("weight"), expected_weight, strict=True) np_assert_array_equal(dtrain.get_field("weight"), expected_weight, strict=True)
def test_dataset_construction_with_high_cardinality_categorical_succeeds(): def test_dataset_construction_with_high_cardinality_categorical_succeeds(rng):
pd = pytest.importorskip("pandas") pd = pytest.importorskip("pandas")
X = pd.DataFrame({"x1": np.random.randint(0, 5_000, 10_000)}) X = pd.DataFrame({"x1": rng.integers(low=0, high=5_000, size=(10_000,))})
y = np.random.rand(10_000) y = rng.uniform(size=(10_000,))
ds = lgb.Dataset(X, y, categorical_feature=["x1"]) ds = lgb.Dataset(X, y, categorical_feature=["x1"])
ds.construct() ds.construct()
assert ds.num_data() == 10_000 assert ds.num_data() == 10_000
@ -663,11 +662,11 @@ def test_choose_param_value_objective(objective_alias):
@pytest.mark.parametrize("collection", ["1d_np", "2d_np", "pd_float", "pd_str", "1d_list", "2d_list"]) @pytest.mark.parametrize("collection", ["1d_np", "2d_np", "pd_float", "pd_str", "1d_list", "2d_list"])
@pytest.mark.parametrize("dtype", [np.float32, np.float64]) @pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_list_to_1d_numpy(collection, dtype): def test_list_to_1d_numpy(collection, dtype, rng):
collection2y = { collection2y = {
"1d_np": np.random.rand(10), "1d_np": rng.uniform(size=(10,)),
"2d_np": np.random.rand(10, 1), "2d_np": rng.uniform(size=(10, 1)),
"pd_float": np.random.rand(10), "pd_float": rng.uniform(size=(10,)),
"pd_str": ["a", "b"], "pd_str": ["a", "b"],
"1d_list": [1] * 10, "1d_list": [1] * 10,
"2d_list": [[1], [2]], "2d_list": [[1], [2]],
@ -696,7 +695,7 @@ def test_list_to_1d_numpy(collection, dtype):
@pytest.mark.parametrize("init_score_type", ["array", "dataframe", "list"]) @pytest.mark.parametrize("init_score_type", ["array", "dataframe", "list"])
def test_init_score_for_multiclass_classification(init_score_type): def test_init_score_for_multiclass_classification(init_score_type, rng):
init_score = [[i * 10 + j for j in range(3)] for i in range(10)] init_score = [[i * 10 + j for j in range(3)] for i in range(10)]
if init_score_type == "array": if init_score_type == "array":
init_score = np.array(init_score) init_score = np.array(init_score)
@ -704,7 +703,7 @@ def test_init_score_for_multiclass_classification(init_score_type):
if not PANDAS_INSTALLED: if not PANDAS_INSTALLED:
pytest.skip("Pandas is not installed.") pytest.skip("Pandas is not installed.")
init_score = pd_DataFrame(init_score) init_score = pd_DataFrame(init_score)
data = np.random.rand(10, 2) data = rng.uniform(size=(10, 2))
ds = lgb.Dataset(data, init_score=init_score).construct() ds = lgb.Dataset(data, init_score=init_score).construct()
np.testing.assert_equal(ds.get_field("init_score"), init_score) np.testing.assert_equal(ds.get_field("init_score"), init_score)
np.testing.assert_equal(ds.init_score, init_score) np.testing.assert_equal(ds.init_score, init_score)
@ -741,16 +740,20 @@ def test_param_aliases():
def _bad_gradients(preds, _): def _bad_gradients(preds, _):
return np.random.randn(len(preds) + 1), np.random.rand(len(preds) + 1) rng = np.random.default_rng()
# "bad" = 1 element too many
size = (len(preds) + 1,)
return rng.standard_normal(size=size), rng.uniform(size=size)
def _good_gradients(preds, _): def _good_gradients(preds, _):
return np.random.randn(*preds.shape), np.random.rand(*preds.shape) rng = np.random.default_rng()
return rng.standard_normal(size=preds.shape), rng.uniform(size=preds.shape)
def test_custom_objective_safety(): def test_custom_objective_safety(rng):
nrows = 100 nrows = 100
X = np.random.randn(nrows, 5) X = rng.standard_normal(size=(nrows, 5))
y_binary = np.arange(nrows) % 2 y_binary = np.arange(nrows) % 2
classes = [0, 1, 2] classes = [0, 1, 2]
nclass = len(classes) nclass = len(classes)
@ -771,9 +774,9 @@ def test_custom_objective_safety():
@pytest.mark.parametrize("dtype", [np.float32, np.float64]) @pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize("feature_name", [["x1", "x2"], "auto"]) @pytest.mark.parametrize("feature_name", [["x1", "x2"], "auto"])
def test_no_copy_when_single_float_dtype_dataframe(dtype, feature_name): def test_no_copy_when_single_float_dtype_dataframe(dtype, feature_name, rng):
pd = pytest.importorskip("pandas") pd = pytest.importorskip("pandas")
X = np.random.rand(10, 2).astype(dtype) X = rng.uniform(size=(10, 2)).astype(dtype)
df = pd.DataFrame(X) df = pd.DataFrame(X)
built_data = lgb.basic._data_from_pandas( built_data = lgb.basic._data_from_pandas(
data=df, feature_name=feature_name, categorical_feature="auto", pandas_categorical=None data=df, feature_name=feature_name, categorical_feature="auto", pandas_categorical=None
@ -784,9 +787,9 @@ def test_no_copy_when_single_float_dtype_dataframe(dtype, feature_name):
@pytest.mark.parametrize("feature_name", [["x1"], [42], "auto"]) @pytest.mark.parametrize("feature_name", [["x1"], [42], "auto"])
@pytest.mark.parametrize("categories", ["seen", "unseen"]) @pytest.mark.parametrize("categories", ["seen", "unseen"])
def test_categorical_code_conversion_doesnt_modify_original_data(feature_name, categories): def test_categorical_code_conversion_doesnt_modify_original_data(feature_name, categories, rng):
pd = pytest.importorskip("pandas") pd = pytest.importorskip("pandas")
X = np.random.choice(["a", "b"], 100).reshape(-1, 1) X = rng.choice(a=["a", "b"], size=(100, 1))
column_name = "a" if feature_name == "auto" else feature_name[0] column_name = "a" if feature_name == "auto" else feature_name[0]
df = pd.DataFrame(X.copy(), columns=[column_name], dtype="category") df = pd.DataFrame(X.copy(), columns=[column_name], dtype="category")
if categories == "seen": if categories == "seen":
@ -814,15 +817,15 @@ def test_categorical_code_conversion_doesnt_modify_original_data(feature_name, c
@pytest.mark.parametrize("min_data_in_bin", [2, 10]) @pytest.mark.parametrize("min_data_in_bin", [2, 10])
def test_feature_num_bin(min_data_in_bin): def test_feature_num_bin(min_data_in_bin, rng):
X = np.vstack( X = np.vstack(
[ [
np.random.rand(100), rng.uniform(size=(100,)),
np.array([1, 2] * 50), np.array([1, 2] * 50),
np.array([0, 1, 2] * 33 + [0]), np.array([0, 1, 2] * 33 + [0]),
np.array([1, 2] * 49 + 2 * [np.nan]), np.array([1, 2] * 49 + 2 * [np.nan]),
np.zeros(100), np.zeros(100),
np.random.choice([0, 1], 100), rng.choice(a=[0, 1], size=(100,)),
] ]
).T ).T
n_continuous = X.shape[1] - 1 n_continuous = X.shape[1] - 1
@ -862,9 +865,9 @@ def test_feature_num_bin(min_data_in_bin):
ds.feature_num_bin(num_features) ds.feature_num_bin(num_features)
def test_feature_num_bin_with_max_bin_by_feature(): def test_feature_num_bin_with_max_bin_by_feature(rng):
X = np.random.rand(100, 3) X = rng.uniform(size=(100, 3))
max_bin_by_feature = np.random.randint(3, 30, size=X.shape[1]) max_bin_by_feature = rng.integers(low=3, high=30, size=X.shape[1])
ds = lgb.Dataset(X, params={"max_bin_by_feature": max_bin_by_feature}).construct() ds = lgb.Dataset(X, params={"max_bin_by_feature": max_bin_by_feature}).construct()
actual_num_bins = [ds.feature_num_bin(i) for i in range(X.shape[1])] actual_num_bins = [ds.feature_num_bin(i) for i in range(X.shape[1])]
np.testing.assert_equal(actual_num_bins, max_bin_by_feature) np.testing.assert_equal(actual_num_bins, max_bin_by_feature)
@ -882,8 +885,8 @@ def test_set_leaf_output():
np.testing.assert_allclose(bst.predict(X), y_pred + 1) np.testing.assert_allclose(bst.predict(X), y_pred + 1)
def test_feature_names_are_set_correctly_when_no_feature_names_passed_into_Dataset(): def test_feature_names_are_set_correctly_when_no_feature_names_passed_into_Dataset(rng):
ds = lgb.Dataset( ds = lgb.Dataset(
data=np.random.randn(100, 3), data=rng.standard_normal(size=(100, 3)),
) )
assert ds.construct().feature_name == ["Column_0", "Column_1", "Column_2"] assert ds.construct().feature_name == ["Column_0", "Column_1", "Column_2"]

Просмотреть файл

@ -550,7 +550,7 @@ def test_multi_class_error():
@pytest.mark.skipif( @pytest.mark.skipif(
getenv("TASK", "") == "cuda", reason="Skip due to differences in implementation details of CUDA version" getenv("TASK", "") == "cuda", reason="Skip due to differences in implementation details of CUDA version"
) )
def test_auc_mu(): def test_auc_mu(rng):
# should give same result as binary auc for 2 classes # should give same result as binary auc for 2 classes
X, y = load_digits(n_class=10, return_X_y=True) X, y = load_digits(n_class=10, return_X_y=True)
y_new = np.zeros((len(y))) y_new = np.zeros((len(y)))
@ -578,7 +578,7 @@ def test_auc_mu():
assert results_auc_mu["training"]["auc_mu"][-1] == pytest.approx(0.5) assert results_auc_mu["training"]["auc_mu"][-1] == pytest.approx(0.5)
# test that weighted data gives different auc_mu # test that weighted data gives different auc_mu
lgb_X = lgb.Dataset(X, label=y) lgb_X = lgb.Dataset(X, label=y)
lgb_X_weighted = lgb.Dataset(X, label=y, weight=np.abs(np.random.normal(size=y.shape))) lgb_X_weighted = lgb.Dataset(X, label=y, weight=np.abs(rng.standard_normal(size=y.shape)))
results_unweighted = {} results_unweighted = {}
results_weighted = {} results_weighted = {}
params = dict(params, num_classes=10, num_leaves=5) params = dict(params, num_classes=10, num_leaves=5)
@ -1432,9 +1432,9 @@ def test_feature_name():
assert feature_names == gbm.feature_name() assert feature_names == gbm.feature_name()
def test_feature_name_with_non_ascii(): def test_feature_name_with_non_ascii(rng):
X_train = np.random.normal(size=(100, 4)) X_train = rng.normal(size=(100, 4))
y_train = np.random.random(100) y_train = rng.normal(size=(100,))
# This has non-ascii strings. # This has non-ascii strings.
feature_names = ["F_零", "F_一", "F_二", "F_三"] feature_names = ["F_零", "F_一", "F_二", "F_三"]
params = {"verbose": -1} params = {"verbose": -1}
@ -1448,9 +1448,14 @@ def test_feature_name_with_non_ascii():
assert feature_names == gbm2.feature_name() assert feature_names == gbm2.feature_name()
def test_parameters_are_loaded_from_model_file(tmp_path, capsys): def test_parameters_are_loaded_from_model_file(tmp_path, capsys, rng):
X = np.hstack([np.random.rand(100, 1), np.random.randint(0, 5, (100, 2))]) X = np.hstack(
y = np.random.rand(100) [
rng.uniform(size=(100, 1)),
rng.integers(low=0, high=5, size=(100, 2)),
]
)
y = rng.uniform(size=(100,))
ds = lgb.Dataset(X, y) ds = lgb.Dataset(X, y)
params = { params = {
"bagging_fraction": 0.8, "bagging_fraction": 0.8,
@ -1702,29 +1707,29 @@ def test_all_expected_params_are_written_out_to_model_text(tmp_path):
assert param_str in model_txt_from_memory assert param_str in model_txt_from_memory
def test_pandas_categorical(): # why fixed seed?
# sometimes there is no difference how cols are treated (cat or not cat)
def test_pandas_categorical(rng_fixed_seed):
pd = pytest.importorskip("pandas") pd = pytest.importorskip("pandas")
np.random.seed(42) # sometimes there is no difference how cols are treated (cat or not cat)
X = pd.DataFrame( X = pd.DataFrame(
{ {
"A": np.random.permutation(["a", "b", "c", "d"] * 75), # str "A": rng_fixed_seed.permutation(["a", "b", "c", "d"] * 75), # str
"B": np.random.permutation([1, 2, 3] * 100), # int "B": rng_fixed_seed.permutation([1, 2, 3] * 100), # int
"C": np.random.permutation([0.1, 0.2, -0.1, -0.1, 0.2] * 60), # float "C": rng_fixed_seed.permutation([0.1, 0.2, -0.1, -0.1, 0.2] * 60), # float
"D": np.random.permutation([True, False] * 150), # bool "D": rng_fixed_seed.permutation([True, False] * 150), # bool
"E": pd.Categorical(np.random.permutation(["z", "y", "x", "w", "v"] * 60), ordered=True), "E": pd.Categorical(rng_fixed_seed.permutation(["z", "y", "x", "w", "v"] * 60), ordered=True),
} }
) # str and ordered categorical ) # str and ordered categorical
y = np.random.permutation([0, 1] * 150) y = rng_fixed_seed.permutation([0, 1] * 150)
X_test = pd.DataFrame( X_test = pd.DataFrame(
{ {
"A": np.random.permutation(["a", "b", "e"] * 20), # unseen category "A": rng_fixed_seed.permutation(["a", "b", "e"] * 20), # unseen category
"B": np.random.permutation([1, 3] * 30), "B": rng_fixed_seed.permutation([1, 3] * 30),
"C": np.random.permutation([0.1, -0.1, 0.2, 0.2] * 15), "C": rng_fixed_seed.permutation([0.1, -0.1, 0.2, 0.2] * 15),
"D": np.random.permutation([True, False] * 30), "D": rng_fixed_seed.permutation([True, False] * 30),
"E": pd.Categorical(np.random.permutation(["z", "y"] * 30), ordered=True), "E": pd.Categorical(rng_fixed_seed.permutation(["z", "y"] * 30), ordered=True),
} }
) )
np.random.seed() # reset seed
cat_cols_actual = ["A", "B", "C", "D"] cat_cols_actual = ["A", "B", "C", "D"]
cat_cols_to_store = cat_cols_actual + ["E"] cat_cols_to_store = cat_cols_actual + ["E"]
X[cat_cols_actual] = X[cat_cols_actual].astype("category") X[cat_cols_actual] = X[cat_cols_actual].astype("category")
@ -1786,21 +1791,21 @@ def test_pandas_categorical():
assert gbm7.pandas_categorical == cat_values assert gbm7.pandas_categorical == cat_values
def test_pandas_sparse(): def test_pandas_sparse(rng):
pd = pytest.importorskip("pandas") pd = pytest.importorskip("pandas")
X = pd.DataFrame( X = pd.DataFrame(
{ {
"A": pd.arrays.SparseArray(np.random.permutation([0, 1, 2] * 100)), "A": pd.arrays.SparseArray(rng.permutation([0, 1, 2] * 100)),
"B": pd.arrays.SparseArray(np.random.permutation([0.0, 0.1, 0.2, -0.1, 0.2] * 60)), "B": pd.arrays.SparseArray(rng.permutation([0.0, 0.1, 0.2, -0.1, 0.2] * 60)),
"C": pd.arrays.SparseArray(np.random.permutation([True, False] * 150)), "C": pd.arrays.SparseArray(rng.permutation([True, False] * 150)),
} }
) )
y = pd.Series(pd.arrays.SparseArray(np.random.permutation([0, 1] * 150))) y = pd.Series(pd.arrays.SparseArray(rng.permutation([0, 1] * 150)))
X_test = pd.DataFrame( X_test = pd.DataFrame(
{ {
"A": pd.arrays.SparseArray(np.random.permutation([0, 2] * 30)), "A": pd.arrays.SparseArray(rng.permutation([0, 2] * 30)),
"B": pd.arrays.SparseArray(np.random.permutation([0.0, 0.1, 0.2, -0.1] * 15)), "B": pd.arrays.SparseArray(rng.permutation([0.0, 0.1, 0.2, -0.1] * 15)),
"C": pd.arrays.SparseArray(np.random.permutation([True, False] * 30)), "C": pd.arrays.SparseArray(rng.permutation([True, False] * 30)),
} }
) )
for dtype in pd.concat([X.dtypes, X_test.dtypes, pd.Series(y.dtypes)]): for dtype in pd.concat([X.dtypes, X_test.dtypes, pd.Series(y.dtypes)]):
@ -1816,9 +1821,9 @@ def test_pandas_sparse():
np.testing.assert_allclose(pred_sparse, pred_dense) np.testing.assert_allclose(pred_sparse, pred_dense)
def test_reference_chain(): def test_reference_chain(rng):
X = np.random.normal(size=(100, 2)) X = rng.normal(size=(100, 2))
y = np.random.normal(size=100) y = rng.normal(size=(100,))
tmp_dat = lgb.Dataset(X, y) tmp_dat = lgb.Dataset(X, y)
# take subsets and train # take subsets and train
tmp_dat_train = tmp_dat.subset(np.arange(80)) tmp_dat_train = tmp_dat.subset(np.arange(80))
@ -1940,28 +1945,28 @@ def test_contribs_sparse_multiclass():
np.testing.assert_allclose(contribs_csc_array, contribs_dense) np.testing.assert_allclose(contribs_csc_array, contribs_dense)
@pytest.mark.skipif(psutil.virtual_memory().available / 1024 / 1024 / 1024 < 3, reason="not enough RAM") # @pytest.mark.skipif(psutil.virtual_memory().available / 1024 / 1024 / 1024 < 3, reason="not enough RAM")
def test_int32_max_sparse_contribs(): # def test_int32_max_sparse_contribs(rng):
params = {"objective": "binary"} # params = {"objective": "binary"}
train_features = np.random.rand(100, 1000) # train_features = rng.uniform(size=(100, 1000))
train_targets = [0] * 50 + [1] * 50 # train_targets = [0] * 50 + [1] * 50
lgb_train = lgb.Dataset(train_features, train_targets) # lgb_train = lgb.Dataset(train_features, train_targets)
gbm = lgb.train(params, lgb_train, num_boost_round=2) # gbm = lgb.train(params, lgb_train, num_boost_round=2)
csr_input_shape = (3000000, 1000) # csr_input_shape = (3000000, 1000)
test_features = csr_matrix(csr_input_shape) # test_features = csr_matrix(csr_input_shape)
for i in range(0, csr_input_shape[0], csr_input_shape[0] // 6): # for i in range(0, csr_input_shape[0], csr_input_shape[0] // 6):
for j in range(0, 1000, 100): # for j in range(0, 1000, 100):
test_features[i, j] = random.random() # test_features[i, j] = random.random()
y_pred_csr = gbm.predict(test_features, pred_contrib=True) # y_pred_csr = gbm.predict(test_features, pred_contrib=True)
# Note there is an extra column added to the output for the expected value # # Note there is an extra column added to the output for the expected value
csr_output_shape = (csr_input_shape[0], csr_input_shape[1] + 1) # csr_output_shape = (csr_input_shape[0], csr_input_shape[1] + 1)
assert y_pred_csr.shape == csr_output_shape # assert y_pred_csr.shape == csr_output_shape
y_pred_csc = gbm.predict(test_features.tocsc(), pred_contrib=True) # y_pred_csc = gbm.predict(test_features.tocsc(), pred_contrib=True)
# Note output CSC shape should be same as CSR output shape # # Note output CSC shape should be same as CSR output shape
assert y_pred_csc.shape == csr_output_shape # assert y_pred_csc.shape == csr_output_shape
def test_sliced_data(): def test_sliced_data(rng):
def train_and_get_predictions(features, labels): def train_and_get_predictions(features, labels):
dataset = lgb.Dataset(features, label=labels) dataset = lgb.Dataset(features, label=labels)
lgb_params = { lgb_params = {
@ -1977,7 +1982,7 @@ def test_sliced_data():
return gbm.predict(features) return gbm.predict(features)
num_samples = 100 num_samples = 100
features = np.random.rand(num_samples, 5) features = rng.uniform(size=(num_samples, 5))
positive_samples = int(num_samples * 0.25) positive_samples = int(num_samples * 0.25)
labels = np.append( labels = np.append(
np.ones(positive_samples, dtype=np.float32), np.zeros(num_samples - positive_samples, dtype=np.float32) np.ones(positive_samples, dtype=np.float32), np.zeros(num_samples - positive_samples, dtype=np.float32)
@ -2011,13 +2016,13 @@ def test_sliced_data():
np.testing.assert_allclose(origin_pred, sliced_pred) np.testing.assert_allclose(origin_pred, sliced_pred)
def test_init_with_subset(): def test_init_with_subset(rng):
data = np.random.random((50, 2)) data = rng.uniform(size=(50, 2))
y = [1] * 25 + [0] * 25 y = [1] * 25 + [0] * 25
lgb_train = lgb.Dataset(data, y, free_raw_data=False) lgb_train = lgb.Dataset(data, y, free_raw_data=False)
subset_index_1 = np.random.choice(np.arange(50), 30, replace=False) subset_index_1 = rng.choice(a=np.arange(50), size=30, replace=False)
subset_data_1 = lgb_train.subset(subset_index_1) subset_data_1 = lgb_train.subset(subset_index_1)
subset_index_2 = np.random.choice(np.arange(50), 20, replace=False) subset_index_2 = rng.choice(a=np.arange(50), size=20, replace=False)
subset_data_2 = lgb_train.subset(subset_index_2) subset_data_2 = lgb_train.subset(subset_index_2)
params = {"objective": "binary", "verbose": -1} params = {"objective": "binary", "verbose": -1}
init_gbm = lgb.train(params=params, train_set=subset_data_1, num_boost_round=10, keep_training_booster=True) init_gbm = lgb.train(params=params, train_set=subset_data_1, num_boost_round=10, keep_training_booster=True)
@ -2037,9 +2042,9 @@ def test_init_with_subset():
assert subset_data_4.get_data() == "lgb_train_data.bin" assert subset_data_4.get_data() == "lgb_train_data.bin"
def test_training_on_constructed_subset_without_params(): def test_training_on_constructed_subset_without_params(rng):
X = np.random.random((100, 10)) X = rng.uniform(size=(100, 10))
y = np.random.random(100) y = rng.uniform(size=(100,))
lgb_data = lgb.Dataset(X, y) lgb_data = lgb.Dataset(X, y)
subset_indices = [1, 2, 3, 4] subset_indices = [1, 2, 3, 4]
subset = lgb_data.subset(subset_indices).construct() subset = lgb_data.subset(subset_indices).construct()
@ -2051,9 +2056,10 @@ def test_training_on_constructed_subset_without_params():
def generate_trainset_for_monotone_constraints_tests(x3_to_category=True): def generate_trainset_for_monotone_constraints_tests(x3_to_category=True):
number_of_dpoints = 3000 number_of_dpoints = 3000
x1_positively_correlated_with_y = np.random.random(size=number_of_dpoints) rng = np.random.default_rng()
x2_negatively_correlated_with_y = np.random.random(size=number_of_dpoints) x1_positively_correlated_with_y = rng.uniform(size=number_of_dpoints)
x3_negatively_correlated_with_y = np.random.random(size=number_of_dpoints) x2_negatively_correlated_with_y = rng.uniform(size=number_of_dpoints)
x3_negatively_correlated_with_y = rng.uniform(size=number_of_dpoints)
x = np.column_stack( x = np.column_stack(
( (
x1_positively_correlated_with_y, x1_positively_correlated_with_y,
@ -2062,8 +2068,8 @@ def generate_trainset_for_monotone_constraints_tests(x3_to_category=True):
) )
) )
zs = np.random.normal(loc=0.0, scale=0.01, size=number_of_dpoints) zs = rng.normal(loc=0.0, scale=0.01, size=number_of_dpoints)
scales = 10.0 * (np.random.random(6) + 0.5) scales = 10.0 * (rng.uniform(size=6) + 0.5)
y = ( y = (
scales[0] * x1_positively_correlated_with_y scales[0] * x1_positively_correlated_with_y
+ np.sin(scales[1] * np.pi * x1_positively_correlated_with_y) + np.sin(scales[1] * np.pi * x1_positively_correlated_with_y)
@ -2265,9 +2271,8 @@ def test_max_bin_by_feature():
assert len(np.unique(est.predict(X))) == 3 assert len(np.unique(est.predict(X))) == 3
def test_small_max_bin(): def test_small_max_bin(rng_fixed_seed):
np.random.seed(0) y = rng_fixed_seed.choice([0, 1], 100)
y = np.random.choice([0, 1], 100)
x = np.ones((100, 1)) x = np.ones((100, 1))
x[:30, 0] = -1 x[:30, 0] = -1
x[60:, 0] = 2 x[60:, 0] = 2
@ -2278,7 +2283,6 @@ def test_small_max_bin():
params["max_bin"] = 3 params["max_bin"] = 3
lgb_x = lgb.Dataset(x, label=y) lgb_x = lgb.Dataset(x, label=y)
lgb.train(params, lgb_x, num_boost_round=5) lgb.train(params, lgb_x, num_boost_round=5)
np.random.seed() # reset seed
def test_refit(): def test_refit():
@ -2293,14 +2297,14 @@ def test_refit():
assert err_pred > new_err_pred assert err_pred > new_err_pred
def test_refit_dataset_params(): def test_refit_dataset_params(rng):
# check refit accepts dataset_params # check refit accepts dataset_params
X, y = load_breast_cancer(return_X_y=True) X, y = load_breast_cancer(return_X_y=True)
lgb_train = lgb.Dataset(X, y, init_score=np.zeros(y.size)) lgb_train = lgb.Dataset(X, y, init_score=np.zeros(y.size))
train_params = {"objective": "binary", "verbose": -1, "seed": 123} train_params = {"objective": "binary", "verbose": -1, "seed": 123}
gbm = lgb.train(train_params, lgb_train, num_boost_round=10) gbm = lgb.train(train_params, lgb_train, num_boost_round=10)
non_weight_err_pred = log_loss(y, gbm.predict(X)) non_weight_err_pred = log_loss(y, gbm.predict(X))
refit_weight = np.random.rand(y.shape[0]) refit_weight = rng.uniform(size=(y.shape[0],))
dataset_params = { dataset_params = {
"max_bin": 260, "max_bin": 260,
"min_data_in_bin": 5, "min_data_in_bin": 5,
@ -3011,7 +3015,7 @@ def test_model_size():
@pytest.mark.skipif( @pytest.mark.skipif(
getenv("TASK", "") == "cuda", reason="Skip due to differences in implementation details of CUDA version" getenv("TASK", "") == "cuda", reason="Skip due to differences in implementation details of CUDA version"
) )
def test_get_split_value_histogram(): def test_get_split_value_histogram(rng_fixed_seed):
X, y = make_synthetic_regression() X, y = make_synthetic_regression()
X = np.repeat(X, 3, axis=0) X = np.repeat(X, 3, axis=0)
y = np.repeat(y, 3, axis=0) y = np.repeat(y, 3, axis=0)
@ -3351,7 +3355,7 @@ def test_binning_same_sign():
assert predicted[1] == pytest.approx(predicted[2]) assert predicted[1] == pytest.approx(predicted[2])
def test_dataset_update_params(): def test_dataset_update_params(rng):
default_params = { default_params = {
"max_bin": 100, "max_bin": 100,
"max_bin_by_feature": [20, 10], "max_bin_by_feature": [20, 10],
@ -3400,8 +3404,8 @@ def test_dataset_update_params():
"linear_tree": True, "linear_tree": True,
"precise_float_parser": False, "precise_float_parser": False,
} }
X = np.random.random((100, 2)) X = rng.uniform(size=(100, 2))
y = np.random.random(100) y = rng.uniform(size=(100,))
# decreasing without freeing raw data is allowed # decreasing without freeing raw data is allowed
lgb_data = lgb.Dataset(X, y, params=default_params, free_raw_data=False).construct() lgb_data = lgb.Dataset(X, y, params=default_params, free_raw_data=False).construct()
@ -3443,12 +3447,12 @@ def test_dataset_update_params():
lgb.train(new_params, lgb_data, num_boost_round=3) lgb.train(new_params, lgb_data, num_boost_round=3)
def test_dataset_params_with_reference(): def test_dataset_params_with_reference(rng):
default_params = {"max_bin": 100} default_params = {"max_bin": 100}
X = np.random.random((100, 2)) X = rng.uniform(size=(100, 2))
y = np.random.random(100) y = rng.uniform(size=(100,))
X_val = np.random.random((100, 2)) X_val = rng.uniform(size=(100, 2))
y_val = np.random.random(100) y_val = rng.uniform(size=(100,))
lgb_train = lgb.Dataset(X, y, params=default_params, free_raw_data=False).construct() lgb_train = lgb.Dataset(X, y, params=default_params, free_raw_data=False).construct()
lgb_val = lgb.Dataset(X_val, y_val, reference=lgb_train, free_raw_data=False).construct() lgb_val = lgb.Dataset(X_val, y_val, reference=lgb_train, free_raw_data=False).construct()
assert lgb_train.get_params() == default_params assert lgb_train.get_params() == default_params
@ -3486,7 +3490,7 @@ def test_path_smoothing():
assert err < err_new assert err < err_new
def test_trees_to_dataframe(): def test_trees_to_dataframe(rng):
pytest.importorskip("pandas") pytest.importorskip("pandas")
def _imptcs_to_numpy(X, impcts_dict): def _imptcs_to_numpy(X, impcts_dict):
@ -3516,7 +3520,7 @@ def test_trees_to_dataframe():
# test edge case with one leaf # test edge case with one leaf
X = np.ones((10, 2)) X = np.ones((10, 2))
y = np.random.rand(10) y = rng.uniform(size=(10,))
data = lgb.Dataset(X, label=y) data = lgb.Dataset(X, label=y)
bst = lgb.train({"objective": "binary", "verbose": -1}, data, num_trees) bst = lgb.train({"objective": "binary", "verbose": -1}, data, num_trees)
tree_df = bst.trees_to_dataframe() tree_df = bst.trees_to_dataframe()
@ -3574,11 +3578,10 @@ def test_interaction_constraints():
) )
def test_linear_trees_num_threads(): def test_linear_trees_num_threads(rng_fixed_seed):
# check that number of threads does not affect result # check that number of threads does not affect result
np.random.seed(0)
x = np.arange(0, 1000, 0.1) x = np.arange(0, 1000, 0.1)
y = 2 * x + np.random.normal(0, 0.1, len(x)) y = 2 * x + rng_fixed_seed.normal(loc=0, scale=0.1, size=(len(x),))
x = x[:, np.newaxis] x = x[:, np.newaxis]
lgb_train = lgb.Dataset(x, label=y) lgb_train = lgb.Dataset(x, label=y)
params = {"verbose": -1, "objective": "regression", "seed": 0, "linear_tree": True, "num_threads": 2} params = {"verbose": -1, "objective": "regression", "seed": 0, "linear_tree": True, "num_threads": 2}
@ -3590,11 +3593,10 @@ def test_linear_trees_num_threads():
np.testing.assert_allclose(pred1, pred2) np.testing.assert_allclose(pred1, pred2)
def test_linear_trees(tmp_path): def test_linear_trees(tmp_path, rng_fixed_seed):
# check that setting linear_tree=True fits better than ordinary trees when data has linear relationship # check that setting linear_tree=True fits better than ordinary trees when data has linear relationship
np.random.seed(0)
x = np.arange(0, 100, 0.1) x = np.arange(0, 100, 0.1)
y = 2 * x + np.random.normal(0, 0.1, len(x)) y = 2 * x + rng_fixed_seed.normal(0, 0.1, len(x))
x = x[:, np.newaxis] x = x[:, np.newaxis]
lgb_train = lgb.Dataset(x, label=y) lgb_train = lgb.Dataset(x, label=y)
params = {"verbose": -1, "metric": "mse", "seed": 0, "num_leaves": 2} params = {"verbose": -1, "metric": "mse", "seed": 0, "num_leaves": 2}
@ -4099,21 +4101,20 @@ def test_record_evaluation_with_cv(train_metric):
np.testing.assert_allclose(cv_hist[key], eval_result[dataset][f"{metric}-{agg}"]) np.testing.assert_allclose(cv_hist[key], eval_result[dataset][f"{metric}-{agg}"])
def test_pandas_with_numpy_regular_dtypes(): def test_pandas_with_numpy_regular_dtypes(rng_fixed_seed):
pd = pytest.importorskip("pandas") pd = pytest.importorskip("pandas")
uints = ["uint8", "uint16", "uint32", "uint64"] uints = ["uint8", "uint16", "uint32", "uint64"]
ints = ["int8", "int16", "int32", "int64"] ints = ["int8", "int16", "int32", "int64"]
bool_and_floats = ["bool", "float16", "float32", "float64"] bool_and_floats = ["bool", "float16", "float32", "float64"]
rng = np.random.RandomState(42)
n_samples = 100 n_samples = 100
# data as float64 # data as float64
df = pd.DataFrame( df = pd.DataFrame(
{ {
"x1": rng.randint(0, 2, n_samples), "x1": rng_fixed_seed.integers(low=0, high=2, size=n_samples),
"x2": rng.randint(1, 3, n_samples), "x2": rng_fixed_seed.integers(low=1, high=3, size=n_samples),
"x3": 10 * rng.randint(1, 3, n_samples), "x3": 10 * rng_fixed_seed.integers(low=1, high=3, size=n_samples),
"x4": 100 * rng.randint(1, 3, n_samples), "x4": 100 * rng_fixed_seed.integers(low=1, high=3, size=n_samples),
} }
) )
df = df.astype(np.float64) df = df.astype(np.float64)
@ -4139,15 +4140,14 @@ def test_pandas_with_numpy_regular_dtypes():
np.testing.assert_allclose(preds, preds2) np.testing.assert_allclose(preds, preds2)
def test_pandas_nullable_dtypes(): def test_pandas_nullable_dtypes(rng_fixed_seed):
pd = pytest.importorskip("pandas") pd = pytest.importorskip("pandas")
rng = np.random.RandomState(0)
df = pd.DataFrame( df = pd.DataFrame(
{ {
"x1": rng.randint(1, 3, size=100), "x1": rng_fixed_seed.integers(low=1, high=3, size=100),
"x2": np.linspace(-1, 1, 100), "x2": np.linspace(-1, 1, 100),
"x3": pd.arrays.SparseArray(rng.randint(0, 11, size=100)), "x3": pd.arrays.SparseArray(rng_fixed_seed.integers(low=0, high=11, size=100)),
"x4": rng.rand(100) < 0.5, "x4": rng_fixed_seed.uniform(size=(100,)) < 0.5,
} }
) )
# introduce some missing values # introduce some missing values
@ -4219,7 +4219,7 @@ def test_boost_from_average_with_single_leaf_trees():
assert y.min() <= mean_preds <= y.max() assert y.min() <= mean_preds <= y.max()
def test_cegb_split_buffer_clean(): def test_cegb_split_buffer_clean(rng_fixed_seed):
# modified from https://github.com/microsoft/LightGBM/issues/3679#issuecomment-938652811 # modified from https://github.com/microsoft/LightGBM/issues/3679#issuecomment-938652811
# and https://github.com/microsoft/LightGBM/pull/5087 # and https://github.com/microsoft/LightGBM/pull/5087
# test that the ``splits_per_leaf_`` of CEGB is cleaned before training a new tree # test that the ``splits_per_leaf_`` of CEGB is cleaned before training a new tree
@ -4228,11 +4228,9 @@ def test_cegb_split_buffer_clean():
# Check failed: (best_split_info.left_count) > (0) # Check failed: (best_split_info.left_count) > (0)
R, C = 1000, 100 R, C = 1000, 100
seed = 29 data = rng_fixed_seed.standard_normal(size=(R, C))
np.random.seed(seed)
data = np.random.randn(R, C)
for i in range(1, C): for i in range(1, C):
data[i] += data[0] * np.random.randn() data[i] += data[0] * rng_fixed_seed.standard_normal()
N = int(0.8 * len(data)) N = int(0.8 * len(data))
train_data = data[:N] train_data = data[:N]

Просмотреть файл

@ -340,7 +340,7 @@ def test_grid_search():
assert evals_result == grid.best_estimator_.evals_result_ assert evals_result == grid.best_estimator_.evals_result_
def test_random_search(): def test_random_search(rng):
X, y = load_iris(return_X_y=True) X, y = load_iris(return_X_y=True)
y = y.astype(str) # utilize label encoder at it's max power y = y.astype(str) # utilize label encoder at it's max power
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
@ -349,8 +349,8 @@ def test_random_search():
params = {"subsample": 0.8, "subsample_freq": 1} params = {"subsample": 0.8, "subsample_freq": 1}
param_dist = { param_dist = {
"boosting_type": ["rf", "gbdt"], "boosting_type": ["rf", "gbdt"],
"n_estimators": [np.random.randint(low=3, high=10) for i in range(n_iter)], "n_estimators": rng.integers(low=3, high=10, size=(n_iter,)).tolist(),
"reg_alpha": [np.random.uniform(low=0.01, high=0.06) for i in range(n_iter)], "reg_alpha": rng.uniform(low=0.01, high=0.06, size=(n_iter,)).tolist(),
} }
fit_params = {"eval_set": [(X_val, y_val)], "eval_metric": constant_metric, "callbacks": [lgb.early_stopping(2)]} fit_params = {"eval_set": [(X_val, y_val)], "eval_metric": constant_metric, "callbacks": [lgb.early_stopping(2)]}
rand = RandomizedSearchCV( rand = RandomizedSearchCV(
@ -556,29 +556,29 @@ def test_feature_importances_type():
assert importance_split_top1 != importance_gain_top1 assert importance_split_top1 != importance_gain_top1
def test_pandas_categorical(): # why fixed seed?
# sometimes there is no difference how cols are treated (cat or not cat)
def test_pandas_categorical(rng_fixed_seed):
pd = pytest.importorskip("pandas") pd = pytest.importorskip("pandas")
np.random.seed(42) # sometimes there is no difference how cols are treated (cat or not cat)
X = pd.DataFrame( X = pd.DataFrame(
{ {
"A": np.random.permutation(["a", "b", "c", "d"] * 75), # str "A": rng_fixed_seed.permutation(["a", "b", "c", "d"] * 75), # str
"B": np.random.permutation([1, 2, 3] * 100), # int "B": rng_fixed_seed.permutation([1, 2, 3] * 100), # int
"C": np.random.permutation([0.1, 0.2, -0.1, -0.1, 0.2] * 60), # float "C": rng_fixed_seed.permutation([0.1, 0.2, -0.1, -0.1, 0.2] * 60), # float
"D": np.random.permutation([True, False] * 150), # bool "D": rng_fixed_seed.permutation([True, False] * 150), # bool
"E": pd.Categorical(np.random.permutation(["z", "y", "x", "w", "v"] * 60), ordered=True), "E": pd.Categorical(rng_fixed_seed.permutation(["z", "y", "x", "w", "v"] * 60), ordered=True),
} }
) # str and ordered categorical ) # str and ordered categorical
y = np.random.permutation([0, 1] * 150) y = rng_fixed_seed.permutation([0, 1] * 150)
X_test = pd.DataFrame( X_test = pd.DataFrame(
{ {
"A": np.random.permutation(["a", "b", "e"] * 20), # unseen category "A": rng_fixed_seed.permutation(["a", "b", "e"] * 20), # unseen category
"B": np.random.permutation([1, 3] * 30), "B": rng_fixed_seed.permutation([1, 3] * 30),
"C": np.random.permutation([0.1, -0.1, 0.2, 0.2] * 15), "C": rng_fixed_seed.permutation([0.1, -0.1, 0.2, 0.2] * 15),
"D": np.random.permutation([True, False] * 30), "D": rng_fixed_seed.permutation([True, False] * 30),
"E": pd.Categorical(np.random.permutation(["z", "y"] * 30), ordered=True), "E": pd.Categorical(rng_fixed_seed.permutation(["z", "y"] * 30), ordered=True),
} }
) )
np.random.seed() # reset seed
cat_cols_actual = ["A", "B", "C", "D"] cat_cols_actual = ["A", "B", "C", "D"]
cat_cols_to_store = cat_cols_actual + ["E"] cat_cols_to_store = cat_cols_actual + ["E"]
X[cat_cols_actual] = X[cat_cols_actual].astype("category") X[cat_cols_actual] = X[cat_cols_actual].astype("category")
@ -620,21 +620,21 @@ def test_pandas_categorical():
assert gbm6.booster_.pandas_categorical == cat_values assert gbm6.booster_.pandas_categorical == cat_values
def test_pandas_sparse(): def test_pandas_sparse(rng):
pd = pytest.importorskip("pandas") pd = pytest.importorskip("pandas")
X = pd.DataFrame( X = pd.DataFrame(
{ {
"A": pd.arrays.SparseArray(np.random.permutation([0, 1, 2] * 100)), "A": pd.arrays.SparseArray(rng.permutation([0, 1, 2] * 100)),
"B": pd.arrays.SparseArray(np.random.permutation([0.0, 0.1, 0.2, -0.1, 0.2] * 60)), "B": pd.arrays.SparseArray(rng.permutation([0.0, 0.1, 0.2, -0.1, 0.2] * 60)),
"C": pd.arrays.SparseArray(np.random.permutation([True, False] * 150)), "C": pd.arrays.SparseArray(rng.permutation([True, False] * 150)),
} }
) )
y = pd.Series(pd.arrays.SparseArray(np.random.permutation([0, 1] * 150))) y = pd.Series(pd.arrays.SparseArray(rng.permutation([0, 1] * 150)))
X_test = pd.DataFrame( X_test = pd.DataFrame(
{ {
"A": pd.arrays.SparseArray(np.random.permutation([0, 2] * 30)), "A": pd.arrays.SparseArray(rng.permutation([0, 2] * 30)),
"B": pd.arrays.SparseArray(np.random.permutation([0.0, 0.1, 0.2, -0.1] * 15)), "B": pd.arrays.SparseArray(rng.permutation([0.0, 0.1, 0.2, -0.1] * 15)),
"C": pd.arrays.SparseArray(np.random.permutation([True, False] * 30)), "C": pd.arrays.SparseArray(rng.permutation([True, False] * 30)),
} }
) )
for dtype in pd.concat([X.dtypes, X_test.dtypes, pd.Series(y.dtypes)]): for dtype in pd.concat([X.dtypes, X_test.dtypes, pd.Series(y.dtypes)]):
@ -1073,11 +1073,11 @@ def test_multiple_eval_metrics():
assert "binary_logloss" in gbm.evals_result_["training"] assert "binary_logloss" in gbm.evals_result_["training"]
def test_nan_handle(): def test_nan_handle(rng):
nrows = 100 nrows = 100
ncols = 10 ncols = 10
X = np.random.randn(nrows, ncols) X = rng.standard_normal(size=(nrows, ncols))
y = np.random.randn(nrows) + np.full(nrows, 1e30) y = rng.standard_normal(size=(nrows,)) + np.full(nrows, 1e30)
weight = np.zeros(nrows) weight = np.zeros(nrows)
params = {"n_estimators": 20, "verbose": -1} params = {"n_estimators": 20, "verbose": -1}
params_fit = {"X": X, "y": y, "sample_weight": weight, "eval_set": (X, y), "callbacks": [lgb.early_stopping(5)]} params_fit = {"X": X, "y": y, "sample_weight": weight, "eval_set": (X, y), "callbacks": [lgb.early_stopping(5)]}
@ -1410,13 +1410,13 @@ def test_validate_features(task):
@pytest.mark.parametrize("X_type", ["dt_DataTable", "list2d", "numpy", "scipy_csc", "scipy_csr", "pd_DataFrame"]) @pytest.mark.parametrize("X_type", ["dt_DataTable", "list2d", "numpy", "scipy_csc", "scipy_csr", "pd_DataFrame"])
@pytest.mark.parametrize("y_type", ["list1d", "numpy", "pd_Series", "pd_DataFrame"]) @pytest.mark.parametrize("y_type", ["list1d", "numpy", "pd_Series", "pd_DataFrame"])
@pytest.mark.parametrize("task", ["binary-classification", "multiclass-classification", "regression"]) @pytest.mark.parametrize("task", ["binary-classification", "multiclass-classification", "regression"])
def test_classification_and_regression_minimally_work_with_all_all_accepted_data_types(X_type, y_type, task): def test_classification_and_regression_minimally_work_with_all_all_accepted_data_types(X_type, y_type, task, rng):
if any(t.startswith("pd_") for t in [X_type, y_type]) and not PANDAS_INSTALLED: if any(t.startswith("pd_") for t in [X_type, y_type]) and not PANDAS_INSTALLED:
pytest.skip("pandas is not installed") pytest.skip("pandas is not installed")
if any(t.startswith("dt_") for t in [X_type, y_type]) and not DATATABLE_INSTALLED: if any(t.startswith("dt_") for t in [X_type, y_type]) and not DATATABLE_INSTALLED:
pytest.skip("datatable is not installed") pytest.skip("datatable is not installed")
X, y, g = _create_data(task, n_samples=2_000) X, y, g = _create_data(task, n_samples=2_000)
weights = np.abs(np.random.randn(y.shape[0])) weights = np.abs(rng.standard_normal(size=(y.shape[0],)))
if task == "binary-classification" or task == "regression": if task == "binary-classification" or task == "regression":
init_score = np.full_like(y, np.mean(y)) init_score = np.full_like(y, np.mean(y))
@ -1487,13 +1487,13 @@ def test_classification_and_regression_minimally_work_with_all_all_accepted_data
@pytest.mark.parametrize("X_type", ["dt_DataTable", "list2d", "numpy", "scipy_csc", "scipy_csr", "pd_DataFrame"]) @pytest.mark.parametrize("X_type", ["dt_DataTable", "list2d", "numpy", "scipy_csc", "scipy_csr", "pd_DataFrame"])
@pytest.mark.parametrize("y_type", ["list1d", "numpy", "pd_DataFrame", "pd_Series"]) @pytest.mark.parametrize("y_type", ["list1d", "numpy", "pd_DataFrame", "pd_Series"])
@pytest.mark.parametrize("g_type", ["list1d_float", "list1d_int", "numpy", "pd_Series"]) @pytest.mark.parametrize("g_type", ["list1d_float", "list1d_int", "numpy", "pd_Series"])
def test_ranking_minimally_works_with_all_all_accepted_data_types(X_type, y_type, g_type): def test_ranking_minimally_works_with_all_all_accepted_data_types(X_type, y_type, g_type, rng):
if any(t.startswith("pd_") for t in [X_type, y_type, g_type]) and not PANDAS_INSTALLED: if any(t.startswith("pd_") for t in [X_type, y_type, g_type]) and not PANDAS_INSTALLED:
pytest.skip("pandas is not installed") pytest.skip("pandas is not installed")
if any(t.startswith("dt_") for t in [X_type, y_type, g_type]) and not DATATABLE_INSTALLED: if any(t.startswith("dt_") for t in [X_type, y_type, g_type]) and not DATATABLE_INSTALLED:
pytest.skip("datatable is not installed") pytest.skip("datatable is not installed")
X, y, g = _create_data(task="ranking", n_samples=1_000) X, y, g = _create_data(task="ranking", n_samples=1_000)
weights = np.abs(np.random.randn(y.shape[0])) weights = np.abs(rng.standard_normal(size=(y.shape[0],)))
init_score = np.full_like(y, np.mean(y)) init_score = np.full_like(y, np.mean(y))
X_valid = X * 2 X_valid = X * 2