Fix issue with categorical inputs to gcm ProductRegressor

Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
This commit is contained in:
Patrick Bloebaum 2022-10-31 18:34:35 -07:00 коммит произвёл Patrick Blöbaum
Родитель 560b3460aa
Коммит a136ed41bf
2 изменённых файлов: 16 добавлений и 3 удалений

Просмотреть файл

@ -144,11 +144,14 @@ class InvertibleLogarithmicFunction(InvertibleFunction):
class ProductRegressor(PredictionModel):
def __init__(self):
self._one_hot_encoders = {}
def fit(self, X, Y):
# Nothing to fit here.
pass
self._one_hot_encoders = fit_one_hot_encoders(X)
def predict(self, X):
X = apply_one_hot_encoding(X, self._one_hot_encoders)
return np.prod(X, axis=1).reshape(-1, 1)
def clone(self):

Просмотреть файл

@ -4,9 +4,19 @@ from _pytest.python_api import approx
from dowhy.gcm.ml.regression import create_product_regressor
def test_given_product_regressor_then_computes_correct_values():
def test_when_use_product_regressor_then_computes_correct_values():
X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
mdl = create_product_regressor()
# No fit needed
assert mdl.predict(X).reshape(-1) == approx(np.array([6, 120, 504]))
def test_when_input_is_categorical_when_use_product_regressor_then_computes_correct_values():
X = np.column_stack([np.array(["Class 1", "Class 2"]).astype(object), np.array([1, 2])]).astype(object)
mdl = create_product_regressor()
mdl.fit(X, np.zeros(2)) # Need to fit one-hot-encoder
assert mdl.predict(X).reshape(-1) == approx(np.array([0, 2]))