partially support matmul broadcast.
This commit is contained in:
Родитель
2964f1fde5
Коммит
6ed965924c
|
@ -3535,27 +3535,116 @@ namespace CNTK
|
|||
// we implement this import step-by-step.
|
||||
// 1. Both 2-D inputs.
|
||||
// 2. Both N-D inputs.
|
||||
// 3. Add broadcast. (Not supported in ONNX/CNTK yet)
|
||||
// 3. Add broadcast. (Partially supported in ONNX/CNTK. )
|
||||
//
|
||||
// Note: we don't support free/inferred dimension for N-D MatMul currently.
|
||||
|
||||
const NDShape& inputShape0 = leftOperand.Shape();
|
||||
const NDShape& inputShape1 = rightOperand.Shape();
|
||||
|
||||
auto isBothNDInputs = [&]() -> bool {
|
||||
if (inputShape0.Rank() != inputShape1.Rank() || inputShape0.Rank() <= 2) return false;
|
||||
// In this case we don't require broadcast, thus prefix dimensions should match.
|
||||
return inputShape0.SubShape(2) == inputShape1.SubShape(2);
|
||||
};
|
||||
NDShape inputShape0 = leftOperand.Shape();
|
||||
NDShape inputShape1 = rightOperand.Shape();
|
||||
|
||||
Variable leftOperandPlaceholder = PlaceholderVariable(inputShape0, L"leftOperand", {});
|
||||
Variable rightOperandPlaceholder = PlaceholderVariable(inputShape1, L"rightOperand", {});
|
||||
Variable operand0 = leftOperandPlaceholder;
|
||||
Variable operand1 = rightOperandPlaceholder;
|
||||
FunctionPtr cntkFunction;
|
||||
size_t unitTailLengthToAppend = 0;
|
||||
|
||||
///
|
||||
/// Preprocessing
|
||||
///
|
||||
|
||||
const bool isBothNDInputs = [&]() -> bool {
|
||||
if (inputShape0.Rank() != inputShape1.Rank() || inputShape0.Rank() <= 2) return false;
|
||||
// In this case we don't require broadcast, thus prefix dimensions should match.
|
||||
return inputShape0.SubShape(2) == inputShape1.SubShape(2);
|
||||
}();
|
||||
if (!isBothNDInputs)
|
||||
{
|
||||
// Do check for case 3. If total size of tail shape of longer operand is 1,
|
||||
// we can handle this broadcast case by removing the tail shape and append to result later.
|
||||
const bool isSimpleBroadcast = [&]() -> bool {
|
||||
// Some situation needs broadcast.
|
||||
// Check if this is simple broadcast: that we only need to append 1 to the shorter shape.
|
||||
// e.g.
|
||||
// input_0: [b, a, 2, 1, 1]
|
||||
// input_1: [c, b, 2]
|
||||
// ==> broadcast: remove tail shape [1, 1] from input_0
|
||||
// input_0: [b, a, 2]
|
||||
//
|
||||
if (inputShape0.Rank() < 2 || inputShape1.Rank() < 2) return false;
|
||||
Variable *shortOperand = &operand1;
|
||||
Variable *longOperand = &operand0;
|
||||
if (inputShape0.Rank() < inputShape1.Rank()) {
|
||||
std::swap(shortOperand, longOperand);
|
||||
}
|
||||
const NDShape& tailShape = longOperand->Shape().SubShape(shortOperand->Shape().Rank());
|
||||
if (tailShape.TotalSize() != 1) return false;
|
||||
return true;
|
||||
}();
|
||||
if (!isSimpleBroadcast)
|
||||
LogicError("MatMul: Complex form of broadcasting is currently not supported in ONNX/CNTK.");
|
||||
|
||||
// Remove tail shape of {1, ..., 1}
|
||||
if (inputShape0.Rank() < inputShape1.Rank())
|
||||
{
|
||||
unitTailLengthToAppend = inputShape1.Rank() - inputShape0.Rank();
|
||||
inputShape1 = inputShape1.SubShape(0, inputShape0.Rank());
|
||||
operand1 = Reshape(operand1, inputShape1)->Output();
|
||||
}
|
||||
else
|
||||
{
|
||||
unitTailLengthToAppend = inputShape0.Rank() - inputShape1.Rank();
|
||||
inputShape0 = inputShape0.SubShape(0, inputShape1.Rank());
|
||||
operand0 = Reshape(operand0, inputShape0)->Output();
|
||||
}
|
||||
}
|
||||
|
||||
// Do check for case 2. We can reduce tail shape of both operand if they are both 1.
|
||||
const size_t sharedUnitTailLength = [&]() -> size_t {
|
||||
assert(inputShape0.Rank() == inputShape1.Rank());
|
||||
size_t sharedUnitTailLength = 0;
|
||||
for (size_t i = inputShape0.Rank() - 1; i >= 2; --i)
|
||||
{
|
||||
if (inputShape0[i] == 1 && inputShape1[i] == 1)
|
||||
sharedUnitTailLength++;
|
||||
else
|
||||
break;
|
||||
}
|
||||
return sharedUnitTailLength;
|
||||
}();
|
||||
if (sharedUnitTailLength > 0)
|
||||
{
|
||||
// e.g.
|
||||
// input_0: [b, a, 2, 1, 1]
|
||||
// input_1: [c, b, 2, 1 ,1]
|
||||
// ==> remove common tail shape [1, 1] from both input
|
||||
// input_0: [b, a, 2]
|
||||
// input_1: [c, b, 2]
|
||||
//
|
||||
inputShape0 = inputShape0.SubShape(0, inputShape0.Rank() - sharedUnitTailLength);
|
||||
inputShape1 = inputShape1.SubShape(0, inputShape1.Rank() - sharedUnitTailLength);
|
||||
operand0 = Reshape(operand0, inputShape0)->Output();
|
||||
operand1 = Reshape(operand1, inputShape1)->Output();
|
||||
unitTailLengthToAppend += sharedUnitTailLength;
|
||||
}
|
||||
|
||||
///
|
||||
/// After preprocessing, inputs are either both 2-D or both N-D.
|
||||
///
|
||||
|
||||
if (inputShape0.Rank() == 2 && inputShape1.Rank() == 2)
|
||||
{
|
||||
// 1. Both 2-D inputs.
|
||||
// CNTK Times has reversed input order than ONNX(numpy) MatMul.
|
||||
cntkFunction = Times(rightOperandPlaceholder, leftOperandPlaceholder);
|
||||
cntkFunction = Times(operand1, operand0);
|
||||
if (unitTailLengthToAppend > 0)
|
||||
{
|
||||
const NDShape& outputShape{ inputShape1[0], inputShape0[1] };
|
||||
const NDShape& tailShape(std::vector<size_t>(unitTailLengthToAppend, 1));
|
||||
cntkFunction = Reshape(cntkFunction, outputShape.AppendShape(tailShape));
|
||||
}
|
||||
}
|
||||
else if (isBothNDInputs())
|
||||
else
|
||||
{
|
||||
// 2. Both N-D inputs.
|
||||
// Convert inputs into CNTK style:
|
||||
|
@ -3590,9 +3679,9 @@ namespace CNTK
|
|||
{
|
||||
inputPrefixProd *= inputShape0[i];
|
||||
}
|
||||
FunctionPtr input0CNTK = Reshape(leftOperandPlaceholder, { bDim, inputPrefixProd * aDim });
|
||||
FunctionPtr input0CNTK = Reshape(operand0, { bDim, inputPrefixProd * aDim });
|
||||
// 2)
|
||||
FunctionPtr input1CNTK = Reshape(rightOperandPlaceholder, { cDim, bDim, inputPrefixProd });
|
||||
FunctionPtr input1CNTK = Reshape(operand1, { cDim, bDim, inputPrefixProd });
|
||||
// 3)
|
||||
input1CNTK = TransposeAxes(input1CNTK, Axis(1), Axis(2));
|
||||
// 4)
|
||||
|
@ -3612,11 +3701,12 @@ namespace CNTK
|
|||
// 10)
|
||||
const NDShape& outputShape = NDShape({ cDim, aDim }).AppendShape(inputShape0.SubShape(2));
|
||||
cntkFunction = Reshape(outputCNTK, outputShape);
|
||||
}
|
||||
else
|
||||
{
|
||||
// 3. broadcast.
|
||||
LogicError("MatMul: broadcasting is currently not supported in ONNX/CNTK.");
|
||||
|
||||
if (unitTailLengthToAppend > 0)
|
||||
{
|
||||
const NDShape& tailShape(std::vector<size_t>(unitTailLengthToAppend, 1));
|
||||
cntkFunction = Reshape(cntkFunction, outputShape.AppendShape(tailShape));
|
||||
}
|
||||
}
|
||||
|
||||
return AsBlock(std::move(cntkFunction),
|
||||
|
|
|
@ -554,6 +554,9 @@ Test module "V2LibraryTests" has passed with:
|
|||
Test case "FunctionSuite/TestSettingRandomSeed" has passed with:
|
||||
2 assertions out of 2 passed
|
||||
|
||||
Test case "FunctionSuite/MatMul" has passed with:
|
||||
4 assertions out of 4 passed
|
||||
|
||||
Test suite "NDArrayViewSuite" has passed with:
|
||||
8 test cases out of 8 passed
|
||||
66039 assertions out of 66039 passed
|
||||
|
|
|
@ -1096,6 +1096,110 @@ void SetRandomSeed(const DeviceDescriptor& device)
|
|||
FloatingPointVectorCompare(result2, result4, "SetRandomSeed: output does match the expected after resetting the dropout seed.");
|
||||
}
|
||||
|
||||
void TestMatMul(const DeviceDescriptor& device)
|
||||
{
|
||||
srand(1);
|
||||
auto diff_size = [](const std::vector<size_t>& a, const std::vector<size_t>& b)
|
||||
{
|
||||
bool foundDifference = false;
|
||||
if (a.size() != b.size()) return true;
|
||||
for (int i = 0; !foundDifference && i < a.size() && i < b.size(); ++i)
|
||||
{
|
||||
foundDifference = (a[i] != b[i]);
|
||||
}
|
||||
return foundDifference;
|
||||
};
|
||||
|
||||
std::vector<std::vector<size_t>> inputShapeVec0{{3, 4}, {3, 4, 2, 2}, {64, 4, 1}, {64, NDShape::InferredDimension, 1}};
|
||||
std::vector<std::vector<size_t>> inputShapeVec1{{5, 3}, {5, 3, 2, 2}, {2, 64}, {2, 64}};
|
||||
std::vector<std::vector<size_t>> outShapeVec{{5, 4}, {5, 4, 2, 2}, {2, 4, 1}, {2, NDShape::InferredDimension, 1}};
|
||||
std::vector<std::vector<size_t>> inputValueShapeVec0{ { 3, 4 },{ 3, 4, 2, 2 },{ 64, 4, 1 },{ 64, 4, 1 } };
|
||||
std::vector<std::vector<size_t>> inputValueShapeVec1{ { 5, 3 },{ 5, 3, 2, 2 },{ 2, 64 },{ 2, 64 } };
|
||||
std::vector<std::vector<size_t>> outValueShapeVec{ { 5, 4 },{ 5, 4, 2, 2 },{ 2, 4, 1 },{ 2, 4, 1 } };
|
||||
std::vector<size_t> inputTotalSizeVec0 = { 12, 48, 256, 256 };
|
||||
std::vector<size_t> inputTotalSizeVec1 = { 15, 60, 128, 128 };
|
||||
std::vector<size_t> outputTotalSizeVec = { 20, 80, 8, 8 };
|
||||
std::vector<size_t> inputSubSizeVec0 = { 12, 12, 256, 256 };
|
||||
std::vector<size_t> inputSubSizeVec1 = { 15, 15, 128, 128 };
|
||||
std::vector<size_t> outputSubSizeVec = { 20, 20, 8, 8 };
|
||||
|
||||
|
||||
size_t testCases = inputShapeVec0.size();
|
||||
for (size_t test_i = 0; test_i < testCases; ++test_i)
|
||||
{
|
||||
auto shape0 = NDShape(inputShapeVec0[test_i]);
|
||||
auto shape1 = NDShape(inputShapeVec1[test_i]);
|
||||
auto valueShape0 = NDShape(inputValueShapeVec0[test_i]);
|
||||
auto valueShape1 = NDShape(inputValueShapeVec1[test_i]);
|
||||
auto outShape = NDShape(outShapeVec[test_i]);
|
||||
auto outValueShape = NDShape(outValueShapeVec[test_i]);
|
||||
|
||||
size_t inputTotalSize0 = inputTotalSizeVec0[test_i];
|
||||
size_t inputTotalSize1 = inputTotalSizeVec1[test_i];
|
||||
size_t outputTotalSize = outputTotalSizeVec[test_i];
|
||||
size_t inputSubSize0 = inputSubSizeVec0[test_i];
|
||||
size_t inputSubSize1 = inputSubSizeVec1[test_i];
|
||||
size_t outputSubSize = outputSubSizeVec[test_i];
|
||||
|
||||
auto input0 = InputVariable(shape0, DataType::Float);
|
||||
auto input1 = InputVariable(shape1, DataType::Float);
|
||||
auto result = ::CNTK::Internal::MatMul(input0, input1);
|
||||
|
||||
std::vector<float> inputData0(inputTotalSize0);
|
||||
std::vector<float> inputData1(inputTotalSize1);
|
||||
for (size_t i = 0; i < inputData0.size(); ++i)
|
||||
inputData0[i] = ((float)rand()) / RAND_MAX;
|
||||
for (size_t i = 0; i < inputData1.size(); ++i)
|
||||
inputData1[i] = ((float)rand()) / RAND_MAX;
|
||||
|
||||
ValuePtr inputValue0 = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(valueShape0.AppendShape({1,1}), inputData0, true));
|
||||
ValuePtr inputValue1 = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(valueShape1.AppendShape({ 1,1 }), inputData1, true));
|
||||
|
||||
NDShape outputShape = result->Output().Shape();
|
||||
BOOST_TEST(!diff_size(outShape.Dimensions(), outputShape.Dimensions()));
|
||||
std::vector<float> outputData(outputTotalSize);
|
||||
ValuePtr outputValue = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(outValueShape.AppendShape({ 1,1 }), outputData, false));
|
||||
|
||||
std::unordered_map<Variable, ValuePtr> outputs = {{result->Output(), outputValue}};
|
||||
result->Forward({{input0, inputValue0}, {input1, inputValue1}}, outputs, device);
|
||||
|
||||
std::vector<float> expectedOutputValues(outputTotalSize);
|
||||
{
|
||||
for (size_t i = 0; i < outputTotalSize / outputSubSize; i++)
|
||||
{
|
||||
std::vector<float> inputTimesData0(inputSubSize0);
|
||||
std::vector<float> inputTimesData1(inputSubSize1);
|
||||
std::vector<float> outTimesData(outputSubSize);
|
||||
auto inputTimes0 = InputVariable(shape0.SubShape(0, 2), DataType::Float);
|
||||
auto inputTimes1 = InputVariable(shape1.SubShape(0, 2), DataType::Float);
|
||||
auto timesResult = Times(inputTimes1, inputTimes0);
|
||||
|
||||
for (size_t j = 0; j < inputSubSize0; ++j)
|
||||
{
|
||||
inputTimesData0[j] = inputData0[i * inputSubSize0 + j];
|
||||
}
|
||||
for (size_t j = 0; j < inputSubSize1; ++j)
|
||||
{
|
||||
inputTimesData1[j] = inputData1[i * inputSubSize1 + j];
|
||||
}
|
||||
|
||||
ValuePtr inputTimesValue0 = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(valueShape0.SubShape(0, 2).AppendShape({ 1,1 }), inputTimesData0, true));
|
||||
ValuePtr inputTimesValue1 = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(valueShape1.SubShape(0, 2).AppendShape({ 1,1 }), inputTimesData1, true));
|
||||
ValuePtr outputTimesValue = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(outValueShape.SubShape(0, 2).AppendShape({ 1,1 }), outTimesData, false));
|
||||
std::unordered_map<Variable, ValuePtr> timesOutputs = { {timesResult->Output(), outputTimesValue}};
|
||||
timesResult->Forward({{inputTimes0, inputTimesValue0 }, {inputTimes1, inputTimesValue1 }}, timesOutputs, device);
|
||||
|
||||
for (size_t j = 0; j < outputSubSize; ++j)
|
||||
{
|
||||
expectedOutputValues[i * outputSubSize + j] = outTimesData[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
FloatingPointVectorCompare(outputData, expectedOutputValues, "TestMatMul: Forward prop results do not match expected results.");
|
||||
}
|
||||
}
|
||||
|
||||
BOOST_AUTO_TEST_SUITE(FunctionSuite)
|
||||
|
||||
BOOST_AUTO_TEST_CASE(FindNameInCPU)
|
||||
|
@ -1222,6 +1326,13 @@ BOOST_AUTO_TEST_CASE(TestSettingRandomSeed)
|
|||
SetRandomSeed(DeviceDescriptor::GPUDevice(0));
|
||||
}
|
||||
|
||||
BOOST_AUTO_TEST_CASE(MatMul)
|
||||
{
|
||||
if (ShouldRunOnCpu())
|
||||
TestMatMul(DeviceDescriptor::CPUDevice());
|
||||
if (ShouldRunOnGpu())
|
||||
TestMatMul(DeviceDescriptor::GPUDevice(0));
|
||||
}
|
||||
|
||||
BOOST_AUTO_TEST_SUITE_END()
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче