[python] add parameter object_hook to method dump_model (#4533)

* add parameter object_hook to function dump_model (python API)

* eol

* fix syntax

* lint

* better documentation

* Update python-package/lightgbm/basic.py

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

Co-authored-by: xavier dupré <xavier.dupre@gmail.com>
Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
This commit is contained in:
Xavier Dupré 2021-08-24 00:48:16 +02:00 коммит произвёл GitHub
Родитель 4db10d86dc
Коммит 11d7608f2d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 31 добавлений и 2 удалений

Просмотреть файл

@ -3342,7 +3342,7 @@ class Booster:
ret += _dump_pandas_categorical(self.pandas_categorical)
return ret
def dump_model(self, num_iteration=None, start_iteration=0, importance_type='split'):
def dump_model(self, num_iteration=None, start_iteration=0, importance_type='split', object_hook=None):
"""Dump Booster to JSON format.
Parameters
@ -3357,6 +3357,15 @@ class Booster:
What type of feature importance should be dumped.
If "split", result contains numbers of times the feature is used in a model.
If "gain", result contains total gains of splits which use the feature.
object_hook : callable or None, optional (default=None)
If not None, ``object_hook`` is a function called while parsing the json
string returned by the C API. It may be used to alter the json, to store
specific values while building the json structure. It avoids
walking through the structure again. It saves a significant amount
of time if the number of trees is huge.
Signature is ``def object_hook(node: dict) -> dict``.
None is equivalent to ``lambda node: node``.
See documentation of ``json.loads()`` for further details.
Returns
-------
@ -3391,7 +3400,7 @@ class Booster:
ctypes.c_int64(actual_len),
ctypes.byref(tmp_out_len),
ptr_string_buffer))
ret = json.loads(string_buffer.value.decode('utf-8'))
ret = json.loads(string_buffer.value.decode('utf-8'), object_hook=object_hook)
ret['pandas_categorical'] = json.loads(json.dumps(self.pandas_categorical,
default=json_default_with_numpy))
return ret

Просмотреть файл

@ -2846,3 +2846,23 @@ def test_dump_model():
assert "leaf_const" in dumped_model_str
assert "leaf_value" in dumped_model_str
assert "leaf_count" in dumped_model_str
def test_dump_model_hook():
def hook(obj):
if 'leaf_value' in obj:
obj['LV'] = obj['leaf_value']
del obj['leaf_value']
return obj
X, y = load_breast_cancer(return_X_y=True)
train_data = lgb.Dataset(X, label=y)
params = {
"objective": "binary",
"verbose": -1
}
bst = lgb.train(params, train_data, num_boost_round=5)
dumped_model_str = str(bst.dump_model(5, 0, object_hook=hook))
assert "leaf_value" not in dumped_model_str
assert "LV" in dumped_model_str