[python] migrate to pathlib in python tests (#4435)

This commit is contained in:
Nikita Titov 2021-07-04 07:31:41 +03:00 коммит произвёл GitHub
Родитель 3594f36937
Коммит cff80442e1
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
8 изменённых файлов: 81 добавлений и 90 удалений

Просмотреть файл

@ -1,6 +1,7 @@
# coding: utf-8
import ctypes
import os
from os import environ
from pathlib import Path
from platform import system
import numpy as np
@ -8,28 +9,27 @@ from scipy import sparse
def find_lib_path():
if os.environ.get('LIGHTGBM_BUILD_DOC', False):
if environ.get('LIGHTGBM_BUILD_DOC', False):
# we don't need lib_lightgbm while building docs
return []
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
curr_path = Path(__file__).parent.absolute()
dll_path = [curr_path,
os.path.join(curr_path, '../../'),
os.path.join(curr_path, '../../python-package/lightgbm/compile'),
os.path.join(curr_path, '../../python-package/compile'),
os.path.join(curr_path, '../../lib/')]
curr_path.parents[1],
curr_path.parents[1] / 'python-package' / 'lightgbm' / 'compile',
curr_path.parents[1] / 'python-package' / 'compile',
curr_path.parents[1] / 'lib']
if system() in ('Windows', 'Microsoft'):
dll_path.append(os.path.join(curr_path, '../../python-package/compile/Release/'))
dll_path.append(os.path.join(curr_path, '../../python-package/compile/windows/x64/DLL/'))
dll_path.append(os.path.join(curr_path, '../../Release/'))
dll_path.append(os.path.join(curr_path, '../../windows/x64/DLL/'))
dll_path = [os.path.join(p, 'lib_lightgbm.dll') for p in dll_path]
dll_path.append(curr_path.parents[1] / 'python-package' / 'compile' / 'Release/')
dll_path.append(curr_path.parents[1] / 'python-package' / 'compile' / 'windows' / 'x64' / 'DLL')
dll_path.append(curr_path.parents[1] / 'Release')
dll_path.append(curr_path.parents[1] / 'windows' / 'x64' / 'DLL')
dll_path = [p / 'lib_lightgbm.dll' for p in dll_path]
else:
dll_path = [os.path.join(p, 'lib_lightgbm.so') for p in dll_path]
lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)]
dll_path = [p / 'lib_lightgbm.so' for p in dll_path]
lib_path = [str(p) for p in dll_path if p.is_file()]
if not lib_path:
dll_path = [os.path.realpath(p) for p in dll_path]
dll_path_joined = '\n'.join(dll_path)
dll_path_joined = '\n'.join(map(str, dll_path))
raise Exception(f'Cannot find lightgbm library file in following paths:\n{dll_path_joined}')
return lib_path
@ -62,7 +62,7 @@ def load_from_file(filename, reference):
ref = reference
handle = ctypes.c_void_p()
LIB.LGBM_DatasetCreateFromFile(
c_str(filename),
c_str(str(filename)),
c_str('max_bin=15'),
ref,
ctypes.byref(handle))
@ -207,16 +207,13 @@ def free_dataset(handle):
def test_dataset():
train = load_from_file(os.path.join(os.path.dirname(os.path.realpath(__file__)),
'../../examples/binary_classification/binary.train'), None)
test = load_from_mat(os.path.join(os.path.dirname(os.path.realpath(__file__)),
'../../examples/binary_classification/binary.test'), train)
binary_example_dir = Path(__file__).absolute().parents[2] / 'examples' / 'binary_classification'
train = load_from_file(binary_example_dir / 'binary.train', None)
test = load_from_mat(binary_example_dir / 'binary.test', train)
free_dataset(test)
test = load_from_csr(os.path.join(os.path.dirname(os.path.realpath(__file__)),
'../../examples/binary_classification/binary.test'), train)
test = load_from_csr(binary_example_dir / 'binary.test', train)
free_dataset(test)
test = load_from_csc(os.path.join(os.path.dirname(os.path.realpath(__file__)),
'../../examples/binary_classification/binary.test'), train)
test = load_from_csc(binary_example_dir / 'binary.test', train)
free_dataset(test)
save_to_binary(train, 'train.binary.bin')
free_dataset(train)
@ -225,10 +222,9 @@ def test_dataset():
def test_booster():
train = load_from_mat(os.path.join(os.path.dirname(os.path.realpath(__file__)),
'../../examples/binary_classification/binary.train'), None)
test = load_from_mat(os.path.join(os.path.dirname(os.path.realpath(__file__)),
'../../examples/binary_classification/binary.test'), train)
binary_example_dir = Path(__file__).absolute().parents[2] / 'examples' / 'binary_classification'
train = load_from_mat(binary_example_dir / 'binary.train', None)
test = load_from_mat(binary_example_dir / 'binary.test', train)
booster = ctypes.c_void_p()
LIB.LGBM_BoosterCreate(
train,
@ -263,8 +259,7 @@ def test_booster():
ctypes.byref(num_total_model),
ctypes.byref(booster2))
data = []
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)),
'../../examples/binary_classification/binary.test'), 'r') as inp:
with open(binary_example_dir / 'binary.test', 'r') as inp:
for line in inp.readlines():
data.append([float(x) for x in line.split('\t')[1:]])
mat = np.array(data, dtype=np.float64)
@ -286,8 +281,7 @@ def test_booster():
preb.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))
LIB.LGBM_BoosterPredictForFile(
booster2,
c_str(os.path.join(os.path.dirname(os.path.realpath(__file__)),
'../../examples/binary_classification/binary.test')),
c_str(str(binary_example_dir / 'binary.test')),
ctypes.c_int(0),
ctypes.c_int(0),
ctypes.c_int(0),
@ -296,8 +290,7 @@ def test_booster():
c_str('preb.txt'))
LIB.LGBM_BoosterPredictForFile(
booster2,
c_str(os.path.join(os.path.dirname(os.path.realpath(__file__)),
'../../examples/binary_classification/binary.test')),
c_str(str(binary_example_dir / 'binary.test')),
ctypes.c_int(0),
ctypes.c_int(0),
ctypes.c_int(10),

Просмотреть файл

@ -1,7 +1,7 @@
# coding: utf-8
import glob
from pathlib import Path
import numpy as np
preds = [np.loadtxt(name) for name in glob.glob('*.pred')]
preds = [np.loadtxt(str(name)) for name in Path(__file__).parent.absolute().glob('*.pred')]
np.testing.assert_allclose(preds[0], preds[1])

Просмотреть файл

@ -1,7 +1,7 @@
# coding: utf-8
import filecmp
import numbers
import os
from pathlib import Path
import numpy as np
import pytest
@ -153,8 +153,8 @@ def test_sequence(tmpdir, sample_count, batch_size, include_0_and_nan, num_seq):
X = data[:, :-1]
Y = data[:, -1]
npy_bin_fname = os.path.join(tmpdir, 'data_from_npy.bin')
seq_bin_fname = os.path.join(tmpdir, 'data_from_seq.bin')
npy_bin_fname = str(tmpdir / 'data_from_npy.bin')
seq_bin_fname = str(tmpdir / 'data_from_seq.bin')
# Create dataset from numpy array directly.
ds = lgb.Dataset(X, label=Y, params=params)
@ -175,9 +175,9 @@ def test_sequence(tmpdir, sample_count, batch_size, include_0_and_nan, num_seq):
valid_X = valid_data[:, :-1]
valid_Y = valid_data[:, -1]
valid_npy_bin_fname = os.path.join(tmpdir, 'valid_data_from_npy.bin')
valid_seq_bin_fname = os.path.join(tmpdir, 'valid_data_from_seq.bin')
valid_seq2_bin_fname = os.path.join(tmpdir, 'valid_data_from_seq2.bin')
valid_npy_bin_fname = str(tmpdir / 'valid_data_from_npy.bin')
valid_seq_bin_fname = str(tmpdir / 'valid_data_from_seq.bin')
valid_seq2_bin_fname = str(tmpdir / 'valid_data_from_seq2.bin')
valid_ds = lgb.Dataset(valid_X, label=valid_Y, params=params, reference=ds)
valid_ds.save_binary(valid_npy_bin_fname)
@ -222,10 +222,9 @@ def test_chunked_dataset_linear():
def test_subset_group():
X_train, y_train = load_svmlight_file(os.path.join(os.path.dirname(os.path.realpath(__file__)),
'../../examples/lambdarank/rank.train'))
q_train = np.loadtxt(os.path.join(os.path.dirname(os.path.realpath(__file__)),
'../../examples/lambdarank/rank.train.query'))
rank_example_dir = Path(__file__).absolute().parents[2] / 'examples' / 'lambdarank'
X_train, y_train = load_svmlight_file(str(rank_example_dir / 'rank.train'))
q_train = np.loadtxt(str(rank_example_dir / 'rank.train.query'))
lgb_train = lgb.Dataset(X_train, y_train, group=q_train)
assert len(lgb_train.get_group()) == 201
subset = lgb_train.subset(list(range(10))).construct()

Просмотреть файл

@ -1,20 +1,21 @@
# coding: utf-8
import os
from pathlib import Path
import numpy as np
from sklearn.datasets import load_svmlight_file
import lightgbm as lgb
EXAMPLES_DIR = Path(__file__).absolute().parents[2] / 'examples'
class FileLoader:
def __init__(self, directory, prefix, config_file='train.conf'):
directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), directory)
self.directory = directory
self.prefix = prefix
self.params = {'gpu_use_dp': True}
with open(os.path.join(directory, config_file), 'r') as f:
with open(self.directory / config_file, 'r') as f:
for line in f.readlines():
line = line.strip()
if line and not line.startswith('#'):
@ -32,10 +33,10 @@ class FileLoader:
return mat[:, 1:], mat[:, 0], filename
def load_field(self, suffix):
return np.loadtxt(os.path.join(self.directory, f'{self.prefix}{suffix}'))
return np.loadtxt(str(self.directory / f'{self.prefix}{suffix}'))
def load_cpp_result(self, result_file='LightGBM_predict_result.txt'):
return np.loadtxt(os.path.join(self.directory, result_file))
return np.loadtxt(str(self.directory / result_file))
def train_predict_check(self, lgb_train, X_test, X_test_fn, sk_pred):
params = dict(self.params)
@ -61,11 +62,11 @@ class FileLoader:
assert a == b, f
def path(self, suffix):
return os.path.join(self.directory, f'{self.prefix}{suffix}')
return str(self.directory / f'{self.prefix}{suffix}')
def test_binary():
fd = FileLoader('../../examples/binary_classification', 'binary')
fd = FileLoader(EXAMPLES_DIR / 'binary_classification', 'binary')
X_train, y_train, _ = fd.load_dataset('.train')
X_test, _, X_test_fn = fd.load_dataset('.test')
weight_train = fd.load_field('.train.weight')
@ -78,7 +79,7 @@ def test_binary():
def test_binary_linear():
fd = FileLoader('../../examples/binary_classification', 'binary', 'train_linear.conf')
fd = FileLoader(EXAMPLES_DIR / 'binary_classification', 'binary', 'train_linear.conf')
X_train, y_train, _ = fd.load_dataset('.train')
X_test, _, X_test_fn = fd.load_dataset('.test')
weight_train = fd.load_field('.train.weight')
@ -91,7 +92,7 @@ def test_binary_linear():
def test_multiclass():
fd = FileLoader('../../examples/multiclass_classification', 'multiclass')
fd = FileLoader(EXAMPLES_DIR / 'multiclass_classification', 'multiclass')
X_train, y_train, _ = fd.load_dataset('.train')
X_test, _, X_test_fn = fd.load_dataset('.test')
lgb_train = lgb.Dataset(X_train, y_train)
@ -103,7 +104,7 @@ def test_multiclass():
def test_regression():
fd = FileLoader('../../examples/regression', 'regression')
fd = FileLoader(EXAMPLES_DIR / 'regression', 'regression')
X_train, y_train, _ = fd.load_dataset('.train')
X_test, _, X_test_fn = fd.load_dataset('.test')
init_score_train = fd.load_field('.train.init')
@ -116,7 +117,7 @@ def test_regression():
def test_lambdarank():
fd = FileLoader('../../examples/lambdarank', 'rank')
fd = FileLoader(EXAMPLES_DIR / 'lambdarank', 'rank')
X_train, y_train, _ = fd.load_dataset('.train', is_sparse=True)
X_test, _, X_test_fn = fd.load_dataset('.test', is_sparse=True)
group_train = fd.load_field('.train.query')
@ -131,7 +132,7 @@ def test_lambdarank():
def test_xendcg():
fd = FileLoader('../../examples/xendcg', 'rank')
fd = FileLoader(EXAMPLES_DIR / 'xendcg', 'rank')
X_train, y_train, _ = fd.load_dataset('.train', is_sparse=True)
X_test, _, X_test_fn = fd.load_dataset('.test', is_sparse=True)
group_train = fd.load_field('.train.query')

