diff --git a/.gitignore b/.gitignore index 7bbc71c..918a251 100644 --- a/.gitignore +++ b/.gitignore @@ -99,3 +99,9 @@ ENV/ # mypy .mypy_cache/ + +# Pycharm +.idea/ + +################# +job.json \ No newline at end of file diff --git a/Pytorch/Docker/Dockerfile b/Pytorch/Docker/Dockerfile new file mode 100644 index 0000000..ba0cd3a --- /dev/null +++ b/Pytorch/Docker/Dockerfile @@ -0,0 +1,3 @@ +FROM pytorch/pytorch:0.4_cuda9_cudnn7 + +RUN pip install --no-cache-dir h5py scipy jupyter ipykernel numpy toolz pandas scikit-learn pillow \ No newline at end of file diff --git a/Pytorch/Makefile b/Pytorch/Makefile new file mode 100644 index 0000000..7b1788a --- /dev/null +++ b/Pytorch/Makefile @@ -0,0 +1,11 @@ +DATA_DIR:=/mnt/imagenet +PWD:=$(shell pwd) +FAKE:='False' +FAKE_DATA_LENGTH:=1281167 +name_prefix:=iliauk +tag:=latest +image-open:=$(name_prefix)/pytorch_gloo:$(tag) +open-path:=$(PWD)/Docker +script:=\$$AZ_BATCHAI_INPUT_SCRIPTS/imagenet_pytorch_gloo.py +include ../include/build.mk + diff --git a/Pytorch/src/imagenet_pytorch_gloo.py b/Pytorch/src/imagenet_pytorch_gloo.py new file mode 100644 index 0000000..90a7ce5 --- /dev/null +++ b/Pytorch/src/imagenet_pytorch_gloo.py @@ -0,0 +1,283 @@ +import argparse +import logging +import os +from os import path +import numpy as np +import pandas as pd +import multiprocessing +from toolz import pipe +from timer import Timer +from PIL import Image + +import torch +import torch.nn as nn +import torch.backends.cudnn as cudnn +import torch.optim +import torch.utils.data +import torchvision.transforms as transforms +from torch.utils.data import DataLoader, Dataset +import torchvision.models as models +import torch.distributed as dist +import torch.utils.data.distributed + +print("PyTorch: ", torch.__version__) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) +# Distributed training settings +parser = argparse.ArgumentParser(description='PyTorch ResNet Example') +parser.add_argument('--world-size', default=1, type=int, help='number of distributed processes') +parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, help='url used to set up distributed training') +parser.add_argument('--dist-backend', default='gloo', type=str, help='distributed backend') +parser.add_argument('--rank', default=-1, type=int, help='rank of the worker') + +_WIDTH = 224 +_HEIGHT = 224 +_LR = 0.001 +_EPOCHS = 1 +_NUM_GPU = int(torch.cuda.device_count()) +_BATCHSIZE = 64*_NUM_GPU +_RGB_MEAN = [0.485, 0.456, 0.406] +_RGB_SD = [0.229, 0.224, 0.225] + + +args = parser.parse_args() + +def _str_to_bool(in_str): + if 't' in in_str.lower(): + return True + else: + return False + +_FAKE = _str_to_bool(os.getenv('FAKE', 'True')) +_DATA_LENGTH = int(os.getenv('FAKE_DATA_LENGTH', 1281167)) # How much fake data to simulate, default to size of imagenet dataset + +#_DISTRIBUTED = _str_to_bool(os.getenv('DISTRIBUTED', 'False')) +_DISTRIBUTED = True +_CPU_COUNT = 8 +logger.info("Distributed mode: ", _DISTRIBUTED) +logger.info("CPU Count: ", _CPU_COUNT) + + +def _append_path_to(data_path, data_series): + return data_series.apply(lambda x: path.join(data_path, x)) + + +def _load_training(data_dir): + train_df = pd.read_csv(path.join(data_dir, 'train.csv')) + return train_df.assign(filenames=_append_path_to(path.join(data_dir, 'train'), + train_df.filenames)) + + +def _load_validation(data_dir): + train_df = pd.read_csv(path.join(data_dir, 'validation.csv')) + return train_df.assign(filenames=_append_path_to(path.join(data_dir, 'validation'), + train_df.filenames)) + + +def _create_data_fn(train_path, test_path): + logger.info('Reading training data info') + train_df = _load_training(train_path) + logger.info('Reading validation data info') + validation_df = _load_validation(test_path) + # File-path + train_X = train_df['filenames'].values + validation_X = validation_df['filenames'].values + # One-hot encoded labels for torch + train_labels = train_df[['num_id']].values.ravel() + validation_labels = validation_df[['num_id']].values.ravel() + # Index starts from 0 + train_labels -= 1 + validation_labels -= 1 + return train_X, train_labels, validation_X, validation_labels + + +class ImageNet(Dataset): + def __init__(self, img_locs, img_labels, transform=None): + self.img_locs, self.labels = img_locs, img_labels + self.transform = transform + logger.info("Loaded {} labels and {} images".format(len(self.labels), len(self.img_locs))) + + def __getitem__(self, idx): + im_file = self.img_locs[idx] + label = self.labels[idx] + with open(im_file, 'rb') as f: + im_rgb = Image.open(f) + # Make sure 3-channel (RGB) + im_rgb = im_rgb.convert('RGB') + if self.transform is not None: + im_rgb = self.transform(im_rgb) + return im_rgb, label + + def __len__(self): + return len(self.img_locs) + + +class FakeData(Dataset): + def __init__(self, + batch_size=32, + num_batches=20, + dim=(224, 224), + n_channels=3, + n_classes=10, + length=_DATA_LENGTH, + seed=42, + data_transform=None): + self.dim = dim + self.n_channels = n_channels + self.n_classes = n_classes + self.num_batches = num_batches + self._data = _create_data(batch_size, self.num_batches, self.dim, self.n_channels) + self._labels = _create_labels(batch_size, self.num_batches, self.n_classes) + self.translation_index = np.random.choice(len(self._labels), length) + self._length=length + + self._data_transform = data_transform + #logger = _get_logger() + logger.info("Creating fake data {} labels and {} images".format(n_classes, len(self._data))) + + def __getitem__(self, idx): + #logger = _get_logger() + logger.debug('Retrieving samples') + logger.debug(str(idx)) + tr_index_array = self.translation_index[idx] + + if self._data_transform is not None: + data=self._data_transform(self._data[tr_index_array]) + else: + data=self._data[tr_index_array] + + return data, self._labels[tr_index_array] + + def __len__(self): + return self._length + + +def _log_summary(data_length, duration): + #logger = _get_logger() + images_per_second = data_length / duration + logger.info('Data length: {}'.format(data_length)) + logger.info('Total duration: {:.3f}'.format(duration)) + logger.info('Total images/sec: {:.3f}'.format(images_per_second)) + logger.info('Batch size: (Per GPU {}: Total {})'.format(int(_BATCHSIZE/_NUM_GPU), _BATCHSIZE)) + logger.info('Distributed: {}'.format('True' if _DISTRIBUTED else 'False')) + logger.info('Num GPUs: {:.3f}'.format(_NUM_GPU)) # May need to pass in argument to get this + logger.info('Dataset: {}'.format('Synthetic' if _FAKE else 'Imagenet')) + +def _create_data(batch_size, num_batches, dim, channels, seed=42): + np.random.seed(seed) + return np.random.rand(batch_size * num_batches, + channels, + dim[0], + dim[1]).astype(np.float32) + + +def _create_labels(batch_size, num_batches, n_classes): + return np.random.choice(n_classes, batch_size * num_batches) + + +def train(train_loader, model, criterion, optimizer, epoch): + logger.info("Training ...") + model.train() + for i, (input, target) in enumerate(train_loader): + input, target = input.cuda(non_blocking=True), target.cuda(non_blocking=True) + # compute output + output = model(input) + loss = criterion(output, target) + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + +def validate(val_loader, model, criterion): + logger.info("Validating ...") + model.eval() + correct = 0 + total = 0 + with torch.no_grad(): + for i, (input, target) in enumerate(val_loader): + target = target.cuda(non_blocking=True) + # compute output + output = model(input) + _, predicted = torch.max(output.data, 1) + total += target.size(0) + correct += (predicted == target).sum().item() + logger.info('Top-1 Accuracy: %.2f %%' % (100 * correct / total)) + + +def main(): + # Autotune + cudnn.benchmark = True + # Load symbol + model = models.__dict__['resnet50'](pretrained=False) + if _DISTRIBUTED: + logger.info('Running in distributed mode') + dist.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank) + model.cuda() + model = torch.nn.parallel.DistributedDataParallel(model) + else: + model = torch.nn.DataParallel(model).cuda() + # Optimisers + criterion = nn.CrossEntropyLoss().cuda() + optimizer = torch.optim.SGD(model.parameters(), lr=_LR) + # Data-sets + if _FAKE: + logger.info("Setting up fake loaders") + train_dataset = FakeData(n_classes=1000, data_transform=torch.FloatTensor) + else: + normalize = transforms.Normalize(_RGB_MEAN, _RGB_SD) + train_X, train_y, valid_X, valid_y = _create_data_fn(os.getenv('AZ_BATCHAI_INPUT_TRAIN'), + os.getenv('AZ_BATCHAI_INPUT_TEST')) + train_dataset = ImageNet( + train_X, + train_y, + transforms.Compose([ + transforms.RandomResizedCrop(_WIDTH), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize])) + + + if _DISTRIBUTED: + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + else: + train_sampler = None + + # Data-loaders + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=_BATCHSIZE, shuffle=(train_sampler is None), num_workers=_CPU_COUNT, sampler=train_sampler) + + #val_loader = torch.utils.data.DataLoader( + # ImageNet( + # valid_X, + # valid_y, + # transforms.Compose([ + # transforms.Resize(256), + # transforms.CenterCrop(_WIDTH), + # transforms.ToTensor(), + # normalize])), batch_size=_BATCHSIZE, shuffle=False, + # num_workers=_CPU_COUNT) + + # Main training-loop + for epoch in range(_EPOCHS): + if _DISTRIBUTED: + train_sampler.set_epoch(epoch) + # Train + with Timer(output=logger.info, prefix="Training") as t: + train(train_loader, model, criterion, optimizer, epoch) + _log_summary(len(train_dataset), t.elapsed) + + # Validate + #with Timer(output=logger.info, prefix="Testing"): + # validate(val_loader, model, criterion) + + print("Finished") + +if __name__ == '__main__': + print("Pytorch") + main() \ No newline at end of file diff --git a/experiments/experiments_config.mk b/experiments/experiments_config.mk index 30e25c3..976d75b 100644 --- a/experiments/experiments_config.mk +++ b/experiments/experiments_config.mk @@ -1,5 +1,5 @@ # Variables for Batch AI - change as necessary -ID:=disdl +ID:=iliadl2 LOCATION:=eastus GROUP_NAME:=batch${ID}rg STORAGE_ACCOUNT_NAME:=batch${ID}st @@ -8,8 +8,8 @@ SELECTED_SUBSCRIPTION:="Team Danielle Internal" WORKSPACE:=workspace VM_SIZE:=Standard_NC24rs_v3 -NUM_NODES:=8 -CLUSTER_NAME:=msv100 +NUM_NODES:=2 +CLUSTER_NAME:=ikv100 GPU_TYPE:=V100 diff --git a/experiments/generate_job_spec.py b/experiments/generate_job_spec.py index 4863ff4..e917f28 100644 --- a/experiments/generate_job_spec.py +++ b/experiments/generate_job_spec.py @@ -8,36 +8,36 @@ logger = logging.getLogger(__name__) # # Config for Intel cmd_for_intel = \ - """source /opt/intel/compilers_and_libraries_2017.4.196/linux/mpi/intel64/bin/mpivars.sh; - echo $AZ_BATCH_HOST_LIST; - mpirun -n {total_processes} -ppn {processes_per_node} {hosts} - -env I_MPI_FABRICS=dapl - -env I_MPI_DAPL_PROVIDER=ofa-v2-ib0 - -env I_MPI_DYNAMIC_CONNECTION=0 - -env I_MPI_DEBUG=6 - -env I_MPI_HYDRA_DEBUG=on - -env DISTRIBUTED=True - {fake} - {fake_length} + """source /opt/intel/compilers_and_libraries_2017.4.196/linux/mpi/intel64/bin/mpivars.sh; + echo $AZ_BATCH_HOST_LIST; + mpirun -n {total_processes} -ppn {processes_per_node} {hosts} + -env I_MPI_FABRICS=dapl + -env I_MPI_DAPL_PROVIDER=ofa-v2-ib0 + -env I_MPI_DYNAMIC_CONNECTION=0 + -env I_MPI_DEBUG=6 + -env I_MPI_HYDRA_DEBUG=on + -env DISTRIBUTED=True + {fake} + {fake_length} python -u {script}""".replace('\n', '') # Config for OpenMPI cmd_for_openmpi = \ - """echo $AZ_BATCH_HOST_LIST; - cat $AZ_BATCHAI_MPI_HOST_FILE; - mpirun -np {total_processes} {hosts} - -bind-to none -map-by slot - -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH - -mca btl_tcp_if_include eth0 - -x NCCL_SOCKET_IFNAME=eth0 - -mca btl ^openib - -x NCCL_IB_DISABLE=1 - -x DISTRIBUTED=True + """echo $AZ_BATCH_HOST_LIST; + cat $AZ_BATCHAI_MPI_HOST_FILE; + mpirun -np {total_processes} {hosts} + -bind-to none -map-by slot + -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH + -mca btl_tcp_if_include eth0 + -x NCCL_SOCKET_IFNAME=eth0 + -mca btl ^openib + -x NCCL_IB_DISABLE=1 + -x DISTRIBUTED=True -x AZ_BATCHAI_INPUT_TRAIN -x AZ_BATCHAI_INPUT_TEST - {fake} - {fake_length} - --allow-run-as-root + {fake} + {fake_length} + --allow-run-as-root python -u {script}""".replace('\n', '') # Running on single node without mpi @@ -89,7 +89,8 @@ def _fake_length_for(mpitype, fake_length, data): return '' -def _prepare_command(mpitype, total_processes, processes_per_node, script, node_count, data=None, synthetic_length=1281167): +def _prepare_command(mpitype, total_processes, processes_per_node, script, node_count, data=None, + synthetic_length=1281167): command = cmd_choice_dict.get(mpitype, cmd_for_intel) return command.format(total_processes=total_processes, processes_per_node=processes_per_node, @@ -141,6 +142,48 @@ def generate_job_dict(image_name, } +def generate_job_dict_gloo(image_name, + script, + node_count=2): + # Command is hard-coded for time-being + # Not sure what world-size is?? Probably node_count but check + return { + "$schema": "https://raw.githubusercontent.com/Azure/BatchAI/master/schemas/2018-05-01/job.json", + "properties": { + "pyTorchSettings": { + "pythonScriptFilePath": script, + "commandLineArgs": "--world-size 2 --dist-backend $AZ_BATCHAI_PYTORCH_BACKEND --dist-url $AZ_BATCHAI_PYTORCH_INIT_METHOD --rank $AZ_BATCHAI_TASK_INDEX", + "communicationBackend": "gloo" + }, + "nodeCount": node_count, + "stdOutErrPathPrefix": "$AZ_BATCHAI_MOUNT_ROOT/extfs", + "inputDirectories": [{ + "id": "SCRIPTS", + "path": "$AZ_BATCHAI_MOUNT_ROOT/extfs/scripts" + }, + { + "id": "TRAIN", + "path": "$AZ_BATCHAI_MOUNT_ROOT/nfs/imagenet", + }, + { + "id": "TEST", + "path": "$AZ_BATCHAI_MOUNT_ROOT/nfs/imagenet", + }, + ], + "outputDirectories": [{ + "id": "MODEL", + "pathPrefix": "$AZ_BATCHAI_MOUNT_ROOT/extfs", + "pathSuffix": "Models" + }], + "containerSettings": { + "imageSourceRegistry": { + "image": image_name + } + } + } + } + + def generate_job_dict_cntk(image_name, command, node_count=2, @@ -203,15 +246,21 @@ def synthetic_data_job(image_name, filename, image_name)) total_processes = processes_per_node * \ node_count if total_processes is None else total_processes - command = _prepare_command(mpitype, - total_processes, - processes_per_node, - script, - node_count, - synthetic_length=synthetic_length) - job_template = generate_job_dict(image_name, - command, - node_count=node_count) + if mpitype == "gloo": + job_template = generate_job_dict_gloo(image_name, + script, + node_count=node_count) + else: + command = _prepare_command(mpitype, + total_processes, + processes_per_node, + script, + node_count, + synthetic_length=synthetic_length) + job_template = generate_job_dict(image_name, + command, + node_count=node_count) + write_json_to_file(job_template, filename) logger.info('Done') @@ -228,6 +277,7 @@ def imagenet_data_job(image_name, filename, image_name)) total_processes = processes_per_node * \ node_count if total_processes is None else total_processes + # non-synthetic gloo to add command = _prepare_command(mpitype, total_processes, processes_per_node, diff --git a/experiments/synthetic/Makefile b/experiments/synthetic/Makefile index 56a3d68..2b436a0 100644 --- a/experiments/synthetic/Makefile +++ b/experiments/synthetic/Makefile @@ -45,6 +45,11 @@ define submit_pytorch_local $(call submit_job, $(2)) endef +define submit_pytorch_gloo + $(call generate_job_gloo,iliauk/pytorch_gloo,\$$AZ_BATCHAI_INPUT_SCRIPTS/imagenet_pytorch_gloo.py,$(1),$(2), --synthetic_length ${FAKE_DATA_LENGTH}) + $(call submit_job, $(3)) +endef + define submit_cntk $(call generate_job_openmpi,hoaphumanoid/cntk:distributed,\$$AZ_BATCHAI_INPUT_SCRIPTS/imagenet_cntk.py,$(1),$(2), --synthetic_length ${FAKE_DATA_LENGTH}) $(call submit_job, $(3)) @@ -75,6 +80,7 @@ create-cluster: upload-nodeprep-scripts submit-all: submit-keras-intel32 submit-keras-intel16 submit-keras-intel8 submit-keras-intel4 \ submit-tf-intel32 submit-tf-intel16 submit-tf-intel8 submit-tf-intel4 \ submit-pytorch32 submit-pytorch16 submit-pytorch8 submit-pytorch4 \ +submit-pytorch_gloo32 submit-pytorch_gloo16 submit-pytorch_gloo8 submit-pytorch_gloo4 \ submit-cntk32 submit-cntk16 submit-cntk8 submit-cntk4 \ submit-keras-local submit-tf-local submit-pytorch-local submit_cntk_local @@ -139,4 +145,17 @@ submit-cntk4: $(call submit_cntk,1,$(PROCESSES_PER_NODE),cntk-4) submit-cntk-local: - $(call submit_cntk_local,1,cntk-local) \ No newline at end of file + $(call submit_cntk_local,1,cntk-local) + + +submit-pytorch_gloo32: + $(call submit_pytorch_gloo,8,$(PROCESSES_PER_NODE),pytorch_gloo-32) + +submit-pytorch_gloo16: + $(call submit_pytorch_gloo,4,$(PROCESSES_PER_NODE),pytorch_gloo-16) + +submit-pytorch_gloo8: + $(call submit_pytorch_gloo,2,$(PROCESSES_PER_NODE),pytorch_gloo-8) + +submit-pytorch_gloo4: + $(call submit_pytorch_gloo,1,$(PROCESSES_PER_NODE),pytorch_gloo-4) diff --git a/include/control.mk b/include/control.mk index 4ce6f9a..823b35c 100644 --- a/include/control.mk +++ b/include/control.mk @@ -51,6 +51,16 @@ define generate_job_local endef +define generate_job_gloo + python ../generate_job_spec.py $(1) gloo \ + $(2) \ + --filename job.json \ + --node_count $(3) \ + --ppn $(4) \ + $(5) +endef + + define stream_stdout az batchai job file stream -w $(WORKSPACE) -e $(EXPERIMENT) \ --j $(1) --output-directory-id stdouterr -f stdout.txt @@ -113,6 +123,7 @@ upload-scripts: set-storage $(call upload_script, ../../HorovodPytorch/src/imagenet_pytorch_horovod.py) $(call upload_script, ../../CNTK/src/imagenet_cntk.py) $(call upload_script, ../../CNTK/src/resnet_models.py) + $(call upload_script, ../../Pytorch/src/imagenet_pytorch_gloo.py) $(call upload_script, ../../common/timer.py) upload-nodeprep-scripts: set-storage @@ -160,7 +171,7 @@ delete: delete-cluster az group delete --name ${GROUP_NAME} -y -setup: select-subscription create-resource-group create-workspace create-storage set-storage set-az-defaults create-fileshare create-cluster list-clusters +setup: select-subscription create-resource-group create-workspace create-storage set-storage set-az-defaults create-fileshare create-directory upload-scripts create-cluster list-clusters create-experiment @echo "Cluster created" # @@ -169,6 +180,7 @@ setup: select-subscription create-resource-group create-workspace create-storage submit-all: submit-keras-intel32 submit-keras-intel16 submit-keras-intel8 submit-keras-intel4 \ submit-tf-intel32 submit-tf-intel16 submit-tf-intel8 submit-tf-intel4 \ submit-pytorch32 submit-pytorch16 submit-pytorch8 submit-pytorch4 \ +submit-pytorch_gloo32 submit-pytorch_gloo16 submit-pytorch_gloo8 submit-pytorch_gloo4 \ submit-cntk32 submit-cntk16 submit-cntk8 submit-cntk4 \ submit-keras-local submit-tf-local submit-pytorch-local submit_cntk_local @@ -191,6 +203,11 @@ clean-jobs: $(call delete_job, pytorch-16) $(call delete_job, pytorch-32) + $(call delete_job, pytorch_gloo-4) + $(call delete_job, pytorch_gloo-8) + $(call delete_job, pytorch_gloo-16) + $(call delete_job, pytorch_gloo-32) + $(call delete_job, cntk-local) $(call delete_job, cntk-4) $(call delete_job, cntk-8) @@ -198,6 +215,7 @@ clean-jobs: $(call delete_job, cntk-32) ####### Gather Results ###### +# TODO for PyTorch_Gloo gather-results:results.json @echo "All results gathered" @@ -205,6 +223,9 @@ gather-results:results.json results.json: pytorch_1gpulocal_$(GPU_TYPE)_local.results pytorch_4gpuopen_$(GPU_TYPE)_open.results \ pytorch_8gpuopen_$(GPU_TYPE)_open.results pytorch_16gpuopen_$(GPU_TYPE)_open.results \ pytorch_32gpuopen_$(GPU_TYPE)_open.results \ + pytorch_gloo_1gpulocal_$(GPU_TYPE)_local.results pytorch_gloo_4gpuopen_$(GPU_TYPE)_open.results \ + pytorch_gloo_8gpuopen_$(GPU_TYPE)_open.results pytorch_gloo_16gpuopen_$(GPU_TYPE)_open.results \ + pytorch_gloo_32gpuopen_$(GPU_TYPE)_open.results \ tf_1gpulocal_$(GPU_TYPE)_local.results tf_4gpuintel_$(GPU_TYPE)_intel.results \ tf_8gpuintel_$(GPU_TYPE)_intel.results tf_16gpuintel_$(GPU_TYPE)_intel.results \ tf_32gpuintel_$(GPU_TYPE)_intel.results \ @@ -233,7 +254,20 @@ pytorch_32gpuopen_$(GPU_TYPE)_open.results: $(call stream_stdout, pytorch-32)>pytorch_32gpuopen_$(GPU_TYPE)_open.results +pytorch_gloo_1gpulocal_$(GPU_TYPE)_local.results: + $(call stream_stdout, pytorch_gloo-local)>pytorch_gloo_1gpulocal_$(GPU_TYPE)_local.results +pytorch_gloo_4gpuopen_$(GPU_TYPE)_open.results: + $(call stream_stdout, pytorch_gloo-4)>pytorch_gloo_4gpuopen_$(GPU_TYPE)_open.results + +pytorch_gloo_8gpuopen_$(GPU_TYPE)_open.results: + $(call stream_stdout, pytorch_gloo-8)>pytorch_gloo_8gpuopen_$(GPU_TYPE)_open.results + +pytorch_gloo_16gpuopen_$(GPU_TYPE)_open.results: + $(call stream_stdout, pytorch_gloo-16)>pytorch_gloo_16gpuopen_$(GPU_TYPE)_open.results + +pytorch_gloo_32gpuopen_$(GPU_TYPE)_open.results: + $(call stream_stdout, pytorch_gloo-32)>pytorch_gloo_32gpuopen_$(GPU_TYPE)_open.results tf_1gpulocal_$(GPU_TYPE)_local.results: $(call stream_stdout, tf-local)>tf_1gpulocal_$(GPU_TYPE)_local.results