зеркало из https://github.com/microsoft/landcover.git
fix undo bug
This commit is contained in:
Родитель
05f6b54970
Коммит
92e5a5f57b
|
@ -155,6 +155,8 @@ class TorchSmoothingCycleFineTune(BackendModel):
|
||||||
|
|
||||||
self.init_model()
|
self.init_model()
|
||||||
|
|
||||||
|
self.num_corrections_since_retrain.append([0 for _ in range(self.num_models)])
|
||||||
|
|
||||||
for model, corr_features, corr_labels in zip(self.augment_models, self.corr_features, self.corr_labels):
|
for model, corr_features, corr_labels in zip(self.augment_models, self.corr_features, self.corr_labels):
|
||||||
batch_x = T.from_numpy(np.array(corr_features)).float().to(self.device)
|
batch_x = T.from_numpy(np.array(corr_features)).float().to(self.device)
|
||||||
batch_y = T.from_numpy(np.array(corr_labels)).to(self.device)
|
batch_y = T.from_numpy(np.array(corr_labels)).to(self.device)
|
||||||
|
@ -195,9 +197,10 @@ class TorchSmoothingCycleFineTune(BackendModel):
|
||||||
num_undone = sum(self.num_corrections_since_retrain[-1])
|
num_undone = sum(self.num_corrections_since_retrain[-1])
|
||||||
message = 'Removed {} labels'.format(' '.join(map(str,self.num_corrections_since_retrain[-1])))
|
message = 'Removed {} labels'.format(' '.join(map(str,self.num_corrections_since_retrain[-1])))
|
||||||
for i in range(self.num_models):
|
for i in range(self.num_models):
|
||||||
self.corr_features[i] = self.corr_features[i][:-self.num_corrections_since_retrain[-1][i]]
|
self.corr_features[i] = self.corr_features[i][:len(self.corr_features[i])-self.num_corrections_since_retrain[-1][i]]
|
||||||
self.corr_labels[i] = self.corr_labels[i][:-self.num_corrections_since_retrain[-1][i]]
|
self.corr_labels[i] = self.corr_labels[i][:len(self.corr_labels[i])-self.num_corrections_since_retrain[-1][i]]
|
||||||
self.num_corrections_since_retrain = self.num_corrections_since_retrain[:-1]
|
self.num_corrections_since_retrain[-1][i] = 0
|
||||||
|
if num_undone == 0: self.num_corrections_since_retrain = self.num_corrections_since_retrain[:-1]
|
||||||
if len(self.num_corrections_since_retrain) == 0:
|
if len(self.num_corrections_since_retrain) == 0:
|
||||||
self.num_corrections_since_retrain = [ [ 0 for _ in range(self.num_models)] ]
|
self.num_corrections_since_retrain = [ [ 0 for _ in range(self.num_models)] ]
|
||||||
return True, message, num_undone
|
return True, message, num_undone
|
||||||
|
|
Загрузка…
Ссылка в новой задаче