[v1,23/34] ml/cnxk: fetch layer info and load TVM model

Message ID 20230830155927.3566-24-syalavarthi@marvell.com (mailing list archive)
State Superseded, archived
Delegated to: Thomas Monjalon
Headers
Series Implemenation of revised ml/cnxk driver |

Checks

Context Check Description
ci/checkpatch success coding style OK

Commit Message

Srikanth Yalavarthi Aug. 30, 2023, 3:59 p.m. UTC
  Added support to fetch TVM model layer information and
update internal structures based on the layer information
Set callback functions for layer load and unload and
enable model loading using TVMDP library. Added support
to fetch full metadata after model load.

Signed-off-by: Srikanth Yalavarthi <syalavarthi@marvell.com>
---
 drivers/ml/cnxk/cn10k_ml_ops.c   | 22 ++++++++-
 drivers/ml/cnxk/mvtvm_ml_model.h |  2 +
 drivers/ml/cnxk/mvtvm_ml_ops.c   | 83 ++++++++++++++++++++++++++++++++
 3 files changed, 106 insertions(+), 1 deletion(-)
  

Patch

diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c
index db18f320527..79217165cd5 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.c
+++ b/drivers/ml/cnxk/cn10k_ml_ops.c
@@ -508,8 +508,10 @@  cn10k_ml_layer_load(void *device, uint16_t model_id, const char *layer_name, uin
 	int qp_id;
 	int ret;
 
-	PLT_SET_USED(size);
+#ifndef RTE_MLDEV_CNXK_ENABLE_MVTVM
 	PLT_SET_USED(layer_name);
+#endif
+	PLT_SET_USED(size);
 
 	cnxk_mldev = (struct cnxk_ml_dev *)device;
 	if (cnxk_mldev == NULL) {
@@ -523,6 +525,24 @@  cn10k_ml_layer_load(void *device, uint16_t model_id, const char *layer_name, uin
 		return -EINVAL;
 	}
 
+#ifdef RTE_MLDEV_CNXK_ENABLE_MVTVM
+	if (model->type == ML_CNXK_MODEL_TYPE_TVM) {
+		for (layer_id = 0; layer_id < model->mvtvm.metadata.model.nb_layers; layer_id++) {
+			if (strcmp(model->layer[layer_id].name, layer_name) == 0)
+				break;
+		}
+
+		if (layer_id == model->mvtvm.metadata.model.nb_layers) {
+			plt_err("Invalid layer name: %s", layer_name);
+			return -EINVAL;
+		}
+
+		if (model->layer[layer_id].type != ML_CNXK_LAYER_TYPE_MRVL) {
+			plt_err("Invalid layer name / type: %s", layer_name);
+			return -EINVAL;
+		}
+	}
+#endif
 	layer = &model->layer[layer_id];
 
 	ret = cn10k_ml_model_metadata_check(buffer, size);
diff --git a/drivers/ml/cnxk/mvtvm_ml_model.h b/drivers/ml/cnxk/mvtvm_ml_model.h
index 73a45a91d66..6c38217c158 100644
--- a/drivers/ml/cnxk/mvtvm_ml_model.h
+++ b/drivers/ml/cnxk/mvtvm_ml_model.h
@@ -11,6 +11,8 @@ 
 
 #include "cnxk_ml_io.h"
 
+struct cnxk_ml_model;
+
 /* Maximum number of objects per model */
 #define ML_MVTVM_MODEL_OBJECT_MAX 3
 
diff --git a/drivers/ml/cnxk/mvtvm_ml_ops.c b/drivers/ml/cnxk/mvtvm_ml_ops.c
index 1bdd4515771..5c30bbf6b89 100644
--- a/drivers/ml/cnxk/mvtvm_ml_ops.c
+++ b/drivers/ml/cnxk/mvtvm_ml_ops.c
@@ -9,6 +9,8 @@ 
 #include <rte_mldev.h>
 #include <rte_mldev_pmd.h>
 
+#include "cn10k_ml_ops.h"
+
 #include "mvtvm_ml_model.h"
 #include "mvtvm_ml_ops.h"
 
@@ -53,9 +55,13 @@  mvtvm_ml_model_load(struct cnxk_ml_dev *cnxk_mldev, struct rte_ml_model_params *
 		    struct cnxk_ml_model *model)
 {
 	struct mvtvm_ml_model_object object[ML_MVTVM_MODEL_OBJECT_MAX];
+	struct tvmrt_glow_callback *callback;
 	char str[RTE_MEMZONE_NAMESIZE];
 	const struct plt_memzone *mz;
 	size_t model_object_size = 0;
+	uint16_t nb_mrvl_layers;
+	uint16_t nb_llvm_layers;
+	uint8_t layer_id = 0;
 	uint64_t mz_size = 0;
 	int ret;
 
@@ -103,5 +109,82 @@  mvtvm_ml_model_load(struct cnxk_ml_dev *cnxk_mldev, struct rte_ml_model_params *
 	rte_memcpy(model->mvtvm.object.params.addr, object[2].buffer, object[2].size);
 	rte_free(object[2].buffer);
 
+	/* Get metadata - stage 1 */
+	ret = tvmdp_model_metadata_get_stage1(model->mvtvm.object.json.addr,
+					      model->mvtvm.object.json.size,
+					      &model->mvtvm.metadata);
+	if (ret != 0) {
+		plt_err("TVMDP: Failed to parse metadata - stage 1, model_id = %u, error = %d",
+			model->model_id, ret);
+		goto error;
+	}
+
+	/* Set model fields */
+	plt_strlcpy(model->name, model->mvtvm.metadata.model.name, TVMDP_NAME_STRLEN);
+	model->batch_size = 1;
+	model->nb_layers = model->mvtvm.metadata.model.nb_layers;
+
+	/* Update layer info */
+	nb_mrvl_layers = 0;
+	nb_llvm_layers = 0;
+	for (layer_id = 0; layer_id < model->mvtvm.metadata.model.nb_layers; layer_id++) {
+		strncpy(model->layer[layer_id].name,
+			model->mvtvm.metadata.model.layer[layer_id].name, TVMDP_NAME_STRLEN);
+		if (strcmp(model->mvtvm.metadata.model.layer[layer_id].type, "mrvl") == 0 ||
+		    strcmp(model->mvtvm.metadata.model.layer[layer_id].type, "MRVL") == 0) {
+			model->layer[layer_id].type = ML_CNXK_LAYER_TYPE_MRVL;
+			nb_mrvl_layers++;
+		} else if (strcmp(model->mvtvm.metadata.model.layer[layer_id].type, "llvm") == 0 ||
+			   strcmp(model->mvtvm.metadata.model.layer[layer_id].type, "LLVM") == 0) {
+			model->layer[layer_id].type = ML_CNXK_LAYER_TYPE_LLVM;
+			nb_llvm_layers++;
+		}
+	}
+
+	if ((nb_llvm_layers == 0) && (nb_mrvl_layers == 0)) {
+		plt_err("Invalid model, nb_llvm_layers = %u, nb_mrvl_layers = %u", nb_llvm_layers,
+			nb_mrvl_layers);
+		goto error;
+	}
+
+	/* Set model subtype */
+	if ((nb_llvm_layers == 0) && (nb_mrvl_layers == 1))
+		model->subtype = ML_CNXK_MODEL_SUBTYPE_TVM_MRVL;
+	else if ((nb_llvm_layers > 0) && (nb_mrvl_layers == 0))
+		model->subtype = ML_CNXK_MODEL_SUBTYPE_TVM_LLVM;
+	else
+		model->subtype = ML_CNXK_MODEL_SUBTYPE_TVM_HYBRID;
+
+	/* Set callback function array */
+	if (model->subtype != ML_CNXK_MODEL_SUBTYPE_TVM_LLVM) {
+		callback = &model->mvtvm.cb;
+		callback->tvmrt_glow_layer_load = cn10k_ml_layer_load;
+		callback->tvmrt_glow_layer_unload = cn10k_ml_layer_unload;
+	} else {
+		callback = NULL;
+	}
+
+	/* Initialize model in TVMDP */
+	ret = tvmdp_model_load(cnxk_mldev, model->model_id, (void *)(&model->mvtvm.object),
+			       callback);
+	if (ret != 0) {
+		plt_err("TVMDP: Model load failed, model_id = %u, error = %d", model->model_id,
+			ret);
+		goto error;
+	}
+
+	/* Get model metadata - stage 2 */
+	ret = tvmdp_model_metadata_get_stage2(model->model_id, &model->mvtvm.metadata);
+	if (ret != 0) {
+		plt_err("TVMDP: Failed to get metadata, model_id = %u, error = %d\n",
+			model->model_id, ret);
+		goto error;
+	}
+
 	return 0;
+
+error:
+	rte_memzone_free(mz);
+
+	return ret;
 }