Просмотреть файл

@ -1090,7 +1090,7 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
assert "client" not in local_model.get_params()
assert getattr(local_model, "client", None) is None
tmp_file = str(tmp_path / "model-1.pkl")
tmp_file = tmp_path / "model-1.pkl"
_pickle(
obj=dask_model,
filepath=tmp_file,
@ -1101,7 +1101,7 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
serializer=serializer
)
local_tmp_file = str(tmp_path / "local-model-1.pkl")
local_tmp_file = tmp_path / "local-model-1.pkl"
_pickle(
obj=local_model,
filepath=local_tmp_file,
@ -1146,7 +1146,7 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
local_model.client
local_model.client_
tmp_file2 = str(tmp_path / "model-2.pkl")
tmp_file2 = tmp_path / "model-2.pkl"
_pickle(
obj=dask_model,
filepath=tmp_file2,
@ -1157,7 +1157,7 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
serializer=serializer
)
local_tmp_file2 = str(tmp_path / "local-model-2.pkl")
local_tmp_file2 = tmp_path / "local-model-2.pkl"
_pickle(
obj=local_model,
filepath=local_tmp_file2,

Просмотреть файл

@ -2,10 +2,10 @@
import copy
import itertools
import math
import os
import pickle
import platform
import random
from pathlib import Path
import numpy as np
import psutil
@ -568,8 +568,9 @@ def test_auc_mu():
lgb.train(params, lgb_X, num_boost_round=100, valid_sets=[lgb_X], evals_result=results)
assert results['training']['auc_mu'][-1] == pytest.approx(1)
# test loading class weights
Xy = np.loadtxt(os.path.join(os.path.dirname(os.path.realpath(__file__)),
'../../examples/multiclass_classification/multiclass.train'))
Xy = np.loadtxt(
str(Path(__file__).absolute().parents[2] / 'examples' / 'multiclass_classification' / 'multiclass.train')
)
y = Xy[:, 0]
X = Xy[:, 1:]
lgb_X = lgb.Dataset(X, label=y)
@ -646,7 +647,6 @@ def test_continue_train():
assert ret < 2.0
assert evals_result['valid_0']['l1'][-1] == pytest.approx(ret)
np.testing.assert_allclose(evals_result['valid_0']['l1'], evals_result['valid_0']['custom_mae'])
os.remove(model_name)
def test_continue_train_reused_dataset():
@ -748,10 +748,9 @@ def test_cv():
verbose_eval=False)
np.testing.assert_allclose(cv_res_gen['l2-mean'], cv_res_obj['l2-mean'])
# LambdaRank
X_train, y_train = load_svmlight_file(os.path.join(os.path.dirname(os.path.realpath(__file__)),
'../../examples/lambdarank/rank.train'))
q_train = np.loadtxt(os.path.join(os.path.dirname(os.path.realpath(__file__)),
'../../examples/lambdarank/rank.train.query'))
rank_example_dir = Path(__file__).absolute().parents[2] / 'examples' / 'lambdarank'
X_train, y_train = load_svmlight_file(str(rank_example_dir / 'rank.train'))
q_train = np.loadtxt(str(rank_example_dir / 'rank.train.query'))
params_lambdarank = {'objective': 'lambdarank', 'verbose': -1, 'eval_at': 3}
lgb_train = lgb.Dataset(X_train, y_train, group=q_train)
# ... with l2 metric
@ -2262,8 +2261,9 @@ def test_forced_bins():
x[:, 0] = np.arange(0, 1, 0.01)
x[:, 1] = -np.arange(0, 1, 0.01)
y = np.arange(0, 1, 0.01)
forcedbins_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),
'../../examples/regression/forced_bins.json')
forcedbins_filename = str(
Path(__file__).absolute().parents[2] / 'examples' / 'regression' / 'forced_bins.json'
)
params = {'objective': 'regression_l1',
'max_bin': 5,
'forcedbins_filename': forcedbins_filename,
@ -2285,8 +2285,9 @@ def test_forced_bins():
est = lgb.train(params, lgb_x, num_boost_round=20)
predicted = est.predict(new_x)
assert len(np.unique(predicted)) == 3
params['forcedbins_filename'] = os.path.join(os.path.dirname(os.path.realpath(__file__)),
'../../examples/regression/forced_bins2.json')
params['forcedbins_filename'] = str(
Path(__file__).absolute().parents[2] / 'examples' / 'regression' / 'forced_bins2.json'
)
params['max_bin'] = 11
lgb_x = lgb.Dataset(x[:, :1], label=y)
est = lgb.train(params, lgb_x, num_boost_round=50)

Просмотреть файл

@ -1,7 +1,7 @@
# coding: utf-8
import itertools
import math
import os
from pathlib import Path
import joblib
import numpy as np
@ -113,14 +113,11 @@ def test_multiclass():
def test_lambdarank():
X_train, y_train = load_svmlight_file(os.path.join(os.path.dirname(os.path.realpath(__file__)),
'../../examples/lambdarank/rank.train'))
X_test, y_test = load_svmlight_file(os.path.join(os.path.dirname(os.path.realpath(__file__)),
'../../examples/lambdarank/rank.test'))
q_train = np.loadtxt(os.path.join(os.path.dirname(os.path.realpath(__file__)),
'../../examples/lambdarank/rank.train.query'))
q_test = np.loadtxt(os.path.join(os.path.dirname(os.path.realpath(__file__)),
'../../examples/lambdarank/rank.test.query'))
rank_example_dir = Path(__file__).absolute().parents[2] / 'examples' / 'lambdarank'
X_train, y_train = load_svmlight_file(str(rank_example_dir / 'rank.train'))
X_test, y_test = load_svmlight_file(str(rank_example_dir / 'rank.test'))
q_train = np.loadtxt(str(rank_example_dir / 'rank.train.query'))
q_test = np.loadtxt(str(rank_example_dir / 'rank.test.query'))
gbm = lgb.LGBMRanker(n_estimators=50)
gbm.fit(X_train, y_train, group=q_train, eval_set=[(X_test, y_test)],
eval_group=[q_test], eval_at=[1, 3], early_stopping_rounds=10, verbose=False,
@ -131,11 +128,11 @@ def test_lambdarank():
def test_xendcg():
dir_path = os.path.dirname(os.path.realpath(__file__))
X_train, y_train = load_svmlight_file(os.path.join(dir_path, '../../examples/xendcg/rank.train'))
X_test, y_test = load_svmlight_file(os.path.join(dir_path, '../../examples/xendcg/rank.test'))
q_train = np.loadtxt(os.path.join(dir_path, '../../examples/xendcg/rank.train.query'))
q_test = np.loadtxt(os.path.join(dir_path, '../../examples/xendcg/rank.test.query'))
xendcg_example_dir = Path(__file__).absolute().parents[2] / 'examples' / 'xendcg'
X_train, y_train = load_svmlight_file(str(xendcg_example_dir / 'rank.train'))
X_test, y_test = load_svmlight_file(str(xendcg_example_dir / 'rank.test'))
q_train = np.loadtxt(str(xendcg_example_dir / 'rank.train.query'))
q_test = np.loadtxt(str(xendcg_example_dir / 'rank.test.query'))
gbm = lgb.LGBMRanker(n_estimators=50, objective='rank_xendcg', random_state=5, n_jobs=1)
gbm.fit(X_train, y_train, group=q_train, eval_set=[(X_test, y_test)],
eval_group=[q_test], eval_at=[1, 3], early_stopping_rounds=10, verbose=False,

Просмотреть файл

@ -10,7 +10,7 @@ def test_register_logger(tmp_path):
logger = logging.getLogger("LightGBM")
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(levelname)s | %(message)s')
log_filename = str(tmp_path / "LightGBM_test_logger.log")
log_filename = tmp_path / "LightGBM_test_logger.log"
file_handler = logging.FileHandler(log_filename, mode="w", encoding="utf-8")
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)