diff mbox series

[2/4] VSOCK DRIVER: support communication using additional guest cid

Message ID 20210802120720.547894-3-fuguancheng@bytedance.com
State New
Headers show
Series [1/4] VSOCK DRIVER: Add multi-cid support for guest | expand

Commit Message

fuguancheng Aug. 2, 2021, 12:07 p.m. UTC
Changes in this patch are made to allow the guest communicate
with the host using the additional cids specified when
creating the guest.

In original settings, the packet sent with the additional CIDS will
be rejected when received by the host, the newly added function
vhost_vsock_contain_cid will fix this error.

Now that we have multiple CIDS, the VMADDR_CID_ANY now behaves like
this:
1. The client will use the first available cid specified in the cids
array if VMADDR_CID_ANY is used.
2. The host will still use the original default CID.
3. If a guest server binds to VMADDR_CID_ANY, then the server can
choose to connect to any of the available CIDs for this guest.

Signed-off-by: fuguancheng <fuguancheng@bytedance.com>
---
 drivers/vhost/vsock.c                   | 14 +++++++++++++-
 net/vmw_vsock/af_vsock.c                |  2 +-
 net/vmw_vsock/virtio_transport_common.c |  5 ++++-
 3 files changed, 18 insertions(+), 3 deletions(-)
diff mbox series

Patch

diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
index f66c87de91b8..013f8ebf8189 100644
--- a/drivers/vhost/vsock.c
+++ b/drivers/vhost/vsock.c
@@ -74,6 +74,18 @@  struct vhost_vsock {
 	bool seqpacket_allow;
 };
 
+static bool
+vhost_vsock_contain_cid(struct vhost_vsock *vsock, u32 cid)
+{
+	u32 index;
+
+	for (index = 0; index < vsock->num_cid; index++) {
+		if (cid == vsock->cids[index])
+			return true;
+	}
+	return false;
+}
+
 static u32 vhost_transport_get_local_cid(void)
 {
 	return VHOST_VSOCK_DEFAULT_HOST_CID;
@@ -584,7 +596,7 @@  static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
 
 		/* Only accept correctly addressed packets */
 		if (vsock->num_cid > 0 &&
-		    (pkt->hdr.src_cid) == vsock->cids[0] &&
+			vhost_vsock_contain_cid(vsock, pkt->hdr.src_cid) &&
 		    le64_to_cpu(pkt->hdr.dst_cid) == vhost_transport_get_local_cid())
 			virtio_transport_recv_pkt(&vhost_transport, pkt);
 		else
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index 4e1fbe74013f..c22ae7101e55 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -251,7 +251,7 @@  static struct sock *__vsock_find_connected_socket(struct sockaddr_vm *src,
 	list_for_each_entry(vsk, vsock_connected_sockets(src, dst),
 			    connected_table) {
 		if (vsock_addr_equals_addr(src, &vsk->remote_addr) &&
-		    dst->svm_port == vsk->local_addr.svm_port) {
+		    vsock_addr_equals_addr(&vsk->local_addr, dst)) {
 			return sk_vsock(vsk);
 		}
 	}
diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
index 169ba8b72a63..cb45e2f801f1 100644
--- a/net/vmw_vsock/virtio_transport_common.c
+++ b/net/vmw_vsock/virtio_transport_common.c
@@ -197,7 +197,10 @@  static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
 	if (unlikely(!t_ops))
 		return -EFAULT;
 
-	src_cid = t_ops->transport.get_local_cid();
+	if (vsk->local_addr.svm_cid != VMADDR_CID_ANY)
+		src_cid = vsk->local_addr.svm_cid;
+	else
+		src_cid = t_ops->transport.get_local_cid();
 	src_port = vsk->local_addr.svm_port;
 	if (!info->remote_cid) {
 		dst_cid	= vsk->remote_addr.svm_cid;