[python-package] replace .values usage with .to_numpy() (#5612)

This commit is contained in:
superlaut 2022-12-29 09:18:19 +01:00 коммит произвёл GitHub
Родитель 73531662e0
Коммит 46278af56d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 21 добавлений и 2 удалений

Просмотреть файл

@ -602,7 +602,16 @@ def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorica
df_dtypes = [dtype.type for dtype in data.dtypes]
df_dtypes.append(np.float32) # so that the target dtype considers floats
target_dtype = np.find_common_type(df_dtypes, [])
data = data.astype(target_dtype, copy=False).values
try:
# most common case (no nullable dtypes)
data = data.to_numpy(dtype=target_dtype, copy=False)
except TypeError:
# 1.0 <= pd version < 1.1 and nullable dtypes, least common case
# raises error because array is casted to type(pd.NA) and there's no na_value argument
data = data.astype(target_dtype, copy=False).values
except ValueError:
# data has nullable dtypes, but we can specify na_value argument and copy will be made
data = data.to_numpy(dtype=target_dtype, na_value=np.nan)
else:
if feature_name == 'auto':
feature_name = None
@ -2291,7 +2300,17 @@ class Dataset:
if len(label.columns) > 1:
raise ValueError('DataFrame for label cannot have multiple columns')
_check_for_bad_pandas_dtypes(label.dtypes)
label_array = np.ravel(label.values.astype(np.float32, copy=False))
try:
# most common case (no nullable dtypes)
label = label.to_numpy(dtype=np.float32, copy=False)
except TypeError:
# 1.0 <= pd version < 1.1 and nullable dtypes, least common case
# raises error because array is casted to type(pd.NA) and there's no na_value argument
label = label.astype(np.float32, copy=False).values
except ValueError:
# data has nullable dtypes, but we can specify na_value argument and copy will be made
label = label.to_numpy(dtype=np.float32, na_value=np.nan)
label_array = np.ravel(label)
else:
label_array = _list_to_1d_numpy(label, name='label')
self.set_field('label', label_array)