[3/4] ml/cnxk: add adapter enqueue function

Message ID 20240107154013.4676-4-syalavarthi@marvell.com (mailing list archive)
State Changes Requested, archived
Delegated to: Jerin Jacob
Headers
Series Implementation of CNXK ML event adapter driver |

Checks

Context Check Description
ci/checkpatch success coding style OK

Commit Message

Srikanth Yalavarthi Jan. 7, 2024, 3:40 p.m. UTC
  Implemented ML adapter enqueue function. Rename internal
fast-path JD preparation function for poll mode. Added JD
preparation function for event mode. Updated meson build
dependencies for ml/cnxk driver.

Signed-off-by: Srikanth Yalavarthi <syalavarthi@marvell.com>
---
 drivers/event/cnxk/cn10k_eventdev.c |   3 +
 drivers/ml/cnxk/cn10k_ml_event_dp.h |  16 ++++
 drivers/ml/cnxk/cn10k_ml_ops.c      | 129 ++++++++++++++++++++++++++--
 drivers/ml/cnxk/cn10k_ml_ops.h      |   1 +
 drivers/ml/cnxk/cnxk_ml_ops.h       |   8 ++
 drivers/ml/cnxk/meson.build         |   2 +-
 drivers/ml/cnxk/version.map         |   7 ++
 7 files changed, 160 insertions(+), 6 deletions(-)
 create mode 100644 drivers/ml/cnxk/cn10k_ml_event_dp.h
 create mode 100644 drivers/ml/cnxk/version.map
  

Patch

diff --git a/drivers/event/cnxk/cn10k_eventdev.c b/drivers/event/cnxk/cn10k_eventdev.c
index 201972cec9e..3b5dce23fe9 100644
--- a/drivers/event/cnxk/cn10k_eventdev.c
+++ b/drivers/event/cnxk/cn10k_eventdev.c
@@ -6,6 +6,7 @@ 
 #include "cn10k_worker.h"
 #include "cn10k_ethdev.h"
 #include "cn10k_cryptodev_ops.h"
+#include "cn10k_ml_event_dp.h"
 #include "cnxk_ml_ops.h"
 #include "cnxk_eventdev.h"
 #include "cnxk_worker.h"
@@ -478,6 +479,8 @@  cn10k_sso_fp_fns_set(struct rte_eventdev *event_dev)
 	else
 		event_dev->ca_enqueue = cn10k_cpt_sg_ver1_crypto_adapter_enqueue;
 
+	event_dev->mla_enqueue = cn10k_ml_adapter_enqueue;
+
 	if (dev->tx_offloads & NIX_TX_MULTI_SEG_F)
 		CN10K_SET_EVDEV_ENQ_OP(dev, event_dev->txa_enqueue, sso_hws_tx_adptr_enq_seg);
 	else
diff --git a/drivers/ml/cnxk/cn10k_ml_event_dp.h b/drivers/ml/cnxk/cn10k_ml_event_dp.h
new file mode 100644
index 00000000000..bf7fc57bceb
--- /dev/null
+++ b/drivers/ml/cnxk/cn10k_ml_event_dp.h
@@ -0,0 +1,16 @@ 
+/* SPDX-License-Identifier: BSD-3-Clause
+ * Copyright(C) 2024 Marvell.
+ */
+
+#ifndef _CN10K_ML_EVENT_DP_H_
+#define _CN10K_ML_EVENT_DP_H_
+
+#include <stdint.h>
+
+#include <rte_common.h>
+#include <rte_eventdev.h>
+
+__rte_internal
+__rte_hot uint16_t cn10k_ml_adapter_enqueue(void *ws, struct rte_event ev[], uint16_t nb_events);
+
+#endif /* _CN10K_ML_EVENT_DP_H_ */
diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c
index 834e55e88e9..4bc17eaa8c4 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.c
+++ b/drivers/ml/cnxk/cn10k_ml_ops.c
@@ -2,11 +2,13 @@ 
  * Copyright (c) 2022 Marvell.
  */
 
+#include <rte_event_ml_adapter.h>
 #include <rte_mldev.h>
 #include <rte_mldev_pmd.h>
 
 #include <mldev_utils.h>
 
+#include "cn10k_ml_event_dp.h"
 #include "cnxk_ml_dev.h"
 #include "cnxk_ml_model.h"
 #include "cnxk_ml_ops.h"
@@ -144,8 +146,8 @@  cn10k_ml_prep_sp_job_descriptor(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_l
 }
 
 static __rte_always_inline void
