[python-package] consolidate pandas-to-numpy conversion code (#6156)

This commit is contained in:
James Lamb 2023-11-15 22:10:54 -06:00 коммит произвёл GitHub
Родитель e63e54ace0
Коммит 18dbd65e57
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 27 добавлений и 26 удалений

Просмотреть файл

@ -758,6 +758,23 @@ def _check_for_bad_pandas_dtypes(pandas_dtypes_series: pd_Series) -> None:
f'Fields with bad pandas dtypes: {", ".join(bad_pandas_dtypes)}')
def _pandas_to_numpy(
data: pd_DataFrame,
target_dtype: "np.typing.DTypeLike"
) -> np.ndarray:
_check_for_bad_pandas_dtypes(data.dtypes)
try:
# most common case (no nullable dtypes)
return data.to_numpy(dtype=target_dtype, copy=False)
except TypeError:
# 1.0 <= pd version < 1.1 and nullable dtypes, least common case
# raises error because array is casted to type(pd.NA) and there's no na_value argument
return data.astype(target_dtype, copy=False).values
except ValueError:
# data has nullable dtypes, but we can specify na_value argument and copy will be made
return data.to_numpy(dtype=target_dtype, na_value=np.nan)
def _data_from_pandas(
data: pd_DataFrame,
feature_name: _LGBM_FeatureNameConfiguration,
@ -790,22 +807,17 @@ def _data_from_pandas(
else: # use cat cols specified by user
categorical_feature = list(categorical_feature) # type: ignore[assignment]
# get numpy representation of the data
_check_for_bad_pandas_dtypes(data.dtypes)
df_dtypes = [dtype.type for dtype in data.dtypes]
df_dtypes.append(np.float32) # so that the target dtype considers floats
# so that the target dtype considers floats
df_dtypes.append(np.float32)
target_dtype = np.result_type(*df_dtypes)
try:
# most common case (no nullable dtypes)
data = data.to_numpy(dtype=target_dtype, copy=False)
except TypeError:
# 1.0 <= pd version < 1.1 and nullable dtypes, least common case
# raises error because array is casted to type(pd.NA) and there's no na_value argument
data = data.astype(target_dtype, copy=False).values
except ValueError:
# data has nullable dtypes, but we can specify na_value argument and copy will be made
data = data.to_numpy(dtype=target_dtype, na_value=np.nan)
return data, feature_name, categorical_feature, pandas_categorical
return (
_pandas_to_numpy(data, target_dtype=target_dtype),
feature_name,
categorical_feature,
pandas_categorical
)
def _dump_pandas_categorical(
@ -2805,18 +2817,7 @@ class Dataset:
if isinstance(label, pd_DataFrame):
if len(label.columns) > 1:
raise ValueError('DataFrame for label cannot have multiple columns')
_check_for_bad_pandas_dtypes(label.dtypes)
try:
# most common case (no nullable dtypes)
label = label.to_numpy(dtype=np.float32, copy=False)
except TypeError:
# 1.0 <= pd version < 1.1 and nullable dtypes, least common case
# raises error because array is casted to type(pd.NA) and there's no na_value argument
label = label.astype(np.float32, copy=False).values
except ValueError:
# data has nullable dtypes, but we can specify na_value argument and copy will be made
label = label.to_numpy(dtype=np.float32, na_value=np.nan)
label_array = np.ravel(label)
label_array = np.ravel(_pandas_to_numpy(label, target_dtype=np.float32))
elif _is_pyarrow_array(label):
label_array = label
else: