@@ -33,6 +33,9 @@
/* ML command timeout in seconds */
#define ML_CN10K_CMD_TIMEOUT 5
+/* ML slow-path job flags */
+#define ML_CN10K_SP_FLAGS_OCM_NONRELOCATABLE BIT(0)
+
/* Poll mode job state */
#define ML_CN10K_POLL_JOB_START 0
#define ML_CN10K_POLL_JOB_FINISH 1
@@ -114,6 +114,64 @@ cn10k_ml_qp_create(const struct rte_ml_dev *dev, uint16_t qp_id, uint32_t nb_des
return NULL;
}
+static void
+cn10k_ml_prep_sp_job_descriptor(struct cn10k_ml_dev *mldev, struct cn10k_ml_model *model,
+ struct cn10k_ml_req *req, enum cn10k_ml_job_type job_type)
+{
+ struct cn10k_ml_model_metadata *metadata;
+ struct cn10k_ml_model_addr *addr;
+
+ metadata = &model->metadata;
+ addr = &model->addr;
+
+ memset(&req->jd, 0, sizeof(struct cn10k_ml_jd));
+ req->jd.hdr.jce.w0.u64 = 0;
+ req->jd.hdr.jce.w1.u64 = PLT_U64_CAST(&req->status);
+ req->jd.hdr.model_id = model->model_id;
+ req->jd.hdr.job_type = job_type;
+ req->jd.hdr.fp_flags = 0x0;
+ req->jd.hdr.result = roc_ml_addr_ap2mlip(&mldev->roc, &req->result);
+
+ if (job_type == ML_CN10K_JOB_TYPE_MODEL_START) {
+ if (!model->metadata.model.ocm_relocatable)
+ req->jd.hdr.sp_flags = ML_CN10K_SP_FLAGS_OCM_NONRELOCATABLE;
+ else
+ req->jd.hdr.sp_flags = 0x0;
+ req->jd.model_start.model_src_ddr_addr =
+ PLT_U64_CAST(roc_ml_addr_ap2mlip(&mldev->roc, addr->init_load_addr));
+ req->jd.model_start.model_dst_ddr_addr =
+ PLT_U64_CAST(roc_ml_addr_ap2mlip(&mldev->roc, addr->init_run_addr));
+ req->jd.model_start.model_init_offset = 0x0;
+ req->jd.model_start.model_main_offset = metadata->init_model.file_size;
+ req->jd.model_start.model_finish_offset =
+ metadata->init_model.file_size + metadata->main_model.file_size;
+ req->jd.model_start.model_init_size = metadata->init_model.file_size;
+ req->jd.model_start.model_main_size = metadata->main_model.file_size;
+ req->jd.model_start.model_finish_size = metadata->finish_model.file_size;
+ req->jd.model_start.model_wb_offset = metadata->init_model.file_size +
+ metadata->main_model.file_size +
+ metadata->finish_model.file_size;
+ req->jd.model_start.num_layers = metadata->model.num_layers;
+ req->jd.model_start.num_gather_entries = 0;
+ req->jd.model_start.num_scatter_entries = 0;
+ req->jd.model_start.tilemask = 0; /* Updated after reserving pages */
+ req->jd.model_start.batch_size = model->batch_size;
+ req->jd.model_start.ocm_wb_base_address = 0; /* Updated after reserving pages */
+ req->jd.model_start.ocm_wb_range_start = metadata->model.ocm_wb_range_start;
+ req->jd.model_start.ocm_wb_range_end = metadata->model.ocm_wb_range_end;
+ req->jd.model_start.ddr_wb_base_address = PLT_U64_CAST(roc_ml_addr_ap2mlip(
+ &mldev->roc,
+ PLT_PTR_ADD(addr->finish_load_addr, metadata->finish_model.file_size)));
+ req->jd.model_start.ddr_wb_range_start = metadata->model.ddr_wb_range_start;
+ req->jd.model_start.ddr_wb_range_end = metadata->model.ddr_wb_range_end;
+ req->jd.model_start.input.s.ddr_range_start = metadata->model.ddr_input_range_start;
+ req->jd.model_start.input.s.ddr_range_end = metadata->model.ddr_input_range_end;
+ req->jd.model_start.output.s.ddr_range_start =
+ metadata->model.ddr_output_range_start;
+ req->jd.model_start.output.s.ddr_range_end = metadata->model.ddr_output_range_end;
+ }
+}
+
static int
cn10k_ml_dev_info_get(struct rte_ml_dev *dev, struct rte_ml_dev_info *dev_info)
{
@@ -560,6 +618,155 @@ cn10k_ml_model_unload(struct rte_ml_dev *dev, int16_t model_id)
return plt_memzone_free(plt_memzone_lookup(str));
}
+int
+cn10k_ml_model_start(struct rte_ml_dev *dev, int16_t model_id)
+{
+ struct cn10k_ml_model *model;
+ struct cn10k_ml_dev *mldev;
+ struct cn10k_ml_ocm *ocm;
+ struct cn10k_ml_req *req;
+
+ bool job_enqueued;
+ bool job_dequeued;
+ uint8_t num_tiles;
+ uint64_t tilemask;
+ int wb_page_start;
+ int tile_start;
+ int tile_end;
+ bool locked;
+ int ret = 0;
+
+ mldev = dev->data->dev_private;
+ ocm = &mldev->ocm;
+ model = dev->data->models[model_id];
+
+ if (model == NULL) {
+ plt_err("Invalid model_id = %d", model_id);
+ return -EINVAL;
+ }
+
+ /* Prepare JD */
+ req = model->req;
+ cn10k_ml_prep_sp_job_descriptor(mldev, model, req, ML_CN10K_JOB_TYPE_MODEL_START);
+
+ req->result.error_code = 0x0;
+ req->result.user_ptr = NULL;
+
+ plt_write64(ML_CN10K_POLL_JOB_START, &req->status);
+ plt_wmb();
+
+ num_tiles = model->metadata.model.tile_end - model->metadata.model.tile_start + 1;
+
+ locked = false;
+ while (!locked) {
+ if (plt_spinlock_trylock(&model->lock) != 0) {
+ if (model->state == ML_CN10K_MODEL_STATE_STARTED) {
+ plt_ml_dbg("Model already started, model = 0x%016lx",
+ PLT_U64_CAST(model));
+ plt_spinlock_unlock(&model->lock);
+ return 1;
+ }
+
+ if (model->state == ML_CN10K_MODEL_STATE_JOB_ACTIVE) {
+ plt_err("A slow-path job is active for the model = 0x%016lx",
+ PLT_U64_CAST(model));
+ plt_spinlock_unlock(&model->lock);
+ return -EBUSY;
+ }
+
+ model->state = ML_CN10K_MODEL_STATE_JOB_ACTIVE;
+ plt_spinlock_unlock(&model->lock);
+ locked = true;
+ }
+ }
+
+ while (!model->model_mem_map.ocm_reserved) {
+ if (plt_spinlock_trylock(&ocm->lock) != 0) {
+ wb_page_start = cn10k_ml_ocm_tilemask_find(
+ dev, num_tiles, model->model_mem_map.wb_pages,
+ model->model_mem_map.scratch_pages, &tilemask);
+
+ if (wb_page_start == -1) {
+ plt_err("Free pages not available on OCM tiles");
+ plt_err("Failed to load model = 0x%016lx, name = %s",
+ PLT_U64_CAST(model), model->metadata.model.name);
+
+ plt_spinlock_unlock(&ocm->lock);
+ return -ENOMEM;
+ }
+
+ model->model_mem_map.tilemask = tilemask;
+ model->model_mem_map.wb_page_start = wb_page_start;
+
+ cn10k_ml_ocm_reserve_pages(
+ dev, model->model_id, model->model_mem_map.tilemask,
+ model->model_mem_map.wb_page_start, model->model_mem_map.wb_pages,
+ model->model_mem_map.scratch_pages);
+ model->model_mem_map.ocm_reserved = true;
+ plt_spinlock_unlock(&ocm->lock);
+ }
+ }
+
+ /* Update JD */
+ cn10k_ml_ocm_tilecount(model->model_mem_map.tilemask, &tile_start, &tile_end);
+ req->jd.model_start.tilemask = GENMASK_ULL(tile_end, tile_start);
+ req->jd.model_start.ocm_wb_base_address =
+ model->model_mem_map.wb_page_start * ocm->page_size;
+
+ job_enqueued = false;
+ job_dequeued = false;
+ do {
+ if (!job_enqueued) {
+ req->timeout = plt_tsc_cycles() + ML_CN10K_CMD_TIMEOUT * plt_tsc_hz();
+ job_enqueued = roc_ml_scratch_enqueue(&mldev->roc, &req->jd);
+ }
+
+ if (job_enqueued && !job_dequeued)
+ job_dequeued = roc_ml_scratch_dequeue(&mldev->roc, &req->jd);
+
+ if (job_dequeued)
+ break;
+ } while (plt_tsc_cycles() < req->timeout);
+
+ if (job_dequeued) {
+ if (plt_read64(&req->status) == ML_CN10K_POLL_JOB_FINISH) {
+ if (req->result.error_code == 0)
+ ret = 0;
+ else
+ ret = -1;
+ }
+ } else { /* Reset scratch registers */
+ roc_ml_scratch_queue_reset(&mldev->roc);
+ ret = -ETIME;
+ }
+
+ locked = false;
+ while (!locked) {
+ if (plt_spinlock_trylock(&model->lock) != 0) {
+ if (ret == 0)
+ model->state = ML_CN10K_MODEL_STATE_STARTED;
+ else
+ model->state = ML_CN10K_MODEL_STATE_UNKNOWN;
+
+ plt_spinlock_unlock(&model->lock);
+ locked = true;
+ }
+ }
+
+ if (model->state == ML_CN10K_MODEL_STATE_UNKNOWN) {
+ while (model->model_mem_map.ocm_reserved) {
+ if (plt_spinlock_trylock(&ocm->lock) != 0) {
+ cn10k_ml_ocm_free_pages(dev, model->model_id);
+ model->model_mem_map.ocm_reserved = false;
+ model->model_mem_map.tilemask = 0x0;
+ plt_spinlock_unlock(&ocm->lock);
+ }
+ }
+ }
+
+ return ret;
+}
+
struct rte_ml_dev_ops cn10k_ml_ops = {
/* Device control ops */
.dev_info_get = cn10k_ml_dev_info_get,
@@ -575,4 +782,5 @@ struct rte_ml_dev_ops cn10k_ml_ops = {
/* Model ops */
.model_load = cn10k_ml_model_load,
.model_unload = cn10k_ml_model_unload,
+ .model_start = cn10k_ml_model_start,
};
@@ -25,6 +25,9 @@ struct cn10k_ml_req {
/* Job command */
struct ml_job_cmd_s jcmd;
+
+ /* Request timeout cycle */
+ uint64_t timeout;
} __rte_aligned(ROC_ALIGN);
/* ML request queue */
@@ -61,5 +64,6 @@ extern struct rte_ml_dev_ops cn10k_ml_ops;
int cn10k_ml_model_load(struct rte_ml_dev *dev, struct rte_ml_model_params *params,
int16_t *model_id);
int cn10k_ml_model_unload(struct rte_ml_dev *dev, int16_t model_id);
+int cn10k_ml_model_start(struct rte_ml_dev *dev, int16_t model_id);
#endif /* _CN10K_ML_OPS_H_ */