CNTK v2 library: Clone constants to right device when user supplied value is not on the target compute device
This commit is contained in:
Родитель
37b6897e94
Коммит
8493f118da
|
@ -137,8 +137,15 @@ namespace CNTK
|
|||
computationNodePtr->SetLearningRateMultiplier(0.0);
|
||||
|
||||
NDArrayViewPtr value = variable.IsConstant() ? Constant(variable).Value() : Parameter(variable).Value();
|
||||
auto matrix = variable.IsConstant() ? value->GetMatrix<ElementType>()->AsReference() : value->GetWritableMatrix<ElementType>()->AsReference();
|
||||
computationNodePtr->Value() = std::move(matrix);
|
||||
std::shared_ptr<const Matrix<ElementType>> valueMatrix = variable.IsConstant() ? value->GetMatrix<ElementType>() : value->GetWritableMatrix<ElementType>();
|
||||
if (variable.IsParameter() || (valueMatrix->GetDeviceId() == network->GetDeviceId()))
|
||||
computationNodePtr->Value() = valueMatrix->AsReference();
|
||||
else
|
||||
{
|
||||
Matrix<ElementType> clonedMatrix(valueMatrix->GetNumRows(), valueMatrix->GetNumCols(), network->GetDeviceId(), valueMatrix->GetMatrixType(), valueMatrix->GetFormat());
|
||||
clonedMatrix.AssignValuesOf(*valueMatrix);
|
||||
computationNodePtr->Value() = std::move(clonedMatrix);
|
||||
}
|
||||
}
|
||||
else if (variable.IsInput())
|
||||
{
|
||||
|
|
Загрузка…
Ссылка в новой задаче