diff --git a/web_tool/ServerModelsTorchSmoothingCycle.py b/web_tool/ServerModelsTorchSmoothingCycle.py index db161df..99561e8 100644 --- a/web_tool/ServerModelsTorchSmoothingCycle.py +++ b/web_tool/ServerModelsTorchSmoothingCycle.py @@ -155,6 +155,8 @@ class TorchSmoothingCycleFineTune(BackendModel): self.init_model() + self.num_corrections_since_retrain.append([0 for _ in range(self.num_models)]) + for model, corr_features, corr_labels in zip(self.augment_models, self.corr_features, self.corr_labels): batch_x = T.from_numpy(np.array(corr_features)).float().to(self.device) batch_y = T.from_numpy(np.array(corr_labels)).to(self.device) @@ -195,9 +197,10 @@ class TorchSmoothingCycleFineTune(BackendModel): num_undone = sum(self.num_corrections_since_retrain[-1]) message = 'Removed {} labels'.format(' '.join(map(str,self.num_corrections_since_retrain[-1]))) for i in range(self.num_models): - self.corr_features[i] = self.corr_features[i][:-self.num_corrections_since_retrain[-1][i]] - self.corr_labels[i] = self.corr_labels[i][:-self.num_corrections_since_retrain[-1][i]] - self.num_corrections_since_retrain = self.num_corrections_since_retrain[:-1] + self.corr_features[i] = self.corr_features[i][:len(self.corr_features[i])-self.num_corrections_since_retrain[-1][i]] + self.corr_labels[i] = self.corr_labels[i][:len(self.corr_labels[i])-self.num_corrections_since_retrain[-1][i]] + self.num_corrections_since_retrain[-1][i] = 0 + if num_undone == 0: self.num_corrections_since_retrain = self.num_corrections_since_retrain[:-1] if len(self.num_corrections_since_retrain) == 0: self.num_corrections_since_retrain = [ [ 0 for _ in range(self.num_models)] ] return True, message, num_undone