diff mbox series

[RFC,4/4] testing: add a getrandom() GRND_TIMESTAMP vDSO demonstration/benchmark

Message ID ee7dec1ec967c38080c44f73246e9b8636b8b624.1673539719.git.ydroneaud@opteya.com
State New
Headers show
Series random: a simple vDSO mechanism for reseeding userspace CSPRNGs | expand

Commit Message

Yann Droneaud Jan. 12, 2023, 5:02 p.m. UTC
Link: https://lore.kernel.org/all/cover.1673539719.git.ydroneaud@opteya.com/
Signed-off-by: Yann Droneaud <ydroneaud@opteya.com>
---
 tools/testing/crypto/getrandom/Makefile       |   4 +
 .../testing/crypto/getrandom/test-getrandom.c | 307 ++++++++++++++++++
 2 files changed, 311 insertions(+)
 create mode 100644 tools/testing/crypto/getrandom/Makefile
 create mode 100644 tools/testing/crypto/getrandom/test-getrandom.c
diff mbox series

Patch

diff --git a/tools/testing/crypto/getrandom/Makefile b/tools/testing/crypto/getrandom/Makefile
new file mode 100644
index 000000000000..1370b6f1ae94
--- /dev/null
+++ b/tools/testing/crypto/getrandom/Makefile
@@ -0,0 +1,4 @@ 
+# SPDX-License-Identifier: GPL-2.0+
+
+test-getrandom: test-getrandom.c
+	$(CC) $(CPPFLAGS) $(CFLAGS) -I ../../../../usr/include/ -O2 -Wall -Wextra -o $@ $^ -ldl
diff --git a/tools/testing/crypto/getrandom/test-getrandom.c b/tools/testing/crypto/getrandom/test-getrandom.c
new file mode 100644
index 000000000000..311eef503f50
--- /dev/null
+++ b/tools/testing/crypto/getrandom/test-getrandom.c
@@ -0,0 +1,307 @@ 
+// SPDX-License-Identifier: GPL-2.0+
+/*
+ * Copyright (C) 2022 Yann Droneaud. All Rights Reserved.
+ */
+
+#include <dlfcn.h>
+#include <errno.h>
+#include <inttypes.h>
+#include <stdarg.h>
+#include <stdbool.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/mman.h>
+#include <sys/syscall.h>
+#include <time.h>
+#include <unistd.h>
+
+#include <linux/random.h>
+
+static size_t pagesz;
+static size_t discarded;
+
+typedef ssize_t(*getrandom_fn) (void *, size_t, int);
+
+static bool grnd_timestamp;
+static getrandom_fn getrandom_vDSO;
+
+static ssize_t getrandom_syscall(void *buffer, size_t size, int flags)
+{
+	return syscall(SYS_getrandom, buffer, size, flags);
+}
+
+static ssize_t timestamp(getrandom_fn _getrandom, uint64_t *grnd_ts,
+			 size_t size)
+{
+	ssize_t ret;
+
+	ret = _getrandom(grnd_ts, size, GRND_TIMESTAMP);
+	if (ret < 0) {
+		fprintf(stderr,
+			"getrandom(,,GRND_TIMESTAMP) failed: %ld (%s)\n", -ret,
+			strerror((int)-ret));
+		return -1;
+	}
+
+	return ret;
+}
+
+static void fetch(getrandom_fn _getrandom, void *buffer, size_t size)
+{
+	ssize_t ret;
+
+	ret = _getrandom(buffer, size, 0);
+	if (ret < 0) {
+		fprintf(stderr, "getrandom(,,0) failed: %ld (%s)\n", -ret,
+			strerror((int)-ret));
+		exit(EXIT_FAILURE);
+	}
+}
+
+struct rng {
+	uint64_t grnd_ts;
+	size_t availsz;		/* available bytes in buffer */
+	size_t buffersz;	/* buffer size */
+	uint8_t buffer[];
+};
+
+static struct rng *rng;
+
+static void init_rng(void)
+{
+	int r;
+	ssize_t s;
+	void *p;
+
+	r = getpagesize();
+	if (r == -1) {
+		fprintf(stderr, "getpagesize() failed: %d\n", errno);
+		exit(EXIT_FAILURE);
+	}
+
+	pagesz = (size_t)r;
+
+	p = mmap(NULL, pagesz, PROT_READ | PROT_WRITE,
+		 MAP_ANONYMOUS | MAP_PRIVATE, -1, 0);
+	if (p == MAP_FAILED) {
+		fprintf(stderr, "mmap() failed: %d\n", errno);
+		exit(EXIT_FAILURE);
+	}
+
+	r = madvise(p, pagesz, MADV_DONTDUMP | MADV_WIPEONFORK);
+	if (r == -1) {
+		fprintf(stderr, "madvise() failed: %d\n", errno);
+		exit(EXIT_FAILURE);
+	}
+
+	r = mlock(p, pagesz);
+	if (r == -1)
+		fprintf(stderr, "mlock() failed: %d\n", errno);
+
+	rng = p;
+
+	s = timestamp(getrandom_syscall, &rng->grnd_ts, sizeof(rng->grnd_ts));
+	if (s == -1)
+		return;
+
+	printf("getrandom() support GRND_TIMESTAMP\n");
+
+	grnd_timestamp = true;
+}
+
+static void init_vdso(void)
+{
+	void *h;
+	void *p;
+
+	h = dlopen("linux-vdso.so.1", RTLD_LAZY | RTLD_LOCAL | RTLD_NOLOAD);
+	if (!h) {
+		fprintf(stderr, "failed to open vDSO: %s\n", dlerror());
+		return;
+	}
+
+	p = dlsym(h, "__vdso_getrandom");
+	if (!p) {
+		fprintf(stderr, "getrandom() not found in vDSO: %s\n",
+			dlerror());
+		return;
+	}
+
+	printf("found getrandom() in vDSO at %p\n", p);
+
+	getrandom_vDSO = p;
+}
+
+/*
+ * 1) check timestamp isn't expired
+ * 2) if expired or there's not enough data in buffer
+ *     a) if expired, reset buffer size,
+ *     b) fetch new random stream
+ *     c) check timestamp
+ *     d) if expired, reset buffer size, go to b)
+ *
+ */
+static void ensure(getrandom_fn _getrandom, size_t request)
+{
+	ssize_t r;
+
+	r = timestamp(_getrandom, &rng->grnd_ts, sizeof(rng->grnd_ts));
+	switch (r) {
+	case 0:	/* timestamp didn't change */
+		/* enough available random bytes ? */
+		if (rng->availsz >= request)
+			return;
+
+		/* increase buffer size when drained */
+		if (rng->buffersz < pagesz - sizeof(*rng))
+			rng->buffersz *= 2;
+
+		/* no less than 32 */
+		if (rng->buffersz < 32)
+			rng->buffersz = 32;
+
+		/* no more than a full page minus the rng structure */
+		if (rng->buffersz > pagesz - sizeof(*rng))
+			rng->buffersz = pagesz - sizeof(*rng);
+
+		break;
+
+	case sizeof(rng->grnd_ts):	/* timestamp did change, random bytes must be discarded */
+		rng->buffersz = 32;	/* reset size */
+		break;
+
+	default:
+		fprintf(stderr, "unexpected timestamp size %zd\n", r);
+		exit(EXIT_FAILURE);
+	}
+
+	/* keep fetching if timestamp is updated */
+	for (;;) {
+		if (rng->availsz)
+			discarded += rng->availsz;
+
+		fetch(_getrandom, rng->buffer, rng->buffersz);
+		rng->availsz = rng->buffersz;
+
+		r = timestamp(_getrandom, &rng->grnd_ts, sizeof(rng->grnd_ts));
+
+		switch (r) {
+		case 0:	/* timestamp didn't change between previous check and last fetch */
+			return;
+
+		case sizeof(rng->grnd_ts):	/* timestamp did change, random bytes just fetched must be discarded */
+			rng->buffersz = 32;	/* reset size */
+			continue;	/* retry again */
+
+		default:
+			fprintf(stderr, "unexpected timestamp size %zd\n", r);
+			exit(EXIT_FAILURE);
+		}
+	}
+}
+
+/* arc4random() */
+static void get_direct(getrandom_fn _getrandom)
+{
+	uint32_t v;
+	fetch(_getrandom, &v, sizeof(v));
+}
+
+static void get_pooled(getrandom_fn _getrandom)
+{
+	ensure(_getrandom, sizeof(uint32_t));
+	rng->availsz -= sizeof(uint32_t);
+}
+
+static inline struct timespec timespec_sub(const struct timespec *a,
+					   const struct timespec *b)
+{
+	struct timespec res;
+
+	res.tv_sec = a->tv_sec - b->tv_sec;
+	res.tv_nsec = a->tv_nsec - b->tv_nsec;
+	if (res.tv_nsec < 0) {
+		res.tv_sec--;
+		res.tv_nsec += 1000000000L;
+	}
+
+	return res;
+}
+
+#define SAMPLES 13
+#define VALUES (16 * 1024 * 1024)
+
+static void test_direct(getrandom_fn _getrandom, const char *method)
+{
+	struct timespec start, end, diff;
+
+	for (int i = 0; i < SAMPLES; i++) {
+		clock_gettime(CLOCK_MONOTONIC, &start);
+
+		for (uint32_t j = 0; j < VALUES; j++)
+			get_direct(_getrandom);
+
+		clock_gettime(CLOCK_MONOTONIC, &end);
+
+		diff = timespec_sub(&end, &start);
+
+		printf("== direct %s getrandom(), %u u32, %lu.%09lu s, %.3f M u32/s, %.3f ns/u32\n",
+		       method, VALUES, diff.tv_sec, diff.tv_nsec,
+		       VALUES / (1000000 *
+				 (diff.tv_sec +
+				  (double)diff.tv_nsec / 1000000000UL)),
+		       (double)(diff.tv_sec * 1000000000UL +
+				diff.tv_nsec) / VALUES);
+	}
+}
+
+static void test_pooled(getrandom_fn _getrandom, const char *method)
+{
+	struct timespec start, end, diff;
+
+	for (int i = 0; i < SAMPLES; i++) {
+		discarded = 0;
+
+		clock_gettime(CLOCK_MONOTONIC, &start);
+
+		for (uint32_t j = 0; j < VALUES; j++)
+			get_pooled(_getrandom);
+
+		clock_gettime(CLOCK_MONOTONIC, &end);
+
+		diff = timespec_sub(&end, &start);
+
+		printf("== pooled %s getrandom(), %u u32, %lu.%09lu s, %.3f M u32/s, %.3f ns/u32, (%zu bytes discarded)\n",
+		       method, VALUES, diff.tv_sec, diff.tv_nsec,
+		       VALUES / (1000000 *
+				 (diff.tv_sec +
+				  (double)diff.tv_nsec / 1000000000UL)),
+		       (double)(diff.tv_sec * 1000000000UL +
+				diff.tv_nsec) / VALUES,
+		       discarded);
+	}
+}
+
+int main(void)
+{
+	printf("getrandom(,,GRND_TIMESTAMP) test\n");
+
+	init_rng();
+	init_vdso();
+
+	while (1) {
+		test_direct(getrandom_syscall, "syscall");
+
+		if (getrandom_vDSO)
+			test_direct(getrandom_vDSO, "vDSO");
+
+		if (grnd_timestamp)
+			test_pooled(getrandom_syscall, "syscall");
+
+		if (getrandom_vDSO && grnd_timestamp)
+			test_pooled(getrandom_vDSO, "vDSO");
+	}
+
+	return 0;
+}