Refactor: immutable core classes using dataclasses (#15)
- Rename Module to Function to be consistent with MLIR - Add a name property to Function - Unify type and shape inference into a single pass to be performed after the full Function is constructed - At op creation time, only look up op signature in terms of number of inputs and outputs - Use dataclasses to define core IR classes and use the frozen option to ensure they are immutable - Removes any leftover uses of deepcopy. * Rename Module to Function * Add device_var attribute to Pmap op * Fix prettyprinting of Pmap * Use frozen dataclasses for Types * Convert Device to frozen dataclass * Convert Op to a frozen dataclass * Convert Function to frozen dataclass; add FunctionMaker * Convert Value to frozen dataclass * Make dataclass types hashable and use Value instead of value names (#17) * Make dataclass types hashable and update SequentialExecutor tests * Remove main from sequential executor tests * Fix pipeline parallel scheduler * Upgrade Python to 3.8, cache pip * Op: minor simplification * Update MLIR parser to use refactored IR * Allow passing output types to Op() * Fix typo, import order * First implementation of type inference * Fix a test * Add some TODOs * Replace strings with values in data parallel transform (and tests) * Add remaining type inference functions * Add type inference for pmap * Change names to values for pipeline parallel transform * Implement pmap's type prop fn, fix some others * Fix Device constructor arg order * Misc * Fix tests * Move shape inference to type inference * Remove shape inference module * Remove uses of deepcopy * Rename Op.{in,out}_edges to Op.{in,out}puts * Use ops instead of names for pipeline parallel utils * Fix for error message when output types and outputs don't match * Remove TODO for replacing value names with values * Make type inference functions more robust * Clean up op register * Minor improvements to type prop fns Co-authored-by: Keshav Santhanam <keshav2@stanford.edu>
This commit is contained in:
Родитель
36498ed684
Коммит
3c9905b659
|
@ -12,12 +12,14 @@ on:
|
|||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
env:
|
||||
MLIR_VERSION: 20210111.38
|
||||
PY_VERSION: 3.8
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
|
|
28
README.md
28
README.md
|
@ -40,40 +40,40 @@ python -m pytest
|
|||
# Components
|
||||
|
||||
- Executors:
|
||||
- SequentialExecutor: a reference implementation that runs a DistIR module
|
||||
- SequentialExecutor: a reference implementation that runs a DistIR function
|
||||
on a single device. Can be used to check correctness of transforms.
|
||||
- DistributedSimulator: an executor that uses profile data or flop counts to
|
||||
simulate the execution of a given DistIR module on a given hardware
|
||||
simulate the execution of a given DistIR function on a given hardware
|
||||
configuration (including communication bandwidths and processor speed).
|
||||
Returns estimated execution time and live memory profile. This can be
|
||||
split into three subcomponents:
|
||||
- Shape Inference: a pass that uses the shapes of inputs to calculate
|
||||
the shapes of all intermediate values.
|
||||
- Cost Inference: a pass that uses either shape information to compute
|
||||
(or profiles the module and measures) the runtime and temporary
|
||||
memory requirement of each op in the module.
|
||||
(or profiles the function and measures) the runtime and temporary
|
||||
memory requirement of each op in the function.
|
||||
This output can be cached.
|
||||
- Simulator: takes a module and a mapping from op to time/memory
|
||||
- Simulator: takes a function and a mapping from op to time/memory
|
||||
consumption and does a simulation to obtain a concurrent trace
|
||||
(from which total runtime and memory usage plots can be derived).
|
||||
- Importers:
|
||||
- ONNX Importer: convert a `.onnx` file to a DistIR module. Can be given an
|
||||
- ONNX Importer: convert a `.onnx` file to a DistIR function. Can be given an
|
||||
intermediate graph from ORT (for example, after AD).
|
||||
- MLIR Importer: import a DistIR module written in MLIR text format to an
|
||||
in-memory DistIR module object. TODO
|
||||
- Exporter/Prettyprinter: converts a DistIR module to an MLIR text format string.
|
||||
- MLIR Importer: import a DistIR function written in MLIR text format to an
|
||||
in-memory DistIR function object. TODO
|
||||
- Exporter/Prettyprinter: converts a DistIR function to an MLIR text format string.
|
||||
- Transforms: a module containing DistIR->DistIR transforms.
|
||||
Ideally, these modules should be composable and should run on submodules
|
||||
Ideally, these transforms should be composable and should run on subfunctions
|
||||
so that we can have nested parallelism (data parallel where a subset of the
|
||||
layers are horizontal parallel, or pipeline parallel where each stage is
|
||||
data parallel with a different degree).
|
||||
- DataParallelTransform: converts a given DistIR module to a data-parallel
|
||||
- DataParallelTransform: converts a given DistIR function to a data-parallel
|
||||
version that runs on a given number of devices.
|
||||
- HorizontalParallelTransform: converts a given DistIR module to a
|
||||
- HorizontalParallelTransform: converts a given DistIR function to a
|
||||
horizontal-parallel version (if possible) that runs on a given number of
|
||||
devices.
|
||||
- PipelineParallelTransform: converts a given DistIR module to a
|
||||
- PipelineParallelTransform: converts a given DistIR function to a
|
||||
pipeline-parallel version that runs on a given number of devices.
|
||||
- Search: an algorithm to find the best distributed version of a given
|
||||
sequential DistIR module. Initially, this can be something that searches
|
||||
sequential DistIR function. Initially, this can be something that searches
|
||||
the DHP space (i.e. find the optimal parameters to give the D/H/P transforms).
|
||||
|
|
|
@ -1,2 +1,3 @@
|
|||
from .distributed_simulator import DistributedSimulator
|
||||
from .sequential_executor import SequentialExecutor
|
||||
from .type_inference import infer_types
|
||||
|
|
|
@ -140,7 +140,7 @@ class CostModel:
|
|||
return costs
|
||||
|
||||
def infer_costs(self, op):
|
||||
inputs = op.get_in_edges()
|
||||
outputs = op.get_out_edges()
|
||||
inputs = op.inputs
|
||||
outputs = op.outputs
|
||||
|
||||
return self._op_register[op.op_type](op, inputs, outputs)
|
||||
|
|
|
@ -2,7 +2,7 @@ from copy import deepcopy
|
|||
from collections import defaultdict
|
||||
import json
|
||||
|
||||
from ..ir import Module
|
||||
from ..ir import Function
|
||||
from . import utils
|
||||
|
||||
SECONDS_TO_MICROSECONDS = 1e6
|
||||
|
@ -43,11 +43,11 @@ class DistributedSimulator:
|
|||
def __init__(self, cost_model):
|
||||
self._cost_model = cost_model
|
||||
|
||||
def _simulate(self, module: Module, state: DistributedSimulatorState):
|
||||
def _simulate(self, function: Function, state: DistributedSimulatorState):
|
||||
|
||||
for op_name, op in module.get_ops().items():
|
||||
in_edges = op.get_in_edges()
|
||||
out_edges = op.get_out_edges()
|
||||
for op in function.ops:
|
||||
in_edges = op.inputs
|
||||
out_edges = op.outputs
|
||||
|
||||
# Synchronize all input and output devices for this op.
|
||||
input_devices = utils.get_all_devices(in_edges)
|
||||
|
@ -63,40 +63,40 @@ class DistributedSimulator:
|
|||
# Compute the costs for the op.
|
||||
if op.op_type == "Pmap":
|
||||
# For Pmap ops we use a fresh state object and update the enclosing
|
||||
# module state using the Pmap state.
|
||||
submodule = op.get_submodule(0)
|
||||
submodule_state = DistributedSimulatorState()
|
||||
self._simulate(submodule, submodule_state)
|
||||
device_vars = submodule_state.timestamps.keys()
|
||||
# function state using the Pmap state.
|
||||
subfunction = op.subfunctions[0]
|
||||
subfunction_state = DistributedSimulatorState()
|
||||
self._simulate(subfunction, subfunction_state)
|
||||
device_vars = subfunction_state.timestamps.keys()
|
||||
assert len(device_vars) == 1
|
||||
# TODO what happens when pmaps are nested?
|
||||
bound_devices = op.get_attribute("devices")
|
||||
# Add submodule's trace to trace of all participating devices
|
||||
bound_devices = op.attributes["devices"]
|
||||
# Add subfunction's trace to trace of all participating devices
|
||||
for device in bound_devices:
|
||||
for event in submodule_state.trace:
|
||||
for event in subfunction_state.trace:
|
||||
# Need to add pmap's starting timestamp to event
|
||||
# since submodule_state started at time 0
|
||||
# since subfunction_state started at time 0
|
||||
start_time = event["ts"] + state.timestamps[device]
|
||||
state.add_trace_event(
|
||||
event["name"], device, start_time, event["dur"]
|
||||
)
|
||||
for device_var in device_vars:
|
||||
for bound_device in bound_devices:
|
||||
state.timestamps[bound_device] += submodule_state.timestamps[
|
||||
device_var
|
||||
]
|
||||
state.live_memory[bound_device] += submodule_state.live_memory[
|
||||
device_var
|
||||
]
|
||||
state.peak_memory[bound_device] += submodule_state.peak_memory[
|
||||
state.timestamps[bound_device] += subfunction_state.timestamps[
|
||||
device_var
|
||||
]
|
||||
state.live_memory[
|
||||
bound_device
|
||||
] += subfunction_state.live_memory[device_var]
|
||||
state.peak_memory[
|
||||
bound_device
|
||||
] += subfunction_state.peak_memory[device_var]
|
||||
# TODO: Update consumers?
|
||||
else:
|
||||
costs = self._cost_model.infer_costs(op)
|
||||
for device in costs:
|
||||
state.add_trace_event(
|
||||
op_name,
|
||||
op.name,
|
||||
device,
|
||||
state.timestamps[device],
|
||||
costs[device],
|
||||
|
@ -105,9 +105,7 @@ class DistributedSimulator:
|
|||
|
||||
# Update the live memory.
|
||||
for out_edge in out_edges:
|
||||
state.consumers[out_edge] = len(
|
||||
module.get_consumers_for_value(out_edge.name)
|
||||
)
|
||||
state.consumers[out_edge] = len(function.get_consumers(out_edge))
|
||||
# Output value could live on multiple devices (e.g. scatter) so
|
||||
# update memory on all devices:
|
||||
output_devices = out_edge.type.get_all_devices()
|
||||
|
@ -115,7 +113,9 @@ class DistributedSimulator:
|
|||
state.live_memory[output_device] += out_edge.type.size()
|
||||
# TODO: Can we optimize this using a priority queue?
|
||||
for value in state.consumers:
|
||||
if state.consumers[value] == 0 and not module.is_input(value.name):
|
||||
if state.consumers[value] == 0 and all(
|
||||
value != v for v in function.inputs
|
||||
):
|
||||
devices = value.type.get_all_devices()
|
||||
for device in devices:
|
||||
state.live_memory[device] -= value.type.size()
|
||||
|
@ -126,7 +126,7 @@ class DistributedSimulator:
|
|||
state.peak_memory[device], state.live_memory[device]
|
||||
)
|
||||
|
||||
def simulate(self, module):
|
||||
def simulate(self, function):
|
||||
state = DistributedSimulatorState()
|
||||
self._simulate(module, state)
|
||||
self._simulate(function, state)
|
||||
return state
|
||||
|
|
|
@ -12,25 +12,33 @@ def allreduce(op, inputs):
|
|||
|
||||
|
||||
def broadcast(op, inputs):
|
||||
return [inputs[0] for _ in range(len(op.get_attribute("devices")))]
|
||||
return [inputs[0] for _ in range(len(op.attributes["devices"]))]
|
||||
|
||||
|
||||
def concat(op, inputs):
|
||||
dim = op.get_attribute("dim")
|
||||
# assert len(inputs) == 1
|
||||
# dim = op.attributes["dim"]
|
||||
# return np.concatenate(inputs[0], axis=dim)
|
||||
dim = op.attributes["dim"]
|
||||
return np.concatenate(inputs, axis=dim)
|
||||
|
||||
|
||||
def gather(op, inputs):
|
||||
dim = op.attributes["dim"]
|
||||
return np.concatenate(inputs[0], axis=dim)
|
||||
|
||||
|
||||
def identity(op, inputs):
|
||||
return inputs[0]
|
||||
|
||||
|
||||
def loss(op, inputs):
|
||||
N = op.get_attribute("N")
|
||||
N = op.attributes["N"]
|
||||
return np.square(inputs[0] - inputs[1]) / N
|
||||
|
||||
|
||||
def loss_grad(op, inputs):
|
||||
N = op.get_attribute("N")
|
||||
N = op.attributes["N"]
|
||||
return 2 * (inputs[0] - inputs[1]) / N
|
||||
|
||||
|
||||
|
@ -47,16 +55,16 @@ def relu(op, inputs):
|
|||
|
||||
|
||||
def select(op, inputs):
|
||||
dim = op.get_attribute("dim")
|
||||
dim = op.attributes["dim"]
|
||||
return inputs[0][dim]
|
||||
|
||||
|
||||
def split(op, inputs):
|
||||
dim = op.get_attribute("dim")
|
||||
dim = op.attributes["dim"]
|
||||
if op.op_type == "Split":
|
||||
num_splits = op.get_attribute("num_splits")
|
||||
num_splits = op.attributes["num_splits"]
|
||||
elif op.op_type == "Scatter":
|
||||
num_splits = len(op.get_attribute("devices"))
|
||||
num_splits = len(op.attributes["devices"])
|
||||
|
||||
return np.split(inputs[0], num_splits, axis=dim)
|
||||
|
||||
|
@ -66,7 +74,7 @@ NumPyRegister = {
|
|||
"Allreduce": allreduce,
|
||||
"Broadcast": broadcast,
|
||||
"Concat": concat,
|
||||
"Gather": concat,
|
||||
"Gather": gather,
|
||||
"Loss": loss,
|
||||
"LossGrad": loss_grad,
|
||||
"MatMul": matmul,
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Any, Dict, List
|
||||
|
||||
from .backend_register import BackendRegister
|
||||
from ..ir import Module, Op
|
||||
from ..ir import Function, Op, Value
|
||||
|
||||
|
||||
class SequentialExecutor:
|
||||
|
@ -19,12 +19,11 @@ class SequentialExecutor:
|
|||
# Iterate over the inputs
|
||||
results = []
|
||||
for inps in inputs:
|
||||
# Execute submodule with appropriate inputs
|
||||
inp_names = (e.name for e in op.get_submodule(0).get_inputs())
|
||||
inp_data = {n: v for n, v in zip(inp_names, inps)}
|
||||
outs = self.compute(op.get_submodule(0), inp_data)
|
||||
# Match output names to output data using the module output order.
|
||||
ordered_outs = [outs[e.name] for e in op.get_submodule(0).get_outputs()]
|
||||
# Execute subfunction with appropriate inputs
|
||||
inp_data = {k: v for k, v in zip(op.subfunctions[0].inputs, inps)}
|
||||
outs = self.compute(op.subfunctions[0], inp_data)
|
||||
# Match output names to output data using the function output order.
|
||||
ordered_outs = [outs[e] for e in op.subfunctions[0].outputs]
|
||||
results.append(ordered_outs)
|
||||
# Unzip the results
|
||||
results = tuple(zip(*results))
|
||||
|
@ -39,56 +38,51 @@ class SequentialExecutor:
|
|||
output_data = (output_data,)
|
||||
return output_data
|
||||
|
||||
def compute(self, module: Module, input_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Executes the module given the specified inputs and returns the final result.
|
||||
def compute(
|
||||
self, function: Function, input_data: Dict[Value, Any]
|
||||
) -> Dict[Value, Any]:
|
||||
"""Executes the function given the specified inputs and returns the final result.
|
||||
|
||||
Args:
|
||||
module: The module to execute.
|
||||
input_data: A map from input tensor name to data represented in the
|
||||
function: The function to execute.
|
||||
input_data: A map from input value to data represented in the
|
||||
specified backend.
|
||||
|
||||
Returns:
|
||||
A map from output tensor name to output tensor.
|
||||
A map from output value to output data.
|
||||
"""
|
||||
output_data = {}
|
||||
consumers = {}
|
||||
ops = module.get_ops()
|
||||
|
||||
# Execute ops in topological order.
|
||||
for op_name, op in ops.items():
|
||||
for op in function.ops:
|
||||
inputs = []
|
||||
in_edges = op.get_in_edges()
|
||||
for in_edge in in_edges:
|
||||
input_name = in_edge.name
|
||||
if module.is_input(input_name):
|
||||
input_name = in_edge.name
|
||||
if input_name not in input_data:
|
||||
for in_edge in op.inputs:
|
||||
if in_edge in function.inputs:
|
||||
if in_edge not in input_data:
|
||||
raise ValueError(
|
||||
f"Could not find input {input_name} in input_data"
|
||||
f"Could not find input {in_edge} in input_data"
|
||||
)
|
||||
input_value = input_data[input_name]
|
||||
elif input_name in output_data:
|
||||
input_value = output_data[input_name]
|
||||
consumers[input_name] -= 1
|
||||
input_value = input_data[in_edge]
|
||||
elif in_edge in output_data:
|
||||
input_value = output_data[in_edge]
|
||||
consumers[in_edge] -= 1
|
||||
else:
|
||||
raise ValueError(f"Invalid input {input_name} for op {op_name}")
|
||||
raise ValueError(f"Invalid input {in_edge} for op {op}")
|
||||
inputs.append(input_value)
|
||||
|
||||
res = self._compute_op(op, inputs)
|
||||
out_edges = op.get_out_edges()
|
||||
for i, out_edge in enumerate(out_edges):
|
||||
output_data[out_edge.name] = res[i]
|
||||
consumers[out_edge.name] = len(
|
||||
module.get_consumers_for_value(out_edge.name)
|
||||
)
|
||||
for i, out_edge in enumerate(op.outputs):
|
||||
output_data[out_edge] = res[i]
|
||||
consumers[out_edge] = len(function.get_consumers(out_edge))
|
||||
|
||||
# Garbage collect the fully consumed output tensors.
|
||||
to_free = []
|
||||
for output_name in output_data:
|
||||
if consumers[output_name] == 0 and not module.is_output(output_name):
|
||||
to_free.append(output_name)
|
||||
for output_name in to_free:
|
||||
del output_data[output_name]
|
||||
for out_edge in output_data:
|
||||
if consumers[out_edge] == 0 and not out_edge in function.outputs:
|
||||
to_free.append(out_edge)
|
||||
for out_edge in to_free:
|
||||
del output_data[out_edge]
|
||||
|
||||
# Return the outputs.
|
||||
return output_data
|
||||
|
|
|
@ -1,202 +0,0 @@
|
|||
from ..ir.type import Float
|
||||
from ..ir.type import Tensor, TupleType
|
||||
from ..ir.value import Value
|
||||
from ..ir.device import Device
|
||||
|
||||
import copy
|
||||
|
||||
|
||||
def _get_shapes(values):
|
||||
shapes = []
|
||||
for value in values:
|
||||
if isinstance(value.type, Tensor):
|
||||
shapes.append(value.type.shape)
|
||||
else:
|
||||
shapes.append(None)
|
||||
return shapes
|
||||
|
||||
|
||||
def _error_invalid_shapes(op, input_shapes):
|
||||
raise ValueError(
|
||||
f"Op {op.name} ({op.op_type}): Incompatible input shapes {input_shapes}"
|
||||
)
|
||||
|
||||
|
||||
def _infer_shapes_for_add(op, inputs, outputs):
|
||||
# TODO: Handle input tensors with > 2 dimensions
|
||||
input_shapes = _get_shapes(inputs)
|
||||
if input_shapes[0] != input_shapes[1]:
|
||||
_error_invalid_shapes(op, input_shapes)
|
||||
|
||||
output_shape = input_shapes[0]
|
||||
output_type = Tensor(
|
||||
dtype=inputs[0].type.dtype, shape=output_shape, device=inputs[0].type.device
|
||||
)
|
||||
outputs[0].type = output_type
|
||||
|
||||
|
||||
def _infer_shapes_for_allreduce(op, inputs, outputs):
|
||||
outputs[0].type = copy.deepcopy(inputs[0].type)
|
||||
|
||||
|
||||
def _infer_shapes_for_broadcast(op, inputs, outputs):
|
||||
input_type = inputs[0].type
|
||||
devices = op.get_attribute("devices")
|
||||
output_types = []
|
||||
for (output_type, device) in zip(outputs[0].type.types, devices):
|
||||
if isinstance(output_type, Tensor) and isinstance(input_type, Tensor):
|
||||
output_type.shape = input_type.shape
|
||||
output_type.set_device(device)
|
||||
|
||||
|
||||
def _infer_shapes_for_concat(op, inputs, outputs):
|
||||
input_shapes = _get_shapes(inputs)
|
||||
dim = op.get_attribute("dim")
|
||||
for i, (dim0, dim1) in enumerate(zip(input_shapes[0], input_shapes[1])):
|
||||
if i != dim and dim0 != dim1:
|
||||
_error_invalid_shapes(op, input_shapes)
|
||||
output_shape = list(input_shapes[0])
|
||||
output_shape[dim] += input_shapes[1][dim]
|
||||
outputs[0].type.shape = output_shape
|
||||
|
||||
|
||||
def _infer_shapes_for_gather(op, inputs, outputs):
|
||||
dim = op.get_attribute("dim")
|
||||
device = op.get_attribute("device")
|
||||
output_shape = list(inputs[0].type.types[0].shape)
|
||||
for typ in inputs[0].type.types[1:]:
|
||||
output_shape[dim] += typ.shape[dim]
|
||||
outputs[0].type.dtype = inputs[0].type.types[0].dtype
|
||||
outputs[0].type.shape = output_shape
|
||||
outputs[0].type.set_device(device)
|
||||
|
||||
|
||||
def _infer_shapes_for_matmul(op, inputs, outputs):
|
||||
# TODO: Handle input tensors with > 2 dimensions
|
||||
input_shapes = _get_shapes(inputs)
|
||||
if input_shapes[0][1] != input_shapes[1][0]:
|
||||
_error_invalid_shapes(op, input_shapes)
|
||||
output_shape = (input_shapes[0][0], input_shapes[1][1])
|
||||
outputs[0].type = Tensor(
|
||||
dtype=inputs[0].type.dtype, shape=output_shape, device=inputs[0].type.device
|
||||
)
|
||||
|
||||
|
||||
def _infer_shapes_for_matmul_grad(op, inputs, outputs):
|
||||
for i, output in enumerate(outputs):
|
||||
output.type = copy.deepcopy(inputs[i].type)
|
||||
|
||||
|
||||
def _infer_shapes_for_loss(op, inputs, outputs):
|
||||
input_shapes = _get_shapes(inputs)
|
||||
if input_shapes[0] != input_shapes[1]:
|
||||
_error_invalid_shapes(op, input_shapes)
|
||||
|
||||
outputs[0].type = copy.deepcopy(inputs[0].type)
|
||||
|
||||
|
||||
def _infer_shapes_for_loss_grad(op, inputs, outputs):
|
||||
input_shapes = _get_shapes(inputs)
|
||||
if input_shapes[0] != input_shapes[1]:
|
||||
_error_invalid_shapes(op, input_shapes)
|
||||
|
||||
outputs[0].type = copy.deepcopy(inputs[0].type)
|
||||
|
||||
|
||||
def _infer_shapes_for_pmap(op, inputs, outputs):
|
||||
submodule = op.get_submodule(0)
|
||||
|
||||
for (pmap_input, submodule_input) in zip(inputs, submodule.get_inputs()):
|
||||
assert isinstance(pmap_input.type, TupleType)
|
||||
# TODO check that all elements of the tuple have the same type and, if
|
||||
# they are tensors, that they have the same shape
|
||||
if isinstance(submodule_input.type, Tensor):
|
||||
submodule_input.type.shape = pmap_input.type.types[0].shape
|
||||
|
||||
_infer_shapes(submodule)
|
||||
|
||||
for (pmap_output, submodule_output) in zip(outputs, submodule.get_outputs()):
|
||||
if isinstance(submodule_output.type, Tensor):
|
||||
assert isinstance(pmap_output.type, TupleType)
|
||||
for pmap_output_type in pmap_output.type.types:
|
||||
pmap_output_type.shape = submodule_output.type.shape
|
||||
pmap_output_type.dtype = submodule_output.type.dtype
|
||||
|
||||
|
||||
def _infer_shapes_for_scatter(op, inputs, outputs):
|
||||
input_type = inputs[0].type
|
||||
split_dim = op.get_attribute("dim")
|
||||
devices = op.get_attribute("devices")
|
||||
for (output_type, device) in zip(outputs[0].type.types, devices):
|
||||
if isinstance(output_type, Tensor) and isinstance(input_type, Tensor):
|
||||
output_shape = list(input_type.shape)
|
||||
output_shape[split_dim] //= len(devices)
|
||||
output_type.shape = tuple(output_shape)
|
||||
output_type.set_device(device)
|
||||
|
||||
|
||||
def _infer_shapes_for_select(op, inputs, outputs):
|
||||
dim = op.get_attribute("dim")
|
||||
outputs[0].type.shape = inputs[0].type.types[dim].shape
|
||||
outputs[0].type.dtype = inputs[0].type.types[dim].dtype
|
||||
|
||||
|
||||
def _infer_shapes_for_send(op, inputs, outputs):
|
||||
outputs[0].type.shape = inputs[0].type.shape
|
||||
outputs[0].type.dtype = inputs[0].type.dtype
|
||||
|
||||
|
||||
def _infer_shapes_for_split(op, inputs, outputs):
|
||||
num_splits = op.get_attribute("num_splits")
|
||||
split_dim = op.get_attribute("dim")
|
||||
output_shape = list(inputs[0].type.shape)
|
||||
output_shape[split_dim] //= num_splits
|
||||
for typ in outputs[0].type.types:
|
||||
typ.shape = tuple(output_shape)
|
||||
typ.dtype = inputs[0].type.dtype
|
||||
|
||||
|
||||
ShapeInferenceRegister = {
|
||||
"Add": _infer_shapes_for_add,
|
||||
"Allreduce": _infer_shapes_for_allreduce,
|
||||
"Broadcast": _infer_shapes_for_broadcast,
|
||||
"Concat": _infer_shapes_for_concat,
|
||||
"Gather": _infer_shapes_for_gather,
|
||||
"Loss": _infer_shapes_for_loss,
|
||||
"LossGrad": _infer_shapes_for_loss_grad,
|
||||
"MatMul": _infer_shapes_for_matmul,
|
||||
"MatMulGrad": _infer_shapes_for_matmul_grad,
|
||||
"Pmap": _infer_shapes_for_pmap,
|
||||
"Scatter": _infer_shapes_for_scatter,
|
||||
"Select": _infer_shapes_for_select,
|
||||
"Send": _infer_shapes_for_send,
|
||||
"Split": _infer_shapes_for_split,
|
||||
}
|
||||
|
||||
|
||||
def _infer_shapes(module):
|
||||
"""Helper function for inferring shapes.
|
||||
|
||||
Inputs:
|
||||
module: The module to infer shapes for.
|
||||
"""
|
||||
|
||||
for op_name, op in module.get_ops().items():
|
||||
inputs = op.get_in_edges()
|
||||
outputs = op.get_out_edges()
|
||||
|
||||
# Invariant: types and shapes of input are already inferred
|
||||
for input in inputs:
|
||||
assert input.type is not None
|
||||
if isinstance(input.type, Tensor):
|
||||
if input.type.shape is None:
|
||||
raise ValueError(f"Input {input.name} of op {op_name} has no shape")
|
||||
|
||||
ShapeInferenceRegister[op.op_type](op, inputs, outputs)
|
||||
# TODO maybe the register gives back the output types and we can check
|
||||
# here if they match existing types (if any) and if not, replace them
|
||||
|
||||
|
||||
def infer_shapes(module):
|
||||
"""Infers shapes for the given module."""
|
||||
_infer_shapes(module)
|
|
@ -0,0 +1,310 @@
|
|||
"""
|
||||
A type inference module that converts an untyped DistIR Function into one where
|
||||
every Value is typed with shape and dtype information, given input types or
|
||||
example inputs.
|
||||
|
||||
Type inference requires a register mapping ops to type propagation functions:
|
||||
- This is a function foo(op, x1, x2, .., xN), where op is an N-ary Op, and x1 to
|
||||
xN are Types of the inputs.
|
||||
- The function should check that the inputs have the expected types.
|
||||
- The function should return the type of the output/a tuple of types of the
|
||||
outputs.
|
||||
(When we say types we also mean shape and device information.)
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from ..ir import Device, Function, FunctionMaker, Op, Value
|
||||
from ..ir.type import Type, Tensor, TupleType
|
||||
|
||||
|
||||
def _raise_type_error(op, *args):
|
||||
raise ValueError(f"Type error: op\n{op}\nwas given arguments\n{tuple(args)}")
|
||||
|
||||
|
||||
def _allreduce_prop_fn(op, x):
|
||||
devices = tuple(t.device for t in x.types)
|
||||
if not (
|
||||
isinstance(x, TupleType)
|
||||
and all(isinstance(t, Tensor) for t in x.types)
|
||||
and len(x.types) > 0
|
||||
and all(t.shape == x.types[0].shape for t in x.types)
|
||||
and len(set(devices)) == len(devices)
|
||||
):
|
||||
_raise_type_error(op, x)
|
||||
return x
|
||||
|
||||
|
||||
# TODO update the below prop functions to be as robust as _allreduce_prop_fn
|
||||
|
||||
|
||||
def _broadcast_prop_fn(op, x):
|
||||
if not isinstance(x, Tensor):
|
||||
_raise_type_error(op, x)
|
||||
devices = op.attributes["devices"]
|
||||
return TupleType(
|
||||
tuple(Tensor(dtype=x.dtype, shape=x.shape, device=device) for device in devices)
|
||||
)
|
||||
|
||||
|
||||
def _concat_prop_fn(op, x, y):
|
||||
if not (
|
||||
isinstance(x, Tensor)
|
||||
and isinstance(y, Tensor)
|
||||
and x.dtype == y.dtype
|
||||
and x.device == y.device
|
||||
):
|
||||
_raise_type_error(op, x, y)
|
||||
dim = op.attributes["dim"]
|
||||
for i, (d0, d1) in enumerate(zip(x.shape, y.shape)):
|
||||
if not i != dim and d0 != d1:
|
||||
_raise_type_error(op, x, y)
|
||||
output_shape = tuple(
|
||||
n + (y.shape[i] if i == dim else 0) for i, n in enumerate(x.shape)
|
||||
)
|
||||
return Tensor(dtype=x.dtype, shape=output_shape, device=x.device)
|
||||
|
||||
|
||||
def _elementwise_tensor_op_prop_fn(op, x, y):
|
||||
if not (
|
||||
isinstance(x, Tensor)
|
||||
and isinstance(y, Tensor)
|
||||
and x.dtype == y.dtype
|
||||
and x.shape == y.shape
|
||||
and x.device == y.device
|
||||
):
|
||||
_raise_type_error(op, x, y)
|
||||
return x
|
||||
|
||||
|
||||
def _gather_prop_fn(op, x):
|
||||
if not (
|
||||
isinstance(x, TupleType)
|
||||
and all(isinstance(t, Tensor) for t in x.types)
|
||||
and len(set(t.shape for t in x.types)) == 1
|
||||
and len(set(t.dtype for t in x.types)) == 1
|
||||
and len(x.types) > 0
|
||||
):
|
||||
_raise_type_error(op, x)
|
||||
dim = op.attributes["dim"]
|
||||
device = op.attributes["device"]
|
||||
output_shape = list(x.types[0].shape)
|
||||
for i in range(1, len(x.types)):
|
||||
for j in range(len(x.types[i].shape)):
|
||||
if j == dim:
|
||||
output_shape[j] += x.types[i].shape[j]
|
||||
elif x.types[i].shape[j] != x.types[0].shape[j]:
|
||||
_raise_type_error(op, x)
|
||||
output_shape = tuple(output_shape)
|
||||
return Tensor(dtype=x.types[0].dtype, shape=output_shape, device=device)
|
||||
|
||||
|
||||
def _matmul_prop_fn(op, x, y):
|
||||
if not (
|
||||
isinstance(x, Tensor)
|
||||
and isinstance(y, Tensor)
|
||||
and x.dtype == y.dtype
|
||||
and x.device == y.device
|
||||
and x.shape[1] == y.shape[0]
|
||||
):
|
||||
_raise_type_error(op, x, y)
|
||||
return Tensor(dtype=x.dtype, shape=(x.shape[0], y.shape[1]), device=x.device)
|
||||
|
||||
|
||||
def _matmul_grad_prop_fn(op, x, y, z):
|
||||
# TODO: Check that shapes can be multipled together?
|
||||
if not (
|
||||
isinstance(x, Tensor)
|
||||
and isinstance(y, Tensor)
|
||||
and isinstance(z, Tensor)
|
||||
and x.dtype == y.dtype
|
||||
and x.dtype == z.dtype
|
||||
and x.device == y.device
|
||||
and x.device == z.device
|
||||
):
|
||||
_raise_type_error(op, x, y, z)
|
||||
|
||||
return (x, y)
|
||||
|
||||
|
||||
def _scatter_prop_fn(op, x):
|
||||
if not isinstance(x, Tensor):
|
||||
_raise_type_error(op, x)
|
||||
devices = op.attributes["devices"]
|
||||
# Check devices is a list of distinct Devices
|
||||
assert isinstance(devices, list) and all(isinstance(d, Device) for d in devices)
|
||||
assert len(devices) == len(set(devices))
|
||||
dim = op.attributes["dim"]
|
||||
# TODO: Should we add another function to raise an attribute error?
|
||||
assert x.shape[dim] % len(devices) == 0
|
||||
output_shape = list(x.shape)
|
||||
output_shape[dim] //= len(devices)
|
||||
output_shape = tuple(output_shape)
|
||||
return TupleType(
|
||||
tuple(
|
||||
Tensor(dtype=x.dtype, shape=output_shape, device=device)
|
||||
for device in devices
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _select_prop_fn(op, x):
|
||||
if not (
|
||||
isinstance(x, TupleType)
|
||||
and all(isinstance(t, Tensor) for t in x.types)
|
||||
and len(x.types) > 0
|
||||
and all(t.shape == x.types[0].shape for t in x.types)
|
||||
and len(set(t.device for t in x.types)) == 1
|
||||
):
|
||||
_raise_type_error(op, x)
|
||||
dim = op.attributes["dim"]
|
||||
return x.types[dim]
|
||||
|
||||
|
||||
def _send_prop_fn(op, x):
|
||||
if not isinstance(x, Tensor):
|
||||
_raise_type_error(op, x)
|
||||
device = op.attributes["device"]
|
||||
return Tensor(dtype=x.dtype, shape=x.shape, device=device)
|
||||
|
||||
|
||||
def _split_prop_fn(op, x):
|
||||
if not isinstance(x, Tensor):
|
||||
_raise_type_error(op, x)
|
||||
num_splits = op.attributes["num_splits"]
|
||||
split_dim = op.attributes["dim"]
|
||||
output_shape = list(x.shape)
|
||||
# TODO: Move this check to attribute error function?
|
||||
assert output_shape[split_dim] % num_splits == 0
|
||||
output_shape[split_dim] //= num_splits
|
||||
output_shape = tuple(output_shape)
|
||||
return TupleType(
|
||||
tuple(
|
||||
Tensor(dtype=x.dtype, shape=output_shape, device=x.device)
|
||||
for i in range(num_splits)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
TypePropRegister = {
|
||||
"Add": _elementwise_tensor_op_prop_fn,
|
||||
# "Allgather": TODO,
|
||||
"Allreduce": _allreduce_prop_fn,
|
||||
"Broadcast": _broadcast_prop_fn,
|
||||
"Concat": _concat_prop_fn,
|
||||
"Gather": _gather_prop_fn,
|
||||
"Loss": _elementwise_tensor_op_prop_fn,
|
||||
"LossGrad": _elementwise_tensor_op_prop_fn,
|
||||
"MatMul": _matmul_prop_fn,
|
||||
"MatMulGrad": _matmul_grad_prop_fn,
|
||||
"Scatter": _scatter_prop_fn,
|
||||
"Select": _select_prop_fn,
|
||||
"Send": _send_prop_fn,
|
||||
"Split": _split_prop_fn,
|
||||
}
|
||||
|
||||
# Handling pmap specially for now since it needs to return a typed subfunction
|
||||
|
||||
|
||||
def _pmap_prop_fn(op: Op, input_types: Tuple[Type]):
|
||||
"""pmap maps over a tuple of values, all of the same type and shape, but on
|
||||
distinct devices. For convenience, to avoid a lot of zipping, we allow
|
||||
multiple inputs, as long as they are all tuples of the same length and the
|
||||
list of devices in each tuple are exactly the same.
|
||||
"""
|
||||
# Pmap expects 1 or more tuples as input
|
||||
assert isinstance(input_types, tuple) and len(input_types) > 0
|
||||
assert all(isinstance(t, TupleType) for t in input_types)
|
||||
# Check that pmap's arguments all have same length and shapes
|
||||
assert len(set(len(t.types) for t in input_types)) == 1
|
||||
for t in input_types:
|
||||
assert all(isinstance(x, Tensor) for x in t.types)
|
||||
assert len(set(x.shape for x in t.types)) == 1
|
||||
assert len(set(x.dtype for x in t.types)) == 1
|
||||
# Check that pmap's arguments are on distinct devices
|
||||
devices = tuple(x.device for x in input_types[0].types)
|
||||
assert len(set(devices)) == len(devices)
|
||||
# Check that all inputs have same list of devices
|
||||
for t in input_types:
|
||||
assert devices == tuple(x.device for x in t.types)
|
||||
|
||||
# Subfunction's inputs are given by pmap's arguments, but on device d
|
||||
subfn_inputs = [
|
||||
Value(v.name, t.types[0])
|
||||
for v, t in zip(op.subfunctions[0].inputs, input_types)
|
||||
]
|
||||
|
||||
# Recursively call infer_types on subfunction
|
||||
assert len(op.subfunctions) == 1
|
||||
subfunctions = [infer_types(op.subfunctions[0], subfn_inputs)]
|
||||
|
||||
# Pmap's output types are given by subfunction's output types
|
||||
out_types = tuple(
|
||||
TupleType(
|
||||
tuple(
|
||||
Tensor(shape=t.type.shape, dtype=t.type.dtype, device=d)
|
||||
for d in devices
|
||||
)
|
||||
)
|
||||
for t in subfunctions[0].outputs
|
||||
)
|
||||
return out_types, subfunctions
|
||||
|
||||
|
||||
def infer_types(function: Function, inputs: List[Value]) -> Function:
|
||||
"""Given a function and a list of input values, returns a new function where
|
||||
all values are typed.
|
||||
|
||||
inputs: a list/tuple of Values, of the same length as function.inputs, but
|
||||
the names are irrelevant.
|
||||
"""
|
||||
new_function = FunctionMaker()
|
||||
# A Map from function's values to new_function's (typed) values:
|
||||
value_map: Dict[Value, Value] = {}
|
||||
|
||||
def assert_is_typed(v: Value):
|
||||
assert v.type is not None
|
||||
if isinstance(v.type, Tensor):
|
||||
if v.type.shape is None:
|
||||
raise ValueError(f"Expected Value {v} to have a shape")
|
||||
|
||||
# Add inputs to new_function
|
||||
assert len(inputs) == len(function.inputs)
|
||||
for old_inp, inp in zip(function.inputs, inputs):
|
||||
assert_is_typed(inp)
|
||||
new_inp = new_function.add_input_value(old_inp.name, inp.type)
|
||||
value_map[old_inp] = new_inp
|
||||
|
||||
op: Op # https://stackoverflow.com/q/59102038
|
||||
for op in function.ops:
|
||||
# Invariant: inputs of op are already typed (as ops are toposorted)
|
||||
typed_inputs = tuple(value_map[inp] for inp in op.inputs)
|
||||
input_types = tuple(v.type for v in typed_inputs)
|
||||
|
||||
# Infer types of outputs and create output values
|
||||
if op.op_type == "Pmap":
|
||||
out_types, subfunctions = _pmap_prop_fn(op, input_types)
|
||||
else:
|
||||
out_types = TypePropRegister[op.op_type](op, *input_types)
|
||||
if not isinstance(out_types, tuple):
|
||||
assert isinstance(out_types, Type)
|
||||
out_types = (out_types,)
|
||||
subfunctions = []
|
||||
|
||||
new_op = Op(
|
||||
op.op_type,
|
||||
op.name,
|
||||
typed_inputs,
|
||||
op.attributes,
|
||||
subfunctions,
|
||||
tuple(v.name for v in op.outputs),
|
||||
out_types,
|
||||
)
|
||||
new_function.ops.append(new_op)
|
||||
|
||||
# Add op's outputs to value_map
|
||||
for old_out, out in zip(op.outputs, new_op.outputs):
|
||||
assert_is_typed(out)
|
||||
value_map[old_out] = out
|
||||
|
||||
return new_function.finalize()
|
|
@ -1,10 +1,10 @@
|
|||
from dataclasses import dataclass, field
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, Union
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import mlir
|
||||
from ..ir import cpprint, Module, Value
|
||||
from ..ir import Function, FunctionMaker, Value
|
||||
from ..ir.device import Device
|
||||
from ..ir.type import Float, Tensor
|
||||
|
||||
|
@ -69,8 +69,8 @@ def _parse_type(mlir_type, context: Context):
|
|||
raise ValueError(f"Unknown MLIR type {mlir_type}")
|
||||
|
||||
|
||||
def _parse_module(mlir_region, context=None):
|
||||
"""Creates a DistIR Module out of an MLIR region. The region must be either
|
||||
def _parse_function(mlir_region, context: Context = None) -> Function:
|
||||
"""Creates a DistIR Function out of an MLIR region. The region must be either
|
||||
the single region in a function, or the sub-region of a pmap operator.
|
||||
"""
|
||||
if context is None:
|
||||
|
@ -79,11 +79,11 @@ def _parse_module(mlir_region, context=None):
|
|||
assert len(mlir_region.blocks) == 1
|
||||
entry_block = mlir_region.blocks[0]
|
||||
|
||||
module = Module()
|
||||
function = FunctionMaker()
|
||||
|
||||
# Find the inputs
|
||||
for arg in entry_block.arguments:
|
||||
v = module.add_input_value(
|
||||
v = function.add_input_value(
|
||||
f"%arg{arg.arg_number}", _parse_type(arg.type, context)
|
||||
)
|
||||
assert str(arg) not in context.values
|
||||
|
@ -121,11 +121,11 @@ def _parse_module(mlir_region, context=None):
|
|||
# Create output names (TODO should be done by Op.__init__ or add_op)
|
||||
output_names = [_make_fresh_var(context) for _ in op.results]
|
||||
|
||||
submodules = []
|
||||
subfunctions = []
|
||||
|
||||
if op_name == "std.return" or op_name == "dist.return":
|
||||
# Set return values as module outputs
|
||||
module.set_outputs(args)
|
||||
# Set return values as function outputs
|
||||
function.set_outputs(args)
|
||||
returned = True
|
||||
continue
|
||||
if op_name == "dist.pmap":
|
||||
|
@ -134,18 +134,18 @@ def _parse_module(mlir_region, context=None):
|
|||
new_device = Device.get_new_device_variable("gpu")
|
||||
assert attributes["device_var"] not in context.devices
|
||||
context.devices[attributes["device_var"]] = new_device
|
||||
# Parse the submodule
|
||||
submodules.append(_parse_module(op.regions[0], context))
|
||||
# Remove device var from context.devices as it is only in scope for submodule
|
||||
# Parse the subfunction
|
||||
subfunctions.append(_parse_function(op.regions[0], context))
|
||||
# Remove device var from context.devices as it is only in scope for subfunction
|
||||
del context.devices[attributes["device_var"]]
|
||||
|
||||
# Create an op and add it to the module
|
||||
outs = module.add_op(
|
||||
# Create an op and add it to the function
|
||||
outs = function.add_op(
|
||||
op_name,
|
||||
inputs=args,
|
||||
attributes=attributes,
|
||||
output_names=output_names,
|
||||
submodules=submodules,
|
||||
subfunctions=subfunctions,
|
||||
)
|
||||
if not isinstance(outs, tuple):
|
||||
outs = (outs,)
|
||||
|
@ -156,22 +156,21 @@ def _parse_module(mlir_region, context=None):
|
|||
assert str(mlirval) not in context.values
|
||||
context.values[str(mlirval)] = val
|
||||
|
||||
module.finalize()
|
||||
return module
|
||||
return function.finalize()
|
||||
|
||||
|
||||
def _parse_mlir_module(mlir_module):
|
||||
modules = []
|
||||
def _parse_mlir_module(mlir_module) -> List[Function]:
|
||||
functions = []
|
||||
for func in mlir_module.body.operations:
|
||||
if str(func) == "module_terminator":
|
||||
break
|
||||
assert len(func.regions) == 1
|
||||
|
||||
modules.append(_parse_module(func.regions[0]))
|
||||
return modules
|
||||
functions.append(_parse_function(func.regions[0]))
|
||||
return functions
|
||||
|
||||
|
||||
def parse_mlir_str(mlir_str):
|
||||
def parse_mlir_str(mlir_str: str) -> List[Function]:
|
||||
ctx = mlir.ir.Context()
|
||||
ctx.allow_unregistered_dialects = True
|
||||
|
||||
|
@ -179,7 +178,7 @@ def parse_mlir_str(mlir_str):
|
|||
return _parse_mlir_module(mlir_module)
|
||||
|
||||
|
||||
def parse_mlir_file(filename):
|
||||
def parse_mlir_file(filename) -> List[Function]:
|
||||
with open(filename, "r") as fin:
|
||||
mlir_str = fin.read()
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import onnx
|
||||
|
||||
from ..ir import Module, Value
|
||||
from ..ir import FunctionMaker, Value
|
||||
from ..ir.type import Tensor, Float
|
||||
|
||||
|
||||
|
@ -8,14 +8,14 @@ def import_from_onnx(onnx_model):
|
|||
# TODO: Remove prints?
|
||||
# TODO: Support types beyond Tensor
|
||||
onnx_model = onnx.load(onnx_model)
|
||||
dist_ir_module = Module()
|
||||
dist_ir_function = FunctionMaker("foo") # TODO get name?
|
||||
|
||||
inputs = {}
|
||||
output_src = {}
|
||||
|
||||
def add_input(value):
|
||||
# TODO lookup shape and dtype of input if exists
|
||||
v = dist_ir_module.add_input_value(value.name, Tensor(Float()))
|
||||
v = dist_ir_function.add_input_value(value.name, Tensor(Float()))
|
||||
inputs[value.name] = v
|
||||
|
||||
for value in onnx_model.graph.value_info:
|
||||
|
@ -41,21 +41,26 @@ def import_from_onnx(onnx_model):
|
|||
else:
|
||||
print(f"---> Could not find input {value}!")
|
||||
# TODO do something better here
|
||||
v = dist_ir_module.add_input_value(value, Tensor(Float()))
|
||||
v = dist_ir_function.add_input_value(value, Tensor(Float()))
|
||||
inputs[value] = v
|
||||
per_node_inputs.append(v)
|
||||
print()
|
||||
output_names = node.output
|
||||
op = dist_ir_module.add_op(
|
||||
outputs = dist_ir_function.add_op(
|
||||
op_type=node.op_type,
|
||||
name=node.name,
|
||||
inputs=per_node_inputs,
|
||||
output_names=output_names,
|
||||
)
|
||||
for output in node.output:
|
||||
# TODO lookup shape and dtype of input if exists
|
||||
v = Value(output, Tensor(Float()))
|
||||
output_src[output] = v
|
||||
# Match node's outputs with the output Values created in op:
|
||||
if len(node.output) == 1:
|
||||
assert isinstance(outputs, Value)
|
||||
outputs = [outputs]
|
||||
else:
|
||||
assert len(outputs) == len(node.output)
|
||||
for out_name, value in zip(node.output, outputs):
|
||||
assert out_name == value.name
|
||||
output_src[out_name] = value
|
||||
print(f"Found output {out_name}")
|
||||
print()
|
||||
|
||||
dist_ir_module.verify_ops_in_topological_order()
|
||||
return dist_ir_module
|
||||
return dist_ir_function.finalize()
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from .module import Module
|
||||
from .device import Device
|
||||
from .function import Function, FunctionMaker
|
||||
from .op import Op
|
||||
from .prettyprint import cpprint, pformat
|
||||
from .topology import Topology
|
||||
|
|
|
@ -1,38 +1,21 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Device:
|
||||
|
||||
device_variable_id = 0
|
||||
device_id: str
|
||||
device_type: str
|
||||
is_variable: bool = False
|
||||
|
||||
def __init__(self, device_id, device_type, is_variable=False):
|
||||
self._device_id = device_id
|
||||
self._device_type = device_type
|
||||
self._is_variable = is_variable
|
||||
device_variable_id: ClassVar[int] = 0
|
||||
|
||||
def __str__(self):
|
||||
return f"{self._device_id} ({self._device_type})"
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
def __eq__(self, other):
|
||||
return other is not None and self._device_id == other._device_id
|
||||
return f"{self.device_id} ({self.device_type})"
|
||||
|
||||
def __lt__(self, other):
|
||||
return self._device_id < other._device_id
|
||||
|
||||
def __hash__(self):
|
||||
return hash(str(self))
|
||||
|
||||
@property
|
||||
def device_id(self):
|
||||
return self._device_id
|
||||
|
||||
@property
|
||||
def device_type(self):
|
||||
return self._device_type
|
||||
|
||||
@property
|
||||
def is_variable(self):
|
||||
return self._is_variable
|
||||
return self.device_id < other.device_id
|
||||
|
||||
@classmethod
|
||||
def get_new_device_variable(cls, device_type):
|
||||
|
|
|
@ -0,0 +1,259 @@
|
|||
from __future__ import annotations
|
||||
from collections import OrderedDict, defaultdict
|
||||
import copy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from frozendict import frozendict
|
||||
|
||||
from .op import Op
|
||||
from .value import Value
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Function:
|
||||
"""The core DistIR concept: a function.
|
||||
|
||||
A function has a name, a list of input values, a list of operations, and a
|
||||
list of output values. Functions are immutable.
|
||||
"""
|
||||
|
||||
name: str
|
||||
ops: Tuple[Op]
|
||||
inputs: Tuple[Value]
|
||||
outputs: Tuple[Value]
|
||||
|
||||
# Map from Value -> List of Ops that consume it
|
||||
_consumers: Dict[Value, Tuple[Op]] = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Creates the _consumers map, verifies the function, and performs
|
||||
type inference. This is called automatically at initialization.
|
||||
"""
|
||||
consumers = defaultdict(list)
|
||||
for op in self.ops:
|
||||
for in_edge in op.inputs:
|
||||
consumers[in_edge].append(op)
|
||||
for out_edge in op.outputs:
|
||||
consumers[out_edge] = []
|
||||
for v in consumers:
|
||||
consumers[v] = tuple(consumers[v])
|
||||
# Can't assign to frozen field:
|
||||
object.__setattr__(self, "_consumers", frozendict(consumers))
|
||||
|
||||
# Check that ops don't use values from the future
|
||||
self._verify_ops_in_topological_order()
|
||||
|
||||
def _verify_ops_in_topological_order(self):
|
||||
seen = set()
|
||||
for inp in self.inputs:
|
||||
seen.add(inp)
|
||||
for op in self.ops:
|
||||
for in_edge in op.inputs:
|
||||
if in_edge not in seen:
|
||||
raise ValueError(
|
||||
f"Ops are not in topological order: op {op.name} has "
|
||||
f"unseen edge {in_edge}"
|
||||
)
|
||||
for out_edge in op.outputs:
|
||||
seen.add(out_edge)
|
||||
|
||||
def get_consumers(self, value: Value) -> List[Op]:
|
||||
return self._consumers[value]
|
||||
|
||||
def __str__(self):
|
||||
# TODO can we use the prettyprint output as __str__?
|
||||
return self.get_summary()
|
||||
|
||||
def get_summary(self):
|
||||
output = ""
|
||||
output += "Function inputs:\n"
|
||||
for input_value in self.inputs:
|
||||
output += " " + str(input_value) + "\n"
|
||||
output += "\n"
|
||||
output += "Function outputs:\n"
|
||||
for input_value in self.outputs:
|
||||
output += " " + str(input_value) + "\n"
|
||||
output += "\n"
|
||||
output += "Ops:\n"
|
||||
for op in self.ops:
|
||||
output += str(op) + "\n"
|
||||
return output
|
||||
|
||||
def has_input(self, value):
|
||||
"""Checks whether the given value is an input of this function."""
|
||||
return value in self.inputs
|
||||
|
||||
def has_output(self, value):
|
||||
"""Checks whether the given value is an output of this function."""
|
||||
return value in self.outputs
|
||||
|
||||
def get_subfunction(
|
||||
self, op_names: Tuple[str], deepcopy: bool = False, name: Optional[str] = None
|
||||
) -> Function:
|
||||
"""Returns a Function comprised of the specified subset of ops."""
|
||||
subfunction = FunctionMaker(name)
|
||||
op_names_set = set(op_names)
|
||||
ops = []
|
||||
for op in self.ops:
|
||||
if op.name in op_names_set:
|
||||
ops.append(op)
|
||||
value_map = {}
|
||||
outputs = []
|
||||
for op in ops:
|
||||
subfunction_op_inputs = []
|
||||
for inp in op.inputs:
|
||||
if inp not in value_map:
|
||||
if deepcopy:
|
||||
value_map[inp] = subfunction.add_input_value(inp.name, inp.type)
|
||||
else:
|
||||
subfunction.inputs.append(inp)
|
||||
value_map[inp] = inp
|
||||
subfunction_op_inputs.append(value_map[inp])
|
||||
output_names = [output.name for output in op.outputs]
|
||||
if deepcopy:
|
||||
subfunction_op_outputs = subfunction.add_op(
|
||||
op.op_type,
|
||||
name=op.name,
|
||||
inputs=subfunction_op_inputs,
|
||||
attributes=copy.deepcopy(op.attributes),
|
||||
subfunctions=copy.deepcopy(op.subfunctions),
|
||||
output_names=output_names,
|
||||
)
|
||||
else:
|
||||
subfunction.ops.append(op)
|
||||
subfunction_op_outputs = op.outputs
|
||||
if not isinstance(subfunction_op_outputs, tuple):
|
||||
subfunction_op_outputs = (subfunction_op_outputs,)
|
||||
for orig_output, subfunction_output in zip(
|
||||
op.outputs, subfunction_op_outputs
|
||||
):
|
||||
# We need to explicitly set the subfunction outputs because some output
|
||||
# values might have consumers outside the subfunction (external).
|
||||
has_external_output = False
|
||||
if orig_output in self.outputs or any(
|
||||
[c not in ops for c in self._consumers[orig_output]]
|
||||
):
|
||||
outputs.append(subfunction_output)
|
||||
value_map[orig_output] = subfunction_output
|
||||
subfunction.set_outputs(outputs)
|
||||
return subfunction.finalize()
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionMaker:
|
||||
"""A helper class for creating Functions."""
|
||||
|
||||
name: str = "foo"
|
||||
ops: List[Op] = field(default_factory=list)
|
||||
inputs: List[Value] = field(default_factory=list)
|
||||
outputs: List[Value] = field(default_factory=list)
|
||||
|
||||
def add_op(
|
||||
self,
|
||||
op_type,
|
||||
name=None,
|
||||
inputs: List[Value] = None,
|
||||
attributes: Dict[str, Any] = None,
|
||||
subfunctions: List["Function"] = None,
|
||||
output_names: List[str] = None,
|
||||
) -> Union[None, Value, Tuple[Value, ...]]:
|
||||
"""Adds an op to the function.
|
||||
|
||||
Args:
|
||||
op_type: The op's type.
|
||||
name: The op's name.
|
||||
inputs: The input values for this op.
|
||||
attributes: Any op-specific attributes.
|
||||
subfunctions: Any subfunctions this op is wrapping.
|
||||
output_names: An optional list of output value names.
|
||||
|
||||
Returns:
|
||||
The outputs of the newly created op.
|
||||
"""
|
||||
op = Op(
|
||||
op_type,
|
||||
name=name,
|
||||
inputs=None if inputs is None else tuple(inputs),
|
||||
attributes=None if attributes is None else frozendict(attributes),
|
||||
subfunctions=None if subfunctions is None else tuple(subfunctions),
|
||||
output_names=None if output_names is None else tuple(output_names),
|
||||
)
|
||||
self.ops.append(op)
|
||||
|
||||
# Return the op outputs.
|
||||
num_out_edges = len(op.outputs)
|
||||
if num_out_edges == 0:
|
||||
return None
|
||||
elif num_out_edges == 1:
|
||||
return op.outputs[0]
|
||||
else:
|
||||
return tuple(op.outputs)
|
||||
|
||||
def add_input_value(self, name, value_type):
|
||||
"""Adds an input value to the function and returns the value."""
|
||||
value = Value(name, value_type)
|
||||
if value in self.inputs:
|
||||
raise ValueError(f"Function already has input value {value}")
|
||||
self.inputs.append(value)
|
||||
return value
|
||||
|
||||
def set_outputs(self, outputs: Iterable[Value]):
|
||||
"""Sets the output of this function to be the given values. They must be
|
||||
valid values, i.e. outputs of some existing op in the function. This clears
|
||||
any previous outputs registered with this function.
|
||||
"""
|
||||
self.outputs.clear()
|
||||
seen = set()
|
||||
for output in outputs:
|
||||
if output in seen:
|
||||
raise ValueError(f"Function already has output value {output}")
|
||||
seen.add(output)
|
||||
self.outputs.append(output)
|
||||
|
||||
def set_outputs_auto(self):
|
||||
"""Marks all sink nodes in the graph as output values."""
|
||||
is_output = OrderedDict()
|
||||
|
||||
self.outputs.clear()
|
||||
for input_value in self.inputs:
|
||||
is_output[input_value] = True
|
||||
|
||||
for op in self.ops:
|
||||
for in_edge in op.inputs:
|
||||
is_output[in_edge] = False
|
||||
for out_edge in op.outputs:
|
||||
is_output[out_edge] = True
|
||||
|
||||
self.outputs = [v for v in is_output if is_output[v]]
|
||||
|
||||
def _get_ops_in_topological_order_helper(self, name, visited, order):
|
||||
visited.add(name)
|
||||
|
||||
out_edges = self.ops[name].outputs
|
||||
for out_edge in out_edges:
|
||||
output_name = out_edge
|
||||
if output_name not in visited:
|
||||
self._get_ops_in_topological_order_helper(output_name, visited, order)
|
||||
|
||||
order.append(name)
|
||||
|
||||
def get_ops_in_topological_order(self):
|
||||
"""Return ops in topological order. DEPRECATED, ops should always be
|
||||
topologically ordered.
|
||||
"""
|
||||
visited = set()
|
||||
order = []
|
||||
for name in self.ops:
|
||||
if name not in visited:
|
||||
self._get_ops_in_topological_order_helper(name, visited, order)
|
||||
return order[::-1]
|
||||
|
||||
def finalize(self) -> Function:
|
||||
"""Returns the created Function. Outputs, if unspecified, are the sinks."""
|
||||
if len(self.outputs) == 0:
|
||||
self.set_outputs_auto()
|
||||
|
||||
return Function(
|
||||
self.name, tuple(self.ops), tuple(self.inputs), tuple(self.outputs)
|
||||
)
|
|
@ -1,294 +0,0 @@
|
|||
from collections import OrderedDict, defaultdict
|
||||
import copy
|
||||
from typing import Any, Dict, Iterable, List, Tuple, Union
|
||||
|
||||
from .op import Op
|
||||
from .value import Value
|
||||
|
||||
|
||||
class Module:
|
||||
def __init__(self, name=None):
|
||||
self._ops = OrderedDict()
|
||||
self._inputs = OrderedDict()
|
||||
self._outputs = OrderedDict()
|
||||
self._op_counter = defaultdict(int)
|
||||
self._consumers = defaultdict(list)
|
||||
self._name = name
|
||||
self._hash = None
|
||||
|
||||
def __str__(self):
|
||||
if self._name is not None:
|
||||
return self._name
|
||||
else:
|
||||
return self.get_summary()
|
||||
|
||||
def __repr__(self):
|
||||
return self.get_summary()
|
||||
|
||||
def __hash__(self):
|
||||
if self._hash is None:
|
||||
raise RuntimeError("Cannot hash unfinalized module!")
|
||||
return self._hash
|
||||
|
||||
def __eq__(self, other):
|
||||
for op_name in self._ops:
|
||||
if op_name not in other._ops or self._ops[op_name] != other._ops[op_name]:
|
||||
return False
|
||||
for input_name in self._inputs:
|
||||
if (
|
||||
input_name not in other._inputs
|
||||
or self._inputs[input_name] != other._inputs[input_name]
|
||||
):
|
||||
return False
|
||||
for output_name in self._outputs:
|
||||
if (
|
||||
output_name not in other._outputs
|
||||
or self._outputs[output_name] != other._outputs[output_name]
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_summary(self):
|
||||
output = ""
|
||||
output += "Module inputs:\n"
|
||||
for input_value in self._inputs.values():
|
||||
output += " " + str(input_value) + "\n"
|
||||
output += "\n"
|
||||
output += "Module outputs:\n"
|
||||
for input_value in self._outputs.values():
|
||||
output += " " + str(input_value) + "\n"
|
||||
output += "\n"
|
||||
output += "Ops:\n"
|
||||
for op in self._ops.values():
|
||||
output += str(op) + "\n"
|
||||
return output
|
||||
|
||||
# TODO: Convert to property
|
||||
def get_ops(self):
|
||||
"""Returns all ops in the module."""
|
||||
return self._ops
|
||||
|
||||
def is_op(self, name):
|
||||
"""Checks whether a op exists with the specified name."""
|
||||
return name in self._ops
|
||||
|
||||
def is_input(self, name):
|
||||
"""Checks whether an input value exists with the specified name."""
|
||||
return name in self._inputs
|
||||
|
||||
def is_output(self, name):
|
||||
"""Checks whether an output value exists with the specified name."""
|
||||
return name in self._outputs
|
||||
|
||||
def get_op(self, name):
|
||||
"""Returns the op with the specified name if it exists."""
|
||||
if name not in self._ops:
|
||||
return None
|
||||
return self._ops[name]
|
||||
|
||||
def get_input(self, name):
|
||||
"""Returns the input value with the specified name if it exists."""
|
||||
if name not in self._inputs:
|
||||
return None
|
||||
return self._inputs[name]
|
||||
|
||||
def get_inputs(self):
|
||||
"""Returns the module inputs."""
|
||||
return self._inputs.values()
|
||||
|
||||
def get_outputs(self):
|
||||
"""Returns the module outputs."""
|
||||
return self._outputs.values()
|
||||
|
||||
def add_op(
|
||||
self,
|
||||
op_type,
|
||||
name=None,
|
||||
inputs: List[Value] = None,
|
||||
attributes: Dict[str, Any] = None,
|
||||
submodules: List["Module"] = None,
|
||||
output_names: List[str] = None,
|
||||
) -> Union[None, Value, Tuple[Value, ...]]:
|
||||
"""Adds an op to the graph.
|
||||
|
||||
Args:
|
||||
op_type: The op's type.
|
||||
name: The op's name.
|
||||
inputs: The input values for this op.
|
||||
attributes: Any op-specific attributes.
|
||||
submodules: Any submodules this op is wrapping.
|
||||
output_names: An optinal list of output value names.
|
||||
|
||||
Returns:
|
||||
The outputs of the newly created op.
|
||||
"""
|
||||
if name in self._ops:
|
||||
raise ValueError(f"op with name {name} already exists!")
|
||||
elif name is None or name == "":
|
||||
name = f"{op_type}_#{self._op_counter[op_type]}"
|
||||
op = Op(
|
||||
name,
|
||||
op_type,
|
||||
in_edges=inputs,
|
||||
attributes=attributes,
|
||||
submodules=submodules,
|
||||
output_names=output_names,
|
||||
)
|
||||
self._ops[name] = op
|
||||
self._op_counter[op_type] += 1
|
||||
|
||||
# Update _consumers.
|
||||
out_edges = op.get_out_edges()
|
||||
for in_edge in inputs:
|
||||
self._consumers[in_edge.name].append(op.name)
|
||||
for out_edge in out_edges:
|
||||
self._consumers[out_edge.name] = []
|
||||
|
||||
# Return the op outputs.
|
||||
num_out_edges = len(out_edges)
|
||||
if num_out_edges == 0:
|
||||
return None
|
||||
elif num_out_edges == 1:
|
||||
return out_edges[0]
|
||||
else:
|
||||
return tuple(out_edges)
|
||||
|
||||
def add_input_value(self, name, value_type):
|
||||
"""Adds an input value to the graph and returns the value."""
|
||||
value = Value(name=name, value_type=value_type)
|
||||
if value.name in self._inputs:
|
||||
raise ValueError(f"Module already has input value with name {value.name}")
|
||||
self._inputs[value.name] = value
|
||||
return value
|
||||
|
||||
def get_consumers_for_value(self, name):
|
||||
return self._consumers[name]
|
||||
|
||||
def set_outputs(self, outputs: Iterable[Value]):
|
||||
"""Sets the output of this module to be the given values. They must be
|
||||
valid values, i.e. outputs of some existing op in the module. This clears
|
||||
any previous outputs registered with this module.
|
||||
"""
|
||||
for output in outputs:
|
||||
# NOTE: Using consumers as a proxy for valid values
|
||||
if output.name not in self._consumers:
|
||||
raise ValueError(f"Module has no value {output}")
|
||||
self._outputs.clear()
|
||||
for output in outputs:
|
||||
if output.name in self._outputs:
|
||||
raise ValueError(
|
||||
f"Module already has output value with name {output.name}"
|
||||
)
|
||||
self._outputs[output.name] = output
|
||||
|
||||
def set_outputs_auto(self):
|
||||
"""Marks all sink nodes in the graph as output values."""
|
||||
all_values = OrderedDict()
|
||||
is_output = OrderedDict()
|
||||
|
||||
self._outputs.clear()
|
||||
for input_value_name, input_value in self._inputs.items():
|
||||
all_values[input_value_name] = input_value
|
||||
is_output[input_value_name] = True
|
||||
|
||||
for op in self._ops.values():
|
||||
for in_edge in op.get_in_edges():
|
||||
is_output[in_edge.name] = False
|
||||
for out_edge in op.get_out_edges():
|
||||
all_values[out_edge.name] = out_edge
|
||||
is_output[out_edge.name] = True
|
||||
|
||||
for output_value_name in is_output:
|
||||
if is_output[output_value_name]:
|
||||
self._outputs[output_value_name] = all_values[output_value_name]
|
||||
|
||||
def _get_ops_in_topological_order_helper(self, name, visited, order):
|
||||
visited.add(name)
|
||||
|
||||
out_edges = self._ops[name].get_out_edges()
|
||||
for out_edge in out_edges:
|
||||
output_name = out_edge
|
||||
if output_name not in visited:
|
||||
self._get_ops_in_topological_order_helper(output_name, visited, order)
|
||||
|
||||
order.append(name)
|
||||
|
||||
def get_ops_in_topological_order(self):
|
||||
"""Return ops in topological order. DEPRECATED, ops should always be
|
||||
topologically ordered.
|
||||
"""
|
||||
visited = set()
|
||||
order = []
|
||||
for name in self._ops:
|
||||
if name not in visited:
|
||||
self._get_ops_in_topological_order_helper(name, visited, order)
|
||||
return order[::-1]
|
||||
|
||||
def verify_ops_in_topological_order(self):
|
||||
seen = set()
|
||||
for input_name in self._inputs:
|
||||
seen.add(input_name)
|
||||
|
||||
for op_name, op in self._ops.items():
|
||||
for in_edge in op.get_in_edges():
|
||||
if in_edge.name not in seen:
|
||||
raise ValueError(
|
||||
f"Ops are not in topological order: op {op_name} has "
|
||||
f"unseen edge {in_edge}"
|
||||
)
|
||||
for out_edge in op.get_out_edges():
|
||||
seen.add(out_edge.name)
|
||||
|
||||
def finalize(self):
|
||||
"""Performs some standard verification and inference passes. Use at the
|
||||
end whenever creating a module. Assumes that the module will no longer be
|
||||
modified after this function is called.
|
||||
"""
|
||||
# Putting this import at the top level causes an import loop
|
||||
from ..executor.shape_inference import infer_shapes
|
||||
|
||||
self.verify_ops_in_topological_order()
|
||||
if len(self._outputs) == 0:
|
||||
self.set_outputs_auto()
|
||||
infer_shapes(self)
|
||||
self._hash = hash(tuple(self._ops.keys()))
|
||||
|
||||
def get_submodule(self, op_names, name=None):
|
||||
"""Returns a submodule comprised of the specified subset of ops."""
|
||||
submodule = Module(name)
|
||||
value_map = {}
|
||||
outputs = []
|
||||
op_names_set = set(op_names)
|
||||
for op_name in op_names:
|
||||
op = self._ops[op_name]
|
||||
submodule_op_inputs = []
|
||||
for input in op.get_in_edges():
|
||||
if input.name not in value_map:
|
||||
value_map[input.name] = submodule.add_input_value(
|
||||
input.name, input.type
|
||||
)
|
||||
submodule_op_inputs.append(value_map[input.name])
|
||||
output_names = [output.name for output in op.get_out_edges()]
|
||||
submodule_op_outputs = submodule.add_op(
|
||||
op.op_type,
|
||||
name=op.name,
|
||||
inputs=submodule_op_inputs,
|
||||
attributes=op._attributes,
|
||||
submodules=copy.deepcopy(op._submodules),
|
||||
output_names=output_names,
|
||||
)
|
||||
if not isinstance(submodule_op_outputs, tuple):
|
||||
submodule_op_outputs = (submodule_op_outputs,)
|
||||
for output in submodule_op_outputs:
|
||||
# We need to explicitly set the submodule outputs because some output
|
||||
# values might have consumers outside the submodule (external).
|
||||
consumers = self._consumers[output.name]
|
||||
has_external_output = any([c not in op_names_set for c in consumers])
|
||||
if (
|
||||
output.name in self._outputs or has_external_output
|
||||
) and output.name not in op_names:
|
||||
outputs.append(output)
|
||||
value_map[output.name] = output
|
||||
submodule.set_outputs(outputs)
|
||||
submodule.finalize()
|
||||
return submodule
|
131
dist_ir/ir/op.py
131
dist_ir/ir/op.py
|
@ -1,86 +1,59 @@
|
|||
from dataclasses import dataclass, field, InitVar
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from frozendict import frozendict
|
||||
|
||||
from .op_register import OpRegister
|
||||
from .type import *
|
||||
from .value import Value
|
||||
from .type import Type
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Op:
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
op_type,
|
||||
in_edges=None,
|
||||
attributes=None,
|
||||
submodules=None,
|
||||
output_names=None,
|
||||
):
|
||||
if op_type not in OpRegister:
|
||||
raise ValueError(f"Invalid op type {op_type}")
|
||||
self._name = name
|
||||
self._op_type = op_type
|
||||
if in_edges is None:
|
||||
self._in_edges = []
|
||||
op_type: str
|
||||
name: str = ""
|
||||
inputs: Tuple[Value] = field(default_factory=tuple)
|
||||
attributes: Dict[str, Any] = field(default_factory=frozendict)
|
||||
subfunctions: Tuple["Function"] = field(default_factory=tuple)
|
||||
outputs: Tuple[Value] = field(init=False)
|
||||
|
||||
# These are not fields, just parameters to init and post_init:
|
||||
output_names: InitVar[Tuple[str]] = None
|
||||
output_types: InitVar[Tuple[Type]] = None
|
||||
|
||||
def __post_init__(self, output_names, output_types):
|
||||
if self.op_type == "Pmap":
|
||||
# Handle pmap specially
|
||||
assert len(self.subfunctions) == 1
|
||||
# Number of inputs is arbitrary but positive
|
||||
assert len(self.inputs) > 0
|
||||
# Number of inputs matches subfunction
|
||||
assert len(self.inputs) == len(self.subfunctions[0].inputs)
|
||||
# Number of outputs is given by subfunction
|
||||
num_outputs = len(self.subfunctions[0].outputs)
|
||||
|
||||
else:
|
||||
self._in_edges = in_edges
|
||||
if attributes is None:
|
||||
self._attributes = {}
|
||||
if self.op_type not in OpRegister:
|
||||
raise ValueError(f"Invalid op type {self.op_type}")
|
||||
# Check that we got the right number of inputs
|
||||
assert len(self.inputs) == OpRegister[self.op_type].num_inputs
|
||||
# Number of outputs is given by OpRegister
|
||||
num_outputs = OpRegister[self.op_type].num_outputs
|
||||
|
||||
# Create the correct number of output values with appropriate types
|
||||
if output_names is None:
|
||||
output_names = [f"{self.name}_out_{i}" for i in range(num_outputs)]
|
||||
else:
|
||||
self._attributes = attributes
|
||||
if submodules is None:
|
||||
self._submodules = []
|
||||
else:
|
||||
self._submodules = submodules
|
||||
self._out_edges = []
|
||||
OpRegister[op_type].infer_types(self, output_names)
|
||||
|
||||
def __str__(self):
|
||||
output = ""
|
||||
output += f"Name: {self._name}\n"
|
||||
output += f"Op type: {self._op_type}\n"
|
||||
output += "Inputs:\n"
|
||||
for in_edge in self._in_edges:
|
||||
output += " " + str(in_edge) + "\n"
|
||||
output += "Outputs:\n"
|
||||
for out_edge in self._out_edges:
|
||||
output += " " + str(out_edge) + "\n"
|
||||
if len(self._submodules) > 0:
|
||||
output += "Submodules:\n"
|
||||
for submodule in self._submodules:
|
||||
output += "\n".join(
|
||||
[" " + line for line in str(submodule).split("\n")]
|
||||
)
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
def add_in_edge(self, in_edge: Value):
|
||||
"""Adds an input edge."""
|
||||
self._in_edges.append(in_edge)
|
||||
|
||||
def add_out_edge(self, out_edge: Value):
|
||||
"""Adds an output edge."""
|
||||
self._out_edges.append(out_edge)
|
||||
|
||||
def get_in_edges(self):
|
||||
"""Returns all input edges."""
|
||||
return self._in_edges
|
||||
|
||||
def get_out_edges(self):
|
||||
"""Returns all output edges."""
|
||||
return self._out_edges
|
||||
|
||||
def get_attribute(self, attribute_name):
|
||||
"""Returns the specified attributes, or throws error if it does not exist."""
|
||||
return self._attributes[attribute_name]
|
||||
|
||||
def get_submodule(self, idx):
|
||||
"""Returns the submodule at the specified index."""
|
||||
return self._submodules[idx]
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def op_type(self):
|
||||
return self._op_type
|
||||
assert len(output_names) == num_outputs
|
||||
if output_types is None:
|
||||
output_types = [None for i in range(num_outputs)]
|
||||
elif len(output_types) != num_outputs:
|
||||
raise ValueError(
|
||||
f"Op {self.name} has {len(output_types)} outputs; "
|
||||
f"num_outputs expected"
|
||||
)
|
||||
outputs = tuple(
|
||||
Value(out_name, out_type)
|
||||
for out_name, out_type in zip(output_names, output_types)
|
||||
)
|
||||
object.__setattr__(self, "outputs", outputs) # Can't assign to frozen field
|
||||
|
|
|
@ -1,193 +1,35 @@
|
|||
from .device import Device
|
||||
from .type import Tensor, TupleType
|
||||
from .value import Value
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OpRegisterEntry:
|
||||
def __init__(self, input_types, output_types):
|
||||
self._input_types = input_types
|
||||
self._output_types = output_types
|
||||
|
||||
def infer_types(self, op, output_names=None):
|
||||
# Verify that number of inputs and input types match the expected input types.
|
||||
inputs = op.get_in_edges()
|
||||
if len(inputs) != len(self._input_types):
|
||||
raise ValueError(
|
||||
f"Op {op.name}: Expected {len(self._input_types)} inputs, "
|
||||
f"got {len(inputs)}"
|
||||
)
|
||||
for i, (input, input_type) in enumerate(zip(inputs, self._input_types)):
|
||||
if not isinstance(input.type, input_type):
|
||||
raise ValueError(
|
||||
f"Op {op.name}: Expected input of type {input_type} for "
|
||||
f"input {i}, got input of type {input.type}"
|
||||
)
|
||||
|
||||
# Verify that the number of output names is correct if specified.
|
||||
if output_names is not None and len(output_names) != len(self._output_types):
|
||||
raise ValueError(
|
||||
f"Op {op.name}: Expected {len(output_names)} outputs, "
|
||||
f"got {len(self._output_types)}"
|
||||
)
|
||||
|
||||
# Construct the output values and add them to the op's out edge list.
|
||||
for i, output_type in enumerate(self._output_types):
|
||||
if output_names is not None and output_names[i] != "":
|
||||
output_name = output_names[i]
|
||||
else:
|
||||
output_name = f"{op.name}/{i}"
|
||||
op.add_out_edge(Value(output_name, value_type=output_type()))
|
||||
|
||||
|
||||
class AllreduceOpRegisterEntry(OpRegisterEntry):
|
||||
# TODO: Remove this and handle generic types in OpRegisterEntry
|
||||
def infer_types(self, op, output_names=None):
|
||||
inputs = op.get_in_edges()
|
||||
if len(inputs) != 1:
|
||||
raise ValueError(f"Op {op.name}: Expected 1 input, got {len(inputs)}")
|
||||
elif not isinstance(inputs[0].type, TupleType):
|
||||
raise ValueError(
|
||||
f"Op {op.name}: Expected input of type {self._input_types[0]}, "
|
||||
f"got input of type {inputs[0].type}"
|
||||
)
|
||||
if output_names is not None:
|
||||
output_name = output_names[0]
|
||||
else:
|
||||
output_name = f"{op.name}/{0}"
|
||||
output_value_type = copy.deepcopy(inputs[0].type)
|
||||
op.add_out_edge(Value(name=output_name, value_type=output_value_type))
|
||||
|
||||
|
||||
class BroadcastScatterOpRegisterEntry(OpRegisterEntry):
|
||||
# TODO: Remove this and handle generic types in OpRegisterEntry
|
||||
def infer_types(self, op, output_names=None):
|
||||
inputs = op.get_in_edges()
|
||||
devices = op.get_attribute("devices")
|
||||
if output_names is not None and len(output_names) != 1:
|
||||
raise ValueError(
|
||||
f"Op {op.name}: Expected 1 output name but got {len(output_names)}"
|
||||
)
|
||||
output_types = []
|
||||
for i, device in enumerate(devices):
|
||||
output_type = copy.deepcopy(inputs[0].type)
|
||||
output_type.set_device(device)
|
||||
if op.op_type == "Scatter":
|
||||
if isinstance(output_type, Tensor):
|
||||
output_type.shape = None
|
||||
output_types.append(output_type)
|
||||
output_value = Value(output_names[0], value_type=TupleType(output_types))
|
||||
op.add_out_edge(output_value)
|
||||
|
||||
|
||||
class PmapOpRegisterEntry(OpRegisterEntry):
|
||||
def infer_types(self, op, output_names=None):
|
||||
devices = op.get_attribute("devices")
|
||||
submodule = op.get_submodule(0)
|
||||
submodule_inputs = submodule.get_inputs()
|
||||
submodule_outputs = submodule.get_outputs()
|
||||
# TODO: If we want a more robust solution for nested pmaps, move the
|
||||
# parameterization over device variable to the module code
|
||||
# TODO: Handle multiple device types?
|
||||
d = Device.get_new_device_variable(devices[0].device_type)
|
||||
for in_edge in submodule_inputs:
|
||||
in_edge.type.set_device(d)
|
||||
|
||||
# TODO: Change the submodule input names to indicate they are
|
||||
# parameterized over the devices
|
||||
|
||||
for i, out_edge in enumerate(submodule_outputs):
|
||||
output_types = []
|
||||
for device in devices:
|
||||
output_type = copy.deepcopy(out_edge.type)
|
||||
output_type.set_device(device)
|
||||
output_types.append(output_type)
|
||||
if output_names is None:
|
||||
output_name = f"{out_edge.name}is"
|
||||
else:
|
||||
output_name = output_names[i]
|
||||
output_value = Value(output_name, value_type=TupleType(output_types))
|
||||
op.add_out_edge(output_value)
|
||||
|
||||
|
||||
class SelectOpRegisterEntry(OpRegisterEntry):
|
||||
def infer_types(self, op, output_names=None):
|
||||
inputs = op.get_in_edges()
|
||||
dim = op.get_attribute("dim")
|
||||
output_value = Value(
|
||||
output_names[0], value_type=copy.deepcopy(inputs[0].type.types[dim])
|
||||
)
|
||||
op.add_out_edge(output_value)
|
||||
|
||||
|
||||
class SendOpRegisterEntry(OpRegisterEntry):
|
||||
def infer_types(self, op, output_names=None):
|
||||
inputs = op.get_in_edges()
|
||||
device = op.get_attribute("device")
|
||||
output_value_type = copy.deepcopy(inputs[0].type)
|
||||
output_value_type.set_device(device)
|
||||
output_value = Value(output_names[0], value_type=output_value_type)
|
||||
op.add_out_edge(output_value)
|
||||
|
||||
|
||||
class SplitOpRegisterEntry(OpRegisterEntry):
|
||||
def infer_types(self, op, output_names=None):
|
||||
inputs = op.get_in_edges()
|
||||
num_splits = op.get_attribute("num_splits")
|
||||
output_types = []
|
||||
for i in range(num_splits):
|
||||
output_type = copy.deepcopy(inputs[0].type)
|
||||
output_type.shape = None
|
||||
output_types.append(output_type)
|
||||
output_value = Value(output_names[0], value_type=TupleType(output_types))
|
||||
op.add_out_edge(output_value)
|
||||
num_inputs: int
|
||||
num_outputs: int
|
||||
|
||||
|
||||
OpRegister = {
|
||||
"Add": OpRegisterEntry(input_types=[Tensor, Tensor], output_types=[Tensor]),
|
||||
"Allreduce": AllreduceOpRegisterEntry(
|
||||
input_types=[TupleType], output_types=[TupleType]
|
||||
),
|
||||
"Broadcast": BroadcastScatterOpRegisterEntry(
|
||||
input_types=[Tensor], output_types=[TupleType]
|
||||
),
|
||||
"BroadcastGradientArgs": OpRegisterEntry(
|
||||
input_types=[Tensor, Tensor], output_types=[Tensor, Tensor]
|
||||
),
|
||||
"Concat": OpRegisterEntry(input_types=[Tensor, Tensor], output_types=[Tensor]),
|
||||
"Gather": OpRegisterEntry(input_types=[TupleType], output_types=[Tensor]),
|
||||
"Gemm": OpRegisterEntry(
|
||||
input_types=[Tensor, Tensor, Tensor], output_types=[Tensor]
|
||||
),
|
||||
"Loss": OpRegisterEntry(input_types=[Tensor, Tensor], output_types=[Tensor]),
|
||||
"LossGrad": OpRegisterEntry(input_types=[Tensor, Tensor], output_types=[Tensor]),
|
||||
"MatMul": OpRegisterEntry(input_types=[Tensor, Tensor], output_types=[Tensor]),
|
||||
"MatMulGrad": OpRegisterEntry(
|
||||
input_types=[Tensor, Tensor, Tensor], output_types=[Tensor, Tensor]
|
||||
),
|
||||
"ReduceSumTraining": OpRegisterEntry(
|
||||
input_types=[Tensor, Tensor], output_types=[Tensor]
|
||||
),
|
||||
"Relu": OpRegisterEntry(input_types=[Tensor], output_types=[Tensor]),
|
||||
"ReluGrad": OpRegisterEntry(input_types=[Tensor, Tensor], output_types=[Tensor]),
|
||||
"Reshape": OpRegisterEntry(input_types=[Tensor, Tensor], output_types=[Tensor]),
|
||||
"Opt": OpRegisterEntry(input_types=[Tensor, Tensor], output_types=[Tensor]),
|
||||
"Pmap": PmapOpRegisterEntry(input_types=None, output_types=None),
|
||||
"Scatter": BroadcastScatterOpRegisterEntry(
|
||||
input_types=[Tensor], output_types=[TupleType]
|
||||
),
|
||||
"Select": SelectOpRegisterEntry(input_types=[TupleType], output_types=[Tensor]),
|
||||
"Send": SendOpRegisterEntry(input_types=[Tensor], output_types=[Tensor]),
|
||||
"SGDOptimizer": OpRegisterEntry(
|
||||
input_types=[Tensor, Tensor, Tensor], output_types=[Tensor, Tensor]
|
||||
),
|
||||
"Shape": OpRegisterEntry(input_types=[Tensor], output_types=[Tensor]),
|
||||
"SoftmaxCrossEntropy": OpRegisterEntry(
|
||||
input_types=[Tensor, Tensor], output_types=[Tensor, Tensor]
|
||||
),
|
||||
"SoftmaxCrossEntropyGrad": OpRegisterEntry(
|
||||
input_types=[Tensor, Tensor, Tensor], output_types=[Tensor]
|
||||
),
|
||||
"Split": SplitOpRegisterEntry(input_types=[Tensor], output_types=[TupleType]),
|
||||
"Add": OpRegisterEntry(num_inputs=2, num_outputs=1),
|
||||
"Allreduce": OpRegisterEntry(num_inputs=1, num_outputs=1),
|
||||
"Broadcast": OpRegisterEntry(num_inputs=1, num_outputs=1),
|
||||
"BroadcastGradientArgs": OpRegisterEntry(num_inputs=2, num_outputs=2),
|
||||
"Concat": OpRegisterEntry(num_inputs=2, num_outputs=1),
|
||||
"Gather": OpRegisterEntry(num_inputs=1, num_outputs=1),
|
||||
"Gemm": OpRegisterEntry(num_inputs=3, num_outputs=1),
|
||||
"Loss": OpRegisterEntry(num_inputs=2, num_outputs=1),
|
||||
"LossGrad": OpRegisterEntry(num_inputs=2, num_outputs=1),
|
||||
"MatMul": OpRegisterEntry(num_inputs=2, num_outputs=1),
|
||||
"MatMulGrad": OpRegisterEntry(num_inputs=3, num_outputs=2),
|
||||
"ReduceSumTraining": OpRegisterEntry(num_inputs=2, num_outputs=1),
|
||||
"Relu": OpRegisterEntry(num_inputs=1, num_outputs=1),
|
||||
"ReluGrad": OpRegisterEntry(num_inputs=2, num_outputs=1),
|
||||
"Reshape": OpRegisterEntry(num_inputs=2, num_outputs=1),
|
||||
"Opt": OpRegisterEntry(num_inputs=2, num_outputs=1),
|
||||
"Scatter": OpRegisterEntry(num_inputs=1, num_outputs=1),
|
||||
"Select": OpRegisterEntry(num_inputs=1, num_outputs=1),
|
||||
"Send": OpRegisterEntry(num_inputs=1, num_outputs=1),
|
||||
"SGDOptimizer": OpRegisterEntry(num_inputs=3, num_outputs=2),
|
||||
"Shape": OpRegisterEntry(num_inputs=1, num_outputs=1),
|
||||
"SoftmaxCrossEntropy": OpRegisterEntry(num_inputs=2, num_outputs=2),
|
||||
"SoftmaxCrossEntropyGrad": OpRegisterEntry(num_inputs=3, num_outputs=1),
|
||||
"Split": OpRegisterEntry(num_inputs=1, num_outputs=1),
|
||||
}
|
||||
|
|
|
@ -47,7 +47,7 @@ from prettyprinter.prettyprinter import (
|
|||
)
|
||||
from prettyprinter.utils import intersperse
|
||||
|
||||
from .module import Module
|
||||
from .function import Function
|
||||
from .value import Value
|
||||
from .type import Type, Int, Float, Tensor, TupleType
|
||||
from .device import Device
|
||||
|
@ -97,10 +97,10 @@ def interline(*docs):
|
|||
# ----------------------------------------
|
||||
|
||||
|
||||
def _pprint_module_body(module: Module, ctx):
|
||||
ops = [pretty_dispatch(op, ctx) for op in module.get_ops().values()]
|
||||
def _pprint_function_body(function: Function, ctx):
|
||||
ops = [pretty_dispatch(op, ctx) for op in function.ops]
|
||||
# Include the outputs as a final "return" op
|
||||
outputs = concat(_join(*(r.name for r in module.get_outputs())))
|
||||
outputs = concat(_join(*(r.name for r in function.outputs)))
|
||||
return_line = group(
|
||||
nest(ctx.indent, concat([pp_reserved("return"), LINE, outputs]))
|
||||
)
|
||||
|
@ -108,12 +108,12 @@ def _pprint_module_body(module: Module, ctx):
|
|||
return ops
|
||||
|
||||
|
||||
@register_pretty(Module)
|
||||
def _(module: Module, ctx):
|
||||
ops = _pprint_module_body(module, ctx)
|
||||
@register_pretty(Function)
|
||||
def _(function: Function, ctx):
|
||||
ops = _pprint_function_body(function, ctx)
|
||||
return concat(
|
||||
[
|
||||
pretty_call(ctx, pp_fnname("Module"), *module.get_inputs()),
|
||||
pretty_call(ctx, pp_fnname("Function"), *function.inputs),
|
||||
nest(ctx.indent, concat([COLON, HARDLINE, interline(*ops)])),
|
||||
]
|
||||
)
|
||||
|
@ -121,15 +121,15 @@ def _(module: Module, ctx):
|
|||
|
||||
@register_pretty(Op)
|
||||
def _(op: Op, ctx):
|
||||
results = concat(_join(*(pretty_dispatch(r, ctx) for r in op.get_out_edges())))
|
||||
args = concat(_join(*(v.name for v in op.get_in_edges())))
|
||||
results = concat(_join(*(pretty_dispatch(r, ctx) for r in op.outputs)))
|
||||
args = concat(_join(*(v.name for v in op.inputs)))
|
||||
|
||||
if op.op_type == "Pmap":
|
||||
lambda_args = _join(
|
||||
*(pretty_dispatch(i, ctx) for i in op.get_submodule(0).get_inputs())
|
||||
*(pretty_dispatch(i, ctx) for i in op.subfunctions[0].inputs)
|
||||
)
|
||||
lambda_args = concat([LPAREN, nest(ctx.indent, concat(lambda_args)), RPAREN])
|
||||
lambda_body = _pprint_module_body(op.get_submodule(0), ctx)
|
||||
lambda_body = _pprint_function_body(op.subfunctions[0], ctx)
|
||||
actual_args = group(
|
||||
concat(
|
||||
[
|
||||
|
@ -139,7 +139,8 @@ def _(op: Op, ctx):
|
|||
]
|
||||
)
|
||||
)
|
||||
d = str(op.get_attribute("device_var").device_id)
|
||||
# TODO: Also print out the list of devices this pmaps over
|
||||
d = str(op.attributes["device_var"].device_id)
|
||||
pmap_args = nest(
|
||||
ctx.indent,
|
||||
concat(
|
||||
|
|
|
@ -1,24 +1,26 @@
|
|||
from abc import ABC
|
||||
from dataclasses import dataclass
|
||||
from functools import reduce
|
||||
from operator import add, mul
|
||||
from typing import Optional, Tuple, TypeVar
|
||||
from typing import Optional, Set, Tuple
|
||||
|
||||
from .device import Device
|
||||
from .utils import singleton
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Type:
|
||||
def __init__(self, has_device=False):
|
||||
self._has_device = has_device
|
||||
self._device = None
|
||||
"""Base class for all types."""
|
||||
|
||||
def set_device(self, device):
|
||||
if self._has_device:
|
||||
self._device = device
|
||||
device: Optional[Device] = None
|
||||
|
||||
def get_all_devices(self):
|
||||
if self._has_device and self._device is not None:
|
||||
return set([self._device])
|
||||
def get_all_devices(self) -> Set[Device]:
|
||||
"""Returns all devices that a value of this type lives on. For example,
|
||||
a tuple can have different elements live on different devices.
|
||||
|
||||
Subclasses should override this default implementation.
|
||||
"""
|
||||
if self.device is not None:
|
||||
return set([self.device])
|
||||
return set()
|
||||
|
||||
|
||||
|
@ -29,11 +31,8 @@ class Type:
|
|||
class Int(Type):
|
||||
"""The integer type. A singleton class."""
|
||||
|
||||
def __str__(self):
|
||||
return "Int"
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
return "Int"
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
|
@ -44,90 +43,61 @@ class Int(Type):
|
|||
class Float(Type):
|
||||
"""The float type. A singleton class."""
|
||||
|
||||
def __str__(self):
|
||||
return f"Float"
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
return "Float"
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return 4
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Tensor(Type):
|
||||
"""The tensor type. Contains a shape and an element type (dtype)."""
|
||||
|
||||
# TODO have a global cache to avoid creating multiple objects of same type?
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dtype: Type = None,
|
||||
shape: Optional[Tuple[int]] = None,
|
||||
device: Device = None,
|
||||
):
|
||||
Type.__init__(self, has_device=True)
|
||||
self._device = device
|
||||
self._shape = shape
|
||||
self._dtype = dtype
|
||||
dtype: Type = None # Unable to make this non-optional, but this should be
|
||||
shape: Optional[Tuple[int]] = None
|
||||
|
||||
def __init__(self, dtype=None, shape=None, device=None):
|
||||
# TODO make dtype a required argument?
|
||||
assert dtype is None or isinstance(dtype, Type)
|
||||
assert shape is None or (
|
||||
isinstance(shape, tuple) and all(isinstance(n, int) for n in shape)
|
||||
)
|
||||
object.__setattr__(self, "dtype", dtype) # Can't assign to frozen field
|
||||
object.__setattr__(self, "shape", shape)
|
||||
Type.__init__(self, device=device)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"Tensor[shape={self._shape}, dtype={self._dtype}, device={self._device}]"
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
self._dtype == other._dtype
|
||||
and self._shape == other._shape
|
||||
and self._device == other._device
|
||||
)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self._shape
|
||||
|
||||
@shape.setter
|
||||
def shape(self, shape):
|
||||
self._shape = shape
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
@dtype.setter
|
||||
def dtype(self, dtype):
|
||||
self._dtype = dtype
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._device
|
||||
return f"Tensor[shape={self.shape}, dtype={self.dtype}, device={self.device}]"
|
||||
|
||||
def size(self):
|
||||
return reduce(mul, self._shape) * self._dtype.size
|
||||
return reduce(mul, self.shape) * self.dtype.size
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TupleType(Type):
|
||||
def __init__(self, types):
|
||||
Type.__init__(self)
|
||||
self._types = types
|
||||
|
||||
def __str__(self):
|
||||
elems_str = ", ".join(str(t) for t in self._types)
|
||||
return f"Tuple[{elems_str}]"
|
||||
types: Tuple[Type] = None
|
||||
|
||||
def __init__(self, types):
|
||||
# Override __init__ because it doesn't make sense for a tuple to have a
|
||||
# device. Devices are stored in each tuple element.
|
||||
object.__setattr__(self, "types", types) # Can't assign to frozen field
|
||||
assert isinstance(types, tuple) and all(isinstance(t, Type) for t in types)
|
||||
assert self.device is None
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
elems_str = ", ".join(str(t) for t in self.types)
|
||||
return f"Tuple[{elems_str}]"
|
||||
|
||||
@property
|
||||
def types(self):
|
||||
return self._types
|
||||
|
||||
def get_all_devices(self):
|
||||
def get_all_devices(self) -> Set[Device]:
|
||||
devices = set()
|
||||
for typ in self._types:
|
||||
for typ in self.types:
|
||||
devices.update(typ.get_all_devices())
|
||||
return devices
|
||||
|
||||
def size(self):
|
||||
return reduce(add, [typ.size() for typ in self._types])
|
||||
return reduce(add, [typ.size() for typ in self.types])
|
||||
|
|
|
@ -1,28 +1,14 @@
|
|||
from dataclasses import dataclass
|
||||
|
||||
from .type import Type
|
||||
|
||||
|
||||
@dataclass(frozen=True, eq=False)
|
||||
class Value:
|
||||
def __init__(self, name, value_type):
|
||||
self._name = name
|
||||
self._type = value_type
|
||||
"""A DistIR value. While values have names, DistIR makes no attempt to ensure
|
||||
value names are unique in a function. Therefore Value equality is object
|
||||
equality. (TODO correct terminology for this?)
|
||||
"""
|
||||
|
||||
def __str__(self):
|
||||
return f"{self._name}: type={str(self._type)}"
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(repr(self))
|
||||
|
||||
def __eq__(self, other):
|
||||
return self._name == other._name and self._type == other._type
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return self._type
|
||||
|
||||
@type.setter
|
||||
def type(self, typ):
|
||||
self._type = typ
|
||||
name: str
|
||||
type: Type
|
||||
|
|
|
@ -1,17 +1,15 @@
|
|||
from ..ir.module import Module
|
||||
|
||||
import copy
|
||||
from ..ir.function import FunctionMaker
|
||||
|
||||
|
||||
class DataParallelTransform:
|
||||
"""Partitions a module using data parallelism.
|
||||
"""Partitions a function using data parallelism.
|
||||
|
||||
Replicates the given model across devices by instantiating an identical version
|
||||
of the model on each device. The user specifies which input values to
|
||||
partition between each device as well as the dimension to partition for each input
|
||||
(e.g. selecting the first dimension for the input minibatch would partition
|
||||
along the batch dimension). The selected input values are scattered between
|
||||
each device, while the remaining input values are broadcasted. The module will
|
||||
each device, while the remaining input values are broadcasted. The function will
|
||||
be replicated using a Pmap operator. The original output values are retrieved
|
||||
from each replica through Allreduce operators.
|
||||
|
||||
|
@ -26,33 +24,31 @@ class DataParallelTransform:
|
|||
self._reduction_params = reduction_params
|
||||
self._devices = devices
|
||||
|
||||
def apply(self, module):
|
||||
"""Applies the transformation to the given module and returns the transformed module."""
|
||||
transformed_module = Module()
|
||||
def apply(self, function):
|
||||
"""Applies the transformation to the given function and returns the transformed function."""
|
||||
transformed_function = FunctionMaker()
|
||||
|
||||
# Either scatter or broadcast each input value depending on what the user
|
||||
# has requested.
|
||||
# TODO: Add explicit Send ops if the source device is not one of the
|
||||
# destination devices.
|
||||
input_values = module.get_inputs()
|
||||
input_values = function.inputs
|
||||
pmap_input_values = []
|
||||
for input_value in input_values:
|
||||
v = transformed_module.add_input_value(
|
||||
input_value.name, copy.deepcopy(input_value.type)
|
||||
)
|
||||
if input_value.name in self._batch_dims:
|
||||
vs = transformed_module.add_op(
|
||||
v = transformed_function.add_input_value(input_value.name, input_value.type)
|
||||
if input_value in self._batch_dims:
|
||||
vs = transformed_function.add_op(
|
||||
"Scatter",
|
||||
name=f"Scatter/{v.name}",
|
||||
inputs=[v],
|
||||
attributes={
|
||||
"devices": self._devices,
|
||||
"dim": self._batch_dims[input_value.name],
|
||||
"dim": self._batch_dims[input_value],
|
||||
},
|
||||
output_names=[f"{v.name}s"],
|
||||
)
|
||||
else:
|
||||
vs = transformed_module.add_op(
|
||||
vs = transformed_function.add_op(
|
||||
"Broadcast",
|
||||
name=f"Broadcast/{v.name}",
|
||||
inputs=[v],
|
||||
|
@ -61,18 +57,18 @@ class DataParallelTransform:
|
|||
)
|
||||
pmap_input_values.append(vs)
|
||||
|
||||
# Add the Pmap operator to the transformed module. The Pmap operator will
|
||||
# encapsulate the original module.
|
||||
output_values = module.get_outputs()
|
||||
# Add the Pmap operator to the transformed function. The Pmap operator will
|
||||
# encapsulate the original function.
|
||||
output_values = function.outputs
|
||||
pmap_output_names = []
|
||||
for i, output_value in enumerate(output_values):
|
||||
pmap_output_name = f"{output_value.name}is"
|
||||
pmap_output_names.append(pmap_output_name)
|
||||
pmap_output_values = transformed_module.add_op(
|
||||
pmap_output_values = transformed_function.add_op(
|
||||
"Pmap",
|
||||
inputs=pmap_input_values,
|
||||
attributes={"devices": self._devices},
|
||||
submodules=[module],
|
||||
subfunctions=[function],
|
||||
output_names=pmap_output_names,
|
||||
)
|
||||
|
||||
|
@ -83,23 +79,23 @@ class DataParallelTransform:
|
|||
# TODO: Add explicit Send ops if the destination device is not one of the
|
||||
# source devices.
|
||||
for i, output_value in enumerate(output_values):
|
||||
reduction_op_type = self._reduction_params[output_value.name]["op_type"]
|
||||
reduction_op_type = self._reduction_params[output_value]["op_type"]
|
||||
if reduction_op_type == "Allreduce":
|
||||
transformed_module.add_op(
|
||||
transformed_function.add_op(
|
||||
"Allreduce",
|
||||
name=f"Allreduce/{output_value.name}",
|
||||
inputs=[pmap_output_values[i]],
|
||||
output_names=[f"{output_value.name}s"],
|
||||
)
|
||||
elif reduction_op_type == "Gather":
|
||||
dim = self._reduction_params[output_value.name]["dim"]
|
||||
device = self._reduction_params[output_value.name]["device"]
|
||||
transformed_module.add_op(
|
||||
dim = self._reduction_params[output_value]["dim"]
|
||||
device = self._reduction_params[output_value]["device"]
|
||||
transformed_function.add_op(
|
||||
"Gather",
|
||||
name=f"Gather/{output_value.name}",
|
||||
inputs=[pmap_output_values[i]],
|
||||
attributes={"dim": dim, "device": device},
|
||||
output_names=[f"{output_value.name}s"],
|
||||
output_names=[f"{output_value.name}"],
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
@ -107,4 +103,4 @@ class DataParallelTransform:
|
|||
f"output value {output_value}"
|
||||
)
|
||||
|
||||
return transformed_module
|
||||
return transformed_function.finalize()
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from collections import defaultdict
|
||||
from typing import Dict, Set, Tuple
|
||||
|
||||
from ..ir import Device, Module
|
||||
from ..ir import Device, Function
|
||||
from .pipeline_parallel_scheduler import PipelineParallelScheduler
|
||||
|
||||
|
||||
|
@ -9,13 +9,12 @@ class FIFOScheduler(PipelineParallelScheduler):
|
|||
"""Implements a FIFO schedule where all forward pass stages are executed before
|
||||
backward pass stages."""
|
||||
|
||||
def _get_next_stage_to_schedule(self, device: Device) -> Tuple[Module, int]:
|
||||
def _get_next_stage_to_schedule(self, device: Device) -> Tuple[Function, int]:
|
||||
ready_stages_by_type = defaultdict(list)
|
||||
for ready_stage in self._ready_stages[device]:
|
||||
# TODO: Use a more robust method to identify backwards pass stages.
|
||||
(stage, microbatch) = ready_stage
|
||||
ops = stage.get_ops()
|
||||
if "Grad" in list(ops.keys())[0]:
|
||||
if "Grad" in stage.ops[0].name:
|
||||
ready_stages_by_type["bw"].append(ready_stage)
|
||||
else:
|
||||
ready_stages_by_type["fw"].append(ready_stage)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from collections import defaultdict
|
||||
from typing import Dict, Set, Tuple
|
||||
|
||||
from ..ir import Device, Module
|
||||
from ..ir import Device, Function
|
||||
from .pipeline_parallel_scheduler import PipelineParallelScheduler
|
||||
|
||||
|
||||
|
@ -12,13 +12,12 @@ class PipeDreamScheduler(PipelineParallelScheduler):
|
|||
PipelineParallelScheduler.__init__(self, num_microbatches)
|
||||
self._prev_stage_types = defaultdict(lambda: "bw")
|
||||
|
||||
def _get_next_stage_to_schedule(self, device: Device) -> Tuple[Module, int]:
|
||||
def _get_next_stage_to_schedule(self, device: Device) -> Tuple[Function, int]:
|
||||
ready_stages_by_type = defaultdict(list)
|
||||
for ready_stage in self._ready_stages[device]:
|
||||
# TODO: Use a more robust method to identify backwards pass stages.
|
||||
(stage, microbatch) = ready_stage
|
||||
ops = stage.get_ops()
|
||||
if "Grad" in list(ops.keys())[0]:
|
||||
if "Grad" in stage.ops[0].name:
|
||||
ready_stages_by_type["bw"].append(ready_stage)
|
||||
else:
|
||||
ready_stages_by_type["fw"].append(ready_stage)
|
||||
|
|
|
@ -2,14 +2,14 @@ from abc import ABC, abstractmethod
|
|||
from collections import defaultdict
|
||||
from typing import Dict, Set, Tuple
|
||||
|
||||
from ..ir import Module, Device, Op
|
||||
from ..ir import Function, Device, Op
|
||||
from . import utils
|
||||
|
||||
|
||||
class PipelineParallelScheduler(ABC):
|
||||
"""Interface for a pipeline parallel scheduler.
|
||||
|
||||
Pipeline parallel schedulers take as input a DistIR module, the number of
|
||||
Pipeline parallel schedulers take as input a DistIR function, the number of
|
||||
microbatches to partition each minibatch into, and a partition map which
|
||||
captures the explicit placement of each stage onto corresponding devices.
|
||||
The scheduler will return a time-ordered list of stages to execute on each
|
||||
|
@ -24,13 +24,13 @@ class PipelineParallelScheduler(ABC):
|
|||
self._remaining_inputs = defaultdict(lambda: 0)
|
||||
self._ready_stages = defaultdict(list)
|
||||
|
||||
def _prepare_stages_to_schedule(self, module, partition_map):
|
||||
def _prepare_stages_to_schedule(self, function, partition_map):
|
||||
"""Enumerates the stages to schedule on each device across all microbatches."""
|
||||
for stage, device in partition_map.items():
|
||||
inputs = stage.get_inputs()
|
||||
inputs = stage.inputs
|
||||
remaining_inputs = len(inputs)
|
||||
for input in inputs:
|
||||
if module.is_input(input.name):
|
||||
if input in function.inputs:
|
||||
remaining_inputs -= 1
|
||||
for i in range(self._num_microbatches):
|
||||
self._remaining_inputs[(stage, i)] = remaining_inputs
|
||||
|
@ -38,11 +38,11 @@ class PipelineParallelScheduler(ABC):
|
|||
self._ready_stages[device].append((stage, i))
|
||||
|
||||
@abstractmethod
|
||||
def _get_next_stage_to_schedule(self, device: Device) -> Tuple[Module, int]:
|
||||
def _get_next_stage_to_schedule(self, device: Device) -> Tuple[Function, int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def schedule(self, module, partition_map):
|
||||
self._prepare_stages_to_schedule(module, partition_map)
|
||||
def schedule(self, function, partition_map):
|
||||
self._prepare_stages_to_schedule(function, partition_map)
|
||||
op_to_stage = utils.get_op_to_stage_map(list(partition_map.keys()))
|
||||
num_scheduled_stages = 0
|
||||
total_stages_to_schedule = len(partition_map) * self._num_microbatches
|
||||
|
@ -58,10 +58,10 @@ class PipelineParallelScheduler(ABC):
|
|||
# TODO: Optimize this so it isn't an O(N) call?
|
||||
self._ready_stages[device].remove(stage_to_schedule)
|
||||
num_scheduled_stages += 1
|
||||
outputs = stage.get_outputs()
|
||||
outputs = stage.outputs
|
||||
for output in outputs:
|
||||
consumer_ops = module.get_consumers_for_value(output.name)
|
||||
consumer_stages = utils.get_stages_from_op_names(
|
||||
consumer_ops = function.get_consumers(output)
|
||||
consumer_stages = utils.get_stages_from_ops(
|
||||
op_to_stage, consumer_ops
|
||||
)
|
||||
for consumer_stage in consumer_stages:
|
||||
|
|
|
@ -1,20 +1,19 @@
|
|||
import copy
|
||||
from collections import defaultdict
|
||||
|
||||
from ..ir.module import Module
|
||||
from ..ir.function import FunctionMaker
|
||||
from ..ir.value import Value
|
||||
from . import utils
|
||||
|
||||
|
||||
class PipelineParallelTransform:
|
||||
"""Partitions a module using pipeline parallelism.
|
||||
"""Partitions a function using pipeline parallelism.
|
||||
|
||||
Attributes:
|
||||
num_microbatches: The number of microbatches per pipeline iteration.
|
||||
batch_dims: A map from input value name to partition dimension.
|
||||
reduction_ops: A map from output value name to a map of reduction op params.
|
||||
partition_map: A map from op name to device.
|
||||
schedule: A list of maps from device to a tuple of (op_name, microbatch).
|
||||
batch_dims: A map from input value to partition dimension.
|
||||
reduction_ops: A map from output value to a map of reduction op params.
|
||||
partition_map: A map from op to device.
|
||||
schedule: A list of maps from device to a tuple of (op, microbatch).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -29,9 +28,9 @@ class PipelineParallelTransform:
|
|||
list(self._partition_map.keys())
|
||||
)
|
||||
|
||||
def _forward_value(self, transformed_module, value, device):
|
||||
def _forward_value(self, transformed_function, value, device):
|
||||
"""Forwards the specified value to the specified device by adding a Send op."""
|
||||
forwarded_value = transformed_module.add_op(
|
||||
forwarded_value = transformed_function.add_op(
|
||||
"Send",
|
||||
name=f"Send/{value.name}@{device}",
|
||||
inputs=[value],
|
||||
|
@ -40,27 +39,24 @@ class PipelineParallelTransform:
|
|||
)
|
||||
return forwarded_value
|
||||
|
||||
def _partition_inputs(self, module, transformed_module, pipelined_value_map):
|
||||
def _partition_inputs(self, function, transformed_function, pipelined_value_map):
|
||||
"""Splits the input values according to the number of specified microbatches."""
|
||||
input_values = module.get_inputs()
|
||||
for input_value in input_values:
|
||||
v = transformed_module.add_input_value(
|
||||
input_value.name, copy.deepcopy(input_value.type)
|
||||
)
|
||||
pipelined_input_map = pipelined_value_map[input_value.name]
|
||||
if input_value.name in self._batch_dims:
|
||||
vs = transformed_module.add_op(
|
||||
for input_value in function.inputs:
|
||||
v = transformed_function.add_input_value(input_value.name, input_value.type)
|
||||
pipelined_input_map = pipelined_value_map[input_value]
|
||||
if input_value in self._batch_dims:
|
||||
vs = transformed_function.add_op(
|
||||
"Split",
|
||||
name=f"Split/{v.name}",
|
||||
inputs=[v],
|
||||
attributes={
|
||||
"num_splits": self._num_microbatches,
|
||||
"dim": self._batch_dims[input_value.name],
|
||||
"dim": self._batch_dims[input_value],
|
||||
},
|
||||
output_names=[f"{v.name}s"],
|
||||
)
|
||||
for i in range(self._num_microbatches):
|
||||
v_i = transformed_module.add_op(
|
||||
v_i = transformed_function.add_op(
|
||||
"Select",
|
||||
name=f"Select/{v.name}_{i}",
|
||||
attributes={"dim": i},
|
||||
|
@ -75,8 +71,8 @@ class PipelineParallelTransform:
|
|||
# Forward the input value(s) if the destination device(s) are not
|
||||
# the same as the source device.
|
||||
input_device = input_value.type.device
|
||||
consumer_ops = module.get_consumers_for_value(input_value.name)
|
||||
consumer_stages = utils.get_stages_from_op_names(
|
||||
consumer_ops = function.get_consumers(input_value)
|
||||
consumer_stages = utils.get_stages_from_ops(
|
||||
self._op_to_stage_map, consumer_ops
|
||||
)
|
||||
consumer_devices = set([self._partition_map[c] for c in consumer_stages])
|
||||
|
@ -92,9 +88,9 @@ class PipelineParallelTransform:
|
|||
# TODO: Propagate these values alongside activations instead of sending them
|
||||
# ahead of time to be consistent with ORT?
|
||||
for i in range(self._num_microbatches):
|
||||
if input_value.name in self._batch_dims or i == 0:
|
||||
if input_value in self._batch_dims or i == 0:
|
||||
forwarded_input = self._forward_value(
|
||||
transformed_module,
|
||||
transformed_function,
|
||||
pipelined_input_map[i],
|
||||
consumer_device,
|
||||
)
|
||||
|
@ -106,7 +102,7 @@ class PipelineParallelTransform:
|
|||
|
||||
def _aggregate_outputs(
|
||||
self,
|
||||
transformed_module,
|
||||
transformed_function,
|
||||
orig_output,
|
||||
pipelined_output,
|
||||
merged_output_map,
|
||||
|
@ -115,27 +111,21 @@ class PipelineParallelTransform:
|
|||
"""Aggregates the specified output according to the user-provided reduction parameters.
|
||||
|
||||
Args:
|
||||
transformed_module: The transformed module.
|
||||
transformed_function: The transformed function.
|
||||
orig_output: The original version of the output value.
|
||||
pipelined_output: The transformed (i.e. partitioned) version of the output value.
|
||||
merged_output_map: A map from original output value name to aggregated output value.
|
||||
merged_output_map: A map from original output value to aggregated output value.
|
||||
num_completed_microbatches: The number of microbatches completed so far.
|
||||
"""
|
||||
if self._reduction_params[orig_output.name] is None:
|
||||
if self._reduction_params[orig_output] is None:
|
||||
# This output does not need to be aggregated.
|
||||
return
|
||||
|
||||
reduction_op_type = self._reduction_params[orig_output.name]["op_type"]
|
||||
reduction_op_type = self._reduction_params[orig_output]["op_type"]
|
||||
if num_completed_microbatches == 1:
|
||||
merged_output_map[orig_output.name] = pipelined_output
|
||||
merged_output_map[orig_output] = pipelined_output
|
||||
else:
|
||||
merged_output = merged_output_map[orig_output.name]
|
||||
|
||||
# Forward the output value if necessary.
|
||||
if merged_output.type.device != pipelined_output.type.device:
|
||||
pipelined_output = self._forward_value(
|
||||
transformed_module, pipelined_output, merged_output.type.device
|
||||
)
|
||||
merged_output = merged_output_map[orig_output]
|
||||
|
||||
# Prepare the reduction op name and output value name.
|
||||
op_name = (
|
||||
|
@ -146,17 +136,17 @@ class PipelineParallelTransform:
|
|||
else:
|
||||
output_name = f"{orig_output.name}/merged_{num_completed_microbatches}"
|
||||
|
||||
# Add the requested reduction op to the transformed module.
|
||||
# Add the requested reduction op to the transformed function.
|
||||
if reduction_op_type == "Add":
|
||||
merged_output_map[orig_output.name] = transformed_module.add_op(
|
||||
merged_output_map[orig_output] = transformed_function.add_op(
|
||||
"Add",
|
||||
name=op_name,
|
||||
inputs=[merged_output, pipelined_output],
|
||||
output_names=[output_name],
|
||||
)
|
||||
elif reduction_op_type == "Concat":
|
||||
dim = self._reduction_params[orig_output.name]["dim"]
|
||||
merged_output_map[orig_output.name] = transformed_module.add_op(
|
||||
dim = self._reduction_params[orig_output]["dim"]
|
||||
merged_output_map[orig_output] = transformed_function.add_op(
|
||||
"Concat",
|
||||
attributes={"dim": dim},
|
||||
name=op_name,
|
||||
|
@ -169,49 +159,47 @@ class PipelineParallelTransform:
|
|||
f"for output {orig_output}"
|
||||
)
|
||||
|
||||
def apply(self, module):
|
||||
"""Applies the transformation to the module and returns a transformed module."""
|
||||
def apply(self, function):
|
||||
"""Applies the transformation to the function and returns a transformed function."""
|
||||
|
||||
transformed_module = Module()
|
||||
transformed_function = FunctionMaker()
|
||||
|
||||
# A map from original value name to another map from microbatch number to
|
||||
# A map from original value to another map from microbatch number to
|
||||
# pipelined value.
|
||||
pipelined_value_map = defaultdict(lambda: defaultdict(Value))
|
||||
pipelined_value_map = defaultdict(lambda: {})
|
||||
|
||||
# A map from original output value name to merged output value.
|
||||
# A map from original output value to merged output value.
|
||||
merged_output_map = defaultdict(Value)
|
||||
|
||||
# Partition the input values for each microbatch.
|
||||
self._partition_inputs(module, transformed_module, pipelined_value_map)
|
||||
# Partition the input values according to the number of microbatches.
|
||||
self._partition_inputs(function, transformed_function, pipelined_value_map)
|
||||
|
||||
# Schedule stages on each device in order of increasing timestep.
|
||||
for timestep in range(len(self._schedule)):
|
||||
for device in self._schedule[timestep]:
|
||||
# Look up the next stage to execute according to the schedule
|
||||
# and add each op in the stage to the transformed module.
|
||||
# and add each op in the stage to the transformed function.
|
||||
(stage, microbatch) = self._schedule[timestep][device]
|
||||
stage_outputs = set([v.name for v in stage.get_outputs()])
|
||||
for op_name, orig_op in stage.get_ops().items():
|
||||
orig_inputs = orig_op.get_in_edges()
|
||||
orig_outputs = orig_op.get_out_edges()
|
||||
for orig_op in stage.ops:
|
||||
orig_inputs = orig_op.inputs
|
||||
orig_outputs = orig_op.outputs
|
||||
|
||||
# Collect the pipelined input values for this op.
|
||||
pipelined_inputs = []
|
||||
for orig_input in orig_inputs:
|
||||
pipelined_input_map = pipelined_value_map[orig_input.name]
|
||||
pipelined_input = pipelined_input_map[microbatch]
|
||||
pipelined_input_map = pipelined_value_map[orig_input]
|
||||
pipelined_inputs.append(pipelined_input_map[microbatch])
|
||||
|
||||
# Add the pipelined version of the op for the given microbatch to
|
||||
# the transformed module.
|
||||
# the transformed function.
|
||||
pipelined_output_names = [
|
||||
f"{orig_output.name}_{microbatch}"
|
||||
for orig_output in orig_outputs
|
||||
]
|
||||
pipelined_outputs = transformed_module.add_op(
|
||||
pipelined_outputs = transformed_function.add_op(
|
||||
orig_op.op_type,
|
||||
name=f"{orig_op.name}_{microbatch}",
|
||||
attributes=orig_op._attributes,
|
||||
attributes=orig_op.attributes,
|
||||
inputs=pipelined_inputs,
|
||||
output_names=pipelined_output_names,
|
||||
)
|
||||
|
@ -223,19 +211,19 @@ class PipelineParallelTransform:
|
|||
for (orig_output, pipelined_output) in zip(
|
||||
orig_outputs, pipelined_outputs
|
||||
):
|
||||
pipelined_output_map = pipelined_value_map[orig_output.name]
|
||||
pipelined_output_map = pipelined_value_map[orig_output]
|
||||
pipelined_output_map[microbatch] = pipelined_output
|
||||
|
||||
if orig_output.name not in stage_outputs:
|
||||
if orig_output not in stage.outputs:
|
||||
# This output is an intermediate output *within* a stage which does not
|
||||
# require any additional processing.
|
||||
continue
|
||||
elif module.is_output(orig_output.name):
|
||||
# This output is a module output, which means we need to aggregate it
|
||||
elif orig_output in function.outputs:
|
||||
# This output is a function output, which means we need to aggregate it
|
||||
# with all other corresponding partitioned outputs for each microbatch.
|
||||
num_completed_microbatches = len(pipelined_output_map)
|
||||
self._aggregate_outputs(
|
||||
transformed_module,
|
||||
transformed_function,
|
||||
orig_output,
|
||||
pipelined_output,
|
||||
merged_output_map,
|
||||
|
@ -245,23 +233,21 @@ class PipelineParallelTransform:
|
|||
# This output is an intermediate stage output, which means we need to
|
||||
# forward the output to the next stage if the next stage is located on
|
||||
# a different device.
|
||||
consumer_ops = module.get_consumers_for_value(
|
||||
orig_output.name
|
||||
)
|
||||
consumer_stages = utils.get_stages_from_op_names(
|
||||
consumer_ops = function.get_consumers(orig_output)
|
||||
consumer_stages = utils.get_stages_from_ops(
|
||||
self._op_to_stage_map, consumer_ops
|
||||
)
|
||||
consumer_devices = set(
|
||||
[self._partition_map[c] for c in consumer_stages]
|
||||
)
|
||||
consumer_devices = {
|
||||
self._partition_map[c] for c in consumer_stages
|
||||
}
|
||||
for consumer_device in consumer_devices:
|
||||
if device != consumer_device:
|
||||
pipelined_output_map[
|
||||
microbatch
|
||||
] = self._forward_value(
|
||||
transformed_module,
|
||||
transformed_function,
|
||||
pipelined_output,
|
||||
consumer_device,
|
||||
)
|
||||
|
||||
return transformed_module
|
||||
return transformed_function.finalize()
|
||||
|
|
|
@ -1,27 +1,27 @@
|
|||
from typing import Dict, Iterable, List
|
||||
|
||||
from ..ir import Module
|
||||
from ..ir import Function, Op
|
||||
|
||||
|
||||
def get_op_to_stage_map(stages: Iterable[Module]) -> Dict[str, Module]:
|
||||
"""Given a list of stages, returns a map from individual op name to
|
||||
def get_op_to_stage_map(stages: Iterable[Function]) -> Dict[Op, Function]:
|
||||
"""Given a list of stages, returns a map from individual op to
|
||||
encompassing stage."""
|
||||
op_to_stage = {}
|
||||
for stage in stages:
|
||||
for op_name in stage.get_ops():
|
||||
op_to_stage[op_name] = stage
|
||||
for op in stage.ops:
|
||||
op_to_stage[op] = stage
|
||||
return op_to_stage
|
||||
|
||||
|
||||
def get_stages_from_op_names(
|
||||
op_to_stage: Dict[str, Module], op_names: Iterable[str]
|
||||
) -> List[Module]:
|
||||
"""Given a list of op names and a map from op name to encompassing stage,
|
||||
def get_stages_from_ops(
|
||||
op_to_stage: Dict[Op, Function], ops: Iterable[Op]
|
||||
) -> List[Function]:
|
||||
"""Given a list of ops and a map from op to encompassing stage,
|
||||
returns a list of encompassing stages."""
|
||||
seen = set()
|
||||
stages = []
|
||||
for op_name in op_names:
|
||||
stage = op_to_stage[op_name]
|
||||
for op in ops:
|
||||
stage = op_to_stage[op]
|
||||
if stage not in seen:
|
||||
stages.append(stage)
|
||||
return stages
|
||||
|
|
22
docs/ir.md
22
docs/ir.md
|
@ -1,7 +1,7 @@
|
|||
# The DistIR Internal Representation
|
||||
|
||||
```
|
||||
Module = {
|
||||
Function = {
|
||||
# Invariant: ops are in topological order
|
||||
ops: List[Op]
|
||||
inputs: List[Value]
|
||||
|
@ -17,16 +17,16 @@ Device = {
|
|||
|
||||
Op = {
|
||||
name: String
|
||||
# Do we need this to be unique in a module? Across all modules?
|
||||
# Do we need this to be unique in a function? Across all functions?
|
||||
op_type: OpType
|
||||
# The type of operator
|
||||
in_edges: List[Value]
|
||||
# Pointer to a Value object either in Module.inputs or another Op.out_edges
|
||||
# Pointer to a Value object either in Function.inputs or another Op.out_edges
|
||||
out_edges: List[Value]
|
||||
# To support ops that have more than one output
|
||||
attributes: Dict[String, Any]
|
||||
# Constant data for ops, e.g. stride of convolution or devices to scatter
|
||||
submodules: List[Module]
|
||||
subfunctions: List[Function]
|
||||
}
|
||||
|
||||
OpType =
|
||||
|
@ -37,7 +37,7 @@ OpType =
|
|||
|
||||
Value = {
|
||||
name: String
|
||||
# Again, does it need to be unique in a module?
|
||||
# Again, does it need to be unique in a function?
|
||||
type: Type
|
||||
device: DeviceID
|
||||
# Which device this value lives on
|
||||
|
@ -366,7 +366,7 @@ def mlp(
|
|||
```
|
||||
|
||||
TODO do we want explicit free?
|
||||
If so, then need to add it to input module after say importing from onnx,
|
||||
If so, then need to add it to input function after say importing from onnx,
|
||||
and need a free after last occurrence of every value.
|
||||
If not, then simulator has to do a liveness analysis. This isn't hard/expensive.
|
||||
But how do we deal with inplace operations? Have an attribute on those ops?
|
||||
|
@ -398,16 +398,16 @@ two layer FF or attention layers?
|
|||
|
||||
- Can we do without a special BroadcastScatterOpRegisterEntry?
|
||||
- Why do we even need types at op creation time?
|
||||
- Think about best way to create pmap device variable: before/after submodule creation?
|
||||
- Think about best way to create pmap device variable: before/after subfunction creation?
|
||||
|
||||
- Create `HardwareConfiguration` that has all device speeds as well as topology and bandwidth information
|
||||
|
||||
- Have a validation pass that checks module is valid (e.g. that inputs are on the same device for matmul)
|
||||
- Add test that `DistributedSimulator` doesn't modify module
|
||||
- Have a validation pass that checks function is valid (e.g. that inputs are on the same device for matmul)
|
||||
- Add test that `DistributedSimulator` doesn't modify function
|
||||
|
||||
- Maybe one problem with the return value of a pmap or scatter is that we are expecting things to have types but not shapes at module creation time, and then run a shape inference pass later to fill in the shapes. Would it be cleaner to just infer shapes and fill in a "full" type at op-creation time? Will we ever need to run shape inference later?
|
||||
- Maybe one problem with the return value of a pmap or scatter is that we are expecting things to have types but not shapes at function creation time, and then run a shape inference pass later to fill in the shapes. Would it be cleaner to just infer shapes and fill in a "full" type at op-creation time? Will we ever need to run shape inference later?
|
||||
|
||||
|
||||
## Problems
|
||||
|
||||
- Should a tensor never have a device? What if a pmap's submodule expects a tensor? Then how do we enforce that it is on device d? Is it better to go back to all values have an `Option[Device]`?
|
||||
- Should a tensor never have a device? What if a pmap's subfunction expects a tensor? Then how do we enforce that it is on device d? Is it better to go back to all values have an `Option[Device]`?
|
|
@ -26,7 +26,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from dist_ir.ir import Module\n",
|
||||
"from dist_ir.ir import Function\n",
|
||||
"from dist_ir.ir import Topology\n",
|
||||
"from dist_ir.ir.type import Float\n",
|
||||
"from dist_ir.ir.type import Tensor\n",
|
||||
|
@ -70,7 +70,7 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"def construct_mlp_module(topology, batch_size, num_classes, input_dim, hidden_dims):\n",
|
||||
" module = Module()\n",
|
||||
" module = FunctionMaker()\n",
|
||||
"\n",
|
||||
" device = topology.devices[0]\n",
|
||||
" x = module.add_input_value(\n",
|
||||
|
@ -180,7 +180,7 @@
|
|||
" transform = DataParallelTransform(\n",
|
||||
" batch_dims={\"x\": 0, \"z\": 0},\n",
|
||||
" reduction_params={\n",
|
||||
" f\"{dw.name}\": {\"op_type\": \"Allreduce\"} for dw in module.get_outputs()\n",
|
||||
" f\"{dw.name}\": {\"op_type\": \"Allreduce\"} for dw in module.outputs\n",
|
||||
" },\n",
|
||||
" devices=topology.devices,\n",
|
||||
" )\n",
|
||||
|
@ -365,7 +365,7 @@
|
|||
" num_layers = len(hidden_dims) + 1\n",
|
||||
" layers_per_device = num_layers // num_devices\n",
|
||||
" partition_map = OrderedDict()\n",
|
||||
" op_names = list(module.get_ops().keys())\n",
|
||||
" op_names = list(module.ops.keys())\n",
|
||||
" # Add forward pass stages\n",
|
||||
" for i in range(num_devices):\n",
|
||||
" idxs = [i*layers_per_device, (i+1)*layers_per_device] \n",
|
||||
|
@ -388,7 +388,7 @@
|
|||
" num_microbatches=num_microbatches,\n",
|
||||
" batch_dims={\"x\": 0, \"z\": 0},\n",
|
||||
" reduction_params={\n",
|
||||
" f\"{dw.name}\": {\"op_type\": \"Add\"} for dw in module.get_outputs()\n",
|
||||
" f\"{dw.name}\": {\"op_type\": \"Add\"} for dw in module.outputs\n",
|
||||
" },\n",
|
||||
" partition_map=partition_map,\n",
|
||||
" schedule=schedule,\n",
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
frozendict >= 1.2
|
||||
numpy >= 1.19
|
||||
onnx >= 1.7.0
|
||||
torch >= 1.6.0
|
||||
|
|
|
@ -1,53 +1,52 @@
|
|||
from collections import OrderedDict
|
||||
|
||||
from dist_ir.ir import Device, Module
|
||||
from dist_ir.ir import Device, FunctionMaker
|
||||
from dist_ir.ir.type import Float, Tensor
|
||||
|
||||
|
||||
def construct_module_and_partition_map():
|
||||
module = Module()
|
||||
def construct_function_and_partition_map():
|
||||
function = FunctionMaker()
|
||||
|
||||
d0 = Device(0, "gpu")
|
||||
d1 = Device(1, "gpu")
|
||||
batch_size = 16
|
||||
x = module.add_input_value(
|
||||
x = function.add_input_value(
|
||||
"x", Tensor(dtype=Float(), shape=(batch_size, 4), device=d0)
|
||||
)
|
||||
z = module.add_input_value(
|
||||
z = function.add_input_value(
|
||||
"z", Tensor(dtype=Float(), shape=(batch_size, 1), device=d0)
|
||||
)
|
||||
wA = module.add_input_value("wA", Tensor(dtype=Float(), shape=(4, 2), device=d0))
|
||||
wB = module.add_input_value("wB", Tensor(dtype=Float(), shape=(2, 1), device=d0))
|
||||
a = module.add_op("MatMul", "MatMul0", inputs=[x, wA], output_names=["a"])
|
||||
y = module.add_op("MatMul", "MatMul1", inputs=[a, wB], output_names=["y"])
|
||||
l = module.add_op(
|
||||
wA = function.add_input_value("wA", Tensor(dtype=Float(), shape=(4, 2), device=d0))
|
||||
wB = function.add_input_value("wB", Tensor(dtype=Float(), shape=(2, 1), device=d0))
|
||||
a = function.add_op("MatMul", "MatMul0", inputs=[x, wA], output_names=["a"])
|
||||
y = function.add_op("MatMul", "MatMul1", inputs=[a, wB], output_names=["y"])
|
||||
l = function.add_op(
|
||||
"Loss", "Loss", inputs=[y, z], attributes={"N": batch_size}, output_names=["l"]
|
||||
)
|
||||
dl = module.add_op(
|
||||
dl = function.add_op(
|
||||
"LossGrad",
|
||||
"LossGrad",
|
||||
inputs=[y, z],
|
||||
attributes={"N": batch_size},
|
||||
output_names=["dl"],
|
||||
)
|
||||
da, dwB = module.add_op(
|
||||
da, dwB = function.add_op(
|
||||
"MatMulGrad", "MatMul1Grad", inputs=[a, wB, dl], output_names=["da", "dwB"]
|
||||
)
|
||||
_, dwA = module.add_op(
|
||||
_, dwA = function.add_op(
|
||||
"MatMulGrad", "MatMul0Grad", inputs=[x, wA, da], output_names=["dx", "dwA"]
|
||||
)
|
||||
module.set_outputs([l, dwA, dwB])
|
||||
module.finalize()
|
||||
function = function.finalize()
|
||||
|
||||
stages = [
|
||||
module.get_submodule(("MatMul0",), name="f0"),
|
||||
module.get_submodule(("MatMul1", "Loss"), name="f1"),
|
||||
module.get_submodule(("LossGrad", "MatMul1Grad"), name="b1"),
|
||||
module.get_submodule(("MatMul0Grad",), name="b0"),
|
||||
function.get_subfunction(("MatMul0",), name="f0"),
|
||||
function.get_subfunction(("MatMul1", "Loss"), name="f1"),
|
||||
function.get_subfunction(("LossGrad", "MatMul1Grad"), name="b1"),
|
||||
function.get_subfunction(("MatMul0Grad",), name="b0"),
|
||||
]
|
||||
|
||||
partition_map = OrderedDict(
|
||||
[(stages[0], d0), (stages[1], d1), (stages[2], d1), (stages[3], d0)]
|
||||
)
|
||||
|
||||
return (module, partition_map)
|
||||
return (function, partition_map)
|
||||
|
|
|
@ -1,153 +1,52 @@
|
|||
import numpy as np
|
||||
|
||||
from dist_ir.ir import Device, Module
|
||||
from dist_ir.ir import Device, FunctionMaker
|
||||
from dist_ir.ir.type import Float, Tensor
|
||||
from dist_ir.transforms import DataParallelTransform
|
||||
from dist_ir.executor import SequentialExecutor
|
||||
|
||||
# TODO test on actual inputs using sequential executor
|
||||
|
||||
|
||||
def test_single_variable_partition():
|
||||
module = Module()
|
||||
function = FunctionMaker()
|
||||
|
||||
d0 = Device(0, "gpu")
|
||||
d1 = Device(1, "gpu")
|
||||
|
||||
a = module.add_input_value("a", Tensor(Float(), (4, 4)))
|
||||
b = module.add_input_value("b", Tensor(Float(), (4, 4)))
|
||||
x = module.add_op("MatMul", "MatMul0", inputs=[a, b], output_names=["x"])
|
||||
module.finalize()
|
||||
a = function.add_input_value("a", Tensor(Float(), (4, 4)))
|
||||
b = function.add_input_value("b", Tensor(Float(), (4, 4)))
|
||||
x = function.add_op("MatMul", "MatMul0", inputs=[a, b], output_names=["x"])
|
||||
function = function.finalize()
|
||||
transform = DataParallelTransform(
|
||||
batch_dims={"a": 0},
|
||||
reduction_params={"x": {"op_type": "Gather", "dim": 0, "device": d0}},
|
||||
devices=[d0, d1],
|
||||
)
|
||||
transformed_module = transform.apply(module)
|
||||
|
||||
print("-" * 88)
|
||||
print("Original module")
|
||||
print("-" * 88)
|
||||
print(module)
|
||||
print()
|
||||
print("-" * 88)
|
||||
print("Transformed module")
|
||||
print("-" * 88)
|
||||
print(transformed_module)
|
||||
|
||||
assert transformed_module.is_op("Scatter/a")
|
||||
assert transformed_module.is_op("Broadcast/b")
|
||||
assert transformed_module.is_op("Pmap_#0")
|
||||
assert transformed_module.is_op("Gather/x")
|
||||
|
||||
|
||||
def test_double_variable_partition():
|
||||
module = Module()
|
||||
|
||||
d0 = Device(0, "gpu")
|
||||
d1 = Device(1, "gpu")
|
||||
|
||||
a = module.add_input_value("a", Tensor(Float(), (4, 4)))
|
||||
b = module.add_input_value("b", Tensor(Float(), (4, 4)))
|
||||
c = module.add_input_value("c", Tensor(Float(), (4, 4)))
|
||||
x = module.add_op("MatMul", "MatMul0", inputs=[a, b], output_names=["x"])
|
||||
y = module.add_op("MatMul", "MatMul1", inputs=[x, c], output_names=["y"])
|
||||
module.finalize()
|
||||
transform = DataParallelTransform(
|
||||
batch_dims={"a": 0, "c": 0},
|
||||
reduction_params={"y": {"op_type": "Gather", "dim": 0, "device": d0}},
|
||||
devices=[d0, d1],
|
||||
)
|
||||
transformed_module = transform.apply(module)
|
||||
|
||||
print("-" * 88)
|
||||
print("Original module")
|
||||
print("-" * 88)
|
||||
print(module)
|
||||
print()
|
||||
print("-" * 88)
|
||||
print("Transformed module")
|
||||
print("-" * 88)
|
||||
print(transformed_module)
|
||||
|
||||
assert transformed_module.is_op("Scatter/a")
|
||||
assert transformed_module.is_op("Broadcast/b")
|
||||
assert transformed_module.is_op("Scatter/c")
|
||||
assert transformed_module.is_op("Pmap_#0")
|
||||
assert transformed_module.is_op("Gather/y")
|
||||
|
||||
|
||||
def test_mnist():
|
||||
module = Module()
|
||||
|
||||
d0 = Device(0, "gpu")
|
||||
d1 = Device(1, "gpu")
|
||||
|
||||
batch_size = 16
|
||||
x = module.add_input_value("x", Tensor(Float(), (batch_size, 4)))
|
||||
z = module.add_input_value("z", Tensor(Float(), (batch_size, 1)))
|
||||
wA = module.add_input_value("wA", Tensor(Float(), (4, 2)))
|
||||
wB = module.add_input_value("wB", Tensor(Float(), (2, 1)))
|
||||
a = module.add_op("MatMul", "MatMul0", inputs=[x, wA], output_names=["a"])
|
||||
y = module.add_op("MatMul", "MatMul1", inputs=[a, wB], output_names=["y"])
|
||||
l = module.add_op(
|
||||
"Loss", "Loss", inputs=[y, z], attributes={"N": batch_size}, output_names=["l"]
|
||||
)
|
||||
dl = module.add_op(
|
||||
"LossGrad",
|
||||
"LossGrad",
|
||||
inputs=[y, z],
|
||||
attributes={"N": batch_size},
|
||||
output_names=["dl"],
|
||||
)
|
||||
da, dwB = module.add_op(
|
||||
"MatMulGrad", "MatMul1Grad", inputs=[a, wB, dl], output_names=["da", "dwB"]
|
||||
)
|
||||
dx, dwA = module.add_op(
|
||||
"MatMulGrad", "MatMul0Grad", inputs=[x, wA, da], output_names=["dx", "dwA"]
|
||||
)
|
||||
module.set_outputs([l, dwA, dwB])
|
||||
module.finalize()
|
||||
transform = DataParallelTransform(
|
||||
batch_dims={"x": 0, "z": 0},
|
||||
batch_dims={function.inputs[0]: 0},
|
||||
reduction_params={
|
||||
"l": {"op_type": "Gather", "dim": 0, "device": d0},
|
||||
"dx": {"op_type": "Gather", "dim": 0, "device": d0},
|
||||
"dwA": {"op_type": "Allreduce"},
|
||||
"dwB": {"op_type": "Allreduce"},
|
||||
function.outputs[0]: {"op_type": "Gather", "dim": 0, "device": d0}
|
||||
},
|
||||
devices=[d0, d1],
|
||||
)
|
||||
transformed_module = transform.apply(module)
|
||||
transformed_module.finalize()
|
||||
transformed_function = transform.apply(function)
|
||||
|
||||
print("-" * 88)
|
||||
print("Original module")
|
||||
print("Original function")
|
||||
print("-" * 88)
|
||||
print(module)
|
||||
print(function)
|
||||
print()
|
||||
print("-" * 88)
|
||||
print("Transformed module")
|
||||
print("Transformed function")
|
||||
print("-" * 88)
|
||||
print(transformed_module)
|
||||
print(transformed_function)
|
||||
|
||||
ex = SequentialExecutor("numpy")
|
||||
_x = np.arange(batch_size * 4).reshape((batch_size, 4))
|
||||
_z = np.ones((batch_size, 1))
|
||||
_wA = np.ones((4, 2))
|
||||
_wB = np.ones((2, 1))
|
||||
orig_res = ex.compute(
|
||||
module,
|
||||
{"x": _x, "z": _z, "wA": _wA, "wB": _wB},
|
||||
)
|
||||
_a = np.ones((4, 4))
|
||||
_b = np.ones((4, 4))
|
||||
orig_res = ex.compute(function, {function.inputs[0]: _a, function.inputs[1]: _b})
|
||||
|
||||
transformed_res = ex.compute(
|
||||
transformed_module,
|
||||
{"x": _x, "z": _z, "wA": _wA, "wB": _wB},
|
||||
transformed_function,
|
||||
{transformed_function.inputs[0]: _a, transformed_function.inputs[1]: _b},
|
||||
)
|
||||
|
||||
print("-" * 88)
|
||||
print("Original module results")
|
||||
print("Original function results")
|
||||
print("-" * 88)
|
||||
for k, v in orig_res.items():
|
||||
print(k)
|
||||
|
@ -155,29 +54,193 @@ def test_mnist():
|
|||
print()
|
||||
print()
|
||||
print("-" * 88)
|
||||
print("Transformed module results")
|
||||
print("Transformed function results")
|
||||
print("-" * 88)
|
||||
for k, v in transformed_res.items():
|
||||
print(k)
|
||||
print(v)
|
||||
print()
|
||||
|
||||
assert np.array_equal(orig_res["l"], np.concatenate(transformed_res["ls"], axis=0))
|
||||
assert np.array_equal(orig_res["dwA"], transformed_res["dwAs"][0])
|
||||
assert np.array_equal(orig_res["dwB"], transformed_res["dwBs"][0])
|
||||
|
||||
"""
|
||||
assert transformed_module.is_op("Scatter/x")
|
||||
assert transformed_module.is_op("Scatter/z")
|
||||
assert transformed_module.is_op("Broadcast/wA")
|
||||
assert transformed_module.is_op("Broadcast/wB")
|
||||
assert transformed_module.is_op("Pmap_#0")
|
||||
assert transformed_module.is_op("Gather/l")
|
||||
assert transformed_module.is_op("Gather/dx")
|
||||
assert transformed_module.is_op("Allreduce/dwA")
|
||||
assert transformed_module.is_op("Allreduce/dwB")
|
||||
"""
|
||||
np.testing.assert_array_almost_equal(
|
||||
orig_res[function.outputs[0]], transformed_res[transformed_function.outputs[0]]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_mnist()
|
||||
def test_double_variable_partition():
|
||||
function = FunctionMaker()
|
||||
|
||||
d0 = Device(0, "gpu")
|
||||
d1 = Device(1, "gpu")
|
||||
|
||||
a = function.add_input_value("a", Tensor(Float(), (4, 4)))
|
||||
b = function.add_input_value("b", Tensor(Float(), (4, 4)))
|
||||
c = function.add_input_value("c", Tensor(Float(), (4, 4)))
|
||||
x = function.add_op("MatMul", "MatMul", inputs=[a, b], output_names=["x"])
|
||||
y = function.add_op("Add", "Add", inputs=[x, c], output_names=["y"])
|
||||
function = function.finalize()
|
||||
transform = DataParallelTransform(
|
||||
batch_dims={function.inputs[0]: 0, function.inputs[2]: 0},
|
||||
reduction_params={
|
||||
function.outputs[0]: {"op_type": "Gather", "dim": 0, "device": d0}
|
||||
},
|
||||
devices=[d0, d1],
|
||||
)
|
||||
transformed_function = transform.apply(function)
|
||||
|
||||
print("-" * 88)
|
||||
print("Original function")
|
||||
print("-" * 88)
|
||||
print(function)
|
||||
print()
|
||||
print("-" * 88)
|
||||
print("Transformed function")
|
||||
print("-" * 88)
|
||||
print(transformed_function)
|
||||
|
||||
ex = SequentialExecutor("numpy")
|
||||
_a = np.ones((4, 4))
|
||||
_b = np.ones((4, 4))
|
||||
_c = np.ones((4, 4))
|
||||
orig_res = ex.compute(
|
||||
function,
|
||||
{function.inputs[0]: _a, function.inputs[1]: _b, function.inputs[2]: _c},
|
||||
)
|
||||
|
||||
transformed_res = ex.compute(
|
||||
transformed_function,
|
||||
{
|
||||
transformed_function.inputs[0]: _a,
|
||||
transformed_function.inputs[1]: _b,
|
||||
transformed_function.inputs[2]: _c,
|
||||
},
|
||||
)
|
||||
|
||||
print("-" * 88)
|
||||
print("Original function results")
|
||||
print("-" * 88)
|
||||
for k, v in orig_res.items():
|
||||
print(k)
|
||||
print(v)
|
||||
print()
|
||||
print()
|
||||
print("-" * 88)
|
||||
print("Transformed function results")
|
||||
print("-" * 88)
|
||||
for k, v in transformed_res.items():
|
||||
print(k)
|
||||
print(v)
|
||||
print()
|
||||
|
||||
np.testing.assert_array_almost_equal(
|
||||
orig_res[function.outputs[0]], transformed_res[transformed_function.outputs[0]]
|
||||
)
|
||||
|
||||
|
||||
def test_mnist():
|
||||
function = FunctionMaker()
|
||||
|
||||
d0 = Device(0, "gpu")
|
||||
d1 = Device(1, "gpu")
|
||||
|
||||
batch_size = 16
|
||||
x = function.add_input_value("x", Tensor(Float(), (batch_size, 4)))
|
||||
z = function.add_input_value("z", Tensor(Float(), (batch_size, 1)))
|
||||
wA = function.add_input_value("wA", Tensor(Float(), (4, 2)))
|
||||
wB = function.add_input_value("wB", Tensor(Float(), (2, 1)))
|
||||
a = function.add_op("MatMul", "MatMul0", inputs=[x, wA], output_names=["a"])
|
||||
y = function.add_op("MatMul", "MatMul1", inputs=[a, wB], output_names=["y"])
|
||||
l = function.add_op(
|
||||
"Loss", "Loss", inputs=[y, z], attributes={"N": batch_size}, output_names=["l"]
|
||||
)
|
||||
dl = function.add_op(
|
||||
"LossGrad",
|
||||
"LossGrad",
|
||||
inputs=[y, z],
|
||||
attributes={"N": batch_size},
|
||||
output_names=["dl"],
|
||||
)
|
||||
da, dwB = function.add_op(
|
||||
"MatMulGrad", "MatMul1Grad", inputs=[a, wB, dl], output_names=["da", "dwB"]
|
||||
)
|
||||
dx, dwA = function.add_op(
|
||||
"MatMulGrad", "MatMul0Grad", inputs=[x, wA, da], output_names=["dx", "dwA"]
|
||||
)
|
||||
function = function.finalize()
|
||||
transform = DataParallelTransform(
|
||||
batch_dims={function.inputs[0]: 0, function.inputs[1]: 0},
|
||||
reduction_params={
|
||||
function.outputs[0]: {"op_type": "Gather", "dim": 0, "device": d0},
|
||||
function.outputs[1]: {"op_type": "Allreduce"},
|
||||
function.outputs[2]: {"op_type": "Gather", "dim": 0, "device": d0},
|
||||
function.outputs[3]: {"op_type": "Allreduce"},
|
||||
},
|
||||
devices=[d0, d1],
|
||||
)
|
||||
transformed_function = transform.apply(function)
|
||||
|
||||
print("-" * 88)
|
||||
print("Original function")
|
||||
print("-" * 88)
|
||||
print(function)
|
||||
print()
|
||||
print("-" * 88)
|
||||
print("Transformed function")
|
||||
print("-" * 88)
|
||||
print(transformed_function)
|
||||
|
||||
ex = SequentialExecutor("numpy")
|
||||
_x = np.arange(batch_size * 4).reshape((batch_size, 4))
|
||||
_z = np.ones((batch_size, 1))
|
||||
_wA = np.ones((4, 2))
|
||||
_wB = np.ones((2, 1))
|
||||
orig_res = ex.compute(
|
||||
function,
|
||||
{
|
||||
function.inputs[0]: _x,
|
||||
function.inputs[1]: _z,
|
||||
function.inputs[2]: _wA,
|
||||
function.inputs[3]: _wB,
|
||||
},
|
||||
)
|
||||
|
||||
transformed_res = ex.compute(
|
||||
transformed_function,
|
||||
{
|
||||
transformed_function.inputs[0]: _x,
|
||||
transformed_function.inputs[1]: _z,
|
||||
transformed_function.inputs[2]: _wA,
|
||||
transformed_function.inputs[3]: _wB,
|
||||
},
|
||||
)
|
||||
|
||||
print("-" * 88)
|
||||
print("Original function results")
|
||||
print("-" * 88)
|
||||
for k, v in orig_res.items():
|
||||
print(k)
|
||||
print(v)
|
||||
print()
|
||||
print()
|
||||
print("-" * 88)
|
||||
print("Transformed function results")
|
||||
print("-" * 88)
|
||||
for k, v in transformed_res.items():
|
||||
print(k)
|
||||
print(v)
|
||||
print()
|
||||
|
||||
np.testing.assert_array_almost_equal(
|
||||
orig_res[function.outputs[0]], transformed_res[transformed_function.outputs[0]]
|
||||
)
|
||||
np.testing.assert_array_almost_equal(
|
||||
orig_res[function.outputs[1]],
|
||||
transformed_res[transformed_function.outputs[1]][0],
|
||||
)
|
||||
np.testing.assert_array_almost_equal(
|
||||
orig_res[function.outputs[2]],
|
||||
transformed_res[transformed_function.outputs[2]],
|
||||
)
|
||||
np.testing.assert_array_almost_equal(
|
||||
orig_res[function.outputs[3]],
|
||||
transformed_res[transformed_function.outputs[3]][0],
|
||||
)
|
||||
|
|
|
@ -1,58 +1,66 @@
|
|||
from dist_ir.ir import Module
|
||||
from dist_ir.ir import FunctionMaker
|
||||
from dist_ir.ir import Topology
|
||||
from dist_ir.ir.type import Float
|
||||
from dist_ir.ir.type import Tensor
|
||||
from dist_ir.executor.cost_inference import CostModel
|
||||
from dist_ir.executor.type_inference import infer_types
|
||||
from dist_ir.executor import DistributedSimulator
|
||||
from dist_ir.transforms import DataParallelTransform
|
||||
|
||||
|
||||
def test_single_device():
|
||||
module = Module()
|
||||
function = FunctionMaker()
|
||||
topology = Topology()
|
||||
|
||||
d = topology.add_device("gpu")
|
||||
|
||||
a = module.add_input_value("a", Tensor(dtype=Float(), shape=(4, 4), device=d))
|
||||
b = module.add_input_value("b", Tensor(dtype=Float(), shape=(4, 4), device=d))
|
||||
x = module.add_op("MatMul", "MatMul0", inputs=[a, b])
|
||||
module.finalize()
|
||||
a = function.add_input_value("a", Tensor(dtype=Float(), shape=(4, 4), device=d))
|
||||
b = function.add_input_value("b", Tensor(dtype=Float(), shape=(4, 4), device=d))
|
||||
x = function.add_op("MatMul", "MatMul0", inputs=[a, b])
|
||||
function = function.finalize()
|
||||
function = infer_types(function, [a, b])
|
||||
device_speeds = {"gpu": 1.0e13}
|
||||
# TODO shouldn't device_speeds be set in the topology?
|
||||
cost_model = CostModel(topology, device_speeds)
|
||||
simulator = DistributedSimulator(cost_model)
|
||||
simulator_state = simulator.simulate(module)
|
||||
simulator_state = simulator.simulate(function)
|
||||
assert d in simulator_state.timestamps
|
||||
assert d in simulator_state.peak_memory
|
||||
# TODO: Check specific values
|
||||
|
||||
|
||||
def test_data_parallel():
|
||||
module = Module()
|
||||
function = FunctionMaker()
|
||||
topology = Topology()
|
||||
|
||||
d0 = topology.add_device("gpu")
|
||||
d1 = topology.add_device("gpu")
|
||||
topology.set_bandwidth(d0, d1, 2)
|
||||
|
||||
a = module.add_input_value("a", Tensor(Float(), (4, 4), device=d0))
|
||||
b = module.add_input_value("b", Tensor(Float(), (4, 4), device=d0))
|
||||
c = module.add_input_value("c", Tensor(Float(), (4, 4), device=d0))
|
||||
x = module.add_op("MatMul", "MatMul0", inputs=[a, b], output_names=["x"])
|
||||
y = module.add_op("MatMul", "MatMul1", inputs=[x, c], output_names=["y"])
|
||||
module.finalize()
|
||||
a = function.add_input_value("a", Tensor(Float(), (4, 4), device=d0))
|
||||
b = function.add_input_value("b", Tensor(Float(), (4, 4), device=d0))
|
||||
c = function.add_input_value("c", Tensor(Float(), (4, 4), device=d0))
|
||||
x = function.add_op("MatMul", "MatMul0", inputs=[a, b], output_names=["x"])
|
||||
y = function.add_op("MatMul", "MatMul1", inputs=[x, c], output_names=["y"])
|
||||
function = function.finalize()
|
||||
function = infer_types(function, [a, b, c])
|
||||
transform = DataParallelTransform(
|
||||
batch_dims={"a": 0},
|
||||
reduction_params={"y": {"op_type": "Gather", "dim": 0, "device": d0}},
|
||||
batch_dims={function.inputs[0]: 0},
|
||||
reduction_params={
|
||||
function.outputs[0]: {"op_type": "Gather", "dim": 0, "device": d0}
|
||||
},
|
||||
devices=[d0, d1],
|
||||
)
|
||||
transformed_module = transform.apply(module)
|
||||
transformed_function = transform.apply(function)
|
||||
transformed_function = infer_types(
|
||||
transformed_function, transformed_function.inputs
|
||||
)
|
||||
|
||||
transformed_module.finalize()
|
||||
print(transformed_module)
|
||||
print(transformed_function)
|
||||
device_speeds = {"gpu": 1.0e13}
|
||||
cost_model = CostModel(topology, device_speeds)
|
||||
simulator = DistributedSimulator(cost_model)
|
||||
simulator_state = simulator.simulate(transformed_module)
|
||||
simulator_state = simulator.simulate(transformed_function)
|
||||
assert d0 in simulator_state.timestamps
|
||||
assert d1 in simulator_state.timestamps
|
||||
assert d0 in simulator_state.peak_memory
|
||||
|
@ -61,31 +69,36 @@ def test_data_parallel():
|
|||
|
||||
|
||||
def test_chrome_trace():
|
||||
module = Module()
|
||||
function = FunctionMaker()
|
||||
|
||||
topology = Topology()
|
||||
d0 = topology.add_device("gpu")
|
||||
d1 = topology.add_device("gpu")
|
||||
topology.set_bandwidth(d0, d1, 2)
|
||||
|
||||
a = module.add_input_value("a", Tensor(Float(), (4, 4), device=d0))
|
||||
b = module.add_input_value("b", Tensor(Float(), (4, 4), device=d0))
|
||||
c = module.add_input_value("c", Tensor(Float(), (4, 4), device=d0))
|
||||
x = module.add_op("MatMul", "MatMul0", inputs=[a, b], output_names=["x"])
|
||||
y = module.add_op("MatMul", "MatMul1", inputs=[x, c], output_names=["y"])
|
||||
module.finalize()
|
||||
a = function.add_input_value("a", Tensor(Float(), (4, 4), device=d0))
|
||||
b = function.add_input_value("b", Tensor(Float(), (4, 4), device=d0))
|
||||
c = function.add_input_value("c", Tensor(Float(), (4, 4), device=d0))
|
||||
x = function.add_op("MatMul", "MatMul0", inputs=[a, b], output_names=["x"])
|
||||
y = function.add_op("MatMul", "MatMul1", inputs=[x, c], output_names=["y"])
|
||||
function = function.finalize()
|
||||
function = infer_types(function, [a, b, c])
|
||||
|
||||
device_speeds = {"gpu": 1.0e13}
|
||||
cost_model = CostModel(topology, device_speeds)
|
||||
simulator = DistributedSimulator(cost_model)
|
||||
|
||||
transform = DataParallelTransform(
|
||||
batch_dims={"a": 0},
|
||||
reduction_params={"y": {"op_type": "Gather", "dim": 0, "device": d0}},
|
||||
batch_dims={function.inputs[0]: 0},
|
||||
reduction_params={
|
||||
function.outputs[0]: {"op_type": "Gather", "dim": 0, "device": d0}
|
||||
},
|
||||
devices=[d0, d1],
|
||||
)
|
||||
transformed_module = transform.apply(module)
|
||||
transformed_module.finalize()
|
||||
transformed_function = transform.apply(function)
|
||||
transformed_function = infer_types(
|
||||
transformed_function, transformed_function.inputs
|
||||
)
|
||||
|
||||
simulation = simulator.simulate(transformed_module)
|
||||
simulation = simulator.simulate(transformed_function)
|
||||
simulation.dump_chrome_trace("test/trace.json")
|
||||
|
|
|
@ -37,15 +37,15 @@ def test_parser():
|
|||
return %y: !dist.tensor<8x6xf32, 0>
|
||||
}
|
||||
"""
|
||||
modules = mlir_parser.parse_mlir_str(mlir_str)
|
||||
assert len(modules) == 1
|
||||
module = modules[0]
|
||||
cpprint(module)
|
||||
functions = mlir_parser.parse_mlir_str(mlir_str)
|
||||
assert len(functions) == 1
|
||||
function = functions[0]
|
||||
cpprint(function)
|
||||
|
||||
ex = SequentialExecutor("numpy")
|
||||
_wA = np.ones((4, 6))
|
||||
_x = np.arange(8 * 4).reshape((8, 4))
|
||||
res = ex.compute(module, {"%arg1": _x, "%arg0": _wA})
|
||||
res = ex.compute(function, {function.inputs[0]: _wA, function.inputs[1]: _x})
|
||||
|
||||
# TODO fix concat's implementation in numpy register for this:
|
||||
# assert np.array_equal(res["%var4"], np.matmul(_x, _wA))
|
||||
|
|
|
@ -3,10 +3,10 @@ import pipeline_parallel_utils as utils
|
|||
|
||||
|
||||
def test_fifo_scheduler():
|
||||
(module, partition_map) = utils.construct_module_and_partition_map()
|
||||
(function, partition_map) = utils.construct_function_and_partition_map()
|
||||
(d0, d1) = sorted(set(partition_map.values()))
|
||||
scheduler = FIFOScheduler(num_microbatches=2)
|
||||
schedule = scheduler.schedule(module, partition_map)
|
||||
schedule = scheduler.schedule(function, partition_map)
|
||||
|
||||
stages = list(partition_map.keys())
|
||||
ref_schedule = [
|
||||
|
@ -22,10 +22,10 @@ def test_fifo_scheduler():
|
|||
|
||||
|
||||
def test_pipedream_scheduler():
|
||||
(module, partition_map) = utils.construct_module_and_partition_map()
|
||||
(function, partition_map) = utils.construct_function_and_partition_map()
|
||||
(d0, d1) = sorted(set(partition_map.values()))
|
||||
scheduler = PipeDreamScheduler(num_microbatches=2)
|
||||
schedule = scheduler.schedule(module, partition_map)
|
||||
schedule = scheduler.schedule(function, partition_map)
|
||||
|
||||
stages = list(partition_map.keys())
|
||||
ref_schedule = [
|
||||
|
@ -38,3 +38,7 @@ def test_pipedream_scheduler():
|
|||
]
|
||||
|
||||
assert schedule == ref_schedule
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_fifo_scheduler()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import numpy as np
|
||||
|
||||
from dist_ir.ir import Device, Module
|
||||
from dist_ir.ir import Device, Function
|
||||
from dist_ir.ir.type import Float, Tensor
|
||||
from dist_ir.transforms import PipelineParallelTransform
|
||||
from dist_ir.executor import SequentialExecutor
|
||||
|
@ -8,9 +8,8 @@ import pipeline_parallel_utils as utils
|
|||
|
||||
|
||||
def test_mnist_fw_bw():
|
||||
(module, partition_map) = utils.construct_module_and_partition_map()
|
||||
(function, partition_map) = utils.construct_function_and_partition_map()
|
||||
(d0, d1) = sorted(set(partition_map.values()))
|
||||
|
||||
stages = list(partition_map.keys())
|
||||
schedule = [
|
||||
{d0: (stages[0], 0)},
|
||||
|
@ -22,27 +21,27 @@ def test_mnist_fw_bw():
|
|||
]
|
||||
transform = PipelineParallelTransform(
|
||||
num_microbatches=2,
|
||||
batch_dims={"x": 0, "z": 0},
|
||||
batch_dims={function.inputs[0]: 0, function.inputs[1]: 0},
|
||||
reduction_params={
|
||||
"dwB": {"op_type": "Add"},
|
||||
"dwA": {"op_type": "Add"},
|
||||
"l": {"op_type": "Concat", "dim": 0},
|
||||
function.outputs[0]: {"op_type": "Concat", "dim": 0}, # l
|
||||
function.outputs[1]: {"op_type": "Add"}, # dwB
|
||||
function.outputs[2]: {"op_type": "Concat", "dim": 0}, # dx
|
||||
function.outputs[3]: {"op_type": "Add"}, # dwA
|
||||
},
|
||||
partition_map=partition_map,
|
||||
schedule=schedule,
|
||||
)
|
||||
transformed_module = transform.apply(module)
|
||||
transformed_module.finalize()
|
||||
transformed_function = transform.apply(function)
|
||||
|
||||
print("-" * 88)
|
||||
print("Original module")
|
||||
print("Original function")
|
||||
print("-" * 88)
|
||||
print(module)
|
||||
print(function)
|
||||
print()
|
||||
print("-" * 88)
|
||||
print("Transformed module")
|
||||
print("Transformed function")
|
||||
print("-" * 88)
|
||||
print(transformed_module)
|
||||
print(transformed_function)
|
||||
|
||||
batch_size = 16
|
||||
ex = SequentialExecutor("numpy")
|
||||
|
@ -51,17 +50,27 @@ def test_mnist_fw_bw():
|
|||
_wA = np.ones((4, 2))
|
||||
_wB = np.ones((2, 1))
|
||||
orig_res = ex.compute(
|
||||
module,
|
||||
{"x": _x, "z": _z, "wA": _wA, "wB": _wB},
|
||||
function,
|
||||
{
|
||||
function.inputs[0]: _x,
|
||||
function.inputs[1]: _z,
|
||||
function.inputs[2]: _wA,
|
||||
function.inputs[3]: _wB,
|
||||
},
|
||||
)
|
||||
|
||||
transformed_res = ex.compute(
|
||||
transformed_module,
|
||||
{"x": _x, "z": _z, "wA": _wA, "wB": _wB},
|
||||
transformed_function,
|
||||
{
|
||||
transformed_function.inputs[0]: _x,
|
||||
transformed_function.inputs[1]: _z,
|
||||
transformed_function.inputs[2]: _wA,
|
||||
transformed_function.inputs[3]: _wB,
|
||||
},
|
||||
)
|
||||
|
||||
print("-" * 88)
|
||||
print("Original module results")
|
||||
print("Original function results")
|
||||
print("-" * 88)
|
||||
for k, v in orig_res.items():
|
||||
print(k)
|
||||
|
@ -69,16 +78,18 @@ def test_mnist_fw_bw():
|
|||
print()
|
||||
print()
|
||||
print("-" * 88)
|
||||
print("Transformed module results")
|
||||
print("Transformed function results")
|
||||
print("-" * 88)
|
||||
for k, v in transformed_res.items():
|
||||
print(k)
|
||||
print(v)
|
||||
print()
|
||||
|
||||
assert np.array_equal(orig_res["l"], transformed_res["l"])
|
||||
assert np.array_equal(orig_res["dwA"], transformed_res["dwA"])
|
||||
assert np.array_equal(orig_res["dwB"], transformed_res["dwB"])
|
||||
for i in range(len(function.outputs)):
|
||||
np.testing.assert_array_almost_equal(
|
||||
orig_res[function.outputs[i]],
|
||||
transformed_res[transformed_function.outputs[i]],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,27 +1,27 @@
|
|||
from pathlib import Path
|
||||
|
||||
from dist_ir.importer import import_from_onnx
|
||||
from dist_ir.ir import Module, Topology
|
||||
from dist_ir.ir import FunctionMaker, Topology
|
||||
from dist_ir.ir.type import Float, Tensor
|
||||
from dist_ir.ir import cpprint
|
||||
|
||||
|
||||
def test_cpprint():
|
||||
module = Module()
|
||||
function = FunctionMaker()
|
||||
topology = Topology()
|
||||
|
||||
d = topology.add_device("gpu")
|
||||
|
||||
a = module.add_input_value("a", Tensor(dtype=Float(), shape=(4, 4), device=d))
|
||||
b = module.add_input_value("b", Tensor(dtype=Float(), shape=(4, 4), device=d))
|
||||
x = module.add_op("MatMul", "MatMul0", inputs=[a, b])
|
||||
y = module.add_op("MatMul", "MatMul1", inputs=[x, b])
|
||||
module.finalize()
|
||||
a = function.add_input_value("a", Tensor(dtype=Float(), shape=(4, 4), device=d))
|
||||
b = function.add_input_value("b", Tensor(dtype=Float(), shape=(4, 4), device=d))
|
||||
x = function.add_op("MatMul", "MatMul0", inputs=[a, b])
|
||||
y = function.add_op("MatMul", "MatMul1", inputs=[x, b])
|
||||
function.finalize()
|
||||
|
||||
cpprint(module)
|
||||
cpprint(function)
|
||||
|
||||
|
||||
def test_import_from_onnx():
|
||||
onnx_model_path = Path(__file__).parent / "mnist_gemm_bw_running.onnx"
|
||||
module = import_from_onnx(onnx_model_path)
|
||||
cpprint(module)
|
||||
function = import_from_onnx(onnx_model_path)
|
||||
cpprint(function)
|
||||
|
|
|
@ -2,20 +2,19 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from dist_ir.ir import Device, Module
|
||||
from dist_ir.ir import Device, FunctionMaker, cpprint
|
||||
from dist_ir.ir.type import Float, Tensor, TupleType
|
||||
from dist_ir.executor import SequentialExecutor
|
||||
from dist_ir.executor.shape_inference import infer_shapes
|
||||
|
||||
|
||||
class Helper:
|
||||
def __init__(self, backend):
|
||||
self.backend = backend
|
||||
self.executor = SequentialExecutor(self.backend)
|
||||
self.module = Module()
|
||||
self.t1 = self.module.add_input_value("a", Tensor(Float(), (4, 4)))
|
||||
self.t2 = self.module.add_input_value("b", Tensor(Float(), (4, 4)))
|
||||
self.t3 = self.module.add_input_value("c", Tensor(Float(), (4, 4)))
|
||||
self.function = FunctionMaker()
|
||||
self.a = self.function.add_input_value("a", Tensor(Float(), (4, 4)))
|
||||
self.b = self.function.add_input_value("b", Tensor(Float(), (4, 4)))
|
||||
self.c = self.function.add_input_value("c", Tensor(Float(), (4, 4)))
|
||||
if self.backend == "numpy":
|
||||
a = np.random.normal(size=(4, 4))
|
||||
b = np.random.normal(size=(4, 4))
|
||||
|
@ -27,9 +26,9 @@ class Helper:
|
|||
else:
|
||||
raise ValueError(f"Unknown backend {self.backend}")
|
||||
self.input_data = {
|
||||
"a": a,
|
||||
"b": b,
|
||||
"c": c,
|
||||
self.a: a,
|
||||
self.b: b,
|
||||
self.c: c,
|
||||
}
|
||||
print(f"Backend: {self.backend}")
|
||||
|
||||
|
@ -41,118 +40,122 @@ def backend(request):
|
|||
|
||||
def test_single_add(backend):
|
||||
h = Helper(backend)
|
||||
h.module.add_op("Add", "Add_0", inputs=[h.t1, h.t2])
|
||||
h.module.finalize()
|
||||
output_data = h.executor.compute(h.module, h.input_data)
|
||||
result = output_data["Add_0/0"]
|
||||
res = h.function.add_op("Add", "Add_0", inputs=[h.a, h.b])
|
||||
h.function = h.function.finalize()
|
||||
output_data = h.executor.compute(h.function, h.input_data)
|
||||
result = output_data[res]
|
||||
if h.backend == "numpy":
|
||||
assert np.array_equal(result, np.add(h.input_data["a"], h.input_data["b"]))
|
||||
assert np.array_equal(result, np.add(h.input_data[h.a], h.input_data[h.b]))
|
||||
elif h.backend == "torch":
|
||||
assert result.equal(torch.add(h.input_data["a"], h.input_data["b"]))
|
||||
assert result.equal(torch.add(h.input_data[h.a], h.input_data[h.b]))
|
||||
|
||||
|
||||
def test_double_add(backend):
|
||||
h = Helper(backend)
|
||||
x = h.module.add_op("Add", "Add_0", inputs=[h.t1, h.t2])
|
||||
h.module.add_op("Add", "Add_1", inputs=[h.t3, x])
|
||||
h.module.finalize()
|
||||
output_data = h.executor.compute(h.module, h.input_data)
|
||||
result = output_data["Add_1/0"]
|
||||
x = h.function.add_op("Add", "Add_0", inputs=[h.a, h.b])
|
||||
res = h.function.add_op("Add", "Add_1", inputs=[h.c, x])
|
||||
h.function = h.function.finalize()
|
||||
output_data = h.executor.compute(h.function, h.input_data)
|
||||
result = output_data[res]
|
||||
if h.backend == "numpy":
|
||||
assert np.array_equal(
|
||||
result,
|
||||
np.add(h.input_data["c"], np.add(h.input_data["a"], h.input_data["b"])),
|
||||
np.add(h.input_data[h.c], np.add(h.input_data[h.a], h.input_data[h.b])),
|
||||
)
|
||||
elif h.backend == "torch":
|
||||
assert result.equal(
|
||||
torch.add(
|
||||
h.input_data["c"],
|
||||
torch.add(h.input_data["a"], h.input_data["b"]),
|
||||
h.input_data[h.c],
|
||||
torch.add(h.input_data[h.a], h.input_data[h.b]),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_double_add_inverted(backend):
|
||||
h = Helper(backend)
|
||||
x = h.module.add_op("Add", "Add_0", inputs=[h.t1, h.t2])
|
||||
h.module.add_op("Add", "Add_1", inputs=[x, h.t3])
|
||||
h.module.finalize()
|
||||
output_data = h.executor.compute(h.module, h.input_data)
|
||||
result = output_data["Add_1/0"]
|
||||
x = h.function.add_op("Add", "Add_0", inputs=[h.a, h.b])
|
||||
res = h.function.add_op("Add", "Add_1", inputs=[x, h.c])
|
||||
h.function = h.function.finalize()
|
||||
output_data = h.executor.compute(h.function, h.input_data)
|
||||
result = output_data[res]
|
||||
if h.backend == "numpy":
|
||||
assert np.array_equal(
|
||||
result,
|
||||
np.add(np.add(h.input_data["a"], h.input_data["b"]), h.input_data["c"]),
|
||||
np.add(np.add(h.input_data[h.a], h.input_data[h.b]), h.input_data[h.c]),
|
||||
)
|
||||
elif h.backend == "torch":
|
||||
assert result.equal(
|
||||
torch.add(
|
||||
torch.add(h.input_data["a"], h.input_data["b"]),
|
||||
h.input_data["c"],
|
||||
torch.add(h.input_data[h.a], h.input_data[h.b]),
|
||||
h.input_data[h.c],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_single_matmul(backend):
|
||||
h = Helper(backend)
|
||||
h.module.add_op("MatMul", "MatMul_0", inputs=[h.t1, h.t2])
|
||||
h.module.finalize()
|
||||
output_data = h.executor.compute(h.module, h.input_data)
|
||||
result = output_data["MatMul_0/0"]
|
||||
res = h.function.add_op("MatMul", "MatMul_0", inputs=[h.a, h.b])
|
||||
h.function = h.function.finalize()
|
||||
output_data = h.executor.compute(h.function, h.input_data)
|
||||
result = output_data[res]
|
||||
if h.backend == "numpy":
|
||||
assert np.array_equal(result, np.matmul(h.input_data["a"], h.input_data["b"]))
|
||||
assert np.array_equal(result, np.matmul(h.input_data[h.a], h.input_data[h.b]))
|
||||
elif h.backend == "torch":
|
||||
assert result.equal(torch.matmul(h.input_data["a"], h.input_data["b"]))
|
||||
assert result.equal(torch.matmul(h.input_data[h.a], h.input_data[h.b]))
|
||||
|
||||
|
||||
def test_double_matmul(backend):
|
||||
h = Helper(backend)
|
||||
x = h.module.add_op("MatMul", "MatMul_0", inputs=[h.t1, h.t2])
|
||||
h.module.add_op("MatMul", "MatMul_1", inputs=[h.t3, x])
|
||||
h.module.finalize()
|
||||
output_data = h.executor.compute(h.module, h.input_data)
|
||||
result = output_data["MatMul_1/0"]
|
||||
x = h.function.add_op("MatMul", "MatMul_0", inputs=[h.a, h.b])
|
||||
res = h.function.add_op("MatMul", "MatMul_1", inputs=[h.c, x])
|
||||
h.function = h.function.finalize()
|
||||
output_data = h.executor.compute(h.function, h.input_data)
|
||||
result = output_data[res]
|
||||
if h.backend == "numpy":
|
||||
assert np.array_equal(
|
||||
result,
|
||||
np.matmul(
|
||||
h.input_data["c"], np.matmul(h.input_data["a"], h.input_data["b"])
|
||||
h.input_data[h.c], np.matmul(h.input_data[h.a], h.input_data[h.b])
|
||||
),
|
||||
)
|
||||
elif h.backend == "torch":
|
||||
assert result.equal(
|
||||
torch.matmul(
|
||||
h.input_data["c"],
|
||||
torch.matmul(h.input_data["a"], h.input_data["b"]),
|
||||
h.input_data[h.c],
|
||||
torch.matmul(h.input_data[h.a], h.input_data[h.b]),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_double_matmul_inverted(backend):
|
||||
h = Helper(backend)
|
||||
x = h.module.add_op("MatMul", "MatMul_0", inputs=[h.t1, h.t2])
|
||||
h.module.add_op("MatMul", "MatMul_1", inputs=[x, h.t3])
|
||||
h.module.finalize()
|
||||
output_data = h.executor.compute(h.module, h.input_data)
|
||||
result = output_data["MatMul_1/0"]
|
||||
x = h.function.add_op("MatMul", "MatMul_0", inputs=[h.a, h.b])
|
||||
res = h.function.add_op("MatMul", "MatMul_1", inputs=[x, h.c])
|
||||
h.function = h.function.finalize()
|
||||
output_data = h.executor.compute(h.function, h.input_data)
|
||||
result = output_data[res]
|
||||
if h.backend == "numpy":
|
||||
assert np.array_equal(
|
||||
result,
|
||||
np.matmul(
|
||||
np.matmul(h.input_data["a"], h.input_data["b"]), h.input_data["c"]
|
||||
np.matmul(h.input_data[h.a], h.input_data[h.b]), h.input_data[h.c]
|
||||
),
|
||||
)
|
||||
elif h.backend == "torch":
|
||||
assert result.equal(
|
||||
torch.matmul(
|
||||
torch.matmul(h.input_data["a"], h.input_data["b"]),
|
||||
h.input_data["c"],
|
||||
torch.matmul(h.input_data[h.a], h.input_data[h.b]),
|
||||
h.input_data[h.c],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# TODO: Add test for op with multiple outputs
|
||||
|
||||
# TODO for all pmap tests, make a FunctionMaker helper function to add pmap
|
||||
# which also creates the device var and sets the attributes etc appropriately.
|
||||
# This should also be used by transforms/parsers that create pmap ops.
|
||||
|
||||
|
||||
def test_pmap_on_executor():
|
||||
d0 = Device(0, "gpu")
|
||||
|
@ -169,137 +172,142 @@ def test_pmap_on_executor():
|
|||
_y_0, _y_1 = _y[:4], _y[4:]
|
||||
|
||||
# A pmap with 1 input and 1 output
|
||||
module = Module()
|
||||
xs = module.add_input_value("xs", TupleType((x_type(d0), x_type(d1))))
|
||||
submodule = Module()
|
||||
x = submodule.add_input_value("x", x_type(None))
|
||||
_ = submodule.add_op("Add", "Add0", inputs=[x, x], output_names=["z"])
|
||||
# submodule.set_outputs()
|
||||
submodule.finalize()
|
||||
_ = module.add_op(
|
||||
function = FunctionMaker()
|
||||
xs = function.add_input_value("xs", TupleType((x_type(d0), x_type(d1))))
|
||||
subfunction = FunctionMaker()
|
||||
x = subfunction.add_input_value("x", x_type(None))
|
||||
_ = subfunction.add_op("Add", "Add0", inputs=[x, x], output_names=["z"])
|
||||
# subfunction.set_outputs()
|
||||
subfunction = subfunction.finalize()
|
||||
zis = function.add_op(
|
||||
"Pmap",
|
||||
inputs=[xs],
|
||||
attributes={"devices": [d0, d1]},
|
||||
submodules=[submodule],
|
||||
subfunctions=[subfunction],
|
||||
output_names=["zis"],
|
||||
)
|
||||
module.finalize()
|
||||
function = function.finalize()
|
||||
|
||||
res = ex.compute(module, {"xs": (_x_0, _x_1)})
|
||||
assert np.array_equal(res["zis"][0], _x_0 + _x_0)
|
||||
assert np.array_equal(res["zis"][1], _x_1 + _x_1)
|
||||
cpprint(function)
|
||||
res = ex.compute(function, {xs: (_x_0, _x_1)})
|
||||
assert np.array_equal(res[zis][0], _x_0 + _x_0)
|
||||
assert np.array_equal(res[zis][1], _x_1 + _x_1)
|
||||
|
||||
# A pmap with 2 inputs and 1 output
|
||||
module = Module()
|
||||
xs = module.add_input_value("xs", TupleType((x_type(d0), x_type(d1))))
|
||||
ys = module.add_input_value("ys", TupleType((y_type(d0), y_type(d1))))
|
||||
submodule = Module()
|
||||
x = submodule.add_input_value("x", x_type(None))
|
||||
y = submodule.add_input_value("y", y_type(None))
|
||||
_ = submodule.add_op("MatMul", "MatMul0", inputs=[x, y], output_names=["z"])
|
||||
submodule.finalize()
|
||||
_ = module.add_op(
|
||||
function = FunctionMaker()
|
||||
xs = function.add_input_value("xs", TupleType((x_type(d0), x_type(d1))))
|
||||
ys = function.add_input_value("ys", TupleType((y_type(d0), y_type(d1))))
|
||||
subfunction = FunctionMaker()
|
||||
x = subfunction.add_input_value("x", x_type(None))
|
||||
y = subfunction.add_input_value("y", y_type(None))
|
||||
_ = subfunction.add_op("MatMul", "MatMul0", inputs=[x, y], output_names=["z"])
|
||||
subfunction = subfunction.finalize()
|
||||
zis = function.add_op(
|
||||
"Pmap",
|
||||
inputs=[xs, ys],
|
||||
attributes={"devices": [d0, d1]},
|
||||
submodules=[submodule],
|
||||
subfunctions=[subfunction],
|
||||
output_names=["zis"],
|
||||
)
|
||||
module.finalize()
|
||||
function = function.finalize()
|
||||
|
||||
res = ex.compute(module, {"xs": (_x_0, _x_1), "ys": (_y_0, _y_1)})
|
||||
assert np.array_equal(res["zis"][0], np.matmul(_x_0, _y_0))
|
||||
assert np.array_equal(res["zis"][1], np.matmul(_x_1, _y_1))
|
||||
cpprint(function)
|
||||
res = ex.compute(function, {xs: (_x_0, _x_1), ys: (_y_0, _y_1)})
|
||||
assert np.array_equal(res[zis][0], np.matmul(_x_0, _y_0))
|
||||
assert np.array_equal(res[zis][1], np.matmul(_x_1, _y_1))
|
||||
|
||||
# A pmap with 2 inputs and 2 outputs
|
||||
module = Module()
|
||||
xs = module.add_input_value("xs", TupleType((x_type(d0), x_type(d1))))
|
||||
ys = module.add_input_value("ys", TupleType((y_type(d0), y_type(d1))))
|
||||
submodule = Module()
|
||||
x = submodule.add_input_value("x", x_type(None))
|
||||
y = submodule.add_input_value("y", y_type(None))
|
||||
_ = submodule.add_op("Add", "Add0", inputs=[x, x], output_names=["w"])
|
||||
_ = submodule.add_op("MatMul", "MatMul0", inputs=[x, y], output_names=["z"])
|
||||
submodule.finalize()
|
||||
_ = module.add_op(
|
||||
function = FunctionMaker()
|
||||
xs = function.add_input_value("xs", TupleType((x_type(d0), x_type(d1))))
|
||||
ys = function.add_input_value("ys", TupleType((y_type(d0), y_type(d1))))
|
||||
subfunction = FunctionMaker()
|
||||
x = subfunction.add_input_value("x", x_type(None))
|
||||
y = subfunction.add_input_value("y", y_type(None))
|
||||
_ = subfunction.add_op("Add", "Add0", inputs=[x, x], output_names=["w"])
|
||||
_ = subfunction.add_op("MatMul", "MatMul0", inputs=[x, y], output_names=["z"])
|
||||
subfunction = subfunction.finalize()
|
||||
(wis, zis) = function.add_op(
|
||||
"Pmap",
|
||||
inputs=[xs, ys],
|
||||
attributes={"devices": [d0, d1]},
|
||||
submodules=[submodule],
|
||||
subfunctions=[subfunction],
|
||||
output_names=["wis", "zis"],
|
||||
)
|
||||
module.finalize()
|
||||
function = function.finalize()
|
||||
|
||||
res = ex.compute(module, {"xs": (_x_0, _x_1), "ys": (_y_0, _y_1)})
|
||||
assert np.array_equal(res["wis"][0], _x_0 + _x_0)
|
||||
assert np.array_equal(res["wis"][1], _x_1 + _x_1)
|
||||
assert np.array_equal(res["zis"][0], np.matmul(_x_0, _y_0))
|
||||
assert np.array_equal(res["zis"][1], np.matmul(_x_1, _y_1))
|
||||
cpprint(function)
|
||||
res = ex.compute(function, {xs: (_x_0, _x_1), ys: (_y_0, _y_1)})
|
||||
assert np.array_equal(res[wis][0], _x_0 + _x_0)
|
||||
assert np.array_equal(res[wis][1], _x_1 + _x_1)
|
||||
assert np.array_equal(res[zis][0], np.matmul(_x_0, _y_0))
|
||||
assert np.array_equal(res[zis][1], np.matmul(_x_1, _y_1))
|
||||
|
||||
# A pmap with a single device
|
||||
module = Module()
|
||||
xs = module.add_input_value("xs", TupleType((x_type(None),)))
|
||||
ys = module.add_input_value("ys", TupleType((y_type(None),)))
|
||||
submodule = Module()
|
||||
x = submodule.add_input_value("x", x_type(None))
|
||||
y = submodule.add_input_value("y", y_type(None))
|
||||
_ = submodule.add_op("Add", "Add0", inputs=[x, x], output_names=["w"])
|
||||
_ = submodule.add_op("MatMul", "MatMul0", inputs=[x, y], output_names=["z"])
|
||||
submodule.finalize()
|
||||
_ = module.add_op(
|
||||
function = FunctionMaker()
|
||||
xs = function.add_input_value("xs", TupleType((x_type(None),)))
|
||||
ys = function.add_input_value("ys", TupleType((y_type(None),)))
|
||||
subfunction = FunctionMaker()
|
||||
x = subfunction.add_input_value("x", x_type(None))
|
||||
y = subfunction.add_input_value("y", y_type(None))
|
||||
_ = subfunction.add_op("Add", "Add0", inputs=[x, x], output_names=["w"])
|
||||
_ = subfunction.add_op("MatMul", "MatMul0", inputs=[x, y], output_names=["z"])
|
||||
subfunction = subfunction.finalize()
|
||||
(wis, zis) = function.add_op(
|
||||
"Pmap",
|
||||
inputs=[xs, ys],
|
||||
attributes={"devices": [d0]},
|
||||
submodules=[submodule],
|
||||
subfunctions=[subfunction],
|
||||
output_names=["wis", "zis"],
|
||||
)
|
||||
module.finalize()
|
||||
function = function.finalize()
|
||||
|
||||
res = ex.compute(module, {"xs": (_x_0,), "ys": (_y_0,)})
|
||||
assert np.array_equal(res["wis"][0], _x_0 + _x_0)
|
||||
assert np.array_equal(res["zis"][0], np.matmul(_x_0, _y_0))
|
||||
cpprint(function)
|
||||
res = ex.compute(function, {xs: (_x_0,), ys: (_y_0,)})
|
||||
assert np.array_equal(res[wis][0], _x_0 + _x_0)
|
||||
assert np.array_equal(res[zis][0], np.matmul(_x_0, _y_0))
|
||||
|
||||
|
||||
def test_pmap_dp():
|
||||
module = Module()
|
||||
function = FunctionMaker()
|
||||
|
||||
d0 = Device(0, "gpu")
|
||||
d1 = Device(1, "gpu")
|
||||
|
||||
xs = module.add_input_value(
|
||||
xs = function.add_input_value(
|
||||
"xs",
|
||||
TupleType(
|
||||
(Tensor(Float(), (8, 4), device=d0), Tensor(Float(), (8, 4), device=d1))
|
||||
),
|
||||
)
|
||||
wAs = module.add_input_value(
|
||||
wAs = function.add_input_value(
|
||||
"wAs",
|
||||
TupleType(
|
||||
(Tensor(Float(), (4, 2), device=d0), Tensor(Float(), (4, 2), device=d1))
|
||||
),
|
||||
)
|
||||
wBs = module.add_input_value(
|
||||
wBs = function.add_input_value(
|
||||
"wBs",
|
||||
TupleType(
|
||||
(Tensor(Float(), (2, 1), device=d0), Tensor(Float(), (2, 1), device=d1))
|
||||
),
|
||||
)
|
||||
|
||||
submodule = Module()
|
||||
x = submodule.add_input_value("x", Tensor(Float(), (8, 4)))
|
||||
wA = submodule.add_input_value("wA", Tensor(Float(), (4, 2)))
|
||||
wB = submodule.add_input_value("wB", Tensor(Float(), (2, 1)))
|
||||
y = submodule.add_op("MatMul", "MatMul0", inputs=[x, wA], output_names=["y"])
|
||||
_ = submodule.add_op("MatMul", "MatMul1", inputs=[y, wB], output_names=["z"])
|
||||
submodule.finalize()
|
||||
_ = module.add_op(
|
||||
subfunction = FunctionMaker()
|
||||
x = subfunction.add_input_value("x", Tensor(Float(), (8, 4)))
|
||||
wA = subfunction.add_input_value("wA", Tensor(Float(), (4, 2)))
|
||||
wB = subfunction.add_input_value("wB", Tensor(Float(), (2, 1)))
|
||||
y = subfunction.add_op("MatMul", "MatMul0", inputs=[x, wA], output_names=["y"])
|
||||
_ = subfunction.add_op("MatMul", "MatMul1", inputs=[y, wB], output_names=["z"])
|
||||
subfunction = subfunction.finalize()
|
||||
zis = function.add_op(
|
||||
"Pmap",
|
||||
inputs=[xs, wAs, wBs],
|
||||
attributes={"devices": [d0, d1]},
|
||||
submodules=[submodule],
|
||||
subfunctions=[subfunction],
|
||||
output_names=["zis"],
|
||||
)
|
||||
module.finalize()
|
||||
function = function.finalize()
|
||||
cpprint(function)
|
||||
|
||||
ex = SequentialExecutor("numpy")
|
||||
_x = np.arange(16 * 4).reshape((16, 4))
|
||||
|
@ -307,12 +315,8 @@ def test_pmap_dp():
|
|||
_wA = np.ones((4, 2))
|
||||
_wB = np.ones((2, 1))
|
||||
res = ex.compute(
|
||||
module,
|
||||
{"xs": (x_0, x_1), "wAs": (_wA, _wA), "wBs": (_wB, _wB)},
|
||||
function,
|
||||
{xs: (x_0, x_1), wAs: (_wA, _wA), wBs: (_wB, _wB)},
|
||||
)
|
||||
assert np.array_equal(res["zis"][0], np.matmul(np.matmul(x_0, _wA), _wB))
|
||||
assert np.array_equal(res["zis"][1], np.matmul(np.matmul(x_1, _wA), _wB))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_pmap_on_executor()
|
||||
assert np.array_equal(res[zis][0], np.matmul(np.matmul(x_0, _wA), _wB))
|
||||
assert np.array_equal(res[zis][1], np.matmul(np.matmul(x_1, _wA), _wB))
|
||||
|
|
|
@ -1,190 +0,0 @@
|
|||
import pytest
|
||||
|
||||
from dist_ir.ir import Device, Module
|
||||
from dist_ir.executor.shape_inference import infer_shapes
|
||||
from dist_ir.ir.type import Float, Tensor, TupleType
|
||||
|
||||
|
||||
def test_add_valid():
|
||||
module = Module()
|
||||
|
||||
a = module.add_input_value("a", Tensor(Float(), (4, 4)))
|
||||
b = module.add_input_value("b", Tensor(Float(), (4, 4)))
|
||||
x = module.add_op("Add", "Add0", inputs=[a, b], output_names=["x"])
|
||||
infer_shapes(module)
|
||||
assert x.type.shape == (4, 4)
|
||||
|
||||
|
||||
def test_add_invalid():
|
||||
module = Module()
|
||||
|
||||
a = module.add_input_value("a", Tensor(Float(), (8, 4)))
|
||||
b = module.add_input_value("b", Tensor(Float(), (4, 2)))
|
||||
x = module.add_op("Add", "Add0", inputs=[a, b], output_names=["x"])
|
||||
with pytest.raises(ValueError):
|
||||
infer_shapes(module)
|
||||
|
||||
|
||||
def test_allreduce():
|
||||
module = Module()
|
||||
|
||||
d0 = Device(0, "gpu")
|
||||
d1 = Device(1, "gpu")
|
||||
|
||||
xis = module.add_input_value(
|
||||
"xis",
|
||||
TupleType(
|
||||
(Tensor(Float(), (4, 4), device=d0), Tensor(Float(), (4, 4), device=d1))
|
||||
),
|
||||
)
|
||||
xs = module.add_op(
|
||||
"Allreduce",
|
||||
"Allreduces/xis",
|
||||
inputs=[xis],
|
||||
output_names=["xs"],
|
||||
)
|
||||
infer_shapes(module)
|
||||
|
||||
assert isinstance(xs.type, TupleType)
|
||||
for i, value_type in enumerate(xis.type.types):
|
||||
assert value_type.shape == xs.type.types[i].shape
|
||||
assert value_type.device == xs.type.types[i].device
|
||||
|
||||
|
||||
def test_broadcast():
|
||||
module = Module()
|
||||
|
||||
d0 = Device(0, "gpu")
|
||||
d1 = Device(1, "gpu")
|
||||
|
||||
x = module.add_input_value("x", Tensor(Float(), (4, 4)))
|
||||
xs = module.add_op(
|
||||
"Broadcast",
|
||||
"Broadcast/x",
|
||||
inputs=[x],
|
||||
attributes={"devices": [d0, d1]},
|
||||
output_names=["xs"],
|
||||
)
|
||||
infer_shapes(module)
|
||||
|
||||
assert isinstance(xs.type, TupleType)
|
||||
assert xs.type.types[0].shape == (4, 4)
|
||||
assert xs.type.types[0].device == d0
|
||||
assert xs.type.types[1].shape == (4, 4)
|
||||
assert xs.type.types[1].device == d1
|
||||
|
||||
|
||||
def test_matmul_valid():
|
||||
module = Module()
|
||||
|
||||
a = module.add_input_value("a", Tensor(Float(), (8, 4)))
|
||||
b = module.add_input_value("b", Tensor(Float(), (4, 2)))
|
||||
x = module.add_op("MatMul", "MatMul0", inputs=[a, b], output_names=["x"])
|
||||
infer_shapes(module)
|
||||
assert x.type.shape == (8, 2)
|
||||
|
||||
|
||||
def test_matmul_invalid():
|
||||
module = Module()
|
||||
|
||||
a = module.add_input_value("a", Tensor(Float(), (8, 8)))
|
||||
b = module.add_input_value("b", Tensor(Float(), (4, 2)))
|
||||
x = module.add_op("MatMul", "MatMul0", inputs=[a, b], output_names=["x"])
|
||||
with pytest.raises(ValueError):
|
||||
infer_shapes(module)
|
||||
|
||||
|
||||
def test_matmul_grad():
|
||||
module = Module()
|
||||
|
||||
x = module.add_input_value("x", Tensor(Float(), (8, 4)))
|
||||
w = module.add_input_value("w", Tensor(Float(), (4, 2)))
|
||||
l = module.add_input_value("l", Tensor(Float(), (8,)))
|
||||
dx, dw = module.add_op(
|
||||
"MatMulGrad", "MatMulGrad0", inputs=[x, w, l], output_names=["dx", "dw"]
|
||||
)
|
||||
infer_shapes(module)
|
||||
assert dx.type.shape == x.type.shape
|
||||
assert dw.type.shape == w.type.shape
|
||||
|
||||
|
||||
def test_pmap():
|
||||
module = Module()
|
||||
|
||||
d0 = Device(0, "gpu")
|
||||
d1 = Device(1, "gpu")
|
||||
|
||||
xs = module.add_input_value(
|
||||
"xs",
|
||||
TupleType(
|
||||
(Tensor(Float(), (8, 4), device=d0), Tensor(Float(), (8, 4), device=d1))
|
||||
),
|
||||
)
|
||||
wAs = module.add_input_value(
|
||||
"wAs",
|
||||
TupleType(
|
||||
(Tensor(Float(), (4, 2), device=d0), Tensor(Float(), (4, 2), device=d1))
|
||||
),
|
||||
)
|
||||
wBs = module.add_input_value(
|
||||
"wBs",
|
||||
TupleType(
|
||||
(Tensor(Float(), (2, 1), device=d0), Tensor(Float(), (2, 1), device=d1))
|
||||
),
|
||||
)
|
||||
|
||||
submodule = Module()
|
||||
# TODO: Add a check in shape inference to not overwrite if there is already a shape
|
||||
# If there is an existing shape, validate that it is what we expect, otherwise throw error
|
||||
x = submodule.add_input_value("x", Tensor(Float(), (8, 4)))
|
||||
wA = submodule.add_input_value("wA", Tensor(Float(), (4, 2)))
|
||||
wB = submodule.add_input_value("wB", Tensor(Float(), (2, 1)))
|
||||
y = submodule.add_op("MatMul", "MatMul0", inputs=[x, wA], output_names=["y"])
|
||||
z = submodule.add_op("MatMul", "MatMul1", inputs=[y, wB], output_names=["z"])
|
||||
submodule.finalize()
|
||||
|
||||
zis = module.add_op(
|
||||
"Pmap",
|
||||
inputs=[xs, wAs, wBs],
|
||||
attributes={"devices": [d0, d1]},
|
||||
submodules=[submodule],
|
||||
output_names=["zis"],
|
||||
)
|
||||
|
||||
infer_shapes(module)
|
||||
|
||||
print(module)
|
||||
|
||||
# TODO: Verify submodule shapes and devices
|
||||
|
||||
assert zis.type.types[0].shape == (8, 1)
|
||||
assert zis.type.types[0].device == d0
|
||||
assert zis.type.types[1].shape == (8, 1)
|
||||
assert zis.type.types[1].device == d1
|
||||
|
||||
|
||||
def test_scatter():
|
||||
module = Module()
|
||||
|
||||
d0 = Device(0, "gpu")
|
||||
d1 = Device(1, "gpu")
|
||||
|
||||
x = module.add_input_value("x", Tensor(Float(), (4, 4)))
|
||||
xs = module.add_op(
|
||||
"Scatter",
|
||||
"Scatter/x",
|
||||
inputs=[x],
|
||||
attributes={"dim": 0, "devices": [d0, d1]},
|
||||
output_names=["xs"],
|
||||
)
|
||||
infer_shapes(module)
|
||||
|
||||
assert isinstance(xs.type, TupleType)
|
||||
assert xs.type.types[0].shape == (2, 4)
|
||||
assert xs.type.types[0].device == d0
|
||||
assert xs.type.types[1].shape == (2, 4)
|
||||
assert xs.type.types[1].device == d1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_pmap()
|
|
@ -0,0 +1,45 @@
|
|||
from dist_ir.ir import FunctionMaker
|
||||
from dist_ir.ir.type import Tensor, Float
|
||||
|
||||
|
||||
def test_subfunction():
|
||||
function = FunctionMaker()
|
||||
|
||||
inputs = []
|
||||
outputs = []
|
||||
num_ops = 9
|
||||
for i in range(num_ops + 1):
|
||||
inputs.append(function.add_input_value(f"x{i}", Tensor(Float(), (4, 4))))
|
||||
for i in range(num_ops):
|
||||
if i == 0:
|
||||
input_values = inputs[:2]
|
||||
else:
|
||||
input_values = [outputs[-1], inputs[i + 1]]
|
||||
outputs.append(
|
||||
function.add_op(
|
||||
"Add", f"Add{i}", inputs=input_values, output_names=[f"a{i}"]
|
||||
)
|
||||
)
|
||||
function = function.finalize()
|
||||
|
||||
subfunction = function.get_subfunction(("Add0", "Add1", "Add2"))
|
||||
subfunction_inputs = subfunction.inputs
|
||||
subfunction_outputs = subfunction.outputs
|
||||
assert [v.name for v in subfunction_inputs] == ["x0", "x1", "x2", "x3"]
|
||||
assert [v.name for v in subfunction_outputs] == ["a2"]
|
||||
|
||||
subfunction = function.get_subfunction(("Add3", "Add4", "Add5"))
|
||||
subfunction_inputs = subfunction.inputs
|
||||
subfunction_outputs = subfunction.outputs
|
||||
assert [v.name for v in subfunction_inputs] == ["a2", "x4", "x5", "x6"]
|
||||
assert [v.name for v in subfunction_outputs] == ["a5"]
|
||||
|
||||
subfunction = function.get_subfunction(("Add6", "Add7", "Add8"))
|
||||
subfunction_inputs = subfunction.inputs
|
||||
subfunction_outputs = subfunction.outputs
|
||||
assert [v.name for v in subfunction_inputs] == ["a5", "x7", "x8", "x9"]
|
||||
assert [v.name for v in subfunction_outputs] == ["a8"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_subfunction()
|
|
@ -1,43 +0,0 @@
|
|||
from dist_ir.ir import Module
|
||||
from dist_ir.ir.type import Tensor, Float
|
||||
|
||||
|
||||
def test_submodule():
|
||||
module = Module()
|
||||
|
||||
inputs = []
|
||||
outputs = []
|
||||
num_ops = 9
|
||||
for i in range(num_ops + 1):
|
||||
inputs.append(module.add_input_value(f"x{i}", Tensor(Float(), (4, 4))))
|
||||
for i in range(num_ops):
|
||||
if i == 0:
|
||||
input_values = inputs[:2]
|
||||
else:
|
||||
input_values = [outputs[-1], inputs[i + 1]]
|
||||
outputs.append(
|
||||
module.add_op("Add", f"Add{i}", inputs=input_values, output_names=[f"a{i}"])
|
||||
)
|
||||
module.finalize()
|
||||
|
||||
submodule = module.get_submodule(("Add0", "Add1", "Add2"))
|
||||
submodule_inputs = submodule.get_inputs()
|
||||
submodule_outputs = submodule.get_outputs()
|
||||
assert [v.name for v in submodule_inputs] == ["x0", "x1", "x2", "x3"]
|
||||
assert [v.name for v in submodule_outputs] == ["a2"]
|
||||
|
||||
submodule = module.get_submodule(("Add3", "Add4", "Add5"))
|
||||
submodule_inputs = submodule.get_inputs()
|
||||
submodule_outputs = submodule.get_outputs()
|
||||
assert [v.name for v in submodule_inputs] == ["a2", "x4", "x5", "x6"]
|
||||
assert [v.name for v in submodule_outputs] == ["a5"]
|
||||
|
||||
submodule = module.get_submodule(("Add6", "Add7", "Add8"))
|
||||
submodule_inputs = submodule.get_inputs()
|
||||
submodule_outputs = submodule.get_outputs()
|
||||
assert [v.name for v in submodule_inputs] == ["a5", "x7", "x8", "x9"]
|
||||
assert [v.name for v in submodule_outputs] == ["a8"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_submodule()
|
|
@ -0,0 +1,205 @@
|
|||
import pytest
|
||||
|
||||
from dist_ir.ir import cpprint, Device, Function, FunctionMaker, Op, Value
|
||||
from dist_ir.executor.type_inference import infer_types
|
||||
from dist_ir.ir.type import Float, Tensor, TupleType
|
||||
|
||||
|
||||
def test_add_valid():
|
||||
function = FunctionMaker()
|
||||
|
||||
a = function.add_input_value("a", Tensor(Float(), (4, 4)))
|
||||
b = function.add_input_value("b", Tensor(Float(), (4, 4)))
|
||||
x = function.add_op("Add", "Add0", inputs=[a, b], output_names=["x"])
|
||||
function = function.finalize()
|
||||
typed_function = infer_types(function, [a, b])
|
||||
assert typed_function.outputs[0].type.shape == (4, 4)
|
||||
|
||||
|
||||
def test_add_invalid():
|
||||
function = FunctionMaker()
|
||||
|
||||
a = function.add_input_value("a", Tensor(Float(), (8, 4)))
|
||||
b = function.add_input_value("b", Tensor(Float(), (4, 2)))
|
||||
x = function.add_op("Add", "Add0", inputs=[a, b], output_names=["x"])
|
||||
function = function.finalize()
|
||||
with pytest.raises(ValueError):
|
||||
infer_types(function, [a, b])
|
||||
|
||||
|
||||
def test_allreduce():
|
||||
d0 = Device(0, "gpu")
|
||||
d1 = Device(1, "gpu")
|
||||
|
||||
xis = Value(
|
||||
"xis",
|
||||
TupleType(
|
||||
(Tensor(Float(), (4, 4), device=d0), Tensor(Float(), (4, 4), device=d1))
|
||||
),
|
||||
)
|
||||
op1 = Op(
|
||||
"Allreduce",
|
||||
"Allreduces/xis",
|
||||
inputs=[xis],
|
||||
output_names=["xs"],
|
||||
)
|
||||
function = Function("foo", (op1,), (xis,), (op1.outputs[0],))
|
||||
function = infer_types(function, [xis])
|
||||
xs = function.outputs[0]
|
||||
|
||||
assert isinstance(xs.type, TupleType)
|
||||
for i, value_type in enumerate(xis.type.types):
|
||||
assert value_type.shape == xs.type.types[i].shape
|
||||
assert value_type.device == xs.type.types[i].device
|
||||
|
||||
|
||||
def test_broadcast():
|
||||
function = FunctionMaker()
|
||||
|
||||
d0 = Device(0, "gpu")
|
||||
d1 = Device(1, "gpu")
|
||||
|
||||
x = function.add_input_value("x", Tensor(Float(), (4, 4)))
|
||||
xs = function.add_op(
|
||||
"Broadcast",
|
||||
"Broadcast/x",
|
||||
inputs=[x],
|
||||
attributes={"devices": [d0, d1]},
|
||||
output_names=["xs"],
|
||||
)
|
||||
function = function.finalize()
|
||||
function = infer_types(function, [x])
|
||||
xs = function.outputs[0]
|
||||
|
||||
assert isinstance(xs.type, TupleType)
|
||||
assert xs.type.types[0].shape == (4, 4)
|
||||
assert xs.type.types[0].device == d0
|
||||
assert xs.type.types[1].shape == (4, 4)
|
||||
assert xs.type.types[1].device == d1
|
||||
|
||||
|
||||
def test_matmul_valid():
|
||||
function = FunctionMaker()
|
||||
|
||||
a = function.add_input_value("a", Tensor(Float(), (8, 4)))
|
||||
b = function.add_input_value("b", Tensor(Float(), (4, 2)))
|
||||
x = function.add_op("MatMul", "MatMul0", inputs=[a, b], output_names=["x"])
|
||||
function = function.finalize()
|
||||
function = infer_types(function, [a, b])
|
||||
assert function.outputs[0].type.shape == (8, 2)
|
||||
|
||||
|
||||
def test_matmul_invalid():
|
||||
function = FunctionMaker()
|
||||
|
||||
a = function.add_input_value("a", Tensor(Float(), (8, 8)))
|
||||
b = function.add_input_value("b", Tensor(Float(), (4, 2)))
|
||||
x = function.add_op("MatMul", "MatMul0", inputs=[a, b], output_names=["x"])
|
||||
function = function.finalize()
|
||||
with pytest.raises(ValueError):
|
||||
function = infer_types(function, [a, b])
|
||||
|
||||
|
||||
def test_matmul_grad():
|
||||
function = FunctionMaker()
|
||||
|
||||
x = function.add_input_value("x", Tensor(Float(), (8, 4)))
|
||||
w = function.add_input_value("w", Tensor(Float(), (4, 2)))
|
||||
l = function.add_input_value("l", Tensor(Float(), (8,)))
|
||||
dx, dw = function.add_op(
|
||||
"MatMulGrad", "MatMulGrad0", inputs=[x, w, l], output_names=["dx", "dw"]
|
||||
)
|
||||
function = function.finalize()
|
||||
function = infer_types(function, [x, w, l])
|
||||
dx, dw = function.outputs
|
||||
assert dx.type.shape == x.type.shape
|
||||
assert dw.type.shape == w.type.shape
|
||||
|
||||
|
||||
def test_pmap():
|
||||
function = FunctionMaker()
|
||||
|
||||
d0 = Device(0, "gpu")
|
||||
d1 = Device(1, "gpu")
|
||||
|
||||
xs = function.add_input_value(
|
||||
"xs",
|
||||
TupleType(
|
||||
(Tensor(Float(), (8, 4), device=d0), Tensor(Float(), (8, 4), device=d1))
|
||||
),
|
||||
)
|
||||
wAs = function.add_input_value(
|
||||
"wAs",
|
||||
TupleType(
|
||||
(Tensor(Float(), (4, 2), device=d0), Tensor(Float(), (4, 2), device=d1))
|
||||
),
|
||||
)
|
||||
wBs = function.add_input_value(
|
||||
"wBs",
|
||||
TupleType(
|
||||
(Tensor(Float(), (2, 1), device=d0), Tensor(Float(), (2, 1), device=d1))
|
||||
),
|
||||
)
|
||||
|
||||
subfunction = FunctionMaker()
|
||||
x = subfunction.add_input_value("x", None)
|
||||
wA = subfunction.add_input_value("wA", None)
|
||||
wB = subfunction.add_input_value("wB", None)
|
||||
y = subfunction.add_op("MatMul", "MatMul0", inputs=[x, wA], output_names=["y"])
|
||||
z = subfunction.add_op("MatMul", "MatMul1", inputs=[y, wB], output_names=["z"])
|
||||
subfunction = subfunction.finalize()
|
||||
|
||||
zis = function.add_op(
|
||||
"Pmap",
|
||||
inputs=[xs, wAs, wBs],
|
||||
attributes={
|
||||
"devices": [d0, d1],
|
||||
"device_var": Device.get_new_device_variable(
|
||||
"gpu"
|
||||
), # TODO where best to do this?
|
||||
},
|
||||
subfunctions=[subfunction],
|
||||
output_names=["zis"],
|
||||
)
|
||||
|
||||
function = function.finalize()
|
||||
cpprint(function)
|
||||
function = infer_types(function, [xs, wAs, wBs])
|
||||
cpprint(function)
|
||||
|
||||
# TODO: Verify subfunction shapes and devices
|
||||
|
||||
zis = function.outputs[0]
|
||||
assert zis.type.types[0].shape == (8, 1)
|
||||
assert zis.type.types[0].device == d0
|
||||
assert zis.type.types[1].shape == (8, 1)
|
||||
assert zis.type.types[1].device == d1
|
||||
|
||||
|
||||
def test_scatter():
|
||||
function = FunctionMaker()
|
||||
|
||||
d0 = Device(0, "gpu")
|
||||
d1 = Device(1, "gpu")
|
||||
|
||||
x = function.add_input_value("x", Tensor(Float(), (4, 4)))
|
||||
xs = function.add_op(
|
||||
"Scatter",
|
||||
"Scatter/x",
|
||||
inputs=[x],
|
||||
attributes={"dim": 0, "devices": [d0, d1]},
|
||||
output_names=["xs"],
|
||||
)
|
||||
function = function.finalize()
|
||||
function = infer_types(function, [x])
|
||||
xs = function.outputs[0]
|
||||
|
||||
assert isinstance(xs.type, TupleType)
|
||||
assert xs.type.types[0].shape == (2, 4)
|
||||
assert xs.type.types[0].device == d0
|
||||
assert xs.type.types[1].shape == (2, 4)
|
||||
assert xs.type.types[1].device == d1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_pmap()
|
Загрузка…
Ссылка в новой задаче