зеркало из https://github.com/microsoft/LightGBM.git
[python-package] [ci] fix mypy errors in Booster.__inner_predict() (#5852)
This commit is contained in:
Родитель
ef5acfb423
Коммит
8670013d81
|
@ -3110,7 +3110,7 @@ class Booster:
|
|||
ctypes.byref(out_num_class)))
|
||||
self.__num_class = out_num_class.value
|
||||
# buffer for inner predict
|
||||
self.__inner_predict_buffer = [None]
|
||||
self.__inner_predict_buffer: List[Optional[np.ndarray]] = [None]
|
||||
self.__is_predicted_cur_iter = [False]
|
||||
self.__get_eval_info()
|
||||
self.pandas_categorical = train_set.pandas_categorical
|
||||
|
@ -4518,16 +4518,16 @@ class Booster:
|
|||
# avoid to predict many time in one iteration
|
||||
if not self.__is_predicted_cur_iter[data_idx]:
|
||||
tmp_out_len = ctypes.c_int64(0)
|
||||
data_ptr = self.__inner_predict_buffer[data_idx].ctypes.data_as(ctypes.POINTER(ctypes.c_double))
|
||||
data_ptr = self.__inner_predict_buffer[data_idx].ctypes.data_as(ctypes.POINTER(ctypes.c_double)) # type: ignore[union-attr]
|
||||
_safe_call(_LIB.LGBM_BoosterGetPredict(
|
||||
self.handle,
|
||||
ctypes.c_int(data_idx),
|
||||
ctypes.byref(tmp_out_len),
|
||||
data_ptr))
|
||||
if tmp_out_len.value != len(self.__inner_predict_buffer[data_idx]):
|
||||
if tmp_out_len.value != len(self.__inner_predict_buffer[data_idx]): # type: ignore[arg-type]
|
||||
raise ValueError(f"Wrong length of predict results for data {data_idx}")
|
||||
self.__is_predicted_cur_iter[data_idx] = True
|
||||
result = self.__inner_predict_buffer[data_idx]
|
||||
result: np.ndarray = self.__inner_predict_buffer[data_idx] # type: ignore[assignment]
|
||||
if self.__num_class > 1:
|
||||
num_data = result.size // self.__num_class
|
||||
result = result.reshape(num_data, self.__num_class, order='F')
|
||||
|
|
Загрузка…
Ссылка в новой задаче