GPUMatrix.cu split: moved TensorView support to separate compilation unit (GPUMatrix had gotten too large for MSVC to compile the fatbin file). No code change otherwise

This commit is contained in:
Frank Seide 2015-12-28 19:45:36 -08:00
Родитель d27753d33a
Коммит 22bab75d5f
8 изменённых файлов: 747 добавлений и 577 удалений

Просмотреть файл

@ -5,29 +5,26 @@
//
#include "stdafx.h"
#include "Basics.h"
#include "BestGpu.h"
#include "DebugUtil.h"
#ifndef CPUONLY
#include "cublas_v2.h"
#include "Basics.h"
#include "GPUMatrix.h"
#include "GPUMatrixCUDAKernels.cuh"
#include "GPUSparseMatrix.h"
#include "GPUTensor.h"
#include "CommonMatrix.h"
#define TENSOR_OPS_DECL __device__ __host__
#include "TensorOps.h"
#include "device_launch_parameters.h"
#include <assert.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <curand.h>
#include <curand_kernel.h>
#ifndef let
#define let const auto
#endif
#include "cublas_v2.h"
#include <assert.h>
#pragma comment (lib, "cudart.lib") // instruct linker to reference these libs
#pragma comment (lib, "cublas.lib")
@ -82,9 +79,9 @@ cudaStream_t MATH_API GetStream()
performElementWiseFunction(ElementWiseOperator::op##f, a.m_pArray); \
return *this; }
static const char * CudaErrString(cudaError_t x) { cudaDeviceSynchronize(); return cudaGetErrorString(x); }
static const char * CudaErrString(cublasStatus_t) { cudaDeviceSynchronize(); return "(see cublas_api.h & look for cublasStatus_t or CUBLAS_STATUS_xxx)"; }
static const char * CudaErrString(curandStatus) { cudaDeviceSynchronize(); return "(see curand.h & look for curandStatus or CURAND_STATUS_xxx)"; }
const char * CudaErrString(cudaError_t x) { cudaDeviceSynchronize(); return cudaGetErrorString(x); }
const char * CudaErrString(cublasStatus_t) { cudaDeviceSynchronize(); return "(see cublas_api.h & look for cublasStatus_t or CUBLAS_STATUS_xxx)"; }
const char * CudaErrString(curandStatus) { cudaDeviceSynchronize(); return "(see curand.h & look for curandStatus or CURAND_STATUS_xxx)"; }
namespace Microsoft { namespace MSR { namespace CNTK {
@ -4440,541 +4437,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
CUDA_CALL(cudaFree(d_zeta));
};
// =======================================================================
// TensorView support
// =======================================================================
// BUGBUG: This is a stub that currently is just the CPU code. This is not functional yet.
// To save time, this makes extensive use of templates and macros.
// -----------------------------------------------------------------------
// simple fixed-size arrays for passing dimension information by value
// since CUDA can't just take our std::array and std::vector
// -----------------------------------------------------------------------
template<typename T, size_t N>
struct FixedArray
{
T m_data[N];
__device__ __host__ size_t size() const { return N; }
__device__ __host__ T & operator[](size_t n) { return m_data[n]; }
__device__ __host__ T operator[](size_t n) const { return m_data[n]; }
template<class VEC> FixedArray(const VEC & data) // construct from CPU-side STL array or vector
{
assert(data.size() == N);
for (size_t n = 0; n < N; n++)
{
m_data[n] = (T)data[n];
if (m_data[n] != data[n]) // overflow check
InvalidArgument("FixedArray: Dimensions out of range, too few bits.");
}
}
};
template<typename T> // specialized version for 0 elements
struct FixedArray<T, 0>
{
__device__ __host__ size_t size() const { return 0; }
template<class VEC> FixedArray(const VEC & data) { assert(data.size() == 0); UNUSED(data); }
};
template<typename T, size_t N, size_t K> // N = which input/output; K = index depth
struct FixedMatrix
{
T m_data[N][K];
__device__ __host__ size_t getNumRows() const { return N; }
__device__ __host__ size_t getNumCols() const { return K; }
__device__ __host__ T & operator()(size_t n, size_t k) { return m_data[n][k]; }
__device__ __host__ T operator()(size_t n, size_t k) const { return m_data[n][k]; }
template<typename U> FixedMatrix(const array<SmallVector<U>, N> & data) // construct from CPU-side array of vectors
{
assert(data.size() == N);
for (size_t n = 0; n < N; n++)
{
assert(data[n].size() == K);
for (size_t k = 0; k < K; k++)
{
m_data[n][k] = (T)data[n][k];
if (m_data[n][k] != data[n][k]) // overflow check
InvalidArgument("FixedArray: Dimensions out of range, too few bits.");
}
}
}
};
template<typename T, size_t N> // specialized version for 0 elements
struct FixedMatrix<T, N, 0>
{
__device__ __host__ size_t getNumRows() const { return N; }
__device__ __host__ size_t getNumCols() const { return 0; }
template<typename U> FixedMatrix(const array<SmallVector<U>, N> & data) { assert(data.size() == N); for (size_t n = 0; n < N; n++) assert(data[n].size() == 0); UNUSED(data); }
};
// -----------------------------------------------------------------------
// function to actually compute a function of (N-1) inputs based on the opcode
// -----------------------------------------------------------------------
template<class ElemType>
struct TensorOps
{
static __device__ ElemType Compute(const FixedArray<ElemType*, 1> & pointers, ElementWiseOperator op)
{
#define CaseNullaryTensorOp(oper) case ElementWiseOperator::op ## oper: return Op ## oper<ElemType>()
switch (op)
{
ForAllNullaryOps(CaseNullaryTensorOp);
default: return OpConstOne<ElemType>(); // (failure--we only have one nullary op, so use the same, maybe it will eliminate the switch altogether)
}
}
static __device__ ElemType Compute(const FixedArray<ElemType*, 2> & pointers, ElementWiseOperator op)
{
ElemType a = *(pointers[0]);
#define CaseUnaryTensorOp(oper) case ElementWiseOperator::op ## oper: return Op ## oper(a)
switch (op)
{
ForAllUnaryOps(CaseUnaryTensorOp);
default: return 0; // (failure)
}
}
static __device__ ElemType Compute(const FixedArray<ElemType*, 3> & pointers, ElementWiseOperator op)
{
ElemType a = *(pointers[0]);
ElemType b = *(pointers[1]);
#define CaseBinaryTensorOp(oper) case ElementWiseOperator::op ## oper: return Op ## oper(a,b)
switch (op)
{
ForAllBinaryOps(CaseBinaryTensorOp); // note: this costs about 6% compared to having only a single case
default: return 0; // (failure)
}
}
static __device__ ElemType Compute(const FixedArray<ElemType*, 4> & pointers, ElementWiseOperator op)
{
ElemType a = *(pointers[0]);
ElemType b = *(pointers[1]);
ElemType c = *(pointers[2]);
#define CaseTernaryTensorOp(oper) case ElementWiseOperator::op ## oper: return Op ## oper(a,b,c)
switch (op)
{
ForAllTernaryOps(CaseTernaryTensorOp);
default: return 0; // (failure)
}
}
};
#define C_size_t CUDA_LONG
#define C_int CUDA_LONG
#define C_unsigned_int CUDA_LONG
// -----------------------------------------------------------------------
// function to compute the value for a given output location (this version performs reduction if needed)
// -----------------------------------------------------------------------
template<class ElemType, C_size_t N, C_int M, C_int m>
struct TensorOpReduce
{
// this version for m >= 0
static __device__ ElemType Compute(FixedArray<ElemType*, N> pointers, ElementWiseOperator op,
const FixedArray<C_unsigned_int, M> & reducingOpDims, const FixedMatrix<C_int, N, M> & reducingStrides)
{
// start with index 0
// Using 'double' since we are memory-bound anyway.
double/*ElemType*/ aggregate = TensorOpReduce<ElemType, N, M, m - 1>::Compute(pointers, op, reducingOpDims, reducingStrides);
// apply this index to the pointers
C_size_t dim = reducingOpDims[m];
for (C_size_t k = 1/*done with k=0 already*/; k < dim; k++)
{
// bump the pointers
for (C_size_t i = 0; i < N; i++)
pointers[i] += reducingStrides(i,(C_size_t)m);
ElemType val = TensorOpReduce<ElemType, N, M, m - 1>::Compute(pointers, op, reducingOpDims, reducingStrides);
aggregate += val;
}
return (ElemType)aggregate;
}
};
// this one terminates the template recursion over reduction dimensions
// The pointers are pointing to the input element.
template<class ElemType, C_size_t N, C_int M>
struct TensorOpReduce<ElemType, N, M, /*m=*/-1>
{
// this version for m = -1
// the pointers are pointing to the right location(s) to take the operation over
static __device__ ElemType Compute(FixedArray<ElemType*, N> pointers, ElementWiseOperator op,
const FixedArray<C_unsigned_int, M> & /*reducingOpDims*/, const FixedMatrix<C_int, N, M> & /*reducingStrides*/)
{
return TensorOps<ElemType>::Compute(pointers, op); // finally computing something!
}
};
// -----------------------------------------------------------------------
// function to compute one constituent of the value for a given output location (this version has reduction done outside)
// -----------------------------------------------------------------------
template<class ElemType, C_size_t N, C_int M, C_int m>
struct TensorOpParallelReduce
{
// this version for m >= 0
static __device__ ElemType Compute(CUDA_LONG id, FixedArray<ElemType*, N> pointers, ElementWiseOperator op,
const FixedArray<C_unsigned_int, M> & reducingOpDims, const FixedMatrix<C_int, N, M> & reducingStrides)
{
// map id (location on grid) to index[k]
C_size_t stride = 1; // compute the stride. This seems expensive, but since we we only currently support M <= 2, this is just compile-time selection between 1 and reducingOpDims[0].
for (int i = 0; i < m; i++)
stride *= reducingOpDims[(C_size_t)i];
C_size_t index = id / stride; // this dimension. For m=0, the stride is 1 and hence the division will be removed at compile time.
id = id % stride; // remaining dimensions inside this. For m=0 this value is ignored and hence not even computed.
// apply this index to the pointers
for (C_size_t i = 0; i < N; i++)
pointers[i] += index * reducingStrides(i, (C_size_t)m); // now this dimension is taken care of
return TensorOpParallelReduce<ElemType, N, M, m - 1>::Compute(id, pointers, op, reducingOpDims, reducingStrides);
}
};
// this one terminates the template recursion over reduction dimensions
// The pointers are pointing to the input element.
template<class ElemType, C_size_t N, C_int M>
struct TensorOpParallelReduce<ElemType, N, M, /*m=*/-1>
{
// this version for m = -1
// the pointers are pointing to the right location(s) to take the operation over
static __device__ ElemType Compute(CUDA_LONG /*id*/, FixedArray<ElemType*, N> pointers, ElementWiseOperator op,
const FixedArray<C_unsigned_int, M> & /*reducingOpDims*/, const FixedMatrix<C_int, N, M> & /*reducingStrides*/)
{
return TensorOps<ElemType>::Compute(pointers, op); // finally computing something!
}
};
// -----------------------------------------------------------------------
// perform loop over regular index k for N-nary operations (N counting the output)
// -----------------------------------------------------------------------
// The canonical case, vector op without reduction, is this PTX function:
// _ZN9Microsoft3MSR4CNTK15_launchTensorOpIfLi3ELi0ELi1EEEvT_NS1_10FixedArrayIPS3_XT0_EEES3_NS1_19ElementWiseOperatorENS4_IiXT2_EEENS1_11FixedMatrixIiXT0_EXT2_EEENS4_IiXT1_EEENS9_IiXT0_EXT1_EEEi
// float ^ ^ aggregate loop
// args? ^ ^ input dims
// _ZN9Microsoft3MSR4CNTK15_launchTensorOpIfLi2ELi0ELi1EEEvT_NS1_10FixedArrayIPS3_XT0_EEES3_NS1_19ElementWiseOperatorENS4_IiXT2_EEENS1_11FixedMatrixIiXT0_EXT2_EEENS4_IiXT1_EEENS9_IiXT0_EXT1_EEEi
// The 'pointers' only refer to a single element, so we will bump them in-place to perform indexing.
template<class ElemType, C_size_t N, C_int M, C_int K, bool parallelReduce, C_int k>
struct TensorOpElement
{
// template-recursive version loops over indices
static __device__ void Compute(CUDA_LONG id, ElemType beta, FixedArray<ElemType*, N> & pointers, ElemType alpha, ElementWiseOperator op,
const FixedArray<C_unsigned_int, K> & regularOpStrides, const FixedMatrix<C_int, N, K> & regularStrides,
const FixedArray<C_unsigned_int, M> & reducingOpDims, const FixedMatrix<C_int, N, M> & reducingStrides)
{
// map id (location on grid) to index[k]
C_size_t stride = regularOpStrides[(C_size_t)k];
C_size_t index = id / stride; // this dimension
id = id % stride; // remaining dimensions inside this
// apply this index to the pointers
for (C_size_t i = 0; i < N; i++)
pointers[i] += index * regularStrides(i,(C_size_t)k); // now this dimension is taken care of
// process the previous index
TensorOpElement<ElemType, N, M, K, parallelReduce, k - 1>::Compute(id, beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides);
}
};
// specialization for k=0 where op stride is guaranteed to be 1
template<class ElemType, C_size_t N, C_int M, C_int K, bool parallelReduce>
struct TensorOpElement<ElemType, N, M, K, parallelReduce, /*k=*/0>
{
// template-recursive version loops over indices
static __device__ void Compute(CUDA_LONG id, ElemType beta, FixedArray<ElemType*, N> & pointers, ElemType alpha, ElementWiseOperator op,
const FixedArray<C_unsigned_int, K> & regularOpStrides, const FixedMatrix<C_int, N, K> & regularStrides,
const FixedArray<C_unsigned_int, M> & reducingOpDims, const FixedMatrix<C_int, N, M> & reducingStrides)
{
// map id (location on grid) to index[k]
C_size_t index = id; // this dimension
// apply this index to the pointers
for (C_size_t i = 0; i < N; i++)
pointers[i] += index * regularStrides(i,0); // now this dimension is taken care of
// process the previous index
TensorOpElement<ElemType, N, M, K, parallelReduce, -1>::Compute(/*id*/0, beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides);
}
};
// apply beta and alpha and save
template<class ElemType, class PointersType>
static __device__ void SetFinalValue(ElemType val, ElemType beta, const PointersType & pointers, ElemType alpha)
{
// scale
val *= alpha;
// combine with previous value in target matrix, then write it out
auto * pout = pointers[pointers.size() - 1];
if (beta != 0)
val += beta * *pout;
// save
*pout = val;
}
// specialization for k = -1 terminates the template recursion, and computes reductions in a for loop
template<class ElemType, C_size_t N, C_int M, C_int K>
struct TensorOpElement<ElemType, N, M, K, /*parallelReduce=*/false, /*k=*/-1>
{
// template-recursion-teminating version computes the actual value for this output location
// now the output pointers point to the right element (input pointers may still iterate for reduction)
static __device__ void Compute(CUDA_LONG /*id*/, ElemType beta, FixedArray<ElemType*, N> & pointers, ElemType alpha, ElementWiseOperator op,
const FixedArray<C_unsigned_int, K> & /*regularOpStrides*/, const FixedMatrix<C_int, N, K> & /*regularStrides*/,
const FixedArray<C_unsigned_int, M> & reducingOpDims, const FixedMatrix<C_int, N, M> & reducingStrides)
{
// compute the operation for this output coordinate
// This may still involve a reduction over inverse-broadcasting dimensions.
ElemType val = TensorOpReduce<ElemType, N, M, M - 1>::Compute(pointers, op, reducingOpDims, reducingStrides);
// and save the final value
SetFinalValue(val, beta, pointers, alpha);
}
};
// specialization for k = -1 terminates the template recursion, and computes reductions in parallel
template<class ElemType, C_size_t N, C_int M, C_int K>
struct TensorOpElement<ElemType, N, M, K, /*parallelReduce=*/true, /*k=*/-1>
{
// template-recursion-teminating version computes the actual value for this output location
// now the output pointers point to the right element (input pointers may still iterate for reduction)
static __device__ void Compute(CUDA_LONG /*id*/, ElemType beta, FixedArray<ElemType*, N> & pointers, ElemType alpha, ElementWiseOperator op,
const FixedArray<C_unsigned_int, K> & /*regularOpStrides*/, const FixedMatrix<C_int, N, K> & /*regularStrides*/,
const FixedArray<C_unsigned_int, M> & reducingOpDims, const FixedMatrix<C_int, N, M> & reducingStrides)
{
CUDA_LONG reductionDim = blockDim.x;
CUDA_LONG redId = threadIdx.x;
// accumulator
__shared__ double accumulators[GridDim::maxThreadsPerBlock];
// compute the operation for this input coordinate
accumulators[redId] = TensorOpParallelReduce<ElemType, N, M, M - 1>::Compute(redId, pointers, op, reducingOpDims, reducingStrides);
// reduce
__syncthreads(); // watch out to do this outside 'if (inRange)'
if (redId == 0)
{
double sum = 0;
for (int i = 0; i < reductionDim; i++)
sum += accumulators[i];
accumulators[0] = sum;
}
__syncthreads();
// now set final value to output coordinate
if (redId == 0)
{
ElemType val = (ElemType)accumulators[0];
SetFinalValue(val, beta, pointers, alpha);
}
}
};
// -----------------------------------------------------------------------
// kernel and launch
// -----------------------------------------------------------------------
// the top-level kernel
template<class ElemType, C_size_t N, C_int M, C_int K>
__global__ void _launchTensorOp(ElemType beta, FixedArray<ElemType*, N> pointers, ElemType alpha, ElementWiseOperator op,
FixedArray<C_unsigned_int, K> regularOpStrides, FixedMatrix<C_int, N, K> regularStrides,
FixedArray<C_unsigned_int, M> reducingOpDims, FixedMatrix<C_int, N, M> reducingStrides, CUDA_LONG numElements)
{
CUDA_LONG id = GridDim::GetLinearThreadId();
if (id < numElements)
TensorOpElement<ElemType, N, M, K, false, K - 1>::Compute(id, beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides);
}
template<class ElemType, C_size_t N, C_int M, C_int K>
__global__ void _launchTensorOpParallelReduction(ElemType beta, FixedArray<ElemType*, N> pointers, ElemType alpha, ElementWiseOperator op,
FixedArray<C_unsigned_int, K> regularOpStrides, FixedMatrix<C_int, N, K> regularStrides,
FixedArray<C_unsigned_int, M> reducingOpDims, FixedMatrix<C_int, N, M> reducingStrides, CUDA_LONG numElements)
{
CUDA_LONG id = gridDim.y * blockIdx.x + blockIdx.y; // input dimensions are Y dimension of blocks in this case, so we can use thread dim for shared-memory/parallelization
if (id < numElements)
TensorOpElement<ElemType, N, M, K, true, K - 1>::Compute(id, beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides);
}
// launch tensor op with CUDA
// All dimensions (N-ariness, number of input dimensions K and number of reduction dimensions M) are bound to template parameters now.
template<class ElemType, C_size_t N, C_int M, C_int K>
static void LaunchTensorOp(ElemType beta, array<ElemType*, N> pointerVector, ElemType alpha, ElementWiseOperator op,
const SmallVector<size_t> & regularOpDims, const array<SmallVector<ptrdiff_t>, N> & regularStrideVectors,
const SmallVector<size_t> & reducingOpDimVector, const array<SmallVector<ptrdiff_t>, N> & reducingStrideVectors)
{
// copy all parameters to CUDA-compatible data structures
FixedArray<ElemType*, N> pointers(pointerVector);
SmallVector<C_size_t> regularOpStrideVector; // kernel needs the strides for converting thread index back to multi-dimensional tensor index
C_size_t numElements = 1;
for (C_size_t k = 0; k < regularOpDims.size(); k++)
{
regularOpStrideVector.push_back(numElements);
numElements *= (C_size_t)regularOpDims[k];
}
FixedArray<C_unsigned_int, K> regularOpStrides(regularOpStrideVector);
FixedMatrix<C_int, N, K> regularStrides(regularStrideVectors);
FixedArray<C_unsigned_int, M> reducingOpDims(reducingOpDimVector);
FixedMatrix<C_int, N, M> reducingStrides(reducingStrideVectors);
// launch the kernel
CUDA_LONG NN = (CUDA_LONG)numElements; // linear space identifying each individual input element
cudaEvent_t done = nullptr;
if (do_sync) CUDA_CALL(cudaEventCreate(&done));
// do some optimization for reductions
// Cases:
// - input elements >> GPU procs --> do reduction in inner loop
// - reduction dimension fits into a single kernel --> launch it that way
// - reduction dimension requires multiple kernels --> use atomic add, to avoid temp mem alloc --is this any good?
// - PlusNode: reducing to a bias for small matrices
// - ScaleNode: big elementwise product reduced to a scalar (dot product)
#if 1
C_size_t reductionDim = 1; // number of elements to reduce over
for (C_size_t k = 0; k < reducingOpDimVector.size(); k++)
reductionDim *= (C_size_t)reducingOpDimVector[k];
let & props = GridDim::GetDeviceProps();
if (reductionDim > 1 && NN < props.multiProcessorCount * props.warpSize && reductionDim <= GridDim::maxThreadsPerBlock)
{
if (reductionDim <= GridDim::maxThreadsPerBlock)
{
// one thread block per reduction is sufficient
// TODO: In the special case where reduction dim <= 16 (half warp size), we could fit more than one reduction.
GridDim grid(NN);
let blocksPerGrid = dim3(grid.m_blocksPerGrid, grid.m_threadsPerBlock); // block Y is element dimension
let threadsPerBlock = dim3(reductionDim); // X dimension is reduction dimension
_launchTensorOpParallelReduction<ElemType, N, M, K> << <blocksPerGrid, threadsPerBlock, reductionDim * sizeof(double), t_stream >> >(beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides, grid.m_N);
}
else
{
// we need more than one block for each reduction
// Temporary memory is required.
#if 0
// We have too few elements to do reduction in inner loop. Need to be more clever.
CUDA_LONG numReductionThreads = min(reductionDim, GridDim::maxThreadsPerBlock);
// TODO: special case: <= half warp size: We can fit more than one.
// round up to multiples of warp size
numReductionThreads = (numReductionThreads + props.warpSize - 1) / props.warpSize * props.warpSize;
// atomicAdd mode: need to zero out/pre-multiply first
if (reductionDim > numReductionThreads)
{
LaunchTensorOp<ElemType, /*N=*/1, /*M=*/0, K>(beta, array<ElemType*, 1> { pointerVector.back() }, /*alpha=*/0, ElementWiseOperator::opConstOne,
regularOpDims, array<SmallVector<ptrdiff_t>, 1> { regularStrideVectors.back() },
SmallVector<size_t>(), array<SmallVector<ptrdiff_t>, 1>());
beta = 1; // and actual operation now adds with weight 1 since we already initialized/pre-multiplied
}
size_t numReductionDoubles = numReductionThreads; // using 4k shared mem out of 48k, should be OK
GridDim grid(NN);
_launchTensorOp<ElemType, N, M, K> << <grid.m_blocksPerGrid, grid.m_threadsPerBlock, numReductionDoubles * sizeof(ElemType), t_stream >> >(beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides, grid.m_N);
#endif
}
}
else
#endif
{
GridDim grid(NN);
_launchTensorOp<ElemType, N, M, K> << <grid.m_blocksPerGrid, grid.m_threadsPerBlock, 0, t_stream >> >(beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides, grid.m_N);
}
if (do_sync) CUDA_CALL(cudaEventRecord(done));
if (do_sync) CUDA_CALL(cudaEventSynchronize(done));
if (do_sync) CUDA_CALL(cudaEventDestroy(done));
}
// for linear unary ops, we need to define a functor for every function for use as a template parameter (lambda syntax doesn't work in CUDA 7)
#define DefineUnaryTensorFunctor(oper) \
struct Functor ## oper { template<class ElemType> static __device__ ElemType f(ElemType a) { return Op ## oper(a); } };
ForAllUnaryOps(DefineUnaryTensorFunctor);
// the top-level kernel for linear unary ops
// Note: If we have a beta, we have 2 memory accesses, so this optimization may no longer be needed as we are memory-bound.
template<class ElemType, class FN>
__global__ void _launchUnaryTensorOp(ElemType beta, const ElemType * pa, ElemType * pb, ElemType alpha, CUDA_LONG numElements)
{
CUDA_LONG id = GridDim::GetLinearThreadId();
if (id >= numElements)
return;
ElemType a = pa[id];
ElemType val = FN::f(a);
val *= alpha;
if (beta != 0)
val += beta * pb[id];
pb[id] = val;
}
// version without beta and alpha
template<class ElemType, class FN>
__global__ void _launchUnaryTensorOp(const ElemType * pa, ElemType * pb, CUDA_LONG numElements)
{
CUDA_LONG id = GridDim::GetLinearThreadId();
if (id >= numElements)
return;
ElemType a = pa[id];
ElemType val = FN::f(a);
pb[id] = val;
}
// special case of linear unary operation
template<class ElemType>
static void LaunchUnaryTensorOp(ElemType beta, const ElemType * pa, ElemType * pb, ElemType alpha, ElementWiseOperator op, size_t regularOpDim)
{
CUDA_LONG NN = (CUDA_LONG)regularOpDim;
#define CaseLaunchUnaryTensorOp(oper) case ElementWiseOperator::op ## oper: \
if (beta == 0 && alpha == 1) \
return _launchUnaryTensorOp<ElemType,Functor ## oper> << <grid.m_blocksPerGrid, grid.m_threadsPerBlock, 0, t_stream >> >(pa, pb, NN); \
else \
return _launchUnaryTensorOp<ElemType,Functor ## oper> << <grid.m_blocksPerGrid, grid.m_threadsPerBlock, 0, t_stream >> >(beta, pa, pb, alpha, NN);
cudaEvent_t done = nullptr;
if (do_sync) CUDA_CALL(cudaEventCreate(&done));
GridDim grid(NN);
switch (op)
{
ForAllUnaryOps(CaseLaunchUnaryTensorOp);
default: LogicError("LaunchTensorOp1: Unknown op code %d.", (int)op);
}
if (do_sync) CUDA_CALL(cudaEventRecord(done));
if (do_sync) CUDA_CALL(cudaEventSynchronize(done));
if (do_sync) CUDA_CALL(cudaEventDestroy(done));
}
// -----------------------------------------------------------------------
// map runtime parameters N to template parameters
// -----------------------------------------------------------------------
// tensor operation with k+1 dimensions (-1 means scalar)
template<class ElemType, C_size_t N, C_int K>
static void TensorOpWithRegularLoop(ElemType beta, const array<ElemType*, N> & pointers, ElemType alpha, ElementWiseOperator op,
const SmallVector<size_t> & regularOpDims, const array<SmallVector<ptrdiff_t>, N> & regularStrides,
const SmallVector<size_t> & reducingOpDims, const array<SmallVector<ptrdiff_t>, N> & reducingStrides)
{
size_t dims = reducingOpDims.size();
switch (dims)
{
case 2: return LaunchTensorOp<ElemType, N, 2, K>(beta, pointers, alpha, op, regularOpDims, regularStrides, reducingOpDims, reducingStrides);
case 1: return LaunchTensorOp<ElemType, N, 1, K>(beta, pointers, alpha, op, regularOpDims, regularStrides, reducingOpDims, reducingStrides);
case 0: return LaunchTensorOp<ElemType, N, 0, K>(beta, pointers, alpha, op, regularOpDims, regularStrides, reducingOpDims, reducingStrides);
default: LogicError("TensorOp: %d non-flattened reduction dimensions are not supported.", (C_int)dims);
}
}
// tensor operation, generalized in number of arguments
// This function now expands into different k. It also eliminates the offsets by adding them to the pointers.
template<class ElemType, C_size_t N>
static void TensorOpN(ElemType beta, array<ElemType*, N> pointers, ElemType alpha, ElementWiseOperator op,
const array<size_t, N> & offsets,
const SmallVector<size_t> & regularOpDims, const array<SmallVector<ptrdiff_t>, N> & regularStrides,
const SmallVector<size_t> & reducingOpDims, const array<SmallVector<ptrdiff_t>, N> & reducingStrides)
{
for (C_size_t i = 0; i < N; i++) // N = a small constant, this will be unrolled
pointers[i] += offsets[i];
size_t dims = regularOpDims.size();
switch (dims)
{
case 4: return TensorOpWithRegularLoop<ElemType, N, 4>(beta, pointers, alpha, op, regularOpDims, regularStrides, reducingOpDims, reducingStrides);
case 3: return TensorOpWithRegularLoop<ElemType, N, 3>(beta, pointers, alpha, op, regularOpDims, regularStrides, reducingOpDims, reducingStrides);
case 2: return TensorOpWithRegularLoop<ElemType, N, 2>(beta, pointers, alpha, op, regularOpDims, regularStrides, reducingOpDims, reducingStrides);
case 1: return TensorOpWithRegularLoop<ElemType, N, 1>(beta, pointers, alpha, op, regularOpDims, regularStrides, reducingOpDims, reducingStrides);
case 0: return TensorOpWithRegularLoop<ElemType, N, 0>(beta, pointers, alpha, op, regularOpDims, regularStrides, reducingOpDims, reducingStrides);
default: LogicError("TensorOp: %d non-flattened input dimensions are not supported.", (C_int)dims);
}
}
// -----------------------------------------------------------------------
// entry points from Matrix.cpp
// TensorView entry points from Matrix.cpp
// -----------------------------------------------------------------------
// perform unary operation 'op' on a giving 'this', reinterpreting the matrices as tensors as specified by the dims and strides
@ -5026,7 +4490,6 @@ namespace Microsoft { namespace MSR { namespace CNTK {
return TensorOpN<ElemType, 4>(beta, array<ElemType*, 4> { a.m_pArray, b.m_pArray, c.m_pArray, m_pArray }, alpha, op, offsets, regularOpDims, regularStrides, reducingOpDims, reducingStrides);
}
// =======================================================================
// explicit instantiations business
// =======================================================================
@ -5037,10 +4500,10 @@ namespace Microsoft { namespace MSR { namespace CNTK {
template class DeviceBoundNumber<double>;
template<class ElemType>
cublasHandle_t GPUMatrix<ElemType>::s_cuHandle[GPUMatrix<ElemType>::MaxGpus]={0};
cublasHandle_t GPUMatrix<ElemType>::s_cuHandle[GPUMatrix<ElemType>::MaxGpus] = { 0 };
template<class ElemType>
void* GPUMatrix<ElemType>::s_curandGenerator=NULL;
void* GPUMatrix<ElemType>::s_curandGenerator = NULL;
// We use Matrix<char> as the backing store for QuantizedMatrix
// Let's explicitly instantiate the methods we need for that purpose

Просмотреть файл

@ -47,9 +47,7 @@ typedef struct CUstream_st *cudaStream_t;
void MATH_API SetStream(cudaStream_t stream);
cudaStream_t MATH_API GetStream();
namespace Microsoft {
namespace MSR {
namespace CNTK {
namespace Microsoft { namespace MSR { namespace CNTK {
// -----------------------------------------------------------------------
// DeviceBoundNumber -- This class represents a number which resides on a particular device. Use it to avoid unnecessary transfers between CPU and GPU
@ -506,7 +504,7 @@ namespace Microsoft {
}}}
// Error handling
template<typename ERRTYPE> static const char * CudaErrString(ERRTYPE x);
template<typename ERRTYPE> const char * CudaErrString(ERRTYPE x);
template<typename ERRTYPE> static void CudaCall(ERRTYPE retCode, const char * exprString, const char * libName, ERRTYPE successCode)
{
if (retCode != successCode)

Просмотреть файл

@ -4,16 +4,18 @@
// </copyright>
//
#pragma once
#include "BestGpu.h"
#ifndef CPUONLY
#include <float.h>
#include <cuda_runtime.h>
#include "CommonMatrix.h"
#include "GPUMatrix.h"
#include "device_functions.h"
#include <cuda_runtime.h>
#include <assert.h>
#include <float.h>
// REVIEW alexeyk: disable warnings properly for GCC/clang
#ifdef _MSC_VER

609
Source/Math/GPUTensor.cu Normal file
Просмотреть файл

@ -0,0 +1,609 @@
//
// <copyright file="GPUMatrix.cu" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//
#include "stdafx.h"
#include "Basics.h"
#include "BestGpu.h"
//#include "DebugUtil.h"
#ifndef CPUONLY
#include "GPUTensor.h"
#include "GPUMatrix.h"
#include "GPUMatrixCUDAKernels.cuh"
#include "CommonMatrix.h"
#define TENSOR_OPS_DECL __device__ __host__
#include "TensorOps.h"
#include <cuda.h>
#include <cuda_runtime.h>
#include "cublas_v2.h"
#include <assert.h>
#ifndef let
#define let const auto
#endif
#pragma comment (lib, "cudart.lib") // instruct linker to reference these libs
#pragma comment (lib, "cublas.lib")
#pragma warning (disable: 4267) // conversion from 'size_t' to 'unsigned int'; happens in CUDA <<<a,b>>> syntax if a and b are size_t
#pragma warning (disable: 4127) // conditional expression is constant; "if (sizeof(ElemType)==sizeof(float))" triggers this
#pragma warning (disable: 4702) // unreachable code; triggered for unknown reasons
extern bool do_sync;
#ifdef _WIN32
// thread local storage to access the current stream, initalize to default stream
__declspec (thread)
#else
static
#endif
extern cudaStream_t t_stream;
namespace Microsoft { namespace MSR { namespace CNTK {
// =======================================================================
// TensorView support
// =======================================================================
// To save time, this makes extensive use of templates and macros.
// -----------------------------------------------------------------------
// simple fixed-size arrays for passing dimension information by value
// since CUDA can't just take our std::array and std::vector
// -----------------------------------------------------------------------
template<typename T, size_t N>
struct FixedArray
{
T m_data[N];
__device__ __host__ size_t size() const { return N; }
__device__ __host__ T & operator[](size_t n) { return m_data[n]; }
__device__ __host__ T operator[](size_t n) const { return m_data[n]; }
template<class VEC> FixedArray(const VEC & data) // construct from CPU-side STL array or vector
{
assert(data.size() == N);
for (size_t n = 0; n < N; n++)
{
m_data[n] = (T)data[n];
if (m_data[n] != data[n]) // overflow check
InvalidArgument("FixedArray: Dimensions out of range, too few bits.");
}
}
};
template<typename T> // specialized version for 0 elements
struct FixedArray<T, 0>
{
__device__ __host__ size_t size() const { return 0; }
template<class VEC> FixedArray(const VEC & data) { assert(data.size() == 0); UNUSED(data); }
};
template<typename T, size_t N, size_t K> // N = which input/output; K = index depth
struct FixedMatrix
{
T m_data[N][K];
__device__ __host__ size_t getNumRows() const { return N; }
__device__ __host__ size_t getNumCols() const { return K; }
__device__ __host__ T & operator()(size_t n, size_t k) { return m_data[n][k]; }
__device__ __host__ T operator()(size_t n, size_t k) const { return m_data[n][k]; }
template<typename U> FixedMatrix(const array<SmallVector<U>, N> & data) // construct from CPU-side array of vectors
{
assert(data.size() == N);
for (size_t n = 0; n < N; n++)
{
assert(data[n].size() == K);
for (size_t k = 0; k < K; k++)
{
m_data[n][k] = (T)data[n][k];
if (m_data[n][k] != data[n][k]) // overflow check
InvalidArgument("FixedArray: Dimensions out of range, too few bits.");
}
}
}
};
template<typename T, size_t N> // specialized version for 0 elements
struct FixedMatrix<T, N, 0>
{
__device__ __host__ size_t getNumRows() const { return N; }
__device__ __host__ size_t getNumCols() const { return 0; }
template<typename U> FixedMatrix(const array<SmallVector<U>, N> & data) { assert(data.size() == N); for (size_t n = 0; n < N; n++) assert(data[n].size() == 0); UNUSED(data); }
};
// -----------------------------------------------------------------------
// function to actually compute a function of (N-1) inputs based on the opcode
// -----------------------------------------------------------------------
template<class ElemType>
struct TensorOps
{
static __device__ ElemType Compute(const FixedArray<ElemType*, 1> & pointers, ElementWiseOperator op)
{
#define CaseNullaryTensorOp(oper) case ElementWiseOperator::op ## oper: return Op ## oper<ElemType>()
switch (op)
{
ForAllNullaryOps(CaseNullaryTensorOp);
default: return OpConstOne<ElemType>(); // (failure--we only have one nullary op, so use the same, maybe it will eliminate the switch altogether)
}
}
static __device__ ElemType Compute(const FixedArray<ElemType*, 2> & pointers, ElementWiseOperator op)
{
ElemType a = *(pointers[0]);
#define CaseUnaryTensorOp(oper) case ElementWiseOperator::op ## oper: return Op ## oper(a)
switch (op)
{
ForAllUnaryOps(CaseUnaryTensorOp);
default: return 0; // (failure)
}
}
static __device__ ElemType Compute(const FixedArray<ElemType*, 3> & pointers, ElementWiseOperator op)
{
ElemType a = *(pointers[0]);
ElemType b = *(pointers[1]);
#define CaseBinaryTensorOp(oper) case ElementWiseOperator::op ## oper: return Op ## oper(a,b)
switch (op)
{
ForAllBinaryOps(CaseBinaryTensorOp); // note: this costs about 6% compared to having only a single case
default: return 0; // (failure)
}
}
static __device__ ElemType Compute(const FixedArray<ElemType*, 4> & pointers, ElementWiseOperator op)
{
ElemType a = *(pointers[0]);
ElemType b = *(pointers[1]);
ElemType c = *(pointers[2]);
#define CaseTernaryTensorOp(oper) case ElementWiseOperator::op ## oper: return Op ## oper(a,b,c)
switch (op)
{
ForAllTernaryOps(CaseTernaryTensorOp);
default: return 0; // (failure)
}
}
};
// -----------------------------------------------------------------------
// function to compute the value for a given output location (this version performs reduction if needed)
// -----------------------------------------------------------------------
template<class ElemType, C_size_t N, C_int M, C_int m>
struct TensorOpReduce
{
// this version for m >= 0
static __device__ ElemType Compute(FixedArray<ElemType*, N> pointers, ElementWiseOperator op,
const FixedArray<C_unsigned_int, M> & reducingOpDims, const FixedMatrix<C_int, N, M> & reducingStrides)
{
// start with index 0
// Using 'double' since we are memory-bound anyway.
double/*ElemType*/ aggregate = TensorOpReduce<ElemType, N, M, m - 1>::Compute(pointers, op, reducingOpDims, reducingStrides);
// apply this index to the pointers
C_size_t dim = reducingOpDims[m];
for (C_size_t k = 1/*done with k=0 already*/; k < dim; k++)
{
// bump the pointers
for (C_size_t i = 0; i < N; i++)
pointers[i] += reducingStrides(i,(C_size_t)m);
ElemType val = TensorOpReduce<ElemType, N, M, m - 1>::Compute(pointers, op, reducingOpDims, reducingStrides);
aggregate += val;
}
return (ElemType)aggregate;
}
};
// this one terminates the template recursion over reduction dimensions
// The pointers are pointing to the input element.
template<class ElemType, C_size_t N, C_int M>
struct TensorOpReduce<ElemType, N, M, /*m=*/-1>
{
// this version for m = -1
// the pointers are pointing to the right location(s) to take the operation over
static __device__ ElemType Compute(FixedArray<ElemType*, N> pointers, ElementWiseOperator op,
const FixedArray<C_unsigned_int, M> & /*reducingOpDims*/, const FixedMatrix<C_int, N, M> & /*reducingStrides*/)
{
return TensorOps<ElemType>::Compute(pointers, op); // finally computing something!
}
};
// -----------------------------------------------------------------------
// function to compute one constituent of the value for a given output location (this version has reduction done outside)
// -----------------------------------------------------------------------
template<class ElemType, C_size_t N, C_int M, C_int m>
struct TensorOpParallelReduce
{
// this version for m >= 0
static __device__ ElemType Compute(CUDA_LONG id, FixedArray<ElemType*, N> pointers, ElementWiseOperator op,
const FixedArray<C_unsigned_int, M> & reducingOpDims, const FixedMatrix<C_int, N, M> & reducingStrides)
{
// map id (location on grid) to index[k]
C_size_t stride = 1; // compute the stride. This seems expensive, but since we we only currently support M <= 2, this is just compile-time selection between 1 and reducingOpDims[0].
for (int i = 0; i < m; i++)
stride *= reducingOpDims[(C_size_t)i];
C_size_t index = id / stride; // this dimension. For m=0, the stride is 1 and hence the division will be removed at compile time.
id = id % stride; // remaining dimensions inside this. For m=0 this value is ignored and hence not even computed.
// apply this index to the pointers
for (C_size_t i = 0; i < N; i++)
pointers[i] += index * reducingStrides(i, (C_size_t)m); // now this dimension is taken care of
return TensorOpParallelReduce<ElemType, N, M, m - 1>::Compute(id, pointers, op, reducingOpDims, reducingStrides);
}
};
// this one terminates the template recursion over reduction dimensions
// The pointers are pointing to the input element.
template<class ElemType, C_size_t N, C_int M>
struct TensorOpParallelReduce<ElemType, N, M, /*m=*/-1>
{
// this version for m = -1
// the pointers are pointing to the right location(s) to take the operation over
static __device__ ElemType Compute(CUDA_LONG /*id*/, FixedArray<ElemType*, N> pointers, ElementWiseOperator op,
const FixedArray<C_unsigned_int, M> & /*reducingOpDims*/, const FixedMatrix<C_int, N, M> & /*reducingStrides*/)
{
return TensorOps<ElemType>::Compute(pointers, op); // finally computing something!
}
};
// -----------------------------------------------------------------------
// perform loop over regular index k for N-nary operations (N counting the output)
// -----------------------------------------------------------------------
// The canonical case, vector op without reduction, is this PTX function:
// _ZN9Microsoft3MSR4CNTK15_launchTensorOpIfLi3ELi0ELi1EEEvT_NS1_10FixedArrayIPS3_XT0_EEES3_NS1_19ElementWiseOperatorENS4_IiXT2_EEENS1_11FixedMatrixIiXT0_EXT2_EEENS4_IiXT1_EEENS9_IiXT0_EXT1_EEEi
// float ^ ^ aggregate loop
// args? ^ ^ input dims
// _ZN9Microsoft3MSR4CNTK15_launchTensorOpIfLi2ELi0ELi1EEEvT_NS1_10FixedArrayIPS3_XT0_EEES3_NS1_19ElementWiseOperatorENS4_IiXT2_EEENS1_11FixedMatrixIiXT0_EXT2_EEENS4_IiXT1_EEENS9_IiXT0_EXT1_EEEi
// The 'pointers' only refer to a single element, so we will bump them in-place to perform indexing.
template<class ElemType, C_size_t N, C_int M, C_int K, bool parallelReduce, C_int k>
struct TensorOpElement
{
// template-recursive version loops over indices
static __device__ void Compute(CUDA_LONG id, ElemType beta, FixedArray<ElemType*, N> & pointers, ElemType alpha, ElementWiseOperator op,
const FixedArray<C_unsigned_int, K> & regularOpStrides, const FixedMatrix<C_int, N, K> & regularStrides,
const FixedArray<C_unsigned_int, M> & reducingOpDims, const FixedMatrix<C_int, N, M> & reducingStrides)
{
// map id (location on grid) to index[k]
C_size_t stride = regularOpStrides[(C_size_t)k];
C_size_t index = id / stride; // this dimension
id = id % stride; // remaining dimensions inside this
// apply this index to the pointers
for (C_size_t i = 0; i < N; i++)
pointers[i] += index * regularStrides(i,(C_size_t)k); // now this dimension is taken care of
// process the previous index
TensorOpElement<ElemType, N, M, K, parallelReduce, k - 1>::Compute(id, beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides);
}
};
// specialization for k=0 where op stride is guaranteed to be 1
template<class ElemType, C_size_t N, C_int M, C_int K, bool parallelReduce>
struct TensorOpElement<ElemType, N, M, K, parallelReduce, /*k=*/0>
{
// template-recursive version loops over indices
static __device__ void Compute(CUDA_LONG id, ElemType beta, FixedArray<ElemType*, N> & pointers, ElemType alpha, ElementWiseOperator op,
const FixedArray<C_unsigned_int, K> & regularOpStrides, const FixedMatrix<C_int, N, K> & regularStrides,
const FixedArray<C_unsigned_int, M> & reducingOpDims, const FixedMatrix<C_int, N, M> & reducingStrides)
{
// map id (location on grid) to index[k]
C_size_t index = id; // this dimension
// apply this index to the pointers
for (C_size_t i = 0; i < N; i++)
pointers[i] += index * regularStrides(i,0); // now this dimension is taken care of
// process the previous index
TensorOpElement<ElemType, N, M, K, parallelReduce, -1>::Compute(/*id*/0, beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides);
}
};
// apply beta and alpha and save
template<class ElemType, class PointersType>
static __device__ void SetFinalValue(ElemType val, ElemType beta, const PointersType & pointers, ElemType alpha)
{
// scale
val *= alpha;
// combine with previous value in target matrix, then write it out
auto * pout = pointers[pointers.size() - 1];
if (beta != 0)
val += beta * *pout;
// save
*pout = val;
}
// specialization for k = -1 terminates the template recursion, and computes reductions in a for loop
template<class ElemType, C_size_t N, C_int M, C_int K>
struct TensorOpElement<ElemType, N, M, K, /*parallelReduce=*/false, /*k=*/-1>
{
// template-recursion-teminating version computes the actual value for this output location
// now the output pointers point to the right element (input pointers may still iterate for reduction)
static __device__ void Compute(CUDA_LONG /*id*/, ElemType beta, FixedArray<ElemType*, N> & pointers, ElemType alpha, ElementWiseOperator op,
const FixedArray<C_unsigned_int, K> & /*regularOpStrides*/, const FixedMatrix<C_int, N, K> & /*regularStrides*/,
const FixedArray<C_unsigned_int, M> & reducingOpDims, const FixedMatrix<C_int, N, M> & reducingStrides)
{
// compute the operation for this output coordinate
// This may still involve a reduction over inverse-broadcasting dimensions.
ElemType val = TensorOpReduce<ElemType, N, M, M - 1>::Compute(pointers, op, reducingOpDims, reducingStrides);
// and save the final value
SetFinalValue(val, beta, pointers, alpha);
}
};
// specialization for k = -1 terminates the template recursion, and computes reductions in parallel
template<class ElemType, C_size_t N, C_int M, C_int K>
struct TensorOpElement<ElemType, N, M, K, /*parallelReduce=*/true, /*k=*/-1>
{
// template-recursion-teminating version computes the actual value for this output location
// now the output pointers point to the right element (input pointers may still iterate for reduction)
static __device__ void Compute(CUDA_LONG /*id*/, ElemType beta, FixedArray<ElemType*, N> & pointers, ElemType alpha, ElementWiseOperator op,
const FixedArray<C_unsigned_int, K> & /*regularOpStrides*/, const FixedMatrix<C_int, N, K> & /*regularStrides*/,
const FixedArray<C_unsigned_int, M> & reducingOpDims, const FixedMatrix<C_int, N, M> & reducingStrides)
{
CUDA_LONG reductionDim = blockDim.x;
CUDA_LONG redId = threadIdx.x;
// accumulator
__shared__ double accumulators[GridDim::maxThreadsPerBlock];
// compute the operation for this input coordinate
accumulators[redId] = TensorOpParallelReduce<ElemType, N, M, M - 1>::Compute(redId, pointers, op, reducingOpDims, reducingStrides);
// reduce
__syncthreads(); // watch out to do this outside 'if (inRange)'
if (redId == 0)
{
double sum = 0;
for (int i = 0; i < reductionDim; i++)
sum += accumulators[i];
accumulators[0] = sum;
}
__syncthreads();
// now set final value to output coordinate
if (redId == 0)
{
ElemType val = (ElemType)accumulators[0];
SetFinalValue(val, beta, pointers, alpha);
}
}
};
// -----------------------------------------------------------------------
// kernel and launch
// -----------------------------------------------------------------------
// the top-level kernel
template<class ElemType, C_size_t N, C_int M, C_int K>
__global__ void _launchTensorOp(ElemType beta, FixedArray<ElemType*, N> pointers, ElemType alpha, ElementWiseOperator op,
FixedArray<C_unsigned_int, K> regularOpStrides, FixedMatrix<C_int, N, K> regularStrides,
FixedArray<C_unsigned_int, M> reducingOpDims, FixedMatrix<C_int, N, M> reducingStrides, CUDA_LONG numElements)
{
CUDA_LONG id = GridDim::GetLinearThreadId();
if (id < numElements)
TensorOpElement<ElemType, N, M, K, false, K - 1>::Compute(id, beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides);
}
template<class ElemType, C_size_t N, C_int M, C_int K>
__global__ void _launchTensorOpParallelReduction(ElemType beta, FixedArray<ElemType*, N> pointers, ElemType alpha, ElementWiseOperator op,
FixedArray<C_unsigned_int, K> regularOpStrides, FixedMatrix<C_int, N, K> regularStrides,
FixedArray<C_unsigned_int, M> reducingOpDims, FixedMatrix<C_int, N, M> reducingStrides, CUDA_LONG numElements)
{
CUDA_LONG id = gridDim.y * blockIdx.x + blockIdx.y; // input dimensions are Y dimension of blocks in this case, so we can use thread dim for shared-memory/parallelization
if (id < numElements)
TensorOpElement<ElemType, N, M, K, true, K - 1>::Compute(id, beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides);
}
// launch tensor op with CUDA
// All dimensions (N-ariness, number of input dimensions K and number of reduction dimensions M) are bound to template parameters now.
template<class ElemType, C_size_t N, C_int M, C_int K>
static void LaunchTensorOp(ElemType beta, array<ElemType*, N> pointerVector, ElemType alpha, ElementWiseOperator op,
const SmallVector<size_t> & regularOpDims, const array<SmallVector<ptrdiff_t>, N> & regularStrideVectors,
const SmallVector<size_t> & reducingOpDimVector, const array<SmallVector<ptrdiff_t>, N> & reducingStrideVectors)
{
// copy all parameters to CUDA-compatible data structures
FixedArray<ElemType*, N> pointers(pointerVector);
SmallVector<C_size_t> regularOpStrideVector; // kernel needs the strides for converting thread index back to multi-dimensional tensor index
C_size_t numElements = 1;
for (C_size_t k = 0; k < regularOpDims.size(); k++)
{
regularOpStrideVector.push_back(numElements);
numElements *= (C_size_t)regularOpDims[k];
}
FixedArray<C_unsigned_int, K> regularOpStrides(regularOpStrideVector);
FixedMatrix<C_int, N, K> regularStrides(regularStrideVectors);
FixedArray<C_unsigned_int, M> reducingOpDims(reducingOpDimVector);
FixedMatrix<C_int, N, M> reducingStrides(reducingStrideVectors);
// launch the kernel
CUDA_LONG NN = (CUDA_LONG)numElements; // linear space identifying each individual input element
cudaEvent_t done = nullptr;
if (do_sync) CUDA_CALL(cudaEventCreate(&done));
// do some optimization for reductions
// Cases:
// - input elements >> GPU procs --> do reduction in inner loop
// - reduction dimension fits into a single kernel --> launch it that way
// - reduction dimension requires multiple kernels --> use atomic add, to avoid temp mem alloc --is this any good?
// - PlusNode: reducing to a bias for small matrices
// - ScaleNode: big elementwise product reduced to a scalar (dot product)
#if 1
C_size_t reductionDim = 1; // number of elements to reduce over
for (C_size_t k = 0; k < reducingOpDimVector.size(); k++)
reductionDim *= (C_size_t)reducingOpDimVector[k];
let & props = GridDim::GetDeviceProps();
if (reductionDim > 1 && NN < props.multiProcessorCount * props.warpSize && reductionDim <= GridDim::maxThreadsPerBlock)
{
if (reductionDim <= GridDim::maxThreadsPerBlock)
{
// one thread block per reduction is sufficient
// TODO: In the special case where reduction dim <= 16 (half warp size), we could fit more than one reduction.
GridDim grid(NN);
let blocksPerGrid = dim3(grid.m_blocksPerGrid, grid.m_threadsPerBlock); // block Y is element dimension
let threadsPerBlock = dim3(reductionDim); // X dimension is reduction dimension
_launchTensorOpParallelReduction<ElemType, N, M, K> << <blocksPerGrid, threadsPerBlock, reductionDim * sizeof(double), t_stream >> >(beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides, grid.m_N);
}
else
{
// we need more than one block for each reduction
// Temporary memory is required.
#if 0
// We have too few elements to do reduction in inner loop. Need to be more clever.
CUDA_LONG numReductionThreads = min(reductionDim, GridDim::maxThreadsPerBlock);
// TODO: special case: <= half warp size: We can fit more than one.
// round up to multiples of warp size
numReductionThreads = (numReductionThreads + props.warpSize - 1) / props.warpSize * props.warpSize;
// atomicAdd mode: need to zero out/pre-multiply first
if (reductionDim > numReductionThreads)
{
LaunchTensorOp<ElemType, /*N=*/1, /*M=*/0, K>(beta, array<ElemType*, 1> { pointerVector.back() }, /*alpha=*/0, ElementWiseOperator::opConstOne,
regularOpDims, array<SmallVector<ptrdiff_t>, 1> { regularStrideVectors.back() },
SmallVector<size_t>(), array<SmallVector<ptrdiff_t>, 1>());
beta = 1; // and actual operation now adds with weight 1 since we already initialized/pre-multiplied
}
size_t numReductionDoubles = numReductionThreads; // using 4k shared mem out of 48k, should be OK
GridDim grid(NN);
_launchTensorOp<ElemType, N, M, K> << <grid.m_blocksPerGrid, grid.m_threadsPerBlock, numReductionDoubles * sizeof(ElemType), t_stream >> >(beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides, grid.m_N);
#endif
}
}
else
#endif
{
GridDim grid(NN);
_launchTensorOp<ElemType, N, M, K> << <grid.m_blocksPerGrid, grid.m_threadsPerBlock, 0, t_stream >> >(beta, pointers, alpha, op, regularOpStrides, regularStrides, reducingOpDims, reducingStrides, grid.m_N);
}
if (do_sync) CUDA_CALL(cudaEventRecord(done));
if (do_sync) CUDA_CALL(cudaEventSynchronize(done));
if (do_sync) CUDA_CALL(cudaEventDestroy(done));
}
// for linear unary ops, we need to define a functor for every function for use as a template parameter (lambda syntax doesn't work in CUDA 7)
#define DefineUnaryTensorFunctor(oper) \
struct Functor ## oper { template<class ElemType> static __device__ ElemType f(ElemType a) { return Op ## oper(a); } };
ForAllUnaryOps(DefineUnaryTensorFunctor);
// the top-level kernel for linear unary ops
// Note: If we have a beta, we have 2 memory accesses, so this optimization may no longer be needed as we are memory-bound.
template<class ElemType, class FN>
__global__ void _launchUnaryTensorOp(ElemType beta, const ElemType * pa, ElemType * pb, ElemType alpha, CUDA_LONG numElements)
{
CUDA_LONG id = GridDim::GetLinearThreadId();
if (id >= numElements)
return;
ElemType a = pa[id];
ElemType val = FN::f(a);
val *= alpha;
if (beta != 0)
val += beta * pb[id];
pb[id] = val;
}
// version without beta and alpha
template<class ElemType, class FN>
__global__ void _launchUnaryTensorOp(const ElemType * pa, ElemType * pb, CUDA_LONG numElements)
{
CUDA_LONG id = GridDim::GetLinearThreadId();
if (id >= numElements)
return;
ElemType a = pa[id];
ElemType val = FN::f(a);
pb[id] = val;
}
// special case of linear unary operation
template<class ElemType>
static void LaunchUnaryTensorOp(ElemType beta, const ElemType * pa, ElemType * pb, ElemType alpha, ElementWiseOperator op, size_t regularOpDim)
{
CUDA_LONG NN = (CUDA_LONG)regularOpDim;
#define CaseLaunchUnaryTensorOp(oper) case ElementWiseOperator::op ## oper: \
if (beta == 0 && alpha == 1) \
return _launchUnaryTensorOp<ElemType,Functor ## oper> << <grid.m_blocksPerGrid, grid.m_threadsPerBlock, 0, t_stream >> >(pa, pb, NN); \
else \
return _launchUnaryTensorOp<ElemType,Functor ## oper> << <grid.m_blocksPerGrid, grid.m_threadsPerBlock, 0, t_stream >> >(beta, pa, pb, alpha, NN);
cudaEvent_t done = nullptr;
if (do_sync) CUDA_CALL(cudaEventCreate(&done));
GridDim grid(NN);
switch (op)
{
ForAllUnaryOps(CaseLaunchUnaryTensorOp);
default: LogicError("LaunchTensorOp1: Unknown op code %d.", (int)op);
}
if (do_sync) CUDA_CALL(cudaEventRecord(done));
if (do_sync) CUDA_CALL(cudaEventSynchronize(done));
if (do_sync) CUDA_CALL(cudaEventDestroy(done));
}
// -----------------------------------------------------------------------
// map runtime parameters N to template parameters
// -----------------------------------------------------------------------
// tensor operation with k+1 dimensions (-1 means scalar)
template<class ElemType, C_size_t N, C_int K>
static void TensorOpWithRegularLoop(ElemType beta, const array<ElemType*, N> & pointers, ElemType alpha, ElementWiseOperator op,
const SmallVector<size_t> & regularOpDims, const array<SmallVector<ptrdiff_t>, N> & regularStrides,
const SmallVector<size_t> & reducingOpDims, const array<SmallVector<ptrdiff_t>, N> & reducingStrides)
{
size_t dims = reducingOpDims.size();
switch (dims)
{
case 2: return LaunchTensorOp<ElemType, N, 2, K>(beta, pointers, alpha, op, regularOpDims, regularStrides, reducingOpDims, reducingStrides);
case 1: return LaunchTensorOp<ElemType, N, 1, K>(beta, pointers, alpha, op, regularOpDims, regularStrides, reducingOpDims, reducingStrides);
case 0: return LaunchTensorOp<ElemType, N, 0, K>(beta, pointers, alpha, op, regularOpDims, regularStrides, reducingOpDims, reducingStrides);
default: LogicError("TensorOp: %d non-flattened reduction dimensions are not supported.", (C_int)dims);
}
}
// tensor operation, generalized in number of arguments
// This function now expands into different k. It also eliminates the offsets by adding them to the pointers.
template<class ElemType, C_size_t N>
static void TensorOpN(ElemType beta, array<ElemType*, N> pointers, ElemType alpha, ElementWiseOperator op,
const array<size_t, N> & offsets,
const SmallVector<size_t> & regularOpDims, const array<SmallVector<ptrdiff_t>, N> & regularStrides,
const SmallVector<size_t> & reducingOpDims, const array<SmallVector<ptrdiff_t>, N> & reducingStrides)
{
for (C_size_t i = 0; i < N; i++) // N = a small constant, this will be unrolled
pointers[i] += offsets[i];
size_t dims = regularOpDims.size();
switch (dims)
{
case 4: return TensorOpWithRegularLoop<ElemType, N, 4>(beta, pointers, alpha, op, regularOpDims, regularStrides, reducingOpDims, reducingStrides);
case 3: return TensorOpWithRegularLoop<ElemType, N, 3>(beta, pointers, alpha, op, regularOpDims, regularStrides, reducingOpDims, reducingStrides);
case 2: return TensorOpWithRegularLoop<ElemType, N, 2>(beta, pointers, alpha, op, regularOpDims, regularStrides, reducingOpDims, reducingStrides);
case 1: return TensorOpWithRegularLoop<ElemType, N, 1>(beta, pointers, alpha, op, regularOpDims, regularStrides, reducingOpDims, reducingStrides);
case 0: return TensorOpWithRegularLoop<ElemType, N, 0>(beta, pointers, alpha, op, regularOpDims, regularStrides, reducingOpDims, reducingStrides);
default: LogicError("TensorOp: %d non-flattened input dimensions are not supported.", (C_int)dims);
}
}
//------------------------------------------------------------------------
// explicit instantiations--these are being called from GPUMatrix.cu
//------------------------------------------------------------------------
template void TensorOpN(float beta, array<float*, 2> pointers, float alpha, ElementWiseOperator op,
const array<size_t, 2> & offsets,
const SmallVector<size_t> & regularOpDims, const array<SmallVector<ptrdiff_t>, 2> & regularStrides,
const SmallVector<size_t> & reducingOpDims, const array<SmallVector<ptrdiff_t>, 2> & reducingStrides);
template void TensorOpN(float beta, array<float*, 3> pointers, float alpha, ElementWiseOperator op,
const array<size_t, 3> & offsets,
const SmallVector<size_t> & regularOpDims, const array<SmallVector<ptrdiff_t>, 3> & regularStrides,
const SmallVector<size_t> & reducingOpDims, const array<SmallVector<ptrdiff_t>, 3> & reducingStrides);
template void TensorOpN(float beta, array<float*, 4> pointers, float alpha, ElementWiseOperator op,
const array<size_t, 4> & offsets,
const SmallVector<size_t> & regularOpDims, const array<SmallVector<ptrdiff_t>, 4> & regularStrides,
const SmallVector<size_t> & reducingOpDims, const array<SmallVector<ptrdiff_t>, 4> & reducingStrides);
template void TensorOpN(double beta, array<double*, 2> pointers, double alpha, ElementWiseOperator op,
const array<size_t, 2> & offsets,
const SmallVector<size_t> & regularOpDims, const array<SmallVector<ptrdiff_t>, 2> & regularStrides,
const SmallVector<size_t> & reducingOpDims, const array<SmallVector<ptrdiff_t>, 2> & reducingStrides);
template void TensorOpN(double beta, array<double*, 3> pointers, double alpha, ElementWiseOperator op,
const array<size_t, 3> & offsets,
const SmallVector<size_t> & regularOpDims, const array<SmallVector<ptrdiff_t>, 3> & regularStrides,
const SmallVector<size_t> & reducingOpDims, const array<SmallVector<ptrdiff_t>, 3> & reducingStrides);
template void TensorOpN(double beta, array<double*, 4> pointers, double alpha, ElementWiseOperator op,
const array<size_t, 4> & offsets,
const SmallVector<size_t> & regularOpDims, const array<SmallVector<ptrdiff_t>, 4> & regularStrides,
const SmallVector<size_t> & reducingOpDims, const array<SmallVector<ptrdiff_t>, 4> & reducingStrides);
template void LaunchUnaryTensorOp(float beta, const float * pa, float * pb, float alpha, ElementWiseOperator op, size_t regularOpDim);
template void LaunchUnaryTensorOp(double beta, const double * pa, double * pb, double alpha, ElementWiseOperator op, size_t regularOpDim);
}}}
#endif // CPUONLY

30
Source/Math/GPUTensor.h Normal file
Просмотреть файл

@ -0,0 +1,30 @@
//
// <copyright file="GPUTensor.h" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//
#pragma once
#include "CommonMatrix.h"
#include "DataTensor.h" // only for SmallVector; I was hoping to keep this out
#include "GPUMatrixCUDAKernels.cuh"
#include <array>
namespace Microsoft { namespace MSR { namespace CNTK {
// GPUMatrix::TensorOp() interfaces with actual tensor code through these two functions, which are independent of the GPUMatrix class
#define C_size_t CUDA_LONG
#define C_int CUDA_LONG
#define C_unsigned_int CUDA_LONG
template<class ElemType, C_size_t N>
void TensorOpN(ElemType beta, array<ElemType*, N> pointers, ElemType alpha, ElementWiseOperator op,
const array<size_t, N> & offsets,
const SmallVector<size_t> & regularOpDims, const array<SmallVector<ptrdiff_t>, N> & regularStrides,
const SmallVector<size_t> & reducingOpDims, const array<SmallVector<ptrdiff_t>, N> & reducingStrides);
template<class ElemType>
void LaunchUnaryTensorOp(ElemType beta, const ElemType * pa, ElemType * pb, ElemType alpha, ElementWiseOperator op, size_t regularOpDim);
}}}

Просмотреть файл

@ -1,9 +1,7 @@
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup>
<ClCompile Include="dllmain.cpp" />
<ClCompile Include="Matrix.cpp" />
<ClCompile Include="stdafx.cpp" />
<ClCompile Include="..\Common\File.cpp">
<Filter>Common</Filter>
</ClCompile>
@ -25,22 +23,31 @@
<ClCompile Include="MatrixQuantizerCPU.cpp">
<Filter>CPU\1bitSGD</Filter>
</ClCompile>
<ClCompile Include="MatrixQuantizer.cpp" />
<ClCompile Include="QuantizedMatrix.cpp" />
<ClCompile Include="CUDAPageLockedMemAllocator.cpp">
<Filter>GPU\1bitSGD</Filter>
</ClCompile>
<ClCompile Include="ConvolutionEngine.cpp" />
<ClCompile Include="TensorView.cpp">
<Filter>Tensors</Filter>
</ClCompile>
<ClCompile Include="dllmain.cpp">
<Filter>Misc</Filter>
</ClCompile>
<ClCompile Include="ConvolutionEngine.cpp">
<Filter>Convolution</Filter>
</ClCompile>
<ClCompile Include="stdafx.cpp">
<Filter>Misc</Filter>
</ClCompile>
<ClCompile Include="QuantizedMatrix.cpp">
<Filter>1bitSGD</Filter>
</ClCompile>
<ClCompile Include="MatrixQuantizer.cpp">
<Filter>1bitSGD</Filter>
</ClCompile>
</ItemGroup>
<ItemGroup>
<ClInclude Include="CommonMatrix.h" />
<ClInclude Include="Helpers.h" />
<ClInclude Include="Matrix.h" />
<ClInclude Include="stdafx.h" />
<ClInclude Include="targetver.h" />
<ClInclude Include="..\Common\Include\File.h">
<Filter>Common\Include</Filter>
</ClInclude>
@ -59,14 +66,10 @@
<ClInclude Include="MatrixQuantizerCPU.h">
<Filter>CPU\1bitSGD</Filter>
</ClInclude>
<ClInclude Include="MatrixQuantizer.h" />
<ClInclude Include="QuantizedMatrix.h" />
<ClInclude Include="MemAllocator.h" />
<ClInclude Include="CUDAPageLockedMemAllocator.h">
<Filter>GPU\1bitSGD</Filter>
</ClInclude>
<ClInclude Include="..\Common\Include\DebugUtil.h" />
<ClInclude Include="ConvolutionEngine.h" />
<ClInclude Include="TensorView.h">
<Filter>Tensors</Filter>
</ClInclude>
@ -76,6 +79,27 @@
<ClInclude Include="..\Common\Include\DataTensor.h">
<Filter>Common\Include</Filter>
</ClInclude>
<ClInclude Include="Helpers.h">
<Filter>Misc</Filter>
</ClInclude>
<ClInclude Include="..\Common\Include\DebugUtil.h">
<Filter>Common\Include</Filter>
</ClInclude>
<ClInclude Include="ConvolutionEngine.h">
<Filter>Convolution</Filter>
</ClInclude>
<ClInclude Include="stdafx.h">
<Filter>Misc</Filter>
</ClInclude>
<ClInclude Include="targetver.h">
<Filter>Misc</Filter>
</ClInclude>
<ClInclude Include="QuantizedMatrix.h">
<Filter>1bitSGD</Filter>
</ClInclude>
<ClInclude Include="MatrixQuantizer.h">
<Filter>1bitSGD</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<None Include="GPUMatrix.h">
@ -113,5 +137,14 @@
<Filter Include="Tensors">
<UniqueIdentifier>{70fb07cf-603e-4444-bc10-f0add4920fd2}</UniqueIdentifier>
</Filter>
<Filter Include="Misc">
<UniqueIdentifier>{62b92193-92d0-4e5b-8c3e-67ffd01a98c0}</UniqueIdentifier>
</Filter>
<Filter Include="Convolution">
<UniqueIdentifier>{3a49e94d-14ee-4ca1-a56e-a1472206a076}</UniqueIdentifier>
</Filter>
<Filter Include="1bitSGD">
<UniqueIdentifier>{546cacbd-253e-485b-8c8c-8b9ee0e2f631}</UniqueIdentifier>
</Filter>
</ItemGroup>
</Project>

Просмотреть файл

@ -157,7 +157,9 @@ if exist "$(CuDnnDll)" (xcopy /Y "$(CuDnnDll)" $(OutputPath))
<ClInclude Include="cudalatticeops.h" />
<ClInclude Include="cudalib.h" />
<ClInclude Include="CuDnnConvolutionEngine.h" />
<ClInclude Include="GPUTensor.h" />
<ClInclude Include="latticefunctionskernels.h" />
<ClInclude Include="TensorOps.h" />
<ClInclude Include="ValueQuantizer.h" />
<None Include="GPUWatcher.h">
<FileType>CppHeader</FileType>
@ -171,6 +173,9 @@ if exist "$(CuDnnDll)" (xcopy /Y "$(CuDnnDll)" $(OutputPath))
<ClInclude Include="targetver.h" />
</ItemGroup>
<ItemGroup>
<CudaCompile Include="GPUTensor.cu">
<InterleaveSourceInPTX Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</InterleaveSourceInPTX>
</CudaCompile>
<CudaCompile Include="cudalatticeops.cu">
<FileType>CppCode</FileType>
</CudaCompile>
@ -202,7 +207,7 @@ if exist "$(CuDnnDll)" (xcopy /Y "$(CuDnnDll)" $(OutputPath))
<CudaCompile Include="GPUMatrix.cu">
<FileType>CppCode</FileType>
<Keep Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</Keep>
<InterleaveSourceInPTX Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</InterleaveSourceInPTX>
<InterleaveSourceInPTX Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</InterleaveSourceInPTX>
</CudaCompile>
<CudaCompile Include="GPUMatrixCUDAKernels.cuh">
<ExcludedFromBuild>true</ExcludedFromBuild>

Просмотреть файл

@ -22,25 +22,28 @@
<CudaCompile Include="GPUMatrixCUDAKernels.cuh">
<Filter>GPU</Filter>
</CudaCompile>
<CudaCompile Include="GPUTensor.cu">
<Filter>GPU\Tensors</Filter>
</CudaCompile>
</ItemGroup>
<ItemGroup>
<ClCompile Include="stdafx.cpp" />
<ClCompile Include="cudalattice.cpp">
<Filter>GPU\SequenceTraining</Filter>
</ClCompile>
<ClCompile Include="cudalib.cpp">
<Filter>GPU\SequenceTraining</Filter>
</ClCompile>
<ClCompile Include="..\Common\DebugUtil.cpp" />
<ClCompile Include="..\Common\DebugUtil.cpp">
<Filter>Misc</Filter>
</ClCompile>
<ClCompile Include="stdafx.cpp">
<Filter>Misc</Filter>
</ClCompile>
<ClCompile Include="CuDnnConvolutionEngine.cpp">
<Filter>GPU</Filter>
<Filter>GPU\Convolution</Filter>
</ClCompile>
</ItemGroup>
<ItemGroup>
<ClInclude Include="CommonMatrix.h" />
<ClInclude Include="Helpers.h" />
<ClInclude Include="stdafx.h" />
<ClInclude Include="targetver.h" />
<ClInclude Include="..\Common\Include\File.h">
<Filter>Common\Include</Filter>
</ClInclude>
@ -80,8 +83,26 @@
<ClInclude Include="latticefunctionskernels.h">
<Filter>GPU\SequenceTraining</Filter>
</ClInclude>
<ClInclude Include="GPUTensor.h">
<Filter>GPU\Tensors</Filter>
</ClInclude>
<ClInclude Include="Helpers.h">
<Filter>Misc</Filter>
</ClInclude>
<ClInclude Include="stdafx.h">
<Filter>Misc</Filter>
</ClInclude>
<ClInclude Include="targetver.h">
<Filter>Misc</Filter>
</ClInclude>
<ClInclude Include="CommonMatrix.h">
<Filter>from Math</Filter>
</ClInclude>
<ClInclude Include="CuDnnConvolutionEngine.h">
<Filter>GPU</Filter>
<Filter>GPU\Convolution</Filter>
</ClInclude>
<ClInclude Include="TensorOps.h">
<Filter>from Math</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
@ -105,14 +126,23 @@
<Filter Include="GPU">
<UniqueIdentifier>{cc9a219d-d8ab-484a-b253-fd2a29ad7c7c}</UniqueIdentifier>
</Filter>
<Filter Include="Include">
<UniqueIdentifier>{3c982109-64b1-469a-8d85-2abdf12d636a}</UniqueIdentifier>
</Filter>
<Filter Include="GPU\1bitSGD">
<UniqueIdentifier>{3415233d-9ef7-41c6-abbb-cec1b4f8d14c}</UniqueIdentifier>
</Filter>
<Filter Include="GPU\SequenceTraining">
<UniqueIdentifier>{6a3569b1-6c9e-47b3-870f-bb581349e75e}</UniqueIdentifier>
</Filter>
<Filter Include="Misc">
<UniqueIdentifier>{3c982109-64b1-469a-8d85-2abdf12d636a}</UniqueIdentifier>
</Filter>
<Filter Include="GPU\Tensors">
<UniqueIdentifier>{16214e65-2d24-4e4c-a0dd-c37e505bda32}</UniqueIdentifier>
</Filter>
<Filter Include="from Math">
<UniqueIdentifier>{b1b59e2e-5c54-4e40-ad0a-1523ddeb63ba}</UniqueIdentifier>
</Filter>
<Filter Include="GPU\Convolution">
<UniqueIdentifier>{3155488f-128f-494e-858d-459b4cc9fab7}</UniqueIdentifier>
</Filter>
</ItemGroup>
</Project>