diff --git a/PyTorchClassification/data_loader.py b/PyTorchClassification/data_loader.py index 7e06635..1252776 100644 --- a/PyTorchClassification/data_loader.py +++ b/PyTorchClassification/data_loader.py @@ -346,6 +346,10 @@ class JSONDataset(data.Dataset): if (dataFormat2017): # self.tax_levels = ['id', 'name', 'supercategory'] self.tax_levels = ['id', 'name'] + if label_smoothing > 0: + assert len(self.tax_levels) == 3, "Please comment in the line above to include the taxonomy " + \ + "level 'supercategory' in order for label smoothing to work. It should look like this: " + \ + "self.tax_levels = ['id', 'name', 'supercategory']" else: self.tax_levels = ['id', 'genus', 'family', 'order', 'class', 'phylum', 'kingdom'] #8142, 4412, 1120, 273, 57, 25, 6