[v5,10/34] ml/cnxk: update model start and stop functions

Message ID 20231018064806.24145-11-syalavarthi@marvell.com (mailing list archive)
State Superseded, archived
Delegated to: Jerin Jacob
Headers
Series Implementation of revised ml/cnxk driver |

Checks

Context Check Description
ci/checkpatch success coding style OK

Commit Message

Srikanth Yalavarthi Oct. 18, 2023, 6:47 a.m. UTC
  Implemented cnxk wrapper functions to start and stop
ML models. Wrapper functions would invoke the cn10k
model start and stop functions.

Signed-off-by: Srikanth Yalavarthi <syalavarthi@marvell.com>
---
 drivers/ml/cnxk/cn10k_ml_ocm.c |  28 ++--
 drivers/ml/cnxk/cn10k_ml_ocm.h |  12 +-
 drivers/ml/cnxk/cn10k_ml_ops.c | 282 ++++++++++++++++++++-------------
 drivers/ml/cnxk/cn10k_ml_ops.h |   8 +-
 drivers/ml/cnxk/cnxk_ml_ops.c  |  48 +++++-
 drivers/ml/cnxk/cnxk_ml_ops.h  |   1 +
 6 files changed, 240 insertions(+), 139 deletions(-)
  

Patch

diff --git a/drivers/ml/cnxk/cn10k_ml_ocm.c b/drivers/ml/cnxk/cn10k_ml_ocm.c
index d71c36eae6..2197e5e0ed 100644
--- a/drivers/ml/cnxk/cn10k_ml_ocm.c
+++ b/drivers/ml/cnxk/cn10k_ml_ocm.c
@@ -215,11 +215,10 @@  cn10k_ml_ocm_tilecount(uint64_t tilemask, int *start, int *end)
  * scratch & WB pages and OCM allocation mode.
  */
 int
