diff mbox series

[v5,26/45] target/arm: Implement FMOPA, FMOPS (widening)

Message ID 20220706082411.1664825-27-richard.henderson@linaro.org
State Superseded
Headers show
Series target/arm: Scalable Matrix Extension | expand

Commit Message

Richard Henderson July 6, 2022, 8:23 a.m. UTC
Signed-off-by: Richard Henderson <richard.henderson@linaro.org>
---
 target/arm/helper-sme.h    |  2 ++
 target/arm/sme.decode      |  1 +
 target/arm/sme_helper.c    | 68 ++++++++++++++++++++++++++++++++++++++
 target/arm/translate-sme.c |  1 +
 4 files changed, 72 insertions(+)

Comments

Peter Maydell July 7, 2022, 9:50 a.m. UTC | #1
On Wed, 6 Jul 2022 at 10:26, Richard Henderson
<richard.henderson@linaro.org> wrote:
>
> Signed-off-by: Richard Henderson <richard.henderson@linaro.org>


> +static float32 f16_dotadd(float32 sum, uint32_t e1, uint32_t e2,
> +                          float_status *s_std, float_status *s_odd)
> +{
> +    float64 e1r = float16_to_float64(e1 & 0xffff, true, s_std);
> +    float64 e1c = float16_to_float64(e1 >> 16, true, s_std);
> +    float64 e2r = float16_to_float64(e2 & 0xffff, true, s_std);
> +    float64 e2c = float16_to_float64(e2 >> 16, true, s_std);
> +    float64 t64;
> +    float32 t32;
> +
> +    /*
> +     * The ARM pseudocode function FPDot performs both multiplies
> +     * and the add with a single rounding operation.  Emulate this
> +     * by performing the first multiply in round-to-odd, then doing
> +     * the second multiply as fused multiply-add, and rounding to
> +     * float32 all in one step.
> +     */

I guess if we find we're not producing quite bit-accurate results
we can come back and revisit this :-)

> +    t64 = float64_mul(e1r, e2r, s_odd);
> +    t64 = float64r32_muladd(e1c, e2c, t64, 0, s_std);
> +
> +    /* This conversion is exact, because we've already rounded. */
> +    t32 = float64_to_float32(t64, s_std);
> +
> +    /* The final accumulation step is not fused. */
> +    return float32_add(sum, t32, s_std);
> +}
> +
> +void HELPER(sme_fmopa_h)(void *vza, void *vzn, void *vzm, void *vpn,
> +                         void *vpm, void *vst, uint32_t desc)
> +{
> +    intptr_t row, col, oprsz = simd_maxsz(desc);
> +    uint32_t neg = simd_data(desc) << 15;
> +    uint16_t *pn = vpn, *pm = vpm;
> +    float_status fpst_odd, fpst_std = *(float_status *)vst;
> +
> +    set_default_nan_mode(true, &fpst_std);
> +    fpst_odd = fpst_std;
> +    set_float_rounding_mode(float_round_to_odd, &fpst_odd);
> +
> +    for (row = 0; row < oprsz; ) {
> +        uint16_t pa = pn[H2(row >> 4)];
> +        do {
> +            void *vza_row = vza + tile_vslice_offset(row);
> +            uint32_t n = *(uint32_t *)(vzn + row);

More missing H macros.

> +
> +            n = f16mop_adj_pair(n, pa, neg);
> +
> +            for (col = 0; col < oprsz; ) {
> +                uint16_t pb = pm[H2(col >> 4)];
> +                do {
> +                    if ((pa & 0b0101) == 0b0101 || (pb & 0b0101) == 0b0101) {

Wrong condition again?

> +                        uint32_t *a = vza_row + col;
> +                        uint32_t m = *(uint32_t *)(vzm + col);
> +
> +                        m = f16mop_adj_pair(m, pb, neg);
> +                        *a = f16_dotadd(*a, n, m, &fpst_std, &fpst_odd);
> +
> +                        col += 4;
> +                        pb >>= 4;
> +                    }
> +                } while (col & 15);
> +            }
> +            row += 4;
> +            pa >>= 4;
> +        } while (row & 15);
> +    }
> +}

thanks
-- PMM
diff mbox series

Patch

diff --git a/target/arm/helper-sme.h b/target/arm/helper-sme.h
index 1d68fb8c74..4d5d05db3a 100644
--- a/target/arm/helper-sme.h
+++ b/target/arm/helper-sme.h
@@ -121,6 +121,8 @@  DEF_HELPER_FLAGS_5(sme_addva_s, TCG_CALL_NO_RWG, void, ptr, ptr, ptr, ptr, i32)
 DEF_HELPER_FLAGS_5(sme_addha_d, TCG_CALL_NO_RWG, void, ptr, ptr, ptr, ptr, i32)
 DEF_HELPER_FLAGS_5(sme_addva_d, TCG_CALL_NO_RWG, void, ptr, ptr, ptr, ptr, i32)
 
