From 8670013d819817e62015f9951eca683a74561d56 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Tue, 25 Apr 2023 14:34:07 -0500 Subject: [PATCH] [python-package] [ci] fix mypy errors in Booster.__inner_predict() (#5852) --- python-package/lightgbm/basic.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 46d9dd772..507aea265 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -3110,7 +3110,7 @@ class Booster: ctypes.byref(out_num_class))) self.__num_class = out_num_class.value # buffer for inner predict - self.__inner_predict_buffer = [None] + self.__inner_predict_buffer: List[Optional[np.ndarray]] = [None] self.__is_predicted_cur_iter = [False] self.__get_eval_info() self.pandas_categorical = train_set.pandas_categorical @@ -4518,16 +4518,16 @@ class Booster: # avoid to predict many time in one iteration if not self.__is_predicted_cur_iter[data_idx]: tmp_out_len = ctypes.c_int64(0) - data_ptr = self.__inner_predict_buffer[data_idx].ctypes.data_as(ctypes.POINTER(ctypes.c_double)) + data_ptr = self.__inner_predict_buffer[data_idx].ctypes.data_as(ctypes.POINTER(ctypes.c_double)) # type: ignore[union-attr] _safe_call(_LIB.LGBM_BoosterGetPredict( self.handle, ctypes.c_int(data_idx), ctypes.byref(tmp_out_len), data_ptr)) - if tmp_out_len.value != len(self.__inner_predict_buffer[data_idx]): + if tmp_out_len.value != len(self.__inner_predict_buffer[data_idx]): # type: ignore[arg-type] raise ValueError(f"Wrong length of predict results for data {data_idx}") self.__is_predicted_cur_iter[data_idx] = True - result = self.__inner_predict_buffer[data_idx] + result: np.ndarray = self.__inner_predict_buffer[data_idx] # type: ignore[assignment] if self.__num_class > 1: num_data = result.size // self.__num_class result = result.reshape(num_data, self.__num_class, order='F')