зеркало из https://github.com/microsoft/LightGBM.git
[python-package] fix some warnings from mypy (#3891)
* minor mypy type errors fixed * fix some warnings from mypy * fix 3 mypy warnings * selectively ignored some mypy errors * minor mypy type errors fixed * minor mypy type errors fixed * minor mypy type errors fixed * added import * Update python-package/lightgbm/callback.py * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: James Lamb <jaylamb20@gmail.com> Co-authored-by: James Lamb <jaylamb20@gmail.com>
This commit is contained in:
Родитель
e5eafad2ba
Коммит
eda1effb52
|
@ -9,7 +9,7 @@ It is based on dask-lightgbm, which was based on dask-xgboost.
|
||||||
import socket
|
import socket
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union, Set
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -77,7 +77,6 @@ def _find_open_port(worker_ip: str, local_listen_port: int, ports_to_skip: Itera
|
||||||
A free port on the machine referenced by ``worker_ip``.
|
A free port on the machine referenced by ``worker_ip``.
|
||||||
"""
|
"""
|
||||||
max_tries = 1000
|
max_tries = 1000
|
||||||
out_port = None
|
|
||||||
found_port = False
|
found_port = False
|
||||||
for i in range(max_tries):
|
for i in range(max_tries):
|
||||||
out_port = local_listen_port + i
|
out_port = local_listen_port + i
|
||||||
|
@ -117,7 +116,7 @@ def _find_ports_for_workers(client: Client, worker_addresses: Iterable[str], loc
|
||||||
result : Dict[str, int]
|
result : Dict[str, int]
|
||||||
Dictionary where keys are worker addresses and values are an open port for LightGBM to use.
|
Dictionary where keys are worker addresses and values are an open port for LightGBM to use.
|
||||||
"""
|
"""
|
||||||
lightgbm_ports = set()
|
lightgbm_ports: Set[int] = set()
|
||||||
worker_ip_to_port = {}
|
worker_ip_to_port = {}
|
||||||
for worker_address in worker_addresses:
|
for worker_address in worker_addresses:
|
||||||
port = client.submit(
|
port = client.submit(
|
||||||
|
@ -306,11 +305,11 @@ def _train(
|
||||||
wait(parts)
|
wait(parts)
|
||||||
|
|
||||||
for part in parts:
|
for part in parts:
|
||||||
if part.status == 'error':
|
if part.status == 'error': # type: ignore
|
||||||
return part # trigger error locally
|
return part # trigger error locally
|
||||||
|
|
||||||
# Find locations of all parts and map them to particular Dask workers
|
# Find locations of all parts and map them to particular Dask workers
|
||||||
key_to_part_dict = {part.key: part for part in parts}
|
key_to_part_dict = {part.key: part for part in parts} # type: ignore
|
||||||
who_has = client.who_has(parts)
|
who_has = client.who_has(parts)
|
||||||
worker_map = defaultdict(list)
|
worker_map = defaultdict(list)
|
||||||
for key, workers in who_has.items():
|
for key, workers in who_has.items():
|
||||||
|
|
|
@ -824,8 +824,8 @@ class LGBMRegressor(LGBMModel, _LGBMRegressorBase):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
_base_doc = LGBMModel.fit.__doc__
|
_base_doc = LGBMModel.fit.__doc__
|
||||||
_base_doc = (_base_doc[:_base_doc.find('group :')]
|
_base_doc = (_base_doc[:_base_doc.find('group :')] # type: ignore
|
||||||
+ _base_doc[_base_doc.find('eval_set :'):])
|
+ _base_doc[_base_doc.find('eval_set :'):]) # type: ignore
|
||||||
_base_doc = (_base_doc[:_base_doc.find('eval_class_weight :')]
|
_base_doc = (_base_doc[:_base_doc.find('eval_class_weight :')]
|
||||||
+ _base_doc[_base_doc.find('eval_init_score :'):])
|
+ _base_doc[_base_doc.find('eval_init_score :'):])
|
||||||
fit.__doc__ = (_base_doc[:_base_doc.find('eval_group :')]
|
fit.__doc__ = (_base_doc[:_base_doc.find('eval_group :')]
|
||||||
|
@ -897,8 +897,8 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
_base_doc = LGBMModel.fit.__doc__
|
_base_doc = LGBMModel.fit.__doc__
|
||||||
_base_doc = (_base_doc[:_base_doc.find('group :')]
|
_base_doc = (_base_doc[:_base_doc.find('group :')] # type: ignore
|
||||||
+ _base_doc[_base_doc.find('eval_set :'):])
|
+ _base_doc[_base_doc.find('eval_set :'):]) # type: ignore
|
||||||
fit.__doc__ = (_base_doc[:_base_doc.find('eval_group :')]
|
fit.__doc__ = (_base_doc[:_base_doc.find('eval_group :')]
|
||||||
+ _base_doc[_base_doc.find('eval_metric :'):])
|
+ _base_doc[_base_doc.find('eval_metric :'):])
|
||||||
|
|
||||||
|
@ -989,8 +989,8 @@ class LGBMRanker(LGBMModel):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
_base_doc = LGBMModel.fit.__doc__
|
_base_doc = LGBMModel.fit.__doc__
|
||||||
fit.__doc__ = (_base_doc[:_base_doc.find('eval_class_weight :')]
|
fit.__doc__ = (_base_doc[:_base_doc.find('eval_class_weight :')] # type: ignore
|
||||||
+ _base_doc[_base_doc.find('eval_init_score :'):])
|
+ _base_doc[_base_doc.find('eval_init_score :'):]) # type: ignore
|
||||||
_base_doc = fit.__doc__
|
_base_doc = fit.__doc__
|
||||||
_before_early_stop, _early_stop, _after_early_stop = _base_doc.partition('early_stopping_rounds :')
|
_before_early_stop, _early_stop, _after_early_stop = _base_doc.partition('early_stopping_rounds :')
|
||||||
fit.__doc__ = (_before_early_stop
|
fit.__doc__ = (_before_early_stop
|
||||||
|
|
|
@ -322,7 +322,7 @@ if __name__ == "__main__":
|
||||||
if os.path.isfile(os.path.join(CURRENT_DIR, os.path.pardir, 'VERSION.txt')):
|
if os.path.isfile(os.path.join(CURRENT_DIR, os.path.pardir, 'VERSION.txt')):
|
||||||
copy_file(os.path.join(CURRENT_DIR, os.path.pardir, 'VERSION.txt'),
|
copy_file(os.path.join(CURRENT_DIR, os.path.pardir, 'VERSION.txt'),
|
||||||
os.path.join(CURRENT_DIR, 'lightgbm', 'VERSION.txt'),
|
os.path.join(CURRENT_DIR, 'lightgbm', 'VERSION.txt'),
|
||||||
verbose=0)
|
verbose=0) # type:ignore
|
||||||
version = open(os.path.join(CURRENT_DIR, 'lightgbm', 'VERSION.txt'), encoding='utf-8').read().strip()
|
version = open(os.path.join(CURRENT_DIR, 'lightgbm', 'VERSION.txt'), encoding='utf-8').read().strip()
|
||||||
readme = open(os.path.join(CURRENT_DIR, 'README.rst'), encoding='utf-8').read()
|
readme = open(os.path.join(CURRENT_DIR, 'README.rst'), encoding='utf-8').read()
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче