increasing the precision for Naive Bayes method (#360)

This commit is contained in:
Supun Nakandala 2020-11-09 10:36:47 -08:00 коммит произвёл GitHub
Родитель 50f1f5c598
Коммит dac08f4ff4
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 17 добавлений и 16 удалений

Просмотреть файл

@ -20,10 +20,10 @@ class BernoulliNBModel(BaseOperator, torch.nn.Module):
self.classification = True self.classification = True
self.binarize = binarize self.binarize = binarize
self.jll_calc_bias = torch.nn.Parameter( self.jll_calc_bias = torch.nn.Parameter(
torch.from_numpy(jll_calc_bias.astype("float32")).view(-1), requires_grad=False torch.from_numpy(jll_calc_bias.astype("float64")).view(-1), requires_grad=False
) )
self.feature_log_prob_minus_neg_prob = torch.nn.Parameter( self.feature_log_prob_minus_neg_prob = torch.nn.Parameter(
torch.from_numpy(feature_log_prob_minus_neg_prob.astype("float32")), requires_grad=False torch.from_numpy(feature_log_prob_minus_neg_prob.astype("float64")), requires_grad=False
) )
self.classes = torch.nn.Parameter(torch.IntTensor(classes), requires_grad=False) self.classes = torch.nn.Parameter(torch.IntTensor(classes), requires_grad=False)
self.perform_class_select = False self.perform_class_select = False
@ -31,13 +31,14 @@ class BernoulliNBModel(BaseOperator, torch.nn.Module):
self.perform_class_select = True self.perform_class_select = True
def forward(self, x): def forward(self, x):
x = x.double()
if self.binarize is not None: if self.binarize is not None:
x = torch.gt(x, self.binarize).float() x = torch.gt(x, self.binarize).double()
jll = torch.addmm(self.jll_calc_bias, x, self.feature_log_prob_minus_neg_prob) jll = torch.addmm(self.jll_calc_bias, x, self.feature_log_prob_minus_neg_prob)
log_prob_x = torch.logsumexp(jll, dim=1) log_prob_x = torch.logsumexp(jll, dim=1)
log_prob_x = jll - log_prob_x.view(-1, 1) log_prob_x = jll - log_prob_x.view(-1, 1)
prob_x = torch.exp(log_prob_x) prob_x = torch.exp(log_prob_x).float()
if self.perform_class_select: if self.perform_class_select:
return torch.index_select(self.classes, 0, torch.argmax(jll, dim=1)), prob_x return torch.index_select(self.classes, 0, torch.argmax(jll, dim=1)), prob_x

Просмотреть файл

@ -117,62 +117,62 @@ class TestSklearnNBClassifier(unittest.TestCase):
# MultinomialNB binary # MultinomialNB binary
def test_multinomialnb_classifer_bi(self): def test_multinomialnb_classifer_bi(self):
self._test_bernoulinb_classifer(2) self._test_multinomialnb_classifer(2)
# MultinomialNB multi-class # MultinomialNB multi-class
def test_multinomialnb_classifer_multi(self): def test_multinomialnb_classifer_multi(self):
self._test_bernoulinb_classifer(3) self._test_multinomialnb_classifer(3)
# MultinomialNB multi-class w/ modified alpha # MultinomialNB multi-class w/ modified alpha
def test_multinomialnb_classifer_multi_alpha(self): def test_multinomialnb_classifer_multi_alpha(self):
self._test_bernoulinb_classifer(3, alpha=0.5) self._test_multinomialnb_classifer(3, alpha=0.5)
# MultinomialNB multi-class w/ fir prior # MultinomialNB multi-class w/ fir prior
def test_multinomialnb_classifer_multi_fit_prior(self): def test_multinomialnb_classifer_multi_fit_prior(self):
self._test_bernoulinb_classifer(3, fit_prior=True) self._test_multinomialnb_classifer(3, fit_prior=True)
# MultinomialNB multi-class w/ class prior # MultinomialNB multi-class w/ class prior
def test_multinomialnb_classifer_multi_class_prior(self): def test_multinomialnb_classifer_multi_class_prior(self):
np.random.seed(0) np.random.seed(0)
class_prior = np.random.rand(3) class_prior = np.random.rand(3)
self._test_bernoulinb_classifer(3, class_prior=class_prior) self._test_multinomialnb_classifer(3, class_prior=class_prior)
# BernoulliNB multi-class w/ labels shift # BernoulliNB multi-class w/ labels shift
def test_multinomialnb_classifer_multi_labels_shift(self): def test_multinomialnb_classifer_multi_labels_shift(self):
self._test_bernoulinb_classifer(3, labels_shift=3) self._test_multinomialnb_classifer(3, labels_shift=3)
# TVM Backend # TVM Backend
# MultinomialNB binary # MultinomialNB binary
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM") @unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_multinomialnb_classifer_bi_tvm(self): def test_multinomialnb_classifer_bi_tvm(self):
self._test_bernoulinb_classifer(2, backend="tvm") self._test_multinomialnb_classifer(2, backend="tvm")
# MultinomialNB multi-class # MultinomialNB multi-class
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM") @unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_multinomialnb_classifer_multi_tvm(self): def test_multinomialnb_classifer_multi_tvm(self):
self._test_bernoulinb_classifer(3, backend="tvm") self._test_multinomialnb_classifer(3, backend="tvm")
# MultinomialNB multi-class w/ modified alpha # MultinomialNB multi-class w/ modified alpha
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM") @unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_multinomialnb_classifer_multi_alpha_tvm(self): def test_multinomialnb_classifer_multi_alpha_tvm(self):
self._test_bernoulinb_classifer(3, alpha=0.5, backend="tvm") self._test_multinomialnb_classifer(3, alpha=0.5, backend="tvm")
# MultinomialNB multi-class w/ fir prior # MultinomialNB multi-class w/ fir prior
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM") @unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_multinomialnb_classifer_multi_fit_prior_tvm(self): def test_multinomialnb_classifer_multi_fit_prior_tvm(self):
self._test_bernoulinb_classifer(3, fit_prior=True, backend="tvm") self._test_multinomialnb_classifer(3, fit_prior=True, backend="tvm")
# MultinomialNB multi-class w/ class prior # MultinomialNB multi-class w/ class prior
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM") @unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_multinomialnb_classifer_multi_class_prior_tvm(self): def test_multinomialnb_classifer_multi_class_prior_tvm(self):
np.random.seed(0) np.random.seed(0)
class_prior = np.random.rand(3) class_prior = np.random.rand(3)
self._test_bernoulinb_classifer(3, class_prior=class_prior, backend="tvm") self._test_multinomialnb_classifer(3, class_prior=class_prior, backend="tvm")
# BernoulliNB multi-class w/ labels shift # BernoulliNB multi-class w/ labels shift
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM") @unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_multinomialnb_classifer_multi_labels_shift_tvm(self): def test_multinomialnb_classifer_multi_labels_shift_tvm(self):
self._test_bernoulinb_classifer(3, labels_shift=3, backend="tvm") self._test_multinomialnb_classifer(3, labels_shift=3, backend="tvm")
# GaussianNB test function to be parameterized # GaussianNB test function to be parameterized
def _test_gaussiannb_classifer(self, num_classes, priors=None, var_smoothing=1e-9, labels_shift=0, backend="torch"): def _test_gaussiannb_classifer(self, num_classes, priors=None, var_smoothing=1e-9, labels_shift=0, backend="torch"):