diff mbox series

[RFC,v7,12/22] virtio/vsock: fetch length for SEQPACKET record

Message ID 20210323131258.2461163-1-arseny.krasnov@kaspersky.com
State New
Headers show
Series virtio/vsock: introduce SOCK_SEQPACKET support | expand

Commit Message

Arseny Krasnov March 23, 2021, 1:12 p.m. UTC
This adds transport callback which tries to fetch record begin marker
from socket's rx queue. It is called from af_vsock.c before reading data
packets of record.

Signed-off-by: Arseny Krasnov <arseny.krasnov@kaspersky.com>
---
 v6 -> v7:
 1) Now 'virtio_transport_seqpacket_seq_get_len()' returns 0, if rx
    queue of socket is empty. Else it returns length of current message
    to handle.
 2) If dequeue callback is called, but there is no detected length of
    message to dequeue, EAGAIN is returned, and outer loop restarts
    receiving.

 net/vmw_vsock/virtio_transport_common.c | 61 +++++++++++++++++++++++++
 1 file changed, 61 insertions(+)

Comments

Stefano Garzarella March 25, 2021, 10:08 a.m. UTC | #1
On Tue, Mar 23, 2021 at 04:12:55PM +0300, Arseny Krasnov wrote:
>This adds transport callback which tries to fetch record begin marker

>from socket's rx queue. It is called from af_vsock.c before reading data

>packets of record.

>

>Signed-off-by: Arseny Krasnov <arseny.krasnov@kaspersky.com>

>---

> v6 -> v7:

> 1) Now 'virtio_transport_seqpacket_seq_get_len()' returns 0, if rx

>    queue of socket is empty. Else it returns length of current message

>    to handle.

> 2) If dequeue callback is called, but there is no detected length of

>    message to dequeue, EAGAIN is returned, and outer loop restarts

>    receiving.

>

> net/vmw_vsock/virtio_transport_common.c | 61 +++++++++++++++++++++++++

> 1 file changed, 61 insertions(+)

>

>diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c

>index a8f4326e45e8..41f05034593e 100644

>--- a/net/vmw_vsock/virtio_transport_common.c

>+++ b/net/vmw_vsock/virtio_transport_common.c

>@@ -399,6 +399,62 @@ static inline void virtio_transport_remove_pkt(struct virtio_vsock_pkt *pkt)

> 	virtio_transport_free_pkt(pkt);

> }

>

>+static size_t virtio_transport_drop_until_seq_begin(struct 

>virtio_vsock_sock *vvs)

>+{

>+	struct virtio_vsock_pkt *pkt, *n;

>+	size_t bytes_dropped = 0;

>+

>+	list_for_each_entry_safe(pkt, n, &vvs->rx_queue, list) {

>+		if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_SEQ_BEGIN)

>+			break;

>+

>+		bytes_dropped += le32_to_cpu(pkt->hdr.len);

>+		virtio_transport_dec_rx_pkt(vvs, pkt);

>+		virtio_transport_remove_pkt(pkt);

>+	}

>+

>+	return bytes_dropped;

>+}

>+

>+static size_t virtio_transport_seqpacket_seq_get_len(struct vsock_sock *vsk)

>+{

>+	struct virtio_vsock_seq_hdr *seq_hdr;

>+	struct virtio_vsock_sock *vvs;

>+	struct virtio_vsock_pkt *pkt;

>+	size_t bytes_dropped = 0;

>+

>+	vvs = vsk->trans;

>+

>+	spin_lock_bh(&vvs->rx_lock);

>+

>+	/* Have some record to process, return it's length. */

>+	if (vvs->seq_state.user_read_seq_len)

>+		goto out;

>+

>+	/* Fetch all orphaned 'RW' packets and send credit update. */

>+	bytes_dropped = virtio_transport_drop_until_seq_begin(vvs);

>+

>+	if (list_empty(&vvs->rx_queue))

>+		goto out;

>+

>+	pkt = list_first_entry(&vvs->rx_queue, struct virtio_vsock_pkt, list);

>+

>+	vvs->seq_state.user_read_copied = 0;

>+

>+	seq_hdr = (struct virtio_vsock_seq_hdr *)pkt->buf;

>+	vvs->seq_state.user_read_seq_len = le32_to_cpu(seq_hdr->msg_len);

>+	vvs->seq_state.curr_rx_msg_id = le32_to_cpu(seq_hdr->msg_id);

>+	virtio_transport_dec_rx_pkt(vvs, pkt);

>+	virtio_transport_remove_pkt(pkt);

>+out:

>+	spin_unlock_bh(&vvs->rx_lock);

>+

>+	if (bytes_dropped)

>+		virtio_transport_send_credit_update(vsk);

>+

>+	return vvs->seq_state.user_read_seq_len;

>+}

>+

> static int virtio_transport_seqpacket_do_dequeue(struct vsock_sock *vsk,

> 						 struct msghdr *msg,

> 						 bool *msg_ready)

