diff mbox series

[v3,30/51] target/arm: Implement FMOPA, FMOPS (non-widening)

Message ID 20220620175235.60881-31-richard.henderson@linaro.org
State Superseded
Headers show
Series target/arm: Scalable Matrix Extension | expand

Commit Message

Richard Henderson June 20, 2022, 5:52 p.m. UTC
Signed-off-by: Richard Henderson <richard.henderson@linaro.org>
---
 target/arm/helper-sme.h    |  5 +++
 target/arm/sme.decode      |  9 +++++
 target/arm/sme_helper.c    | 67 ++++++++++++++++++++++++++++++++++++++
 target/arm/translate-sme.c | 33 +++++++++++++++++++
 4 files changed, 114 insertions(+)

Comments

Peter Maydell June 24, 2022, 12:31 p.m. UTC | #1
On Mon, 20 Jun 2022 at 19:07, Richard Henderson
<richard.henderson@linaro.org> wrote:
>
> Signed-off-by: Richard Henderson <richard.henderson@linaro.org>

> +void HELPER(sme_fmopa_s)(void *vza, void *vzn, void *vzm, void *vpn,
> +                         void *vpm, void *vst, uint32_t desc)
> +{
> +    intptr_t row, col, oprsz = simd_maxsz(desc);
> +    uint32_t neg = simd_data(desc) << 31;
> +    uint16_t *pn = vpn, *pm = vpm;
> +
> +    bool save_dn = get_default_nan_mode(vst);
> +    set_default_nan_mode(true, vst);
> +
> +    for (row = 0; row < oprsz; ) {
> +        uint16_t pa = pn[H2(row >> 4)];
> +        do {
> +            if (pa & 1) {
> +                void *vza_row = vza + row * sizeof(ARMVectorReg);
> +                uint32_t n = *(uint32_t *)(vzn + row) ^ neg;
> +
> +                for (col = 0; col < oprsz; ) {
> +                    uint16_t pb = pm[H2(col >> 4)];
> +                    do {
> +                        if (pb & 1) {
> +                            uint32_t *a = vza_row + col;
> +                            uint32_t *m = vzm + col;
> +                            *a = float32_muladd(n, *m, *a, 0, vst);
> +                        }
> +                        col += 4;
> +                        pb >>= 4;
> +                    } while (col & 15);
> +                }
> +            }
> +            row += 4;
> +            pa >>= 4;
> +        } while (row & 15);
> +    }

The code for the double version seems straightforward:
row counts from 0 up to the number of rows, and we
do something per row. Why is the single precision version
doing something with an unrolled loop here? It's confusing
that 'oprsz' in the two functions isn't the same thing --
in the double version we divide by the element size, but
here we don't.

> +
> +    set_default_nan_mode(save_dn, vst);
> +}
> +
> +void HELPER(sme_fmopa_d)(void *vza, void *vzn, void *vzm, void *vpn,
> +                         void *vpm, void *vst, uint32_t desc)
> +{
> +    intptr_t row, col, oprsz = simd_oprsz(desc) / 8;
> +    uint64_t neg = (uint64_t)simd_data(desc) << 63;
> +    uint64_t *za = vza, *zn = vzn, *zm = vzm;
> +    uint8_t *pn = vpn, *pm = vpm;
> +
> +    bool save_dn = get_default_nan_mode(vst);
> +    set_default_nan_mode(true, vst);
> +
> +    for (row = 0; row < oprsz; ++row) {
> +        if (pn[H1(row)] & 1) {
> +            uint64_t *za_row = &za[row * sizeof(ARMVectorReg)];
> +            uint64_t n = zn[row] ^ neg;
> +
> +            for (col = 0; col < oprsz; ++col) {
> +                if (pm[H1(col)] & 1) {
> +                    uint64_t *a = &za_row[col];
> +                    *a = float64_muladd(n, zm[col], *a, 0, vst);
> +                }
> +            }
> +        }
> +    }
> +
> +    set_default_nan_mode(save_dn, vst);
> +}

The pseudocode says that we ignore floating point exceptions
(ie do not accumulate them in the FPSR) -- it passes fpexc == false
to FPMulAdd(). Don't we need to do something special to arrange
for that ?

