This commit is contained in:
Ubuntu 2020-07-22 19:37:57 +00:00 коммит произвёл Caleb Robinson
Родитель 9f4159f0ed
Коммит 05f6b54970
1 изменённых файлов: 12 добавлений и 5 удалений

Просмотреть файл

@ -65,6 +65,7 @@ class TorchSmoothingCycleFineTune(BackendModel):
self.corr_features = [[] for _ in range(num_models) ]
self.corr_labels = [[] for _ in range(num_models) ]
self.num_corrections_since_retrain = [ [ 0 for _ in range(num_models) ] ]
def run(self, naip_data, naip_fn, extent):
print(naip_data.shape)
@ -191,11 +192,15 @@ class TorchSmoothingCycleFineTune(BackendModel):
return success, message
def undo(self):
pass
#if len(self.corr_features)>0:
# self.corr_features = self.corr_features[:-1]
# self.corr_labels = self.corr_labels[:-1]
#print('undoing; now there are %d samples' % len(self.corr_features))
num_undone = sum(self.num_corrections_since_retrain[-1])
message = 'Removed {} labels'.format(' '.join(map(str,self.num_corrections_since_retrain[-1])))
for i in range(self.num_models):
self.corr_features[i] = self.corr_features[i][:-self.num_corrections_since_retrain[-1][i]]
self.corr_labels[i] = self.corr_labels[i][:-self.num_corrections_since_retrain[-1][i]]
self.num_corrections_since_retrain = self.num_corrections_since_retrain[:-1]
if len(self.num_corrections_since_retrain) == 0:
self.num_corrections_since_retrain = [ [ 0 for _ in range(self.num_models)] ]
return True, message, num_undone
def add_sample(self, tdst_row, bdst_row, tdst_col, bdst_col, class_idx, model_idx):
print("adding sample: class %d (incremented to %d) at (%d, %d), model %d" % (class_idx, class_idx+1 , tdst_row, tdst_col, model_idx))
@ -204,6 +209,7 @@ class TorchSmoothingCycleFineTune(BackendModel):
for j in range(tdst_col,bdst_col+1):
self.corr_labels[model_idx].append(class_idx+1)
self.corr_features[model_idx].append(self.features[0,:,i,j])
self.num_corrections_since_retrain[-1][model_idx] += 1
def init_model(self):
checkpoint = T.load(self.model_fn, map_location=self.device)
@ -217,6 +223,7 @@ class TorchSmoothingCycleFineTune(BackendModel):
def reset(self):
self.init_model()
for i in range(self.num_models): self.num_corrections_since_retrain[i] = 0
def run_core_model_on_tile(self, naip_tile):