зеркало из https://github.com/microsoft/msrflute.git
Merged PR 1213: Remove file type dependency on client.py
- Replace _data_dict in client.py by a dataset. - Remove the loader_type dependency for dataloaders utilities - Add Base Classes for dataset and dataloaders - Add example for a previously created dataset instantiation in classif_cnn example - Allow datasets to be downloaded on the fly - Update documentation Sanity checks: [x] nlg_gru: https://aka.ms/amlt?q=cn6vj [x] mlm_bert: https://aka.ms/amlt?q=cppmb [x] classif_cnn: https://aka.ms/amlt?q=cn6vw [x] ecg: https://aka.ms/amlt?q=codet
This commit is contained in:
Родитель
78a401a48a
Коммит
866d0a072c
|
@ -82,7 +82,6 @@ server_config:
|
||||||
num_clients_per_iteration: 200 # Number of clients sampled per round
|
num_clients_per_iteration: 200 # Number of clients sampled per round
|
||||||
data_config: # Server-side data configuration
|
data_config: # Server-side data configuration
|
||||||
val: # Validation data
|
val: # Validation data
|
||||||
loader_type: text
|
|
||||||
val_data: <add path to data here>
|
val_data: <add path to data here>
|
||||||
task: mlm
|
task: mlm
|
||||||
mlm_probability: 0.25
|
mlm_probability: 0.25
|
||||||
|
@ -104,7 +103,6 @@ server_config:
|
||||||
# train_data_server: null
|
# train_data_server: null
|
||||||
# desired_max_samples: null
|
# desired_max_samples: null
|
||||||
test: # Test data configuration
|
test: # Test data configuration
|
||||||
loader_type: text
|
|
||||||
test_data: <add path to data here>
|
test_data: <add path to data here>
|
||||||
task: mlm
|
task: mlm
|
||||||
mlm_probability: 0.25
|
mlm_probability: 0.25
|
||||||
|
@ -140,7 +138,6 @@ client_config:
|
||||||
do_profiling: false # Enables client-side training profiling
|
do_profiling: false # Enables client-side training profiling
|
||||||
data_config:
|
data_config:
|
||||||
train: # This is the main training data configuration
|
train: # This is the main training data configuration
|
||||||
loader_type: text
|
|
||||||
list_of_train_data: <add path to data here>
|
list_of_train_data: <add path to data here>
|
||||||
task: mlm
|
task: mlm
|
||||||
mlm_probability: 0.25
|
mlm_probability: 0.25
|
||||||
|
|
|
@ -60,7 +60,6 @@ server_config:
|
||||||
data_config: # Server-side data configuration
|
data_config: # Server-side data configuration
|
||||||
val: # Validation data
|
val: # Validation data
|
||||||
batch_size: 2048
|
batch_size: 2048
|
||||||
loader_type: text
|
|
||||||
tokenizer_type: not_applicable
|
tokenizer_type: not_applicable
|
||||||
prepend_datapath: false
|
prepend_datapath: false
|
||||||
val_data: <add path to data here> # Path for validation data
|
val_data: <add path to data here> # Path for validation data
|
||||||
|
@ -92,7 +91,6 @@ server_config:
|
||||||
# unsorted_batch: true
|
# unsorted_batch: true
|
||||||
test: # Test data configuration
|
test: # Test data configuration
|
||||||
batch_size: 2048
|
batch_size: 2048
|
||||||
loader_type: text
|
|
||||||
tokenizer_type: not_applicable
|
tokenizer_type: not_applicable
|
||||||
prepend_datapath: false
|
prepend_datapath: false
|
||||||
train_data: null
|
train_data: null
|
||||||
|
@ -130,7 +128,6 @@ client_config:
|
||||||
data_config:
|
data_config:
|
||||||
train: # This is the main training data configuration
|
train: # This is the main training data configuration
|
||||||
batch_size: 64
|
batch_size: 64
|
||||||
loader_type: text
|
|
||||||
tokenizer_type: not_applicable
|
tokenizer_type: not_applicable
|
||||||
prepend_datapath: false
|
prepend_datapath: false
|
||||||
list_of_train_data: <add path to data here> # Path to training data
|
list_of_train_data: <add path to data here> # Path to training data
|
||||||
|
|
154
core/client.py
154
core/client.py
|
@ -7,13 +7,12 @@ workers 1 to N for processing a given client's data. It's main method is the
|
||||||
'''
|
'''
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from easydict import EasyDict as edict
|
from easydict import EasyDict as edict
|
||||||
|
|
||||||
import h5py
|
from importlib.machinery import SourceFileLoader
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -47,13 +46,6 @@ import extensions.privacy
|
||||||
from extensions.privacy import metrics as privacy_metrics
|
from extensions.privacy import metrics as privacy_metrics
|
||||||
from experiments import make_model
|
from experiments import make_model
|
||||||
|
|
||||||
|
|
||||||
# A per-process cache of the training data, so clients don't have to repeatedly re-load
|
|
||||||
# TODO: deprecate this in favor of passing dataloader around
|
|
||||||
_data_dict = None
|
|
||||||
_file_ext = None
|
|
||||||
|
|
||||||
|
|
||||||
class Client:
|
class Client:
|
||||||
# It's unclear why, but sphinx refuses to generate method docs
|
# It's unclear why, but sphinx refuses to generate method docs
|
||||||
# if there is no docstring for this class.
|
# if there is no docstring for this class.
|
||||||
|
@ -83,112 +75,45 @@ class Client:
|
||||||
return self.client_id, self.client_data, self.config, self.send_gradients
|
return self.client_id, self.client_data, self.config, self.send_gradients
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_num_users(filename):
|
def get_train_dataset(data_path, client_train_config, task):
|
||||||
'''Count users given a JSON or HDF5 file.
|
'''This function will obtain the training dataset for all
|
||||||
|
users.
|
||||||
This function will fill the global data dict. Ideally we want data
|
|
||||||
handling not to happen here and only at the dataloader, that will be the
|
|
||||||
behavior in future releases.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename (str): path to file containing data.
|
data_path (str): path to file containing taining data.
|
||||||
|
client_train_config (dict): trainig data config.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
global _data_dict
|
|
||||||
global _file_ext
|
|
||||||
_file_ext = filename.split('.')[-1]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if _file_ext == 'json' or _file_ext == 'txt':
|
dir = os.path.join('experiments',task,'dataloaders','dataset.py')
|
||||||
if _data_dict is None:
|
loader = SourceFileLoader("Dataset",dir).load_module()
|
||||||
print_rank('Reading training data dictionary from JSON')
|
dataset = loader.Dataset
|
||||||
with open(filename,'r') as fid:
|
train_file = os.path.join(data_path, client_train_config['list_of_train_data']) if client_train_config['list_of_train_data'] != None else None
|
||||||
_data_dict = json.load(fid) # pre-cache the training data
|
train_dataset = dataset(train_file, args=client_train_config)
|
||||||
_data_dict = scrub_empty_clients(_data_dict) # empty clients MUST be scrubbed here to match num_clients in the entry script
|
num_users = len(train_dataset.user_list)
|
||||||
print_rank('Read training data dictionary', loglevel=logging.DEBUG)
|
print_rank("Total amount of training users: {}".format(num_users))
|
||||||
|
|
||||||
elif _file_ext == 'hdf5':
|
|
||||||
print_rank('Reading training data dictionary from HDF5')
|
|
||||||
_data_dict = h5py.File(filename, 'r')
|
|
||||||
print_rank('Read training data dictionary', loglevel=logging.DEBUG)
|
|
||||||
|
|
||||||
except:
|
except:
|
||||||
raise ValueError('Error reading training file. Please make sure the format is allowed')
|
print_rank("Dataset not found, please make sure is located inside the experiment folder")
|
||||||
|
|
||||||
num_users = len(_data_dict['users'])
|
return num_users, train_dataset
|
||||||
return num_users
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_data(client_id, dataloader):
|
def get_data(clients, dataset):
|
||||||
'''Load data from the dataloader given the client's id.
|
''' Create training dictionary'''
|
||||||
|
|
||||||
This function will load the global data dict. Ideally we want data
|
data_with_labels = hasattr(dataset,"user_data_label")
|
||||||
handling not to happen here and only at the dataloader, that will be the
|
input_strct = {'users': [], 'num_samples': [],'user_data': dict(), 'user_data_label': dict()} if data_with_labels else {'users': [], 'num_samples': [],'user_data': dict()}
|
||||||
behavior in future releases.
|
|
||||||
|
|
||||||
Args:
|
for client in clients:
|
||||||
client_id (int or list): identifier(s) for grabbing client's data.
|
user = dataset.user_list[client]
|
||||||
dataloader (torch.utils.data.DataLoader): dataloader that
|
input_strct['users'].append(user)
|
||||||
provides the trianing
|
input_strct['num_samples'].append(dataset.num_samples[client])
|
||||||
'''
|
input_strct['user_data'][user]= dataset.user_data[user]
|
||||||
|
if data_with_labels:
|
||||||
|
input_strct['user_data_label'][user] = dataset.user_data_label[user]
|
||||||
|
|
||||||
# Auxiliary function for decoding only when necessary
|
return edict(input_strct)
|
||||||
decode_if_str = lambda x: x.decode() if isinstance(x, bytes) else x
|
|
||||||
|
|
||||||
# During training, client_id will be always an integer
|
|
||||||
if isinstance(client_id, int):
|
|
||||||
user_name = decode_if_str(_data_dict['users'][client_id])
|
|
||||||
num_samples = _data_dict['num_samples'][client_id]
|
|
||||||
|
|
||||||
if _file_ext == 'hdf5':
|
|
||||||
arr_data = [decode_if_str(e) for e in _data_dict['user_data'][user_name]['x'][()]]
|
|
||||||
user_data = {'x': arr_data}
|
|
||||||
elif _file_ext == 'json' or _file_ext == 'txt':
|
|
||||||
user_data = _data_dict['user_data'][user_name]
|
|
||||||
|
|
||||||
if 'user_data_label' in _data_dict: # supervised problem
|
|
||||||
labels = _data_dict['user_data_label'][user_name]
|
|
||||||
if _file_ext == 'hdf5': # transforms HDF5 Dataset into Numpy array
|
|
||||||
labels = labels[()]
|
|
||||||
|
|
||||||
return edict({'users': [user_name],
|
|
||||||
'user_data': {user_name: user_data},
|
|
||||||
'num_samples': [num_samples],
|
|
||||||
'user_data_label': {user_name: labels}})
|
|
||||||
else:
|
|
||||||
print_rank('no labels present, unsupervised problem', loglevel=logging.DEBUG)
|
|
||||||
return edict({'users': [user_name],
|
|
||||||
'user_data': {user_name: user_data},
|
|
||||||
'num_samples': [num_samples]})
|
|
||||||
|
|
||||||
# During validation and test, client_id might be a list of integers
|
|
||||||
elif isinstance(client_id, list):
|
|
||||||
if 'user_data_label' in _data_dict:
|
|
||||||
users_dict = {'users': [], 'num_samples': [], 'user_data': {}, 'user_data_label': {}}
|
|
||||||
else:
|
|
||||||
users_dict = {'users': [], 'num_samples': [], 'user_data': {}}
|
|
||||||
|
|
||||||
for client in client_id:
|
|
||||||
user_name = decode_if_str(dataloader.dataset.user_list[client])
|
|
||||||
users_dict['users'].append(user_name)
|
|
||||||
users_dict['num_samples'].append(dataloader.dataset.num_samples[client])
|
|
||||||
|
|
||||||
if _file_ext == 'hdf5':
|
|
||||||
arr_data = dataloader.dataset.user_data[user_name]['x']
|
|
||||||
arr_decoded = [decode_if_str(e) for e in arr_data]
|
|
||||||
users_dict['user_data'][user_name] = {'x': arr_decoded}
|
|
||||||
elif _file_ext == 'json':
|
|
||||||
users_dict['user_data'][user_name] = {'x': dataloader.dataset.user_data[user_name]['x']}
|
|
||||||
elif _file_ext == 'txt': # using a different line for .txt since our files have a different structure
|
|
||||||
users_dict['user_data'][user_name] = dataloader.dataset.user_data[user_name]
|
|
||||||
|
|
||||||
if 'user_data_label' in _data_dict:
|
|
||||||
labels = dataloader.dataset.user_data_label[user_name]
|
|
||||||
if _file_ext == 'hdf5':
|
|
||||||
labels = labels[()]
|
|
||||||
users_dict['user_data_label'][user_name] = labels
|
|
||||||
|
|
||||||
return users_dict
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def run_testvalidate(client_data, server_data, mode, model):
|
def run_testvalidate(client_data, server_data, mode, model):
|
||||||
|
@ -287,30 +212,12 @@ class Client:
|
||||||
begin = time.time()
|
begin = time.time()
|
||||||
client_stats = {}
|
client_stats = {}
|
||||||
|
|
||||||
# Update the location of the training file
|
|
||||||
data_config['list_of_train_data'] = os.path.join(data_path, data_config['list_of_train_data'])
|
|
||||||
|
|
||||||
user = data_strct['users'][0]
|
user = data_strct['users'][0]
|
||||||
if 'user_data_label' in data_strct.keys(): # supervised case
|
|
||||||
input_strct = edict({
|
|
||||||
'users': [user],
|
|
||||||
'user_data': {user: data_strct['user_data'][user]},
|
|
||||||
'num_samples': [data_strct['num_samples'][0]],
|
|
||||||
'user_data_label': {user: data_strct['user_data_label'][user]}
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
input_strct = edict({
|
|
||||||
'users': [user],
|
|
||||||
'user_data': {user: data_strct['user_data'][user]},
|
|
||||||
'num_samples': [data_strct['num_samples'][0]]
|
|
||||||
})
|
|
||||||
|
|
||||||
print_rank('Loading : {}-th client with name: {}, {} samples, {}s elapsed'.format(
|
print_rank('Loading : {}-th client with name: {}, {} samples, {}s elapsed'.format(
|
||||||
client_id, user, data_strct['num_samples'][0], time.time() - begin), loglevel=logging.INFO)
|
client_id[0], user, data_strct['num_samples'][0], time.time() - begin), loglevel=logging.INFO)
|
||||||
|
|
||||||
# Get dataloaders
|
# Get dataloaders
|
||||||
train_dataloader = make_train_dataloader(data_config, data_path, task=task, clientx=0, data_strct=input_strct)
|
train_dataloader = make_train_dataloader(data_config, data_path, task=task, clientx=0, data_strct=data_strct)
|
||||||
val_dataloader = make_val_dataloader(data_config, data_path)
|
|
||||||
|
|
||||||
# Instantiate the model object
|
# Instantiate the model object
|
||||||
if model is None:
|
if model is None:
|
||||||
|
@ -349,7 +256,6 @@ class Client:
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
ss_scheduler=ss_scheduler,
|
ss_scheduler=ss_scheduler,
|
||||||
train_dataloader=train_dataloader,
|
train_dataloader=train_dataloader,
|
||||||
val_dataloader=val_dataloader,
|
|
||||||
server_replay_config =client_config,
|
server_replay_config =client_config,
|
||||||
max_grad_norm=client_config['data_config']['train'].get('max_grad_norm', None),
|
max_grad_norm=client_config['data_config']['train'].get('max_grad_norm', None),
|
||||||
anneal_config=client_config['annealing_config'] if 'annealing_config' in client_config else None,
|
anneal_config=client_config['annealing_config'] if 'annealing_config' in client_config else None,
|
||||||
|
@ -386,7 +292,7 @@ class Client:
|
||||||
|
|
||||||
# This is where training actually happens
|
# This is where training actually happens
|
||||||
train_loss, num_samples = trainer.train_desired_samples(desired_max_samples=desired_max_samples, apply_privacy_metrics=apply_privacy_metrics)
|
train_loss, num_samples = trainer.train_desired_samples(desired_max_samples=desired_max_samples, apply_privacy_metrics=apply_privacy_metrics)
|
||||||
print_rank('client={}: training loss={}'.format(client_id, train_loss), loglevel=logging.DEBUG)
|
print_rank('client={}: training loss={}'.format(client_id[0], train_loss), loglevel=logging.DEBUG)
|
||||||
|
|
||||||
# Estimate gradient magnitude mean/var
|
# Estimate gradient magnitude mean/var
|
||||||
# Now computed when the sufficient stats are updated.
|
# Now computed when the sufficient stats are updated.
|
||||||
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
from torch.utils.data import DataLoader as PyTorchDataLoader
|
||||||
|
from abc import ABC
|
||||||
|
|
||||||
|
class BaseDataLoader(ABC, PyTorchDataLoader):
|
||||||
|
'''This is a wrapper class for PyTorch dataloaders.'''
|
||||||
|
|
||||||
|
def create_loader(self):
|
||||||
|
'''Returns the dataloader'''
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
from torch.utils.data import Dataset as PyTorchDataset
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
class BaseDataset(ABC, PyTorchDataset):
|
||||||
|
'''This is a wrapper class for PyTorch datasets.'''
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self,**kwargs):
|
||||||
|
super(BaseDataset, self).__init__()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __getitem__(self, idx, **kwargs):
|
||||||
|
'''Fetches a data sample for a given key'''
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __len__(self):
|
||||||
|
'''Returns the size of the dataset'''
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_data(self,**kwargs):
|
||||||
|
'''Wrapper method to read/instantiate the dataset'''
|
||||||
|
pass
|
|
@ -184,10 +184,10 @@ class Evaluation():
|
||||||
current_total += count
|
current_total += count
|
||||||
if current_total > threshold:
|
if current_total > threshold:
|
||||||
print_rank(f'sending {len(current_users_idxs)} users', loglevel=logging.DEBUG)
|
print_rank(f'sending {len(current_users_idxs)} users', loglevel=logging.DEBUG)
|
||||||
yield Client(current_users_idxs, self.config, False, dataloader)
|
yield Client(current_users_idxs, self.config, False, dataloader.dataset)
|
||||||
current_users_idxs = list()
|
current_users_idxs = list()
|
||||||
current_total = 0
|
current_total = 0
|
||||||
|
|
||||||
if len(current_users_idxs) != 0:
|
if len(current_users_idxs) != 0:
|
||||||
print_rank(f'sending {len(current_users_idxs)} users -- residual', loglevel=logging.DEBUG)
|
print_rank(f'sending {len(current_users_idxs)} users -- residual', loglevel=logging.DEBUG)
|
||||||
yield Client(current_users_idxs, self.config, False, dataloader)
|
yield Client(current_users_idxs, self.config, False, dataloader.dataset)
|
||||||
|
|
|
@ -2,20 +2,7 @@
|
||||||
# Licensed under the MIT license.
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
# Macro variable that sets which distributed trainig framework is used (e.g. mpi, syft, horovod)
|
# Macro variable that sets which distributed trainig framework is used (e.g. mpi, syft, horovod)
|
||||||
TRAINING_FRAMEWORK_TYPE = 'mpi'
|
TRAINING_FRAMEWORK_TYPE = 'mpi'
|
||||||
logging_level = logging.INFO # DEBUG | INFO
|
logging_level = logging.INFO # DEBUG | INFO
|
||||||
file_type = None
|
|
||||||
task = None
|
|
||||||
|
|
||||||
|
|
||||||
def define_file_type (data_path,config, exp_folder):
|
|
||||||
global file_type
|
|
||||||
global task
|
|
||||||
|
|
||||||
filename = os.path.join(data_path, config["client_config"]["data_config"]["train"]["list_of_train_data"])
|
|
||||||
arr_filename = filename.split(".")
|
|
||||||
file_type = arr_filename[-1]
|
|
||||||
print(" File_type has ben assigned to: {}".format(file_type))
|
|
||||||
task = exp_folder
|
|
|
@ -121,8 +121,7 @@
|
||||||
'allow_unknown': True,
|
'allow_unknown': True,
|
||||||
'schema': {
|
'schema': {
|
||||||
'batch_size': {'required': False, 'type':'integer', 'default': 40},
|
'batch_size': {'required': False, 'type':'integer', 'default': 40},
|
||||||
'loader_type': {'required': False, 'type':'string', 'default':'text'},
|
'val_data': {'required': True, 'type':'string', 'nullable':True},
|
||||||
'val_data': {'required': True, 'type':'string'},
|
|
||||||
'tokenizer_type': {'required': False, 'type':'string'},
|
'tokenizer_type': {'required': False, 'type':'string'},
|
||||||
'prepend_datapath': {'required': False, 'type':'boolean', 'default': False},
|
'prepend_datapath': {'required': False, 'type':'boolean', 'default': False},
|
||||||
'vocab_dict': {'required': False, 'type':'string'},
|
'vocab_dict': {'required': False, 'type':'string'},
|
||||||
|
@ -142,8 +141,7 @@
|
||||||
'allow_unknown': True,
|
'allow_unknown': True,
|
||||||
'schema': {
|
'schema': {
|
||||||
'batch_size': {'required': False, 'type':'integer', 'default': 40},
|
'batch_size': {'required': False, 'type':'integer', 'default': 40},
|
||||||
'loader_type': {'required': False, 'type':'string', 'default':'text'},
|
'test_data': {'required': True, 'type':'string', 'nullable': True},
|
||||||
'test_data': {'required': True, 'type':'string'},
|
|
||||||
'tokenizer_type': {'required': False, 'type':'string'},
|
'tokenizer_type': {'required': False, 'type':'string'},
|
||||||
'prepend_datapath': {'required': False, 'type':'boolean', 'default': False},
|
'prepend_datapath': {'required': False, 'type':'boolean', 'default': False},
|
||||||
'vocab_dict': {'required': False, 'type':'string'},
|
'vocab_dict': {'required': False, 'type':'string'},
|
||||||
|
@ -163,7 +161,6 @@
|
||||||
'allow_unknown': True,
|
'allow_unknown': True,
|
||||||
'schema': {
|
'schema': {
|
||||||
'batch_size': {'required': False, 'type':'integer', 'default': 40},
|
'batch_size': {'required': False, 'type':'integer', 'default': 40},
|
||||||
'loader_type': {'required': False, 'type':'string', 'default':'text'},
|
|
||||||
'train_data': {'required': True, 'type':'string'},
|
'train_data': {'required': True, 'type':'string'},
|
||||||
'train_data_server': {'required': False, 'type':'string'},
|
'train_data_server': {'required': False, 'type':'string'},
|
||||||
'desired_max_samples': {'required': False, 'type':'integer'},
|
'desired_max_samples': {'required': False, 'type':'integer'},
|
||||||
|
@ -248,8 +245,7 @@
|
||||||
'allow_unknown': True,
|
'allow_unknown': True,
|
||||||
'schema': {
|
'schema': {
|
||||||
'batch_size': {'required': False, 'type':'integer', 'default': 40},
|
'batch_size': {'required': False, 'type':'integer', 'default': 40},
|
||||||
'loader_type': {'required': False, 'type':'string', 'default':'text'},
|
'list_of_train_data': {'required': True, 'type':'string', 'nullable': True},
|
||||||
'list_of_train_data': {'required': True, 'type':'string'},
|
|
||||||
'tokenizer_type': {'required': False, 'type':'string'},
|
'tokenizer_type': {'required': False, 'type':'string'},
|
||||||
'prepend_datapath': {'required': False, 'type':'boolean', 'default': False},
|
'prepend_datapath': {'required': False, 'type':'boolean', 'default': False},
|
||||||
'vocab_dict': {'required': False, 'type':'string'},
|
'vocab_dict': {'required': False, 'type':'string'},
|
||||||
|
|
|
@ -49,7 +49,7 @@ run = Run.get_context()
|
||||||
|
|
||||||
|
|
||||||
class OptimizationServer(federated.Server):
|
class OptimizationServer(federated.Server):
|
||||||
def __init__(self, num_clients, model, optimizer, ss_scheduler, data_path, model_path, train_dataloader,
|
def __init__(self, num_clients, model, optimizer, ss_scheduler, data_path, model_path, train_dataloader, train_dataset,
|
||||||
val_dataloader, test_dataloader, config, config_server):
|
val_dataloader, test_dataloader, config, config_server):
|
||||||
'''Implement Server's orchestration and aggregation.
|
'''Implement Server's orchestration and aggregation.
|
||||||
|
|
||||||
|
@ -133,6 +133,7 @@ class OptimizationServer(federated.Server):
|
||||||
# Creating an instance for the server-side trainer (runs mini-batch SGD)
|
# Creating an instance for the server-side trainer (runs mini-batch SGD)
|
||||||
self.server_replay_iterations = None
|
self.server_replay_iterations = None
|
||||||
self.server_trainer = None
|
self.server_trainer = None
|
||||||
|
self.train_dataset = train_dataset
|
||||||
if train_dataloader is not None:
|
if train_dataloader is not None:
|
||||||
assert 'server_replay_config' in server_config, 'server_replay_config is not set'
|
assert 'server_replay_config' in server_config, 'server_replay_config is not set'
|
||||||
assert 'optimizer_config' in server_config[
|
assert 'optimizer_config' in server_config[
|
||||||
|
@ -305,10 +306,10 @@ class OptimizationServer(federated.Server):
|
||||||
num_clients_curr_iter) if num_clients_curr_iter > 0 else self.client_idx_list
|
num_clients_curr_iter) if num_clients_curr_iter > 0 else self.client_idx_list
|
||||||
sampled_clients = [
|
sampled_clients = [
|
||||||
Client(
|
Client(
|
||||||
client_id,
|
[client_id],
|
||||||
self.config,
|
self.config,
|
||||||
self.config['client_config']['type'] == 'optimization',
|
self.config['client_config']['type'] == 'optimization',
|
||||||
None
|
self.train_dataset
|
||||||
) for client_id in sampled_idx_clients
|
) for client_id in sampled_idx_clients
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,6 @@ import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from importlib.machinery import SourceFileLoader
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -205,7 +204,6 @@ class Trainer(TrainerBase):
|
||||||
ss_scheduler: scheduled sampler.
|
ss_scheduler: scheduled sampler.
|
||||||
train_dataloader (torch.data.utils.DataLoader): dataloader that
|
train_dataloader (torch.data.utils.DataLoader): dataloader that
|
||||||
provides the training data.
|
provides the training data.
|
||||||
val_dataloader (torch.data.utils.DataLoader): provides val data.
|
|
||||||
server_replay_config (dict or None): config for replaying training;
|
server_replay_config (dict or None): config for replaying training;
|
||||||
defaults to None, in which case no replaying happens.
|
defaults to None, in which case no replaying happens.
|
||||||
optimizer (torch.optim.Optimizer or None): optimizer that will be used
|
optimizer (torch.optim.Optimizer or None): optimizer that will be used
|
||||||
|
@ -222,7 +220,6 @@ class Trainer(TrainerBase):
|
||||||
model,
|
model,
|
||||||
ss_scheduler,
|
ss_scheduler,
|
||||||
train_dataloader,
|
train_dataloader,
|
||||||
val_dataloader,
|
|
||||||
server_replay_config=None,
|
server_replay_config=None,
|
||||||
optimizer=None,
|
optimizer=None,
|
||||||
max_grad_norm=None,
|
max_grad_norm=None,
|
||||||
|
@ -255,7 +252,6 @@ class Trainer(TrainerBase):
|
||||||
self.anneal_config,
|
self.anneal_config,
|
||||||
self.optimizer)
|
self.optimizer)
|
||||||
|
|
||||||
self.val_dataloader = val_dataloader
|
|
||||||
self.cached_batches = []
|
self.cached_batches = []
|
||||||
self.ss_scheduler = ss_scheduler
|
self.ss_scheduler = ss_scheduler
|
||||||
|
|
||||||
|
|
|
@ -3,8 +3,12 @@ Adding New Scenarios
|
||||||
|
|
||||||
Data Preparation
|
Data Preparation
|
||||||
------------
|
------------
|
||||||
|
FLUTE provides the abstract class `BaseDataset` inside ``core/dataset.py`` that can be used to wrap
|
||||||
At this moment FLUTE only allows JSON and HDF5 files, and requires an specific formatting for the training data. Here is a sample data blob for language model training.
|
any dataset and make it compatible with the platform. The dataset should be able to access all the data,
|
||||||
|
and store it in the attributes `user_list`, `user_data`, `num_samples` and `user_data_labels` (optional).
|
||||||
|
These attributes are required to have these exact names. The abstract method ``load_data ()`` should be
|
||||||
|
used to instantiate/load the dataset and provide the training format required by FLUTE on-the-fly.
|
||||||
|
Here is a sample data blob for language model training.
|
||||||
|
|
||||||
.. code:: json
|
.. code:: json
|
||||||
|
|
||||||
|
@ -43,7 +47,7 @@ If labels are needed by the task, ``user_data_label`` will be required by FLUTE
|
||||||
Add the model to FLUTE
|
Add the model to FLUTE
|
||||||
--------------
|
--------------
|
||||||
|
|
||||||
FLUTE requires the model declaration framed in PyTorch, which must inhereit from the `BaseModel` class defined in `core/model.py`. The following methods should be overridden:
|
FLUTE requires the model declaration framed in PyTorch, which must inhereit from the `BaseModel` class defined in ``core/model.py``. The following methods should be overridden:
|
||||||
|
|
||||||
* __init__: model definition
|
* __init__: model definition
|
||||||
* loss: computes the loss used for training rounds
|
* loss: computes the loss used for training rounds
|
||||||
|
@ -92,8 +96,8 @@ Once the model is ready, all mandatory files must be in a single folder inside
|
||||||
|
|
||||||
task_name
|
task_name
|
||||||
|---- dataloaders
|
|---- dataloaders
|
||||||
|---- text_dataloader.py
|
|---- dataloader.py
|
||||||
|---- text_dataset.py
|
|---- dataset.py
|
||||||
|---- utils
|
|---- utils
|
||||||
|---- utils.py (if needed)
|
|---- utils.py (if needed)
|
||||||
|---- model.py
|
|---- model.py
|
||||||
|
@ -130,11 +134,12 @@ Once the keys have been included in the returning dictionary from `inference()`,
|
||||||
Create the configuration file
|
Create the configuration file
|
||||||
---------------------------------
|
---------------------------------
|
||||||
|
|
||||||
The configuration file will allow you to specify the setup in your experiment, such as the optimizer, learning rate, number of clients and so on. FLUTE requires the following 5 sections:
|
The configuration file will allow you to specify the setup in your experiment, such as the optimizer, learning rate, number of clients and so on. FLUTE requires the following 6 sections:
|
||||||
|
|
||||||
* model_config: path an parameters (if needed) to initialize the model.
|
* model_config: path an parameters (if needed) to initialize the model.
|
||||||
* dp_config: differential privacy setup.
|
* dp_config: differential privacy setup.
|
||||||
* privacy_metrics_config: for cache data to compute additional metrics.
|
* privacy_metrics_config: for cache data to compute additional metrics.
|
||||||
|
* strategy: defines the federated optimizer.
|
||||||
* server_config: determines all the server-side settings.
|
* server_config: determines all the server-side settings.
|
||||||
* client_config: dictates the learning parameters for client-side model updates.
|
* client_config: dictates the learning parameters for client-side model updates.
|
||||||
|
|
||||||
|
@ -175,12 +180,10 @@ The blob below indicates the basic parameters required by FLUTE to run an experi
|
||||||
data_config: # Information for the test/val dataloaders
|
data_config: # Information for the test/val dataloaders
|
||||||
val:
|
val:
|
||||||
batch_size: 10000
|
batch_size: 10000
|
||||||
loader_type: text
|
val_data: test_data.hdf5 # Assign to null for data loaded on-the-fly
|
||||||
val_data: test_data.hdf5
|
|
||||||
test:
|
test:
|
||||||
batch_size: 10000
|
batch_size: 10000
|
||||||
loader_type: text
|
test_data: test_data.hdf5 # Assign to null for data loaded on-the-fly
|
||||||
test_data: test_data.hdf5
|
|
||||||
type: model_optimization # Server type (model_optimization is the only available for now)
|
type: model_optimization # Server type (model_optimization is the only available for now)
|
||||||
aggregate_median: softmax # How aggregations weights are computed
|
aggregate_median: softmax # How aggregations weights are computed
|
||||||
initial_lr_client: 0.001 # Learning rate used on optimizer
|
initial_lr_client: 0.001 # Learning rate used on optimizer
|
||||||
|
@ -196,8 +199,7 @@ The blob below indicates the basic parameters required by FLUTE to run an experi
|
||||||
data_config: # Information for the train dataloader
|
data_config: # Information for the train dataloader
|
||||||
train:
|
train:
|
||||||
batch_size: 4
|
batch_size: 4
|
||||||
loader_type: text
|
list_of_train_data: train_data.hdf5 # Assign to null for data loaded on-the-fly
|
||||||
list_of_train_data: train_data.hdf5
|
|
||||||
desired_max_samples: 50000
|
desired_max_samples: 50000
|
||||||
optimizer_config: # Optimizer used by the client
|
optimizer_config: # Optimizer used by the client
|
||||||
type: sgd
|
type: sgd
|
||||||
|
|
|
@ -22,7 +22,7 @@ from core import federated
|
||||||
from core.config import FLUTEConfig
|
from core.config import FLUTEConfig
|
||||||
from core.server import select_server
|
from core.server import select_server
|
||||||
from core.client import Client
|
from core.client import Client
|
||||||
from core.globals import TRAINING_FRAMEWORK_TYPE, logging_level, define_file_type
|
from core.globals import TRAINING_FRAMEWORK_TYPE, logging_level
|
||||||
from experiments import make_model
|
from experiments import make_model
|
||||||
from utils import (
|
from utils import (
|
||||||
make_optimizer,
|
make_optimizer,
|
||||||
|
@ -88,7 +88,6 @@ def run_worker(model_path, config, task, data_path, local_rank):
|
||||||
"""
|
"""
|
||||||
model_config = config["model_config"]
|
model_config = config["model_config"]
|
||||||
server_config = config["server_config"]
|
server_config = config["server_config"]
|
||||||
define_file_type(data_path, config, task)
|
|
||||||
|
|
||||||
# Get the rank on MPI
|
# Get the rank on MPI
|
||||||
rank = local_rank if local_rank > -1 else federated.rank()
|
rank = local_rank if local_rank > -1 else federated.rank()
|
||||||
|
@ -108,11 +107,12 @@ def run_worker(model_path, config, task, data_path, local_rank):
|
||||||
print_rank('Server data preparation')
|
print_rank('Server data preparation')
|
||||||
|
|
||||||
# pre-cache the training data and capture the number of clients for sampling
|
# pre-cache the training data and capture the number of clients for sampling
|
||||||
training_filename = os.path.join(data_path, config["client_config"]["data_config"]["train"]["list_of_train_data"])
|
client_train_config = config["client_config"]["data_config"]["train"]
|
||||||
config["server_config"]["data_config"]["num_clients"] = Client.get_num_users(training_filename)
|
num_clients, train_dataset = Client.get_train_dataset(data_path, client_train_config,task)
|
||||||
data_config = config['server_config']['data_config']
|
config["server_config"]["data_config"]["num_clients"] = num_clients
|
||||||
|
|
||||||
# Make the Dataloaders
|
# Make the Dataloaders
|
||||||
|
data_config = config['server_config']['data_config']
|
||||||
if 'train' in data_config:
|
if 'train' in data_config:
|
||||||
server_train_dataloader = make_train_dataloader(data_config['train'], data_path, task=task, clientx=None)
|
server_train_dataloader = make_train_dataloader(data_config['train'], data_path, task=task, clientx=None)
|
||||||
else:
|
else:
|
||||||
|
@ -142,6 +142,7 @@ def run_worker(model_path, config, task, data_path, local_rank):
|
||||||
data_path,
|
data_path,
|
||||||
model_path,
|
model_path,
|
||||||
server_train_dataloader,
|
server_train_dataloader,
|
||||||
|
train_dataset,
|
||||||
val_dataloader,
|
val_dataloader,
|
||||||
test_dataloader,
|
test_dataloader,
|
||||||
config,
|
config,
|
||||||
|
|
|
@ -9,11 +9,9 @@ An adapted version of the tutorial above is provided in the
|
||||||
|
|
||||||
## Preparing the data
|
## Preparing the data
|
||||||
|
|
||||||
Right now FLUTE expects data to be provided either in JSON or HDF5 formats. It
|
In this experiment we are making use of the CIFAR10 Dataset from torchvision,
|
||||||
should be made data-agnostic in the near future, but right now we need to
|
initializated in `dataloaders/cifar_dataset.py`, which inhereits from the
|
||||||
convert the data to either of these formats. In our case, we can use the script
|
FLUTE base dataset class `core/dataset.py`
|
||||||
`utils/download_and_convert_data.py` to do that for us; a HDF5 file will be
|
|
||||||
generated.
|
|
||||||
|
|
||||||
## Specifying the model
|
## Specifying the model
|
||||||
|
|
||||||
|
@ -27,12 +25,11 @@ should be the same as in this example.
|
||||||
|
|
||||||
## Specifying dataset and dataloaders
|
## Specifying dataset and dataloaders
|
||||||
|
|
||||||
Inside the `dataloaders` folder, there are two files: `text_dataset.py` and
|
Inside the `dataloaders` folder, there are two files: `dataset.py` and
|
||||||
`text_dataloader.py` (the word "text" is used to mimic the other datasets, even
|
`dataloader.py`. Both inherit from the base classes declared in `core`
|
||||||
though in practice this loads images -- this will be changed in the future).
|
folder, that under the hood inhereit from Pytorch classes with same name.
|
||||||
Both inherit from the Pytorch classes with same name.
|
|
||||||
|
|
||||||
The dataset should be able to access all the data, which is stored in the
|
The dataset should be able to access all the data, and store it in the
|
||||||
attributes `user_list`, `user_data`, `user_data_labels` and `num_samples` (user
|
attributes `user_list`, `user_data`, `user_data_labels` and `num_samples` (user
|
||||||
names, user features, user labels if the problem is supervised, and number of
|
names, user features, user labels if the problem is supervised, and number of
|
||||||
samples for each user, respectively). These attributes are required to have
|
samples for each user, respectively). These attributes are required to have
|
||||||
|
@ -51,8 +48,7 @@ example is provided in `config.yaml`.
|
||||||
## Running the experiment
|
## Running the experiment
|
||||||
|
|
||||||
Finally, to launch the experiment, it suffices to launch the `e2e_trainer.py`
|
Finally, to launch the experiment, it suffices to launch the `e2e_trainer.py`
|
||||||
script using MPI (don't forget to first run
|
script using MPI
|
||||||
`utils/download_and_convert_data.py`):
|
|
||||||
|
|
||||||
```
|
```
|
||||||
mpiexec -n 4 python e2e_trainer.py -dataPath experiments/classif_cnn/utils/data -outputPath scratch -config experiments/classif_cnn/config.yaml -task classif_cnn
|
mpiexec -n 4 python e2e_trainer.py -dataPath experiments/classif_cnn/utils/data -outputPath scratch -config experiments/classif_cnn/config.yaml -task classif_cnn
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Basic configuration file for running classif_cnn example using hdf5 files.
|
# Basic configuration file for running classif_cnn example using torchvision CIFAR10 dataset.
|
||||||
# Parameters needed to initialize the model
|
# Parameters needed to initialize the model
|
||||||
model_config:
|
model_config:
|
||||||
model_type: CNN # class w/ `loss` and `inference` methods
|
model_type: CNN # class w/ `loss` and `inference` methods
|
||||||
|
@ -37,12 +37,10 @@ server_config:
|
||||||
data_config: # where to get val and test data from
|
data_config: # where to get val and test data from
|
||||||
val:
|
val:
|
||||||
batch_size: 10000
|
batch_size: 10000
|
||||||
loader_type: text
|
val_data: null # Assigned to null because dataset is being instantiated
|
||||||
val_data: test_data.hdf5
|
|
||||||
test:
|
test:
|
||||||
batch_size: 10000
|
batch_size: 10000
|
||||||
loader_type: text
|
test_data: null # Assigned to null because dataset is being instantiated
|
||||||
test_data: test_data.hdf5
|
|
||||||
type: model_optimization
|
type: model_optimization
|
||||||
aggregate_median: softmax # how aggregations weights are computed
|
aggregate_median: softmax # how aggregations weights are computed
|
||||||
initial_lr_client: 0.001 # learning rate used on client optimizer
|
initial_lr_client: 0.001 # learning rate used on client optimizer
|
||||||
|
@ -59,8 +57,7 @@ client_config:
|
||||||
data_config: # where to get training data from
|
data_config: # where to get training data from
|
||||||
train:
|
train:
|
||||||
batch_size: 4
|
batch_size: 4
|
||||||
loader_type: text
|
list_of_train_data: null # Assigned to null because dataset is being instantiated
|
||||||
list_of_train_data: train_data.hdf5
|
|
||||||
desired_max_samples: 50000
|
desired_max_samples: 50000
|
||||||
optimizer_config: # this is the optimizer used by the client
|
optimizer_config: # this is the optimizer used by the client
|
||||||
type: sgd
|
type: sgd
|
||||||
|
|
|
@ -0,0 +1,51 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
import time
|
||||||
|
import torchvision
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
|
||||||
|
class CIFAR10:
|
||||||
|
def __init__(self) :
|
||||||
|
# Get training and testing data from torchvision
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||||
|
])
|
||||||
|
|
||||||
|
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
|
||||||
|
download=True, transform=transform)
|
||||||
|
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
|
||||||
|
download=True, transform=transform)
|
||||||
|
|
||||||
|
print('Processing training set...')
|
||||||
|
self.trainset=_process(trainset, n_users=1000)
|
||||||
|
|
||||||
|
print('Processing test set...')
|
||||||
|
self.testset=_process(testset, n_users=200)
|
||||||
|
|
||||||
|
def _process(dataset, n_users):
|
||||||
|
'''Process a Torchvision dataset to expected format and save to disk'''
|
||||||
|
|
||||||
|
# Split training data equally among all users
|
||||||
|
total_samples = len(dataset)
|
||||||
|
samples_per_user = total_samples // n_users
|
||||||
|
assert total_samples % n_users == 0
|
||||||
|
|
||||||
|
# Function for getting a given user's data indices
|
||||||
|
user_idxs = lambda user_id: slice(user_id * samples_per_user, (user_id + 1) * samples_per_user)
|
||||||
|
|
||||||
|
# Convert training data to expected format
|
||||||
|
print('Converting data to expected format...')
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
data_dict = { # the data is expected to have this format
|
||||||
|
'users' : [f'{user_id:04d}' for user_id in range(n_users)],
|
||||||
|
'num_samples' : 10000 * [samples_per_user],
|
||||||
|
'user_data' : {f'{user_id:04d}': dataset.data[user_idxs(user_id)].tolist() for user_id in range(n_users)},
|
||||||
|
'user_data_label': {f'{user_id:04d}': dataset.targets[user_idxs(user_id)] for user_id in range(n_users)},
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f'Finished converting data in {time.time() - start_time:.2f}s.')
|
||||||
|
|
||||||
|
return data_dict
|
||||||
|
|
|
@ -2,21 +2,19 @@
|
||||||
# Licensed under the MIT license.
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from experiments.classif_cnn.dataloaders.text_dataset import TextDataset
|
from core.dataloader import BaseDataLoader
|
||||||
|
from experiments.classif_cnn.dataloaders.dataset import Dataset
|
||||||
|
|
||||||
|
class DataLoader(BaseDataLoader):
|
||||||
class TextDataLoader(DataLoader):
|
|
||||||
def __init__(self, mode, num_workers=0, **kwargs):
|
def __init__(self, mode, num_workers=0, **kwargs):
|
||||||
args = kwargs['args']
|
args = kwargs['args']
|
||||||
self.batch_size = args['batch_size']
|
self.batch_size = args['batch_size']
|
||||||
|
|
||||||
dataset = TextDataset(
|
dataset = Dataset(
|
||||||
data=kwargs['data'],
|
data=kwargs['data'],
|
||||||
test_only=(not mode=='train'),
|
test_only=(not mode=='train'),
|
||||||
user_idx=kwargs.get('user_idx', None),
|
user_idx=kwargs.get('user_idx', None),
|
||||||
file_type='hdf5',
|
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
@ -27,9 +25,6 @@ class TextDataLoader(DataLoader):
|
||||||
collate_fn=self.collate_fn,
|
collate_fn=self.collate_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_loader(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def collate_fn(self, batch):
|
def collate_fn(self, batch):
|
||||||
x, y = list(zip(*batch))
|
x, y = list(zip(*batch))
|
||||||
return {'x': torch.tensor(x), 'y': torch.tensor(y)}
|
return {'x': torch.tensor(x), 'y': torch.tensor(y)}
|
|
@ -0,0 +1,46 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from core.dataset import BaseDataset
|
||||||
|
from experiments.classif_cnn.dataloaders.cifar_dataset import CIFAR10
|
||||||
|
|
||||||
|
class Dataset(BaseDataset):
|
||||||
|
def __init__(self, data, test_only=False, user_idx=0, **kwargs):
|
||||||
|
self.test_only = test_only
|
||||||
|
self.user_idx = user_idx
|
||||||
|
|
||||||
|
# Get all data
|
||||||
|
self.user_list, self.user_data, self.user_data_label, self.num_samples = self.load_data(data, self.test_only)
|
||||||
|
|
||||||
|
if self.test_only: # combine all data into single array
|
||||||
|
self.user = 'test_only'
|
||||||
|
self.features = np.vstack([user_data for user_data in self.user_data.values()])
|
||||||
|
self.labels = np.hstack([user_label for user_label in self.user_data_label.values()])
|
||||||
|
else: # get a single user's data
|
||||||
|
if user_idx is None:
|
||||||
|
raise ValueError('in train mode, user_idx must be specified')
|
||||||
|
|
||||||
|
self.user = self.user_list[user_idx]
|
||||||
|
self.features = self.user_data[self.user]
|
||||||
|
self.labels = self.user_data_label[self.user]
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return np.array(self.features[idx]).astype(np.float32).T, self.labels[idx]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.features)
|
||||||
|
|
||||||
|
def load_data(self, data, test_only):
|
||||||
|
'''Wrapper method to read/instantiate the dataset'''
|
||||||
|
|
||||||
|
if data == None:
|
||||||
|
dataset = CIFAR10()
|
||||||
|
data = dataset.testset if test_only else dataset.trainset
|
||||||
|
|
||||||
|
users = data['users']
|
||||||
|
features = data['user_data']
|
||||||
|
labels = data['user_data_label']
|
||||||
|
num_samples = data['num_samples']
|
||||||
|
|
||||||
|
return users, features, labels, num_samples
|
|
@ -1,56 +0,0 @@
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT license.
|
|
||||||
|
|
||||||
import h5py
|
|
||||||
import json
|
|
||||||
import numpy as np
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
|
||||||
class TextDataset(Dataset):
|
|
||||||
def __init__(self, data, test_only=False, user_idx=None, file_type=None):
|
|
||||||
self.test_only = test_only
|
|
||||||
self.user_idx = user_idx
|
|
||||||
self.file_type = file_type
|
|
||||||
|
|
||||||
# Get all data
|
|
||||||
self.user_list, self.user_data, self.user_data_label, self.num_samples = self.load_data(data, self.file_type)
|
|
||||||
|
|
||||||
if self.test_only: # combine all data into single array
|
|
||||||
self.user = 'test_only'
|
|
||||||
self.features = np.vstack([user_data['x'] for user_data in self.user_data.values()])
|
|
||||||
self.labels = np.hstack(list(self.user_data_label.values()))
|
|
||||||
else: # get a single user's data
|
|
||||||
if user_idx is None:
|
|
||||||
raise ValueError('in train mode, user_idx must be specified')
|
|
||||||
|
|
||||||
self.user = self.user_list[user_idx]
|
|
||||||
self.features = self.user_data[self.user]['x']
|
|
||||||
self.labels = self.user_data_label[self.user]
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
return self.features[idx].astype(np.float32).T, self.labels[idx]
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.features)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load_data(data, file_type):
|
|
||||||
'''Load data from disk or memory.
|
|
||||||
|
|
||||||
The :code:`data` argument can be either the path to the JSON
|
|
||||||
or HDF5 file that contains the expected dictionary, or the
|
|
||||||
actual dictionary.'''
|
|
||||||
|
|
||||||
if isinstance(data, str):
|
|
||||||
if file_type == 'json':
|
|
||||||
with open(data, 'r') as fid:
|
|
||||||
data = json.load(fid)
|
|
||||||
elif file_type == 'hdf5':
|
|
||||||
data = h5py.File(data, 'r')
|
|
||||||
|
|
||||||
users = data['users']
|
|
||||||
features = data['user_data']
|
|
||||||
labels = data['user_data_label']
|
|
||||||
num_samples = data['num_samples']
|
|
||||||
|
|
||||||
return users, features, labels, num_samples
|
|
|
@ -37,11 +37,9 @@ server_config:
|
||||||
data_config: # where to get val and test data from
|
data_config: # where to get val and test data from
|
||||||
val:
|
val:
|
||||||
batch_size: 10000
|
batch_size: 10000
|
||||||
loader_type: text
|
|
||||||
val_data: test_data.hdf5
|
val_data: test_data.hdf5
|
||||||
test:
|
test:
|
||||||
batch_size: 10000
|
batch_size: 10000
|
||||||
loader_type: text
|
|
||||||
test_data: test_data.hdf5
|
test_data: test_data.hdf5
|
||||||
type: model_optimization
|
type: model_optimization
|
||||||
aggregate_median: softmax # how aggregations weights are computed
|
aggregate_median: softmax # how aggregations weights are computed
|
||||||
|
@ -59,7 +57,6 @@ client_config:
|
||||||
data_config: # where to get training data from
|
data_config: # where to get training data from
|
||||||
train:
|
train:
|
||||||
batch_size: 96
|
batch_size: 96
|
||||||
loader_type: text
|
|
||||||
list_of_train_data: train_data.hdf5
|
list_of_train_data: train_data.hdf5
|
||||||
desired_max_samples: 87000
|
desired_max_samples: 87000
|
||||||
optimizer_config: # this is the optimizer used by the client
|
optimizer_config: # this is the optimizer used by the client
|
||||||
|
|
|
@ -1,17 +1,17 @@
|
||||||
# Copyright (c) Microsoft Corporation.
|
# Copyright (c) Microsoft Corporation.
|
||||||
# Licensed under the MIT license.
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
from experiments.ecg_cnn.dataloaders.text_dataset import TextDataset
|
from experiments.ecg_cnn.dataloaders.dataset import Dataset
|
||||||
|
from core.dataloader import BaseDataLoader
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
class TextDataLoader(DataLoader):
|
class DataLoader(BaseDataLoader):
|
||||||
def __init__(self, mode, num_workers=0, **kwargs):
|
def __init__(self, mode, num_workers=0, **kwargs):
|
||||||
args = kwargs['args']
|
args = kwargs['args']
|
||||||
self.batch_size = args['batch_size']
|
self.batch_size = args['batch_size']
|
||||||
|
|
||||||
dataset = TextDataset(
|
dataset = Dataset(
|
||||||
data=kwargs['data'],
|
data=kwargs['data'],
|
||||||
test_only=(not mode=='train'),
|
test_only=(not mode=='train'),
|
||||||
user_idx=kwargs.get('user_idx', None),
|
user_idx=kwargs.get('user_idx', None),
|
||||||
|
@ -26,9 +26,6 @@ class TextDataLoader(DataLoader):
|
||||||
collate_fn=self.collate_fn,
|
collate_fn=self.collate_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_loader(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def collate_fn(self, batch):
|
def collate_fn(self, batch):
|
||||||
x, y = list(zip(*batch))
|
x, y = list(zip(*batch))
|
||||||
return {'x': torch.tensor(x), 'y': torch.tensor(y)}
|
return {'x': torch.tensor(x), 'y': torch.tensor(y)}
|
|
@ -0,0 +1,64 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from core.dataset import BaseDataset
|
||||||
|
|
||||||
|
class Dataset(BaseDataset):
|
||||||
|
def __init__(self, data, test_only=False, user_idx=0, **kwargs):
|
||||||
|
self.test_only = test_only
|
||||||
|
self.user_idx = user_idx
|
||||||
|
|
||||||
|
# Get all data
|
||||||
|
self.user_list, self.user_data, self.user_data_label, self.num_samples = self.load_data(data)
|
||||||
|
|
||||||
|
if self.test_only: # combine all data into single array
|
||||||
|
self.user = 'test_only'
|
||||||
|
self.features = np.vstack([user_data['x'] for user_data in self.user_data.values()])
|
||||||
|
self.labels = np.hstack([user_label['x'] for user_label in self.user_data_label.values()])
|
||||||
|
else: # get a single user's data
|
||||||
|
if user_idx is None:
|
||||||
|
raise ValueError('in train mode, user_idx must be specified')
|
||||||
|
|
||||||
|
self.user = self.user_list[user_idx]
|
||||||
|
self.features = self.user_data[self.user]['x']
|
||||||
|
self.labels = self.user_data_label[self.user]['x']
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
items = self.features[idx].astype(np.float32).T.reshape(1,187)
|
||||||
|
return items, self.labels[idx]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.features)
|
||||||
|
|
||||||
|
def load_data(self,data):
|
||||||
|
'''Load data from disk or memory'''
|
||||||
|
|
||||||
|
if isinstance(data, str):
|
||||||
|
try:
|
||||||
|
data = h5py.File(data, 'r')
|
||||||
|
except:
|
||||||
|
raise ValueError('Only HDF5 format is allowed for this experiment')
|
||||||
|
|
||||||
|
users = []
|
||||||
|
num_samples = data['num_samples']
|
||||||
|
features, labels = dict(), dict()
|
||||||
|
|
||||||
|
# Decoding bytes from hdf5
|
||||||
|
decode_if_str = lambda x: x.decode() if isinstance(x, bytes) else x
|
||||||
|
for user in data['users']:
|
||||||
|
user = decode_if_str(user)
|
||||||
|
users.append(user)
|
||||||
|
features[user] = {'x': data['user_data'][user]['x'][()]}
|
||||||
|
labels[user] = {'x': data['user_data_label'][user][()]}
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
users = data['users']
|
||||||
|
features = data['user_data']
|
||||||
|
labels = data['user_data_label']
|
||||||
|
num_samples = data['num_samples']
|
||||||
|
|
||||||
|
return users, features, labels, num_samples
|
|
@ -1,56 +0,0 @@
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT license.
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
|
||||||
import h5py
|
|
||||||
import json
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
class TextDataset(Dataset):
|
|
||||||
def __init__(self, data, test_only=False, user_idx=None, file_type=None):
|
|
||||||
self.test_only = test_only
|
|
||||||
self.user_idx = user_idx
|
|
||||||
self.file_type = file_type
|
|
||||||
|
|
||||||
# Get all data
|
|
||||||
self.user_list, self.user_data, self.user_data_label, self.num_samples = self.load_data(data, self.file_type)
|
|
||||||
|
|
||||||
if self.test_only: # combine all data into single array
|
|
||||||
self.user = 'test_only'
|
|
||||||
self.features = np.vstack([user_data['x'] for user_data in self.user_data.values()])
|
|
||||||
self.labels = np.hstack(list(self.user_data_label.values()))
|
|
||||||
else: # get a single user's data
|
|
||||||
if user_idx is None:
|
|
||||||
raise ValueError('in train mode, user_idx must be specified')
|
|
||||||
|
|
||||||
self.user = self.user_list[user_idx]
|
|
||||||
self.features = self.user_data[self.user]['x']
|
|
||||||
self.labels = self.user_data_label[self.user]
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
items = self.features[idx].astype(np.float32).T.reshape(1,187)
|
|
||||||
return items, self.labels[idx]
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.features)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load_data(data, file_type):
|
|
||||||
'''Load data from disk or memory.
|
|
||||||
|
|
||||||
The :code:`data` argument can be either the path to the JSON
|
|
||||||
or HDF5 file that contains the expected dictionary, or the
|
|
||||||
actual dictionary.'''
|
|
||||||
|
|
||||||
if isinstance(data, str):
|
|
||||||
try:
|
|
||||||
data = h5py.File(data, 'r')
|
|
||||||
except:
|
|
||||||
raise ValueError('Only HDF5 format is allowed for this experiment')
|
|
||||||
|
|
||||||
users = data['users']
|
|
||||||
features = data['user_data']
|
|
||||||
labels = data['user_data_label']
|
|
||||||
num_samples = data['num_samples']
|
|
||||||
|
|
||||||
return users, features, labels, num_samples
|
|
|
@ -32,16 +32,15 @@ The file `centralized_model.ipynb` can be used to test a centralized run of the
|
||||||
|
|
||||||
#### Preparing the data
|
#### Preparing the data
|
||||||
|
|
||||||
Right now FLUTE expects data to be provided either in JSON or HDF5 formats. First, place the `mitbih_test.csv` and `mitbig_train.csv` files in the folder `.\ecg_cnn\data\mitbih\`. Next, run preprocess.py in the `utils` folder to generate the HDF5 files.
|
First, place the `mitbih_test.csv` and `mitbig_train.csv` files in the folder `.\ecg_cnn\data\mitbih\`. Next, run preprocess.py in the `utils` folder to generate the HDF5 files.
|
||||||
|
|
||||||
## Specifying dataset and dataloaders
|
## Specifying dataset and dataloaders
|
||||||
|
|
||||||
Inside the `dataloaders` folder, there are two files: `text_dataset.py` and
|
Inside the `dataloaders` folder, there are two files: `dataset.py` and
|
||||||
`text_dataloader.py` (the word "text" is used to mimic the other datasets, even
|
`dataloader.py`. Both inherit from the base classes declared in `core`
|
||||||
though in practice this loads images -- this will be changed in the future).
|
folder, that under the hood inhereit from Pytorch classes with same name.
|
||||||
Both inherit from the Pytorch classes with same name.
|
|
||||||
|
|
||||||
The dataset should be able to access all the data, which is stored in the
|
The dataset should be able to access all the data, and store it in the
|
||||||
attributes `user_list`, `user_data`, `user_data_labels` and `num_samples` (user
|
attributes `user_list`, `user_data`, `user_data_labels` and `num_samples` (user
|
||||||
names, user features, user labels if the problem is supervised, and number of
|
names, user features, user labels if the problem is supervised, and number of
|
||||||
samples for each user, respectively). These attributes are required to have
|
samples for each user, respectively). These attributes are required to have
|
||||||
|
|
|
@ -4,16 +4,12 @@ Instructions on how to run the experiment, given below.
|
||||||
|
|
||||||
## Preparing the data
|
## Preparing the data
|
||||||
|
|
||||||
Right now FLUTE expects data to be provided either in JSON or HDF5 formats. It
|
For this experiment, we can create a dummy dataset by running the
|
||||||
should be made data-agnostic in the near future, but at this moment we need to do some
|
|
||||||
preprocessing before handling the data on the model. For this experiment, we can run the
|
|
||||||
script located in `testing/create_data.py` as follows:
|
script located in `testing/create_data.py` as follows:
|
||||||
|
|
||||||
```code
|
```code
|
||||||
python create_data.py -e mlm
|
python create_data.py -e mlm
|
||||||
```
|
```
|
||||||
to download mock data already preprocessed. A new folder `mockup` will be generated
|
|
||||||
inside `testing` with all data needed for a local run.
|
|
||||||
|
|
||||||
A couple of scripts are provided in `utils/preprocessing` for preprocessing .tsv files
|
A couple of scripts are provided in `utils/preprocessing` for preprocessing .tsv files
|
||||||
in case you want to use your own data.
|
in case you want to use your own data.
|
||||||
|
@ -23,9 +19,16 @@ in case you want to use your own data.
|
||||||
All the parameters of the experiment are passed in a YAML file. An example is
|
All the parameters of the experiment are passed in a YAML file. An example is
|
||||||
provided in `configs/hello_world_mlm_bert_json.yaml` with the suggested parameters
|
provided in `configs/hello_world_mlm_bert_json.yaml` with the suggested parameters
|
||||||
to do a simple run for this experiment. Make sure to point your training files at
|
to do a simple run for this experiment. Make sure to point your training files at
|
||||||
the fields: train_data, test_data and val_data inside the config file.
|
the fields: list_of_train_data, test_data and val_data inside the config file.
|
||||||
|
|
||||||
## Running the experiment
|
## Running the experiment locally
|
||||||
|
|
||||||
|
Finally, to launch the experiment, it suffices to launch the `e2e_trainer.py`
|
||||||
|
script using MPI:
|
||||||
|
|
||||||
|
```code
|
||||||
|
mpiexec -n 2 python .\e2e_trainer.py -dataPath data_folder -outputPath scratch -config configs\hello_world_mlm_bert_json.yaml -task mlm_bert
|
||||||
|
```
|
||||||
|
|
||||||
For submitting jobs in Azure ML, we have included the instructions in the `Experiments`
|
For submitting jobs in Azure ML, we have included the instructions in the `Experiments`
|
||||||
section of the main `README.md`.
|
section of the main `README.md`.
|
|
@ -2,13 +2,13 @@
|
||||||
# Licensed under the MIT license.
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
from transformers.data.data_collator import default_data_collator, DataCollatorWithPadding
|
from transformers.data.data_collator import default_data_collator, DataCollatorWithPadding
|
||||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import RandomSampler, SequentialSampler
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from transformers import DataCollatorForLanguageModeling
|
from transformers import DataCollatorForLanguageModeling
|
||||||
from experiments.mlm_bert.dataloaders.text_dataset import TextDataset
|
from experiments.mlm_bert.dataloaders.dataset import Dataset
|
||||||
import torch
|
from core.dataloader import BaseDataLoader
|
||||||
|
|
||||||
class TextDataLoader(DataLoader):
|
class DataLoader(BaseDataLoader):
|
||||||
"""
|
"""
|
||||||
PyTorch dataloader for loading text data from
|
PyTorch dataloader for loading text data from
|
||||||
text_dataset.
|
text_dataset.
|
||||||
|
@ -40,7 +40,7 @@ class TextDataLoader(DataLoader):
|
||||||
|
|
||||||
print("Tokenizer is: ",tokenizer)
|
print("Tokenizer is: ",tokenizer)
|
||||||
|
|
||||||
dataset = TextDataset(
|
dataset = Dataset(
|
||||||
data,
|
data,
|
||||||
args= args,
|
args= args,
|
||||||
test_only = self.mode is not 'train',
|
test_only = self.mode is not 'train',
|
||||||
|
@ -63,7 +63,7 @@ class TextDataLoader(DataLoader):
|
||||||
|
|
||||||
if self.mode == 'train':
|
if self.mode == 'train':
|
||||||
train_sampler = RandomSampler(dataset)
|
train_sampler = RandomSampler(dataset)
|
||||||
super(TextDataLoader, self).__init__(
|
super(DataLoader, self).__init__(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
sampler=train_sampler,
|
sampler=train_sampler,
|
||||||
|
@ -75,7 +75,7 @@ class TextDataLoader(DataLoader):
|
||||||
|
|
||||||
elif self.mode == 'val' or self.mode == 'test':
|
elif self.mode == 'val' or self.mode == 'test':
|
||||||
eval_sampler = SequentialSampler(dataset)
|
eval_sampler = SequentialSampler(dataset)
|
||||||
super(TextDataLoader, self).__init__(
|
super(DataLoader, self).__init__(
|
||||||
dataset,
|
dataset,
|
||||||
sampler=eval_sampler,
|
sampler=eval_sampler,
|
||||||
batch_size= self.batch_size,
|
batch_size= self.batch_size,
|
||||||
|
@ -88,9 +88,6 @@ class TextDataLoader(DataLoader):
|
||||||
else:
|
else:
|
||||||
raise Exception("Sorry, there is something wrong with the 'mode'-parameter ")
|
raise Exception("Sorry, there is something wrong with the 'mode'-parameter ")
|
||||||
|
|
||||||
def create_loader(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def get_user(self):
|
def get_user(self):
|
||||||
return self.utt_ids
|
return self.utt_ids
|
||||||
|
|
|
@ -1,33 +1,50 @@
|
||||||
# Copyright (c) Microsoft Corporation.
|
# Copyright (c) Microsoft Corporation.
|
||||||
# Licensed under the MIT license.
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
from torch.utils.data import Dataset
|
from core.dataset import BaseDataset
|
||||||
|
from transformers import AutoTokenizer
|
||||||
from utils import print_rank
|
from utils import print_rank
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
class TextDataset(Dataset):
|
class Dataset(BaseDataset):
|
||||||
"""
|
"""
|
||||||
Map a text source to the target text
|
Map a text source to the target text
|
||||||
"""
|
"""
|
||||||
def __init__(self, data, args, tokenizer, test_only=False, user_idx=None, max_samples_per_user=-1, min_words_per_utt=5):
|
|
||||||
|
def __init__(self, data, args, tokenizer=None, test_only=False, user_idx=0, max_samples_per_user=-1, min_words_per_utt=5, **kwargs):
|
||||||
|
|
||||||
self.utt_list = list()
|
self.utt_list = list()
|
||||||
self.test_only= test_only
|
self.test_only= test_only
|
||||||
self.padding = args.get('padding', True)
|
self.padding = args.get('padding', True)
|
||||||
self.max_seq_length= args['max_seq_length']
|
self.max_seq_length= args['max_seq_length']
|
||||||
self.max_samples_per_user = max_samples_per_user
|
self.max_samples_per_user = max_samples_per_user
|
||||||
self.min_num_words = min_words_per_utt
|
self.min_num_words = min_words_per_utt
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.process_line_by_line=args.get('process_line_by_line', False)
|
self.process_line_by_line=args.get('process_line_by_line', False)
|
||||||
self.user = None
|
self.user = None
|
||||||
|
|
||||||
|
if tokenizer != None:
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
else:
|
||||||
|
tokenizer_kwargs = {
|
||||||
|
"cache_dir": args['cache_dir'],
|
||||||
|
"use_fast": args['tokenizer_type_fast'],
|
||||||
|
"use_auth_token": None
|
||||||
|
}
|
||||||
|
|
||||||
|
if 'tokenizer_name' in args:
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(args['tokenizer_name'], **tokenizer_kwargs)
|
||||||
|
elif 'model_name_or_path' in args:
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(args['model_name_or_path'], **tokenizer_kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError("You are instantiating a new tokenizer from scratch. This is not supported by this script.")
|
||||||
|
|
||||||
if self.max_seq_length is None:
|
if self.max_seq_length is None:
|
||||||
self.max_seq_length = self.tokenizer.model_max_length
|
self.max_seq_length = self.tokenizer.model_max_length
|
||||||
if self.max_seq_length > 512:
|
if self.max_seq_length > 512:
|
||||||
print_rank(
|
print_rank(
|
||||||
f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
|
f"The tokenizer picked seems to have a very large `model_max_length` ({self.tokenizer.model_max_length}). "
|
||||||
"Picking 512 instead. You can change that default value by passing --max_seq_length xxx.", loglevel=logging.DEBUG
|
"Picking 512 instead. You can change that default value by passing --max_seq_length xxx.", loglevel=logging.DEBUG
|
||||||
)
|
)
|
||||||
self.max_seq_length = 512
|
self.max_seq_length = 512
|
||||||
|
@ -39,7 +56,7 @@ class TextDataset(Dataset):
|
||||||
)
|
)
|
||||||
self.max_seq_length = min(self.max_seq_length, self.tokenizer.model_max_length)
|
self.max_seq_length = min(self.max_seq_length, self.tokenizer.model_max_length)
|
||||||
|
|
||||||
self.read_data(data, user_idx)
|
self.load_data(data, user_idx)
|
||||||
|
|
||||||
if not self.process_line_by_line:
|
if not self.process_line_by_line:
|
||||||
self.post_process_list()
|
self.post_process_list()
|
||||||
|
@ -65,7 +82,7 @@ class TextDataset(Dataset):
|
||||||
return self.utt_list[idx]
|
return self.utt_list[idx]
|
||||||
|
|
||||||
|
|
||||||
def read_data(self, orig_strct, user_idx):
|
def load_data(self, orig_strct, user_idx):
|
||||||
""" Reads the data for a specific user (unless it's for val/testing) and returns a
|
""" Reads the data for a specific user (unless it's for val/testing) and returns a
|
||||||
list of embeddings and targets."""
|
list of embeddings and targets."""
|
||||||
|
|
||||||
|
@ -85,7 +102,6 @@ class TextDataset(Dataset):
|
||||||
self.user = self.user_list[user_idx]
|
self.user = self.user_list[user_idx]
|
||||||
self.process_x(self.user_data[self.user])
|
self.process_x(self.user_data[self.user])
|
||||||
|
|
||||||
|
|
||||||
def process_x(self, raw_x_batch):
|
def process_x(self, raw_x_batch):
|
||||||
|
|
||||||
if self.test_only:
|
if self.test_only:
|
||||||
|
@ -101,7 +117,6 @@ class TextDataset(Dataset):
|
||||||
|
|
||||||
print_rank('Processing json-structure for User: {} Utterances Processed: {}'.format(self.user, len(self.utt_list)), loglevel=logging.INFO)
|
print_rank('Processing json-structure for User: {} Utterances Processed: {}'.format(self.user, len(self.utt_list)), loglevel=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
def process_user(self, user, user_data):
|
def process_user(self, user, user_data):
|
||||||
counter=0
|
counter=0
|
||||||
for line in user_data:
|
for line in user_data:
|
|
@ -4,16 +4,12 @@ Instructions on how to run the experiment, given below.
|
||||||
|
|
||||||
## Preparing the data
|
## Preparing the data
|
||||||
|
|
||||||
Right now FLUTE expects data to be provided either in JSON or HDF5 formats. It
|
For this experiment, we can create a dummy dataset by running the
|
||||||
should be made data-agnostic in the near future, but at this moment we need to do some
|
|
||||||
preprocessing before handling the data on the model. For this experiment, we can run the
|
|
||||||
script located in `testing/create_data.py` as follows:
|
script located in `testing/create_data.py` as follows:
|
||||||
|
|
||||||
```code
|
```code
|
||||||
python create_data.py -e nlg
|
python create_data.py -e nlg
|
||||||
```
|
```
|
||||||
to download mock data already preprocessed. A new folder `mockup` will be generated
|
|
||||||
inside `testing` with all data needed for a local run.
|
|
||||||
|
|
||||||
A couple of scripts are provided in `utils/preprocessing` for preprocessing .tsv files
|
A couple of scripts are provided in `utils/preprocessing` for preprocessing .tsv files
|
||||||
in case you want to use your own data.
|
in case you want to use your own data.
|
||||||
|
@ -34,7 +30,7 @@ Finally, to launch the experiment locally , it suffices to launch the `e2e_train
|
||||||
script using MPI, you can use as example the following line:
|
script using MPI, you can use as example the following line:
|
||||||
|
|
||||||
```code
|
```code
|
||||||
mpiexec -n 3 python e2e_trainer.py -dataPath .\testing\mockup\ -outputPath scratch -config .\testing\configs\hello_world_local.yaml -task nlg_gru
|
mpiexec -n 3 python e2e_trainer.py -dataPath .\testing\mockup\ -outputPath scratch -config .\testing\configs\hello_world_nlg_gru.yaml -task nlg_gru
|
||||||
```
|
```
|
||||||
|
|
||||||
For submitting jobs in Azure ML, we have included the instructions in the `Experiments`
|
For submitting jobs in Azure ML, we have included the instructions in the `Experiments`
|
||||||
|
|
|
@ -4,12 +4,12 @@
|
||||||
import random
|
import random
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.data import DataLoader
|
from core.dataloader import BaseDataLoader
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from experiments.nlg_gru.dataloaders.text_dataset import TextDataset
|
from experiments.nlg_gru.dataloaders.dataset import Dataset
|
||||||
from utils.data_utils import BatchSampler, DynamicBatchSampler
|
from utils.data_utils import BatchSampler, DynamicBatchSampler
|
||||||
|
|
||||||
class TextDataLoader(DataLoader):
|
class DataLoader(BaseDataLoader):
|
||||||
"""
|
"""
|
||||||
PyTorch dataloader for loading text data from
|
PyTorch dataloader for loading text data from
|
||||||
text_dataset.
|
text_dataset.
|
||||||
|
@ -20,7 +20,7 @@ class TextDataLoader(DataLoader):
|
||||||
self.batch_size = args['batch_size']
|
self.batch_size = args['batch_size']
|
||||||
batch_sampler = None
|
batch_sampler = None
|
||||||
|
|
||||||
dataset = TextDataset(
|
dataset = Dataset(
|
||||||
data = kwargs['data'],
|
data = kwargs['data'],
|
||||||
test_only = not mode=="train",
|
test_only = not mode=="train",
|
||||||
vocab_dict = args['vocab_dict'],
|
vocab_dict = args['vocab_dict'],
|
||||||
|
@ -61,11 +61,6 @@ class TextDataLoader(DataLoader):
|
||||||
collate_fn=self.collate_fn,
|
collate_fn=self.collate_fn,
|
||||||
pin_memory=args["pin_memory"])
|
pin_memory=args["pin_memory"])
|
||||||
|
|
||||||
|
|
||||||
def create_loader(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
def collate_fn(self, batch):
|
def collate_fn(self, batch):
|
||||||
def pad_and_concat_feats(labels):
|
def pad_and_concat_feats(labels):
|
||||||
batch_size = len(labels)
|
batch_size = len(labels)
|
|
@ -1,21 +1,20 @@
|
||||||
# Copyright (c) Microsoft Corporation.
|
# Copyright (c) Microsoft Corporation.
|
||||||
# Licensed under the MIT license.
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
from utils import print_rank
|
|
||||||
from core.globals import file_type
|
|
||||||
from experiments.nlg_gru.utils.utility import *
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import h5py
|
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
|
|
||||||
class TextDataset(Dataset):
|
from utils import print_rank
|
||||||
|
from core.dataset import BaseDataset
|
||||||
|
from experiments.nlg_gru.utils.utility import *
|
||||||
|
|
||||||
|
class Dataset(BaseDataset):
|
||||||
"""
|
"""
|
||||||
Map a text source to the target text
|
Map a text source to the target text
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data, min_num_words=2, max_num_words=25, test_only=False, user_idx=None, vocab_dict=None, preencoded=False):
|
def __init__(self, data, min_num_words=2, max_num_words=25, test_only=False, user_idx=0, vocab_dict=None, preencoded=False, **kwargs):
|
||||||
|
|
||||||
self.utt_list = list()
|
self.utt_list = list()
|
||||||
self.test_only = test_only
|
self.test_only = test_only
|
||||||
|
@ -24,11 +23,11 @@ class TextDataset(Dataset):
|
||||||
self.preencoded = preencoded
|
self.preencoded = preencoded
|
||||||
|
|
||||||
# Load the vocab
|
# Load the vocab
|
||||||
self.vocab = load_vocab(vocab_dict)
|
self.vocab = load_vocab(kwargs['args']['vocab_dict']) if 'args' in kwargs else load_vocab(vocab_dict)
|
||||||
self.vocab_size = len(self.vocab)
|
self.vocab_size = len(self.vocab)
|
||||||
|
|
||||||
# reading the jsonl for a specific user_idx
|
# reading the jsonl for a specific user_idx
|
||||||
self.read_data(data, user_idx)
|
self.load_data(data, user_idx)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
"""Return the length of the elements in the list."""
|
"""Return the length of the elements in the list."""
|
||||||
|
@ -47,47 +46,28 @@ class TextDataset(Dataset):
|
||||||
|
|
||||||
return batch, self.user
|
return batch, self.user
|
||||||
|
|
||||||
# Reads JSON or HDF5 files
|
def load_data(self, orig_strct, user_idx):
|
||||||
def read_data(self, orig_strct, user_idx):
|
|
||||||
|
|
||||||
if isinstance(orig_strct, str):
|
if isinstance(orig_strct, str):
|
||||||
if file_type == "json":
|
|
||||||
print('Loading json-file: ', orig_strct)
|
print('Loading json-file: ', orig_strct)
|
||||||
with open(orig_strct, 'r') as fid:
|
with open(orig_strct, 'r') as fid:
|
||||||
orig_strct = json.load(fid)
|
orig_strct = json.load(fid)
|
||||||
|
|
||||||
elif file_type == "hdf5":
|
|
||||||
print('Loading hdf5-file: ', orig_strct)
|
|
||||||
orig_strct = h5py.File(orig_strct, 'r')
|
|
||||||
|
|
||||||
self.user_list = orig_strct['users']
|
self.user_list = orig_strct['users']
|
||||||
self.num_samples = orig_strct['num_samples']
|
self.num_samples = orig_strct['num_samples']
|
||||||
self.user_data = orig_strct['user_data']
|
self.user_data = orig_strct['user_data']
|
||||||
|
self.user = 'test_only' if self.test_only else self.user_list[user_idx]
|
||||||
|
|
||||||
if self.test_only:
|
|
||||||
self.user = 'test_only'
|
|
||||||
self.process_x(self.user_data)
|
self.process_x(self.user_data)
|
||||||
else:
|
|
||||||
self.user = self.user_list[user_idx]
|
|
||||||
self.process_x(self.user_data[self.user])
|
|
||||||
|
|
||||||
|
def process_x(self, user_data):
|
||||||
def process_x(self, raw_x_batch):
|
|
||||||
print_rank('Processing data-structure: {} Utterances expected'.format(sum(self.num_samples)), loglevel=logging.DEBUG)
|
print_rank('Processing data-structure: {} Utterances expected'.format(sum(self.num_samples)), loglevel=logging.DEBUG)
|
||||||
if self.test_only:
|
|
||||||
for user in self.user_list:
|
for user in self.user_list:
|
||||||
for e in raw_x_batch[user]['x']:
|
for e in user_data[user]['x']:
|
||||||
utt={}
|
utt={}
|
||||||
utt['src_text'] = e if type(e) is list else e.split()
|
utt['src_text'] = e if type(e) is list else e.split()
|
||||||
utt['duration'] = len(e)
|
utt['duration'] = len(e)
|
||||||
utt["loss_weight"] = 1.0
|
|
||||||
self.utt_list.append(utt)
|
|
||||||
|
|
||||||
else:
|
|
||||||
for e in raw_x_batch['x']:
|
|
||||||
utt={}
|
|
||||||
utt['src_text'] = e if type(e) is list else e.split()
|
|
||||||
utt['duration'] = len(utt["src_text"])
|
|
||||||
if utt['duration']<= self.min_num_words:
|
if utt['duration']<= self.min_num_words:
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -37,12 +37,10 @@ server_config:
|
||||||
data_config: # where to get val and test data from
|
data_config: # where to get val and test data from
|
||||||
val:
|
val:
|
||||||
batch_size: 10000
|
batch_size: 10000
|
||||||
loader_type: text
|
val_data: null
|
||||||
val_data: data/classif_cnn/test_data.hdf5
|
|
||||||
test:
|
test:
|
||||||
batch_size: 10000
|
batch_size: 10000
|
||||||
loader_type: text
|
test_data: null
|
||||||
test_data: data/classif_cnn/test_data.hdf5
|
|
||||||
type: model_optimization
|
type: model_optimization
|
||||||
aggregate_median: softmax # how aggregations weights are computed
|
aggregate_median: softmax # how aggregations weights are computed
|
||||||
initial_lr_client: 0.001 # learning rate used on client optimizer
|
initial_lr_client: 0.001 # learning rate used on client optimizer
|
||||||
|
@ -59,8 +57,7 @@ client_config:
|
||||||
data_config: # where to get training data from
|
data_config: # where to get training data from
|
||||||
train:
|
train:
|
||||||
batch_size: 4
|
batch_size: 4
|
||||||
loader_type: text
|
list_of_train_data: null
|
||||||
list_of_train_data: data/classif_cnn/train_data.hdf5
|
|
||||||
desired_max_samples: 50000
|
desired_max_samples: 50000
|
||||||
optimizer_config: # this is the optimizer used by the client
|
optimizer_config: # this is the optimizer used by the client
|
||||||
type: sgd
|
type: sgd
|
||||||
|
|
|
@ -37,11 +37,9 @@ server_config:
|
||||||
data_config: # where to get val and test data from
|
data_config: # where to get val and test data from
|
||||||
val:
|
val:
|
||||||
batch_size: 10000
|
batch_size: 10000
|
||||||
loader_type: text
|
|
||||||
val_data: data/ecg_cnn/test_data.hdf5
|
val_data: data/ecg_cnn/test_data.hdf5
|
||||||
test:
|
test:
|
||||||
batch_size: 10000
|
batch_size: 10000
|
||||||
loader_type: text
|
|
||||||
test_data: data/ecg_cnn/test_data.hdf5
|
test_data: data/ecg_cnn/test_data.hdf5
|
||||||
type: model_optimization
|
type: model_optimization
|
||||||
aggregate_median: softmax # how aggregations weights are computed
|
aggregate_median: softmax # how aggregations weights are computed
|
||||||
|
@ -59,7 +57,6 @@ client_config:
|
||||||
data_config: # where to get training data from
|
data_config: # where to get training data from
|
||||||
train:
|
train:
|
||||||
batch_size: 96
|
batch_size: 96
|
||||||
loader_type: text
|
|
||||||
list_of_train_data: data/ecg_cnn/train_data.hdf5
|
list_of_train_data: data/ecg_cnn/train_data.hdf5
|
||||||
desired_max_samples: 87000
|
desired_max_samples: 87000
|
||||||
optimizer_config: # this is the optimizer used by the client
|
optimizer_config: # this is the optimizer used by the client
|
||||||
|
|
|
@ -60,7 +60,6 @@ server_config:
|
||||||
num_clients_per_iteration: 2 # Number of clients sampled per round
|
num_clients_per_iteration: 2 # Number of clients sampled per round
|
||||||
data_config: # Server-side data configuration
|
data_config: # Server-side data configuration
|
||||||
val: # Validation data
|
val: # Validation data
|
||||||
loader_type: text
|
|
||||||
val_data: data/mlm_bert/val_data.txt
|
val_data: data/mlm_bert/val_data.txt
|
||||||
task: mlm
|
task: mlm
|
||||||
mlm_probability: 0.25
|
mlm_probability: 0.25
|
||||||
|
@ -82,7 +81,6 @@ server_config:
|
||||||
# train_data_server: null
|
# train_data_server: null
|
||||||
# desired_max_samples: null
|
# desired_max_samples: null
|
||||||
test: # Test data configuration
|
test: # Test data configuration
|
||||||
loader_type: text
|
|
||||||
test_data: data/mlm_bert/test_data.txt
|
test_data: data/mlm_bert/test_data.txt
|
||||||
task: mlm
|
task: mlm
|
||||||
mlm_probability: 0.25
|
mlm_probability: 0.25
|
||||||
|
@ -112,7 +110,6 @@ client_config:
|
||||||
do_profiling: false # Enables client-side training profiling
|
do_profiling: false # Enables client-side training profiling
|
||||||
data_config:
|
data_config:
|
||||||
train: # This is the main training data configuration
|
train: # This is the main training data configuration
|
||||||
loader_type: text
|
|
||||||
list_of_train_data: data/mlm_bert/train_data.txt
|
list_of_train_data: data/mlm_bert/train_data.txt
|
||||||
task: mlm
|
task: mlm
|
||||||
mlm_probability: 0.25
|
mlm_probability: 0.25
|
||||||
|
|
|
@ -42,7 +42,6 @@ server_config:
|
||||||
data_config: # Server-side data configuration
|
data_config: # Server-side data configuration
|
||||||
val: # Validation data
|
val: # Validation data
|
||||||
# batch_size: 2048
|
# batch_size: 2048
|
||||||
# loader_type: text
|
|
||||||
tokenizer_type: not_applicable
|
tokenizer_type: not_applicable
|
||||||
prepend_datapath: false
|
prepend_datapath: false
|
||||||
val_data: data/nlg_gru/val_data.json
|
val_data: data/nlg_gru/val_data.json
|
||||||
|
@ -55,7 +54,6 @@ server_config:
|
||||||
unsorted_batch: true
|
unsorted_batch: true
|
||||||
test: # Test data configuration
|
test: # Test data configuration
|
||||||
batch_size: 2048
|
batch_size: 2048
|
||||||
loader_type: text
|
|
||||||
tokenizer_type: not_applicable
|
tokenizer_type: not_applicable
|
||||||
prepend_datapath: false
|
prepend_datapath: false
|
||||||
train_data: null
|
train_data: null
|
||||||
|
@ -87,7 +85,6 @@ client_config:
|
||||||
data_config:
|
data_config:
|
||||||
train: # This is the main training data configuration
|
train: # This is the main training data configuration
|
||||||
batch_size: 64
|
batch_size: 64
|
||||||
loader_type: text
|
|
||||||
tokenizer_type: not_applicable
|
tokenizer_type: not_applicable
|
||||||
prepend_datapath: false
|
prepend_datapath: false
|
||||||
list_of_train_data: data/nlg_gru/train_data.json
|
list_of_train_data: data/nlg_gru/train_data.json
|
||||||
|
|
|
@ -61,11 +61,13 @@ def test_mlm_bert():
|
||||||
assert run_pipeline(data_path, output_path, config_path, task)==0
|
assert run_pipeline(data_path, output_path, config_path, task)==0
|
||||||
print("PASSED")
|
print("PASSED")
|
||||||
|
|
||||||
|
@pytest.mark.xfail
|
||||||
def test_classif_cnn():
|
def test_classif_cnn():
|
||||||
|
|
||||||
task = 'classif_cnn'
|
task = 'classif_cnn'
|
||||||
data_path, output_path, config_path = get_info(task)
|
data_path, output_path, config_path = get_info(task)
|
||||||
assert run_pipeline(data_path, output_path, config_path, task)==0
|
assert run_pipeline(data_path, output_path, config_path, task)==0
|
||||||
|
print("PASSED")
|
||||||
|
|
||||||
def test_ecg_cnn():
|
def test_ecg_cnn():
|
||||||
|
|
||||||
|
|
|
@ -14,34 +14,14 @@ def get_exp_dataloader(task):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
dir = os.path.join('experiments',task,'dataloaders','text_dataloader.py')
|
dir = os.path.join('experiments',task,'dataloaders','dataloader.py')
|
||||||
loader = SourceFileLoader("TextDataLoader",dir).load_module()
|
loader = SourceFileLoader("DataLoader",dir).load_module()
|
||||||
loader = loader.TextDataLoader
|
loader = loader.DataLoader
|
||||||
except:
|
except:
|
||||||
print_rank("Dataloader not found, please make sure is located inside the experiment folder")
|
print_rank("Dataloader not found, please make sure is located inside the experiment folder")
|
||||||
|
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
|
|
||||||
def detect_loader_type(my_data, loader_type):
|
|
||||||
""" Detect the loader type declared in the configuration file
|
|
||||||
|
|
||||||
Inside this function should go the implementation of
|
|
||||||
specific detection for any kind of loader.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
my_data (str): path of file or chunk file set
|
|
||||||
loader_type (str): loader description in yaml file
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not loader_type == "auto_detect":
|
|
||||||
return loader_type
|
|
||||||
|
|
||||||
# Here should go the implementation for the rest of loaders
|
|
||||||
else:
|
|
||||||
raise ValueError("Unknown format: {}".format(loader_type))
|
|
||||||
|
|
||||||
|
|
||||||
def make_train_dataloader(data_config, data_path, clientx, task=None, vec_size=300, data_strct=None):
|
def make_train_dataloader(data_config, data_path, clientx, task=None, vec_size=300, data_strct=None):
|
||||||
""" Create a dataloader for training on either server or client side """
|
""" Create a dataloader for training on either server or client side """
|
||||||
|
|
||||||
|
@ -64,67 +44,43 @@ def make_train_dataloader(data_config, data_path, clientx, task=None, vec_size=3
|
||||||
else:
|
else:
|
||||||
my_data = data_config["list_of_train_data"]
|
my_data = data_config["list_of_train_data"]
|
||||||
|
|
||||||
# Find the loader_type
|
DataLoader = get_exp_dataloader(task)
|
||||||
loader_type = detect_loader_type(my_data, data_config["loader_type"])
|
train_dataloader = DataLoader(data = data_strct if data_strct is not None else my_data,
|
||||||
|
|
||||||
if loader_type == 'text':
|
|
||||||
TextDataLoader = get_exp_dataloader(task)
|
|
||||||
train_dataloader = TextDataLoader(
|
|
||||||
data = data_strct if data_strct is not None else my_data,
|
|
||||||
user_idx = clientx,
|
user_idx = clientx,
|
||||||
mode = mode,
|
mode = mode,
|
||||||
args=data_config
|
args=data_config
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
raise NotImplementedError("Not supported {}: detected_type={} loader_type={} audio_format={}".format(my_data, loader_type, data_config["loader_type"], data_config["audio_format"]))
|
|
||||||
return train_dataloader
|
return train_dataloader
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def make_val_dataloader(data_config, data_path, task=None, data_strct=None):
|
def make_val_dataloader(data_config, data_path, task=None, data_strct=None, train_mode=False):
|
||||||
""" Return a data loader for a validation set """
|
""" Return a data loader for a validation set """
|
||||||
|
if train_mode:
|
||||||
if not "val_data" in data_config or data_config["val_data"] is None:
|
|
||||||
print_rank("Validation data list is not set", loglevel=logging.DEBUG)
|
|
||||||
return None
|
return None
|
||||||
|
DataLoader = get_exp_dataloader(task)
|
||||||
loader_type = detect_loader_type(data_config["val_data"], data_config["loader_type"])
|
val_file = os.path.join(data_path, data_config["val_data"]) if data_config["val_data"] != None and data_path != None else None
|
||||||
|
val_dataloader = DataLoader(data = data_strct if data_strct is not None else val_file,
|
||||||
if loader_type == 'text':
|
|
||||||
TextDataLoader = get_exp_dataloader(task)
|
|
||||||
|
|
||||||
val_dataloader = TextDataLoader(
|
|
||||||
data = data_strct if data_strct is not None else os.path.join(data_path, data_config["val_data"]),
|
|
||||||
user_idx = 0,
|
user_idx = 0,
|
||||||
mode = 'val',
|
mode = 'val',
|
||||||
args=data_config
|
args=data_config
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
raise NotImplementedError("Not supported loader_type={} audio_format={}".format(loader_type, data_config["audio_format"]))
|
|
||||||
return val_dataloader
|
return val_dataloader
|
||||||
|
|
||||||
|
|
||||||
def make_test_dataloader(data_config, data_path, task=None, data_strct=None):
|
def make_test_dataloader(data_config, data_path, task=None, data_strct=None):
|
||||||
""" Return a data loader for an evaluation set. """
|
""" Return a data loader for an evaluation set. """
|
||||||
|
|
||||||
if not "test_data" in data_config or data_config["test_data"] is None:
|
DataLoader = get_exp_dataloader(task)
|
||||||
print_rank("Test data list is not set")
|
test_file = os.path.join(data_path, data_config["test_data"]) if data_config["test_data"] != None and data_path != None else None
|
||||||
return None
|
test_dataloader = DataLoader(data = data_strct if data_strct is not None else test_file,
|
||||||
|
|
||||||
loader_type = detect_loader_type(data_config["test_data"], data_config["loader_type"])
|
|
||||||
|
|
||||||
if loader_type == 'text':
|
|
||||||
TextDataLoader = get_exp_dataloader(task)
|
|
||||||
|
|
||||||
test_dataloader = TextDataLoader(
|
|
||||||
data = data_strct if data_strct is not None else os.path.join(data_path, data_config["test_data"]),
|
|
||||||
user_idx = 0,
|
user_idx = 0,
|
||||||
mode = 'test',
|
mode = 'test',
|
||||||
args=data_config
|
args=data_config
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Not supported loader_type={} audio_format={}".format(loader_type, data_config["audio_format"]))
|
|
||||||
return test_dataloader
|
return test_dataloader
|
||||||
|
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче