Fix issue with categorical inputs to gcm ProductRegressor
Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
This commit is contained in:
Родитель
560b3460aa
Коммит
a136ed41bf
|
@ -144,11 +144,14 @@ class InvertibleLogarithmicFunction(InvertibleFunction):
|
|||
|
||||
|
||||
class ProductRegressor(PredictionModel):
|
||||
def __init__(self):
|
||||
self._one_hot_encoders = {}
|
||||
|
||||
def fit(self, X, Y):
|
||||
# Nothing to fit here.
|
||||
pass
|
||||
self._one_hot_encoders = fit_one_hot_encoders(X)
|
||||
|
||||
def predict(self, X):
|
||||
X = apply_one_hot_encoding(X, self._one_hot_encoders)
|
||||
return np.prod(X, axis=1).reshape(-1, 1)
|
||||
|
||||
def clone(self):
|
||||
|
|
|
@ -4,9 +4,19 @@ from _pytest.python_api import approx
|
|||
from dowhy.gcm.ml.regression import create_product_regressor
|
||||
|
||||
|
||||
def test_given_product_regressor_then_computes_correct_values():
|
||||
def test_when_use_product_regressor_then_computes_correct_values():
|
||||
X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||
|
||||
mdl = create_product_regressor()
|
||||
# No fit needed
|
||||
|
||||
assert mdl.predict(X).reshape(-1) == approx(np.array([6, 120, 504]))
|
||||
|
||||
|
||||
def test_when_input_is_categorical_when_use_product_regressor_then_computes_correct_values():
|
||||
X = np.column_stack([np.array(["Class 1", "Class 2"]).astype(object), np.array([1, 2])]).astype(object)
|
||||
|
||||
mdl = create_product_regressor()
|
||||
mdl.fit(X, np.zeros(2)) # Need to fit one-hot-encoder
|
||||
|
||||
assert mdl.predict(X).reshape(-1) == approx(np.array([0, 2]))
|
||||
|
|
Загрузка…
Ссылка в новой задаче