зеркало из https://github.com/microsoft/LightGBM.git
[tests][python] Handle data types more accurate in C API test (#4297)
This commit is contained in:
Родитель
e35ed5f6e5
Коммит
272fedb95a
|
@ -52,10 +52,6 @@ dtype_int32 = 2
|
|||
dtype_int64 = 3
|
||||
|
||||
|
||||
def c_array(ctype, values):
|
||||
return (ctype * len(values))(*values)
|
||||
|
||||
|
||||
def c_str(string):
|
||||
return ctypes.c_char_p(string.encode('utf-8'))
|
||||
|
||||
|
@ -71,9 +67,9 @@ def load_from_file(filename, reference):
|
|||
ref,
|
||||
ctypes.byref(handle))
|
||||
print(LIB.LGBM_GetLastError())
|
||||
num_data = ctypes.c_long()
|
||||
num_data = ctypes.c_int(0)
|
||||
LIB.LGBM_DatasetGetNumData(handle, ctypes.byref(num_data))
|
||||
num_feature = ctypes.c_long()
|
||||
num_feature = ctypes.c_int(0)
|
||||
LIB.LGBM_DatasetGetNumFeature(handle, ctypes.byref(num_feature))
|
||||
print(f'#data: {num_data.value} #feature: {num_feature.value}')
|
||||
return handle
|
||||
|
@ -91,7 +87,7 @@ def load_from_csr(filename, reference):
|
|||
values = line.split('\t')
|
||||
data.append([float(x) for x in values[1:]])
|
||||
label.append(float(values[0]))
|
||||
mat = np.array(data)
|
||||
mat = np.array(data, dtype=np.float64)
|
||||
label = np.array(label, dtype=np.float32)
|
||||
csr = sparse.csr_matrix(mat)
|
||||
handle = ctypes.c_void_p()
|
||||
|
@ -100,22 +96,27 @@ def load_from_csr(filename, reference):
|
|||
ref = reference
|
||||
|
||||
LIB.LGBM_DatasetCreateFromCSR(
|
||||
c_array(ctypes.c_int, csr.indptr),
|
||||
dtype_int32,
|
||||
c_array(ctypes.c_int, csr.indices),
|
||||
csr.data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)),
|
||||
dtype_float64,
|
||||
csr.indptr.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
|
||||
ctypes.c_int(dtype_int32),
|
||||
csr.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
|
||||
csr.data.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
|
||||
ctypes.c_int(dtype_float64),
|
||||
ctypes.c_int64(len(csr.indptr)),
|
||||
ctypes.c_int64(len(csr.data)),
|
||||
ctypes.c_int64(csr.shape[1]),
|
||||
c_str('max_bin=15'),
|
||||
ref,
|
||||
ctypes.byref(handle))
|
||||
num_data = ctypes.c_long()
|
||||
num_data = ctypes.c_int(0)
|
||||
LIB.LGBM_DatasetGetNumData(handle, ctypes.byref(num_data))
|
||||
num_feature = ctypes.c_long()
|
||||
num_feature = ctypes.c_int(0)
|
||||
LIB.LGBM_DatasetGetNumFeature(handle, ctypes.byref(num_feature))
|
||||
LIB.LGBM_DatasetSetField(handle, c_str('label'), c_array(ctypes.c_float, label), len(label), 0)
|
||||
LIB.LGBM_DatasetSetField(
|
||||
handle,
|
||||
c_str('label'),
|
||||
label.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
|
||||
ctypes.c_int(len(label)),
|
||||
ctypes.c_int(dtype_float32))
|
||||
print(f'#data: {num_data.value} #feature: {num_feature.value}')
|
||||
return handle
|
||||
|
||||
|
@ -128,31 +129,36 @@ def load_from_csc(filename, reference):
|
|||
values = line.split('\t')
|
||||
data.append([float(x) for x in values[1:]])
|
||||
label.append(float(values[0]))
|
||||
mat = np.array(data)
|
||||
mat = np.array(data, dtype=np.float64)
|
||||
label = np.array(label, dtype=np.float32)
|
||||
csr = sparse.csc_matrix(mat)
|
||||
csc = sparse.csc_matrix(mat)
|
||||
handle = ctypes.c_void_p()
|
||||
ref = None
|
||||
if reference is not None:
|
||||
ref = reference
|
||||
|
||||
LIB.LGBM_DatasetCreateFromCSC(
|
||||
c_array(ctypes.c_int, csr.indptr),
|
||||
dtype_int32,
|
||||
c_array(ctypes.c_int, csr.indices),
|
||||
csr.data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)),
|
||||
dtype_float64,
|
||||
ctypes.c_int64(len(csr.indptr)),
|
||||
ctypes.c_int64(len(csr.data)),
|
||||
ctypes.c_int64(csr.shape[0]),
|
||||
csc.indptr.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
|
||||
ctypes.c_int(dtype_int32),
|
||||
csc.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
|
||||
csc.data.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
|
||||
ctypes.c_int(dtype_float64),
|
||||
ctypes.c_int64(len(csc.indptr)),
|
||||
ctypes.c_int64(len(csc.data)),
|
||||
ctypes.c_int64(csc.shape[0]),
|
||||
c_str('max_bin=15'),
|
||||
ref,
|
||||
ctypes.byref(handle))
|
||||
num_data = ctypes.c_long()
|
||||
num_data = ctypes.c_int(0)
|
||||
LIB.LGBM_DatasetGetNumData(handle, ctypes.byref(num_data))
|
||||
num_feature = ctypes.c_long()
|
||||
num_feature = ctypes.c_int(0)
|
||||
LIB.LGBM_DatasetGetNumFeature(handle, ctypes.byref(num_feature))
|
||||
LIB.LGBM_DatasetSetField(handle, c_str('label'), c_array(ctypes.c_float, label), len(label), 0)
|
||||
LIB.LGBM_DatasetSetField(
|
||||
handle,
|
||||
c_str('label'),
|
||||
label.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
|
||||
ctypes.c_int(len(label)),
|
||||
ctypes.c_int(dtype_float32))
|
||||
print(f'#data: {num_data.value} #feature: {num_feature.value}')
|
||||
return handle
|
||||
|
||||
|
@ -165,8 +171,8 @@ def load_from_mat(filename, reference):
|
|||
values = line.split('\t')
|
||||
data.append([float(x) for x in values[1:]])
|
||||
label.append(float(values[0]))
|
||||
mat = np.array(data)
|
||||
data = np.array(mat.reshape(mat.size), copy=False)
|
||||
mat = np.array(data, dtype=np.float64)
|
||||
data = np.array(mat.reshape(mat.size), dtype=np.float64, copy=False)
|
||||
label = np.array(label, dtype=np.float32)
|
||||
handle = ctypes.c_void_p()
|
||||
ref = None
|
||||
|
@ -174,19 +180,24 @@ def load_from_mat(filename, reference):
|
|||
ref = reference
|
||||
|
||||
LIB.LGBM_DatasetCreateFromMat(
|
||||
data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)),
|
||||
dtype_float64,
|
||||
mat.shape[0],
|
||||
mat.shape[1],
|
||||
1,
|
||||
data.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
|
||||
ctypes.c_int(dtype_float64),
|
||||
ctypes.c_int32(mat.shape[0]),
|
||||
ctypes.c_int32(mat.shape[1]),
|
||||
ctypes.c_int(1),
|
||||
c_str('max_bin=15'),
|
||||
ref,
|
||||
ctypes.byref(handle))
|
||||
num_data = ctypes.c_long()
|
||||
num_data = ctypes.c_int(0)
|
||||
LIB.LGBM_DatasetGetNumData(handle, ctypes.byref(num_data))
|
||||
num_feature = ctypes.c_long()
|
||||
num_feature = ctypes.c_int(0)
|
||||
LIB.LGBM_DatasetGetNumFeature(handle, ctypes.byref(num_feature))
|
||||
LIB.LGBM_DatasetSetField(handle, c_str('label'), c_array(ctypes.c_float, label), len(label), 0)
|
||||
LIB.LGBM_DatasetSetField(
|
||||
handle,
|
||||
c_str('label'),
|
||||
label.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
|
||||
ctypes.c_int(len(label)),
|
||||
ctypes.c_int(dtype_float32))
|
||||
print(f'#data: {num_data.value} #feature: {num_feature.value}')
|
||||
return handle
|
||||
|
||||
|
@ -228,20 +239,25 @@ def test_booster():
|
|||
for i in range(1, 51):
|
||||
LIB.LGBM_BoosterUpdateOneIter(booster, ctypes.byref(is_finished))
|
||||
result = np.array([0.0], dtype=np.float64)
|
||||
out_len = ctypes.c_ulong(0)
|
||||
out_len = ctypes.c_int(0)
|
||||
LIB.LGBM_BoosterGetEval(
|
||||
booster,
|
||||
0,
|
||||
ctypes.c_int(0),
|
||||
ctypes.byref(out_len),
|
||||
result.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))
|
||||
if i % 10 == 0:
|
||||
print(f'{i} iteration test AUC {result[0]:.6f}')
|
||||
LIB.LGBM_BoosterSaveModel(booster, 0, -1, 0, c_str('model.txt'))
|
||||
LIB.LGBM_BoosterSaveModel(
|
||||
booster,
|
||||
ctypes.c_int(0),
|
||||
ctypes.c_int(-1),
|
||||
ctypes.c_int(0),
|
||||
c_str('model.txt'))
|
||||
LIB.LGBM_BoosterFree(booster)
|
||||
free_dataset(train)
|
||||
free_dataset(test)
|
||||
booster2 = ctypes.c_void_p()
|
||||
num_total_model = ctypes.c_long()
|
||||
num_total_model = ctypes.c_int(0)
|
||||
LIB.LGBM_BoosterCreateFromModelfile(
|
||||
c_str('model.txt'),
|
||||
ctypes.byref(num_total_model),
|
||||
|
@ -251,20 +267,20 @@ def test_booster():
|
|||
'../../examples/binary_classification/binary.test'), 'r') as inp:
|
||||
for line in inp.readlines():
|
||||
data.append([float(x) for x in line.split('\t')[1:]])
|
||||
mat = np.array(data)
|
||||
mat = np.array(data, dtype=np.float64)
|
||||
preb = np.zeros(mat.shape[0], dtype=np.float64)
|
||||
num_preb = ctypes.c_long()
|
||||
data = np.array(mat.reshape(mat.size), copy=False)
|
||||
num_preb = ctypes.c_int64(0)
|
||||
data = np.array(mat.reshape(mat.size), dtype=np.float64, copy=False)
|
||||
LIB.LGBM_BoosterPredictForMat(
|
||||
booster2,
|
||||
data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)),
|
||||
dtype_float64,
|
||||
mat.shape[0],
|
||||
mat.shape[1],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
25,
|
||||
data.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
|
||||
ctypes.c_int(dtype_float64),
|
||||
ctypes.c_int32(mat.shape[0]),
|
||||
ctypes.c_int32(mat.shape[1]),
|
||||
ctypes.c_int(1),
|
||||
ctypes.c_int(1),
|
||||
ctypes.c_int(0),
|
||||
ctypes.c_int(25),
|
||||
c_str(''),
|
||||
ctypes.byref(num_preb),
|
||||
preb.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))
|
||||
|
@ -272,20 +288,20 @@ def test_booster():
|
|||
booster2,
|
||||
c_str(os.path.join(os.path.dirname(os.path.realpath(__file__)),
|
||||
'../../examples/binary_classification/binary.test')),
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
25,
|
||||
ctypes.c_int(0),
|
||||
ctypes.c_int(0),
|
||||
ctypes.c_int(0),
|
||||
ctypes.c_int(25),
|
||||
c_str(''),
|
||||
c_str('preb.txt'))
|
||||
LIB.LGBM_BoosterPredictForFile(
|
||||
booster2,
|
||||
c_str(os.path.join(os.path.dirname(os.path.realpath(__file__)),
|
||||
'../../examples/binary_classification/binary.test')),
|
||||
0,
|
||||
0,
|
||||
10,
|
||||
25,
|
||||
ctypes.c_int(0),
|
||||
ctypes.c_int(0),
|
||||
ctypes.c_int(10),
|
||||
ctypes.c_int(25),
|
||||
c_str(''),
|
||||
c_str('preb.txt'))
|
||||
LIB.LGBM_BoosterFree(booster2)
|
||||
|
|
Загрузка…
Ссылка в новой задаче