[v1,2/3] mldev: introduce support for IO layout

Message ID 20230830155303.30380-3-syalavarthi@marvell.com (mailing list archive)
State Superseded, archived
Delegated to: Thomas Monjalon
Headers
Series Spec changes to support multi I/O models |

Checks

Context Check Description
ci/checkpatch success coding style OK

Commit Message

Srikanth Yalavarthi Aug. 30, 2023, 3:53 p.m. UTC
  Introduce IO layout in ML device specification. IO layout
defines the expected arrangement of model input and output
buffers in the memory. Packed and Split layout support is
added in the specification.

Updated rte_ml_op to support array of rte_ml_buff_seg
pointers to support packed and split I/O layouts. Updated
ML quantize and dequantize APIs to support rte_ml_buff_seg
pointer arrays. Replaced batch_size with min_batches and
max_batches in rte_ml_model_info.

Implement support for model IO layout in ml/cnxk driver.
Updated the ML test application to support IO layout and
dropped support for '--batches' in test application.

Signed-off-by: Srikanth Yalavarthi <syalavarthi@marvell.com>
---
 app/test-mldev/ml_options.c            |  15 --
 app/test-mldev/ml_options.h            |   2 -
 app/test-mldev/test_inference_common.c | 323 +++++++++++++++++++++----
 app/test-mldev/test_inference_common.h |   6 +
 app/test-mldev/test_model_common.c     |   6 -
 app/test-mldev/test_model_common.h     |   1 -
 doc/guides/tools/testmldev.rst         |   6 -
 drivers/ml/cnxk/cn10k_ml_dev.h         |   3 +
 drivers/ml/cnxk/cn10k_ml_model.c       |   6 +-
 drivers/ml/cnxk/cn10k_ml_ops.c         |  74 +++---
 lib/mldev/meson.build                  |   2 +-
 lib/mldev/rte_mldev.c                  |  12 +-
 lib/mldev/rte_mldev.h                  |  90 +++++--
 lib/mldev/rte_mldev_core.h             |  14 +-
 14 files changed, 415 insertions(+), 145 deletions(-)
  

Patch

diff --git a/app/test-mldev/ml_options.c b/app/test-mldev/ml_options.c
index 816e41fdb70..c0468f5eee4 100644
--- a/app/test-mldev/ml_options.c
+++ b/app/test-mldev/ml_options.c
@@ -28,7 +28,6 @@  ml_options_default(struct ml_options *opt)
 	opt->burst_size = 1;
 	opt->queue_pairs = 1;
 	opt->queue_size = 1;
-	opt->batches = 0;
 	opt->tolerance = 0.0;
 	opt->stats = false;
 	opt->debug = false;
@@ -212,18 +211,6 @@  ml_parse_queue_size(struct ml_options *opt, const char *arg)
 	return ret;
 }
 
-static int
-ml_parse_batches(struct ml_options *opt, const char *arg)
-{
-	int ret;
-
-	ret = parser_read_uint16(&opt->batches, arg);
-	if (ret != 0)
-		ml_err("Invalid option, batches = %s\n", arg);
-
-	return ret;
-}
-
 static int
 ml_parse_tolerance(struct ml_options *opt, const char *arg)
 {
@@ -286,7 +273,6 @@  static struct option lgopts[] = {
 	{ML_BURST_SIZE, 1, 0, 0},
 	{ML_QUEUE_PAIRS, 1, 0, 0},
 	{ML_QUEUE_SIZE, 1, 0, 0},
-	{ML_BATCHES, 1, 0, 0},
 	{ML_TOLERANCE, 1, 0, 0},
 	{ML_STATS, 0, 0, 0},
 	{ML_DEBUG, 0, 0, 0},
@@ -308,7 +294,6 @@  ml_opts_parse_long(int opt_idx, struct ml_options *opt)
 		{ML_BURST_SIZE, ml_parse_burst_size},
 		{ML_QUEUE_PAIRS, ml_parse_queue_pairs},
 		{ML_QUEUE_SIZE, ml_parse_queue_size},
-		{ML_BATCHES, ml_parse_batches},
 		{ML_TOLERANCE, ml_parse_tolerance},
 	};
 
diff --git a/app/test-mldev/ml_options.h b/app/test-mldev/ml_options.h
index 622a4c05fc2..90e22adeac1 100644
--- a/app/test-mldev/ml_options.h
+++ b/app/test-mldev/ml_options.h
@@ -21,7 +21,6 @@ 
 #define ML_BURST_SIZE  ("burst_size")
 #define ML_QUEUE_PAIRS ("queue_pairs")
 #define ML_QUEUE_SIZE  ("queue_size")
-#define ML_BATCHES     ("batches")
 #define ML_TOLERANCE   ("tolerance")
 #define ML_STATS       ("stats")
 #define ML_DEBUG       ("debug")
@@ -44,7 +43,6 @@  struct ml_options {
 	uint16_t burst_size;
 	uint16_t queue_pairs;
 	uint16_t queue_size;
-	uint16_t batches;
 	float tolerance;
 	bool stats;
 	bool debug;
diff --git a/app/test-mldev/test_inference_common.c b/app/test-mldev/test_inference_common.c
index 6bda37b0fab..0018cc92514 100644
--- a/app/test-mldev/test_inference_common.c
+++ b/app/test-mldev/test_inference_common.c
@@ -47,7 +47,10 @@  ml_enqueue_single(void *arg)
 	uint64_t start_cycle;
 	uint32_t burst_enq;
 	uint32_t lcore_id;
+	uint64_t offset;
+	uint64_t bufsz;
 	uint16_t fid;
+	uint32_t i;
 	int ret;
 
 	lcore_id = rte_lcore_id();
@@ -66,24 +69,64 @@  ml_enqueue_single(void *arg)
 	if (ret != 0)
 		goto next_model;
 
-retry:
+retry_req:
 	ret = rte_mempool_get(t->model[fid].io_pool, (void **)&req);
 	if (ret != 0)
-		goto retry;
+		goto retry_req;
+
+retry_inp_segs:
+	ret = rte_mempool_get_bulk(t->buf_seg_pool, (void **)req->inp_buf_segs,
+				   t->model[fid].info.nb_inputs);
+	if (ret != 0)
+		goto retry_inp_segs;
+
+retry_out_segs:
+	ret = rte_mempool_get_bulk(t->buf_seg_pool, (void **)req->out_buf_segs,
+				   t->model[fid].info.nb_outputs);
+	if (ret != 0)
+		goto retry_out_segs;
 
 	op->model_id = t->model[fid].id;
-	op->nb_batches = t->model[fid].nb_batches;
+	op->nb_batches = t->model[fid].info.min_batches;
 	op->mempool = t->op_pool;
+	op->input = req->inp_buf_segs;
+	op->output = req->out_buf_segs;
+	op->user_ptr = req;
 
-	op->input.addr = req->input;
-	op->input.length = t->model[fid].inp_qsize;
-	op->input.next = NULL;
+	if (t->model[fid].info.io_layout == RTE_ML_IO_LAYOUT_PACKED) {
+		op->input[0]->addr = req->input;
+		op->input[0]->iova_addr = rte_mem_virt2iova(req->input);
+		op->input[0]->length = t->model[fid].inp_qsize;
+		op->input[0]->next = NULL;
+
+		op->output[0]->addr = req->output;
+		op->output[0]->iova_addr = rte_mem_virt2iova(req->output);
+		op->output[0]->length = t->model[fid].out_qsize;
+		op->output[0]->next = NULL;
+	} else {
+		offset = 0;
+		for (i = 0; i < t->model[fid].info.nb_inputs; i++) {
+			bufsz = RTE_ALIGN_CEIL(t->model[fid].info.input_info[i].size,
+					       t->cmn.dev_info.align_size);
+			op->input[i]->addr = req->input + offset;
+			op->input[i]->iova_addr = rte_mem_virt2iova(req->input + offset);
+			op->input[i]->length = bufsz;
+			op->input[i]->next = NULL;
+			offset += bufsz;
+		}
 
-	op->output.addr = req->output;
-	op->output.length = t->model[fid].out_qsize;
-	op->output.next = NULL;
+		offset = 0;
+		for (i = 0; i < t->model[fid].info.nb_outputs; i++) {
+			bufsz = RTE_ALIGN_CEIL(t->model[fid].info.output_info[i].size,
+					       t->cmn.dev_info.align_size);
+			op->output[i]->addr = req->output + offset;
+			op->output[i]->iova_addr = rte_mem_virt2iova(req->output + offset);
+			op->output[i]->length = bufsz;
+			op->output[i]->next = NULL;
+			offset += bufsz;
+		}
+	}
 
-	op->user_ptr = req;
 	req->niters++;
 	req->fid = fid;
 
@@ -143,6 +186,10 @@  ml_dequeue_single(void *arg)
 		}
 		req = (struct ml_request *)op->user_ptr;
 		rte_mempool_put(t->model[req->fid].io_pool, req);
+		rte_mempool_put_bulk(t->buf_seg_pool, (void **)op->input,
+				     t->model[req->fid].info.nb_inputs);
+		rte_mempool_put_bulk(t->buf_seg_pool, (void **)op->output,
+				     t->model[req->fid].info.nb_outputs);
 		rte_mempool_put(t->op_pool, op);
 	}
 
@@ -164,9 +211,12 @@  ml_enqueue_burst(void *arg)
 	uint16_t burst_enq;
 	uint32_t lcore_id;
 	uint16_t pending;
+	uint64_t offset;
+	uint64_t bufsz;
 	uint16_t idx;
 	uint16_t fid;
 	uint16_t i;
+	uint16_t j;
 	int ret;
 
 	lcore_id = rte_lcore_id();
@@ -186,25 +236,70 @@  ml_enqueue_burst(void *arg)
 	if (ret != 0)
 		goto next_model;
 
-retry:
+retry_reqs:
 	ret = rte_mempool_get_bulk(t->model[fid].io_pool, (void **)args->reqs, ops_count);
 	if (ret != 0)
-		goto retry;
+		goto retry_reqs;
 
 	for (i = 0; i < ops_count; i++) {
+retry_inp_segs:
+		ret = rte_mempool_get_bulk(t->buf_seg_pool, (void **)args->reqs[i]->inp_buf_segs,
+					   t->model[fid].info.nb_inputs);
+		if (ret != 0)
+			goto retry_inp_segs;
+
+retry_out_segs:
+		ret = rte_mempool_get_bulk(t->buf_seg_pool, (void **)args->reqs[i]->out_buf_segs,
+					   t->model[fid].info.nb_outputs);
+		if (ret != 0)
+			goto retry_out_segs;
+
 		args->enq_ops[i]->model_id = t->model[fid].id;
-		args->enq_ops[i]->nb_batches = t->model[fid].nb_batches;
+		args->enq_ops[i]->nb_batches = t->model[fid].info.min_batches;
 		args->enq_ops[i]->mempool = t->op_pool;
+		args->enq_ops[i]->input = args->reqs[i]->inp_buf_segs;
+		args->enq_ops[i]->output = args->reqs[i]->out_buf_segs;
+		args->enq_ops[i]->user_ptr = args->reqs[i];
 
-		args->enq_ops[i]->input.addr = args->reqs[i]->input;
-		args->enq_ops[i]->input.length = t->model[fid].inp_qsize;
-		args->enq_ops[i]->input.next = NULL;
+		if (t->model[fid].info.io_layout == RTE_ML_IO_LAYOUT_PACKED) {
+			args->enq_ops[i]->input[0]->addr = args->reqs[i]->input;
+			args->enq_ops[i]->input[0]->iova_addr =
+				rte_mem_virt2iova(args->reqs[i]->input);
+			args->enq_ops[i]->input[0]->length = t->model[fid].inp_qsize;
+			args->enq_ops[i]->input[0]->next = NULL;
+
+			args->enq_ops[i]->output[0]->addr = args->reqs[i]->output;
+			args->enq_ops[i]->output[0]->iova_addr =
+				rte_mem_virt2iova(args->reqs[i]->output);
+			args->enq_ops[i]->output[0]->length = t->model[fid].out_qsize;
+			args->enq_ops[i]->output[0]->next = NULL;
+		} else {
+			offset = 0;
+			for (j = 0; j < t->model[fid].info.nb_inputs; j++) {
+				bufsz = RTE_ALIGN_CEIL(t->model[fid].info.input_info[i].size,
+						       t->cmn.dev_info.align_size);
+
+				args->enq_ops[i]->input[j]->addr = args->reqs[i]->input + offset;
+				args->enq_ops[i]->input[j]->iova_addr =
+					rte_mem_virt2iova(args->reqs[i]->input + offset);
+				args->enq_ops[i]->input[j]->length = t->model[fid].inp_qsize;
+				args->enq_ops[i]->input[j]->next = NULL;
+				offset += bufsz;
+			}
 
-		args->enq_ops[i]->output.addr = args->reqs[i]->output;
-		args->enq_ops[i]->output.length = t->model[fid].out_qsize;
-		args->enq_ops[i]->output.next = NULL;
+			offset = 0;
+			for (j = 0; j < t->model[fid].info.nb_outputs; j++) {
+				bufsz = RTE_ALIGN_CEIL(t->model[fid].info.output_info[i].size,
+						       t->cmn.dev_info.align_size);
+				args->enq_ops[i]->output[j]->addr = args->reqs[i]->output + offset;
+				args->enq_ops[i]->output[j]->iova_addr =
+					rte_mem_virt2iova(args->reqs[i]->output + offset);
+				args->enq_ops[i]->output[j]->length = t->model[fid].out_qsize;
+				args->enq_ops[i]->output[j]->next = NULL;
+				offset += bufsz;
+			}
+		}
 
-		args->enq_ops[i]->user_ptr = args->reqs[i];
 		args->reqs[i]->niters++;
 		args->reqs[i]->fid = fid;
 	}
