This commit is contained in:
Caleb Robinson 2021-06-08 16:48:57 +00:00
Родитель 43c01922b4
Коммит 4fff2bfd73
3 изменённых файлов: 86 добавлений и 55 удалений

46
training/models/fcn.py Normal file
Просмотреть файл

@ -0,0 +1,46 @@
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
class FCN(nn.Module):
def __init__(self, num_input_channels, num_output_classes, num_filters=64):
super(FCN, self).__init__()
self.conv1 = nn.Conv2d(
num_input_channels, num_filters, kernel_size=3, stride=1, padding=1
)
self.conv2 = nn.Conv2d(
num_filters, num_filters, kernel_size=3, stride=1, padding=1
)
self.conv3 = nn.Conv2d(
num_filters, num_filters, kernel_size=3, stride=1, padding=1
)
self.conv4 = nn.Conv2d(
num_filters, num_filters, kernel_size=3, stride=1, padding=1
)
self.conv5 = nn.Conv2d(
num_filters, num_filters, kernel_size=3, stride=1, padding=1
)
self.last = nn.Conv2d(
num_filters, num_output_classes, kernel_size=1, stride=1, padding=0
)
def forward(self, inputs):
x = F.relu(self.conv1(inputs))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = F.relu(self.conv4(x))
x = F.relu(self.conv5(x))
x = self.last(x)
return x
def forward_features(self, inputs):
x = F.relu(self.conv1(inputs))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = F.relu(self.conv4(x))
x = F.relu(self.conv5(x))
y = self.last(x)
return y, x

Просмотреть файл

