Fix StructuredColumnTransformer transformer and add some tests for it (#3762)

The test also shows the difference between StructuredColumnTransformer and ColumnTransformer
This commit is contained in:
Marco Castelluccio 2023-10-27 17:18:38 +02:00 коммит произвёл GitHub
Родитель 798f215e42
Коммит efabd210fd
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 48 добавлений и 1 удалений

Просмотреть файл

@ -79,7 +79,7 @@ class StructuredColumnTransformer(ColumnTransformer):
for i, (f, transformer_name) in enumerate(zip(Xs, transformer_names)):
types.append((transformer_name, result.dtype, (f.shape[1],)))
return result.todense().view(np.dtype(types))
return result.view(np.dtype(types))
class DictExtractor(BaseEstimator, TransformerMixin):

Просмотреть файл

@ -8,10 +8,14 @@ import os
import pickle
from datetime import datetime
import numpy as np
import pandas as pd
import pytest
import requests
import responses
import urllib3
from sklearn.compose import ColumnTransformer
from sklearn.feature_extraction.text import CountVectorizer
from bugbug import utils
@ -419,3 +423,46 @@ def test_extract_private_url_empty() -> None:
body = """<p>Test content</p> """
result = utils.extract_private(body)
assert result is None
def test_StructuredColumnTransformer() -> None:
transformers = [
("feat1_transformed", CountVectorizer(), "feat1"),
("feat2_transformed", CountVectorizer(), "feat2"),
]
df = pd.DataFrame(
[
{
"feat1": "First",
"feat2": "Second",
},
{
"feat1": "Third",
"feat2": "Fourth",
},
]
)
np.testing.assert_array_equal(
ColumnTransformer(transformers).fit_transform(df),
np.array([[1, 0, 0, 1], [0, 1, 1, 0]]),
)
np.testing.assert_array_equal(
utils.StructuredColumnTransformer(transformers).fit_transform(df),
np.array(
[[([1, 0], [0, 1])], [([0, 1], [1, 0])]],
dtype=[
("feat1_transformed", "<i8", (2,)),
("feat2_transformed", "<i8", (2,)),
],
),
)
np.testing.assert_array_equal(
utils.StructuredColumnTransformer(transformers)
.fit_transform(df)
.view(np.dtype("int64")),
ColumnTransformer(transformers).fit_transform(df),
)