diff mbox series

[32/55] target/arm: Implement MVE VRMLALDAVH, VRMLSLDAVH

Message ID 20210607165821.9892-33-peter.maydell@linaro.org
State Superseded
Headers show
Series target/arm: First slice of MVE implementation | expand

Commit Message

Peter Maydell June 7, 2021, 4:57 p.m. UTC
Implement the MVE VRMLALDAVH and VRMLSLDAVH insns, which accumulate
the results of a rounded multiply of pairs of elements into a 72-bit
accumulator, returning the top 64 bits in a pair of general purpose
registers.

Signed-off-by: Peter Maydell <peter.maydell@linaro.org>

---
 target/arm/helper-mve.h    |  8 ++++++++
 target/arm/mve.decode      |  7 +++++++
 target/arm/mve_helper.c    | 35 +++++++++++++++++++++++++++++++++++
 target/arm/translate-mve.c | 24 ++++++++++++++++++++++++
 4 files changed, 74 insertions(+)

-- 
2.20.1

Comments

Richard Henderson June 9, 2021, 1:05 a.m. UTC | #1
On 6/7/21 9:57 AM, Peter Maydell wrote:
> +#define DO_LDAVH(OP, ESIZE, TYPE, H, XCHG, EVENACC, ODDACC, TO128)      \

> +    uint64_t HELPER(glue(mve_, OP))(CPUARMState *env, void *vn,         \

> +                                    void *vm, uint64_t a)               \