>@@ -522,6 +578,11 @@ virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk,

> 	if (flags & MSG_PEEK)

> 		return -EOPNOTSUPP;

>

>+	*msg_len = virtio_transport_seqpacket_seq_get_len(vsk);

>+

>+	if (*msg_len == 0)

>+		return -EAGAIN;

>+


Okay, I see now, I think you can move this patch before the previous one 
or merge them in a single patch, it is better to review and to bisect.

As mentioned, I think we can return msg_len if 
virtio_transport_seqpacket_do_dequeue() does not fail, otherwise the 
error.

I mean something like this:

static ssize_t virtio_transport_seqpacket_do_dequeue(...)
{
	size_t msg_len;
	ssize_t ret;

	msg_len = virtio_transport_seqpacket_seq_get_len(vsk);
	if (msg_len == 0)
		return -EAGAIN;

	ret = virtio_transport_seqpacket_do_dequeue(vsk, msg, msg_ready);
	if (ret < 0)
		return ret;

	return msg_len;
}

> 	return virtio_transport_seqpacket_do_dequeue(vsk, msg, msg_ready);

> }

> EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_dequeue);

>-- 2.25.1

>
Arseny Krasnov March 25, 2021, 4:02 p.m. UTC | #2
On 25.03.2021 13:08, Stefano Garzarella wrote:
> On Tue, Mar 23, 2021 at 04:12:55PM +0300, Arseny Krasnov wrote:

>> This adds transport callback which tries to fetch record begin marker

> >from socket's rx queue. It is called from af_vsock.c before reading data

>> packets of record.

>>

>> Signed-off-by: Arseny Krasnov <arseny.krasnov@kaspersky.com>

>> ---

>> v6 -> v7:

>> 1) Now 'virtio_transport_seqpacket_seq_get_len()' returns 0, if rx

>>    queue of socket is empty. Else it returns length of current message

>>    to handle.

>> 2) If dequeue callback is called, but there is no detected length of

>>    message to dequeue, EAGAIN is returned, and outer loop restarts

>>    receiving.

>>

>> net/vmw_vsock/virtio_transport_common.c | 61 +++++++++++++++++++++++++

>> 1 file changed, 61 insertions(+)

>>

>> diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c

>> index a8f4326e45e8..41f05034593e 100644

>> --- a/net/vmw_vsock/virtio_transport_common.c

>> +++ b/net/vmw_vsock/virtio_transport_common.c

>> @@ -399,6 +399,62 @@ static inline void virtio_transport_remove_pkt(struct virtio_vsock_pkt *pkt)

>> 	virtio_transport_free_pkt(pkt);

>> }

>>

>> +static size_t virtio_transport_drop_until_seq_begin(struct 

>> virtio_vsock_sock *vvs)

>> +{

>> +	struct virtio_vsock_pkt *pkt, *n;

>> +	size_t bytes_dropped = 0;

>> +

>> +	list_for_each_entry_safe(pkt, n, &vvs->rx_queue, list) {

>> +		if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_SEQ_BEGIN)

>> +			break;

>> +

>> +		bytes_dropped += le32_to_cpu(pkt->hdr.len);

>> +		virtio_transport_dec_rx_pkt(vvs, pkt);

>> +		virtio_transport_remove_pkt(pkt);

>> +	}

>> +

>> +	return bytes_dropped;

>> +}

>> +

>> +static size_t virtio_transport_seqpacket_seq_get_len(struct vsock_sock *vsk)

>> +{

>> +	struct virtio_vsock_seq_hdr *seq_hdr;

>> +	struct virtio_vsock_sock *vvs;

>> +	struct virtio_vsock_pkt *pkt;

>> +	size_t bytes_dropped = 0;

>> +

>> +	vvs = vsk->trans;

>> +

>> +	spin_lock_bh(&vvs->rx_lock);

>> +

>> +	/* Have some record to process, return it's length. */

>> +	if (vvs->seq_state.user_read_seq_len)

>> +		goto out;

>> +

>> +	/* Fetch all orphaned 'RW' packets and send credit update. */

>> +	bytes_dropped = virtio_transport_drop_until_seq_begin(vvs);

>> +

>> +	if (list_empty(&vvs->rx_queue))

>> +		goto out;

>> +

>> +	pkt = list_first_entry(&vvs->rx_queue, struct virtio_vsock_pkt, list);

>> +

>> +	vvs->seq_state.user_read_copied = 0;

>> +

>> +	seq_hdr = (struct virtio_vsock_seq_hdr *)pkt->buf;

>> +	vvs->seq_state.user_read_seq_len = le32_to_cpu(seq_hdr->msg_len);

>> +	vvs->seq_state.curr_rx_msg_id = le32_to_cpu(seq_hdr->msg_id);

>> +	virtio_transport_dec_rx_pkt(vvs, pkt);

>> +	virtio_transport_remove_pkt(pkt);

>> +out:

>> +	spin_unlock_bh(&vvs->rx_lock);

>> +

>> +	if (bytes_dropped)

>> +		virtio_transport_send_credit_update(vsk);

>> +

>> +	return vvs->seq_state.user_read_seq_len;

>> +}

