зеркало из https://github.com/microsoft/msrflute.git
Merged PR 1139: Abstract class for models
- Include abstract class for models in core/model.py - Update in model classes accordingly per experiment. - Remove abstract class for metrics (it is no longer necessary), new metrics only should be declared in the returning dictionary of `inference()` and FLUTE will recognize them during the evaluation rounds. - custom_metrics.py inside each experiment folder is not needed anymore. - Update in the docs for model implementation and metrics.
This commit is contained in:
Родитель
08ac1bb4ed
Коммит
299312e461
|
@ -4,42 +4,31 @@
|
|||
In this file we define the wrapper class for
|
||||
implementing metrics.
|
||||
'''
|
||||
from abc import ABC
|
||||
import logging
|
||||
from unittest import result
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from utils import print_rank
|
||||
|
||||
class Metrics(ABC):
|
||||
class Metrics():
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def compute_metrics(self,dataloader, model):
|
||||
'''This function is called by ´run_validation_generic´ function
|
||||
'''This method is called by ´run_validation_generic´ function
|
||||
inside trainer.py .
|
||||
|
||||
This is just a helper function that computes loss and accuracy
|
||||
metrics that will be used for all experiments. This function will
|
||||
concatenate and return the basic_metrics dict + customized_metrics
|
||||
dict.
|
||||
This is just a helper function that computes the metrics returned
|
||||
in the inference function inside ´model.py´.
|
||||
'''
|
||||
|
||||
print_rank("Computing metrics")
|
||||
output_to, metrics, inf_results = self.basic_metrics(dataloader,model)
|
||||
try:
|
||||
metrics.update(self.customized_metrics(inf_results=inf_results))
|
||||
except:
|
||||
print_rank("File custom_metrics.py not found")
|
||||
|
||||
return output_to, metrics
|
||||
return self.call_inference(dataloader,model)
|
||||
|
||||
def basic_metrics(self, dataloader, model):
|
||||
val_losses, val_accuracies = list(), list()
|
||||
def call_inference(self, dataloader, model):
|
||||
|
||||
metrics, sum_metrics = dict(), dict()
|
||||
output_tot = {"probabilities": [], "predictions": [], "labels":[]}
|
||||
counter = 0
|
||||
|
||||
|
@ -47,17 +36,22 @@ class Metrics(ABC):
|
|||
for _, batch in enumerate(dataloader):
|
||||
val_loss = model.loss(batch).item()
|
||||
inf_results = model.inference(batch)
|
||||
output = inf_results['output']
|
||||
val_acc = inf_results['val_acc']
|
||||
batch_size = inf_results['batch_size']
|
||||
inf_results ['loss'] = {'value': val_loss,'higher_is_better': False}
|
||||
output = inf_results.pop('output')
|
||||
batch_size = inf_results.pop('batch_size')
|
||||
|
||||
for key in inf_results.keys():
|
||||
if not isinstance(inf_results[key], dict):
|
||||
inf_results[key] = {'value':inf_results[key],'higher_is_better': True}
|
||||
sum_metrics[key] = [] if not key in sum_metrics else sum_metrics[key]
|
||||
|
||||
if isinstance(output, dict):
|
||||
output_tot["probabilities"].append(output["probabilities"])
|
||||
output_tot["predictions"].append(output["predictions"])
|
||||
output_tot["labels"].append(output["labels"])
|
||||
|
||||
val_losses.append(val_loss * batch_size)
|
||||
val_accuracies.append(val_acc * batch_size)
|
||||
for q in inf_results.keys():
|
||||
sum_metrics[q].append(inf_results[q]['value']* batch_size)
|
||||
counter += batch_size
|
||||
|
||||
output_tot["probabilities"] = np.concatenate(output_tot["probabilities"]) if output_tot["probabilities"] else []
|
||||
|
@ -67,15 +61,11 @@ class Metrics(ABC):
|
|||
# Post-processing of metrics
|
||||
print_rank(f"validation complete {counter}", loglevel=logging.DEBUG)
|
||||
model.set_train()
|
||||
avg_val_loss = sum(val_losses) / counter
|
||||
avg_val_acc = sum(val_accuracies) / counter
|
||||
|
||||
for k in inf_results.keys():
|
||||
metrics[k] = inf_results[k]
|
||||
metrics[k]['value'] = sum(sum_metrics[k])/counter
|
||||
|
||||
print_rank(f"validation examples {counter}", loglevel=logging.DEBUG)
|
||||
|
||||
# Create metrics dict
|
||||
metrics = {'loss': {'value':avg_val_loss,'higher_is_better': False},
|
||||
'acc': {'value':avg_val_acc,'higher_is_better': True}}
|
||||
|
||||
return output_tot, metrics, inf_results
|
||||
|
||||
def customized_metrics(self, inf_results):
|
||||
pass
|
||||
return output_tot, metrics
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch as T
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
class BaseModel(ABC, T.nn.Module):
|
||||
'''This is a wrapper class for PyTorch models.'''
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self,**kwargs):
|
||||
super(BaseModel, self).__init__()
|
||||
|
||||
@abstractmethod
|
||||
def loss(self, input):
|
||||
'''Performs forward step and computes the loss
|
||||
|
||||
Returns:
|
||||
torch: Computed loss.
|
||||
'''
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def inference(self, input):
|
||||
'''Performs forward step and computes metrics
|
||||
|
||||
Returns:
|
||||
dict: The metrics to be computed. The following keys are
|
||||
the minimum required by FLUTE during evaluations rounds:
|
||||
- output
|
||||
- acc
|
||||
- batch_size
|
||||
|
||||
More metrics can be computed by adding the key with a
|
||||
dictionary that includes the fields ´value´ and
|
||||
´higher_is_better´ as follows:
|
||||
|
||||
{'output':output,
|
||||
'acc': accuracy,
|
||||
'batch_size': n_samples,
|
||||
'f1_score': {'value':f1,'higher_is_better': True}}
|
||||
'''
|
||||
pass
|
||||
|
||||
def set_eval(self):
|
||||
'''Bring the model into evaluation mode'''
|
||||
self.eval()
|
||||
|
||||
def set_train(self):
|
||||
'''Bring the model into training mode'''
|
||||
self.train()
|
|
@ -467,15 +467,8 @@ def run_validation_generic(model, val_dataloader):
|
|||
loglevel=logging.DEBUG
|
||||
)
|
||||
|
||||
try:
|
||||
from core.globals import task
|
||||
loader = SourceFileLoader("CustomMetrics", str("./experiments/"+task+"/custom_metrics.py")).load_module()
|
||||
metrics_cl = getattr(loader,"CustomMetrics")()
|
||||
print_rank("Loading customized metrics")
|
||||
except:
|
||||
metrics_cl = Metrics()
|
||||
print_rank("Loading default metrics")
|
||||
|
||||
print_rank("Loading metrics ...")
|
||||
metrics_cl = Metrics()
|
||||
return metrics_cl.compute_metrics(dataloader=val_loader, model=model)
|
||||
|
||||
def set_component_wise_lr(model, optimizer_config, updatable_names):
|
||||
|
|
|
@ -43,19 +43,19 @@ If labels are needed by the task, ``user_data_label`` will be required by FLUTE
|
|||
Add the model to FLUTE
|
||||
--------------
|
||||
|
||||
FLUTE requires the model declaration framed in PyTorch, with the following functions:
|
||||
FLUTE requires the model declaration framed in PyTorch, which must inhereit from the `BaseModel` class defined in `core/model.py`. The following methods should be overridden:
|
||||
|
||||
* __init__: model definition
|
||||
* loss: computes the loss used for training rounds
|
||||
* inference: computes the metrics used during evaluation rounds
|
||||
* set_eval: brings the model into evaluation mode
|
||||
* set_train: brings the model into training mode
|
||||
|
||||
Please see the example provided below:
|
||||
|
||||
.. code:: python
|
||||
|
||||
class CNN(nn.Module):
|
||||
from core.model import BaseModel
|
||||
|
||||
class CNN(BaseModel):
|
||||
'''This is a PyTorch model with some extra methods'''
|
||||
|
||||
def __init__(self, model_config):
|
||||
|
@ -79,42 +79,54 @@ Please see the example provided below:
|
|||
accuracy = torch.mean((torch.argmax(output, dim=1) == labels).float()).item()
|
||||
f1 = f1_score(labels.cpu(), torch.argmax(output, dim=1).cpu(), average='micro')
|
||||
|
||||
return {'output':output, 'val_acc': accuracy, 'batch_size': n_samples, 'f1_score':f1}
|
||||
|
||||
def set_eval(self):
|
||||
'''Bring the model into evaluation mode'''
|
||||
self.eval()
|
||||
# NOTE: Only the keys 'output','acc' and 'batch_size' does not require
|
||||
# extra fields as 'value' and 'higher is better'. FLUTE requires this
|
||||
# format only for customized metrics.
|
||||
|
||||
def set_train(self):
|
||||
'''Bring the model into training mode'''
|
||||
self.train()
|
||||
return {'output':output, 'acc': accuracy, 'batch_size': n_samples, \
|
||||
'f1_score': {'value':f1,'higher_is_better': True}}
|
||||
|
||||
The Inference function must return a dictionary with the metrics that will be computed, as follows:
|
||||
|
||||
.. code:: bash
|
||||
|
||||
{ "output": loss, "val_acc": accuracy, "batch_size": batch_size}
|
||||
|
||||
.. note:: FLUTE requires at least loss, accuracy and batch size for the dictionary returned by inference(). More metrics can be added just by includding a new key in the same dictionary.
|
||||
|
||||
Once the model is ready, all mandatory files must be in a single folder inside /experiments. Please adjust your files with the following naming structure so FLUTE can be able to find all the scripts needed.
|
||||
Once the model is ready, all mandatory files must be in a single folder inside ´{/experiments´. Please adjust your files with the following naming structure so FLUTE can be able to find all the scripts needed.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
task_name
|
||||
|---- dataloaders
|
||||
|---- text_dataloader.py
|
||||
|---- text_dataset.py
|
||||
|---- utils
|
||||
|---- utils.py
|
||||
|---- utils.py (if needed)
|
||||
|---- model.py
|
||||
|---- config.yaml
|
||||
|---- custom_metrics.py (optional)
|
||||
|---- README.txt
|
||||
|
||||
.. note:: In case you need to import a module that has not been considered in FLUTE, this can be added in requirements.txt
|
||||
|
||||
.. note:: All files must contain only absolute imports, in order to avoid issues when running.
|
||||
|
||||
Implement new metrics
|
||||
--------------
|
||||
|
||||
The metrics computed during the evaluation rounds are declared inside `inference()` in the model declaration. FLUTE requires this function to return a dictionary with at least `output`, `acc` and `batch_size` as follows:
|
||||
|
||||
.. code:: bash
|
||||
|
||||
{ "output": loss, "acc": accuracy, "batch_size": batch_size}
|
||||
|
||||
In order to add a new metric, we just need to add the key inside the same dictionary with the following format:
|
||||
|
||||
.. code:: bash
|
||||
|
||||
{ "output": loss,
|
||||
"acc": accuracy,
|
||||
"batch_size": batch_size,
|
||||
"custom_metric_1": {"value": value1 ,'higher_is_better': True},
|
||||
"custom_metric_2": {"value": value2 ,'higher_is_better': False}}
|
||||
|
||||
Once the keys have been included in the returning dictionary from `inference()`, FLUTE will automatically recognize them during the test/val rounds.
|
||||
|
||||
.. note:: Only the keys `output`, `acc` and `batch_size` does not require a dictionary.
|
||||
|
||||
Create the configuration file
|
||||
---------------------------------
|
||||
|
||||
|
|
|
@ -1,38 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
'''
|
||||
In this file we define the new metrics to
|
||||
implement in any experiment.
|
||||
'''
|
||||
|
||||
from core.metrics import Metrics
|
||||
|
||||
class CustomMetrics(Metrics):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def customized_metrics(self, inf_results):
|
||||
'''This function is called by ´compute_metrics´ function inside
|
||||
metrics.py .
|
||||
|
||||
This is just a helper function that computes and fetches customized
|
||||
metrics that will be used for any experiment. This function receives
|
||||
the loss and accuracy computed previously, so they can be used for
|
||||
computing customized metrics. It should return a dictionary where the
|
||||
keys are the name of the metric to be logged, with the following form:
|
||||
|
||||
metrics = {'metric_name_1': {'value':metric_value_1,
|
||||
'higher_is_better': False},
|
||||
'metric_name_2': {'value':metric_value_2,
|
||||
'higher_is_better': True}
|
||||
}
|
||||
|
||||
Args:
|
||||
acc_and_loss (dict): Computations from 'basic_metrics', the only keys
|
||||
inside the dict are 'acc' and 'loss'.
|
||||
'''
|
||||
|
||||
customized = dict()
|
||||
customized = {'f1_score': {'value':inf_results['f1_score'],'higher_is_better': True}}
|
||||
|
||||
return customized
|
|
@ -1,12 +1,12 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from sklearn.metrics import f1_score
|
||||
|
||||
from core.model import BaseModel
|
||||
|
||||
class Net(nn.Module):
|
||||
'''The standard PyTorch model we want to federate'''
|
||||
|
@ -30,7 +30,7 @@ class Net(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class CNN(nn.Module):
|
||||
class CNN(BaseModel):
|
||||
'''This is a PyTorch model with some extra methods'''
|
||||
|
||||
def __init__(self, model_config):
|
||||
|
@ -54,12 +54,12 @@ class CNN(nn.Module):
|
|||
accuracy = torch.mean((torch.argmax(output, dim=1) == labels).float()).item()
|
||||
f1 = f1_score(labels.cpu(), torch.argmax(output, dim=1).cpu(), average='micro')
|
||||
|
||||
return {'output':output, 'val_acc': accuracy, 'batch_size': n_samples, 'f1_score':f1}
|
||||
|
||||
def set_eval(self):
|
||||
'''Bring the model into evaluation mode'''
|
||||
self.eval()
|
||||
# NOTE: Only the keys 'output','acc' and 'batch_size' does not require
|
||||
# extra fields as 'value' and 'higher is better'. FLUTE requires this
|
||||
# format only for customized metrics.
|
||||
|
||||
def set_train(self):
|
||||
'''Bring the model into training mode'''
|
||||
self.train()
|
||||
return {'output':output, 'acc': accuracy, 'batch_size': n_samples, \
|
||||
'f1_score': {'value':f1,'higher_is_better': True}}
|
||||
|
||||
|
||||
|
|
@ -9,6 +9,8 @@ import torch
|
|||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from core.model import BaseModel
|
||||
|
||||
# ReLu alternative
|
||||
class Swish(nn.Module):
|
||||
def forward(self, x):
|
||||
|
@ -148,7 +150,7 @@ class Net(nn.Module):
|
|||
x = F.softmax(self.fc(x), dim=-1)
|
||||
return x
|
||||
|
||||
class SuperNet(nn.Module):
|
||||
class SuperNet(BaseModel):
|
||||
'''This is the parent of the net with some extra methods'''
|
||||
def __init__(self, model_config):
|
||||
super().__init__()
|
||||
|
@ -168,15 +170,9 @@ class SuperNet(nn.Module):
|
|||
|
||||
accuracy = torch.mean((torch.argmax(output, dim=1) == labels).float()).item()
|
||||
|
||||
return {'output':output, 'val_acc': accuracy, 'batch_size': n_samples}
|
||||
|
||||
def set_eval(self):
|
||||
'''Bring the model into evaluation mode'''
|
||||
self.eval()
|
||||
|
||||
def set_train(self):
|
||||
'''Bring the model into training mode'''
|
||||
self.train()
|
||||
return {'output':output, 'acc': accuracy, 'batch_size': n_samples}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -31,11 +31,12 @@ from transformers import (
|
|||
set_seed,
|
||||
)
|
||||
from utils.utils import to_device
|
||||
from core.model import BaseModel
|
||||
|
||||
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
class BERT(T.nn.Module):
|
||||
class BERT(BaseModel):
|
||||
def __init__(self, model_config, **kwargs):
|
||||
super(BERT, self).__init__()
|
||||
"""
|
||||
|
@ -274,7 +275,7 @@ class BERT(T.nn.Module):
|
|||
description="Evaluation",
|
||||
ignore_keys=ignore_keys,
|
||||
metric_key_prefix=metric_key_prefix)
|
||||
return {'output':output['eval_loss'], 'val_acc': output['eval_acc'], 'batch_size': batch_size[0]}
|
||||
return {'output':output['eval_loss'], 'acc': output['eval_acc'], 'batch_size': batch_size[0]}
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -3,7 +3,9 @@
|
|||
|
||||
import torch as T
|
||||
from torch import Tensor
|
||||
from typing import Dict, List, Tuple, Optional, NamedTuple
|
||||
from typing import List, Tuple
|
||||
|
||||
from core.model import BaseModel
|
||||
from utils import softmax, to_device
|
||||
|
||||
class GRU2(T.nn.Module):
|
||||
|
@ -52,7 +54,7 @@ class Embedding(T.nn.Module):
|
|||
return output
|
||||
|
||||
|
||||
class GRU(T.nn.Module): #DLM_2_0
|
||||
class GRU(BaseModel): #DLM_2_0
|
||||
def __init__(self, model_config, OOV_correct=False, dropout=0.0, topK_results=1, wantLogits=False, **kwargs):
|
||||
super(GRU, self).__init__()
|
||||
self.vocab_size = model_config['vocab_size']
|
||||
|
@ -128,17 +130,7 @@ class GRU(T.nn.Module): #DLM_2_0
|
|||
'predictions': preds_topK.cpu().detach().numpy(),
|
||||
'labels': targets.cpu().detach().numpy()}
|
||||
|
||||
return {'output':output, 'val_acc': acc.item(), 'batch_size': input.shape[0]}
|
||||
|
||||
def set_eval(self):
|
||||
"""
|
||||
Bring the model into evaluation mode
|
||||
"""
|
||||
self.eval()
|
||||
return {'output':output, 'acc': acc.item(), 'batch_size': input.shape[0]}
|
||||
|
||||
|
||||
|
||||
def set_train(self):
|
||||
"""
|
||||
Bring the model into train mode
|
||||
"""
|
||||
self.train()
|
||||
|
|
Загрузка…
Ссылка в новой задаче