> +    {                                                                   \

> +        uint16_t mask = mve_element_mask(env);                          \

> +        unsigned e;                                                     \

> +        TYPE *n = vn, *m = vm;                                          \

> +        Int128 acc = TO128(a);                                          \


This seems to miss the << 8.

Which suggests that the whole thing can be done without Int128:

> +        for (e = 0; e < 16 / ESIZE; e++, mask >>= ESIZE) {              \

> +            if (mask & 1) {                                             \

> +                if (e & 1) {                                            \

> +                    acc = ODDACC(acc, TO128(n[H(e - 1 * XCHG)] * m[H(e)])); \


   tmp = n * m;
   tmp = (tmp >> 8) + ((tmp >> 7) & 1);
   acc ODDACC tmp;

> +static bool trans_VRMLALDAVH_S(DisasContext *s, arg_vmlaldav *a)

> +{

> +    MVEGenDualAccOpFn *fns[] = {


static const, etc.


r~
Peter Maydell June 14, 2021, 10:19 a.m. UTC | #2
On Wed, 9 Jun 2021 at 02:05, Richard Henderson
<richard.henderson@linaro.org> wrote:
>

> On 6/7/21 9:57 AM, Peter Maydell wrote:

> > +#define DO_LDAVH(OP, ESIZE, TYPE, H, XCHG, EVENACC, ODDACC, TO128)      \

> > +    uint64_t HELPER(glue(mve_, OP))(CPUARMState *env, void *vn,         \

> > +                                    void *vm, uint64_t a)               \

> > +    {                                                                   \

> > +        uint16_t mask = mve_element_mask(env);                          \

> > +        unsigned e;                                                     \

> > +        TYPE *n = vn, *m = vm;                                          \

> > +        Int128 acc = TO128(a);                                          \

>

> This seems to miss the << 8.


Oops, yes it does.

> Which suggests that the whole thing can be done without Int128:

>

> > +        for (e = 0; e < 16 / ESIZE; e++, mask >>= ESIZE) {              \

> > +            if (mask & 1) {                                             \

> > +                if (e & 1) {                                            \

> > +                    acc = ODDACC(acc, TO128(n[H(e - 1 * XCHG)] * m[H(e)])); \

>

>    tmp = n * m;

>    tmp = (tmp >> 8) + ((tmp >> 7) & 1);

>    acc ODDACC tmp;


I'm not sure about this suggestion though. It throws away all
of the bottom 7 bits of the product, but because we're iterating through
this 4 times and adding (potentially) four of these products together,
those bottom 7 bits in the 4 products might be able to add together
to become significant enough to affect the final result.

-- PMM
diff mbox series

Patch

diff --git a/target/arm/helper-mve.h b/target/arm/helper-mve.h
index 7789da1986b..723bef4a83a 100644
--- a/target/arm/helper-mve.h
+++ b/target/arm/helper-mve.h
@@ -159,3 +159,11 @@  DEF_HELPER_FLAGS_4(mve_vmlsldavsh, TCG_CALL_NO_WG, i64, env, ptr, ptr, i64)
 DEF_HELPER_FLAGS_4(mve_vmlsldavsw, TCG_CALL_NO_WG, i64, env, ptr, ptr, i64)
 DEF_HELPER_FLAGS_4(mve_vmlsldavxsh, TCG_CALL_NO_WG, i64, env, ptr, ptr, i64)
 DEF_HELPER_FLAGS_4(mve_vmlsldavxsw, TCG_CALL_NO_WG, i64, env, ptr, ptr, i64)
+
+DEF_HELPER_FLAGS_4(mve_vrmlaldavhsw, TCG_CALL_NO_WG, i64, env, ptr, ptr, i64)
+DEF_HELPER_FLAGS_4(mve_vrmlaldavhxsw, TCG_CALL_NO_WG, i64, env, ptr, ptr, i64)
+
+DEF_HELPER_FLAGS_4(mve_vrmlaldavhuw, TCG_CALL_NO_WG, i64, env, ptr, ptr, i64)
+
+DEF_HELPER_FLAGS_4(mve_vrmlsldavhsw, TCG_CALL_NO_WG, i64, env, ptr, ptr, i64)
+DEF_HELPER_FLAGS_4(mve_vrmlsldavhxsw, TCG_CALL_NO_WG, i64, env, ptr, ptr, i64)
diff --git a/target/arm/mve.decode b/target/arm/mve.decode
index 1be2d6b270f..ac68f072bbe 100644
--- a/target/arm/mve.decode
+++ b/target/arm/mve.decode
@@ -143,7 +143,14 @@  VDUP             1110 1110 1 0 10 ... 0 .... 1011 . 0 0 1 0000 @vdup size=2
 
 @vmlaldav        .... .... . ... ... . ... . .... .... qm:3 . \
                  qn=%qn rdahi=%rdahi rdalo=%rdalo size=%size_16 &vmlaldav
+@vmlaldav_nosz   .... .... . ... ... . ... . .... .... qm:3 . \
+                 qn=%qn rdahi=%rdahi rdalo=%rdalo size=0 &vmlaldav
 VMLALDAV_S       1110 1110 1 ... ... . ... x:1 1110 . 0 a:1 0 ... 0 @vmlaldav
 VMLALDAV_U       1111 1110 1 ... ... . ... x:1 1110 . 0 a:1 0 ... 0 @vmlaldav
 
 VMLSLDAV         1110 1110 1 ... ... . ... x:1 1110 . 0 a:1 0 ... 1 @vmlaldav
+
+VRMLALDAVH_S     1110 1110 1 ... ... 0 ... x:1 1111 . 0 a:1 0 ... 0 @vmlaldav_nosz
+VRMLALDAVH_U     1111 1110 1 ... ... 0 ... x:1 1111 . 0 a:1 0 ... 0 @vmlaldav_nosz
+
+VRMLSLDAVH       1111 1110 1 ... ... 0 ... x:1 1110 . 0 a:1 0 ... 1 @vmlaldav_nosz
diff --git a/target/arm/mve_helper.c b/target/arm/mve_helper.c
index 1c22e2777d9..b22a7535308 100644
--- a/target/arm/mve_helper.c
+++ b/target/arm/mve_helper.c
@@ -18,6 +18,7 @@ 
  */
 
 #include "qemu/osdep.h"
+#include "qemu/int128.h"
 #include "cpu.h"
 #include "internals.h"
 #include "exec/helper-proto.h"
@@ -512,3 +513,37 @@  DO_LDAV(vmlsldavsh, 2, int16_t, H2, false, +=, -=)
 DO_LDAV(vmlsldavxsh, 2, int16_t, H2, true, +=, -=)
 DO_LDAV(vmlsldavsw, 4, int32_t, H4, false, +=, -=)
 DO_LDAV(vmlsldavxsw, 4, int32_t, H4, true, +=, -=)
+
+/*
+ * Rounding multiply add long dual accumulate high: we must keep
+ * a 72-bit internal accumulator value and return the top 64 bits.
+ */
+#define DO_LDAVH(OP, ESIZE, TYPE, H, XCHG, EVENACC, ODDACC, TO128)      \
+    uint64_t HELPER(glue(mve_, OP))(CPUARMState *env, void *vn,         \
+                                    void *vm, uint64_t a)               \
+    {                                                                   \
+        uint16_t mask = mve_element_mask(env);                          \
+        unsigned e;                                                     \
+        TYPE *n = vn, *m = vm;                                          \
+        Int128 acc = TO128(a);                                          \
+        for (e = 0; e < 16 / ESIZE; e++, mask >>= ESIZE) {              \
+            if (mask & 1) {                                             \
+                if (e & 1) {                                            \
+                    acc = ODDACC(acc, TO128(n[H(e - 1 * XCHG)] * m[H(e)])); \
+                } else {                                                \
+                    acc = EVENACC(acc, TO128(n[H(e + 1 * XCHG)] * m[H(e)])); \
+                }                                                       \
+                acc = int128_add(acc, 1 << 7);                          \
+            }                                                           \
+        }                                                               \
+        mve_advance_vpt(env);                                           \
+        return int128_getlo(int128_rshift(acc, 8));                     \
+    }
+
+DO_LDAVH(vrmlaldavhsw, 4, int32_t, H4, false, int128_add, int128_add, int128_makes64)
+DO_LDAVH(vrmlaldavhxsw, 4, int32_t, H4, true, int128_add, int128_add, int128_makes64)
+
+DO_LDAVH(vrmlaldavhuw, 4, uint32_t, H4, false, int128_add, int128_add, int128_make64)
+
+DO_LDAVH(vrmlsldavhsw, 4, int32_t, H4, false, int128_add, int128_sub, int128_makes64)
+DO_LDAVH(vrmlsldavhxsw, 4, int32_t, H4, true, int128_add, int128_sub, int128_makes64)
diff --git a/target/arm/translate-mve.c b/target/arm/translate-mve.c
index 66d713a24e2..6792fca798d 100644
--- a/target/arm/translate-mve.c
+++ b/target/arm/translate-mve.c
@@ -508,3 +508,27 @@  static bool trans_VMLSLDAV(DisasContext *s, arg_vmlaldav *a)
     };
     return do_long_dual_acc(s, a, fns[a->size][a->x]);
 }
+
+static bool trans_VRMLALDAVH_S(DisasContext *s, arg_vmlaldav *a)
+{
+    MVEGenDualAccOpFn *fns[] = {
+        gen_helper_mve_vrmlaldavhsw, gen_helper_mve_vrmlaldavhxsw,
+    };
+    return do_long_dual_acc(s, a, fns[a->x]);
+}
+
+static bool trans_VRMLALDAVH_U(DisasContext *s, arg_vmlaldav *a)
+{
+    MVEGenDualAccOpFn *fns[] = {
+        gen_helper_mve_vrmlaldavhuw, NULL,
+    };
+    return do_long_dual_acc(s, a, fns[a->x]);
+}
+
+static bool trans_VRMLSLDAVH(DisasContext *s, arg_vmlaldav *a)
+{
+    MVEGenDualAccOpFn *fns[] = {
+        gen_helper_mve_vrmlsldavhsw, gen_helper_mve_vrmlsldavhxsw,
+    };
+    return do_long_dual_acc(s, a, fns[a->x]);
+}