[v2,33/37] ml/cnxk: add support to enable model data caching

Message ID 20221208201806.21893-34-syalavarthi@marvell.com (mailing list archive)
State Superseded, archived
Delegated to: Thomas Monjalon
Headers
Series Implementation of ML CNXK driver |

Checks

Context Check Description
ci/checkpatch success coding style OK

Commit Message

Srikanth Yalavarthi Dec. 8, 2022, 8:18 p.m. UTC
  Added device argument 'cache_model_data' to enable model data
caching. An inference request would be executed with dummy data
in synchronous mode during model start stage. This run would
cache the model weights and bias in the memory and result in
improved inference throughput.

cache_model_data = 1, enable (default)
cache_model_data = 0, disable

Signed-off-by: Srikanth Yalavarthi <syalavarthi@marvell.com>
---
 drivers/ml/cnxk/cn10k_ml_dev.c | 33 ++++++++++++++++++++--
 drivers/ml/cnxk/cn10k_ml_dev.h |  3 ++
 drivers/ml/cnxk/cn10k_ml_ops.c | 50 ++++++++++++++++++++++++++++++++++
 3 files changed, 84 insertions(+), 2 deletions(-)
  

Patch

diff --git a/drivers/ml/cnxk/cn10k_ml_dev.c b/drivers/ml/cnxk/cn10k_ml_dev.c
index 0b345b3d4e..b844a42677 100644
--- a/drivers/ml/cnxk/cn10k_ml_dev.c
+++ b/drivers/ml/cnxk/cn10k_ml_dev.c
@@ -20,10 +20,12 @@ 
 #define CN10K_ML_FW_PATH		"fw_path"
 #define CN10K_ML_FW_ENABLE_DPE_WARNINGS "enable_dpe_warnings"
 #define CN10K_ML_FW_REPORT_DPE_WARNINGS "report_dpe_warnings"
+#define CN10K_ML_DEV_CACHE_MODEL_DATA	"cache_model_data"
 
 #define CN10K_ML_FW_PATH_DEFAULT		"/lib/firmware/mlip-fw.bin"
 #define CN10K_ML_FW_ENABLE_DPE_WARNINGS_DEFAULT 1
 #define CN10K_ML_FW_REPORT_DPE_WARNINGS_DEFAULT 0
+#define CN10K_ML_DEV_CACHE_MODEL_DATA_DEFAULT	1
 
 /* ML firmware macros */
 #define FW_MEMZONE_NAME		 "ml_cn10k_fw_mz"
@@ -38,7 +40,8 @@ 
 #define FW_REPORT_DPE_WARNING_BITMASK BIT(1)
 
 static const char *const valid_args[] = {CN10K_ML_FW_PATH, CN10K_ML_FW_ENABLE_DPE_WARNINGS,
-					 CN10K_ML_FW_REPORT_DPE_WARNINGS, NULL};
+					 CN10K_ML_FW_REPORT_DPE_WARNINGS,
+					 CN10K_ML_DEV_CACHE_MODEL_DATA, NULL};
 
 /* Dummy operations for ML device */
 struct rte_ml_dev_ops ml_dev_dummy_ops = {0};
@@ -76,6 +79,7 @@  cn10k_mldev_parse_devargs(struct rte_devargs *devargs, struct cn10k_ml_dev *mlde
 {
 	bool enable_dpe_warnings_set = false;
 	bool report_dpe_warnings_set = false;
+	bool cache_model_data_set = false;
 	struct rte_kvargs *kvlist = NULL;
 	bool fw_path_set = false;
 	char *fw_path = NULL;
@@ -124,6 +128,18 @@  cn10k_mldev_parse_devargs(struct rte_devargs *devargs, struct cn10k_ml_dev *mlde
 		report_dpe_warnings_set = true;
 	}
 
+	if (rte_kvargs_count(kvlist, CN10K_ML_DEV_CACHE_MODEL_DATA) == 1) {
+		ret = rte_kvargs_process(kvlist, CN10K_ML_DEV_CACHE_MODEL_DATA, &parse_integer_arg,
+					 &mldev->cache_model_data);
+		if (ret < 0) {
+			plt_err("Error processing arguments, key = %s\n",
+				CN10K_ML_DEV_CACHE_MODEL_DATA);
+			ret = -EINVAL;
+			goto exit;
+		}
+		cache_model_data_set = true;
+	}
+
 check_args:
 	if (!fw_path_set)
 		mldev->fw.path = CN10K_ML_FW_PATH_DEFAULT;
@@ -155,6 +171,18 @@  cn10k_mldev_parse_devargs(struct rte_devargs *devargs, struct cn10k_ml_dev *mlde
 	}
 	plt_info("ML: %s = %d", CN10K_ML_FW_REPORT_DPE_WARNINGS, mldev->fw.report_dpe_warnings);
 
+	if (!cache_model_data_set) {
+		mldev->cache_model_data = CN10K_ML_DEV_CACHE_MODEL_DATA_DEFAULT;
+	} else {
+		if ((mldev->cache_model_data < 0) || (mldev->cache_model_data > 1)) {
+			plt_err("Invalid argument, %s = %d\n", CN10K_ML_DEV_CACHE_MODEL_DATA,
+				mldev->cache_model_data);
+			ret = -EINVAL;
+			goto exit;
+		}
+	}
+	plt_info("ML: %s = %d", CN10K_ML_DEV_CACHE_MODEL_DATA, mldev->cache_model_data);
+
 exit:
 	if (kvlist)
 		rte_kvargs_free(kvlist);
@@ -694,4 +722,5 @@  RTE_PMD_REGISTER_KMOD_DEP(MLDEV_NAME_CN10K_PMD, "vfio-pci");
 
 RTE_PMD_REGISTER_PARAM_STRING(MLDEV_NAME_CN10K_PMD,
 			      CN10K_ML_FW_PATH "=<path>" CN10K_ML_FW_ENABLE_DPE_WARNINGS
-					       "=<0|1>" CN10K_ML_FW_REPORT_DPE_WARNINGS "=<0|1>");
+					       "=<0|1>" CN10K_ML_FW_REPORT_DPE_WARNINGS
+					       "=<0|1>" CN10K_ML_DEV_CACHE_MODEL_DATA "=<0|1>");
diff --git a/drivers/ml/cnxk/cn10k_ml_dev.h b/drivers/ml/cnxk/cn10k_ml_dev.h
index 52c8bd1af7..24e4823196 100644
--- a/drivers/ml/cnxk/cn10k_ml_dev.h
+++ b/drivers/ml/cnxk/cn10k_ml_dev.h
@@ -381,6 +381,9 @@  struct cn10k_ml_dev {
 
 	/* xstats status */
 	bool xstats_enabled;
+
+	/* Enable / disable model data caching */
+	int cache_model_data;
 };
 
 uint64_t cn10k_ml_fw_flags_get(struct cn10k_ml_fw *fw);
diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c
index 5d29a55e66..a44d77df76 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.c
+++ b/drivers/ml/cnxk/cn10k_ml_ops.c
@@ -488,6 +488,49 @@  cn10k_ml_model_xstat_reset(struct rte_ml_dev *dev, uint16_t model_id,
 	}
 }
 
+static int
+cn10k_ml_cache_model_data(struct rte_ml_dev *dev, int16_t model_id)
+{
+	struct cn10k_ml_model *model;
+	struct rte_ml_op op;
+
+	char str[RTE_MEMZONE_NAMESIZE];
+	const struct plt_memzone *mz;
+	uint64_t isize = 0;
+	uint64_t osize = 0;
+	int ret = 0;
+
+	model = dev->data->models[model_id];
+
+	/* Create input and output buffers. */
+	rte_ml_io_input_size_get(dev->data->dev_id, model_id, model->batch_size, &isize, NULL);
+	rte_ml_io_output_size_get(dev->data->dev_id, model_id, model->batch_size, &osize, NULL);
+
+	snprintf(str, RTE_MEMZONE_NAMESIZE, "%s_%d", "ml_dummy_io", model_id);
+	mz = plt_memzone_reserve_aligned(str, isize + osize, 0, ML_CN10K_ALIGN_SIZE);
+	if (mz == NULL)
+		return -ENOMEM;
+	memset(mz->addr, 0, isize + osize);
+
+	op.model_id = model_id;
+	op.nb_batches = model->batch_size;
+	op.mempool = NULL;
+
+	op.input.addr = mz->addr;
+	op.input.length = isize;
+	op.input.next = NULL;
+
+	op.output.addr = PLT_PTR_ADD(op.input.addr, isize);
+	op.output.length = osize;
+	op.output.next = NULL;
+
+	memset(model->req, 0, sizeof(struct cn10k_ml_req));
+	ret = cn10k_ml_inference_sync(dev, &op);
+	plt_memzone_free(mz);
+
+	return ret;
+}
+
 static int
 cn10k_ml_dev_info_get(struct rte_ml_dev *dev, struct rte_ml_dev_info *dev_info)
 {
@@ -1471,6 +1514,13 @@  cn10k_ml_model_start(struct rte_ml_dev *dev, int16_t model_id)
 		}
 	}
 
+	if (ret < 0) { /* Call unload to update model and FW state, ignore error */
+		rte_ml_model_stop(dev->data->dev_id, model_id);
+	} else {
+		if (mldev->cache_model_data && roc_model_is_cn10ka())
+			ret = cn10k_ml_cache_model_data(dev, model_id);
+	}
+
 	return ret;
 }