[v5,09/34] ml/cnxk: update model load and unload functions

Message ID 20231018064806.24145-10-syalavarthi@marvell.com (mailing list archive)
State Superseded, archived
Delegated to: Jerin Jacob
Headers
Series Implementation of revised ml/cnxk driver |

Checks

Context Check Description
ci/checkpatch success coding style OK

Commit Message

Srikanth Yalavarthi Oct. 18, 2023, 6:47 a.m. UTC
  Implemented cnxk wrapper functions to load and unload
ML models. Wrapper functions would invoke the cn10k
model load and unload functions.

Signed-off-by: Srikanth Yalavarthi <syalavarthi@marvell.com>
---
 drivers/ml/cnxk/cn10k_ml_model.c | 244 ++++++++++++-------------
 drivers/ml/cnxk/cn10k_ml_model.h |  26 ++-
 drivers/ml/cnxk/cn10k_ml_ops.c   | 296 ++++++++++++++++++-------------
 drivers/ml/cnxk/cn10k_ml_ops.h   |  12 +-
 drivers/ml/cnxk/cnxk_ml_dev.h    |  15 ++
 drivers/ml/cnxk/cnxk_ml_ops.c    | 144 ++++++++++++++-
 drivers/ml/cnxk/cnxk_ml_ops.h    |   2 +
 7 files changed, 462 insertions(+), 277 deletions(-)
  

Patch

