[v2,23/34] ml/cnxk: fetch layer info and load TVM model

Message ID 20230920072528.14185-24-syalavarthi@marvell.com (mailing list archive)
State Changes Requested, archived
Delegated to: Jerin Jacob
Headers
Series Implemenation of revised ml/cnxk driver |

Checks

Context Check Description
ci/checkpatch success coding style OK

Commit Message

Srikanth Yalavarthi Sept. 20, 2023, 7:25 a.m. UTC
  Added support to fetch TVM model layer information and
update internal structures based on the layer information
Set callback functions for layer load and unload and
enable model loading using TVMDP library. Added support
to fetch full metadata after model load.

Signed-off-by: Srikanth Yalavarthi <syalavarthi@marvell.com>
---
 drivers/ml/cnxk/cn10k_ml_ops.c   | 22 ++++++++-
 drivers/ml/cnxk/mvtvm_ml_model.h |  2 +
 drivers/ml/cnxk/mvtvm_ml_ops.c   | 83 ++++++++++++++++++++++++++++++++
 3 files changed, 106 insertions(+), 1 deletion(-)
  

Patch

diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c
index db18f32052..79217165cd 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.c
+++ b/drivers/ml/cnxk/cn10k_ml_ops.c
@@ -508,8 +508,10 @@  cn10k_ml_layer_load(void *device, uint16_t model_id, const char *layer_name, uin
 	int qp_id;
 	int ret;
 
-	PLT_SET_USED(size);
+#ifndef RTE_MLDEV_CNXK_ENABLE_MVTVM
 	PLT_SET_USED(layer_name);
+#endif
+	PLT_SET_USED(size);
 
 	cnxk_mldev = (struct cnxk_ml_dev *)device;
 	if (cnxk_mldev == NULL) {
@@ -523,6 +525,24 @@  cn10k_ml_layer_load(void *device, uint16_t model_id, const char *layer_name, uin
 		return -EINVAL;
 	}
 
+#ifdef RTE_MLDEV_CNXK_ENABLE_MVTVM
+	if (model->type == ML_CNXK_MODEL_TYPE_TVM) {
+		for (layer_id = 0; layer_id < model->mvtvm.metadata.model.nb_layers; layer_id++) {
+			if (strcmp(model->layer[layer_id].name, layer_name) == 0)
+				break;
+		}
+
+		if (layer_id == model->mvtvm.metadata.model.nb_layers) {
+			plt_err("Invalid layer name: %s", layer_name);
+			return -EINVAL;
+		}
+
+		if (model->layer[layer_id].type != ML_CNXK_LAYER_TYPE_MRVL) {
+			plt_err("Invalid layer name / type: %s", layer_name);
+			return -EINVAL;
+		}
+	}
+#endif
 	layer = &model->layer[layer_id];
 
 	ret = cn10k_ml_model_metadata_check(buffer, size);
diff --git a/drivers/ml/cnxk/mvtvm_ml_model.h b/drivers/ml/cnxk/mvtvm_ml_model.h
index 73a45a91d6..6c38217c15 100644
--- a/drivers/ml/cnxk/mvtvm_ml_model.h
+++ b/drivers/ml/cnxk/mvtvm_ml_model.h
@@ -11,6 +11,8 @@ 
 
 #include "cnxk_ml_io.h"
 
+struct cnxk_ml_model;
+
 /* Maximum number of objects per model */
 #define ML_MVTVM_MODEL_OBJECT_MAX 3
 
diff --git a/drivers/ml/cnxk/mvtvm_ml_ops.c b/drivers/ml/cnxk/mvtvm_ml_ops.c
index baa9099084..d9ec411385 100644
--- a/drivers/ml/cnxk/mvtvm_ml_ops.c
+++ b/drivers/ml/cnxk/mvtvm_ml_ops.c
@@ -7,6 +7,8 @@ 
 #include <rte_mldev.h>
 #include <rte_mldev_pmd.h>
 
+#include "cn10k_ml_ops.h"
+
 #include "mvtvm_ml_model.h"
 #include "mvtvm_ml_ops.h"
 
@@ -51,9 +53,13 @@  mvtvm_ml_model_load(struct cnxk_ml_dev *cnxk_mldev, struct rte_ml_model_params *
 		    struct cnxk_ml_model *model)
 {
 	struct mvtvm_ml_model_object object[ML_MVTVM_MODEL_OBJECT_MAX];
+	struct tvmrt_glow_callback *callback;
 	char str[RTE_MEMZONE_NAMESIZE];
 	const struct plt_memzone *mz;
 	size_t model_object_size = 0;
+	uint16_t nb_mrvl_layers;
+	uint16_t nb_llvm_layers;
+	uint8_t layer_id = 0;
 	uint64_t mz_size = 0;
 	int ret;
 
@@ -101,5 +107,82 @@  mvtvm_ml_model_load(struct cnxk_ml_dev *cnxk_mldev, struct rte_ml_model_params *
 	rte_memcpy(model->mvtvm.object.params.addr, object[2].buffer, object[2].size);
 	rte_free(object[2].buffer);
 
+	/* Get metadata - stage 1 */
+	ret = tvmdp_model_metadata_get_stage1(model->mvtvm.object.json.addr,
+					      model->mvtvm.object.json.size,
+					      &model->mvtvm.metadata);
+	if (ret != 0) {
+		plt_err("TVMDP: Failed to parse metadata - stage 1, model_id = %u, error = %d",
+			model->model_id, ret);
+		goto error;
+	}
+
+	/* Set model fields */
+	plt_strlcpy(model->name, model->mvtvm.metadata.model.name, TVMDP_NAME_STRLEN);
+	model->batch_size = 1;
+	model->nb_layers = model->mvtvm.metadata.model.nb_layers;
+
+	/* Update layer info */
+	nb_mrvl_layers = 0;
+	nb_llvm_layers = 0;
+	for (layer_id = 0; layer_id < model->mvtvm.metadata.model.nb_layers; layer_id++) {
+		strncpy(model->layer[layer_id].name,
+			model->mvtvm.metadata.model.layer[layer_id].name, TVMDP_NAME_STRLEN);
+		if (strcmp(model->mvtvm.metadata.model.layer[layer_id].type, "mrvl") == 0 ||
+		    strcmp(model->mvtvm.metadata.model.layer[layer_id].type, "MRVL") == 0) {
+			model->layer[layer_id].type = ML_CNXK_LAYER_TYPE_MRVL;
+			nb_mrvl_layers++;
+		} else if (strcmp(model->mvtvm.metadata.model.layer[layer_id].type, "llvm") == 0 ||
+			   strcmp(model->mvtvm.metadata.model.layer[layer_id].type, "LLVM") == 0) {
+			model->layer[layer_id].type = ML_CNXK_LAYER_TYPE_LLVM;
+			nb_llvm_layers++;
+		}
+	}
+
+	if ((nb_llvm_layers == 0) && (nb_mrvl_layers == 0)) {
+		plt_err("Invalid model, nb_llvm_layers = %u, nb_mrvl_layers = %u", nb_llvm_layers,
+			nb_mrvl_layers);
+		goto error;
+	}
+
+	/* Set model subtype */
+	if ((nb_llvm_layers == 0) && (nb_mrvl_layers == 1))
+		model->subtype = ML_CNXK_MODEL_SUBTYPE_TVM_MRVL;
+	else if ((nb_llvm_layers > 0) && (nb_mrvl_layers == 0))
+		model->subtype = ML_CNXK_MODEL_SUBTYPE_TVM_LLVM;
+	else
+		model->subtype = ML_CNXK_MODEL_SUBTYPE_TVM_HYBRID;
+
+	/* Set callback function array */
+	if (model->subtype != ML_CNXK_MODEL_SUBTYPE_TVM_LLVM) {
+		callback = &model->mvtvm.cb;
+		callback->tvmrt_glow_layer_load = cn10k_ml_layer_load;
+		callback->tvmrt_glow_layer_unload = cn10k_ml_layer_unload;
+	} else {
+		callback = NULL;
+	}
+
+	/* Initialize model in TVMDP */
+	ret = tvmdp_model_load(cnxk_mldev, model->model_id, (void *)(&model->mvtvm.object),
+			       callback);
+	if (ret != 0) {
+		plt_err("TVMDP: Model load failed, model_id = %u, error = %d", model->model_id,
+			ret);
+		goto error;
+	}
+
+	/* Get model metadata - stage 2 */
+	ret = tvmdp_model_metadata_get_stage2(model->model_id, &model->mvtvm.metadata);
+	if (ret != 0) {
+		plt_err("TVMDP: Failed to get metadata, model_id = %u, error = %d\n",
+			model->model_id, ret);
+		goto error;
+	}
+
 	return 0;
+
+error:
+	rte_memzone_free(mz);
+
+	return ret;
 }