@@ -20,10 +20,12 @@
#define CN10K_ML_FW_PATH "fw_path"
#define CN10K_ML_FW_ENABLE_DPE_WARNINGS "enable_dpe_warnings"
#define CN10K_ML_FW_REPORT_DPE_WARNINGS "report_dpe_warnings"
+#define CN10K_ML_DEV_CACHE_MODEL_DATA "cache_model_data"
#define CN10K_ML_FW_PATH_DEFAULT "/lib/firmware/mlip-fw.bin"
#define CN10K_ML_FW_ENABLE_DPE_WARNINGS_DEFAULT 1
#define CN10K_ML_FW_REPORT_DPE_WARNINGS_DEFAULT 0
+#define CN10K_ML_DEV_CACHE_MODEL_DATA_DEFAULT 1
/* ML firmware macros */
#define FW_MEMZONE_NAME "ml_cn10k_fw_mz"
@@ -38,7 +40,8 @@
#define FW_REPORT_DPE_WARNING_BITMASK BIT(1)
static const char *const valid_args[] = {CN10K_ML_FW_PATH, CN10K_ML_FW_ENABLE_DPE_WARNINGS,
- CN10K_ML_FW_REPORT_DPE_WARNINGS, NULL};
+ CN10K_ML_FW_REPORT_DPE_WARNINGS,
+ CN10K_ML_DEV_CACHE_MODEL_DATA, NULL};
/* Dummy operations for ML device */
struct rte_ml_dev_ops ml_dev_dummy_ops = {0};
@@ -76,6 +79,7 @@ cn10k_mldev_parse_devargs(struct rte_devargs *devargs, struct cn10k_ml_dev *mlde
{
bool enable_dpe_warnings_set = false;
bool report_dpe_warnings_set = false;
+ bool cache_model_data_set = false;
struct rte_kvargs *kvlist = NULL;
bool fw_path_set = false;
char *fw_path = NULL;
@@ -124,6 +128,18 @@ cn10k_mldev_parse_devargs(struct rte_devargs *devargs, struct cn10k_ml_dev *mlde
report_dpe_warnings_set = true;
}
+ if (rte_kvargs_count(kvlist, CN10K_ML_DEV_CACHE_MODEL_DATA) == 1) {
+ ret = rte_kvargs_process(kvlist, CN10K_ML_DEV_CACHE_MODEL_DATA, &parse_integer_arg,
+ &mldev->cache_model_data);
+ if (ret < 0) {
+ plt_err("Error processing arguments, key = %s\n",
+ CN10K_ML_DEV_CACHE_MODEL_DATA);
+ ret = -EINVAL;
+ goto exit;
+ }
+ cache_model_data_set = true;
+ }
+
check_args:
if (!fw_path_set)
mldev->fw.path = CN10K_ML_FW_PATH_DEFAULT;
@@ -155,6 +171,18 @@ cn10k_mldev_parse_devargs(struct rte_devargs *devargs, struct cn10k_ml_dev *mlde
}
plt_info("ML: %s = %d", CN10K_ML_FW_REPORT_DPE_WARNINGS, mldev->fw.report_dpe_warnings);
+ if (!cache_model_data_set) {
+ mldev->cache_model_data = CN10K_ML_DEV_CACHE_MODEL_DATA_DEFAULT;
+ } else {
+ if ((mldev->cache_model_data < 0) || (mldev->cache_model_data > 1)) {
+ plt_err("Invalid argument, %s = %d\n", CN10K_ML_DEV_CACHE_MODEL_DATA,
+ mldev->cache_model_data);
+ ret = -EINVAL;
+ goto exit;
+ }
+ }
+ plt_info("ML: %s = %d", CN10K_ML_DEV_CACHE_MODEL_DATA, mldev->cache_model_data);
+
exit:
if (kvlist)
rte_kvargs_free(kvlist);
@@ -694,4 +722,5 @@ RTE_PMD_REGISTER_KMOD_DEP(MLDEV_NAME_CN10K_PMD, "vfio-pci");
RTE_PMD_REGISTER_PARAM_STRING(MLDEV_NAME_CN10K_PMD,
CN10K_ML_FW_PATH "=<path>" CN10K_ML_FW_ENABLE_DPE_WARNINGS
- "=<0|1>" CN10K_ML_FW_REPORT_DPE_WARNINGS "=<0|1>");
+ "=<0|1>" CN10K_ML_FW_REPORT_DPE_WARNINGS
+ "=<0|1>" CN10K_ML_DEV_CACHE_MODEL_DATA "=<0|1>");
@@ -381,6 +381,9 @@ struct cn10k_ml_dev {
/* xstats status */
bool xstats_enabled;
+
+ /* Enable / disable model data caching */
+ int cache_model_data;
};
uint64_t cn10k_ml_fw_flags_get(struct cn10k_ml_fw *fw);
@@ -488,6 +488,49 @@ cn10k_ml_model_xstat_reset(struct rte_ml_dev *dev, uint16_t model_id,
}
}
+static int
+cn10k_ml_cache_model_data(struct rte_ml_dev *dev, uint16_t model_id)
+{
+ struct cn10k_ml_model *model;
+ struct rte_ml_op op;
+
+ char str[RTE_MEMZONE_NAMESIZE];
+ const struct plt_memzone *mz;
+ uint64_t isize = 0;
+ uint64_t osize = 0;
+ int ret = 0;
+
+ model = dev->data->models[model_id];
+
+ /* Create input and output buffers. */
+ rte_ml_io_input_size_get(dev->data->dev_id, model_id, model->batch_size, &isize, NULL);
+ rte_ml_io_output_size_get(dev->data->dev_id, model_id, model->batch_size, &osize, NULL);
+
+ snprintf(str, RTE_MEMZONE_NAMESIZE, "%s_%u", "ml_dummy_io", model_id);
+ mz = plt_memzone_reserve_aligned(str, isize + osize, 0, ML_CN10K_ALIGN_SIZE);
+ if (mz == NULL)
+ return -ENOMEM;
+ memset(mz->addr, 0, isize + osize);
+
+ op.model_id = model_id;
+ op.nb_batches = model->batch_size;
+ op.mempool = NULL;
+
+ op.input.addr = mz->addr;
+ op.input.length = isize;
+ op.input.next = NULL;
+
+ op.output.addr = PLT_PTR_ADD(op.input.addr, isize);
+ op.output.length = osize;
+ op.output.next = NULL;
+
+ memset(model->req, 0, sizeof(struct cn10k_ml_req));
+ ret = cn10k_ml_inference_sync(dev, &op);
+ plt_memzone_free(mz);
+
+ return ret;
+}
+
static int
cn10k_ml_dev_info_get(struct rte_ml_dev *dev, struct rte_ml_dev_info *dev_info)
{
@@ -1467,6 +1510,13 @@ cn10k_ml_model_start(struct rte_ml_dev *dev, uint16_t model_id)
}
}
+ if (ret < 0) { /* Call unload to update model and FW state, ignore error */
+ rte_ml_model_stop(dev->data->dev_id, model_id);
+ } else {
+ if (mldev->cache_model_data && roc_model_is_cn10ka())
+ ret = cn10k_ml_cache_model_data(dev, model_id);
+ }
+
return ret;
}