@@ -508,8 +508,10 @@ cn10k_ml_layer_load(void *device, uint16_t model_id, const char *layer_name, uin
int qp_id;
int ret;
- PLT_SET_USED(size);
+#ifndef RTE_MLDEV_CNXK_ENABLE_MVTVM
PLT_SET_USED(layer_name);
+#endif
+ PLT_SET_USED(size);
cnxk_mldev = (struct cnxk_ml_dev *)device;
if (cnxk_mldev == NULL) {
@@ -523,6 +525,24 @@ cn10k_ml_layer_load(void *device, uint16_t model_id, const char *layer_name, uin
return -EINVAL;
}
+#ifdef RTE_MLDEV_CNXK_ENABLE_MVTVM
+ if (model->type == ML_CNXK_MODEL_TYPE_TVM) {
+ for (layer_id = 0; layer_id < model->mvtvm.metadata.model.nb_layers; layer_id++) {
+ if (strcmp(model->layer[layer_id].name, layer_name) == 0)
+ break;
+ }
+
+ if (layer_id == model->mvtvm.metadata.model.nb_layers) {
+ plt_err("Invalid layer name: %s", layer_name);
+ return -EINVAL;
+ }
+
+ if (model->layer[layer_id].type != ML_CNXK_LAYER_TYPE_MRVL) {
+ plt_err("Invalid layer name / type: %s", layer_name);
+ return -EINVAL;
+ }
+ }
+#endif
layer = &model->layer[layer_id];
ret = cn10k_ml_model_metadata_check(buffer, size);
@@ -11,6 +11,8 @@
#include "cnxk_ml_io.h"
+struct cnxk_ml_model;
+
/* Maximum number of objects per model */
#define ML_MVTVM_MODEL_OBJECT_MAX 3
@@ -9,6 +9,8 @@
#include <rte_mldev.h>
#include <rte_mldev_pmd.h>
+#include "cn10k_ml_ops.h"
+
#include "mvtvm_ml_model.h"
#include "mvtvm_ml_ops.h"
@@ -53,9 +55,13 @@ mvtvm_ml_model_load(struct cnxk_ml_dev *cnxk_mldev, struct rte_ml_model_params *
struct cnxk_ml_model *model)
{
struct mvtvm_ml_model_object object[ML_MVTVM_MODEL_OBJECT_MAX];
+ struct tvmrt_glow_callback *callback;
char str[RTE_MEMZONE_NAMESIZE];
const struct plt_memzone *mz;
size_t model_object_size = 0;
+ uint16_t nb_mrvl_layers;
+ uint16_t nb_llvm_layers;
+ uint8_t layer_id = 0;
uint64_t mz_size = 0;
int ret;
@@ -103,5 +109,82 @@ mvtvm_ml_model_load(struct cnxk_ml_dev *cnxk_mldev, struct rte_ml_model_params *
rte_memcpy(model->mvtvm.object.params.addr, object[2].buffer, object[2].size);
rte_free(object[2].buffer);
+ /* Get metadata - stage 1 */
+ ret = tvmdp_model_metadata_get_stage1(model->mvtvm.object.json.addr,
+ model->mvtvm.object.json.size,
+ &model->mvtvm.metadata);
+ if (ret != 0) {
+ plt_err("TVMDP: Failed to parse metadata - stage 1, model_id = %u, error = %d",
+ model->model_id, ret);
+ goto error;
+ }
+
+ /* Set model fields */
+ plt_strlcpy(model->name, model->mvtvm.metadata.model.name, TVMDP_NAME_STRLEN);
+ model->batch_size = 1;
+ model->nb_layers = model->mvtvm.metadata.model.nb_layers;
+
+ /* Update layer info */
+ nb_mrvl_layers = 0;
+ nb_llvm_layers = 0;
+ for (layer_id = 0; layer_id < model->mvtvm.metadata.model.nb_layers; layer_id++) {
+ strncpy(model->layer[layer_id].name,
+ model->mvtvm.metadata.model.layer[layer_id].name, TVMDP_NAME_STRLEN);
+ if (strcmp(model->mvtvm.metadata.model.layer[layer_id].type, "mrvl") == 0 ||
+ strcmp(model->mvtvm.metadata.model.layer[layer_id].type, "MRVL") == 0) {
+ model->layer[layer_id].type = ML_CNXK_LAYER_TYPE_MRVL;
+ nb_mrvl_layers++;
+ } else if (strcmp(model->mvtvm.metadata.model.layer[layer_id].type, "llvm") == 0 ||
+ strcmp(model->mvtvm.metadata.model.layer[layer_id].type, "LLVM") == 0) {
+ model->layer[layer_id].type = ML_CNXK_LAYER_TYPE_LLVM;
+ nb_llvm_layers++;
+ }
+ }
+
+ if ((nb_llvm_layers == 0) && (nb_mrvl_layers == 0)) {
+ plt_err("Invalid model, nb_llvm_layers = %u, nb_mrvl_layers = %u", nb_llvm_layers,
+ nb_mrvl_layers);
+ goto error;
+ }
+
+ /* Set model subtype */
+ if ((nb_llvm_layers == 0) && (nb_mrvl_layers == 1))
+ model->subtype = ML_CNXK_MODEL_SUBTYPE_TVM_MRVL;
+ else if ((nb_llvm_layers > 0) && (nb_mrvl_layers == 0))
+ model->subtype = ML_CNXK_MODEL_SUBTYPE_TVM_LLVM;
+ else
+ model->subtype = ML_CNXK_MODEL_SUBTYPE_TVM_HYBRID;
+
+ /* Set callback function array */
+ if (model->subtype != ML_CNXK_MODEL_SUBTYPE_TVM_LLVM) {
+ callback = &model->mvtvm.cb;
+ callback->tvmrt_glow_layer_load = cn10k_ml_layer_load;
+ callback->tvmrt_glow_layer_unload = cn10k_ml_layer_unload;
+ } else {
+ callback = NULL;
+ }
+
+ /* Initialize model in TVMDP */
+ ret = tvmdp_model_load(cnxk_mldev, model->model_id, (void *)(&model->mvtvm.object),
+ callback);
+ if (ret != 0) {
+ plt_err("TVMDP: Model load failed, model_id = %u, error = %d", model->model_id,
+ ret);
+ goto error;
+ }
+
+ /* Get model metadata - stage 2 */
+ ret = tvmdp_model_metadata_get_stage2(model->model_id, &model->mvtvm.metadata);
+ if (ret != 0) {
+ plt_err("TVMDP: Failed to get metadata, model_id = %u, error = %d\n",
+ model->model_id, ret);
+ goto error;
+ }
+
return 0;
+
+error:
+ rte_memzone_free(mz);
+
+ return ret;
}