@ -19,6 +19,7 @@ from sklearn.preprocessing import LabelBinarizer
from .ModelSessionAbstract import ModelSession
from training.models.unet_solar import UnetModel
from training.models.fcn import FCN
import torch
import torch.nn as nn
@ -27,14 +28,6 @@ import torch.nn.functional as F
class TorchFineTuning(ModelSession):
# AUGMENT_MODEL = SGDClassifier(
# loss="log",
# shuffle=True,
# n_jobs=-1,
# learning_rate="constant",
# eta0=0.001,
# warm_start=True
# )
AUGMENT_MODEL = MLPClassifier(
hidden_layer_sizes=(),
alpha=0.0001,
@ -50,42 +43,41 @@ class TorchFineTuning(ModelSession):
self.model_fn = kwargs["fn"]
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.output_channels = 3 #kwargs["num_classes"]
self.output_channels = kwargs["num_classes"]
self.output_features = 64
self.input_size = kwargs["input_size"]
self.input_channels = kwargs["input_channels"]
self.down_weight_padding = 10
self.stride_x = self.input_size - self.down_weight_padding*2
self.stride_y = self.input_size - self.down_weight_padding*2
model_opts = types.SimpleNamespace(
**kwargs
)
self.model = UnetModel(model_opts)
self.model = FCN(self.input_channels, num_output_classes=self.output_channels, num_filters=64)
self._init_model()
for param in self.model.parameters():
param.requires_grad = False
self.initial_weights = self.model.seg_layer.weight.cpu().detach().numpy().squeeze()
self.initial_biases = self.model.seg_layer.bias.cpu().detach().numpy()
print(self.initial_weights.shape)
print(self.initial_biases.shape)
self.augment_model = sklearn.base.clone(TorchFineTuning.AUGMENT_MODEL)
#self.augment_model.coef_ = self.initial_weights.astype(np.float64)
#self.augment_model.intercept_ = self.initial_biases.astype(np.float64)
#self.augment_model.classes_ = np.array(list(range(self.output_channels)))
#self.augment_model.n_features_in_ = self.output_features
#self.augment_model.n_features = self.output_features
self._last_tile = None
with np.load(kwargs["seed_data_fn"]) as f:
embeddings = f["embeddings"].copy()
labels = f["labels"].copy()
idxs = np.random.choice(embeddings.shape[0], size=500)
self.augment_x_base = embeddings[idxs]
self.augment_y_base = labels[idxs]
self.augment_x_train = []
self.augment_y_train = []
for row in self.augment_x_base:
self.augment_x_train.append(row)
for row in self.augment_y_base:
self.augment_y_train.append(row)
@property
def last_tile(self):
@ -95,14 +87,16 @@ class TorchFineTuning(ModelSession):
checkpoint = torch.load(self.model_fn, map_location=self.device)
self.model.load_state_dict(checkpoint)
self.model.eval()
self.model.seg_layer = nn.Conv2d(64, 3, kernel_size=1)
self.model = self.model.to(self.device)
def run(self, tile, inference_mode=False):
means = np.array([660.5929,812.9481,1080.6552,1398.3968,1662.5913,1899.4804,2061.932,2100.2792,2214.9325,2230.5973,2443.3014,1968.1885])
stdevs = np.array([137.4943,195.3494,241.2698,378.7495,383.0338,449.3187,511.3159,547.6335,563.8937,501.023,624.041,478.9655])
tile = (tile - means) / stdevs
if tile.shape[2] == 3: # If we get a 3 channel image, then pretend it is 4 channel by duplicating the first band
tile = np.concatenate([
tile,
tile[:,:,0][:,:,np.newaxis]
], axis=2)
tile = tile / 255.0
tile = tile.astype(np.float32)
output, output_features = self.run_model_on_tile(tile)
@ -113,14 +107,14 @@ class TorchFineTuning(ModelSession):
def retrain(self, **kwargs):
x_train = np.array(self.augment_x_train)
y_train = np.array(self.augment_y_train)
if x_train.shape[0] == 0:
return {
"message": "Need to add training samples in order to train",
"success": False
}
try:
self.augment_model.fit(x_train, y_train)
score = self.augment_model.score(x_train, y_train)
@ -131,8 +125,8 @@ class TorchFineTuning(ModelSession):
new_weights = new_weights.to(self.device)
new_biases = new_biases.to(self.device)
self.model.seg_layer.weight.data = new_weights
self.model.seg_layer.bias.data = new_biases
self.model.last.weight.data = new_weights
self.model.last.bias.data = new_biases
return {
"message": "Fine-tuning accuracy on data: %0.2f" % (score),
@ -174,23 +168,14 @@ class TorchFineTuning(ModelSession):
def reset(self):
self._init_model()
self.augment_x_train = []
self.augment_y_train = []
self.augment_model = sklearn.base.clone(TorchFineTuning.AUGMENT_MODEL)
label_binarizer = LabelBinarizer()
label_binarizer.fit(range(self.output_channels))
self.augment_model.coefs_ = [self.initial_weights]
self.augment_model.intercepts_ = [self.initial_biases]
self.augment_model.classes_ = np.array(list(range(self.output_channels)))
self.augment_model.n_features_in_ = self.output_features
self.augment_model.n_outputs_ = self.output_channels
self.augment_model.n_layers_ = 2
self.augment_model.out_activation_ = 'softmax'
self.augment_model._label_binarizer = label_binarizer
self.augment_x_train = []
self.augment_y_train = []
for row in self.augment_x_base:
self.augment_x_train.append(row)
for row in self.augment_y_base:
self.augment_y_train.append(row)
return {
"message": "Model reset successfully",
@ -231,7 +216,7 @@ class TorchFineTuning(ModelSession):
t_batch = batch[i:i+batch_size]
t_batch = np.rollaxis(t_batch, 3, 1)
t_batch = torch.from_numpy(t_batch).to(self.device)
with torch.no_grad():
predictions, features = self.model.forward_features(t_batch)
predictions = F.softmax(predictions)
@ -247,7 +232,7 @@ class TorchFineTuning(ModelSession):
model_output = np.concatenate(model_output, axis=0)
model_feature_output = np.concatenate(model_feature_output, axis=0)
for i, (y, x) in enumerate(batch_indices):
output[y:y+self.input_size, x:x+self.input_size] += model_output[i] * kernel[..., np.newaxis]
output_features[y:y+self.input_size, x:x+self.input_size] += model_feature_output[i] * kernel[..., np.newaxis]

Просмотреть файл

@ -84,7 +84,7 @@
"name": "Grid"
}
],
"validModels": ["naip_demo"]
"validModels": ["naip_demo", "devseed_9class"]
},
"naip_all": {
"metadata": {
@ -110,7 +110,7 @@
}
],
"shapeLayers": null,
"validModels": ["naip_demo"]
"validModels": ["naip_demo", "devseed_9class"]
},
"esri_all": {
"metadata": {
@ -137,7 +137,7 @@
}
],
"shapeLayers": null,
"validModels": ["naip_demo"]
"validModels": ["naip_demo", "devseed_9class"]
},
"lc_all": {
"metadata": {