Refactor: immutable core classes using dataclasses (#15)

- Rename Module to Function to be consistent with MLIR
- Add a name property to Function
- Unify type and shape inference into a single pass to be performed after the full Function is constructed
- At op creation time, only look up op signature in terms of number of inputs and outputs
- Use dataclasses to define core IR classes and use the frozen option to ensure they are immutable
- Removes any leftover uses of deepcopy.

* Rename Module to Function

* Add device_var attribute to Pmap op

* Fix prettyprinting of Pmap

* Use frozen dataclasses for Types

* Convert Device to frozen dataclass

* Convert Op to a frozen dataclass

* Convert Function to frozen dataclass; add FunctionMaker

* Convert Value to frozen dataclass

* Make dataclass types hashable and use Value instead of value names (#17)

* Make dataclass types hashable and update SequentialExecutor tests

* Remove main from sequential executor tests

* Fix pipeline parallel scheduler

* Upgrade Python to 3.8, cache pip

* Op: minor simplification

* Update MLIR parser to use refactored IR

* Allow passing output types to Op()

* Fix typo, import order

* First implementation of type inference

* Fix a test

* Add some TODOs

* Replace strings with values in data parallel transform (and tests)

* Add remaining type inference functions

* Add type inference for pmap

* Change names to values for pipeline parallel transform

* Implement pmap's type prop fn, fix some others

* Fix Device constructor arg order

* Misc

* Fix tests

* Move shape inference to type inference

* Remove shape inference module

* Remove uses of deepcopy

* Rename Op.{in,out}_edges to Op.{in,out}puts

* Use ops instead of names for pipeline parallel utils

* Fix for error message when output types and outputs don't match

* Remove TODO for replacing value names with values

* Make type inference functions more robust

* Clean up op register

* Minor improvements to type prop fns

Co-authored-by: Keshav Santhanam <keshav2@stanford.edu>
This commit is contained in:
Siddharth Krishna 2021-01-12 22:16:32 +00:00 коммит произвёл GitHub
Родитель 36498ed684
Коммит 3c9905b659
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
41 изменённых файлов: 1703 добавлений и 1774 удалений

4
.github/workflows/tests.yml поставляемый
Просмотреть файл

@ -12,12 +12,14 @@ on:
jobs:
build:
runs-on: ubuntu-latest
env:
MLIR_VERSION: 20210111.38
PY_VERSION: 3.8
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:

Просмотреть файл

@ -40,40 +40,40 @@ python -m pytest
# Components
- Executors:
- SequentialExecutor: a reference implementation that runs a DistIR module
- SequentialExecutor: a reference implementation that runs a DistIR function
on a single device. Can be used to check correctness of transforms.
- DistributedSimulator: an executor that uses profile data or flop counts to
simulate the execution of a given DistIR module on a given hardware
simulate the execution of a given DistIR function on a given hardware
configuration (including communication bandwidths and processor speed).
Returns estimated execution time and live memory profile. This can be
split into three subcomponents:
- Shape Inference: a pass that uses the shapes of inputs to calculate
the shapes of all intermediate values.
- Cost Inference: a pass that uses either shape information to compute
(or profiles the module and measures) the runtime and temporary
memory requirement of each op in the module.
(or profiles the function and measures) the runtime and temporary
memory requirement of each op in the function.
This output can be cached.
- Simulator: takes a module and a mapping from op to time/memory
- Simulator: takes a function and a mapping from op to time/memory
consumption and does a simulation to obtain a concurrent trace
(from which total runtime and memory usage plots can be derived).
- Importers:
- ONNX Importer: convert a `.onnx` file to a DistIR module. Can be given an
- ONNX Importer: convert a `.onnx` file to a DistIR function. Can be given an
intermediate graph from ORT (for example, after AD).
- MLIR Importer: import a DistIR module written in MLIR text format to an
in-memory DistIR module object. TODO
- Exporter/Prettyprinter: converts a DistIR module to an MLIR text format string.
- MLIR Importer: import a DistIR function written in MLIR text format to an
in-memory DistIR function object. TODO
- Exporter/Prettyprinter: converts a DistIR function to an MLIR text format string.
- Transforms: a module containing DistIR->DistIR transforms.
Ideally, these modules should be composable and should run on submodules
Ideally, these transforms should be composable and should run on subfunctions
so that we can have nested parallelism (data parallel where a subset of the
layers are horizontal parallel, or pipeline parallel where each stage is
data parallel with a different degree).
- DataParallelTransform: converts a given DistIR module to a data-parallel
- DataParallelTransform: converts a given DistIR function to a data-parallel
version that runs on a given number of devices.
- HorizontalParallelTransform: converts a given DistIR module to a
- HorizontalParallelTransform: converts a given DistIR function to a
horizontal-parallel version (if possible) that runs on a given number of
devices.
- PipelineParallelTransform: converts a given DistIR module to a
- PipelineParallelTransform: converts a given DistIR function to a
pipeline-parallel version that runs on a given number of devices.
- Search: an algorithm to find the best distributed version of a given
sequential DistIR module. Initially, this can be something that searches
sequential DistIR function. Initially, this can be something that searches
the DHP space (i.e. find the optimal parameters to give the D/H/P transforms).

Просмотреть файл

@ -1,2 +1,3 @@
from .distributed_simulator import DistributedSimulator
from .sequential_executor import SequentialExecutor
from .type_inference import infer_types

Просмотреть файл

@ -140,7 +140,7 @@ class CostModel:
return costs
def infer_costs(self, op):
inputs = op.get_in_edges()
outputs = op.get_out_edges()
inputs = op.inputs
outputs = op.outputs
return self._op_register[op.op_type](op, inputs, outputs)

Просмотреть файл

@ -2,7 +2,7 @@ from copy import deepcopy
from collections import defaultdict
import json
from ..ir import Module
from ..ir import Function
from . import utils
SECONDS_TO_MICROSECONDS = 1e6
@ -43,11 +43,11 @@ class DistributedSimulator:
def __init__(self, cost_model):
self._cost_model = cost_model
def _simulate(self, module: Module, state: DistributedSimulatorState):
def _simulate(self, function: Function, state: DistributedSimulatorState):
for op_name, op in module.get_ops().items():
in_edges = op.get_in_edges()
out_edges = op.get_out_edges()
for op in function.ops:
in_edges = op.inputs
out_edges = op.outputs
# Synchronize all input and output devices for this op.
input_devices = utils.get_all_devices(in_edges)
@ -63,40 +63,40 @@ class DistributedSimulator:
# Compute the costs for the op.
if op.op_type == "Pmap":
# For Pmap ops we use a fresh state object and update the enclosing
# module state using the Pmap state.
submodule = op.get_submodule(0)
submodule_state = DistributedSimulatorState()
self._simulate(submodule, submodule_state)
device_vars = submodule_state.timestamps.keys()
# function state using the Pmap state.
subfunction = op.subfunctions[0]
subfunction_state = DistributedSimulatorState()
self._simulate(subfunction, subfunction_state)
device_vars = subfunction_state.timestamps.keys()
assert len(device_vars) == 1
# TODO what happens when pmaps are nested?
bound_devices = op.get_attribute("devices")
# Add submodule's trace to trace of all participating devices
bound_devices = op.attributes["devices"]
# Add subfunction's trace to trace of all participating devices
for device in bound_devices:
for event in submodule_state.trace:
for event in subfunction_state.trace:
# Need to add pmap's starting timestamp to event
# since submodule_state started at time 0
# since subfunction_state started at time 0
start_time = event["ts"] + state.timestamps[device]
state.add_trace_event(
event["name"], device, start_time, event["dur"]
)
for device_var in device_vars:
for bound_device in bound_devices:
state.timestamps[bound_device] += submodule_state.timestamps[
device_var
]
state.live_memory[bound_device] += submodule_state.live_memory[
device_var
]
state.peak_memory[bound_device] += submodule_state.peak_memory[
state.timestamps[bound_device] += subfunction_state.timestamps[
device_var
]
state.live_memory[
bound_device
] += subfunction_state.live_memory[device_var]
state.peak_memory[
bound_device
] += subfunction_state.peak_memory[device_var]
# TODO: Update consumers?
else:
costs = self._cost_model.infer_costs(op)
for device in costs:
state.add_trace_event(
op_name,
op.name,
device,
state.timestamps[device],
costs[device],
@ -105,9 +105,7 @@ class DistributedSimulator:
# Update the live memory.
for out_edge in out_edges:
state.consumers[out_edge] = len(
module.get_consumers_for_value(out_edge.name)
)
state.consumers[out_edge] = len(function.get_consumers(out_edge))
# Output value could live on multiple devices (e.g. scatter) so
# update memory on all devices:
output_devices = out_edge.type.get_all_devices()
@ -115,7 +113,9 @@ class DistributedSimulator:
state.live_memory[output_device] += out_edge.type.size()
# TODO: Can we optimize this using a priority queue?
for value in state.consumers:
if state.consumers[value] == 0 and not module.is_input(value.name):
if state.consumers[value] == 0 and all(
value != v for v in function.inputs
):
devices = value.type.get_all_devices()
for device in devices:
state.live_memory[device] -= value.type.size()
@ -126,7 +126,7 @@ class DistributedSimulator:
state.peak_memory[device], state.live_memory[device]
)
def simulate(self, module):
def simulate(self, function):
state = DistributedSimulatorState()
self._simulate(module, state)
self._simulate(function, state)
return state

Просмотреть файл

@ -12,25 +12,33 @@ def allreduce(op, inputs):
def broadcast(op, inputs):
return [inputs[0] for _ in range(len(op.get_attribute("devices")))]
return [inputs[0] for _ in range(len(op.attributes["devices"]))]
def concat(op, inputs):
dim = op.get_attribute("dim")
# assert len(inputs) == 1
# dim = op.attributes["dim"]
# return np.concatenate(inputs[0], axis=dim)
dim = op.attributes["dim"]
return np.concatenate(inputs, axis=dim)
def gather(op, inputs):
dim = op.attributes["dim"]
return np.concatenate(inputs[0], axis=dim)
def identity(op, inputs):
return inputs[0]
def loss(op, inputs):
N = op.get_attribute("N")
N = op.attributes["N"]
return np.square(inputs[0] - inputs[1]) / N
def loss_grad(op, inputs):
N = op.get_attribute("N")
N = op.attributes["N"]
return 2 * (inputs[0] - inputs[1]) / N
@ -47,16 +55,16 @@ def relu(op, inputs):
def select(op, inputs):
dim = op.get_attribute("dim")
dim = op.attributes["dim"]
return inputs[0][dim]
def split(op, inputs):
dim = op.get_attribute("dim")
dim = op.attributes["dim"]
if op.op_type == "Split":
num_splits = op.get_attribute("num_splits")
num_splits = op.attributes["num_splits"]
elif op.op_type == "Scatter":
num_splits = len(op.get_attribute("devices"))
num_splits = len(op.attributes["devices"])
return np.split(inputs[0], num_splits, axis=dim)
@ -66,7 +74,7 @@ NumPyRegister = {
"Allreduce": allreduce,
"Broadcast": broadcast,
"Concat": concat,
"Gather": concat,
"Gather": gather,
"Loss": loss,
"LossGrad": loss_grad,
"MatMul": matmul,

Просмотреть файл

@ -1,7 +1,7 @@
from typing import Any, Dict, List
from .backend_register import BackendRegister
from ..ir import Module, Op
from ..ir import Function, Op, Value
class SequentialExecutor:
@ -19,12 +19,11 @@ class SequentialExecutor:
# Iterate over the inputs
results = []
for inps in inputs:
# Execute submodule with appropriate inputs
inp_names = (e.name for e in op.get_submodule(0).get_inputs())
inp_data = {n: v for n, v in zip(inp_names, inps)}
outs = self.compute(op.get_submodule(0), inp_data)
# Match output names to output data using the module output order.
ordered_outs = [outs[e.name] for e in op.get_submodule(0).get_outputs()]
# Execute subfunction with appropriate inputs
inp_data = {k: v for k, v in zip(op.subfunctions[0].inputs, inps)}
outs = self.compute(op.subfunctions[0], inp_data)
# Match output names to output data using the function output order.
ordered_outs = [outs[e] for e in op.subfunctions[0].outputs]
results.append(ordered_outs)
# Unzip the results
results = tuple(zip(*results))
@ -39,56 +38,51 @@ class SequentialExecutor:
output_data = (output_data,)
return output_data
def compute(self, module: Module, input_data: Dict[str, Any]) -> Dict[str, Any]:
"""Executes the module given the specified inputs and returns the final result.
def compute(
self, function: Function, input_data: Dict[Value, Any]
) -> Dict[Value, Any]:
"""Executes the function given the specified inputs and returns the final result.
Args:
module: The module to execute.
input_data: A map from input tensor name to data represented in the
function: The function to execute.
input_data: A map from input value to data represented in the
specified backend.
Returns:
A map from output tensor name to output tensor.
A map from output value to output data.
"""
output_data = {}
consumers = {}
ops = module.get_ops()
# Execute ops in topological order.
for op_name, op in ops.items():
for op in function.ops:
inputs = []
in_edges = op.get_in_edges()
for in_edge in in_edges:
input_name = in_edge.name
if module.is_input(input_name):
input_name = in_edge.name
if input_name not in input_data:
for in_edge in op.inputs:
if in_edge in function.inputs:
if in_edge not in input_data:
raise ValueError(
f"Could not find input {input_name} in input_data"
f"Could not find input {in_edge} in input_data"
)
input_value = input_data[input_name]
elif input_name in output_data:
input_value = output_data[input_name]
consumers[input_name] -= 1
input_value = input_data[in_edge]
elif in_edge in output_data:
input_value = output_data[in_edge]
consumers[in_edge] -= 1
else:
raise ValueError(f"Invalid input {input_name} for op {op_name}")
raise ValueError(f"Invalid input {in_edge} for op {op}")
inputs.append(input_value)
res = self._compute_op(op, inputs)
out_edges = op.get_out_edges()
for i, out_edge in enumerate(out_edges):
output_data[out_edge.name] = res[i]
consumers[out_edge.name] = len(
module.get_consumers_for_value(out_edge.name)
)
for i, out_edge in enumerate(op.outputs):
output_data[out_edge] = res[i]
consumers[out_edge] = len(function.get_consumers(out_edge))
# Garbage collect the fully consumed output tensors.
to_free = []
for output_name in output_data:
if consumers[output_name] == 0 and not module.is_output(output_name):
to_free.append(output_name)
for output_name in to_free:
del output_data[output_name]
for out_edge in output_data:
if consumers[out_edge] == 0 and not out_edge in function.outputs:
to_free.append(out_edge)
for out_edge in to_free:
del output_data[out_edge]
# Return the outputs.
return output_data

Просмотреть файл

@ -1,202 +0,0 @@
from ..ir.type import Float
from ..ir.type import Tensor, TupleType
from ..ir.value import Value
from ..ir.device import Device
import copy
def _get_shapes(values):
shapes = []
for value in values:
if isinstance(value.type, Tensor):
shapes.append(value.type.shape)
else:
shapes.append(None)
return shapes
def _error_invalid_shapes(op, input_shapes):
raise ValueError(
f"Op {op.name} ({op.op_type}): Incompatible input shapes {input_shapes}"
)
def _infer_shapes_for_add(op, inputs, outputs):
# TODO: Handle input tensors with > 2 dimensions
input_shapes = _get_shapes(inputs)
if input_shapes[0] != input_shapes[1]:
_error_invalid_shapes(op, input_shapes)
output_shape = input_shapes[0]
output_type = Tensor(
dtype=inputs[0].type.dtype, shape=output_shape, device=inputs[0].type.device
)
outputs[0].type = output_type
def _infer_shapes_for_allreduce(op, inputs, outputs):
outputs[0].type = copy.deepcopy(inputs[0].type)
def _infer_shapes_for_broadcast(op, inputs, outputs):
input_type = inputs[0].type
devices = op.get_attribute("devices")
output_types = []
for (output_type, device) in zip(outputs[0].type.types, devices):
if isinstance(output_type, Tensor) and isinstance(input_type, Tensor):
output_type.shape = input_type.shape
output_type.set_device(device)
def _infer_shapes_for_concat(op, inputs, outputs):
input_shapes = _get_shapes(inputs)
dim = op.get_attribute("dim")
for i, (dim0, dim1) in enumerate(zip(input_shapes[0], input_shapes[1])):
if i != dim and dim0 != dim1:
_error_invalid_shapes(op, input_shapes)
output_shape = list(input_shapes[0])
output_shape[dim] += input_shapes[1][dim]
outputs[0].type.shape = output_shape
def _infer_shapes_for_gather(op, inputs, outputs):
dim = op.get_attribute("dim")
device = op.get_attribute("device")
output_shape = list(inputs[0].type.types[0].shape)
for typ in inputs[0].type.types[1:]:
output_shape[dim] += typ.shape[dim]
outputs[0].type.dtype = inputs[0].type.types[0].dtype
outputs[0].type.shape = output_shape
outputs[0].type.set_device(device)
def _infer_shapes_for_matmul(op, inputs, outputs):
# TODO: Handle input tensors with > 2 dimensions
input_shapes = _get_shapes(inputs)
if input_shapes[0][1] != input_shapes[1][0]:
_error_invalid_shapes(op, input_shapes)
output_shape = (input_shapes[0][0], input_shapes[1][1])
outputs[0].type = Tensor(
dtype=inputs[0].type.dtype, shape=output_shape, device=inputs[0].type.device
)
def _infer_shapes_for_matmul_grad(op, inputs, outputs):
for i, output in enumerate(outputs):
output.type = copy.deepcopy(inputs[i].type)
def _infer_shapes_for_loss(op, inputs, outputs):
input_shapes = _get_shapes(inputs)
if input_shapes[0] != input_shapes[1]:
_error_invalid_shapes(op, input_shapes)
outputs[0].type = copy.deepcopy(inputs[0].type)
def _infer_shapes_for_loss_grad(op, inputs, outputs):
input_shapes = _get_shapes(inputs)
if input_shapes[0] != input_shapes[1]:
_error_invalid_shapes(op, input_shapes)
outputs[0].type = copy.deepcopy(inputs[0].type)
def _infer_shapes_for_pmap(op, inputs, outputs):
submodule = op.get_submodule(0)
for (pmap_input, submodule_input) in zip(inputs, submodule.get_inputs()):
assert isinstance(pmap_input.type, TupleType)
# TODO check that all elements of the tuple have the same type and, if
# they are tensors, that they have the same shape
if isinstance(submodule_input.type, Tensor):
submodule_input.type.shape = pmap_input.type.types[0].shape
_infer_shapes(submodule)
for (pmap_output, submodule_output) in zip(outputs, submodule.get_outputs()):
if isinstance(submodule_output.type, Tensor):
assert isinstance(pmap_output.type, TupleType)
for pmap_output_type in pmap_output.type.types:
pmap_output_type.shape = submodule_output.type.shape
pmap_output_type.dtype = submodule_output.type.dtype
def _infer_shapes_for_scatter(op, inputs, outputs):
input_type = inputs[0].type
split_dim = op.get_attribute("dim")
devices = op.get_attribute("devices")
for (output_type, device) in zip(outputs[0].type.types, devices):
if isinstance(output_type, Tensor) and isinstance(input_type, Tensor):
output_shape = list(input_type.shape)
output_shape[split_dim] //= len(devices)
output_type.shape = tuple(output_shape)
output_type.set_device(device)
def _infer_shapes_for_select(op, inputs, outputs):
dim = op.get_attribute("dim")
outputs[0].type.shape = inputs[0].type.types[dim].shape
outputs[0].type.dtype = inputs[0].type.types[dim].dtype
def _infer_shapes_for_send(op, inputs, outputs):
outputs[0].type.shape = inputs[0].type.shape
outputs[0].type.dtype = inputs[0].type.dtype
def _infer_shapes_for_split(op, inputs, outputs):
num_splits = op.get_attribute("num_splits")
split_dim = op.get_attribute("dim")
output_shape = list(inputs[0].type.shape)
output_shape[split_dim] //= num_splits
for typ in outputs[0].type.types:
typ.shape = tuple(output_shape)
typ.dtype = inputs[0].type.dtype
ShapeInferenceRegister = {
"Add": _infer_shapes_for_add,
"Allreduce": _infer_shapes_for_allreduce,
"Broadcast": _infer_shapes_for_broadcast,
"Concat": _infer_shapes_for_concat,
"Gather": _infer_shapes_for_gather,
"Loss": _infer_shapes_for_loss,
"LossGrad": _infer_shapes_for_loss_grad,
"MatMul": _infer_shapes_for_matmul,
"MatMulGrad": _infer_shapes_for_matmul_grad,
"Pmap": _infer_shapes_for_pmap,
"Scatter": _infer_shapes_for_scatter,
"Select": _infer_shapes_for_select,
"Send": _infer_shapes_for_send,
"Split": _infer_shapes_for_split,
}
def _infer_shapes(module):
"""Helper function for inferring shapes.
Inputs:
module: The module to infer shapes for.
"""
for op_name, op in module.get_ops().items():
inputs = op.get_in_edges()
outputs = op.get_out_edges()
# Invariant: types and shapes of input are already inferred
for input in inputs:
assert input.type is not None
if isinstance(input.type, Tensor):
if input.type.shape is None:
raise ValueError(f"Input {input.name} of op {op_name} has no shape")
ShapeInferenceRegister[op.op_type](op, inputs, outputs)
# TODO maybe the register gives back the output types and we can check
# here if they match existing types (if any) and if not, replace them
def infer_shapes(module):
"""Infers shapes for the given module."""
_infer_shapes(module)

Просмотреть файл

@ -0,0 +1,310 @@
"""
A type inference module that converts an untyped DistIR Function into one where
every Value is typed with shape and dtype information, given input types or
example inputs.
Type inference requires a register mapping ops to type propagation functions:
- This is a function foo(op, x1, x2, .., xN), where op is an N-ary Op, and x1 to
xN are Types of the inputs.
- The function should check that the inputs have the expected types.
- The function should return the type of the output/a tuple of types of the
outputs.
(When we say types we also mean shape and device information.)
"""
from typing import Dict, List, Tuple
from ..ir import Device, Function, FunctionMaker, Op, Value
from ..ir.type import Type, Tensor, TupleType
def _raise_type_error(op, *args):
raise ValueError(f"Type error: op\n{op}\nwas given arguments\n{tuple(args)}")
def _allreduce_prop_fn(op, x):
devices = tuple(t.device for t in x.types)
if not (
isinstance(x, TupleType)
and all(isinstance(t, Tensor) for t in x.types)
and len(x.types) > 0
and all(t.shape == x.types[0].shape for t in x.types)
and len(set(devices)) == len(devices)
):
_raise_type_error(op, x)
return x
# TODO update the below prop functions to be as robust as _allreduce_prop_fn
def _broadcast_prop_fn(op, x):
if not isinstance(x, Tensor):
_raise_type_error(op, x)
devices = op.attributes["devices"]
return TupleType(
tuple(Tensor(dtype=x.dtype, shape=x.shape, device=device) for device in devices)
)
def _concat_prop_fn(op, x, y):
if not (
isinstance(x, Tensor)
and isinstance(y, Tensor)
and x.dtype == y.dtype
and x.device == y.device
):
_raise_type_error(op, x, y)
dim = op.attributes["dim"]
for i, (d0, d1) in enumerate(zip(x.shape, y.shape)):
if not i != dim and d0 != d1:
_raise_type_error(op, x, y)
output_shape = tuple(
n + (y.shape[i] if i == dim else 0) for i, n in enumerate(x.shape)
)
return Tensor(dtype=x.dtype, shape=output_shape, device=x.device)
def _elementwise_tensor_op_prop_fn(op, x, y):
if not (
isinstance(x, Tensor)
and isinstance(y, Tensor)
and x.dtype == y.dtype
and x.shape == y.shape
and x.device == y.device
):
_raise_type_error(op, x, y)
return x
def _gather_prop_fn(op, x):
if not (
isinstance(x, TupleType)
and all(isinstance(t, Tensor) for t in x.types)
and len(set(t.shape for t in x.types)) == 1
and len(set(t.dtype for t in x.types)) == 1
and len(x.types) > 0
):
_raise_type_error(op, x)
dim = op.attributes["dim"]
device = op.attributes["device"]
output_shape = list(x.types[0].shape)
for i in range(1, len(x.types)):
for j in range(len(x.types[i].shape)):
if j == dim:
output_shape[j] += x.types[i].shape[j]
elif x.types[i].shape[j] != x.types[0].shape[j]:
_raise_type_error(op, x)
output_shape = tuple(output_shape)
return Tensor(dtype=x.types[0].dtype, shape=output_shape, device=device)
def _matmul_prop_fn(op, x, y):
if not (
isinstance(x, Tensor)
and isinstance(y, Tensor)
and x.dtype == y.dtype
and x.device == y.device
and x.shape[1] == y.shape[0]
):
_raise_type_error(op, x, y)
return Tensor(dtype=x.dtype, shape=(x.shape[0], y.shape[1]), device=x.device)
def _matmul_grad_prop_fn(op, x, y, z):
# TODO: Check that shapes can be multipled together?
if not (
isinstance(x, Tensor)
and isinstance(y, Tensor)
and isinstance(z, Tensor)
and x.dtype == y.dtype
and x.dtype == z.dtype
and x.device == y.device
and x.device == z.device
):
_raise_type_error(op, x, y, z)
return (x, y)
def _scatter_prop_fn(op, x):
if not isinstance(x, Tensor):
_raise_type_error(op, x)
devices = op.attributes["devices"]
# Check devices is a list of distinct Devices
assert isinstance(devices, list) and all(isinstance(d, Device) for d in devices)
assert len(devices) == len(set(devices))
dim = op.attributes["dim"]
# TODO: Should we add another function to raise an attribute error?
assert x.shape[dim] % len(devices) == 0
output_shape = list(x.shape)
output_shape[dim] //= len(devices)
output_shape = tuple(output_shape)
return TupleType(
tuple(
Tensor(dtype=x.dtype, shape=output_shape, device=device)
for device in devices
)
)
def _select_prop_fn(op, x):
if not (
isinstance(x, TupleType)
and all(isinstance(t, Tensor) for t in x.types)
and len(x.types) > 0
and all(t.shape == x.types[0].shape for t in x.types)
and len(set(t.device for t in x.types)) == 1
):
_raise_type_error(op, x)
dim = op.attributes["dim"]
return x.types[dim]
def _send_prop_fn(op, x):
if not isinstance(x, Tensor):
_raise_type_error(op, x)
device = op.attributes["device"]
return Tensor(dtype=x.dtype, shape=x.shape, device=device)
def _split_prop_fn(op, x):
if not isinstance(x, Tensor):
_raise_type_error(op, x)
num_splits = op.attributes["num_splits"]
split_dim = op.attributes["dim"]
output_shape = list(x.shape)
# TODO: Move this check to attribute error function?
assert output_shape[split_dim] % num_splits == 0
output_shape[split_dim] //= num_splits
output_shape = tuple(output_shape)
return TupleType(
tuple(
Tensor(dtype=x.dtype, shape=output_shape, device=x.device)
for i in range(num_splits)
)
)
TypePropRegister = {
"Add": _elementwise_tensor_op_prop_fn,
# "Allgather": TODO,
"Allreduce": _allreduce_prop_fn,
"Broadcast": _broadcast_prop_fn,
"Concat": _concat_prop_fn,
"Gather": _gather_prop_fn,
"Loss": _elementwise_tensor_op_prop_fn,
"LossGrad": _elementwise_tensor_op_prop_fn,
"MatMul": _matmul_prop_fn,
"MatMulGrad": _matmul_grad_prop_fn,
"Scatter": _scatter_prop_fn,
"Select": _select_prop_fn,
"Send": _send_prop_fn,
"Split": _split_prop_fn,
}
# Handling pmap specially for now since it needs to return a typed subfunction
def _pmap_prop_fn(op: Op, input_types: Tuple[Type]):
"""pmap maps over a tuple of values, all of the same type and shape, but on
distinct devices. For convenience, to avoid a lot of zipping, we allow
multiple inputs, as long as they are all tuples of the same length and the
list of devices in each tuple are exactly the same.
"""
# Pmap expects 1 or more tuples as input
assert isinstance(input_types, tuple) and len(input_types) > 0
assert all(isinstance(t, TupleType) for t in input_types)
# Check that pmap's arguments all have same length and shapes
assert len(set(len(t.types) for t in input_types)) == 1
for t in input_types:
assert all(isinstance(x, Tensor) for x in t.types)
assert len(set(x.shape for x in t.types)) == 1
assert len(set(x.dtype for x in t.types)) == 1
# Check that pmap's arguments are on distinct devices
devices = tuple(x.device for x in input_types[0].types)
assert len(set(devices)) == len(devices)
# Check that all inputs have same list of devices
for t in input_types:
assert devices == tuple(x.device for x in t.types)
# Subfunction's inputs are given by pmap's arguments, but on device d
subfn_inputs = [
Value(v.name, t.types[0])
for v, t in zip(op.subfunctions[0].inputs, input_types)
]
# Recursively call infer_types on subfunction
assert len(op.subfunctions) == 1
subfunctions = [infer_types(op.subfunctions[0], subfn_inputs)]
# Pmap's output types are given by subfunction's output types
out_types = tuple(
TupleType(
tuple(
Tensor(shape=t.type.shape, dtype=t.type.dtype, device=d)
for d in devices
)
)
for t in subfunctions[0].outputs
)
return out_types, subfunctions
def infer_types(function: Function, inputs: List[Value]) -> Function:
"""Given a function and a list of input values, returns a new function where
all values are typed.
inputs: a list/tuple of Values, of the same length as function.inputs, but
the names are irrelevant.
"""
new_function = FunctionMaker()
# A Map from function's values to new_function's (typed) values:
value_map: Dict[Value, Value] = {}
def assert_is_typed(v: Value):
assert v.type is not None
if isinstance(v.type, Tensor):
if v.type.shape is None:
raise ValueError(f"Expected Value {v} to have a shape")
# Add inputs to new_function
assert len(inputs) == len(function.inputs)
for old_inp, inp in zip(function.inputs, inputs):
assert_is_typed(inp)
new_inp = new_function.add_input_value(old_inp.name, inp.type)
value_map[old_inp] = new_inp
op: Op # https://stackoverflow.com/q/59102038
for op in function.ops:
# Invariant: inputs of op are already typed (as ops are toposorted)
typed_inputs = tuple(value_map[inp] for inp in op.inputs)
input_types = tuple(v.type for v in typed_inputs)
# Infer types of outputs and create output values
if op.op_type == "Pmap":
out_types, subfunctions = _pmap_prop_fn(op, input_types)
else:
out_types = TypePropRegister[op.op_type](op, *input_types)
if not isinstance(out_types, tuple):
assert isinstance(out_types, Type)
out_types = (out_types,)
subfunctions = []
new_op = Op(
op.op_type,
op.name,
typed_inputs,
op.attributes,
subfunctions,
tuple(v.name for v in op.outputs),
out_types,
)
new_function.ops.append(new_op)
# Add op's outputs to value_map
for old_out, out in zip(op.outputs, new_op.outputs):
assert_is_typed(out)
value_map[old_out] = out
return new_function.finalize()

Просмотреть файл

@ -1,10 +1,10 @@
from dataclasses import dataclass, field
import json
import re
from typing import Any, Dict, Union
from typing import Any, Dict, List, Union
import mlir
from ..ir import cpprint, Module, Value
from ..ir import Function, FunctionMaker, Value
from ..ir.device import Device
from ..ir.type import Float, Tensor
@ -69,8 +69,8 @@ def _parse_type(mlir_type, context: Context):
raise ValueError(f"Unknown MLIR type {mlir_type}")
def _parse_module(mlir_region, context=None):
"""Creates a DistIR Module out of an MLIR region. The region must be either
def _parse_function(mlir_region, context: Context = None) -> Function:
"""Creates a DistIR Function out of an MLIR region. The region must be either
the single region in a function, or the sub-region of a pmap operator.
"""
if context is None:
@ -79,11 +79,11 @@ def _parse_module(mlir_region, context=None):
assert len(mlir_region.blocks) == 1
entry_block = mlir_region.blocks[0]
module = Module()
function = FunctionMaker()
# Find the inputs
for arg in entry_block.arguments:
v = module.add_input_value(
v = function.add_input_value(
f"%arg{arg.arg_number}", _parse_type(arg.type, context)
)
assert str(arg) not in context.values
@ -121,11 +121,11 @@ def _parse_module(mlir_region, context=None):
# Create output names (TODO should be done by Op.__init__ or add_op)
output_names = [_make_fresh_var(context) for _ in op.results]
submodules = []
subfunctions = []
if op_name == "std.return" or op_name == "dist.return":
# Set return values as module outputs
module.set_outputs(args)
# Set return values as function outputs
function.set_outputs(args)
returned = True
continue
if op_name == "dist.pmap":
@ -134,18 +134,18 @@ def _parse_module(mlir_region, context=None):
new_device = Device.get_new_device_variable("gpu")
assert attributes["device_var"] not in context.devices
context.devices[attributes["device_var"]] = new_device
# Parse the submodule
submodules.append(_parse_module(op.regions[0], context))
# Remove device var from context.devices as it is only in scope for submodule
# Parse the subfunction
subfunctions.append(_parse_function(op.regions[0], context))
# Remove device var from context.devices as it is only in scope for subfunction
del context.devices[attributes["device_var"]]
# Create an op and add it to the module
outs = module.add_op(
# Create an op and add it to the function
outs = function.add_op(
op_name,
inputs=args,
attributes=attributes,
output_names=output_names,
submodules=submodules,
subfunctions=subfunctions,
)
if not isinstance(outs, tuple):
outs = (outs,)
@ -156,22 +156,21 @@ def _parse_module(mlir_region, context=None):
assert str(mlirval) not in context.values
context.values[str(mlirval)] = val
module.finalize()
return module
return function.finalize()
def _parse_mlir_module(mlir_module):
modules = []
def _parse_mlir_module(mlir_module) -> List[Function]:
functions = []
for func in mlir_module.body.operations:
if str(func) == "module_terminator":
break
assert len(func.regions) == 1
modules.append(_parse_module(func.regions[0]))
return modules
functions.append(_parse_function(func.regions[0]))
return functions
def parse_mlir_str(mlir_str):
def parse_mlir_str(mlir_str: str) -> List[Function]:
ctx = mlir.ir.Context()
ctx.allow_unregistered_dialects = True
@ -179,7 +178,7 @@ def parse_mlir_str(mlir_str):
return _parse_mlir_module(mlir_module)
def parse_mlir_file(filename):
def parse_mlir_file(filename) -> List[Function]:
with open(filename, "r") as fin:
mlir_str = fin.read()

Просмотреть файл

@ -1,6 +1,6 @@
import onnx
from ..ir import Module, Value
from ..ir import FunctionMaker, Value
from ..ir.type import Tensor, Float
@ -8,14 +8,14 @@ def import_from_onnx(onnx_model):
# TODO: Remove prints?
# TODO: Support types beyond Tensor
onnx_model = onnx.load(onnx_model)
dist_ir_module = Module()
dist_ir_function = FunctionMaker("foo") # TODO get name?
inputs = {}
output_src = {}
def add_input(value):
# TODO lookup shape and dtype of input if exists
v = dist_ir_module.add_input_value(value.name, Tensor(Float()))
v = dist_ir_function.add_input_value(value.name, Tensor(Float()))
inputs[value.name] = v
for value in onnx_model.graph.value_info:
@ -41,21 +41,26 @@ def import_from_onnx(onnx_model):
else:
print(f"---> Could not find input {value}!")
# TODO do something better here
v = dist_ir_module.add_input_value(value, Tensor(Float()))
v = dist_ir_function.add_input_value(value, Tensor(Float()))
inputs[value] = v
per_node_inputs.append(v)
print()
output_names = node.output
op = dist_ir_module.add_op(
outputs = dist_ir_function.add_op(
op_type=node.op_type,
name=node.name,
inputs=per_node_inputs,
output_names=output_names,
)
for output in node.output:
# TODO lookup shape and dtype of input if exists
v = Value(output, Tensor(Float()))
output_src[output] = v
# Match node's outputs with the output Values created in op:
if len(node.output) == 1:
assert isinstance(outputs, Value)
outputs = [outputs]
else:
assert len(outputs) == len(node.output)
for out_name, value in zip(node.output, outputs):
assert out_name == value.name
output_src[out_name] = value
print(f"Found output {out_name}")
print()
dist_ir_module.verify_ops_in_topological_order()
return dist_ir_module
return dist_ir_function.finalize()

Просмотреть файл

@ -1,5 +1,5 @@
from .module import Module
from .device import Device
from .function import Function, FunctionMaker
from .op import Op
from .prettyprint import cpprint, pformat
from .topology import Topology

Просмотреть файл

@ -1,38 +1,21 @@
from dataclasses import dataclass
from typing import ClassVar
@dataclass(frozen=True)
class Device:
device_variable_id = 0
device_id: str
device_type: str
is_variable: bool = False
def __init__(self, device_id, device_type, is_variable=False):
self._device_id = device_id
self._device_type = device_type
self._is_variable = is_variable
device_variable_id: ClassVar[int] = 0
def __str__(self):
return f"{self._device_id} ({self._device_type})"
def __repr__(self):
return str(self)
def __eq__(self, other):
return other is not None and self._device_id == other._device_id
return f"{self.device_id} ({self.device_type})"
def __lt__(self, other):
return self._device_id < other._device_id
def __hash__(self):
return hash(str(self))
@property
def device_id(self):
return self._device_id
@property
def device_type(self):
return self._device_type
@property
def is_variable(self):
return self._is_variable
return self.device_id < other.device_id
@classmethod
def get_new_device_variable(cls, device_type):

259
dist_ir/ir/function.py Normal file
Просмотреть файл

@ -0,0 +1,259 @@
from __future__ import annotations
from collections import OrderedDict, defaultdict
import copy
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from frozendict import frozendict
from .op import Op
from .value import Value
@dataclass(frozen=True)
class Function:
"""The core DistIR concept: a function.
A function has a name, a list of input values, a list of operations, and a
list of output values. Functions are immutable.
"""
name: str
ops: Tuple[Op]
inputs: Tuple[Value]
outputs: Tuple[Value]
# Map from Value -> List of Ops that consume it
_consumers: Dict[Value, Tuple[Op]] = field(init=False)
def __post_init__(self):
"""Creates the _consumers map, verifies the function, and performs
type inference. This is called automatically at initialization.
"""
consumers = defaultdict(list)
for op in self.ops:
for in_edge in op.inputs:
consumers[in_edge].append(op)
for out_edge in op.outputs:
consumers[out_edge] = []
for v in consumers:
consumers[v] = tuple(consumers[v])
# Can't assign to frozen field:
object.__setattr__(self, "_consumers", frozendict(consumers))
# Check that ops don't use values from the future
self._verify_ops_in_topological_order()
def _verify_ops_in_topological_order(self):
seen = set()
for inp in self.inputs:
seen.add(inp)
for op in self.ops:
for in_edge in op.inputs:
if in_edge not in seen:
raise ValueError(
f"Ops are not in topological order: op {op.name} has "
f"unseen edge {in_edge}"
)
for out_edge in op.outputs:
seen.add(out_edge)
def get_consumers(self, value: Value) -> List[Op]:
return self._consumers[value]
def __str__(self):
# TODO can we use the prettyprint output as __str__?
return self.get_summary()
def get_summary(self):
output = ""
output += "Function inputs:\n"
for input_value in self.inputs:
output += " " + str(input_value) + "\n"
output += "\n"
output += "Function outputs:\n"
for input_value in self.outputs:
output += " " + str(input_value) + "\n"
output += "\n"
output += "Ops:\n"
for op in self.ops:
output += str(op) + "\n"
return output
def has_input(self, value):
"""Checks whether the given value is an input of this function."""
return value in self.inputs
def has_output(self, value):
"""Checks whether the given value is an output of this function."""
return value in self.outputs
def get_subfunction(
self, op_names: Tuple[str], deepcopy: bool = False, name: Optional[str] = None
) -> Function:
"""Returns a Function comprised of the specified subset of ops."""
subfunction = FunctionMaker(name)
op_names_set = set(op_names)
ops = []
for op in self.ops:
if op.name in op_names_set:
ops.append(op)
value_map = {}
outputs = []
for op in ops:
subfunction_op_inputs = []
for inp in op.inputs:
if inp not in value_map:
if deepcopy:
value_map[inp] = subfunction.add_input_value(inp.name, inp.type)
else:
subfunction.inputs.append(inp)
value_map[inp] = inp
subfunction_op_inputs.append(value_map[inp])
output_names = [output.name for output in op.outputs]
if deepcopy:
subfunction_op_outputs = subfunction.add_op(
op.op_type,
name=op.name,
inputs=subfunction_op_inputs,
attributes=copy.deepcopy(op.attributes),
subfunctions=copy.deepcopy(op.subfunctions),
output_names=output_names,
)
else:
subfunction.ops.append(op)
subfunction_op_outputs = op.outputs
if not isinstance(subfunction_op_outputs, tuple):
subfunction_op_outputs = (subfunction_op_outputs,)
for orig_output, subfunction_output in zip(
op.outputs, subfunction_op_outputs
):
# We need to explicitly set the subfunction outputs because some output
# values might have consumers outside the subfunction (external).
has_external_output = False
if orig_output in self.outputs or any(
[c not in ops for c in self._consumers[orig_output]]
):
outputs.append(subfunction_output)
value_map[orig_output] = subfunction_output
subfunction.set_outputs(outputs)
return subfunction.finalize()
@dataclass
class FunctionMaker:
"""A helper class for creating Functions."""
name: str = "foo"
ops: List[Op] = field(default_factory=list)
inputs: List[Value] = field(default_factory=list)
outputs: List[Value] = field(default_factory=list)
def add_op(
self,
op_type,
name=None,
inputs: List[Value] = None,
attributes: Dict[str, Any] = None,
subfunctions: List["Function"] = None,
output_names: List[str] = None,
) -> Union[None, Value, Tuple[Value, ...]]:
"""Adds an op to the function.
Args:
op_type: The op's type.
name: The op's name.
inputs: The input values for this op.
attributes: Any op-specific attributes.
subfunctions: Any subfunctions this op is wrapping.
output_names: An optional list of output value names.
Returns:
The outputs of the newly created op.
"""
op = Op(
op_type,
name=name,
inputs=None if inputs is None else tuple(inputs),
attributes=None if attributes is None else frozendict(attributes),
subfunctions=None if subfunctions is None else tuple(subfunctions),
output_names=None if output_names is None else tuple(output_names),
)
self.ops.append(op)
# Return the op outputs.
num_out_edges = len(op.outputs)
if num_out_edges == 0:
return None
elif num_out_edges == 1:
return op.outputs[0]
else:
return tuple(op.outputs)
def add_input_value(self, name, value_type):
"""Adds an input value to the function and returns the value."""
value = Value(name, value_type)
if value in self.inputs:
raise ValueError(f"Function already has input value {value}")
self.inputs.append(value)
return value
def set_outputs(self, outputs: Iterable[Value]):
"""Sets the output of this function to be the given values. They must be
valid values, i.e. outputs of some existing op in the function. This clears
any previous outputs registered with this function.
"""
self.outputs.clear()
seen = set()
for output in outputs:
if output in seen:
raise ValueError(f"Function already has output value {output}")
seen.add(output)
self.outputs.append(output)
def set_outputs_auto(self):
"""Marks all sink nodes in the graph as output values."""
is_output = OrderedDict()
self.outputs.clear()
for input_value in self.inputs:
is_output[input_value] = True
for op in self.ops:
for in_edge in op.inputs:
is_output[in_edge] = False
for out_edge in op.outputs:
is_output[out_edge] = True
self.outputs = [v for v in is_output if is_output[v]]
def _get_ops_in_topological_order_helper(self, name, visited, order):
visited.add(name)
out_edges = self.ops[name].outputs
for out_edge in out_edges:
output_name = out_edge
if output_name not in visited:
self._get_ops_in_topological_order_helper(output_name, visited, order)
order.append(name)
def get_ops_in_topological_order(self):
"""Return ops in topological order. DEPRECATED, ops should always be
topologically ordered.
"""
visited = set()
order = []
for name in self.ops:
if name not in visited:
self._get_ops_in_topological_order_helper(name, visited, order)
return order[::-1]
def finalize(self) -> Function:
"""Returns the created Function. Outputs, if unspecified, are the sinks."""
if len(self.outputs) == 0:
self.set_outputs_auto()
return Function(
self.name, tuple(self.ops), tuple(self.inputs), tuple(self.outputs)
)

Просмотреть файл

@ -1,294 +0,0 @@
from collections import OrderedDict, defaultdict
import copy
from typing import Any, Dict, Iterable, List, Tuple, Union
from .op import Op
from .value import Value
class Module:
def __init__(self, name=None):
self._ops = OrderedDict()
self._inputs = OrderedDict()
self._outputs = OrderedDict()
self._op_counter = defaultdict(int)
self._consumers = defaultdict(list)
self._name = name
self._hash = None
def __str__(self):
if self._name is not None:
return self._name
else:
return self.get_summary()
def __repr__(self):
return self.get_summary()
def __hash__(self):
if self._hash is None:
raise RuntimeError("Cannot hash unfinalized module!")
return self._hash
def __eq__(self, other):
for op_name in self._ops:
if op_name not in other._ops or self._ops[op_name] != other._ops[op_name]:
return False
for input_name in self._inputs:
if (
input_name not in other._inputs
or self._inputs[input_name] != other._inputs[input_name]
):
return False
for output_name in self._outputs:
if (
output_name not in other._outputs
or self._outputs[output_name] != other._outputs[output_name]
):
return False
return True
def get_summary(self):
output = ""
output += "Module inputs:\n"
for input_value in self._inputs.values():
output += " " + str(input_value) + "\n"
output += "\n"
output += "Module outputs:\n"
for input_value in self._outputs.values():
output += " " + str(input_value) + "\n"
output += "\n"
output += "Ops:\n"
for op in self._ops.values():
output += str(op) + "\n"
return output
# TODO: Convert to property
def get_ops(self):
"""Returns all ops in the module."""
return self._ops
def is_op(self, name):
"""Checks whether a op exists with the specified name."""
return name in self._ops
def is_input(self, name):
"""Checks whether an input value exists with the specified name."""
return name in self._inputs
def is_output(self, name):
"""Checks whether an output value exists with the specified name."""
return name in self._outputs
def get_op(self, name):
"""Returns the op with the specified name if it exists."""
if name not in self._ops:
return None
return self._ops[name]
def get_input(self, name):
"""Returns the input value with the specified name if it exists."""
if name not in self._inputs:
return None
return self._inputs[name]
def get_inputs(self):
"""Returns the module inputs."""
return self._inputs.values()
def get_outputs(self):
"""Returns the module outputs."""
return self._outputs.values()
def add_op(
self,
op_type,
name=None,
inputs: List[Value] = None,
attributes: Dict[str, Any] = None,
submodules: List["Module"] = None,
output_names: List[str] = None,
) -> Union[None, Value, Tuple[Value, ...]]:
"""Adds an op to the graph.
Args:
op_type: The op's type.
name: The op's name.
inputs: The input values for this op.
attributes: Any op-specific attributes.
submodules: Any submodules this op is wrapping.
output_names: An optinal list of output value names.
Returns:
The outputs of the newly created op.
"""
if name in self._ops:
raise ValueError(f"op with name {name} already exists!")
elif name is None or name == "":
name = f"{op_type}_#{self._op_counter[op_type]}"
op = Op(
name,
op_type,
in_edges=inputs,
attributes=attributes,
submodules=submodules,
output_names=output_names,
)
self._ops[name] = op
self._op_counter[op_type] += 1
# Update _consumers.
out_edges = op.get_out_edges()
for in_edge in inputs:
self._consumers[in_edge.name].append(op.name)
for out_edge in out_edges:
self._consumers[out_edge.name] = []
# Return the op outputs.
num_out_edges = len(out_edges)
if num_out_edges == 0:
return None
elif num_out_edges == 1:
return out_edges[0]
else:
return tuple(out_edges)
def add_input_value(self, name, value_type):
"""Adds an input value to the graph and returns the value."""
value = Value(name=name, value_type=value_type)
if value.name in self._inputs:
raise ValueError(f"Module already has input value with name {value.name}")
self._inputs[value.name] = value
return value
def get_consumers_for_value(self, name):
return self._consumers[name]
def set_outputs(self, outputs: Iterable[Value]):
"""Sets the output of this module to be the given values. They must be
valid values, i.e. outputs of some existing op in the module. This clears
any previous outputs registered with this module.
"""
for output in outputs:
# NOTE: Using consumers as a proxy for valid values
if output.name not in self._consumers:
raise ValueError(f"Module has no value {output}")
self._outputs.clear()
for output in outputs:
if output.name in self._outputs:
raise ValueError(
f"Module already has output value with name {output.name}"
)
self._outputs[output.name] = output
def set_outputs_auto(self):
"""Marks all sink nodes in the graph as output values."""
all_values = OrderedDict()
is_output = OrderedDict()
self._outputs.clear()
for input_value_name, input_value in self._inputs.items():
all_values[input_value_name] = input_value
is_output[input_value_name] = True
for op in self._ops.values():
for in_edge in op.get_in_edges():
is_output[in_edge.name] = False
for out_edge in op.get_out_edges():
all_values[out_edge.name] = out_edge
is_output[out_edge.name] = True
for output_value_name in is_output:
if is_output[output_value_name]:
self._outputs[output_value_name] = all_values[output_value_name]
def _get_ops_in_topological_order_helper(self, name, visited, order):
visited.add(name)
out_edges = self._ops[name].get_out_edges()
for out_edge in out_edges:
output_name = out_edge
if output_name not in visited:
self._get_ops_in_topological_order_helper(output_name, visited, order)
order.append(name)
def get_ops_in_topological_order(self):
"""Return ops in topological order. DEPRECATED, ops should always be
topologically ordered.
"""
visited = set()
order = []
for name in self._ops:
if name not in visited:
self._get_ops_in_topological_order_helper(name, visited, order)
return order[::-1]
def verify_ops_in_topological_order(self):
seen = set()
for input_name in self._inputs:
seen.add(input_name)
for op_name, op in self._ops.items():
for in_edge in op.get_in_edges():
if in_edge.name not in seen:
raise ValueError(
f"Ops are not in topological order: op {op_name} has "
f"unseen edge {in_edge}"
)
for out_edge in op.get_out_edges():
seen.add(out_edge.name)
def finalize(self):
"""Performs some standard verification and inference passes. Use at the
end whenever creating a module. Assumes that the module will no longer be
modified after this function is called.
"""
# Putting this import at the top level causes an import loop
from ..executor.shape_inference import infer_shapes
self.verify_ops_in_topological_order()
if len(self._outputs) == 0:
self.set_outputs_auto()
infer_shapes(self)
self._hash = hash(tuple(self._ops.keys()))
def get_submodule(self, op_names, name=None):
"""Returns a submodule comprised of the specified subset of ops."""
submodule = Module(name)
value_map = {}
outputs = []
op_names_set = set(op_names)
for op_name in op_names:
op = self._ops[op_name]
submodule_op_inputs = []
for input in op.get_in_edges():
if input.name not in value_map:
value_map[input.name] = submodule.add_input_value(
input.name, input.type
)
submodule_op_inputs.append(value_map[input.name])
output_names = [output.name for output in op.get_out_edges()]
submodule_op_outputs = submodule.add_op(
op.op_type,
name=op.name,
inputs=submodule_op_inputs,
attributes=op._attributes,
submodules=copy.deepcopy(op._submodules),
output_names=output_names,
)
if not isinstance(submodule_op_outputs, tuple):
submodule_op_outputs = (submodule_op_outputs,)
for output in submodule_op_outputs:
# We need to explicitly set the submodule outputs because some output
# values might have consumers outside the submodule (external).
consumers = self._consumers[output.name]
has_external_output = any([c not in op_names_set for c in consumers])
if (
output.name in self._outputs or has_external_output
) and output.name not in op_names:
outputs.append(output)
value_map[output.name] = output
submodule.set_outputs(outputs)
submodule.finalize()
return submodule

Просмотреть файл

@ -1,86 +1,59 @@
from dataclasses import dataclass, field, InitVar
from typing import Any, Dict, List, Tuple
from frozendict import frozendict
from .op_register import OpRegister
from .type import *
from .value import Value
from .type import Type
@dataclass(frozen=True)
class Op:
def __init__(
self,
name,
op_type,
in_edges=None,
attributes=None,
submodules=None,
output_names=None,
):
if op_type not in OpRegister:
raise ValueError(f"Invalid op type {op_type}")
self._name = name
self._op_type = op_type
if in_edges is None:
self._in_edges = []
op_type: str
name: str = ""
inputs: Tuple[Value] = field(default_factory=tuple)
attributes: Dict[str, Any] = field(default_factory=frozendict)
subfunctions: Tuple["Function"] = field(default_factory=tuple)
outputs: Tuple[Value] = field(init=False)
# These are not fields, just parameters to init and post_init:
output_names: InitVar[Tuple[str]] = None
output_types: InitVar[Tuple[Type]] = None
def __post_init__(self, output_names, output_types):
if self.op_type == "Pmap":
# Handle pmap specially
assert len(self.subfunctions) == 1
# Number of inputs is arbitrary but positive
assert len(self.inputs) > 0
# Number of inputs matches subfunction
assert len(self.inputs) == len(self.subfunctions[0].inputs)
# Number of outputs is given by subfunction
num_outputs = len(self.subfunctions[0].outputs)
else:
self._in_edges = in_edges
if attributes is None:
self._attributes = {}
if self.op_type not in OpRegister:
raise ValueError(f"Invalid op type {self.op_type}")
# Check that we got the right number of inputs
assert len(self.inputs) == OpRegister[self.op_type].num_inputs
# Number of outputs is given by OpRegister
num_outputs = OpRegister[self.op_type].num_outputs
# Create the correct number of output values with appropriate types
if output_names is None:
output_names = [f"{self.name}_out_{i}" for i in range(num_outputs)]
else:
self._attributes = attributes
if submodules is None:
self._submodules = []
else:
self._submodules = submodules
self._out_edges = []
OpRegister[op_type].infer_types(self, output_names)
def __str__(self):
output = ""
output += f"Name: {self._name}\n"
output += f"Op type: {self._op_type}\n"
output += "Inputs:\n"
for in_edge in self._in_edges:
output += " " + str(in_edge) + "\n"
output += "Outputs:\n"
for out_edge in self._out_edges:
output += " " + str(out_edge) + "\n"
if len(self._submodules) > 0:
output += "Submodules:\n"
for submodule in self._submodules:
output += "\n".join(
[" " + line for line in str(submodule).split("\n")]
)
return output
def __repr__(self):
return str(self)
def add_in_edge(self, in_edge: Value):
"""Adds an input edge."""
self._in_edges.append(in_edge)
def add_out_edge(self, out_edge: Value):
"""Adds an output edge."""
self._out_edges.append(out_edge)
def get_in_edges(self):
"""Returns all input edges."""
return self._in_edges
def get_out_edges(self):
"""Returns all output edges."""
return self._out_edges
def get_attribute(self, attribute_name):
"""Returns the specified attributes, or throws error if it does not exist."""
return self._attributes[attribute_name]
def get_submodule(self, idx):
"""Returns the submodule at the specified index."""
return self._submodules[idx]
@property
def name(self):
return self._name
@property
def op_type(self):
return self._op_type
assert len(output_names) == num_outputs
if output_types is None:
output_types = [None for i in range(num_outputs)]
elif len(output_types) != num_outputs:
raise ValueError(
f"Op {self.name} has {len(output_types)} outputs; "
f"num_outputs expected"
)
outputs = tuple(
Value(out_name, out_type)
for out_name, out_type in zip(output_names, output_types)
)
object.__setattr__(self, "outputs", outputs) # Can't assign to frozen field

Просмотреть файл

@ -1,193 +1,35 @@
from .device import Device
from .type import Tensor, TupleType
from .value import Value
import copy
from dataclasses import dataclass
@dataclass(frozen=True)
class OpRegisterEntry:
def __init__(self, input_types, output_types):
self._input_types = input_types
self._output_types = output_types
def infer_types(self, op, output_names=None):
# Verify that number of inputs and input types match the expected input types.
inputs = op.get_in_edges()
if len(inputs) != len(self._input_types):
raise ValueError(
f"Op {op.name}: Expected {len(self._input_types)} inputs, "
f"got {len(inputs)}"
)
for i, (input, input_type) in enumerate(zip(inputs, self._input_types)):
if not isinstance(input.type, input_type):
raise ValueError(
f"Op {op.name}: Expected input of type {input_type} for "
f"input {i}, got input of type {input.type}"
)
# Verify that the number of output names is correct if specified.
if output_names is not None and len(output_names) != len(self._output_types):
raise ValueError(
f"Op {op.name}: Expected {len(output_names)} outputs, "
f"got {len(self._output_types)}"
)
# Construct the output values and add them to the op's out edge list.
for i, output_type in enumerate(self._output_types):
if output_names is not None and output_names[i] != "":
output_name = output_names[i]
else:
output_name = f"{op.name}/{i}"
op.add_out_edge(Value(output_name, value_type=output_type()))
class AllreduceOpRegisterEntry(OpRegisterEntry):
# TODO: Remove this and handle generic types in OpRegisterEntry
def infer_types(self, op, output_names=None):
inputs = op.get_in_edges()
if len(inputs) != 1:
raise ValueError(f"Op {op.name}: Expected 1 input, got {len(inputs)}")
elif not isinstance(inputs[0].type, TupleType):
raise ValueError(
f"Op {op.name}: Expected input of type {self._input_types[0]}, "
f"got input of type {inputs[0].type}"
)
if output_names is not None:
output_name = output_names[0]
else:
output_name = f"{op.name}/{0}"
output_value_type = copy.deepcopy(inputs[0].type)
op.add_out_edge(Value(name=output_name, value_type=output_value_type))
class BroadcastScatterOpRegisterEntry(OpRegisterEntry):
# TODO: Remove this and handle generic types in OpRegisterEntry
def infer_types(self, op, output_names=None):
inputs = op.get_in_edges()
devices = op.get_attribute("devices")
if output_names is not None and len(output_names) != 1:
raise ValueError(
f"Op {op.name}: Expected 1 output name but got {len(output_names)}"
)
output_types = []
for i, device in enumerate(devices):
output_type = copy.deepcopy(inputs[0].type)
output_type.set_device(device)
if op.op_type == "Scatter":
if isinstance(output_type, Tensor):
output_type.shape = None
output_types.append(output_type)
output_value = Value(output_names[0], value_type=TupleType(output_types))
op.add_out_edge(output_value)
class PmapOpRegisterEntry(OpRegisterEntry):
def infer_types(self, op, output_names=None):
devices = op.get_attribute("devices")
submodule = op.get_submodule(0)
submodule_inputs = submodule.get_inputs()
submodule_outputs = submodule.get_outputs()
# TODO: If we want a more robust solution for nested pmaps, move the
# parameterization over device variable to the module code
# TODO: Handle multiple device types?
d = Device.get_new_device_variable(devices[0].device_type)
for in_edge in submodule_inputs:
in_edge.type.set_device(d)
# TODO: Change the submodule input names to indicate they are
# parameterized over the devices
for i, out_edge in enumerate(submodule_outputs):
output_types = []
for device in devices:
output_type = copy.deepcopy(out_edge.type)
output_type.set_device(device)
output_types.append(output_type)
if output_names is None:
output_name = f"{out_edge.name}is"
else:
output_name = output_names[i]
output_value = Value(output_name, value_type=TupleType(output_types))
op.add_out_edge(output_value)
class SelectOpRegisterEntry(OpRegisterEntry):
def infer_types(self, op, output_names=None):
inputs = op.get_in_edges()
dim = op.get_attribute("dim")
output_value = Value(
output_names[0], value_type=copy.deepcopy(inputs[0].type.types[dim])
)
op.add_out_edge(output_value)
class SendOpRegisterEntry(OpRegisterEntry):
def infer_types(self, op, output_names=None):
inputs = op.get_in_edges()
device = op.get_attribute("device")
output_value_type = copy.deepcopy(inputs[0].type)
output_value_type.set_device(device)
output_value = Value(output_names[0], value_type=output_value_type)
op.add_out_edge(output_value)
class SplitOpRegisterEntry(OpRegisterEntry):
def infer_types(self, op, output_names=None):
inputs = op.get_in_edges()
num_splits = op.get_attribute("num_splits")
output_types = []
for i in range(num_splits):
output_type = copy.deepcopy(inputs[0].type)
output_type.shape = None
output_types.append(output_type)
output_value = Value(output_names[0], value_type=TupleType(output_types))
op.add_out_edge(output_value)
num_inputs: int
num_outputs: int
OpRegister = {
"Add": OpRegisterEntry(input_types=[Tensor, Tensor], output_types=[Tensor]),
"Allreduce": AllreduceOpRegisterEntry(
input_types=[TupleType], output_types=[TupleType]
),
"Broadcast": BroadcastScatterOpRegisterEntry(
input_types=[Tensor], output_types=[TupleType]
),
"BroadcastGradientArgs": OpRegisterEntry(
input_types=[Tensor, Tensor], output_types=[Tensor, Tensor]
),
"Concat": OpRegisterEntry(input_types=[Tensor, Tensor], output_types=[Tensor]),
"Gather": OpRegisterEntry(input_types=[TupleType], output_types=[Tensor]),
"Gemm": OpRegisterEntry(
input_types=[Tensor, Tensor, Tensor], output_types=[Tensor]
),
"Loss": OpRegisterEntry(input_types=[Tensor, Tensor], output_types=[Tensor]),
"LossGrad": OpRegisterEntry(input_types=[Tensor, Tensor], output_types=[Tensor]),
"MatMul": OpRegisterEntry(input_types=[Tensor, Tensor], output_types=[Tensor]),
"MatMulGrad": OpRegisterEntry(
input_types=[Tensor, Tensor, Tensor], output_types=[Tensor, Tensor]
),
"ReduceSumTraining": OpRegisterEntry(
input_types=[Tensor, Tensor], output_types=[Tensor]
),
"Relu": OpRegisterEntry(input_types=[Tensor], output_types=[Tensor]),
"ReluGrad": OpRegisterEntry(input_types=[Tensor, Tensor], output_types=[Tensor]),
"Reshape": OpRegisterEntry(input_types=[Tensor, Tensor], output_types=[Tensor]),
"Opt": OpRegisterEntry(input_types=[Tensor, Tensor], output_types=[Tensor]),
"Pmap": PmapOpRegisterEntry(input_types=None, output_types=None),
"Scatter": BroadcastScatterOpRegisterEntry(
input_types=[Tensor], output_types=[TupleType]
),
"Select": SelectOpRegisterEntry(input_types=[TupleType], output_types=[Tensor]),
"Send": SendOpRegisterEntry(input_types=[Tensor], output_types=[Tensor]),
"SGDOptimizer": OpRegisterEntry(
input_types=[Tensor, Tensor, Tensor], output_types=[Tensor, Tensor]
),
"Shape": OpRegisterEntry(input_types=[Tensor], output_types=[Tensor]),
"SoftmaxCrossEntropy": OpRegisterEntry(
input_types=[Tensor, Tensor], output_types=[Tensor, Tensor]
),
"SoftmaxCrossEntropyGrad": OpRegisterEntry(
input_types=[Tensor, Tensor, Tensor], output_types=[Tensor]
),
"Split": SplitOpRegisterEntry(input_types=[Tensor], output_types=[TupleType]),
"Add": OpRegisterEntry(num_inputs=2, num_outputs=1),
"Allreduce": OpRegisterEntry(num_inputs=1, num_outputs=1),
"Broadcast": OpRegisterEntry(num_inputs=1, num_outputs=1),
"BroadcastGradientArgs": OpRegisterEntry(num_inputs=2, num_outputs=2),
"Concat": OpRegisterEntry(num_inputs=2, num_outputs=1),
"Gather": OpRegisterEntry(num_inputs=1, num_outputs=1),
"Gemm": OpRegisterEntry(num_inputs=3, num_outputs=1),
"Loss": OpRegisterEntry(num_inputs=2, num_outputs=1),
"LossGrad": OpRegisterEntry(num_inputs=2, num_outputs=1),
"MatMul": OpRegisterEntry(num_inputs=2, num_outputs=1),
"MatMulGrad": OpRegisterEntry(num_inputs=3, num_outputs=2),
"ReduceSumTraining": OpRegisterEntry(num_inputs=2, num_outputs=1),
"Relu": OpRegisterEntry(num_inputs=1, num_outputs=1),
"ReluGrad": OpRegisterEntry(num_inputs=2, num_outputs=1),
"Reshape": OpRegisterEntry(num_inputs=2, num_outputs=1),
"Opt": OpRegisterEntry(num_inputs=2, num_outputs=1),
"Scatter": OpRegisterEntry(num_inputs=1, num_outputs=1),
"Select": OpRegisterEntry(num_inputs=1, num_outputs=1),
"Send": OpRegisterEntry(num_inputs=1, num_outputs=1),
"SGDOptimizer": OpRegisterEntry(num_inputs=3, num_outputs=2),
"Shape": OpRegisterEntry(num_inputs=1, num_outputs=1),
"SoftmaxCrossEntropy": OpRegisterEntry(num_inputs=2, num_outputs=2),
"SoftmaxCrossEntropyGrad": OpRegisterEntry(num_inputs=3, num_outputs=1),
"Split": OpRegisterEntry(num_inputs=1, num_outputs=1),
}

Просмотреть файл

@ -47,7 +47,7 @@ from prettyprinter.prettyprinter import (
)
from prettyprinter.utils import intersperse
from .module import Module
from .function import Function
from .value import Value
from .type import Type, Int, Float, Tensor, TupleType
from .device import Device
@ -97,10 +97,10 @@ def interline(*docs):
# ----------------------------------------
def _pprint_module_body(module: Module, ctx):
ops = [pretty_dispatch(op, ctx) for op in module.get_ops().values()]
def _pprint_function_body(function: Function, ctx):
ops = [pretty_dispatch(op, ctx) for op in function.ops]
# Include the outputs as a final "return" op
outputs = concat(_join(*(r.name for r in module.get_outputs())))
outputs = concat(_join(*(r.name for r in function.outputs)))
return_line = group(
nest(ctx.indent, concat([pp_reserved("return"), LINE, outputs]))
)
@ -108,12 +108,12 @@ def _pprint_module_body(module: Module, ctx):
return ops
@register_pretty(Module)
def _(module: Module, ctx):
ops = _pprint_module_body(module, ctx)
@register_pretty(Function)
def _(function: Function, ctx):
ops = _pprint_function_body(function, ctx)
return concat(
[
pretty_call(ctx, pp_fnname("Module"), *module.get_inputs()),
pretty_call(ctx, pp_fnname("Function"), *function.inputs),
nest(ctx.indent, concat([COLON, HARDLINE, interline(*ops)])),
]
)
@ -121,15 +121,15 @@ def _(module: Module, ctx):
@register_pretty(Op)
def _(op: Op, ctx):
results = concat(_join(*(pretty_dispatch(r, ctx) for r in op.get_out_edges())))
args = concat(_join(*(v.name for v in op.get_in_edges())))
results = concat(_join(*(pretty_dispatch(r, ctx) for r in op.outputs)))
args = concat(_join(*(v.name for v in op.inputs)))
if op.op_type == "Pmap":
lambda_args = _join(
*(pretty_dispatch(i, ctx) for i in op.get_submodule(0).get_inputs())
*(pretty_dispatch(i, ctx) for i in op.subfunctions[0].inputs)
)
lambda_args = concat([LPAREN, nest(ctx.indent, concat(lambda_args)), RPAREN])
lambda_body = _pprint_module_body(op.get_submodule(0), ctx)
lambda_body = _pprint_function_body(op.subfunctions[0], ctx)
actual_args = group(
concat(
[
@ -139,7 +139,8 @@ def _(op: Op, ctx):
]
)
)
d = str(op.get_attribute("device_var").device_id)
# TODO: Also print out the list of devices this pmaps over
d = str(op.attributes["device_var"].device_id)
pmap_args = nest(
ctx.indent,
concat(

Просмотреть файл

@ -1,24 +1,26 @@
from abc import ABC
from dataclasses import dataclass
from functools import reduce
from operator import add, mul
from typing import Optional, Tuple, TypeVar
from typing import Optional, Set, Tuple
from .device import Device
from .utils import singleton
@dataclass(frozen=True)
class Type:
def __init__(self, has_device=False):
self._has_device = has_device
self._device = None
"""Base class for all types."""
def set_device(self, device):
if self._has_device:
self._device = device
device: Optional[Device] = None
def get_all_devices(self):
if self._has_device and self._device is not None:
return set([self._device])
def get_all_devices(self) -> Set[Device]:
"""Returns all devices that a value of this type lives on. For example,
a tuple can have different elements live on different devices.
Subclasses should override this default implementation.
"""
if self.device is not None:
return set([self.device])
return set()
@ -29,11 +31,8 @@ class Type:
class Int(Type):
"""The integer type. A singleton class."""
def __str__(self):
return "Int"
def __repr__(self):
return str(self)
return "Int"
@property
def size(self):
@ -44,90 +43,61 @@ class Int(Type):
class Float(Type):
"""The float type. A singleton class."""
def __str__(self):
return f"Float"
def __repr__(self):
return str(self)
return "Float"
@property
def size(self):
return 4
@dataclass(frozen=True)
class Tensor(Type):
"""The tensor type. Contains a shape and an element type (dtype)."""
# TODO have a global cache to avoid creating multiple objects of same type?
def __init__(
self,
dtype: Type = None,
shape: Optional[Tuple[int]] = None,
device: Device = None,
):
Type.__init__(self, has_device=True)
self._device = device
self._shape = shape
self._dtype = dtype
dtype: Type = None # Unable to make this non-optional, but this should be
shape: Optional[Tuple[int]] = None
def __init__(self, dtype=None, shape=None, device=None):
# TODO make dtype a required argument?
assert dtype is None or isinstance(dtype, Type)
assert shape is None or (
isinstance(shape, tuple) and all(isinstance(n, int) for n in shape)
)
object.__setattr__(self, "dtype", dtype) # Can't assign to frozen field
object.__setattr__(self, "shape", shape)
Type.__init__(self, device=device)
def __repr__(self):
return (
f"Tensor[shape={self._shape}, dtype={self._dtype}, device={self._device}]"
)
def __eq__(self, other):
return (
self._dtype == other._dtype
and self._shape == other._shape
and self._device == other._device
)
@property
def shape(self):
return self._shape
@shape.setter
def shape(self, shape):
self._shape = shape
@property
def dtype(self):
return self._dtype
@dtype.setter
def dtype(self, dtype):
self._dtype = dtype
@property
def device(self):
return self._device
return f"Tensor[shape={self.shape}, dtype={self.dtype}, device={self.device}]"
def size(self):
return reduce(mul, self._shape) * self._dtype.size
return reduce(mul, self.shape) * self.dtype.size
@dataclass(frozen=True)
class TupleType(Type):
def __init__(self, types):
Type.__init__(self)
self._types = types
def __str__(self):
elems_str = ", ".join(str(t) for t in self._types)
return f"Tuple[{elems_str}]"
types: Tuple[Type] = None
def __init__(self, types):
# Override __init__ because it doesn't make sense for a tuple to have a
# device. Devices are stored in each tuple element.
object.__setattr__(self, "types", types) # Can't assign to frozen field
assert isinstance(types, tuple) and all(isinstance(t, Type) for t in types)
assert self.device is None
def __repr__(self):
return str(self)
elems_str = ", ".join(str(t) for t in self.types)
return f"Tuple[{elems_str}]"
@property
def types(self):
return self._types
def get_all_devices(self):
def get_all_devices(self) -> Set[Device]:
devices = set()
for typ in self._types:
for typ in self.types:
devices.update(typ.get_all_devices())
return devices
def size(self):
return reduce(add, [typ.size() for typ in self._types])
return reduce(add, [typ.size() for typ in self.types])

Просмотреть файл

@ -1,28 +1,14 @@
from dataclasses import dataclass
from .type import Type
@dataclass(frozen=True, eq=False)
class Value:
def __init__(self, name, value_type):
self._name = name
self._type = value_type
"""A DistIR value. While values have names, DistIR makes no attempt to ensure
value names are unique in a function. Therefore Value equality is object
equality. (TODO correct terminology for this?)
"""
def __str__(self):
return f"{self._name}: type={str(self._type)}"
def __repr__(self):
return str(self)
def __hash__(self):
return hash(repr(self))
def __eq__(self, other):
return self._name == other._name and self._type == other._type
@property
def name(self):
return self._name
@property
def type(self):
return self._type
@type.setter
def type(self, typ):
self._type = typ
name: str
type: Type

Просмотреть файл

@ -1,17 +1,15 @@
from ..ir.module import Module
import copy
from ..ir.function import FunctionMaker
class DataParallelTransform:
"""Partitions a module using data parallelism.
"""Partitions a function using data parallelism.
Replicates the given model across devices by instantiating an identical version
of the model on each device. The user specifies which input values to
partition between each device as well as the dimension to partition for each input
(e.g. selecting the first dimension for the input minibatch would partition
along the batch dimension). The selected input values are scattered between
each device, while the remaining input values are broadcasted. The module will
each device, while the remaining input values are broadcasted. The function will
be replicated using a Pmap operator. The original output values are retrieved
from each replica through Allreduce operators.
@ -26,33 +24,31 @@ class DataParallelTransform:
self._reduction_params = reduction_params
self._devices = devices
def apply(self, module):
"""Applies the transformation to the given module and returns the transformed module."""
transformed_module = Module()
def apply(self, function):
"""Applies the transformation to the given function and returns the transformed function."""
transformed_function = FunctionMaker()
# Either scatter or broadcast each input value depending on what the user
# has requested.
# TODO: Add explicit Send ops if the source device is not one of the
# destination devices.
input_values = module.get_inputs()
input_values = function.inputs
pmap_input_values = []
for input_value in input_values:
v = transformed_module.add_input_value(
input_value.name, copy.deepcopy(input_value.type)
)
if input_value.name in self._batch_dims:
vs = transformed_module.add_op(
v = transformed_function.add_input_value(input_value.name, input_value.type)
if input_value in self._batch_dims:
vs = transformed_function.add_op(
"Scatter",
name=f"Scatter/{v.name}",
inputs=[v],
attributes={
"devices": self._devices,
"dim": self._batch_dims[input_value.name],
"dim": self._batch_dims[input_value],
},
output_names=[f"{v.name}s"],
)
else:
vs = transformed_module.add_op(
vs = transformed_function.add_op(
"Broadcast",
name=f"Broadcast/{v.name}",
inputs=[v],
@ -61,18 +57,18 @@ class DataParallelTransform:
)
pmap_input_values.append(vs)
# Add the Pmap operator to the transformed module. The Pmap operator will
# encapsulate the original module.
output_values = module.get_outputs()
# Add the Pmap operator to the transformed function. The Pmap operator will
# encapsulate the original function.
output_values = function.outputs
pmap_output_names = []
for i, output_value in enumerate(output_values):
pmap_output_name = f"{output_value.name}is"
pmap_output_names.append(pmap_output_name)
pmap_output_values = transformed_module.add_op(
pmap_output_values = transformed_function.add_op(
"Pmap",
inputs=pmap_input_values,
attributes={"devices": self._devices},
submodules=[module],
subfunctions=[function],
output_names=pmap_output_names,
)
@ -83,23 +79,23 @@ class DataParallelTransform:
# TODO: Add explicit Send ops if the destination device is not one of the
# source devices.
for i, output_value in enumerate(output_values):
reduction_op_type = self._reduction_params[output_value.name]["op_type"]
reduction_op_type = self._reduction_params[output_value]["op_type"]
if reduction_op_type == "Allreduce":
transformed_module.add_op(
transformed_function.add_op(
"Allreduce",
name=f"Allreduce/{output_value.name}",
inputs=[pmap_output_values[i]],
output_names=[f"{output_value.name}s"],
)
elif reduction_op_type == "Gather":
dim = self._reduction_params[output_value.name]["dim"]
device = self._reduction_params[output_value.name]["device"]
transformed_module.add_op(
dim = self._reduction_params[output_value]["dim"]
device = self._reduction_params[output_value]["device"]
transformed_function.add_op(
"Gather",
name=f"Gather/{output_value.name}",
inputs=[pmap_output_values[i]],
attributes={"dim": dim, "device": device},
output_names=[f"{output_value.name}s"],
output_names=[f"{output_value.name}"],
)
else:
raise ValueError(
@ -107,4 +103,4 @@ class DataParallelTransform:
f"output value {output_value}"
)
return transformed_module
return transformed_function.finalize()

Просмотреть файл

@ -1,7 +1,7 @@
from collections import defaultdict
from typing import Dict, Set, Tuple
from ..ir import Device, Module
from ..ir import Device, Function
from .pipeline_parallel_scheduler import PipelineParallelScheduler
@ -9,13 +9,12 @@ class FIFOScheduler(PipelineParallelScheduler):
"""Implements a FIFO schedule where all forward pass stages are executed before
backward pass stages."""
def _get_next_stage_to_schedule(self, device: Device) -> Tuple[Module, int]:
def _get_next_stage_to_schedule(self, device: Device) -> Tuple[Function, int]:
ready_stages_by_type = defaultdict(list)
for ready_stage in self._ready_stages[device]:
# TODO: Use a more robust method to identify backwards pass stages.
(stage, microbatch) = ready_stage
ops = stage.get_ops()
if "Grad" in list(ops.keys())[0]:
if "Grad" in stage.ops[0].name:
ready_stages_by_type["bw"].append(ready_stage)
else:
ready_stages_by_type["fw"].append(ready_stage)

Просмотреть файл

@ -1,7 +1,7 @@
from collections import defaultdict
from typing import Dict, Set, Tuple
from ..ir import Device, Module
from ..ir import Device, Function
from .pipeline_parallel_scheduler import PipelineParallelScheduler
@ -12,13 +12,12 @@ class PipeDreamScheduler(PipelineParallelScheduler):
PipelineParallelScheduler.__init__(self, num_microbatches)
self._prev_stage_types = defaultdict(lambda: "bw")
def _get_next_stage_to_schedule(self, device: Device) -> Tuple[Module, int]:
def _get_next_stage_to_schedule(self, device: Device) -> Tuple[Function, int]:
ready_stages_by_type = defaultdict(list)
for ready_stage in self._ready_stages[device]:
# TODO: Use a more robust method to identify backwards pass stages.
(stage, microbatch) = ready_stage
ops = stage.get_ops()
if "Grad" in list(ops.keys())[0]:
if "Grad" in stage.ops[0].name:
ready_stages_by_type["bw"].append(ready_stage)
else:
ready_stages_by_type["fw"].append(ready_stage)

Просмотреть файл

@ -2,14 +2,14 @@ from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Dict, Set, Tuple
from ..ir import Module, Device, Op
from ..ir import Function, Device, Op
from . import utils
class PipelineParallelScheduler(ABC):
"""Interface for a pipeline parallel scheduler.
Pipeline parallel schedulers take as input a DistIR module, the number of
Pipeline parallel schedulers take as input a DistIR function, the number of
microbatches to partition each minibatch into, and a partition map which
captures the explicit placement of each stage onto corresponding devices.
The scheduler will return a time-ordered list of stages to execute on each
@ -24,13 +24,13 @@ class PipelineParallelScheduler(ABC):
self._remaining_inputs = defaultdict(lambda: 0)
self._ready_stages = defaultdict(list)
def _prepare_stages_to_schedule(self, module, partition_map):
def _prepare_stages_to_schedule(self, function, partition_map):
"""Enumerates the stages to schedule on each device across all microbatches."""
for stage, device in partition_map.items():
inputs = stage.get_inputs()
inputs = stage.inputs
remaining_inputs = len(inputs)
for input in inputs:
if module.is_input(input.name):
if input in function.inputs:
remaining_inputs -= 1
for i in range(self._num_microbatches):
self._remaining_inputs[(stage, i)] = remaining_inputs
@ -38,11 +38,11 @@ class PipelineParallelScheduler(ABC):
self._ready_stages[device].append((stage, i))
@abstractmethod
def _get_next_stage_to_schedule(self, device: Device) -> Tuple[Module, int]:
def _get_next_stage_to_schedule(self, device: Device) -> Tuple[Function, int]:
raise NotImplementedError()
def schedule(self, module, partition_map):
self._prepare_stages_to_schedule(module, partition_map)
def schedule(self, function, partition_map):
self._prepare_stages_to_schedule(function, partition_map)
op_to_stage = utils.get_op_to_stage_map(list(partition_map.keys()))
num_scheduled_stages = 0
total_stages_to_schedule = len(partition_map) * self._num_microbatches
@ -58,10 +58,10 @@ class PipelineParallelScheduler(ABC):
# TODO: Optimize this so it isn't an O(N) call?
self._ready_stages[device].remove(stage_to_schedule)
num_scheduled_stages += 1
outputs = stage.get_outputs()
outputs = stage.outputs
for output in outputs:
consumer_ops = module.get_consumers_for_value(output.name)
consumer_stages = utils.get_stages_from_op_names(
consumer_ops = function.get_consumers(output)
consumer_stages = utils.get_stages_from_ops(
op_to_stage, consumer_ops
)
for consumer_stage in consumer_stages:

Просмотреть файл

@ -1,20 +1,19 @@
import copy
from collections import defaultdict
from ..ir.module import Module
from ..ir.function import FunctionMaker
from ..ir.value import Value
from . import utils
class PipelineParallelTransform:
"""Partitions a module using pipeline parallelism.
"""Partitions a function using pipeline parallelism.
Attributes:
num_microbatches: The number of microbatches per pipeline iteration.
batch_dims: A map from input value name to partition dimension.
reduction_ops: A map from output value name to a map of reduction op params.
partition_map: A map from op name to device.
schedule: A list of maps from device to a tuple of (op_name, microbatch).
batch_dims: A map from input value to partition dimension.
reduction_ops: A map from output value to a map of reduction op params.
partition_map: A map from op to device.
schedule: A list of maps from device to a tuple of (op, microbatch).
"""
def __init__(
@ -29,9 +28,9 @@ class PipelineParallelTransform:
list(self._partition_map.keys())
)
def _forward_value(self, transformed_module, value, device):
def _forward_value(self, transformed_function, value, device):
"""Forwards the specified value to the specified device by adding a Send op."""
forwarded_value = transformed_module.add_op(
forwarded_value = transformed_function.add_op(
"Send",
name=f"Send/{value.name}@{device}",
inputs=[value],
@ -40,27 +39,24 @@ class PipelineParallelTransform:
)
return forwarded_value
def _partition_inputs(self, module, transformed_module, pipelined_value_map):
def _partition_inputs(self, function, transformed_function, pipelined_value_map):
"""Splits the input values according to the number of specified microbatches."""
input_values = module.get_inputs()
for input_value in input_values:
v = transformed_module.add_input_value(
input_value.name, copy.deepcopy(input_value.type)
)
pipelined_input_map = pipelined_value_map[input_value.name]
if input_value.name in self._batch_dims:
vs = transformed_module.add_op(
for input_value in function.inputs:
v = transformed_function.add_input_value(input_value.name, input_value.type)
pipelined_input_map = pipelined_value_map[input_value]
if input_value in self._batch_dims:
vs = transformed_function.add_op(
"Split",
name=f"Split/{v.name}",
inputs=[v],
attributes={
"num_splits": self._num_microbatches,
"dim": self._batch_dims[input_value.name],
"dim": self._batch_dims[input_value],
},
output_names=[f"{v.name}s"],
)
for i in range(self._num_microbatches):
v_i = transformed_module.add_op(
v_i = transformed_function.add_op(
"Select",
name=f"Select/{v.name}_{i}",
attributes={"dim": i},
@ -75,8 +71,8 @@ class PipelineParallelTransform:
# Forward the input value(s) if the destination device(s) are not
# the same as the source device.
input_device = input_value.type.device
consumer_ops = module.get_consumers_for_value(input_value.name)
consumer_stages = utils.get_stages_from_op_names(
consumer_ops = function.get_consumers(input_value)
consumer_stages = utils.get_stages_from_ops(
self._op_to_stage_map, consumer_ops
)
consumer_devices = set([self._partition_map[c] for c in consumer_stages])
@ -92,9 +88,9 @@ class PipelineParallelTransform:
# TODO: Propagate these values alongside activations instead of sending them
# ahead of time to be consistent with ORT?
for i in range(self._num_microbatches):
if input_value.name in self._batch_dims or i == 0:
if input_value in self._batch_dims or i == 0:
forwarded_input = self._forward_value(
transformed_module,
transformed_function,
pipelined_input_map[i],
consumer_device,
)
@ -106,7 +102,7 @@ class PipelineParallelTransform:
def _aggregate_outputs(
self,
transformed_module,
transformed_function,
orig_output,
pipelined_output,
merged_output_map,
@ -115,27 +111,21 @@ class PipelineParallelTransform:
"""Aggregates the specified output according to the user-provided reduction parameters.
Args:
transformed_module: The transformed module.
transformed_function: The transformed function.
orig_output: The original version of the output value.
pipelined_output: The transformed (i.e. partitioned) version of the output value.
merged_output_map: A map from original output value name to aggregated output value.
merged_output_map: A map from original output value to aggregated output value.
num_completed_microbatches: The number of microbatches completed so far.
"""
if self._reduction_params[orig_output.name] is None:
if self._reduction_params[orig_output] is None:
# This output does not need to be aggregated.
return
reduction_op_type = self._reduction_params[orig_output.name]["op_type"]
reduction_op_type = self._reduction_params[orig_output]["op_type"]
if num_completed_microbatches == 1:
merged_output_map[orig_output.name] = pipelined_output
merged_output_map[orig_output] = pipelined_output
else:
merged_output = merged_output_map[orig_output.name]
# Forward the output value if necessary.
if merged_output.type.device != pipelined_output.type.device:
pipelined_output = self._forward_value(
transformed_module, pipelined_output, merged_output.type.device
)
merged_output = merged_output_map[orig_output]
# Prepare the reduction op name and output value name.
op_name = (
@ -146,17 +136,17 @@ class PipelineParallelTransform:
else:
output_name = f"{orig_output.name}/merged_{num_completed_microbatches}"
# Add the requested reduction op to the transformed module.
# Add the requested reduction op to the transformed function.
if reduction_op_type == "Add":
merged_output_map[orig_output.name] = transformed_module.add_op(
merged_output_map[orig_output] = transformed_function.add_op(
"Add",
name=op_name,
inputs=[merged_output, pipelined_output],
output_names=[output_name],
)
elif reduction_op_type == "Concat":
dim = self._reduction_params[orig_output.name]["dim"]
merged_output_map[orig_output.name] = transformed_module.add_op(
dim = self._reduction_params[orig_output]["dim"]
merged_output_map[orig_output] = transformed_function.add_op(
"Concat",
attributes={"dim": dim},
name=op_name,
@ -169,49 +159,47 @@ class PipelineParallelTransform:
f"for output {orig_output}"
)
def apply(self, module):
"""Applies the transformation to the module and returns a transformed module."""
def apply(self, function):
"""Applies the transformation to the function and returns a transformed function."""
transformed_module = Module()
transformed_function = FunctionMaker()
# A map from original value name to another map from microbatch number to
# A map from original value to another map from microbatch number to
# pipelined value.
pipelined_value_map = defaultdict(lambda: defaultdict(Value))
pipelined_value_map = defaultdict(lambda: {})
# A map from original output value name to merged output value.
# A map from original output value to merged output value.
merged_output_map = defaultdict(Value)
# Partition the input values for each microbatch.
self._partition_inputs(module, transformed_module, pipelined_value_map)
# Partition the input values according to the number of microbatches.
self._partition_inputs(function, transformed_function, pipelined_value_map)
# Schedule stages on each device in order of increasing timestep.
for timestep in range(len(self._schedule)):
for device in self._schedule[timestep]:
# Look up the next stage to execute according to the schedule
# and add each op in the stage to the transformed module.
# and add each op in the stage to the transformed function.
(stage, microbatch) = self._schedule[timestep][device]
stage_outputs = set([v.name for v in stage.get_outputs()])
for op_name, orig_op in stage.get_ops().items():
orig_inputs = orig_op.get_in_edges()
orig_outputs = orig_op.get_out_edges()
for orig_op in stage.ops:
orig_inputs = orig_op.inputs
orig_outputs = orig_op.outputs
# Collect the pipelined input values for this op.
pipelined_inputs = []
for orig_input in orig_inputs:
pipelined_input_map = pipelined_value_map[orig_input.name]
pipelined_input = pipelined_input_map[microbatch]
pipelined_input_map = pipelined_value_map[orig_input]
pipelined_inputs.append(pipelined_input_map[microbatch])
# Add the pipelined version of the op for the given microbatch to
# the transformed module.
# the transformed function.
pipelined_output_names = [
f"{orig_output.name}_{microbatch}"
for orig_output in orig_outputs
]
pipelined_outputs = transformed_module.add_op(
pipelined_outputs = transformed_function.add_op(
orig_op.op_type,
name=f"{orig_op.name}_{microbatch}",
attributes=orig_op._attributes,
attributes=orig_op.attributes,
inputs=pipelined_inputs,
output_names=pipelined_output_names,
)
@ -223,19 +211,19 @@ class PipelineParallelTransform:
for (orig_output, pipelined_output) in zip(
orig_outputs, pipelined_outputs
):
pipelined_output_map = pipelined_value_map[orig_output.name]
pipelined_output_map = pipelined_value_map[orig_output]
pipelined_output_map[microbatch] = pipelined_output
if orig_output.name not in stage_outputs:
if orig_output not in stage.outputs:
# This output is an intermediate output *within* a stage which does not
# require any additional processing.
continue
elif module.is_output(orig_output.name):
# This output is a module output, which means we need to aggregate it
elif orig_output in function.outputs:
# This output is a function output, which means we need to aggregate it
# with all other corresponding partitioned outputs for each microbatch.
num_completed_microbatches = len(pipelined_output_map)
self._aggregate_outputs(
transformed_module,
transformed_function,
orig_output,
pipelined_output,
merged_output_map,
@ -245,23 +233,21 @@ class PipelineParallelTransform:
# This output is an intermediate stage output, which means we need to
# forward the output to the next stage if the next stage is located on
# a different device.
consumer_ops = module.get_consumers_for_value(
orig_output.name
)
consumer_stages = utils.get_stages_from_op_names(
consumer_ops = function.get_consumers(orig_output)
consumer_stages = utils.get_stages_from_ops(
self._op_to_stage_map, consumer_ops
)
consumer_devices = set(
[self._partition_map[c] for c in consumer_stages]
)
consumer_devices = {
self._partition_map[c] for c in consumer_stages
}
for consumer_device in consumer_devices:
if device != consumer_device:
pipelined_output_map[
microbatch
] = self._forward_value(
transformed_module,
transformed_function,
pipelined_output,
consumer_device,
)
return transformed_module
return transformed_function.finalize()

Просмотреть файл

@ -1,27 +1,27 @@
from typing import Dict, Iterable, List
from ..ir import Module
from ..ir import Function, Op
def get_op_to_stage_map(stages: Iterable[Module]) -> Dict[str, Module]:
"""Given a list of stages, returns a map from individual op name to
def get_op_to_stage_map(stages: Iterable[Function]) -> Dict[Op, Function]:
"""Given a list of stages, returns a map from individual op to
encompassing stage."""
op_to_stage = {}
for stage in stages:
for op_name in stage.get_ops():
op_to_stage[op_name] = stage
for op in stage.ops:
op_to_stage[op] = stage
return op_to_stage
def get_stages_from_op_names(
op_to_stage: Dict[str, Module], op_names: Iterable[str]
) -> List[Module]:
"""Given a list of op names and a map from op name to encompassing stage,
def get_stages_from_ops(
op_to_stage: Dict[Op, Function], ops: Iterable[Op]
) -> List[Function]:
"""Given a list of ops and a map from op to encompassing stage,
returns a list of encompassing stages."""
seen = set()
stages = []
for op_name in op_names:
stage = op_to_stage[op_name]
for op in ops:
stage = op_to_stage[op]
if stage not in seen:
stages.append(stage)
return stages

Просмотреть файл

@ -1,7 +1,7 @@
# The DistIR Internal Representation
```
Module = {
Function = {
# Invariant: ops are in topological order
ops: List[Op]
inputs: List[Value]
@ -17,16 +17,16 @@ Device = {
Op = {
name: String
# Do we need this to be unique in a module? Across all modules?
# Do we need this to be unique in a function? Across all functions?
op_type: OpType
# The type of operator
in_edges: List[Value]
# Pointer to a Value object either in Module.inputs or another Op.out_edges
# Pointer to a Value object either in Function.inputs or another Op.out_edges
out_edges: List[Value]
# To support ops that have more than one output
attributes: Dict[String, Any]
# Constant data for ops, e.g. stride of convolution or devices to scatter
submodules: List[Module]
subfunctions: List[Function]
}
OpType =
@ -37,7 +37,7 @@ OpType =
Value = {
name: String
# Again, does it need to be unique in a module?
# Again, does it need to be unique in a function?
type: Type
device: DeviceID
# Which device this value lives on
@ -366,7 +366,7 @@ def mlp(
```
TODO do we want explicit free?
If so, then need to add it to input module after say importing from onnx,
If so, then need to add it to input function after say importing from onnx,
and need a free after last occurrence of every value.
If not, then simulator has to do a liveness analysis. This isn't hard/expensive.
But how do we deal with inplace operations? Have an attribute on those ops?
@ -398,16 +398,16 @@ two layer FF or attention layers?
- Can we do without a special BroadcastScatterOpRegisterEntry?
- Why do we even need types at op creation time?
- Think about best way to create pmap device variable: before/after submodule creation?
- Think about best way to create pmap device variable: before/after subfunction creation?
- Create `HardwareConfiguration` that has all device speeds as well as topology and bandwidth information
- Have a validation pass that checks module is valid (e.g. that inputs are on the same device for matmul)
- Add test that `DistributedSimulator` doesn't modify module
- Have a validation pass that checks function is valid (e.g. that inputs are on the same device for matmul)
- Add test that `DistributedSimulator` doesn't modify function
- Maybe one problem with the return value of a pmap or scatter is that we are expecting things to have types but not shapes at module creation time, and then run a shape inference pass later to fill in the shapes. Would it be cleaner to just infer shapes and fill in a "full" type at op-creation time? Will we ever need to run shape inference later?
- Maybe one problem with the return value of a pmap or scatter is that we are expecting things to have types but not shapes at function creation time, and then run a shape inference pass later to fill in the shapes. Would it be cleaner to just infer shapes and fill in a "full" type at op-creation time? Will we ever need to run shape inference later?
## Problems
- Should a tensor never have a device? What if a pmap's submodule expects a tensor? Then how do we enforce that it is on device d? Is it better to go back to all values have an `Option[Device]`?
- Should a tensor never have a device? What if a pmap's subfunction expects a tensor? Then how do we enforce that it is on device d? Is it better to go back to all values have an `Option[Device]`?

Просмотреть файл

@ -26,7 +26,7 @@
"metadata": {},
"outputs": [],
"source": [
"from dist_ir.ir import Module\n",
"from dist_ir.ir import Function\n",
"from dist_ir.ir import Topology\n",
"from dist_ir.ir.type import Float\n",
"from dist_ir.ir.type import Tensor\n",
@ -70,7 +70,7 @@
"outputs": [],
"source": [
"def construct_mlp_module(topology, batch_size, num_classes, input_dim, hidden_dims):\n",
" module = Module()\n",
" module = FunctionMaker()\n",
"\n",
" device = topology.devices[0]\n",
" x = module.add_input_value(\n",
@ -180,7 +180,7 @@
" transform = DataParallelTransform(\n",
" batch_dims={\"x\": 0, \"z\": 0},\n",
" reduction_params={\n",
" f\"{dw.name}\": {\"op_type\": \"Allreduce\"} for dw in module.get_outputs()\n",
" f\"{dw.name}\": {\"op_type\": \"Allreduce\"} for dw in module.outputs\n",
" },\n",
" devices=topology.devices,\n",
" )\n",
@ -365,7 +365,7 @@
" num_layers = len(hidden_dims) + 1\n",
" layers_per_device = num_layers // num_devices\n",
" partition_map = OrderedDict()\n",
" op_names = list(module.get_ops().keys())\n",
" op_names = list(module.ops.keys())\n",
" # Add forward pass stages\n",
" for i in range(num_devices):\n",
" idxs = [i*layers_per_device, (i+1)*layers_per_device] \n",
@ -388,7 +388,7 @@
" num_microbatches=num_microbatches,\n",
" batch_dims={\"x\": 0, \"z\": 0},\n",
" reduction_params={\n",
" f\"{dw.name}\": {\"op_type\": \"Add\"} for dw in module.get_outputs()\n",
" f\"{dw.name}\": {\"op_type\": \"Add\"} for dw in module.outputs\n",
" },\n",
" partition_map=partition_map,\n",
" schedule=schedule,\n",

Просмотреть файл

@ -1,3 +1,4 @@
frozendict >= 1.2
numpy >= 1.19
onnx >= 1.7.0
torch >= 1.6.0

Просмотреть файл

@ -1,53 +1,52 @@
from collections import OrderedDict
from dist_ir.ir import Device, Module
from dist_ir.ir import Device, FunctionMaker
from dist_ir.ir.type import Float, Tensor
def construct_module_and_partition_map():
module = Module()
def construct_function_and_partition_map():
function = FunctionMaker()
d0 = Device(0, "gpu")
d1 = Device(1, "gpu")
batch_size = 16
x = module.add_input_value(
x = function.add_input_value(
"x", Tensor(dtype=Float(), shape=(batch_size, 4), device=d0)
)
z = module.add_input_value(
z = function.add_input_value(
"z", Tensor(dtype=Float(), shape=(batch_size, 1), device=d0)
)
wA = module.add_input_value("wA", Tensor(dtype=Float(), shape=(4, 2), device=d0))
wB = module.add_input_value("wB", Tensor(dtype=Float(), shape=(2, 1), device=d0))
a = module.add_op("MatMul", "MatMul0", inputs=[x, wA], output_names=["a"])
y = module.add_op("MatMul", "MatMul1", inputs=[a, wB], output_names=["y"])
l = module.add_op(
wA = function.add_input_value("wA", Tensor(dtype=Float(), shape=(4, 2), device=d0))
wB = function.add_input_value("wB", Tensor(dtype=Float(), shape=(2, 1), device=d0))
a = function.add_op("MatMul", "MatMul0", inputs=[x, wA], output_names=["a"])
y = function.add_op("MatMul", "MatMul1", inputs=[a, wB], output_names=["y"])
l = function.add_op(
"Loss", "Loss", inputs=[y, z], attributes={"N": batch_size}, output_names=["l"]
)
dl = module.add_op(
dl = function.add_op(
"LossGrad",
"LossGrad",
inputs=[y, z],
attributes={"N": batch_size},
output_names=["dl"],
)
da, dwB = module.add_op(
da, dwB = function.add_op(
"MatMulGrad", "MatMul1Grad", inputs=[a, wB, dl], output_names=["da", "dwB"]
)
_, dwA = module.add_op(
_, dwA = function.add_op(
"MatMulGrad", "MatMul0Grad", inputs=[x, wA, da], output_names=["dx", "dwA"]
)
module.set_outputs([l, dwA, dwB])
module.finalize()
function = function.finalize()
stages = [
module.get_submodule(("MatMul0",), name="f0"),
module.get_submodule(("MatMul1", "Loss"), name="f1"),
module.get_submodule(("LossGrad", "MatMul1Grad"), name="b1"),
module.get_submodule(("MatMul0Grad",), name="b0"),
function.get_subfunction(("MatMul0",), name="f0"),
function.get_subfunction(("MatMul1", "Loss"), name="f1"),
function.get_subfunction(("LossGrad", "MatMul1Grad"), name="b1"),
function.get_subfunction(("MatMul0Grad",), name="b0"),
]
partition_map = OrderedDict(
[(stages[0], d0), (stages[1], d1), (stages[2], d1), (stages[3], d0)]
)
return (module, partition_map)
return (function, partition_map)

Просмотреть файл

@ -1,153 +1,52 @@
import numpy as np
from dist_ir.ir import Device, Module
from dist_ir.ir import Device, FunctionMaker
from dist_ir.ir.type import Float, Tensor
from dist_ir.transforms import DataParallelTransform
from dist_ir.executor import SequentialExecutor
# TODO test on actual inputs using sequential executor
def test_single_variable_partition():
module = Module()
function = FunctionMaker()
d0 = Device(0, "gpu")
d1 = Device(1, "gpu")
a = module.add_input_value("a", Tensor(Float(), (4, 4)))
b = module.add_input_value("b", Tensor(Float(), (4, 4)))
x = module.add_op("MatMul", "MatMul0", inputs=[a, b], output_names=["x"])
module.finalize()
a = function.add_input_value("a", Tensor(Float(), (4, 4)))
b = function.add_input_value("b", Tensor(Float(), (4, 4)))
x = function.add_op("MatMul", "MatMul0", inputs=[a, b], output_names=["x"])
function = function.finalize()
transform = DataParallelTransform(
batch_dims={"a": 0},
reduction_params={"x": {"op_type": "Gather", "dim": 0, "device": d0}},
devices=[d0, d1],
)
transformed_module = transform.apply(module)
print("-" * 88)
print("Original module")
print("-" * 88)
print(module)
print()
print("-" * 88)
print("Transformed module")
print("-" * 88)
print(transformed_module)
assert transformed_module.is_op("Scatter/a")
assert transformed_module.is_op("Broadcast/b")
assert transformed_module.is_op("Pmap_#0")
assert transformed_module.is_op("Gather/x")
def test_double_variable_partition():
module = Module()
d0 = Device(0, "gpu")
d1 = Device(1, "gpu")
a = module.add_input_value("a", Tensor(Float(), (4, 4)))
b = module.add_input_value("b", Tensor(Float(), (4, 4)))
c = module.add_input_value("c", Tensor(Float(), (4, 4)))
x = module.add_op("MatMul", "MatMul0", inputs=[a, b], output_names=["x"])
y = module.add_op("MatMul", "MatMul1", inputs=[x, c], output_names=["y"])
module.finalize()
transform = DataParallelTransform(
batch_dims={"a": 0, "c": 0},
reduction_params={"y": {"op_type": "Gather", "dim": 0, "device": d0}},
devices=[d0, d1],
)
transformed_module = transform.apply(module)
print("-" * 88)
print("Original module")
print("-" * 88)
print(module)
print()
print("-" * 88)
print("Transformed module")
print("-" * 88)
print(transformed_module)
assert transformed_module.is_op("Scatter/a")
assert transformed_module.is_op("Broadcast/b")
assert transformed_module.is_op("Scatter/c")
assert transformed_module.is_op("Pmap_#0")
assert transformed_module.is_op("Gather/y")
def test_mnist():
module = Module()
d0 = Device(0, "gpu")
d1 = Device(1, "gpu")
batch_size = 16
x = module.add_input_value("x", Tensor(Float(), (batch_size, 4)))
z = module.add_input_value("z", Tensor(Float(), (batch_size, 1)))
wA = module.add_input_value("wA", Tensor(Float(), (4, 2)))
wB = module.add_input_value("wB", Tensor(Float(), (2, 1)))
a = module.add_op("MatMul", "MatMul0", inputs=[x, wA], output_names=["a"])
y = module.add_op("MatMul", "MatMul1", inputs=[a, wB], output_names=["y"])
l = module.add_op(
"Loss", "Loss", inputs=[y, z], attributes={"N": batch_size}, output_names=["l"]
)
dl = module.add_op(
"LossGrad",
"LossGrad",
inputs=[y, z],
attributes={"N": batch_size},
output_names=["dl"],
)
da, dwB = module.add_op(
"MatMulGrad", "MatMul1Grad", inputs=[a, wB, dl], output_names=["da", "dwB"]
)
dx, dwA = module.add_op(
"MatMulGrad", "MatMul0Grad", inputs=[x, wA, da], output_names=["dx", "dwA"]
)
module.set_outputs([l, dwA, dwB])
module.finalize()
transform = DataParallelTransform(
batch_dims={"x": 0, "z": 0},
batch_dims={function.inputs[0]: 0},
reduction_params={
"l": {"op_type": "Gather", "dim": 0, "device": d0},
"dx": {"op_type": "Gather", "dim": 0, "device": d0},
"dwA": {"op_type": "Allreduce"},
"dwB": {"op_type": "Allreduce"},
function.outputs[0]: {"op_type": "Gather", "dim": 0, "device": d0}
},
devices=[d0, d1],
)
transformed_module = transform.apply(module)
transformed_module.finalize()
transformed_function = transform.apply(function)
print("-" * 88)
print("Original module")
print("Original function")
print("-" * 88)
print(module)
print(function)
print()
print("-" * 88)
print("Transformed module")
print("Transformed function")
print("-" * 88)
print(transformed_module)
print(transformed_function)
ex = SequentialExecutor("numpy")
_x = np.arange(batch_size * 4).reshape((batch_size, 4))
_z = np.ones((batch_size, 1))
_wA = np.ones((4, 2))
_wB = np.ones((2, 1))
orig_res = ex.compute(
module,
{"x": _x, "z": _z, "wA": _wA, "wB": _wB},
)
_a = np.ones((4, 4))
_b = np.ones((4, 4))
orig_res = ex.compute(function, {function.inputs[0]: _a, function.inputs[1]: _b})
transformed_res = ex.compute(
transformed_module,
{"x": _x, "z": _z, "wA": _wA, "wB": _wB},
transformed_function,
{transformed_function.inputs[0]: _a, transformed_function.inputs[1]: _b},
)
print("-" * 88)
print("Original module results")
print("Original function results")
print("-" * 88)
for k, v in orig_res.items():
print(k)
@ -155,29 +54,193 @@ def test_mnist():
print()
print()
print("-" * 88)
print("Transformed module results")
print("Transformed function results")
print("-" * 88)
for k, v in transformed_res.items():
print(k)
print(v)
print()
assert np.array_equal(orig_res["l"], np.concatenate(transformed_res["ls"], axis=0))
assert np.array_equal(orig_res["dwA"], transformed_res["dwAs"][0])
assert np.array_equal(orig_res["dwB"], transformed_res["dwBs"][0])
"""
assert transformed_module.is_op("Scatter/x")
assert transformed_module.is_op("Scatter/z")
assert transformed_module.is_op("Broadcast/wA")
assert transformed_module.is_op("Broadcast/wB")
assert transformed_module.is_op("Pmap_#0")
assert transformed_module.is_op("Gather/l")
assert transformed_module.is_op("Gather/dx")
assert transformed_module.is_op("Allreduce/dwA")
assert transformed_module.is_op("Allreduce/dwB")
"""
np.testing.assert_array_almost_equal(
orig_res[function.outputs[0]], transformed_res[transformed_function.outputs[0]]
)
if __name__ == "__main__":
test_mnist()
def test_double_variable_partition():
function = FunctionMaker()
d0 = Device(0, "gpu")
d1 = Device(1, "gpu")
a = function.add_input_value("a", Tensor(Float(), (4, 4)))
b = function.add_input_value("b", Tensor(Float(), (4, 4)))
c = function.add_input_value("c", Tensor(Float(), (4, 4)))
x = function.add_op("MatMul", "MatMul", inputs=[a, b], output_names=["x"])
y = function.add_op("Add", "Add", inputs=[x, c], output_names=["y"])
function = function.finalize()
transform = DataParallelTransform(
batch_dims={function.inputs[0]: 0, function.inputs[2]: 0},
reduction_params={
function.outputs[0]: {"op_type": "Gather", "dim": 0, "device": d0}
},
devices=[d0, d1],
)
transformed_function = transform.apply(function)
print("-" * 88)
print("Original function")
print("-" * 88)
print(function)
print()
print("-" * 88)
print("Transformed function")
print("-" * 88)
print(transformed_function)
ex = SequentialExecutor("numpy")
_a = np.ones((4, 4))
_b = np.ones((4, 4))
_c = np.ones((4, 4))
orig_res = ex.compute(
function,
{function.inputs[0]: _a, function.inputs[1]: _b, function.inputs[2]: _c},
)
transformed_res = ex.compute(
transformed_function,
{
transformed_function.inputs[0]: _a,
transformed_function.inputs[1]: _b,
transformed_function.inputs[2]: _c,
},
)
print("-" * 88)
print("Original function results")
print("-" * 88)
for k, v in orig_res.items():
print(k)
print(v)
print()
print()
print("-" * 88)
print("Transformed function results")
print("-" * 88)
for k, v in transformed_res.items():
print(k)
print(v)
print()
np.testing.assert_array_almost_equal(
orig_res[function.outputs[0]], transformed_res[transformed_function.outputs[0]]
)
def test_mnist():
function = FunctionMaker()
d0 = Device(0, "gpu")
d1 = Device(1, "gpu")
batch_size = 16
x = function.add_input_value("x", Tensor(Float(), (batch_size, 4)))
z = function.add_input_value("z", Tensor(Float(), (batch_size, 1)))
wA = function.add_input_value("wA", Tensor(Float(), (4, 2)))
wB = function.add_input_value("wB", Tensor(Float(), (2, 1)))
a = function.add_op("MatMul", "MatMul0", inputs=[x, wA], output_names=["a"])
y = function.add_op("MatMul", "MatMul1", inputs=[a, wB], output_names=["y"])
l = function.add_op(
"Loss", "Loss", inputs=[y, z], attributes={"N": batch_size}, output_names=["l"]
)
dl = function.add_op(
"LossGrad",
"LossGrad",
inputs=[y, z],
attributes={"N": batch_size},
output_names=["dl"],
)
da, dwB = function.add_op(
"MatMulGrad", "MatMul1Grad", inputs=[a, wB, dl], output_names=["da", "dwB"]
)
dx, dwA = function.add_op(
"MatMulGrad", "MatMul0Grad", inputs=[x, wA, da], output_names=["dx", "dwA"]
)
function = function.finalize()
transform = DataParallelTransform(
batch_dims={function.inputs[0]: 0, function.inputs[1]: 0},
reduction_params={
function.outputs[0]: {"op_type": "Gather", "dim": 0, "device": d0},
function.outputs[1]: {"op_type": "Allreduce"},
function.outputs[2]: {"op_type": "Gather", "dim": 0, "device": d0},
function.outputs[3]: {"op_type": "Allreduce"},
},
devices=[d0, d1],
)
transformed_function = transform.apply(function)
print("-" * 88)
print("Original function")
print("-" * 88)
print(function)
print()
print("-" * 88)
print("Transformed function")
print("-" * 88)
print(transformed_function)
ex = SequentialExecutor("numpy")
_x = np.arange(batch_size * 4).reshape((batch_size, 4))
_z = np.ones((batch_size, 1))
_wA = np.ones((4, 2))
_wB = np.ones((2, 1))
orig_res = ex.compute(
function,
{
function.inputs[0]: _x,
function.inputs[1]: _z,
function.inputs[2]: _wA,
function.inputs[3]: _wB,
},
)
transformed_res = ex.compute(
transformed_function,
{
transformed_function.inputs[0]: _x,
transformed_function.inputs[1]: _z,
transformed_function.inputs[2]: _wA,
transformed_function.inputs[3]: _wB,
},
)
print("-" * 88)
print("Original function results")
print("-" * 88)
for k, v in orig_res.items():
print(k)
print(v)
print()
print()
print("-" * 88)
print("Transformed function results")
print("-" * 88)
for k, v in transformed_res.items():
print(k)
print(v)
print()
np.testing.assert_array_almost_equal(
orig_res[function.outputs[0]], transformed_res[transformed_function.outputs[0]]
)
np.testing.assert_array_almost_equal(
orig_res[function.outputs[1]],
transformed_res[transformed_function.outputs[1]][0],
)
np.testing.assert_array_almost_equal(
orig_res[function.outputs[2]],
transformed_res[transformed_function.outputs[2]],
)
np.testing.assert_array_almost_equal(
orig_res[function.outputs[3]],
transformed_res[transformed_function.outputs[3]][0],
)

Просмотреть файл

@ -1,58 +1,66 @@
from dist_ir.ir import Module
from dist_ir.ir import FunctionMaker
from dist_ir.ir import Topology
from dist_ir.ir.type import Float
from dist_ir.ir.type import Tensor
from dist_ir.executor.cost_inference import CostModel
from dist_ir.executor.type_inference import infer_types
from dist_ir.executor import DistributedSimulator
from dist_ir.transforms import DataParallelTransform
def test_single_device():
module = Module()
function = FunctionMaker()
topology = Topology()
d = topology.add_device("gpu")
a = module.add_input_value("a", Tensor(dtype=Float(), shape=(4, 4), device=d))
b = module.add_input_value("b", Tensor(dtype=Float(), shape=(4, 4), device=d))
x = module.add_op("MatMul", "MatMul0", inputs=[a, b])
module.finalize()
a = function.add_input_value("a", Tensor(dtype=Float(), shape=(4, 4), device=d))
b = function.add_input_value("b", Tensor(dtype=Float(), shape=(4, 4), device=d))
x = function.add_op("MatMul", "MatMul0", inputs=[a, b])
function = function.finalize()
function = infer_types(function, [a, b])
device_speeds = {"gpu": 1.0e13}
# TODO shouldn't device_speeds be set in the topology?
cost_model = CostModel(topology, device_speeds)
simulator = DistributedSimulator(cost_model)
simulator_state = simulator.simulate(module)
simulator_state = simulator.simulate(function)
assert d in simulator_state.timestamps
assert d in simulator_state.peak_memory
# TODO: Check specific values
def test_data_parallel():
module = Module()
function = FunctionMaker()
topology = Topology()
d0 = topology.add_device("gpu")
d1 = topology.add_device("gpu")
topology.set_bandwidth(d0, d1, 2)
a = module.add_input_value("a", Tensor(Float(), (4, 4), device=d0))
b = module.add_input_value("b", Tensor(Float(), (4, 4), device=d0))
c = module.add_input_value("c", Tensor(Float(), (4, 4), device=d0))
x = module.add_op("MatMul", "MatMul0", inputs=[a, b], output_names=["x"])
y = module.add_op("MatMul", "MatMul1", inputs=[x, c], output_names=["y"])
module.finalize()
a = function.add_input_value("a", Tensor(Float(), (4, 4), device=d0))
b = function.add_input_value("b", Tensor(Float(), (4, 4), device=d0))
c = function.add_input_value("c", Tensor(Float(), (4, 4), device=d0))
x = function.add_op("MatMul", "MatMul0", inputs=[a, b], output_names=["x"])
y = function.add_op("MatMul", "MatMul1", inputs=[x, c], output_names=["y"])
function = function.finalize()
function = infer_types(function, [a, b, c])
transform = DataParallelTransform(
batch_dims={"a": 0},
reduction_params={"y": {"op_type": "Gather", "dim": 0, "device": d0}},
batch_dims={function.inputs[0]: 0},
reduction_params={
function.outputs[0]: {"op_type": "Gather", "dim": 0, "device": d0}
},
devices=[d0, d1],
)
transformed_module = transform.apply(module)
transformed_function = transform.apply(function)
transformed_function = infer_types(
transformed_function, transformed_function.inputs
)
transformed_module.finalize()
print(transformed_module)
print(transformed_function)
device_speeds = {"gpu": 1.0e13}
cost_model = CostModel(topology, device_speeds)
simulator = DistributedSimulator(cost_model)
simulator_state = simulator.simulate(transformed_module)
simulator_state = simulator.simulate(transformed_function)
assert d0 in simulator_state.timestamps
assert d1 in simulator_state.timestamps
assert d0 in simulator_state.peak_memory
@ -61,31 +69,36 @@ def test_data_parallel():
def test_chrome_trace():
module = Module()
function = FunctionMaker()
topology = Topology()
d0 = topology.add_device("gpu")
d1 = topology.add_device("gpu")
topology.set_bandwidth(d0, d1, 2)
a = module.add_input_value("a", Tensor(Float(), (4, 4), device=d0))
b = module.add_input_value("b", Tensor(Float(), (4, 4), device=d0))
c = module.add_input_value("c", Tensor(Float(), (4, 4), device=d0))
x = module.add_op("MatMul", "MatMul0", inputs=[a, b], output_names=["x"])
y = module.add_op("MatMul", "MatMul1", inputs=[x, c], output_names=["y"])
module.finalize()
a = function.add_input_value("a", Tensor(Float(), (4, 4), device=d0))
b = function.add_input_value("b", Tensor(Float(), (4, 4), device=d0))
c = function.add_input_value("c", Tensor(Float(), (4, 4), device=d0))
x = function.add_op("MatMul", "MatMul0", inputs=[a, b], output_names=["x"])
y = function.add_op("MatMul", "MatMul1", inputs=[x, c], output_names=["y"])
function = function.finalize()
function = infer_types(function, [a, b, c])
device_speeds = {"gpu": 1.0e13}
cost_model = CostModel(topology, device_speeds)
simulator = DistributedSimulator(cost_model)
transform = DataParallelTransform(
batch_dims={"a": 0},
reduction_params={"y": {"op_type": "Gather", "dim": 0, "device": d0}},
batch_dims={function.inputs[0]: 0},
reduction_params={
function.outputs[0]: {"op_type": "Gather", "dim": 0, "device": d0}
},
devices=[d0, d1],
)
transformed_module = transform.apply(module)
transformed_module.finalize()
transformed_function = transform.apply(function)
transformed_function = infer_types(
transformed_function, transformed_function.inputs
)
simulation = simulator.simulate(transformed_module)
simulation = simulator.simulate(transformed_function)
simulation.dump_chrome_trace("test/trace.json")

Просмотреть файл

@ -37,15 +37,15 @@ def test_parser():
return %y: !dist.tensor<8x6xf32, 0>
}
"""
modules = mlir_parser.parse_mlir_str(mlir_str)
assert len(modules) == 1
module = modules[0]
cpprint(module)
functions = mlir_parser.parse_mlir_str(mlir_str)
assert len(functions) == 1
function = functions[0]
cpprint(function)
ex = SequentialExecutor("numpy")
_wA = np.ones((4, 6))
_x = np.arange(8 * 4).reshape((8, 4))
res = ex.compute(module, {"%arg1": _x, "%arg0": _wA})
res = ex.compute(function, {function.inputs[0]: _wA, function.inputs[1]: _x})
# TODO fix concat's implementation in numpy register for this:
# assert np.array_equal(res["%var4"], np.matmul(_x, _wA))

Просмотреть файл

@ -3,10 +3,10 @@ import pipeline_parallel_utils as utils
def test_fifo_scheduler():
(module, partition_map) = utils.construct_module_and_partition_map()
(function, partition_map) = utils.construct_function_and_partition_map()
(d0, d1) = sorted(set(partition_map.values()))
scheduler = FIFOScheduler(num_microbatches=2)
schedule = scheduler.schedule(module, partition_map)
schedule = scheduler.schedule(function, partition_map)
stages = list(partition_map.keys())
ref_schedule = [
@ -22,10 +22,10 @@ def test_fifo_scheduler():
def test_pipedream_scheduler():
(module, partition_map) = utils.construct_module_and_partition_map()
(function, partition_map) = utils.construct_function_and_partition_map()
(d0, d1) = sorted(set(partition_map.values()))
scheduler = PipeDreamScheduler(num_microbatches=2)
schedule = scheduler.schedule(module, partition_map)
schedule = scheduler.schedule(function, partition_map)
stages = list(partition_map.keys())
ref_schedule = [
@ -38,3 +38,7 @@ def test_pipedream_scheduler():
]
assert schedule == ref_schedule
if __name__ == "__main__":
test_fifo_scheduler()

Просмотреть файл

@ -1,6 +1,6 @@
import numpy as np
from dist_ir.ir import Device, Module
from dist_ir.ir import Device, Function
from dist_ir.ir.type import Float, Tensor
from dist_ir.transforms import PipelineParallelTransform
from dist_ir.executor import SequentialExecutor
@ -8,9 +8,8 @@ import pipeline_parallel_utils as utils
def test_mnist_fw_bw():
(module, partition_map) = utils.construct_module_and_partition_map()
(function, partition_map) = utils.construct_function_and_partition_map()
(d0, d1) = sorted(set(partition_map.values()))
stages = list(partition_map.keys())
schedule = [
{d0: (stages[0], 0)},
@ -22,27 +21,27 @@ def test_mnist_fw_bw():
]
transform = PipelineParallelTransform(
num_microbatches=2,
batch_dims={"x": 0, "z": 0},
batch_dims={function.inputs[0]: 0, function.inputs[1]: 0},
reduction_params={
"dwB": {"op_type": "Add"},
"dwA": {"op_type": "Add"},
"l": {"op_type": "Concat", "dim": 0},
function.outputs[0]: {"op_type": "Concat", "dim": 0}, # l
function.outputs[1]: {"op_type": "Add"}, # dwB
function.outputs[2]: {"op_type": "Concat", "dim": 0}, # dx
function.outputs[3]: {"op_type": "Add"}, # dwA
},
partition_map=partition_map,
schedule=schedule,
)
transformed_module = transform.apply(module)
transformed_module.finalize()
transformed_function = transform.apply(function)
print("-" * 88)
print("Original module")
print("Original function")
print("-" * 88)
print(module)
print(function)
print()
print("-" * 88)
print("Transformed module")
print("Transformed function")
print("-" * 88)
print(transformed_module)
print(transformed_function)
batch_size = 16
ex = SequentialExecutor("numpy")
@ -51,17 +50,27 @@ def test_mnist_fw_bw():
_wA = np.ones((4, 2))
_wB = np.ones((2, 1))
orig_res = ex.compute(
module,
{"x": _x, "z": _z, "wA": _wA, "wB": _wB},
function,
{
function.inputs[0]: _x,
function.inputs[1]: _z,
function.inputs[2]: _wA,
function.inputs[3]: _wB,
},
)
transformed_res = ex.compute(
transformed_module,
{"x": _x, "z": _z, "wA": _wA, "wB": _wB},
transformed_function,
{
transformed_function.inputs[0]: _x,
transformed_function.inputs[1]: _z,
transformed_function.inputs[2]: _wA,
transformed_function.inputs[3]: _wB,
},
)
print("-" * 88)
print("Original module results")
print("Original function results")
print("-" * 88)
for k, v in orig_res.items():
print(k)
@ -69,16 +78,18 @@ def test_mnist_fw_bw():
print()
print()
print("-" * 88)
print("Transformed module results")
print("Transformed function results")
print("-" * 88)
for k, v in transformed_res.items():
print(k)
print(v)
print()
assert np.array_equal(orig_res["l"], transformed_res["l"])
assert np.array_equal(orig_res["dwA"], transformed_res["dwA"])
assert np.array_equal(orig_res["dwB"], transformed_res["dwB"])
for i in range(len(function.outputs)):
np.testing.assert_array_almost_equal(
orig_res[function.outputs[i]],
transformed_res[transformed_function.outputs[i]],
)
if __name__ == "__main__":

Просмотреть файл

@ -1,27 +1,27 @@
from pathlib import Path
from dist_ir.importer import import_from_onnx
from dist_ir.ir import Module, Topology
from dist_ir.ir import FunctionMaker, Topology
from dist_ir.ir.type import Float, Tensor
from dist_ir.ir import cpprint
def test_cpprint():
module = Module()
function = FunctionMaker()
topology = Topology()
d = topology.add_device("gpu")
a = module.add_input_value("a", Tensor(dtype=Float(), shape=(4, 4), device=d))
b = module.add_input_value("b", Tensor(dtype=Float(), shape=(4, 4), device=d))
x = module.add_op("MatMul", "MatMul0", inputs=[a, b])
y = module.add_op("MatMul", "MatMul1", inputs=[x, b])
module.finalize()
a = function.add_input_value("a", Tensor(dtype=Float(), shape=(4, 4), device=d))
b = function.add_input_value("b", Tensor(dtype=Float(), shape=(4, 4), device=d))
x = function.add_op("MatMul", "MatMul0", inputs=[a, b])
y = function.add_op("MatMul", "MatMul1", inputs=[x, b])
function.finalize()
cpprint(module)
cpprint(function)
def test_import_from_onnx():
onnx_model_path = Path(__file__).parent / "mnist_gemm_bw_running.onnx"
module = import_from_onnx(onnx_model_path)
cpprint(module)
function = import_from_onnx(onnx_model_path)
cpprint(function)

Просмотреть файл

@ -2,20 +2,19 @@ import numpy as np
import pytest
import torch
from dist_ir.ir import Device, Module
from dist_ir.ir import Device, FunctionMaker, cpprint
from dist_ir.ir.type import Float, Tensor, TupleType
from dist_ir.executor import SequentialExecutor
from dist_ir.executor.shape_inference import infer_shapes
class Helper:
def __init__(self, backend):
self.backend = backend
self.executor = SequentialExecutor(self.backend)
self.module = Module()
self.t1 = self.module.add_input_value("a", Tensor(Float(), (4, 4)))
self.t2 = self.module.add_input_value("b", Tensor(Float(), (4, 4)))
self.t3 = self.module.add_input_value("c", Tensor(Float(), (4, 4)))
self.function = FunctionMaker()
self.a = self.function.add_input_value("a", Tensor(Float(), (4, 4)))
self.b = self.function.add_input_value("b", Tensor(Float(), (4, 4)))
self.c = self.function.add_input_value("c", Tensor(Float(), (4, 4)))
if self.backend == "numpy":
a = np.random.normal(size=(4, 4))
b = np.random.normal(size=(4, 4))
@ -27,9 +26,9 @@ class Helper:
else:
raise ValueError(f"Unknown backend {self.backend}")
self.input_data = {
"a": a,
"b": b,
"c": c,
self.a: a,
self.b: b,
self.c: c,
}
print(f"Backend: {self.backend}")
@ -41,118 +40,122 @@ def backend(request):
def test_single_add(backend):
h = Helper(backend)
h.module.add_op("Add", "Add_0", inputs=[h.t1, h.t2])
h.module.finalize()
output_data = h.executor.compute(h.module, h.input_data)
result = output_data["Add_0/0"]
res = h.function.add_op("Add", "Add_0", inputs=[h.a, h.b])
h.function = h.function.finalize()
output_data = h.executor.compute(h.function, h.input_data)
result = output_data[res]
if h.backend == "numpy":
assert np.array_equal(result, np.add(h.input_data["a"], h.input_data["b"]))
assert np.array_equal(result, np.add(h.input_data[h.a], h.input_data[h.b]))
elif h.backend == "torch":
assert result.equal(torch.add(h.input_data["a"], h.input_data["b"]))
assert result.equal(torch.add(h.input_data[h.a], h.input_data[h.b]))
def test_double_add(backend):
h = Helper(backend)
x = h.module.add_op("Add", "Add_0", inputs=[h.t1, h.t2])
h.module.add_op("Add", "Add_1", inputs=[h.t3, x])
h.module.finalize()
output_data = h.executor.compute(h.module, h.input_data)
result = output_data["Add_1/0"]
x = h.function.add_op("Add", "Add_0", inputs=[h.a, h.b])
res = h.function.add_op("Add", "Add_1", inputs=[h.c, x])
h.function = h.function.finalize()
output_data = h.executor.compute(h.function, h.input_data)
result = output_data[res]
if h.backend == "numpy":
assert np.array_equal(
result,
np.add(h.input_data["c"], np.add(h.input_data["a"], h.input_data["b"])),
np.add(h.input_data[h.c], np.add(h.input_data[h.a], h.input_data[h.b])),
)
elif h.backend == "torch":
assert result.equal(
torch.add(
h.input_data["c"],
torch.add(h.input_data["a"], h.input_data["b"]),
h.input_data[h.c],
torch.add(h.input_data[h.a], h.input_data[h.b]),
)
)
def test_double_add_inverted(backend):
h = Helper(backend)
x = h.module.add_op("Add", "Add_0", inputs=[h.t1, h.t2])
h.module.add_op("Add", "Add_1", inputs=[x, h.t3])
h.module.finalize()
output_data = h.executor.compute(h.module, h.input_data)
result = output_data["Add_1/0"]
x = h.function.add_op("Add", "Add_0", inputs=[h.a, h.b])
res = h.function.add_op("Add", "Add_1", inputs=[x, h.c])
h.function = h.function.finalize()
output_data = h.executor.compute(h.function, h.input_data)
result = output_data[res]
if h.backend == "numpy":
assert np.array_equal(
result,
np.add(np.add(h.input_data["a"], h.input_data["b"]), h.input_data["c"]),
np.add(np.add(h.input_data[h.a], h.input_data[h.b]), h.input_data[h.c]),
)
elif h.backend == "torch":
assert result.equal(
torch.add(
torch.add(h.input_data["a"], h.input_data["b"]),
h.input_data["c"],
torch.add(h.input_data[h.a], h.input_data[h.b]),
h.input_data[h.c],
)
)
def test_single_matmul(backend):
h = Helper(backend)
h.module.add_op("MatMul", "MatMul_0", inputs=[h.t1, h.t2])
h.module.finalize()
output_data = h.executor.compute(h.module, h.input_data)
result = output_data["MatMul_0/0"]
res = h.function.add_op("MatMul", "MatMul_0", inputs=[h.a, h.b])
h.function = h.function.finalize()
output_data = h.executor.compute(h.function, h.input_data)
result = output_data[res]
if h.backend == "numpy":
assert np.array_equal(result, np.matmul(h.input_data["a"], h.input_data["b"]))
assert np.array_equal(result, np.matmul(h.input_data[h.a], h.input_data[h.b]))
elif h.backend == "torch":
assert result.equal(torch.matmul(h.input_data["a"], h.input_data["b"]))
assert result.equal(torch.matmul(h.input_data[h.a], h.input_data[h.b]))
def test_double_matmul(backend):
h = Helper(backend)
x = h.module.add_op("MatMul", "MatMul_0", inputs=[h.t1, h.t2])
h.module.add_op("MatMul", "MatMul_1", inputs=[h.t3, x])
h.module.finalize()
output_data = h.executor.compute(h.module, h.input_data)
result = output_data["MatMul_1/0"]
x = h.function.add_op("MatMul", "MatMul_0", inputs=[h.a, h.b])
res = h.function.add_op("MatMul", "MatMul_1", inputs=[h.c, x])
h.function = h.function.finalize()
output_data = h.executor.compute(h.function, h.input_data)
result = output_data[res]
if h.backend == "numpy":
assert np.array_equal(
result,
np.matmul(
h.input_data["c"], np.matmul(h.input_data["a"], h.input_data["b"])
h.input_data[h.c], np.matmul(h.input_data[h.a], h.input_data[h.b])
),
)
elif h.backend == "torch":
assert result.equal(
torch.matmul(
h.input_data["c"],
torch.matmul(h.input_data["a"], h.input_data["b"]),
h.input_data[h.c],
torch.matmul(h.input_data[h.a], h.input_data[h.b]),
)
)
def test_double_matmul_inverted(backend):
h = Helper(backend)
x = h.module.add_op("MatMul", "MatMul_0", inputs=[h.t1, h.t2])
h.module.add_op("MatMul", "MatMul_1", inputs=[x, h.t3])
h.module.finalize()
output_data = h.executor.compute(h.module, h.input_data)
result = output_data["MatMul_1/0"]
x = h.function.add_op("MatMul", "MatMul_0", inputs=[h.a, h.b])
res = h.function.add_op("MatMul", "MatMul_1", inputs=[x, h.c])
h.function = h.function.finalize()
output_data = h.executor.compute(h.function, h.input_data)
result = output_data[res]
if h.backend == "numpy":
assert np.array_equal(
result,
np.matmul(
np.matmul(h.input_data["a"], h.input_data["b"]), h.input_data["c"]
np.matmul(h.input_data[h.a], h.input_data[h.b]), h.input_data[h.c]
),
)
elif h.backend == "torch":
assert result.equal(
torch.matmul(
torch.matmul(h.input_data["a"], h.input_data["b"]),
h.input_data["c"],
torch.matmul(h.input_data[h.a], h.input_data[h.b]),
h.input_data[h.c],
)
)
# TODO: Add test for op with multiple outputs
# TODO for all pmap tests, make a FunctionMaker helper function to add pmap
# which also creates the device var and sets the attributes etc appropriately.
# This should also be used by transforms/parsers that create pmap ops.
def test_pmap_on_executor():
d0 = Device(0, "gpu")
@ -169,137 +172,142 @@ def test_pmap_on_executor():
_y_0, _y_1 = _y[:4], _y[4:]
# A pmap with 1 input and 1 output
module = Module()
xs = module.add_input_value("xs", TupleType((x_type(d0), x_type(d1))))
submodule = Module()
x = submodule.add_input_value("x", x_type(None))
_ = submodule.add_op("Add", "Add0", inputs=[x, x], output_names=["z"])
# submodule.set_outputs()
submodule.finalize()
_ = module.add_op(
function = FunctionMaker()
xs = function.add_input_value("xs", TupleType((x_type(d0), x_type(d1))))
subfunction = FunctionMaker()
x = subfunction.add_input_value("x", x_type(None))
_ = subfunction.add_op("Add", "Add0", inputs=[x, x], output_names=["z"])
# subfunction.set_outputs()
subfunction = subfunction.finalize()
zis = function.add_op(
"Pmap",
inputs=[xs],
attributes={"devices": [d0, d1]},
submodules=[submodule],
subfunctions=[subfunction],
output_names=["zis"],
)
module.finalize()
function = function.finalize()
res = ex.compute(module, {"xs": (_x_0, _x_1)})
assert np.array_equal(res["zis"][0], _x_0 + _x_0)
assert np.array_equal(res["zis"][1], _x_1 + _x_1)
cpprint(function)
res = ex.compute(function, {xs: (_x_0, _x_1)})
assert np.array_equal(res[zis][0], _x_0 + _x_0)
assert np.array_equal(res[zis][1], _x_1 + _x_1)
# A pmap with 2 inputs and 1 output
module = Module()
xs = module.add_input_value("xs", TupleType((x_type(d0), x_type(d1))))
ys = module.add_input_value("ys", TupleType((y_type(d0), y_type(d1))))
submodule = Module()
x = submodule.add_input_value("x", x_type(None))
y = submodule.add_input_value("y", y_type(None))
_ = submodule.add_op("MatMul", "MatMul0", inputs=[x, y], output_names=["z"])
submodule.finalize()
_ = module.add_op(
function = FunctionMaker()
xs = function.add_input_value("xs", TupleType((x_type(d0), x_type(d1))))
ys = function.add_input_value("ys", TupleType((y_type(d0), y_type(d1))))
subfunction = FunctionMaker()
x = subfunction.add_input_value("x", x_type(None))
y = subfunction.add_input_value("y", y_type(None))
_ = subfunction.add_op("MatMul", "MatMul0", inputs=[x, y], output_names=["z"])
subfunction = subfunction.finalize()
zis = function.add_op(
"Pmap",
inputs=[xs, ys],
attributes={"devices": [d0, d1]},
submodules=[submodule],
subfunctions=[subfunction],
output_names=["zis"],
)
module.finalize()
function = function.finalize()
res = ex.compute(module, {"xs": (_x_0, _x_1), "ys": (_y_0, _y_1)})
assert np.array_equal(res["zis"][0], np.matmul(_x_0, _y_0))
assert np.array_equal(res["zis"][1], np.matmul(_x_1, _y_1))
cpprint(function)
res = ex.compute(function, {xs: (_x_0, _x_1), ys: (_y_0, _y_1)})
assert np.array_equal(res[zis][0], np.matmul(_x_0, _y_0))
assert np.array_equal(res[zis][1], np.matmul(_x_1, _y_1))
# A pmap with 2 inputs and 2 outputs
module = Module()
xs = module.add_input_value("xs", TupleType((x_type(d0), x_type(d1))))
ys = module.add_input_value("ys", TupleType((y_type(d0), y_type(d1))))
submodule = Module()
x = submodule.add_input_value("x", x_type(None))
y = submodule.add_input_value("y", y_type(None))
_ = submodule.add_op("Add", "Add0", inputs=[x, x], output_names=["w"])
_ = submodule.add_op("MatMul", "MatMul0", inputs=[x, y], output_names=["z"])
submodule.finalize()
_ = module.add_op(
function = FunctionMaker()
xs = function.add_input_value("xs", TupleType((x_type(d0), x_type(d1))))
ys = function.add_input_value("ys", TupleType((y_type(d0), y_type(d1))))
subfunction = FunctionMaker()
x = subfunction.add_input_value("x", x_type(None))
y = subfunction.add_input_value("y", y_type(None))
_ = subfunction.add_op("Add", "Add0", inputs=[x, x], output_names=["w"])
_ = subfunction.add_op("MatMul", "MatMul0", inputs=[x, y], output_names=["z"])
subfunction = subfunction.finalize()
(wis, zis) = function.add_op(
"Pmap",
inputs=[xs, ys],
attributes={"devices": [d0, d1]},
submodules=[submodule],
subfunctions=[subfunction],
output_names=["wis", "zis"],
)
module.finalize()
function = function.finalize()
res = ex.compute(module, {"xs": (_x_0, _x_1), "ys": (_y_0, _y_1)})
assert np.array_equal(res["wis"][0], _x_0 + _x_0)
assert np.array_equal(res["wis"][1], _x_1 + _x_1)
assert np.array_equal(res["zis"][0], np.matmul(_x_0, _y_0))
assert np.array_equal(res["zis"][1], np.matmul(_x_1, _y_1))
cpprint(function)
res = ex.compute(function, {xs: (_x_0, _x_1), ys: (_y_0, _y_1)})
assert np.array_equal(res[wis][0], _x_0 + _x_0)
assert np.array_equal(res[wis][1], _x_1 + _x_1)
assert np.array_equal(res[zis][0], np.matmul(_x_0, _y_0))
assert np.array_equal(res[zis][1], np.matmul(_x_1, _y_1))
# A pmap with a single device
module = Module()
xs = module.add_input_value("xs", TupleType((x_type(None),)))
ys = module.add_input_value("ys", TupleType((y_type(None),)))
submodule = Module()
x = submodule.add_input_value("x", x_type(None))
y = submodule.add_input_value("y", y_type(None))
_ = submodule.add_op("Add", "Add0", inputs=[x, x], output_names=["w"])
_ = submodule.add_op("MatMul", "MatMul0", inputs=[x, y], output_names=["z"])
submodule.finalize()
_ = module.add_op(
function = FunctionMaker()
xs = function.add_input_value("xs", TupleType((x_type(None),)))
ys = function.add_input_value("ys", TupleType((y_type(None),)))
subfunction = FunctionMaker()
x = subfunction.add_input_value("x", x_type(None))
y = subfunction.add_input_value("y", y_type(None))
_ = subfunction.add_op("Add", "Add0", inputs=[x, x], output_names=["w"])
_ = subfunction.add_op("MatMul", "MatMul0", inputs=[x, y], output_names=["z"])
subfunction = subfunction.finalize()
(wis, zis) = function.add_op(
"Pmap",
inputs=[xs, ys],
attributes={"devices": [d0]},
submodules=[submodule],
subfunctions=[subfunction],
output_names=["wis", "zis"],
)
module.finalize()
function = function.finalize()
res = ex.compute(module, {"xs": (_x_0,), "ys": (_y_0,)})
assert np.array_equal(res["wis"][0], _x_0 + _x_0)
assert np.array_equal(res["zis"][0], np.matmul(_x_0, _y_0))
cpprint(function)
res = ex.compute(function, {xs: (_x_0,), ys: (_y_0,)})
assert np.array_equal(res[wis][0], _x_0 + _x_0)
assert np.array_equal(res[zis][0], np.matmul(_x_0, _y_0))
def test_pmap_dp():
module = Module()
function = FunctionMaker()
d0 = Device(0, "gpu")
d1 = Device(1, "gpu")
xs = module.add_input_value(
xs = function.add_input_value(
"xs",
TupleType(
(Tensor(Float(), (8, 4), device=d0), Tensor(Float(), (8, 4), device=d1))
),
)
wAs = module.add_input_value(
wAs = function.add_input_value(
"wAs",
TupleType(
(Tensor(Float(), (4, 2), device=d0), Tensor(Float(), (4, 2), device=d1))
),
)
wBs = module.add_input_value(
wBs = function.add_input_value(
"wBs",
TupleType(
(Tensor(Float(), (2, 1), device=d0), Tensor(Float(), (2, 1), device=d1))
),
)
submodule = Module()
x = submodule.add_input_value("x", Tensor(Float(), (8, 4)))
wA = submodule.add_input_value("wA", Tensor(Float(), (4, 2)))
wB = submodule.add_input_value("wB", Tensor(Float(), (2, 1)))
y = submodule.add_op("MatMul", "MatMul0", inputs=[x, wA], output_names=["y"])
_ = submodule.add_op("MatMul", "MatMul1", inputs=[y, wB], output_names=["z"])
submodule.finalize()
_ = module.add_op(
subfunction = FunctionMaker()
x = subfunction.add_input_value("x", Tensor(Float(), (8, 4)))
wA = subfunction.add_input_value("wA", Tensor(Float(), (4, 2)))
wB = subfunction.add_input_value("wB", Tensor(Float(), (2, 1)))
y = subfunction.add_op("MatMul", "MatMul0", inputs=[x, wA], output_names=["y"])
_ = subfunction.add_op("MatMul", "MatMul1", inputs=[y, wB], output_names=["z"])
subfunction = subfunction.finalize()
zis = function.add_op(
"Pmap",
inputs=[xs, wAs, wBs],
attributes={"devices": [d0, d1]},
submodules=[submodule],
subfunctions=[subfunction],
output_names=["zis"],
)
module.finalize()
function = function.finalize()
cpprint(function)
ex = SequentialExecutor("numpy")
_x = np.arange(16 * 4).reshape((16, 4))
@ -307,12 +315,8 @@ def test_pmap_dp():
_wA = np.ones((4, 2))
_wB = np.ones((2, 1))
res = ex.compute(
module,
{"xs": (x_0, x_1), "wAs": (_wA, _wA), "wBs": (_wB, _wB)},
function,
{xs: (x_0, x_1), wAs: (_wA, _wA), wBs: (_wB, _wB)},
)
assert np.array_equal(res["zis"][0], np.matmul(np.matmul(x_0, _wA), _wB))
assert np.array_equal(res["zis"][1], np.matmul(np.matmul(x_1, _wA), _wB))
if __name__ == "__main__":
test_pmap_on_executor()
assert np.array_equal(res[zis][0], np.matmul(np.matmul(x_0, _wA), _wB))
assert np.array_equal(res[zis][1], np.matmul(np.matmul(x_1, _wA), _wB))

Просмотреть файл

@ -1,190 +0,0 @@
import pytest
from dist_ir.ir import Device, Module
from dist_ir.executor.shape_inference import infer_shapes
from dist_ir.ir.type import Float, Tensor, TupleType
def test_add_valid():
module = Module()
a = module.add_input_value("a", Tensor(Float(), (4, 4)))
b = module.add_input_value("b", Tensor(Float(), (4, 4)))
x = module.add_op("Add", "Add0", inputs=[a, b], output_names=["x"])
infer_shapes(module)
assert x.type.shape == (4, 4)
def test_add_invalid():
module = Module()
a = module.add_input_value("a", Tensor(Float(), (8, 4)))
b = module.add_input_value("b", Tensor(Float(), (4, 2)))
x = module.add_op("Add", "Add0", inputs=[a, b], output_names=["x"])
with pytest.raises(ValueError):
infer_shapes(module)
def test_allreduce():
module = Module()
d0 = Device(0, "gpu")
d1 = Device(1, "gpu")
xis = module.add_input_value(
"xis",
TupleType(
(Tensor(Float(), (4, 4), device=d0), Tensor(Float(), (4, 4), device=d1))
),
)
xs = module.add_op(
"Allreduce",
"Allreduces/xis",
inputs=[xis],
output_names=["xs"],
)
infer_shapes(module)
assert isinstance(xs.type, TupleType)
for i, value_type in enumerate(xis.type.types):
assert value_type.shape == xs.type.types[i].shape
assert value_type.device == xs.type.types[i].device
def test_broadcast():
module = Module()
d0 = Device(0, "gpu")
d1 = Device(1, "gpu")
x = module.add_input_value("x", Tensor(Float(), (4, 4)))
xs = module.add_op(
"Broadcast",
"Broadcast/x",
inputs=[x],
attributes={"devices": [d0, d1]},
output_names=["xs"],
)
infer_shapes(module)
assert isinstance(xs.type, TupleType)
assert xs.type.types[0].shape == (4, 4)
assert xs.type.types[0].device == d0
assert xs.type.types[1].shape == (4, 4)
assert xs.type.types[1].device == d1
def test_matmul_valid():
module = Module()
a = module.add_input_value("a", Tensor(Float(), (8, 4)))
b = module.add_input_value("b", Tensor(Float(), (4, 2)))
x = module.add_op("MatMul", "MatMul0", inputs=[a, b], output_names=["x"])
infer_shapes(module)
assert x.type.shape == (8, 2)
def test_matmul_invalid():
module = Module()
a = module.add_input_value("a", Tensor(Float(), (8, 8)))
b = module.add_input_value("b", Tensor(Float(), (4, 2)))
x = module.add_op("MatMul", "MatMul0", inputs=[a, b], output_names=["x"])
with pytest.raises(ValueError):
infer_shapes(module)
def test_matmul_grad():
module = Module()
x = module.add_input_value("x", Tensor(Float(), (8, 4)))
w = module.add_input_value("w", Tensor(Float(), (4, 2)))
l = module.add_input_value("l", Tensor(Float(), (8,)))
dx, dw = module.add_op(
"MatMulGrad", "MatMulGrad0", inputs=[x, w, l], output_names=["dx", "dw"]
)
infer_shapes(module)
assert dx.type.shape == x.type.shape
assert dw.type.shape == w.type.shape
def test_pmap():
module = Module()
d0 = Device(0, "gpu")
d1 = Device(1, "gpu")
xs = module.add_input_value(
"xs",
TupleType(
(Tensor(Float(), (8, 4), device=d0), Tensor(Float(), (8, 4), device=d1))
),
)
wAs = module.add_input_value(
"wAs",
TupleType(
(Tensor(Float(), (4, 2), device=d0), Tensor(Float(), (4, 2), device=d1))
),
)
wBs = module.add_input_value(
"wBs",
TupleType(
(Tensor(Float(), (2, 1), device=d0), Tensor(Float(), (2, 1), device=d1))
),
)
submodule = Module()
# TODO: Add a check in shape inference to not overwrite if there is already a shape
# If there is an existing shape, validate that it is what we expect, otherwise throw error
x = submodule.add_input_value("x", Tensor(Float(), (8, 4)))
wA = submodule.add_input_value("wA", Tensor(Float(), (4, 2)))
wB = submodule.add_input_value("wB", Tensor(Float(), (2, 1)))
y = submodule.add_op("MatMul", "MatMul0", inputs=[x, wA], output_names=["y"])
z = submodule.add_op("MatMul", "MatMul1", inputs=[y, wB], output_names=["z"])
submodule.finalize()
zis = module.add_op(
"Pmap",
inputs=[xs, wAs, wBs],
attributes={"devices": [d0, d1]},
submodules=[submodule],
output_names=["zis"],
)
infer_shapes(module)
print(module)
# TODO: Verify submodule shapes and devices
assert zis.type.types[0].shape == (8, 1)
assert zis.type.types[0].device == d0
assert zis.type.types[1].shape == (8, 1)
assert zis.type.types[1].device == d1
def test_scatter():
module = Module()
d0 = Device(0, "gpu")
d1 = Device(1, "gpu")
x = module.add_input_value("x", Tensor(Float(), (4, 4)))
xs = module.add_op(
"Scatter",
"Scatter/x",
inputs=[x],
attributes={"dim": 0, "devices": [d0, d1]},
output_names=["xs"],
)
infer_shapes(module)
assert isinstance(xs.type, TupleType)
assert xs.type.types[0].shape == (2, 4)
assert xs.type.types[0].device == d0
assert xs.type.types[1].shape == (2, 4)
assert xs.type.types[1].device == d1
if __name__ == "__main__":
test_pmap()

45
test/test_subfunction.py Normal file
Просмотреть файл

@ -0,0 +1,45 @@
from dist_ir.ir import FunctionMaker
from dist_ir.ir.type import Tensor, Float
def test_subfunction():
function = FunctionMaker()
inputs = []
outputs = []
num_ops = 9
for i in range(num_ops + 1):
inputs.append(function.add_input_value(f"x{i}", Tensor(Float(), (4, 4))))
for i in range(num_ops):
if i == 0:
input_values = inputs[:2]
else:
input_values = [outputs[-1], inputs[i + 1]]
outputs.append(
function.add_op(
"Add", f"Add{i}", inputs=input_values, output_names=[f"a{i}"]
)
)
function = function.finalize()
subfunction = function.get_subfunction(("Add0", "Add1", "Add2"))
subfunction_inputs = subfunction.inputs
subfunction_outputs = subfunction.outputs
assert [v.name for v in subfunction_inputs] == ["x0", "x1", "x2", "x3"]
assert [v.name for v in subfunction_outputs] == ["a2"]
subfunction = function.get_subfunction(("Add3", "Add4", "Add5"))
subfunction_inputs = subfunction.inputs
subfunction_outputs = subfunction.outputs
assert [v.name for v in subfunction_inputs] == ["a2", "x4", "x5", "x6"]
assert [v.name for v in subfunction_outputs] == ["a5"]
subfunction = function.get_subfunction(("Add6", "Add7", "Add8"))
subfunction_inputs = subfunction.inputs
subfunction_outputs = subfunction.outputs
assert [v.name for v in subfunction_inputs] == ["a5", "x7", "x8", "x9"]
assert [v.name for v in subfunction_outputs] == ["a8"]
if __name__ == "__main__":
test_subfunction()

Просмотреть файл

@ -1,43 +0,0 @@
from dist_ir.ir import Module
from dist_ir.ir.type import Tensor, Float
def test_submodule():
module = Module()
inputs = []
outputs = []
num_ops = 9
for i in range(num_ops + 1):
inputs.append(module.add_input_value(f"x{i}", Tensor(Float(), (4, 4))))
for i in range(num_ops):
if i == 0:
input_values = inputs[:2]
else:
input_values = [outputs[-1], inputs[i + 1]]
outputs.append(
module.add_op("Add", f"Add{i}", inputs=input_values, output_names=[f"a{i}"])
)
module.finalize()
submodule = module.get_submodule(("Add0", "Add1", "Add2"))
submodule_inputs = submodule.get_inputs()
submodule_outputs = submodule.get_outputs()
assert [v.name for v in submodule_inputs] == ["x0", "x1", "x2", "x3"]
assert [v.name for v in submodule_outputs] == ["a2"]
submodule = module.get_submodule(("Add3", "Add4", "Add5"))
submodule_inputs = submodule.get_inputs()
submodule_outputs = submodule.get_outputs()
assert [v.name for v in submodule_inputs] == ["a2", "x4", "x5", "x6"]
assert [v.name for v in submodule_outputs] == ["a5"]
submodule = module.get_submodule(("Add6", "Add7", "Add8"))
submodule_inputs = submodule.get_inputs()
submodule_outputs = submodule.get_outputs()
assert [v.name for v in submodule_inputs] == ["a5", "x7", "x8", "x9"]
assert [v.name for v in submodule_outputs] == ["a8"]
if __name__ == "__main__":
test_submodule()

205
test/test_type_inference.py Normal file
Просмотреть файл

@ -0,0 +1,205 @@
import pytest
from dist_ir.ir import cpprint, Device, Function, FunctionMaker, Op, Value
from dist_ir.executor.type_inference import infer_types
from dist_ir.ir.type import Float, Tensor, TupleType
def test_add_valid():
function = FunctionMaker()
a = function.add_input_value("a", Tensor(Float(), (4, 4)))
b = function.add_input_value("b", Tensor(Float(), (4, 4)))
x = function.add_op("Add", "Add0", inputs=[a, b], output_names=["x"])
function = function.finalize()
typed_function = infer_types(function, [a, b])
assert typed_function.outputs[0].type.shape == (4, 4)
def test_add_invalid():
function = FunctionMaker()
a = function.add_input_value("a", Tensor(Float(), (8, 4)))
b = function.add_input_value("b", Tensor(Float(), (4, 2)))
x = function.add_op("Add", "Add0", inputs=[a, b], output_names=["x"])
function = function.finalize()
with pytest.raises(ValueError):
infer_types(function, [a, b])
def test_allreduce():
d0 = Device(0, "gpu")
d1 = Device(1, "gpu")
xis = Value(
"xis",
TupleType(
(Tensor(Float(), (4, 4), device=d0), Tensor(Float(), (4, 4), device=d1))
),
)
op1 = Op(
"Allreduce",
"Allreduces/xis",
inputs=[xis],
output_names=["xs"],
)
function = Function("foo", (op1,), (xis,), (op1.outputs[0],))
function = infer_types(function, [xis])
xs = function.outputs[0]
assert isinstance(xs.type, TupleType)
for i, value_type in enumerate(xis.type.types):
assert value_type.shape == xs.type.types[i].shape
assert value_type.device == xs.type.types[i].device
def test_broadcast():
function = FunctionMaker()
d0 = Device(0, "gpu")
d1 = Device(1, "gpu")
x = function.add_input_value("x", Tensor(Float(), (4, 4)))
xs = function.add_op(
"Broadcast",
"Broadcast/x",
inputs=[x],
attributes={"devices": [d0, d1]},
output_names=["xs"],
)
function = function.finalize()
function = infer_types(function, [x])
xs = function.outputs[0]
assert isinstance(xs.type, TupleType)
assert xs.type.types[0].shape == (4, 4)
assert xs.type.types[0].device == d0
assert xs.type.types[1].shape == (4, 4)
assert xs.type.types[1].device == d1
def test_matmul_valid():
function = FunctionMaker()
a = function.add_input_value("a", Tensor(Float(), (8, 4)))
b = function.add_input_value("b", Tensor(Float(), (4, 2)))
x = function.add_op("MatMul", "MatMul0", inputs=[a, b], output_names=["x"])
function = function.finalize()
function = infer_types(function, [a, b])
assert function.outputs[0].type.shape == (8, 2)
def test_matmul_invalid():
function = FunctionMaker()
a = function.add_input_value("a", Tensor(Float(), (8, 8)))
b = function.add_input_value("b", Tensor(Float(), (4, 2)))
x = function.add_op("MatMul", "MatMul0", inputs=[a, b], output_names=["x"])
function = function.finalize()
with pytest.raises(ValueError):
function = infer_types(function, [a, b])
def test_matmul_grad():
function = FunctionMaker()
x = function.add_input_value("x", Tensor(Float(), (8, 4)))
w = function.add_input_value("w", Tensor(Float(), (4, 2)))
l = function.add_input_value("l", Tensor(Float(), (8,)))
dx, dw = function.add_op(
"MatMulGrad", "MatMulGrad0", inputs=[x, w, l], output_names=["dx", "dw"]
)
function = function.finalize()
function = infer_types(function, [x, w, l])
dx, dw = function.outputs
assert dx.type.shape == x.type.shape
assert dw.type.shape == w.type.shape
def test_pmap():
function = FunctionMaker()
d0 = Device(0, "gpu")
d1 = Device(1, "gpu")
xs = function.add_input_value(
"xs",
TupleType(
(Tensor(Float(), (8, 4), device=d0), Tensor(Float(), (8, 4), device=d1))
),
)
wAs = function.add_input_value(
"wAs",
TupleType(
(Tensor(Float(), (4, 2), device=d0), Tensor(Float(), (4, 2), device=d1))
),
)
wBs = function.add_input_value(
"wBs",
TupleType(
(Tensor(Float(), (2, 1), device=d0), Tensor(Float(), (2, 1), device=d1))
),
)
subfunction = FunctionMaker()
x = subfunction.add_input_value("x", None)
wA = subfunction.add_input_value("wA", None)
wB = subfunction.add_input_value("wB", None)
y = subfunction.add_op("MatMul", "MatMul0", inputs=[x, wA], output_names=["y"])
z = subfunction.add_op("MatMul", "MatMul1", inputs=[y, wB], output_names=["z"])
subfunction = subfunction.finalize()
zis = function.add_op(
"Pmap",
inputs=[xs, wAs, wBs],
attributes={
"devices": [d0, d1],
"device_var": Device.get_new_device_variable(
"gpu"
), # TODO where best to do this?
},
subfunctions=[subfunction],
output_names=["zis"],
)
function = function.finalize()
cpprint(function)
function = infer_types(function, [xs, wAs, wBs])
cpprint(function)
# TODO: Verify subfunction shapes and devices
zis = function.outputs[0]
assert zis.type.types[0].shape == (8, 1)
assert zis.type.types[0].device == d0
assert zis.type.types[1].shape == (8, 1)
assert zis.type.types[1].device == d1
def test_scatter():
function = FunctionMaker()
d0 = Device(0, "gpu")
d1 = Device(1, "gpu")
x = function.add_input_value("x", Tensor(Float(), (4, 4)))
xs = function.add_op(
"Scatter",
"Scatter/x",
inputs=[x],
attributes={"dim": 0, "devices": [d0, d1]},
output_names=["xs"],
)
function = function.finalize()
function = infer_types(function, [x])
xs = function.outputs[0]
assert isinstance(xs.type, TupleType)
assert xs.type.types[0].shape == (2, 4)
assert xs.type.types[0].device == d0
assert xs.type.types[1].shape == (2, 4)
assert xs.type.types[1].device == d1
if __name__ == "__main__":
test_pmap()