DeepSpeed/csrc/includes/feed_forward.h

115 строки
3.5 KiB
C++

// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#ifndef __FEEDFORWARD_H__
#define __FEEDFORWARD_H__
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include "custom_cuda_layers.h"
template <typename T>
class FeedForward {
public:
struct Config {
int batchSize, outputSize;
int inputSize;
std::array<int, 3> gemm_algos;
Config(int batch, int outputs, int inputs, const std::array<int, 3>& algos)
: batchSize(batch), outputSize(outputs), inputSize(inputs), gemm_algos(algos)
{
}
};
FeedForward(Config config) : config_(config) {}
~FeedForward() {}
void Forward(int bsz,
const T* input_ptr,
const T* weights,
T* out,
cublasHandle_t& _cublasHandle)
{
float alpha = T(1.);
float beta = T(0.);
cublas_gemm_ex(_cublasHandle,
CUBLAS_OP_T,
CUBLAS_OP_N,
config_.outputSize,
bsz,
config_.inputSize,
&alpha,
&beta,
weights,
input_ptr,
out,
// TODO HIP: Remove backward compatibility for torch<=2.0 in future
#if defined(__HIP_PLATFORM_AMD__) && \
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
rocblas_gemm_algo(config_.gemm_algos[0]));
#else
cublasGemmAlgo_t(config_.gemm_algos[0]));
#endif
}
void Backward(int bsz,
const T* out_grad,
const T* input_ptr,
const T* weights,
T* weights_grad,
T* bias_grad,
cublasHandle_t& _cublasHandle,
cudaStream_t& stream,
T* inp_grad_out = nullptr,
T* out_grad_trans_out = nullptr)
{
float alpha = (T)1.0, beta = (T)0.0;
cublas_gemm_ex(_cublasHandle,
CUBLAS_OP_N,
CUBLAS_OP_T,
config_.inputSize,
config_.outputSize,
bsz,
&alpha,
&beta,
input_ptr,
out_grad,
weights_grad,
#if defined(__HIP_PLATFORM_AMD__) && \
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
rocblas_gemm_algo(config_.gemm_algos[1]));
#else
cublasGemmAlgo_t(config_.gemm_algos[1]));
#endif
cublas_gemm_ex(_cublasHandle,
CUBLAS_OP_N,
CUBLAS_OP_N,
config_.inputSize,
bsz,
config_.outputSize,
&alpha,
&beta,
weights,
out_grad,
inp_grad_out,
#if defined(__HIP_PLATFORM_AMD__) && \
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
rocblas_gemm_algo(config_.gemm_algos[2]));
#else
cublasGemmAlgo_t(config_.gemm_algos[2]));
#endif
launch_fuse_transpose_bias_kernel<T>(out_grad, bias_grad, bsz, config_.outputSize, stream);
}
private:
Config config_;
};
#endif