[v5,11/34] ml/cnxk: update model utility functions

Message ID 20231018064806.24145-12-syalavarthi@marvell.com (mailing list archive)
State Superseded, archived
Delegated to: Jerin Jacob
Headers
Series Implementation of revised ml/cnxk driver |

Checks

Context Check Description
ci/checkpatch success coding style OK

Commit Message

Srikanth Yalavarthi Oct. 18, 2023, 6:47 a.m. UTC
  Added cnxk wrapper function to update model params and
fetch model info.

Signed-off-by: Srikanth Yalavarthi <syalavarthi@marvell.com>
---
 drivers/ml/cnxk/cn10k_ml_ops.c | 38 ++++++---------------------
 drivers/ml/cnxk/cn10k_ml_ops.h |  5 ++--
 drivers/ml/cnxk/cnxk_ml_ops.c  | 48 ++++++++++++++++++++++++++++++++--
 3 files changed, 56 insertions(+), 35 deletions(-)
  

Patch

diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c
index c677861645..c0d6216485 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.c
+++ b/drivers/ml/cnxk/cn10k_ml_ops.c
@@ -1835,45 +1835,23 @@  cn10k_ml_model_stop(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model)
 }
 
 int
-cn10k_ml_model_info_get(struct rte_ml_dev *dev, uint16_t model_id,
-			struct rte_ml_model_info *model_info)
+cn10k_ml_model_params_update(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model,
+			     void *buffer)
 {
-	struct cnxk_ml_model *model;
-
-	model = dev->data->models[model_id];
-
-	if (model == NULL) {
-		plt_err("Invalid model_id = %u", model_id);
-		return -EINVAL;
-	}
-
-	rte_memcpy(model_info, model->info, sizeof(struct rte_ml_model_info));
-	model_info->input_info = ((struct rte_ml_model_info *)model->info)->input_info;
-	model_info->output_info = ((struct rte_ml_model_info *)model->info)->output_info;
-
-	return 0;
-}
-
-int
-cn10k_ml_model_params_update(struct rte_ml_dev *dev, uint16_t model_id, void *buffer)
-{
-	struct cnxk_ml_model *model;
-
-	model = dev->data->models[model_id];
+	struct cnxk_ml_layer *layer;
 
-	if (model == NULL) {
-		plt_err("Invalid model_id = %u", model_id);
-		return -EINVAL;
-	}
+	RTE_SET_USED(cnxk_mldev);
 
 	if (model->state == ML_CNXK_MODEL_STATE_UNKNOWN)
 		return -1;
 	else if (model->state != ML_CNXK_MODEL_STATE_LOADED)
 		return -EBUSY;
 
+	layer = &model->layer[0];
+
 	/* Update model weights & bias */
-	rte_memcpy(model->layer[0].glow.addr.wb_load_addr, buffer,
-		   model->layer[0].glow.metadata.weights_bias.file_size);
+	rte_memcpy(layer->glow.addr.wb_load_addr, buffer,
+		   layer->glow.metadata.weights_bias.file_size);
 
 	return 0;
 }
diff --git a/drivers/ml/cnxk/cn10k_ml_ops.h b/drivers/ml/cnxk/cn10k_ml_ops.h
index a222a43d55..ef12069f0d 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.h
+++ b/drivers/ml/cnxk/cn10k_ml_ops.h
@@ -317,9 +317,8 @@  int cn10k_ml_model_load(struct cnxk_ml_dev *cnxk_mldev, struct rte_ml_model_para
 int cn10k_ml_model_unload(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model);
 int cn10k_ml_model_start(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model);
 int cn10k_ml_model_stop(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model);
-int cn10k_ml_model_info_get(struct rte_ml_dev *dev, uint16_t model_id,
-			    struct rte_ml_model_info *model_info);
-int cn10k_ml_model_params_update(struct rte_ml_dev *dev, uint16_t model_id, void *buffer);
+int cn10k_ml_model_params_update(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model,
+				 void *buffer);
 
 /* I/O ops */
 int cn10k_ml_io_quantize(struct rte_ml_dev *dev, uint16_t model_id,
diff --git a/drivers/ml/cnxk/cnxk_ml_ops.c b/drivers/ml/cnxk/cnxk_ml_ops.c
index b61ed45876..9ce37fcfd1 100644
--- a/drivers/ml/cnxk/cnxk_ml_ops.c
+++ b/drivers/ml/cnxk/cnxk_ml_ops.c
@@ -604,6 +604,50 @@  cnxk_ml_model_stop(struct rte_ml_dev *dev, uint16_t model_id)
 	return cn10k_ml_model_stop(cnxk_mldev, model);
 }
 
+static int
+cnxk_ml_model_info_get(struct rte_ml_dev *dev, uint16_t model_id,
+		       struct rte_ml_model_info *model_info)
+{
+	struct rte_ml_model_info *info;
+	struct cnxk_ml_model *model;
+
+	if ((dev == NULL) || (model_info == NULL))
+		return -EINVAL;
+
+	model = dev->data->models[model_id];
+	if (model == NULL) {
+		plt_err("Invalid model_id = %u", model_id);
+		return -EINVAL;
+	}
+
+	info = (struct rte_ml_model_info *)model->info;
+	rte_memcpy(model_info, info, sizeof(struct rte_ml_model_info));
+	model_info->input_info = info->input_info;
+	model_info->output_info = info->output_info;
+
+	return 0;
+}
+
+static int
+cnxk_ml_model_params_update(struct rte_ml_dev *dev, uint16_t model_id, void *buffer)
+{
+	struct cnxk_ml_dev *cnxk_mldev;
+	struct cnxk_ml_model *model;
+
+	if ((dev == NULL) || (buffer == NULL))
+		return -EINVAL;
+
+	cnxk_mldev = dev->data->dev_private;
+
+	model = dev->data->models[model_id];
+	if (model == NULL) {
+		plt_err("Invalid model_id = %u", model_id);
+		return -EINVAL;
+	}
+
+	return cn10k_ml_model_params_update(cnxk_mldev, model, buffer);
+}
+
 struct rte_ml_dev_ops cnxk_ml_ops = {
 	/* Device control ops */
 	.dev_info_get = cnxk_ml_dev_info_get,
@@ -631,8 +675,8 @@  struct rte_ml_dev_ops cnxk_ml_ops = {
 	.model_unload = cnxk_ml_model_unload,
 	.model_start = cnxk_ml_model_start,
 	.model_stop = cnxk_ml_model_stop,
-	.model_info_get = cn10k_ml_model_info_get,
-	.model_params_update = cn10k_ml_model_params_update,
+	.model_info_get = cnxk_ml_model_info_get,
+	.model_params_update = cnxk_ml_model_params_update,
 
 	/* I/O ops */
 	.io_quantize = cn10k_ml_io_quantize,