This commit is contained in:
Ubuntu 2020-07-22 19:53:30 +00:00 коммит произвёл Caleb Robinson
Родитель 05f6b54970
Коммит 92e5a5f57b
1 изменённых файлов: 6 добавлений и 3 удалений

Просмотреть файл

@ -155,6 +155,8 @@ class TorchSmoothingCycleFineTune(BackendModel):
self.init_model() self.init_model()
self.num_corrections_since_retrain.append([0 for _ in range(self.num_models)])
for model, corr_features, corr_labels in zip(self.augment_models, self.corr_features, self.corr_labels): for model, corr_features, corr_labels in zip(self.augment_models, self.corr_features, self.corr_labels):
batch_x = T.from_numpy(np.array(corr_features)).float().to(self.device) batch_x = T.from_numpy(np.array(corr_features)).float().to(self.device)
batch_y = T.from_numpy(np.array(corr_labels)).to(self.device) batch_y = T.from_numpy(np.array(corr_labels)).to(self.device)
@ -195,9 +197,10 @@ class TorchSmoothingCycleFineTune(BackendModel):
num_undone = sum(self.num_corrections_since_retrain[-1]) num_undone = sum(self.num_corrections_since_retrain[-1])
message = 'Removed {} labels'.format(' '.join(map(str,self.num_corrections_since_retrain[-1]))) message = 'Removed {} labels'.format(' '.join(map(str,self.num_corrections_since_retrain[-1])))
for i in range(self.num_models): for i in range(self.num_models):
self.corr_features[i] = self.corr_features[i][:-self.num_corrections_since_retrain[-1][i]] self.corr_features[i] = self.corr_features[i][:len(self.corr_features[i])-self.num_corrections_since_retrain[-1][i]]
self.corr_labels[i] = self.corr_labels[i][:-self.num_corrections_since_retrain[-1][i]] self.corr_labels[i] = self.corr_labels[i][:len(self.corr_labels[i])-self.num_corrections_since_retrain[-1][i]]
self.num_corrections_since_retrain = self.num_corrections_since_retrain[:-1] self.num_corrections_since_retrain[-1][i] = 0
if num_undone == 0: self.num_corrections_since_retrain = self.num_corrections_since_retrain[:-1]
if len(self.num_corrections_since_retrain) == 0: if len(self.num_corrections_since_retrain) == 0:
self.num_corrections_since_retrain = [ [ 0 for _ in range(self.num_models)] ] self.num_corrections_since_retrain = [ [ 0 for _ in range(self.num_models)] ]
return True, message, num_undone return True, message, num_undone