partially support matmul broadcast.
This commit is contained in:
Родитель
2964f1fde5
Коммит
6ed965924c
|
@ -3535,27 +3535,116 @@ namespace CNTK
|
||||||
// we implement this import step-by-step.
|
// we implement this import step-by-step.
|
||||||
// 1. Both 2-D inputs.
|
// 1. Both 2-D inputs.
|
||||||
// 2. Both N-D inputs.
|
// 2. Both N-D inputs.
|
||||||
// 3. Add broadcast. (Not supported in ONNX/CNTK yet)
|
// 3. Add broadcast. (Partially supported in ONNX/CNTK. )
|
||||||
|
//
|
||||||
|
// Note: we don't support free/inferred dimension for N-D MatMul currently.
|
||||||
|
|
||||||
const NDShape& inputShape0 = leftOperand.Shape();
|
NDShape inputShape0 = leftOperand.Shape();
|
||||||
const NDShape& inputShape1 = rightOperand.Shape();
|
NDShape inputShape1 = rightOperand.Shape();
|
||||||
|
|
||||||
auto isBothNDInputs = [&]() -> bool {
|
|
||||||
if (inputShape0.Rank() != inputShape1.Rank() || inputShape0.Rank() <= 2) return false;
|
|
||||||
// In this case we don't require broadcast, thus prefix dimensions should match.
|
|
||||||
return inputShape0.SubShape(2) == inputShape1.SubShape(2);
|
|
||||||
};
|
|
||||||
|
|
||||||
Variable leftOperandPlaceholder = PlaceholderVariable(inputShape0, L"leftOperand", {});
|
Variable leftOperandPlaceholder = PlaceholderVariable(inputShape0, L"leftOperand", {});
|
||||||
Variable rightOperandPlaceholder = PlaceholderVariable(inputShape1, L"rightOperand", {});
|
Variable rightOperandPlaceholder = PlaceholderVariable(inputShape1, L"rightOperand", {});
|
||||||
|
Variable operand0 = leftOperandPlaceholder;
|
||||||
|
Variable operand1 = rightOperandPlaceholder;
|
||||||
FunctionPtr cntkFunction;
|
FunctionPtr cntkFunction;
|
||||||
|
size_t unitTailLengthToAppend = 0;
|
||||||
|
|
||||||
|
///
|
||||||
|
/// Preprocessing
|
||||||
|
///
|
||||||
|
|
||||||
|
const bool isBothNDInputs = [&]() -> bool {
|
||||||
|
if (inputShape0.Rank() != inputShape1.Rank() || inputShape0.Rank() <= 2) return false;
|
||||||
|
// In this case we don't require broadcast, thus prefix dimensions should match.
|
||||||
|
return inputShape0.SubShape(2) == inputShape1.SubShape(2);
|
||||||
|
}();
|
||||||
|
if (!isBothNDInputs)
|
||||||
|
{
|
||||||
|
// Do check for case 3. If total size of tail shape of longer operand is 1,
|
||||||
|
// we can handle this broadcast case by removing the tail shape and append to result later.
|
||||||
|
const bool isSimpleBroadcast = [&]() -> bool {
|
||||||
|
// Some situation needs broadcast.
|
||||||
|
// Check if this is simple broadcast: that we only need to append 1 to the shorter shape.
|
||||||
|
// e.g.
|
||||||
|
// input_0: [b, a, 2, 1, 1]
|
||||||
|
// input_1: [c, b, 2]
|
||||||
|
// ==> broadcast: remove tail shape [1, 1] from input_0
|
||||||
|
// input_0: [b, a, 2]
|
||||||
|
//
|
||||||
|
if (inputShape0.Rank() < 2 || inputShape1.Rank() < 2) return false;
|
||||||
|
Variable *shortOperand = &operand1;
|
||||||
|
Variable *longOperand = &operand0;
|
||||||
|
if (inputShape0.Rank() < inputShape1.Rank()) {
|
||||||
|
std::swap(shortOperand, longOperand);
|
||||||
|
}
|
||||||
|
const NDShape& tailShape = longOperand->Shape().SubShape(shortOperand->Shape().Rank());
|
||||||
|
if (tailShape.TotalSize() != 1) return false;
|
||||||
|
return true;
|
||||||
|
}();
|
||||||
|
if (!isSimpleBroadcast)
|
||||||
|
LogicError("MatMul: Complex form of broadcasting is currently not supported in ONNX/CNTK.");
|
||||||
|
|
||||||
|
// Remove tail shape of {1, ..., 1}
|
||||||
|
if (inputShape0.Rank() < inputShape1.Rank())
|
||||||
|
{
|
||||||
|
unitTailLengthToAppend = inputShape1.Rank() - inputShape0.Rank();
|
||||||
|
inputShape1 = inputShape1.SubShape(0, inputShape0.Rank());
|
||||||
|
operand1 = Reshape(operand1, inputShape1)->Output();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
unitTailLengthToAppend = inputShape0.Rank() - inputShape1.Rank();
|
||||||
|
inputShape0 = inputShape0.SubShape(0, inputShape1.Rank());
|
||||||
|
operand0 = Reshape(operand0, inputShape0)->Output();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do check for case 2. We can reduce tail shape of both operand if they are both 1.
|
||||||
|
const size_t sharedUnitTailLength = [&]() -> size_t {
|
||||||
|
assert(inputShape0.Rank() == inputShape1.Rank());
|
||||||
|
size_t sharedUnitTailLength = 0;
|
||||||
|
for (size_t i = inputShape0.Rank() - 1; i >= 2; --i)
|
||||||
|
{
|
||||||
|
if (inputShape0[i] == 1 && inputShape1[i] == 1)
|
||||||
|
sharedUnitTailLength++;
|
||||||
|
else
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return sharedUnitTailLength;
|
||||||
|
}();
|
||||||
|
if (sharedUnitTailLength > 0)
|
||||||
|
{
|
||||||
|
// e.g.
|
||||||
|
// input_0: [b, a, 2, 1, 1]
|
||||||
|
// input_1: [c, b, 2, 1 ,1]
|
||||||
|
// ==> remove common tail shape [1, 1] from both input
|
||||||
|
// input_0: [b, a, 2]
|
||||||
|
// input_1: [c, b, 2]
|
||||||
|
//
|
||||||
|
inputShape0 = inputShape0.SubShape(0, inputShape0.Rank() - sharedUnitTailLength);
|
||||||
|
inputShape1 = inputShape1.SubShape(0, inputShape1.Rank() - sharedUnitTailLength);
|
||||||
|
operand0 = Reshape(operand0, inputShape0)->Output();
|
||||||
|
operand1 = Reshape(operand1, inputShape1)->Output();
|
||||||
|
unitTailLengthToAppend += sharedUnitTailLength;
|
||||||
|
}
|
||||||
|
|
||||||
|
///
|
||||||
|
/// After preprocessing, inputs are either both 2-D or both N-D.
|
||||||
|
///
|
||||||
|
|
||||||
if (inputShape0.Rank() == 2 && inputShape1.Rank() == 2)
|
if (inputShape0.Rank() == 2 && inputShape1.Rank() == 2)
|
||||||
{
|
{
|
||||||
// 1. Both 2-D inputs.
|
// 1. Both 2-D inputs.
|
||||||
// CNTK Times has reversed input order than ONNX(numpy) MatMul.
|
// CNTK Times has reversed input order than ONNX(numpy) MatMul.
|
||||||
cntkFunction = Times(rightOperandPlaceholder, leftOperandPlaceholder);
|
cntkFunction = Times(operand1, operand0);
|
||||||
|
if (unitTailLengthToAppend > 0)
|
||||||
|
{
|
||||||
|
const NDShape& outputShape{ inputShape1[0], inputShape0[1] };
|
||||||
|
const NDShape& tailShape(std::vector<size_t>(unitTailLengthToAppend, 1));
|
||||||
|
cntkFunction = Reshape(cntkFunction, outputShape.AppendShape(tailShape));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else if (isBothNDInputs())
|
else
|
||||||
{
|
{
|
||||||
// 2. Both N-D inputs.
|
// 2. Both N-D inputs.
|
||||||
// Convert inputs into CNTK style:
|
// Convert inputs into CNTK style:
|
||||||
|
@ -3590,9 +3679,9 @@ namespace CNTK
|
||||||
{
|
{
|
||||||
inputPrefixProd *= inputShape0[i];
|
inputPrefixProd *= inputShape0[i];
|
||||||
}
|
}
|
||||||
FunctionPtr input0CNTK = Reshape(leftOperandPlaceholder, { bDim, inputPrefixProd * aDim });
|
FunctionPtr input0CNTK = Reshape(operand0, { bDim, inputPrefixProd * aDim });
|
||||||
// 2)
|
// 2)
|
||||||
FunctionPtr input1CNTK = Reshape(rightOperandPlaceholder, { cDim, bDim, inputPrefixProd });
|
FunctionPtr input1CNTK = Reshape(operand1, { cDim, bDim, inputPrefixProd });
|
||||||
// 3)
|
// 3)
|
||||||
input1CNTK = TransposeAxes(input1CNTK, Axis(1), Axis(2));
|
input1CNTK = TransposeAxes(input1CNTK, Axis(1), Axis(2));
|
||||||
// 4)
|
// 4)
|
||||||
|
@ -3612,11 +3701,12 @@ namespace CNTK
|
||||||
// 10)
|
// 10)
|
||||||
const NDShape& outputShape = NDShape({ cDim, aDim }).AppendShape(inputShape0.SubShape(2));
|
const NDShape& outputShape = NDShape({ cDim, aDim }).AppendShape(inputShape0.SubShape(2));
|
||||||
cntkFunction = Reshape(outputCNTK, outputShape);
|
cntkFunction = Reshape(outputCNTK, outputShape);
|
||||||
}
|
|
||||||
else
|
if (unitTailLengthToAppend > 0)
|
||||||
{
|
{
|
||||||
// 3. broadcast.
|
const NDShape& tailShape(std::vector<size_t>(unitTailLengthToAppend, 1));
|
||||||
LogicError("MatMul: broadcasting is currently not supported in ONNX/CNTK.");
|
cntkFunction = Reshape(cntkFunction, outputShape.AppendShape(tailShape));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return AsBlock(std::move(cntkFunction),
|
return AsBlock(std::move(cntkFunction),
|
||||||
|
|
|
@ -554,6 +554,9 @@ Test module "V2LibraryTests" has passed with:
|
||||||
Test case "FunctionSuite/TestSettingRandomSeed" has passed with:
|
Test case "FunctionSuite/TestSettingRandomSeed" has passed with:
|
||||||
2 assertions out of 2 passed
|
2 assertions out of 2 passed
|
||||||
|
|
||||||
|
Test case "FunctionSuite/MatMul" has passed with:
|
||||||
|
4 assertions out of 4 passed
|
||||||
|
|
||||||
Test suite "NDArrayViewSuite" has passed with:
|
Test suite "NDArrayViewSuite" has passed with:
|
||||||
8 test cases out of 8 passed
|
8 test cases out of 8 passed
|
||||||
66039 assertions out of 66039 passed
|
66039 assertions out of 66039 passed
|
||||||
|
|
|
@ -1096,6 +1096,110 @@ void SetRandomSeed(const DeviceDescriptor& device)
|
||||||
FloatingPointVectorCompare(result2, result4, "SetRandomSeed: output does match the expected after resetting the dropout seed.");
|
FloatingPointVectorCompare(result2, result4, "SetRandomSeed: output does match the expected after resetting the dropout seed.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TestMatMul(const DeviceDescriptor& device)
|
||||||
|
{
|
||||||
|
srand(1);
|
||||||
|
auto diff_size = [](const std::vector<size_t>& a, const std::vector<size_t>& b)
|
||||||
|
{
|
||||||
|
bool foundDifference = false;
|
||||||
|
if (a.size() != b.size()) return true;
|
||||||
|
for (int i = 0; !foundDifference && i < a.size() && i < b.size(); ++i)
|
||||||
|
{
|
||||||
|
foundDifference = (a[i] != b[i]);
|
||||||
|
}
|
||||||
|
return foundDifference;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<std::vector<size_t>> inputShapeVec0{{3, 4}, {3, 4, 2, 2}, {64, 4, 1}, {64, NDShape::InferredDimension, 1}};
|
||||||
|
std::vector<std::vector<size_t>> inputShapeVec1{{5, 3}, {5, 3, 2, 2}, {2, 64}, {2, 64}};
|
||||||
|
std::vector<std::vector<size_t>> outShapeVec{{5, 4}, {5, 4, 2, 2}, {2, 4, 1}, {2, NDShape::InferredDimension, 1}};
|
||||||
|
std::vector<std::vector<size_t>> inputValueShapeVec0{ { 3, 4 },{ 3, 4, 2, 2 },{ 64, 4, 1 },{ 64, 4, 1 } };
|
||||||
|
std::vector<std::vector<size_t>> inputValueShapeVec1{ { 5, 3 },{ 5, 3, 2, 2 },{ 2, 64 },{ 2, 64 } };
|
||||||
|
std::vector<std::vector<size_t>> outValueShapeVec{ { 5, 4 },{ 5, 4, 2, 2 },{ 2, 4, 1 },{ 2, 4, 1 } };
|
||||||
|
std::vector<size_t> inputTotalSizeVec0 = { 12, 48, 256, 256 };
|
||||||
|
std::vector<size_t> inputTotalSizeVec1 = { 15, 60, 128, 128 };
|
||||||
|
std::vector<size_t> outputTotalSizeVec = { 20, 80, 8, 8 };
|
||||||
|
std::vector<size_t> inputSubSizeVec0 = { 12, 12, 256, 256 };
|
||||||
|
std::vector<size_t> inputSubSizeVec1 = { 15, 15, 128, 128 };
|
||||||
|
std::vector<size_t> outputSubSizeVec = { 20, 20, 8, 8 };
|
||||||
|
|
||||||
|
|
||||||
|
size_t testCases = inputShapeVec0.size();
|
||||||
|
for (size_t test_i = 0; test_i < testCases; ++test_i)
|
||||||
|
{
|
||||||
|
auto shape0 = NDShape(inputShapeVec0[test_i]);
|
||||||
|
auto shape1 = NDShape(inputShapeVec1[test_i]);
|
||||||
|
auto valueShape0 = NDShape(inputValueShapeVec0[test_i]);
|
||||||
|
auto valueShape1 = NDShape(inputValueShapeVec1[test_i]);
|
||||||
|
auto outShape = NDShape(outShapeVec[test_i]);
|
||||||
|
auto outValueShape = NDShape(outValueShapeVec[test_i]);
|
||||||
|
|
||||||
|
size_t inputTotalSize0 = inputTotalSizeVec0[test_i];
|
||||||
|
size_t inputTotalSize1 = inputTotalSizeVec1[test_i];
|
||||||
|
size_t outputTotalSize = outputTotalSizeVec[test_i];
|
||||||
|
size_t inputSubSize0 = inputSubSizeVec0[test_i];
|
||||||
|
size_t inputSubSize1 = inputSubSizeVec1[test_i];
|
||||||
|
size_t outputSubSize = outputSubSizeVec[test_i];
|
||||||
|
|
||||||
|
auto input0 = InputVariable(shape0, DataType::Float);
|
||||||
|
auto input1 = InputVariable(shape1, DataType::Float);
|
||||||
|
auto result = ::CNTK::Internal::MatMul(input0, input1);
|
||||||
|
|
||||||
|
std::vector<float> inputData0(inputTotalSize0);
|
||||||
|
std::vector<float> inputData1(inputTotalSize1);
|
||||||
|
for (size_t i = 0; i < inputData0.size(); ++i)
|
||||||
|
inputData0[i] = ((float)rand()) / RAND_MAX;
|
||||||
|
for (size_t i = 0; i < inputData1.size(); ++i)
|
||||||
|
inputData1[i] = ((float)rand()) / RAND_MAX;
|
||||||
|
|
||||||
|
ValuePtr inputValue0 = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(valueShape0.AppendShape({1,1}), inputData0, true));
|
||||||
|
ValuePtr inputValue1 = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(valueShape1.AppendShape({ 1,1 }), inputData1, true));
|
||||||
|
|
||||||
|
NDShape outputShape = result->Output().Shape();
|
||||||
|
BOOST_TEST(!diff_size(outShape.Dimensions(), outputShape.Dimensions()));
|
||||||
|
std::vector<float> outputData(outputTotalSize);
|
||||||
|
ValuePtr outputValue = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(outValueShape.AppendShape({ 1,1 }), outputData, false));
|
||||||
|
|
||||||
|
std::unordered_map<Variable, ValuePtr> outputs = {{result->Output(), outputValue}};
|
||||||
|
result->Forward({{input0, inputValue0}, {input1, inputValue1}}, outputs, device);
|
||||||
|
|
||||||
|
std::vector<float> expectedOutputValues(outputTotalSize);
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < outputTotalSize / outputSubSize; i++)
|
||||||
|
{
|
||||||
|
std::vector<float> inputTimesData0(inputSubSize0);
|
||||||
|
std::vector<float> inputTimesData1(inputSubSize1);
|
||||||
|
std::vector<float> outTimesData(outputSubSize);
|
||||||
|
auto inputTimes0 = InputVariable(shape0.SubShape(0, 2), DataType::Float);
|
||||||
|
auto inputTimes1 = InputVariable(shape1.SubShape(0, 2), DataType::Float);
|
||||||
|
auto timesResult = Times(inputTimes1, inputTimes0);
|
||||||
|
|
||||||
|
for (size_t j = 0; j < inputSubSize0; ++j)
|
||||||
|
{
|
||||||
|
inputTimesData0[j] = inputData0[i * inputSubSize0 + j];
|
||||||
|
}
|
||||||
|
for (size_t j = 0; j < inputSubSize1; ++j)
|
||||||
|
{
|
||||||
|
inputTimesData1[j] = inputData1[i * inputSubSize1 + j];
|
||||||
|
}
|
||||||
|
|
||||||
|
ValuePtr inputTimesValue0 = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(valueShape0.SubShape(0, 2).AppendShape({ 1,1 }), inputTimesData0, true));
|
||||||
|
ValuePtr inputTimesValue1 = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(valueShape1.SubShape(0, 2).AppendShape({ 1,1 }), inputTimesData1, true));
|
||||||
|
ValuePtr outputTimesValue = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(outValueShape.SubShape(0, 2).AppendShape({ 1,1 }), outTimesData, false));
|
||||||
|
std::unordered_map<Variable, ValuePtr> timesOutputs = { {timesResult->Output(), outputTimesValue}};
|
||||||
|
timesResult->Forward({{inputTimes0, inputTimesValue0 }, {inputTimes1, inputTimesValue1 }}, timesOutputs, device);
|
||||||
|
|
||||||
|
for (size_t j = 0; j < outputSubSize; ++j)
|
||||||
|
{
|
||||||
|
expectedOutputValues[i * outputSubSize + j] = outTimesData[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
FloatingPointVectorCompare(outputData, expectedOutputValues, "TestMatMul: Forward prop results do not match expected results.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
BOOST_AUTO_TEST_SUITE(FunctionSuite)
|
BOOST_AUTO_TEST_SUITE(FunctionSuite)
|
||||||
|
|
||||||
BOOST_AUTO_TEST_CASE(FindNameInCPU)
|
BOOST_AUTO_TEST_CASE(FindNameInCPU)
|
||||||
|
@ -1222,6 +1326,13 @@ BOOST_AUTO_TEST_CASE(TestSettingRandomSeed)
|
||||||
SetRandomSeed(DeviceDescriptor::GPUDevice(0));
|
SetRandomSeed(DeviceDescriptor::GPUDevice(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BOOST_AUTO_TEST_CASE(MatMul)
|
||||||
|
{
|
||||||
|
if (ShouldRunOnCpu())
|
||||||
|
TestMatMul(DeviceDescriptor::CPUDevice());
|
||||||
|
if (ShouldRunOnGpu())
|
||||||
|
TestMatMul(DeviceDescriptor::GPUDevice(0));
|
||||||
|
}
|
||||||
|
|
||||||
BOOST_AUTO_TEST_SUITE_END()
|
BOOST_AUTO_TEST_SUITE_END()
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче