[v5,07/34] ml/cnxk: update device handling functions

Message ID 20231018064806.24145-8-syalavarthi@marvell.com (mailing list archive)
State Superseded, archived
Delegated to: Jerin Jacob
Headers
Series Implementation of revised ml/cnxk driver |

Checks

Context Check Description
ci/checkpatch success coding style OK

Commit Message

Srikanth Yalavarthi Oct. 18, 2023, 6:47 a.m. UTC
  Implement CNXK wrapper functions for dev_info_get,
dev_configure, dev_close, dev_start and dev_stop. The
wrapper functions allocate / release common resources
for the ML driver and invoke device specific functions.

Signed-off-by: Srikanth Yalavarthi <syalavarthi@marvell.com>
---
 drivers/ml/cnxk/cn10k_ml_ops.c | 230 ++------------------------
 drivers/ml/cnxk/cn10k_ml_ops.h |  16 +-
 drivers/ml/cnxk/cnxk_ml_dev.h  |   3 +
 drivers/ml/cnxk/cnxk_ml_ops.c  | 286 ++++++++++++++++++++++++++++++++-
 drivers/ml/cnxk/cnxk_ml_ops.h  |   3 +
 5 files changed, 314 insertions(+), 224 deletions(-)
  

Patch

diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c
index 66b38fc1eb..6d8f2c8777 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.c
+++ b/drivers/ml/cnxk/cn10k_ml_ops.c
@@ -101,7 +101,7 @@  qp_memzone_name_get(char *name, int size, int dev_id, int qp_id)
 	snprintf(name, size, "cnxk_ml_qp_mem_%u:%u", dev_id, qp_id);
 }
 
-static int
+int
 cnxk_ml_qp_destroy(const struct rte_ml_dev *dev, struct cnxk_ml_qp *qp)
 {
 	const struct rte_memzone *qp_mem;
@@ -861,20 +861,12 @@  cn10k_ml_cache_model_data(struct rte_ml_dev *dev, uint16_t model_id)
 }
 
 int
-cn10k_ml_dev_info_get(struct rte_ml_dev *dev, struct rte_ml_dev_info *dev_info)
+cn10k_ml_dev_info_get(struct cnxk_ml_dev *cnxk_mldev, struct rte_ml_dev_info *dev_info)
 {
 	struct cn10k_ml_dev *cn10k_mldev;
-	struct cnxk_ml_dev *cnxk_mldev;
 
-	if (dev_info == NULL)
-		return -EINVAL;
-
-	cnxk_mldev = dev->data->dev_private;
 	cn10k_mldev = &cnxk_mldev->cn10k_mldev;
 
-	memset(dev_info, 0, sizeof(struct rte_ml_dev_info));
-	dev_info->driver_name = dev->device->driver->name;
-	dev_info->max_models = ML_CNXK_MAX_MODELS;
 	if (cn10k_mldev->hw_queue_lock)
 		dev_info->max_queue_pairs = ML_CN10K_MAX_QP_PER_DEVICE_SL;
 	else
@@ -889,143 +881,17 @@  cn10k_ml_dev_info_get(struct rte_ml_dev *dev, struct rte_ml_dev_info *dev_info)
 }
 
 int
