diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 560a9a438..5c3a32a4c 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -786,7 +786,7 @@ def _data_from_pandas( feature_name: _LGBM_FeatureNameConfiguration, categorical_feature: _LGBM_CategoricalFeatureConfiguration, pandas_categorical: Optional[List[List]] -) -> Tuple[np.ndarray, List[str], List[str], List[List]]: +) -> Tuple[np.ndarray, List[str], Union[List[str], List[int]], List[List]]: if len(data.shape) != 2 or data.shape[0] < 1: raise ValueError('Input data must be 2 dimensional and non empty.') @@ -800,7 +800,7 @@ def _data_from_pandas( # determine categorical features cat_cols = [col for col, dtype in zip(data.columns, data.dtypes) if isinstance(dtype, pd_CategoricalDtype)] - cat_cols_not_ordered = [col for col in cat_cols if not data[col].cat.ordered] + cat_cols_not_ordered: List[str] = [col for col in cat_cols if not data[col].cat.ordered] if pandas_categorical is None: # train dataset pandas_categorical = [list(data[col].cat.categories) for col in cat_cols] else: @@ -811,10 +811,10 @@ def _data_from_pandas( data[col] = data[col].cat.set_categories(category) if len(cat_cols): # cat_cols is list data[cat_cols] = data[cat_cols].apply(lambda x: x.cat.codes).replace({-1: np.nan}) - if categorical_feature == 'auto': # use cat cols from DataFrame + + # use cat cols from DataFrame + if categorical_feature == 'auto': categorical_feature = cat_cols_not_ordered - else: # use cat cols specified by user - categorical_feature = list(categorical_feature) # type: ignore[assignment] df_dtypes = [dtype.type for dtype in data.dtypes] # so that the target dtype considers floats