diff mbox series

[v6,59/82] target/arm: Implement SVE mixed sign dot product (indexed)

Message ID 20210430202610.1136687-60-richard.henderson@linaro.org
State New
Headers show
Series target/arm: Implement SVE2 | expand

Commit Message

Richard Henderson April 30, 2021, 8:25 p.m. UTC
Signed-off-by: Richard Henderson <richard.henderson@linaro.org>

---
 target/arm/cpu.h           |  5 +++
 target/arm/helper.h        |  4 +++
 target/arm/sve.decode      |  4 +++
 target/arm/translate-sve.c | 16 +++++++++
 target/arm/vec_helper.c    | 68 ++++++++++++++++++++++++++++++++++++++
 5 files changed, 97 insertions(+)

-- 
2.25.1

Comments

Peter Maydell May 13, 2021, 12:57 p.m. UTC | #1
On Fri, 30 Apr 2021 at 22:04, Richard Henderson
<richard.henderson@linaro.org> wrote:
>

> Signed-off-by: Richard Henderson <richard.henderson@linaro.org>

> ---

>  target/arm/cpu.h           |  5 +++

>  target/arm/helper.h        |  4 +++

>  target/arm/sve.decode      |  4 +++

>  target/arm/translate-sve.c | 16 +++++++++

>  target/arm/vec_helper.c    | 68 ++++++++++++++++++++++++++++++++++++++

>  5 files changed, 97 insertions(+)


> diff --git a/target/arm/vec_helper.c b/target/arm/vec_helper.c

> index 8b7269d8e1..98b707f4f5 100644

> --- a/target/arm/vec_helper.c

> +++ b/target/arm/vec_helper.c

> @@ -677,6 +677,74 @@ void HELPER(gvec_udot_idx_b)(void *vd, void *vn, void *vm,

>      clear_tail(d, opr_sz, simd_maxsz(desc));

>  }

>

> +void HELPER(gvec_sudot_idx_b)(void *vd, void *vn, void *vm,

> +                              void *va, uint32_t desc)

> +{

> +    intptr_t i, segend, opr_sz = simd_oprsz(desc), opr_sz_4 = opr_sz / 4;

> +    intptr_t index = simd_data(desc);

> +    int32_t *d = vd, *a = va;

> +    int8_t *n = vn;

> +    uint8_t *m_indexed = (uint8_t *)vm + index * 4;

> +

> +    /*

> +     * Notice the special case of opr_sz == 8, from aa64/aa32 advsimd.

> +     * Otherwise opr_sz is a multiple of 16.

> +     */


These are only used by SVE, aren't they ? I guess maintaining
the parallelism with the helpers that are shared is worthwhile.

> +    segend = MIN(4, opr_sz_4);

> +    i = 0;

> +    do {

> +        uint8_t m0 = m_indexed[i * 4 + 0];

> +        uint8_t m1 = m_indexed[i * 4 + 1];

> +        uint8_t m2 = m_indexed[i * 4 + 2];

> +        uint8_t m3 = m_indexed[i * 4 + 3];

> +

> +        do {

> +            d[i] = (a[i] +

> +                    n[i * 4 + 0] * m0 +

> +                    n[i * 4 + 1] * m1 +

> +                    n[i * 4 + 2] * m2 +

> +                    n[i * 4 + 3] * m3);

> +        } while (++i < segend);

> +        segend = i + 4;

> +    } while (i < opr_sz_4);

> +

> +    clear_tail(d, opr_sz, simd_maxsz(desc));

> +}

> +

> +void HELPER(gvec_usdot_idx_b)(void *vd, void *vn, void *vm,

> +                              void *va, uint32_t desc)

> +{

> +    intptr_t i, segend, opr_sz = simd_oprsz(desc), opr_sz_4 = opr_sz / 4;

> +    intptr_t index = simd_data(desc);

> +    uint32_t *d = vd, *a = va;

> +    uint8_t *n = vn;

> +    int8_t *m_indexed = (int8_t *)vm + index * 4;

> +

> +    /*

> +     * Notice the special case of opr_sz == 8, from aa64/aa32 advsimd.

> +     * Otherwise opr_sz is a multiple of 16.

> +     */

> +    segend = MIN(4, opr_sz_4);

> +    i = 0;

> +    do {

> +        int8_t m0 = m_indexed[i * 4 + 0];

> +        int8_t m1 = m_indexed[i * 4 + 1];

> +        int8_t m2 = m_indexed[i * 4 + 2];

> +        int8_t m3 = m_indexed[i * 4 + 3];

> +

> +        do {

> +            d[i] = (a[i] +

> +                    n[i * 4 + 0] * m0 +

> +                    n[i * 4 + 1] * m1 +

> +                    n[i * 4 + 2] * m2 +

> +                    n[i * 4 + 3] * m3);

> +        } while (++i < segend);

> +        segend = i + 4;

> +    } while (i < opr_sz_4);

> +

> +    clear_tail(d, opr_sz, simd_maxsz(desc));

> +}


