[v5,22/39] ml/cnxk: add support to get IO buffer sizes

Message ID 20230207160719.1307-23-syalavarthi@marvell.com (mailing list archive)
State Superseded, archived
Delegated to: Thomas Monjalon
Headers
Series Implementation of ML CNXK driver |

Checks

Context Check Description
ci/checkpatch success coding style OK

Commit Message

Srikanth Yalavarthi Feb. 7, 2023, 4:07 p.m. UTC
  Added driver functions to get input and output buffer sizes
for a given batch size. This function would compute the buffer
size based on specific requirements of the device.

Signed-off-by: Srikanth Yalavarthi <syalavarthi@marvell.com>
---
 drivers/ml/cnxk/cn10k_ml_ops.c | 52 ++++++++++++++++++++++++++++++++++
 1 file changed, 52 insertions(+)
  

Patch

diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c
index 92bf1a0854..b5c89bee40 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.c
+++ b/drivers/ml/cnxk/cn10k_ml_ops.c
@@ -935,6 +935,54 @@  cn10k_ml_model_params_update(struct rte_ml_dev *dev, uint16_t model_id, void *bu
 	return 0;
 }
 
+static int
+cn10k_ml_io_input_size_get(struct rte_ml_dev *dev, uint16_t model_id, uint32_t nb_batches,
+			   uint64_t *input_qsize, uint64_t *input_dsize)
+{
+	struct cn10k_ml_model *model;
+
+	model = dev->data->models[model_id];
+
+	if (model == NULL) {
+		plt_err("Invalid model_id = %u", model_id);
+		return -EINVAL;
+	}
+
+	if (input_qsize != NULL)
+		*input_qsize = PLT_U64_CAST(model->addr.total_input_sz_q *
+					    PLT_DIV_CEIL(nb_batches, model->batch_size));
+
+	if (input_dsize != NULL)
+		*input_dsize = PLT_U64_CAST(model->addr.total_input_sz_d *
+					    PLT_DIV_CEIL(nb_batches, model->batch_size));
+
+	return 0;
+}
+
+static int
+cn10k_ml_io_output_size_get(struct rte_ml_dev *dev, uint16_t model_id, uint32_t nb_batches,
+			    uint64_t *output_qsize, uint64_t *output_dsize)
+{
+	struct cn10k_ml_model *model;
+
+	model = dev->data->models[model_id];
+
+	if (model == NULL) {
+		plt_err("Invalid model_id = %u", model_id);
+		return -EINVAL;
+	}
+
+	if (output_qsize != NULL)
+		*output_qsize = PLT_U64_CAST(model->addr.total_output_sz_q *
+					     PLT_DIV_CEIL(nb_batches, model->batch_size));
+
+	if (output_dsize != NULL)
+		*output_dsize = PLT_U64_CAST(model->addr.total_output_sz_d *
+					     PLT_DIV_CEIL(nb_batches, model->batch_size));
+
+	return 0;
+}
+
 struct rte_ml_dev_ops cn10k_ml_ops = {
 	/* Device control ops */
 	.dev_info_get = cn10k_ml_dev_info_get,
@@ -954,4 +1002,8 @@  struct rte_ml_dev_ops cn10k_ml_ops = {
 	.model_stop = cn10k_ml_model_stop,
 	.model_info_get = cn10k_ml_model_info_get,
 	.model_params_update = cn10k_ml_model_params_update,
+
+	/* I/O ops */
+	.io_input_size_get = cn10k_ml_io_input_size_get,
+	.io_output_size_get = cn10k_ml_io_output_size_get,
 };