зеркало из https://github.com/microsoft/landcover.git
Added DevSeed 9 class model
This commit is contained in:
Родитель
43c01922b4
Коммит
4fff2bfd73
|
@ -0,0 +1,46 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class FCN(nn.Module):
|
||||||
|
def __init__(self, num_input_channels, num_output_classes, num_filters=64):
|
||||||
|
super(FCN, self).__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(
|
||||||
|
num_input_channels, num_filters, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
self.conv2 = nn.Conv2d(
|
||||||
|
num_filters, num_filters, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
self.conv3 = nn.Conv2d(
|
||||||
|
num_filters, num_filters, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
self.conv4 = nn.Conv2d(
|
||||||
|
num_filters, num_filters, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
self.conv5 = nn.Conv2d(
|
||||||
|
num_filters, num_filters, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
self.last = nn.Conv2d(
|
||||||
|
num_filters, num_output_classes, kernel_size=1, stride=1, padding=0
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
x = F.relu(self.conv1(inputs))
|
||||||
|
x = F.relu(self.conv2(x))
|
||||||
|
x = F.relu(self.conv3(x))
|
||||||
|
x = F.relu(self.conv4(x))
|
||||||
|
x = F.relu(self.conv5(x))
|
||||||
|
x = self.last(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward_features(self, inputs):
|
||||||
|
x = F.relu(self.conv1(inputs))
|
||||||
|
x = F.relu(self.conv2(x))
|
||||||
|
x = F.relu(self.conv3(x))
|
||||||
|
x = F.relu(self.conv4(x))
|
||||||
|
x = F.relu(self.conv5(x))
|
||||||
|
y = self.last(x)
|
||||||
|
return y, x
|
|
@ -19,6 +19,7 @@ from sklearn.preprocessing import LabelBinarizer
|
||||||
|
|
||||||
from .ModelSessionAbstract import ModelSession
|
from .ModelSessionAbstract import ModelSession
|
||||||
from training.models.unet_solar import UnetModel
|
from training.models.unet_solar import UnetModel
|
||||||
|
from training.models.fcn import FCN
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -27,14 +28,6 @@ import torch.nn.functional as F
|
||||||
|
|
||||||
class TorchFineTuning(ModelSession):
|
class TorchFineTuning(ModelSession):
|
||||||
|
|
||||||
# AUGMENT_MODEL = SGDClassifier(
|
|
||||||
# loss="log",
|
|
||||||
# shuffle=True,
|
|
||||||
# n_jobs=-1,
|
|
||||||
# learning_rate="constant",
|
|
||||||
# eta0=0.001,
|
|
||||||
# warm_start=True
|
|
||||||
# )
|
|
||||||
AUGMENT_MODEL = MLPClassifier(
|
AUGMENT_MODEL = MLPClassifier(
|
||||||
hidden_layer_sizes=(),
|
hidden_layer_sizes=(),
|
||||||
alpha=0.0001,
|
alpha=0.0001,
|
||||||
|
@ -50,42 +43,41 @@ class TorchFineTuning(ModelSession):
|
||||||
self.model_fn = kwargs["fn"]
|
self.model_fn = kwargs["fn"]
|
||||||
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
self.output_channels = 3 #kwargs["num_classes"]
|
self.output_channels = kwargs["num_classes"]
|
||||||
self.output_features = 64
|
self.output_features = 64
|
||||||
self.input_size = kwargs["input_size"]
|
self.input_size = kwargs["input_size"]
|
||||||
|
self.input_channels = kwargs["input_channels"]
|
||||||
|
|
||||||
self.down_weight_padding = 10
|
self.down_weight_padding = 10
|
||||||
|
|
||||||
self.stride_x = self.input_size - self.down_weight_padding*2
|
self.stride_x = self.input_size - self.down_weight_padding*2
|
||||||
self.stride_y = self.input_size - self.down_weight_padding*2
|
self.stride_y = self.input_size - self.down_weight_padding*2
|
||||||
|
|
||||||
model_opts = types.SimpleNamespace(
|
self.model = FCN(self.input_channels, num_output_classes=self.output_channels, num_filters=64)
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
self.model = UnetModel(model_opts)
|
|
||||||
self._init_model()
|
self._init_model()
|
||||||
|
|
||||||
for param in self.model.parameters():
|
for param in self.model.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
self.initial_weights = self.model.seg_layer.weight.cpu().detach().numpy().squeeze()
|
|
||||||
self.initial_biases = self.model.seg_layer.bias.cpu().detach().numpy()
|
|
||||||
|
|
||||||
print(self.initial_weights.shape)
|
|
||||||
print(self.initial_biases.shape)
|
|
||||||
|
|
||||||
self.augment_model = sklearn.base.clone(TorchFineTuning.AUGMENT_MODEL)
|
self.augment_model = sklearn.base.clone(TorchFineTuning.AUGMENT_MODEL)
|
||||||
|
|
||||||
#self.augment_model.coef_ = self.initial_weights.astype(np.float64)
|
|
||||||
#self.augment_model.intercept_ = self.initial_biases.astype(np.float64)
|
|
||||||
#self.augment_model.classes_ = np.array(list(range(self.output_channels)))
|
|
||||||
#self.augment_model.n_features_in_ = self.output_features
|
|
||||||
#self.augment_model.n_features = self.output_features
|
|
||||||
|
|
||||||
self._last_tile = None
|
self._last_tile = None
|
||||||
|
|
||||||
|
with np.load(kwargs["seed_data_fn"]) as f:
|
||||||
|
embeddings = f["embeddings"].copy()
|
||||||
|
labels = f["labels"].copy()
|
||||||
|
|
||||||
|
idxs = np.random.choice(embeddings.shape[0], size=500)
|
||||||
|
|
||||||
|
self.augment_x_base = embeddings[idxs]
|
||||||
|
self.augment_y_base = labels[idxs]
|
||||||
|
|
||||||
self.augment_x_train = []
|
self.augment_x_train = []
|
||||||
self.augment_y_train = []
|
self.augment_y_train = []
|
||||||
|
for row in self.augment_x_base:
|
||||||
|
self.augment_x_train.append(row)
|
||||||
|
for row in self.augment_y_base:
|
||||||
|
self.augment_y_train.append(row)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def last_tile(self):
|
def last_tile(self):
|
||||||
|
@ -95,14 +87,16 @@ class TorchFineTuning(ModelSession):
|
||||||
checkpoint = torch.load(self.model_fn, map_location=self.device)
|
checkpoint = torch.load(self.model_fn, map_location=self.device)
|
||||||
self.model.load_state_dict(checkpoint)
|
self.model.load_state_dict(checkpoint)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
self.model.seg_layer = nn.Conv2d(64, 3, kernel_size=1)
|
|
||||||
self.model = self.model.to(self.device)
|
self.model = self.model.to(self.device)
|
||||||
|
|
||||||
|
|
||||||
def run(self, tile, inference_mode=False):
|
def run(self, tile, inference_mode=False):
|
||||||
means = np.array([660.5929,812.9481,1080.6552,1398.3968,1662.5913,1899.4804,2061.932,2100.2792,2214.9325,2230.5973,2443.3014,1968.1885])
|
if tile.shape[2] == 3: # If we get a 3 channel image, then pretend it is 4 channel by duplicating the first band
|
||||||
stdevs = np.array([137.4943,195.3494,241.2698,378.7495,383.0338,449.3187,511.3159,547.6335,563.8937,501.023,624.041,478.9655])
|
tile = np.concatenate([
|
||||||
tile = (tile - means) / stdevs
|
tile,
|
||||||
|
tile[:,:,0][:,:,np.newaxis]
|
||||||
|
], axis=2)
|
||||||
|
|
||||||
|
tile = tile / 255.0
|
||||||
tile = tile.astype(np.float32)
|
tile = tile.astype(np.float32)
|
||||||
|
|
||||||
output, output_features = self.run_model_on_tile(tile)
|
output, output_features = self.run_model_on_tile(tile)
|
||||||
|
@ -113,14 +107,14 @@ class TorchFineTuning(ModelSession):
|
||||||
def retrain(self, **kwargs):
|
def retrain(self, **kwargs):
|
||||||
x_train = np.array(self.augment_x_train)
|
x_train = np.array(self.augment_x_train)
|
||||||
y_train = np.array(self.augment_y_train)
|
y_train = np.array(self.augment_y_train)
|
||||||
|
|
||||||
if x_train.shape[0] == 0:
|
if x_train.shape[0] == 0:
|
||||||
return {
|
return {
|
||||||
"message": "Need to add training samples in order to train",
|
"message": "Need to add training samples in order to train",
|
||||||
"success": False
|
"success": False
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.augment_model.fit(x_train, y_train)
|
self.augment_model.fit(x_train, y_train)
|
||||||
score = self.augment_model.score(x_train, y_train)
|
score = self.augment_model.score(x_train, y_train)
|
||||||
|
@ -131,8 +125,8 @@ class TorchFineTuning(ModelSession):
|
||||||
new_weights = new_weights.to(self.device)
|
new_weights = new_weights.to(self.device)
|
||||||
new_biases = new_biases.to(self.device)
|
new_biases = new_biases.to(self.device)
|
||||||
|
|
||||||
self.model.seg_layer.weight.data = new_weights
|
self.model.last.weight.data = new_weights
|
||||||
self.model.seg_layer.bias.data = new_biases
|
self.model.last.bias.data = new_biases
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"message": "Fine-tuning accuracy on data: %0.2f" % (score),
|
"message": "Fine-tuning accuracy on data: %0.2f" % (score),
|
||||||
|
@ -174,23 +168,14 @@ class TorchFineTuning(ModelSession):
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self._init_model()
|
self._init_model()
|
||||||
self.augment_x_train = []
|
|
||||||
self.augment_y_train = []
|
|
||||||
self.augment_model = sklearn.base.clone(TorchFineTuning.AUGMENT_MODEL)
|
self.augment_model = sklearn.base.clone(TorchFineTuning.AUGMENT_MODEL)
|
||||||
|
|
||||||
label_binarizer = LabelBinarizer()
|
self.augment_x_train = []
|
||||||
label_binarizer.fit(range(self.output_channels))
|
self.augment_y_train = []
|
||||||
|
for row in self.augment_x_base:
|
||||||
self.augment_model.coefs_ = [self.initial_weights]
|
self.augment_x_train.append(row)
|
||||||
self.augment_model.intercepts_ = [self.initial_biases]
|
for row in self.augment_y_base:
|
||||||
|
self.augment_y_train.append(row)
|
||||||
self.augment_model.classes_ = np.array(list(range(self.output_channels)))
|
|
||||||
self.augment_model.n_features_in_ = self.output_features
|
|
||||||
self.augment_model.n_outputs_ = self.output_channels
|
|
||||||
self.augment_model.n_layers_ = 2
|
|
||||||
self.augment_model.out_activation_ = 'softmax'
|
|
||||||
|
|
||||||
self.augment_model._label_binarizer = label_binarizer
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"message": "Model reset successfully",
|
"message": "Model reset successfully",
|
||||||
|
@ -231,7 +216,7 @@ class TorchFineTuning(ModelSession):
|
||||||
t_batch = batch[i:i+batch_size]
|
t_batch = batch[i:i+batch_size]
|
||||||
t_batch = np.rollaxis(t_batch, 3, 1)
|
t_batch = np.rollaxis(t_batch, 3, 1)
|
||||||
t_batch = torch.from_numpy(t_batch).to(self.device)
|
t_batch = torch.from_numpy(t_batch).to(self.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
predictions, features = self.model.forward_features(t_batch)
|
predictions, features = self.model.forward_features(t_batch)
|
||||||
predictions = F.softmax(predictions)
|
predictions = F.softmax(predictions)
|
||||||
|
@ -247,7 +232,7 @@ class TorchFineTuning(ModelSession):
|
||||||
|
|
||||||
model_output = np.concatenate(model_output, axis=0)
|
model_output = np.concatenate(model_output, axis=0)
|
||||||
model_feature_output = np.concatenate(model_feature_output, axis=0)
|
model_feature_output = np.concatenate(model_feature_output, axis=0)
|
||||||
|
|
||||||
for i, (y, x) in enumerate(batch_indices):
|
for i, (y, x) in enumerate(batch_indices):
|
||||||
output[y:y+self.input_size, x:x+self.input_size] += model_output[i] * kernel[..., np.newaxis]
|
output[y:y+self.input_size, x:x+self.input_size] += model_output[i] * kernel[..., np.newaxis]
|
||||||
output_features[y:y+self.input_size, x:x+self.input_size] += model_feature_output[i] * kernel[..., np.newaxis]
|
output_features[y:y+self.input_size, x:x+self.input_size] += model_feature_output[i] * kernel[..., np.newaxis]
|
||||||
|
|
|
@ -84,7 +84,7 @@
|
||||||
"name": "Grid"
|
"name": "Grid"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"validModels": ["naip_demo"]
|
"validModels": ["naip_demo", "devseed_9class"]
|
||||||
},
|
},
|
||||||
"naip_all": {
|
"naip_all": {
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@ -110,7 +110,7 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"shapeLayers": null,
|
"shapeLayers": null,
|
||||||
"validModels": ["naip_demo"]
|
"validModels": ["naip_demo", "devseed_9class"]
|
||||||
},
|
},
|
||||||
"esri_all": {
|
"esri_all": {
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@ -137,7 +137,7 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"shapeLayers": null,
|
"shapeLayers": null,
|
||||||
"validModels": ["naip_demo"]
|
"validModels": ["naip_demo", "devseed_9class"]
|
||||||
},
|
},
|
||||||
"lc_all": {
|
"lc_all": {
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
|
Загрузка…
Ссылка в новой задаче