thanks
-- PMM
Richard Henderson June 24, 2022, 2:16 p.m. UTC | #2
On 6/24/22 05:31, Peter Maydell wrote:
> On Mon, 20 Jun 2022 at 19:07, Richard Henderson
> <richard.henderson@linaro.org> wrote:
>>
>> Signed-off-by: Richard Henderson <richard.henderson@linaro.org>
> 
>> +void HELPER(sme_fmopa_s)(void *vza, void *vzn, void *vzm, void *vpn,
>> +                         void *vpm, void *vst, uint32_t desc)
>> +{
>> +    intptr_t row, col, oprsz = simd_maxsz(desc);
>> +    uint32_t neg = simd_data(desc) << 31;
>> +    uint16_t *pn = vpn, *pm = vpm;
>> +
>> +    bool save_dn = get_default_nan_mode(vst);
>> +    set_default_nan_mode(true, vst);
>> +
>> +    for (row = 0; row < oprsz; ) {
>> +        uint16_t pa = pn[H2(row >> 4)];
>> +        do {
>> +            if (pa & 1) {
>> +                void *vza_row = vza + row * sizeof(ARMVectorReg);
>> +                uint32_t n = *(uint32_t *)(vzn + row) ^ neg;
>> +
>> +                for (col = 0; col < oprsz; ) {
>> +                    uint16_t pb = pm[H2(col >> 4)];
>> +                    do {
>> +                        if (pb & 1) {
>> +                            uint32_t *a = vza_row + col;
>> +                            uint32_t *m = vzm + col;
>> +                            *a = float32_muladd(n, *m, *a, 0, vst);
>> +                        }
>> +                        col += 4;
>> +                        pb >>= 4;
>> +                    } while (col & 15);
>> +                }
>> +            }
>> +            row += 4;
>> +            pa >>= 4;
>> +        } while (row & 15);
>> +    }
> 
> The code for the double version seems straightforward:
> row counts from 0 up to the number of rows, and we
> do something per row. Why is the single precision version
> doing something with an unrolled loop here? It's confusing
> that 'oprsz' in the two functions isn't the same thing --
> in the double version we divide by the element size, but
> here we don't.

It's all about the predicate addressing.  For doubles, the bits are spaced 8 bits apart, 
which makes it easy as you see.  For singles, the bits are spaced 4 bits apart, which is 
inconvenient.  Anyway, just as over in sve_helper.c, I load uint16_t at a time and shift 
to find each predicate bit.

So it's not unrolled, exactly.  There's second loop over predicates.  And since this is a 
matrix op, we get loops nested 4 deep.

> The pseudocode says that we ignore floating point exceptions
> (ie do not accumulate them in the FPSR) -- it passes fpexc == false
> to FPMulAdd(). Don't we need to do something special to arrange
> for that ?

Oops, somewhere I read that as "do not trap" not "do not accumulate".
But R_TGSKG is very clear on this as accumulate.


r~
diff mbox series

Patch

diff --git a/target/arm/helper-sme.h b/target/arm/helper-sme.h
index 6f0fce7e2c..727095a3eb 100644
--- a/target/arm/helper-sme.h
+++ b/target/arm/helper-sme.h
@@ -119,3 +119,8 @@  DEF_HELPER_FLAGS_5(sme_addha_s, TCG_CALL_NO_RWG, void, ptr, ptr, ptr, ptr, i32)
 DEF_HELPER_FLAGS_5(sme_addva_s, TCG_CALL_NO_RWG, void, ptr, ptr, ptr, ptr, i32)
 DEF_HELPER_FLAGS_5(sme_addha_d, TCG_CALL_NO_RWG, void, ptr, ptr, ptr, ptr, i32)
 DEF_HELPER_FLAGS_5(sme_addva_d, TCG_CALL_NO_RWG, void, ptr, ptr, ptr, ptr, i32)
+
+DEF_HELPER_FLAGS_7(sme_fmopa_s, TCG_CALL_NO_RWG,
+                   void, ptr, ptr, ptr, ptr, ptr, ptr, i32)
+DEF_HELPER_FLAGS_7(sme_fmopa_d, TCG_CALL_NO_RWG,
+                   void, ptr, ptr, ptr, ptr, ptr, ptr, i32)
diff --git a/target/arm/sme.decode b/target/arm/sme.decode
index 8cb6c4053c..ba4774d174 100644
--- a/target/arm/sme.decode
+++ b/target/arm/sme.decode
@@ -64,3 +64,12 @@  ADDHA_s         11000000 10 01000 0 ... ... ..... 000 ..        @adda_32
 ADDVA_s         11000000 10 01000 1 ... ... ..... 000 ..        @adda_32
 ADDHA_d         11000000 11 01000 0 ... ... ..... 00 ...        @adda_64
 ADDVA_d         11000000 11 01000 1 ... ... ..... 00 ...        @adda_64
+
+### SME Outer Product
+
+&op             zad zn zm pm pn sub:bool
+@op_32          ........ ... zm:5 pm:3 pn:3 zn:5 sub:1 .. zad:2 &op
+@op_64          ........ ... zm:5 pm:3 pn:3 zn:5 sub:1 .  zad:3 &op
+
+FMOPA_s         10000000 100 ..... ... ... ..... . 00 ..        @op_32
+FMOPA_d         10000000 110 ..... ... ... ..... . 0 ...        @op_64
diff --git a/target/arm/sme_helper.c b/target/arm/sme_helper.c
index 799e44c047..62d9690cae 100644
--- a/target/arm/sme_helper.c
+++ b/target/arm/sme_helper.c
@@ -25,6 +25,7 @@ 
 #include "exec/cpu_ldst.h"
 #include "exec/exec-all.h"
 #include "qemu/int128.h"
+#include "fpu/softfloat.h"
 #include "vec_internal.h"
 #include "sve_ldst_internal.h"
 
@@ -897,3 +898,69 @@  void HELPER(sme_addva_d)(void *vzda, void *vzn, void *vpn,
         }
     }
 }