-cn10k_ml_prep_fp_job_descriptor(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_req *req,
-				uint16_t index, void *input, void *output, uint16_t nb_batches)
+cn10k_ml_prep_fp_job_descriptor_poll(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_req *req,
+				     uint16_t index, void *input, void *output, uint16_t nb_batches)
 {
 	struct cn10k_ml_dev *cn10k_mldev;
 
@@ -166,6 +168,33 @@  cn10k_ml_prep_fp_job_descriptor(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_r
 	req->cn10k_req.jd.model_run.num_batches = nb_batches;
 }
 
+static __rte_always_inline void
+cn10k_ml_prep_fp_job_descriptor_event(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_req *req,
+				      uint16_t index, void *input, void *output, uint16_t nb_batches
+
+				      ,
+				      uint64_t *compl_W0)
+{
+
+	struct cn10k_ml_dev *cn10k_mldev;
+
+	cn10k_mldev = &cnxk_mldev->cn10k_mldev;
+
+	req->cn10k_req.jd.hdr.jce.w0.u64 = *compl_W0;
+	req->cn10k_req.jd.hdr.jce.w1.s.wqp = PLT_U64_CAST(req);
+	req->cn10k_req.jd.hdr.model_id = index;
+	req->cn10k_req.jd.hdr.job_type = ML_CN10K_JOB_TYPE_MODEL_RUN;
+	req->cn10k_req.jd.hdr.fp_flags = ML_FLAGS_SSO_COMPL;
+	req->cn10k_req.jd.hdr.sp_flags = 0x0;
+	req->cn10k_req.jd.hdr.result =
+		roc_ml_addr_ap2mlip(&cn10k_mldev->roc, &req->cn10k_req.result);
+	req->cn10k_req.jd.model_run.input_ddr_addr =
+		PLT_U64_CAST(roc_ml_addr_ap2mlip(&cn10k_mldev->roc, input));
+	req->cn10k_req.jd.model_run.output_ddr_addr =
+		PLT_U64_CAST(roc_ml_addr_ap2mlip(&cn10k_mldev->roc, output));
+	req->cn10k_req.jd.model_run.num_batches = nb_batches;
+}
+
 static void
 cn10k_ml_xstats_layer_name_update(struct cnxk_ml_dev *cnxk_mldev, uint16_t model_id,
 				  uint16_t layer_id)
@@ -1305,13 +1334,16 @@  cn10k_ml_enqueue_single(struct cnxk_ml_dev *cnxk_mldev, struct rte_ml_op *op, ui
 
 	model = cnxk_mldev->mldev->data->models[op->model_id];
 	model->set_poll_addr(req);
-	cn10k_ml_prep_fp_job_descriptor(cnxk_mldev, req, model->layer[layer_id].index,
-					op->input[0]->addr, op->output[0]->addr, op->nb_batches);
+	cn10k_ml_prep_fp_job_descriptor_poll(cnxk_mldev, req, model->layer[layer_id].index,
+					     op->input[0]->addr, op->output[0]->addr,
+					     op->nb_batches);
 
 	memset(&req->cn10k_req.result, 0, sizeof(struct cn10k_ml_result));
 	error_code = (union cn10k_ml_error_code *)&req->cn10k_req.result.error_code;
 	error_code->s.etype = ML_CNXK_ETYPE_UNKNOWN;
 	req->cn10k_req.result.user_ptr = op->user_ptr;
+	req->cnxk_mldev = cnxk_mldev;
+	req->qp_id = qp->id;
 
 	cnxk_ml_set_poll_ptr(req);
 	if (unlikely(!cn10k_mldev->ml_jcmdq_enqueue(&cn10k_mldev->roc, &req->cn10k_req.jcmd)))
@@ -1383,7 +1415,7 @@  cn10k_ml_inference_sync(void *device, uint16_t index, void *input, void *output,
 	op.impl_opaque = 0;
 
 	cn10k_ml_set_poll_addr(req);
-	cn10k_ml_prep_fp_job_descriptor(cnxk_mldev, req, index, input, output, nb_batches);
+	cn10k_ml_prep_fp_job_descriptor_poll(cnxk_mldev, req, index, input, output, nb_batches);
 
 	memset(&req->cn10k_req.result, 0, sizeof(struct cn10k_ml_result));
 	error_code = (union cn10k_ml_error_code *)&req->cn10k_req.result.error_code;
@@ -1541,3 +1573,90 @@  cn10k_ml_free(const char *name)
 
 	return plt_memzone_free(mz);
 }
+
+static int
+cn10k_ml_meta_info_extract(struct rte_ml_op *op, struct cnxk_ml_qp **qp, uint64_t *W0,
+			   struct rte_ml_dev **dev)
+{
+	union rte_event_ml_metadata *eml_mdata;
+	struct rte_event *rsp_info;
+	union ml_jce_w0 jce_w0;
+	uint8_t mldev_id;
+	uint16_t qp_id;
+
+	eml_mdata = (union rte_event_ml_metadata *)((uint8_t *)op + op->private_data_offset);
+	rsp_info = &eml_mdata->response_info;
+	mldev_id = eml_mdata->request_info.mldev_id;
+	qp_id = eml_mdata->request_info.queue_pair_id;
+
+	*dev = rte_ml_dev_pmd_get_dev(mldev_id);
+	*qp = (*dev)->data->queue_pairs[qp_id];
+
+	jce_w0.s.ttype = rsp_info->sched_type;
+	jce_w0.s.pf_func = roc_ml_sso_pf_func_get();
+	jce_w0.s.ggrp = rsp_info->queue_id;
+	jce_w0.s.tag =
+		(RTE_EVENT_TYPE_MLDEV << 28) | (rsp_info->sub_event_type << 20) | rsp_info->flow_id;
+	*W0 = jce_w0.u64;
+
+	return 0;
+}
+
+__rte_hot uint16_t
+cn10k_ml_adapter_enqueue(void *ws, struct rte_event ev[], uint16_t nb_events)
+{
+	union cn10k_ml_error_code *error_code;
+	struct cn10k_ml_dev *cn10k_mldev;
+	struct cnxk_ml_dev *cnxk_mldev;
+	struct cnxk_ml_model *model;
+	struct cnxk_ml_req *req;
+	struct cnxk_ml_qp *qp;
+
+	struct rte_ml_dev *dev;
+	struct rte_ml_op *op;
+
+	uint16_t count;
+	uint64_t W0;
+	int ret, i;
+
+	PLT_SET_USED(ws);
+
+	count = 0;
+	for (i = 0; i < nb_events; i++) {
+		op = ev[i].event_ptr;
+		ret = cn10k_ml_meta_info_extract(op, &qp, &W0, &dev);
+		if (ret) {
+			rte_errno = EINVAL;
+			return count;
+		}
+
+		cnxk_mldev = dev->data->dev_private;
+		cn10k_mldev = &cnxk_mldev->cn10k_mldev;
+		if (rte_mempool_get(qp->mla.req_mp, (void **)(&req)) != 0) {
+			rte_errno = ENOMEM;
+			return 0;
+		}
+		req->cn10k_req.jcmd.w1.s.jobptr = PLT_U64_CAST(&req->cn10k_req.jd);
+
+		model = cnxk_mldev->mldev->data->models[op->model_id];
+		cn10k_ml_prep_fp_job_descriptor_event(cnxk_mldev, req, model->layer[0].index,
+						      op->input[0]->addr, op->output[0]->addr,
+						      op->nb_batches, &W0);
+		memset(&req->cn10k_req.result, 0, sizeof(struct cn10k_ml_result));
+		error_code = (union cn10k_ml_error_code *)&req->cn10k_req.result.error_code;
+		error_code->s.etype = ML_CNXK_ETYPE_UNKNOWN;
+		req->cn10k_req.result.user_ptr = op->user_ptr;
+		req->cnxk_mldev = cnxk_mldev;
+		req->qp_id = qp->id;
+		rte_wmb();
+
+		if (!cn10k_mldev->ml_jcmdq_enqueue(&cn10k_mldev->roc, &req->cn10k_req.jcmd)) {
+			rte_mempool_put(qp->mla.req_mp, req);
+			break;
+		}
+
+		count++;
+	}
+
+	return count;
+}
diff --git a/drivers/ml/cnxk/cn10k_ml_ops.h b/drivers/ml/cnxk/cn10k_ml_ops.h
index d225ed2098e..bf3a9fdc26c 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.h
+++ b/drivers/ml/cnxk/cn10k_ml_ops.h
@@ -5,6 +5,7 @@ 
 #ifndef _CN10K_ML_OPS_H_
 #define _CN10K_ML_OPS_H_
 
+#include <rte_eventdev.h>
 #include <rte_mldev.h>
 #include <rte_mldev_pmd.h>
 
diff --git a/drivers/ml/cnxk/cnxk_ml_ops.h b/drivers/ml/cnxk/cnxk_ml_ops.h
index 81f91df2a80..745701185ea 100644
--- a/drivers/ml/cnxk/cnxk_ml_ops.h
+++ b/drivers/ml/cnxk/cnxk_ml_ops.h
@@ -19,6 +19,8 @@ 
 #include "mvtvm_ml_stubs.h"
 #endif
 
+struct cnxk_ml_dev;
+
 /* Request structure */
 struct cnxk_ml_req {
 	/* Device specific request */
@@ -40,6 +42,12 @@  struct cnxk_ml_req {
 
 	/* Op */
 	struct rte_ml_op *op;
+
+	/* Device handle */
+	struct cnxk_ml_dev *cnxk_mldev;
+
+	/* Queue-pair ID */
+	uint16_t qp_id;
 } __rte_aligned(ROC_ALIGN);
 
 /* Request queue */
diff --git a/drivers/ml/cnxk/meson.build b/drivers/ml/cnxk/meson.build
index 0680a0faa5c..a37250babf4 100644
--- a/drivers/ml/cnxk/meson.build
+++ b/drivers/ml/cnxk/meson.build
@@ -55,7 +55,7 @@  sources = files(
         'cnxk_ml_utils.c',
 )
 
-deps += ['mldev', 'common_cnxk', 'kvargs', 'hash']
+deps += ['mldev', 'common_cnxk', 'kvargs', 'hash', 'eventdev']
 
 if enable_mvtvm
 
diff --git a/drivers/ml/cnxk/version.map b/drivers/ml/cnxk/version.map
new file mode 100644
index 00000000000..c2cacaf8c65
--- /dev/null
+++ b/drivers/ml/cnxk/version.map
@@ -0,0 +1,7 @@ 
+INTERNAL {
+	global:
+
+	cn10k_ml_adapter_enqueue;
+
+	local: *;
+};