increasing the precision for Naive Bayes method (#360)
This commit is contained in:
Родитель
50f1f5c598
Коммит
dac08f4ff4
|
@ -20,10 +20,10 @@ class BernoulliNBModel(BaseOperator, torch.nn.Module):
|
|||
self.classification = True
|
||||
self.binarize = binarize
|
||||
self.jll_calc_bias = torch.nn.Parameter(
|
||||
torch.from_numpy(jll_calc_bias.astype("float32")).view(-1), requires_grad=False
|
||||
torch.from_numpy(jll_calc_bias.astype("float64")).view(-1), requires_grad=False
|
||||
)
|
||||
self.feature_log_prob_minus_neg_prob = torch.nn.Parameter(
|
||||
torch.from_numpy(feature_log_prob_minus_neg_prob.astype("float32")), requires_grad=False
|
||||
torch.from_numpy(feature_log_prob_minus_neg_prob.astype("float64")), requires_grad=False
|
||||
)
|
||||
self.classes = torch.nn.Parameter(torch.IntTensor(classes), requires_grad=False)
|
||||
self.perform_class_select = False
|
||||
|
@ -31,13 +31,14 @@ class BernoulliNBModel(BaseOperator, torch.nn.Module):
|
|||
self.perform_class_select = True
|
||||
|
||||
def forward(self, x):
|
||||
x = x.double()
|
||||
if self.binarize is not None:
|
||||
x = torch.gt(x, self.binarize).float()
|
||||
x = torch.gt(x, self.binarize).double()
|
||||
|
||||
jll = torch.addmm(self.jll_calc_bias, x, self.feature_log_prob_minus_neg_prob)
|
||||
log_prob_x = torch.logsumexp(jll, dim=1)
|
||||
log_prob_x = jll - log_prob_x.view(-1, 1)
|
||||
prob_x = torch.exp(log_prob_x)
|
||||
prob_x = torch.exp(log_prob_x).float()
|
||||
|
||||
if self.perform_class_select:
|
||||
return torch.index_select(self.classes, 0, torch.argmax(jll, dim=1)), prob_x
|
||||
|
|
|
@ -117,62 +117,62 @@ class TestSklearnNBClassifier(unittest.TestCase):
|
|||
|
||||
# MultinomialNB binary
|
||||
def test_multinomialnb_classifer_bi(self):
|
||||
self._test_bernoulinb_classifer(2)
|
||||
self._test_multinomialnb_classifer(2)
|
||||
|
||||
# MultinomialNB multi-class
|
||||
def test_multinomialnb_classifer_multi(self):
|
||||
self._test_bernoulinb_classifer(3)
|
||||
self._test_multinomialnb_classifer(3)
|
||||
|
||||
# MultinomialNB multi-class w/ modified alpha
|
||||
def test_multinomialnb_classifer_multi_alpha(self):
|
||||
self._test_bernoulinb_classifer(3, alpha=0.5)
|
||||
self._test_multinomialnb_classifer(3, alpha=0.5)
|
||||
|
||||
# MultinomialNB multi-class w/ fir prior
|
||||
def test_multinomialnb_classifer_multi_fit_prior(self):
|
||||
self._test_bernoulinb_classifer(3, fit_prior=True)
|
||||
self._test_multinomialnb_classifer(3, fit_prior=True)
|
||||
|
||||
# MultinomialNB multi-class w/ class prior
|
||||
def test_multinomialnb_classifer_multi_class_prior(self):
|
||||
np.random.seed(0)
|
||||
class_prior = np.random.rand(3)
|
||||
self._test_bernoulinb_classifer(3, class_prior=class_prior)
|
||||
self._test_multinomialnb_classifer(3, class_prior=class_prior)
|
||||
|
||||
# BernoulliNB multi-class w/ labels shift
|
||||
def test_multinomialnb_classifer_multi_labels_shift(self):
|
||||
self._test_bernoulinb_classifer(3, labels_shift=3)
|
||||
self._test_multinomialnb_classifer(3, labels_shift=3)
|
||||
|
||||
# TVM Backend
|
||||
# MultinomialNB binary
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_multinomialnb_classifer_bi_tvm(self):
|
||||
self._test_bernoulinb_classifer(2, backend="tvm")
|
||||
self._test_multinomialnb_classifer(2, backend="tvm")
|
||||
|
||||
# MultinomialNB multi-class
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_multinomialnb_classifer_multi_tvm(self):
|
||||
self._test_bernoulinb_classifer(3, backend="tvm")
|
||||
self._test_multinomialnb_classifer(3, backend="tvm")
|
||||
|
||||
# MultinomialNB multi-class w/ modified alpha
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_multinomialnb_classifer_multi_alpha_tvm(self):
|
||||
self._test_bernoulinb_classifer(3, alpha=0.5, backend="tvm")
|
||||
self._test_multinomialnb_classifer(3, alpha=0.5, backend="tvm")
|
||||
|
||||
# MultinomialNB multi-class w/ fir prior
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_multinomialnb_classifer_multi_fit_prior_tvm(self):
|
||||
self._test_bernoulinb_classifer(3, fit_prior=True, backend="tvm")
|
||||
self._test_multinomialnb_classifer(3, fit_prior=True, backend="tvm")
|
||||
|
||||
# MultinomialNB multi-class w/ class prior
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_multinomialnb_classifer_multi_class_prior_tvm(self):
|
||||
np.random.seed(0)
|
||||
class_prior = np.random.rand(3)
|
||||
self._test_bernoulinb_classifer(3, class_prior=class_prior, backend="tvm")
|
||||
self._test_multinomialnb_classifer(3, class_prior=class_prior, backend="tvm")
|
||||
|
||||
# BernoulliNB multi-class w/ labels shift
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_multinomialnb_classifer_multi_labels_shift_tvm(self):
|
||||
self._test_bernoulinb_classifer(3, labels_shift=3, backend="tvm")
|
||||
self._test_multinomialnb_classifer(3, labels_shift=3, backend="tvm")
|
||||
|
||||
# GaussianNB test function to be parameterized
|
||||
def _test_gaussiannb_classifer(self, num_classes, priors=None, var_smoothing=1e-9, labels_shift=0, backend="torch"):
|
||||
|
|
Загрузка…
Ссылка в новой задаче