Change default parameter of SVC model in the GCM module
Before, the Support Vector Classifier did not produce probabilities, which are required for different algorithms in the GCM module. This changes the 'probability' parameter to True. Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
This commit is contained in:
Родитель
c88cc83d83
Коммит
bd23532342
|
@ -76,7 +76,7 @@ def create_ada_boost_classifier(**kwargs) -> SklearnClassificationModel:
|
|||
|
||||
|
||||
def create_support_vector_classifier(**kwargs) -> SklearnClassificationModel:
|
||||
return SklearnClassificationModel(SVC(**kwargs))
|
||||
return SklearnClassificationModel(SVC(**kwargs, probability=True))
|
||||
|
||||
|
||||
def create_knn_classifier(**kwargs) -> SklearnClassificationModel:
|
||||
|
|
|
@ -2,6 +2,7 @@ import numpy as np
|
|||
from flaky import flaky
|
||||
|
||||
from dowhy.gcm.ml import create_hist_gradient_boost_classifier, create_polynom_logistic_regression_classifier
|
||||
from dowhy.gcm.ml.classification import create_support_vector_classifier
|
||||
|
||||
|
||||
@flaky(max_runs=3)
|
||||
|
@ -58,3 +59,9 @@ def test_given_categorical_training_data_with_many_categories_when_fit_classific
|
|||
mdl.fit(X_training, Y_training)
|
||||
|
||||
assert np.sum(mdl.predict(X_test).reshape(-1) == Y_test) > 950
|
||||
|
||||
|
||||
def test_given_svc_model_then_supports_predict_probabilities():
|
||||
mdl = create_support_vector_classifier()
|
||||
mdl.fit(np.random.normal(0, 1, 100), np.random.choice(2, 100).astype(str))
|
||||
mdl.predict_probabilities(np.random.normal(0, 1, 10))
|
||||
|
|
Загрузка…
Ссылка в новой задаче