@@ -101,7 +101,7 @@ qp_memzone_name_get(char *name, int size, int dev_id, int qp_id)
snprintf(name, size, "cnxk_ml_qp_mem_%u:%u", dev_id, qp_id);
}
-static int
+int
cnxk_ml_qp_destroy(const struct rte_ml_dev *dev, struct cnxk_ml_qp *qp)
{
const struct rte_memzone *qp_mem;
@@ -861,20 +861,12 @@ cn10k_ml_cache_model_data(struct rte_ml_dev *dev, uint16_t model_id)
}
int
-cn10k_ml_dev_info_get(struct rte_ml_dev *dev, struct rte_ml_dev_info *dev_info)
+cn10k_ml_dev_info_get(struct cnxk_ml_dev *cnxk_mldev, struct rte_ml_dev_info *dev_info)
{
struct cn10k_ml_dev *cn10k_mldev;
- struct cnxk_ml_dev *cnxk_mldev;
- if (dev_info == NULL)
- return -EINVAL;
-
- cnxk_mldev = dev->data->dev_private;
cn10k_mldev = &cnxk_mldev->cn10k_mldev;
- memset(dev_info, 0, sizeof(struct rte_ml_dev_info));
- dev_info->driver_name = dev->device->driver->name;
- dev_info->max_models = ML_CNXK_MAX_MODELS;
if (cn10k_mldev->hw_queue_lock)
dev_info->max_queue_pairs = ML_CN10K_MAX_QP_PER_DEVICE_SL;
else
@@ -889,143 +881,17 @@ cn10k_ml_dev_info_get(struct rte_ml_dev *dev, struct rte_ml_dev_info *dev_info)
}
int
-cn10k_ml_dev_configure(struct rte_ml_dev *dev, const struct rte_ml_dev_config *conf)
+cn10k_ml_dev_configure(struct cnxk_ml_dev *cnxk_mldev, const struct rte_ml_dev_config *conf)
{
- struct rte_ml_dev_info dev_info;
struct cn10k_ml_dev *cn10k_mldev;
- struct cnxk_ml_dev *cnxk_mldev;
- struct cnxk_ml_model *model;
struct cn10k_ml_ocm *ocm;
- struct cnxk_ml_qp *qp;
- uint16_t model_id;
- uint32_t mz_size;
uint16_t tile_id;
- uint16_t qp_id;
int ret;
- if (dev == NULL || conf == NULL)
- return -EINVAL;
+ RTE_SET_USED(conf);
- /* Get CN10K device handle */
- cnxk_mldev = dev->data->dev_private;
cn10k_mldev = &cnxk_mldev->cn10k_mldev;
- cn10k_ml_dev_info_get(dev, &dev_info);
- if (conf->nb_models > dev_info.max_models) {
- plt_err("Invalid device config, nb_models > %u\n", dev_info.max_models);
- return -EINVAL;
- }
-
- if (conf->nb_queue_pairs > dev_info.max_queue_pairs) {
- plt_err("Invalid device config, nb_queue_pairs > %u\n", dev_info.max_queue_pairs);
- return -EINVAL;
- }
-
- if (cnxk_mldev->state == ML_CNXK_DEV_STATE_PROBED) {
- plt_ml_dbg("Configuring ML device, nb_queue_pairs = %u, nb_models = %u",
- conf->nb_queue_pairs, conf->nb_models);
-
- /* Load firmware */
- ret = cn10k_ml_fw_load(cnxk_mldev);
- if (ret != 0)
- return ret;
- } else if (cnxk_mldev->state == ML_CNXK_DEV_STATE_CONFIGURED) {
- plt_ml_dbg("Re-configuring ML device, nb_queue_pairs = %u, nb_models = %u",
- conf->nb_queue_pairs, conf->nb_models);
- } else if (cnxk_mldev->state == ML_CNXK_DEV_STATE_STARTED) {
- plt_err("Device can't be reconfigured in started state\n");
- return -ENOTSUP;
- } else if (cnxk_mldev->state == ML_CNXK_DEV_STATE_CLOSED) {
- plt_err("Device can't be reconfigured after close\n");
- return -ENOTSUP;
- }
-
- /* Configure queue-pairs */
- if (dev->data->queue_pairs == NULL) {
- mz_size = sizeof(dev->data->queue_pairs[0]) * conf->nb_queue_pairs;
- dev->data->queue_pairs =
- rte_zmalloc("cn10k_mldev_queue_pairs", mz_size, RTE_CACHE_LINE_SIZE);
- if (dev->data->queue_pairs == NULL) {
- dev->data->nb_queue_pairs = 0;
- plt_err("Failed to get memory for queue_pairs, nb_queue_pairs %u",
- conf->nb_queue_pairs);
- return -ENOMEM;
- }
- } else { /* Re-configure */
- void **queue_pairs;
-
- /* Release all queue pairs as ML spec doesn't support queue_pair_destroy. */
- for (qp_id = 0; qp_id < dev->data->nb_queue_pairs; qp_id++) {
- qp = dev->data->queue_pairs[qp_id];
- if (qp != NULL) {
- ret = cn10k_ml_dev_queue_pair_release(dev, qp_id);
- if (ret < 0)
- return ret;
- }
- }
-
- queue_pairs = dev->data->queue_pairs;
- queue_pairs =
- rte_realloc(queue_pairs, sizeof(queue_pairs[0]) * conf->nb_queue_pairs,
- RTE_CACHE_LINE_SIZE);
- if (queue_pairs == NULL) {
- dev->data->nb_queue_pairs = 0;
- plt_err("Failed to realloc queue_pairs, nb_queue_pairs = %u",
- conf->nb_queue_pairs);
- ret = -ENOMEM;
- goto error;
- }
-
- memset(queue_pairs, 0, sizeof(queue_pairs[0]) * conf->nb_queue_pairs);
- dev->data->queue_pairs = queue_pairs;
- }
- dev->data->nb_queue_pairs = conf->nb_queue_pairs;
-
- /* Allocate ML models */
- if (dev->data->models == NULL) {
- mz_size = sizeof(dev->data->models[0]) * conf->nb_models;
- dev->data->models = rte_zmalloc("cn10k_mldev_models", mz_size, RTE_CACHE_LINE_SIZE);
- if (dev->data->models == NULL) {
- dev->data->nb_models = 0;
- plt_err("Failed to get memory for ml_models, nb_models %u",
- conf->nb_models);
- ret = -ENOMEM;
- goto error;
- }
- } else {
- /* Re-configure */
- void **models;
-
- /* Stop and unload all models */
- for (model_id = 0; model_id < dev->data->nb_models; model_id++) {
- model = dev->data->models[model_id];
- if (model != NULL) {
- if (model->state == ML_CNXK_MODEL_STATE_STARTED) {
- if (cn10k_ml_model_stop(dev, model_id) != 0)
- plt_err("Could not stop model %u", model_id);
- }
- if (model->state == ML_CNXK_MODEL_STATE_LOADED) {
- if (cn10k_ml_model_unload(dev, model_id) != 0)
- plt_err("Could not unload model %u", model_id);
- }
- dev->data->models[model_id] = NULL;
- }
- }
-
- models = dev->data->models;
- models = rte_realloc(models, sizeof(models[0]) * conf->nb_models,
- RTE_CACHE_LINE_SIZE);
- if (models == NULL) {
- dev->data->nb_models = 0;
- plt_err("Failed to realloc ml_models, nb_models = %u", conf->nb_models);
- ret = -ENOMEM;
- goto error;
- }
- memset(models, 0, sizeof(models[0]) * conf->nb_models);
- dev->data->models = models;
- }
- dev->data->nb_models = conf->nb_models;
-
ocm = &cn10k_mldev->ocm;
ocm->num_tiles = ML_CN10K_OCM_NUMTILES;
ocm->size_per_tile = ML_CN10K_OCM_TILESIZE;
@@ -1038,8 +904,7 @@ cn10k_ml_dev_configure(struct rte_ml_dev *dev, const struct rte_ml_dev_config *c
rte_zmalloc("ocm_mask", ocm->mask_words * ocm->num_tiles, RTE_CACHE_LINE_SIZE);
if (ocm->ocm_mask == NULL) {
plt_err("Unable to allocate memory for OCM mask");
- ret = -ENOMEM;
- goto error;
+ return -ENOMEM;
}
for (tile_id = 0; tile_id < ocm->num_tiles; tile_id++) {
@@ -1050,10 +915,10 @@ cn10k_ml_dev_configure(struct rte_ml_dev *dev, const struct rte_ml_dev_config *c
rte_spinlock_init(&ocm->lock);
/* Initialize xstats */
- ret = cn10k_ml_xstats_init(dev);
+ ret = cn10k_ml_xstats_init(cnxk_mldev->mldev);
if (ret != 0) {
plt_err("Failed to initialize xstats");
- goto error;
+ return ret;
}
/* Set JCMDQ enqueue function */
@@ -1067,77 +932,25 @@ cn10k_ml_dev_configure(struct rte_ml_dev *dev, const struct rte_ml_dev_config *c
cn10k_mldev->set_poll_ptr = cn10k_ml_set_poll_ptr;
cn10k_mldev->get_poll_ptr = cn10k_ml_get_poll_ptr;
- dev->enqueue_burst = cn10k_ml_enqueue_burst;
- dev->dequeue_burst = cn10k_ml_dequeue_burst;
- dev->op_error_get = cn10k_ml_op_error_get;
-
- cnxk_mldev->nb_models_loaded = 0;
- cnxk_mldev->nb_models_started = 0;
- cnxk_mldev->nb_models_stopped = 0;
- cnxk_mldev->nb_models_unloaded = 0;
- cnxk_mldev->state = ML_CNXK_DEV_STATE_CONFIGURED;
+ cnxk_mldev->mldev->enqueue_burst = cn10k_ml_enqueue_burst;
+ cnxk_mldev->mldev->dequeue_burst = cn10k_ml_dequeue_burst;
+ cnxk_mldev->mldev->op_error_get = cn10k_ml_op_error_get;
return 0;
-
-error:
- rte_free(dev->data->queue_pairs);
-
- rte_free(dev->data->models);
-
- return ret;
}
int
-cn10k_ml_dev_close(struct rte_ml_dev *dev)
+cn10k_ml_dev_close(struct cnxk_ml_dev *cnxk_mldev)
{
struct cn10k_ml_dev *cn10k_mldev;
- struct cnxk_ml_dev *cnxk_mldev;
- struct cnxk_ml_model *model;
- struct cnxk_ml_qp *qp;
- uint16_t model_id;
- uint16_t qp_id;
- if (dev == NULL)
- return -EINVAL;
-
- cnxk_mldev = dev->data->dev_private;
cn10k_mldev = &cnxk_mldev->cn10k_mldev;
/* Release ocm_mask memory */
rte_free(cn10k_mldev->ocm.ocm_mask);
- /* Stop and unload all models */
- for (model_id = 0; model_id < dev->data->nb_models; model_id++) {
- model = dev->data->models[model_id];
- if (model != NULL) {
- if (model->state == ML_CNXK_MODEL_STATE_STARTED) {
- if (cn10k_ml_model_stop(dev, model_id) != 0)
- plt_err("Could not stop model %u", model_id);
- }
- if (model->state == ML_CNXK_MODEL_STATE_LOADED) {
- if (cn10k_ml_model_unload(dev, model_id) != 0)
- plt_err("Could not unload model %u", model_id);
- }
- dev->data->models[model_id] = NULL;
- }
- }
-
- rte_free(dev->data->models);
-
- /* Destroy all queue pairs */
- for (qp_id = 0; qp_id < dev->data->nb_queue_pairs; qp_id++) {
- qp = dev->data->queue_pairs[qp_id];
- if (qp != NULL) {
- if (cnxk_ml_qp_destroy(dev, qp) != 0)
- plt_err("Could not destroy queue pair %u", qp_id);
- dev->data->queue_pairs[qp_id] = NULL;
- }
- }
-
- rte_free(dev->data->queue_pairs);
-
/* Un-initialize xstats */
- cn10k_ml_xstats_uninit(dev);
+ cn10k_ml_xstats_uninit(cnxk_mldev->mldev);
/* Unload firmware */
cn10k_ml_fw_unload(cnxk_mldev);
@@ -1154,20 +967,15 @@ cn10k_ml_dev_close(struct rte_ml_dev *dev)
roc_ml_reg_write64(&cn10k_mldev->roc, 0, ML_MLR_BASE);
plt_ml_dbg("ML_MLR_BASE = 0x%016lx", roc_ml_reg_read64(&cn10k_mldev->roc, ML_MLR_BASE));
- cnxk_mldev->state = ML_CNXK_DEV_STATE_CLOSED;
-
- /* Remove PCI device */
- return rte_dev_remove(dev->device);
+ return 0;
}
int
-cn10k_ml_dev_start(struct rte_ml_dev *dev)
+cn10k_ml_dev_start(struct cnxk_ml_dev *cnxk_mldev)
{
struct cn10k_ml_dev *cn10k_mldev;
- struct cnxk_ml_dev *cnxk_mldev;
uint64_t reg_val64;
- cnxk_mldev = dev->data->dev_private;
cn10k_mldev = &cnxk_mldev->cn10k_mldev;
reg_val64 = roc_ml_reg_read64(&cn10k_mldev->roc, ML_CFG);
@@ -1175,19 +983,15 @@ cn10k_ml_dev_start(struct rte_ml_dev *dev)
roc_ml_reg_write64(&cn10k_mldev->roc, reg_val64, ML_CFG);
plt_ml_dbg("ML_CFG => 0x%016lx", roc_ml_reg_read64(&cn10k_mldev->roc, ML_CFG));
- cnxk_mldev->state = ML_CNXK_DEV_STATE_STARTED;
-
return 0;
}
int
-cn10k_ml_dev_stop(struct rte_ml_dev *dev)
+cn10k_ml_dev_stop(struct cnxk_ml_dev *cnxk_mldev)
{
struct cn10k_ml_dev *cn10k_mldev;
- struct cnxk_ml_dev *cnxk_mldev;
uint64_t reg_val64;
- cnxk_mldev = dev->data->dev_private;
cn10k_mldev = &cnxk_mldev->cn10k_mldev;
reg_val64 = roc_ml_reg_read64(&cn10k_mldev->roc, ML_CFG);
@@ -1195,8 +999,6 @@ cn10k_ml_dev_stop(struct rte_ml_dev *dev)
roc_ml_reg_write64(&cn10k_mldev->roc, reg_val64, ML_CFG);
plt_ml_dbg("ML_CFG => 0x%016lx", roc_ml_reg_read64(&cn10k_mldev->roc, ML_CFG));
- cnxk_mldev->state = ML_CNXK_DEV_STATE_CONFIGURED;
-
return 0;
}
@@ -1217,7 +1019,7 @@ cn10k_ml_dev_queue_pair_setup(struct rte_ml_dev *dev, uint16_t queue_pair_id,
if (dev->data->queue_pairs[queue_pair_id] != NULL)
cn10k_ml_dev_queue_pair_release(dev, queue_pair_id);
- cn10k_ml_dev_info_get(dev, &dev_info);
+ cnxk_ml_dev_info_get(dev, &dev_info);
if ((qp_conf->nb_desc > dev_info.max_desc) || (qp_conf->nb_desc == 0)) {
plt_err("Could not setup queue pair for %u descriptors", qp_conf->nb_desc);
return -EINVAL;
@@ -10,6 +10,9 @@
#include <roc_api.h>
+struct cnxk_ml_dev;
+struct cnxk_ml_qp;
+
/* Firmware version string length */
#define MLDEV_FIRMWARE_VERSION_LENGTH 32
@@ -286,11 +289,11 @@ struct cn10k_ml_req {
};
/* Device ops */
-int cn10k_ml_dev_info_get(struct rte_ml_dev *dev, struct rte_ml_dev_info *dev_info);
-int cn10k_ml_dev_configure(struct rte_ml_dev *dev, const struct rte_ml_dev_config *conf);
-int cn10k_ml_dev_close(struct rte_ml_dev *dev);
-int cn10k_ml_dev_start(struct rte_ml_dev *dev);
-int cn10k_ml_dev_stop(struct rte_ml_dev *dev);
+int cn10k_ml_dev_info_get(struct cnxk_ml_dev *cnxk_mldev, struct rte_ml_dev_info *dev_info);
+int cn10k_ml_dev_configure(struct cnxk_ml_dev *cnxk_mldev, const struct rte_ml_dev_config *conf);
+int cn10k_ml_dev_close(struct cnxk_ml_dev *cnxk_mldev);
+int cn10k_ml_dev_start(struct cnxk_ml_dev *cnxk_mldev);
+int cn10k_ml_dev_stop(struct cnxk_ml_dev *cnxk_mldev);
int cn10k_ml_dev_dump(struct rte_ml_dev *dev, FILE *fp);
int cn10k_ml_dev_selftest(struct rte_ml_dev *dev);
int cn10k_ml_dev_queue_pair_setup(struct rte_ml_dev *dev, uint16_t queue_pair_id,
@@ -336,4 +339,7 @@ __rte_hot int cn10k_ml_op_error_get(struct rte_ml_dev *dev, struct rte_ml_op *op
struct rte_ml_op_error *error);
__rte_hot int cn10k_ml_inference_sync(struct rte_ml_dev *dev, struct rte_ml_op *op);
+/* Temporarily set below functions as non-static */
+int cnxk_ml_qp_destroy(const struct rte_ml_dev *dev, struct cnxk_ml_qp *qp);
+
#endif /* _CN10K_ML_OPS_H_ */
@@ -53,6 +53,9 @@ struct cnxk_ml_dev {
/* CN10K device structure */
struct cn10k_ml_dev cn10k_mldev;
+
+ /* Maximum number of layers */
+ uint64_t max_nb_layers;
};
#endif /* _CNXK_ML_DEV_H_ */
@@ -5,15 +5,291 @@
#include <rte_mldev.h>
#include <rte_mldev_pmd.h>
+#include "cnxk_ml_dev.h"
+#include "cnxk_ml_io.h"
+#include "cnxk_ml_model.h"
#include "cnxk_ml_ops.h"
+int
+cnxk_ml_dev_info_get(struct rte_ml_dev *dev, struct rte_ml_dev_info *dev_info)
+{
+ struct cnxk_ml_dev *cnxk_mldev;
+
+ if (dev == NULL || dev_info == NULL)
+ return -EINVAL;
+
+ cnxk_mldev = dev->data->dev_private;
+
+ memset(dev_info, 0, sizeof(struct rte_ml_dev_info));
+ dev_info->driver_name = dev->device->driver->name;
+ dev_info->max_models = ML_CNXK_MAX_MODELS;
+
+ return cn10k_ml_dev_info_get(cnxk_mldev, dev_info);
+}
+
+static int
+cnxk_ml_dev_configure(struct rte_ml_dev *dev, const struct rte_ml_dev_config *conf)
+{
+ struct rte_ml_dev_info dev_info;
+ struct cnxk_ml_dev *cnxk_mldev;
+ struct cnxk_ml_model *model;
+ struct cnxk_ml_qp *qp;
+ uint16_t model_id;
+ uint32_t mz_size;
+ uint16_t qp_id;
+ int ret;
+
+ if (dev == NULL)
+ return -EINVAL;
+
+ /* Get CNXK device handle */
+ cnxk_mldev = dev->data->dev_private;
+
+ cnxk_ml_dev_info_get(dev, &dev_info);
+ if (conf->nb_models > dev_info.max_models) {
+ plt_err("Invalid device config, nb_models > %u\n", dev_info.max_models);
+ return -EINVAL;
+ }
+
+ if (conf->nb_queue_pairs > dev_info.max_queue_pairs) {
+ plt_err("Invalid device config, nb_queue_pairs > %u\n", dev_info.max_queue_pairs);
+ return -EINVAL;
+ }
+
+ if (cnxk_mldev->state == ML_CNXK_DEV_STATE_PROBED) {
+ plt_ml_dbg("Configuring ML device, nb_queue_pairs = %u, nb_models = %u",
+ conf->nb_queue_pairs, conf->nb_models);
+
+ /* Load firmware */
+ ret = cn10k_ml_fw_load(cnxk_mldev);
+ if (ret != 0)
+ return ret;
+ } else if (cnxk_mldev->state == ML_CNXK_DEV_STATE_CONFIGURED) {
+ plt_ml_dbg("Re-configuring ML device, nb_queue_pairs = %u, nb_models = %u",
+ conf->nb_queue_pairs, conf->nb_models);
+ } else if (cnxk_mldev->state == ML_CNXK_DEV_STATE_STARTED) {
+ plt_err("Device can't be reconfigured in started state\n");
+ return -ENOTSUP;
+ } else if (cnxk_mldev->state == ML_CNXK_DEV_STATE_CLOSED) {
+ plt_err("Device can't be reconfigured after close\n");
+ return -ENOTSUP;
+ }
+
+ /* Configure queue-pairs */
+ if (dev->data->queue_pairs == NULL) {
+ mz_size = sizeof(dev->data->queue_pairs[0]) * conf->nb_queue_pairs;
+ dev->data->queue_pairs =
+ rte_zmalloc("cnxk_mldev_queue_pairs", mz_size, RTE_CACHE_LINE_SIZE);
+ if (dev->data->queue_pairs == NULL) {
+ dev->data->nb_queue_pairs = 0;
+ plt_err("Failed to get memory for queue_pairs, nb_queue_pairs %u",
+ conf->nb_queue_pairs);
+ return -ENOMEM;
+ }
+ } else { /* Re-configure */
+ void **queue_pairs;
+
+ /* Release all queue pairs as ML spec doesn't support queue_pair_destroy. */
+ for (qp_id = 0; qp_id < dev->data->nb_queue_pairs; qp_id++) {
+ qp = dev->data->queue_pairs[qp_id];
+ if (qp != NULL) {
+ ret = cn10k_ml_dev_queue_pair_release(dev, qp_id);
+ if (ret < 0)
+ return ret;
+ }
+ }
+
+ queue_pairs = dev->data->queue_pairs;
+ queue_pairs =
+ rte_realloc(queue_pairs, sizeof(queue_pairs[0]) * conf->nb_queue_pairs,
+ RTE_CACHE_LINE_SIZE);
+ if (queue_pairs == NULL) {
+ dev->data->nb_queue_pairs = 0;
+ plt_err("Failed to realloc queue_pairs, nb_queue_pairs = %u",
+ conf->nb_queue_pairs);
+ ret = -ENOMEM;
+ goto error;
+ }
+
+ memset(queue_pairs, 0, sizeof(queue_pairs[0]) * conf->nb_queue_pairs);
+ dev->data->queue_pairs = queue_pairs;
+ }
+ dev->data->nb_queue_pairs = conf->nb_queue_pairs;
+
+ /* Allocate ML models */
+ if (dev->data->models == NULL) {
+ mz_size = sizeof(dev->data->models[0]) * conf->nb_models;
+ dev->data->models = rte_zmalloc("cnxk_mldev_models", mz_size, RTE_CACHE_LINE_SIZE);
+ if (dev->data->models == NULL) {
+ dev->data->nb_models = 0;
+ plt_err("Failed to get memory for ml_models, nb_models %u",
+ conf->nb_models);
+ ret = -ENOMEM;
+ goto error;
+ }
+ } else {
+ /* Re-configure */
+ void **models;
+
+ /* Stop and unload all models */
+ for (model_id = 0; model_id < dev->data->nb_models; model_id++) {
+ model = dev->data->models[model_id];
+ if (model != NULL) {
+ if (model->state == ML_CNXK_MODEL_STATE_STARTED) {
+ if (cn10k_ml_model_stop(dev, model_id) != 0)
+ plt_err("Could not stop model %u", model_id);
+ }
+ if (model->state == ML_CNXK_MODEL_STATE_LOADED) {
+ if (cn10k_ml_model_unload(dev, model_id) != 0)
+ plt_err("Could not unload model %u", model_id);
+ }
+ dev->data->models[model_id] = NULL;
+ }
+ }
+
+ models = dev->data->models;
+ models = rte_realloc(models, sizeof(models[0]) * conf->nb_models,
+ RTE_CACHE_LINE_SIZE);
+ if (models == NULL) {
+ dev->data->nb_models = 0;
+ plt_err("Failed to realloc ml_models, nb_models = %u", conf->nb_models);
+ ret = -ENOMEM;
+ goto error;
+ }
+ memset(models, 0, sizeof(models[0]) * conf->nb_models);
+ dev->data->models = models;
+ }
+ dev->data->nb_models = conf->nb_models;
+
+ ret = cn10k_ml_dev_configure(cnxk_mldev, conf);
+ if (ret != 0) {
+ plt_err("Failed to configure CN10K ML Device");
+ goto error;
+ }
+
+ /* Set device capabilities */
+ cnxk_mldev->max_nb_layers =
+ cnxk_mldev->cn10k_mldev.fw.req->cn10k_req.jd.fw_load.cap.s.max_models;
+
+ cnxk_mldev->nb_models_loaded = 0;
+ cnxk_mldev->nb_models_started = 0;
+ cnxk_mldev->nb_models_stopped = 0;
+ cnxk_mldev->nb_models_unloaded = 0;
+ cnxk_mldev->state = ML_CNXK_DEV_STATE_CONFIGURED;
+
+ return 0;
+
+error:
+ rte_free(dev->data->queue_pairs);
+ rte_free(dev->data->models);
+
+ return ret;
+}
+
+static int
+cnxk_ml_dev_close(struct rte_ml_dev *dev)
+{
+ struct cnxk_ml_dev *cnxk_mldev;
+ struct cnxk_ml_model *model;
+ struct cnxk_ml_qp *qp;
+ uint16_t model_id;
+ uint16_t qp_id;
+
+ if (dev == NULL)
+ return -EINVAL;
+
+ cnxk_mldev = dev->data->dev_private;
+
+ if (cn10k_ml_dev_close(cnxk_mldev) != 0)
+ plt_err("Failed to close CN10K ML Device");
+
+ /* Stop and unload all models */
+ for (model_id = 0; model_id < dev->data->nb_models; model_id++) {
+ model = dev->data->models[model_id];
+ if (model != NULL) {
+ if (model->state == ML_CNXK_MODEL_STATE_STARTED) {
+ if (cn10k_ml_model_stop(dev, model_id) != 0)
+ plt_err("Could not stop model %u", model_id);
+ }
+ if (model->state == ML_CNXK_MODEL_STATE_LOADED) {
+ if (cn10k_ml_model_unload(dev, model_id) != 0)
+ plt_err("Could not unload model %u", model_id);
+ }
+ dev->data->models[model_id] = NULL;
+ }
+ }
+
+ rte_free(dev->data->models);
+
+ /* Destroy all queue pairs */
+ for (qp_id = 0; qp_id < dev->data->nb_queue_pairs; qp_id++) {
+ qp = dev->data->queue_pairs[qp_id];
+ if (qp != NULL) {
+ if (cnxk_ml_qp_destroy(dev, qp) != 0)
+ plt_err("Could not destroy queue pair %u", qp_id);
+ dev->data->queue_pairs[qp_id] = NULL;
+ }
+ }
+
+ rte_free(dev->data->queue_pairs);
+
+ cnxk_mldev->state = ML_CNXK_DEV_STATE_CLOSED;
+
+ /* Remove PCI device */
+ return rte_dev_remove(dev->device);
+}
+
+static int
+cnxk_ml_dev_start(struct rte_ml_dev *dev)
+{
+ struct cnxk_ml_dev *cnxk_mldev;
+ int ret;
+
+ if (dev == NULL)
+ return -EINVAL;
+
+ cnxk_mldev = dev->data->dev_private;
+
+ ret = cn10k_ml_dev_start(cnxk_mldev);
+ if (ret != 0) {
+ plt_err("Failed to start CN10K ML Device");
+ return ret;
+ }
+
+ cnxk_mldev->state = ML_CNXK_DEV_STATE_STARTED;
+
+ return 0;
+}
+
+static int
+cnxk_ml_dev_stop(struct rte_ml_dev *dev)
+{
+ struct cnxk_ml_dev *cnxk_mldev;
+ int ret;
+
+ if (dev == NULL)
+ return -EINVAL;
+
+ cnxk_mldev = dev->data->dev_private;
+
+ ret = cn10k_ml_dev_stop(cnxk_mldev);
+ if (ret != 0) {
+ plt_err("Failed to stop CN10K ML Device");
+ return ret;
+ }
+
+ cnxk_mldev->state = ML_CNXK_DEV_STATE_CONFIGURED;
+
+ return 0;
+}
+
struct rte_ml_dev_ops cnxk_ml_ops = {
/* Device control ops */
- .dev_info_get = cn10k_ml_dev_info_get,
- .dev_configure = cn10k_ml_dev_configure,
- .dev_close = cn10k_ml_dev_close,
- .dev_start = cn10k_ml_dev_start,
- .dev_stop = cn10k_ml_dev_stop,
+ .dev_info_get = cnxk_ml_dev_info_get,
+ .dev_configure = cnxk_ml_dev_configure,
+ .dev_close = cnxk_ml_dev_close,
+ .dev_start = cnxk_ml_dev_start,
+ .dev_stop = cnxk_ml_dev_stop,
.dev_dump = cn10k_ml_dev_dump,
.dev_selftest = cn10k_ml_dev_selftest,
@@ -62,4 +62,7 @@ struct cnxk_ml_qp {
extern struct rte_ml_dev_ops cnxk_ml_ops;
+/* Temporarily set cnxk driver functions as non-static */
+int cnxk_ml_dev_info_get(struct rte_ml_dev *dev, struct rte_ml_dev_info *dev_info);
+
#endif /* _CNXK_ML_OPS_H_ */