Updated code to work with the requirement of each matrix owning its storage in order to resize.

This commit is contained in:
thhoens 2016-04-04 18:57:55 -07:00
Родитель e432c455ae
Коммит d02109ac1b
1 изменённых файлов: 15 добавлений и 2 удалений

Просмотреть файл

@ -14,6 +14,7 @@
#include "cudalatticeops.h"
#include <numeric> // for debug
#include "cudalib.h"
#include <memory>
#define TWO_CHANNEL // [v-hansu]
using namespace msra::cuda;
@ -623,16 +624,28 @@ struct parallelstateimpl
if (errorsignalgpustorage->GetNumRows() != 0 && errorsignalgpustorage->GetNumRows() != errorsignal.rows())
throw ::logic_error("gpumatrixstorage->rows() shall be fixed once allocated");
if (errorsignalgpustorage->GetNumCols() < errorsignal.cols())
{
// Note: This is required because otherwise errorsignalgpustorage will be a view of the storage object in
// errorsignalgpustorage, and thuse it can't resize. This is perhaps not the optimal way to do this, but
// how else? Why do these two matrices exist? Why not just one?
errorsignalgpu = nullptr;
errorsignalgpustorage->Resize(errorsignal.rows(), errorsignal.cols());
*errorsignalgpu = errorsignalgpustorage->ColumnSlice(0, errorsignal.cols());
}
//*errorsignalgpu = errorsignalgpustorage->ColumnSlice(0, errorsignal.cols());
errorsignalgpu = make_unique<Microsoft::MSR::CNTK::Matrix<float>>(errorsignalgpustorage->ColumnSlice(0, errorsignal.cols()));
if (cacheerrsignalneg)
{
if (errorsignalneggpustorage->GetNumRows() != 0 && errorsignalneggpustorage->GetNumRows() != errorsignal.rows())
throw ::logic_error("gpumatrixstorage->rows() shall be fixed once allocated");
if (errorsignalneggpustorage->GetNumCols() < errorsignal.cols())
{
// Same as above.
errorsignalneggpu = nullptr;
errorsignalneggpustorage->Resize(errorsignal.rows(), errorsignal.cols());
*errorsignalneggpu = errorsignalneggpustorage->ColumnSlice(0, errorsignal.cols());
}
//*errorsignalneggpu = errorsignalneggpustorage->ColumnSlice(0, errorsignal.cols());
errorsignalneggpu = make_unique<Microsoft::MSR::CNTK::Matrix<float>>(errorsignalneggpustorage->ColumnSlice(0, errorsignal.cols()));
}
}