зеркало из https://github.com/microsoft/LightGBM.git
[python] add parameter object_hook to method dump_model (#4533)
* add parameter object_hook to function dump_model (python API) * eol * fix syntax * lint * better documentation * Update python-package/lightgbm/basic.py Co-authored-by: Nikita Titov <nekit94-08@mail.ru> Co-authored-by: xavier dupré <xavier.dupre@gmail.com> Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
This commit is contained in:
Родитель
4db10d86dc
Коммит
11d7608f2d
|
@ -3342,7 +3342,7 @@ class Booster:
|
|||
ret += _dump_pandas_categorical(self.pandas_categorical)
|
||||
return ret
|
||||
|
||||
def dump_model(self, num_iteration=None, start_iteration=0, importance_type='split'):
|
||||
def dump_model(self, num_iteration=None, start_iteration=0, importance_type='split', object_hook=None):
|
||||
"""Dump Booster to JSON format.
|
||||
|
||||
Parameters
|
||||
|
@ -3357,6 +3357,15 @@ class Booster:
|
|||
What type of feature importance should be dumped.
|
||||
If "split", result contains numbers of times the feature is used in a model.
|
||||
If "gain", result contains total gains of splits which use the feature.
|
||||
object_hook : callable or None, optional (default=None)
|
||||
If not None, ``object_hook`` is a function called while parsing the json
|
||||
string returned by the C API. It may be used to alter the json, to store
|
||||
specific values while building the json structure. It avoids
|
||||
walking through the structure again. It saves a significant amount
|
||||
of time if the number of trees is huge.
|
||||
Signature is ``def object_hook(node: dict) -> dict``.
|
||||
None is equivalent to ``lambda node: node``.
|
||||
See documentation of ``json.loads()`` for further details.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
@ -3391,7 +3400,7 @@ class Booster:
|
|||
ctypes.c_int64(actual_len),
|
||||
ctypes.byref(tmp_out_len),
|
||||
ptr_string_buffer))
|
||||
ret = json.loads(string_buffer.value.decode('utf-8'))
|
||||
ret = json.loads(string_buffer.value.decode('utf-8'), object_hook=object_hook)
|
||||
ret['pandas_categorical'] = json.loads(json.dumps(self.pandas_categorical,
|
||||
default=json_default_with_numpy))
|
||||
return ret
|
||||
|
|
|
@ -2846,3 +2846,23 @@ def test_dump_model():
|
|||
assert "leaf_const" in dumped_model_str
|
||||
assert "leaf_value" in dumped_model_str
|
||||
assert "leaf_count" in dumped_model_str
|
||||
|
||||
|
||||
def test_dump_model_hook():
|
||||
|
||||
def hook(obj):
|
||||
if 'leaf_value' in obj:
|
||||
obj['LV'] = obj['leaf_value']
|
||||
del obj['leaf_value']
|
||||
return obj
|
||||
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
train_data = lgb.Dataset(X, label=y)
|
||||
params = {
|
||||
"objective": "binary",
|
||||
"verbose": -1
|
||||
}
|
||||
bst = lgb.train(params, train_data, num_boost_round=5)
|
||||
dumped_model_str = str(bst.dump_model(5, 0, object_hook=hook))
|
||||
assert "leaf_value" not in dumped_model_str
|
||||
assert "LV" in dumped_model_str
|
||||
|
|
Загрузка…
Ссылка в новой задаче