From dc6e3ca605842bf3c2b74967d8c009e62be637aa Mon Sep 17 00:00:00 2001 From: Sxy-17 Date: Wed, 19 Nov 2025 18:03:16 +0800 Subject: [PATCH 01/13] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=BA=86=E4=B8=80?= =?UTF-8?q?=E4=B8=8B=E6=B5=B7=E5=85=89DCU=E4=B8=8D=E8=83=BD=E8=B7=91conv2d?= =?UTF-8?q?=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/infiniop/ops/conv/nvidia/conv_nvidia.cu | 20 +++++++++++++++++--- src/infiniop/ops/conv/operator.cc | 14 +++++++++++++- test/infiniop/conv.py | 8 +++++++- third_party/spdlog | 1 + xmake/hygon.lua | 1 + 5 files changed, 39 insertions(+), 5 deletions(-) create mode 160000 third_party/spdlog diff --git a/src/infiniop/ops/conv/nvidia/conv_nvidia.cu b/src/infiniop/ops/conv/nvidia/conv_nvidia.cu index f4f8d6d0f..5403966ce 100644 --- a/src/infiniop/ops/conv/nvidia/conv_nvidia.cu +++ b/src/infiniop/ops/conv/nvidia/conv_nvidia.cu @@ -213,10 +213,16 @@ private: infiniStatus_t setupAlgorithmWithBias() { int maxAlgoCount = 0; + + // 为海光DCU提供特殊处理 - 避免使用不支持的API CHECK_STATUS(internal->useCudnn( nullptr, [&](cudnnHandle_t handle) { - CHECK_CUDNN(cudnnGetConvolutionForwardAlgorithmMaxCount(handle, &maxAlgoCount)); + auto result = cudnnGetConvolutionForwardAlgorithmMaxCount(handle, &maxAlgoCount); + if (result != CUDNN_STATUS_SUCCESS) { + // 如果海光DCU不支持此API,使用默认值 + maxAlgoCount = 8; + } return INFINI_STATUS_SUCCESS; })); @@ -227,11 +233,19 @@ private: std::vector perf_results(maxAlgoCount); int algoCounts = 0; + // 为海光DCU提供特殊处理 - 避免使用可能不支持的API CHECK_STATUS(internal->useCudnn( nullptr, [&](cudnnHandle_t handle) { - CHECK_CUDNN(cudnnFindConvolutionForwardAlgorithm( + auto result = cudnnFindConvolutionForwardAlgorithm( handle, x_desc, w_desc, conv_desc, y_desc, - maxAlgoCount, &algoCounts, perf_results.data())); + maxAlgoCount, &algoCounts, perf_results.data()); + if (result != CUDNN_STATUS_SUCCESS) { + // 如果海光DCU不支持此API,使用默认算法 + algoCounts = 1; + perf_results[0].algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; + perf_results[0].status = CUDNN_STATUS_SUCCESS; + perf_results[0].time = 0.0f; + } return INFINI_STATUS_SUCCESS; })); diff --git a/src/infiniop/ops/conv/operator.cc b/src/infiniop/ops/conv/operator.cc index df033f44f..5732dee73 100644 --- a/src/infiniop/ops/conv/operator.cc +++ b/src/infiniop/ops/conv/operator.cc @@ -5,7 +5,7 @@ #ifdef ENABLE_CPU_API #include "cpu/conv_cpu.h" #endif -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_HYGON_API) #include "nvidia/conv_nvidia.cuh" #endif @@ -43,6 +43,9 @@ __C __export infiniStatus_t infiniopCreateConvDescriptor(infiniopHandle_t handle CREATE(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#ifdef ENABLE_HYGON_API + CREATE(INFINI_DEVICE_HYGON, nvidia); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } @@ -70,6 +73,9 @@ infiniopGetConvWorkspaceSize( #ifdef ENABLE_ILUVATAR_API GET(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#ifdef ENABLE_HYGON_API + GET(INFINI_DEVICE_HYGON, nvidia); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -106,6 +112,9 @@ __C infiniStatus_t infiniopConv( #ifdef ENABLE_ILUVATAR_API CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#ifdef ENABLE_HYGON_API + CALCULATE(INFINI_DEVICE_HYGON, nvidia); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -130,6 +139,9 @@ infiniopDestroyConvDescriptor(infiniopConvDescriptor_t desc) { #ifdef ENABLE_ILUVATAR_API DELETE(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#ifdef ENABLE_HYGON_API + DELETE(INFINI_DEVICE_HYGON, nvidia); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; diff --git a/test/infiniop/conv.py b/test/infiniop/conv.py index 6cb99da9f..02a8db253 100644 --- a/test/infiniop/conv.py +++ b/test/infiniop/conv.py @@ -264,6 +264,12 @@ def lib_conv(): NUM_PRERUN = args.num_prerun NUM_ITERATIONS = args.num_iterations for device in get_test_devices(args): - test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + # 海光DCU不支持bfloat16,只测试F16和F32 + tensor_dtypes = _TENSOR_DTYPES + if InfiniDeviceNames[device] == "Hygon": + tensor_dtypes = [InfiniDtype.F16, InfiniDtype.F32] # 跳过BF16 + print(f"Testing on Hygon DCU, skipping BF16 (unsupported)") + + test_operator(device, test, _TEST_CASES, tensor_dtypes) print("\033[92mTest passed!\033[0m") diff --git a/third_party/spdlog b/third_party/spdlog new file mode 160000 index 000000000..f1d748e5e --- /dev/null +++ b/third_party/spdlog @@ -0,0 +1 @@ +Subproject commit f1d748e5e3edfa4b1778edea003bac94781bc7b7 diff --git a/xmake/hygon.lua b/xmake/hygon.lua index ed4b91f0e..4c36731c1 100644 --- a/xmake/hygon.lua +++ b/xmake/hygon.lua @@ -74,6 +74,7 @@ target("infiniop-hygon") add_files("../src/infiniop/ops/rearrange/nvidia/*.cu") add_files("../src/infiniop/ops/rms_norm/nvidia/*.cu") add_files("../src/infiniop/ops/swiglu/nvidia/*.cu") + add_files("../src/infiniop/ops/conv/nvidia/*.cu") if has_config("ninetoothed") then add_files("../build/ninetoothed/*.c", {cxflags = {"-Wno-return-type"}}) From ec04cde5d90d66a974c548d872cd01bfd8633223 Mon Sep 17 00:00:00 2001 From: Sxy-17 Date: Wed, 3 Dec 2025 19:57:19 +0800 Subject: [PATCH 02/13] =?UTF-8?q?=E6=B7=BB=E5=8A=A0add=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/infiniop/ops/add/operator.cc | 15 ++++++++++++++- test/infiniop/rope.py | 13 ++++++++++++- xmake/hygon.lua | 1 + 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/src/infiniop/ops/add/operator.cc b/src/infiniop/ops/add/operator.cc index 52d19e501..02d93bd17 100644 --- a/src/infiniop/ops/add/operator.cc +++ b/src/infiniop/ops/add/operator.cc @@ -5,7 +5,8 @@ #ifdef ENABLE_CPU_API #include "cpu/add_cpu.h" #endif -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) +// #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_HYGON_API) #include "nvidia/add_nvidia.cuh" #endif #ifdef ENABLE_METAX_API @@ -45,6 +46,9 @@ __C infiniStatus_t infiniopCreateAddDescriptor( #ifdef ENABLE_ILUVATAR_API CREATE(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#ifdef ENABLE_HYGON_API + CREATE(INFINI_DEVICE_HYGON, nvidia); +#endif #ifdef ENABLE_METAX_API CREATE(INFINI_DEVICE_METAX, metax); #endif @@ -79,6 +83,9 @@ __C infiniStatus_t infiniopGetAddWorkspaceSize(infiniopAddDescriptor_t desc, siz #ifdef ENABLE_ILUVATAR_API GET(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#ifdef ENABLE_HYGON_API + GET(INFINI_DEVICE_HYGON, nvidia); +#endif #ifdef ENABLE_METAX_API GET(INFINI_DEVICE_METAX, metax); #endif @@ -121,6 +128,9 @@ __C infiniStatus_t infiniopAdd( #ifdef ENABLE_ILUVATAR_API CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#ifdef ENABLE_HYGON_API + CALCULATE(INFINI_DEVICE_HYGON, nvidia); +#endif #ifdef ENABLE_METAX_API CALCULATE(INFINI_DEVICE_METAX, metax); #endif @@ -157,6 +167,9 @@ infiniopDestroyAddDescriptor(infiniopAddDescriptor_t desc) { #ifdef ENABLE_ILUVATAR_API DELETE(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#ifdef ENABLE_HYGON_API + DELETE(INFINI_DEVICE_HYGON, nvidia); +#endif #ifdef ENABLE_METAX_API DELETE(INFINI_DEVICE_METAX, metax); #endif diff --git a/test/infiniop/rope.py b/test/infiniop/rope.py index 040f386c7..a6c48a9e2 100644 --- a/test/infiniop/rope.py +++ b/test/infiniop/rope.py @@ -82,6 +82,7 @@ class Algorithm(Enum): def rotary_embedding(ans, t, sin, cos, device, algo): def _torch_rope(sin, cos, t1, t2): + # PyTorch的标准RoPE实现 cos = cos.unsqueeze(1) # [seq_len, 1, dh // 2] sin = sin.unsqueeze(1) # [seq_len, 1, dh // 2] if device == InfiniDeviceEnum.CPU: @@ -101,6 +102,7 @@ def _torch_rope(sin, cos, t1, t2): dt = t.dtype assert dh % 2 == 0, "Embedding dimension must be even." + # 根据不同算法(GPT-J/GPT-NeoX)处理输入 if algo == Algorithm.GPT_J: t_even = t[..., 0::2] # [seq_len, n_head, dh // 2] t_odd = t[..., 1::2] # [seq_len, n_head, dh // 2] @@ -109,7 +111,7 @@ def _torch_rope(sin, cos, t1, t2): ans[..., 0::2] = t_out_even.to(dt) ans[..., 1::2] = t_out_odd.to(dt) - else: + else: # GPT_NEOX half_dim = dh // 2 t_first = t[..., :half_dim] t_second = t[..., half_dim:] @@ -141,6 +143,7 @@ def test( dtype=torch.float32, sync=None, ): + # 创建测试tensor x = TestTensor(shape, x_strides, dtype, device) if inplace == Inplace.INPLACE_X: if x_strides != y_strides: @@ -153,11 +156,14 @@ def test( f"Testing Rotary Positional Embedding on {InfiniDeviceNames[device]} with shape:{shape} x_strides:{x_strides} y_strides:{y_strides} and dtype:{InfiniDtypeNames[dtype]} inplace:{inplace} algo:{algo}" ) theta = 1e5 + + # 生成sin/cos表 pos = TestTensor.from_torch(torch.arange(0, x.shape[0]), InfiniDtype.I32, device) sin_table, cos_table = sin_cos_table( pos.torch_tensor(), x.shape[2], x.device, theta, dtype ) + # 运行 baseline (PyTorch) rotary_embedding( y.torch_tensor(), x.torch_tensor(), @@ -167,6 +173,7 @@ def test( algo, ) + # 创建InfiniCore算子descriptor descriptor = infiniopOperatorDescriptor_t() if sync is not None: @@ -189,6 +196,7 @@ def test( for tensor in [y, x, pos, sin_table, cos_table]: tensor.destroy_desc() + # 获取workspace大小并分配 workspace_size = c_uint64(0) check_error( LIBINFINIOP.infiniopGetRoPEWorkspaceSize( @@ -197,6 +205,7 @@ def test( ) workspace = TestWorkspace(workspace_size.value, x.device) + # 定义InfiniCore算子执行函数 def lib_rope(): check_error( LIBINFINIOP.infiniopRoPE( @@ -212,6 +221,7 @@ def lib_rope(): ) ) + # 执行InfiniCore算子 lib_rope() if sync is not None: @@ -222,6 +232,7 @@ def lib_rope(): debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) + # 性能测试 (可选) if PROFILE: profile_operation( "PyTorch", diff --git a/xmake/hygon.lua b/xmake/hygon.lua index 4c36731c1..c49960a70 100644 --- a/xmake/hygon.lua +++ b/xmake/hygon.lua @@ -75,6 +75,7 @@ target("infiniop-hygon") add_files("../src/infiniop/ops/rms_norm/nvidia/*.cu") add_files("../src/infiniop/ops/swiglu/nvidia/*.cu") add_files("../src/infiniop/ops/conv/nvidia/*.cu") + add_files("../src/infiniop/ops/add/nvidia/*.cu") if has_config("ninetoothed") then add_files("../build/ninetoothed/*.c", {cxflags = {"-Wno-return-type"}}) From c6ab0279c75ecb04b7ac530e8c53c44ec0b2bbde Mon Sep 17 00:00:00 2001 From: Sxy-17 Date: Mon, 8 Dec 2025 21:34:13 +0800 Subject: [PATCH 03/13] =?UTF-8?q?=E6=B7=BB=E5=8A=A0relu=E7=AE=97=E5=AD=90?= =?UTF-8?q?=EF=BC=8C=E5=B7=B2=E5=8F=AF=E8=B7=91test=EF=BC=9B=20=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0layer=5Fnorm=E7=AE=97=E5=AD=90=EF=BC=8C=E6=9A=82?= =?UTF-8?q?=E6=9C=89=E6=9C=AA=E5=A4=84=E7=90=86=E7=9A=84=E6=8A=A5=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/infiniop.h | 1 + include/infiniop/ops/layer_norm.h | 34 ++ .../ops/layer_norm/cpu/layer_norm_cpu.cc | 112 +++++++ .../ops/layer_norm/cpu/layer_norm_cpu.h | 8 + src/infiniop/ops/layer_norm/cuda/kernel.cuh | 157 +++++++++ src/infiniop/ops/layer_norm/info.h | 82 +++++ src/infiniop/ops/layer_norm/layer_norm.h | 53 ++++ .../layer_norm/nvidia/layer_norm_nvidia.cu | 264 ++++++++++++++++ .../layer_norm/nvidia/layer_norm_nvidia.cuh | 7 + src/infiniop/ops/layer_norm/operator.cc | 172 ++++++++++ src/infiniop/ops/relu/cuda/kernel.cuh | 35 ++ src/infiniop/ops/relu/nvidia/relu_nvidia.cu | 23 +- src/infiniop/ops/relu/nvidia/relu_nvidia.cuh | 4 +- src/infiniop/ops/relu/operator.cc | 60 +++- src/infiniop/reduce/cuda/reduce.cuh | 1 + test/infiniop/layer_norm.py | 298 ++++++++++++++++++ test/infiniop/libinfiniop/op_register.py | 36 +++ xmake/hygon.lua | 2 + 18 files changed, 1330 insertions(+), 19 deletions(-) create mode 100644 include/infiniop/ops/layer_norm.h create mode 100644 src/infiniop/ops/layer_norm/cpu/layer_norm_cpu.cc create mode 100644 src/infiniop/ops/layer_norm/cpu/layer_norm_cpu.h create mode 100644 src/infiniop/ops/layer_norm/cuda/kernel.cuh create mode 100644 src/infiniop/ops/layer_norm/info.h create mode 100644 src/infiniop/ops/layer_norm/layer_norm.h create mode 100644 src/infiniop/ops/layer_norm/nvidia/layer_norm_nvidia.cu create mode 100644 src/infiniop/ops/layer_norm/nvidia/layer_norm_nvidia.cuh create mode 100644 src/infiniop/ops/layer_norm/operator.cc create mode 100644 src/infiniop/ops/relu/cuda/kernel.cuh create mode 100644 test/infiniop/layer_norm.py diff --git a/include/infiniop.h b/include/infiniop.h index f0d75abc9..c54986b4b 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -10,6 +10,7 @@ #include "infiniop/ops/conv.h" #include "infiniop/ops/dequantize_awq.h" #include "infiniop/ops/gemm.h" +#include "infiniop/ops/layer_norm.h" #include "infiniop/ops/mul.h" #include "infiniop/ops/random_sample.h" #include "infiniop/ops/rearrange.h" diff --git a/include/infiniop/ops/layer_norm.h b/include/infiniop/ops/layer_norm.h new file mode 100644 index 000000000..5f852a9db --- /dev/null +++ b/include/infiniop/ops/layer_norm.h @@ -0,0 +1,34 @@ +#ifndef __INFINIOP_LAYER_NORM_API_H__ +#define __INFINIOP_LAYER_NORM_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopLayerNormDescriptor_t; + +__C __export infiniStatus_t infiniopCreateLayerNormDescriptor( + infiniopHandle_t handle, + infiniopLayerNormDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_standardization_desc, + infiniopTensorDescriptor_t input_std_deviation_desc, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t weight_desc, + infiniopTensorDescriptor_t bias_desc, + float eps); + +__C __export infiniStatus_t infiniopGetLayerNormWorkspaceSize(infiniopLayerNormDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopLayerNorm(infiniopLayerNormDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *output, + void *input_standardization, + void *input_std_deviation, + const void *input, + const void *weight, + const void *bias, + void *stream); + +__C __export infiniStatus_t infiniopDestroyLayerNormDescriptor(infiniopLayerNormDescriptor_t desc); + +#endif \ No newline at end of file diff --git a/src/infiniop/ops/layer_norm/cpu/layer_norm_cpu.cc b/src/infiniop/ops/layer_norm/cpu/layer_norm_cpu.cc new file mode 100644 index 000000000..58a0030e8 --- /dev/null +++ b/src/infiniop/ops/layer_norm/cpu/layer_norm_cpu.cc @@ -0,0 +1,112 @@ +#include "layer_norm_cpu.h" +#include "../../../devices/cpu/common_cpu.h" +#include "../../../reduce/cpu/reduce.h" +#include "../info.h" + +namespace op::layer_norm::cpu { + +template +infiniStatus_t calculate_layer_norm( + const LayerNormInfo &info, + Tdata *output, + Tdata *input_standardization, + Tdata *input_std_deviation, + const Tdata *input, + const Tdata *weight, + const Tdata *bias) { + +#pragma omp parallel for + for (int b = 0; b < (int)(info.input_shape[0] * info.input_shape[1]); b++) { + int b0 = b / (int)info.input_shape[1], b1 = b % (int)info.input_shape[1]; + auto output_ptr = output + b0 * info.output_strides[0] + b1 * info.output_strides[1]; + auto input_ptr = input + b0 * info.input_strides[0] + b1 * info.input_strides[1]; + auto standard_ptr = input_standardization + b0 * info.input_standardization_strides[0] + b1 * info.input_standardization_strides[1]; + auto std_ptr = input_std_deviation + b0 * info.input_std_deviation_strides[0] + b1 * info.input_std_deviation_strides[1]; + float mean = op::common_cpu::reduce_op::sum( + input_ptr, + info.normalized_size, + info.input_strides[2]) + / info.input_shape[2]; + float sum_sq = op::common_cpu::reduce_op::sumSquared( + input_ptr, + info.normalized_size, + info.input_strides[2]); + float var = sum_sq / (info.normalized_size) - mean * mean; + float std_deviation = std::sqrt(var + info.eps); + *std_ptr = utils::cast(std_deviation); + + for (size_t d = 0; d < info.normalized_size; d++) { + float x_standard = (utils::cast(*(input_ptr + d * info.input_strides[2])) - mean) / std_deviation; + *(standard_ptr + d * info.input_standardization_strides[2]) = utils::cast(x_standard); + *(output_ptr + d * info.output_strides[2]) = utils::cast( + x_standard * utils::cast(*(weight + d * info.weight_strides[0])) + (info.bias_exist ? utils::cast(*(bias + d * info.bias_strides[0])) : float(0))); + } + } + + return INFINI_STATUS_SUCCESS; +} + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_standardization_desc, + infiniopTensorDescriptor_t input_std_deviation_desc, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t weight_desc, + infiniopTensorDescriptor_t bias_desc, + float eps) { + auto handle = reinterpret_cast(handle_); + + // --------------------- start: check data type and calculate workspace size ---------------------- + auto dtype = input_desc->dtype(); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16); + size_t WorkSpaceSize = 0; + + auto result = LayerNormInfo::createLayerNormInfo( + output_desc, + input_standardization_desc, + input_std_deviation_desc, + input_desc, + weight_desc, + bias_desc, + eps); + CHECK_RESULT(result); + const LayerNormInfo &info = result.take(); + + *desc_ptr = new Descriptor( + dtype, std::move(info), WorkSpaceSize, + nullptr, + handle->device, handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +#define CALCULATE_LAYER_NORM(TDATA) \ + CHECK_STATUS(calculate_layer_norm(_info, \ + (TDATA *)output, (TDATA *)input_standardization, (TDATA *)input_std_deviation, (const TDATA *)input, (const TDATA *)weight, (const TDATA *)bias)) + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + void *input_standardization, + void *input_std_deviation, + const void *input, + const void *weight, + const void *bias, + void *stream) const { + if (_info.dtype == INFINI_DTYPE_F16) { + CALCULATE_LAYER_NORM(fp16_t); + } else if (_info.dtype == INFINI_DTYPE_BF16) { + CALCULATE_LAYER_NORM(bf16_t); + } else if (_info.dtype == INFINI_DTYPE_F32) { + CALCULATE_LAYER_NORM(float); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + return INFINI_STATUS_SUCCESS; +} +} // namespace op::layer_norm::cpu \ No newline at end of file diff --git a/src/infiniop/ops/layer_norm/cpu/layer_norm_cpu.h b/src/infiniop/ops/layer_norm/cpu/layer_norm_cpu.h new file mode 100644 index 000000000..51d56bbf7 --- /dev/null +++ b/src/infiniop/ops/layer_norm/cpu/layer_norm_cpu.h @@ -0,0 +1,8 @@ +#ifndef __LAYER_NORM_CPU_H__ +#define __LAYER_NORM_CPU_H__ + +#include "../layer_norm.h" + +DESCRIPTOR(cpu) + +#endif // __LAYER_NORM_CPU_H__ \ No newline at end of file diff --git a/src/infiniop/ops/layer_norm/cuda/kernel.cuh b/src/infiniop/ops/layer_norm/cuda/kernel.cuh new file mode 100644 index 000000000..120ed203a --- /dev/null +++ b/src/infiniop/ops/layer_norm/cuda/kernel.cuh @@ -0,0 +1,157 @@ +#ifndef __LAYER_NORM_KERNEL_CUH__ +#define __LAYER_NORM_KERNEL_CUH__ +#include + +template +__device__ void layerNormKernel( + Tdata *output, + Tdata *input_standardization, + Tdata *input_std_deviation, + const Tdata *input, + const Tdata *weight, + const Tdata *bias, + float eps, + size_t normalized_size, + const ptrdiff_t *output_strides, + const ptrdiff_t *input_standardization_strides, + const ptrdiff_t *input_std_deviation_strides, + const ptrdiff_t *input_strides, + ptrdiff_t weight_stride, + ptrdiff_t bias_stride, + bool bias_exist) { + size_t b0 = blockIdx.x, b1 = blockIdx.y; + + auto output_ptr = output + b0 * output_strides[0] + b1 * output_strides[1]; + auto input_ptr = input + b0 * input_strides[0] + b1 * input_strides[1]; + auto standard_ptr = input_standardization + b0 * input_standardization_strides[0] + b1 * input_standardization_strides[1]; + auto std_ptr = input_std_deviation + b0 * input_std_deviation_strides[0] + b1 * input_std_deviation_strides[1]; + Tcompute mean = op::common_cuda::reduce_op::sum( + input_ptr, + normalized_size) + / normalized_size; + Tcompute sum_squared = op::common_cuda::reduce_op::sumSquared( + input_ptr, + normalized_size); + + Tcompute var = sum_squared / normalized_size - mean * mean; + Tcompute std_deviation = sqrtf(var + Tcompute(eps)); + *std_ptr = std_deviation; + + for (size_t d = 0; d < normalized_size; d++) { + Tcompute x_standard = (Tcompute(input_ptr[d]) - mean) / std_deviation; + standard_ptr[d] = x_standard; + output_ptr[d] = x_standard * Tcompute(*(weight + d * weight_stride)) + (bias_exist ? Tcompute(*(bias + d * bias_stride)) : Tcompute(0)); + } +} + +template +__device__ void blockLayernormKernel(T *output, T const *input, T const *weight, T const *bias, float eps, int dimsize, + const ptrdiff_t *output_strides, + const ptrdiff_t *input_strides, + const size_t *shape, + ptrdiff_t weight_stride, + ptrdiff_t bias_stride, + int ndim, + bool bias_exist) { + // 只能处理axis=-1 + int ind_i = 0; // input id + int ind_o = 0; // output id + int tid = blockIdx.x; + for (int j = ndim - 2; j >= 0; j--) { + ind_i += (tid % (int)shape[j]) * (int)input_strides[j]; + ind_o += (tid % (int)shape[j]) * (int)output_strides[j]; + tid = tid / (int)shape[j]; + } + + float mu_partial = op::common_cuda::reduce_op::sum( + input + ind_i, + dimsize) + / dimsize; + __shared__ float mu; + if (threadIdx.x == 0) { + mu = mu_partial; + } // threadIdx.x = 0对应的是全局sum + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + __syncthreads(); + float sigma2_partial = 0.0f; + for (int id = threadIdx.x; id < dimsize; id += BLOCK_SIZE) { + sigma2_partial += (static_cast(input[ind_i + id]) - mu) * (static_cast(input[ind_i + id]) - mu); + } + + __shared__ float sigma2; + float sigma2_block = BlockReduce(temp_storage).Reduce(sigma2_partial, cub::Sum()); + if (threadIdx.x == 0) { + float sigma_tmp = sqrt(sigma2_block * __fdividef(1.0F, dimsize) + eps); + sigma2 = __fdividef(1.0F, sigma_tmp); + } + __syncthreads(); + for (int id = threadIdx.x; id < dimsize; id += BLOCK_SIZE) { + output[ind_o + id] = static_cast(static_cast(weight[id * weight_stride]) * (static_cast(input[ind_i + id]) - mu) * sigma2 + (bias_exist ? static_cast(bias[id * bias_stride]) : 0.0f)); + } +} +template +struct SumOp { + __device__ __forceinline__ T operator()(const T &a, const T &b) const { + return a + b; + } +}; + +template