зеркало из https://github.com/microsoft/msrflute.git
Merged PR 1213: Remove file type dependency on client.py
- Replace _data_dict in client.py by a dataset. - Remove the loader_type dependency for dataloaders utilities - Add Base Classes for dataset and dataloaders - Add example for a previously created dataset instantiation in classif_cnn example - Allow datasets to be downloaded on the fly - Update documentation Sanity checks: [x] nlg_gru: https://aka.ms/amlt?q=cn6vj [x] mlm_bert: https://aka.ms/amlt?q=cppmb [x] classif_cnn: https://aka.ms/amlt?q=cn6vw [x] ecg: https://aka.ms/amlt?q=codet
This commit is contained in:
Родитель
78a401a48a
Коммит
866d0a072c
|
@ -82,7 +82,6 @@ server_config:
|
|||
num_clients_per_iteration: 200 # Number of clients sampled per round
|
||||
data_config: # Server-side data configuration
|
||||
val: # Validation data
|
||||
loader_type: text
|
||||
val_data: <add path to data here>
|
||||
task: mlm
|
||||
mlm_probability: 0.25
|
||||
|
@ -104,7 +103,6 @@ server_config:
|
|||
# train_data_server: null
|
||||
# desired_max_samples: null
|
||||
test: # Test data configuration
|
||||
loader_type: text
|
||||
test_data: <add path to data here>
|
||||
task: mlm
|
||||
mlm_probability: 0.25
|
||||
|
@ -140,7 +138,6 @@ client_config:
|
|||
do_profiling: false # Enables client-side training profiling
|
||||
data_config:
|
||||
train: # This is the main training data configuration
|
||||
loader_type: text
|
||||
list_of_train_data: <add path to data here>
|
||||
task: mlm
|
||||
mlm_probability: 0.25
|
||||
|
|
|
@ -60,7 +60,6 @@ server_config:
|
|||
data_config: # Server-side data configuration
|
||||
val: # Validation data
|
||||
batch_size: 2048
|
||||
loader_type: text
|
||||
tokenizer_type: not_applicable
|
||||
prepend_datapath: false
|
||||
val_data: <add path to data here> # Path for validation data
|
||||
|
@ -92,7 +91,6 @@ server_config:
|
|||
# unsorted_batch: true
|
||||
test: # Test data configuration
|
||||
batch_size: 2048
|
||||
loader_type: text
|
||||
tokenizer_type: not_applicable
|
||||
prepend_datapath: false
|
||||
train_data: null
|
||||
|
@ -130,7 +128,6 @@ client_config:
|
|||
data_config:
|
||||
train: # This is the main training data configuration
|
||||
batch_size: 64
|
||||
loader_type: text
|
||||
tokenizer_type: not_applicable
|
||||
prepend_datapath: false
|
||||
list_of_train_data: <add path to data here> # Path to training data
|
||||
|
|
160
core/client.py
160
core/client.py
|
@ -7,13 +7,12 @@ workers 1 to N for processing a given client's data. It's main method is the
|
|||
'''
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
import h5py
|
||||
from importlib.machinery import SourceFileLoader
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
@ -47,13 +46,6 @@ import extensions.privacy
|
|||
from extensions.privacy import metrics as privacy_metrics
|
||||
from experiments import make_model
|
||||
|
||||
|
||||
# A per-process cache of the training data, so clients don't have to repeatedly re-load
|
||||
# TODO: deprecate this in favor of passing dataloader around
|
||||
_data_dict = None
|
||||
_file_ext = None
|
||||
|
||||
|
||||
class Client:
|
||||
# It's unclear why, but sphinx refuses to generate method docs
|
||||
# if there is no docstring for this class.
|
||||
|
@ -72,9 +64,9 @@ class Client:
|
|||
training data for the client.
|
||||
'''
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.client_id = client_id
|
||||
self.client_data = self.get_data(client_id,dataloader)
|
||||
self.client_data = self.get_data(client_id, dataloader)
|
||||
self.config = copy.deepcopy(config)
|
||||
self.send_gradients = send_gradients
|
||||
|
||||
|
@ -83,112 +75,45 @@ class Client:
|
|||
return self.client_id, self.client_data, self.config, self.send_gradients
|
||||
|
||||
@staticmethod
|
||||
def get_num_users(filename):
|
||||
'''Count users given a JSON or HDF5 file.
|
||||
|
||||
This function will fill the global data dict. Ideally we want data
|
||||
handling not to happen here and only at the dataloader, that will be the
|
||||
behavior in future releases.
|
||||
def get_train_dataset(data_path, client_train_config, task):
|
||||
'''This function will obtain the training dataset for all
|
||||
users.
|
||||
|
||||
Args:
|
||||
filename (str): path to file containing data.
|
||||
data_path (str): path to file containing taining data.
|
||||
client_train_config (dict): trainig data config.
|
||||
'''
|
||||
|
||||
global _data_dict
|
||||
global _file_ext
|
||||
_file_ext = filename.split('.')[-1]
|
||||
|
||||
try:
|
||||
if _file_ext == 'json' or _file_ext == 'txt':
|
||||
if _data_dict is None:
|
||||
print_rank('Reading training data dictionary from JSON')
|
||||
with open(filename,'r') as fid:
|
||||
_data_dict = json.load(fid) # pre-cache the training data
|
||||
_data_dict = scrub_empty_clients(_data_dict) # empty clients MUST be scrubbed here to match num_clients in the entry script
|
||||
print_rank('Read training data dictionary', loglevel=logging.DEBUG)
|
||||
|
||||
elif _file_ext == 'hdf5':
|
||||
print_rank('Reading training data dictionary from HDF5')
|
||||
_data_dict = h5py.File(filename, 'r')
|
||||
print_rank('Read training data dictionary', loglevel=logging.DEBUG)
|
||||
|
||||
dir = os.path.join('experiments',task,'dataloaders','dataset.py')
|
||||
loader = SourceFileLoader("Dataset",dir).load_module()
|
||||
dataset = loader.Dataset
|
||||
train_file = os.path.join(data_path, client_train_config['list_of_train_data']) if client_train_config['list_of_train_data'] != None else None
|
||||
train_dataset = dataset(train_file, args=client_train_config)
|
||||
num_users = len(train_dataset.user_list)
|
||||
print_rank("Total amount of training users: {}".format(num_users))
|
||||
except:
|
||||
raise ValueError('Error reading training file. Please make sure the format is allowed')
|
||||
print_rank("Dataset not found, please make sure is located inside the experiment folder")
|
||||
|
||||
num_users = len(_data_dict['users'])
|
||||
return num_users
|
||||
return num_users, train_dataset
|
||||
|
||||
@staticmethod
|
||||
def get_data(client_id, dataloader):
|
||||
'''Load data from the dataloader given the client's id.
|
||||
def get_data(clients, dataset):
|
||||
''' Create training dictionary'''
|
||||
|
||||
This function will load the global data dict. Ideally we want data
|
||||
handling not to happen here and only at the dataloader, that will be the
|
||||
behavior in future releases.
|
||||
|
||||
Args:
|
||||
client_id (int or list): identifier(s) for grabbing client's data.
|
||||
dataloader (torch.utils.data.DataLoader): dataloader that
|
||||
provides the trianing
|
||||
'''
|
||||
|
||||
# Auxiliary function for decoding only when necessary
|
||||
decode_if_str = lambda x: x.decode() if isinstance(x, bytes) else x
|
||||
|
||||
# During training, client_id will be always an integer
|
||||
if isinstance(client_id, int):
|
||||
user_name = decode_if_str(_data_dict['users'][client_id])
|
||||
num_samples = _data_dict['num_samples'][client_id]
|
||||
|
||||
if _file_ext == 'hdf5':
|
||||
arr_data = [decode_if_str(e) for e in _data_dict['user_data'][user_name]['x'][()]]
|
||||
user_data = {'x': arr_data}
|
||||
elif _file_ext == 'json' or _file_ext == 'txt':
|
||||
user_data = _data_dict['user_data'][user_name]
|
||||
|
||||
if 'user_data_label' in _data_dict: # supervised problem
|
||||
labels = _data_dict['user_data_label'][user_name]
|
||||
if _file_ext == 'hdf5': # transforms HDF5 Dataset into Numpy array
|
||||
labels = labels[()]
|
||||
|
||||
return edict({'users': [user_name],
|
||||
'user_data': {user_name: user_data},
|
||||
'num_samples': [num_samples],
|
||||
'user_data_label': {user_name: labels}})
|
||||
else:
|
||||
print_rank('no labels present, unsupervised problem', loglevel=logging.DEBUG)
|
||||
return edict({'users': [user_name],
|
||||
'user_data': {user_name: user_data},
|
||||
'num_samples': [num_samples]})
|
||||
|
||||
# During validation and test, client_id might be a list of integers
|
||||
elif isinstance(client_id, list):
|
||||
if 'user_data_label' in _data_dict:
|
||||
users_dict = {'users': [], 'num_samples': [], 'user_data': {}, 'user_data_label': {}}
|
||||
else:
|
||||
users_dict = {'users': [], 'num_samples': [], 'user_data': {}}
|
||||
data_with_labels = hasattr(dataset,"user_data_label")
|
||||
input_strct = {'users': [], 'num_samples': [],'user_data': dict(), 'user_data_label': dict()} if data_with_labels else {'users': [], 'num_samples': [],'user_data': dict()}
|
||||
|
||||
for client in client_id:
|
||||
user_name = decode_if_str(dataloader.dataset.user_list[client])
|
||||
users_dict['users'].append(user_name)
|
||||
users_dict['num_samples'].append(dataloader.dataset.num_samples[client])
|
||||
for client in clients:
|
||||
user = dataset.user_list[client]
|
||||
input_strct['users'].append(user)
|
||||
input_strct['num_samples'].append(dataset.num_samples[client])
|
||||
input_strct['user_data'][user]= dataset.user_data[user]
|
||||
if data_with_labels:
|
||||
input_strct['user_data_label'][user] = dataset.user_data_label[user]
|
||||
|
||||
if _file_ext == 'hdf5':
|
||||
arr_data = dataloader.dataset.user_data[user_name]['x']
|
||||
arr_decoded = [decode_if_str(e) for e in arr_data]
|
||||
users_dict['user_data'][user_name] = {'x': arr_decoded}
|
||||
elif _file_ext == 'json':
|
||||
users_dict['user_data'][user_name] = {'x': dataloader.dataset.user_data[user_name]['x']}
|
||||
elif _file_ext == 'txt': # using a different line for .txt since our files have a different structure
|
||||
users_dict['user_data'][user_name] = dataloader.dataset.user_data[user_name]
|
||||
return edict(input_strct)
|
||||
|
||||
if 'user_data_label' in _data_dict:
|
||||
labels = dataloader.dataset.user_data_label[user_name]
|
||||
if _file_ext == 'hdf5':
|
||||
labels = labels[()]
|
||||
users_dict['user_data_label'][user_name] = labels
|
||||
|
||||
return users_dict
|
||||
|
||||
@staticmethod
|
||||
def run_testvalidate(client_data, server_data, mode, model):
|
||||
|
@ -285,32 +210,14 @@ class Client:
|
|||
print_rank(f'Client successfully instantiated strategy {strategy}', loglevel=logging.DEBUG)
|
||||
|
||||
begin = time.time()
|
||||
client_stats = {}
|
||||
|
||||
# Update the location of the training file
|
||||
data_config['list_of_train_data'] = os.path.join(data_path, data_config['list_of_train_data'])
|
||||
client_stats = {}
|
||||
|
||||
user = data_strct['users'][0]
|
||||
if 'user_data_label' in data_strct.keys(): # supervised case
|
||||
input_strct = edict({
|
||||
'users': [user],
|
||||
'user_data': {user: data_strct['user_data'][user]},
|
||||
'num_samples': [data_strct['num_samples'][0]],
|
||||
'user_data_label': {user: data_strct['user_data_label'][user]}
|
||||
})
|
||||
else:
|
||||
input_strct = edict({
|
||||
'users': [user],
|
||||
'user_data': {user: data_strct['user_data'][user]},
|
||||
'num_samples': [data_strct['num_samples'][0]]
|
||||
})
|
||||
|
||||
print_rank('Loading : {}-th client with name: {}, {} samples, {}s elapsed'.format(
|
||||
client_id, user, data_strct['num_samples'][0], time.time() - begin), loglevel=logging.INFO)
|
||||
client_id[0], user, data_strct['num_samples'][0], time.time() - begin), loglevel=logging.INFO)
|
||||
|
||||
# Get dataloaders
|
||||
train_dataloader = make_train_dataloader(data_config, data_path, task=task, clientx=0, data_strct=input_strct)
|
||||
val_dataloader = make_val_dataloader(data_config, data_path)
|
||||
train_dataloader = make_train_dataloader(data_config, data_path, task=task, clientx=0, data_strct=data_strct)
|
||||
|
||||
# Instantiate the model object
|
||||
if model is None:
|
||||
|
@ -349,7 +256,6 @@ class Client:
|
|||
optimizer=optimizer,
|
||||
ss_scheduler=ss_scheduler,
|
||||
train_dataloader=train_dataloader,
|
||||
val_dataloader=val_dataloader,
|
||||
server_replay_config =client_config,
|
||||
max_grad_norm=client_config['data_config']['train'].get('max_grad_norm', None),
|
||||
anneal_config=client_config['annealing_config'] if 'annealing_config' in client_config else None,
|
||||
|
@ -386,7 +292,7 @@ class Client:
|
|||
|
||||
# This is where training actually happens
|
||||
train_loss, num_samples = trainer.train_desired_samples(desired_max_samples=desired_max_samples, apply_privacy_metrics=apply_privacy_metrics)
|
||||
print_rank('client={}: training loss={}'.format(client_id, train_loss), loglevel=logging.DEBUG)
|
||||
print_rank('client={}: training loss={}'.format(client_id[0], train_loss), loglevel=logging.DEBUG)
|
||||
|
||||
# Estimate gradient magnitude mean/var
|
||||
# Now computed when the sufficient stats are updated.
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from torch.utils.data import DataLoader as PyTorchDataLoader
|
||||
from abc import ABC
|
||||
|
||||
class BaseDataLoader(ABC, PyTorchDataLoader):
|
||||
'''This is a wrapper class for PyTorch dataloaders.'''
|
||||
|
||||
def create_loader(self):
|
||||
'''Returns the dataloader'''
|
||||
return self
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from torch.utils.data import Dataset as PyTorchDataset
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
class BaseDataset(ABC, PyTorchDataset):
|
||||
'''This is a wrapper class for PyTorch datasets.'''
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self,**kwargs):
|
||||
super(BaseDataset, self).__init__()
|
||||
|
||||
@abstractmethod
|
||||
def __getitem__(self, idx, **kwargs):
|
||||
'''Fetches a data sample for a given key'''
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __len__(self):
|
||||
'''Returns the size of the dataset'''
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_data(self,**kwargs):
|
||||
'''Wrapper method to read/instantiate the dataset'''
|
||||
pass
|
|
@ -184,10 +184,10 @@ class Evaluation():
|
|||
current_total += count
|
||||
if current_total > threshold:
|
||||
print_rank(f'sending {len(current_users_idxs)} users', loglevel=logging.DEBUG)
|
||||
yield Client(current_users_idxs, self.config, False, dataloader)
|
||||
yield Client(current_users_idxs, self.config, False, dataloader.dataset)
|
||||
current_users_idxs = list()
|
||||
current_total = 0
|
||||
|
||||
if len(current_users_idxs) != 0:
|
||||
print_rank(f'sending {len(current_users_idxs)} users -- residual', loglevel=logging.DEBUG)
|
||||
yield Client(current_users_idxs, self.config, False, dataloader)
|
||||
yield Client(current_users_idxs, self.config, False, dataloader.dataset)
|
||||
|
|
|
@ -2,20 +2,7 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
# Macro variable that sets which distributed trainig framework is used (e.g. mpi, syft, horovod)
|
||||
TRAINING_FRAMEWORK_TYPE = 'mpi'
|
||||
logging_level = logging.INFO # DEBUG | INFO
|
||||
file_type = None
|
||||
task = None
|
||||
|
||||
|
||||
def define_file_type (data_path,config, exp_folder):
|
||||
global file_type
|
||||
global task
|
||||
|
||||
filename = os.path.join(data_path, config["client_config"]["data_config"]["train"]["list_of_train_data"])
|
||||
arr_filename = filename.split(".")
|
||||
file_type = arr_filename[-1]
|
||||
print(" File_type has ben assigned to: {}".format(file_type))
|
||||
task = exp_folder
|
||||
logging_level = logging.INFO # DEBUG | INFO
|
|
@ -121,8 +121,7 @@
|
|||
'allow_unknown': True,
|
||||
'schema': {
|
||||
'batch_size': {'required': False, 'type':'integer', 'default': 40},
|
||||
'loader_type': {'required': False, 'type':'string', 'default':'text'},
|
||||
'val_data': {'required': True, 'type':'string'},
|
||||
'val_data': {'required': True, 'type':'string', 'nullable':True},
|
||||
'tokenizer_type': {'required': False, 'type':'string'},
|
||||
'prepend_datapath': {'required': False, 'type':'boolean', 'default': False},
|
||||
'vocab_dict': {'required': False, 'type':'string'},
|
||||
|
@ -142,8 +141,7 @@
|
|||
'allow_unknown': True,
|
||||
'schema': {
|
||||
'batch_size': {'required': False, 'type':'integer', 'default': 40},
|
||||
'loader_type': {'required': False, 'type':'string', 'default':'text'},
|
||||
'test_data': {'required': True, 'type':'string'},
|
||||
'test_data': {'required': True, 'type':'string', 'nullable': True},
|
||||
'tokenizer_type': {'required': False, 'type':'string'},
|
||||
'prepend_datapath': {'required': False, 'type':'boolean', 'default': False},
|
||||
'vocab_dict': {'required': False, 'type':'string'},
|
||||
|
@ -163,7 +161,6 @@
|
|||
'allow_unknown': True,
|
||||
'schema': {
|
||||
'batch_size': {'required': False, 'type':'integer', 'default': 40},
|
||||
'loader_type': {'required': False, 'type':'string', 'default':'text'},
|
||||
'train_data': {'required': True, 'type':'string'},
|
||||
'train_data_server': {'required': False, 'type':'string'},
|
||||
'desired_max_samples': {'required': False, 'type':'integer'},
|
||||
|
@ -248,8 +245,7 @@
|
|||
'allow_unknown': True,
|
||||
'schema': {
|
||||
'batch_size': {'required': False, 'type':'integer', 'default': 40},
|
||||
'loader_type': {'required': False, 'type':'string', 'default':'text'},
|
||||
'list_of_train_data': {'required': True, 'type':'string'},
|
||||
'list_of_train_data': {'required': True, 'type':'string', 'nullable': True},
|
||||
'tokenizer_type': {'required': False, 'type':'string'},
|
||||
'prepend_datapath': {'required': False, 'type':'boolean', 'default': False},
|
||||
'vocab_dict': {'required': False, 'type':'string'},
|
||||
|
|
|
@ -49,7 +49,7 @@ run = Run.get_context()
|
|||
|
||||
|
||||
class OptimizationServer(federated.Server):
|
||||
def __init__(self, num_clients, model, optimizer, ss_scheduler, data_path, model_path, train_dataloader,
|
||||
def __init__(self, num_clients, model, optimizer, ss_scheduler, data_path, model_path, train_dataloader, train_dataset,
|
||||
val_dataloader, test_dataloader, config, config_server):
|
||||
'''Implement Server's orchestration and aggregation.
|
||||
|
||||
|
@ -133,6 +133,7 @@ class OptimizationServer(federated.Server):
|
|||
# Creating an instance for the server-side trainer (runs mini-batch SGD)
|
||||
self.server_replay_iterations = None
|
||||
self.server_trainer = None
|
||||
self.train_dataset = train_dataset
|
||||
if train_dataloader is not None:
|
||||
assert 'server_replay_config' in server_config, 'server_replay_config is not set'
|
||||
assert 'optimizer_config' in server_config[
|
||||
|
@ -305,10 +306,10 @@ class OptimizationServer(federated.Server):
|
|||
num_clients_curr_iter) if num_clients_curr_iter > 0 else self.client_idx_list
|
||||
sampled_clients = [
|
||||
Client(
|
||||
client_id,
|
||||
[client_id],
|
||||
self.config,
|
||||
self.config['client_config']['type'] == 'optimization',
|
||||
None
|
||||
self.train_dataset
|
||||
) for client_id in sampled_idx_clients
|
||||
]
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@ import logging
|
|||
import os
|
||||
import re
|
||||
|
||||
from importlib.machinery import SourceFileLoader
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -205,7 +204,6 @@ class Trainer(TrainerBase):
|
|||
ss_scheduler: scheduled sampler.
|
||||
train_dataloader (torch.data.utils.DataLoader): dataloader that
|
||||
provides the training data.
|
||||
val_dataloader (torch.data.utils.DataLoader): provides val data.
|
||||
server_replay_config (dict or None): config for replaying training;
|
||||
defaults to None, in which case no replaying happens.
|
||||
optimizer (torch.optim.Optimizer or None): optimizer that will be used
|
||||
|
@ -222,7 +220,6 @@ class Trainer(TrainerBase):
|
|||
model,
|
||||
ss_scheduler,
|
||||
train_dataloader,
|
||||
val_dataloader,
|
||||
server_replay_config=None,
|
||||
optimizer=None,
|
||||
max_grad_norm=None,
|
||||
|
@ -255,7 +252,6 @@ class Trainer(TrainerBase):
|
|||
self.anneal_config,
|
||||
self.optimizer)
|
||||
|
||||
self.val_dataloader = val_dataloader
|
||||
self.cached_batches = []
|
||||
self.ss_scheduler = ss_scheduler
|
||||
|
||||
|
|
|
@ -3,8 +3,12 @@ Adding New Scenarios
|
|||
|
||||
Data Preparation
|
||||
------------
|
||||
|
||||
At this moment FLUTE only allows JSON and HDF5 files, and requires an specific formatting for the training data. Here is a sample data blob for language model training.
|
||||
FLUTE provides the abstract class `BaseDataset` inside ``core/dataset.py`` that can be used to wrap
|
||||
any dataset and make it compatible with the platform. The dataset should be able to access all the data,
|
||||
and store it in the attributes `user_list`, `user_data`, `num_samples` and `user_data_labels` (optional).
|
||||
These attributes are required to have these exact names. The abstract method ``load_data ()`` should be
|
||||
used to instantiate/load the dataset and provide the training format required by FLUTE on-the-fly.
|
||||
Here is a sample data blob for language model training.
|
||||
|
||||
.. code:: json
|
||||
|
||||
|
@ -43,7 +47,7 @@ If labels are needed by the task, ``user_data_label`` will be required by FLUTE
|
|||
Add the model to FLUTE
|
||||
--------------
|
||||
|
||||
FLUTE requires the model declaration framed in PyTorch, which must inhereit from the `BaseModel` class defined in `core/model.py`. The following methods should be overridden:
|
||||
FLUTE requires the model declaration framed in PyTorch, which must inhereit from the `BaseModel` class defined in ``core/model.py``. The following methods should be overridden:
|
||||
|
||||
* __init__: model definition
|
||||
* loss: computes the loss used for training rounds
|
||||
|
@ -92,8 +96,8 @@ Once the model is ready, all mandatory files must be in a single folder inside
|
|||
|
||||
task_name
|
||||
|---- dataloaders
|
||||
|---- text_dataloader.py
|
||||
|---- text_dataset.py
|
||||
|---- dataloader.py
|
||||
|---- dataset.py
|
||||
|---- utils
|
||||
|---- utils.py (if needed)
|
||||
|---- model.py
|
||||
|
@ -130,11 +134,12 @@ Once the keys have been included in the returning dictionary from `inference()`,
|
|||
Create the configuration file
|
||||
---------------------------------
|
||||
|
||||
The configuration file will allow you to specify the setup in your experiment, such as the optimizer, learning rate, number of clients and so on. FLUTE requires the following 5 sections:
|
||||
The configuration file will allow you to specify the setup in your experiment, such as the optimizer, learning rate, number of clients and so on. FLUTE requires the following 6 sections:
|
||||
|
||||
* model_config: path an parameters (if needed) to initialize the model.
|
||||
* dp_config: differential privacy setup.
|
||||
* privacy_metrics_config: for cache data to compute additional metrics.
|
||||
* strategy: defines the federated optimizer.
|
||||
* server_config: determines all the server-side settings.
|
||||
* client_config: dictates the learning parameters for client-side model updates.
|
||||
|
||||
|
@ -175,12 +180,10 @@ The blob below indicates the basic parameters required by FLUTE to run an experi
|
|||
data_config: # Information for the test/val dataloaders
|
||||
val:
|
||||
batch_size: 10000
|
||||
loader_type: text
|
||||
val_data: test_data.hdf5
|
||||
val_data: test_data.hdf5 # Assign to null for data loaded on-the-fly
|
||||
test:
|
||||
batch_size: 10000
|
||||
loader_type: text
|
||||
test_data: test_data.hdf5
|
||||
test_data: test_data.hdf5 # Assign to null for data loaded on-the-fly
|
||||
type: model_optimization # Server type (model_optimization is the only available for now)
|
||||
aggregate_median: softmax # How aggregations weights are computed
|
||||
initial_lr_client: 0.001 # Learning rate used on optimizer
|
||||
|
@ -196,8 +199,7 @@ The blob below indicates the basic parameters required by FLUTE to run an experi
|
|||
data_config: # Information for the train dataloader
|
||||
train:
|
||||
batch_size: 4
|
||||
loader_type: text
|
||||
list_of_train_data: train_data.hdf5
|
||||
list_of_train_data: train_data.hdf5 # Assign to null for data loaded on-the-fly
|
||||
desired_max_samples: 50000
|
||||
optimizer_config: # Optimizer used by the client
|
||||
type: sgd
|
||||
|
|
|
@ -22,7 +22,7 @@ from core import federated
|
|||
from core.config import FLUTEConfig
|
||||
from core.server import select_server
|
||||
from core.client import Client
|
||||
from core.globals import TRAINING_FRAMEWORK_TYPE, logging_level, define_file_type
|
||||
from core.globals import TRAINING_FRAMEWORK_TYPE, logging_level
|
||||
from experiments import make_model
|
||||
from utils import (
|
||||
make_optimizer,
|
||||
|
@ -88,7 +88,6 @@ def run_worker(model_path, config, task, data_path, local_rank):
|
|||
"""
|
||||
model_config = config["model_config"]
|
||||
server_config = config["server_config"]
|
||||
define_file_type(data_path, config, task)
|
||||
|
||||
# Get the rank on MPI
|
||||
rank = local_rank if local_rank > -1 else federated.rank()
|
||||
|
@ -108,11 +107,12 @@ def run_worker(model_path, config, task, data_path, local_rank):
|
|||
print_rank('Server data preparation')
|
||||
|
||||
# pre-cache the training data and capture the number of clients for sampling
|
||||
training_filename = os.path.join(data_path, config["client_config"]["data_config"]["train"]["list_of_train_data"])
|
||||
config["server_config"]["data_config"]["num_clients"] = Client.get_num_users(training_filename)
|
||||
data_config = config['server_config']['data_config']
|
||||
client_train_config = config["client_config"]["data_config"]["train"]
|
||||
num_clients, train_dataset = Client.get_train_dataset(data_path, client_train_config,task)
|
||||
config["server_config"]["data_config"]["num_clients"] = num_clients
|
||||
|
||||
# Make the Dataloaders
|
||||
data_config = config['server_config']['data_config']
|
||||
if 'train' in data_config:
|
||||
server_train_dataloader = make_train_dataloader(data_config['train'], data_path, task=task, clientx=None)
|
||||
else:
|
||||
|
@ -142,6 +142,7 @@ def run_worker(model_path, config, task, data_path, local_rank):
|
|||
data_path,
|
||||
model_path,
|
||||
server_train_dataloader,
|
||||
train_dataset,
|
||||
val_dataloader,
|
||||
test_dataloader,
|
||||
config,
|
||||
|
|
|
@ -9,11 +9,9 @@ An adapted version of the tutorial above is provided in the
|
|||
|
||||
## Preparing the data
|
||||
|
||||
Right now FLUTE expects data to be provided either in JSON or HDF5 formats. It
|
||||
should be made data-agnostic in the near future, but right now we need to
|
||||
convert the data to either of these formats. In our case, we can use the script
|
||||
`utils/download_and_convert_data.py` to do that for us; a HDF5 file will be
|
||||
generated.
|
||||
In this experiment we are making use of the CIFAR10 Dataset from torchvision,
|
||||
initializated in `dataloaders/cifar_dataset.py`, which inhereits from the
|
||||
FLUTE base dataset class `core/dataset.py`
|
||||
|
||||
## Specifying the model
|
||||
|
||||
|
@ -27,12 +25,11 @@ should be the same as in this example.
|
|||
|
||||
## Specifying dataset and dataloaders
|
||||
|
||||
Inside the `dataloaders` folder, there are two files: `text_dataset.py` and
|
||||
`text_dataloader.py` (the word "text" is used to mimic the other datasets, even
|
||||
though in practice this loads images -- this will be changed in the future).
|
||||
Both inherit from the Pytorch classes with same name.
|
||||
Inside the `dataloaders` folder, there are two files: `dataset.py` and
|
||||
`dataloader.py`. Both inherit from the base classes declared in `core`
|
||||
folder, that under the hood inhereit from Pytorch classes with same name.
|
||||
|
||||
The dataset should be able to access all the data, which is stored in the
|
||||
The dataset should be able to access all the data, and store it in the
|
||||
attributes `user_list`, `user_data`, `user_data_labels` and `num_samples` (user
|
||||
names, user features, user labels if the problem is supervised, and number of
|
||||
samples for each user, respectively). These attributes are required to have
|
||||
|
@ -51,8 +48,7 @@ example is provided in `config.yaml`.
|
|||
## Running the experiment
|
||||
|
||||
Finally, to launch the experiment, it suffices to launch the `e2e_trainer.py`
|
||||
script using MPI (don't forget to first run
|
||||
`utils/download_and_convert_data.py`):
|
||||
script using MPI
|
||||
|
||||
```
|
||||
mpiexec -n 4 python e2e_trainer.py -dataPath experiments/classif_cnn/utils/data -outputPath scratch -config experiments/classif_cnn/config.yaml -task classif_cnn
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Basic configuration file for running classif_cnn example using hdf5 files.
|
||||
# Basic configuration file for running classif_cnn example using torchvision CIFAR10 dataset.
|
||||
# Parameters needed to initialize the model
|
||||
model_config:
|
||||
model_type: CNN # class w/ `loss` and `inference` methods
|
||||
|
@ -37,12 +37,10 @@ server_config:
|
|||
data_config: # where to get val and test data from
|
||||
val:
|
||||
batch_size: 10000
|
||||
loader_type: text
|
||||
val_data: test_data.hdf5
|
||||
val_data: null # Assigned to null because dataset is being instantiated
|
||||
test:
|
||||
batch_size: 10000
|
||||
loader_type: text
|
||||
test_data: test_data.hdf5
|
||||
test_data: null # Assigned to null because dataset is being instantiated
|
||||
type: model_optimization
|
||||
aggregate_median: softmax # how aggregations weights are computed
|
||||
initial_lr_client: 0.001 # learning rate used on client optimizer
|
||||
|
@ -59,8 +57,7 @@ client_config:
|
|||
data_config: # where to get training data from
|
||||
train:
|
||||
batch_size: 4
|
||||
loader_type: text
|
||||
list_of_train_data: train_data.hdf5
|
||||
list_of_train_data: null # Assigned to null because dataset is being instantiated
|
||||
desired_max_samples: 50000
|
||||
optimizer_config: # this is the optimizer used by the client
|
||||
type: sgd
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
import time
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
class CIFAR10:
|
||||
def __init__(self) :
|
||||
# Get training and testing data from torchvision
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||
])
|
||||
|
||||
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
|
||||
download=True, transform=transform)
|
||||
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
|
||||
download=True, transform=transform)
|
||||
|
||||
print('Processing training set...')
|
||||
self.trainset=_process(trainset, n_users=1000)
|
||||
|
||||
print('Processing test set...')
|
||||
self.testset=_process(testset, n_users=200)
|
||||
|
||||
def _process(dataset, n_users):
|
||||
'''Process a Torchvision dataset to expected format and save to disk'''
|
||||
|
||||
# Split training data equally among all users
|
||||
total_samples = len(dataset)
|
||||
samples_per_user = total_samples // n_users
|
||||
assert total_samples % n_users == 0
|
||||
|
||||
# Function for getting a given user's data indices
|
||||
user_idxs = lambda user_id: slice(user_id * samples_per_user, (user_id + 1) * samples_per_user)
|
||||
|
||||
# Convert training data to expected format
|
||||
print('Converting data to expected format...')
|
||||
start_time = time.time()
|
||||
|
||||
data_dict = { # the data is expected to have this format
|
||||
'users' : [f'{user_id:04d}' for user_id in range(n_users)],
|
||||
'num_samples' : 10000 * [samples_per_user],
|
||||
'user_data' : {f'{user_id:04d}': dataset.data[user_idxs(user_id)].tolist() for user_id in range(n_users)},
|
||||
'user_data_label': {f'{user_id:04d}': dataset.targets[user_idxs(user_id)] for user_id in range(n_users)},
|
||||
}
|
||||
|
||||
print(f'Finished converting data in {time.time() - start_time:.2f}s.')
|
||||
|
||||
return data_dict
|
||||
|
|
@ -2,21 +2,19 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from experiments.classif_cnn.dataloaders.text_dataset import TextDataset
|
||||
from core.dataloader import BaseDataLoader
|
||||
from experiments.classif_cnn.dataloaders.dataset import Dataset
|
||||
|
||||
|
||||
class TextDataLoader(DataLoader):
|
||||
class DataLoader(BaseDataLoader):
|
||||
def __init__(self, mode, num_workers=0, **kwargs):
|
||||
args = kwargs['args']
|
||||
self.batch_size = args['batch_size']
|
||||
|
||||
dataset = TextDataset(
|
||||
dataset = Dataset(
|
||||
data=kwargs['data'],
|
||||
test_only=(not mode=='train'),
|
||||
user_idx=kwargs.get('user_idx', None),
|
||||
file_type='hdf5',
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
|
@ -27,9 +25,6 @@ class TextDataLoader(DataLoader):
|
|||
collate_fn=self.collate_fn,
|
||||
)
|
||||
|
||||
def create_loader(self):
|
||||
return self
|
||||
|
||||
def collate_fn(self, batch):
|
||||
x, y = list(zip(*batch))
|
||||
return {'x': torch.tensor(x), 'y': torch.tensor(y)}
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import numpy as np
|
||||
from core.dataset import BaseDataset
|
||||
from experiments.classif_cnn.dataloaders.cifar_dataset import CIFAR10
|
||||
|
||||
class Dataset(BaseDataset):
|
||||
def __init__(self, data, test_only=False, user_idx=0, **kwargs):
|
||||
self.test_only = test_only
|
||||
self.user_idx = user_idx
|
||||
|
||||
# Get all data
|
||||
self.user_list, self.user_data, self.user_data_label, self.num_samples = self.load_data(data, self.test_only)
|
||||
|
||||
if self.test_only: # combine all data into single array
|
||||
self.user = 'test_only'
|
||||
self.features = np.vstack([user_data for user_data in self.user_data.values()])
|
||||
self.labels = np.hstack([user_label for user_label in self.user_data_label.values()])
|
||||
else: # get a single user's data
|
||||
if user_idx is None:
|
||||
raise ValueError('in train mode, user_idx must be specified')
|
||||
|
||||
self.user = self.user_list[user_idx]
|
||||
self.features = self.user_data[self.user]
|
||||
self.labels = self.user_data_label[self.user]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return np.array(self.features[idx]).astype(np.float32).T, self.labels[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.features)
|
||||
|
||||
def load_data(self, data, test_only):
|
||||
'''Wrapper method to read/instantiate the dataset'''
|
||||
|
||||
if data == None:
|
||||
dataset = CIFAR10()
|
||||
data = dataset.testset if test_only else dataset.trainset
|
||||
|
||||
users = data['users']
|
||||
features = data['user_data']
|
||||
labels = data['user_data_label']
|
||||
num_samples = data['num_samples']
|
||||
|
||||
return users, features, labels, num_samples
|
|
@ -1,56 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import h5py
|
||||
import json
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
class TextDataset(Dataset):
|
||||
def __init__(self, data, test_only=False, user_idx=None, file_type=None):
|
||||
self.test_only = test_only
|
||||
self.user_idx = user_idx
|
||||
self.file_type = file_type
|
||||
|
||||
# Get all data
|
||||
self.user_list, self.user_data, self.user_data_label, self.num_samples = self.load_data(data, self.file_type)
|
||||
|
||||
if self.test_only: # combine all data into single array
|
||||
self.user = 'test_only'
|
||||
self.features = np.vstack([user_data['x'] for user_data in self.user_data.values()])
|
||||
self.labels = np.hstack(list(self.user_data_label.values()))
|
||||
else: # get a single user's data
|
||||
if user_idx is None:
|
||||
raise ValueError('in train mode, user_idx must be specified')
|
||||
|
||||
self.user = self.user_list[user_idx]
|
||||
self.features = self.user_data[self.user]['x']
|
||||
self.labels = self.user_data_label[self.user]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.features[idx].astype(np.float32).T, self.labels[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.features)
|
||||
|
||||
@staticmethod
|
||||
def load_data(data, file_type):
|
||||
'''Load data from disk or memory.
|
||||
|
||||
The :code:`data` argument can be either the path to the JSON
|
||||
or HDF5 file that contains the expected dictionary, or the
|
||||
actual dictionary.'''
|
||||
|
||||
if isinstance(data, str):
|
||||
if file_type == 'json':
|
||||
with open(data, 'r') as fid:
|
||||
data = json.load(fid)
|
||||
elif file_type == 'hdf5':
|
||||
data = h5py.File(data, 'r')
|
||||
|
||||
users = data['users']
|
||||
features = data['user_data']
|
||||
labels = data['user_data_label']
|
||||
num_samples = data['num_samples']
|
||||
|
||||
return users, features, labels, num_samples
|
|
@ -37,11 +37,9 @@ server_config:
|
|||
data_config: # where to get val and test data from
|
||||
val:
|
||||
batch_size: 10000
|
||||
loader_type: text
|
||||
val_data: test_data.hdf5
|
||||
test:
|
||||
batch_size: 10000
|
||||
loader_type: text
|
||||
test_data: test_data.hdf5
|
||||
type: model_optimization
|
||||
aggregate_median: softmax # how aggregations weights are computed
|
||||
|
@ -59,7 +57,6 @@ client_config:
|
|||
data_config: # where to get training data from
|
||||
train:
|
||||
batch_size: 96
|
||||
loader_type: text
|
||||
list_of_train_data: train_data.hdf5
|
||||
desired_max_samples: 87000
|
||||
optimizer_config: # this is the optimizer used by the client
|
||||
|
|
|
@ -1,17 +1,17 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from experiments.ecg_cnn.dataloaders.text_dataset import TextDataset
|
||||
from experiments.ecg_cnn.dataloaders.dataset import Dataset
|
||||
from core.dataloader import BaseDataLoader
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
class TextDataLoader(DataLoader):
|
||||
class DataLoader(BaseDataLoader):
|
||||
def __init__(self, mode, num_workers=0, **kwargs):
|
||||
args = kwargs['args']
|
||||
self.batch_size = args['batch_size']
|
||||
|
||||
dataset = TextDataset(
|
||||
dataset = Dataset(
|
||||
data=kwargs['data'],
|
||||
test_only=(not mode=='train'),
|
||||
user_idx=kwargs.get('user_idx', None),
|
||||
|
@ -26,9 +26,6 @@ class TextDataLoader(DataLoader):
|
|||
collate_fn=self.collate_fn,
|
||||
)
|
||||
|
||||
def create_loader(self):
|
||||
return self
|
||||
|
||||
def collate_fn(self, batch):
|
||||
x, y = list(zip(*batch))
|
||||
return {'x': torch.tensor(x), 'y': torch.tensor(y)}
|
|
@ -0,0 +1,64 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
|
||||
from core.dataset import BaseDataset
|
||||
|
||||
class Dataset(BaseDataset):
|
||||
def __init__(self, data, test_only=False, user_idx=0, **kwargs):
|
||||
self.test_only = test_only
|
||||
self.user_idx = user_idx
|
||||
|
||||
# Get all data
|
||||
self.user_list, self.user_data, self.user_data_label, self.num_samples = self.load_data(data)
|
||||
|
||||
if self.test_only: # combine all data into single array
|
||||
self.user = 'test_only'
|
||||
self.features = np.vstack([user_data['x'] for user_data in self.user_data.values()])
|
||||
self.labels = np.hstack([user_label['x'] for user_label in self.user_data_label.values()])
|
||||
else: # get a single user's data
|
||||
if user_idx is None:
|
||||
raise ValueError('in train mode, user_idx must be specified')
|
||||
|
||||
self.user = self.user_list[user_idx]
|
||||
self.features = self.user_data[self.user]['x']
|
||||
self.labels = self.user_data_label[self.user]['x']
|
||||
|
||||
def __getitem__(self, idx):
|
||||
items = self.features[idx].astype(np.float32).T.reshape(1,187)
|
||||
return items, self.labels[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.features)
|
||||
|
||||
def load_data(self,data):
|
||||
'''Load data from disk or memory'''
|
||||
|
||||
if isinstance(data, str):
|
||||
try:
|
||||
data = h5py.File(data, 'r')
|
||||
except:
|
||||
raise ValueError('Only HDF5 format is allowed for this experiment')
|
||||
|
||||
users = []
|
||||
num_samples = data['num_samples']
|
||||
features, labels = dict(), dict()
|
||||
|
||||
# Decoding bytes from hdf5
|
||||
decode_if_str = lambda x: x.decode() if isinstance(x, bytes) else x
|
||||
for user in data['users']:
|
||||
user = decode_if_str(user)
|
||||
users.append(user)
|
||||
features[user] = {'x': data['user_data'][user]['x'][()]}
|
||||
labels[user] = {'x': data['user_data_label'][user][()]}
|
||||
|
||||
else:
|
||||
|
||||
users = data['users']
|
||||
features = data['user_data']
|
||||
labels = data['user_data_label']
|
||||
num_samples = data['num_samples']
|
||||
|
||||
return users, features, labels, num_samples
|
|
@ -1,56 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import h5py
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
class TextDataset(Dataset):
|
||||
def __init__(self, data, test_only=False, user_idx=None, file_type=None):
|
||||
self.test_only = test_only
|
||||
self.user_idx = user_idx
|
||||
self.file_type = file_type
|
||||
|
||||
# Get all data
|
||||
self.user_list, self.user_data, self.user_data_label, self.num_samples = self.load_data(data, self.file_type)
|
||||
|
||||
if self.test_only: # combine all data into single array
|
||||
self.user = 'test_only'
|
||||
self.features = np.vstack([user_data['x'] for user_data in self.user_data.values()])
|
||||
self.labels = np.hstack(list(self.user_data_label.values()))
|
||||
else: # get a single user's data
|
||||
if user_idx is None:
|
||||
raise ValueError('in train mode, user_idx must be specified')
|
||||
|
||||
self.user = self.user_list[user_idx]
|
||||
self.features = self.user_data[self.user]['x']
|
||||
self.labels = self.user_data_label[self.user]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
items = self.features[idx].astype(np.float32).T.reshape(1,187)
|
||||
return items, self.labels[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.features)
|
||||
|
||||
@staticmethod
|
||||
def load_data(data, file_type):
|
||||
'''Load data from disk or memory.
|
||||
|
||||
The :code:`data` argument can be either the path to the JSON
|
||||
or HDF5 file that contains the expected dictionary, or the
|
||||
actual dictionary.'''
|
||||
|
||||
if isinstance(data, str):
|
||||
try:
|
||||
data = h5py.File(data, 'r')
|
||||
except:
|
||||
raise ValueError('Only HDF5 format is allowed for this experiment')
|
||||
|
||||
users = data['users']
|
||||
features = data['user_data']
|
||||
labels = data['user_data_label']
|
||||
num_samples = data['num_samples']
|
||||
|
||||
return users, features, labels, num_samples
|
|
@ -32,16 +32,15 @@ The file `centralized_model.ipynb` can be used to test a centralized run of the
|
|||
|
||||
#### Preparing the data
|
||||
|
||||
Right now FLUTE expects data to be provided either in JSON or HDF5 formats. First, place the `mitbih_test.csv` and `mitbig_train.csv` files in the folder `.\ecg_cnn\data\mitbih\`. Next, run preprocess.py in the `utils` folder to generate the HDF5 files.
|
||||
First, place the `mitbih_test.csv` and `mitbig_train.csv` files in the folder `.\ecg_cnn\data\mitbih\`. Next, run preprocess.py in the `utils` folder to generate the HDF5 files.
|
||||
|
||||
## Specifying dataset and data loaders
|
||||
## Specifying dataset and dataloaders
|
||||
|
||||
Inside the `dataloaders` folder, there are two files: `text_dataset.py` and
|
||||
`text_dataloader.py` (the word "text" is used to mimic the other datasets, even
|
||||
though in practice this loads images -- this will be changed in the future).
|
||||
Both inherit from the Pytorch classes with same name.
|
||||
Inside the `dataloaders` folder, there are two files: `dataset.py` and
|
||||
`dataloader.py`. Both inherit from the base classes declared in `core`
|
||||
folder, that under the hood inhereit from Pytorch classes with same name.
|
||||
|
||||
The dataset should be able to access all the data, which is stored in the
|
||||
The dataset should be able to access all the data, and store it in the
|
||||
attributes `user_list`, `user_data`, `user_data_labels` and `num_samples` (user
|
||||
names, user features, user labels if the problem is supervised, and number of
|
||||
samples for each user, respectively). These attributes are required to have
|
||||
|
|
|
@ -4,16 +4,12 @@ Instructions on how to run the experiment, given below.
|
|||
|
||||
## Preparing the data
|
||||
|
||||
Right now FLUTE expects data to be provided either in JSON or HDF5 formats. It
|
||||
should be made data-agnostic in the near future, but at this moment we need to do some
|
||||
preprocessing before handling the data on the model. For this experiment, we can run the
|
||||
For this experiment, we can create a dummy dataset by running the
|
||||
script located in `testing/create_data.py` as follows:
|
||||
|
||||
```code
|
||||
python create_data.py -e mlm
|
||||
```
|
||||
to download mock data already preprocessed. A new folder `mockup` will be generated
|
||||
inside `testing` with all data needed for a local run.
|
||||
|
||||
A couple of scripts are provided in `utils/preprocessing` for preprocessing .tsv files
|
||||
in case you want to use your own data.
|
||||
|
@ -23,9 +19,16 @@ in case you want to use your own data.
|
|||
All the parameters of the experiment are passed in a YAML file. An example is
|
||||
provided in `configs/hello_world_mlm_bert_json.yaml` with the suggested parameters
|
||||
to do a simple run for this experiment. Make sure to point your training files at
|
||||
the fields: train_data, test_data and val_data inside the config file.
|
||||
the fields: list_of_train_data, test_data and val_data inside the config file.
|
||||
|
||||
## Running the experiment
|
||||
## Running the experiment locally
|
||||
|
||||
Finally, to launch the experiment, it suffices to launch the `e2e_trainer.py`
|
||||
script using MPI:
|
||||
|
||||
```code
|
||||
mpiexec -n 2 python .\e2e_trainer.py -dataPath data_folder -outputPath scratch -config configs\hello_world_mlm_bert_json.yaml -task mlm_bert
|
||||
```
|
||||
|
||||
For submitting jobs in Azure ML, we have included the instructions in the `Experiments`
|
||||
section of the main `README.md`.
|
|
@ -2,13 +2,13 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
from transformers.data.data_collator import default_data_collator, DataCollatorWithPadding
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data import RandomSampler, SequentialSampler
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import DataCollatorForLanguageModeling
|
||||
from experiments.mlm_bert.dataloaders.text_dataset import TextDataset
|
||||
import torch
|
||||
from experiments.mlm_bert.dataloaders.dataset import Dataset
|
||||
from core.dataloader import BaseDataLoader
|
||||
|
||||
class TextDataLoader(DataLoader):
|
||||
class DataLoader(BaseDataLoader):
|
||||
"""
|
||||
PyTorch dataloader for loading text data from
|
||||
text_dataset.
|
||||
|
@ -40,7 +40,7 @@ class TextDataLoader(DataLoader):
|
|||
|
||||
print("Tokenizer is: ",tokenizer)
|
||||
|
||||
dataset = TextDataset(
|
||||
dataset = Dataset(
|
||||
data,
|
||||
args= args,
|
||||
test_only = self.mode is not 'train',
|
||||
|
@ -63,7 +63,7 @@ class TextDataLoader(DataLoader):
|
|||
|
||||
if self.mode == 'train':
|
||||
train_sampler = RandomSampler(dataset)
|
||||
super(TextDataLoader, self).__init__(
|
||||
super(DataLoader, self).__init__(
|
||||
dataset,
|
||||
batch_size=self.batch_size,
|
||||
sampler=train_sampler,
|
||||
|
@ -75,7 +75,7 @@ class TextDataLoader(DataLoader):
|
|||
|
||||
elif self.mode == 'val' or self.mode == 'test':
|
||||
eval_sampler = SequentialSampler(dataset)
|
||||
super(TextDataLoader, self).__init__(
|
||||
super(DataLoader, self).__init__(
|
||||
dataset,
|
||||
sampler=eval_sampler,
|
||||
batch_size= self.batch_size,
|
||||
|
@ -88,9 +88,6 @@ class TextDataLoader(DataLoader):
|
|||
else:
|
||||
raise Exception("Sorry, there is something wrong with the 'mode'-parameter ")
|
||||
|
||||
def create_loader(self):
|
||||
return self
|
||||
|
||||
def get_user(self):
|
||||
return self.utt_ids
|
||||
|
|
@ -1,33 +1,50 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
from core.dataset import BaseDataset
|
||||
from transformers import AutoTokenizer
|
||||
from utils import print_rank
|
||||
import logging
|
||||
import json
|
||||
import itertools
|
||||
|
||||
class TextDataset(Dataset):
|
||||
class Dataset(BaseDataset):
|
||||
"""
|
||||
Map a text source to the target text
|
||||
"""
|
||||
def __init__(self, data, args, tokenizer, test_only=False, user_idx=None, max_samples_per_user=-1, min_words_per_utt=5):
|
||||
|
||||
def __init__(self, data, args, tokenizer=None, test_only=False, user_idx=0, max_samples_per_user=-1, min_words_per_utt=5, **kwargs):
|
||||
|
||||
self.utt_list = list()
|
||||
self.test_only= test_only
|
||||
self.padding = args.get('padding', True)
|
||||
self.max_seq_length= args['max_seq_length']
|
||||
self.max_samples_per_user = max_samples_per_user
|
||||
self.min_num_words = min_words_per_utt
|
||||
self.tokenizer = tokenizer
|
||||
self.process_line_by_line=args.get('process_line_by_line', False)
|
||||
self.user = None
|
||||
|
||||
if tokenizer != None:
|
||||
self.tokenizer = tokenizer
|
||||
else:
|
||||
tokenizer_kwargs = {
|
||||
"cache_dir": args['cache_dir'],
|
||||
"use_fast": args['tokenizer_type_fast'],
|
||||
"use_auth_token": None
|
||||
}
|
||||
|
||||
if 'tokenizer_name' in args:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(args['tokenizer_name'], **tokenizer_kwargs)
|
||||
elif 'model_name_or_path' in args:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(args['model_name_or_path'], **tokenizer_kwargs)
|
||||
else:
|
||||
raise ValueError("You are instantiating a new tokenizer from scratch. This is not supported by this script.")
|
||||
|
||||
if self.max_seq_length is None:
|
||||
self.max_seq_length = self.tokenizer.model_max_length
|
||||
if self.max_seq_length > 512:
|
||||
print_rank(
|
||||
f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
|
||||
f"The tokenizer picked seems to have a very large `model_max_length` ({self.tokenizer.model_max_length}). "
|
||||
"Picking 512 instead. You can change that default value by passing --max_seq_length xxx.", loglevel=logging.DEBUG
|
||||
)
|
||||
self.max_seq_length = 512
|
||||
|
@ -39,7 +56,7 @@ class TextDataset(Dataset):
|
|||
)
|
||||
self.max_seq_length = min(self.max_seq_length, self.tokenizer.model_max_length)
|
||||
|
||||
self.read_data(data, user_idx)
|
||||
self.load_data(data, user_idx)
|
||||
|
||||
if not self.process_line_by_line:
|
||||
self.post_process_list()
|
||||
|
@ -65,7 +82,7 @@ class TextDataset(Dataset):
|
|||
return self.utt_list[idx]
|
||||
|
||||
|
||||
def read_data(self, orig_strct, user_idx):
|
||||
def load_data(self, orig_strct, user_idx):
|
||||
""" Reads the data for a specific user (unless it's for val/testing) and returns a
|
||||
list of embeddings and targets."""
|
||||
|
||||
|
@ -85,7 +102,6 @@ class TextDataset(Dataset):
|
|||
self.user = self.user_list[user_idx]
|
||||
self.process_x(self.user_data[self.user])
|
||||
|
||||
|
||||
def process_x(self, raw_x_batch):
|
||||
|
||||
if self.test_only:
|
||||
|
@ -101,7 +117,6 @@ class TextDataset(Dataset):
|
|||
|
||||
print_rank('Processing json-structure for User: {} Utterances Processed: {}'.format(self.user, len(self.utt_list)), loglevel=logging.INFO)
|
||||
|
||||
|
||||
def process_user(self, user, user_data):
|
||||
counter=0
|
||||
for line in user_data:
|
|
@ -4,16 +4,12 @@ Instructions on how to run the experiment, given below.
|
|||
|
||||
## Preparing the data
|
||||
|
||||
Right now FLUTE expects data to be provided either in JSON or HDF5 formats. It
|
||||
should be made data-agnostic in the near future, but at this moment we need to do some
|
||||
preprocessing before handling the data on the model. For this experiment, we can run the
|
||||
For this experiment, we can create a dummy dataset by running the
|
||||
script located in `testing/create_data.py` as follows:
|
||||
|
||||
```code
|
||||
python create_data.py -e nlg
|
||||
```
|
||||
to download mock data already preprocessed. A new folder `mockup` will be generated
|
||||
inside `testing` with all data needed for a local run.
|
||||
|
||||
A couple of scripts are provided in `utils/preprocessing` for preprocessing .tsv files
|
||||
in case you want to use your own data.
|
||||
|
@ -34,7 +30,7 @@ Finally, to launch the experiment locally , it suffices to launch the `e2e_train
|
|||
script using MPI, you can use as example the following line:
|
||||
|
||||
```code
|
||||
mpiexec -n 3 python e2e_trainer.py -dataPath .\testing\mockup\ -outputPath scratch -config .\testing\configs\hello_world_local.yaml -task nlg_gru
|
||||
mpiexec -n 3 python e2e_trainer.py -dataPath .\testing\mockup\ -outputPath scratch -config .\testing\configs\hello_world_nlg_gru.yaml -task nlg_gru
|
||||
```
|
||||
|
||||
For submitting jobs in Azure ML, we have included the instructions in the `Experiments`
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
import random
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import DataLoader
|
||||
from core.dataloader import BaseDataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from experiments.nlg_gru.dataloaders.text_dataset import TextDataset
|
||||
from experiments.nlg_gru.dataloaders.dataset import Dataset
|
||||
from utils.data_utils import BatchSampler, DynamicBatchSampler
|
||||
|
||||
class TextDataLoader(DataLoader):
|
||||
class DataLoader(BaseDataLoader):
|
||||
"""
|
||||
PyTorch dataloader for loading text data from
|
||||
text_dataset.
|
||||
|
@ -20,7 +20,7 @@ class TextDataLoader(DataLoader):
|
|||
self.batch_size = args['batch_size']
|
||||
batch_sampler = None
|
||||
|
||||
dataset = TextDataset(
|
||||
dataset = Dataset(
|
||||
data = kwargs['data'],
|
||||
test_only = not mode=="train",
|
||||
vocab_dict = args['vocab_dict'],
|
||||
|
@ -61,11 +61,6 @@ class TextDataLoader(DataLoader):
|
|||
collate_fn=self.collate_fn,
|
||||
pin_memory=args["pin_memory"])
|
||||
|
||||
|
||||
def create_loader(self):
|
||||
return self
|
||||
|
||||
|
||||
def collate_fn(self, batch):
|
||||
def pad_and_concat_feats(labels):
|
||||
batch_size = len(labels)
|
|
@ -1,21 +1,20 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
from utils import print_rank
|
||||
from core.globals import file_type
|
||||
from experiments.nlg_gru.utils.utility import *
|
||||
import numpy as np
|
||||
import h5py
|
||||
import logging
|
||||
import json
|
||||
|
||||
class TextDataset(Dataset):
|
||||
from utils import print_rank
|
||||
from core.dataset import BaseDataset
|
||||
from experiments.nlg_gru.utils.utility import *
|
||||
|
||||
class Dataset(BaseDataset):
|
||||
"""
|
||||
Map a text source to the target text
|
||||
"""
|
||||
|
||||
def __init__(self, data, min_num_words=2, max_num_words=25, test_only=False, user_idx=None, vocab_dict=None, preencoded=False):
|
||||
|
||||
def __init__(self, data, min_num_words=2, max_num_words=25, test_only=False, user_idx=0, vocab_dict=None, preencoded=False, **kwargs):
|
||||
|
||||
self.utt_list = list()
|
||||
self.test_only = test_only
|
||||
|
@ -24,11 +23,11 @@ class TextDataset(Dataset):
|
|||
self.preencoded = preencoded
|
||||
|
||||
# Load the vocab
|
||||
self.vocab = load_vocab(vocab_dict)
|
||||
self.vocab = load_vocab(kwargs['args']['vocab_dict']) if 'args' in kwargs else load_vocab(vocab_dict)
|
||||
self.vocab_size = len(self.vocab)
|
||||
|
||||
# reading the jsonl for a specific user_idx
|
||||
self.read_data(data, user_idx)
|
||||
self.load_data(data, user_idx)
|
||||
|
||||
def __len__(self):
|
||||
"""Return the length of the elements in the list."""
|
||||
|
@ -47,47 +46,28 @@ class TextDataset(Dataset):
|
|||
|
||||
return batch, self.user
|
||||
|
||||
# Reads JSON or HDF5 files
|
||||
def read_data(self, orig_strct, user_idx):
|
||||
def load_data(self, orig_strct, user_idx):
|
||||
|
||||
if isinstance(orig_strct, str):
|
||||
if file_type == "json":
|
||||
print('Loading json-file: ', orig_strct)
|
||||
with open(orig_strct, 'r') as fid:
|
||||
orig_strct = json.load(fid)
|
||||
print('Loading json-file: ', orig_strct)
|
||||
with open(orig_strct, 'r') as fid:
|
||||
orig_strct = json.load(fid)
|
||||
|
||||
elif file_type == "hdf5":
|
||||
print('Loading hdf5-file: ', orig_strct)
|
||||
orig_strct = h5py.File(orig_strct, 'r')
|
||||
|
||||
self.user_list = orig_strct['users']
|
||||
self.num_samples = orig_strct['num_samples']
|
||||
self.user_data = orig_strct['user_data']
|
||||
self.user = 'test_only' if self.test_only else self.user_list[user_idx]
|
||||
|
||||
self.process_x(self.user_data)
|
||||
|
||||
if self.test_only:
|
||||
self.user = 'test_only'
|
||||
self.process_x(self.user_data)
|
||||
else:
|
||||
self.user = self.user_list[user_idx]
|
||||
self.process_x(self.user_data[self.user])
|
||||
|
||||
|
||||
def process_x(self, raw_x_batch):
|
||||
def process_x(self, user_data):
|
||||
print_rank('Processing data-structure: {} Utterances expected'.format(sum(self.num_samples)), loglevel=logging.DEBUG)
|
||||
if self.test_only:
|
||||
for user in self.user_list:
|
||||
for e in raw_x_batch[user]['x']:
|
||||
utt={}
|
||||
utt['src_text'] = e if type(e) is list else e.split()
|
||||
utt['duration'] = len(e)
|
||||
utt["loss_weight"] = 1.0
|
||||
self.utt_list.append(utt)
|
||||
|
||||
else:
|
||||
for e in raw_x_batch['x']:
|
||||
for user in self.user_list:
|
||||
for e in user_data[user]['x']:
|
||||
utt={}
|
||||
utt['src_text'] = e if type(e) is list else e.split()
|
||||
utt['duration'] = len(utt["src_text"])
|
||||
utt['duration'] = len(e)
|
||||
if utt['duration']<= self.min_num_words:
|
||||
continue
|
||||
|
|
@ -37,12 +37,10 @@ server_config:
|
|||
data_config: # where to get val and test data from
|
||||
val:
|
||||
batch_size: 10000
|
||||
loader_type: text
|
||||
val_data: data/classif_cnn/test_data.hdf5
|
||||
val_data: null
|
||||
test:
|
||||
batch_size: 10000
|
||||
loader_type: text
|
||||
test_data: data/classif_cnn/test_data.hdf5
|
||||
test_data: null
|
||||
type: model_optimization
|
||||
aggregate_median: softmax # how aggregations weights are computed
|
||||
initial_lr_client: 0.001 # learning rate used on client optimizer
|
||||
|
@ -59,8 +57,7 @@ client_config:
|
|||
data_config: # where to get training data from
|
||||
train:
|
||||
batch_size: 4
|
||||
loader_type: text
|
||||
list_of_train_data: data/classif_cnn/train_data.hdf5
|
||||
list_of_train_data: null
|
||||
desired_max_samples: 50000
|
||||
optimizer_config: # this is the optimizer used by the client
|
||||
type: sgd
|
||||
|
|
|
@ -37,11 +37,9 @@ server_config:
|
|||
data_config: # where to get val and test data from
|
||||
val:
|
||||
batch_size: 10000
|
||||
loader_type: text
|
||||
val_data: data/ecg_cnn/test_data.hdf5
|
||||
test:
|
||||
batch_size: 10000
|
||||
loader_type: text
|
||||
test_data: data/ecg_cnn/test_data.hdf5
|
||||
type: model_optimization
|
||||
aggregate_median: softmax # how aggregations weights are computed
|
||||
|
@ -59,7 +57,6 @@ client_config:
|
|||
data_config: # where to get training data from
|
||||
train:
|
||||
batch_size: 96
|
||||
loader_type: text
|
||||
list_of_train_data: data/ecg_cnn/train_data.hdf5
|
||||
desired_max_samples: 87000
|
||||
optimizer_config: # this is the optimizer used by the client
|
||||
|
|
|
@ -60,7 +60,6 @@ server_config:
|
|||
num_clients_per_iteration: 2 # Number of clients sampled per round
|
||||
data_config: # Server-side data configuration
|
||||
val: # Validation data
|
||||
loader_type: text
|
||||
val_data: data/mlm_bert/val_data.txt
|
||||
task: mlm
|
||||
mlm_probability: 0.25
|
||||
|
@ -82,7 +81,6 @@ server_config:
|
|||
# train_data_server: null
|
||||
# desired_max_samples: null
|
||||
test: # Test data configuration
|
||||
loader_type: text
|
||||
test_data: data/mlm_bert/test_data.txt
|
||||
task: mlm
|
||||
mlm_probability: 0.25
|
||||
|
@ -112,7 +110,6 @@ client_config:
|
|||
do_profiling: false # Enables client-side training profiling
|
||||
data_config:
|
||||
train: # This is the main training data configuration
|
||||
loader_type: text
|
||||
list_of_train_data: data/mlm_bert/train_data.txt
|
||||
task: mlm
|
||||
mlm_probability: 0.25
|
||||
|
|
|
@ -42,7 +42,6 @@ server_config:
|
|||
data_config: # Server-side data configuration
|
||||
val: # Validation data
|
||||
# batch_size: 2048
|
||||
# loader_type: text
|
||||
tokenizer_type: not_applicable
|
||||
prepend_datapath: false
|
||||
val_data: data/nlg_gru/val_data.json
|
||||
|
@ -55,7 +54,6 @@ server_config:
|
|||
unsorted_batch: true
|
||||
test: # Test data configuration
|
||||
batch_size: 2048
|
||||
loader_type: text
|
||||
tokenizer_type: not_applicable
|
||||
prepend_datapath: false
|
||||
train_data: null
|
||||
|
@ -87,7 +85,6 @@ client_config:
|
|||
data_config:
|
||||
train: # This is the main training data configuration
|
||||
batch_size: 64
|
||||
loader_type: text
|
||||
tokenizer_type: not_applicable
|
||||
prepend_datapath: false
|
||||
list_of_train_data: data/nlg_gru/train_data.json
|
||||
|
|
|
@ -60,12 +60,14 @@ def test_mlm_bert():
|
|||
data_path, output_path, config_path = get_info(task)
|
||||
assert run_pipeline(data_path, output_path, config_path, task)==0
|
||||
print("PASSED")
|
||||
|
||||
|
||||
@pytest.mark.xfail
|
||||
def test_classif_cnn():
|
||||
|
||||
task = 'classif_cnn'
|
||||
data_path, output_path, config_path = get_info(task)
|
||||
assert run_pipeline(data_path, output_path, config_path, task)==0
|
||||
print("PASSED")
|
||||
|
||||
def test_ecg_cnn():
|
||||
|
||||
|
|
|
@ -14,34 +14,14 @@ def get_exp_dataloader(task):
|
|||
"""
|
||||
|
||||
try:
|
||||
dir = os.path.join('experiments',task,'dataloaders','text_dataloader.py')
|
||||
loader = SourceFileLoader("TextDataLoader",dir).load_module()
|
||||
loader = loader.TextDataLoader
|
||||
dir = os.path.join('experiments',task,'dataloaders','dataloader.py')
|
||||
loader = SourceFileLoader("DataLoader",dir).load_module()
|
||||
loader = loader.DataLoader
|
||||
except:
|
||||
print_rank("Dataloader not found, please make sure is located inside the experiment folder")
|
||||
|
||||
return loader
|
||||
|
||||
|
||||
def detect_loader_type(my_data, loader_type):
|
||||
""" Detect the loader type declared in the configuration file
|
||||
|
||||
Inside this function should go the implementation of
|
||||
specific detection for any kind of loader.
|
||||
|
||||
Args:
|
||||
my_data (str): path of file or chunk file set
|
||||
loader_type (str): loader description in yaml file
|
||||
"""
|
||||
|
||||
if not loader_type == "auto_detect":
|
||||
return loader_type
|
||||
|
||||
# Here should go the implementation for the rest of loaders
|
||||
else:
|
||||
raise ValueError("Unknown format: {}".format(loader_type))
|
||||
|
||||
|
||||
def make_train_dataloader(data_config, data_path, clientx, task=None, vec_size=300, data_strct=None):
|
||||
""" Create a dataloader for training on either server or client side """
|
||||
|
||||
|
@ -64,67 +44,43 @@ def make_train_dataloader(data_config, data_path, clientx, task=None, vec_size=3
|
|||
else:
|
||||
my_data = data_config["list_of_train_data"]
|
||||
|
||||
# Find the loader_type
|
||||
loader_type = detect_loader_type(my_data, data_config["loader_type"])
|
||||
|
||||
if loader_type == 'text':
|
||||
TextDataLoader = get_exp_dataloader(task)
|
||||
train_dataloader = TextDataLoader(
|
||||
data = data_strct if data_strct is not None else my_data,
|
||||
DataLoader = get_exp_dataloader(task)
|
||||
train_dataloader = DataLoader(data = data_strct if data_strct is not None else my_data,
|
||||
user_idx = clientx,
|
||||
mode = mode,
|
||||
args=data_config
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Not supported {}: detected_type={} loader_type={} audio_format={}".format(my_data, loader_type, data_config["loader_type"], data_config["audio_format"]))
|
||||
|
||||
return train_dataloader
|
||||
|
||||
|
||||
|
||||
def make_val_dataloader(data_config, data_path, task=None, data_strct=None):
|
||||
def make_val_dataloader(data_config, data_path, task=None, data_strct=None, train_mode=False):
|
||||
""" Return a data loader for a validation set """
|
||||
|
||||
if not "val_data" in data_config or data_config["val_data"] is None:
|
||||
print_rank("Validation data list is not set", loglevel=logging.DEBUG)
|
||||
if train_mode:
|
||||
return None
|
||||
|
||||
loader_type = detect_loader_type(data_config["val_data"], data_config["loader_type"])
|
||||
|
||||
if loader_type == 'text':
|
||||
TextDataLoader = get_exp_dataloader(task)
|
||||
|
||||
val_dataloader = TextDataLoader(
|
||||
data = data_strct if data_strct is not None else os.path.join(data_path, data_config["val_data"]),
|
||||
DataLoader = get_exp_dataloader(task)
|
||||
val_file = os.path.join(data_path, data_config["val_data"]) if data_config["val_data"] != None and data_path != None else None
|
||||
val_dataloader = DataLoader(data = data_strct if data_strct is not None else val_file,
|
||||
user_idx = 0,
|
||||
mode = 'val',
|
||||
args=data_config
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Not supported loader_type={} audio_format={}".format(loader_type, data_config["audio_format"]))
|
||||
|
||||
return val_dataloader
|
||||
|
||||
|
||||
def make_test_dataloader(data_config, data_path, task=None, data_strct=None):
|
||||
""" Return a data loader for an evaluation set. """
|
||||
|
||||
if not "test_data" in data_config or data_config["test_data"] is None:
|
||||
print_rank("Test data list is not set")
|
||||
return None
|
||||
|
||||
loader_type = detect_loader_type(data_config["test_data"], data_config["loader_type"])
|
||||
|
||||
if loader_type == 'text':
|
||||
TextDataLoader = get_exp_dataloader(task)
|
||||
|
||||
test_dataloader = TextDataLoader(
|
||||
data = data_strct if data_strct is not None else os.path.join(data_path, data_config["test_data"]),
|
||||
DataLoader = get_exp_dataloader(task)
|
||||
test_file = os.path.join(data_path, data_config["test_data"]) if data_config["test_data"] != None and data_path != None else None
|
||||
test_dataloader = DataLoader(data = data_strct if data_strct is not None else test_file,
|
||||
user_idx = 0,
|
||||
mode = 'test',
|
||||
args=data_config
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError("Not supported loader_type={} audio_format={}".format(loader_type, data_config["audio_format"]))
|
||||
return test_dataloader
|
||||
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче