Added driver functions to get input and output buffer sizes
for a given batch size. This function would compute the buffer
size based on specific requirements of the device.
Signed-off-by: Srikanth Yalavarthi <syalavarthi@marvell.com>
---
drivers/ml/cnxk/cn10k_ml_ops.c | 52 ++++++++++++++++++++++++++++++++++
1 file changed, 52 insertions(+)
@@ -935,6 +935,54 @@ cn10k_ml_model_params_update(struct rte_ml_dev *dev, uint16_t model_id, void *bu
return 0;
}
+static int
+cn10k_ml_io_input_size_get(struct rte_ml_dev *dev, uint16_t model_id, uint32_t nb_batches,
+ uint64_t *input_qsize, uint64_t *input_dsize)
+{
+ struct cn10k_ml_model *model;
+
+ model = dev->data->models[model_id];
+
+ if (model == NULL) {
+ plt_err("Invalid model_id = %u", model_id);
+ return -EINVAL;
+ }
+
+ if (input_qsize != NULL)
+ *input_qsize = PLT_U64_CAST(model->addr.total_input_sz_q *
+ PLT_DIV_CEIL(nb_batches, model->batch_size));
+
+ if (input_dsize != NULL)
+ *input_dsize = PLT_U64_CAST(model->addr.total_input_sz_d *
+ PLT_DIV_CEIL(nb_batches, model->batch_size));
+
+ return 0;
+}
+
+static int
+cn10k_ml_io_output_size_get(struct rte_ml_dev *dev, uint16_t model_id, uint32_t nb_batches,
+ uint64_t *output_qsize, uint64_t *output_dsize)
+{
+ struct cn10k_ml_model *model;
+
+ model = dev->data->models[model_id];
+
+ if (model == NULL) {
+ plt_err("Invalid model_id = %u", model_id);
+ return -EINVAL;
+ }
+
+ if (output_qsize != NULL)
+ *output_qsize = PLT_U64_CAST(model->addr.total_output_sz_q *
+ PLT_DIV_CEIL(nb_batches, model->batch_size));
+
+ if (output_dsize != NULL)
+ *output_dsize = PLT_U64_CAST(model->addr.total_output_sz_d *
+ PLT_DIV_CEIL(nb_batches, model->batch_size));
+
+ return 0;
+}
+
struct rte_ml_dev_ops cn10k_ml_ops = {
/* Device control ops */
.dev_info_get = cn10k_ml_dev_info_get,
@@ -954,4 +1002,8 @@ struct rte_ml_dev_ops cn10k_ml_ops = {
.model_stop = cn10k_ml_model_stop,
.model_info_get = cn10k_ml_model_info_get,
.model_params_update = cn10k_ml_model_params_update,
+
+ /* I/O ops */
+ .io_input_size_get = cn10k_ml_io_input_size_get,
+ .io_output_size_get = cn10k_ml_io_output_size_get,
};