@@ -23,6 +23,7 @@
#define CN10K_ML_DEV_CACHE_MODEL_DATA "cache_model_data"
#define CN10K_ML_OCM_ALLOC_MODE "ocm_alloc_mode"
#define CN10K_ML_DEV_HW_QUEUE_LOCK "hw_queue_lock"
+#define CN10K_ML_FW_POLL_MEM "poll_mem"
#define CN10K_ML_FW_PATH_DEFAULT "/lib/firmware/mlip-fw.bin"
#define CN10K_ML_FW_ENABLE_DPE_WARNINGS_DEFAULT 1
@@ -30,6 +31,7 @@
#define CN10K_ML_DEV_CACHE_MODEL_DATA_DEFAULT 1
#define CN10K_ML_OCM_ALLOC_MODE_DEFAULT "lowest"
#define CN10K_ML_DEV_HW_QUEUE_LOCK_DEFAULT 1
+#define CN10K_ML_FW_POLL_MEM_DEFAULT "ddr"
/* ML firmware macros */
#define FW_MEMZONE_NAME "ml_cn10k_fw_mz"
@@ -42,6 +44,7 @@
/* Firmware flags */
#define FW_ENABLE_DPE_WARNING_BITMASK BIT(0)
#define FW_REPORT_DPE_WARNING_BITMASK BIT(1)
+#define FW_USE_DDR_POLL_ADDR_FP BIT(2)
static const char *const valid_args[] = {CN10K_ML_FW_PATH,
CN10K_ML_FW_ENABLE_DPE_WARNINGS,
@@ -49,6 +52,7 @@ static const char *const valid_args[] = {CN10K_ML_FW_PATH,
CN10K_ML_DEV_CACHE_MODEL_DATA,
CN10K_ML_OCM_ALLOC_MODE,
CN10K_ML_DEV_HW_QUEUE_LOCK,
+ CN10K_ML_FW_POLL_MEM,
NULL};
/* Dummy operations for ML device */
@@ -92,7 +96,9 @@ cn10k_mldev_parse_devargs(struct rte_devargs *devargs, struct cn10k_ml_dev *mlde
bool ocm_alloc_mode_set = false;
bool hw_queue_lock_set = false;
char *ocm_alloc_mode = NULL;
+ bool poll_mem_set = false;
bool fw_path_set = false;
+ char *poll_mem = NULL;
char *fw_path = NULL;
int ret = 0;
@@ -174,6 +180,17 @@ cn10k_mldev_parse_devargs(struct rte_devargs *devargs, struct cn10k_ml_dev *mlde
hw_queue_lock_set = true;
}
+ if (rte_kvargs_count(kvlist, CN10K_ML_FW_POLL_MEM) == 1) {
+ ret = rte_kvargs_process(kvlist, CN10K_ML_FW_POLL_MEM, &parse_string_arg,
+ &poll_mem);
+ if (ret < 0) {
+ plt_err("Error processing arguments, key = %s\n", CN10K_ML_FW_POLL_MEM);
+ ret = -EINVAL;
+ goto exit;
+ }
+ poll_mem_set = true;
+ }
+
check_args:
if (!fw_path_set)
mldev->fw.path = CN10K_ML_FW_PATH_DEFAULT;
@@ -243,6 +260,18 @@ cn10k_mldev_parse_devargs(struct rte_devargs *devargs, struct cn10k_ml_dev *mlde
}
plt_info("ML: %s = %d", CN10K_ML_DEV_HW_QUEUE_LOCK, mldev->hw_queue_lock);
+ if (!poll_mem_set) {
+ mldev->fw.poll_mem = CN10K_ML_FW_POLL_MEM_DEFAULT;
+ } else {
+ if (!((strcmp(poll_mem, "ddr") == 0) || (strcmp(poll_mem, "register") == 0))) {
+ plt_err("Invalid argument, %s = %s\n", CN10K_ML_FW_POLL_MEM, poll_mem);
+ ret = -EINVAL;
+ goto exit;
+ }
+ mldev->fw.poll_mem = poll_mem;
+ }
+ plt_info("ML: %s = %s", CN10K_ML_FW_POLL_MEM, mldev->fw.poll_mem);
+
exit:
if (kvlist)
rte_kvargs_free(kvlist);
@@ -376,6 +405,11 @@ cn10k_ml_fw_flags_get(struct cn10k_ml_fw *fw)
if (fw->report_dpe_warnings)
flags = flags | FW_REPORT_DPE_WARNING_BITMASK;
+ if (strcmp(fw->poll_mem, "ddr") == 0)
+ flags = flags | FW_USE_DDR_POLL_ADDR_FP;
+ else if (strcmp(fw->poll_mem, "register") == 0)
+ flags = flags & ~FW_USE_DDR_POLL_ADDR_FP;
+
return flags;
}
@@ -780,9 +814,10 @@ RTE_PMD_REGISTER_PCI(MLDEV_NAME_CN10K_PMD, cn10k_mldev_pmd);
RTE_PMD_REGISTER_PCI_TABLE(MLDEV_NAME_CN10K_PMD, pci_id_ml_table);
RTE_PMD_REGISTER_KMOD_DEP(MLDEV_NAME_CN10K_PMD, "vfio-pci");
-RTE_PMD_REGISTER_PARAM_STRING(MLDEV_NAME_CN10K_PMD, CN10K_ML_FW_PATH
- "=<path>" CN10K_ML_FW_ENABLE_DPE_WARNINGS
- "=<0|1>" CN10K_ML_FW_REPORT_DPE_WARNINGS
- "=<0|1>" CN10K_ML_DEV_CACHE_MODEL_DATA
- "=<0|1>" CN10K_ML_OCM_ALLOC_MODE
- "=<lowest|largest>" CN10K_ML_DEV_HW_QUEUE_LOCK "=<0|1>");
+RTE_PMD_REGISTER_PARAM_STRING(MLDEV_NAME_CN10K_PMD,
+ CN10K_ML_FW_PATH "=<path>" CN10K_ML_FW_ENABLE_DPE_WARNINGS
+ "=<0|1>" CN10K_ML_FW_REPORT_DPE_WARNINGS
+ "=<0|1>" CN10K_ML_DEV_CACHE_MODEL_DATA
+ "=<0|1>" CN10K_ML_OCM_ALLOC_MODE
+ "=<lowest|largest>" CN10K_ML_DEV_HW_QUEUE_LOCK
+ "=<0|1>" CN10K_ML_FW_POLL_MEM "=<ddr|register>");
@@ -43,6 +43,18 @@
#define ML_CN10K_POLL_JOB_START 0
#define ML_CN10K_POLL_JOB_FINISH 1
+/* Memory barrier macros */
+#if defined(RTE_ARCH_ARM)
+#define dmb_st ({ asm volatile("dmb st" : : : "memory"); })
+#define dsb_st ({ asm volatile("dsb st" : : : "memory"); })
+#else
+#define dmb_st
+#define dsb_st
+#endif
+
+struct cn10k_ml_req;
+struct cn10k_ml_qp;
+
/* ML Job types */
enum cn10k_ml_job_type {
ML_CN10K_JOB_TYPE_MODEL_RUN = 0,
@@ -358,6 +370,9 @@ struct cn10k_ml_fw {
/* Report DPE warnings */
int report_dpe_warnings;
+ /* Memory to be used for polling in fast-path requests */
+ const char *poll_mem;
+
/* Data buffer */
uint8_t *data;
@@ -393,6 +408,15 @@ struct cn10k_ml_dev {
/* JCMD enqueue function handler */
bool (*ml_jcmdq_enqueue)(struct roc_ml *roc_ml, struct ml_job_cmd_s *job_cmd);
+
+ /* Poll handling function pointers */
+ void (*set_poll_addr)(struct cn10k_ml_qp *qp, struct cn10k_ml_req *req, uint64_t idx);
+ void (*set_poll_ptr)(struct roc_ml *roc_ml, struct cn10k_ml_req *req);
+ uint64_t (*get_poll_ptr)(struct roc_ml *roc_ml, struct cn10k_ml_req *req);
+
+ /* Memory barrier function pointers to handle synchronization */
+ void (*set_enq_barrier)(void);
+ void (*set_deq_barrier)(void);
};
uint64_t cn10k_ml_fw_flags_get(struct cn10k_ml_fw *fw);
@@ -23,6 +23,11 @@
#define ML_FLAGS_POLL_COMPL BIT(0)
#define ML_FLAGS_SSO_COMPL BIT(1)
+/* Scratch register range for poll mode requests */
+#define ML_POLL_REGISTER_SYNC 1023
+#define ML_POLL_REGISTER_START 1024
+#define ML_POLL_REGISTER_END 2047
+
/* Error message length */
#define ERRMSG_LEN 32
@@ -76,6 +81,80 @@ print_line(FILE *fp, int len)
fprintf(fp, "\n");
}
+static inline void
+cn10k_ml_set_poll_addr_ddr(struct cn10k_ml_qp *qp, struct cn10k_ml_req *req, uint64_t idx)
+{
+ PLT_SET_USED(qp);
+ PLT_SET_USED(idx);
+
+ req->compl_W1 = PLT_U64_CAST(&req->status);
+}
+
+static inline void
+cn10k_ml_set_poll_addr_reg(struct cn10k_ml_qp *qp, struct cn10k_ml_req *req, uint64_t idx)
+{
+ req->compl_W1 = ML_SCRATCH(qp->block_start + idx % qp->block_size);
+}
+
+static inline void
+cn10k_ml_set_poll_ptr_ddr(struct roc_ml *roc_ml, struct cn10k_ml_req *req)
+{
+ PLT_SET_USED(roc_ml);
+
+ plt_write64(ML_CN10K_POLL_JOB_START, req->compl_W1);
+}
+
+static inline void
+cn10k_ml_set_poll_ptr_reg(struct roc_ml *roc_ml, struct cn10k_ml_req *req)
+{
+ roc_ml_reg_write64(roc_ml, ML_CN10K_POLL_JOB_START, req->compl_W1);
+}
+
+static inline uint64_t
+cn10k_ml_get_poll_ptr_ddr(struct roc_ml *roc_ml, struct cn10k_ml_req *req)
+{
+ PLT_SET_USED(roc_ml);
+
+ return plt_read64(req->compl_W1);
+}
+
+static inline uint64_t
+cn10k_ml_get_poll_ptr_reg(struct roc_ml *roc_ml, struct cn10k_ml_req *req)
+{
+ return roc_ml_reg_read64(roc_ml, req->compl_W1);
+}
+
+static inline void
+cn10k_ml_set_sync_addr(struct cn10k_ml_dev *mldev, struct cn10k_ml_req *req)
+{
+ if (strcmp(mldev->fw.poll_mem, "ddr") == 0)
+ req->compl_W1 = PLT_U64_CAST(&req->status);
+ else if (strcmp(mldev->fw.poll_mem, "register") == 0)
+ req->compl_W1 = ML_SCRATCH(ML_POLL_REGISTER_SYNC);
+}
+
+static inline void
+cn10k_ml_enq_barrier_ddr(void)
+{
+}
+
+static inline void
+cn10k_ml_deq_barrier_ddr(void)
+{
+}
+
+static inline void
+cn10k_ml_enq_barrier_register(void)
+{
+ dmb_st;
+}
+
+static inline void
+cn10k_ml_deq_barrier_register(void)
+{
+ dsb_st;
+}
+
static void
qp_memzone_name_get(char *name, int size, int dev_id, int qp_id)
{
@@ -163,6 +242,9 @@ cn10k_ml_qp_create(const struct rte_ml_dev *dev, uint16_t qp_id, uint32_t nb_des
qp->stats.dequeued_count = 0;
qp->stats.enqueue_err_count = 0;
qp->stats.dequeue_err_count = 0;
+ qp->block_size =
+ (ML_POLL_REGISTER_END - ML_POLL_REGISTER_START + 1) / dev->data->nb_queue_pairs;
+ qp->block_start = ML_POLL_REGISTER_START + qp_id * qp->block_size;
/* Initialize job command */
for (i = 0; i < qp->nb_desc; i++) {
@@ -341,7 +423,7 @@ cn10k_ml_prep_fp_job_descriptor(struct rte_ml_dev *dev, struct cn10k_ml_req *req
mldev = dev->data->dev_private;
req->jd.hdr.jce.w0.u64 = 0;
- req->jd.hdr.jce.w1.u64 = PLT_U64_CAST(&req->status);
+ req->jd.hdr.jce.w1.u64 = req->compl_W1;
req->jd.hdr.model_id = op->model_id;
req->jd.hdr.job_type = ML_CN10K_JOB_TYPE_MODEL_RUN;
req->jd.hdr.fp_flags = ML_FLAGS_POLL_COMPL;
@@ -549,7 +631,11 @@ cn10k_ml_dev_info_get(struct rte_ml_dev *dev, struct rte_ml_dev_info *dev_info)
else
dev_info->max_queue_pairs = ML_CN10K_MAX_QP_PER_DEVICE_LF;
- dev_info->max_desc = ML_CN10K_MAX_DESC_PER_QP;
+ if (strcmp(mldev->fw.poll_mem, "register") == 0)
+ dev_info->max_desc = ML_CN10K_MAX_DESC_PER_QP / dev_info->max_queue_pairs;
+ else if (strcmp(mldev->fw.poll_mem, "ddr") == 0)
+ dev_info->max_desc = ML_CN10K_MAX_DESC_PER_QP;
+
dev_info->max_segments = ML_CN10K_MAX_SEGMENTS;
dev_info->min_align_size = ML_CN10K_ALIGN_SIZE;
@@ -717,6 +803,26 @@ cn10k_ml_dev_configure(struct rte_ml_dev *dev, const struct rte_ml_dev_config *c
else
mldev->ml_jcmdq_enqueue = roc_ml_jcmdq_enqueue_lf;
+ /* Set polling function pointers */
+ if (strcmp(mldev->fw.poll_mem, "ddr") == 0) {
+ mldev->set_poll_addr = cn10k_ml_set_poll_addr_ddr;
+ mldev->set_poll_ptr = cn10k_ml_set_poll_ptr_ddr;
+ mldev->get_poll_ptr = cn10k_ml_get_poll_ptr_ddr;
+ } else if (strcmp(mldev->fw.poll_mem, "register") == 0) {
+ mldev->set_poll_addr = cn10k_ml_set_poll_addr_reg;
+ mldev->set_poll_ptr = cn10k_ml_set_poll_ptr_reg;
+ mldev->get_poll_ptr = cn10k_ml_get_poll_ptr_reg;
+ }
+
+ /* Set barrier function pointers */
+ if (strcmp(mldev->fw.poll_mem, "ddr") == 0) {
+ mldev->set_enq_barrier = cn10k_ml_enq_barrier_ddr;
+ mldev->set_deq_barrier = cn10k_ml_deq_barrier_ddr;
+ } else if (strcmp(mldev->fw.poll_mem, "register") == 0) {
+ mldev->set_enq_barrier = cn10k_ml_enq_barrier_register;
+ mldev->set_deq_barrier = cn10k_ml_deq_barrier_register;
+ }
+
dev->enqueue_burst = cn10k_ml_enqueue_burst;
dev->dequeue_burst = cn10k_ml_dequeue_burst;
dev->op_error_get = cn10k_ml_op_error_get;
@@ -2003,13 +2109,15 @@ cn10k_ml_enqueue_burst(struct rte_ml_dev *dev, uint16_t qp_id, struct rte_ml_op
op = ops[count];
req = &queue->reqs[head];
+ mldev->set_poll_addr(qp, req, head);
cn10k_ml_prep_fp_job_descriptor(dev, req, op);
memset(&req->result, 0, sizeof(struct cn10k_ml_result));
req->result.error_code.s.etype = ML_ETYPE_UNKNOWN;
req->result.user_ptr = op->user_ptr;
+ mldev->set_enq_barrier();
- plt_write64(ML_CN10K_POLL_JOB_START, &req->status);
+ mldev->set_poll_ptr(&mldev->roc, req);
enqueued = mldev->ml_jcmdq_enqueue(&mldev->roc, &req->jcmd);
if (unlikely(!enqueued))
goto jcmdq_full;
@@ -2035,6 +2143,7 @@ cn10k_ml_dequeue_burst(struct rte_ml_dev *dev, uint16_t qp_id, struct rte_ml_op
uint16_t nb_ops)
{
struct cn10k_ml_queue *queue;
+ struct cn10k_ml_dev *mldev;
struct cn10k_ml_req *req;
struct cn10k_ml_qp *qp;
@@ -2042,6 +2151,7 @@ cn10k_ml_dequeue_burst(struct rte_ml_dev *dev, uint16_t qp_id, struct rte_ml_op
uint16_t count;
uint64_t tail;
+ mldev = dev->data->dev_private;
qp = dev->data->queue_pairs[qp_id];
queue = &qp->queue;
@@ -2054,7 +2164,7 @@ cn10k_ml_dequeue_burst(struct rte_ml_dev *dev, uint16_t qp_id, struct rte_ml_op
dequeue_req:
req = &queue->reqs[tail];
- status = plt_read64(&req->status);
+ status = mldev->get_poll_ptr(&mldev->roc, req);
if (unlikely(status != ML_CN10K_POLL_JOB_FINISH)) {
if (plt_tsc_cycles() < req->timeout)
goto empty_or_active;
@@ -2062,6 +2172,7 @@ cn10k_ml_dequeue_burst(struct rte_ml_dev *dev, uint16_t qp_id, struct rte_ml_op
req->result.error_code.s.etype = ML_ETYPE_DRIVER;
}
+ mldev->set_deq_barrier();
cn10k_ml_result_update(dev, qp_id, &req->result, req->op);
ops[count] = req->op;
@@ -2119,13 +2230,14 @@ cn10k_ml_inference_sync(struct rte_ml_dev *dev, struct rte_ml_op *op)
model = dev->data->models[op->model_id];
req = model->req;
+ cn10k_ml_set_sync_addr(mldev, req);
cn10k_ml_prep_fp_job_descriptor(dev, req, op);
memset(&req->result, 0, sizeof(struct cn10k_ml_result));
req->result.error_code.s.etype = ML_ETYPE_UNKNOWN;
req->result.user_ptr = op->user_ptr;
- plt_write64(ML_CN10K_POLL_JOB_START, &req->status);
+ mldev->set_poll_ptr(&mldev->roc, req);
req->jcmd.w1.s.jobptr = PLT_U64_CAST(&req->jd);
timeout = true;
@@ -2145,7 +2257,7 @@ cn10k_ml_inference_sync(struct rte_ml_dev *dev, struct rte_ml_op *op)
timeout = true;
do {
- if (plt_read64(&req->status) == ML_CN10K_POLL_JOB_FINISH) {
+ if (mldev->get_poll_ptr(&mldev->roc, req) == ML_CN10K_POLL_JOB_FINISH) {
timeout = false;
break;
}
@@ -26,6 +26,9 @@ struct cn10k_ml_req {
/* Job command */
struct ml_job_cmd_s jcmd;
+ /* Job completion W1 */
+ uint64_t compl_W1;
+
/* Request timeout cycle */
uint64_t timeout;
@@ -61,6 +64,12 @@ struct cn10k_ml_qp {
/* Queue pair statistics */
struct rte_ml_dev_stats stats;
+
+ /* Register block start for polling */
+ uint32_t block_start;
+
+ /* Register block end for polling */
+ uint32_t block_size;
};
/* CN10K device ops */