[v5,18/39] ml/cnxk: enable support to start an ML model

Message ID 20230207160719.1307-19-syalavarthi@marvell.com (mailing list archive)
State Superseded, archived
Delegated to: Thomas Monjalon
Headers
Series Implementation of ML CNXK driver |

Checks

Context Check Description
ci/checkpatch success coding style OK

Commit Message

Srikanth Yalavarthi Feb. 7, 2023, 4:06 p.m. UTC
  Implemented model start driver function. A model start  job
is checked for completion in synchronous mode. Tilemask and
OCM slot is calculated before starting the model. Model start
is enqueued through scratch registers. OCM pages are reserved
after model start completion.

Signed-off-by: Srikanth Yalavarthi <syalavarthi@marvell.com>
---
 drivers/ml/cnxk/cn10k_ml_dev.h |   3 +
 drivers/ml/cnxk/cn10k_ml_ops.c | 207 +++++++++++++++++++++++++++++++++
 drivers/ml/cnxk/cn10k_ml_ops.h |   4 +
 3 files changed, 214 insertions(+)
  

Patch

diff --git a/drivers/ml/cnxk/cn10k_ml_dev.h b/drivers/ml/cnxk/cn10k_ml_dev.h
index 68fcc957fa..8f6bc24370 100644
--- a/drivers/ml/cnxk/cn10k_ml_dev.h
+++ b/drivers/ml/cnxk/cn10k_ml_dev.h
@@ -33,6 +33,9 @@ 
 /* ML command timeout in seconds */
 #define ML_CN10K_CMD_TIMEOUT 5
 
+/* ML slow-path job flags */
+#define ML_CN10K_SP_FLAGS_OCM_NONRELOCATABLE BIT(0)
+
 /* Poll mode job state */
 #define ML_CN10K_POLL_JOB_START	 0
 #define ML_CN10K_POLL_JOB_FINISH 1
diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c
index 56adce12ea..e8ce65b182 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.c
+++ b/drivers/ml/cnxk/cn10k_ml_ops.c
@@ -114,6 +114,64 @@  cn10k_ml_qp_create(const struct rte_ml_dev *dev, uint16_t qp_id, uint32_t nb_des
 	return NULL;
 }
 
+static void
+cn10k_ml_prep_sp_job_descriptor(struct cn10k_ml_dev *mldev, struct cn10k_ml_model *model,
+				struct cn10k_ml_req *req, enum cn10k_ml_job_type job_type)
+{
+	struct cn10k_ml_model_metadata *metadata;
+	struct cn10k_ml_model_addr *addr;
+
+	metadata = &model->metadata;
+	addr = &model->addr;
+
+	memset(&req->jd, 0, sizeof(struct cn10k_ml_jd));
+	req->jd.hdr.jce.w0.u64 = 0;
+	req->jd.hdr.jce.w1.u64 = PLT_U64_CAST(&req->status);
+	req->jd.hdr.model_id = model->model_id;
+	req->jd.hdr.job_type = job_type;
+	req->jd.hdr.fp_flags = 0x0;
+	req->jd.hdr.result = roc_ml_addr_ap2mlip(&mldev->roc, &req->result);
+
+	if (job_type == ML_CN10K_JOB_TYPE_MODEL_START) {
+		if (!model->metadata.model.ocm_relocatable)
+			req->jd.hdr.sp_flags = ML_CN10K_SP_FLAGS_OCM_NONRELOCATABLE;
+		else
+			req->jd.hdr.sp_flags = 0x0;
+		req->jd.model_start.model_src_ddr_addr =
+			PLT_U64_CAST(roc_ml_addr_ap2mlip(&mldev->roc, addr->init_load_addr));
+		req->jd.model_start.model_dst_ddr_addr =
+			PLT_U64_CAST(roc_ml_addr_ap2mlip(&mldev->roc, addr->init_run_addr));
+		req->jd.model_start.model_init_offset = 0x0;
+		req->jd.model_start.model_main_offset = metadata->init_model.file_size;
+		req->jd.model_start.model_finish_offset =
+			metadata->init_model.file_size + metadata->main_model.file_size;
+		req->jd.model_start.model_init_size = metadata->init_model.file_size;
+		req->jd.model_start.model_main_size = metadata->main_model.file_size;
+		req->jd.model_start.model_finish_size = metadata->finish_model.file_size;
+		req->jd.model_start.model_wb_offset = metadata->init_model.file_size +
+						      metadata->main_model.file_size +
+						      metadata->finish_model.file_size;
+		req->jd.model_start.num_layers = metadata->model.num_layers;
+		req->jd.model_start.num_gather_entries = 0;
+		req->jd.model_start.num_scatter_entries = 0;
+		req->jd.model_start.tilemask = 0; /* Updated after reserving pages */
+		req->jd.model_start.batch_size = model->batch_size;
+		req->jd.model_start.ocm_wb_base_address = 0; /* Updated after reserving pages */
+		req->jd.model_start.ocm_wb_range_start = metadata->model.ocm_wb_range_start;
+		req->jd.model_start.ocm_wb_range_end = metadata->model.ocm_wb_range_end;
+		req->jd.model_start.ddr_wb_base_address = PLT_U64_CAST(roc_ml_addr_ap2mlip(
+			&mldev->roc,
+			PLT_PTR_ADD(addr->finish_load_addr, metadata->finish_model.file_size)));
+		req->jd.model_start.ddr_wb_range_start = metadata->model.ddr_wb_range_start;
+		req->jd.model_start.ddr_wb_range_end = metadata->model.ddr_wb_range_end;
+		req->jd.model_start.input.s.ddr_range_start = metadata->model.ddr_input_range_start;
+		req->jd.model_start.input.s.ddr_range_end = metadata->model.ddr_input_range_end;
+		req->jd.model_start.output.s.ddr_range_start =
+			metadata->model.ddr_output_range_start;
+		req->jd.model_start.output.s.ddr_range_end = metadata->model.ddr_output_range_end;
+	}
+}
+
 static int
 cn10k_ml_dev_info_get(struct rte_ml_dev *dev, struct rte_ml_dev_info *dev_info)
 {
@@ -561,6 +619,154 @@  cn10k_ml_model_unload(struct rte_ml_dev *dev, uint16_t model_id)
 	return plt_memzone_free(plt_memzone_lookup(str));
 }
 