@@ -277,6 +372,11 @@  ml_dequeue_burst(void *arg)
 			req = (struct ml_request *)args->deq_ops[i]->user_ptr;
 			if (req != NULL)
 				rte_mempool_put(t->model[req->fid].io_pool, req);
+
+			rte_mempool_put_bulk(t->buf_seg_pool, (void **)args->deq_ops[i]->input,
+					     t->model[req->fid].info.nb_inputs);
+			rte_mempool_put_bulk(t->buf_seg_pool, (void **)args->deq_ops[i]->output,
+					     t->model[req->fid].info.nb_outputs);
 		}
 		rte_mempool_put_bulk(t->op_pool, (void *)args->deq_ops, burst_deq);
 	}
@@ -315,6 +415,12 @@  test_inference_cap_check(struct ml_options *opt)
 		return false;
 	}
 
+	if (dev_info.max_io < ML_TEST_MAX_IO_SIZE) {
+		ml_err("Insufficient capabilities:  Max I/O, count = %u > (max limit = %u)",
+		       ML_TEST_MAX_IO_SIZE, dev_info.max_io);
+		return false;
+	}
+
 	return true;
 }
 
@@ -403,11 +509,6 @@  test_inference_opt_dump(struct ml_options *opt)
 	ml_dump("tolerance", "%-7.3f", opt->tolerance);
 	ml_dump("stats", "%s", (opt->stats ? "true" : "false"));
 
