diff mbox series

[v6,13/28] ntsync: Introduce alertable waits.

Message ID 20241209185904.507350-14-zfigura@codeweavers.com
State New
Headers show
Series NT synchronization primitive driver | expand

Commit Message

Elizabeth Figura Dec. 9, 2024, 6:58 p.m. UTC
NT waits can optionally be made "alertable". This is a special channel for
thread wakeup that is mildly similar to SIGIO. A thread has an internal single
bit of "alerted" state, and if a thread is alerted while an alertable wait, the
wait will return a special value, consume the "alerted" state, and will not
consume any of its objects.

Alerts are implemented using events; the user-space NT emulator is expected to
create an internal ntsync event for each thread and pass that event to wait
functions.

Signed-off-by: Elizabeth Figura <zfigura@codeweavers.com>
---
 drivers/misc/ntsync.c       | 70 ++++++++++++++++++++++++++++++++-----
 include/uapi/linux/ntsync.h |  3 +-
 2 files changed, 63 insertions(+), 10 deletions(-)
diff mbox series

Patch

diff --git a/drivers/misc/ntsync.c b/drivers/misc/ntsync.c
index 5a5ee7b6ee92..3fac06270549 100644
--- a/drivers/misc/ntsync.c
+++ b/drivers/misc/ntsync.c
@@ -884,22 +884,29 @@  static int setup_wait(struct ntsync_device *dev,
 		      const struct ntsync_wait_args *args, bool all,
 		      struct ntsync_q **ret_q)
 {
+	int fds[NTSYNC_MAX_WAIT_COUNT + 1];
 	const __u32 count = args->count;
-	int fds[NTSYNC_MAX_WAIT_COUNT];
 	struct ntsync_q *q;
+	__u32 total_count;
 	__u32 i, j;
 
-	if (args->pad[0] || args->pad[1] || (args->flags & ~NTSYNC_WAIT_REALTIME))
+	if (args->pad || (args->flags & ~NTSYNC_WAIT_REALTIME))
 		return -EINVAL;
 
 	if (args->count > NTSYNC_MAX_WAIT_COUNT)
 		return -EINVAL;
 
+	total_count = count;
+	if (args->alert)
+		total_count++;
+
 	if (copy_from_user(fds, u64_to_user_ptr(args->objs),
 			   array_size(count, sizeof(*fds))))
 		return -EFAULT;
+	if (args->alert)
+		fds[count] = args->alert;
 
-	q = kmalloc(struct_size(q, entries, count), GFP_KERNEL);
+	q = kmalloc(struct_size(q, entries, total_count), GFP_KERNEL);
 	if (!q)
 		return -ENOMEM;
 	q->task = current;
@@ -909,7 +916,7 @@  static int setup_wait(struct ntsync_device *dev,
 	q->ownerdead = false;
 	q->count = count;
 
-	for (i = 0; i < count; i++) {
+	for (i = 0; i < total_count; i++) {
 		struct ntsync_q_entry *entry = &q->entries[i];
 		struct ntsync_obj *obj = get_obj(dev, fds[i]);
 
@@ -959,10 +966,10 @@  static void try_wake_any_obj(struct ntsync_obj *obj)
 static int ntsync_wait_any(struct ntsync_device *dev, void __user *argp)
 {
 	struct ntsync_wait_args args;
+	__u32 i, total_count;
 	struct ntsync_q *q;
 	int signaled;
 	bool all;
-	__u32 i;
 	int ret;
 
 	if (copy_from_user(&args, argp, sizeof(args)))
@@ -972,9 +979,13 @@  static int ntsync_wait_any(struct ntsync_device *dev, void __user *argp)
 	if (ret < 0)
 		return ret;
 
+	total_count = args.count;
+	if (args.alert)
+		total_count++;
+
 	/* queue ourselves */
 
-	for (i = 0; i < args.count; i++) {
+	for (i = 0; i < total_count; i++) {
 		struct ntsync_q_entry *entry = &q->entries[i];
 		struct ntsync_obj *obj = entry->obj;
 
@@ -983,9 +994,15 @@  static int ntsync_wait_any(struct ntsync_device *dev, void __user *argp)
 		ntsync_unlock_obj(dev, obj, all);
 	}
 
-	/* check if we are already signaled */
+	/*
+	 * Check if we are already signaled.
+	 *
+	 * Note that the API requires that normal objects are checked before
+	 * the alert event. Hence we queue the alert event last, and check
+	 * objects in order.
+	 */
 
-	for (i = 0; i < args.count; i++) {
+	for (i = 0; i < total_count; i++) {
 		struct ntsync_obj *obj = q->entries[i].obj;
 
 		if (atomic_read(&q->signaled) != -1)
@@ -1002,7 +1019,7 @@  static int ntsync_wait_any(struct ntsync_device *dev, void __user *argp)
 
 	/* and finally, unqueue */
 
-	for (i = 0; i < args.count; i++) {
+	for (i = 0; i < total_count; i++) {
 		struct ntsync_q_entry *entry = &q->entries[i];
 		struct ntsync_obj *obj = entry->obj;
 
@@ -1062,6 +1079,14 @@  static int ntsync_wait_all(struct ntsync_device *dev, void __user *argp)
 		 */
 		list_add_tail(&entry->node, &obj->all_waiters);
 	}
+	if (args.alert) {
+		struct ntsync_q_entry *entry = &q->entries[args.count];
+		struct ntsync_obj *obj = entry->obj;
+
+		dev_lock_obj(dev, obj);
+		list_add_tail(&entry->node, &obj->any_waiters);
+		dev_unlock_obj(dev, obj);
+	}
 
 	/* check if we are already signaled */
 
@@ -1069,6 +1094,21 @@  static int ntsync_wait_all(struct ntsync_device *dev, void __user *argp)
 
 	mutex_unlock(&dev->wait_all_lock);
 
+	/*
+	 * Check if the alert event is signaled, making sure to do so only
+	 * after checking if the other objects are signaled.
+	 */
+
+	if (args.alert) {
+		struct ntsync_obj *obj = q->entries[args.count].obj;
+
+		if (atomic_read(&q->signaled) == -1) {
+			bool all = ntsync_lock_obj(dev, obj);
+			try_wake_any_obj(obj);
+			ntsync_unlock_obj(dev, obj, all);
+		}
+	}
+
 	/* sleep */
 
 	ret = ntsync_schedule(q, &args);
@@ -1094,6 +1134,18 @@  static int ntsync_wait_all(struct ntsync_device *dev, void __user *argp)
 
 	mutex_unlock(&dev->wait_all_lock);
 
+	if (args.alert) {
+		struct ntsync_q_entry *entry = &q->entries[args.count];
+		struct ntsync_obj *obj = entry->obj;
+		bool all;
+
+		all = ntsync_lock_obj(dev, obj);
+		list_del(&entry->node);
+		ntsync_unlock_obj(dev, obj, all);
+
+		put_obj(obj);
+	}
+
 	signaled = atomic_read(&q->signaled);
 	if (signaled != -1) {
 		struct ntsync_wait_args __user *user_args = argp;
diff --git a/include/uapi/linux/ntsync.h b/include/uapi/linux/ntsync.h
index 74abeba832f7..4a8095a3fc34 100644
--- a/include/uapi/linux/ntsync.h
+++ b/include/uapi/linux/ntsync.h
@@ -37,7 +37,8 @@  struct ntsync_wait_args {
 	__u32 index;
 	__u32 flags;
 	__u32 owner;
-	__u32 pad[2];
+	__u32 alert;
+	__u32 pad;
 };
 
 #define NTSYNC_MAX_WAIT_COUNT 64