+
+void HELPER(sme_fmopa_s)(void *vza, void *vzn, void *vzm, void *vpn,
+                         void *vpm, void *vst, uint32_t desc)
+{
+    intptr_t row, col, oprsz = simd_maxsz(desc);
+    uint32_t neg = simd_data(desc) << 31;
+    uint16_t *pn = vpn, *pm = vpm;
+
+    bool save_dn = get_default_nan_mode(vst);
+    set_default_nan_mode(true, vst);
+
+    for (row = 0; row < oprsz; ) {
+        uint16_t pa = pn[H2(row >> 4)];
+        do {
+            if (pa & 1) {
+                void *vza_row = vza + row * sizeof(ARMVectorReg);
+                uint32_t n = *(uint32_t *)(vzn + row) ^ neg;
+
+                for (col = 0; col < oprsz; ) {
+                    uint16_t pb = pm[H2(col >> 4)];
+                    do {
+                        if (pb & 1) {
+                            uint32_t *a = vza_row + col;
+                            uint32_t *m = vzm + col;
+                            *a = float32_muladd(n, *m, *a, 0, vst);
+                        }
+                        col += 4;
+                        pb >>= 4;
+                    } while (col & 15);
+                }
+            }
+            row += 4;
+            pa >>= 4;
+        } while (row & 15);
+    }
+
+    set_default_nan_mode(save_dn, vst);
+}
+
+void HELPER(sme_fmopa_d)(void *vza, void *vzn, void *vzm, void *vpn,
+                         void *vpm, void *vst, uint32_t desc)
+{
+    intptr_t row, col, oprsz = simd_oprsz(desc) / 8;
+    uint64_t neg = (uint64_t)simd_data(desc) << 63;
+    uint64_t *za = vza, *zn = vzn, *zm = vzm;
+    uint8_t *pn = vpn, *pm = vpm;
+
+    bool save_dn = get_default_nan_mode(vst);
+    set_default_nan_mode(true, vst);
+
+    for (row = 0; row < oprsz; ++row) {
+        if (pn[H1(row)] & 1) {
+            uint64_t *za_row = &za[row * sizeof(ARMVectorReg)];
+            uint64_t n = zn[row] ^ neg;
+
+            for (col = 0; col < oprsz; ++col) {
+                if (pm[H1(col)] & 1) {
+                    uint64_t *a = &za_row[col];
+                    *a = float64_muladd(n, zm[col], *a, 0, vst);
+                }
+            }
+        }
+    }
+
+    set_default_nan_mode(save_dn, vst);
+}
diff --git a/target/arm/translate-sme.c b/target/arm/translate-sme.c
index e9676b2415..e6e4541e76 100644
--- a/target/arm/translate-sme.c
+++ b/target/arm/translate-sme.c
@@ -273,3 +273,36 @@  TRANS_FEAT(ADDHA_s, aa64_sme, do_adda, a, MO_32, gen_helper_sme_addha_s)
 TRANS_FEAT(ADDVA_s, aa64_sme, do_adda, a, MO_32, gen_helper_sme_addva_s)
 TRANS_FEAT(ADDHA_d, aa64_sme_i16i64, do_adda, a, MO_64, gen_helper_sme_addha_d)
 TRANS_FEAT(ADDVA_d, aa64_sme_i16i64, do_adda, a, MO_64, gen_helper_sme_addva_d)
+
+static bool do_outprod_fpst(DisasContext *s, arg_op *a, MemOp esz,
+                            gen_helper_gvec_5_ptr *fn)
+{
+    uint32_t desc = simd_desc(s->svl, s->svl, a->sub);
+    TCGv_ptr za, zn, zm, pn, pm, fpst;
+
+    if (!sme_smza_enabled_check(s)) {
+        return true;
+    }
+
+    /* Sum XZR+zad to find ZAd. */
+    za = get_tile_rowcol(s, esz, 31, a->zad, false);
+    zn = vec_full_reg_ptr(s, a->zn);
+    zm = vec_full_reg_ptr(s, a->zm);
+    pn = pred_full_reg_ptr(s, a->pn);
+    pm = pred_full_reg_ptr(s, a->pm);
+    fpst = fpstatus_ptr(FPST_FPCR);
+
+    fn(za, zn, zm, pn, pm, fpst, tcg_constant_i32(desc));
+
+    tcg_temp_free_ptr(za);
+    tcg_temp_free_ptr(zn);
+    tcg_temp_free_ptr(pn);
+    tcg_temp_free_ptr(pm);
+    tcg_temp_free_ptr(fpst);
+    return true;
+}
+
+TRANS_FEAT(FMOPA_s, aa64_sme, do_outprod_fpst,
+           a, MO_32, gen_helper_sme_fmopa_s)
+TRANS_FEAT(FMOPA_d, aa64_sme_f64f64, do_outprod_fpst,
+           a, MO_64, gen_helper_sme_fmopa_d)