>> +

>> static int virtio_transport_seqpacket_do_dequeue(struct vsock_sock *vsk,

>> 						 struct msghdr *msg,

>> 						 bool *msg_ready)

>> @@ -522,6 +578,11 @@ virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk,

>> 	if (flags & MSG_PEEK)

>> 		return -EOPNOTSUPP;

>>

>> +	*msg_len = virtio_transport_seqpacket_seq_get_len(vsk);

>> +

>> +	if (*msg_len == 0)

>> +		return -EAGAIN;

>> +

> Okay, I see now, I think you can move this patch before the previous one 

> or merge them in a single patch, it is better to review and to bisect.

>

> As mentioned, I think we can return msg_len if 

> virtio_transport_seqpacket_do_dequeue() does not fail, otherwise the 

> error.

>

> I mean something like this:

>

> static ssize_t virtio_transport_seqpacket_do_dequeue(...)

> {

> 	size_t msg_len;

> 	ssize_t ret;

>

> 	msg_len = virtio_transport_seqpacket_seq_get_len(vsk);

> 	if (msg_len == 0)

> 		return -EAGAIN;

>

> 	ret = virtio_transport_seqpacket_do_dequeue(vsk, msg, msg_ready);

> 	if (ret < 0)

> 		return ret;

>

> 	return msg_len;

> }

Ack
>

>> 	return virtio_transport_seqpacket_do_dequeue(vsk, msg, msg_ready);

>> }

>> EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_dequeue);

>> -- 2.25.1

>>

>
diff mbox series

Patch

diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
index a8f4326e45e8..41f05034593e 100644
--- a/net/vmw_vsock/virtio_transport_common.c
+++ b/net/vmw_vsock/virtio_transport_common.c
@@ -399,6 +399,62 @@  static inline void virtio_transport_remove_pkt(struct virtio_vsock_pkt *pkt)
 	virtio_transport_free_pkt(pkt);
 }
 
+static size_t virtio_transport_drop_until_seq_begin(struct virtio_vsock_sock *vvs)
+{
+	struct virtio_vsock_pkt *pkt, *n;
+	size_t bytes_dropped = 0;
+
+	list_for_each_entry_safe(pkt, n, &vvs->rx_queue, list) {
+		if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_SEQ_BEGIN)
+			break;
+
+		bytes_dropped += le32_to_cpu(pkt->hdr.len);
+		virtio_transport_dec_rx_pkt(vvs, pkt);
+		virtio_transport_remove_pkt(pkt);
+	}
+
+	return bytes_dropped;
+}
+
+static size_t virtio_transport_seqpacket_seq_get_len(struct vsock_sock *vsk)
+{
+	struct virtio_vsock_seq_hdr *seq_hdr;
+	struct virtio_vsock_sock *vvs;
+	struct virtio_vsock_pkt *pkt;
+	size_t bytes_dropped = 0;
+
+	vvs = vsk->trans;
+
+	spin_lock_bh(&vvs->rx_lock);
+
+	/* Have some record to process, return it's length. */
+	if (vvs->seq_state.user_read_seq_len)
+		goto out;
+
+	/* Fetch all orphaned 'RW' packets and send credit update. */
+	bytes_dropped = virtio_transport_drop_until_seq_begin(vvs);
+
+	if (list_empty(&vvs->rx_queue))
+		goto out;
+
+	pkt = list_first_entry(&vvs->rx_queue, struct virtio_vsock_pkt, list);
+
+	vvs->seq_state.user_read_copied = 0;
+
+	seq_hdr = (struct virtio_vsock_seq_hdr *)pkt->buf;
+	vvs->seq_state.user_read_seq_len = le32_to_cpu(seq_hdr->msg_len);
+	vvs->seq_state.curr_rx_msg_id = le32_to_cpu(seq_hdr->msg_id);
+	virtio_transport_dec_rx_pkt(vvs, pkt);
+	virtio_transport_remove_pkt(pkt);
+out:
+	spin_unlock_bh(&vvs->rx_lock);
+
+	if (bytes_dropped)
+		virtio_transport_send_credit_update(vsk);
+
+	return vvs->seq_state.user_read_seq_len;
+}
+
 static int virtio_transport_seqpacket_do_dequeue(struct vsock_sock *vsk,
 						 struct msghdr *msg,
 						 bool *msg_ready)
@@ -522,6 +578,11 @@  virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk,
 	if (flags & MSG_PEEK)
 		return -EOPNOTSUPP;
 
+	*msg_len = virtio_transport_seqpacket_seq_get_len(vsk);
+
+	if (*msg_len == 0)
+		return -EAGAIN;
+
 	return virtio_transport_seqpacket_do_dequeue(vsk, msg, msg_ready);
 }
 EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_dequeue);