From 98f9e8ac397c1081317c02a5edf7aeeda64fbfd3 Mon Sep 17 00:00:00 2001 From: Thilo Will Date: Mon, 11 Jul 2016 15:49:13 +0200 Subject: [PATCH] Passing reduction op furhter down --- Source/Math/GPUTensor.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Source/Math/GPUTensor.cu b/Source/Math/GPUTensor.cu index 7f7230026..4b3ad0275 100644 --- a/Source/Math/GPUTensor.cu +++ b/Source/Math/GPUTensor.cu @@ -496,7 +496,7 @@ struct TensorOpElement // launch tensor op with CUDA template -__global__ void _launchTensorOp(ElemType beta, FixedArray pointers, ElemType alpha, ElementWiseOperator op, +__global__ void _launchTensorOp(ElemType beta, FixedArray pointers, ElemType alpha, ElementWiseOperator op, ElementWiseOperator reductionOp, FixedArray regularOpStrides, FixedMatrix regularStrides, CUDA_LONG numElements, FixedArray reducingOpDims, FixedMatrix reducingStrides) { @@ -527,7 +527,7 @@ static void LaunchTensorOp(ElemType beta, array pointerVector, Ele CUDA_LONG NN = (CUDA_LONG) numElements; // linear space identifying each individual input element SyncGuard syncGuard; GridDim grid(NN); - _launchTensorOp<<>>(beta, pointers, alpha, op, regularOpStrides, regularStrides, grid.m_N, reducingOpDims, reducingStrides); + _launchTensorOp << > >(beta, pointers, alpha, op, ElementWiseOperator::opSum /* dummy reductionOp */, regularOpStrides, regularStrides, grid.m_N, reducingOpDims, reducingStrides); } // ----------------------------------------------------------------------- @@ -631,7 +631,7 @@ static void LaunchTensorOpWithReduction(ElemType beta, array point { // we got enough elements to generate: do one element per thread, and reduction inside _launchTensorOp<<>>( - beta, pointers, alpha, op, + beta, pointers, alpha, op, reductionOp, regularOpStrides, regularStrides, grid.m_N, reducingOpDims, reducingStrides); } @@ -745,7 +745,7 @@ static void LaunchTensorOpWithReduction(ElemType beta, array point #else _launchTensorOp<<>>( - beta, pointers, alpha, op, + beta, pointers, alpha, op, reductionOp, regularOpStrides, regularStrides, grid.m_N, reducingOpDims, reducingStrides); //for (size_t z = 0; z < numBlocksZ; z++)