-	if (opt->batches == 0)
-		ml_dump("batches", "%u (default batch size)", opt->batches);
-	else
-		ml_dump("batches", "%u", opt->batches);
-
 	ml_dump_begin("filelist");
 	for (i = 0; i < opt->nb_filelist; i++) {
 		ml_dump_list("model", i, opt->filelist[i].model);
@@ -492,10 +593,18 @@  void
 test_inference_destroy(struct ml_test *test, struct ml_options *opt)
 {
 	struct test_inference *t;
+	uint32_t lcore_id;
 
 	RTE_SET_USED(opt);
 
 	t = ml_test_priv(test);
+
+	for (lcore_id = 0; lcore_id < RTE_MAX_LCORE; lcore_id++) {
+		rte_free(t->args[lcore_id].enq_ops);
+		rte_free(t->args[lcore_id].deq_ops);
+		rte_free(t->args[lcore_id].reqs);
+	}
+
 	rte_free(t);
 }
 
@@ -572,19 +681,62 @@  ml_request_initialize(struct rte_mempool *mp, void *opaque, void *obj, unsigned
 {
 	struct test_inference *t = ml_test_priv((struct ml_test *)opaque);
 	struct ml_request *req = (struct ml_request *)obj;
+	struct rte_ml_buff_seg dbuff_seg[ML_TEST_MAX_IO_SIZE];
+	struct rte_ml_buff_seg qbuff_seg[ML_TEST_MAX_IO_SIZE];
+	struct rte_ml_buff_seg *q_segs[ML_TEST_MAX_IO_SIZE];
+	struct rte_ml_buff_seg *d_segs[ML_TEST_MAX_IO_SIZE];
+	uint64_t offset;
+	uint64_t bufsz;
+	uint32_t i;
 
 	RTE_SET_USED(mp);
 	RTE_SET_USED(obj_idx);
 
 	req->input = (uint8_t *)obj +
-		     RTE_ALIGN_CEIL(sizeof(struct ml_request), t->cmn.dev_info.min_align_size);
-	req->output = req->input +
-		      RTE_ALIGN_CEIL(t->model[t->fid].inp_qsize, t->cmn.dev_info.min_align_size);
+		     RTE_ALIGN_CEIL(sizeof(struct ml_request), t->cmn.dev_info.align_size);
+	req->output =
+		req->input + RTE_ALIGN_CEIL(t->model[t->fid].inp_qsize, t->cmn.dev_info.align_size);
 	req->niters = 0;
 
+	if (t->model[t->fid].info.io_layout == RTE_ML_IO_LAYOUT_PACKED) {
+		dbuff_seg[0].addr = t->model[t->fid].input;
+		dbuff_seg[0].iova_addr = rte_mem_virt2iova(t->model[t->fid].input);
+		dbuff_seg[0].length = t->model[t->fid].inp_dsize;
+		dbuff_seg[0].next = NULL;
+		d_segs[0] = &dbuff_seg[0];
+
+		qbuff_seg[0].addr = req->input;
+		qbuff_seg[0].iova_addr = rte_mem_virt2iova(req->input);
+		qbuff_seg[0].length = t->model[t->fid].inp_qsize;
+		qbuff_seg[0].next = NULL;
+		q_segs[0] = &qbuff_seg[0];
+	} else {
+		offset = 0;
+		for (i = 0; i < t->model[t->fid].info.nb_inputs; i++) {
+			bufsz = t->model[t->fid].info.input_info[i].nb_elements * sizeof(float);
+			dbuff_seg[i].addr = t->model[t->fid].input + offset;
+			dbuff_seg[i].iova_addr = rte_mem_virt2iova(t->model[t->fid].input + offset);
+			dbuff_seg[i].length = bufsz;
+			dbuff_seg[i].next = NULL;
+			d_segs[i] = &dbuff_seg[i];
+			offset += bufsz;
+		}
+
+		offset = 0;
+		for (i = 0; i < t->model[t->fid].info.nb_inputs; i++) {
+			bufsz = RTE_ALIGN_CEIL(t->model[t->fid].info.input_info[i].size,
+					       t->cmn.dev_info.align_size);
+			qbuff_seg[i].addr = req->input + offset;
+			qbuff_seg[i].iova_addr = rte_mem_virt2iova(req->input + offset);
+			qbuff_seg[i].length = bufsz;
+			qbuff_seg[i].next = NULL;
+			q_segs[i] = &qbuff_seg[i];
+			offset += bufsz;
+		}
+	}
+
 	/* quantize data */
-	rte_ml_io_quantize(t->cmn.opt->dev_id, t->model[t->fid].id, t->model[t->fid].nb_batches,
-			   t->model[t->fid].input, req->input);
+	rte_ml_io_quantize(t->cmn.opt->dev_id, t->model[t->fid].id, d_segs, q_segs);
 }
 
 int
@@ -599,24 +751,39 @@  ml_inference_iomem_setup(struct ml_test *test, struct ml_options *opt, uint16_t
 	uint32_t buff_size;
 	uint32_t mz_size;
 	size_t fsize;
+	uint32_t i;
 	int ret;
 
 	/* get input buffer size */
-	ret = rte_ml_io_input_size_get(opt->dev_id, t->model[fid].id, t->model[fid].nb_batches,
-				       &t->model[fid].inp_qsize, &t->model[fid].inp_dsize);
-	if (ret != 0) {
-		ml_err("Failed to get input size, model : %s\n", opt->filelist[fid].model);
-		return ret;
+	t->model[fid].inp_qsize = 0;
+	for (i = 0; i < t->model[fid].info.nb_inputs; i++) {
+		if (t->model[fid].info.io_layout == RTE_ML_IO_LAYOUT_PACKED)
+			t->model[fid].inp_qsize += t->model[fid].info.input_info[i].size;
+		else
+			t->model[fid].inp_qsize += RTE_ALIGN_CEIL(
+				t->model[fid].info.input_info[i].size, t->cmn.dev_info.align_size);
 	}
 
 	/* get output buffer size */
-	ret = rte_ml_io_output_size_get(opt->dev_id, t->model[fid].id, t->model[fid].nb_batches,
-					&t->model[fid].out_qsize, &t->model[fid].out_dsize);
-	if (ret != 0) {
-		ml_err("Failed to get input size, model : %s\n", opt->filelist[fid].model);
-		return ret;
+	t->model[fid].out_qsize = 0;
+	for (i = 0; i < t->model[fid].info.nb_outputs; i++) {
+		if (t->model[fid].info.io_layout == RTE_ML_IO_LAYOUT_PACKED)
+			t->model[fid].out_qsize += t->model[fid].info.output_info[i].size;
+		else
+			t->model[fid].out_qsize += RTE_ALIGN_CEIL(
+				t->model[fid].info.output_info[i].size, t->cmn.dev_info.align_size);
 	}
 
+	t->model[fid].inp_dsize = 0;
+	for (i = 0; i < t->model[fid].info.nb_inputs; i++)
+		t->model[fid].inp_dsize +=
+			t->model[fid].info.input_info[i].nb_elements * sizeof(float);
+
+	t->model[fid].out_dsize = 0;
+	for (i = 0; i < t->model[fid].info.nb_outputs; i++)
+		t->model[fid].out_dsize +=
+			t->model[fid].info.output_info[i].nb_elements * sizeof(float);
+
 	/* allocate buffer for user data */
 	mz_size = t->model[fid].inp_dsize + t->model[fid].out_dsize;
 	if (strcmp(opt->filelist[fid].reference, "\0") != 0)
@@ -673,9 +840,9 @@  ml_inference_iomem_setup(struct ml_test *test, struct ml_options *opt, uint16_t
 	/* create mempool for quantized input and output buffers. ml_request_initialize is
 	 * used as a callback for object creation.
 	 */
-	buff_size = RTE_ALIGN_CEIL(sizeof(struct ml_request), t->cmn.dev_info.min_align_size) +
-		    RTE_ALIGN_CEIL(t->model[fid].inp_qsize, t->cmn.dev_info.min_align_size) +
-		    RTE_ALIGN_CEIL(t->model[fid].out_qsize, t->cmn.dev_info.min_align_size);
+	buff_size = RTE_ALIGN_CEIL(sizeof(struct ml_request), t->cmn.dev_info.align_size) +
+		    RTE_ALIGN_CEIL(t->model[fid].inp_qsize, t->cmn.dev_info.align_size) +
+		    RTE_ALIGN_CEIL(t->model[fid].out_qsize, t->cmn.dev_info.align_size);
 	nb_buffers = RTE_MIN((uint64_t)ML_TEST_MAX_POOL_SIZE, opt->repetitions);
 
 	t->fid = fid;
@@ -740,6 +907,18 @@  ml_inference_mem_setup(struct ml_test *test, struct ml_options *opt)
 		return -ENOMEM;
 	}
 
+	/* create buf_segs pool of with element of uint8_t. external buffers are attached to the
+	 * buf_segs while queuing inference requests.
+	 */
+	t->buf_seg_pool = rte_mempool_create("ml_test_mbuf_pool", ML_TEST_MAX_POOL_SIZE * 2,
+					     sizeof(struct rte_ml_buff_seg), 0, 0, NULL, NULL, NULL,
+					     NULL, opt->socket_id, 0);
+	if (t->buf_seg_pool == NULL) {
+		ml_err("Failed to create buf_segs pool : %s\n", "ml_test_mbuf_pool");
+		rte_ml_op_pool_free(t->op_pool);
+		return -ENOMEM;
+	}
+
 	return 0;
 }
 
@@ -752,6 +931,9 @@  ml_inference_mem_destroy(struct ml_test *test, struct ml_options *opt)
 
 	/* release op pool */
 	rte_mempool_free(t->op_pool);
+
+	/* release buf_segs pool */
+	rte_mempool_free(t->buf_seg_pool);
 }
 
 static bool
@@ -781,8 +963,10 @@  ml_inference_validation(struct ml_test *test, struct ml_request *req)
 		j = 0;
 next_element:
 		match = false;
-		deviation =
-			(*reference == 0 ? 0 : 100 * fabs(*output - *reference) / fabs(*reference));
+		if ((*reference == 0) && (*output == 0))
+			deviation = 0;
+		else
+			deviation = 100 * fabs(*output - *reference) / fabs(*reference);
 		if (deviation <= t->cmn.opt->tolerance)
 			match = true;
 		else
@@ -817,14 +1001,59 @@  ml_request_finish(struct rte_mempool *mp, void *opaque, void *obj, unsigned int
 	bool error = false;
 	char *dump_path;
 
+	struct rte_ml_buff_seg qbuff_seg[ML_TEST_MAX_IO_SIZE];
+	struct rte_ml_buff_seg dbuff_seg[ML_TEST_MAX_IO_SIZE];
+	struct rte_ml_buff_seg *q_segs[ML_TEST_MAX_IO_SIZE];
+	struct rte_ml_buff_seg *d_segs[ML_TEST_MAX_IO_SIZE];
+	uint64_t offset;
+	uint64_t bufsz;
+	uint32_t i;
+
 	RTE_SET_USED(mp);
 
 	if (req->niters == 0)
 		return;
 
 	t->nb_used++;
-	rte_ml_io_dequantize(t->cmn.opt->dev_id, model->id, t->model[req->fid].nb_batches,
-			     req->output, model->output);
+
+	if (t->model[req->fid].info.io_layout == RTE_ML_IO_LAYOUT_PACKED) {
+		qbuff_seg[0].addr = req->output;
+		qbuff_seg[0].iova_addr = rte_mem_virt2iova(req->output);
+		qbuff_seg[0].length = t->model[req->fid].out_qsize;
+		qbuff_seg[0].next = NULL;
+		q_segs[0] = &qbuff_seg[0];
+
+		dbuff_seg[0].addr = model->output;
+		dbuff_seg[0].iova_addr = rte_mem_virt2iova(model->output);
+		dbuff_seg[0].length = t->model[req->fid].out_dsize;
+		dbuff_seg[0].next = NULL;
+		d_segs[0] = &dbuff_seg[0];
+	} else {
+		offset = 0;
+		for (i = 0; i < t->model[req->fid].info.nb_outputs; i++) {
+			bufsz = RTE_ALIGN_CEIL(t->model[req->fid].info.output_info[i].size,
+					       t->cmn.dev_info.align_size);
+			qbuff_seg[i].addr = req->output + offset;
+			qbuff_seg[i].iova_addr = rte_mem_virt2iova(req->output + offset);
+			qbuff_seg[i].length = bufsz;
+			qbuff_seg[i].next = NULL;
+			q_segs[i] = &qbuff_seg[i];
+			offset += bufsz;
+		}
+
+		offset = 0;
+		for (i = 0; i < t->model[req->fid].info.nb_outputs; i++) {
+			bufsz = t->model[req->fid].info.output_info[i].nb_elements * sizeof(float);
+			dbuff_seg[i].addr = model->output + offset;
+			dbuff_seg[i].iova_addr = rte_mem_virt2iova(model->output + offset);
+			dbuff_seg[i].length = bufsz;
+			dbuff_seg[i].next = NULL;
+			d_segs[i] = &dbuff_seg[i];
+			offset += bufsz;
+		}
+	}
+
+	rte_ml_io_dequantize(t->cmn.opt->dev_id, model->id, q_segs, d_segs);
 
 	if (model->reference == NULL)
 		goto dump_output_pass;
diff --git a/app/test-mldev/test_inference_common.h b/app/test-mldev/test_inference_common.h
index 8f27af25e4f..3f4ba3219be 100644
--- a/app/test-mldev/test_inference_common.h
+++ b/app/test-mldev/test_inference_common.h
@@ -11,11 +11,16 @@ 
 
 #include "test_model_common.h"
 
+#define ML_TEST_MAX_IO_SIZE 32
+
 struct ml_request {
 	uint8_t *input;
 	uint8_t *output;
 	uint16_t fid;
 	uint64_t niters;
+
+	struct rte_ml_buff_seg *inp_buf_segs[ML_TEST_MAX_IO_SIZE];
+	struct rte_ml_buff_seg *out_buf_segs[ML_TEST_MAX_IO_SIZE];
 };
 
 struct ml_core_args {
@@ -38,6 +43,7 @@  struct test_inference {
 
 	/* test specific data */
 	struct ml_model model[ML_TEST_MAX_MODELS];
+	struct rte_mempool *buf_seg_pool;
 	struct rte_mempool *op_pool;
 
 	uint64_t nb_used;
diff --git a/app/test-mldev/test_model_common.c b/app/test-mldev/test_model_common.c
index 8dbb0ff89ff..c517a506117 100644
--- a/app/test-mldev/test_model_common.c
+++ b/app/test-mldev/test_model_common.c
@@ -50,12 +50,6 @@  ml_model_load(struct ml_test *test, struct ml_options *opt, struct ml_model *mod
 		return ret;
 	}
 
-	/* Update number of batches */
-	if (opt->batches == 0)
-		model->nb_batches = model->info.batch_size;
-	else
-		model->nb_batches = opt->batches;
-
 	model->state = MODEL_LOADED;
 
 	return 0;
diff --git a/app/test-mldev/test_model_common.h b/app/test-mldev/test_model_common.h
index c1021ef1b6a..a207e54ab71 100644
--- a/app/test-mldev/test_model_common.h
+++ b/app/test-mldev/test_model_common.h
@@ -31,7 +31,6 @@  struct ml_model {
 	uint8_t *reference;
 
 	struct rte_mempool *io_pool;
-	uint32_t nb_batches;
 };
 
 int ml_model_load(struct ml_test *test, struct ml_options *opt, struct ml_model *model,
diff --git a/doc/guides/tools/testmldev.rst b/doc/guides/tools/testmldev.rst
index 741abd722e2..9b1565a4576 100644
--- a/doc/guides/tools/testmldev.rst
+++ b/doc/guides/tools/testmldev.rst
@@ -106,11 +106,6 @@  The following are the command-line options supported by the test application.
   Queue size would translate into ``rte_ml_dev_qp_conf::nb_desc`` field during queue-pair creation.
   Default value is ``1``.
 
-``--batches <n>``
-  Set the number batches in the input file provided for inference run.
-  When not specified, the test would assume the number of batches
-  is the batch size of the model.
-
 ``--tolerance <n>``
   Set the tolerance value in percentage to be used for output validation.
   Default value is ``0``.
@@ -282,7 +277,6 @@  Supported command line options for inference tests are following::
    --burst_size
    --queue_pairs
    --queue_size
-   --batches
    --tolerance
    --stats
 
diff --git a/drivers/ml/cnxk/cn10k_ml_dev.h b/drivers/ml/cnxk/cn10k_ml_dev.h
index 6ca0b0bb6e2..c73bf7d001a 100644
--- a/drivers/ml/cnxk/cn10k_ml_dev.h
+++ b/drivers/ml/cnxk/cn10k_ml_dev.h
@@ -30,6 +30,9 @@ 
 /* Maximum number of descriptors per queue-pair */
 #define ML_CN10K_MAX_DESC_PER_QP 1024
 
+/* Maximum number of inputs / outputs per model */
+#define ML_CN10K_MAX_INPUT_OUTPUT 32
+
 /* Maximum number of segments for IO data */
 #define ML_CN10K_MAX_SEGMENTS 1
 
diff --git a/drivers/ml/cnxk/cn10k_ml_model.c b/drivers/ml/cnxk/cn10k_ml_model.c
index 26df8d9ff94..e0b750cd8ef 100644
--- a/drivers/ml/cnxk/cn10k_ml_model.c
+++ b/drivers/ml/cnxk/cn10k_ml_model.c
@@ -520,9 +520,11 @@  cn10k_ml_model_info_set(struct rte_ml_dev *dev, struct cn10k_ml_model *model)
 	struct rte_ml_model_info *info;
 	struct rte_ml_io_info *output;
 	struct rte_ml_io_info *input;
+	struct cn10k_ml_dev *mldev;
 	uint8_t i;
 	uint8_t j;
 
+	mldev = dev->data->dev_private;
 	metadata = &model->metadata;
 	info = PLT_PTR_CAST(model->info);
 	input = PLT_PTR_ADD(info, sizeof(struct rte_ml_model_info));
@@ -537,7 +539,9 @@  cn10k_ml_model_info_set(struct rte_ml_dev *dev, struct cn10k_ml_model *model)
 		 metadata->model.version[3]);
 	info->model_id = model->model_id;
 	info->device_id = dev->data->dev_id;
-	info->batch_size = model->batch_size;
+	info->io_layout = RTE_ML_IO_LAYOUT_PACKED;
+	info->min_batches = model->batch_size;
+	info->max_batches = mldev->fw.req->jd.fw_load.cap.s.max_num_batches / model->batch_size;
 	info->nb_inputs = metadata->model.num_input;
 	info->input_info = input;
 	info->nb_outputs = metadata->model.num_output;
diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c
index e3faab81ba3..1d72fb52a6a 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.c
+++ b/drivers/ml/cnxk/cn10k_ml_ops.c
@@ -471,9 +471,9 @@  cn10k_ml_prep_fp_job_descriptor(struct rte_ml_dev *dev, struct cn10k_ml_req *req
 	req->jd.hdr.sp_flags = 0x0;
 	req->jd.hdr.result = roc_ml_addr_ap2mlip(&mldev->roc, &req->result);
 	req->jd.model_run.input_ddr_addr =
-		PLT_U64_CAST(roc_ml_addr_ap2mlip(&mldev->roc, op->input.addr));
+		PLT_U64_CAST(roc_ml_addr_ap2mlip(&mldev->roc, op->input[0]->addr));
 	req->jd.model_run.output_ddr_addr =
-		PLT_U64_CAST(roc_ml_addr_ap2mlip(&mldev->roc, op->output.addr));
+		PLT_U64_CAST(roc_ml_addr_ap2mlip(&mldev->roc, op->output[0]->addr));
 	req->jd.model_run.num_batches = op->nb_batches;
 }
 
@@ -856,7 +856,11 @@  cn10k_ml_model_xstats_reset(struct rte_ml_dev *dev, int32_t model_id, const uint
 static int
 cn10k_ml_cache_model_data(struct rte_ml_dev *dev, uint16_t model_id)
 {
+	struct rte_ml_model_info *info;
 	struct cn10k_ml_model *model;
+	struct rte_ml_buff_seg seg[2];
+	struct rte_ml_buff_seg *inp;
+	struct rte_ml_buff_seg *out;
 	struct rte_ml_op op;
 
 	char str[RTE_MEMZONE_NAMESIZE];
@@ -864,12 +868,22 @@  cn10k_ml_cache_model_data(struct rte_ml_dev *dev, uint16_t model_id)
 	uint64_t isize = 0;
 	uint64_t osize = 0;
 	int ret = 0;
+	uint32_t i;
 
 	model = dev->data->models[model_id];
+	info = (struct rte_ml_model_info *)model->info;
+	inp = &seg[0];
+	out = &seg[1];
 
 	/* Create input and output buffers. */
-	rte_ml_io_input_size_get(dev->data->dev_id, model_id, model->batch_size, &isize, NULL);
-	rte_ml_io_output_size_get(dev->data->dev_id, model_id, model->batch_size, &osize, NULL);
+	for (i = 0; i < info->nb_inputs; i++)
+		isize += info->input_info[i].size;
+
+	for (i = 0; i < info->nb_outputs; i++)
+		osize += info->output_info[i].size;
+
+	isize = model->batch_size * isize;
+	osize = model->batch_size * osize;
 
 	snprintf(str, RTE_MEMZONE_NAMESIZE, "%s_%u", "ml_dummy_io", model_id);
 	mz = plt_memzone_reserve_aligned(str, isize + osize, 0, ML_CN10K_ALIGN_SIZE);
@@ -877,17 +891,22 @@  cn10k_ml_cache_model_data(struct rte_ml_dev *dev, uint16_t model_id)
 		return -ENOMEM;
 	memset(mz->addr, 0, isize + osize);
 
+	seg[0].addr = mz->addr;
+	seg[0].iova_addr = mz->iova;
+	seg[0].length = isize;
+	seg[0].next = NULL;
+
+	seg[1].addr = PLT_PTR_ADD(mz->addr, isize);
+	seg[1].iova_addr = mz->iova + isize;
+	seg[1].length = osize;
+	seg[1].next = NULL;
+
 	op.model_id = model_id;
 	op.nb_batches = model->batch_size;
 	op.mempool = NULL;
 
-	op.input.addr = mz->addr;
-	op.input.length = isize;
-	op.input.next = NULL;
-
-	op.output.addr = PLT_PTR_ADD(op.input.addr, isize);
-	op.output.length = osize;
-	op.output.next = NULL;
+	op.input = &inp;
+	op.output = &out;
 
 	memset(model->req, 0, sizeof(struct cn10k_ml_req));
 	ret = cn10k_ml_inference_sync(dev, &op);
@@ -919,8 +938,9 @@  cn10k_ml_dev_info_get(struct rte_ml_dev *dev, struct rte_ml_dev_info *dev_info)
 	else if (strcmp(mldev->fw.poll_mem, "ddr") == 0)
 		dev_info->max_desc = ML_CN10K_MAX_DESC_PER_QP;
 
+	dev_info->max_io = ML_CN10K_MAX_INPUT_OUTPUT;
 	dev_info->max_segments = ML_CN10K_MAX_SEGMENTS;
-	dev_info->min_align_size = ML_CN10K_ALIGN_SIZE;
+	dev_info->align_size = ML_CN10K_ALIGN_SIZE;
 
 	return 0;
 }
@@ -2139,15 +2159,14 @@  cn10k_ml_io_output_size_get(struct rte_ml_dev *dev, uint16_t model_id, uint32_t
 }
 
 static int
-cn10k_ml_io_quantize(struct rte_ml_dev *dev, uint16_t model_id, uint16_t nb_batches, void *dbuffer,
-		     void *qbuffer)
+cn10k_ml_io_quantize(struct rte_ml_dev *dev, uint16_t model_id, struct rte_ml_buff_seg **dbuffer,
+		     struct rte_ml_buff_seg **qbuffer)
 {
 	struct cn10k_ml_model *model;
 	uint8_t model_input_type;
 	uint8_t *lcl_dbuffer;
 	uint8_t *lcl_qbuffer;
 	uint8_t input_type;
-	uint32_t batch_id;
 	float qscale;
 	uint32_t i;
 	uint32_t j;
@@ -2160,11 +2179,9 @@  cn10k_ml_io_quantize(struct rte_ml_dev *dev, uint16_t model_id, uint16_t nb_batc
 		return -EINVAL;
 	}
 
-	lcl_dbuffer = dbuffer;
-	lcl_qbuffer = qbuffer;
-	batch_id = 0;
+	lcl_dbuffer = dbuffer[0]->addr;
+	lcl_qbuffer = qbuffer[0]->addr;
 
-next_batch:
 	for (i = 0; i < model->metadata.model.num_input; i++) {
 		if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) {
 			input_type = model->metadata.input1[i].input_type;
@@ -2218,23 +2235,18 @@  cn10k_ml_io_quantize(struct rte_ml_dev *dev, uint16_t model_id, uint16_t nb_batc
 		lcl_qbuffer += model->addr.input[i].sz_q;
 	}
 
-	batch_id++;
-	if (batch_id < PLT_DIV_CEIL(nb_batches, model->batch_size))
-		goto next_batch;
-
 	return 0;
 }
 
 static int
-cn10k_ml_io_dequantize(struct rte_ml_dev *dev, uint16_t model_id, uint16_t nb_batches,
-		       void *qbuffer, void *dbuffer)
+cn10k_ml_io_dequantize(struct rte_ml_dev *dev, uint16_t model_id, struct rte_ml_buff_seg **qbuffer,
+		       struct rte_ml_buff_seg **dbuffer)
 {
 	struct cn10k_ml_model *model;
 	uint8_t model_output_type;
 	uint8_t *lcl_qbuffer;
 	uint8_t *lcl_dbuffer;
 	uint8_t output_type;
-	uint32_t batch_id;
 	float dscale;
 	uint32_t i;
 	uint32_t j;
@@ -2247,11 +2259,9 @@  cn10k_ml_io_dequantize(struct rte_ml_dev *dev, uint16_t model_id, uint16_t nb_ba
 		return -EINVAL;
 	}
 
-	lcl_dbuffer = dbuffer;
-	lcl_qbuffer = qbuffer;
-	batch_id = 0;
+	lcl_dbuffer = dbuffer[0]->addr;
+	lcl_qbuffer = qbuffer[0]->addr;
 
-next_batch:
 	for (i = 0; i < model->metadata.model.num_output; i++) {
 		if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) {
 			output_type = model->metadata.output1[i].output_type;
@@ -2306,10 +2316,6 @@  cn10k_ml_io_dequantize(struct rte_ml_dev *dev, uint16_t model_id, uint16_t nb_ba
 		lcl_dbuffer += model->addr.output[i].sz_d;
 	}
 
-	batch_id++;
-	if (batch_id < PLT_DIV_CEIL(nb_batches, model->batch_size))
-		goto next_batch;
-
 	return 0;
 }
 
diff --git a/lib/mldev/meson.build b/lib/mldev/meson.build
index 5769b0640a1..0079ccd2052 100644
--- a/lib/mldev/meson.build
+++ b/lib/mldev/meson.build
@@ -35,7 +35,7 @@  driver_sdk_headers += files(
         'mldev_utils.h',
 )
 
-deps += ['mempool']
+deps += ['mempool', 'mbuf']
 
 if get_option('buildtype').contains('debug')
         cflags += [ '-DRTE_LIBRTE_ML_DEV_DEBUG' ]
diff --git a/lib/mldev/rte_mldev.c b/lib/mldev/rte_mldev.c
index 0d8ccd32124..9a48ed3e944 100644
--- a/lib/mldev/rte_mldev.c
+++ b/lib/mldev/rte_mldev.c
@@ -730,8 +730,8 @@  rte_ml_io_output_size_get(int16_t dev_id, uint16_t model_id, uint32_t nb_batches
 }
 
 int
-rte_ml_io_quantize(int16_t dev_id, uint16_t model_id, uint16_t nb_batches, void *dbuffer,
-		   void *qbuffer)
+rte_ml_io_quantize(int16_t dev_id, uint16_t model_id, struct rte_ml_buff_seg **dbuffer,
+		   struct rte_ml_buff_seg **qbuffer)
 {
 	struct rte_ml_dev *dev;
 
@@ -754,12 +754,12 @@  rte_ml_io_quantize(int16_t dev_id, uint16_t model_id, uint16_t nb_batches, void
 		return -EINVAL;
 	}
 
-	return (*dev->dev_ops->io_quantize)(dev, model_id, nb_batches, dbuffer, qbuffer);
+	return (*dev->dev_ops->io_quantize)(dev, model_id, dbuffer, qbuffer);
 }
 
 int
-rte_ml_io_dequantize(int16_t dev_id, uint16_t model_id, uint16_t nb_batches, void *qbuffer,
-		     void *dbuffer)
+rte_ml_io_dequantize(int16_t dev_id, uint16_t model_id, struct rte_ml_buff_seg **qbuffer,
+		     struct rte_ml_buff_seg **dbuffer)
 {
 	struct rte_ml_dev *dev;
 
@@ -782,7 +782,7 @@  rte_ml_io_dequantize(int16_t dev_id, uint16_t model_id, uint16_t nb_batches, voi
 		return -EINVAL;
 	}
 
-	return (*dev->dev_ops->io_dequantize)(dev, model_id, nb_batches, qbuffer, dbuffer);
+	return (*dev->dev_ops->io_dequantize)(dev, model_id, qbuffer, dbuffer);
 }
 
 /** Initialise rte_ml_op mempool element */
diff --git a/lib/mldev/rte_mldev.h b/lib/mldev/rte_mldev.h
index 6204df09308..316c6fd0188 100644
--- a/lib/mldev/rte_mldev.h
+++ b/lib/mldev/rte_mldev.h
@@ -228,12 +228,14 @@  struct rte_ml_dev_info {
 	/**< Maximum allowed number of descriptors for queue pair by the device.
 	 * @see struct rte_ml_dev_qp_conf::nb_desc
 	 */
+	uint16_t max_io;
+	/**< Maximum number of inputs/outputs supported per model. */
 	uint16_t max_segments;
 	/**< Maximum number of scatter-gather entries supported by the device.
 	 * @see struct rte_ml_buff_seg  struct rte_ml_buff_seg::next
 	 */
-	uint16_t min_align_size;
-	/**< Minimum alignment size of IO buffers used by the device. */
+	uint16_t align_size;
+	/**< Alignment size of IO buffers used by the device. */
 };
 
 /**
@@ -429,10 +431,28 @@  struct rte_ml_op {
 	/**< Reserved for future use. */
 	struct rte_mempool *mempool;
 	/**< Pool from which operation is allocated. */
-	struct rte_ml_buff_seg input;
-	/**< Input buffer to hold the inference data. */
-	struct rte_ml_buff_seg output;
-	/**< Output buffer to hold the inference output by the driver. */
+	struct rte_ml_buff_seg **input;
+	/**< Array of buffer segments to hold the inference input data.
+	 *
+	 * When the model supports IO layout RTE_ML_IO_LAYOUT_PACKED, size of
+	 * the array is 1.
+	 *
+	 * When the model supports IO layout RTE_ML_IO_LAYOUT_SPLIT, size of
+	 * the array is rte_ml_model_info::nb_inputs.
+	 *
+	 * @see struct rte_ml_dev_info::io_layout
+	 */
+	struct rte_ml_buff_seg **output;
+	/**< Array of buffer segments to hold the inference output data.
+	 *
+	 * When the model supports IO layout RTE_ML_IO_LAYOUT_PACKED, size of
+	 * the array is 1.
+	 *
+	 * When the model supports IO layout RTE_ML_IO_LAYOUT_SPLIT, size of
+	 * the array is rte_ml_model_info::nb_outputs.
+	 *
+	 * @see struct rte_ml_dev_info::io_layout
+	 */
 	union {
 		uint64_t user_u64;
 		/**< User data as uint64_t.*/
@@ -863,7 +883,37 @@  enum rte_ml_io_type {
 	/**< 16-bit brain floating point number. */
 };
 
-/** Input and output data information structure
+/** ML I/O buffer layout */
+enum rte_ml_io_layout {
+	RTE_ML_IO_LAYOUT_PACKED,
+	/**< All inputs for the model should packed in a single buffer with
+	 * no padding between individual inputs. The buffer is expected to
+	 * be aligned to rte_ml_dev_info::align_size.
+	 *
+	 * When I/O segmentation is supported by the device, the packed
+	 * data can be split into multiple segments. In this case, each
+	 * segment is expected to be aligned to rte_ml_dev_info::align_size
+	 *
+	 * Same applies to output.
+	 *
+	 * @see struct rte_ml_dev_info::max_segments
+	 */
+	RTE_ML_IO_LAYOUT_SPLIT
+	/**< Each input for the model should be stored as separate buffers
+	 * and each input should be aligned to rte_ml_dev_info::align_size.
+	 *
+	 * When I/O segmentation is supported, each input can be split into
+	 * multiple segments. In this case, each segment is expected to be
+	 * aligned to rte_ml_dev_info::align_size
+	 *
+	 * Same applies to output.
+	 *
+	 * @see struct rte_ml_dev_info::max_segments
+	 */
+};
+
+/**
+ * Input and output data information structure
  *
  * Specifies the type and shape of input and output data.
  */
@@ -873,7 +923,7 @@  struct rte_ml_io_info {
 	uint32_t nb_dims;
 	/**< Number of dimensions in shape */
 	uint32_t *shape;
-	/**< Shape of the tensor */
+	/**< Shape of the tensor for rte_ml_model_info::min_batches of the model. */
 	enum rte_ml_io_type type;
 	/**< Type of data
 	 * @see enum rte_ml_io_type
@@ -894,8 +944,16 @@  struct rte_ml_model_info {
 	/**< Model ID */
 	uint16_t device_id;
 	/**< Device ID */
-	uint16_t batch_size;
-	/**< Maximum number of batches that the model can process simultaneously */
+	enum rte_ml_io_layout io_layout;
+	/**< I/O buffer layout for the model */
+	uint16_t min_batches;
+	/**< Minimum number of batches that the model can process
+	 * in one inference request
+	 */
+	uint16_t max_batches;
+	/**< Maximum number of batches that the model can process
+	 * in one inference request
+	 */
 	uint32_t nb_inputs;
 	/**< Number of inputs */
 	const struct rte_ml_io_info *input_info;
@@ -1021,8 +1079,6 @@  rte_ml_io_output_size_get(int16_t dev_id, uint16_t model_id, uint32_t nb_batches
  *   The identifier of the device.
  * @param[in] model_id
  *   Identifier for the model
- * @param[in] nb_batches
- *   Number of batches in the dequantized input buffer
  * @param[in] dbuffer
  *   Address of dequantized input data
  * @param[in] qbuffer
@@ -1034,8 +1090,8 @@  rte_ml_io_output_size_get(int16_t dev_id, uint16_t model_id, uint32_t nb_batches
  */
 __rte_experimental
 int
-rte_ml_io_quantize(int16_t dev_id, uint16_t model_id, uint16_t nb_batches, void *dbuffer,
-		   void *qbuffer);
+rte_ml_io_quantize(int16_t dev_id, uint16_t model_id, struct rte_ml_buff_seg **dbuffer,
+		   struct rte_ml_buff_seg **qbuffer);
 
 /**
  * Dequantize output data.
@@ -1047,8 +1103,6 @@  rte_ml_io_quantize(int16_t dev_id, uint16_t model_id, uint16_t nb_batches, void
  *   The identifier of the device.
  * @param[in] model_id
  *   Identifier for the model
- * @param[in] nb_batches
- *   Number of batches in the dequantized output buffer
  * @param[in] qbuffer
  *   Address of quantized output data
  * @param[in] dbuffer
@@ -1060,8 +1114,8 @@  rte_ml_io_quantize(int16_t dev_id, uint16_t model_id, uint16_t nb_batches, void
  */
 __rte_experimental
 int
-rte_ml_io_dequantize(int16_t dev_id, uint16_t model_id, uint16_t nb_batches, void *qbuffer,
-		     void *dbuffer);
+rte_ml_io_dequantize(int16_t dev_id, uint16_t model_id, struct rte_ml_buff_seg **qbuffer,
+		     struct rte_ml_buff_seg **dbuffer);
 
 /* ML op pool operations */
 
diff --git a/lib/mldev/rte_mldev_core.h b/lib/mldev/rte_mldev_core.h
index 78b8b7633dd..8530b073162 100644
--- a/lib/mldev/rte_mldev_core.h
+++ b/lib/mldev/rte_mldev_core.h
@@ -523,8 +523,6 @@  typedef int (*mldev_io_output_size_get_t)(struct rte_ml_dev *dev, uint16_t model
  *	ML device pointer.
  * @param model_id
  *	Model ID to use.
- * @param nb_batches
- *	Number of batches.
  * @param dbuffer
  *	Pointer t de-quantized data buffer.
  * @param qbuffer
@@ -534,8 +532,9 @@  typedef int (*mldev_io_output_size_get_t)(struct rte_ml_dev *dev, uint16_t model
  *	- 0 on success.
  *	- <0, error on failure.
  */
-typedef int (*mldev_io_quantize_t)(struct rte_ml_dev *dev, uint16_t model_id, uint16_t nb_batches,
-				   void *dbuffer, void *qbuffer);
+typedef int (*mldev_io_quantize_t)(struct rte_ml_dev *dev, uint16_t model_id,
+				   struct rte_ml_buff_seg **dbuffer,
+				   struct rte_ml_buff_seg **qbuffer);
 
 /**
  * @internal
@@ -546,8 +545,6 @@  typedef int (*mldev_io_quantize_t)(struct rte_ml_dev *dev, uint16_t model_id, ui
  *	ML device pointer.
  * @param model_id
  *	Model ID to use.
- * @param nb_batches
- *	Number of batches.
  * @param qbuffer
  *	Pointer t de-quantized data buffer.
  * @param dbuffer
@@ -557,8 +554,9 @@  typedef int (*mldev_io_quantize_t)(struct rte_ml_dev *dev, uint16_t model_id, ui
  *	- 0 on success.
  *	- <0, error on failure.
  */
-typedef int (*mldev_io_dequantize_t)(struct rte_ml_dev *dev, uint16_t model_id, uint16_t nb_batches,
-				     void *qbuffer, void *dbuffer);
+typedef int (*mldev_io_dequantize_t)(struct rte_ml_dev *dev, uint16_t model_id,
+				     struct rte_ml_buff_seg **qbuffer,
+				     struct rte_ml_buff_seg **dbuffer);
 
 /**
  * @internal