update freeform2 parser and add python tests
This commit is contained in:
Родитель
6de49eff7e
Коммит
f6fe40b0bb
|
@ -3,6 +3,8 @@
|
|||
#include <boost/algorithm/string/split.hpp>
|
||||
#include <boost/algorithm/string.hpp>
|
||||
#include <boost/lexical_cast.hpp>
|
||||
#include <boost/property_tree/ptree.hpp>
|
||||
#include <boost/property_tree/json_parser.hpp>
|
||||
#include "TransformProcessor.h"
|
||||
using namespace std;
|
||||
|
||||
|
@ -11,25 +13,25 @@ namespace LightGBM {
|
|||
class FreeForm2Parser: public Parser {
|
||||
public:
|
||||
FreeForm2Parser(std::string config_str) {
|
||||
string label_key = "labelId:";
|
||||
string expr_key = "transform:\n";
|
||||
string header_key = "header:\n";
|
||||
|
||||
size_t start_pos = config_str.find(label_key);
|
||||
config_str.erase(0, start_pos);
|
||||
size_t end_pos = config_str.find("\n");
|
||||
string label_line = config_str.substr(label_key.size(), end_pos);
|
||||
int label_idx = std::stod(label_line);
|
||||
|
||||
start_pos = config_str.find(expr_key);
|
||||
config_str.erase(0, start_pos);
|
||||
end_pos = config_str.find("end of transform");
|
||||
string transform_str = config_str.substr(expr_key.size(), end_pos);
|
||||
|
||||
start_pos = config_str.find(header_key);
|
||||
config_str.erase(0, start_pos);
|
||||
end_pos = config_str.find("end of header");
|
||||
string header_str = config_str.substr(header_key.size(), end_pos);
|
||||
// config should follow json format.
|
||||
std::stringstream config_ss(config_str);
|
||||
boost::property_tree::ptree ptree;
|
||||
boost::property_tree::read_json(config_ss, ptree);
|
||||
int label_idx = -1;
|
||||
auto label_node = ptree.get_child_optional("labelId");
|
||||
if (label_node) {
|
||||
label_idx = label_node->get_value<int>();
|
||||
}
|
||||
std::string transform_str = "";
|
||||
auto transform_node = ptree.get_child_optional("transform");
|
||||
if (transform_node) {
|
||||
transform_str = transform_node->get_value<std::string>();
|
||||
}
|
||||
std::string header_str = "";
|
||||
auto header_node = ptree.get_child_optional("header");
|
||||
if (header_node) {
|
||||
header_str = header_node->get_value<std::string>();
|
||||
}
|
||||
Log::Info("Initializing transform processor.");
|
||||
transform_.reset(new TransformProcessor(transform_str, header_str, label_idx));
|
||||
}
|
||||
|
|
|
@ -3,5 +3,4 @@ from pathlib import Path
|
|||
|
||||
CUSTOM_PARSER_LIB_NAME = 'lib_custom_parser.so'
|
||||
for p in ['lib_transform.so', 'lib_lightgbm.so', CUSTOM_PARSER_LIB_NAME]:
|
||||
print(p)
|
||||
ctypes.cdll.LoadLibrary(str(Path(__file__).resolve().parent / p))
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
# run command `sh ./scripts/publish_python_package.sh` in repo root dir.
|
||||
lgb_python_pkg_dir="./external_libs/LightGBM/python-package"
|
||||
# compile transformation, lightgbm, and customized parser libs.
|
||||
# rm -rf build && mkdir build &&
|
||||
cd build && cmake ../ && make -j4 && cd ../ || exit -1
|
||||
rm -rf build && mkdir build && cd build && cmake ../ && make -j4 && cd ../ || exit -1
|
||||
# copy all shared libs to lightgbm python package directory.
|
||||
cp ./lib_custom_parser.so ${lgb_python_pkg_dir}/lightgbm && \
|
||||
cp ./src/lib_transform.so ${lgb_python_pkg_dir}/lightgbm && \
|
||||
|
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -1,33 +0,0 @@
|
|||
className:FreeForm2Parser
|
||||
labelId:10
|
||||
transform:
|
||||
[Input:1]
|
||||
Line1=(+ feature_1 feature_2)
|
||||
Transform=FreeForm2
|
||||
Slope=1
|
||||
Intercept=0
|
||||
|
||||
[Input:2]
|
||||
Transform=FreeForm2
|
||||
Line1=(* feature_1 feature_3)
|
||||
|
||||
[Input:3]
|
||||
Transform=FreeForm2
|
||||
Line1=(max feature_6 feature_7)
|
||||
|
||||
[Input:4]
|
||||
Transform=Linear
|
||||
Name=feature_8
|
||||
Intercept=0
|
||||
Slope=1
|
||||
|
||||
[Input:5]
|
||||
Transform=Linear
|
||||
Name=feature_9
|
||||
Intercept=0
|
||||
Slope=1
|
||||
end of transform
|
||||
|
||||
header:
|
||||
feature_0 feature_1 feature_2 feature_3 feature_4 feature_5 feature_6 feature_7 feature_8 feature_9 labels
|
||||
end of header
|
|
@ -0,0 +1 @@
|
|||
{"className": "FreeForm2Parser", "transform": "[Input:1]\nLine1=(+ feature_1 feature_2)\nTransform=FreeForm2\nSlope=1\nIntercept=0\n\n[Input:2]\nTransform=FreeForm2\nLine1=(* feature_1 feature_3)\n\n[Input:3]\nTransform=FreeForm2\nLine1=(max feature_6 feature_7)\n\n[Input:4]\nTransform=Linear\nName=feature_8\nIntercept=0\nSlope=1\n\n[Input:5]\nTransform=Linear\nName=feature_9\nIntercept=0\nSlope=1\n", "header": "feature_0\tfeature_1\tfeature_2\tfeature_3\tfeature_4\tfeature_5\tfeature_6\tfeature_7\tfeature_8\tfeature_9\tlabels"}
|
|
@ -0,0 +1,152 @@
|
|||
import json
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
|
||||
import lightgbm as lgb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
Dataset = namedtuple('Dataset', ['data', 'parser_config'])
|
||||
|
||||
|
||||
def root_path():
|
||||
return Path(__file__).parent.parent / "data"
|
||||
|
||||
|
||||
def init_dataset(data_dir):
|
||||
return Dataset(data_dir / "input.tsv", data_dir / "freeform_config.json")
|
||||
|
||||
|
||||
rank_ds = init_dataset(root_path() / "transform_rank_data")
|
||||
simple_ds = init_dataset(root_path() / "transform_simple_data")
|
||||
|
||||
|
||||
def generate_ds_with_header(ds, out_dir):
|
||||
data_path, config_path = out_dir / "input.tsv", out_dir / "freeform_config.json"
|
||||
with open(ds.parser_config) as fin:
|
||||
config = json.load(fin)
|
||||
|
||||
header_str = config.pop("header")
|
||||
with open(config_path, 'w') as fout:
|
||||
json.dump(config, fout)
|
||||
|
||||
df = pd.read_csv(ds.data, sep='\t', header=None)
|
||||
df.to_csv(data_path, index=False, header=header_str.strip().split('\t'), sep='\t')
|
||||
return init_dataset(out_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_ds_with_header(tmp_path):
|
||||
return generate_ds_with_header(simple_ds, tmp_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_data_no_label(simple_ds_with_header, tmp_path):
|
||||
out_path = tmp_path / "simple_input_no_label.tsv"
|
||||
df = pd.read_csv(simple_ds_with_header.data, sep='\t', header=0)
|
||||
df.drop("labels", axis=1).to_csv(out_path, index=False, sep='\t')
|
||||
return out_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trained_model_path(tmp_path):
|
||||
return tmp_path / "model.txt"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def params():
|
||||
return {
|
||||
'boosting': 'gbdt',
|
||||
'learning_rate': 0.1,
|
||||
'label': 3,
|
||||
'query': 0,
|
||||
'objective': 'lambdarank',
|
||||
'metric': 'ndcg',
|
||||
'num_trees': 10,
|
||||
'num_leaves': 31,
|
||||
'label_gain': ','.join([str(i) for i in range(101)]),
|
||||
'force_col_wise': True,
|
||||
'deterministic': True
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def binary_params():
|
||||
return {
|
||||
'boosting': 'gbdt',
|
||||
'learning_rate': 0.1,
|
||||
'label': 10,
|
||||
'objective': 'binary',
|
||||
'metric': 'auc',
|
||||
'num_trees': 10,
|
||||
'num_leaves': 31,
|
||||
'deterministic': True
|
||||
}
|
||||
|
||||
|
||||
def test_e2e(params, trained_model_path):
|
||||
train_data = lgb.Dataset(rank_ds.data, params={"parser_config_file": rank_ds.parser_config})
|
||||
# train and predict.
|
||||
bst = lgb.train(params, train_data, valid_sets=[train_data])
|
||||
pred = bst.predict(rank_ds.data)
|
||||
np.testing.assert_allclose(pred[:5], np.array([0.83267298, 0.388454, 0.35369267, 0.60330376, -1.24218415]))
|
||||
# save model.
|
||||
bst.save_model(trained_model_path)
|
||||
# load model and predict again.
|
||||
bst = lgb.Booster(model_file=trained_model_path)
|
||||
pred = bst.predict(rank_ds.data)
|
||||
np.testing.assert_allclose(pred[:5], np.array([0.83267298, 0.388454, 0.35369267, 0.60330376, -1.24218415]))
|
||||
|
||||
|
||||
def test_train_data_no_header(binary_params, simple_ds_with_header, trained_model_path):
|
||||
train_data = lgb.Dataset(simple_ds.data, params={"parser_config_file": simple_ds.parser_config})
|
||||
valid_data = lgb.Dataset(simple_ds_with_header.data, params={
|
||||
"parser_config_file": simple_ds_with_header.parser_config, "header": True})
|
||||
bst = lgb.train(binary_params, train_data, valid_sets=[valid_data])
|
||||
expected_pred = 0.4894574
|
||||
# predict data with no header.
|
||||
pred = bst.predict(simple_ds.data)
|
||||
np.testing.assert_allclose(pred[:1], expected_pred)
|
||||
# predict data with header.
|
||||
pred = bst.predict(simple_ds_with_header.data, data_has_header=True)
|
||||
np.testing.assert_allclose(pred[:1], expected_pred)
|
||||
|
||||
|
||||
def test_train_data_with_header(binary_params, simple_ds_with_header):
|
||||
train_data = lgb.Dataset(simple_ds_with_header.data, params={
|
||||
"parser_config_file": simple_ds_with_header.parser_config, "header": True})
|
||||
bst = lgb.train(binary_params, train_data, valid_sets=[train_data])
|
||||
expected_pred = 0.4894574
|
||||
# predict data with no header.
|
||||
pred = bst.predict(simple_ds.data)
|
||||
np.testing.assert_allclose(pred[:1], expected_pred)
|
||||
# predict data with header.
|
||||
pred = bst.predict(simple_ds_with_header.data, data_has_header=True)
|
||||
np.testing.assert_allclose(pred[:1], expected_pred)
|
||||
|
||||
|
||||
def test_set_label_by_name(params, capsys):
|
||||
train_data = lgb.Dataset(rank_ds.data, params={"parser_config_file": rank_ds.parser_config})
|
||||
params['label'] = "name:Rating"
|
||||
bst = lgb.train(params, train_data, valid_sets=[train_data])
|
||||
captured = capsys.readouterr()
|
||||
assert "Using column Rating as label" in captured.out
|
||||
pred = bst.predict(rank_ds.data)
|
||||
np.testing.assert_allclose(pred[:5], np.array([0.83267298, 0.388454, 0.35369267, 0.60330376, -1.24218415]))
|
||||
|
||||
|
||||
def test_predict_data_no_label(simple_data_no_label, binary_params):
|
||||
train_data = lgb.Dataset(simple_ds.data,
|
||||
params={"parser_config_file": simple_ds.parser_config})
|
||||
bst = lgb.train(binary_params, train_data, valid_sets=[train_data])
|
||||
pred = bst.predict(simple_data_no_label, data_has_header=True)
|
||||
np.testing.assert_allclose(pred[:5], np.array([0.4894574, 0.43920928, 0.71112129, 0.43920928, 0.39602784]))
|
||||
|
||||
|
||||
def test_train_label_id_less_than_transformed_feature_num(binary_params):
|
||||
train_data = lgb.Dataset(simple_ds.data,
|
||||
params={"parser_config_file": simple_ds.parser_config})
|
||||
bst = lgb.train(binary_params, train_data, valid_sets=[train_data])
|
||||
pred = bst.predict(simple_ds.data)
|
||||
np.testing.assert_allclose(pred[:5], np.array([0.4894574, 0.43920928, 0.71112129, 0.43920928, 0.39602784]))
|
Загрузка…
Ссылка в новой задаче