diff --git a/drivers/ml/cnxk/cn10k_ml_model.c b/drivers/ml/cnxk/cn10k_ml_model.c
index 5d37e9bf8a..69a60b9b90 100644
--- a/drivers/ml/cnxk/cn10k_ml_model.c
+++ b/drivers/ml/cnxk/cn10k_ml_model.c
@@ -316,42 +316,31 @@  cn10k_ml_layer_addr_update(struct cnxk_ml_layer *layer, uint8_t *buffer, uint8_t
 {
 	struct cn10k_ml_model_metadata *metadata;
 	struct cn10k_ml_layer_addr *addr;
-	size_t model_data_size;
 	uint8_t *dma_addr_load;
-	uint8_t *dma_addr_run;
 	int fpos;
 
 	metadata = &layer->glow.metadata;
 	addr = &layer->glow.addr;
-	model_data_size = metadata->init_model.file_size + metadata->main_model.file_size +
-			  metadata->finish_model.file_size + metadata->weights_bias.file_size;
 
 	/* Base address */
 	addr->base_dma_addr_load = base_dma_addr;
-	addr->base_dma_addr_run = PLT_PTR_ADD(addr->base_dma_addr_load, model_data_size);
 
 	/* Init section */
 	dma_addr_load = addr->base_dma_addr_load;
-	dma_addr_run = addr->base_dma_addr_run;
 	fpos = sizeof(struct cn10k_ml_model_metadata);
 	addr->init_load_addr = dma_addr_load;
-	addr->init_run_addr = dma_addr_run;
 	rte_memcpy(dma_addr_load, PLT_PTR_ADD(buffer, fpos), metadata->init_model.file_size);
 
 	/* Main section */
 	dma_addr_load += metadata->init_model.file_size;
-	dma_addr_run += metadata->init_model.file_size;
 	fpos += metadata->init_model.file_size;
 	addr->main_load_addr = dma_addr_load;
-	addr->main_run_addr = dma_addr_run;
 	rte_memcpy(dma_addr_load, PLT_PTR_ADD(buffer, fpos), metadata->main_model.file_size);
 
 	/* Finish section */
 	dma_addr_load += metadata->main_model.file_size;
-	dma_addr_run += metadata->main_model.file_size;
 	fpos += metadata->main_model.file_size;
 	addr->finish_load_addr = dma_addr_load;
-	addr->finish_run_addr = dma_addr_run;
 	rte_memcpy(dma_addr_load, PLT_PTR_ADD(buffer, fpos), metadata->finish_model.file_size);
 
 	/* Weights and Bias section */
@@ -363,140 +352,146 @@  cn10k_ml_layer_addr_update(struct cnxk_ml_layer *layer, uint8_t *buffer, uint8_t
 }
 
 void
-cn10k_ml_layer_info_update(struct cnxk_ml_layer *layer)
+cn10k_ml_layer_io_info_set(struct cnxk_ml_io_info *io_info,
+			   struct cn10k_ml_model_metadata *metadata)
 {
-	struct cn10k_ml_model_metadata *metadata;
 	uint8_t i;
 	uint8_t j;
 
-	metadata = &layer->glow.metadata;
-
 	/* Inputs */
-	layer->info.nb_inputs = metadata->model.num_input;
-	layer->info.total_input_sz_d = 0;
-	layer->info.total_input_sz_q = 0;
+	io_info->nb_inputs = metadata->model.num_input;
+	io_info->total_input_sz_d = 0;
+	io_info->total_input_sz_q = 0;
 	for (i = 0; i < metadata->model.num_input; i++) {
 		if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) {
-			strncpy(layer->info.input[i].name, (char *)metadata->input1[i].input_name,
+			strncpy(io_info->input[i].name, (char *)metadata->input1[i].input_name,
 				MRVL_ML_INPUT_NAME_LEN);
-			layer->info.input[i].dtype = metadata->input1[i].input_type;
-			layer->info.input[i].qtype = metadata->input1[i].model_input_type;
-			layer->info.input[i].nb_dims = 4;
-			layer->info.input[i].shape[0] = metadata->input1[i].shape.w;
-			layer->info.input[i].shape[1] = metadata->input1[i].shape.x;
-			layer->info.input[i].shape[2] = metadata->input1[i].shape.y;
-			layer->info.input[i].shape[3] = metadata->input1[i].shape.z;
-			layer->info.input[i].nb_elements =
+			io_info->input[i].dtype = metadata->input1[i].input_type;
+			io_info->input[i].qtype = metadata->input1[i].model_input_type;
+			io_info->input[i].nb_dims = 4;
+			io_info->input[i].shape[0] = metadata->input1[i].shape.w;
+			io_info->input[i].shape[1] = metadata->input1[i].shape.x;
+			io_info->input[i].shape[2] = metadata->input1[i].shape.y;
+			io_info->input[i].shape[3] = metadata->input1[i].shape.z;
+			io_info->input[i].nb_elements =
 				metadata->input1[i].shape.w * metadata->input1[i].shape.x *
 				metadata->input1[i].shape.y * metadata->input1[i].shape.z;
-			layer->info.input[i].sz_d =
-				layer->info.input[i].nb_elements *
+			io_info->input[i].sz_d =
+				io_info->input[i].nb_elements *
 				rte_ml_io_type_size_get(metadata->input1[i].input_type);
-			layer->info.input[i].sz_q =
-				layer->info.input[i].nb_elements *
+			io_info->input[i].sz_q =
+				io_info->input[i].nb_elements *
 				rte_ml_io_type_size_get(metadata->input1[i].model_input_type);
-			layer->info.input[i].scale = metadata->input1[i].qscale;
+			io_info->input[i].scale = metadata->input1[i].qscale;
 
-			layer->info.total_input_sz_d += layer->info.input[i].sz_d;
-			layer->info.total_input_sz_q += layer->info.input[i].sz_q;
+			io_info->total_input_sz_d += io_info->input[i].sz_d;
+			io_info->total_input_sz_q += io_info->input[i].sz_q;
 
 			plt_ml_dbg(
-				"index = %u, input1[%u] - w:%u x:%u y:%u z:%u, sz_d = %u sz_q = %u",
-				layer->index, i, metadata->input1[i].shape.w,
+				"layer_name = %s, input1[%u] - w:%u x:%u y:%u z:%u, sz_d = %u sz_q = %u",
+				metadata->model.name, i, metadata->input1[i].shape.w,
 				metadata->input1[i].shape.x, metadata->input1[i].shape.y,
-				metadata->input1[i].shape.z, layer->info.input[i].sz_d,
-				layer->info.input[i].sz_q);
+				metadata->input1[i].shape.z, io_info->input[i].sz_d,
+				io_info->input[i].sz_q);
 		} else {
 			j = i - MRVL_ML_NUM_INPUT_OUTPUT_1;
 
-			strncpy(layer->info.input[i].name, (char *)metadata->input2[j].input_name,
+			strncpy(io_info->input[i].name, (char *)metadata->input2[j].input_name,
 				MRVL_ML_INPUT_NAME_LEN);
-			layer->info.input[i].dtype = metadata->input2[j].input_type;
-			layer->info.input[i].qtype = metadata->input2[j].model_input_type;
-			layer->info.input[i].nb_dims = 4;
-			layer->info.input[i].shape[0] = metadata->input2[j].shape.w;
-			layer->info.input[i].shape[1] = metadata->input2[j].shape.x;
-			layer->info.input[i].shape[2] = metadata->input2[j].shape.y;
-			layer->info.input[i].shape[3] = metadata->input2[j].shape.z;
-			layer->info.input[i].nb_elements =
+			io_info->input[i].dtype = metadata->input2[j].input_type;
+			io_info->input[i].qtype = metadata->input2[j].model_input_type;
+			io_info->input[i].nb_dims = 4;
+			io_info->input[i].shape[0] = metadata->input2[j].shape.w;
+			io_info->input[i].shape[1] = metadata->input2[j].shape.x;
+			io_info->input[i].shape[2] = metadata->input2[j].shape.y;
+			io_info->input[i].shape[3] = metadata->input2[j].shape.z;
+			io_info->input[i].nb_elements =
 				metadata->input2[j].shape.w * metadata->input2[j].shape.x *
 				metadata->input2[j].shape.y * metadata->input2[j].shape.z;
-			layer->info.input[i].sz_d =
-				layer->info.input[i].nb_elements *
+			io_info->input[i].sz_d =
+				io_info->input[i].nb_elements *
 				rte_ml_io_type_size_get(metadata->input2[j].input_type);
-			layer->info.input[i].sz_q =
-				layer->info.input[i].nb_elements *
+			io_info->input[i].sz_q =
+				io_info->input[i].nb_elements *
 				rte_ml_io_type_size_get(metadata->input2[j].model_input_type);
-			layer->info.input[i].scale = metadata->input2[j].qscale;
+			io_info->input[i].scale = metadata->input2[j].qscale;
 
-			layer->info.total_input_sz_d += layer->info.input[i].sz_d;
-			layer->info.total_input_sz_q += layer->info.input[i].sz_q;
+			io_info->total_input_sz_d += io_info->input[i].sz_d;
+			io_info->total_input_sz_q += io_info->input[i].sz_q;
 
 			plt_ml_dbg(
-				"index = %u, input2[%u] - w:%u x:%u y:%u z:%u, sz_d = %u sz_q = %u",
-				layer->index, j, metadata->input2[j].shape.w,
+				"layer_name = %s, input2[%u] - w:%u x:%u y:%u z:%u, sz_d = %u sz_q = %u",
+				metadata->model.name, j, metadata->input2[j].shape.w,
 				metadata->input2[j].shape.x, metadata->input2[j].shape.y,
-				metadata->input2[j].shape.z, layer->info.input[i].sz_d,
-				layer->info.input[i].sz_q);
+				metadata->input2[j].shape.z, io_info->input[i].sz_d,
+				io_info->input[i].sz_q);
 		}
 	}
 
 	/* Outputs */
-	layer->info.nb_outputs = metadata->model.num_output;
-	layer->info.total_output_sz_q = 0;
-	layer->info.total_output_sz_d = 0;
+	io_info->nb_outputs = metadata->model.num_output;
+	io_info->total_output_sz_q = 0;
+	io_info->total_output_sz_d = 0;
 	for (i = 0; i < metadata->model.num_output; i++) {
 		if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) {
-			strncpy(layer->info.output[i].name,
-				(char *)metadata->output1[i].output_name, MRVL_ML_OUTPUT_NAME_LEN);
-			layer->info.output[i].dtype = metadata->output1[i].output_type;
-			layer->info.output[i].qtype = metadata->output1[i].model_output_type;
-			layer->info.output[i].nb_dims = 1;
-			layer->info.output[i].shape[0] = metadata->output1[i].size;
-			layer->info.output[i].nb_elements = metadata->output1[i].size;
-			layer->info.output[i].sz_d =
-				layer->info.output[i].nb_elements *
+			strncpy(io_info->output[i].name, (char *)metadata->output1[i].output_name,
+				MRVL_ML_OUTPUT_NAME_LEN);
+			io_info->output[i].dtype = metadata->output1[i].output_type;
+			io_info->output[i].qtype = metadata->output1[i].model_output_type;
+			io_info->output[i].nb_dims = 1;
+			io_info->output[i].shape[0] = metadata->output1[i].size;
+			io_info->output[i].nb_elements = metadata->output1[i].size;
+			io_info->output[i].sz_d =
+				io_info->output[i].nb_elements *
 				rte_ml_io_type_size_get(metadata->output1[i].output_type);
-			layer->info.output[i].sz_q =
-				layer->info.output[i].nb_elements *
+			io_info->output[i].sz_q =
+				io_info->output[i].nb_elements *
 				rte_ml_io_type_size_get(metadata->output1[i].model_output_type);
-			layer->info.output[i].scale = metadata->output1[i].dscale;
+			io_info->output[i].scale = metadata->output1[i].dscale;
 
-			layer->info.total_output_sz_q += layer->info.output[i].sz_q;
-			layer->info.total_output_sz_d += layer->info.output[i].sz_d;
+			io_info->total_output_sz_q += io_info->output[i].sz_q;
+			io_info->total_output_sz_d += io_info->output[i].sz_d;
 
-			plt_ml_dbg("index = %u, output1[%u] - sz_d = %u, sz_q = %u", layer->index,
-				   i, layer->info.output[i].sz_d, layer->info.output[i].sz_q);
+			plt_ml_dbg("layer_name = %s, output1[%u] - sz_d = %u, sz_q = %u",
+				   metadata->model.name, i, io_info->output[i].sz_d,
+				   io_info->output[i].sz_q);
 		} else {
 			j = i - MRVL_ML_NUM_INPUT_OUTPUT_1;
 
-			strncpy(layer->info.output[i].name,
-				(char *)metadata->output2[j].output_name, MRVL_ML_OUTPUT_NAME_LEN);
-			layer->info.output[i].dtype = metadata->output2[j].output_type;
-			layer->info.output[i].qtype = metadata->output2[j].model_output_type;
-			layer->info.output[i].nb_dims = 1;
-			layer->info.output[i].shape[0] = metadata->output2[j].size;
-			layer->info.output[i].nb_elements = metadata->output2[j].size;
-			layer->info.output[i].sz_d =
-				layer->info.output[i].nb_elements *
+			strncpy(io_info->output[i].name, (char *)metadata->output2[j].output_name,
+				MRVL_ML_OUTPUT_NAME_LEN);
+			io_info->output[i].dtype = metadata->output2[j].output_type;
+			io_info->output[i].qtype = metadata->output2[j].model_output_type;
+			io_info->output[i].nb_dims = 1;
+			io_info->output[i].shape[0] = metadata->output2[j].size;
+			io_info->output[i].nb_elements = metadata->output2[j].size;
+			io_info->output[i].sz_d =
+				io_info->output[i].nb_elements *
 				rte_ml_io_type_size_get(metadata->output2[j].output_type);
-			layer->info.output[i].sz_q =
-				layer->info.output[i].nb_elements *
+			io_info->output[i].sz_q =
+				io_info->output[i].nb_elements *
 				rte_ml_io_type_size_get(metadata->output2[j].model_output_type);
-			layer->info.output[i].scale = metadata->output2[j].dscale;
+			io_info->output[i].scale = metadata->output2[j].dscale;
 
-			layer->info.total_output_sz_q += layer->info.output[i].sz_q;
-			layer->info.total_output_sz_d += layer->info.output[i].sz_d;
+			io_info->total_output_sz_q += io_info->output[i].sz_q;
+			io_info->total_output_sz_d += io_info->output[i].sz_d;
 
-			plt_ml_dbg("index = %u, output2[%u] - sz_d = %u, sz_q = %u", layer->index,
-				   j, layer->info.output[i].sz_d, layer->info.output[i].sz_q);
+			plt_ml_dbg("layer_name = %s, output2[%u] - sz_d = %u, sz_q = %u",
+				   metadata->model.name, j, io_info->output[i].sz_d,
+				   io_info->output[i].sz_q);
 		}
 	}
 }
 
