From patchwork Sat Jun 21 23:50:37 2025 Content-Type: text/plain; charset="utf-8" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit X-Patchwork-Submitter: Richard Henderson X-Patchwork-Id: 898992 Delivered-To: patch@linaro.org Received: by 2002:adf:e506:0:b0:3a6:d909:26ce with SMTP id j6csp317401wrm; Sat, 21 Jun 2025 17:02:05 -0700 (PDT) X-Forwarded-Encrypted: i=2; AJvYcCWIZRXai54vOuB134CJ0i0BFZGWr86OKI9m7xQUrZEqwltScxIXL3JoPb0BvxzHo4KlR2j2SA==@linaro.org X-Google-Smtp-Source: AGHT+IFUtVcPIBY+lA2DM1/6i1Hv+5JM2bBSxEebY1+jL1BzpeJ2br+qJntJBSfKbKli9CJEVR/B X-Received: by 2002:a05:622a:14c6:b0:4a5:8387:8b8d with SMTP id d75a77b69052e-4a77c37b277mr121292861cf.21.1750550525335; Sat, 21 Jun 2025 17:02:05 -0700 (PDT) ARC-Seal: i=1; a=rsa-sha256; t=1750550525; cv=none; d=google.com; s=arc-20240605; b=Mz9ujcguB9tcXm4AQXKL/WEGhzV1TAInp1ABFNrQiaqWDOukBCenlE9tGpiBYtbmWw 9Lx1nGIld8Pvw4G7bLOepAq6Sucwa1nwfn69OnmjiLro9nZbWMS7OefH9splXs/hbI45 k+SQcvwhU2FBYsyYInglqgoiHB/AwgmbFtB8hHSC9bZR8rsIpAthDaPI3KGoeYusG/hS gXovRD5MCQmoTytF2ay1+o7jB7aZ+3QEPCs5eReQZ0Skp8EUXwyrw4+bpkzYzaqddaHB VZusoIaIr/cU4mGeeTuD/CEmFgxy/Ul+NdrLztF3jundyNUxkplbE5l6j6JlDw1cczu2 oydA== ARC-Message-Signature: i=1; a=rsa-sha256; c=relaxed/relaxed; d=google.com; s=arc-20240605; h=sender:errors-to:list-subscribe:list-help:list-post:list-archive :list-unsubscribe:list-id:precedence:content-transfer-encoding :mime-version:references:in-reply-to:message-id:date:subject:to:from :dkim-signature; bh=7/B4lDUEqafN5g8eYZX+ybcJ0Y62/HOxTSdDxW3/oSE=; fh=PnYt+qEB9tAfMKoqBm2xjKOFpYyFFGPudh5cVIoieJM=; b=H9L0DSN1vdLmjNannse/w832C6BVIk98FGAYnGw6VNFt6Sfx+kI9gE05GHnCU0qW4T HFPVpJVhubw8PjRYmdYyZ735V0o9JnxgUqm9kX3UnudND7YLAOjwlHuTuEQpU2Ewa3Pj FLz8LSHfjbZxU/4mZEi/sNfQ+cDH1UPQLT0s4lI4X2NBnQ8EZunDbdmM8ZmzTqxbPpfg R9IkJUmevpxV0dg+VvX1IYycU0vJCC10FkJLnRmNlcCiJMZk2ZJ0PfXVl5rrdmXuJLb8 BIwgPs1QWWfZGH3aJ+F2iyWv/NOc+aka/gV4X0FF0y5WTMeKCMMI/m6lcsdIwHwweKUt LLjg==; dara=google.com ARC-Authentication-Results: i=1; mx.google.com; dkim=pass header.i=@linaro.org header.s=google header.b=AVHkBORs; spf=pass (google.com: domain of qemu-devel-bounces+patch=linaro.org@nongnu.org designates 209.51.188.17 as permitted sender) smtp.mailfrom="qemu-devel-bounces+patch=linaro.org@nongnu.org"; dmarc=pass (p=NONE sp=NONE dis=NONE) header.from=linaro.org; dara=neutral header.i=@linaro.org Return-Path: Received: from lists.gnu.org (lists.gnu.org. [209.51.188.17]) by mx.google.com with ESMTPS id d75a77b69052e-4a77a0b7330si49979781cf.336.2025.06.21.17.02.05 for (version=TLS1_2 cipher=ECDHE-ECDSA-CHACHA20-POLY1305 bits=256/256); Sat, 21 Jun 2025 17:02:05 -0700 (PDT) Received-SPF: pass (google.com: domain of qemu-devel-bounces+patch=linaro.org@nongnu.org designates 209.51.188.17 as permitted sender) client-ip=209.51.188.17; Authentication-Results: mx.google.com; dkim=pass header.i=@linaro.org header.s=google header.b=AVHkBORs; spf=pass (google.com: domain of qemu-devel-bounces+patch=linaro.org@nongnu.org designates 209.51.188.17 as permitted sender) smtp.mailfrom="qemu-devel-bounces+patch=linaro.org@nongnu.org"; dmarc=pass (p=NONE sp=NONE dis=NONE) header.from=linaro.org; dara=neutral header.i=@linaro.org Received: from localhost ([::1] helo=lists1p.gnu.org) by lists.gnu.org with esmtp (Exim 4.90_1) (envelope-from ) id 1uT85r-0006Xy-JO; Sat, 21 Jun 2025 19:57:56 -0400 Received: from eggs.gnu.org ([2001:470:142:3::10]) by lists.gnu.org with esmtps (TLS1.2:ECDHE_RSA_AES_256_GCM_SHA384:256) (Exim 4.90_1) (envelope-from ) id 1uT85k-000644-IA for qemu-devel@nongnu.org; Sat, 21 Jun 2025 19:57:48 -0400 Received: from mail-pf1-x434.google.com ([2607:f8b0:4864:20::434]) by eggs.gnu.org with esmtps (TLS1.2:ECDHE_RSA_AES_128_GCM_SHA256:128) (Exim 4.90_1) (envelope-from ) id 1uT85h-0006e5-N2 for qemu-devel@nongnu.org; Sat, 21 Jun 2025 19:57:48 -0400 Received: by mail-pf1-x434.google.com with SMTP id d2e1a72fcca58-7406c6dd2b1so2295847b3a.0 for ; Sat, 21 Jun 2025 16:57:45 -0700 (PDT) DKIM-Signature: v=1; a=rsa-sha256; c=relaxed/relaxed; d=linaro.org; s=google; t=1750550264; x=1751155064; darn=nongnu.org; h=content-transfer-encoding:mime-version:references:in-reply-to :message-id:date:subject:to:from:from:to:cc:subject:date:message-id :reply-to; bh=7/B4lDUEqafN5g8eYZX+ybcJ0Y62/HOxTSdDxW3/oSE=; b=AVHkBORsSVjONKAZVWEZyGJgG0eds3hXtKYur88+ngNIH4DOc0YSIr7gZa2kRwHqa/ MIeR6+c4Buar57aCRXDL+cUhi4qYiFG+MXFrX5Ffs4MslNCnLEeOK3KIOVNvANMS5bFw 44itiITZXFuR5dsXfvcYMCh/1pTF0jRuJ+aceY3i0uQUICTTxK87p8dBtITTen95hhvL Zs0rSFZ5BvMLX6XmdZz9KAnE5KvxqCMwPqGEjy2Vl1Z1seFi2qQnDjlVimjUZO12DHh5 hn9Q0kW4mIlYeV8EwkAu+1XHjmcM+ERrS8jK4r3B2NPwRTgkxZCemu5REQobgzec8noq +1kw== X-Google-DKIM-Signature: v=1; a=rsa-sha256; c=relaxed/relaxed; d=1e100.net; s=20230601; t=1750550264; x=1751155064; h=content-transfer-encoding:mime-version:references:in-reply-to :message-id:date:subject:to:from:x-gm-message-state:from:to:cc :subject:date:message-id:reply-to; bh=7/B4lDUEqafN5g8eYZX+ybcJ0Y62/HOxTSdDxW3/oSE=; b=gkXIy7wxpljAuld/ZYOD4FtunhNh7OJPbe/QrSnOSRJKBZ8cxmPZJA6rWastv39vCz 9fP/lwEuCCCAw41HJMsARnGW7zoHCfeWwNO8pzegghDOU3inmRZ4TBbSCReI7pqTJRH4 /7sDW8LKsvFxCmmEFB+f64Kdya0MBd3+oWTIjNt2huzp9bJd2O9tN6Gjq2iuPhBGtDdD YN5kkmSJpEXg/+eJ7xx/YEhWm9z+vBmI2Qio3f0OgnphDFuCYwYvEAs7JGfb49TQbwYI m8vCxOQCalKiQ7wav/VY5kEYZhRUK/Tbhnym3DIrpmbVdNteQMjTYI69jeMolbDhBnTp 2Whg== X-Gm-Message-State: AOJu0YxWMvhUDM1JvK4sdFiJ1p+sFeEsHT/hlnbxKLdbzIfbS9+XRhqG sL1E0FCLaGlXoi/hAKPKCdJ/U0exDfCWYQA3UMm62BmCOvUND5m7CA124Oxs6Muq6thrImh3CeZ Ba4y8p2s= X-Gm-Gg: ASbGncsIEh0OMTx3qgnT1kZhuqGG3l18Rmwi+iLVoR4XWNLqxoqeFOqaC7pi6VgBOM9 cNPuM59bbFIXTGC1S3Km+QGAH2XPA5MxtJ/WHxHx40iQz05/Dmvmhc22k+9AIdvdRmEZCy9RN8U 2ZzoxMckxZwr4dt95p4b5/PfdW9H+z/7QP9m282mmTpnlKYWhBwdTrLYrEkHpYYHA+MODm6yrtp fJRjyfUOgtt3LZ/Xwh17kGeqqdRaM8ftPhhY7ldZMWBmHbW54MuZJrmezNPZlqJPPWhbiuFWEb3 MJU94MQFdy6UV52SWxAflMhuzie4RRGLQrxJgMa1pyoJa9p3SFY2SbumOnAie1ewLYtYhthuA/6 WBSE9QcQSzThvXrNudRnW X-Received: by 2002:a05:6a21:8909:b0:220:42a0:7f65 with SMTP id adf61e73a8af0-22042a084a7mr2483005637.17.1750550263974; Sat, 21 Jun 2025 16:57:43 -0700 (PDT) Received: from stoup.. (174-21-67-243.tukw.qwest.net. [174.21.67.243]) by smtp.gmail.com with ESMTPSA id 41be03b00d2f7-b31f12584cbsm4551790a12.55.2025.06.21.16.57.43 for (version=TLS1_3 cipher=TLS_AES_256_GCM_SHA384 bits=256/256); Sat, 21 Jun 2025 16:57:43 -0700 (PDT) From: Richard Henderson To: qemu-devel@nongnu.org Subject: [PATCH v2 101/101] tests/tcg/aarch64: Add sme2-matmul test case Date: Sat, 21 Jun 2025 16:50:37 -0700 Message-ID: <20250621235037.74091-102-richard.henderson@linaro.org> X-Mailer: git-send-email 2.43.0 In-Reply-To: <20250621235037.74091-1-richard.henderson@linaro.org> References: <20250621235037.74091-1-richard.henderson@linaro.org> MIME-Version: 1.0 Received-SPF: pass client-ip=2607:f8b0:4864:20::434; envelope-from=richard.henderson@linaro.org; helo=mail-pf1-x434.google.com X-Spam_score_int: -20 X-Spam_score: -2.1 X-Spam_bar: -- X-Spam_report: (-2.1 / 5.0 requ) BAYES_00=-1.9, DKIM_SIGNED=0.1, DKIM_VALID=-0.1, DKIM_VALID_AU=-0.1, DKIM_VALID_EF=-0.1, RCVD_IN_DNSWL_NONE=-0.0001, SPF_HELO_NONE=0.001, SPF_PASS=-0.001 autolearn=ham autolearn_force=no X-Spam_action: no action X-BeenThere: qemu-devel@nongnu.org X-Mailman-Version: 2.1.29 Precedence: list List-Id: List-Unsubscribe: , List-Archive: List-Post: List-Help: List-Subscribe: , Errors-To: qemu-devel-bounces+patch=linaro.org@nongnu.org Sender: qemu-devel-bounces+patch=linaro.org@nongnu.org Extract from https://learn.arm.com/learning-paths/cross-platform/multiplying-matrices-with-sme2/ Merge into two files and drop the EL3 setup code for FVP. Signed-off-by: Richard Henderson --- tests/tcg/aarch64/sme2-matmul-0.c | 236 ++++++++++++++++++++++ tests/tcg/aarch64/Makefile.target | 11 +- tests/tcg/aarch64/sme2-matmul-1.S | 321 ++++++++++++++++++++++++++++++ 3 files changed, 567 insertions(+), 1 deletion(-) create mode 100644 tests/tcg/aarch64/sme2-matmul-0.c create mode 100644 tests/tcg/aarch64/sme2-matmul-1.S diff --git a/tests/tcg/aarch64/sme2-matmul-0.c b/tests/tcg/aarch64/sme2-matmul-0.c new file mode 100644 index 0000000000..35737c5694 --- /dev/null +++ b/tests/tcg/aarch64/sme2-matmul-0.c @@ -0,0 +1,236 @@ +/* + * SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates + * SPDX-License-Identifier: BSD-3-Clause-Clear + * + * Copied from + * https://learn.arm.com/learning-paths/cross-platform/multiplying-matrices-with-sme2/ + * + * and modified for testing with qemu-aarch64. + */ + +#include +#include +#include +#include +#include +#include + +#define DEBUG 0 + +/* + * Vanilla matrix multiplication using the by-the-book definition. + */ + +void preprocess_l(uint64_t nbr, uint64_t nbc, uint64_t SVL, + const float *restrict a, float *restrict a_mod) +{ + // For all tiles of SVL x SVL data + for (uint64_t By = 0; By < nbr; By += SVL) { + for (uint64_t Bx = 0; Bx < nbc; Bx += SVL) { + // For this tile + const uint64_t dest = By * nbc + Bx * SVL; + for (uint64_t j = 0; j < SVL; j++) { + for (uint64_t i = 0; i < SVL && (Bx + i) < nbc; i++) { + if (By + j < nbr) { + a_mod[dest + i * SVL + j] = a[(By + j) * nbc + Bx + i]; + } else { + // These elements are outside of matrix a, so zero them. + a_mod[dest + i * SVL + j] = 0.0; + } + } + } + } + } +} + +void matmul(uint64_t M, uint64_t K, uint64_t N, + const float *restrict matLeft, const float *restrict matRight, + float *restrict matResult) +{ + for (uint64_t m = 0; m < M; m++) { + for (uint64_t n = 0; n < N; n++) { + float acc = 0.0; + + for (uint64_t k = 0; k < K; k++) { + acc += matLeft[m * K + k] * matRight[k * N + n]; + } + + matResult[m * N + n] = acc; + } + } +} + +/* + * SME2 Matrix multiplication handwritten in assembly code. This is split in 2 + * functions that have to be invoked one after the other, with a top level + * binding. + */ + +/* Matrix preprocessing, in assembly. */ +void preprocess_l_asm(uint64_t M, uint64_t K, const float *restrict a, + float *restrict a_mod); + +/* Matrix multiplication (with the *transposed* RHS), in assembly. */ +void matmul_asm_impl(uint64_t M, uint64_t K, uint64_t N, + const float *restrict matLeft_mod, + const float *restrict matRight, float *restrict matResult); + +/* The top level matrix multiplication. */ +void matmul_asm(uint64_t M, uint64_t K, uint64_t N, + const float *restrict matLeft, const float *restrict matRight, + float *restrict matLeft_mod, float *restrict matResult) +{ + __asm volatile("" : : : + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", + "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + + preprocess_l_asm(M, K, matLeft, matLeft_mod); + matmul_asm_impl(M, K, N, matLeft_mod, matRight, matResult); + + __asm volatile("" : : : + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", + "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +} + + +// Initialize an array of float. +enum InitKind { RANDOM_INIT, LINEAR_INIT, DEAD_INIT }; +void initialize_matrix(float *mat, size_t num_elements, enum InitKind kind) +{ + for (size_t i = 0; i < num_elements; i++) + switch (kind) { + case RANDOM_INIT: + mat[i] = (((float)(rand() % 10000) / 100.0f) - 30.0); + break; + case LINEAR_INIT: + mat[i] = i+1; + break; + case DEAD_INIT: + mat[i] = nan(""); + break; + } +} + +/* Pretty print a matrix. */ +void print_matrix(size_t nbr, size_t nbc, const float *mat, const char *name) +{ + printf("%s(%lu,%lu) = [", name, nbr, nbc); + for (size_t y = 0; y < nbr; y++) { + printf("\n "); + for (size_t x = 0; x < nbc; x++) + printf("%9.2f, ", mat[y * nbc + x]); + } + printf("\n];\n"); +} + +/* Compare 2 matrices for equality. */ +unsigned compare_matrices(size_t nbr, size_t nbc, const float *reference, + const float *result, const char *str) +{ + unsigned error = 0; + + for (size_t y = 0; y < nbr; y++) { + for (size_t x = 0; x < nbc; x++) { + if (fabsf(reference[y * nbc + x] - result[y * nbc + x]) > + fabsf(0.0002f * reference[y * nbc + x])) { + error = 1; + if (DEBUG) { + printf("%lu (%lu,%lu): %f <> %f\n", y * nbc + x, x, y, + reference[y * nbc + x], result[y * nbc + x]); + } + } + } + } + if (DEBUG) { + if (error) { + print_matrix(nbr, nbc, reference, "reference"); + print_matrix(nbr, nbc, result, "result"); + } + printf("%s: %s !\n", str, error ? "FAILED" : "PASS"); + } + + return error; +} + +uint64_t ool_svcntsw(void); + +/* + * Assumptions: + * nbr in matLeft (M): any + * nbc in matLeft, nbr in matRight (K): any K > 2 + * nbc in matRight (N): any + */ + +int main(int argc, char **argv) +{ + /* Size parameters */ + uint64_t M, N, K; + if (argc >= 4) { + M = strtoul(argv[1], NULL, 0); + K = strtoul(argv[2], NULL, 0); + N = strtoul(argv[3], NULL, 0); + } else { + /* Default: 125x35x70 */ + M = 125; + K = 35; + N = 70; + } + + if (DEBUG) { + printf("\nSME2 Matrix Multiply fp32 *asm* example " + "with args %lu %lu %lu\n", M, K, N); + } + + const uint64_t SVL = ool_svcntsw(); + + /* Calculate M of transformed matLeft. */ + const uint64_t M_mod = SVL * (M / SVL + (M % SVL != 0 ? 1 : 0)); + + float *matRight = (float *)malloc(K * N * sizeof(float)); + + float *matLeft = (float *)malloc(M * K * sizeof(float)); + float *matLeft_mod = (float *)malloc(M_mod * K * sizeof(float)); + float *matLeft_mod_ref = (float *)malloc(M_mod * K * sizeof(float)); + + float *matResult = (float *)malloc(M * N * sizeof(float)); + float *matResult_ref = (float *)malloc(M * N * sizeof(float)); + + // initialize_matrix(matLeft, M * K, RANDOM_INIT); + // initialize_matrix(matRight, K * N, RANDOM_INIT); + initialize_matrix(matLeft, M * K, LINEAR_INIT); + initialize_matrix(matRight, K * N, LINEAR_INIT); + initialize_matrix(matLeft_mod, M_mod * K, DEAD_INIT); + initialize_matrix(matResult, M * N, DEAD_INIT); + + if (DEBUG) { + print_matrix(M, K, matLeft, "matLeft"); + print_matrix(K, N, matRight, "matRight"); + } + + matmul_asm(M, K, N, matLeft, matRight, matLeft_mod, matResult); + + /* Compute the reference values with the vanilla implementations. */ + matmul(M, K, N, matLeft, matRight, matResult_ref); + preprocess_l(M, K, SVL, matLeft, matLeft_mod_ref); + + unsigned error = compare_matrices(K, M_mod, matLeft_mod_ref, matLeft_mod, + "Matrix preprocessing"); + if (!error) + error = compare_matrices(M, N, matResult_ref, matResult, + "Matrix multiplication"); + + free(matRight); + + free(matLeft); + free(matLeft_mod); + free(matLeft_mod_ref); + + free(matResult); + free(matResult_ref); + + return error ? EXIT_FAILURE : EXIT_SUCCESS; +} diff --git a/tests/tcg/aarch64/Makefile.target b/tests/tcg/aarch64/Makefile.target index 16ddcf4f88..641c00cf02 100644 --- a/tests/tcg/aarch64/Makefile.target +++ b/tests/tcg/aarch64/Makefile.target @@ -28,7 +28,8 @@ config-cc.mak: Makefile $(call cc-option,-march=armv8.5-a, CROSS_CC_HAS_ARMV8_5); \ $(call cc-option,-mbranch-protection=standard, CROSS_CC_HAS_ARMV8_BTI); \ $(call cc-option,-march=armv8.5-a+memtag, CROSS_CC_HAS_ARMV8_MTE); \ - $(call cc-option,-Wa$(COMMA)-march=armv9-a+sme $$fnia, CROSS_AS_HAS_ARMV9_SME)) 3> config-cc.mak + $(call cc-option,-Wa$(COMMA)-march=armv9-a+sme $$fnia, CROSS_AS_HAS_ARMV9_SME); \ + $(call cc-option,-Wa$(COMMA)-march=armv9-a+sme2 $$fnia, CROSS_AS_HAS_ARMV9_SME2)) 3> config-cc.mak -include config-cc.mak ifneq ($(CROSS_CC_HAS_ARMV8_2),) @@ -75,6 +76,14 @@ AARCH64_TESTS += $(SME_TESTS) $(SME_TESTS): CFLAGS += $(CROSS_AS_HAS_ARMV9_SME) endif +# SME2 Tests +ifneq ($(CROSS_AS_HAS_ARMV9_SME2),) +sme2-matmul: CFLAGS += $(CROSS_AS_HAS_ARMV9_SME2) +sme2-matmul: sme2-matmul-0.c sme2-matmul-1.S + $(CC) $(CFLAGS) $(EXTRA_CFLAGS) $^ -o $@ $(LDFLAGS) +AARCH64_TESTS += sme2-matmul +endif + # System Registers Tests AARCH64_TESTS += sysregs diff --git a/tests/tcg/aarch64/sme2-matmul-1.S b/tests/tcg/aarch64/sme2-matmul-1.S new file mode 100644 index 0000000000..5562e24c62 --- /dev/null +++ b/tests/tcg/aarch64/sme2-matmul-1.S @@ -0,0 +1,321 @@ +/* + * SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates + * SPDX-License-Identifier: BSD-3-Clause-Clear + * + * Copied from + * https://learn.arm.com/learning-paths/cross-platform/multiplying-matrices-with-sme2/ + * + * and modified for testing with qemu-aarch64. + */ + + .text + .cfi_sections .debug_frame // put stack frame info into .debug_frame instead of .eh_frame + + .global ool_svcntsw + .type ool_svcntsw, "function" + .cfi_startproc +ool_svcntsw: + rdsvl x0, #1 + lsr x0, x0, #1 + ret + .cfi_endproc + .size ool_svcntsw, .-ool_svcntsw + + .global preprocess_l_asm + .type preprocess_l_asm, "function" + .cfi_startproc + +preprocess_l_asm: + // preprocess_l_asm(uint64_t nbr, uint64_t nbc, const float * restrict a, float * a_mod); + // x0 : nbr + // x1 : nbc + // x2 : &a + // x3 : &a_mod + // x4 : SVLs (=cntw) + // x5 : Exit condition for inner loop + // x6 : a_ptr + // x7 : Outer loop counter + // x8 : a_base + // x9 : a_mod store base address + // x10: 32b Tile0 store end pointer + // x11: SVLs*nbc + // x12: Load/Store loop counter + // x13: 32b Tile1 store end pointer + // x14: 2*nbc + // x15: 3*nbc + // x16: 32b tile size + +// Assumptions: +// nbr in matLeft (M): any +// nbc in matLeft, nbr in matRight (K): any K > 2 +// nbc in matRight (N): any +// +// Left matrix re-arrangement: +// Block of SVLs rows is transposed and contiguously stored. +// Then the same transformation is applied to remaining blocks of SVLs rows. +// The last block of rows is zero-padded to SVLs rows, if applicable. + + smstart + +// constants + cntw x4 // SVLs + mul x11, x4, x1 // SVLs*nbc + lsl x14, x1, #1 // 2*nbc + add x15, x14, x1 // 3*nbc + + mul x16, x4, x4 // SVLs*SVLs + + mov x7, #0 + whilelt p0.s, x7, x0 // Tile predicate (M dimension) + +.Loop_outer: + mov x8, x2 // a load base address + mov x9, x3 // a_mod store base address + add x5, x2, x1, lsl #2 // Exit condition for inner loop + + add x10, x9 , x11, lsl #2 // 32b Tile0 store predicate condition + sub x13, x10, x16, lsl #2 // 32b Tile1 store predicate condition + whilelt pn8.b, x8, x5, vlx2 // Tile predicate-as-counter (K dimension) + +.Loop_inner: + mov x6, x8 // a_ptr + + mov w12, #0 // Load_loop counter + +.Load_loop: + psel pn10, pn8, p0.s[w12, 0] + psel pn11, pn8, p0.s[w12, 1] + psel pn12, pn8, p0.s[w12, 2] + psel pn13, pn8, p0.s[w12, 3] + ld1w {z20.s, z28.s}, pn10/z, [x6] // Load 2 row vectors from a_ptr + ld1w {z21.s, z29.s}, pn11/z, [x6, x1, lsl #2] // Load " " " from a_ptr + nbc + ld1w {z22.s, z30.s}, pn12/z, [x6, x14, lsl #2] // Load " " " from a_ptr + nbc*2 + ld1w {z23.s, z31.s}, pn13/z, [x6, x15, lsl #2] // Load " " " from a_ptr + nbc*3 + mova za0h.s[w12, 0:3], {z20.s-z23.s} + mova za1h.s[w12, 0:3], {z28.s-z31.s} + + add x6, x6, x1, lsl #4 // a_ptr+=4*nbc FP32 elms [Bytes] + add w12, w12, #4 // increment counter + cmp w12, w4 + b.mi .Load_loop + + mov w12, #0 // Store_loop counter + +.Store_loop: + whilelt pn10.b, x9, x10, vlx4 + whilelt pn11.b, x9, x13, vlx4 + mova {z0.s-z3.s}, za0v.s[w12, 0:3] + mova {z4.s-z7.s}, za1v.s[w12, 0:3] + st1w {z0.s-z3.s}, pn10, [x9] // Store 4 col vectors to a_mod + st1w {z4.s-z7.s}, pn11, [x9, x16, lsl #2] // Store 4 col vectors to a_mod + SVLs*SVLs + addvl x9, x9, #4 // a_mod += 4*SVLb [Bytes] + add w12, w12, #4 // increment counter + cmp w12, w4 + b.mi .Store_loop + + add x9, x9, x16, lsl #2 + addvl x8, x8, #2 // a_base += 2*SVLb [Bytes] + whilelt pn8.b, x8, x5, vlx2 + b.first .Loop_inner + + add x3, x3, x11, lsl #2 // &a_mod += SVLs*nbc FP32 elms [Bytes] + add x2, x2, x11, lsl #2 // &a += SVLs*nbc FP32 elms [Bytes] + incw x7 + + whilelt p0.s, x7, x0 + b.first .Loop_outer + + smstop + + ret + .cfi_endproc + .size preprocess_l_asm, .-preprocess_l_asm + + .global matmul_asm_impl + .type matmul_asm_impl, "function" + .cfi_startproc + +matmul_asm_impl: + // matmul_asm_impl(M, K, N, matLeft, matRight, matResult_opt); + // x0 : M + // x1 : K, lda + // x2 : N, ldc, ldb + // x3 : &matLeft + // x4 : &matRight + // x5 : &matResult_opt + // x6 : SVLs-2 + // x7 : a_ptr pointer + // x8 : a_ptr end address + // x9 : c_base pointer + // x10: c_ptr0 pointer + // x11: Exit condition for N loop + // x12: M loop counter + // x13: Store loop counter + // x14: Predicate select index + // x15: Exit condition for K loop + // x16: b_base pointer + // x17: b_ptr pointer + // x18: (SVLs+1)*ldc + // x19: ldb + SVLs + // x20: SVLs*lda + SVLs + // x21: c_ptr1 pointer + // x22: SVLs*lda + // x23: SVLs*ldc + +// Assumptions: +// nbr in matLeft (M): any +// nbc in matLeft, nbr in matRight (K): any K > 2 +// nbc in matRight (N): any +// +// Left matrix is pre-arranged. +// +// 32-bit accumulator mapping with 2x2 tiles processing + + stp x19, x20, [sp, #-48]! + stp x21, x22, [sp, #16] + stp x23, x24, [sp, #32] + + smstart + +// constants + cntw x6 // SVLs + mul x22, x6, x1 // SVLs*lda + mul x23, x6, x2 // SVLs*ldc + add x18, x23, x2 // SVLs*ldc + ldc + add x11, x4, x2, lsl #2 // Exit condition for N loop + mov x12, #0 + cntb x6 // SVLb + mov x14, #0 + ptrue pn10.b // Predicate as counter for SME2 VLx2 (a_ptr loads) + whilelt pn8.s, x12, x0, vlx2 // tiles predicate (M dimension) + sub w6, w6, #8 // SVLb-8 + +.Loop_M: + // Extracting tile 0/1 and tile 2/3 predicates (M dimension) from vlx2 predicate. + pext { p2.s, p3.s }, pn8[0] + mov x16, x4 // b_base + mov x9, x5 // c_base + + whilelt pn9.b, x16, x11, vlx2 // tiles predicate (N dimension) + +.Loop_N: + mov x7, x3 // a_ptr = a_base + mov x17, x16 // b_ptr = b_base + mov x10, x9 // c_ptr0 = c_base + + // Extracting tile 0/2 and tile 1/3 predicates (N dimension) from vlx2 predicate. + pext { p0.b, p1.b }, pn9[0] + + add x8, x3, x22, lsl #2 // a_base + SVLs*lda FP32 elms [Bytes] + addvl x15, x8, #-1 // Exit condition for K loop + ld1w {z1.s}, p2/z, [x7] // Load 1st vector from a_ptr + + zero {za} + ld1w {z2.s-z3.s}, pn9/z, [x17] // Load 2 vectors from b_ptr + + fmopa za0.s, p2/m, p0/m, z1.s, z2.s // ZA0 += 1st a_ptr vector OP 1st b_ptr vector + ld1w {z5.s}, p3/z, [x7, x22, lsl #2] // Load 2nd vector from a_ptr + addvl x7, x7, #1 // a_ptr += SVLb [Bytes] + +.Loop_K: + fmopa za2.s, p3/m, p0/m, z5.s, z2.s // ZA2 += 2nd a_ptr vector OP 1st b_ptr vector + + fmopa za1.s, p2/m, p1/m, z1.s, z3.s // ZA1 += 1st a_ptr vector OP 2nd b_ptr vector + ld1w {z0.s-z1.s}, pn10/z, [x7] // Load next 2 vectors from a_ptr + + fmopa za3.s, p3/m, p1/m, z5.s, z3.s // ZA3 += 2nd a_ptr vector OP 2nd b_ptr vector + ld1w {z6.s-z7.s}, pn9/z, [x17, x2, lsl #2] // Load next 2 vectors from b_ptr + + fmopa za0.s, p2/m, p0/m, z0.s, z6.s // ZA0 += 1st a_ptr vector OP 1st b_ptr vector + psel pn11, pn10, p3.s[w14, 0] // Select predicate-as-counter + ld1w {z4.s-z5.s}, pn11/z, [x7, x22, lsl #2] // Load next 2 vectors from a_ptr + + fmopa za2.s, p3/m, p0/m, z4.s, z6.s // ZA2 += 2nd a_ptr vector OP 1st b_ptr vector + add x17, x17, x2, lsl #3 // b_ptr += 2*ldb FP32 elms [Bytes] + + fmopa za1.s, p2/m, p1/m, z0.s, z7.s // ZA1 += 1st a_ptr vector OP 2nd b_ptr vector + + fmopa za3.s, p3/m, p1/m, z4.s, z7.s // ZA3 += 2nd a_ptr vector OP 2nd b_ptr vector + ld1w {z2.s-z3.s}, pn9/z, [x17] // Load next 2 vectors from b_ptr + + fmopa za0.s, p2/m, p0/m, z1.s, z2.s // ZA0 += 1st a_ptr vector OP 1st b_ptr vector + addvl x7, x7, #2 // a_ptr += 2*SVLb [Bytes] + + cmp x7, x15 + b.mi .Loop_K + + fmopa za2.s, p3/m, p0/m, z5.s, z2.s // ZA2 += 2nd a_ptr vector OP 1st b_ptr vector + + fmopa za1.s, p2/m, p1/m, z1.s, z3.s // ZA1 += 1st a_ptr vector OP 2nd b_ptr vector + + fmopa za3.s, p3/m, p1/m, z5.s, z3.s // ZA3 += 2nd a_ptr vector OP 2nd b_ptr vector + add x17, x17, x2, lsl #2 // b_ptr += 2*ldb FP32 elms [Bytes] + + cmp x7, x8 + b.pl .Ktail_end + +.Ktail_start: + ld1w {z1.s}, p2/z, [x7] + ld1w {z2.s-z3.s}, pn9/z, [x17] + + fmopa za0.s, p2/m, p0/m, z1.s, z2.s + ld1w {z5.s}, p3/z, [x7, x22, lsl #2] + + fmopa za2.s, p3/m, p0/m, z5.s, z2.s + + fmopa za1.s, p2/m, p1/m, z1.s, z3.s + + fmopa za3.s, p3/m, p1/m, z5.s, z3.s + +.Ktail_end: + mov w13, #0 + psel pn11, pn9, p2.b[w13, 0] + psel pn12, pn9, p3.b[w13, 0] + // Move from ZA tiles to vectors: z0 = za0h[1], z1 = za1h[1], z2 = za2h[1], z3 = za3h[1] + mova { z0.b-z3.b }, za0h.b[w13, 0:3] + st1w { z0.s-z1.s }, pn11, [x10] // Store to c_ptr0 + st1w { z2.s-z3.s }, pn12, [x10, x23, lsl #2] // Store to c_ptr0 + SVLs*ldc +.Loop_store_ZA: + psel pn11, pn9, p2.b[w13, 4] + psel pn12, pn9, p3.b[w13, 4] + mova { z0.b-z3.b }, za0h.b[w13, 4:7] + st1w { z0.s-z1.s }, pn11, [x10, x2, lsl #2] // Store to c_ptr0 + ldc + st1w { z2.s-z3.s }, pn12, [x10, x18, lsl #2] // Store to c_ptr0 + (SVLs+1)*ldc + + add x10, x10, x2, lsl #3 // c_ptr0 += 2*ldc FP32 elms [Bytes] + add w13, w13, #8 + + psel pn11, pn9, p2.b[w13, 0] + psel pn12, pn9, p3.b[w13, 0] + mova { z0.b-z3.b }, za0h.b[w13, 0:3] + st1w { z0.s-z1.s }, pn11, [x10] // Store to c_ptr0 + st1w { z2.s-z3.s }, pn12, [x10, x23, lsl #2] // Store to c_ptr0 + SVLs*ldc + cmp w13, w6 + b.mi .Loop_store_ZA + + psel pn11, pn9, p2.b[w13, 4] + psel pn12, pn9, p3.b[w13, 4] + mova { z0.b-z3.b }, za0h.b[w13, 4:7] + st1w { z0.s-z1.s }, pn11, [x10, x2, lsl #2] // Store to c_ptr0 + ldc + st1w { z2.s-z3.s }, pn12, [x10, x18, lsl #2] // Store to c_ptr0 + (SVLs+1)*ldc + + addvl x9, x9, #2 + addvl x16, x16, #2 // b_base += 2*SVLb [Bytes] + whilelt pn9.b, x16, x11, vlx2 // tile predicate (N dimension) + b.first .Loop_N + + add x3, x3, x22, lsl #3 // a_base += 2*SVLs*lda FP32 elms [Bytes] + add x5, x5, x23, lsl #3 // c_base += 2*SVLs*ldc FP32 elms [Bytes] + incw x12, all, mul #2 // M loop counter += 2* SVLs + whilelt pn8.s, x12, x0, vlx2 // tiles predicate (M dimension) + b.first .Loop_M + + smstop + + ldp x23, x24, [sp, #32] + ldp x21, x22, [sp, #16] + ldp x19, x20, [sp], #48 + + ret + .cfi_endproc + .size matmul_asm_impl, .-matmul_asm_impl