Antares IR Syntax
Einstein Expression Parsing:
1) * Transform-based Operator Format: output[D1, D2, ..] = f(input[D1, D2, ..]) where D2 in 8
Step-1: Auto fill einsum dummy-axis, get: output[D1, D2, ..] = f(input[D1, D2, ..]) where D2 in 8, D1 in input.shape[0], ..
Step-2: Construct basic C code, get:
for (int D1 = 0; D1 < input.shape[0]; D1++)
for (int D2 = 0; D2 < 8; D2++)
for (..)
output[D1, D2, ..] = f(input[D1, D2, ..])
* Specific Example 1 (OneHot Op): output0[N, F] = const(1.0).when([input0[N] == F], const(0.0)) where F in 128
Step-1/Step-2: .., finally get:
for (int N = 0; N < input0.shape[0]; ++N)
for (int F = 0; F < 128; ++F)
output0[N, F] = (input0[N] == F) ? const(1.0) : const(0.0);
* Specific Example 2 (Arrange Op): output0[N] = N.cast(`float32`) where N in 1024
Step-1/Step-2: .., finally get:
for (int N = 0; N < 1024; ++N)
output0[N] = static_cast<float>(N);
2) * Aggregation-based Operator Format: output[D1, D2, ..] +=! f(input[D1, D2, .., R1, R2, ..])
Step-1: Auto fill einsum dummy-axis, get: output[D1, D2, ..] +=! f(input[D1, D2, .., R1, R2, ..]) where D1 in input.shape[0], D2 in input.shape[1], .., R1 in input.shape[..], R2 in input.shape[..], ..
Step-2: Construct basic C code, get:
for (int D1 = 0; D1 < input.shape[0]; D1++)
for (int D2 = 0; D2 < input.shape[1]; D2++)
for (..)
output[D1, D2, ..] = 0;
for (int D1 = 0; D1 < input.shape[0]; D1++)
for (int D2 = 0; D2 < input.shape[1]; D2++)
for (..)
for (int R1 = 0; R1 < input.shape[0]; R1++)
for (int R2 = 0; R2 < input.shape[1]; R2++)
for (..)
output[D1, D2, ..] += f(input[D1, D2, .., R1, R2, ..]);
* Specific Example (Conv2D Op): output0[N, F, HO, WO] +=! input0[N, C, HO + KH, WO + KW] * input1[F, C, KH, KW] where HO in 30, WO in 30
Step-1/Step-2: .., finally get:
for (int N = 0; N < input0.shape[0]; ++N)
for (int F = 0; F < input1.shape[0]; ++F)
for (int HO = 0; HO < 30; ++HO)
for (int WO = 0; WO < 30; ++WO)
output0[N, F, HO, WO] = 0;
for (int N = 0; N < input0.shape[0]; ++N)
for (int F = 0; F < input1.shape[0]; ++F)
for (int HO = 0; HO < 30; ++HO)
for (int WO = 0; WO < 30; ++WO)
for (int C = 0; C < input1.shape[1]; ++C) // R1
for (int KH = 0; KH < input1.shape[2]; ++KH) // R2
for (int KW = 0; KW < input1.shape[3]; ++KW) // R3
output0[N, F, HO, WO] += input0[N, C, HO + KH, WO + KW] * input1[F, C, KH, KW];
Antares Built-in Primitive Mapping:
Primitive Type |
Antares IR Format |
C Code Format |
Branch (All) |
x.when([c1, c2, ..], y) |
(c1 && c2 && ..) ? x : y |
Branch (Any) |
x.when([c1, c2, ..], y, merge_op=`any`) |
(c1 || c2 || ..) ? x : y |
Type Cast |
x.cast(`int8`) |
static_cast<char>(x) |
Function Call (Single Arg) |
x.call(`exp`) |
exp(x) |
Function Call (Multiple Args) |
x.call(`max`, [y, ..]) |
max(x, y, ..) |
Logical Ops |
x & ~(y | z) |
x && !(y || z) |
Antares Built-in Data Type Mapping:
Antares Type |
C Type |
float64 |
double |
float32 |
float |
float16 |
half |
int32 |
int |
int16 |
short |
int8 |
char |
Antares Built-in Functions:
Function Name |
Proto |
Explanation |
max |
max(T, T) -> T |
The max value of two inputs |
min |
min(T, T) -> T |
The min value of two inputs |
log |
log(T) -> T |
The log value of input |
exp |
exp(T) -> T |
The exponentiation of input |
sqrt |
sqrt(T) -> T |
The square root of input |
pow |
pow(T, T) -> T |
The power value of inputs: a ^ b |
floor |
floor(float32/float64) -> int32/int64 |
The floor integer of input as int type |
ceil |
ceil(float32/float64) -> int32/int64 |
The ceil integer of input as int type |
rfloor |
floor(T) -> T |
The floor integer of input as input dtype |
rceil |
ceil(T) -> T |
The ceil integer of input as input dtype |
remainder |
remainder(T) -> T |
The remainder float of input |
Detailed Examples:
# Broadcast
COMPUTE_V1='- einstein_v2("output0[N, F, HO, WO] = input0[N] where F in 32, HO in 2, WO in 2", input_dict={"input0": {"dtype": "float32", "shape": [16]}})' antares
# BroadcastAll
COMPUTE_V1='- einstein_v2("output0[N, F, HO, WO] = input0[0] where N in 8, F in 32, HO in 2, WO in 2", input_dict={"input0": {"dtype": "float32", "shape": [1]}})' antares
# MatMul
COMPUTE_V1='- einstein_v2("output0[N, M] +=! input0[N, K] * input1[K, M]", { "input0": {"dtype": "float32", "shape": [1024, 512]}, "input1": {"dtype": "float32", "shape": [512, 512]}})' antares
# MatMulBiasAdd
COMPUTE_V1='- einstein_v2("output0[N, M] +=! input0[N, K] * input1[K, M] + input2[M] / K.val()", { "input0": {"dtype": "float32", "shape": [1024, 512]}, "input1": {"dtype": "float32", "shape": [512, 512]}, "input2": {"dtype": "float32", "shape": [512]} })' antares
# BatchMatMul
COMPUTE_V1='- einstein_v2("output0[B, N, M] +=! input0[B, N, K] * input1[B, K, M]", input_dict={"input0": {"dtype": "float32", "shape": [3, 1024, 512]}, "input1": {"dtype": "float32", "shape": [3, 512, 512]}})' antares
# Elementwise
COMPUTE_V1='- einstein_v2("output0[N] = input0[N] + input1[N]", input_dict={"input0": {"dtype": "float32", "shape": [1024 * 512]}, "input1": {"dtype": "float32", "shape": [1024 * 512]}})' antares
# Scaler Compute
COMPUTE_V1='- einstein_v2("output0[] = input0[] + input1[]", input_dict={"input0": {"dtype": "float32", "shape": []}, "input1": {"dtype": "float32", "shape": []}})' antares
# Multiple Outputs (in same shape)
COMPUTE_V1='- einstein_v2("output0[N] = input0[N] + input1[N]; output1[N] = input0[N] * 2; output2[N] = input1[N] + output1[N];", input_dict={"input0": {"dtype": "float32", "shape": [1024 * 512]}, "input1": {"dtype": "float32", "shape": [1024 * 512]}}, extra_outputs=["output0", "output1", "output2"])' antares
# Transpose
COMPUTE_V1='- einstein_v2("output0[N, C, H, W] = input0[N, H, W, C]", input_dict={"input0": {"dtype": "float32", "shape": [32, 229, 229, 3]}})' antares
# Reshape
COMPUTE_V1='- einstein_v2("output0[A, B, C] = input0[A, B, C / 64, C % 64] where C in 128", input_dict={"input0": {"dtype": "float32", "shape": [3, 3, 2, 64]}})' antares
# ReduceSum
COMPUTE_V1='- einstein_v2("output0[N] +=! input0[N, C]", input_dict={"input0": {"dtype": "float32", "shape": [32, 1024]}})' antares
# ReduceMin
COMPUTE_V1='- einstein_v2("output0[N] <=! input0[N, C]", input_dict={"input0": {"dtype": "float32", "shape": [32, 1024]}})' antares
# ReduceAll
COMPUTE_V1='- einstein_v2("output0[N] &=! input0[N, C]", input_dict={"input0": {"dtype": "int8", "shape": [32, 1024]}})' antares
# ReduceAny
COMPUTE_V1='- einstein_v2("output0[N] |=! input0[N, C]", input_dict={"input0": {"dtype": "int8", "shape": [32, 1024]}})' antares
# Cast
COMPUTE_V1='- einstein_v2("output0[N] = N.cast(`float32`) where N in 1024", {})' antares
# Condition Relu
COMPUTE_V1='- einstein_v2("output0[N, C] = input0[N, C].when([input0[N, C] > 0.0], 0.0)", input_dict={"input0": {"dtype": "float32", "shape": [1024, 512]}})' antares
# Condition Relu for dynamtic data type
COMPUTE_V1='- einstein_v2("output0[N, C] = input0[N, C].when([input0[N, C] > const(0.0).cast(input0[N, C].dtype())], const(0.0).cast(input0[N, C].dtype()))", input_dict={"input0": {"dtype": "float32", "shape": [1024, 512]}})' antares
# `Range + Tanh` using External Function
COMPUTE_V1='- einstein_v2("output0[N] = N.cast(`float32`).call(`tanh`) where N in 1024", {})' antares
# ConvolutionNoPad
COMPUTE_V1='- einstein_v2("output0[N, F, HO, WO] +=! input0[N, C, HO + KH, WO + KW] * input1[F, C, KH, KW] where HO in 30, WO in 30", { "input0": {"dtype": "float32", "shape": [16, 64, 32, 32]}, "input1": {"dtype": "float32", "shape": [256, 64, 3, 3]}})' antares
# ConvolutionWithPad
COMPUTE_V1='- _N, _CI, _H, _W, _CO, _KH, _KW, _SH, _SW, _PH, _PW = 16, 64, 32, 32, 256, 3, 3, 1, 1, 0, 0; _HO, _WO = (_H - _KH + _PH * 2) // _SH + 1, (_W - _KW + _PW * 2) // _SW + 1; einstein_v2(f"output0[N, F, HO, WO] +=! input0[N, C, HO * {_SH} + KH - {_PH}, WO * {_SW} + KW - {_PW}].when([HO * {_SH} + KH - {_PH} >= 0, HO * {_SH} + KH - {_PH} < {_H}, WO * {_SW} + KW - {_PW} >= 0, WO * {_SW} + KW - {_PW} < {_W}], 0.0) * input1[F, C, KH, KW] where HO in {_HO}, WO in {_WO}", { "input0": {"dtype": "float32", "shape": [_N, _CI, _H, _W]}, "input1": {"dtype": "float32", "shape": [_CO, _CI, _KH, _KW]}})' antares
# ConvolutionWithPad (Fused reduce axis)
COMPUTE_V1='- _N, _CI, _H, _W, _CO, _KH, _KW, _SH, _SW, _PH, _PW = 16, 64, 32, 32, 256, 3, 3, 1, 1, 0, 0; \
_HO, _WO = (_H - _KH + _PH * 2) // _SH + 1, (_W - _KW + _PW * 2) // _SW + 1; \
einstein_v2(f" \
output0[N, F, HO, WO] +=! input0[N, CKHKW // {_KW * _KH}, HO * {_SH} + ((CKHKW % {_KW * _KH}) // {_KW}) - {_PH}, WO * {_SW} + ((CKHKW % {_KW * _KH}) % {_KW}) - {_PW}].when([HO * {_SH} + ((CKHKW % {_KW * _KH}) // {_KW}) - {_PH} >= 0, HO * {_SH} + ((CKHKW % {_KW * _KH}) // {_KW}) - {_PH} < {_H}, WO * {_SW} + ((CKHKW % {_KW * _KH}) % {_KW}) - {_PW} >= 0, WO * {_SW} + ((CKHKW % {_KW * _KH}) % {_KW}) - {_PW} < {_W}], 0.0) * input1[F, CKHKW] where HO in {_HO}, WO in {_WO} \
", { "input0": {"dtype": "float32", "shape": [_N, _CI, _H, _W]}, "input1": {"dtype": "float32", "shape": [_CO, _CI * _KH * _KW]}})' antares
# ConvWinograd_3x3 (_KH = _KW = 3, _SH = _SW = 1, _PH = _PW = 0)
COMPUTE_V1='- _N, _CI, _H, _W, _CO = 16, 64, 32, 32, 256; _HO, _WO = _H - 2, _W - 2; _nH, _nW = (_HO + 1) // 2, (_WO + 1) // 2; _P = _N * _nH * _nW; einstein_v2(f"helper4x3[N, M] = const(1.0).when([N * 3 + M == 0, N * 3 + M == 11], const(0.0).when([N * 3 + M == 1, N * 3 + M == 2, N * 3 + M == 9, N * 3 + M == 10], const(-0.5).when([N * 3 + M == 4], 0.5, merge_op=`any`), merge_op=`any`), merge_op=`any`) where N in 4, M in 3; transform_filter[EPS, NU, CI, CO] +=! ((input1[CO, CI, Rkh, Rkw] * helper4x3[EPS, Rkh] * helper4x3[NU, Rkw])); input_tile[C, B, EPS, NU] = input0[B // ({_nH} * {_nW}), C, B // {_nW} % {_nH} * 2 + EPS, B % {_nW} * 2 + NU] where C in {_CI}, B in {_P}, EPS in 4, NU in 4; helper4x4[N, M] = const(1.0).when([N * 4 + M == 0, N * 4 + M == 6, N * 4 + M == 9, N * 4 + M == 10, N * 4 + M == 15], const(-1.0).when([N * 4 + M == 5, N * 4 + M == 7, N * 4 + M == 8], 0.0, merge_op=`any`), merge_op=`any`) where N in 4, M in 4; transform_input[EPS, NU, C, B] +=! input_tile[C, B, K1, K2] * helper4x4[K1, EPS] * helper4x4[K2, NU] where EPS in 4, NU in 4, C in {_CI}, B in {_P}; batch_gemm[EPS, NU, K, B] +=! transform_filter[EPS, NU, CI, K] * transform_input[EPS, NU, CI, B] where EPS in 4, NU in 4, K in {_CO}, B in {_P}; helper4x2[N, M] = const(0.0) .when([N * 2 + M == 1, N * 2 + M == 6], const(-1.0).when([N * 2 + M == 3], 1.0, merge_op=`any`), merge_op=`any`) where N in 4, M in 2; inverse[K, B, VH, VW] +=! batch_gemm[K1, K2, K, B] * helper4x2[K1, VH] * helper4x2[K2, VW] where K in {_CO}, B in {_P}, VH in 2, VW in 2; output0[N, K, H, W] = inverse[K, N * {_nH} * {_nW} + H // 2 * {_nW} + W // 2, H % 2, W % 2] where N in {_N}, K in {_CO}, H in {_HO}, W in {_WO}", {"input0": {"dtype": "float32", "shape": [_N, _CI, _H, _W]}, "input1": {"dtype": "float32", "shape": [_CO, _CI, 3, 3]}})' antares
# ConvWinograd_3x3 with external helper matrix
COMPUTE_V1='- _N, _CI, _H, _W, _CO = 16, 64, 32, 32, 256; _HO, _WO = _H - 2, _W - 2; _nH, _nW = (_HO + 1) // 2, (_WO + 1) // 2; _P = _N * _nH * _nW; einstein_v2(f""" \
transform_filter[EPS, NU, CI, CO] +=! ((input1[CO, CI, Rkh, Rkw] * helper4x3[EPS, Rkh] * helper4x3[NU, Rkw])); \
input_tile[C, B, EPS, NU] = input0[B // ({_nH} * {_nW}), C, B // {_nW} % {_nH} * 2 + EPS, B % {_nW} * 2 + NU] where C in {_CI}, B in {_P}, EPS in 4, NU in 4; \
transform_input[EPS, NU, C, B] +=! input_tile[C, B, K1, K2] * helper4x4[K1, EPS] * helper4x4[K2, NU] where EPS in 4, NU in 4, C in {_CI}, B in {_P}; \
batch_gemm[EPS, NU, K, B] +=! transform_filter[EPS, NU, CI, K] * transform_input[EPS, NU, CI, B] where EPS in 4, NU in 4, K in {_CO}, B in {_P}; \
inverse[K, B, VH, VW] +=! batch_gemm[K1, K2, K, B] * helper4x2[K1, VH] * helper4x2[K2, VW] where K in {_CO}, B in {_P}, VH in 2, VW in 2; \
output0[N, K, H, W] = inverse[K, N * {_nH} * {_nW} + H // 2 * {_nW} + W // 2, H % 2, W % 2] where N in {_N}, K in {_CO}, H in {_HO}, W in {_WO} \
""", {"input0": {"dtype": "float32", "shape": [_N, _CI, _H, _W]}, "input1": {"dtype": "float32", "shape": [_CO, _CI, 3, 3]}, "helper4x2": {"dtype": "float32", "shape": [4, 2]}, "helper4x3": {"dtype": "float32", "shape": [4, 3]}, "helper4x4": {"dtype": "float32", "shape": [4, 4]}})' antares
# DepthToSpace
COMPUTE_V1='- einstein_v2("output0[N, H, C0, W, C1, C2] = input0[N, H, W, C0, C1, C2]", input_dict={"input0": {"dtype": "float32", "shape": [1, 256, 256, 2, 2, 4]}})' antares
# DepthwiseConv
COMPUTE_V1='- einstein_v2("output0[N, C, HO, WO] +=! input0[N, C, HO + KH, WO + KW] * input1[KH, KW, C, 0] where HO in 30, WO in 30", input_dict={"input0": {"dtype": "float32", "shape": [32, 16, 32, 32]}, "input1": {"dtype": "float32", "shape": [3, 3, 16, 1]}})' antares
# Slice
COMPUTE_V1='- einstein_v2("output0[N, F] = input0[N, F, 2]", input_dict={"input0": {"dtype": "float32", "shape": [1, 16, 32]}})' antares
# Concat
COMPUTE_V1='- einstein_v2("output0[N, F] = input0[N, F].when([F < 128], input1[N, F - 128]) where F in 256", input_dict={"input0": {"dtype": "float32", "shape": [4, 128]}, "input1": {"dtype": "float32", "shape": [4, 128]}})' antares
# OneHot
COMPUTE_V1='- einstein_v2("output0[N, F] = const(1.0).when([input0[N] == F], const(0.0)) where F in 128", input_dict={"input0": {"dtype": "int32", "shape": [4]}})' antares
# Take
COMPUTE_V1='- einstein_v2("output0[F, C] = input0[input1[F], C]", input_dict={"input0": {"dtype": "float32", "shape": [30528, 1024]}, "input1": {"dtype": "int32", "shape": [3072]}})' antares
# Gather
COMPUTE_V1='- einstein_v2("output0[N, F] = input0[input1[N, F]]", input_dict={"input0": {"dtype": "float32", "shape": [65536]}, "input1": {"dtype": "int32", "shape": [4, 64]}})' antares
# Pad
COMPUTE_V1='- einstein_v2("output0[N, C, HO, WO] = input0[N, C, -1 + HO, -1 + WO].when([-1 + HO >= 0, -1 + HO < 32, -1 + WO >= 0, -1 + WO < 32], 0.0) where HO in 34, WO in 34", input_dict={"input0": {"dtype": "float32", "shape": [32, 3, 32, 32]}})' antares
# DivNoNan
COMPUTE_V1='- einstein_v2("output0[N] = (input0[N] / input1[N]).when([input1[N] != 0], 0.0)", input_dict={"input0": {"dtype": "float32", "shape": [32 * 1024]}, "input1": {"dtype": "float32", "shape": [32 * 1024]}})' antares
# MaxPool
COMPUTE_V1='- einstein_v2("output0[N, C, HO, WO] >=! input0[N, C, HO * 2 + KH, WO * 2 + KW] where HO in 6, WO in 6, KW in 2, KH in 2", input_dict={"input0": {"dtype": "float32", "shape": [32, 3, 12, 12]}})' antares
# AvgPool
COMPUTE_V1='- einstein_v2("output0[NC, HO, WO] +=! input0[NC, HO * 3 + KH, WO * 3 + KW] / 9.0 where HO in 85, WO in 85, KW in 3, KH in 3", input_dict={"input0": {"dtype": "float32", "shape": [1024, 255, 255]}})' antares
# Tile
COMPUTE_V1='- einstein_v2("output0[ON, OC] = input0[ON % 2, OC % 16] where ON in 1024, OC in 4096", input_dict={"input0": {"dtype": "float32", "shape": [2, 16]}})' antares
# Softmax
COMPUTE_V1='- einstein_v2("temp0[N] >=! input0[N, C]; temp1[N] +=! (input0[N, C] - temp0[N]).call(`exp`); output0[N, C] = (input0[N, C] - temp0[N]).call(`exp`) / temp1[N]", { "input0": {"dtype": "float32", "shape": [32, 1024]} })' antares
# BatchNorm Inference
COMPUTE_V1='- einstein_v2("output0[N, C, H, W] = bias[C] + scale[C] * (input0[N, C, H, W] - mean[C]) * (variance[C] + 1e-5).call(`rsqrt`)", input_dict={"input0": {"dtype": "float32", "shape": [16, 256, 16, 16]}, "mean": {"dtype": "float32", "shape": [256]}, "variance": {"dtype": "float32", "shape": [256]}, "scale": {"dtype": "float32", "shape": [256]}, "bias": {"dtype": "float32", "shape": [256]} })' antares
# LayerNorm
COMPUTE_V1='- einstein_v2("temp0[N] +=! input0[N, C]; temp1[N] +=! input0[N, C] * input0[N, C]; output0[N, C] = (input0[N, C] * C.val() - temp0[N]) * (temp1[N] * C.val() - temp0[N] * temp0[N] + 1e-5).call(`rsqrt`)", { "input0": {"dtype": "float32", "shape": [32, 1024]} })' antares
# InstanceNorm
COMPUTE_V1='- einstein_v2("mediate0[N, C] +=! input0[N, C, I]; mediate1[N, C] +=! input0[N, C, I] * input0[N, C, I]; output0[N, C, I] = input2[C] + input1[C] * (input0[N, C, I] * I.val() - mediate0[N, C]) / (mediate1[N, C] * I.val() - mediate0[N, C] * mediate0[N, C] + 1e-5).call(`sqrt`)", input_dict={"input0" : { "dtype" : "float32", "shape" : [2, 32, 40960]} , "input1" : { "dtype" : "float32", "shape" : [32]} , "input2" : { "dtype" : "float32", "shape" : [32]}})' antares
# Logical Bool Operation
COMPUTE_V1='- einstein_v2("output0[N, M] = input0[N, M] & ~input1[N, M]", { "input0": {"dtype": "int8", "shape": [1024, 512]}, "input1": {"dtype": "int8", "shape": [1024, 512]} })' antares
# Sigmoid
COMPUTE_V1='- einstein_v2("output0[N, M] = 1.0 / (1.0 + (-input0[N, M]).call(`exp`))", { "input0": {"dtype": "float32", "shape": [1024, 512]} })' antares
# Conv2D Transpose
COMPUTE_V1='- _N, _CI, _CO, _H, _W, _KH, _KW, _PH, _PW = 1, 4, 8, 5, 5, 3, 3, 0, 0; _HO, _WO = (_H + _KH - _PH * 2) - 1, (_W + _KW - _PW * 2) - 1; ACCESS_H, ACCESS_W = f"((HO - KH) + {_PH})", f"((WO - KW) + {_PW})"; einstein_v2(f"output0[N, F, HO, WO] +=! input0[N, C, {ACCESS_H}, {ACCESS_W}].when([{ACCESS_H} >= 0, {ACCESS_H} < {_H}, {ACCESS_W} >= 0, {ACCESS_W} < {_W}], const(0, input0.dtype())) * input1[C, F, KH, KW] where HO in {_HO}, WO in {_WO}", { "input0": {"dtype": "float32", "shape": [_N, _CI, _H, _W]}, "input1": {"dtype": "float32", "shape": [_CI, _CO, _KH, _KW]}})' antares
# AddMatMul Head Fusion
COMPUTE_V1='- einstein_v2("temp0[K, N] = input0[N, K] + 100; output0[N, M] +=! temp0[K, N] * input1[K, M] where K in 10", { "input0": {"dtype": "float32", "shape": [1024, 512]}, "input1": {"dtype": "float32", "shape": [512, 512]}})' antares
# ConvBiasRelu Tail Fusion
COMPUTE_V1='- einstein_v2("conv_out[N, F, HO, WO] +=! input0[N, C, HO + KH, WO + KW] * input1[KH, KW, C, F] where HO in 256, WO in 256; conv_bias[N, F, HO, WO] = conv_out[N, F, HO, WO] + input2[0, 0, 0, F]; output0[N, F, HO, WO] = conv_bias[N, F, HO, WO].when(conv_bias[N, F, HO, WO] > 0.0, 0.0)", input_dict={"input0": {"dtype": "float32", "shape": [1, 16, 256, 256]}, "input1": {"dtype": "float32", "shape": [1, 1, 16, 16]}, "input2": {"dtype": "float32", "shape": [1, 1, 1, 16]}})' antares
# Scatter4D
COMPUTE_V1='- _B, _M = 2, 8; einstein_v2("data[indices[B, 0], indices[B, 1], indices[B, 2], indices[B, 3], M] =. updates[B, M]", input_dict={"data": {"dtype": "float32", "shape": [32, 32, 32, 32, _M]}, "indices": {"dtype": "int32", "shape": [_B, 4]}, "updates": {"dtype": "float32", "shape": [_B, _M]}})' antares
# Vectorized Boolean Check
COMPUTE_V1='- einstein_v2("output0[N] = ((input0[N, 0] == 0).cast(`int32`) << 24) + ((input0[N, 1] == 0).cast(`int32`) << 16) + ((input0[N, 2] == 0).cast(`int32`) << 8) + (input0[N, 3] == 0).cast(`int32`)", { "input0": {"dtype": "int32", "shape": [1000, 4]} })' antares