Maybe we should macroify this, as unless I'm misreading them
gvec_sdot_idx_b, gvec_udot_idx_b, gvec_sudot_idx_b and gvec_usdot_idx_b
only differ in the types of the index and the data.

But if you'd rather not you can have a
Reviewed-by: Peter Maydell <peter.maydell@linaro.org>

for this version.

thanks
-- PMM
Richard Henderson May 14, 2021, 6:47 p.m. UTC | #2
On 5/13/21 7:57 AM, Peter Maydell wrote:
> Maybe we should macroify this, as unless I'm misreading them

> gvec_sdot_idx_b, gvec_udot_idx_b, gvec_sudot_idx_b and gvec_usdot_idx_b

> only differ in the types of the index and the data.


Done.

r~
diff mbox series

Patch

diff --git a/target/arm/cpu.h b/target/arm/cpu.h
index 595bc6349d..0a41142d35 100644
--- a/target/arm/cpu.h
+++ b/target/arm/cpu.h
@@ -4246,6 +4246,11 @@  static inline bool isar_feature_aa64_sve2_bitperm(const ARMISARegisters *id)
     return FIELD_EX64(id->id_aa64zfr0, ID_AA64ZFR0, BITPERM) != 0;
 }
 
+static inline bool isar_feature_aa64_sve_i8mm(const ARMISARegisters *id)
+{
+    return FIELD_EX64(id->id_aa64zfr0, ID_AA64ZFR0, I8MM) != 0;
+}
+
 static inline bool isar_feature_aa64_sve_f32mm(const ARMISARegisters *id)
 {
     return FIELD_EX64(id->id_aa64zfr0, ID_AA64ZFR0, F32MM) != 0;
diff --git a/target/arm/helper.h b/target/arm/helper.h
index e7c463fff5..e4c6458f98 100644
--- a/target/arm/helper.h
+++ b/target/arm/helper.h
@@ -621,6 +621,10 @@  DEF_HELPER_FLAGS_5(gvec_sdot_idx_h, TCG_CALL_NO_RWG,
                    void, ptr, ptr, ptr, ptr, i32)
 DEF_HELPER_FLAGS_5(gvec_udot_idx_h, TCG_CALL_NO_RWG,
                    void, ptr, ptr, ptr, ptr, i32)
+DEF_HELPER_FLAGS_5(gvec_sudot_idx_b, TCG_CALL_NO_RWG,
+                   void, ptr, ptr, ptr, ptr, i32)
+DEF_HELPER_FLAGS_5(gvec_usdot_idx_b, TCG_CALL_NO_RWG,
+                   void, ptr, ptr, ptr, ptr, i32)
 
 DEF_HELPER_FLAGS_5(gvec_fcaddh, TCG_CALL_NO_RWG,
                    void, ptr, ptr, ptr, ptr, i32)
diff --git a/target/arm/sve.decode b/target/arm/sve.decode
index 35010d755f..05360e2608 100644
--- a/target/arm/sve.decode
+++ b/target/arm/sve.decode
@@ -813,6 +813,10 @@  SQRDMLSH_zzxz_h 01000100 0. 1 ..... 000101 ..... .....   @rrxr_3 esz=1
 SQRDMLSH_zzxz_s 01000100 10 1 ..... 000101 ..... .....   @rrxr_2 esz=2
 SQRDMLSH_zzxz_d 01000100 11 1 ..... 000101 ..... .....   @rrxr_1 esz=3
 
+# SVE mixed sign dot product (indexed)
+USDOT_zzxw_s    01000100 10 1 ..... 000110 ..... .....   @rrxr_2 esz=2
+SUDOT_zzxw_s    01000100 10 1 ..... 000111 ..... .....   @rrxr_2 esz=2
+
 # SVE2 saturating multiply-add (indexed)
 SQDMLALB_zzxw_s 01000100 10 1 ..... 0010.0 ..... .....   @rrxr_3a esz=2
 SQDMLALB_zzxw_d 01000100 11 1 ..... 0010.0 ..... .....   @rrxr_2a esz=3
diff --git a/target/arm/translate-sve.c b/target/arm/translate-sve.c
index c72d1e4bf0..c988d0125a 100644
--- a/target/arm/translate-sve.c
+++ b/target/arm/translate-sve.c
@@ -3838,6 +3838,22 @@  DO_RRXR(trans_SDOT_zzxw_d, gen_helper_gvec_sdot_idx_h)
 DO_RRXR(trans_UDOT_zzxw_s, gen_helper_gvec_udot_idx_b)
 DO_RRXR(trans_UDOT_zzxw_d, gen_helper_gvec_udot_idx_h)
 
+static bool trans_SUDOT_zzxw_s(DisasContext *s, arg_rrxr_esz *a)
+{
+    if (!dc_isar_feature(aa64_sve_i8mm, s)) {
+        return false;
+    }
+    return do_zzxz_data(s, a, gen_helper_gvec_sudot_idx_b, a->index);
+}
+
+static bool trans_USDOT_zzxw_s(DisasContext *s, arg_rrxr_esz *a)
+{
+    if (!dc_isar_feature(aa64_sve_i8mm, s)) {
+        return false;
+    }
+    return do_zzxz_data(s, a, gen_helper_gvec_usdot_idx_b, a->index);
+}
+
 #undef DO_RRXR
 
 static bool do_sve2_zzx_data(DisasContext *s, arg_rrx_esz *a,
diff --git a/target/arm/vec_helper.c b/target/arm/vec_helper.c
index 8b7269d8e1..98b707f4f5 100644
--- a/target/arm/vec_helper.c
+++ b/target/arm/vec_helper.c
@@ -677,6 +677,74 @@  void HELPER(gvec_udot_idx_b)(void *vd, void *vn, void *vm,
     clear_tail(d, opr_sz, simd_maxsz(desc));
 }
 
+void HELPER(gvec_sudot_idx_b)(void *vd, void *vn, void *vm,
+                              void *va, uint32_t desc)
+{
+    intptr_t i, segend, opr_sz = simd_oprsz(desc), opr_sz_4 = opr_sz / 4;
+    intptr_t index = simd_data(desc);
+    int32_t *d = vd, *a = va;
+    int8_t *n = vn;
+    uint8_t *m_indexed = (uint8_t *)vm + index * 4;
+
+    /*
+     * Notice the special case of opr_sz == 8, from aa64/aa32 advsimd.
+     * Otherwise opr_sz is a multiple of 16.
+     */
+    segend = MIN(4, opr_sz_4);
+    i = 0;
+    do {
+        uint8_t m0 = m_indexed[i * 4 + 0];
+        uint8_t m1 = m_indexed[i * 4 + 1];
+        uint8_t m2 = m_indexed[i * 4 + 2];
+        uint8_t m3 = m_indexed[i * 4 + 3];
+
+        do {
+            d[i] = (a[i] +
+                    n[i * 4 + 0] * m0 +
+                    n[i * 4 + 1] * m1 +
+                    n[i * 4 + 2] * m2 +
+                    n[i * 4 + 3] * m3);
+        } while (++i < segend);
+        segend = i + 4;
+    } while (i < opr_sz_4);
+
+    clear_tail(d, opr_sz, simd_maxsz(desc));
+}
+
+void HELPER(gvec_usdot_idx_b)(void *vd, void *vn, void *vm,
+                              void *va, uint32_t desc)
+{
+    intptr_t i, segend, opr_sz = simd_oprsz(desc), opr_sz_4 = opr_sz / 4;
+    intptr_t index = simd_data(desc);
+    uint32_t *d = vd, *a = va;
+    uint8_t *n = vn;
+    int8_t *m_indexed = (int8_t *)vm + index * 4;
+
+    /*
+     * Notice the special case of opr_sz == 8, from aa64/aa32 advsimd.
+     * Otherwise opr_sz is a multiple of 16.
+     */
+    segend = MIN(4, opr_sz_4);
+    i = 0;
+    do {
+        int8_t m0 = m_indexed[i * 4 + 0];
+        int8_t m1 = m_indexed[i * 4 + 1];
+        int8_t m2 = m_indexed[i * 4 + 2];
+        int8_t m3 = m_indexed[i * 4 + 3];
+
+        do {
+            d[i] = (a[i] +
+                    n[i * 4 + 0] * m0 +
+                    n[i * 4 + 1] * m1 +
+                    n[i * 4 + 2] * m2 +
+                    n[i * 4 + 3] * m3);
+        } while (++i < segend);
+        segend = i + 4;
+    } while (i < opr_sz_4);
+
+    clear_tail(d, opr_sz, simd_maxsz(desc));
+}
+
 void HELPER(gvec_sdot_idx_h)(void *vd, void *vn, void *vm,
                              void *va, uint32_t desc)
 {