зеркало из https://github.com/microsoft/LightGBM.git
[python-package] fix mypy error about pandas categorical features (#6253)
This commit is contained in:
Родитель
2bd60c8f08
Коммит
48e3629dc6
|
@ -786,7 +786,7 @@ def _data_from_pandas(
|
|||
feature_name: _LGBM_FeatureNameConfiguration,
|
||||
categorical_feature: _LGBM_CategoricalFeatureConfiguration,
|
||||
pandas_categorical: Optional[List[List]]
|
||||
) -> Tuple[np.ndarray, List[str], List[str], List[List]]:
|
||||
) -> Tuple[np.ndarray, List[str], Union[List[str], List[int]], List[List]]:
|
||||
if len(data.shape) != 2 or data.shape[0] < 1:
|
||||
raise ValueError('Input data must be 2 dimensional and non empty.')
|
||||
|
||||
|
@ -800,7 +800,7 @@ def _data_from_pandas(
|
|||
|
||||
# determine categorical features
|
||||
cat_cols = [col for col, dtype in zip(data.columns, data.dtypes) if isinstance(dtype, pd_CategoricalDtype)]
|
||||
cat_cols_not_ordered = [col for col in cat_cols if not data[col].cat.ordered]
|
||||
cat_cols_not_ordered: List[str] = [col for col in cat_cols if not data[col].cat.ordered]
|
||||
if pandas_categorical is None: # train dataset
|
||||
pandas_categorical = [list(data[col].cat.categories) for col in cat_cols]
|
||||
else:
|
||||
|
@ -811,10 +811,10 @@ def _data_from_pandas(
|
|||
data[col] = data[col].cat.set_categories(category)
|
||||
if len(cat_cols): # cat_cols is list
|
||||
data[cat_cols] = data[cat_cols].apply(lambda x: x.cat.codes).replace({-1: np.nan})
|
||||
if categorical_feature == 'auto': # use cat cols from DataFrame
|
||||
|
||||
# use cat cols from DataFrame
|
||||
if categorical_feature == 'auto':
|
||||
categorical_feature = cat_cols_not_ordered
|
||||
else: # use cat cols specified by user
|
||||
categorical_feature = list(categorical_feature) # type: ignore[assignment]
|
||||
|
||||
df_dtypes = [dtype.type for dtype in data.dtypes]
|
||||
# so that the target dtype considers floats
|
||||
|
|
Загрузка…
Ссылка в новой задаче