diff --git a/kernel/arm64/KERNEL.NEOVERSEN2 b/kernel/arm64/KERNEL.NEOVERSEN2 index 8269812347..e70486768e 100644 --- a/kernel/arm64/KERNEL.NEOVERSEN2 +++ b/kernel/arm64/KERNEL.NEOVERSEN2 @@ -204,6 +204,25 @@ BGEMMOTCOPYOBJ = bgemm_otcopy$(TSUFFIX).$(SUFFIX) BGEMVTKERNEL = sbgemv_t_bfdot.c BGEMVNKERNEL = bgemv_n_sve_v3x4.c +ifeq ($(BUILD_HFLOAT16), 1) +SHGEMMKERNEL = shgemm_kernel_$(SHGEMM_UNROLL_M)x$(SHGEMM_UNROLL_N)_neoversen2.c +SHGEMMINCOPY = shgemm_ncopy_$(SHGEMM_UNROLL_M)_neoversen2.c +SHGEMMITCOPY = shgemm_tcopy_$(SHGEMM_UNROLL_M)_neoversen2.c +ifneq ($(SHGEMM_UNROLL_M), $(SHGEMM_UNROLL_N)) + SHGEMMINCOPY = ../generic/gemm_ncopy_$(SHGEMM_UNROLL_M).c + SHGEMMITCOPY = ../generic/gemm_tcopy_$(SHGEMM_UNROLL_M).c +endif +SHGEMMONCOPY = shgemm_ncopy_$(SHGEMM_UNROLL_N)_neoversen2.c +SHGEMMOTCOPY = shgemm_tcopy_$(SHGEMM_UNROLL_N)_neoversen2.c +SHGEMMINCOPYOBJ = shgemm_incopy$(TSUFFIX).$(SUFFIX) +SHGEMMITCOPYOBJ = shgemm_itcopy$(TSUFFIX).$(SUFFIX) +SHGEMMONCOPYOBJ = shgemm_oncopy$(TSUFFIX).$(SUFFIX) +SHGEMMOTCOPYOBJ = shgemm_otcopy$(TSUFFIX).$(SUFFIX) +ifndef SHGEMM_BETA +SHGEMM_BETA = sbgemm_beta_neoversen2.c +endif +endif + SBGEMM_BETA = sbgemm_beta_neoversen2.c SBGEMMKERNEL = sbgemm_kernel_$(SBGEMM_UNROLL_M)x$(SBGEMM_UNROLL_N)_neoversen2.c ifneq ($(SBGEMM_UNROLL_M), $(SBGEMM_UNROLL_N)) diff --git a/kernel/arm64/shgemm_kernel_8x8_neoversen2.c b/kernel/arm64/shgemm_kernel_8x8_neoversen2.c new file mode 100644 index 0000000000..cc52ea3366 --- /dev/null +++ b/kernel/arm64/shgemm_kernel_8x8_neoversen2.c @@ -0,0 +1,887 @@ +/*************************************************************************** + * Copyright (c) 2026 The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + +#include + +#include "common.h" + +static inline void kernel_8x8(BLASLONG K, const float16_t *A, const float16_t *B, FLOAT *C, BLASLONG ldc, FLOAT alpha) { + float32x4_t c0_low = vdupq_n_f32(0.0f); + float32x4_t c0_high = vdupq_n_f32(0.0f); + float32x4_t c1_low = vdupq_n_f32(0.0f); + float32x4_t c1_high = vdupq_n_f32(0.0f); + float32x4_t c2_low = vdupq_n_f32(0.0f); + float32x4_t c2_high = vdupq_n_f32(0.0f); + float32x4_t c3_low = vdupq_n_f32(0.0f); + float32x4_t c3_high = vdupq_n_f32(0.0f); + float32x4_t c4_low = vdupq_n_f32(0.0f); + float32x4_t c4_high = vdupq_n_f32(0.0f); + float32x4_t c5_low = vdupq_n_f32(0.0f); + float32x4_t c5_high = vdupq_n_f32(0.0f); + float32x4_t c6_low = vdupq_n_f32(0.0f); + float32x4_t c6_high = vdupq_n_f32(0.0f); + float32x4_t c7_low = vdupq_n_f32(0.0f); + float32x4_t c7_high = vdupq_n_f32(0.0f); + + for (BLASLONG k = 0; k < K; ++k) { + float16x8_t a_f16 = vld1q_f16(A); + float32x4_t a_low = vcvt_f32_f16(vget_low_f16(a_f16)); + float32x4_t a_high = vcvt_f32_f16(vget_high_f16(a_f16)); + + float16x8_t b_f16 = vld1q_f16(B); + float32x4_t b_low = vcvt_f32_f16(vget_low_f16(b_f16)); + float32x4_t b_high = vcvt_f32_f16(vget_high_f16(b_f16)); + + float32_t b0_lane0 = vgetq_lane_f32(b_low, 0); + c0_low = vfmaq_n_f32(c0_low, a_low, b0_lane0); + c0_high = vfmaq_n_f32(c0_high, a_high, b0_lane0); + + float32_t b0_lane1 = vgetq_lane_f32(b_low, 1); + c1_low = vfmaq_n_f32(c1_low, a_low, b0_lane1); + c1_high = vfmaq_n_f32(c1_high, a_high, b0_lane1); + + float32_t b0_lane2 = vgetq_lane_f32(b_low, 2); + c2_low = vfmaq_n_f32(c2_low, a_low, b0_lane2); + c2_high = vfmaq_n_f32(c2_high, a_high, b0_lane2); + + float32_t b0_lane3 = vgetq_lane_f32(b_low, 3); + c3_low = vfmaq_n_f32(c3_low, a_low, b0_lane3); + c3_high = vfmaq_n_f32(c3_high, a_high, b0_lane3); + + float32_t b1_lane0 = vgetq_lane_f32(b_high, 0); + c4_low = vfmaq_n_f32(c4_low, a_low, b1_lane0); + c4_high = vfmaq_n_f32(c4_high, a_high, b1_lane0); + + float32_t b1_lane1 = vgetq_lane_f32(b_high, 1); + c5_low = vfmaq_n_f32(c5_low, a_low, b1_lane1); + c5_high = vfmaq_n_f32(c5_high, a_high, b1_lane1); + + float32_t b1_lane2 = vgetq_lane_f32(b_high, 2); + c6_low = vfmaq_n_f32(c6_low, a_low, b1_lane2); + c6_high = vfmaq_n_f32(c6_high, a_high, b1_lane2); + + float32_t b1_lane3 = vgetq_lane_f32(b_high, 3); + c7_low = vfmaq_n_f32(c7_low, a_low, b1_lane3); + c7_high = vfmaq_n_f32(c7_high, a_high, b1_lane3); + + A += 8; + B += 8; + } + + FLOAT *col_0 = C + 0 * ldc; + FLOAT *col_1 = C + 1 * ldc; + FLOAT *col_2 = C + 2 * ldc; + FLOAT *col_3 = C + 3 * ldc; + FLOAT *col_4 = C + 4 * ldc; + FLOAT *col_5 = C + 5 * ldc; + FLOAT *col_6 = C + 6 * ldc; + FLOAT *col_7 = C + 7 * ldc; + + float32x4_t t0_l = vld1q_f32(col_0); + float32x4_t t0_h = vld1q_f32(col_0 + 4); + t0_l = vaddq_f32(t0_l, vmulq_n_f32(c0_low, alpha)); + t0_h = vaddq_f32(t0_h, vmulq_n_f32(c0_high, alpha)); + vst1q_f32(col_0, t0_l); + vst1q_f32(col_0 + 4, t0_h); + + float32x4_t t1_l = vld1q_f32(col_1); + float32x4_t t1_h = vld1q_f32(col_1 + 4); + t1_l = vaddq_f32(t1_l, vmulq_n_f32(c1_low, alpha)); + t1_h = vaddq_f32(t1_h, vmulq_n_f32(c1_high, alpha)); + vst1q_f32(col_1, t1_l); + vst1q_f32(col_1 + 4, t1_h); + + float32x4_t t2_l = vld1q_f32(col_2); + float32x4_t t2_h = vld1q_f32(col_2 + 4); + t2_l = vaddq_f32(t2_l, vmulq_n_f32(c2_low, alpha)); + t2_h = vaddq_f32(t2_h, vmulq_n_f32(c2_high, alpha)); + vst1q_f32(col_2, t2_l); + vst1q_f32(col_2 + 4, t2_h); + + float32x4_t t3_l = vld1q_f32(col_3); + float32x4_t t3_h = vld1q_f32(col_3 + 4); + t3_l = vaddq_f32(t3_l, vmulq_n_f32(c3_low, alpha)); + t3_h = vaddq_f32(t3_h, vmulq_n_f32(c3_high, alpha)); + vst1q_f32(col_3, t3_l); + vst1q_f32(col_3 + 4, t3_h); + + float32x4_t t4_l = vld1q_f32(col_4); + float32x4_t t4_h = vld1q_f32(col_4 + 4); + t4_l = vaddq_f32(t4_l, vmulq_n_f32(c4_low, alpha)); + t4_h = vaddq_f32(t4_h, vmulq_n_f32(c4_high, alpha)); + vst1q_f32(col_4, t4_l); + vst1q_f32(col_4 + 4, t4_h); + + float32x4_t t5_l = vld1q_f32(col_5); + float32x4_t t5_h = vld1q_f32(col_5 + 4); + t5_l = vaddq_f32(t5_l, vmulq_n_f32(c5_low, alpha)); + t5_h = vaddq_f32(t5_h, vmulq_n_f32(c5_high, alpha)); + vst1q_f32(col_5, t5_l); + vst1q_f32(col_5 + 4, t5_h); + + float32x4_t t6_l = vld1q_f32(col_6); + float32x4_t t6_h = vld1q_f32(col_6 + 4); + t6_l = vaddq_f32(t6_l, vmulq_n_f32(c6_low, alpha)); + t6_h = vaddq_f32(t6_h, vmulq_n_f32(c6_high, alpha)); + vst1q_f32(col_6, t6_l); + vst1q_f32(col_6 + 4, t6_h); + + float32x4_t t7_l = vld1q_f32(col_7); + float32x4_t t7_h = vld1q_f32(col_7 + 4); + t7_l = vaddq_f32(t7_l, vmulq_n_f32(c7_low, alpha)); + t7_h = vaddq_f32(t7_h, vmulq_n_f32(c7_high, alpha)); + vst1q_f32(col_7, t7_l); + vst1q_f32(col_7 + 4, t7_h); +} + +static inline void kernel_4x8(BLASLONG K, const float16_t *A, const float16_t *B, FLOAT *C, BLASLONG ldc, FLOAT alpha) { + float32x4_t c0 = vdupq_n_f32(0.0f); + float32x4_t c1 = vdupq_n_f32(0.0f); + float32x4_t c2 = vdupq_n_f32(0.0f); + float32x4_t c3 = vdupq_n_f32(0.0f); + float32x4_t c4 = vdupq_n_f32(0.0f); + float32x4_t c5 = vdupq_n_f32(0.0f); + float32x4_t c6 = vdupq_n_f32(0.0f); + float32x4_t c7 = vdupq_n_f32(0.0f); + + for (BLASLONG k = 0; k < K; ++k) { + float32x4_t a_f16 = vcvt_f32_f16(vld1_f16(A)); + + float16x8_t b_f16 = vld1q_f16(B); + float32x4_t b_low = vcvt_f32_f16(vget_low_f16(b_f16)); + float32x4_t b_high = vcvt_f32_f16(vget_high_f16(b_f16)); + + float32_t b0_lane0 = vgetq_lane_f32(b_low, 0); + c0 = vfmaq_n_f32(c0, a_f16, b0_lane0); + + float32_t b0_lane1 = vgetq_lane_f32(b_low, 1); + c1 = vfmaq_n_f32(c1, a_f16, b0_lane1); + + float32_t b0_lane2 = vgetq_lane_f32(b_low, 2); + c2 = vfmaq_n_f32(c2, a_f16, b0_lane2); + + float32_t b0_lane3 = vgetq_lane_f32(b_low, 3); + c3 = vfmaq_n_f32(c3, a_f16, b0_lane3); + + float32_t b1_lane0 = vgetq_lane_f32(b_high, 0); + c4 = vfmaq_n_f32(c4, a_f16, b1_lane0); + + float32_t b1_lane1 = vgetq_lane_f32(b_high, 1); + c5 = vfmaq_n_f32(c5, a_f16, b1_lane1); + + float32_t b1_lane2 = vgetq_lane_f32(b_high, 2); + c6 = vfmaq_n_f32(c6, a_f16, b1_lane2); + + float32_t b1_lane3 = vgetq_lane_f32(b_high, 3); + c7 = vfmaq_n_f32(c7, a_f16, b1_lane3); + + A += 4; + B += 8; + } + + FLOAT *col_0 = C + 0 * ldc; + FLOAT *col_1 = C + 1 * ldc; + FLOAT *col_2 = C + 2 * ldc; + FLOAT *col_3 = C + 3 * ldc; + FLOAT *col_4 = C + 4 * ldc; + FLOAT *col_5 = C + 5 * ldc; + FLOAT *col_6 = C + 6 * ldc; + FLOAT *col_7 = C + 7 * ldc; + + float32x4_t t0 = vld1q_f32(col_0); + t0 = vaddq_f32(t0, vmulq_n_f32(c0, alpha)); + vst1q_f32(col_0, t0); + + float32x4_t t1 = vld1q_f32(col_1); + t1 = vaddq_f32(t1, vmulq_n_f32(c1, alpha)); + vst1q_f32(col_1, t1); + + float32x4_t t2 = vld1q_f32(col_2); + t2 = vaddq_f32(t2, vmulq_n_f32(c2, alpha)); + vst1q_f32(col_2, t2); + + float32x4_t t3 = vld1q_f32(col_3); + t3 = vaddq_f32(t3, vmulq_n_f32(c3, alpha)); + vst1q_f32(col_3, t3); + + float32x4_t t4 = vld1q_f32(col_4); + t4 = vaddq_f32(t4, vmulq_n_f32(c4, alpha)); + vst1q_f32(col_4, t4); + + float32x4_t t5 = vld1q_f32(col_5); + t5 = vaddq_f32(t5, vmulq_n_f32(c5, alpha)); + vst1q_f32(col_5, t5); + + float32x4_t t6 = vld1q_f32(col_6); + t6 = vaddq_f32(t6, vmulq_n_f32(c6, alpha)); + vst1q_f32(col_6, t6); + + float32x4_t t7 = vld1q_f32(col_7); + t7 = vaddq_f32(t7, vmulq_n_f32(c7, alpha)); + vst1q_f32(col_7, t7); +} + +static inline void kernel_2x8(BLASLONG K, const float16_t *A, const float16_t *B, FLOAT *C, BLASLONG ldc, FLOAT alpha) { + float32x2_t c0 = vdup_n_f32(0.0f); + float32x2_t c1 = vdup_n_f32(0.0f); + float32x2_t c2 = vdup_n_f32(0.0f); + float32x2_t c3 = vdup_n_f32(0.0f); + float32x2_t c4 = vdup_n_f32(0.0f); + float32x2_t c5 = vdup_n_f32(0.0f); + float32x2_t c6 = vdup_n_f32(0.0f); + float32x2_t c7 = vdup_n_f32(0.0f); + + for (BLASLONG k = 0; k < K; ++k) { + float32x4_t a_f32 = vcvt_f32_f16(vld1_f16(A)); + float32x2_t a_low = vget_low_f32(a_f32); + + float16x8_t b_f16 = vld1q_f16(B); + float32x4_t b_low = vcvt_f32_f16(vget_low_f16(b_f16)); + float32x4_t b_high = vcvt_f32_f16(vget_high_f16(b_f16)); + + float32_t b0_lane0 = vgetq_lane_f32(b_low, 0); + c0 = vfma_n_f32(c0, a_low, b0_lane0); + + float32_t b0_lane1 = vgetq_lane_f32(b_low, 1); + c1 = vfma_n_f32(c1, a_low, b0_lane1); + + float32_t b0_lane2 = vgetq_lane_f32(b_low, 2); + c2 = vfma_n_f32(c2, a_low, b0_lane2); + + float32_t b0_lane3 = vgetq_lane_f32(b_low, 3); + c3 = vfma_n_f32(c3, a_low, b0_lane3); + + float32_t b1_lane0 = vgetq_lane_f32(b_high, 0); + c4 = vfma_n_f32(c4, a_low, b1_lane0); + + float32_t b1_lane1 = vgetq_lane_f32(b_high, 1); + c5 = vfma_n_f32(c5, a_low, b1_lane1); + + float32_t b1_lane2 = vgetq_lane_f32(b_high, 2); + c6 = vfma_n_f32(c6, a_low, b1_lane2); + + float32_t b1_lane3 = vgetq_lane_f32(b_high, 3); + c7 = vfma_n_f32(c7, a_low, b1_lane3); + + A += 2; + B += 8; + } + + FLOAT *col_0 = C + 0 * ldc; + FLOAT *col_1 = C + 1 * ldc; + FLOAT *col_2 = C + 2 * ldc; + FLOAT *col_3 = C + 3 * ldc; + FLOAT *col_4 = C + 4 * ldc; + FLOAT *col_5 = C + 5 * ldc; + FLOAT *col_6 = C + 6 * ldc; + FLOAT *col_7 = C + 7 * ldc; + + float32x2_t t0 = vld1_f32(col_0); + t0 = vadd_f32(t0, vmul_n_f32(c0, alpha)); + vst1_f32(col_0, t0); + + float32x2_t t1 = vld1_f32(col_1); + t1 = vadd_f32(t1, vmul_n_f32(c1, alpha)); + vst1_f32(col_1, t1); + + float32x2_t t2 = vld1_f32(col_2); + t2 = vadd_f32(t2, vmul_n_f32(c2, alpha)); + vst1_f32(col_2, t2); + + float32x2_t t3 = vld1_f32(col_3); + t3 = vadd_f32(t3, vmul_n_f32(c3, alpha)); + vst1_f32(col_3, t3); + + float32x2_t t4 = vld1_f32(col_4); + t4 = vadd_f32(t4, vmul_n_f32(c4, alpha)); + vst1_f32(col_4, t4); + + float32x2_t t5 = vld1_f32(col_5); + t5 = vadd_f32(t5, vmul_n_f32(c5, alpha)); + vst1_f32(col_5, t5); + + float32x2_t t6 = vld1_f32(col_6); + t6 = vadd_f32(t6, vmul_n_f32(c6, alpha)); + vst1_f32(col_6, t6); + + float32x2_t t7 = vld1_f32(col_7); + t7 = vadd_f32(t7, vmul_n_f32(c7, alpha)); + vst1_f32(col_7, t7); +} + +static inline void kernel_1x8(BLASLONG K, const float16_t *A, const float16_t *B, FLOAT *C, BLASLONG ldc, FLOAT alpha) { + FLOAT c0 = 0, c1 = 0, c2 = 0, c3 = 0, c4 = 0, c5 = 0, c6 = 0, c7 = 0; + + for (BLASLONG k = 0; k < K; ++k) { + FLOAT a = A[0]; + c0 += a * B[0]; + c1 += a * B[1]; + c2 += a * B[2]; + c3 += a * B[3]; + c4 += a * B[4]; + c5 += a * B[5]; + c6 += a * B[6]; + c7 += a * B[7]; + + A += 1; + B += 8; + } + + C[0 * ldc] += alpha * c0; + C[1 * ldc] += alpha * c1; + C[2 * ldc] += alpha * c2; + C[3 * ldc] += alpha * c3; + C[4 * ldc] += alpha * c4; + C[5 * ldc] += alpha * c5; + C[6 * ldc] += alpha * c6; + C[7 * ldc] += alpha * c7; +} + +static inline void kernel_8x4(BLASLONG K, const float16_t *A, const float16_t *B, FLOAT *C, BLASLONG ldc, FLOAT alpha) { + float32x4_t c0_low = vdupq_n_f32(0.0f); + float32x4_t c0_high = vdupq_n_f32(0.0f); + float32x4_t c1_low = vdupq_n_f32(0.0f); + float32x4_t c1_high = vdupq_n_f32(0.0f); + float32x4_t c2_low = vdupq_n_f32(0.0f); + float32x4_t c2_high = vdupq_n_f32(0.0f); + float32x4_t c3_low = vdupq_n_f32(0.0f); + float32x4_t c3_high = vdupq_n_f32(0.0f); + + for (BLASLONG k = 0; k < K; ++k) { + float16x8_t a_f16 = vld1q_f16(A); + float32x4_t a_low = vcvt_f32_f16(vget_low_f16(a_f16)); + float32x4_t a_high = vcvt_f32_f16(vget_high_f16(a_f16)); + + float32x4_t b_f32 = vcvt_f32_f16(vld1_f16(B)); + + float32_t b0_lane0 = vgetq_lane_f32(b_f32, 0); + c0_low = vfmaq_n_f32(c0_low, a_low, b0_lane0); + c0_high = vfmaq_n_f32(c0_high, a_high, b0_lane0); + + float32_t b0_lane1 = vgetq_lane_f32(b_f32, 1); + c1_low = vfmaq_n_f32(c1_low, a_low, b0_lane1); + c1_high = vfmaq_n_f32(c1_high, a_high, b0_lane1); + + float32_t b0_lane2 = vgetq_lane_f32(b_f32, 2); + c2_low = vfmaq_n_f32(c2_low, a_low, b0_lane2); + c2_high = vfmaq_n_f32(c2_high, a_high, b0_lane2); + + float32_t b0_lane3 = vgetq_lane_f32(b_f32, 3); + c3_low = vfmaq_n_f32(c3_low, a_low, b0_lane3); + c3_high = vfmaq_n_f32(c3_high, a_high, b0_lane3); + + A += 8; + B += 4; + } + + FLOAT *col_0 = C + 0 * ldc; + FLOAT *col_1 = C + 1 * ldc; + FLOAT *col_2 = C + 2 * ldc; + FLOAT *col_3 = C + 3 * ldc; + + float32x4_t t0_l = vld1q_f32(col_0); + float32x4_t t0_h = vld1q_f32(col_0 + 4); + t0_l = vaddq_f32(t0_l, vmulq_n_f32(c0_low, alpha)); + t0_h = vaddq_f32(t0_h, vmulq_n_f32(c0_high, alpha)); + vst1q_f32(col_0, t0_l); + vst1q_f32(col_0 + 4, t0_h); + + float32x4_t t1_l = vld1q_f32(col_1); + float32x4_t t1_h = vld1q_f32(col_1 + 4); + t1_l = vaddq_f32(t1_l, vmulq_n_f32(c1_low, alpha)); + t1_h = vaddq_f32(t1_h, vmulq_n_f32(c1_high, alpha)); + vst1q_f32(col_1, t1_l); + vst1q_f32(col_1 + 4, t1_h); + + float32x4_t t2_l = vld1q_f32(col_2); + float32x4_t t2_h = vld1q_f32(col_2 + 4); + t2_l = vaddq_f32(t2_l, vmulq_n_f32(c2_low, alpha)); + t2_h = vaddq_f32(t2_h, vmulq_n_f32(c2_high, alpha)); + vst1q_f32(col_2, t2_l); + vst1q_f32(col_2 + 4, t2_h); + + float32x4_t t3_l = vld1q_f32(col_3); + float32x4_t t3_h = vld1q_f32(col_3 + 4); + t3_l = vaddq_f32(t3_l, vmulq_n_f32(c3_low, alpha)); + t3_h = vaddq_f32(t3_h, vmulq_n_f32(c3_high, alpha)); + vst1q_f32(col_3, t3_l); + vst1q_f32(col_3 + 4, t3_h); +} + +static inline void kernel_4x4(BLASLONG K, const float16_t *A, const float16_t *B, FLOAT *C, BLASLONG ldc, FLOAT alpha) { + float32x4_t c0 = vdupq_n_f32(0.0f); + float32x4_t c1 = vdupq_n_f32(0.0f); + float32x4_t c2 = vdupq_n_f32(0.0f); + float32x4_t c3 = vdupq_n_f32(0.0f); + + for (BLASLONG k = 0; k < K; ++k) { + float32x4_t a_f32 = vcvt_f32_f16(vld1_f16(A)); + float32x4_t b_f32 = vcvt_f32_f16(vld1_f16(B)); + + float32_t b0_lane0 = vgetq_lane_f32(b_f32, 0); + c0 = vfmaq_n_f32(c0, a_f32, b0_lane0); + + float32_t b0_lane1 = vgetq_lane_f32(b_f32, 1); + c1 = vfmaq_n_f32(c1, a_f32, b0_lane1); + + float32_t b0_lane2 = vgetq_lane_f32(b_f32, 2); + c2 = vfmaq_n_f32(c2, a_f32, b0_lane2); + + float32_t b0_lane3 = vgetq_lane_f32(b_f32, 3); + c3 = vfmaq_n_f32(c3, a_f32, b0_lane3); + + A += 4; + B += 4; + } + + FLOAT *col_0 = C + 0 * ldc; + FLOAT *col_1 = C + 1 * ldc; + FLOAT *col_2 = C + 2 * ldc; + FLOAT *col_3 = C + 3 * ldc; + + float32x4_t t0 = vld1q_f32(col_0); + t0 = vaddq_f32(t0, vmulq_n_f32(c0, alpha)); + vst1q_f32(col_0, t0); + + float32x4_t t1 = vld1q_f32(col_1); + t1 = vaddq_f32(t1, vmulq_n_f32(c1, alpha)); + vst1q_f32(col_1, t1); + + float32x4_t t2 = vld1q_f32(col_2); + t2 = vaddq_f32(t2, vmulq_n_f32(c2, alpha)); + vst1q_f32(col_2, t2); + + float32x4_t t3 = vld1q_f32(col_3); + t3 = vaddq_f32(t3, vmulq_n_f32(c3, alpha)); + vst1q_f32(col_3, t3); +} + +static inline void kernel_2x4(BLASLONG K, const float16_t *A, const float16_t *B, FLOAT *C, BLASLONG ldc, FLOAT alpha) { + float32x2_t c0 = vdup_n_f32(0.0f); + float32x2_t c1 = vdup_n_f32(0.0f); + float32x2_t c2 = vdup_n_f32(0.0f); + float32x2_t c3 = vdup_n_f32(0.0f); + + for (BLASLONG k = 0; k < K; ++k) { + float32x4_t a_f32 = vcvt_f32_f16(vld1_f16(A)); + float32x2_t a_low = vget_low_f32(a_f32); + + float32x4_t b_f32 = vcvt_f32_f16(vld1_f16(B)); + + float32_t b0_lane0 = vgetq_lane_f32(b_f32, 0); + c0 = vfma_n_f32(c0, a_low, b0_lane0); + + float32_t b0_lane1 = vgetq_lane_f32(b_f32, 1); + c1 = vfma_n_f32(c1, a_low, b0_lane1); + + float32_t b0_lane2 = vgetq_lane_f32(b_f32, 2); + c2 = vfma_n_f32(c2, a_low, b0_lane2); + + float32_t b0_lane3 = vgetq_lane_f32(b_f32, 3); + c3 = vfma_n_f32(c3, a_low, b0_lane3); + A += 2; + B += 4; + } + + FLOAT *col_0 = C + 0 * ldc; + FLOAT *col_1 = C + 1 * ldc; + FLOAT *col_2 = C + 2 * ldc; + FLOAT *col_3 = C + 3 * ldc; + + float32x2_t t0 = vld1_f32(col_0); + t0 = vadd_f32(t0, vmul_n_f32(c0, alpha)); + vst1_f32(col_0, t0); + + float32x2_t t1 = vld1_f32(col_1); + t1 = vadd_f32(t1, vmul_n_f32(c1, alpha)); + vst1_f32(col_1, t1); + + float32x2_t t2 = vld1_f32(col_2); + t2 = vadd_f32(t2, vmul_n_f32(c2, alpha)); + vst1_f32(col_2, t2); + + float32x2_t t3 = vld1_f32(col_3); + t3 = vadd_f32(t3, vmul_n_f32(c3, alpha)); + vst1_f32(col_3, t3); +} + +static inline void kernel_1x4(BLASLONG K, const float16_t *A, const float16_t *B, FLOAT *C, BLASLONG ldc, FLOAT alpha) { + FLOAT c0 = 0, c1 = 0, c2 = 0, c3 = 0; + for (BLASLONG k = 0; k < K; ++k) { + FLOAT a = A[0]; + c0 += a * B[0]; + c1 += a * B[1]; + c2 += a * B[2]; + c3 += a * B[3]; + + A += 1; + B += 4; + } + + C[0 * ldc] += alpha * c0; + C[1 * ldc] += alpha * c1; + C[2 * ldc] += alpha * c2; + C[3 * ldc] += alpha * c3; +} + +static inline void kernel_8x2(BLASLONG K, const float16_t *A, const float16_t *B, FLOAT *C, BLASLONG ldc, FLOAT alpha) { + float32x4_t c0_low = vdupq_n_f32(0.0f); + float32x4_t c0_high = vdupq_n_f32(0.0f); + float32x4_t c1_low = vdupq_n_f32(0.0f); + float32x4_t c1_high = vdupq_n_f32(0.0f); + + for (BLASLONG k = 0; k < K; ++k) { + float16x8_t a_f16 = vld1q_f16(A); + float32x4_t a_low = vcvt_f32_f16(vget_low_f16(a_f16)); + float32x4_t a_high = vcvt_f32_f16(vget_high_f16(a_f16)); + + float32x4_t b_f32 = vcvt_f32_f16(vld1_f16(B)); + + float32_t b0_lane0 = vgetq_lane_f32(b_f32, 0); + c0_low = vfmaq_n_f32(c0_low, a_low, b0_lane0); + c0_high = vfmaq_n_f32(c0_high, a_high, b0_lane0); + + float32_t b0_lane1 = vgetq_lane_f32(b_f32, 1); + c1_low = vfmaq_n_f32(c1_low, a_low, b0_lane1); + c1_high = vfmaq_n_f32(c1_high, a_high, b0_lane1); + + A += 8; + B += 2; + } + + FLOAT *col_0 = C + 0 * ldc; + FLOAT *col_1 = C + 1 * ldc; + + float32x4_t t0_l = vld1q_f32(col_0); + float32x4_t t0_h = vld1q_f32(col_0 + 4); + t0_l = vaddq_f32(t0_l, vmulq_n_f32(c0_low, alpha)); + t0_h = vaddq_f32(t0_h, vmulq_n_f32(c0_high, alpha)); + vst1q_f32(col_0, t0_l); + vst1q_f32(col_0 + 4, t0_h); + + float32x4_t t1_l = vld1q_f32(col_1); + float32x4_t t1_h = vld1q_f32(col_1 + 4); + t1_l = vaddq_f32(t1_l, vmulq_n_f32(c1_low, alpha)); + t1_h = vaddq_f32(t1_h, vmulq_n_f32(c1_high, alpha)); + vst1q_f32(col_1, t1_l); + vst1q_f32(col_1 + 4, t1_h); +} + +static inline void kernel_4x2(BLASLONG K, const float16_t *A, const float16_t *B, FLOAT *C, BLASLONG ldc, FLOAT alpha) { + float32x4_t c0 = vdupq_n_f32(0.0f); + float32x4_t c1 = vdupq_n_f32(0.0f); + + for (BLASLONG k = 0; k < K; ++k) { + float32x4_t a_f32 = vcvt_f32_f16(vld1_f16(A)); + float32x4_t b_f32 = vcvt_f32_f16(vld1_f16(B)); + + float32_t b0_lane0 = vgetq_lane_f32(b_f32, 0); + c0 = vfmaq_n_f32(c0, a_f32, b0_lane0); + + float32_t b0_lane1 = vgetq_lane_f32(b_f32, 1); + c1 = vfmaq_n_f32(c1, a_f32, b0_lane1); + + A += 4; + B += 2; + } + + FLOAT *col_0 = C + 0 * ldc; + FLOAT *col_1 = C + 1 * ldc; + + float32x4_t t0 = vld1q_f32(col_0); + t0 = vaddq_f32(t0, vmulq_n_f32(c0, alpha)); + vst1q_f32(col_0, t0); + + float32x4_t t1 = vld1q_f32(col_1); + t1 = vaddq_f32(t1, vmulq_n_f32(c1, alpha)); + vst1q_f32(col_1, t1); +} + +static inline void kernel_2x2(BLASLONG K, const float16_t *A, const float16_t *B, FLOAT *C, BLASLONG ldc, FLOAT alpha) { + float32x2_t c0 = vdup_n_f32(0.0f); + float32x2_t c1 = vdup_n_f32(0.0f); + + for (BLASLONG k = 0; k < K; ++k) { + float32x4_t a_f32 = vcvt_f32_f16(vld1_f16(A)); + float32x2_t a_low = vget_low_f32(a_f32); + + float32x4_t b_f32 = vcvt_f32_f16(vld1_f16(B)); + + float32_t b0_lane0 = vgetq_lane_f32(b_f32, 0); + c0 = vfma_n_f32(c0, a_low, b0_lane0); + + float32_t b0_lane1 = vgetq_lane_f32(b_f32, 1); + c1 = vfma_n_f32(c1, a_low, b0_lane1); + ; + + A += 2; + B += 2; + } + + FLOAT *col_0 = C + 0 * ldc; + FLOAT *col_1 = C + 1 * ldc; + + float32x2_t t0 = vld1_f32(col_0); + t0 = vadd_f32(t0, vmul_n_f32(c0, alpha)); + vst1_f32(col_0, t0); + + float32x2_t t1 = vld1_f32(col_1); + t1 = vadd_f32(t1, vmul_n_f32(c1, alpha)); + vst1_f32(col_1, t1); +} + +static inline void kernel_1x2(BLASLONG K, const float16_t *A, const float16_t *B, FLOAT *C, BLASLONG ldc, FLOAT alpha) { + FLOAT c0 = 0, c1 = 0; + for (BLASLONG k = 0; k < K; ++k) { + FLOAT a = A[0]; + c0 += a * B[0]; + c1 += a * B[1]; + + A += 1; + B += 2; + } + + C[0 * ldc] += alpha * c0; + C[1 * ldc] += alpha * c1; +} + +static inline void kernel_8x1(BLASLONG K, const float16_t *A, const float16_t *B, FLOAT *C, FLOAT alpha) { + float32x4_t c0_low = vdupq_n_f32(0.0f); + float32x4_t c0_high = vdupq_n_f32(0.0f); + + for (BLASLONG k = 0; k < K; ++k) { + float16x8_t a_f16 = vld1q_f16(A); + float32x4_t a_low = vcvt_f32_f16(vget_low_f16(a_f16)); + float32x4_t a_high = vcvt_f32_f16(vget_high_f16(a_f16)); + + float b_scalar = (float)B[0]; + + c0_low = vfmaq_n_f32(c0_low, a_low, b_scalar); + c0_high = vfmaq_n_f32(c0_high, a_high, b_scalar); + + A += 8; + B += 1; + } + + FLOAT *col_0 = C; + + float32x4_t t0_l = vld1q_f32(col_0); + float32x4_t t0_h = vld1q_f32(col_0 + 4); + t0_l = vaddq_f32(t0_l, vmulq_n_f32(c0_low, alpha)); + t0_h = vaddq_f32(t0_h, vmulq_n_f32(c0_high, alpha)); + vst1q_f32(col_0, t0_l); + vst1q_f32(col_0 + 4, t0_h); +} + +static inline void kernel_4x1(BLASLONG K, const float16_t *A, const float16_t *B, FLOAT *C, FLOAT alpha) { + float32x4_t c0 = vdupq_n_f32(0.0f); + + for (BLASLONG k = 0; k < K; ++k) { + float32x4_t a_f32 = vcvt_f32_f16(vld1_f16(A)); + float b_scalar = (float)B[0]; + c0 = vfmaq_n_f32(c0, a_f32, b_scalar); + + A += 4; + B += 1; + } + + FLOAT *col_0 = C; + float32x4_t t0 = vld1q_f32(col_0); + t0 = vaddq_f32(t0, vmulq_n_f32(c0, alpha)); + vst1q_f32(col_0, t0); +} + +static inline void kernel_2x1(BLASLONG K, const float16_t *A, const float16_t *B, FLOAT *C, FLOAT alpha) { + float32x2_t c0 = vdup_n_f32(0.0f); + + for (BLASLONG k = 0; k < K; ++k) { + float32x4_t a_f32 = vcvt_f32_f16(vld1_f16(A)); + float32x2_t a_low = vget_low_f32(a_f32); + + float b_scalar = (float)B[0]; + c0 = vfma_n_f32(c0, a_low, b_scalar); + + A += 2; + B += 1; + } + + FLOAT *col_0 = C; + float32x2_t t0 = vld1_f32(col_0); + t0 = vadd_f32(t0, vmul_n_f32(c0, alpha)); + vst1_f32(col_0, t0); +} + +static inline void kernel_1x1(BLASLONG K, const float16_t *A, const float16_t *B, FLOAT *C, FLOAT alpha) { + FLOAT sum = 0.0f; + for (BLASLONG k = 0; k < K; ++k) { + sum += A[0] * B[0]; + A += 1; + B += 1; + } + + C[0] += alpha * sum; +} + +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, BLASLONG ldc) { + float16_t *A_base = (float16_t *)A; + float16_t *B_base = (float16_t *)B; + + FLOAT *Ccol = C; + BLASLONG m_rem1, m_rem2, m_rem3, m_rem4; + + while (N >= 8) { + const float16_t *Aptr = A_base; + const float16_t *Bptr = B_base; + FLOAT *Crow = Ccol; + + m_rem1 = M; + + while (m_rem1 >= 8) { + kernel_8x8(K, Aptr, Bptr, Crow, ldc, alpha); + Aptr += K * 8; + Crow += 8; + m_rem1 -= 8; + } + if (m_rem1 >= 4) { + kernel_4x8(K, Aptr, Bptr, Crow, ldc, alpha); + Aptr += K * 4; + Crow += 4; + m_rem1 -= 4; + } + if (m_rem1 >= 2) { + kernel_2x8(K, Aptr, Bptr, Crow, ldc, alpha); + Aptr += K * 2; + Crow += 2; + m_rem1 -= 2; + } + if (m_rem1 >= 1) { + kernel_1x8(K, Aptr, Bptr, Crow, ldc, alpha); + } + + B_base += K * 8; + Ccol += ldc * 8; + N -= 8; + } + + if (N >= 4) { + const float16_t *Aptr = A_base; + const float16_t *Bptr = B_base; + FLOAT *Crow = Ccol; + + m_rem2 = M; + while (m_rem2 >= 8) { + kernel_8x4(K, Aptr, Bptr, Crow, ldc, alpha); + Aptr += K * 8; + Crow += 8; + m_rem2 -= 8; + } + if (m_rem2 >= 4) { + kernel_4x4(K, Aptr, Bptr, Crow, ldc, alpha); + Aptr += K * 4; + Crow += 4; + m_rem2 -= 4; + } + if (m_rem2 >= 2) { + kernel_2x4(K, Aptr, Bptr, Crow, ldc, alpha); + Aptr += K * 2; + Crow += 2; + m_rem2 -= 2; + } + if (m_rem2 >= 1) { + kernel_1x4(K, Aptr, Bptr, Crow, ldc, alpha); + } + + B_base += K * 4; + Ccol += ldc * 4; + N -= 4; + } + + if (N >= 2) { + const float16_t *Aptr = A_base; + const float16_t *Bptr = B_base; + FLOAT *Crow = Ccol; + + m_rem3 = M; + while (m_rem3 >= 8) { + kernel_8x2(K, Aptr, Bptr, Crow, ldc, alpha); + Aptr += K * 8; + Crow += 8; + m_rem3 -= 8; + } + if (m_rem3 >= 4) { + kernel_4x2(K, Aptr, Bptr, Crow, ldc, alpha); + Aptr += K * 4; + Crow += 4; + m_rem3 -= 4; + } + if (m_rem3 >= 2) { + kernel_2x2(K, Aptr, Bptr, Crow, ldc, alpha); + Aptr += K * 2; + Crow += 2; + m_rem3 -= 2; + } + if (m_rem3 >= 1) { + kernel_1x2(K, Aptr, Bptr, Crow, ldc, alpha); + } + + B_base += K * 2; + Ccol += ldc * 2; + N -= 2; + } + + if (N >= 1) { + const float16_t *Aptr = A_base; + const float16_t *Bptr = B_base; + FLOAT *Crow = Ccol; + + m_rem4 = M; + while (m_rem4 >= 8) { + kernel_8x1(K, Aptr, Bptr, Crow, alpha); + Aptr += K * 8; + Crow += 8; + m_rem4 -= 8; + } + if (m_rem4 >= 4) { + kernel_4x1(K, Aptr, Bptr, Crow, alpha); + Aptr += K * 4; + Crow += 4; + m_rem4 -= 4; + } + if (m_rem4 >= 2) { + kernel_2x1(K, Aptr, Bptr, Crow, alpha); + Aptr += K * 2; + Crow += 2; + m_rem4 -= 2; + } + if (m_rem4 >= 1) { + kernel_1x1(K, Aptr, Bptr, Crow, alpha); + } + } + + return 0; +} \ No newline at end of file diff --git a/kernel/arm64/shgemm_ncopy_8_neoversen2.c b/kernel/arm64/shgemm_ncopy_8_neoversen2.c new file mode 100644 index 0000000000..3229a22f24 --- /dev/null +++ b/kernel/arm64/shgemm_ncopy_8_neoversen2.c @@ -0,0 +1,258 @@ +/*************************************************************************** + * Copyright (c) 2026, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + +#include + +#include "common.h" + +static inline void transpose8x8(float16x8_t *rows, float16x8_t *cols) { + float64x2_t b0 = vtrn1q_f64(vreinterpretq_f64_f16(rows[0]), vreinterpretq_f64_f16(rows[4])); + float64x2_t b1 = vtrn1q_f64(vreinterpretq_f64_f16(rows[1]), vreinterpretq_f64_f16(rows[5])); + float64x2_t b2 = vtrn1q_f64(vreinterpretq_f64_f16(rows[2]), vreinterpretq_f64_f16(rows[6])); + float64x2_t b3 = vtrn1q_f64(vreinterpretq_f64_f16(rows[3]), vreinterpretq_f64_f16(rows[7])); + float64x2_t b4 = vtrn2q_f64(vreinterpretq_f64_f16(rows[0]), vreinterpretq_f64_f16(rows[4])); + float64x2_t b5 = vtrn2q_f64(vreinterpretq_f64_f16(rows[1]), vreinterpretq_f64_f16(rows[5])); + float64x2_t b6 = vtrn2q_f64(vreinterpretq_f64_f16(rows[2]), vreinterpretq_f64_f16(rows[6])); + float64x2_t b7 = vtrn2q_f64(vreinterpretq_f64_f16(rows[3]), vreinterpretq_f64_f16(rows[7])); + + float32x4_t c0 = vtrn1q_f32(vreinterpretq_f32_f64(b0), vreinterpretq_f32_f64(b2)); + float32x4_t c1 = vtrn1q_f32(vreinterpretq_f32_f64(b1), vreinterpretq_f32_f64(b3)); + float32x4_t c2 = vtrn2q_f32(vreinterpretq_f32_f64(b0), vreinterpretq_f32_f64(b2)); + float32x4_t c3 = vtrn2q_f32(vreinterpretq_f32_f64(b1), vreinterpretq_f32_f64(b3)); + float32x4_t c4 = vtrn1q_f32(vreinterpretq_f32_f64(b4), vreinterpretq_f32_f64(b6)); + float32x4_t c5 = vtrn1q_f32(vreinterpretq_f32_f64(b5), vreinterpretq_f32_f64(b7)); + float32x4_t c6 = vtrn2q_f32(vreinterpretq_f32_f64(b4), vreinterpretq_f32_f64(b6)); + float32x4_t c7 = vtrn2q_f32(vreinterpretq_f32_f64(b5), vreinterpretq_f32_f64(b7)); + + float16x8_t d0 = vtrn1q_f16(vreinterpretq_f16_f32(c0), vreinterpretq_f16_f32(c1)); + float16x8_t d1 = vtrn2q_f16(vreinterpretq_f16_f32(c0), vreinterpretq_f16_f32(c1)); + float16x8_t d2 = vtrn1q_f16(vreinterpretq_f16_f32(c2), vreinterpretq_f16_f32(c3)); + float16x8_t d3 = vtrn2q_f16(vreinterpretq_f16_f32(c2), vreinterpretq_f16_f32(c3)); + float16x8_t d4 = vtrn1q_f16(vreinterpretq_f16_f32(c4), vreinterpretq_f16_f32(c5)); + float16x8_t d5 = vtrn2q_f16(vreinterpretq_f16_f32(c4), vreinterpretq_f16_f32(c5)); + float16x8_t d6 = vtrn1q_f16(vreinterpretq_f16_f32(c6), vreinterpretq_f16_f32(c7)); + float16x8_t d7 = vtrn2q_f16(vreinterpretq_f16_f32(c6), vreinterpretq_f16_f32(c7)); + + cols[0] = d0; + cols[1] = d1; + cols[2] = d2; + cols[3] = d3; + cols[4] = d4; + cols[5] = d5; + cols[6] = d6; + cols[7] = d7; +} + +static inline void transpose_4x4(float16x4_t *rows, float16x4_t *cols) { + float16x8_t t0 = vcombine_f16(rows[0], vdup_n_f16(0.0f)); + float16x8_t t1 = vcombine_f16(rows[1], vdup_n_f16(0.0f)); + float16x8_t t2 = vcombine_f16(rows[2], vdup_n_f16(0.0f)); + float16x8_t t3 = vcombine_f16(rows[3], vdup_n_f16(0.0f)); + + float16x8_t t02 = vzip1q_f16(t0, t2); + float16x8_t t13 = vzip1q_f16(t1, t3); + + float16x8x2_t t0123 = vzipq_f16(t02, t13); + + cols[0] = vget_low_f16(t0123.val[0]); + cols[1] = vget_high_f16(t0123.val[0]); + cols[2] = vget_low_f16(t0123.val[1]); + cols[3] = vget_high_f16(t0123.val[1]); +} + +int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) { + BLASLONG i, j; + IFLOAT *a_offset = a; + IFLOAT *b_offset = b; + + float16x8_t v0, v1, v2, v3, v4, v5, v6, v7; + float16x4_t v8, v9, v10, v11; + + BLASLONG n8 = n >> 3; + + for (j = 0; j < n8; j++) { + IFLOAT *a0 = a_offset; + IFLOAT *a1 = a0 + lda; + IFLOAT *a2 = a1 + lda; + IFLOAT *a3 = a2 + lda; + IFLOAT *a4 = a3 + lda; + IFLOAT *a5 = a4 + lda; + IFLOAT *a6 = a5 + lda; + IFLOAT *a7 = a6 + lda; + a_offset += 8 * lda; + + BLASLONG m8 = m >> 3; + for (i = 0; i < m8; i++) { + v0 = vld1q_f16((float16_t *)a0); + v1 = vld1q_f16((float16_t *)a1); + v2 = vld1q_f16((float16_t *)a2); + v3 = vld1q_f16((float16_t *)a3); + v4 = vld1q_f16((float16_t *)a4); + v5 = vld1q_f16((float16_t *)a5); + v6 = vld1q_f16((float16_t *)a6); + v7 = vld1q_f16((float16_t *)a7); + + float16x8_t rows[8] = {v0, v1, v2, v3, v4, v5, v6, v7}; + float16x8_t cols[8]; + transpose8x8(rows, cols); + + vst1q_f16((float16_t *)b_offset, cols[0]); + vst1q_f16((float16_t *)b_offset + 8, cols[1]); + vst1q_f16((float16_t *)b_offset + 16, cols[2]); + vst1q_f16((float16_t *)b_offset + 24, cols[3]); + vst1q_f16((float16_t *)b_offset + 32, cols[4]); + vst1q_f16((float16_t *)b_offset + 40, cols[5]); + vst1q_f16((float16_t *)b_offset + 48, cols[6]); + vst1q_f16((float16_t *)b_offset + 56, cols[7]); + + a0 += 8; + a1 += 8; + a2 += 8; + a3 += 8; + a4 += 8; + a5 += 8; + a6 += 8; + a7 += 8; + b_offset += 64; + } + + BLASLONG i = (m & 7); + if (i > 0) { + for (BLASLONG k = 0; k < i; k++) { + *(b_offset + 0) = *a0; + *(b_offset + 1) = *a1; + *(b_offset + 2) = *a2; + *(b_offset + 3) = *a3; + *(b_offset + 4) = *a4; + *(b_offset + 5) = *a5; + *(b_offset + 6) = *a6; + *(b_offset + 7) = *a7; + + a0++; + a1++; + a2++; + a3++; + a4++; + a5++; + a6++; + a7++; + + b_offset += 8; + } + } + } + + if (n & 4) { + IFLOAT *a0 = a_offset; + IFLOAT *a1 = a0 + lda; + IFLOAT *a2 = a1 + lda; + IFLOAT *a3 = a2 + lda; + a_offset += 4 * lda; + + BLASLONG m4 = m >> 2; + for (i = 0; i < m4; i++) { + v8 = vld1_f16((float16_t *)a0); + v9 = vld1_f16((float16_t *)a1); + v10 = vld1_f16((float16_t *)a2); + v11 = vld1_f16((float16_t *)a3); + + float16x4_t rows[4] = {v8, v9, v10, v11}; + float16x4_t cols[4]; + transpose_4x4(rows, cols); + + vst1_f16((float16_t *)b_offset, cols[0]); + vst1_f16((float16_t *)b_offset + 4, cols[1]); + vst1_f16((float16_t *)b_offset + 8, cols[2]); + vst1_f16((float16_t *)b_offset + 12, cols[3]); + + a0 += 4; + a1 += 4; + a2 += 4; + a3 += 4; + b_offset += 16; + } + + BLASLONG i = (m & 3); + if (i > 0) { + for (BLASLONG k = 0; k < i; k++) { + *(b_offset + 0) = *a0; + *(b_offset + 1) = *a1; + *(b_offset + 2) = *a2; + *(b_offset + 3) = *a3; + + a0++; + a1++; + a2++; + a3++; + + b_offset += 4; + } + } + } + + if (n & 2) { + IFLOAT *a0 = a_offset; + IFLOAT *a1 = a0 + lda; + a_offset += 2 * lda; + + BLASLONG m2 = m >> 1; + for (i = 0; i < m2; i++) { + + v8 = vld1_f16((float16_t *)a0); + v9 = vld1_f16((float16_t *)a1); + + float16_t col0[2] = {vget_lane_f16(v8, 0), vget_lane_f16(v9, 0)}; + float16_t col1[2] = {vget_lane_f16(v8, 1), vget_lane_f16(v9, 1)}; + + b_offset[0] = col0[0]; + b_offset[1] = col0[1]; + b_offset[2] = col1[0]; + b_offset[3] = col1[1]; + + a0 += 2; + a1 += 2; + b_offset += 4; + } + + if (m & 1) { + b_offset[0] = *a0; + b_offset[1] = *a1; + b_offset += 2; + } + } + + if (n & 1) { + IFLOAT *a0 = a_offset; + for (i = 0; i < m; i++) { + *b_offset++ = *a0; + a0++; + } + } + + return 0; +} \ No newline at end of file diff --git a/kernel/arm64/shgemm_tcopy_8_neoversen2.c b/kernel/arm64/shgemm_tcopy_8_neoversen2.c new file mode 100644 index 0000000000..275abf124f --- /dev/null +++ b/kernel/arm64/shgemm_tcopy_8_neoversen2.c @@ -0,0 +1,87 @@ +/*************************************************************************** + * Copyright (c) 2026, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + +#include + +#include "common.h" + +int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b) { + BLASLONG i, j; + IFLOAT *aoffset, *aoffset1; + IFLOAT *boffset, *boffset1; + IFLOAT *boffset2, *boffset3, *boffset4; + + aoffset = a; + boffset = b; + + boffset2 = b + m * (n & ~7); + boffset3 = b + m * (n & ~3); + boffset4 = b + m * (n & ~1); + + svbool_t pg8 = svwhilelt_b16(0, 8); + svbool_t pg4 = svwhilelt_b16(0, 4); + + for (j = 0; j < m; j++) { + aoffset1 = aoffset; + boffset1 = boffset; + + aoffset += lda; + boffset += 8; + + for (i = 0; i < (n >> 3); i++) { + svfloat16_t v0 = svld1_f16(pg8, (float16_t *)aoffset1); + svst1_f16(pg8, (float16_t *)boffset1, v0); + + aoffset1 += 8; + boffset1 += 8 * m; + } + + if (n & 4) { + svfloat16_t v0 = svld1_f16(pg4, (float16_t *)aoffset1); + svst1_f16(pg4, (float16_t *)boffset2, v0); + + aoffset1 += 4; + boffset2 += 4; + } + + if (n & 2) { + boffset3[0] = aoffset1[0]; + boffset3[1] = aoffset1[1]; + aoffset1 += 2; + boffset3 += 2; + } + + if (n & 1) { + boffset4[0] = aoffset1[0]; + aoffset1 += 1; + boffset4 += 1; + } + } + + return 0; +} \ No newline at end of file diff --git a/param.h b/param.h index 4faaebff7c..c4a1b2520a 100644 --- a/param.h +++ b/param.h @@ -3682,6 +3682,13 @@ is a big desktop or server with abundant cache rather than a phone or embedded d #define SBGEMM_DEFAULT_UNROLL_M 8 #define SBGEMM_DEFAULT_UNROLL_N 8 +#undef SHGEMM_ALIGN_K +#undef SHGEMM_DEFAULT_UNROLL_M +#undef SHGEMM_DEFAULT_UNROLL_N +#define SHGEMM_ALIGN_K 4 +#define SHGEMM_DEFAULT_UNROLL_M 8 +#define SHGEMM_DEFAULT_UNROLL_N 8 + #define SGEMM_DEFAULT_UNROLL_M 16 #define SGEMM_DEFAULT_UNROLL_N 4 diff --git a/test/Makefile b/test/Makefile index f653b70b13..230ee8ef23 100644 --- a/test/Makefile +++ b/test/Makefile @@ -277,6 +277,10 @@ ifeq ($(BUILD_BFLOAT16),1) OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./test_bgemm > BBLAT3.SUMM @$(GREP) -q FATAL BBLAT3.SUMM && cat BBLAT3.SUMM || exit 0 endif +ifeq ($(BUILD_HFLOAT16),1) + OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./test_shgemm > SHBLAT3.SUMM + @$(GREP) -q FATAL SHBLAT3.SUMM && cat SHBLAT3.SUMM || exit 0 +endif ifeq ($(BUILD_SINGLE),1) OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./sblat3 < ./sblat3.dat @$(GREP) -q FATAL SBLAT3.SUMM && cat SBLAT3.SUMM || exit 0 @@ -302,6 +306,11 @@ ifeq ($(BUILD_BFLOAT16),1) OMP_NUM_THREADS=2 ./test_bgemm > BBLAT3.SUMM @$(GREP) -q FATAL BBLAT3.SUMM && cat BBLAT3.SUMM || exit 0 endif +ifeq ($(BUILD_HFLOAT16),1) + OMP_NUM_THREADS=2 ./test_shgemm > SHBLAT3.SUMM + @$(GREP) -q FATAL SHBLAT3.SUMM && cat SHBLAT3.SUMM || exit 0 +endif + ifeq ($(BUILD_SINGLE),1) OMP_NUM_THREADS=2 ./sblat3 < ./sblat3.dat @$(GREP) -q FATAL SBLAT3.SUMM && cat SBLAT3.SUMM || exit 0 @@ -325,6 +334,10 @@ ifeq ($(BUILD_BFLOAT16),1) OPENBLAS_NUM_THREADS=2 ./test_bgemm > BBLAT3.SUMM @$(GREP) -q FATAL BBLAT3.SUMM && cat BBLAT3.SUMM || exit 0 endif +ifeq ($(BUILD_HFLOAT16),1) + OPENBLAS_NUM_THREADS=2 ./test_shgemm > SHBLAT3.SUMM + @$(GREP) -q FATAL SHBLAT3.SUMM && cat SHBLAT3.SUMM || exit 0 +endif ifeq ($(BUILD_SINGLE),1) OPENBLAS_NUM_THREADS=2 ./sblat3 < ./sblat3.dat @$(GREP) -q FATAL SBLAT3.SUMM && cat SBLAT3.SUMM || exit 0