diff --git a/include/infiniop.h b/include/infiniop.h index f0d75abc9..abf0ea0ba 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -10,6 +10,7 @@ #include "infiniop/ops/conv.h" #include "infiniop/ops/dequantize_awq.h" #include "infiniop/ops/gemm.h" +#include "infiniop/ops/layer_norm.h" #include "infiniop/ops/mul.h" #include "infiniop/ops/random_sample.h" #include "infiniop/ops/rearrange.h" @@ -21,5 +22,11 @@ #include "infiniop/ops/swiglu.h" #include "infiniop/ops/topkrouter.h" #include "infiniop/tensor_descriptor.h" +#include "infiniop/ops/softmax.h" +#include "infiniop/ops/sigmoid.h" +#include "infiniop/ops/gelu.h" +#include "infiniop/ops/tanh.h" +#include "infiniop/ops/quickgelu.h" +#include "infiniop/ops/gelutanh.h" #endif // __INFINIOP_API_H__ diff --git a/include/infiniop/ops/gelu.h b/include/infiniop/ops/gelu.h new file mode 100644 index 000000000..444092b6a --- /dev/null +++ b/include/infiniop/ops/gelu.h @@ -0,0 +1,24 @@ +#ifndef __INFINIOP_GELU_API_H__ +#define __INFINIOP_GELU_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopGeluDescriptor_t; + +__C __export infiniStatus_t infiniopCreateGeluDescriptor(infiniopHandle_t handle, + infiniopGeluDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output, + infiniopTensorDescriptor_t intput); + +__C __export infiniStatus_t infiniopGetGeluWorkspaceSize(infiniopGeluDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopGelu(infiniopGeluDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *output, + const void *intput, + void *stream); + +__C __export infiniStatus_t infiniopDestroyGeluDescriptor(infiniopGeluDescriptor_t desc); + +#endif diff --git a/include/infiniop/ops/gelutanh.h b/include/infiniop/ops/gelutanh.h new file mode 100644 index 000000000..e8eb005fe --- /dev/null +++ b/include/infiniop/ops/gelutanh.h @@ -0,0 +1,43 @@ +#ifndef __INFINIOP_GELUTANH_API_H__ +#define __INFINIOP_GELUTANH_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopGeluTanhDescriptor_t; + +/** + * Create GELU-Tanh descriptor + * + * y = x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) + */ +__C __export infiniStatus_t infiniopCreateGeluTanhDescriptor( + infiniopHandle_t handle, + infiniopGeluTanhDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x); + +/** + * Query workspace size + */ +__C __export infiniStatus_t infiniopGetGeluTanhWorkspaceSize( + infiniopGeluTanhDescriptor_t desc, + size_t *size); + +/** + * Launch GELU-Tanh operator + */ +__C __export infiniStatus_t infiniopGeluTanh( + infiniopGeluTanhDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *y, + const void *x, + void *stream); + +/** + * Destroy descriptor + */ +__C __export infiniStatus_t infiniopDestroyGeluTanhDescriptor( + infiniopGeluTanhDescriptor_t desc); + +#endif // __INFINIOP_GELUTANH_API_H__ diff --git a/include/infiniop/ops/layer_norm.h b/include/infiniop/ops/layer_norm.h new file mode 100644 index 000000000..5f852a9db --- /dev/null +++ b/include/infiniop/ops/layer_norm.h @@ -0,0 +1,34 @@ +#ifndef __INFINIOP_LAYER_NORM_API_H__ +#define __INFINIOP_LAYER_NORM_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopLayerNormDescriptor_t; + +__C __export infiniStatus_t infiniopCreateLayerNormDescriptor( + infiniopHandle_t handle, + infiniopLayerNormDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_standardization_desc, + infiniopTensorDescriptor_t input_std_deviation_desc, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t weight_desc, + infiniopTensorDescriptor_t bias_desc, + float eps); + +__C __export infiniStatus_t infiniopGetLayerNormWorkspaceSize(infiniopLayerNormDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopLayerNorm(infiniopLayerNormDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *output, + void *input_standardization, + void *input_std_deviation, + const void *input, + const void *weight, + const void *bias, + void *stream); + +__C __export infiniStatus_t infiniopDestroyLayerNormDescriptor(infiniopLayerNormDescriptor_t desc); + +#endif \ No newline at end of file diff --git a/include/infiniop/ops/quickgelu.h b/include/infiniop/ops/quickgelu.h new file mode 100644 index 000000000..1ea19ccf1 --- /dev/null +++ b/include/infiniop/ops/quickgelu.h @@ -0,0 +1,42 @@ +#ifndef __INFINIOP_QUICKGELU_API_H__ +#define __INFINIOP_QUICKGELU_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopQuickGeluDescriptor_t; + +/** + * Create QuickGELU descriptor + * y = x * sigmoid(1.702 * x) + */ +__C __export infiniStatus_t infiniopCreateQuickGeluDescriptor( + infiniopHandle_t handle, + infiniopQuickGeluDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x); + +/** + * Query workspace size + */ +__C __export infiniStatus_t infiniopGetQuickGeluWorkspaceSize( + infiniopQuickGeluDescriptor_t desc, + size_t *size); + +/** + * Launch QuickGELU operator + */ +__C __export infiniStatus_t infiniopQuickGelu( + infiniopQuickGeluDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *y, + const void *x, + void *stream); + +/** + * Destroy descriptor + */ +__C __export infiniStatus_t infiniopDestroyQuickGeluDescriptor( + infiniopQuickGeluDescriptor_t desc); + +#endif diff --git a/include/infiniop/ops/relu.h b/include/infiniop/ops/relu.h index 9fdbffbd5..7aeef7dac 100644 --- a/include/infiniop/ops/relu.h +++ b/include/infiniop/ops/relu.h @@ -10,6 +10,8 @@ __C __export infiniStatus_t infiniopCreateReluDescriptor(infiniopHandle_t handle infiniopTensorDescriptor_t y, infiniopTensorDescriptor_t x); +__C infiniStatus_t infiniopGetReluWorkspaceSize(infiniopReluDescriptor_t desc, size_t *size); + __C __export infiniStatus_t infiniopRelu(infiniopReluDescriptor_t desc, void *workspace, size_t workspace_size, diff --git a/include/infiniop/ops/sigmoid.h b/include/infiniop/ops/sigmoid.h new file mode 100644 index 000000000..4fa0f6604 --- /dev/null +++ b/include/infiniop/ops/sigmoid.h @@ -0,0 +1,24 @@ +#ifndef __INFINIOP_SIGMOID_API_H__ +#define __INFINIOP_SIGMOID_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopSigmoidDescriptor_t; + +__C __export infiniStatus_t infiniopCreateSigmoidDescriptor(infiniopHandle_t handle, + infiniopSigmoidDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x); + +__C __export infiniStatus_t infiniopGetSigmoidWorkspaceSize(infiniopSigmoidDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopSigmoid(infiniopSigmoidDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *y, + const void *x, + void *stream); + +__C __export infiniStatus_t infiniopDestroySigmoidDescriptor(infiniopSigmoidDescriptor_t desc); + +#endif diff --git a/include/infiniop/ops/softmax.h b/include/infiniop/ops/softmax.h new file mode 100644 index 000000000..6c8b3c936 --- /dev/null +++ b/include/infiniop/ops/softmax.h @@ -0,0 +1,27 @@ +#ifndef __INFINIOP_SOFTMAX_API_H__ +#define __INFINIOP_SOFTMAX_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopSoftmaxDescriptor_t; + +__C __export infiniStatus_t infiniopCreateSoftmaxDescriptor( + infiniopHandle_t handle, + infiniopSoftmaxDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + int axis); + +__C __export infiniStatus_t infiniopGetSoftmaxWorkspaceSize(infiniopSoftmaxDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopSoftmax( + infiniopSoftmaxDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *y, + const void *x, + void *stream); + +__C __export infiniStatus_t infiniopDestroySoftmaxDescriptor(infiniopSoftmaxDescriptor_t desc); + +#endif diff --git a/include/infiniop/ops/tanh.h b/include/infiniop/ops/tanh.h new file mode 100644 index 000000000..742dba860 --- /dev/null +++ b/include/infiniop/ops/tanh.h @@ -0,0 +1,24 @@ +#ifndef __INFINIOP_TANH_API_H__ +#define __INFINIOP_TANH_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopTanhDescriptor_t; + +__C __export infiniStatus_t infiniopCreateTanhDescriptor(infiniopHandle_t handle, + infiniopTanhDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output, + infiniopTensorDescriptor_t input); + +__C __export infiniStatus_t infiniopGetTanhWorkspaceSize(infiniopTanhDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopTanh(infiniopTanhDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *output, + const void *input, + void *stream); + +__C __export infiniStatus_t infiniopDestroyTanhDescriptor(infiniopTanhDescriptor_t desc); + +#endif diff --git a/src/infiniop/ops/add/moore/add_moore.h b/src/infiniop/ops/add/moore/add_moore.h new file mode 100644 index 000000000..db774c252 --- /dev/null +++ b/src/infiniop/ops/add/moore/add_moore.h @@ -0,0 +1,8 @@ +#ifndef __ADD_MOORE_API_H__ +#define __ADD_MOORE_API_H__ + +#include "../../../elementwise/moore/elementwise_moore_api.h" + +ELEMENTWISE_DESCRIPTOR(add, moore) + +#endif // __ADD_MOORE_API_H__ diff --git a/src/infiniop/ops/add/moore/add_moore.mu b/src/infiniop/ops/add/moore/add_moore.mu new file mode 100644 index 000000000..84df6bcb8 --- /dev/null +++ b/src/infiniop/ops/add/moore/add_moore.mu @@ -0,0 +1,66 @@ +#include "add_moore.h" + +#include "../../../elementwise/moore/elementwise_moore.h" + +#include "add_moore_kernel.h" + +namespace op::add::moore { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &a_desc = input_desc_vec.at(0); + const auto &b_desc = input_desc_vec.at(1); + const auto &c_shape = out_desc->shape(); + const auto &a_shape = a_desc->shape(); + const auto &b_shape = b_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16, INFINI_DTYPE_I32, INFINI_DTYPE_I64); + + CHECK_SAME_SHAPE(c_shape, a_shape, b_shape); + + // create MOORE elementwise descriptor + CREATE_ELEMENTWISE_MOORE_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate<256, moore::AddOp, half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate<256, moore::AddOp, cuda_bfloat16>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate<256, moore::AddOp, float>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate<256, moore::AddOp, double>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_I32: + return _device_info->calculate<256, moore::AddOp, int32_t>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_I64: + return _device_info->calculate<256, moore::AddOp, int64_t>(_info, workspace, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::add::moore diff --git a/src/infiniop/ops/add/moore/add_moore_kernel.h b/src/infiniop/ops/add/moore/add_moore_kernel.h new file mode 100644 index 000000000..9957e5d03 --- /dev/null +++ b/src/infiniop/ops/add/moore/add_moore_kernel.h @@ -0,0 +1,38 @@ +#ifndef __ADD_MOORE_KERNEL_H__ +#define __ADD_MOORE_KERNEL_H__ + +/* + * This file contains the Add operation implementation for the MUSA backend. + * + * It uses the 'op::add::cuda' namespace to maintain a consistent code structure + * and interface with the CUDA implementation, ensuring code alignment across different + * hardware platforms. + */ + +namespace op::add::moore { +typedef struct AddOp { +public: + static constexpr size_t num_inputs = 2; + template + __device__ __forceinline__ T operator()(const T &a, const T &b) const { + if constexpr (std::is_same_v) { + return __hadd2(a, b); + } else if constexpr (std::is_same_v) { + return __hadd(a, b); + } else if constexpr (std::is_same_v) { + // On MUSA platform, convert to float, add, then convert back to avoid ambiguous conversion + // from int (returned by __hadd) to __mt_bfloat16 + float a_f = __bfloat162float(a); + float b_f = __bfloat162float(b); + return __float2bfloat16_rn(a_f + b_f); + } else if constexpr (std::is_same_v) { + // Use __fadd_rn instead of __fadd_rd for moore platform compatibility + return __fadd_rn(a, b); + } else { + return a + b; + } + } +} AddOp; +} // namespace op::add::moore + +#endif // __ADD_MOORE_KERNEL_H__ diff --git a/src/infiniop/ops/add/operator.cc b/src/infiniop/ops/add/operator.cc index 52d19e501..861773fd0 100644 --- a/src/infiniop/ops/add/operator.cc +++ b/src/infiniop/ops/add/operator.cc @@ -5,7 +5,8 @@ #ifdef ENABLE_CPU_API #include "cpu/add_cpu.h" #endif -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) +// #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_HYGON_API) #include "nvidia/add_nvidia.cuh" #endif #ifdef ENABLE_METAX_API @@ -17,6 +18,9 @@ #ifdef ENABLE_CAMBRICON_API #include "bang/add_bang.h" #endif +#ifdef ENABLE_MOORE_API +#include "moore/add_moore.h" +#endif __C infiniStatus_t infiniopCreateAddDescriptor( infiniopHandle_t handle, @@ -45,6 +49,9 @@ __C infiniStatus_t infiniopCreateAddDescriptor( #ifdef ENABLE_ILUVATAR_API CREATE(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#ifdef ENABLE_HYGON_API + CREATE(INFINI_DEVICE_HYGON, nvidia); +#endif #ifdef ENABLE_METAX_API CREATE(INFINI_DEVICE_METAX, metax); #endif @@ -54,6 +61,9 @@ __C infiniStatus_t infiniopCreateAddDescriptor( #ifdef ENABLE_CAMBRICON_API CREATE(INFINI_DEVICE_CAMBRICON, bang); #endif +#ifdef ENABLE_MOORE_API + CREATE(INFINI_DEVICE_MOORE, moore); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -79,6 +89,9 @@ __C infiniStatus_t infiniopGetAddWorkspaceSize(infiniopAddDescriptor_t desc, siz #ifdef ENABLE_ILUVATAR_API GET(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#ifdef ENABLE_HYGON_API + GET(INFINI_DEVICE_HYGON, nvidia); +#endif #ifdef ENABLE_METAX_API GET(INFINI_DEVICE_METAX, metax); #endif @@ -87,6 +100,9 @@ __C infiniStatus_t infiniopGetAddWorkspaceSize(infiniopAddDescriptor_t desc, siz #endif #ifdef ENABLE_CAMBRICON_API GET(INFINI_DEVICE_CAMBRICON, bang); +#endif +#ifdef ENABLE_MOORE_API + GET(INFINI_DEVICE_MOORE, moore); #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -121,6 +137,9 @@ __C infiniStatus_t infiniopAdd( #ifdef ENABLE_ILUVATAR_API CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#ifdef ENABLE_HYGON_API + CALCULATE(INFINI_DEVICE_HYGON, nvidia); +#endif #ifdef ENABLE_METAX_API CALCULATE(INFINI_DEVICE_METAX, metax); #endif @@ -130,6 +149,9 @@ __C infiniStatus_t infiniopAdd( #ifdef ENABLE_CAMBRICON_API CALCULATE(INFINI_DEVICE_CAMBRICON, bang); #endif +#ifdef ENABLE_MOORE_API + CALCULATE(INFINI_DEVICE_MOORE, moore); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -157,6 +179,9 @@ infiniopDestroyAddDescriptor(infiniopAddDescriptor_t desc) { #ifdef ENABLE_ILUVATAR_API DELETE(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#ifdef ENABLE_HYGON_API + DELETE(INFINI_DEVICE_HYGON, nvidia); +#endif #ifdef ENABLE_METAX_API DELETE(INFINI_DEVICE_METAX, metax); #endif @@ -166,6 +191,9 @@ infiniopDestroyAddDescriptor(infiniopAddDescriptor_t desc) { #ifdef ENABLE_CAMBRICON_API DELETE(INFINI_DEVICE_CAMBRICON, bang); #endif +#ifdef ENABLE_MOORE_API + DELETE(INFINI_DEVICE_MOORE, moore); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; diff --git a/src/infiniop/ops/conv/moore/conv_moore.h b/src/infiniop/ops/conv/moore/conv_moore.h new file mode 100644 index 000000000..082a8de84 --- /dev/null +++ b/src/infiniop/ops/conv/moore/conv_moore.h @@ -0,0 +1,85 @@ +#ifndef __CONV_MOORE_H__ +#define __CONV_MOORE_H__ + +#include "conv_mudnn.h" + +namespace op::conv::moore { + +// Descriptor class for CONV operations on Moore devices. +// This class acts as a wrapper to select mudnn backend. +// It encapsulates the backend-specific Descriptor implementation and provides +// a unified interface for workspace query and CONV calculation. +class Descriptor final : public InfiniopDescriptor { +public: + // Destructor: deletes the backend-specific descriptor. + ~Descriptor() { + delete reinterpret_cast(_impl); + } + + // Returns the required workspace size for the CONV operation. + size_t workspaceSize() const { + return reinterpret_cast(_impl)->workspaceSize(); + } + + // Static factory method to create a Descriptor instance. + // This method chooses the backend (mudnn) and constructs + // the corresponding implementation internally. + static infiniStatus_t create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t w_desc, + infiniopTensorDescriptor_t b_desc, + const void *pads, + const void *strides, + const void *dilations, + size_t n) { + auto desc = new Descriptor(handle->device, handle->device_id); + + // Backend selection strategy: + // Currently defaulting to MUDNN. + // Can be modified to choose based on environment variables or runtime parameters. + desc->_backend = Backend::MUDNN; + + mudnn::Descriptor *impl; + auto status = mudnn::Descriptor::create(handle, &impl, y_desc, x_desc, w_desc, b_desc, pads, strides, dilations, n); + if (status != INFINI_STATUS_SUCCESS) { + delete desc; + return status; + } + desc->_impl = impl; + + *desc_ptr = desc; + return INFINI_STATUS_SUCCESS; + } + + // Unified CONV calculation interface. + // Calls the corresponding backend's calculate function internally. + infiniStatus_t calculate( + void *workspace, size_t workspace_size, + void *y, + const void *x, + const void *w, + const void *bias, + void *stream) const { + return reinterpret_cast(_impl) + ->calculate(workspace, workspace_size, y, x, w, bias, stream); + } + +private: + // Private constructor: ensures users cannot directly instantiate Descriptor. + // Instances must be created via the static create() factory method. + Descriptor(infiniDevice_t device_type, int device_id) + : InfiniopDescriptor{device_type, device_id}, _impl(nullptr) {} + + // Enum to indicate which backend is being used internally. + enum class Backend { MUDNN }; + + Backend _backend; // Currently selected MUDNN backend + void *_impl; // Pointer to backend-specific descriptor (mudnn::Descriptor*) +}; + +} // namespace op::conv::moore + +#endif // __CONV_MOORE_H__ \ No newline at end of file diff --git a/src/infiniop/ops/conv/moore/conv_mudnn.h b/src/infiniop/ops/conv/moore/conv_mudnn.h new file mode 100644 index 000000000..a9c0fc50f --- /dev/null +++ b/src/infiniop/ops/conv/moore/conv_mudnn.h @@ -0,0 +1,8 @@ +#ifndef __CONV_MUDNN_H__ +#define __CONV_MUDNN_H__ + +#include "../conv.h" + +DESCRIPTOR(mudnn) + +#endif // __CONV_MUDNN_H__ diff --git a/src/infiniop/ops/conv/moore/conv_mudnn.mu b/src/infiniop/ops/conv/moore/conv_mudnn.mu new file mode 100644 index 000000000..ce110382d --- /dev/null +++ b/src/infiniop/ops/conv/moore/conv_mudnn.mu @@ -0,0 +1,268 @@ +#include "../../../devices/moore/moore_common.h" +#include "../../../devices/moore/moore_handle.h" +#include "conv_mudnn.h" + +#include + +namespace op::conv::mudnn { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t w_desc, + infiniopTensorDescriptor_t b_desc, + const void *pads, + const void *strides, + const void *dilations, + size_t n) { + + auto handle = reinterpret_cast(handle_); + auto dtype = y_desc->dtype(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16); + + auto result = ConvInfo::create(handle_, y_desc, x_desc, w_desc, b_desc, pads, strides, dilations, n); + CHECK_RESULT(result); + + auto info = result.take(); + + *desc_ptr = new Descriptor( + dtype, info, 0, + new Opaque{handle->internal()}, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculate( + const ConvInfo &info, + std::shared_ptr &_internal, + void *y, + const void *x, + const void *w, + const void *bias, + void *stream) { + + auto conv_operator = std::make_unique<::musa::dnn::Convolution>(); + conv_operator->SetComputeMode(::musa::dnn::Convolution::ComputeMode::TENSOR); + + // Use muDNN handle management + return _internal->useMudnn((musaStream_t)stream, [&](::musa::dnn::Handle &mudnn_handle) -> infiniStatus_t { + + // 3. Create Tensor + ::musa::dnn::Tensor input_tensor, output_tensor, weight_tensor, bias_tensor; + + if constexpr (std::is_same::value) { + input_tensor.SetType(::musa::dnn::Tensor::Type::HALF); + output_tensor.SetType(::musa::dnn::Tensor::Type::HALF); + weight_tensor.SetType(::musa::dnn::Tensor::Type::HALF); + bias_tensor.SetType(::musa::dnn::Tensor::Type::HALF); + } else if constexpr (std::is_same::value) { + input_tensor.SetType(::musa::dnn::Tensor::Type::BFLOAT16); + output_tensor.SetType(::musa::dnn::Tensor::Type::BFLOAT16); + weight_tensor.SetType(::musa::dnn::Tensor::Type::BFLOAT16); + bias_tensor.SetType(::musa::dnn::Tensor::Type::BFLOAT16); + } else { + input_tensor.SetType(::musa::dnn::Tensor::Type::FLOAT); + output_tensor.SetType(::musa::dnn::Tensor::Type::FLOAT); + weight_tensor.SetType(::musa::dnn::Tensor::Type::FLOAT); + bias_tensor.SetType(::musa::dnn::Tensor::Type::FLOAT); + } + + // 4. Bind Tensor addr + input_tensor.SetAddr(const_cast(x)); + output_tensor.SetAddr(y); + weight_tensor.SetAddr(const_cast(w)); + bias_tensor.SetAddr(const_cast(bias)); +{ + // 5. Config Tensor input_tensor: [N, C, spatial...] + const size_t ndim = info.ndim(); + std::vector x_dims; + x_dims.reserve(ndim + 2); + + x_dims.push_back(static_cast(info.batch())); + x_dims.push_back(static_cast(info.in_channels())); + for (size_t i = 0; i < ndim; ++i) { + x_dims.push_back(static_cast(info.input_dim(i))); + } + + // contiguous stride + std::vector x_stride(x_dims.size()); + x_stride.back() = 1; + for (int i = static_cast(x_dims.size()) - 2; i >= 0; --i) { + x_stride[i] = x_stride[i + 1] * x_dims[i + 1]; + } + + input_tensor.SetNdInfo( + static_cast(x_dims.size()), + x_dims.data(), + x_stride.data() + ); + +} +{ + // 6. Config Tensor weight_tensor: [Cout, Cin, kernel...] + const size_t ndim = info.ndim(); + std::vector w_dims; + w_dims.reserve(ndim + 2); + + w_dims.push_back(static_cast(info.out_channels())); + w_dims.push_back(static_cast(info.in_channels())); // groups=1 + for (size_t i = 0; i < ndim; ++i) { + w_dims.push_back(static_cast(info.kernel_dim(i))); + } + + std::vector w_stride(w_dims.size()); + w_stride.back() = 1; + for (int i = static_cast(w_dims.size()) - 2; i >= 0; --i) { + w_stride[i] = w_stride[i + 1] * w_dims[i + 1]; + } + + weight_tensor.SetNdInfo( + static_cast(w_dims.size()), + w_dims.data(), + w_stride.data() + ); + +} +{ + // 7. Config Tensor output_tensor: [N, Cout, spatial...] + const size_t ndim = info.ndim(); + std::vector y_dims; + y_dims.reserve(ndim + 2); + + y_dims.push_back(static_cast(info.batch())); + y_dims.push_back(static_cast(info.out_channels())); + for (size_t i = 0; i < ndim; ++i) { + y_dims.push_back(static_cast(info.output_dim(i))); + } + + std::vector y_stride(y_dims.size()); + y_stride.back() = 1; + for (int i = static_cast(y_dims.size()) - 2; i >= 0; --i) { + y_stride[i] = y_stride[i + 1] * y_dims[i + 1]; + } + + output_tensor.SetNdInfo( + static_cast(y_dims.size()), + y_dims.data(), + y_stride.data() + ); +} + + // 8. Bias tensor (if exists) + if (bias != nullptr) { + std::array b_dims = { + static_cast(info.out_channels()) + }; + std::array b_stride = {1}; + bias_tensor.SetNdInfo(1, b_dims.data(), b_stride.data()); + } + + // 9. Configure convolution descriptor (from ConvInfo) + std::vector pad_dims(info.ndim()); + std::vector stride_dims(info.ndim()); + std::vector dilation_dims(info.ndim()); + + for (size_t i = 0; i < info.ndim(); ++i) { + pad_dims[i] = static_cast(info.pad_info(i)); + stride_dims[i] = static_cast(info.stride_info(i)); + dilation_dims[i] = static_cast(info.dilation_info(i)); + } + + // Current infiniop ConvInfo implies groups == 1 + conv_operator->SetGroups(1); + + // muDNN convolution configuration + conv_operator->SetNdInfo( + static_cast(info.ndim()), + pad_dims.data(), + stride_dims.data(), + dilation_dims.data() + ); + + + // 10. Select algorithm (simple version: always query) + ::musa::dnn::Convolution::Algorithm algo; + conv_operator->GetRecommendForwardAlgorithm( + mudnn_handle, + algo, + output_tensor, + input_tensor, + weight_tensor + ); + + // 11. Workspace memory handler + ::musa::dnn::MemoryMaintainer maintainer = + [](size_t size) -> ::musa::dnn::MemoryHandler { + void* ptr = nullptr; + musaMalloc(&ptr, size); + return ::musa::dnn::MemoryHandler( + ptr, + [](void* p) { if (p) musaFree(p); } + ); + }; + + // 12. Run convolution (no fused activation) + ::musa::dnn::Tensor add_tensor; // unused + ::musa::dnn::Convolution::FusedActivationDesc act; + act.SetMode(::musa::dnn::Convolution::FusedActivationDesc::Mode::IDENTITY); + + conv_operator->RunFusion( + mudnn_handle, + output_tensor, + input_tensor, + weight_tensor, + bias != nullptr ? bias_tensor : ::musa::dnn::Tensor(), + add_tensor, + act, + algo, + maintainer + ); + + return INFINI_STATUS_SUCCESS; + }); +} + + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *y, + const void *x, + const void *w, + const void *bias, + void *stream) const { + + + // Check for null pointers + if (!_opaque) { + return INFINI_STATUS_BAD_PARAM; + } + if (!_opaque->internal) { + return INFINI_STATUS_BAD_PARAM; + } + + switch (_dtype) { + case INFINI_DTYPE_F16: + return mudnn::calculate(_info, _opaque->internal, y, x, w, bias, stream); + case INFINI_DTYPE_F32: + return mudnn::calculate(_info, _opaque->internal, y, x, w, bias, stream); + case INFINI_DTYPE_BF16: + return mudnn::calculate<__mt_bfloat16>(_info, _opaque->internal, y, x, w, bias, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +} // namespace op::conv::mudnn \ No newline at end of file diff --git a/src/infiniop/ops/conv/nvidia/conv_nvidia.cu b/src/infiniop/ops/conv/nvidia/conv_nvidia.cu index f4f8d6d0f..5403966ce 100644 --- a/src/infiniop/ops/conv/nvidia/conv_nvidia.cu +++ b/src/infiniop/ops/conv/nvidia/conv_nvidia.cu @@ -213,10 +213,16 @@ private: infiniStatus_t setupAlgorithmWithBias() { int maxAlgoCount = 0; + + // 为海光DCU提供特殊处理 - 避免使用不支持的API CHECK_STATUS(internal->useCudnn( nullptr, [&](cudnnHandle_t handle) { - CHECK_CUDNN(cudnnGetConvolutionForwardAlgorithmMaxCount(handle, &maxAlgoCount)); + auto result = cudnnGetConvolutionForwardAlgorithmMaxCount(handle, &maxAlgoCount); + if (result != CUDNN_STATUS_SUCCESS) { + // 如果海光DCU不支持此API,使用默认值 + maxAlgoCount = 8; + } return INFINI_STATUS_SUCCESS; })); @@ -227,11 +233,19 @@ private: std::vector perf_results(maxAlgoCount); int algoCounts = 0; + // 为海光DCU提供特殊处理 - 避免使用可能不支持的API CHECK_STATUS(internal->useCudnn( nullptr, [&](cudnnHandle_t handle) { - CHECK_CUDNN(cudnnFindConvolutionForwardAlgorithm( + auto result = cudnnFindConvolutionForwardAlgorithm( handle, x_desc, w_desc, conv_desc, y_desc, - maxAlgoCount, &algoCounts, perf_results.data())); + maxAlgoCount, &algoCounts, perf_results.data()); + if (result != CUDNN_STATUS_SUCCESS) { + // 如果海光DCU不支持此API,使用默认算法 + algoCounts = 1; + perf_results[0].algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; + perf_results[0].status = CUDNN_STATUS_SUCCESS; + perf_results[0].time = 0.0f; + } return INFINI_STATUS_SUCCESS; })); diff --git a/src/infiniop/ops/conv/operator.cc b/src/infiniop/ops/conv/operator.cc index df033f44f..a5da724d2 100644 --- a/src/infiniop/ops/conv/operator.cc +++ b/src/infiniop/ops/conv/operator.cc @@ -5,9 +5,12 @@ #ifdef ENABLE_CPU_API #include "cpu/conv_cpu.h" #endif -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_HYGON_API) #include "nvidia/conv_nvidia.cuh" #endif +#ifdef ENABLE_MOORE_API +#include "moore/conv_moore.h" +#endif __C __export infiniStatus_t infiniopCreateConvDescriptor(infiniopHandle_t handle, infiniopConvDescriptor_t *desc_ptr, @@ -43,6 +46,13 @@ __C __export infiniStatus_t infiniopCreateConvDescriptor(infiniopHandle_t handle CREATE(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#ifdef ENABLE_HYGON_API + CREATE(INFINI_DEVICE_HYGON, nvidia); +#endif +#ifdef ENABLE_MOORE_API + CREATE(INFINI_DEVICE_MOORE, moore); +#endif + default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } @@ -70,6 +80,12 @@ infiniopGetConvWorkspaceSize( #ifdef ENABLE_ILUVATAR_API GET(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#ifdef ENABLE_HYGON_API + GET(INFINI_DEVICE_HYGON, nvidia); +#endif +#ifdef ENABLE_MOORE_API + GET(INFINI_DEVICE_MOORE, moore); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -106,6 +122,12 @@ __C infiniStatus_t infiniopConv( #ifdef ENABLE_ILUVATAR_API CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#ifdef ENABLE_HYGON_API + CALCULATE(INFINI_DEVICE_HYGON, nvidia); +#endif +#ifdef ENABLE_MOORE_API + CALCULATE(INFINI_DEVICE_MOORE, moore); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -130,6 +152,13 @@ infiniopDestroyConvDescriptor(infiniopConvDescriptor_t desc) { #ifdef ENABLE_ILUVATAR_API DELETE(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#ifdef ENABLE_HYGON_API + DELETE(INFINI_DEVICE_HYGON, nvidia); +#endif +#ifdef ENABLE_MOORE_API + DELETE(INFINI_DEVICE_MOORE, moore); +#endif + default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; diff --git a/src/infiniop/ops/gelu/cpu/gelu_cpu.cc b/src/infiniop/ops/gelu/cpu/gelu_cpu.cc new file mode 100644 index 000000000..a057ca4bc --- /dev/null +++ b/src/infiniop/ops/gelu/cpu/gelu_cpu.cc @@ -0,0 +1,52 @@ +#include "gelu_cpu.h" + +namespace op::gelu::cpu { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &input_desc = input_desc_vec.at(0); + const auto &output_shape = out_desc->shape(); + const auto &input_shape = input_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_BF16, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + + CHECK_SAME_SHAPE(output_shape, input_shape); + + // create CPU elementwise descriptor + CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + switch (_dtype) { + case INFINI_DTYPE_BF16: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_F16: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate(_info, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::gelu::cpu diff --git a/src/infiniop/ops/gelu/cpu/gelu_cpu.h b/src/infiniop/ops/gelu/cpu/gelu_cpu.h new file mode 100644 index 000000000..5a2d3fa8b --- /dev/null +++ b/src/infiniop/ops/gelu/cpu/gelu_cpu.h @@ -0,0 +1,23 @@ +#ifndef __GELU_CPU_H__ +#define __GELU_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" + +ELEMENTWISE_DESCRIPTOR(gelu, cpu) + +#include + +namespace op::gelu::cpu { +typedef struct GeluOp { +public: + static constexpr size_t num_inputs = 1; + + template + T operator()(const T &x) const { + return static_cast(0.5 * x * (1 + erf(x / sqrt(2.0f)))); + } +} GeluOp; + +} // namespace op::gelu::cpu + +#endif // __GELU_CPU_H__ diff --git a/src/infiniop/ops/gelu/cuda/kernel.cuh b/src/infiniop/ops/gelu/cuda/kernel.cuh new file mode 100644 index 000000000..31fa2b2be --- /dev/null +++ b/src/infiniop/ops/gelu/cuda/kernel.cuh @@ -0,0 +1,35 @@ +#ifndef __GELU_CUDA_H__ +#define __GELU_CUDA_H__ + +#include + +namespace op::gelu::cuda { + +typedef struct GeluOp { +public: + static constexpr size_t num_inputs = 1; + template + __device__ __forceinline__ T operator()(const T &x) const { + + if constexpr (std::is_same_v) { + float x_f = __bfloat162float(x); + float result = 0.5 * x_f * (1 + erf(x_f / sqrt(2.0f))); + + return __float2bfloat16(result); + } else if constexpr (std::is_same_v) { + float x_f = __half2float(x); + float result = 0.5 * x_f * (1 + erf(x_f / sqrt(2.0f))); + + return __float2half(result); + } else if constexpr (std::is_same_v) { + + return 0.5 * x * (1 + erf(x / sqrt(2.0f))); + } else { + return 0.5 * x * (1 + erf(x / sqrt(2.0))); + } + } +} GeluOp; + +} // namespace op::gelu::cuda + +#endif // __GELU_CUDA_H__ diff --git a/src/infiniop/ops/gelu/metax/gelu_meta.maca b/src/infiniop/ops/gelu/metax/gelu_meta.maca new file mode 100644 index 000000000..3a311530a --- /dev/null +++ b/src/infiniop/ops/gelu/metax/gelu_meta.maca @@ -0,0 +1,60 @@ +#include "gelu_metax.h" + +#include "../../../elementwise/metax/elementwise_metax.h" + +#include "../cuda/kernel.cuh" + +namespace op::gelu::metax { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &input_desc = input_desc_vec.at(0); + const auto &output_shape = out_desc->shape(); + const auto &input_shape = input_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_BF16, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + + CHECK_SAME_SHAPE(output_shape, input_shape); + + // create METAX elementwise descriptor + CREATE_ELEMENTWISE_METAX_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + switch (_dtype) { + case INFINI_DTYPE_BF16: + return _device_info->calculate<256, cuda::GeluOp, cuda_bfloat16>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F16: + return _device_info->calculate<256, cuda::GeluOp, half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate<256, cuda::GeluOp, float>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate<256, cuda::GeluOp, double>(_info, workspace, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::gelu::metax diff --git a/src/infiniop/ops/gelu/metax/gelu_metax.h b/src/infiniop/ops/gelu/metax/gelu_metax.h new file mode 100644 index 000000000..9385b7a27 --- /dev/null +++ b/src/infiniop/ops/gelu/metax/gelu_metax.h @@ -0,0 +1,8 @@ +#ifndef __GELU_METAX_API_H__ +#define __GELU_METAX_API_H__ + +#include "../../../elementwise/metax/elementwise_metax_api.h" + +ELEMENTWISE_DESCRIPTOR(gelu, metax) + +#endif // __GELU_METAX_API_H__ diff --git a/src/infiniop/ops/gelu/moore/gelu_moore.h b/src/infiniop/ops/gelu/moore/gelu_moore.h new file mode 100644 index 000000000..341bfd1f5 --- /dev/null +++ b/src/infiniop/ops/gelu/moore/gelu_moore.h @@ -0,0 +1,8 @@ +#ifndef __GELU_MOORE_API_H__ +#define __GELU_MOORE_API_H__ + +#include "../../../elementwise/moore/elementwise_moore_api.h" + +ELEMENTWISE_DESCRIPTOR(gelu, moore) + +#endif // __GELU_MOORE_API_H__ \ No newline at end of file diff --git a/src/infiniop/ops/gelu/moore/gelu_moore.mu b/src/infiniop/ops/gelu/moore/gelu_moore.mu new file mode 100644 index 000000000..6e53be253 --- /dev/null +++ b/src/infiniop/ops/gelu/moore/gelu_moore.mu @@ -0,0 +1,60 @@ +#include "gelu_moore.h" + +#include "../../../elementwise/moore/elementwise_moore.h" + +#include "gelu_moore_kernel.h" + +namespace op::gelu::moore { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &a_desc = input_desc_vec.at(0); + const auto &c_shape = out_desc->shape(); + const auto &a_shape = a_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16); + + CHECK_SAME_SHAPE(c_shape, a_shape); + + // create MOORE elementwise descriptor + CREATE_ELEMENTWISE_MOORE_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate<256, moore::GeluOp, half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate<256, moore::GeluOp, cuda_bfloat16>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate<256, moore::GeluOp, float>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate<256, moore::GeluOp, double>(_info, workspace, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::gelu::moore \ No newline at end of file diff --git a/src/infiniop/ops/gelu/moore/gelu_moore_kernel.h b/src/infiniop/ops/gelu/moore/gelu_moore_kernel.h new file mode 100644 index 000000000..cfdc62f17 --- /dev/null +++ b/src/infiniop/ops/gelu/moore/gelu_moore_kernel.h @@ -0,0 +1,43 @@ +#ifndef __GELU_MOORE_KERNEL_H__ +#define __GELU_MOORE_KERNEL_H__ + +/* + * This file contains the GELU operation implementation for the MUSA backend. + * + * It uses the 'op::gelu::cuda' namespace to maintain a consistent code structure + * and interface with the CUDA implementation, ensuring code alignment across different + * hardware platforms. + */ + +#include + +namespace op::gelu::moore { + +typedef struct GeluOp { +public: + static constexpr size_t num_inputs = 1; + template + __device__ __forceinline__ T operator()(const T &x) const { + + if constexpr (std::is_same_v) { + float x_f = __bfloat162float(x); + float result = 0.5f * x_f * (1.0f + erff(x_f / sqrtf(2.0f))); + + return __float2bfloat16(result); + } else if constexpr (std::is_same_v) { + float x_f = __half2float(x); + float result = 0.5f * x_f * (1.0f + erff(x_f / sqrtf(2.0f))); + + return __float2half(result); + } else if constexpr (std::is_same_v) { + + return 0.5f * x * (1.0f + erff(x / sqrtf(2.0f))); + } else { + return 0.5 * x * (1.0 + std::erf(x / std::sqrt(2.0))); + } + } +} GeluOp; + +} // namespace op::gelu::moore + +#endif // __GELU_MOORE_KERNEL_H__ \ No newline at end of file diff --git a/src/infiniop/ops/gelu/nvidia/gelu_nvidia.cu b/src/infiniop/ops/gelu/nvidia/gelu_nvidia.cu new file mode 100644 index 000000000..4d42cf2df --- /dev/null +++ b/src/infiniop/ops/gelu/nvidia/gelu_nvidia.cu @@ -0,0 +1,59 @@ +#include "../../../elementwise/nvidia/elementwise_nvidia.cuh" + +#include "../cuda/kernel.cuh" +#include "gelu_nvidia.cuh" + +namespace op::gelu::nvidia { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &input_desc = input_desc_vec.at(0); + const auto &output_shape = out_desc->shape(); + const auto &input_shape = input_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_BF16, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + + CHECK_SAME_SHAPE(output_shape, input_shape); + + // create CUDA elementwise descriptor + CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + switch (_dtype) { + case INFINI_DTYPE_BF16: + return _device_info->calculate<256, cuda::GeluOp, cuda_bfloat16>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F16: + return _device_info->calculate<256, cuda::GeluOp, half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate<256, cuda::GeluOp, float>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate<256, cuda::GeluOp, double>(_info, workspace, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::gelu::nvidia diff --git a/src/infiniop/ops/gelu/nvidia/gelu_nvidia.cuh b/src/infiniop/ops/gelu/nvidia/gelu_nvidia.cuh new file mode 100644 index 000000000..72dbbd4f0 --- /dev/null +++ b/src/infiniop/ops/gelu/nvidia/gelu_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __GELU_CUDA_API_H__ +#define __GELU_CUDA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(gelu, nvidia) + +#endif // __GELU_CUDA_API_H__ diff --git a/src/infiniop/ops/gelu/operator.cc b/src/infiniop/ops/gelu/operator.cc new file mode 100644 index 000000000..96daae105 --- /dev/null +++ b/src/infiniop/ops/gelu/operator.cc @@ -0,0 +1,182 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/gelu.h" + +#ifdef ENABLE_CPU_API +#include "cpu/gelu_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) +#include "nvidia/gelu_nvidia.cuh" +#endif +#ifdef ENABLE_METAX_API +#include "metax/gelu_metax.h" +#endif +#ifdef ENABLE_MOORE_API +#include "moore/gelu_moore.h" +#endif + +__C infiniStatus_t infiniopCreateGeluDescriptor( + infiniopHandle_t handle, + infiniopGeluDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::gelu::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + output_desc, \ + {input_desc}) + + switch (handle->device) { + +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CREATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_QY_API + CREATE(INFINI_DEVICE_QY, nvidia); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_HYGON_API + CREATE(INFINI_DEVICE_HYGON, nvidia); +#endif +#ifdef ENABLE_MOORE_API + CREATE(INFINI_DEVICE_MOORE, moore); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopGetGeluWorkspaceSize(infiniopGeluDescriptor_t desc, size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + GET(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_QY_API + GET(INFINI_DEVICE_QY, nvidia); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_HYGON_API + GET(INFINI_DEVICE_HYGON, nvidia); +#endif +#ifdef ENABLE_MOORE_API + GET(INFINI_DEVICE_MOORE, moore); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET + + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopGelu( + infiniopGeluDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *output, + const void *input, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, output, {input}, stream) + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_QY_API + CALCULATE(INFINI_DEVICE_QY, nvidia); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_HYGON_API + CALCULATE(INFINI_DEVICE_HYGON, nvidia); +#endif +#ifdef ENABLE_MOORE_API + CALCULATE(INFINI_DEVICE_MOORE, moore); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t +infiniopDestroyGeluDescriptor(infiniopGeluDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + DELETE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + DELETE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_QY_API + DELETE(INFINI_DEVICE_QY, nvidia); +#endif +#ifdef ENABLE_METAX_API + DELETE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_HYGON_API + DELETE(INFINI_DEVICE_HYGON, nvidia); +#endif +#ifdef ENABLE_MOORE_API + DELETE(INFINI_DEVICE_MOORE, moore); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} diff --git a/src/infiniop/ops/gelutanh/cuda/kernel.cuh b/src/infiniop/ops/gelutanh/cuda/kernel.cuh new file mode 100644 index 000000000..a45cb89ba --- /dev/null +++ b/src/infiniop/ops/gelutanh/cuda/kernel.cuh @@ -0,0 +1,58 @@ +#ifndef __GELUTANH_CUDA_H__ +#define __GELUTANH_CUDA_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia.cuh" +#include +#include +#include + +namespace op::gelutanh::cuda { + +typedef struct GeluTanhOp { +public: + static constexpr size_t num_inputs = 1; + + // GELU-Tanh constants + // static constexpr float alpha = std::sqrt(2.0 / M_PI); + // static constexpr float beta = 0.044715f; + static constexpr float alpha = 0.7978845608f; // sqrt(2/pi) + static constexpr float beta = 0.044715f; + // f32 tanh helper + __device__ __forceinline__ float tanh_f32_func(float x) const { + return tanhf(x); + } + + template + __device__ __forceinline__ T operator()(const T &x) const { + if constexpr (std::is_same_v) { + // half2 -> float2 + float2 vf = __half22float2(x); + float inner_x0 = alpha * (vf.x + beta * vf.x * vf.x * vf.x); + float inner_x1 = alpha * (vf.y + beta * vf.y * vf.y * vf.y); + float2 vr = make_float2(tanh_f32_func(inner_x0) * 0.5f + 0.5f, + tanh_f32_func(inner_x1) * 0.5f + 0.5f); + return __hmul2(x, __float22half2_rn(vr)); // y = x * 0.5 * (1 + tanh(...)) + } else if constexpr (std::is_same_v) { + float xf = __half2float(x); + float inner = alpha * (xf + beta * xf * xf * xf); + float yf = xf * 0.5f * (1.0f + tanh_f32_func(inner)); + return __float2half_rn(yf); + } else if constexpr (std::is_same_v) { + float xf = __bfloat162float(x); + float inner = alpha * (xf + beta * xf * xf * xf); + float yf = xf * 0.5f * (1.0f + tanh_f32_func(inner)); + return __float2bfloat16(yf); + } else if constexpr (std::is_same_v) { + float inner = alpha * (x + beta * x * x * x); + return x * 0.5f * (1.0f + tanh_f32_func(inner)); + } else { // double + double inner = alpha * (x + beta * x * x * x); + return x * 0.5 * (1.0 + std::tanh(inner)); + } + } + +} GeluTanhOp; + +} // namespace op::gelutanh::cuda + +#endif // __GELUTANH_CUDA_H__ diff --git a/src/infiniop/ops/gelutanh/moore/gelutanh_moore.h b/src/infiniop/ops/gelutanh/moore/gelutanh_moore.h new file mode 100644 index 000000000..d129bb602 --- /dev/null +++ b/src/infiniop/ops/gelutanh/moore/gelutanh_moore.h @@ -0,0 +1,8 @@ +#ifndef __GELUTANH_MOORE_API_H__ +#define __GELUTANH_MOORE_API_H__ + +#include "../../../elementwise/moore/elementwise_moore_api.h" + +ELEMENTWISE_DESCRIPTOR(gelutanh, moore) + +#endif // __GELUTANH_MOORE_API_H__ \ No newline at end of file diff --git a/src/infiniop/ops/gelutanh/moore/gelutanh_moore.mu b/src/infiniop/ops/gelutanh/moore/gelutanh_moore.mu new file mode 100644 index 000000000..32fa3248b --- /dev/null +++ b/src/infiniop/ops/gelutanh/moore/gelutanh_moore.mu @@ -0,0 +1,60 @@ +#include "gelutanh_moore.h" + +#include "../../../elementwise/moore/elementwise_moore.h" + +#include "gelutanh_moore_kernel.h" + +namespace op::gelutanh::moore { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &a_desc = input_desc_vec.at(0); + const auto &c_shape = out_desc->shape(); + const auto &a_shape = a_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16); + + CHECK_SAME_SHAPE(c_shape, a_shape); + + // create MOORE elementwise descriptor + CREATE_ELEMENTWISE_MOORE_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate<256, moore::GeluTanhOp, half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate<256, moore::GeluTanhOp, cuda_bfloat16>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate<256, moore::GeluTanhOp, float>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate<256, moore::GeluTanhOp, double>(_info, workspace, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::gelutanh::moore \ No newline at end of file diff --git a/src/infiniop/ops/gelutanh/moore/gelutanh_moore_kernel.h b/src/infiniop/ops/gelutanh/moore/gelutanh_moore_kernel.h new file mode 100644 index 000000000..61a896ee4 --- /dev/null +++ b/src/infiniop/ops/gelutanh/moore/gelutanh_moore_kernel.h @@ -0,0 +1,64 @@ +#ifndef __GELUTANH_MOORE_KERNEL_H__ +#define __GELUTANH_MOORE_KERNEL_H__ + +/* + * This file contains the GELU-Tanh operation implementation for the MUSA backend. + * + * It uses the 'op::gelutanh::cuda' namespace to maintain a consistent code structure + * and interface with the CUDA implementation, ensuring code alignment across different + * hardware platforms. + */ + +#include + +namespace op::gelutanh::moore { + +typedef struct GeluTanhOp { +public: + static constexpr size_t num_inputs = 1; + + // GELU-Tanh constants + // static constexpr float alpha = std::sqrt(2.0 / M_PI); + // static constexpr float beta = 0.044715f; + static constexpr float alpha = 0.7978845608f; // sqrt(2/pi) + static constexpr float beta = 0.044715f; + + // f32 tanh helper + __device__ __forceinline__ float tanh_f32_func(float x) const { + return tanhf(x); + } + + template + __device__ __forceinline__ T operator()(const T &x) const { + if constexpr (std::is_same_v) { + // half2 -> float2 + float2 vf = __half22float2(x); + float inner_x0 = alpha * (vf.x + beta * vf.x * vf.x * vf.x); + float inner_x1 = alpha * (vf.y + beta * vf.y * vf.y * vf.y); + float2 vr = make_float2(tanh_f32_func(inner_x0) * 0.5f + 0.5f, + tanh_f32_func(inner_x1) * 0.5f + 0.5f); + return __hmul2(x, __float22half2_rn(vr)); // y = x * 0.5 * (1 + tanh(...)) + } else if constexpr (std::is_same_v) { + float xf = __half2float(x); + float inner = alpha * (xf + beta * xf * xf * xf); + float yf = xf * 0.5f * (1.0f + tanh_f32_func(inner)); + return __float2half_rn(yf); + } else if constexpr (std::is_same_v) { + float xf = __bfloat162float(x); + float inner = alpha * (xf + beta * xf * xf * xf); + float yf = xf * 0.5f * (1.0f + tanh_f32_func(inner)); + return __float2bfloat16(yf); + } else if constexpr (std::is_same_v) { + float inner = alpha * (x + beta * x * x * x); + return x * 0.5f * (1.0f + tanh_f32_func(inner)); + } else { // double + double inner = alpha * (x + beta * x * x * x); + return x * 0.5 * (1.0 + std::tanh(inner)); + } + } + +} GeluTanhOp; + +} // namespace op::gelutanh::moore + +#endif // __GELUTANH_MOORE_KERNEL_H__ \ No newline at end of file diff --git a/src/infiniop/ops/gelutanh/nvidia/gelutanh_nvidia.cu b/src/infiniop/ops/gelutanh/nvidia/gelutanh_nvidia.cu new file mode 100644 index 000000000..10d8dbeab --- /dev/null +++ b/src/infiniop/ops/gelutanh/nvidia/gelutanh_nvidia.cu @@ -0,0 +1,70 @@ +#include "../cuda/kernel.cuh" +#include "gelutanh_nvidia.cuh" + +namespace op::gelutanh::nvidia { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &x_desc = input_desc_vec.at(0); + const auto &y_shape = out_desc->shape(); + const auto &x_shape = x_desc->shape(); + + CHECK_DTYPE(dtype, + INFINI_DTYPE_F16, + INFINI_DTYPE_F32, + INFINI_DTYPE_F64, + INFINI_DTYPE_BF16); + + CHECK_SAME_SHAPE(y_shape, x_shape); + + // create CUDA elementwise descriptor + CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate<256, cuda::GeluTanhOp, half>( + _info, workspace, output, inputs, stream); + + case INFINI_DTYPE_BF16: + return _device_info->calculate<256, cuda::GeluTanhOp, __nv_bfloat16>( + _info, workspace, output, inputs, stream); + + case INFINI_DTYPE_F32: + return _device_info->calculate<256, cuda::GeluTanhOp, float>( + _info, workspace, output, inputs, stream); + + case INFINI_DTYPE_F64: + return _device_info->calculate<256, cuda::GeluTanhOp, double>( + _info, workspace, output, inputs, stream); + + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::gelutanh::nvidia diff --git a/src/infiniop/ops/gelutanh/nvidia/gelutanh_nvidia.cuh b/src/infiniop/ops/gelutanh/nvidia/gelutanh_nvidia.cuh new file mode 100644 index 000000000..3155a7af1 --- /dev/null +++ b/src/infiniop/ops/gelutanh/nvidia/gelutanh_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __GELUTANH_CUDA_API_H__ +#define __GELUTANH_CUDA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(gelutanh, nvidia) + +#endif // __GELUTANH_CUDA_API_H__ diff --git a/src/infiniop/ops/gelutanh/operator.cc b/src/infiniop/ops/gelutanh/operator.cc new file mode 100644 index 000000000..3255e0200 --- /dev/null +++ b/src/infiniop/ops/gelutanh/operator.cc @@ -0,0 +1,152 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/gelutanh.h" + +// #ifdef ENABLE_CPU_API +// #include "cpu/gelutanh_cpu.h" +// #endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_HYGON_API) +#include "nvidia/gelutanh_nvidia.cuh" +#endif +#ifdef ENABLE_MOORE_API +#include "moore/gelutanh_moore.h" +#endif + +__C infiniStatus_t infiniopCreateGeluTanhDescriptor( + infiniopHandle_t handle, + infiniopGeluTanhDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::gelutanh::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + y_desc, \ + {x_desc}) + + switch (handle->device) { + +// #ifdef ENABLE_CPU_API +// CREATE(INFINI_DEVICE_CPU, cpu); +// #endif +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +// #ifdef ENABLE_QY_API +// CREATE(INFINI_DEVICE_QY, nvidia); +// #endif +#ifdef ENABLE_HYGON_API + CREATE(INFINI_DEVICE_HYGON, nvidia); +#endif +#ifdef ENABLE_MOORE_API + CREATE(INFINI_DEVICE_MOORE, moore); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopGetGeluTanhWorkspaceSize( + infiniopGeluTanhDescriptor_t desc, size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +// #ifdef ENABLE_CPU_API +// GET(INFINI_DEVICE_CPU, cpu) +// #endif +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia) +#endif +// #ifdef ENABLE_QY_API +// GET(INFINI_DEVICE_QY, nvidia) +// #endif +#ifdef ENABLE_HYGON_API + GET(INFINI_DEVICE_HYGON, nvidia) +#endif +#ifdef ENABLE_MOORE_API + GET(INFINI_DEVICE_MOORE, moore); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET +} + +__C infiniStatus_t infiniopGeluTanh( + infiniopGeluTanhDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *y, + const void *x, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, y, {x}, stream) + + switch (desc->device_type) { + +// #ifdef ENABLE_CPU_API +// CALCULATE(INFINI_DEVICE_CPU, cpu); +// #endif +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +// #ifdef ENABLE_QY_API +// CALCULATE(INFINI_DEVICE_QY, nvidia); +// #endif +#ifdef ENABLE_HYGON_API + CALCULATE(INFINI_DEVICE_HYGON, nvidia); +#endif +#ifdef ENABLE_MOORE_API + CALCULATE(INFINI_DEVICE_MOORE, moore); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t infiniopDestroyGeluTanhDescriptor( + infiniopGeluTanhDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { + +// #ifdef ENABLE_CPU_API +// DELETE(INFINI_DEVICE_CPU, cpu); +// #endif +#ifdef ENABLE_NVIDIA_API + DELETE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +// #ifdef ENABLE_QY_API +// DELETE(INFINI_DEVICE_QY, nvidia); +// #endif +#ifdef ENABLE_HYGON_API + DELETE(INFINI_DEVICE_HYGON, nvidia); +#endif +#ifdef ENABLE_MOORE_API + DELETE(INFINI_DEVICE_MOORE, moore); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} diff --git a/src/infiniop/ops/layer_norm/cpu/layer_norm_cpu.cc b/src/infiniop/ops/layer_norm/cpu/layer_norm_cpu.cc new file mode 100644 index 000000000..58a0030e8 --- /dev/null +++ b/src/infiniop/ops/layer_norm/cpu/layer_norm_cpu.cc @@ -0,0 +1,112 @@ +#include "layer_norm_cpu.h" +#include "../../../devices/cpu/common_cpu.h" +#include "../../../reduce/cpu/reduce.h" +#include "../info.h" + +namespace op::layer_norm::cpu { + +template +infiniStatus_t calculate_layer_norm( + const LayerNormInfo &info, + Tdata *output, + Tdata *input_standardization, + Tdata *input_std_deviation, + const Tdata *input, + const Tdata *weight, + const Tdata *bias) { + +#pragma omp parallel for + for (int b = 0; b < (int)(info.input_shape[0] * info.input_shape[1]); b++) { + int b0 = b / (int)info.input_shape[1], b1 = b % (int)info.input_shape[1]; + auto output_ptr = output + b0 * info.output_strides[0] + b1 * info.output_strides[1]; + auto input_ptr = input + b0 * info.input_strides[0] + b1 * info.input_strides[1]; + auto standard_ptr = input_standardization + b0 * info.input_standardization_strides[0] + b1 * info.input_standardization_strides[1]; + auto std_ptr = input_std_deviation + b0 * info.input_std_deviation_strides[0] + b1 * info.input_std_deviation_strides[1]; + float mean = op::common_cpu::reduce_op::sum( + input_ptr, + info.normalized_size, + info.input_strides[2]) + / info.input_shape[2]; + float sum_sq = op::common_cpu::reduce_op::sumSquared( + input_ptr, + info.normalized_size, + info.input_strides[2]); + float var = sum_sq / (info.normalized_size) - mean * mean; + float std_deviation = std::sqrt(var + info.eps); + *std_ptr = utils::cast(std_deviation); + + for (size_t d = 0; d < info.normalized_size; d++) { + float x_standard = (utils::cast(*(input_ptr + d * info.input_strides[2])) - mean) / std_deviation; + *(standard_ptr + d * info.input_standardization_strides[2]) = utils::cast(x_standard); + *(output_ptr + d * info.output_strides[2]) = utils::cast( + x_standard * utils::cast(*(weight + d * info.weight_strides[0])) + (info.bias_exist ? utils::cast(*(bias + d * info.bias_strides[0])) : float(0))); + } + } + + return INFINI_STATUS_SUCCESS; +} + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_standardization_desc, + infiniopTensorDescriptor_t input_std_deviation_desc, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t weight_desc, + infiniopTensorDescriptor_t bias_desc, + float eps) { + auto handle = reinterpret_cast(handle_); + + // --------------------- start: check data type and calculate workspace size ---------------------- + auto dtype = input_desc->dtype(); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16); + size_t WorkSpaceSize = 0; + + auto result = LayerNormInfo::createLayerNormInfo( + output_desc, + input_standardization_desc, + input_std_deviation_desc, + input_desc, + weight_desc, + bias_desc, + eps); + CHECK_RESULT(result); + const LayerNormInfo &info = result.take(); + + *desc_ptr = new Descriptor( + dtype, std::move(info), WorkSpaceSize, + nullptr, + handle->device, handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +#define CALCULATE_LAYER_NORM(TDATA) \ + CHECK_STATUS(calculate_layer_norm(_info, \ + (TDATA *)output, (TDATA *)input_standardization, (TDATA *)input_std_deviation, (const TDATA *)input, (const TDATA *)weight, (const TDATA *)bias)) + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + void *input_standardization, + void *input_std_deviation, + const void *input, + const void *weight, + const void *bias, + void *stream) const { + if (_info.dtype == INFINI_DTYPE_F16) { + CALCULATE_LAYER_NORM(fp16_t); + } else if (_info.dtype == INFINI_DTYPE_BF16) { + CALCULATE_LAYER_NORM(bf16_t); + } else if (_info.dtype == INFINI_DTYPE_F32) { + CALCULATE_LAYER_NORM(float); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + return INFINI_STATUS_SUCCESS; +} +} // namespace op::layer_norm::cpu \ No newline at end of file diff --git a/src/infiniop/ops/layer_norm/cpu/layer_norm_cpu.h b/src/infiniop/ops/layer_norm/cpu/layer_norm_cpu.h new file mode 100644 index 000000000..51d56bbf7 --- /dev/null +++ b/src/infiniop/ops/layer_norm/cpu/layer_norm_cpu.h @@ -0,0 +1,8 @@ +#ifndef __LAYER_NORM_CPU_H__ +#define __LAYER_NORM_CPU_H__ + +#include "../layer_norm.h" + +DESCRIPTOR(cpu) + +#endif // __LAYER_NORM_CPU_H__ \ No newline at end of file diff --git a/src/infiniop/ops/layer_norm/cuda/kernel.cuh b/src/infiniop/ops/layer_norm/cuda/kernel.cuh new file mode 100644 index 000000000..5c9862ee5 --- /dev/null +++ b/src/infiniop/ops/layer_norm/cuda/kernel.cuh @@ -0,0 +1,161 @@ +#ifndef __LAYER_NORM_KERNEL_CUH__ +#define __LAYER_NORM_KERNEL_CUH__ +#include + +template +__device__ void layerNormKernel( + Tdata *output, + Tdata *input_standardization, + Tdata *input_std_deviation, + const Tdata *input, + const Tdata *weight, + const Tdata *bias, + float eps, + size_t normalized_size, + const ptrdiff_t *output_strides, + const ptrdiff_t *input_standardization_strides, + const ptrdiff_t *input_std_deviation_strides, + const ptrdiff_t *input_strides, + ptrdiff_t weight_stride, + ptrdiff_t bias_stride, + bool bias_exist) { + size_t b0 = blockIdx.x, b1 = blockIdx.y; + + auto output_ptr = output + b0 * output_strides[0] + b1 * output_strides[1]; + auto input_ptr = input + b0 * input_strides[0] + b1 * input_strides[1]; + auto standard_ptr = input_standardization + b0 * input_standardization_strides[0] + b1 * input_standardization_strides[1]; + auto std_ptr = input_std_deviation + b0 * input_std_deviation_strides[0] + b1 * input_std_deviation_strides[1]; + Tcompute mean = op::common_cuda::reduce_op::sum( + input_ptr, + normalized_size) + / normalized_size; + Tcompute sum_squared = op::common_cuda::reduce_op::sumSquared( + input_ptr, + normalized_size); + + Tcompute var = sum_squared / normalized_size - mean * mean; + Tcompute std_deviation = sqrtf(var + Tcompute(eps)); + *std_ptr = std_deviation; + + for (size_t d = 0; d < normalized_size; d++) { + Tcompute x_standard = (Tcompute(input_ptr[d]) - mean) / std_deviation; + standard_ptr[d] = x_standard; + output_ptr[d] = x_standard * Tcompute(*(weight + d * weight_stride)) + (bias_exist ? Tcompute(*(bias + d * bias_stride)) : Tcompute(0)); + } +} + +template +__device__ void blockLayernormKernel(T *output, T const *input, T const *weight, T const *bias, float eps, int dimsize, + const ptrdiff_t *output_strides, + const ptrdiff_t *input_strides, + const size_t *shape, + ptrdiff_t weight_stride, + ptrdiff_t bias_stride, + int ndim, + bool bias_exist) { + // 只能处理axis=-1 + int ind_i = 0; // input id + int ind_o = 0; // output id + int tid = blockIdx.x; + for (int j = ndim - 2; j >= 0; j--) { + ind_i += (tid % (int)shape[j]) * (int)input_strides[j]; + ind_o += (tid % (int)shape[j]) * (int)output_strides[j]; + tid = tid / (int)shape[j]; + } + + float mu_partial = op::common_cuda::reduce_op::sum( + input + ind_i, + dimsize) + / dimsize; + __shared__ float mu; + if (threadIdx.x == 0) { + mu = mu_partial; + } // threadIdx.x = 0对应的是全局sum + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + __syncthreads(); + float sigma2_partial = 0.0f; + for (int id = threadIdx.x; id < dimsize; id += BLOCK_SIZE) { + sigma2_partial += (static_cast(input[ind_i + id]) - mu) * (static_cast(input[ind_i + id]) - mu); + } + + // __shared__ float sigma2; + // float sigma2_block = BlockReduce(temp_storage).Reduce(sigma2_partial, cub::Sum()); + + __shared__ float sigma2; + float sigma2_block = BlockReduce(temp_storage).Sum(sigma2_partial); + + if (threadIdx.x == 0) { + float sigma_tmp = sqrt(sigma2_block * __fdividef(1.0F, dimsize) + eps); + sigma2 = __fdividef(1.0F, sigma_tmp); + } + __syncthreads(); + for (int id = threadIdx.x; id < dimsize; id += BLOCK_SIZE) { + output[ind_o + id] = static_cast(static_cast(weight[id * weight_stride]) * (static_cast(input[ind_i + id]) - mu) * sigma2 + (bias_exist ? static_cast(bias[id * bias_stride]) : 0.0f)); + } +} +template +struct SumOp { + __device__ __forceinline__ T operator()(const T &a, const T &b) const { + return a + b; + } +}; + +template