partially support matmul broadcast.

This commit is contained in:
Bowen Bao 2018-07-29 18:06:01 -07:00
Родитель 2964f1fde5
Коммит 6ed965924c
3 изменённых файлов: 222 добавлений и 18 удалений

Просмотреть файл

@ -3535,27 +3535,116 @@ namespace CNTK
// we implement this import step-by-step.
// 1. Both 2-D inputs.
// 2. Both N-D inputs.
// 3. Add broadcast. (Not supported in ONNX/CNTK yet)
// 3. Add broadcast. (Partially supported in ONNX/CNTK. )
//
// Note: we don't support free/inferred dimension for N-D MatMul currently.
const NDShape& inputShape0 = leftOperand.Shape();
const NDShape& inputShape1 = rightOperand.Shape();
auto isBothNDInputs = [&]() -> bool {
if (inputShape0.Rank() != inputShape1.Rank() || inputShape0.Rank() <= 2) return false;
// In this case we don't require broadcast, thus prefix dimensions should match.
return inputShape0.SubShape(2) == inputShape1.SubShape(2);
};
NDShape inputShape0 = leftOperand.Shape();
NDShape inputShape1 = rightOperand.Shape();
Variable leftOperandPlaceholder = PlaceholderVariable(inputShape0, L"leftOperand", {});
Variable rightOperandPlaceholder = PlaceholderVariable(inputShape1, L"rightOperand", {});
Variable operand0 = leftOperandPlaceholder;
Variable operand1 = rightOperandPlaceholder;
FunctionPtr cntkFunction;
size_t unitTailLengthToAppend = 0;
///
/// Preprocessing
///
const bool isBothNDInputs = [&]() -> bool {
if (inputShape0.Rank() != inputShape1.Rank() || inputShape0.Rank() <= 2) return false;
// In this case we don't require broadcast, thus prefix dimensions should match.
return inputShape0.SubShape(2) == inputShape1.SubShape(2);
}();
if (!isBothNDInputs)
{
// Do check for case 3. If total size of tail shape of longer operand is 1,
// we can handle this broadcast case by removing the tail shape and append to result later.
const bool isSimpleBroadcast = [&]() -> bool {
// Some situation needs broadcast.
// Check if this is simple broadcast: that we only need to append 1 to the shorter shape.
// e.g.
// input_0: [b, a, 2, 1, 1]
// input_1: [c, b, 2]
// ==> broadcast: remove tail shape [1, 1] from input_0
// input_0: [b, a, 2]
//
if (inputShape0.Rank() < 2 || inputShape1.Rank() < 2) return false;
Variable *shortOperand = &operand1;
Variable *longOperand = &operand0;
if (inputShape0.Rank() < inputShape1.Rank()) {
std::swap(shortOperand, longOperand);
}
const NDShape& tailShape = longOperand->Shape().SubShape(shortOperand->Shape().Rank());
if (tailShape.TotalSize() != 1) return false;
return true;
}();
if (!isSimpleBroadcast)
LogicError("MatMul: Complex form of broadcasting is currently not supported in ONNX/CNTK.");
// Remove tail shape of {1, ..., 1}
if (inputShape0.Rank() < inputShape1.Rank())
{
unitTailLengthToAppend = inputShape1.Rank() - inputShape0.Rank();
inputShape1 = inputShape1.SubShape(0, inputShape0.Rank());
operand1 = Reshape(operand1, inputShape1)->Output();
}
else
{
unitTailLengthToAppend = inputShape0.Rank() - inputShape1.Rank();
inputShape0 = inputShape0.SubShape(0, inputShape1.Rank());
operand0 = Reshape(operand0, inputShape0)->Output();
}
}
// Do check for case 2. We can reduce tail shape of both operand if they are both 1.
const size_t sharedUnitTailLength = [&]() -> size_t {
assert(inputShape0.Rank() == inputShape1.Rank());
size_t sharedUnitTailLength = 0;
for (size_t i = inputShape0.Rank() - 1; i >= 2; --i)
{
if (inputShape0[i] == 1 && inputShape1[i] == 1)
sharedUnitTailLength++;
else
break;
}
return sharedUnitTailLength;
}();
if (sharedUnitTailLength > 0)
{
// e.g.
// input_0: [b, a, 2, 1, 1]
// input_1: [c, b, 2, 1 ,1]
// ==> remove common tail shape [1, 1] from both input
// input_0: [b, a, 2]
// input_1: [c, b, 2]
//
inputShape0 = inputShape0.SubShape(0, inputShape0.Rank() - sharedUnitTailLength);
inputShape1 = inputShape1.SubShape(0, inputShape1.Rank() - sharedUnitTailLength);
operand0 = Reshape(operand0, inputShape0)->Output();
operand1 = Reshape(operand1, inputShape1)->Output();
unitTailLengthToAppend += sharedUnitTailLength;
}
///
/// After preprocessing, inputs are either both 2-D or both N-D.
///
if (inputShape0.Rank() == 2 && inputShape1.Rank() == 2)
{
// 1. Both 2-D inputs.
// CNTK Times has reversed input order than ONNX(numpy) MatMul.
cntkFunction = Times(rightOperandPlaceholder, leftOperandPlaceholder);
cntkFunction = Times(operand1, operand0);
if (unitTailLengthToAppend > 0)
{
const NDShape& outputShape{ inputShape1[0], inputShape0[1] };
const NDShape& tailShape(std::vector<size_t>(unitTailLengthToAppend, 1));
cntkFunction = Reshape(cntkFunction, outputShape.AppendShape(tailShape));
}
}
else if (isBothNDInputs())
else
{
// 2. Both N-D inputs.
// Convert inputs into CNTK style:
@ -3590,9 +3679,9 @@ namespace CNTK
{
inputPrefixProd *= inputShape0[i];
}
FunctionPtr input0CNTK = Reshape(leftOperandPlaceholder, { bDim, inputPrefixProd * aDim });
FunctionPtr input0CNTK = Reshape(operand0, { bDim, inputPrefixProd * aDim });
// 2)
FunctionPtr input1CNTK = Reshape(rightOperandPlaceholder, { cDim, bDim, inputPrefixProd });
FunctionPtr input1CNTK = Reshape(operand1, { cDim, bDim, inputPrefixProd });
// 3)
input1CNTK = TransposeAxes(input1CNTK, Axis(1), Axis(2));
// 4)
@ -3612,11 +3701,12 @@ namespace CNTK
// 10)
const NDShape& outputShape = NDShape({ cDim, aDim }).AppendShape(inputShape0.SubShape(2));
cntkFunction = Reshape(outputCNTK, outputShape);
}
else
{
// 3. broadcast.
LogicError("MatMul: broadcasting is currently not supported in ONNX/CNTK.");
if (unitTailLengthToAppend > 0)
{
const NDShape& tailShape(std::vector<size_t>(unitTailLengthToAppend, 1));
cntkFunction = Reshape(cntkFunction, outputShape.AppendShape(tailShape));
}
}
return AsBlock(std::move(cntkFunction),

Просмотреть файл

@ -554,6 +554,9 @@ Test module "V2LibraryTests" has passed with:
Test case "FunctionSuite/TestSettingRandomSeed" has passed with:
2 assertions out of 2 passed
Test case "FunctionSuite/MatMul" has passed with:
4 assertions out of 4 passed
Test suite "NDArrayViewSuite" has passed with:
8 test cases out of 8 passed
66039 assertions out of 66039 passed

Просмотреть файл

@ -1096,6 +1096,110 @@ void SetRandomSeed(const DeviceDescriptor& device)
FloatingPointVectorCompare(result2, result4, "SetRandomSeed: output does match the expected after resetting the dropout seed.");
}
void TestMatMul(const DeviceDescriptor& device)
{
srand(1);
auto diff_size = [](const std::vector<size_t>& a, const std::vector<size_t>& b)
{
bool foundDifference = false;
if (a.size() != b.size()) return true;
for (int i = 0; !foundDifference && i < a.size() && i < b.size(); ++i)
{
foundDifference = (a[i] != b[i]);
}
return foundDifference;
};
std::vector<std::vector<size_t>> inputShapeVec0{{3, 4}, {3, 4, 2, 2}, {64, 4, 1}, {64, NDShape::InferredDimension, 1}};
std::vector<std::vector<size_t>> inputShapeVec1{{5, 3}, {5, 3, 2, 2}, {2, 64}, {2, 64}};
std::vector<std::vector<size_t>> outShapeVec{{5, 4}, {5, 4, 2, 2}, {2, 4, 1}, {2, NDShape::InferredDimension, 1}};
std::vector<std::vector<size_t>> inputValueShapeVec0{ { 3, 4 },{ 3, 4, 2, 2 },{ 64, 4, 1 },{ 64, 4, 1 } };
std::vector<std::vector<size_t>> inputValueShapeVec1{ { 5, 3 },{ 5, 3, 2, 2 },{ 2, 64 },{ 2, 64 } };
std::vector<std::vector<size_t>> outValueShapeVec{ { 5, 4 },{ 5, 4, 2, 2 },{ 2, 4, 1 },{ 2, 4, 1 } };
std::vector<size_t> inputTotalSizeVec0 = { 12, 48, 256, 256 };
std::vector<size_t> inputTotalSizeVec1 = { 15, 60, 128, 128 };
std::vector<size_t> outputTotalSizeVec = { 20, 80, 8, 8 };
std::vector<size_t> inputSubSizeVec0 = { 12, 12, 256, 256 };
std::vector<size_t> inputSubSizeVec1 = { 15, 15, 128, 128 };
std::vector<size_t> outputSubSizeVec = { 20, 20, 8, 8 };
size_t testCases = inputShapeVec0.size();
for (size_t test_i = 0; test_i < testCases; ++test_i)
{
auto shape0 = NDShape(inputShapeVec0[test_i]);
auto shape1 = NDShape(inputShapeVec1[test_i]);
auto valueShape0 = NDShape(inputValueShapeVec0[test_i]);
auto valueShape1 = NDShape(inputValueShapeVec1[test_i]);
auto outShape = NDShape(outShapeVec[test_i]);
auto outValueShape = NDShape(outValueShapeVec[test_i]);
size_t inputTotalSize0 = inputTotalSizeVec0[test_i];
size_t inputTotalSize1 = inputTotalSizeVec1[test_i];
size_t outputTotalSize = outputTotalSizeVec[test_i];
size_t inputSubSize0 = inputSubSizeVec0[test_i];
size_t inputSubSize1 = inputSubSizeVec1[test_i];
size_t outputSubSize = outputSubSizeVec[test_i];
auto input0 = InputVariable(shape0, DataType::Float);
auto input1 = InputVariable(shape1, DataType::Float);
auto result = ::CNTK::Internal::MatMul(input0, input1);
std::vector<float> inputData0(inputTotalSize0);
std::vector<float> inputData1(inputTotalSize1);
for (size_t i = 0; i < inputData0.size(); ++i)
inputData0[i] = ((float)rand()) / RAND_MAX;
for (size_t i = 0; i < inputData1.size(); ++i)
inputData1[i] = ((float)rand()) / RAND_MAX;
ValuePtr inputValue0 = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(valueShape0.AppendShape({1,1}), inputData0, true));
ValuePtr inputValue1 = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(valueShape1.AppendShape({ 1,1 }), inputData1, true));
NDShape outputShape = result->Output().Shape();
BOOST_TEST(!diff_size(outShape.Dimensions(), outputShape.Dimensions()));
std::vector<float> outputData(outputTotalSize);
ValuePtr outputValue = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(outValueShape.AppendShape({ 1,1 }), outputData, false));
std::unordered_map<Variable, ValuePtr> outputs = {{result->Output(), outputValue}};
result->Forward({{input0, inputValue0}, {input1, inputValue1}}, outputs, device);
std::vector<float> expectedOutputValues(outputTotalSize);
{
for (size_t i = 0; i < outputTotalSize / outputSubSize; i++)
{
std::vector<float> inputTimesData0(inputSubSize0);
std::vector<float> inputTimesData1(inputSubSize1);
std::vector<float> outTimesData(outputSubSize);
auto inputTimes0 = InputVariable(shape0.SubShape(0, 2), DataType::Float);
auto inputTimes1 = InputVariable(shape1.SubShape(0, 2), DataType::Float);
auto timesResult = Times(inputTimes1, inputTimes0);
for (size_t j = 0; j < inputSubSize0; ++j)
{
inputTimesData0[j] = inputData0[i * inputSubSize0 + j];
}
for (size_t j = 0; j < inputSubSize1; ++j)
{
inputTimesData1[j] = inputData1[i * inputSubSize1 + j];
}
ValuePtr inputTimesValue0 = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(valueShape0.SubShape(0, 2).AppendShape({ 1,1 }), inputTimesData0, true));
ValuePtr inputTimesValue1 = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(valueShape1.SubShape(0, 2).AppendShape({ 1,1 }), inputTimesData1, true));
ValuePtr outputTimesValue = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(outValueShape.SubShape(0, 2).AppendShape({ 1,1 }), outTimesData, false));
std::unordered_map<Variable, ValuePtr> timesOutputs = { {timesResult->Output(), outputTimesValue}};
timesResult->Forward({{inputTimes0, inputTimesValue0 }, {inputTimes1, inputTimesValue1 }}, timesOutputs, device);
for (size_t j = 0; j < outputSubSize; ++j)
{
expectedOutputValues[i * outputSubSize + j] = outTimesData[j];
}
}
}
FloatingPointVectorCompare(outputData, expectedOutputValues, "TestMatMul: Forward prop results do not match expected results.");
}
}
BOOST_AUTO_TEST_SUITE(FunctionSuite)
BOOST_AUTO_TEST_CASE(FindNameInCPU)
@ -1222,6 +1326,13 @@ BOOST_AUTO_TEST_CASE(TestSettingRandomSeed)
SetRandomSeed(DeviceDescriptor::GPUDevice(0));
}
BOOST_AUTO_TEST_CASE(MatMul)
{
if (ShouldRunOnCpu())
TestMatMul(DeviceDescriptor::CPUDevice());
if (ShouldRunOnGpu())
TestMatMul(DeviceDescriptor::GPUDevice(0));
}
BOOST_AUTO_TEST_SUITE_END()