зеркало из https://github.com/mozilla/bugbug.git
Fix StructuredColumnTransformer transformer and add some tests for it (#3762)
The test also shows the difference between StructuredColumnTransformer and ColumnTransformer
This commit is contained in:
Родитель
798f215e42
Коммит
efabd210fd
|
@ -79,7 +79,7 @@ class StructuredColumnTransformer(ColumnTransformer):
|
|||
for i, (f, transformer_name) in enumerate(zip(Xs, transformer_names)):
|
||||
types.append((transformer_name, result.dtype, (f.shape[1],)))
|
||||
|
||||
return result.todense().view(np.dtype(types))
|
||||
return result.view(np.dtype(types))
|
||||
|
||||
|
||||
class DictExtractor(BaseEstimator, TransformerMixin):
|
||||
|
|
|
@ -8,10 +8,14 @@ import os
|
|||
import pickle
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import requests
|
||||
import responses
|
||||
import urllib3
|
||||
from sklearn.compose import ColumnTransformer
|
||||
from sklearn.feature_extraction.text import CountVectorizer
|
||||
|
||||
from bugbug import utils
|
||||
|
||||
|
@ -419,3 +423,46 @@ def test_extract_private_url_empty() -> None:
|
|||
body = """<p>Test content</p> """
|
||||
result = utils.extract_private(body)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_StructuredColumnTransformer() -> None:
|
||||
transformers = [
|
||||
("feat1_transformed", CountVectorizer(), "feat1"),
|
||||
("feat2_transformed", CountVectorizer(), "feat2"),
|
||||
]
|
||||
|
||||
df = pd.DataFrame(
|
||||
[
|
||||
{
|
||||
"feat1": "First",
|
||||
"feat2": "Second",
|
||||
},
|
||||
{
|
||||
"feat1": "Third",
|
||||
"feat2": "Fourth",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
np.testing.assert_array_equal(
|
||||
ColumnTransformer(transformers).fit_transform(df),
|
||||
np.array([[1, 0, 0, 1], [0, 1, 1, 0]]),
|
||||
)
|
||||
|
||||
np.testing.assert_array_equal(
|
||||
utils.StructuredColumnTransformer(transformers).fit_transform(df),
|
||||
np.array(
|
||||
[[([1, 0], [0, 1])], [([0, 1], [1, 0])]],
|
||||
dtype=[
|
||||
("feat1_transformed", "<i8", (2,)),
|
||||
("feat2_transformed", "<i8", (2,)),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
np.testing.assert_array_equal(
|
||||
utils.StructuredColumnTransformer(transformers)
|
||||
.fit_transform(df)
|
||||
.view(np.dtype("int64")),
|
||||
ColumnTransformer(transformers).fit_transform(df),
|
||||
)
|
||||
|
|
Загрузка…
Ссылка в новой задаче