-cn10k_ml_ocm_tilemask_find(struct rte_ml_dev *dev, uint8_t num_tiles, uint16_t wb_pages,
+cn10k_ml_ocm_tilemask_find(struct cnxk_ml_dev *cnxk_mldev, uint8_t num_tiles, uint16_t wb_pages,
 			   uint16_t scratch_pages, uint64_t *tilemask)
 {
 	struct cn10k_ml_dev *cn10k_mldev;
-	struct cnxk_ml_dev *cnxk_mldev;
 	struct cn10k_ml_ocm *ocm;
 
 	uint16_t used_scratch_pages_max;
@@ -238,7 +237,6 @@  cn10k_ml_ocm_tilemask_find(struct rte_ml_dev *dev, uint8_t num_tiles, uint16_t w
 	int max_slot_sz;
 	int page_id;
 
-	cnxk_mldev = dev->data->dev_private;
 	cn10k_mldev = &cnxk_mldev->cn10k_mldev;
 	ocm = &cn10k_mldev->ocm;
 
@@ -333,12 +331,10 @@  cn10k_ml_ocm_tilemask_find(struct rte_ml_dev *dev, uint8_t num_tiles, uint16_t w
 }
 
 void
-cn10k_ml_ocm_reserve_pages(struct rte_ml_dev *dev, uint16_t model_id, uint16_t layer_id,
+cn10k_ml_ocm_reserve_pages(struct cnxk_ml_dev *cnxk_mldev, uint16_t model_id, uint16_t layer_id,
 			   uint64_t tilemask, int wb_page_start, uint16_t wb_pages,
 			   uint16_t scratch_pages)
 {
-	struct cn10k_ml_dev *cn10k_mldev;
-	struct cnxk_ml_dev *cnxk_mldev;
 	struct cnxk_ml_model *model;
 	struct cnxk_ml_layer *layer;
 	struct cn10k_ml_ocm *ocm;
@@ -351,10 +347,8 @@  cn10k_ml_ocm_reserve_pages(struct rte_ml_dev *dev, uint16_t model_id, uint16_t l
 	int tile_id;
 	int page_id;
 
-	cnxk_mldev = dev->data->dev_private;
-	cn10k_mldev = &cnxk_mldev->cn10k_mldev;
-	ocm = &cn10k_mldev->ocm;
-	model = dev->data->models[model_id];
+	ocm = &cnxk_mldev->cn10k_mldev.ocm;
+	model = cnxk_mldev->mldev->data->models[model_id];
 	layer = &model->layer[layer_id];
 
 	/* Get first set bit, tile_start */
@@ -396,12 +390,10 @@  cn10k_ml_ocm_reserve_pages(struct rte_ml_dev *dev, uint16_t model_id, uint16_t l
 }
 
 void
-cn10k_ml_ocm_free_pages(struct rte_ml_dev *dev, uint16_t model_id, uint16_t layer_id)
+cn10k_ml_ocm_free_pages(struct cnxk_ml_dev *cnxk_mldev, uint16_t model_id, uint16_t layer_id)
 {
 	struct cnxk_ml_model *local_model;
 	struct cnxk_ml_layer *local_layer;
-	struct cn10k_ml_dev *cn10k_mldev;
-	struct cnxk_ml_dev *cnxk_mldev;
 	struct cnxk_ml_model *model;
 	struct cnxk_ml_layer *layer;
 	struct cn10k_ml_ocm *ocm;
@@ -416,10 +408,8 @@  cn10k_ml_ocm_free_pages(struct rte_ml_dev *dev, uint16_t model_id, uint16_t laye
 	uint16_t i;
 	uint16_t j;
 
-	cnxk_mldev = dev->data->dev_private;
-	cn10k_mldev = &cnxk_mldev->cn10k_mldev;
-	ocm = &cn10k_mldev->ocm;
-	model = dev->data->models[model_id];
+	ocm = &cnxk_mldev->cn10k_mldev.ocm;
+	model = cnxk_mldev->mldev->data->models[model_id];
 	layer = &model->layer[layer_id];
 
 	/* Update OCM info for WB memory */
@@ -438,8 +428,8 @@  cn10k_ml_ocm_free_pages(struct rte_ml_dev *dev, uint16_t model_id, uint16_t laye
 
 		/* Get max scratch pages required, excluding the current model */
 		scratch_resize_pages = 0;
-		for (i = 0; i < dev->data->nb_models; i++) {
-			local_model = dev->data->models[i];
+		for (i = 0; i < cnxk_mldev->mldev->data->nb_models; i++) {
+			local_model = cnxk_mldev->mldev->data->models[i];
 			if (local_model == NULL)
 				continue;
 
diff --git a/drivers/ml/cnxk/cn10k_ml_ocm.h b/drivers/ml/cnxk/cn10k_ml_ocm.h
index 720f8caf76..97b723a56a 100644
--- a/drivers/ml/cnxk/cn10k_ml_ocm.h
+++ b/drivers/ml/cnxk/cn10k_ml_ocm.h
@@ -8,6 +8,8 @@ 
 #include <rte_mldev.h>
 #include <rte_mldev_pmd.h>
 
+struct cnxk_ml_dev;
+
 /* Number of OCM tiles. */
 #define ML_CN10K_OCM_NUMTILES 0x8
 
@@ -75,12 +77,12 @@  struct cn10k_ml_ocm {
 };
 
 int cn10k_ml_ocm_tilecount(uint64_t tilemask, int *start, int *end);
-int cn10k_ml_ocm_tilemask_find(struct rte_ml_dev *dev, uint8_t num_tiles, uint16_t wb_pages,
+int cn10k_ml_ocm_tilemask_find(struct cnxk_ml_dev *cnxk_mldev, uint8_t num_tiles, uint16_t wb_pages,
 			       uint16_t scratch_pages, uint64_t *tilemask);
-void cn10k_ml_ocm_reserve_pages(struct rte_ml_dev *dev, uint16_t model_id, uint16_t layer_id,
-				uint64_t tilemask, int wb_page_start, uint16_t wb_pages,
-				uint16_t scratch_pages);
-void cn10k_ml_ocm_free_pages(struct rte_ml_dev *dev, uint16_t model_id, uint16_t layer_id);
+void cn10k_ml_ocm_reserve_pages(struct cnxk_ml_dev *cnxk_mldev, uint16_t model_id,
+				uint16_t layer_id, uint64_t tilemask, int wb_page_start,
+				uint16_t wb_pages, uint16_t scratch_pages);
+void cn10k_ml_ocm_free_pages(struct cnxk_ml_dev *cnxk_mldev, uint16_t model_id, uint16_t layer_id);
 void cn10k_ml_ocm_print(struct rte_ml_dev *dev, FILE *fp);
 
 #endif /* _CN10K_ML_OCM_H_ */
diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c
index ad2effb904..c677861645 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.c
+++ b/drivers/ml/cnxk/cn10k_ml_ops.c
@@ -248,26 +248,28 @@  cn10k_ml_model_print(struct rte_ml_dev *dev, uint16_t model_id, FILE *fp)
 }
 
 static void
-cn10k_ml_prep_sp_job_descriptor(struct cn10k_ml_dev *cn10k_mldev, struct cnxk_ml_model *model,
+cn10k_ml_prep_sp_job_descriptor(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_layer *layer,
 				struct cnxk_ml_req *req, enum cn10k_ml_job_type job_type)
 {
 	struct cn10k_ml_model_metadata *metadata;
 	struct cn10k_ml_layer_addr *addr;
+	struct cn10k_ml_dev *cn10k_mldev;
 
-	metadata = &model->glow.metadata;
-	addr = &model->layer[0].glow.addr;
+	cn10k_mldev = &cnxk_mldev->cn10k_mldev;
+	metadata = &layer->glow.metadata;
+	addr = &layer->glow.addr;
 
 	memset(&req->cn10k_req.jd, 0, sizeof(struct cn10k_ml_jd));
 	req->cn10k_req.jd.hdr.jce.w0.u64 = 0;
 	req->cn10k_req.jd.hdr.jce.w1.u64 = PLT_U64_CAST(&req->cn10k_req.status);
-	req->cn10k_req.jd.hdr.model_id = model->model_id;
+	req->cn10k_req.jd.hdr.model_id = layer->index;
 	req->cn10k_req.jd.hdr.job_type = job_type;
 	req->cn10k_req.jd.hdr.fp_flags = 0x0;
 	req->cn10k_req.jd.hdr.result =
 		roc_ml_addr_ap2mlip(&cn10k_mldev->roc, &req->cn10k_req.result);
 
 	if (job_type == ML_CN10K_JOB_TYPE_MODEL_START) {
-		if (!model->glow.metadata.model.ocm_relocatable)
+		if (!layer->glow.metadata.model.ocm_relocatable)
 			req->cn10k_req.jd.hdr.sp_flags = ML_CN10K_SP_FLAGS_OCM_NONRELOCATABLE;
 		else
 			req->cn10k_req.jd.hdr.sp_flags = 0x0;
@@ -291,7 +293,7 @@  cn10k_ml_prep_sp_job_descriptor(struct cn10k_ml_dev *cn10k_mldev, struct cnxk_ml
 		req->cn10k_req.jd.model_start.num_gather_entries = 0;
 		req->cn10k_req.jd.model_start.num_scatter_entries = 0;
 		req->cn10k_req.jd.model_start.tilemask = 0; /* Updated after reserving pages */
-		req->cn10k_req.jd.model_start.batch_size = model->batch_size;
+		req->cn10k_req.jd.model_start.batch_size = layer->batch_size;
 		req->cn10k_req.jd.model_start.ocm_wb_base_address =
 			0; /* Updated after reserving pages */
 		req->cn10k_req.jd.model_start.ocm_wb_range_start =
@@ -323,9 +325,13 @@  cn10k_ml_prep_sp_job_descriptor(struct cn10k_ml_dev *cn10k_mldev, struct cnxk_ml
 }
 
 static __rte_always_inline void
-cn10k_ml_prep_fp_job_descriptor(struct cn10k_ml_dev *cn10k_mldev, struct cnxk_ml_req *req,
+cn10k_ml_prep_fp_job_descriptor(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_req *req,
 				struct rte_ml_op *op)
 {
+	struct cn10k_ml_dev *cn10k_mldev;
+
+	cn10k_mldev = &cnxk_mldev->cn10k_mldev;
+
 	req->cn10k_req.jd.hdr.jce.w0.u64 = 0;
 	req->cn10k_req.jd.hdr.jce.w1.u64 = PLT_U64_CAST(req->status);
 	req->cn10k_req.jd.hdr.model_id = op->model_id;
@@ -714,10 +720,8 @@  cn10k_ml_model_xstats_reset(struct rte_ml_dev *dev, int32_t model_id, const uint
 }
 
 static int
-cn10k_ml_cache_model_data(struct rte_ml_dev *dev, uint16_t model_id)
+cn10k_ml_cache_model_data(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_layer *layer)
 {
-	struct rte_ml_model_info *info;
-	struct cnxk_ml_model *model;
 	struct rte_ml_buff_seg seg[2];
 	struct rte_ml_buff_seg *inp;
 	struct rte_ml_buff_seg *out;
@@ -730,22 +734,20 @@  cn10k_ml_cache_model_data(struct rte_ml_dev *dev, uint16_t model_id)
 	int ret = 0;
 	uint32_t i;
 
-	model = dev->data->models[model_id];
-	info = (struct rte_ml_model_info *)model->info;
 	inp = &seg[0];
 	out = &seg[1];
 
 	/* Create input and output buffers. */
-	for (i = 0; i < info->nb_inputs; i++)
-		isize += info->input_info[i].size;
+	for (i = 0; i < layer->info.nb_inputs; i++)
+		isize += layer->info.input[i].sz_q;
 
-	for (i = 0; i < info->nb_outputs; i++)
-		osize += info->output_info[i].size;
+	for (i = 0; i < layer->info.nb_outputs; i++)
+		osize += layer->info.output[i].sz_q;
 
-	isize = model->batch_size * isize;
-	osize = model->batch_size * osize;
+	isize = layer->batch_size * isize;
+	osize = layer->batch_size * osize;
 
-	snprintf(str, RTE_MEMZONE_NAMESIZE, "%s_%u", "ml_dummy_io", model_id);
+	snprintf(str, RTE_MEMZONE_NAMESIZE, "%s_%u", "ml_dummy_io", layer->index);
 	mz = plt_memzone_reserve_aligned(str, isize + osize, 0, ML_CN10K_ALIGN_SIZE);
 	if (mz == NULL)
 		return -ENOMEM;
@@ -761,15 +763,15 @@  cn10k_ml_cache_model_data(struct rte_ml_dev *dev, uint16_t model_id)
 	seg[1].length = osize;
 	seg[1].next = NULL;
 
-	op.model_id = model_id;
-	op.nb_batches = model->batch_size;
+	op.model_id = layer->index;
+	op.nb_batches = layer->batch_size;
 	op.mempool = NULL;
 
 	op.input = &inp;
 	op.output = &out;
 
-	memset(model->layer[0].glow.req, 0, sizeof(struct cnxk_ml_req));
-	ret = cn10k_ml_inference_sync(dev, &op);
+	memset(layer->glow.req, 0, sizeof(struct cnxk_ml_req));
+	ret = cn10k_ml_inference_sync(cnxk_mldev, &op);
 	plt_memzone_free(mz);
 
 	return ret;
@@ -1506,14 +1508,16 @@  cn10k_ml_model_unload(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *mode
 }
 
 int
-cn10k_ml_model_start(struct rte_ml_dev *dev, uint16_t model_id)
+cn10k_ml_layer_start(void *device, uint16_t model_id, const char *layer_name)
 {
 	struct cn10k_ml_dev *cn10k_mldev;
 	struct cnxk_ml_dev *cnxk_mldev;
 	struct cnxk_ml_model *model;
+	struct cnxk_ml_layer *layer;
 	struct cn10k_ml_ocm *ocm;
 	struct cnxk_ml_req *req;
 
+	uint16_t layer_id = 0;
 	bool job_enqueued;
 	bool job_dequeued;
 	uint8_t num_tiles;
@@ -1524,85 +1528,89 @@  cn10k_ml_model_start(struct rte_ml_dev *dev, uint16_t model_id)
 	bool locked;
 	int ret = 0;
 
-	cnxk_mldev = dev->data->dev_private;
-	cn10k_mldev = &cnxk_mldev->cn10k_mldev;
-	ocm = &cn10k_mldev->ocm;
-	model = dev->data->models[model_id];
+	PLT_SET_USED(layer_name);
 
+	cnxk_mldev = (struct cnxk_ml_dev *)device;
+	if (cnxk_mldev == NULL) {
+		plt_err("Invalid device = %p", device);
+		return -EINVAL;
+	}
+
+	model = cnxk_mldev->mldev->data->models[model_id];
 	if (model == NULL) {
 		plt_err("Invalid model_id = %u", model_id);
 		return -EINVAL;
 	}
 
+	layer = &model->layer[layer_id];
+	cn10k_mldev = &cnxk_mldev->cn10k_mldev;
+	ocm = &cn10k_mldev->ocm;
+
 	/* Prepare JD */
-	req = model->layer[0].glow.req;
-	cn10k_ml_prep_sp_job_descriptor(cn10k_mldev, model, req, ML_CN10K_JOB_TYPE_MODEL_START);
+	req = layer->glow.req;
+	cn10k_ml_prep_sp_job_descriptor(cnxk_mldev, layer, req, ML_CN10K_JOB_TYPE_MODEL_START);
 	req->cn10k_req.result.error_code = 0x0;
 	req->cn10k_req.result.user_ptr = NULL;
 
 	plt_write64(ML_CNXK_POLL_JOB_START, &req->cn10k_req.status);
 	plt_wmb();
 
-	num_tiles = model->layer[0].glow.metadata.model.tile_end -
-		    model->layer[0].glow.metadata.model.tile_start + 1;
+	num_tiles = layer->glow.metadata.model.tile_end - layer->glow.metadata.model.tile_start + 1;
 
 	locked = false;
 	while (!locked) {
 		if (plt_spinlock_trylock(&model->lock) != 0) {
-			if (model->state == ML_CNXK_MODEL_STATE_STARTED) {
-				plt_ml_dbg("Model already started, model = 0x%016lx",
-					   PLT_U64_CAST(model));
+			if (layer->state == ML_CNXK_LAYER_STATE_STARTED) {
+				plt_ml_dbg("Layer already started, model_id = %u, layer_id = %u",
+					   model->model_id, layer_id);
 				plt_spinlock_unlock(&model->lock);
 				return 1;
 			}
 
-			if (model->state == ML_CNXK_MODEL_STATE_JOB_ACTIVE) {
-				plt_err("A slow-path job is active for the model = 0x%016lx",
-					PLT_U64_CAST(model));
+			if (layer->state == ML_CNXK_LAYER_STATE_JOB_ACTIVE) {
+				plt_err("A slow-path job is active for the model_id = %u",
+					model->model_id);
 				plt_spinlock_unlock(&model->lock);
 				return -EBUSY;
 			}
 
-			model->state = ML_CNXK_MODEL_STATE_JOB_ACTIVE;
+			layer->state = ML_CNXK_LAYER_STATE_JOB_ACTIVE;
 			plt_spinlock_unlock(&model->lock);
 			locked = true;
 		}
 	}
 
-	while (!model->layer[0].glow.ocm_map.ocm_reserved) {
+	while (!layer->glow.ocm_map.ocm_reserved) {
 		if (plt_spinlock_trylock(&ocm->lock) != 0) {
 			wb_page_start = cn10k_ml_ocm_tilemask_find(
-				dev, num_tiles, model->layer[0].glow.ocm_map.wb_pages,
-				model->layer[0].glow.ocm_map.scratch_pages, &tilemask);
+				cnxk_mldev, num_tiles, layer->glow.ocm_map.wb_pages,
+				layer->glow.ocm_map.scratch_pages, &tilemask);
 
 			if (wb_page_start == -1) {
 				plt_err("Free pages not available on OCM tiles");
-				plt_err("Failed to start model = 0x%016lx, name = %s",
-					PLT_U64_CAST(model),
-					model->layer[0].glow.metadata.model.name);
-
+				plt_err("Failed to start layer, model_id = %u, layer_id = %u",
+					model->model_id, layer_id);
 				plt_spinlock_unlock(&ocm->lock);
 				return -ENOMEM;
 			}
 
-			model->layer[0].glow.ocm_map.tilemask = tilemask;
-			model->layer[0].glow.ocm_map.wb_page_start = wb_page_start;
+			layer->glow.ocm_map.tilemask = tilemask;
+			layer->glow.ocm_map.wb_page_start = wb_page_start;
 
-			cn10k_ml_ocm_reserve_pages(dev, model->model_id, 0,
-						   model->layer[0].glow.ocm_map.tilemask,
-						   model->layer[0].glow.ocm_map.wb_page_start,
-						   model->layer[0].glow.ocm_map.wb_pages,
-						   model->layer[0].glow.ocm_map.scratch_pages);
-			model->layer[0].glow.ocm_map.ocm_reserved = true;
+			cn10k_ml_ocm_reserve_pages(
+				cnxk_mldev, model->model_id, layer_id, layer->glow.ocm_map.tilemask,
+				layer->glow.ocm_map.wb_page_start, layer->glow.ocm_map.wb_pages,
+				layer->glow.ocm_map.scratch_pages);
+			layer->glow.ocm_map.ocm_reserved = true;
 			plt_spinlock_unlock(&ocm->lock);
 		}
 	}
 
 	/* Update JD */
-	cn10k_ml_ocm_tilecount(model->layer[0].glow.ocm_map.tilemask, &tile_start, &tile_end);
+	cn10k_ml_ocm_tilecount(layer->glow.ocm_map.tilemask, &tile_start, &tile_end);
 	req->cn10k_req.jd.model_start.tilemask = GENMASK_ULL(tile_end, tile_start);
 	req->cn10k_req.jd.model_start.ocm_wb_base_address =
-		model->layer[0].glow.ocm_map.wb_page_start * ocm->page_size;
+		layer->glow.ocm_map.wb_page_start * ocm->page_size;
 
 	job_enqueued = false;
 	job_dequeued = false;
@@ -1636,66 +1644,94 @@  cn10k_ml_model_start(struct rte_ml_dev *dev, uint16_t model_id)
 	locked = false;
 	while (!locked) {
 		if (plt_spinlock_trylock(&model->lock) != 0) {
-			if (ret == 0) {
-				model->state = ML_CNXK_MODEL_STATE_STARTED;
-				cnxk_mldev->nb_models_started++;
-			} else {
-				model->state = ML_CNXK_MODEL_STATE_UNKNOWN;
-			}
+			if (ret == 0)
+				layer->state = ML_CNXK_LAYER_STATE_STARTED;
+			else
+				layer->state = ML_CNXK_LAYER_STATE_UNKNOWN;
 
 			plt_spinlock_unlock(&model->lock);
 			locked = true;
 		}
 	}
 
-	if (model->state == ML_CNXK_MODEL_STATE_UNKNOWN) {
-		while (model->layer[0].glow.ocm_map.ocm_reserved) {
+	if (layer->state == ML_CNXK_LAYER_STATE_UNKNOWN) {
+		while (layer->glow.ocm_map.ocm_reserved) {
 			if (plt_spinlock_trylock(&ocm->lock) != 0) {
-				cn10k_ml_ocm_free_pages(dev, model->model_id, 0);
-				model->layer[0].glow.ocm_map.ocm_reserved = false;
-				model->layer[0].glow.ocm_map.tilemask = 0x0;
+				cn10k_ml_ocm_free_pages(cnxk_mldev, model->model_id, layer_id);
+				layer->glow.ocm_map.ocm_reserved = false;
+				layer->glow.ocm_map.tilemask = 0x0;
 				plt_spinlock_unlock(&ocm->lock);
 			}
 		}
 	}
 
-	if (ret < 0) { /* Call unload to update model and FW state, ignore error */
-		rte_ml_model_stop(dev->data->dev_id, model_id);
+	if (ret < 0) {
+		cn10k_ml_layer_stop(device, model_id, layer_name);
 	} else {
-		if (cn10k_mldev->cache_model_data && roc_model_is_cn10ka())
-			ret = cn10k_ml_cache_model_data(dev, model_id);
+		if (cn10k_mldev->cache_model_data)
+			ret = cn10k_ml_cache_model_data(cnxk_mldev, layer);
 	}
 
 	return ret;
 }
 
 int
-cn10k_ml_model_stop(struct rte_ml_dev *dev, uint16_t model_id)
+cn10k_ml_model_start(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model)
+{
+	struct cnxk_ml_layer *layer;
+	int ret;
+
+	layer = &model->layer[0];
+	ret = cn10k_ml_layer_start(cnxk_mldev, model->model_id, layer->name);
+	if (ret != 0) {
+		plt_err("CN10K Model start failed, model_id = %u, error = %d", model->model_id,
+			ret);
+		return ret;
+	}
+
+	cnxk_mldev->nb_models_started++;
+	model->state = ML_CNXK_MODEL_STATE_STARTED;
+
+	return 0;
+}
+
+int
+cn10k_ml_layer_stop(void *device, uint16_t model_id, const char *layer_name)
 {
 	struct cn10k_ml_dev *cn10k_mldev;
 	struct cnxk_ml_dev *cnxk_mldev;
 	struct cnxk_ml_model *model;
+	struct cnxk_ml_layer *layer;
 	struct cn10k_ml_ocm *ocm;
 	struct cnxk_ml_req *req;
 
+	uint16_t layer_id = 0;
 	bool job_enqueued;
 	bool job_dequeued;
 	bool locked;
 	int ret = 0;
 
-	cnxk_mldev = dev->data->dev_private;
-	cn10k_mldev = &cnxk_mldev->cn10k_mldev;
-	ocm = &cn10k_mldev->ocm;
-	model = dev->data->models[model_id];
+	PLT_SET_USED(layer_name);
+
+	cnxk_mldev = (struct cnxk_ml_dev *)device;
+	if (cnxk_mldev == NULL) {
+		plt_err("Invalid device = %p", device);
+		return -EINVAL;
+	}
 
+	model = cnxk_mldev->mldev->data->models[model_id];
 	if (model == NULL) {
 		plt_err("Invalid model_id = %u", model_id);
 		return -EINVAL;
 	}
 
+	layer = &model->layer[layer_id];
+	cn10k_mldev = &cnxk_mldev->cn10k_mldev;
+	ocm = &cn10k_mldev->ocm;
+
 	/* Prepare JD */
-	req = model->layer[0].glow.req;
-	cn10k_ml_prep_sp_job_descriptor(cn10k_mldev, model, req, ML_CN10K_JOB_TYPE_MODEL_STOP);
+	req = layer->glow.req;
+	cn10k_ml_prep_sp_job_descriptor(cnxk_mldev, layer, req, ML_CN10K_JOB_TYPE_MODEL_STOP);
 	req->cn10k_req.result.error_code = 0x0;
 	req->cn10k_req.result.user_ptr = NULL;
 
@@ -1705,31 +1741,31 @@  cn10k_ml_model_stop(struct rte_ml_dev *dev, uint16_t model_id)
 	locked = false;
 	while (!locked) {
 		if (plt_spinlock_trylock(&model->lock) != 0) {
-			if (model->state == ML_CNXK_MODEL_STATE_LOADED) {
-				plt_ml_dbg("Model not started, model = 0x%016lx",
-					   PLT_U64_CAST(model));
+			if (layer->state == ML_CNXK_LAYER_STATE_LOADED) {
+				plt_ml_dbg("Layer not started, model_id = %u, layer_id = %u",
+					   model->model_id, layer_id);
 				plt_spinlock_unlock(&model->lock);
 				return 1;
 			}
 
-			if (model->state == ML_CNXK_MODEL_STATE_JOB_ACTIVE) {
-				plt_err("A slow-path job is active for the model = 0x%016lx",
-					PLT_U64_CAST(model));
+			if (layer->state == ML_CNXK_LAYER_STATE_JOB_ACTIVE) {
+				plt_err("A slow-path job is active for the layer, model_id = %u, layer_id = %u",
+					model->model_id, layer_id);
 				plt_spinlock_unlock(&model->lock);
 				return -EBUSY;
 			}
 
-			model->state = ML_CNXK_MODEL_STATE_JOB_ACTIVE;
+			layer->state = ML_CNXK_LAYER_STATE_JOB_ACTIVE;
 			plt_spinlock_unlock(&model->lock);
 			locked = true;
 		}
 	}
 
-	while (model->layer[0].glow.ocm_map.ocm_reserved) {
+	while (layer->glow.ocm_map.ocm_reserved) {
 		if (plt_spinlock_trylock(&ocm->lock) != 0) {
-			cn10k_ml_ocm_free_pages(dev, model->model_id, 0);
-			model->layer[0].glow.ocm_map.ocm_reserved = false;
-			model->layer[0].glow.ocm_map.tilemask = 0x0;
+			cn10k_ml_ocm_free_pages(cnxk_mldev, model->model_id, layer_id);
+			layer->glow.ocm_map.ocm_reserved = false;
+			layer->glow.ocm_map.tilemask = 0x0;
 			plt_spinlock_unlock(&ocm->lock);
 		}
 	}
@@ -1766,8 +1802,11 @@  cn10k_ml_model_stop(struct rte_ml_dev *dev, uint16_t model_id)
 	locked = false;
 	while (!locked) {
 		if (plt_spinlock_trylock(&model->lock) != 0) {
-			cnxk_mldev->nb_models_stopped++;
-			model->state = ML_CNXK_MODEL_STATE_LOADED;
+			if (ret == 0)
+				layer->state = ML_CNXK_LAYER_STATE_LOADED;
+			else
+				layer->state = ML_CNXK_LAYER_STATE_UNKNOWN;
+
 			plt_spinlock_unlock(&model->lock);
 			locked = true;
 		}
@@ -1776,6 +1815,25 @@  cn10k_ml_model_stop(struct rte_ml_dev *dev, uint16_t model_id)
 	return ret;
 }
 
+int
+cn10k_ml_model_stop(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model)
+{
+	struct cnxk_ml_layer *layer;
+	int ret;
+
+	layer = &model->layer[0];
+	ret = cn10k_ml_layer_stop(cnxk_mldev, model->model_id, layer->name);
+	if (ret != 0) {
+		plt_err("CN10K Model stop failed, model_id = %u, error = %d", model->model_id, ret);
+		return ret;
+	}
+
+	cnxk_mldev->nb_models_stopped++;
+	model->state = ML_CNXK_MODEL_STATE_LOADED;
+
+	return 0;
+}
+
 int
 cn10k_ml_model_info_get(struct rte_ml_dev *dev, uint16_t model_id,
 			struct rte_ml_model_info *model_info)
@@ -2003,30 +2061,35 @@  queue_free_count(uint64_t head, uint64_t tail, uint64_t nb_desc)
 }
 
 static __rte_always_inline void
-cn10k_ml_result_update(struct rte_ml_dev *dev, int qp_id, struct cnxk_ml_req *req)
+cn10k_ml_result_update(struct cnxk_ml_dev *cnxk_mldev, int qp_id, struct cnxk_ml_req *req)
 {
 	union cn10k_ml_error_code *error_code;
 	struct cn10k_ml_layer_xstats *xstats;
 	struct cn10k_ml_dev *cn10k_mldev;
-	struct cnxk_ml_dev *cnxk_mldev;
 	struct cn10k_ml_result *result;
 	struct cnxk_ml_model *model;
+	struct cnxk_ml_layer *layer;
 	struct cnxk_ml_qp *qp;
 	struct rte_ml_op *op;
 	uint64_t hw_latency;
 	uint64_t fw_latency;
+	uint16_t model_id;
+	uint16_t layer_id;
 
 	result = &req->cn10k_req.result;
 	op = req->op;
 
 	if (likely(result->error_code == 0)) {
-		model = dev->data->models[op->model_id];
+		model_id = cnxk_mldev->index_map[op->model_id].model_id;
+		layer_id = cnxk_mldev->index_map[op->model_id].layer_id;
+		model = cnxk_mldev->mldev->data->models[model_id];
+		layer = &model->layer[layer_id];
 		if (likely(qp_id >= 0)) {
-			qp = dev->data->queue_pairs[qp_id];
+			qp = cnxk_mldev->mldev->data->queue_pairs[qp_id];
 			qp->stats.dequeued_count++;
-			xstats = &model->layer[0].glow.burst_xstats[qp_id];
+			xstats = &layer->glow.burst_xstats[qp_id];
 		} else {
-			xstats = model->layer[0].glow.sync_xstats;
+			xstats = layer->glow.sync_xstats;
 		}
 
 		if (unlikely(xstats->dequeued_count == xstats->hw_reset_count)) {
@@ -2054,14 +2117,13 @@  cn10k_ml_result_update(struct rte_ml_dev *dev, int qp_id, struct cnxk_ml_req *re
 		op->status = RTE_ML_OP_STATUS_SUCCESS;
 	} else {
 		if (likely(qp_id >= 0)) {
-			qp = dev->data->queue_pairs[qp_id];
+			qp = cnxk_mldev->mldev->data->queue_pairs[qp_id];
 			qp->stats.dequeue_err_count++;
 		}
 
 		/* Handle driver error */
 		error_code = (union cn10k_ml_error_code *)&result->error_code;
 		if (error_code->s.etype == ML_ETYPE_DRIVER) {
-			cnxk_mldev = dev->data->dev_private;
 			cn10k_mldev = &cnxk_mldev->cn10k_mldev;
 
 			/* Check for exception */
@@ -2116,7 +2178,7 @@  cn10k_ml_enqueue_burst(struct rte_ml_dev *dev, uint16_t qp_id, struct rte_ml_op
 	req = &queue->reqs[head];
 
 	cn10k_mldev->set_poll_addr(req);
-	cn10k_ml_prep_fp_job_descriptor(cn10k_mldev, req, op);
+	cn10k_ml_prep_fp_job_descriptor(cnxk_mldev, req, op);
 
 	memset(&req->cn10k_req.result, 0, sizeof(struct cn10k_ml_result));
 	error_code = (union cn10k_ml_error_code *)&req->cn10k_req.result.error_code;
@@ -2183,7 +2245,7 @@  cn10k_ml_dequeue_burst(struct rte_ml_dev *dev, uint16_t qp_id, struct rte_ml_op
 		}
 	}
 
-	cn10k_ml_result_update(dev, qp_id, req);
+	cn10k_ml_result_update(cnxk_mldev, qp_id, req);
 	ops[count] = req->op;
 
 	queue_index_advance(&tail, qp->nb_desc);
@@ -2232,23 +2294,27 @@  cn10k_ml_op_error_get(struct rte_ml_dev *dev, struct rte_ml_op *op, struct rte_m
 }
 
 __rte_hot int
-cn10k_ml_inference_sync(struct rte_ml_dev *dev, struct rte_ml_op *op)
+cn10k_ml_inference_sync(struct cnxk_ml_dev *cnxk_mldev, struct rte_ml_op *op)
 {
 	union cn10k_ml_error_code *error_code;
 	struct cn10k_ml_dev *cn10k_mldev;
-	struct cnxk_ml_dev *cnxk_mldev;
 	struct cnxk_ml_model *model;
+	struct cnxk_ml_layer *layer;
 	struct cnxk_ml_req *req;
+	uint16_t model_id;
+	uint16_t layer_id;
 	bool timeout;
 	int ret = 0;
 
-	cnxk_mldev = dev->data->dev_private;
 	cn10k_mldev = &cnxk_mldev->cn10k_mldev;
-	model = dev->data->models[op->model_id];
-	req = model->layer[0].glow.req;
+	model_id = cnxk_mldev->index_map[op->model_id].model_id;
+	layer_id = cnxk_mldev->index_map[op->model_id].layer_id;
+	model = cnxk_mldev->mldev->data->models[model_id];
+	layer = &model->layer[layer_id];
+	req = layer->glow.req;
 
 	cn10k_ml_set_poll_addr(req);
-	cn10k_ml_prep_fp_job_descriptor(cn10k_mldev, req, op);
+	cn10k_ml_prep_fp_job_descriptor(cnxk_mldev, req, op);
 
 	memset(&req->cn10k_req.result, 0, sizeof(struct cn10k_ml_result));
 	error_code = (union cn10k_ml_error_code *)&req->cn10k_req.result.error_code;
@@ -2284,7 +2350,7 @@  cn10k_ml_inference_sync(struct rte_ml_dev *dev, struct rte_ml_op *op)
 	if (timeout)
 		ret = -ETIME;
 	else
-		cn10k_ml_result_update(dev, -1, req);
+		cn10k_ml_result_update(cnxk_mldev, -1, req);
 
 error_enqueue:
 	return ret;
diff --git a/drivers/ml/cnxk/cn10k_ml_ops.h b/drivers/ml/cnxk/cn10k_ml_ops.h
index 677219dfdf..a222a43d55 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.h
+++ b/drivers/ml/cnxk/cn10k_ml_ops.h
@@ -315,8 +315,8 @@  int cn10k_ml_dev_xstats_reset(struct rte_ml_dev *dev, enum rte_ml_dev_xstats_mod
 int cn10k_ml_model_load(struct cnxk_ml_dev *cnxk_mldev, struct rte_ml_model_params *params,
 			struct cnxk_ml_model *model);
 int cn10k_ml_model_unload(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model);
-int cn10k_ml_model_start(struct rte_ml_dev *dev, uint16_t model_id);
-int cn10k_ml_model_stop(struct rte_ml_dev *dev, uint16_t model_id);
+int cn10k_ml_model_start(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model);
+int cn10k_ml_model_stop(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model);
 int cn10k_ml_model_info_get(struct rte_ml_dev *dev, uint16_t model_id,
 			    struct rte_ml_model_info *model_info);
 int cn10k_ml_model_params_update(struct rte_ml_dev *dev, uint16_t model_id, void *buffer);
@@ -335,7 +335,7 @@  __rte_hot uint16_t cn10k_ml_dequeue_burst(struct rte_ml_dev *dev, uint16_t qp_id
 					  struct rte_ml_op **ops, uint16_t nb_ops);
 __rte_hot int cn10k_ml_op_error_get(struct rte_ml_dev *dev, struct rte_ml_op *op,
 				    struct rte_ml_op_error *error);
-__rte_hot int cn10k_ml_inference_sync(struct rte_ml_dev *dev, struct rte_ml_op *op);
+__rte_hot int cn10k_ml_inference_sync(struct cnxk_ml_dev *cnxk_mldev, struct rte_ml_op *op);
 
 /* Misc ops */
 void cn10k_ml_qp_initialize(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_qp *qp);
@@ -344,5 +344,7 @@  void cn10k_ml_qp_initialize(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_qp *q
 int cn10k_ml_layer_load(void *device, uint16_t model_id, const char *layer_name, uint8_t *buffer,
 			size_t size, uint16_t *index);
 int cn10k_ml_layer_unload(void *device, uint16_t model_id, const char *layer_name);
+int cn10k_ml_layer_start(void *device, uint16_t model_id, const char *layer_name);
+int cn10k_ml_layer_stop(void *device, uint16_t model_id, const char *layer_name);
 
 #endif /* _CN10K_ML_OPS_H_ */
diff --git a/drivers/ml/cnxk/cnxk_ml_ops.c b/drivers/ml/cnxk/cnxk_ml_ops.c
index 1d8b84269d..b61ed45876 100644
--- a/drivers/ml/cnxk/cnxk_ml_ops.c
+++ b/drivers/ml/cnxk/cnxk_ml_ops.c
@@ -240,7 +240,7 @@  cnxk_ml_dev_configure(struct rte_ml_dev *dev, const struct rte_ml_dev_config *co
 			model = dev->data->models[model_id];
 			if (model != NULL) {
 				if (model->state == ML_CNXK_MODEL_STATE_STARTED) {
-					if (cn10k_ml_model_stop(dev, model_id) != 0)
+					if (cnxk_ml_model_stop(dev, model_id) != 0)
 						plt_err("Could not stop model %u", model_id);
 				}
 				if (model->state == ML_CNXK_MODEL_STATE_LOADED) {
@@ -332,7 +332,7 @@  cnxk_ml_dev_close(struct rte_ml_dev *dev)
 		model = dev->data->models[model_id];
 		if (model != NULL) {
 			if (model->state == ML_CNXK_MODEL_STATE_STARTED) {
-				if (cn10k_ml_model_stop(dev, model_id) != 0)
+				if (cnxk_ml_model_stop(dev, model_id) != 0)
 					plt_err("Could not stop model %u", model_id);
 			}
 			if (model->state == ML_CNXK_MODEL_STATE_LOADED) {
@@ -564,6 +564,46 @@  cnxk_ml_model_unload(struct rte_ml_dev *dev, uint16_t model_id)
 	return plt_memzone_free(plt_memzone_lookup(str));
 }
 
+static int
+cnxk_ml_model_start(struct rte_ml_dev *dev, uint16_t model_id)
+{
+	struct cnxk_ml_dev *cnxk_mldev;
+	struct cnxk_ml_model *model;
+
+	if (dev == NULL)
+		return -EINVAL;
+
+	cnxk_mldev = dev->data->dev_private;
+
+	model = dev->data->models[model_id];
+	if (model == NULL) {
+		plt_err("Invalid model_id = %u", model_id);
+		return -EINVAL;
+	}
+
+	return cn10k_ml_model_start(cnxk_mldev, model);
+}
+
+int
+cnxk_ml_model_stop(struct rte_ml_dev *dev, uint16_t model_id)
+{
+	struct cnxk_ml_dev *cnxk_mldev;
+	struct cnxk_ml_model *model;
+
+	if (dev == NULL)
+		return -EINVAL;
+
+	cnxk_mldev = dev->data->dev_private;
+
+	model = dev->data->models[model_id];
+	if (model == NULL) {
+		plt_err("Invalid model_id = %u", model_id);
+		return -EINVAL;
+	}
+
+	return cn10k_ml_model_stop(cnxk_mldev, model);
+}
+
 struct rte_ml_dev_ops cnxk_ml_ops = {
 	/* Device control ops */
 	.dev_info_get = cnxk_ml_dev_info_get,
@@ -589,8 +629,8 @@  struct rte_ml_dev_ops cnxk_ml_ops = {
 	/* Model ops */
 	.model_load = cnxk_ml_model_load,
 	.model_unload = cnxk_ml_model_unload,
-	.model_start = cn10k_ml_model_start,
-	.model_stop = cn10k_ml_model_stop,
+	.model_start = cnxk_ml_model_start,
+	.model_stop = cnxk_ml_model_stop,
 	.model_info_get = cn10k_ml_model_info_get,
 	.model_params_update = cn10k_ml_model_params_update,
 
diff --git a/drivers/ml/cnxk/cnxk_ml_ops.h b/drivers/ml/cnxk/cnxk_ml_ops.h
index bc14f6e5b9..d27ca0d0cb 100644
--- a/drivers/ml/cnxk/cnxk_ml_ops.h
+++ b/drivers/ml/cnxk/cnxk_ml_ops.h
@@ -63,5 +63,6 @@  struct cnxk_ml_qp {
 extern struct rte_ml_dev_ops cnxk_ml_ops;
 
 int cnxk_ml_model_unload(struct rte_ml_dev *dev, uint16_t model_id);
+int cnxk_ml_model_stop(struct rte_ml_dev *dev, uint16_t model_id);
 
 #endif /* _CNXK_ML_OPS_H_ */