[python-package] fix mypy error about pandas categorical features (#6253)

This commit is contained in:
James Lamb 2024-01-03 10:21:21 -06:00 коммит произвёл GitHub
Родитель 2bd60c8f08
Коммит 48e3629dc6
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 5 добавлений и 5 удалений

Просмотреть файл

@ -786,7 +786,7 @@ def _data_from_pandas(
feature_name: _LGBM_FeatureNameConfiguration,
categorical_feature: _LGBM_CategoricalFeatureConfiguration,
pandas_categorical: Optional[List[List]]
) -> Tuple[np.ndarray, List[str], List[str], List[List]]:
) -> Tuple[np.ndarray, List[str], Union[List[str], List[int]], List[List]]:
if len(data.shape) != 2 or data.shape[0] < 1:
raise ValueError('Input data must be 2 dimensional and non empty.')
@ -800,7 +800,7 @@ def _data_from_pandas(
# determine categorical features
cat_cols = [col for col, dtype in zip(data.columns, data.dtypes) if isinstance(dtype, pd_CategoricalDtype)]
cat_cols_not_ordered = [col for col in cat_cols if not data[col].cat.ordered]
cat_cols_not_ordered: List[str] = [col for col in cat_cols if not data[col].cat.ordered]
if pandas_categorical is None: # train dataset
pandas_categorical = [list(data[col].cat.categories) for col in cat_cols]
else:
@ -811,10 +811,10 @@ def _data_from_pandas(
data[col] = data[col].cat.set_categories(category)
if len(cat_cols): # cat_cols is list
data[cat_cols] = data[cat_cols].apply(lambda x: x.cat.codes).replace({-1: np.nan})
if categorical_feature == 'auto': # use cat cols from DataFrame
# use cat cols from DataFrame
if categorical_feature == 'auto':
categorical_feature = cat_cols_not_ordered
else: # use cat cols specified by user
categorical_feature = list(categorical_feature) # type: ignore[assignment]
df_dtypes = [dtype.type for dtype in data.dtypes]
# so that the target dtype considers floats