This commit is contained in:
Caleb Robinson 2021-06-08 16:48:57 +00:00
Родитель 43c01922b4
Коммит 4fff2bfd73
3 изменённых файлов: 86 добавлений и 55 удалений

46
training/models/fcn.py Normal file
Просмотреть файл

@ -0,0 +1,46 @@
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
class FCN(nn.Module):
def __init__(self, num_input_channels, num_output_classes, num_filters=64):
super(FCN, self).__init__()
self.conv1 = nn.Conv2d(
num_input_channels, num_filters, kernel_size=3, stride=1, padding=1
)
self.conv2 = nn.Conv2d(
num_filters, num_filters, kernel_size=3, stride=1, padding=1
)
self.conv3 = nn.Conv2d(
num_filters, num_filters, kernel_size=3, stride=1, padding=1
)
self.conv4 = nn.Conv2d(
num_filters, num_filters, kernel_size=3, stride=1, padding=1
)
self.conv5 = nn.Conv2d(
num_filters, num_filters, kernel_size=3, stride=1, padding=1
)
self.last = nn.Conv2d(
num_filters, num_output_classes, kernel_size=1, stride=1, padding=0
)
def forward(self, inputs):
x = F.relu(self.conv1(inputs))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = F.relu(self.conv4(x))
x = F.relu(self.conv5(x))
x = self.last(x)
return x
def forward_features(self, inputs):
x = F.relu(self.conv1(inputs))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = F.relu(self.conv4(x))
x = F.relu(self.conv5(x))
y = self.last(x)
return y, x

Просмотреть файл