-cn10k_ml_dev_configure(struct rte_ml_dev *dev, const struct rte_ml_dev_config *conf)
+cn10k_ml_dev_configure(struct cnxk_ml_dev *cnxk_mldev, const struct rte_ml_dev_config *conf)
 {
-	struct rte_ml_dev_info dev_info;
 	struct cn10k_ml_dev *cn10k_mldev;
-	struct cnxk_ml_dev *cnxk_mldev;
-	struct cnxk_ml_model *model;
 	struct cn10k_ml_ocm *ocm;
-	struct cnxk_ml_qp *qp;
-	uint16_t model_id;
-	uint32_t mz_size;
 	uint16_t tile_id;
-	uint16_t qp_id;
 	int ret;
 
-	if (dev == NULL || conf == NULL)
-		return -EINVAL;
+	RTE_SET_USED(conf);
 
-	/* Get CN10K device handle */
-	cnxk_mldev = dev->data->dev_private;
 	cn10k_mldev = &cnxk_mldev->cn10k_mldev;
 
-	cn10k_ml_dev_info_get(dev, &dev_info);
-	if (conf->nb_models > dev_info.max_models) {
-		plt_err("Invalid device config, nb_models > %u\n", dev_info.max_models);
-		return -EINVAL;
-	}
-
-	if (conf->nb_queue_pairs > dev_info.max_queue_pairs) {
-		plt_err("Invalid device config, nb_queue_pairs > %u\n", dev_info.max_queue_pairs);
-		return -EINVAL;
-	}
-
-	if (cnxk_mldev->state == ML_CNXK_DEV_STATE_PROBED) {
-		plt_ml_dbg("Configuring ML device, nb_queue_pairs = %u, nb_models = %u",
-			   conf->nb_queue_pairs, conf->nb_models);
-
-		/* Load firmware */
-		ret = cn10k_ml_fw_load(cnxk_mldev);
-		if (ret != 0)
-			return ret;
-	} else if (cnxk_mldev->state == ML_CNXK_DEV_STATE_CONFIGURED) {
-		plt_ml_dbg("Re-configuring ML device, nb_queue_pairs = %u, nb_models = %u",
-			   conf->nb_queue_pairs, conf->nb_models);
-	} else if (cnxk_mldev->state == ML_CNXK_DEV_STATE_STARTED) {
-		plt_err("Device can't be reconfigured in started state\n");
-		return -ENOTSUP;
-	} else if (cnxk_mldev->state == ML_CNXK_DEV_STATE_CLOSED) {
-		plt_err("Device can't be reconfigured after close\n");
-		return -ENOTSUP;
-	}
-
-	/* Configure queue-pairs */
-	if (dev->data->queue_pairs == NULL) {
-		mz_size = sizeof(dev->data->queue_pairs[0]) * conf->nb_queue_pairs;
-		dev->data->queue_pairs =
-			rte_zmalloc("cn10k_mldev_queue_pairs", mz_size, RTE_CACHE_LINE_SIZE);
-		if (dev->data->queue_pairs == NULL) {
-			dev->data->nb_queue_pairs = 0;
-			plt_err("Failed to get memory for queue_pairs, nb_queue_pairs %u",
-				conf->nb_queue_pairs);
-			return -ENOMEM;
-		}
-	} else { /* Re-configure */
-		void **queue_pairs;
-
-		/* Release all queue pairs as ML spec doesn't support queue_pair_destroy. */
-		for (qp_id = 0; qp_id < dev->data->nb_queue_pairs; qp_id++) {
-			qp = dev->data->queue_pairs[qp_id];
-			if (qp != NULL) {
-				ret = cn10k_ml_dev_queue_pair_release(dev, qp_id);
-				if (ret < 0)
-					return ret;
-			}
-		}
-
-		queue_pairs = dev->data->queue_pairs;
-		queue_pairs =
-			rte_realloc(queue_pairs, sizeof(queue_pairs[0]) * conf->nb_queue_pairs,
-				    RTE_CACHE_LINE_SIZE);
-		if (queue_pairs == NULL) {
-			dev->data->nb_queue_pairs = 0;
-			plt_err("Failed to realloc queue_pairs, nb_queue_pairs = %u",
-				conf->nb_queue_pairs);
-			ret = -ENOMEM;
-			goto error;
-		}
-
-		memset(queue_pairs, 0, sizeof(queue_pairs[0]) * conf->nb_queue_pairs);
-		dev->data->queue_pairs = queue_pairs;
-	}
-	dev->data->nb_queue_pairs = conf->nb_queue_pairs;
-
-	/* Allocate ML models */
-	if (dev->data->models == NULL) {
-		mz_size = sizeof(dev->data->models[0]) * conf->nb_models;
-		dev->data->models = rte_zmalloc("cn10k_mldev_models", mz_size, RTE_CACHE_LINE_SIZE);
-		if (dev->data->models == NULL) {
-			dev->data->nb_models = 0;
-			plt_err("Failed to get memory for ml_models, nb_models %u",
-				conf->nb_models);
-			ret = -ENOMEM;
-			goto error;
-		}
-	} else {
-		/* Re-configure */
-		void **models;
-
-		/* Stop and unload all models */
-		for (model_id = 0; model_id < dev->data->nb_models; model_id++) {
-			model = dev->data->models[model_id];
-			if (model != NULL) {
-				if (model->state == ML_CNXK_MODEL_STATE_STARTED) {
-					if (cn10k_ml_model_stop(dev, model_id) != 0)
-						plt_err("Could not stop model %u", model_id);
-				}
-				if (model->state == ML_CNXK_MODEL_STATE_LOADED) {
-					if (cn10k_ml_model_unload(dev, model_id) != 0)
-						plt_err("Could not unload model %u", model_id);
-				}
-				dev->data->models[model_id] = NULL;
-			}
-		}
-
-		models = dev->data->models;
-		models = rte_realloc(models, sizeof(models[0]) * conf->nb_models,
-				     RTE_CACHE_LINE_SIZE);
-		if (models == NULL) {
-			dev->data->nb_models = 0;
-			plt_err("Failed to realloc ml_models, nb_models = %u", conf->nb_models);
-			ret = -ENOMEM;
-			goto error;
-		}
-		memset(models, 0, sizeof(models[0]) * conf->nb_models);
-		dev->data->models = models;
-	}
-	dev->data->nb_models = conf->nb_models;
-
 	ocm = &cn10k_mldev->ocm;
 	ocm->num_tiles = ML_CN10K_OCM_NUMTILES;
 	ocm->size_per_tile = ML_CN10K_OCM_TILESIZE;
@@ -1038,8 +904,7 @@  cn10k_ml_dev_configure(struct rte_ml_dev *dev, const struct rte_ml_dev_config *c
 		rte_zmalloc("ocm_mask", ocm->mask_words * ocm->num_tiles, RTE_CACHE_LINE_SIZE);
 	if (ocm->ocm_mask == NULL) {
 		plt_err("Unable to allocate memory for OCM mask");
-		ret = -ENOMEM;
-		goto error;
+		return -ENOMEM;
 	}
 
 	for (tile_id = 0; tile_id < ocm->num_tiles; tile_id++) {
@@ -1050,10 +915,10 @@  cn10k_ml_dev_configure(struct rte_ml_dev *dev, const struct rte_ml_dev_config *c
 	rte_spinlock_init(&ocm->lock);
 
 	/* Initialize xstats */
-	ret = cn10k_ml_xstats_init(dev);
+	ret = cn10k_ml_xstats_init(cnxk_mldev->mldev);
 	if (ret != 0) {
 		plt_err("Failed to initialize xstats");
-		goto error;
+		return ret;
 	}
 
 	/* Set JCMDQ enqueue function */
@@ -1067,77 +932,25 @@  cn10k_ml_dev_configure(struct rte_ml_dev *dev, const struct rte_ml_dev_config *c
 	cn10k_mldev->set_poll_ptr = cn10k_ml_set_poll_ptr;
 	cn10k_mldev->get_poll_ptr = cn10k_ml_get_poll_ptr;
 
-	dev->enqueue_burst = cn10k_ml_enqueue_burst;
-	dev->dequeue_burst = cn10k_ml_dequeue_burst;
-	dev->op_error_get = cn10k_ml_op_error_get;
-
-	cnxk_mldev->nb_models_loaded = 0;
-	cnxk_mldev->nb_models_started = 0;
-	cnxk_mldev->nb_models_stopped = 0;
-	cnxk_mldev->nb_models_unloaded = 0;
-	cnxk_mldev->state = ML_CNXK_DEV_STATE_CONFIGURED;
+	cnxk_mldev->mldev->enqueue_burst = cn10k_ml_enqueue_burst;
+	cnxk_mldev->mldev->dequeue_burst = cn10k_ml_dequeue_burst;
+	cnxk_mldev->mldev->op_error_get = cn10k_ml_op_error_get;
 
 	return 0;
-
-error:
-	rte_free(dev->data->queue_pairs);
-
-	rte_free(dev->data->models);
-
-	return ret;
 }
 
 int
-cn10k_ml_dev_close(struct rte_ml_dev *dev)
+cn10k_ml_dev_close(struct cnxk_ml_dev *cnxk_mldev)
 {
 	struct cn10k_ml_dev *cn10k_mldev;
-	struct cnxk_ml_dev *cnxk_mldev;
-	struct cnxk_ml_model *model;
-	struct cnxk_ml_qp *qp;
-	uint16_t model_id;
-	uint16_t qp_id;
 
-	if (dev == NULL)
-		return -EINVAL;
-
-	cnxk_mldev = dev->data->dev_private;
 	cn10k_mldev = &cnxk_mldev->cn10k_mldev;
 
 	/* Release ocm_mask memory */
 	rte_free(cn10k_mldev->ocm.ocm_mask);
 
-	/* Stop and unload all models */
-	for (model_id = 0; model_id < dev->data->nb_models; model_id++) {
-		model = dev->data->models[model_id];
-		if (model != NULL) {
-			if (model->state == ML_CNXK_MODEL_STATE_STARTED) {
-				if (cn10k_ml_model_stop(dev, model_id) != 0)
-					plt_err("Could not stop model %u", model_id);
-			}
-			if (model->state == ML_CNXK_MODEL_STATE_LOADED) {
-				if (cn10k_ml_model_unload(dev, model_id) != 0)
-					plt_err("Could not unload model %u", model_id);
-			}
-			dev->data->models[model_id] = NULL;
-		}
-	}
-
-	rte_free(dev->data->models);
-
-	/* Destroy all queue pairs */
-	for (qp_id = 0; qp_id < dev->data->nb_queue_pairs; qp_id++) {
-		qp = dev->data->queue_pairs[qp_id];
-		if (qp != NULL) {
-			if (cnxk_ml_qp_destroy(dev, qp) != 0)
-				plt_err("Could not destroy queue pair %u", qp_id);
-			dev->data->queue_pairs[qp_id] = NULL;
-		}
-	}
-
-	rte_free(dev->data->queue_pairs);
-
 	/* Un-initialize xstats */
-	cn10k_ml_xstats_uninit(dev);
+	cn10k_ml_xstats_uninit(cnxk_mldev->mldev);
 
 	/* Unload firmware */
 	cn10k_ml_fw_unload(cnxk_mldev);
@@ -1154,20 +967,15 @@  cn10k_ml_dev_close(struct rte_ml_dev *dev)
 	roc_ml_reg_write64(&cn10k_mldev->roc, 0, ML_MLR_BASE);
 	plt_ml_dbg("ML_MLR_BASE = 0x%016lx", roc_ml_reg_read64(&cn10k_mldev->roc, ML_MLR_BASE));
 
-	cnxk_mldev->state = ML_CNXK_DEV_STATE_CLOSED;
-
-	/* Remove PCI device */
-	return rte_dev_remove(dev->device);
+	return 0;
 }
 
 int
-cn10k_ml_dev_start(struct rte_ml_dev *dev)
+cn10k_ml_dev_start(struct cnxk_ml_dev *cnxk_mldev)
 {
 	struct cn10k_ml_dev *cn10k_mldev;
-	struct cnxk_ml_dev *cnxk_mldev;
 	uint64_t reg_val64;
 
-	cnxk_mldev = dev->data->dev_private;
 	cn10k_mldev = &cnxk_mldev->cn10k_mldev;
 
 	reg_val64 = roc_ml_reg_read64(&cn10k_mldev->roc, ML_CFG);
@@ -1175,19 +983,15 @@  cn10k_ml_dev_start(struct rte_ml_dev *dev)
 	roc_ml_reg_write64(&cn10k_mldev->roc, reg_val64, ML_CFG);
 	plt_ml_dbg("ML_CFG => 0x%016lx", roc_ml_reg_read64(&cn10k_mldev->roc, ML_CFG));
 
-	cnxk_mldev->state = ML_CNXK_DEV_STATE_STARTED;
-
 	return 0;
 }
 
 int
-cn10k_ml_dev_stop(struct rte_ml_dev *dev)
+cn10k_ml_dev_stop(struct cnxk_ml_dev *cnxk_mldev)
 {
 	struct cn10k_ml_dev *cn10k_mldev;
-	struct cnxk_ml_dev *cnxk_mldev;
 	uint64_t reg_val64;
 
-	cnxk_mldev = dev->data->dev_private;
 	cn10k_mldev = &cnxk_mldev->cn10k_mldev;
 
 	reg_val64 = roc_ml_reg_read64(&cn10k_mldev->roc, ML_CFG);
@@ -1195,8 +999,6 @@  cn10k_ml_dev_stop(struct rte_ml_dev *dev)
 	roc_ml_reg_write64(&cn10k_mldev->roc, reg_val64, ML_CFG);
 	plt_ml_dbg("ML_CFG => 0x%016lx", roc_ml_reg_read64(&cn10k_mldev->roc, ML_CFG));
 
-	cnxk_mldev->state = ML_CNXK_DEV_STATE_CONFIGURED;
-
 	return 0;
 }
 
@@ -1217,7 +1019,7 @@  cn10k_ml_dev_queue_pair_setup(struct rte_ml_dev *dev, uint16_t queue_pair_id,
 	if (dev->data->queue_pairs[queue_pair_id] != NULL)
 		cn10k_ml_dev_queue_pair_release(dev, queue_pair_id);
 
-	cn10k_ml_dev_info_get(dev, &dev_info);
+	cnxk_ml_dev_info_get(dev, &dev_info);
 	if ((qp_conf->nb_desc > dev_info.max_desc) || (qp_conf->nb_desc == 0)) {
 		plt_err("Could not setup queue pair for %u descriptors", qp_conf->nb_desc);
 		return -EINVAL;
diff --git a/drivers/ml/cnxk/cn10k_ml_ops.h b/drivers/ml/cnxk/cn10k_ml_ops.h
index 16480b9ad8..d50b5bede7 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.h
+++ b/drivers/ml/cnxk/cn10k_ml_ops.h
@@ -10,6 +10,9 @@ 
 
 #include <roc_api.h>
 
+struct cnxk_ml_dev;
+struct cnxk_ml_qp;
+
 /* Firmware version string length */
 #define MLDEV_FIRMWARE_VERSION_LENGTH 32
 
@@ -286,11 +289,11 @@  struct cn10k_ml_req {
 };
 
 /* Device ops */
-int cn10k_ml_dev_info_get(struct rte_ml_dev *dev, struct rte_ml_dev_info *dev_info);
-int cn10k_ml_dev_configure(struct rte_ml_dev *dev, const struct rte_ml_dev_config *conf);
-int cn10k_ml_dev_close(struct rte_ml_dev *dev);
-int cn10k_ml_dev_start(struct rte_ml_dev *dev);
-int cn10k_ml_dev_stop(struct rte_ml_dev *dev);
+int cn10k_ml_dev_info_get(struct cnxk_ml_dev *cnxk_mldev, struct rte_ml_dev_info *dev_info);
+int cn10k_ml_dev_configure(struct cnxk_ml_dev *cnxk_mldev, const struct rte_ml_dev_config *conf);
+int cn10k_ml_dev_close(struct cnxk_ml_dev *cnxk_mldev);
+int cn10k_ml_dev_start(struct cnxk_ml_dev *cnxk_mldev);
+int cn10k_ml_dev_stop(struct cnxk_ml_dev *cnxk_mldev);
 int cn10k_ml_dev_dump(struct rte_ml_dev *dev, FILE *fp);
 int cn10k_ml_dev_selftest(struct rte_ml_dev *dev);
 int cn10k_ml_dev_queue_pair_setup(struct rte_ml_dev *dev, uint16_t queue_pair_id,
@@ -336,4 +339,7 @@  __rte_hot int cn10k_ml_op_error_get(struct rte_ml_dev *dev, struct rte_ml_op *op
 				    struct rte_ml_op_error *error);
 __rte_hot int cn10k_ml_inference_sync(struct rte_ml_dev *dev, struct rte_ml_op *op);
 
+/* Temporarily set below functions as non-static */
+int cnxk_ml_qp_destroy(const struct rte_ml_dev *dev, struct cnxk_ml_qp *qp);
+
 #endif /* _CN10K_ML_OPS_H_ */
diff --git a/drivers/ml/cnxk/cnxk_ml_dev.h b/drivers/ml/cnxk/cnxk_ml_dev.h
index 51315de622..02605fa28f 100644
--- a/drivers/ml/cnxk/cnxk_ml_dev.h
+++ b/drivers/ml/cnxk/cnxk_ml_dev.h
@@ -53,6 +53,9 @@  struct cnxk_ml_dev {
 
 	/* CN10K device structure */
 	struct cn10k_ml_dev cn10k_mldev;
+
+	/* Maximum number of layers */
+	uint64_t max_nb_layers;
 };
 
 #endif /* _CNXK_ML_DEV_H_ */
diff --git a/drivers/ml/cnxk/cnxk_ml_ops.c b/drivers/ml/cnxk/cnxk_ml_ops.c
index 03402681c5..07a4daabc5 100644
--- a/drivers/ml/cnxk/cnxk_ml_ops.c
+++ b/drivers/ml/cnxk/cnxk_ml_ops.c
@@ -5,15 +5,291 @@ 
 #include <rte_mldev.h>
 #include <rte_mldev_pmd.h>
 
+#include "cnxk_ml_dev.h"
+#include "cnxk_ml_io.h"
+#include "cnxk_ml_model.h"
 #include "cnxk_ml_ops.h"
 
+int
+cnxk_ml_dev_info_get(struct rte_ml_dev *dev, struct rte_ml_dev_info *dev_info)
+{
+	struct cnxk_ml_dev *cnxk_mldev;
+
+	if (dev == NULL || dev_info == NULL)
+		return -EINVAL;
+
+	cnxk_mldev = dev->data->dev_private;
+
+	memset(dev_info, 0, sizeof(struct rte_ml_dev_info));
+	dev_info->driver_name = dev->device->driver->name;
+	dev_info->max_models = ML_CNXK_MAX_MODELS;
+
+	return cn10k_ml_dev_info_get(cnxk_mldev, dev_info);
+}
+
+static int
+cnxk_ml_dev_configure(struct rte_ml_dev *dev, const struct rte_ml_dev_config *conf)
+{
+	struct rte_ml_dev_info dev_info;
+	struct cnxk_ml_dev *cnxk_mldev;
+	struct cnxk_ml_model *model;
+	struct cnxk_ml_qp *qp;
+	uint16_t model_id;
+	uint32_t mz_size;
+	uint16_t qp_id;
+	int ret;
+
+	if (dev == NULL)
+		return -EINVAL;
+
+	/* Get CNXK device handle */
+	cnxk_mldev = dev->data->dev_private;
+
+	cnxk_ml_dev_info_get(dev, &dev_info);
+	if (conf->nb_models > dev_info.max_models) {
+		plt_err("Invalid device config, nb_models > %u\n", dev_info.max_models);
+		return -EINVAL;
+	}
+
+	if (conf->nb_queue_pairs > dev_info.max_queue_pairs) {
+		plt_err("Invalid device config, nb_queue_pairs > %u\n", dev_info.max_queue_pairs);
+		return -EINVAL;
+	}
+
+	if (cnxk_mldev->state == ML_CNXK_DEV_STATE_PROBED) {
+		plt_ml_dbg("Configuring ML device, nb_queue_pairs = %u, nb_models = %u",
+			   conf->nb_queue_pairs, conf->nb_models);
+
+		/* Load firmware */
+		ret = cn10k_ml_fw_load(cnxk_mldev);
+		if (ret != 0)
+			return ret;
+	} else if (cnxk_mldev->state == ML_CNXK_DEV_STATE_CONFIGURED) {
+		plt_ml_dbg("Re-configuring ML device, nb_queue_pairs = %u, nb_models = %u",
+			   conf->nb_queue_pairs, conf->nb_models);
+	} else if (cnxk_mldev->state == ML_CNXK_DEV_STATE_STARTED) {
+		plt_err("Device can't be reconfigured in started state\n");
+		return -ENOTSUP;
+	} else if (cnxk_mldev->state == ML_CNXK_DEV_STATE_CLOSED) {
+		plt_err("Device can't be reconfigured after close\n");
+		return -ENOTSUP;
+	}
+
+	/* Configure queue-pairs */
+	if (dev->data->queue_pairs == NULL) {
+		mz_size = sizeof(dev->data->queue_pairs[0]) * conf->nb_queue_pairs;
+		dev->data->queue_pairs =
+			rte_zmalloc("cnxk_mldev_queue_pairs", mz_size, RTE_CACHE_LINE_SIZE);
+		if (dev->data->queue_pairs == NULL) {
+			dev->data->nb_queue_pairs = 0;
+			plt_err("Failed to get memory for queue_pairs, nb_queue_pairs %u",
+				conf->nb_queue_pairs);
+			return -ENOMEM;
+		}
+	} else { /* Re-configure */
+		void **queue_pairs;
+
+		/* Release all queue pairs as ML spec doesn't support queue_pair_destroy. */
+		for (qp_id = 0; qp_id < dev->data->nb_queue_pairs; qp_id++) {
+			qp = dev->data->queue_pairs[qp_id];
+			if (qp != NULL) {
+				ret = cn10k_ml_dev_queue_pair_release(dev, qp_id);
+				if (ret < 0)
+					return ret;
+			}
+		}
+
+		queue_pairs = dev->data->queue_pairs;
+		queue_pairs =
+			rte_realloc(queue_pairs, sizeof(queue_pairs[0]) * conf->nb_queue_pairs,
+				    RTE_CACHE_LINE_SIZE);
+		if (queue_pairs == NULL) {
+			dev->data->nb_queue_pairs = 0;
+			plt_err("Failed to realloc queue_pairs, nb_queue_pairs = %u",
+				conf->nb_queue_pairs);
+			ret = -ENOMEM;
+			goto error;
+		}
+
+		memset(queue_pairs, 0, sizeof(queue_pairs[0]) * conf->nb_queue_pairs);
+		dev->data->queue_pairs = queue_pairs;
+	}
+	dev->data->nb_queue_pairs = conf->nb_queue_pairs;
+
+	/* Allocate ML models */
+	if (dev->data->models == NULL) {
+		mz_size = sizeof(dev->data->models[0]) * conf->nb_models;
+		dev->data->models = rte_zmalloc("cnxk_mldev_models", mz_size, RTE_CACHE_LINE_SIZE);
+		if (dev->data->models == NULL) {
+			dev->data->nb_models = 0;
+			plt_err("Failed to get memory for ml_models, nb_models %u",
+				conf->nb_models);
+			ret = -ENOMEM;
+			goto error;
+		}
+	} else {
+		/* Re-configure */
+		void **models;
+
+		/* Stop and unload all models */
+		for (model_id = 0; model_id < dev->data->nb_models; model_id++) {
+			model = dev->data->models[model_id];
+			if (model != NULL) {
+				if (model->state == ML_CNXK_MODEL_STATE_STARTED) {
+					if (cn10k_ml_model_stop(dev, model_id) != 0)
+						plt_err("Could not stop model %u", model_id);
+				}
+				if (model->state == ML_CNXK_MODEL_STATE_LOADED) {
+					if (cn10k_ml_model_unload(dev, model_id) != 0)
+						plt_err("Could not unload model %u", model_id);
+				}
+				dev->data->models[model_id] = NULL;
+			}
+		}
+
+		models = dev->data->models;
+		models = rte_realloc(models, sizeof(models[0]) * conf->nb_models,
+				     RTE_CACHE_LINE_SIZE);
+		if (models == NULL) {
+			dev->data->nb_models = 0;
+			plt_err("Failed to realloc ml_models, nb_models = %u", conf->nb_models);
+			ret = -ENOMEM;
+			goto error;
+		}
+		memset(models, 0, sizeof(models[0]) * conf->nb_models);
+		dev->data->models = models;
+	}
+	dev->data->nb_models = conf->nb_models;
+
+	ret = cn10k_ml_dev_configure(cnxk_mldev, conf);
+	if (ret != 0) {
+		plt_err("Failed to configure CN10K ML Device");
+		goto error;
+	}
+
+	/* Set device capabilities */
+	cnxk_mldev->max_nb_layers =
+		cnxk_mldev->cn10k_mldev.fw.req->cn10k_req.jd.fw_load.cap.s.max_models;
+
+	cnxk_mldev->nb_models_loaded = 0;
+	cnxk_mldev->nb_models_started = 0;
+	cnxk_mldev->nb_models_stopped = 0;
+	cnxk_mldev->nb_models_unloaded = 0;
+	cnxk_mldev->state = ML_CNXK_DEV_STATE_CONFIGURED;
+
+	return 0;
+
+error:
+	rte_free(dev->data->queue_pairs);
+	rte_free(dev->data->models);
+
+	return ret;
+}
+
+static int
+cnxk_ml_dev_close(struct rte_ml_dev *dev)
+{
+	struct cnxk_ml_dev *cnxk_mldev;
+	struct cnxk_ml_model *model;
+	struct cnxk_ml_qp *qp;
+	uint16_t model_id;
+	uint16_t qp_id;
+
+	if (dev == NULL)
+		return -EINVAL;
+
+	cnxk_mldev = dev->data->dev_private;
+
+	if (cn10k_ml_dev_close(cnxk_mldev) != 0)
+		plt_err("Failed to close CN10K ML Device");
+
+	/* Stop and unload all models */
+	for (model_id = 0; model_id < dev->data->nb_models; model_id++) {
+		model = dev->data->models[model_id];
+		if (model != NULL) {
+			if (model->state == ML_CNXK_MODEL_STATE_STARTED) {
+				if (cn10k_ml_model_stop(dev, model_id) != 0)
+					plt_err("Could not stop model %u", model_id);
+			}
+			if (model->state == ML_CNXK_MODEL_STATE_LOADED) {
+				if (cn10k_ml_model_unload(dev, model_id) != 0)
+					plt_err("Could not unload model %u", model_id);
+			}
+			dev->data->models[model_id] = NULL;
+		}
+	}
+
+	rte_free(dev->data->models);
+
+	/* Destroy all queue pairs */
+	for (qp_id = 0; qp_id < dev->data->nb_queue_pairs; qp_id++) {
+		qp = dev->data->queue_pairs[qp_id];
+		if (qp != NULL) {
+			if (cnxk_ml_qp_destroy(dev, qp) != 0)
+				plt_err("Could not destroy queue pair %u", qp_id);
+			dev->data->queue_pairs[qp_id] = NULL;
+		}
+	}
+
+	rte_free(dev->data->queue_pairs);
+
+	cnxk_mldev->state = ML_CNXK_DEV_STATE_CLOSED;
+
+	/* Remove PCI device */
+	return rte_dev_remove(dev->device);
+}
+
+static int
+cnxk_ml_dev_start(struct rte_ml_dev *dev)
+{
+	struct cnxk_ml_dev *cnxk_mldev;
+	int ret;
+
+	if (dev == NULL)
+		return -EINVAL;
+
+	cnxk_mldev = dev->data->dev_private;
+
+	ret = cn10k_ml_dev_start(cnxk_mldev);
+	if (ret != 0) {
+		plt_err("Failed to start CN10K ML Device");
+		return ret;
+	}
+
+	cnxk_mldev->state = ML_CNXK_DEV_STATE_STARTED;
+
+	return 0;
+}
+
+static int
+cnxk_ml_dev_stop(struct rte_ml_dev *dev)
+{
+	struct cnxk_ml_dev *cnxk_mldev;
+	int ret;
+
+	if (dev == NULL)
+		return -EINVAL;
+
+	cnxk_mldev = dev->data->dev_private;
+
+	ret = cn10k_ml_dev_stop(cnxk_mldev);
+	if (ret != 0) {
+		plt_err("Failed to stop CN10K ML Device");
+		return ret;
+	}
+
+	cnxk_mldev->state = ML_CNXK_DEV_STATE_CONFIGURED;
+
+	return 0;
+}
+
 struct rte_ml_dev_ops cnxk_ml_ops = {
 	/* Device control ops */
-	.dev_info_get = cn10k_ml_dev_info_get,
-	.dev_configure = cn10k_ml_dev_configure,
-	.dev_close = cn10k_ml_dev_close,
-	.dev_start = cn10k_ml_dev_start,
-	.dev_stop = cn10k_ml_dev_stop,
+	.dev_info_get = cnxk_ml_dev_info_get,
+	.dev_configure = cnxk_ml_dev_configure,
+	.dev_close = cnxk_ml_dev_close,
+	.dev_start = cnxk_ml_dev_start,
+	.dev_stop = cnxk_ml_dev_stop,
 	.dev_dump = cn10k_ml_dev_dump,
 	.dev_selftest = cn10k_ml_dev_selftest,
 
diff --git a/drivers/ml/cnxk/cnxk_ml_ops.h b/drivers/ml/cnxk/cnxk_ml_ops.h
index a925c07580..2996928d7d 100644
--- a/drivers/ml/cnxk/cnxk_ml_ops.h
+++ b/drivers/ml/cnxk/cnxk_ml_ops.h
@@ -62,4 +62,7 @@  struct cnxk_ml_qp {
 
 extern struct rte_ml_dev_ops cnxk_ml_ops;
 
+/* Temporarily set cnxk driver functions as non-static */
+int cnxk_ml_dev_info_get(struct rte_ml_dev *dev, struct rte_ml_dev_info *dev_info);
+
 #endif /* _CNXK_ML_OPS_H_ */