+struct cnxk_ml_io_info *
+cn10k_ml_model_io_info_get(struct cnxk_ml_model *model, uint16_t layer_id)
+{
+	return &model->layer[layer_id].info;
+}
+
 int
-cn10k_ml_model_ocm_pages_count(struct cn10k_ml_dev *cn10k_mldev, uint16_t model_id, uint8_t *buffer,
-			       uint16_t *wb_pages, uint16_t *scratch_pages)
+cn10k_ml_model_ocm_pages_count(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_layer *layer,
+			       uint8_t *buffer, uint16_t *wb_pages, uint16_t *scratch_pages)
 {
 	struct cn10k_ml_model_metadata *metadata;
 	struct cn10k_ml_ocm *ocm;
@@ -504,7 +499,7 @@  cn10k_ml_model_ocm_pages_count(struct cn10k_ml_dev *cn10k_mldev, uint16_t model_
 	uint64_t wb_size;
 
 	metadata = (struct cn10k_ml_model_metadata *)buffer;
-	ocm = &cn10k_mldev->ocm;
+	ocm = &cnxk_mldev->cn10k_mldev.ocm;
 
 	/* Assume wb_size is zero for non-relocatable models */
 	if (metadata->model.ocm_relocatable)
@@ -516,7 +511,7 @@  cn10k_ml_model_ocm_pages_count(struct cn10k_ml_dev *cn10k_mldev, uint16_t model_
 		*wb_pages = wb_size / ocm->page_size + 1;
 	else
 		*wb_pages = wb_size / ocm->page_size;
-	plt_ml_dbg("model_id = %u, wb_size = %" PRIu64 ", wb_pages = %u", model_id, wb_size,
+	plt_ml_dbg("index = %u, wb_size = %" PRIu64 ", wb_pages = %u", layer->index, wb_size,
 		   *wb_pages);
 
 	scratch_size = ocm->size_per_tile - metadata->model.ocm_tmp_range_floor;
@@ -524,15 +519,15 @@  cn10k_ml_model_ocm_pages_count(struct cn10k_ml_dev *cn10k_mldev, uint16_t model_
 		*scratch_pages = scratch_size / ocm->page_size + 1;
 	else
 		*scratch_pages = scratch_size / ocm->page_size;
-	plt_ml_dbg("model_id = %u, scratch_size = %" PRIu64 ", scratch_pages = %u", model_id,
+	plt_ml_dbg("index = %u, scratch_size = %" PRIu64 ", scratch_pages = %u", layer->index,
 		   scratch_size, *scratch_pages);
 
 	/* Check if the model can be loaded on OCM */
-	if ((*wb_pages + *scratch_pages) > cn10k_mldev->ocm.num_pages) {
+	if ((*wb_pages + *scratch_pages) > ocm->num_pages) {
 		plt_err("Cannot create the model, OCM relocatable = %u",
 			metadata->model.ocm_relocatable);
 		plt_err("wb_pages (%u) + scratch_pages (%u) > %u", *wb_pages, *scratch_pages,
-			cn10k_mldev->ocm.num_pages);
+			ocm->num_pages);
 		return -ENOMEM;
 	}
 
@@ -540,28 +535,25 @@  cn10k_ml_model_ocm_pages_count(struct cn10k_ml_dev *cn10k_mldev, uint16_t model_
 	 * prevent the library from allocating the remaining space on the tile to other models.
 	 */
 	if (!metadata->model.ocm_relocatable)
-		*scratch_pages = PLT_MAX(PLT_U64_CAST(*scratch_pages),
-					 PLT_U64_CAST(cn10k_mldev->ocm.num_pages));
+		*scratch_pages =
+			PLT_MAX(PLT_U64_CAST(*scratch_pages), PLT_U64_CAST(ocm->num_pages));
 
 	return 0;
 }
 
 void
-cn10k_ml_model_info_set(struct rte_ml_dev *dev, struct cnxk_ml_model *model)
+cn10k_ml_model_info_set(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model,
+			struct cnxk_ml_io_info *io_info, struct cn10k_ml_model_metadata *metadata)
 {
-	struct cn10k_ml_model_metadata *metadata;
-	struct cnxk_ml_dev *cnxk_mldev;
 	struct rte_ml_model_info *info;
 	struct rte_ml_io_info *output;
 	struct rte_ml_io_info *input;
-	struct cnxk_ml_layer *layer;
 	uint8_t i;
 
-	cnxk_mldev = dev->data->dev_private;
 	metadata = &model->glow.metadata;
 	info = PLT_PTR_CAST(model->info);
 	input = PLT_PTR_ADD(info, sizeof(struct rte_ml_model_info));
-	output = PLT_PTR_ADD(input, metadata->model.num_input * sizeof(struct rte_ml_io_info));
+	output = PLT_PTR_ADD(input, ML_CNXK_MODEL_MAX_INPUT_OUTPUT * sizeof(struct rte_ml_io_info));
 
 	/* Set model info */
 	memset(info, 0, sizeof(struct rte_ml_model_info));
@@ -570,39 +562,37 @@  cn10k_ml_model_info_set(struct rte_ml_dev *dev, struct cnxk_ml_model *model)
 		 metadata->model.version[1], metadata->model.version[2],
 		 metadata->model.version[3]);
 	info->model_id = model->model_id;
-	info->device_id = dev->data->dev_id;
+	info->device_id = cnxk_mldev->mldev->data->dev_id;
 	info->io_layout = RTE_ML_IO_LAYOUT_PACKED;
 	info->min_batches = model->batch_size;
 	info->max_batches =
 		cnxk_mldev->cn10k_mldev.fw.req->cn10k_req.jd.fw_load.cap.s.max_num_batches /
 		model->batch_size;
-	info->nb_inputs = metadata->model.num_input;
+	info->nb_inputs = io_info->nb_inputs;
 	info->input_info = input;
-	info->nb_outputs = metadata->model.num_output;
+	info->nb_outputs = io_info->nb_outputs;
 	info->output_info = output;
 	info->wb_size = metadata->weights_bias.file_size;
 
 	/* Set input info */
-	layer = &model->layer[0];
 	for (i = 0; i < info->nb_inputs; i++) {
-		rte_memcpy(input[i].name, layer->info.input[i].name, MRVL_ML_INPUT_NAME_LEN);
-		input[i].nb_dims = layer->info.input[i].nb_dims;
-		input[i].shape = &layer->info.input[i].shape[0];
-		input[i].type = layer->info.input[i].qtype;
-		input[i].nb_elements = layer->info.input[i].nb_elements;
-		input[i].size = layer->info.input[i].nb_elements *
-				rte_ml_io_type_size_get(layer->info.input[i].qtype);
+		rte_memcpy(input[i].name, io_info->input[i].name, MRVL_ML_INPUT_NAME_LEN);
+		input[i].nb_dims = io_info->input[i].nb_dims;
+		input[i].shape = &io_info->input[i].shape[0];
+		input[i].type = io_info->input[i].qtype;
+		input[i].nb_elements = io_info->input[i].nb_elements;
+		input[i].size = io_info->input[i].nb_elements *
+				rte_ml_io_type_size_get(io_info->input[i].qtype);
 	}
 
 	/* Set output info */
-	layer = &model->layer[0];
 	for (i = 0; i < info->nb_outputs; i++) {
-		rte_memcpy(output[i].name, layer->info.output[i].name, MRVL_ML_INPUT_NAME_LEN);
-		output[i].nb_dims = layer->info.output[i].nb_dims;
-		output[i].shape = &layer->info.output[i].shape[0];
-		output[i].type = layer->info.output[i].qtype;
-		output[i].nb_elements = layer->info.output[i].nb_elements;
-		output[i].size = layer->info.output[i].nb_elements *
-				 rte_ml_io_type_size_get(layer->info.output[i].qtype);
+		rte_memcpy(output[i].name, io_info->output[i].name, MRVL_ML_INPUT_NAME_LEN);
+		output[i].nb_dims = io_info->output[i].nb_dims;
+		output[i].shape = &io_info->output[i].shape[0];
+		output[i].type = io_info->output[i].qtype;
+		output[i].nb_elements = io_info->output[i].nb_elements;
+		output[i].size = io_info->output[i].nb_elements *
+				 rte_ml_io_type_size_get(io_info->output[i].qtype);
 	}
 }
diff --git a/drivers/ml/cnxk/cn10k_ml_model.h b/drivers/ml/cnxk/cn10k_ml_model.h
index 5c32f48c68..b891c9d627 100644
--- a/drivers/ml/cnxk/cn10k_ml_model.h
+++ b/drivers/ml/cnxk/cn10k_ml_model.h
@@ -9,9 +9,11 @@ 
 
 #include <roc_api.h>
 
-#include "cn10k_ml_dev.h"
 #include "cn10k_ml_ocm.h"
 
+#include "cnxk_ml_io.h"
+
+struct cnxk_ml_dev;
 struct cnxk_ml_model;
 struct cnxk_ml_layer;
 struct cnxk_ml_req;
@@ -366,27 +368,15 @@  struct cn10k_ml_layer_addr {
 	/* Base DMA address for load */
 	void *base_dma_addr_load;
 
-	/* Base DMA address for run */
-	void *base_dma_addr_run;
-
 	/* Init section load address */
 	void *init_load_addr;
 
-	/* Init section run address */
-	void *init_run_addr;
-
 	/* Main section load address */
 	void *main_load_addr;
 
-	/* Main section run address */
-	void *main_run_addr;
-
 	/* Finish section load address */
 	void *finish_load_addr;
 
-	/* Finish section run address */
-	void *finish_run_addr;
-
 	/* Weights and Bias base address */
 	void *wb_base_addr;
 
@@ -462,9 +452,13 @@  int cn10k_ml_model_metadata_check(uint8_t *buffer, uint64_t size);
 void cn10k_ml_model_metadata_update(struct cn10k_ml_model_metadata *metadata);
 void cn10k_ml_layer_addr_update(struct cnxk_ml_layer *layer, uint8_t *buffer,
 				uint8_t *base_dma_addr);
-void cn10k_ml_layer_info_update(struct cnxk_ml_layer *layer);
-int cn10k_ml_model_ocm_pages_count(struct cn10k_ml_dev *cn10k_mldev, uint16_t model_id,
+void cn10k_ml_layer_io_info_set(struct cnxk_ml_io_info *io_info,
+				struct cn10k_ml_model_metadata *metadata);
+struct cnxk_ml_io_info *cn10k_ml_model_io_info_get(struct cnxk_ml_model *model, uint16_t layer_id);
+int cn10k_ml_model_ocm_pages_count(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_layer *layer,
 				   uint8_t *buffer, uint16_t *wb_pages, uint16_t *scratch_pages);
-void cn10k_ml_model_info_set(struct rte_ml_dev *dev, struct cnxk_ml_model *model);
+void cn10k_ml_model_info_set(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model,
+			     struct cnxk_ml_io_info *io_info,
+			     struct cn10k_ml_model_metadata *metadata);
 
 #endif /* _CN10K_ML_MODEL_H_ */
diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c
index e3c688a55f..ad2effb904 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.c
+++ b/drivers/ml/cnxk/cn10k_ml_ops.c
@@ -15,6 +15,9 @@ 
 /* ML model macros */
 #define CN10K_ML_MODEL_MEMZONE_NAME "ml_cn10k_model_mz"
 
+/* ML layer macros */
+#define CN10K_ML_LAYER_MEMZONE_NAME "ml_cn10k_layer_mz"
+
 /* Debug print width */
 #define STR_LEN	  12
 #define FIELD_LEN 16
@@ -273,7 +276,7 @@  cn10k_ml_prep_sp_job_descriptor(struct cn10k_ml_dev *cn10k_mldev, struct cnxk_ml
 		req->cn10k_req.jd.model_start.extended_args = PLT_U64_CAST(
 			roc_ml_addr_ap2mlip(&cn10k_mldev->roc, &req->cn10k_req.extended_args));
 		req->cn10k_req.jd.model_start.model_dst_ddr_addr =
-			PLT_U64_CAST(roc_ml_addr_ap2mlip(&cn10k_mldev->roc, addr->init_run_addr));
+			PLT_U64_CAST(roc_ml_addr_ap2mlip(&cn10k_mldev->roc, addr->init_load_addr));
 		req->cn10k_req.jd.model_start.model_init_offset = 0x0;
 		req->cn10k_req.jd.model_start.model_main_offset = metadata->init_model.file_size;
 		req->cn10k_req.jd.model_start.model_finish_offset =
@@ -1261,85 +1264,171 @@  cn10k_ml_dev_selftest(struct rte_ml_dev *dev)
 }
 
 int
-cn10k_ml_model_load(struct rte_ml_dev *dev, struct rte_ml_model_params *params, uint16_t *model_id)
+cn10k_ml_layer_load(void *device, uint16_t model_id, const char *layer_name, uint8_t *buffer,
+		    size_t size, uint16_t *index)
 {
 	struct cn10k_ml_model_metadata *metadata;
 	struct cnxk_ml_dev *cnxk_mldev;
 	struct cnxk_ml_model *model;
+	struct cnxk_ml_layer *layer;
 
 	char str[RTE_MEMZONE_NAMESIZE];
 	const struct plt_memzone *mz;
-	size_t model_scratch_size;
-	size_t model_stats_size;
-	size_t model_data_size;
-	size_t model_info_size;
+	size_t layer_object_size = 0;
+	size_t layer_scratch_size;
+	size_t layer_xstats_size;
 	uint8_t *base_dma_addr;
 	uint16_t scratch_pages;
+	uint16_t layer_id = 0;
 	uint16_t wb_pages;
 	uint64_t mz_size;
 	uint16_t idx;
-	bool found;
 	int qp_id;
 	int ret;
 
-	ret = cn10k_ml_model_metadata_check(params->addr, params->size);
+	PLT_SET_USED(size);
+	PLT_SET_USED(layer_name);
+
+	cnxk_mldev = (struct cnxk_ml_dev *)device;
+	if (cnxk_mldev == NULL) {
+		plt_err("Invalid device = %p", device);
+		return -EINVAL;
+	}
+
+	model = cnxk_mldev->mldev->data->models[model_id];
+	if (model == NULL) {
+		plt_err("Invalid model_id = %u", model_id);
+		return -EINVAL;
+	}
+
+	layer = &model->layer[layer_id];
+
+	ret = cn10k_ml_model_metadata_check(buffer, size);
 	if (ret != 0)
 		return ret;
 
-	cnxk_mldev = dev->data->dev_private;
-
-	/* Find model ID */
-	found = false;
-	for (idx = 0; idx < dev->data->nb_models; idx++) {
-		if (dev->data->models[idx] == NULL) {
-			found = true;
+	/* Get index */
+	for (idx = 0; idx < cnxk_mldev->max_nb_layers; idx++) {
+		if (!cnxk_mldev->index_map[idx].active) {
+			layer->index = idx;
 			break;
 		}
 	}
 
-	if (!found) {
-		plt_err("No slots available to load new model");
-		return -ENOMEM;
+	if (idx >= cnxk_mldev->max_nb_layers) {
+		plt_err("No slots available for model layers, model_id = %u, layer_id = %u",
+			model->model_id, layer_id);
+		return -1;
 	}
 
+	layer->model = model;
+
 	/* Get WB and scratch pages, check if model can be loaded. */
-	ret = cn10k_ml_model_ocm_pages_count(&cnxk_mldev->cn10k_mldev, idx, params->addr, &wb_pages,
-					     &scratch_pages);
+	ret = cn10k_ml_model_ocm_pages_count(cnxk_mldev, layer, buffer, &wb_pages, &scratch_pages);
 	if (ret < 0)
 		return ret;
 
-	/* Compute memzone size */
-	metadata = (struct cn10k_ml_model_metadata *)params->addr;
-	model_data_size = metadata->init_model.file_size + metadata->main_model.file_size +
-			  metadata->finish_model.file_size + metadata->weights_bias.file_size;
-	model_scratch_size = PLT_ALIGN_CEIL(metadata->model.ddr_scratch_range_end -
+	/* Compute layer memzone size */
+	metadata = (struct cn10k_ml_model_metadata *)buffer;
+	layer_object_size = metadata->init_model.file_size + metadata->main_model.file_size +
+			    metadata->finish_model.file_size + metadata->weights_bias.file_size;
+	layer_object_size = PLT_ALIGN_CEIL(layer_object_size, ML_CN10K_ALIGN_SIZE);
+	layer_scratch_size = PLT_ALIGN_CEIL(metadata->model.ddr_scratch_range_end -
 						    metadata->model.ddr_scratch_range_start + 1,
 					    ML_CN10K_ALIGN_SIZE);
-	model_data_size = PLT_ALIGN_CEIL(model_data_size, ML_CN10K_ALIGN_SIZE);
-	model_info_size = sizeof(struct rte_ml_model_info) +
-			  metadata->model.num_input * sizeof(struct rte_ml_io_info) +
-			  metadata->model.num_output * sizeof(struct rte_ml_io_info);
-	model_info_size = PLT_ALIGN_CEIL(model_info_size, ML_CN10K_ALIGN_SIZE);
-	model_stats_size = (dev->data->nb_queue_pairs + 1) * sizeof(struct cn10k_ml_layer_xstats);
-
-	mz_size = PLT_ALIGN_CEIL(sizeof(struct cnxk_ml_model), ML_CN10K_ALIGN_SIZE) +
-		  2 * model_data_size + model_scratch_size + model_info_size +
-		  PLT_ALIGN_CEIL(sizeof(struct cnxk_ml_req), ML_CN10K_ALIGN_SIZE) +
-		  model_stats_size;
+	layer_xstats_size = (cnxk_mldev->mldev->data->nb_queue_pairs + 1) *
+			    sizeof(struct cn10k_ml_layer_xstats);
 
-	/* Allocate memzone for model object and model data */
-	snprintf(str, RTE_MEMZONE_NAMESIZE, "%s_%u", CN10K_ML_MODEL_MEMZONE_NAME, idx);
+	/* Allocate memzone for model data */
+	mz_size = layer_object_size + layer_scratch_size +
+		  PLT_ALIGN_CEIL(sizeof(struct cnxk_ml_req), ML_CN10K_ALIGN_SIZE) +
+		  layer_xstats_size;
+	snprintf(str, RTE_MEMZONE_NAMESIZE, "%s_%u_%u", CN10K_ML_LAYER_MEMZONE_NAME,
+		 model->model_id, layer_id);
 	mz = plt_memzone_reserve_aligned(str, mz_size, 0, ML_CN10K_ALIGN_SIZE);
 	if (!mz) {
 		plt_err("plt_memzone_reserve failed : %s", str);
 		return -ENOMEM;
 	}
 
-	model = mz->addr;
-	model->cnxk_mldev = cnxk_mldev;
-	model->model_id = idx;
-	dev->data->models[idx] = model;
+	/* Copy metadata to internal buffer */
+	rte_memcpy(&layer->glow.metadata, buffer, sizeof(struct cn10k_ml_model_metadata));
+	cn10k_ml_model_metadata_update(&layer->glow.metadata);
+
+	/* Set layer name */
+	rte_memcpy(layer->name, layer->glow.metadata.model.name, MRVL_ML_MODEL_NAME_LEN);
+
+	/* Enable support for batch_size of 256 */
+	if (layer->glow.metadata.model.batch_size == 0)
+		layer->batch_size = 256;
+	else
+		layer->batch_size = layer->glow.metadata.model.batch_size;
+
+	/* Set DMA base address */
+	base_dma_addr = mz->addr;
+	cn10k_ml_layer_addr_update(layer, buffer, base_dma_addr);
+
+	/* Set scratch base address */
+	layer->glow.addr.scratch_base_addr = PLT_PTR_ADD(base_dma_addr, layer_object_size);
+
+	/* Update internal I/O data structure */
+	cn10k_ml_layer_io_info_set(&layer->info, &layer->glow.metadata);
+
+	/* Initialize model_mem_map */
+	memset(&layer->glow.ocm_map, 0, sizeof(struct cn10k_ml_ocm_layer_map));
+	layer->glow.ocm_map.ocm_reserved = false;
+	layer->glow.ocm_map.tilemask = 0;
+	layer->glow.ocm_map.wb_page_start = -1;
+	layer->glow.ocm_map.wb_pages = wb_pages;
+	layer->glow.ocm_map.scratch_pages = scratch_pages;
+
+	/* Set slow-path request address and state */
+	layer->glow.req = PLT_PTR_ADD(mz->addr, layer_object_size + layer_scratch_size);
+
+	/* Reset burst and sync stats */
+	layer->glow.burst_xstats = PLT_PTR_ADD(
+		layer->glow.req, PLT_ALIGN_CEIL(sizeof(struct cnxk_ml_req), ML_CN10K_ALIGN_SIZE));
+	for (qp_id = 0; qp_id < cnxk_mldev->mldev->data->nb_queue_pairs + 1; qp_id++) {
+		layer->glow.burst_xstats[qp_id].hw_latency_tot = 0;
+		layer->glow.burst_xstats[qp_id].hw_latency_min = UINT64_MAX;
+		layer->glow.burst_xstats[qp_id].hw_latency_max = 0;
+		layer->glow.burst_xstats[qp_id].fw_latency_tot = 0;
+		layer->glow.burst_xstats[qp_id].fw_latency_min = UINT64_MAX;
+		layer->glow.burst_xstats[qp_id].fw_latency_max = 0;
+		layer->glow.burst_xstats[qp_id].hw_reset_count = 0;
+		layer->glow.burst_xstats[qp_id].fw_reset_count = 0;
+		layer->glow.burst_xstats[qp_id].dequeued_count = 0;
+	}
+
+	layer->glow.sync_xstats =
+		PLT_PTR_ADD(layer->glow.burst_xstats, cnxk_mldev->mldev->data->nb_queue_pairs *
+							      sizeof(struct cn10k_ml_layer_xstats));
+
+	/* Update xstats names */
+	cn10k_ml_xstats_model_name_update(cnxk_mldev->mldev, idx);
+
+	layer->state = ML_CNXK_LAYER_STATE_LOADED;
+	cnxk_mldev->index_map[idx].model_id = model->model_id;
+	cnxk_mldev->index_map[idx].layer_id = layer_id;
+	cnxk_mldev->index_map[idx].active = true;
+	*index = idx;
+
+	return 0;
+}
+
+int
+cn10k_ml_model_load(struct cnxk_ml_dev *cnxk_mldev, struct rte_ml_model_params *params,
+		    struct cnxk_ml_model *model)
+{
+	struct cnxk_ml_layer *layer;
+	int ret;
+
+	/* Metadata check */
+	ret = cn10k_ml_model_metadata_check(params->addr, params->size);
+	if (ret != 0)
+		return ret;
 
+	/* Copy metadata to internal buffer */
 	rte_memcpy(&model->glow.metadata, params->addr, sizeof(struct cn10k_ml_model_metadata));
 	cn10k_ml_model_metadata_update(&model->glow.metadata);
 
@@ -1358,99 +1447,62 @@  cn10k_ml_model_load(struct rte_ml_dev *dev, struct rte_ml_model_params *params,
 	 */
 	model->nb_layers = 1;
 
-	/* Copy metadata to internal buffer */
-	rte_memcpy(&model->layer[0].glow.metadata, params->addr,
-		   sizeof(struct cn10k_ml_model_metadata));
-	cn10k_ml_model_metadata_update(&model->layer[0].glow.metadata);
-	model->layer[0].model = model;
-
-	/* Set DMA base address */
-	base_dma_addr = PLT_PTR_ADD(
-		mz->addr, PLT_ALIGN_CEIL(sizeof(struct cnxk_ml_model), ML_CN10K_ALIGN_SIZE));
-	cn10k_ml_layer_addr_update(&model->layer[0], params->addr, base_dma_addr);
-	model->layer[0].glow.addr.scratch_base_addr =
-		PLT_PTR_ADD(base_dma_addr, 2 * model_data_size);
-
-	/* Copy data from load to run. run address to be used by MLIP */
-	rte_memcpy(model->layer[0].glow.addr.base_dma_addr_run,
-		   model->layer[0].glow.addr.base_dma_addr_load, model_data_size);
-
-	/* Update internal I/O data structure */
-	cn10k_ml_layer_info_update(&model->layer[0]);
-
-	/* Initialize model_mem_map */
-	memset(&model->layer[0].glow.ocm_map, 0, sizeof(struct cn10k_ml_ocm_layer_map));
-	model->layer[0].glow.ocm_map.ocm_reserved = false;
-	model->layer[0].glow.ocm_map.tilemask = 0;
-	model->layer[0].glow.ocm_map.wb_page_start = -1;
-	model->layer[0].glow.ocm_map.wb_pages = wb_pages;
-	model->layer[0].glow.ocm_map.scratch_pages = scratch_pages;
-
-	/* Set model info */
-	model->info = PLT_PTR_ADD(model->layer[0].glow.addr.scratch_base_addr, model_scratch_size);
-	cn10k_ml_model_info_set(dev, model);
-
-	/* Set slow-path request address and state */
-	model->layer[0].glow.req = PLT_PTR_ADD(model->info, model_info_size);
-
-	/* Reset burst and sync stats */
-	model->layer[0].glow.burst_xstats =
-		PLT_PTR_ADD(model->layer[0].glow.req,
-			    PLT_ALIGN_CEIL(sizeof(struct cnxk_ml_req), ML_CN10K_ALIGN_SIZE));
-	for (qp_id = 0; qp_id < dev->data->nb_queue_pairs + 1; qp_id++) {
-		model->layer[0].glow.burst_xstats[qp_id].hw_latency_tot = 0;
-		model->layer[0].glow.burst_xstats[qp_id].hw_latency_min = UINT64_MAX;
-		model->layer[0].glow.burst_xstats[qp_id].hw_latency_max = 0;
-		model->layer[0].glow.burst_xstats[qp_id].fw_latency_tot = 0;
-		model->layer[0].glow.burst_xstats[qp_id].fw_latency_min = UINT64_MAX;
-		model->layer[0].glow.burst_xstats[qp_id].fw_latency_max = 0;
-		model->layer[0].glow.burst_xstats[qp_id].hw_reset_count = 0;
-		model->layer[0].glow.burst_xstats[qp_id].fw_reset_count = 0;
-		model->layer[0].glow.burst_xstats[qp_id].dequeued_count = 0;
+	/* Load layer and get the index */
+	layer = &model->layer[0];
+	ret = cn10k_ml_layer_load(cnxk_mldev, model->model_id, NULL, params->addr, params->size,
+				  &layer->index);
+	if (ret != 0) {
+		plt_err("Model layer load failed: model_id = %u, layer_id = %u", model->model_id,
+			0);
+		return ret;
 	}
 
-	model->layer[0].glow.sync_xstats =
-		PLT_PTR_ADD(model->layer[0].glow.burst_xstats,
-			    dev->data->nb_queue_pairs * sizeof(struct cn10k_ml_layer_xstats));
-
-	plt_spinlock_init(&model->lock);
-	model->state = ML_CNXK_MODEL_STATE_LOADED;
-	dev->data->models[idx] = model;
-	cnxk_mldev->nb_models_loaded++;
-
-	/* Update xstats names */
-	cn10k_ml_xstats_model_name_update(dev, idx);
-
-	*model_id = idx;
+	cn10k_ml_model_info_set(cnxk_mldev, model, &model->layer[0].info, &model->glow.metadata);
 
 	return 0;
 }
 
 int
-cn10k_ml_model_unload(struct rte_ml_dev *dev, uint16_t model_id)
+cn10k_ml_layer_unload(void *device, uint16_t model_id, const char *layer_name)
 {
-	char str[RTE_MEMZONE_NAMESIZE];
-	struct cnxk_ml_model *model;
 	struct cnxk_ml_dev *cnxk_mldev;
+	struct cnxk_ml_model *model;
+	struct cnxk_ml_layer *layer;
 
-	cnxk_mldev = dev->data->dev_private;
-	model = dev->data->models[model_id];
+	char str[RTE_MEMZONE_NAMESIZE];
+	uint16_t layer_id = 0;
+	int ret;
 
+	PLT_SET_USED(layer_name);
+
+	cnxk_mldev = (struct cnxk_ml_dev *)device;
+	if (cnxk_mldev == NULL) {
+		plt_err("Invalid device = %p", device);
+		return -EINVAL;
+	}
+
+	model = cnxk_mldev->mldev->data->models[model_id];
 	if (model == NULL) {
 		plt_err("Invalid model_id = %u", model_id);
 		return -EINVAL;
 	}
 
-	if (model->state != ML_CNXK_MODEL_STATE_LOADED) {
-		plt_err("Cannot unload. Model in use.");
-		return -EBUSY;
-	}
+	layer = &model->layer[layer_id];
 
-	dev->data->models[model_id] = NULL;
-	cnxk_mldev->nb_models_unloaded++;
+	snprintf(str, RTE_MEMZONE_NAMESIZE, "%s_%u_%u", CN10K_ML_LAYER_MEMZONE_NAME,
+		 model->model_id, layer_id);
+	ret = plt_memzone_free(plt_memzone_lookup(str));
 
-	snprintf(str, RTE_MEMZONE_NAMESIZE, "%s_%u", CN10K_ML_MODEL_MEMZONE_NAME, model_id);
-	return plt_memzone_free(plt_memzone_lookup(str));
+	layer->state = ML_CNXK_LAYER_STATE_UNKNOWN;
+	cnxk_mldev->index_map[layer->index].active = false;
+
+	return ret;
+}
+
+int
+cn10k_ml_model_unload(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model)
+{
+	return cn10k_ml_layer_unload(cnxk_mldev, model->model_id, NULL);
 }
 
 int
@@ -1748,7 +1800,6 @@  int
 cn10k_ml_model_params_update(struct rte_ml_dev *dev, uint16_t model_id, void *buffer)
 {
 	struct cnxk_ml_model *model;
-	size_t size;
 
 	model = dev->data->models[model_id];
 
@@ -1762,19 +1813,10 @@  cn10k_ml_model_params_update(struct rte_ml_dev *dev, uint16_t model_id, void *bu
 	else if (model->state != ML_CNXK_MODEL_STATE_LOADED)
 		return -EBUSY;
 
-	size = model->layer[0].glow.metadata.init_model.file_size +
-	       model->layer[0].glow.metadata.main_model.file_size +
-	       model->layer[0].glow.metadata.finish_model.file_size +
-	       model->layer[0].glow.metadata.weights_bias.file_size;
-
 	/* Update model weights & bias */
 	rte_memcpy(model->layer[0].glow.addr.wb_load_addr, buffer,
 		   model->layer[0].glow.metadata.weights_bias.file_size);
 
-	/* Copy data from load to run. run address to be used by MLIP */
-	rte_memcpy(model->layer[0].glow.addr.base_dma_addr_run,
-		   model->layer[0].glow.addr.base_dma_addr_load, size);
-
 	return 0;
 }
 
diff --git a/drivers/ml/cnxk/cn10k_ml_ops.h b/drivers/ml/cnxk/cn10k_ml_ops.h
index 2d0a49d5cd..677219dfdf 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.h
+++ b/drivers/ml/cnxk/cn10k_ml_ops.h
@@ -12,6 +12,7 @@ 
 
 struct cnxk_ml_dev;
 struct cnxk_ml_qp;
+struct cnxk_ml_model;
 
 /* Firmware version string length */
 #define MLDEV_FIRMWARE_VERSION_LENGTH 32
@@ -311,9 +312,9 @@  int cn10k_ml_dev_xstats_reset(struct rte_ml_dev *dev, enum rte_ml_dev_xstats_mod
 			      int32_t model_id, const uint16_t stat_ids[], uint16_t nb_ids);
 
 /* Slow-path ops */
-int cn10k_ml_model_load(struct rte_ml_dev *dev, struct rte_ml_model_params *params,
-			uint16_t *model_id);
-int cn10k_ml_model_unload(struct rte_ml_dev *dev, uint16_t model_id);
+int cn10k_ml_model_load(struct cnxk_ml_dev *cnxk_mldev, struct rte_ml_model_params *params,
+			struct cnxk_ml_model *model);
+int cn10k_ml_model_unload(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model);
 int cn10k_ml_model_start(struct rte_ml_dev *dev, uint16_t model_id);
 int cn10k_ml_model_stop(struct rte_ml_dev *dev, uint16_t model_id);
 int cn10k_ml_model_info_get(struct rte_ml_dev *dev, uint16_t model_id,
@@ -339,4 +340,9 @@  __rte_hot int cn10k_ml_inference_sync(struct rte_ml_dev *dev, struct rte_ml_op *
 /* Misc ops */
 void cn10k_ml_qp_initialize(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_qp *qp);
 
+/* Layer ops */
+int cn10k_ml_layer_load(void *device, uint16_t model_id, const char *layer_name, uint8_t *buffer,
+			size_t size, uint16_t *index);
+int cn10k_ml_layer_unload(void *device, uint16_t model_id, const char *layer_name);
+
 #endif /* _CN10K_ML_OPS_H_ */
diff --git a/drivers/ml/cnxk/cnxk_ml_dev.h b/drivers/ml/cnxk/cnxk_ml_dev.h
index 02605fa28f..1590249abd 100644
--- a/drivers/ml/cnxk/cnxk_ml_dev.h
+++ b/drivers/ml/cnxk/cnxk_ml_dev.h
@@ -31,6 +31,18 @@  enum cnxk_ml_dev_state {
 	ML_CNXK_DEV_STATE_CLOSED
 };
 
+/* Index to model and layer ID map */
+struct cnxk_ml_index_map {
+	/* Model ID */
+	uint16_t model_id;
+
+	/* Layer ID */
+	uint16_t layer_id;
+
+	/* Layer status */
+	bool active;
+};
+
 /* Device private data */
 struct cnxk_ml_dev {
 	/* RTE device */
@@ -56,6 +68,9 @@  struct cnxk_ml_dev {
 
 	/* Maximum number of layers */
 	uint64_t max_nb_layers;
+
+	/* Index map */
+	struct cnxk_ml_index_map *index_map;
 };
 
 #endif /* _CNXK_ML_DEV_H_ */
diff --git a/drivers/ml/cnxk/cnxk_ml_ops.c b/drivers/ml/cnxk/cnxk_ml_ops.c
index aa56dd2276..1d8b84269d 100644
--- a/drivers/ml/cnxk/cnxk_ml_ops.c
+++ b/drivers/ml/cnxk/cnxk_ml_ops.c
@@ -10,6 +10,9 @@ 
 #include "cnxk_ml_model.h"
 #include "cnxk_ml_ops.h"
 
+/* ML model macros */
+#define CNXK_ML_MODEL_MEMZONE_NAME "ml_cnxk_model_mz"
+
 static void
 qp_memzone_name_get(char *name, int size, int dev_id, int qp_id)
 {
@@ -137,6 +140,7 @@  cnxk_ml_dev_configure(struct rte_ml_dev *dev, const struct rte_ml_dev_config *co
 	uint16_t model_id;
 	uint32_t mz_size;
 	uint16_t qp_id;
+	uint64_t i;
 	int ret;
 
 	if (dev == NULL)
@@ -240,7 +244,7 @@  cnxk_ml_dev_configure(struct rte_ml_dev *dev, const struct rte_ml_dev_config *co
 						plt_err("Could not stop model %u", model_id);
 				}
 				if (model->state == ML_CNXK_MODEL_STATE_LOADED) {
-					if (cn10k_ml_model_unload(dev, model_id) != 0)
+					if (cnxk_ml_model_unload(dev, model_id) != 0)
 						plt_err("Could not unload model %u", model_id);
 				}
 				dev->data->models[model_id] = NULL;
@@ -271,6 +275,23 @@  cnxk_ml_dev_configure(struct rte_ml_dev *dev, const struct rte_ml_dev_config *co
 	cnxk_mldev->max_nb_layers =
 		cnxk_mldev->cn10k_mldev.fw.req->cn10k_req.jd.fw_load.cap.s.max_models;
 
+	/* Allocate and initialize index_map */
+	if (cnxk_mldev->index_map == NULL) {
+		cnxk_mldev->index_map =
+			rte_zmalloc("cnxk_ml_index_map",
+				    sizeof(struct cnxk_ml_index_map) * cnxk_mldev->max_nb_layers,
+				    RTE_CACHE_LINE_SIZE);
+		if (cnxk_mldev->index_map == NULL) {
+			plt_err("Failed to get memory for index_map, nb_layers %" PRIu64,
+				cnxk_mldev->max_nb_layers);
+			ret = -ENOMEM;
+			goto error;
+		}
+	}
+
+	for (i = 0; i < cnxk_mldev->max_nb_layers; i++)
+		cnxk_mldev->index_map[i].active = false;
+
 	cnxk_mldev->nb_models_loaded = 0;
 	cnxk_mldev->nb_models_started = 0;
 	cnxk_mldev->nb_models_stopped = 0;
@@ -303,6 +324,9 @@  cnxk_ml_dev_close(struct rte_ml_dev *dev)
 	if (cn10k_ml_dev_close(cnxk_mldev) != 0)
 		plt_err("Failed to close CN10K ML Device");
 
+	if (cnxk_mldev->index_map)
+		rte_free(cnxk_mldev->index_map);
+
 	/* Stop and unload all models */
 	for (model_id = 0; model_id < dev->data->nb_models; model_id++) {
 		model = dev->data->models[model_id];
@@ -312,7 +336,7 @@  cnxk_ml_dev_close(struct rte_ml_dev *dev)
 					plt_err("Could not stop model %u", model_id);
 			}
 			if (model->state == ML_CNXK_MODEL_STATE_LOADED) {
-				if (cn10k_ml_model_unload(dev, model_id) != 0)
+				if (cnxk_ml_model_unload(dev, model_id) != 0)
 					plt_err("Could not unload model %u", model_id);
 			}
 			dev->data->models[model_id] = NULL;
@@ -428,6 +452,118 @@  cnxk_ml_dev_queue_pair_setup(struct rte_ml_dev *dev, uint16_t queue_pair_id,
 	return 0;
 }
 
+static int
+cnxk_ml_model_load(struct rte_ml_dev *dev, struct rte_ml_model_params *params, uint16_t *model_id)
+{
+	struct rte_ml_dev_info dev_info;
+	struct cnxk_ml_dev *cnxk_mldev;
+	struct cnxk_ml_model *model;
+
+	char str[RTE_MEMZONE_NAMESIZE];
+	const struct plt_memzone *mz;
+	uint64_t model_info_size;
+	uint16_t lcl_model_id;
+	uint64_t mz_size;
+	bool found;
+	int ret;
+
+	if (dev == NULL)
+		return -EINVAL;
+
+	cnxk_mldev = dev->data->dev_private;
+
+	/* Find model ID */
+	found = false;
+	for (lcl_model_id = 0; lcl_model_id < dev->data->nb_models; lcl_model_id++) {
+		if (dev->data->models[lcl_model_id] == NULL) {
+			found = true;
+			break;
+		}
+	}
+
+	if (!found) {
+		plt_err("No slots available to load new model");
+		return -ENOMEM;
+	}
+
+	/* Compute memzone size */
+	cnxk_ml_dev_info_get(dev, &dev_info);
+	mz_size = PLT_ALIGN_CEIL(sizeof(struct cnxk_ml_model), dev_info.align_size);
+	model_info_size = sizeof(struct rte_ml_model_info) +
+			  ML_CNXK_MODEL_MAX_INPUT_OUTPUT * sizeof(struct rte_ml_io_info) +
+			  ML_CNXK_MODEL_MAX_INPUT_OUTPUT * sizeof(struct rte_ml_io_info);
+	model_info_size = PLT_ALIGN_CEIL(model_info_size, dev_info.align_size);
+	mz_size += model_info_size;
+
+	/* Allocate memzone for model object */
+	snprintf(str, RTE_MEMZONE_NAMESIZE, "%s_%u", CNXK_ML_MODEL_MEMZONE_NAME, lcl_model_id);
+	mz = plt_memzone_reserve_aligned(str, mz_size, 0, dev_info.align_size);
+	if (!mz) {
+		plt_err("Failed to allocate memory for cnxk_ml_model: %s", str);
+		return -ENOMEM;
+	}
+
+	model = mz->addr;
+	model->cnxk_mldev = cnxk_mldev;
+	model->model_id = lcl_model_id;
+	model->info = PLT_PTR_ADD(
+		model, PLT_ALIGN_CEIL(sizeof(struct cnxk_ml_model), dev_info.align_size));
+	dev->data->models[lcl_model_id] = model;
+
+	ret = cn10k_ml_model_load(cnxk_mldev, params, model);
+	if (ret != 0)
+		goto error;
+
+	plt_spinlock_init(&model->lock);
+	model->state = ML_CNXK_MODEL_STATE_LOADED;
+	cnxk_mldev->nb_models_loaded++;
+
+	*model_id = lcl_model_id;
+
+	return 0;
+
+error:
+	rte_memzone_free(mz);
+
+	return ret;
+}
+
+int
+cnxk_ml_model_unload(struct rte_ml_dev *dev, uint16_t model_id)
+{
+	struct cnxk_ml_dev *cnxk_mldev;
+	struct cnxk_ml_model *model;
+
+	char str[RTE_MEMZONE_NAMESIZE];
+	int ret;
+
+	if (dev == NULL)
+		return -EINVAL;
+
+	cnxk_mldev = dev->data->dev_private;
+
+	model = dev->data->models[model_id];
+	if (model == NULL) {
+		plt_err("Invalid model_id = %u", model_id);
+		return -EINVAL;
+	}
+
+	if (model->state != ML_CNXK_MODEL_STATE_LOADED) {
+		plt_err("Cannot unload. Model in use.");
+		return -EBUSY;
+	}
+
+	ret = cn10k_ml_model_unload(cnxk_mldev, model);
+	if (ret != 0)
+		return ret;
+
+	dev->data->models[model_id] = NULL;
+	cnxk_mldev->nb_models_unloaded++;
+
+	snprintf(str, RTE_MEMZONE_NAMESIZE, "%s_%u", CNXK_ML_MODEL_MEMZONE_NAME, model_id);
+	return plt_memzone_free(plt_memzone_lookup(str));
+}
+
 struct rte_ml_dev_ops cnxk_ml_ops = {
 	/* Device control ops */
 	.dev_info_get = cnxk_ml_dev_info_get,
@@ -451,8 +587,8 @@  struct rte_ml_dev_ops cnxk_ml_ops = {
 	.dev_xstats_reset = cn10k_ml_dev_xstats_reset,
 
 	/* Model ops */
-	.model_load = cn10k_ml_model_load,
-	.model_unload = cn10k_ml_model_unload,
+	.model_load = cnxk_ml_model_load,
+	.model_unload = cnxk_ml_model_unload,
 	.model_start = cn10k_ml_model_start,
 	.model_stop = cn10k_ml_model_stop,
 	.model_info_get = cn10k_ml_model_info_get,
diff --git a/drivers/ml/cnxk/cnxk_ml_ops.h b/drivers/ml/cnxk/cnxk_ml_ops.h
index a925c07580..bc14f6e5b9 100644
--- a/drivers/ml/cnxk/cnxk_ml_ops.h
+++ b/drivers/ml/cnxk/cnxk_ml_ops.h
@@ -62,4 +62,6 @@  struct cnxk_ml_qp {
 
 extern struct rte_ml_dev_ops cnxk_ml_ops;
 
+int cnxk_ml_model_unload(struct rte_ml_dev *dev, uint16_t model_id);
+
 #endif /* _CNXK_ML_OPS_H_ */