diff mbox series

[RFC,11/11] scsi: Use extract_iter_to_sg()

Message ID 20230630152524.661208-12-dhowells@redhat.com
State New
Headers show
Series None | expand

Commit Message

David Howells June 30, 2023, 3:25 p.m. UTC
Use extract_iter_to_sg() to build a scatterlist from an iterator.

Signed-off-by: David Howells <dhowells@redhat.com>
cc: James E.J. Bottomley <jejb@linux.ibm.com>
cc: Martin K. Petersen <martin.petersen@oracle.com>
cc: Christoph Hellwig <hch@lst.de>
cc: linux-scsi@vger.kernel.org
---
 drivers/vhost/scsi.c | 79 +++++++++++++-------------------------------
 1 file changed, 23 insertions(+), 56 deletions(-)
diff mbox series

Patch

diff --git a/drivers/vhost/scsi.c b/drivers/vhost/scsi.c
index bb10fa4bb4f6..7bb41e2a0d64 100644
--- a/drivers/vhost/scsi.c
+++ b/drivers/vhost/scsi.c
@@ -75,6 +75,9 @@  struct vhost_scsi_cmd {
 	u32 tvc_prot_sgl_count;
 	/* Saved unpacked SCSI LUN for vhost_scsi_target_queue_cmd() */
 	u32 tvc_lun;
+	/* Cleanup modes for scatterlists */
+	unsigned int tvc_need_unpin;
+	unsigned int tvc_prot_need_unpin;
 	/* Pointer to the SGL formatted memory from virtio-scsi */
 	struct scatterlist *tvc_sgl;
 	struct scatterlist *tvc_prot_sgl;
@@ -327,14 +330,13 @@  static void vhost_scsi_release_cmd_res(struct se_cmd *se_cmd)
 	struct vhost_scsi_inflight *inflight = tv_cmd->inflight;
 	int i;
 
-	if (tv_cmd->tvc_sgl_count) {
+	if (tv_cmd->tvc_need_unpin && tv_cmd->tvc_sgl_count)
 		for (i = 0; i < tv_cmd->tvc_sgl_count; i++)
-			put_page(sg_page(&tv_cmd->tvc_sgl[i]));
-	}
-	if (tv_cmd->tvc_prot_sgl_count) {
+			unpin_user_page(sg_page(&tv_cmd->tvc_sgl[i]));
+
+	if (tv_cmd->tvc_prot_need_unpin && tv_cmd->tvc_prot_sgl_count)
 		for (i = 0; i < tv_cmd->tvc_prot_sgl_count; i++)
-			put_page(sg_page(&tv_cmd->tvc_prot_sgl[i]));
-	}
+			unpin_user_page(sg_page(&tv_cmd->tvc_prot_sgl[i]));
 
 	sbitmap_clear_bit(&svq->scsi_tags, se_cmd->map_tag);
 	vhost_scsi_put_inflight(inflight);
@@ -606,38 +608,6 @@  vhost_scsi_get_cmd(struct vhost_virtqueue *vq, struct vhost_scsi_tpg *tpg,
 	return cmd;
 }
 
-/*
- * Map a user memory range into a scatterlist
- *
- * Returns the number of scatterlist entries used or -errno on error.
- */
-static int
-vhost_scsi_map_to_sgl(struct vhost_scsi_cmd *cmd,
-		      struct iov_iter *iter,
-		      struct scatterlist *sgl,
-		      bool write)
-{
-	struct page **pages = cmd->tvc_upages;
-	struct scatterlist *sg = sgl;
-	ssize_t bytes;
-	size_t offset;
-	unsigned int npages = 0;
-
-	bytes = iov_iter_get_pages2(iter, pages, LONG_MAX,
-				VHOST_SCSI_PREALLOC_UPAGES, &offset);
-	/* No pages were pinned */
-	if (bytes <= 0)
-		return bytes < 0 ? bytes : -EFAULT;
-
-	while (bytes) {
-		unsigned n = min_t(unsigned, PAGE_SIZE - offset, bytes);
-		sg_set_page(sg++, pages[npages++], n, offset);
-		bytes -= n;
-		offset = 0;
-	}
-	return npages;
-}
-
 static int
 vhost_scsi_calc_sgls(struct iov_iter *iter, size_t bytes, int max_sgls)
 {
@@ -661,24 +631,19 @@  vhost_scsi_calc_sgls(struct iov_iter *iter, size_t bytes, int max_sgls)
 static int
 vhost_scsi_iov_to_sgl(struct vhost_scsi_cmd *cmd, bool write,
 		      struct iov_iter *iter,
-		      struct scatterlist *sg, int sg_count)
+		      struct scatterlist *sg, int sg_count,
+		      unsigned int *need_unpin)
 {
-	struct scatterlist *p = sg;
-	int ret;
+	struct sg_table sgt = { .sgl = sg };
+	ssize_t ret;
 
-	while (iov_iter_count(iter)) {
-		ret = vhost_scsi_map_to_sgl(cmd, iter, sg, write);
-		if (ret < 0) {
-			while (p < sg) {
-				struct page *page = sg_page(p++);
-				if (page)
-					put_page(page);
-			}
-			return ret;
-		}
-		sg += ret;
-	}
-	return 0;
+	ret = extract_iter_to_sg(iter, LONG_MAX, &sgt, sg_count,
+				 write ? WRITE_FROM_ITER : READ_INTO_ITER);
+	if (ret > 0)
+		sg_mark_end(sg + sgt.nents - 1);
+
+	*need_unpin = iov_iter_extract_will_pin(iter);
+	return ret;
 }
 
 static int
@@ -702,7 +667,8 @@  vhost_scsi_mapal(struct vhost_scsi_cmd *cmd,
 
 		ret = vhost_scsi_iov_to_sgl(cmd, write, prot_iter,
 					    cmd->tvc_prot_sgl,
-					    cmd->tvc_prot_sgl_count);
+					    cmd->tvc_prot_sgl_count,
+					    &cmd->tvc_prot_need_unpin);
 		if (ret < 0) {
 			cmd->tvc_prot_sgl_count = 0;
 			return ret;
@@ -719,7 +685,8 @@  vhost_scsi_mapal(struct vhost_scsi_cmd *cmd,
 		  cmd->tvc_sgl, cmd->tvc_sgl_count);
 
 	ret = vhost_scsi_iov_to_sgl(cmd, write, data_iter,
-				    cmd->tvc_sgl, cmd->tvc_sgl_count);
+				    cmd->tvc_sgl, cmd->tvc_sgl_count,
+				    &cmd->tvc_need_unpin);
 	if (ret < 0) {
 		cmd->tvc_sgl_count = 0;
 		return ret;