@ -19,6 +19,7 @@ from sklearn.preprocessing import LabelBinarizer
from .ModelSessionAbstract import ModelSession from .ModelSessionAbstract import ModelSession
from training.models.unet_solar import UnetModel from training.models.unet_solar import UnetModel
from training.models.fcn import FCN
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -27,14 +28,6 @@ import torch.nn.functional as F
class TorchFineTuning(ModelSession): class TorchFineTuning(ModelSession):
# AUGMENT_MODEL = SGDClassifier(
# loss="log",
# shuffle=True,
# n_jobs=-1,
# learning_rate="constant",
# eta0=0.001,
# warm_start=True
# )
AUGMENT_MODEL = MLPClassifier( AUGMENT_MODEL = MLPClassifier(
hidden_layer_sizes=(), hidden_layer_sizes=(),
alpha=0.0001, alpha=0.0001,
@ -50,42 +43,41 @@ class TorchFineTuning(ModelSession):
self.model_fn = kwargs["fn"] self.model_fn = kwargs["fn"]
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.output_channels = 3 #kwargs["num_classes"] self.output_channels = kwargs["num_classes"]
self.output_features = 64 self.output_features = 64
self.input_size = kwargs["input_size"] self.input_size = kwargs["input_size"]
self.input_channels = kwargs["input_channels"]
self.down_weight_padding = 10 self.down_weight_padding = 10
self.stride_x = self.input_size - self.down_weight_padding*2 self.stride_x = self.input_size - self.down_weight_padding*2
self.stride_y = self.input_size - self.down_weight_padding*2 self.stride_y = self.input_size - self.down_weight_padding*2
model_opts = types.SimpleNamespace( self.model = FCN(self.input_channels, num_output_classes=self.output_channels, num_filters=64)
**kwargs
)
self.model = UnetModel(model_opts)
self._init_model() self._init_model()
for param in self.model.parameters(): for param in self.model.parameters():
param.requires_grad = False param.requires_grad = False
self.initial_weights = self.model.seg_layer.weight.cpu().detach().numpy().squeeze()
self.initial_biases = self.model.seg_layer.bias.cpu().detach().numpy()
print(self.initial_weights.shape)
print(self.initial_biases.shape)
self.augment_model = sklearn.base.clone(TorchFineTuning.AUGMENT_MODEL) self.augment_model = sklearn.base.clone(TorchFineTuning.AUGMENT_MODEL)
#self.augment_model.coef_ = self.initial_weights.astype(np.float64)
#self.augment_model.intercept_ = self.initial_biases.astype(np.float64)
#self.augment_model.classes_ = np.array(list(range(self.output_channels)))
#self.augment_model.n_features_in_ = self.output_features
#self.augment_model.n_features = self.output_features
self._last_tile = None self._last_tile = None
with np.load(kwargs["seed_data_fn"]) as f:
embeddings = f["embeddings"].copy()
labels = f["labels"].copy()
idxs = np.random.choice(embeddings.shape[0], size=500)
self.augment_x_base = embeddings[idxs]
self.augment_y_base = labels[idxs]
self.augment_x_train = [] self.augment_x_train = []
self.augment_y_train = [] self.augment_y_train = []
for row in self.augment_x_base:
self.augment_x_train.append(row)
for row in self.augment_y_base:
self.augment_y_train.append(row)
@property @property
def last_tile(self): def last_tile(self):
@ -95,14 +87,16 @@ class TorchFineTuning(ModelSession):
checkpoint = torch.load(self.model_fn, map_location=self.device) checkpoint = torch.load(self.model_fn, map_location=self.device)
self.model.load_state_dict(checkpoint) self.model.load_state_dict(checkpoint)
self.model.eval() self.model.eval()
self.model.seg_layer = nn.Conv2d(64, 3, kernel_size=1)
self.model = self.model.to(self.device) self.model = self.model.to(self.device)
def run(self, tile, inference_mode=False): def run(self, tile, inference_mode=False):
means = np.array([660.5929,812.9481,1080.6552,1398.3968,1662.5913,1899.4804,2061.932,2100.2792,2214.9325,2230.5973,2443.3014,1968.1885]) if tile.shape[2] == 3: # If we get a 3 channel image, then pretend it is 4 channel by duplicating the first band
stdevs = np.array([137.4943,195.3494,241.2698,378.7495,383.0338,449.3187,511.3159,547.6335,563.8937,501.023,624.041,478.9655]) tile = np.concatenate([
tile = (tile - means) / stdevs tile,
tile[:,:,0][:,:,np.newaxis]
], axis=2)
tile = tile / 255.0
tile = tile.astype(np.float32) tile = tile.astype(np.float32)
output, output_features = self.run_model_on_tile(tile) output, output_features = self.run_model_on_tile(tile)
@ -131,8 +125,8 @@ class TorchFineTuning(ModelSession):
new_weights = new_weights.to(self.device) new_weights = new_weights.to(self.device)
new_biases = new_biases.to(self.device) new_biases = new_biases.to(self.device)
self.model.seg_layer.weight.data = new_weights self.model.last.weight.data = new_weights
self.model.seg_layer.bias.data = new_biases self.model.last.bias.data = new_biases
return { return {
"message": "Fine-tuning accuracy on data: %0.2f" % (score), "message": "Fine-tuning accuracy on data: %0.2f" % (score),
@ -174,23 +168,14 @@ class TorchFineTuning(ModelSession):
def reset(self): def reset(self):
self._init_model() self._init_model()
self.augment_x_train = []
self.augment_y_train = []
self.augment_model = sklearn.base.clone(TorchFineTuning.AUGMENT_MODEL) self.augment_model = sklearn.base.clone(TorchFineTuning.AUGMENT_MODEL)
label_binarizer = LabelBinarizer() self.augment_x_train = []
label_binarizer.fit(range(self.output_channels)) self.augment_y_train = []
for row in self.augment_x_base:
self.augment_model.coefs_ = [self.initial_weights] self.augment_x_train.append(row)
self.augment_model.intercepts_ = [self.initial_biases] for row in self.augment_y_base:
self.augment_y_train.append(row)
self.augment_model.classes_ = np.array(list(range(self.output_channels)))
self.augment_model.n_features_in_ = self.output_features
self.augment_model.n_outputs_ = self.output_channels
self.augment_model.n_layers_ = 2
self.augment_model.out_activation_ = 'softmax'
self.augment_model._label_binarizer = label_binarizer
return { return {
"message": "Model reset successfully", "message": "Model reset successfully",

Просмотреть файл

@ -84,7 +84,7 @@
"name": "Grid" "name": "Grid"
} }
], ],
"validModels": ["naip_demo"] "validModels": ["naip_demo", "devseed_9class"]
}, },
"naip_all": { "naip_all": {
"metadata": { "metadata": {
@ -110,7 +110,7 @@
} }
], ],
"shapeLayers": null, "shapeLayers": null,
"validModels": ["naip_demo"] "validModels": ["naip_demo", "devseed_9class"]
}, },
"esri_all": { "esri_all": {
"metadata": { "metadata": {
@ -137,7 +137,7 @@
} }
], ],
"shapeLayers": null, "shapeLayers": null,
"validModels": ["naip_demo"] "validModels": ["naip_demo", "devseed_9class"]
}, },
"lc_all": { "lc_all": {
"metadata": { "metadata": {