[python-package] [ci] fix mypy errors in Booster.__inner_predict() (#5852)

This commit is contained in:
James Lamb 2023-04-25 14:34:07 -05:00 коммит произвёл GitHub
Родитель ef5acfb423
Коммит 8670013d81
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 4 добавлений и 4 удалений

Просмотреть файл

@ -3110,7 +3110,7 @@ class Booster:
ctypes.byref(out_num_class)))
self.__num_class = out_num_class.value
# buffer for inner predict
self.__inner_predict_buffer = [None]
self.__inner_predict_buffer: List[Optional[np.ndarray]] = [None]
self.__is_predicted_cur_iter = [False]
self.__get_eval_info()
self.pandas_categorical = train_set.pandas_categorical
@ -4518,16 +4518,16 @@ class Booster:
# avoid to predict many time in one iteration
if not self.__is_predicted_cur_iter[data_idx]:
tmp_out_len = ctypes.c_int64(0)
data_ptr = self.__inner_predict_buffer[data_idx].ctypes.data_as(ctypes.POINTER(ctypes.c_double))
data_ptr = self.__inner_predict_buffer[data_idx].ctypes.data_as(ctypes.POINTER(ctypes.c_double)) # type: ignore[union-attr]
_safe_call(_LIB.LGBM_BoosterGetPredict(
self.handle,
ctypes.c_int(data_idx),
ctypes.byref(tmp_out_len),
data_ptr))
if tmp_out_len.value != len(self.__inner_predict_buffer[data_idx]):
if tmp_out_len.value != len(self.__inner_predict_buffer[data_idx]): # type: ignore[arg-type]
raise ValueError(f"Wrong length of predict results for data {data_idx}")
self.__is_predicted_cur_iter[data_idx] = True
result = self.__inner_predict_buffer[data_idx]
result: np.ndarray = self.__inner_predict_buffer[data_idx] # type: ignore[assignment]
if self.__num_class > 1:
num_data = result.size // self.__num_class
result = result.reshape(num_data, self.__num_class, order='F')