[v4,27/34] ml/cnxk: update internal TVM model info structure

Message ID 20231017165951.27299-28-syalavarthi@marvell.com (mailing list archive)
State Superseded, archived
Delegated to: Jerin Jacob
Headers
Series Implementation of revised ml/cnxk driver |

Checks

Context Check Description
ci/checkpatch success coding style OK

Commit Message

Srikanth Yalavarthi Oct. 17, 2023, 4:59 p.m. UTC
  From: Prince Takkar <ptakkar@marvell.com>

Added support to update internal model info structure
for TVM models.

Signed-off-by: Prince Takkar <ptakkar@marvell.com>
Signed-off-by: Srikanth Yalavarthi <syalavarthi@marvell.com>
---
 drivers/ml/cnxk/mvtvm_ml_model.c | 65 ++++++++++++++++++++++++++++++++
 drivers/ml/cnxk/mvtvm_ml_model.h |  2 +
 drivers/ml/cnxk/mvtvm_ml_ops.c   |  3 ++
 3 files changed, 70 insertions(+)
  

Patch

diff --git a/drivers/ml/cnxk/mvtvm_ml_model.c b/drivers/ml/cnxk/mvtvm_ml_model.c
index 14f4b258d8..569147aca7 100644
--- a/drivers/ml/cnxk/mvtvm_ml_model.c
+++ b/drivers/ml/cnxk/mvtvm_ml_model.c
@@ -11,6 +11,7 @@ 
 
 #include <roc_api.h>
 
+#include "cnxk_ml_dev.h"
 #include "cnxk_ml_model.h"
 
 /* Objects list */
@@ -246,3 +247,67 @@  mvtvm_ml_model_io_info_get(struct cnxk_ml_model *model, uint16_t layer_id)
 
 	return &model->mvtvm.info;
 }
+
+void
+mvtvm_ml_model_info_set(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model)
+{
+	struct tvmdp_model_metadata *metadata;
+	struct rte_ml_model_info *info;
+	struct rte_ml_io_info *output;
+	struct rte_ml_io_info *input;
+	uint8_t i;
+
+	info = PLT_PTR_CAST(model->info);
+	input = PLT_PTR_ADD(info, sizeof(struct rte_ml_model_info));
+	output = PLT_PTR_ADD(input, ML_CNXK_MODEL_MAX_INPUT_OUTPUT * sizeof(struct rte_ml_io_info));
+
+	/* Reset model info */
+	memset(info, 0, sizeof(struct rte_ml_model_info));
+
+	if (model->subtype == ML_CNXK_MODEL_SUBTYPE_TVM_MRVL)
+		goto tvm_mrvl_model;
+
+	metadata = &model->mvtvm.metadata;
+	rte_memcpy(info->name, metadata->model.name, TVMDP_NAME_STRLEN);
+	snprintf(info->version, RTE_ML_STR_MAX, "%u.%u.%u.%u", metadata->model.version[0],
+		 metadata->model.version[1], metadata->model.version[2],
+		 metadata->model.version[3]);
+	info->model_id = model->model_id;
+	info->device_id = cnxk_mldev->mldev->data->dev_id;
+	info->io_layout = RTE_ML_IO_LAYOUT_SPLIT;
+	info->min_batches = model->batch_size;
+	info->max_batches = model->batch_size;
+	info->nb_inputs = metadata->model.num_input;
+	info->input_info = input;
+	info->nb_outputs = metadata->model.num_output;
+	info->output_info = output;
+	info->wb_size = 0;
+
+	/* Set input info */
+	for (i = 0; i < info->nb_inputs; i++) {
+		rte_memcpy(input[i].name, metadata->input[i].name, MRVL_ML_INPUT_NAME_LEN);
+		input[i].nb_dims = metadata->input[i].ndim;
+		input[i].shape = &model->mvtvm.info.input[i].shape[0];
+		input[i].type = model->mvtvm.info.input[i].qtype;
+		input[i].nb_elements = model->mvtvm.info.input[i].nb_elements;
+		input[i].size = model->mvtvm.info.input[i].nb_elements *
+				rte_ml_io_type_size_get(model->mvtvm.info.input[i].qtype);
+	}
+
+	/* Set output info */
+	for (i = 0; i < info->nb_outputs; i++) {
+		rte_memcpy(output[i].name, metadata->output[i].name, MRVL_ML_OUTPUT_NAME_LEN);
+		output[i].nb_dims = metadata->output[i].ndim;
+		output[i].shape = &model->mvtvm.info.output[i].shape[0];
+		output[i].type = model->mvtvm.info.output[i].qtype;
+		output[i].nb_elements = model->mvtvm.info.output[i].nb_elements;
+		output[i].size = model->mvtvm.info.output[i].nb_elements *
+				 rte_ml_io_type_size_get(model->mvtvm.info.output[i].qtype);
+	}
+
+	return;
+
+tvm_mrvl_model:
+	cn10k_ml_model_info_set(cnxk_mldev, model, &model->mvtvm.info,
+				&model->layer[0].glow.metadata);
+}
diff --git a/drivers/ml/cnxk/mvtvm_ml_model.h b/drivers/ml/cnxk/mvtvm_ml_model.h
index e86581bc6a..a1247ffbde 100644
--- a/drivers/ml/cnxk/mvtvm_ml_model.h
+++ b/drivers/ml/cnxk/mvtvm_ml_model.h
@@ -11,6 +11,7 @@ 
 
 #include "cnxk_ml_io.h"
 
+struct cnxk_ml_dev;
 struct cnxk_ml_model;
 
 /* Maximum number of objects per model */
@@ -52,5 +53,6 @@  int mvtvm_ml_model_get_layer_id(struct cnxk_ml_model *model, const char *layer_n
 				uint16_t *layer_id);
 void mvtvm_ml_model_io_info_set(struct cnxk_ml_model *model);
 struct cnxk_ml_io_info *mvtvm_ml_model_io_info_get(struct cnxk_ml_model *model, uint16_t layer_id);
+void mvtvm_ml_model_info_set(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model);
 
 #endif /* _MVTVM_ML_MODEL_H_ */
diff --git a/drivers/ml/cnxk/mvtvm_ml_ops.c b/drivers/ml/cnxk/mvtvm_ml_ops.c
index 1d0b3544a7..f13ba76207 100644
--- a/drivers/ml/cnxk/mvtvm_ml_ops.c
+++ b/drivers/ml/cnxk/mvtvm_ml_ops.c
@@ -178,6 +178,9 @@  mvtvm_ml_model_load(struct cnxk_ml_dev *cnxk_mldev, struct rte_ml_model_params *
 	/* Update model I/O data */
 	mvtvm_ml_model_io_info_set(model);
 
+	/* Set model info */
+	mvtvm_ml_model_info_set(cnxk_mldev, model);
+
 	return 0;
 
 error: