[v2,36/37] ml/cnxk: add support to select poll memory region

Message ID 20221208201806.21893-37-syalavarthi@marvell.com (mailing list archive)
State Superseded, archived
Delegated to: Thomas Monjalon
Headers
Series Implementation of ML CNXK driver |

Checks

Context Check Description
ci/checkpatch success coding style OK

Commit Message

Srikanth Yalavarthi Dec. 8, 2022, 8:18 p.m. UTC
  Added device argument "poll_mem" to select the memory
region to be used for polling in fast-path requests.

Implemented support to use scratch registers for polling.
Available pool of scratch registers one-to-one mapped with
the internal request queue.

poll_mem:
ddr:      Use DDR memory location for polling (default)
register: Use scratch registers polling

Signed-off-by: Srikanth Yalavarthi <syalavarthi@marvell.com>
---
 drivers/ml/cnxk/cn10k_ml_dev.c |  47 +++++++++++--
 drivers/ml/cnxk/cn10k_ml_dev.h |  24 +++++++
 drivers/ml/cnxk/cn10k_ml_ops.c | 124 +++++++++++++++++++++++++++++++--
 drivers/ml/cnxk/cn10k_ml_ops.h |   9 +++
 4 files changed, 192 insertions(+), 12 deletions(-)
  

Patch

diff --git a/drivers/ml/cnxk/cn10k_ml_dev.c b/drivers/ml/cnxk/cn10k_ml_dev.c
index 33709dae6f..153a0bdf4c 100644
--- a/drivers/ml/cnxk/cn10k_ml_dev.c
+++ b/drivers/ml/cnxk/cn10k_ml_dev.c
@@ -23,6 +23,7 @@ 
 #define CN10K_ML_DEV_CACHE_MODEL_DATA	"cache_model_data"
 #define CN10K_ML_OCM_ALLOC_MODE		"ocm_alloc_mode"
 #define CN10K_ML_DEV_HW_QUEUE_LOCK	"hw_queue_lock"
+#define CN10K_ML_FW_POLL_MEM		"poll_mem"
 
 #define CN10K_ML_FW_PATH_DEFAULT		"/lib/firmware/mlip-fw.bin"
 #define CN10K_ML_FW_ENABLE_DPE_WARNINGS_DEFAULT 1
@@ -30,6 +31,7 @@ 
 #define CN10K_ML_DEV_CACHE_MODEL_DATA_DEFAULT	1
 #define CN10K_ML_OCM_ALLOC_MODE_DEFAULT		"lowest"
 #define CN10K_ML_DEV_HW_QUEUE_LOCK_DEFAULT	1
+#define CN10K_ML_FW_POLL_MEM_DEFAULT		"ddr"
 
 /* ML firmware macros */
 #define FW_MEMZONE_NAME		 "ml_cn10k_fw_mz"
@@ -42,6 +44,7 @@ 
 /* Firmware flags */
 #define FW_ENABLE_DPE_WARNING_BITMASK BIT(0)
 #define FW_REPORT_DPE_WARNING_BITMASK BIT(1)
+#define FW_USE_DDR_POLL_ADDR_FP	      BIT(2)
 
 static const char *const valid_args[] = {CN10K_ML_FW_PATH,
 					 CN10K_ML_FW_ENABLE_DPE_WARNINGS,
@@ -49,6 +52,7 @@  static const char *const valid_args[] = {CN10K_ML_FW_PATH,
 					 CN10K_ML_DEV_CACHE_MODEL_DATA,
 					 CN10K_ML_OCM_ALLOC_MODE,
 					 CN10K_ML_DEV_HW_QUEUE_LOCK,
+					 CN10K_ML_FW_POLL_MEM,
 					 NULL};
 
 /* Dummy operations for ML device */
@@ -92,7 +96,9 @@  cn10k_mldev_parse_devargs(struct rte_devargs *devargs, struct cn10k_ml_dev *mlde
 	bool ocm_alloc_mode_set = false;
 	bool hw_queue_lock_set = false;
 	char *ocm_alloc_mode = NULL;
+	bool poll_mem_set = false;
 	bool fw_path_set = false;
+	char *poll_mem = NULL;
 	char *fw_path = NULL;
 	int ret = 0;
 
@@ -174,6 +180,17 @@  cn10k_mldev_parse_devargs(struct rte_devargs *devargs, struct cn10k_ml_dev *mlde
 		hw_queue_lock_set = true;
 	}
 
+	if (rte_kvargs_count(kvlist, CN10K_ML_FW_POLL_MEM) == 1) {
+		ret = rte_kvargs_process(kvlist, CN10K_ML_FW_POLL_MEM, &parse_string_arg,
+					 &poll_mem);
+		if (ret < 0) {
+			plt_err("Error processing arguments, key = %s\n", CN10K_ML_FW_POLL_MEM);
+			ret = -EINVAL;
+			goto exit;
+		}
+		poll_mem_set = true;
+	}
+
 check_args:
 	if (!fw_path_set)
 		mldev->fw.path = CN10K_ML_FW_PATH_DEFAULT;
@@ -243,6 +260,18 @@  cn10k_mldev_parse_devargs(struct rte_devargs *devargs, struct cn10k_ml_dev *mlde
 	}
 	plt_info("ML: %s = %d", CN10K_ML_DEV_HW_QUEUE_LOCK, mldev->hw_queue_lock);
 
+	if (!poll_mem_set) {
+		mldev->fw.poll_mem = CN10K_ML_FW_POLL_MEM_DEFAULT;
+	} else {
+		if (!((strcmp(poll_mem, "ddr") == 0) || (strcmp(poll_mem, "register") == 0))) {
+			plt_err("Invalid argument, %s = %s\n", CN10K_ML_FW_POLL_MEM, poll_mem);
+			ret = -EINVAL;
+			goto exit;
+		}
+		mldev->fw.poll_mem = poll_mem;
+	}
+	plt_info("ML: %s = %s", CN10K_ML_FW_POLL_MEM, mldev->fw.poll_mem);
+
 exit:
 	if (kvlist)
 		rte_kvargs_free(kvlist);
@@ -376,6 +405,11 @@  cn10k_ml_fw_flags_get(struct cn10k_ml_fw *fw)
 	if (fw->report_dpe_warnings)
 		flags = flags | FW_REPORT_DPE_WARNING_BITMASK;
 
+	if (strcmp(fw->poll_mem, "ddr") == 0)
+		flags = flags | FW_USE_DDR_POLL_ADDR_FP;
+	else if (strcmp(fw->poll_mem, "register") == 0)
+		flags = flags & ~FW_USE_DDR_POLL_ADDR_FP;
+
 	return flags;
 }
 
@@ -780,9 +814,10 @@  RTE_PMD_REGISTER_PCI(MLDEV_NAME_CN10K_PMD, cn10k_mldev_pmd);
 RTE_PMD_REGISTER_PCI_TABLE(MLDEV_NAME_CN10K_PMD, pci_id_ml_table);
 RTE_PMD_REGISTER_KMOD_DEP(MLDEV_NAME_CN10K_PMD, "vfio-pci");
 
-RTE_PMD_REGISTER_PARAM_STRING(MLDEV_NAME_CN10K_PMD, CN10K_ML_FW_PATH
-			      "=<path>" CN10K_ML_FW_ENABLE_DPE_WARNINGS
-			      "=<0|1>" CN10K_ML_FW_REPORT_DPE_WARNINGS
-			      "=<0|1>" CN10K_ML_DEV_CACHE_MODEL_DATA
-			      "=<0|1>" CN10K_ML_OCM_ALLOC_MODE
-			      "=<lowest|largest>" CN10K_ML_DEV_HW_QUEUE_LOCK "=<0|1>");
+RTE_PMD_REGISTER_PARAM_STRING(MLDEV_NAME_CN10K_PMD,
+			      CN10K_ML_FW_PATH "=<path>" CN10K_ML_FW_ENABLE_DPE_WARNINGS
+					       "=<0|1>" CN10K_ML_FW_REPORT_DPE_WARNINGS
+					       "=<0|1>" CN10K_ML_DEV_CACHE_MODEL_DATA
+					       "=<0|1>" CN10K_ML_OCM_ALLOC_MODE
+					       "=<lowest|largest>" CN10K_ML_DEV_HW_QUEUE_LOCK
+					       "=<0|1>" CN10K_ML_FW_POLL_MEM "=<ddr|register>");
diff --git a/drivers/ml/cnxk/cn10k_ml_dev.h b/drivers/ml/cnxk/cn10k_ml_dev.h
index 4b65efecc5..092a023144 100644
--- a/drivers/ml/cnxk/cn10k_ml_dev.h
+++ b/drivers/ml/cnxk/cn10k_ml_dev.h
@@ -43,6 +43,18 @@ 
 #define ML_CN10K_POLL_JOB_START	 0
 #define ML_CN10K_POLL_JOB_FINISH 1
 
+/* Memory barrier macros */
+#if defined(RTE_ARCH_ARM)
+#define dmb_st ({ asm volatile("dmb st" : : : "memory"); })
+#define dsb_st ({ asm volatile("dsb st" : : : "memory"); })
+#else
+#define dmb_st
+#define dsb_st
+#endif
+
+struct cn10k_ml_req;
+struct cn10k_ml_qp;
+
 /* ML Job types */
 enum cn10k_ml_job_type {
 	ML_CN10K_JOB_TYPE_MODEL_RUN = 0,
@@ -358,6 +370,9 @@  struct cn10k_ml_fw {
 	/* Report DPE warnings */
 	int report_dpe_warnings;
 
+	/* Memory to be used for polling in fast-path requests */
+	const char *poll_mem;
+
 	/* Data buffer */
 	uint8_t *data;
 
@@ -393,6 +408,15 @@  struct cn10k_ml_dev {
 
 	/* JCMD enqueue function handler */
 	bool (*ml_jcmdq_enqueue)(struct roc_ml *roc_ml, struct ml_job_cmd_s *job_cmd);
+
+	/* Poll handling function pointers */
+	void (*set_poll_addr)(struct cn10k_ml_qp *qp, struct cn10k_ml_req *req, uint64_t idx);
+	void (*set_poll_ptr)(struct roc_ml *roc_ml, struct cn10k_ml_req *req);
+	uint64_t (*get_poll_ptr)(struct roc_ml *roc_ml, struct cn10k_ml_req *req);
+
+	/* Memory barrier function pointers to handle synchronization */
+	void (*set_enq_barrier)(void);
+	void (*set_deq_barrier)(void);
 };
 
 uint64_t cn10k_ml_fw_flags_get(struct cn10k_ml_fw *fw);
diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c
index f787455a7f..b73ce8c97a 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.c
+++ b/drivers/ml/cnxk/cn10k_ml_ops.c
@@ -23,6 +23,11 @@ 
 #define ML_FLAGS_POLL_COMPL BIT(0)
 #define ML_FLAGS_SSO_COMPL  BIT(1)
 
+/* Scratch register range for poll mode requests */
+#define ML_POLL_REGISTER_SYNC  1023
+#define ML_POLL_REGISTER_START 1024
+#define ML_POLL_REGISTER_END   2047
+
 /* Error message length */
 #define ERRMSG_LEN 32
 
@@ -76,6 +81,80 @@  print_line(FILE *fp, int len)
 	fprintf(fp, "\n");
 }
 
+static inline void
+cn10k_ml_set_poll_addr_ddr(struct cn10k_ml_qp *qp, struct cn10k_ml_req *req, uint64_t idx)
+{
+	PLT_SET_USED(qp);
+	PLT_SET_USED(idx);
+
+	req->compl_W1 = PLT_U64_CAST(&req->status);
+}
+
+static inline void
+cn10k_ml_set_poll_addr_reg(struct cn10k_ml_qp *qp, struct cn10k_ml_req *req, uint64_t idx)
+{
+	req->compl_W1 = ML_SCRATCH(qp->block_start + idx % qp->block_size);
+}
+
+static inline void
+cn10k_ml_set_poll_ptr_ddr(struct roc_ml *roc_ml, struct cn10k_ml_req *req)
+{
+	PLT_SET_USED(roc_ml);
+
+	plt_write64(ML_CN10K_POLL_JOB_START, req->compl_W1);
+}
+
+static inline void
+cn10k_ml_set_poll_ptr_reg(struct roc_ml *roc_ml, struct cn10k_ml_req *req)
+{
+	roc_ml_reg_write64(roc_ml, ML_CN10K_POLL_JOB_START, req->compl_W1);
+}
+
+static inline uint64_t
+cn10k_ml_get_poll_ptr_ddr(struct roc_ml *roc_ml, struct cn10k_ml_req *req)
+{
+	PLT_SET_USED(roc_ml);
+
+	return plt_read64(req->compl_W1);
+}
+
+static inline uint64_t
+cn10k_ml_get_poll_ptr_reg(struct roc_ml *roc_ml, struct cn10k_ml_req *req)
+{
+	return roc_ml_reg_read64(roc_ml, req->compl_W1);
+}
+
+static inline void
+cn10k_ml_set_sync_addr(struct cn10k_ml_dev *mldev, struct cn10k_ml_req *req)
+{
+	if (strcmp(mldev->fw.poll_mem, "ddr") == 0)
+		req->compl_W1 = PLT_U64_CAST(&req->status);
+	else if (strcmp(mldev->fw.poll_mem, "register") == 0)
+		req->compl_W1 = ML_SCRATCH(ML_POLL_REGISTER_SYNC);
+}
+
+static inline void
+cn10k_ml_enq_barrier_ddr(void)
+{
+}
+
+static inline void
+cn10k_ml_deq_barrier_ddr(void)
+{
+}
+
+static inline void
+cn10k_ml_enq_barrier_register(void)
+{
+	dmb_st;
+}
+
+static inline void
+cn10k_ml_deq_barrier_register(void)
+{
+	dsb_st;
+}
+
 static void
 qp_memzone_name_get(char *name, int size, int dev_id, int qp_id)
 {
@@ -163,6 +242,9 @@  cn10k_ml_qp_create(const struct rte_ml_dev *dev, uint16_t qp_id, uint32_t nb_des
 	qp->stats.dequeued_count = 0;
 	qp->stats.enqueue_err_count = 0;
 	qp->stats.dequeue_err_count = 0;
+	qp->block_size =
+		(ML_POLL_REGISTER_END - ML_POLL_REGISTER_START + 1) / dev->data->nb_queue_pairs;
+	qp->block_start = ML_POLL_REGISTER_START + qp_id * qp->block_size;
 
 	/* Initialize job command */
 	for (i = 0; i < qp->nb_desc; i++) {
@@ -341,7 +423,7 @@  cn10k_ml_prep_fp_job_descriptor(struct rte_ml_dev *dev, struct cn10k_ml_req *req
 	mldev = dev->data->dev_private;
 
 	req->jd.hdr.jce.w0.u64 = 0;
-	req->jd.hdr.jce.w1.u64 = PLT_U64_CAST(&req->status);
+	req->jd.hdr.jce.w1.u64 = req->compl_W1;
 	req->jd.hdr.model_id = op->model_id;
 	req->jd.hdr.job_type = ML_CN10K_JOB_TYPE_MODEL_RUN;
 	req->jd.hdr.fp_flags = ML_FLAGS_POLL_COMPL;
@@ -549,7 +631,11 @@  cn10k_ml_dev_info_get(struct rte_ml_dev *dev, struct rte_ml_dev_info *dev_info)
 	else
 		dev_info->max_queue_pairs = ML_CN10K_MAX_QP_PER_DEVICE_LF;
 
-	dev_info->max_desc = ML_CN10K_MAX_DESC_PER_QP;
+	if (strcmp(mldev->fw.poll_mem, "register") == 0)
+		dev_info->max_desc = ML_CN10K_MAX_DESC_PER_QP / dev_info->max_queue_pairs;
+	else if (strcmp(mldev->fw.poll_mem, "ddr") == 0)
+		dev_info->max_desc = ML_CN10K_MAX_DESC_PER_QP;
+
 	dev_info->max_segments = ML_CN10K_MAX_SEGMENTS;
 	dev_info->min_align_size = ML_CN10K_ALIGN_SIZE;
 
@@ -717,6 +803,26 @@  cn10k_ml_dev_configure(struct rte_ml_dev *dev, const struct rte_ml_dev_config *c
 	else
 		mldev->ml_jcmdq_enqueue = roc_ml_jcmdq_enqueue_lf;
 
+	/* Set polling function pointers */
+	if (strcmp(mldev->fw.poll_mem, "ddr") == 0) {
+		mldev->set_poll_addr = cn10k_ml_set_poll_addr_ddr;
+		mldev->set_poll_ptr = cn10k_ml_set_poll_ptr_ddr;
+		mldev->get_poll_ptr = cn10k_ml_get_poll_ptr_ddr;
+	} else if (strcmp(mldev->fw.poll_mem, "register") == 0) {
+		mldev->set_poll_addr = cn10k_ml_set_poll_addr_reg;
+		mldev->set_poll_ptr = cn10k_ml_set_poll_ptr_reg;
+		mldev->get_poll_ptr = cn10k_ml_get_poll_ptr_reg;
+	}
+
+	/* Set barrier function pointers */
+	if (strcmp(mldev->fw.poll_mem, "ddr") == 0) {
+		mldev->set_enq_barrier = cn10k_ml_enq_barrier_ddr;
+		mldev->set_deq_barrier = cn10k_ml_deq_barrier_ddr;
+	} else if (strcmp(mldev->fw.poll_mem, "register") == 0) {
+		mldev->set_enq_barrier = cn10k_ml_enq_barrier_register;
+		mldev->set_deq_barrier = cn10k_ml_deq_barrier_register;
+	}
+
 	dev->enqueue_burst = cn10k_ml_enqueue_burst;
 	dev->dequeue_burst = cn10k_ml_dequeue_burst;
 	dev->op_error_get = cn10k_ml_op_error_get;
@@ -2003,13 +2109,15 @@  cn10k_ml_enqueue_burst(struct rte_ml_dev *dev, uint16_t qp_id, struct rte_ml_op
 	op = ops[count];
 	req = &queue->reqs[head];
 
+	mldev->set_poll_addr(qp, req, head);
 	cn10k_ml_prep_fp_job_descriptor(dev, req, op);
 
 	memset(&req->result, 0, sizeof(struct cn10k_ml_result));
 	req->result.error_code.s.etype = ML_ETYPE_UNKNOWN;
 	req->result.user_ptr = op->user_ptr;
+	mldev->set_enq_barrier();
 
-	plt_write64(ML_CN10K_POLL_JOB_START, &req->status);
+	mldev->set_poll_ptr(&mldev->roc, req);
 	enqueued = mldev->ml_jcmdq_enqueue(&mldev->roc, &req->jcmd);
 	if (unlikely(!enqueued))
 		goto jcmdq_full;
@@ -2035,6 +2143,7 @@  cn10k_ml_dequeue_burst(struct rte_ml_dev *dev, uint16_t qp_id, struct rte_ml_op
 		       uint16_t nb_ops)
 {
 	struct cn10k_ml_queue *queue;
+	struct cn10k_ml_dev *mldev;
 	struct cn10k_ml_req *req;
 	struct cn10k_ml_qp *qp;
 
@@ -2042,6 +2151,7 @@  cn10k_ml_dequeue_burst(struct rte_ml_dev *dev, uint16_t qp_id, struct rte_ml_op
 	uint16_t count;
 	uint64_t tail;
 
+	mldev = dev->data->dev_private;
 	qp = dev->data->queue_pairs[qp_id];
 	queue = &qp->queue;
 
@@ -2054,7 +2164,7 @@  cn10k_ml_dequeue_burst(struct rte_ml_dev *dev, uint16_t qp_id, struct rte_ml_op
 
 dequeue_req:
 	req = &queue->reqs[tail];
-	status = plt_read64(&req->status);
+	status = mldev->get_poll_ptr(&mldev->roc, req);
 	if (unlikely(status != ML_CN10K_POLL_JOB_FINISH)) {
 		if (plt_tsc_cycles() < req->timeout)
 			goto empty_or_active;
@@ -2062,6 +2172,7 @@  cn10k_ml_dequeue_burst(struct rte_ml_dev *dev, uint16_t qp_id, struct rte_ml_op
 			req->result.error_code.s.etype = ML_ETYPE_DRIVER;
 	}
 
+	mldev->set_deq_barrier();
 	cn10k_ml_result_update(dev, qp_id, &req->result, req->op);
 	ops[count] = req->op;
 
@@ -2119,13 +2230,14 @@  cn10k_ml_inference_sync(struct rte_ml_dev *dev, struct rte_ml_op *op)
 	model = dev->data->models[op->model_id];
 	req = model->req;
 
+	cn10k_ml_set_sync_addr(mldev, req);
 	cn10k_ml_prep_fp_job_descriptor(dev, req, op);
 
 	memset(&req->result, 0, sizeof(struct cn10k_ml_result));
 	req->result.error_code.s.etype = ML_ETYPE_UNKNOWN;
 	req->result.user_ptr = op->user_ptr;
 
-	plt_write64(ML_CN10K_POLL_JOB_START, &req->status);
+	mldev->set_poll_ptr(&mldev->roc, req);
 	req->jcmd.w1.s.jobptr = PLT_U64_CAST(&req->jd);
 
 	timeout = true;
@@ -2145,7 +2257,7 @@  cn10k_ml_inference_sync(struct rte_ml_dev *dev, struct rte_ml_op *op)
 
 	timeout = true;
 	do {
-		if (plt_read64(&req->status) == ML_CN10K_POLL_JOB_FINISH) {
+		if (mldev->get_poll_ptr(&mldev->roc, req) == ML_CN10K_POLL_JOB_FINISH) {
 			timeout = false;
 			break;
 		}
diff --git a/drivers/ml/cnxk/cn10k_ml_ops.h b/drivers/ml/cnxk/cn10k_ml_ops.h
index 4c38f1938a..f09c67f186 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.h
+++ b/drivers/ml/cnxk/cn10k_ml_ops.h
@@ -26,6 +26,9 @@  struct cn10k_ml_req {
 	/* Job command */
 	struct ml_job_cmd_s jcmd;
 
+	/* Job completion W1 */
+	uint64_t compl_W1;
+
 	/* Request timeout cycle */
 	uint64_t timeout;
 
@@ -61,6 +64,12 @@  struct cn10k_ml_qp {
 
 	/* Queue pair statistics */
 	struct rte_ml_dev_stats stats;
+
+	/* Register block start for polling */
+	uint32_t block_start;
+
+	/* Register block end for polling */
+	uint32_t block_size;
 };
 
 /* CN10K device ops */