+DEF_HELPER_FLAGS_7(sme_fmopa_h, TCG_CALL_NO_RWG,
+                   void, ptr, ptr, ptr, ptr, ptr, ptr, i32)
 DEF_HELPER_FLAGS_7(sme_fmopa_s, TCG_CALL_NO_RWG,
                    void, ptr, ptr, ptr, ptr, ptr, ptr, i32)
 DEF_HELPER_FLAGS_7(sme_fmopa_d, TCG_CALL_NO_RWG,
diff --git a/target/arm/sme.decode b/target/arm/sme.decode
index afd9c0dffd..e8d27fd8a0 100644
--- a/target/arm/sme.decode
+++ b/target/arm/sme.decode
@@ -75,3 +75,4 @@  FMOPA_s         10000000 100 ..... ... ... ..... . 00 ..        @op_32
 FMOPA_d         10000000 110 ..... ... ... ..... . 0 ...        @op_64
 
 BFMOPA          10000001 100 ..... ... ... ..... . 00 ..        @op_32
+FMOPA_h         10000001 101 ..... ... ... ..... . 00 ..        @op_32
diff --git a/target/arm/sme_helper.c b/target/arm/sme_helper.c
index 4b437bb913..e92f53ecab 100644
--- a/target/arm/sme_helper.c
+++ b/target/arm/sme_helper.c
@@ -998,6 +998,74 @@  static inline uint32_t f16mop_adj_pair(uint32_t pair, uint32_t pg, uint32_t neg)
     return pair;
 }
 
+static float32 f16_dotadd(float32 sum, uint32_t e1, uint32_t e2,
+                          float_status *s_std, float_status *s_odd)
+{
+    float64 e1r = float16_to_float64(e1 & 0xffff, true, s_std);
+    float64 e1c = float16_to_float64(e1 >> 16, true, s_std);
+    float64 e2r = float16_to_float64(e2 & 0xffff, true, s_std);
+    float64 e2c = float16_to_float64(e2 >> 16, true, s_std);
+    float64 t64;
+    float32 t32;
+
+    /*
+     * The ARM pseudocode function FPDot performs both multiplies
+     * and the add with a single rounding operation.  Emulate this
+     * by performing the first multiply in round-to-odd, then doing
+     * the second multiply as fused multiply-add, and rounding to
+     * float32 all in one step.
+     */
+    t64 = float64_mul(e1r, e2r, s_odd);
+    t64 = float64r32_muladd(e1c, e2c, t64, 0, s_std);
+
+    /* This conversion is exact, because we've already rounded. */
+    t32 = float64_to_float32(t64, s_std);
+
+    /* The final accumulation step is not fused. */
+    return float32_add(sum, t32, s_std);
+}
+
+void HELPER(sme_fmopa_h)(void *vza, void *vzn, void *vzm, void *vpn,
+                         void *vpm, void *vst, uint32_t desc)
+{
+    intptr_t row, col, oprsz = simd_maxsz(desc);
+    uint32_t neg = simd_data(desc) << 15;
+    uint16_t *pn = vpn, *pm = vpm;
+    float_status fpst_odd, fpst_std = *(float_status *)vst;
+
+    set_default_nan_mode(true, &fpst_std);
+    fpst_odd = fpst_std;
+    set_float_rounding_mode(float_round_to_odd, &fpst_odd);
+
+    for (row = 0; row < oprsz; ) {
+        uint16_t pa = pn[H2(row >> 4)];
+        do {
+            void *vza_row = vza + tile_vslice_offset(row);
+            uint32_t n = *(uint32_t *)(vzn + row);
+
+            n = f16mop_adj_pair(n, pa, neg);
+
+            for (col = 0; col < oprsz; ) {
+                uint16_t pb = pm[H2(col >> 4)];
+                do {
+                    if ((pa & 0b0101) == 0b0101 || (pb & 0b0101) == 0b0101) {
+                        uint32_t *a = vza_row + col;
+                        uint32_t m = *(uint32_t *)(vzm + col);
+
+                        m = f16mop_adj_pair(m, pb, neg);
+                        *a = f16_dotadd(*a, n, m, &fpst_std, &fpst_odd);
+
+                        col += 4;
+                        pb >>= 4;
+                    }
+                } while (col & 15);
+            }
+            row += 4;
+            pa >>= 4;
+        } while (row & 15);
+    }
+}
+
 void HELPER(sme_bfmopa)(void *vza, void *vzn, void *vzm, void *vpn,
                         void *vpm, uint32_t desc)
 {
diff --git a/target/arm/translate-sme.c b/target/arm/translate-sme.c
index ecb7583c55..c2953b22ce 100644
--- a/target/arm/translate-sme.c
+++ b/target/arm/translate-sme.c
@@ -355,6 +355,7 @@  static bool do_outprod_fpst(DisasContext *s, arg_op *a, MemOp esz,
     return true;
 }
 
+TRANS_FEAT(FMOPA_h, aa64_sme, do_outprod_fpst, a, MO_32, gen_helper_sme_fmopa_h)
 TRANS_FEAT(FMOPA_s, aa64_sme, do_outprod_fpst, a, MO_32, gen_helper_sme_fmopa_s)
 TRANS_FEAT(FMOPA_d, aa64_sme_f64f64, do_outprod_fpst, a, MO_64, gen_helper_sme_fmopa_d)