CNTK v2 library: Clone constants to right device when user supplied value is not on the target compute device

This commit is contained in:
Amit Agarwal 2016-08-22 11:54:21 -07:00
Родитель 37b6897e94
Коммит 8493f118da
1 изменённых файлов: 9 добавлений и 2 удалений

Просмотреть файл

@ -137,8 +137,15 @@ namespace CNTK
computationNodePtr->SetLearningRateMultiplier(0.0);
NDArrayViewPtr value = variable.IsConstant() ? Constant(variable).Value() : Parameter(variable).Value();
auto matrix = variable.IsConstant() ? value->GetMatrix<ElementType>()->AsReference() : value->GetWritableMatrix<ElementType>()->AsReference();
computationNodePtr->Value() = std::move(matrix);
std::shared_ptr<const Matrix<ElementType>> valueMatrix = variable.IsConstant() ? value->GetMatrix<ElementType>() : value->GetWritableMatrix<ElementType>();
if (variable.IsParameter() || (valueMatrix->GetDeviceId() == network->GetDeviceId()))
computationNodePtr->Value() = valueMatrix->AsReference();
else
{
Matrix<ElementType> clonedMatrix(valueMatrix->GetNumRows(), valueMatrix->GetNumCols(), network->GetDeviceId(), valueMatrix->GetMatrixType(), valueMatrix->GetFormat());
clonedMatrix.AssignValuesOf(*valueMatrix);
computationNodePtr->Value() = std::move(clonedMatrix);
}
}
else if (variable.IsInput())
{