@@ -798,7 +798,9 @@ cn10k_ml_layer_start(void *device, uint16_t model_id, const char *layer_name)
bool locked;
int ret = 0;
+#ifndef RTE_MLDEV_CNXK_ENABLE_MVTVM
PLT_SET_USED(layer_name);
+#endif
cnxk_mldev = (struct cnxk_ml_dev *)device;
if (cnxk_mldev == NULL) {
@@ -812,6 +814,25 @@ cn10k_ml_layer_start(void *device, uint16_t model_id, const char *layer_name)
return -EINVAL;
}
+#ifdef RTE_MLDEV_CNXK_ENABLE_MVTVM
+ if (model->type == ML_CNXK_MODEL_TYPE_TVM) {
+ for (layer_id = 0; layer_id < model->mvtvm.metadata.model.nb_layers; layer_id++) {
+ if (strcmp(model->layer[layer_id].name, layer_name) == 0)
+ break;
+ }
+
+ if (layer_id == model->mvtvm.metadata.model.nb_layers) {
+ plt_err("Invalid layer name: %s", layer_name);
+ return -EINVAL;
+ }
+
+ if (model->layer[layer_id].type != ML_CNXK_LAYER_TYPE_MRVL) {
+ plt_err("Invalid layer name / type: %s", layer_name);
+ return -EINVAL;
+ }
+ }
+#endif
+
layer = &model->layer[layer_id];
cn10k_mldev = &cnxk_mldev->cn10k_mldev;
ocm = &cn10k_mldev->ocm;
@@ -981,7 +1002,9 @@ cn10k_ml_layer_stop(void *device, uint16_t model_id, const char *layer_name)
bool locked;
int ret = 0;
+#ifndef RTE_MLDEV_CNXK_ENABLE_MVTVM
PLT_SET_USED(layer_name);
+#endif
cnxk_mldev = (struct cnxk_ml_dev *)device;
if (cnxk_mldev == NULL) {
@@ -995,6 +1018,25 @@ cn10k_ml_layer_stop(void *device, uint16_t model_id, const char *layer_name)
return -EINVAL;
}
+#ifdef RTE_MLDEV_CNXK_ENABLE_MVTVM
+ if (model->type == ML_CNXK_MODEL_TYPE_TVM) {
+ for (layer_id = 0; layer_id < model->mvtvm.metadata.model.nb_layers; layer_id++) {
+ if (strcmp(model->layer[layer_id].name, layer_name) == 0)
+ break;
+ }
+
+ if (layer_id == model->mvtvm.metadata.model.nb_layers) {
+ plt_err("Invalid layer name: %s", layer_name);
+ return -EINVAL;
+ }
+
+ if (model->layer[layer_id].type != ML_CNXK_LAYER_TYPE_MRVL) {
+ plt_err("Invalid layer name / type: %s", layer_name);
+ return -EINVAL;
+ }
+ }
+#endif
+
layer = &model->layer[layer_id];
cn10k_mldev = &cnxk_mldev->cn10k_mldev;
ocm = &cn10k_mldev->ocm;
@@ -1233,7 +1233,14 @@ cnxk_ml_model_start(struct rte_ml_dev *dev, uint16_t model_id)
return -EINVAL;
}
- return cn10k_ml_model_start(cnxk_mldev, model);
+ if (model->type == ML_CNXK_MODEL_TYPE_GLOW)
+ return cn10k_ml_model_start(cnxk_mldev, model);
+#ifdef RTE_MLDEV_CNXK_ENABLE_MVTVM
+ else
+ return mvtvm_ml_model_start(cnxk_mldev, model);
+#endif
+
+ return 0;
}
int
@@ -1253,7 +1260,14 @@ cnxk_ml_model_stop(struct rte_ml_dev *dev, uint16_t model_id)
return -EINVAL;
}
- return cn10k_ml_model_stop(cnxk_mldev, model);
+ if (model->type == ML_CNXK_MODEL_TYPE_GLOW)
+ return cn10k_ml_model_stop(cnxk_mldev, model);
+#ifdef RTE_MLDEV_CNXK_ENABLE_MVTVM
+ else
+ return mvtvm_ml_model_stop(cnxk_mldev, model);
+#endif
+
+ return 0;
}
static int
@@ -219,3 +219,55 @@ mvtvm_ml_model_unload(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *mode
return plt_memzone_free(mz);
}
+
+int
+mvtvm_ml_model_start(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model)
+{
+ struct cnxk_ml_layer *layer;
+
+ uint16_t layer_id = 0;
+ int ret = 0;
+
+next_layer:
+ layer = &model->layer[layer_id];
+ if (layer->type == ML_CNXK_LAYER_TYPE_MRVL) {
+ ret = cn10k_ml_layer_start(cnxk_mldev, model->model_id, layer->name);
+ if (ret != 0) {
+ plt_err("Layer start failed, model_id = %u, layer_name = %s, error = %d",
+ model->model_id, layer->name, ret);
+ return ret;
+ }
+ }
+ layer_id++;
+
+ if (layer_id < model->nb_layers)
+ goto next_layer;
+
+ return 0;
+}
+
+int
+mvtvm_ml_model_stop(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model)
+{
+ struct cnxk_ml_layer *layer;
+
+ uint16_t layer_id = 0;
+ int ret = 0;
+
+next_layer:
+ layer = &model->layer[layer_id];
+ if (layer->type == ML_CNXK_LAYER_TYPE_MRVL) {
+ ret = cn10k_ml_layer_stop(cnxk_mldev, model->model_id, layer->name);
+ if (ret != 0) {
+ plt_err("Layer stop failed, model_id = %u, layer_name = %s, error = %d",
+ model->model_id, layer->name, ret);
+ return ret;
+ }
+ }
+ layer_id++;
+
+ if (layer_id < model->nb_layers)
+ goto next_layer;
+
+ return 0;
+}
@@ -15,5 +15,7 @@ int mvtvm_ml_dev_close(struct cnxk_ml_dev *cnxk_mldev);
int mvtvm_ml_model_load(struct cnxk_ml_dev *cnxk_mldev, struct rte_ml_model_params *params,
struct cnxk_ml_model *model);
int mvtvm_ml_model_unload(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model);
+int mvtvm_ml_model_start(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model);
+int mvtvm_ml_model_stop(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model);
#endif /* _MVTVM_ML_OPS_H_ */