[v5,25/34] ml/cnxk: enable OCM check for multilayer TVM model

Message ID 20231018064806.24145-26-syalavarthi@marvell.com (mailing list archive)
State Superseded, archived
Delegated to: Jerin Jacob
Headers
Series Implementation of revised ml/cnxk driver |

Checks

Context Check Description
ci/checkpatch success coding style OK

Commit Message

Srikanth Yalavarthi Oct. 18, 2023, 6:47 a.m. UTC
  From: Anup Prabhu <aprabhu@marvell.com>

Enabled check for OCM size requirement for multi-layer
TVM model. Compute OCM scratch and WB requirement for
all layers during the load stage.

Signed-off-by: Anup Prabhu <aprabhu@marvell.com>
---
 drivers/ml/cnxk/cnxk_ml_ops.c | 60 +++++++++++++++++++++++++++++++++++
 1 file changed, 60 insertions(+)
  

Patch

diff --git a/drivers/ml/cnxk/cnxk_ml_ops.c b/drivers/ml/cnxk/cnxk_ml_ops.c
index cd95a3c7ad..03f4783b3f 100644
--- a/drivers/ml/cnxk/cnxk_ml_ops.c
+++ b/drivers/ml/cnxk/cnxk_ml_ops.c
@@ -1023,8 +1023,12 @@  cnxk_ml_model_load(struct rte_ml_dev *dev, struct rte_ml_model_params *params, u
 
 	char str[RTE_MEMZONE_NAMESIZE];
 	const struct plt_memzone *mz;
+	uint16_t max_scratch_pages;
+	struct cn10k_ml_ocm *ocm;
 	uint64_t model_info_size;
+	uint16_t total_wb_pages;
 	uint16_t lcl_model_id;
+	uint16_t layer_id;
 	uint64_t mz_size;
 	bool found;
 	int ret;
@@ -1086,6 +1090,62 @@  cnxk_ml_model_load(struct rte_ml_dev *dev, struct rte_ml_model_params *params, u
 	if (ret != 0)
 		goto error;
 
+	max_scratch_pages = 0;
+	total_wb_pages = 0;
+	layer_id = 0;
+
+	ocm = &cnxk_mldev->cn10k_mldev.ocm;
+
+	if (model->type == ML_CNXK_MODEL_TYPE_GLOW) {
+		total_wb_pages = total_wb_pages + model->layer[layer_id].glow.ocm_map.wb_pages;
+		max_scratch_pages = PLT_MAX(max_scratch_pages,
+					    model->layer[layer_id].glow.ocm_map.scratch_pages);
+#ifdef RTE_MLDEV_CNXK_ENABLE_MVTVM
+	} else {
+		for (layer_id = 0; layer_id < model->mvtvm.metadata.model.nb_layers; layer_id++) {
+			if (model->layer[layer_id].type == ML_CNXK_LAYER_TYPE_MRVL) {
+				total_wb_pages = total_wb_pages +
+						 model->layer[layer_id].glow.ocm_map.wb_pages;
+				max_scratch_pages =
+					PLT_MAX(max_scratch_pages,
+						model->layer[layer_id].glow.ocm_map.scratch_pages);
+			}
+		}
+#endif
+	}
+
+	if ((total_wb_pages + max_scratch_pages) > ocm->num_pages) {
+		plt_err("model_id = %u: total_wb_pages (%u) + scratch_pages (%u) >  %u\n",
+			lcl_model_id, total_wb_pages, max_scratch_pages, ocm->num_pages);
+
+		if (model->type == ML_CNXK_MODEL_TYPE_GLOW) {
+			plt_ml_dbg("layer_id = %u: wb_pages = %u, scratch_pages = %u\n", layer_id,
+				   model->layer[layer_id].glow.ocm_map.wb_pages,
+				   model->layer[layer_id].glow.ocm_map.scratch_pages);
+#ifdef RTE_MLDEV_CNXK_ENABLE_MVTVM
+		} else {
+			for (layer_id = 0; layer_id < model->mvtvm.metadata.model.nb_layers;
+			     layer_id++) {
+				if (model->layer[layer_id].type == ML_CNXK_LAYER_TYPE_MRVL) {
+					plt_ml_dbg(
+						"layer_id = %u: wb_pages = %u, scratch_pages = %u\n",
+						layer_id,
+						model->layer[layer_id].glow.ocm_map.wb_pages,
+						model->layer[layer_id].glow.ocm_map.scratch_pages);
+				}
+			}
+#endif
+		}
+
+		if (model->type == ML_CNXK_MODEL_TYPE_GLOW)
+			cn10k_ml_model_unload(cnxk_mldev, model);
+#ifdef RTE_MLDEV_CNXK_ENABLE_MVTVM
+		else {
+			mvtvm_ml_model_unload(cnxk_mldev, model);
+			return -ENOMEM;
+		}
+#endif
+	}
 	plt_spinlock_init(&model->lock);
 	model->state = ML_CNXK_MODEL_STATE_LOADED;
 	cnxk_mldev->nb_models_loaded++;