+int
+cn10k_ml_model_start(struct rte_ml_dev *dev, uint16_t model_id)
+{
+	struct cn10k_ml_model *model;
+	struct cn10k_ml_dev *mldev;
+	struct cn10k_ml_ocm *ocm;
+	struct cn10k_ml_req *req;
+
+	bool job_enqueued;
+	bool job_dequeued;
+	uint8_t num_tiles;
+	uint64_t tilemask;
+	int wb_page_start;
+	int tile_start;
+	int tile_end;
+	bool locked;
+	int ret = 0;
+
+	mldev = dev->data->dev_private;
+	ocm = &mldev->ocm;
+	model = dev->data->models[model_id];
+
+	if (model == NULL) {
+		plt_err("Invalid model_id = %u", model_id);
+		return -EINVAL;
+	}
+
+	/* Prepare JD */
+	req = model->req;
+	cn10k_ml_prep_sp_job_descriptor(mldev, model, req, ML_CN10K_JOB_TYPE_MODEL_START);
+	req->result.error_code = 0x0;
+	req->result.user_ptr = NULL;
+
+	plt_write64(ML_CN10K_POLL_JOB_START, &req->status);
+	plt_wmb();
+
+	num_tiles = model->metadata.model.tile_end - model->metadata.model.tile_start + 1;
+
+	locked = false;
+	while (!locked) {
+		if (plt_spinlock_trylock(&model->lock) != 0) {
+			if (model->state == ML_CN10K_MODEL_STATE_STARTED) {
+				plt_ml_dbg("Model already started, model = 0x%016lx",
+					   PLT_U64_CAST(model));
+				plt_spinlock_unlock(&model->lock);
+				return 1;
+			}
+
+			if (model->state == ML_CN10K_MODEL_STATE_JOB_ACTIVE) {
+				plt_err("A slow-path job is active for the model = 0x%016lx",
+					PLT_U64_CAST(model));
+				plt_spinlock_unlock(&model->lock);
+				return -EBUSY;
+			}
+
+			model->state = ML_CN10K_MODEL_STATE_JOB_ACTIVE;
+			plt_spinlock_unlock(&model->lock);
+			locked = true;
+		}
+	}
+
+	while (!model->model_mem_map.ocm_reserved) {
+		if (plt_spinlock_trylock(&ocm->lock) != 0) {
+			wb_page_start = cn10k_ml_ocm_tilemask_find(
+				dev, num_tiles, model->model_mem_map.wb_pages,
+				model->model_mem_map.scratch_pages, &tilemask);
+
+			if (wb_page_start == -1) {
+				plt_err("Free pages not available on OCM tiles");
+				plt_err("Failed to start model = 0x%016lx, name = %s",
+					PLT_U64_CAST(model), model->metadata.model.name);
+
+				plt_spinlock_unlock(&ocm->lock);
+				return -ENOMEM;
+			}
+
+			model->model_mem_map.tilemask = tilemask;
+			model->model_mem_map.wb_page_start = wb_page_start;
+
+			cn10k_ml_ocm_reserve_pages(
+				dev, model->model_id, model->model_mem_map.tilemask,
+				model->model_mem_map.wb_page_start, model->model_mem_map.wb_pages,
+				model->model_mem_map.scratch_pages);
+			model->model_mem_map.ocm_reserved = true;
+			plt_spinlock_unlock(&ocm->lock);
+		}
+	}
+
+	/* Update JD */
+	cn10k_ml_ocm_tilecount(model->model_mem_map.tilemask, &tile_start, &tile_end);
+	req->jd.model_start.tilemask = GENMASK_ULL(tile_end, tile_start);
+	req->jd.model_start.ocm_wb_base_address =
+		model->model_mem_map.wb_page_start * ocm->page_size;
+
+	job_enqueued = false;
+	job_dequeued = false;
+	do {
+		if (!job_enqueued) {
+			req->timeout = plt_tsc_cycles() + ML_CN10K_CMD_TIMEOUT * plt_tsc_hz();
+			job_enqueued = roc_ml_scratch_enqueue(&mldev->roc, &req->jd);
+		}
+
+		if (job_enqueued && !job_dequeued)
+			job_dequeued = roc_ml_scratch_dequeue(&mldev->roc, &req->jd);
+
+		if (job_dequeued)
+			break;
+	} while (plt_tsc_cycles() < req->timeout);
+
+	if (job_dequeued) {
+		if (plt_read64(&req->status) == ML_CN10K_POLL_JOB_FINISH) {
+			if (req->result.error_code == 0)
+				ret = 0;
+			else
+				ret = -1;
+		}
+	} else { /* Reset scratch registers */
+		roc_ml_scratch_queue_reset(&mldev->roc);
+		ret = -ETIME;
+	}
+
+	locked = false;
+	while (!locked) {
+		if (plt_spinlock_trylock(&model->lock) != 0) {
+			if (ret == 0)
+				model->state = ML_CN10K_MODEL_STATE_STARTED;
+			else
+				model->state = ML_CN10K_MODEL_STATE_UNKNOWN;
+
+			plt_spinlock_unlock(&model->lock);
+			locked = true;
+		}
+	}
+
+	if (model->state == ML_CN10K_MODEL_STATE_UNKNOWN) {
+		while (model->model_mem_map.ocm_reserved) {
+			if (plt_spinlock_trylock(&ocm->lock) != 0) {
+				cn10k_ml_ocm_free_pages(dev, model->model_id);
+				model->model_mem_map.ocm_reserved = false;
+				model->model_mem_map.tilemask = 0x0;
+				plt_spinlock_unlock(&ocm->lock);
+			}
+		}
+	}
+
+	return ret;
+}
+
 struct rte_ml_dev_ops cn10k_ml_ops = {
 	/* Device control ops */
 	.dev_info_get = cn10k_ml_dev_info_get,
@@ -576,4 +782,5 @@  struct rte_ml_dev_ops cn10k_ml_ops = {
 	/* Model ops */
 	.model_load = cn10k_ml_model_load,
 	.model_unload = cn10k_ml_model_unload,
+	.model_start = cn10k_ml_model_start,
 };
diff --git a/drivers/ml/cnxk/cn10k_ml_ops.h b/drivers/ml/cnxk/cn10k_ml_ops.h
index c86ce66f19..989af978c4 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.h
+++ b/drivers/ml/cnxk/cn10k_ml_ops.h
@@ -25,6 +25,9 @@  struct cn10k_ml_req {
 
 	/* Job command */
 	struct ml_job_cmd_s jcmd;
+
+	/* Timeout cycle */
+	uint64_t timeout;
 } __rte_aligned(ROC_ALIGN);
 
 /* Request queue */
@@ -61,5 +64,6 @@  extern struct rte_ml_dev_ops cn10k_ml_ops;
 int cn10k_ml_model_load(struct rte_ml_dev *dev, struct rte_ml_model_params *params,
 			uint16_t *model_id);
 int cn10k_ml_model_unload(struct rte_ml_dev *dev, uint16_t model_id);
+int cn10k_ml_model_start(struct rte_ml_dev *dev, uint16_t model_id);
 
 #endif /* _CN10K_ML_OPS_H_ */