CNTK/Source/Math/cudalattice.cpp

266 строки
16 KiB
C++

// cudalattice.cpp -- lattice forward/backward functions for CUDA execution (glue code)
//
// F. Seide, V-hansu
#define _CRT_SECURE_NO_WARNINGS // "secure" CRT not available on all platforms --add this at the top of all CPP files that give "function or variable may be unsafe" warnings
#define DLLEXPORT
#define __kernel_emulation__ // allow the compilation of CUDA kernels on the CPU
#include "latticefunctionskernels.h" // for the actual inner kernels and any argument types that are not yet defined in latticestorage.h
#undef __kernel_emulation__
#include "cudalattice.h" // this exports the class
#include "cudalatticeops.h" // brings in the actual lattice functions/kernels
#include "cudalib.h" // generic CUDA helpers
#include "cudadevice.h"
#include <math.h>
#include <memory> // for auto_ptr
#include <assert.h>
#include <float.h>
namespace msra { namespace cuda {
extern void operator||(cudaError_t rc, const char *msg); // TODO: imported from cudamatrix.cpp --better move to cudalib.h
// this implements the basic operations of exported interface vectorbase<>, from which all vectors derive
// TODO: This really should not be in cudalattice, since it is more general; we need a cudavector.cpp/h
template <typename VECTORTYPE, typename OPSTYPE>
class vectorbaseimpl : public /*interface*/ VECTORTYPE, // user-type interface; must derive from vectorbase<VECTORBASE::elemtype>
public OPSTYPE, // type of class that implements the kernels; must derive from vectorref<VECTORBASE::elemtype>
public objectondevice // setdevice()
{
typedef typename VECTORTYPE::elemtype elemtype; // (for convenience)
size_t capacity; // amount of allocated storage (like capacity() vs. vectorref::n = size())
void release()
{
ondevice no(deviceid);
free(this->reset(NULL, 0));
}
public:
vectorbaseimpl(size_t deviceid)
: capacity(0), objectondevice(deviceid)
{
}
~vectorbaseimpl()
{
release();
}
void allocate(size_t sz)
{
if (sz > capacity) // need to grow
{
ondevice no(deviceid); // switch to desired CUDA card
cuda_ptr<elemtype> pnew = malloc<elemtype>(sz); // allocate memory inside CUDA device (or throw)
capacity = sz; // if succeeded then: remember
cuda_ptr<elemtype> p = this->reset(pnew, sz); // and swap the pointers and update n
free(p); // then release the old one
}
else // not growing: keep same allocation
this->reset(this->get(), sz);
}
size_t size() const throw()
{
return vectorref<elemtype>::size();
}
void assign(const elemtype *p, size_t nelem, bool synchronize)
{
allocate(nelem); // assign will resize the target appropriately
ondevice no(deviceid); // switch to desired CUDA card
if (nelem > 0)
memcpy(this->get(), 0, p, nelem);
if (synchronize)
join();
}
void fetch(elemtype *p, size_t nelem, bool synchronize) const
{
if (nelem != size()) // fetch() cannot resize the target; caller must do that
LogicError("fetch: vector size mismatch");
ondevice no(deviceid); // switch to desired CUDA card
if (nelem > 0)
memcpy(p, this->get(), 0, nelem);
if (synchronize)
join();
};
};
// ---------------------------------------------------------------------------
// glue code for lattice-related classes
// The XXXvectorimpl classes must derive from vectorbaseimpl<XXXvector,XXXvectorops>.
// For classes without kernels that operate on the vector, XXXvectorimpl is not
// needed, use vectorbaseimpl<XXXvector,vectorref<XXX>> instead, where
// XXXvector is an alias for vectorbase<XXX> (but better keep that alias in cudalattice.h
// to document which vectors are implemented).
// ---------------------------------------------------------------------------
matrixref<float> tomatrixref(const Microsoft::MSR::CNTK::Matrix<float> &m)
{
return matrixref<float>(m.Data(), m.GetNumRows(), m.GetNumCols(), m.GetNumRows());
}
class latticefunctionsimpl : public vectorbaseimpl<latticefunctions, latticefunctionsops>
{
public:
latticefunctionsimpl(size_t deviceid)
: vectorbaseimpl(deviceid)
{
}
private:
void edgealignment(const lrhmmdefvector &hmms, const lr3transPvector &transPs, const size_t spalignunitid,
const size_t silalignunitid, const Microsoft::MSR::CNTK::Matrix<float> &logLLs, const nodeinfovector &nodes,
const edgeinfowithscoresvector &edges, const aligninfovector &aligns,
const uintvector &alignoffsets, ushortvector &backptrstorage, const sizetvector &backptroffsets,
ushortvector &alignresult, floatvector &edgeacscores) // output
{
ondevice no(deviceid);
matrixref<float> logLLsMatrixRef = tomatrixref(logLLs);
latticefunctionsops::edgealignment(dynamic_cast<const vectorbaseimpl<lrhmmdefvector, vectorref<lrhmmdef>> &>(hmms),
dynamic_cast<const vectorbaseimpl<lr3transPvector, vectorref<lr3transP>> &>(transPs),
spalignunitid, silalignunitid, logLLsMatrixRef,
dynamic_cast<const vectorbaseimpl<nodeinfovector, vectorref<msra::lattices::nodeinfo>> &>(nodes),
dynamic_cast<const vectorbaseimpl<edgeinfowithscoresvector, vectorref<msra::lattices::edgeinfowithscores>> &>(edges),
dynamic_cast<const vectorbaseimpl<aligninfovector, vectorref<msra::lattices::aligninfo>> &>(aligns),
dynamic_cast<const vectorbaseimpl<uintvector, vectorref<unsigned int>> &>(alignoffsets),
dynamic_cast<vectorbaseimpl<ushortvector, vectorref<unsigned short>> &>(backptrstorage),
dynamic_cast<const vectorbaseimpl<sizetvector, vectorref<size_t>> &>(backptroffsets),
dynamic_cast<vectorbaseimpl<ushortvector, vectorref<unsigned short>> &>(alignresult),
dynamic_cast<vectorbaseimpl<floatvector, vectorref<float>> &>(edgeacscores));
}
void forwardbackwardlattice(const size_t *batchsizeforward, const size_t *batchsizebackward,
const size_t numlaunchforward, const size_t numlaunchbackward,
const size_t spalignunitid, const size_t silalignunitid,
const floatvector &edgeacscores, const edgeinfowithscoresvector &edges,
const nodeinfovector &nodes, const aligninfovector &aligns,
const ushortvector &alignments, const uintvector &alignoffsets,
doublevector &logpps, doublevector &logalphas, doublevector &logbetas,
const float lmf, const float wp, const float amf, const float boostingfactor, const bool returnEframescorrect,
const ushortvector &uids, const ushortvector &senone2classmap, doublevector &logaccalphas,
doublevector &logaccbetas, doublevector &logframescorrectedge,
doublevector &logEframescorrect, doublevector &Eframescorrectbuf, double &logEframescorrecttotal, double &totalfwscore)
{
ondevice no(deviceid);
latticefunctionsops::forwardbackwardlattice(batchsizeforward, batchsizebackward, numlaunchforward, numlaunchbackward,
spalignunitid, silalignunitid,
dynamic_cast<const vectorbaseimpl<floatvector, vectorref<float>> &>(edgeacscores),
dynamic_cast<const vectorbaseimpl<edgeinfowithscoresvector, vectorref<msra::lattices::edgeinfowithscores>> &>(edges),
dynamic_cast<const vectorbaseimpl<nodeinfovector, vectorref<msra::lattices::nodeinfo>> &>(nodes),
dynamic_cast<const vectorbaseimpl<aligninfovector, vectorref<msra::lattices::aligninfo>> &>(aligns),
dynamic_cast<const vectorbaseimpl<ushortvector, vectorref<unsigned short>> &>(alignments),
dynamic_cast<const vectorbaseimpl<uintvector, vectorref<unsigned int>> &>(alignoffsets),
dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &>(logpps),
dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &>(logalphas),
dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &>(logbetas),
lmf, wp, amf, boostingfactor, returnEframescorrect,
dynamic_cast<const vectorbaseimpl<ushortvector, vectorref<unsigned short>> &>(uids),
dynamic_cast<const vectorbaseimpl<ushortvector, vectorref<unsigned short>> &>(senone2classmap),
dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &>(logaccalphas),
dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &>(logaccbetas),
dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &>(logframescorrectedge),
dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &>(logEframescorrect),
dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &>(Eframescorrectbuf),
logEframescorrecttotal, totalfwscore);
}
void sMBRerrorsignal(const ushortvector &alignstateids,
const uintvector &alignoffsets,
const edgeinfowithscoresvector &edges, const nodeinfovector &nodes,
const doublevector &logpps, const float amf, const doublevector &logEframescorrect,
const double logEframescorrecttotal, Microsoft::MSR::CNTK::Matrix<float> &dengammas, Microsoft::MSR::CNTK::Matrix<float> &dengammasbuf)
{
ondevice no(deviceid);
matrixref<float> dengammasMatrixRef = tomatrixref(dengammas);
matrixref<float> dengammasbufMatrixRef = tomatrixref(dengammasbuf);
latticefunctionsops::sMBRerrorsignal(dynamic_cast<const vectorbaseimpl<ushortvector, vectorref<unsigned short>> &>(alignstateids),
dynamic_cast<const vectorbaseimpl<uintvector, vectorref<unsigned int>> &>(alignoffsets),
dynamic_cast<const vectorbaseimpl<edgeinfowithscoresvector, vectorref<msra::lattices::edgeinfowithscores>> &>(edges),
dynamic_cast<const vectorbaseimpl<nodeinfovector, vectorref<msra::lattices::nodeinfo>> &>(nodes),
dynamic_cast<const vectorbaseimpl<doublevector, vectorref<double>> &>(logpps),
amf,
dynamic_cast<const vectorbaseimpl<doublevector, vectorref<double>> &>(logEframescorrect),
logEframescorrecttotal, dengammasMatrixRef, dengammasbufMatrixRef);
}
void mmierrorsignal(const ushortvector &alignstateids, const uintvector &alignoffsets,
const edgeinfowithscoresvector &edges, const nodeinfovector &nodes,
const doublevector &logpps, Microsoft::MSR::CNTK::Matrix<float> &dengammas)
{
ondevice no(deviceid);
matrixref<float> dengammasMatrixRef = tomatrixref(dengammas);
latticefunctionsops::mmierrorsignal(dynamic_cast<const vectorbaseimpl<ushortvector, vectorref<unsigned short>> &>(alignstateids),
dynamic_cast<const vectorbaseimpl<uintvector, vectorref<unsigned int>> &>(alignoffsets),
dynamic_cast<const vectorbaseimpl<edgeinfowithscoresvector, vectorref<msra::lattices::edgeinfowithscores>> &>(edges),
dynamic_cast<const vectorbaseimpl<nodeinfovector, vectorref<msra::lattices::nodeinfo>> &>(nodes),
dynamic_cast<const vectorbaseimpl<doublevector, vectorref<double>> &>(logpps),
dengammasMatrixRef);
}
void stateposteriors(const ushortvector &alignstateids, const uintvector &alignoffsets,
const edgeinfowithscoresvector &edges, const nodeinfovector &nodes,
const doublevector &logqs, Microsoft::MSR::CNTK::Matrix<float> &logacc)
{
ondevice no(deviceid);
matrixref<float> logaccMatrixRef = tomatrixref(logacc);
latticefunctionsops::stateposteriors(dynamic_cast<const vectorbaseimpl<ushortvector, vectorref<unsigned short>> &>(alignstateids),
dynamic_cast<const vectorbaseimpl<uintvector, vectorref<unsigned int>> &>(alignoffsets),
dynamic_cast<const vectorbaseimpl<edgeinfowithscoresvector, vectorref<msra::lattices::edgeinfowithscores>> &>(edges),
dynamic_cast<const vectorbaseimpl<nodeinfovector, vectorref<msra::lattices::nodeinfo>> &>(nodes),
dynamic_cast<const vectorbaseimpl<doublevector, vectorref<double>> &>(logqs),
logaccMatrixRef);
}
};
latticefunctions *newlatticefunctions(size_t deviceid)
{
return new latticefunctionsimpl(deviceid);
}
// implementation of lrhmmdefvector
// Class has no vector-level member functions, so no need for an extra type
lrhmmdefvector *newlrhmmdefvector(size_t deviceid)
{
return new vectorbaseimpl<lrhmmdefvector, vectorref<lrhmmdef>>(deviceid);
}
lr3transPvector *newlr3transPvector(size_t deviceid)
{
return new vectorbaseimpl<lr3transPvector, vectorref<lr3transP>>(deviceid);
}
ushortvector *newushortvector(size_t deviceid)
{
return new vectorbaseimpl<ushortvector, vectorref<unsigned short>>(deviceid);
}
uintvector *newuintvector(size_t deviceid)
{
return new vectorbaseimpl<uintvector, vectorref<unsigned int>>(deviceid);
}
floatvector *newfloatvector(size_t deviceid)
{
return new vectorbaseimpl<floatvector, vectorref<float>>(deviceid);
}
doublevector *newdoublevector(size_t deviceid)
{
return new vectorbaseimpl<doublevector, vectorref<double>>(deviceid);
}
sizetvector *newsizetvector(size_t deviceid)
{
return new vectorbaseimpl<sizetvector, vectorref<size_t>>(deviceid);
}
nodeinfovector *newnodeinfovector(size_t deviceid)
{
return new vectorbaseimpl<nodeinfovector, vectorref<nodeinfo>>(deviceid);
}
edgeinfowithscoresvector *newedgeinfovector(size_t deviceid)
{
return new vectorbaseimpl<edgeinfowithscoresvector, vectorref<edgeinfowithscores>>(deviceid);
}
aligninfovector *newaligninfovector(size_t deviceid)
{
return new vectorbaseimpl<aligninfovector, vectorref<aligninfo>>(deviceid);
}
};
};