[v5,23/39] ml/cnxk: enable quantization and dequantization

Message ID 20230207160719.1307-24-syalavarthi@marvell.com (mailing list archive)
State Superseded, archived
Delegated to: Thomas Monjalon
Headers
Series Implementation of ML CNXK driver |

Checks

Context Check Description
ci/checkpatch success coding style OK

Commit Message

Srikanth Yalavarthi Feb. 7, 2023, 4:07 p.m. UTC
  Implemented driver functions to quantize / dequantize input
and output data. Support is enabled for multiple batches.
Quantization / dequantization use the type conversion functions
defined in ML common code.

Signed-off-by: Srikanth Yalavarthi <syalavarthi@marvell.com>
---
 drivers/ml/cnxk/cn10k_ml_ops.c | 151 +++++++++++++++++++++++++++++++++
 1 file changed, 151 insertions(+)
  

Patch

diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c
index b5c89bee40..231c9b340b 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.c
+++ b/drivers/ml/cnxk/cn10k_ml_ops.c
@@ -5,6 +5,8 @@ 
 #include <rte_mldev.h>
 #include <rte_mldev_pmd.h>
 
+#include <mldev_utils.h>
+
 #include "cn10k_ml_dev.h"
 #include "cn10k_ml_model.h"
 #include "cn10k_ml_ops.h"
@@ -983,6 +985,153 @@  cn10k_ml_io_output_size_get(struct rte_ml_dev *dev, uint16_t model_id, uint32_t
 	return 0;
 }
 
+static int
+cn10k_ml_io_quantize(struct rte_ml_dev *dev, uint16_t model_id, uint16_t nb_batches, void *dbuffer,
+		     void *qbuffer)
+{
+	struct cn10k_ml_model *model;
+	uint8_t *lcl_dbuffer;
+	uint8_t *lcl_qbuffer;
+	uint32_t batch_id;
+	uint32_t i;
+	int ret;
+
+	model = dev->data->models[model_id];
+
+	if (model == NULL) {
+		plt_err("Invalid model_id = %u", model_id);
+		return -EINVAL;
+	}
+
+	lcl_dbuffer = dbuffer;
+	lcl_qbuffer = qbuffer;
+	batch_id = 0;
+
+next_batch:
+	for (i = 0; i < model->metadata.model.num_input; i++) {
+		if (model->metadata.input[i].input_type ==
+		    model->metadata.input[i].model_input_type) {
+			rte_memcpy(lcl_qbuffer, lcl_dbuffer, model->addr.input[i].sz_d);
+		} else {
+			switch (model->metadata.input[i].model_input_type) {
+			case RTE_ML_IO_TYPE_INT8:
+				ret = rte_ml_io_float32_to_int8(model->metadata.input[i].qscale,
+								model->addr.input[i].nb_elements,
+								lcl_dbuffer, lcl_qbuffer);
+				break;
+			case RTE_ML_IO_TYPE_UINT8:
+				ret = rte_ml_io_float32_to_uint8(model->metadata.input[i].qscale,
+								 model->addr.input[i].nb_elements,
+								 lcl_dbuffer, lcl_qbuffer);
+				break;
+			case RTE_ML_IO_TYPE_INT16:
+				ret = rte_ml_io_float32_to_int16(model->metadata.input[i].qscale,
+								 model->addr.input[i].nb_elements,
+								 lcl_dbuffer, lcl_qbuffer);
+				break;
+			case RTE_ML_IO_TYPE_UINT16:
+				ret = rte_ml_io_float32_to_uint16(model->metadata.input[i].qscale,
+								  model->addr.input[i].nb_elements,
+								  lcl_dbuffer, lcl_qbuffer);
+				break;
+			case RTE_ML_IO_TYPE_FP16:
+				ret = rte_ml_io_float32_to_float16(model->addr.input[i].nb_elements,
+								   lcl_dbuffer, lcl_qbuffer);
+				break;
+			default:
+				plt_err("Unsupported model_input_type[%u] : %u", i,
+					model->metadata.input[i].model_input_type);
+				ret = -ENOTSUP;
+			}
+			if (ret < 0)
+				return ret;
+		}
+
+		lcl_dbuffer += model->addr.input[i].sz_d;
+		lcl_qbuffer += model->addr.input[i].sz_q;
+	}
+
+	batch_id++;
+	if (batch_id < PLT_DIV_CEIL(nb_batches, model->batch_size))
+		goto next_batch;
+
+	return 0;
+}
+
+static int
+cn10k_ml_io_dequantize(struct rte_ml_dev *dev, uint16_t model_id, uint16_t nb_batches,
+		       void *qbuffer, void *dbuffer)
+{
+	struct cn10k_ml_model *model;
+	uint8_t *lcl_qbuffer;
+	uint8_t *lcl_dbuffer;
+	uint32_t batch_id;
+	uint32_t i;
+	int ret;
+
+	model = dev->data->models[model_id];
+
+	if (model == NULL) {
+		plt_err("Invalid model_id = %u", model_id);
+		return -EINVAL;
+	}
+
+	lcl_dbuffer = dbuffer;
+	lcl_qbuffer = qbuffer;
+	batch_id = 0;
+
+next_batch:
+	for (i = 0; i < model->metadata.model.num_output; i++) {
+		if (model->metadata.output[i].output_type ==
+		    model->metadata.output[i].model_output_type) {
+			rte_memcpy(lcl_dbuffer, lcl_qbuffer, model->addr.output[i].sz_q);
+		} else {
+			switch (model->metadata.output[i].model_output_type) {
+			case RTE_ML_IO_TYPE_INT8:
+				ret = rte_ml_io_int8_to_float32(model->metadata.output[i].dscale,
+								model->addr.output[i].nb_elements,
+								lcl_qbuffer, lcl_dbuffer);
+				break;
+			case RTE_ML_IO_TYPE_UINT8:
+				ret = rte_ml_io_uint8_to_float32(model->metadata.output[i].dscale,
+								 model->addr.output[i].nb_elements,
+								 lcl_qbuffer, lcl_dbuffer);
+				break;
+			case RTE_ML_IO_TYPE_INT16:
+				ret = rte_ml_io_int16_to_float32(model->metadata.output[i].dscale,
+								 model->addr.output[i].nb_elements,
+								 lcl_qbuffer, lcl_dbuffer);
+				break;
+			case RTE_ML_IO_TYPE_UINT16:
+				ret = rte_ml_io_uint16_to_float32(model->metadata.output[i].dscale,
+								  model->addr.output[i].nb_elements,
+								  lcl_qbuffer, lcl_dbuffer);
+				break;
+			case RTE_ML_IO_TYPE_FP16:
+				ret = rte_ml_io_float16_to_float32(
+					model->addr.output[i].nb_elements, lcl_qbuffer,
+					lcl_dbuffer);
+				break;
+			default:
+				plt_err("Unsupported model_output_type[%u] : %u", i,
+					model->metadata.output[i].model_output_type);
+				ret = -ENOTSUP;
+			}
+			if (ret < 0)
+				return ret;
+		}
+
+		lcl_qbuffer += model->addr.output[i].sz_q;
+		lcl_dbuffer += model->addr.output[i].sz_d;
+	}
+
+	batch_id++;
+	if (batch_id < PLT_DIV_CEIL(nb_batches, model->batch_size))
+		goto next_batch;
+
+	return 0;
+}
+
 struct rte_ml_dev_ops cn10k_ml_ops = {
 	/* Device control ops */
 	.dev_info_get = cn10k_ml_dev_info_get,
@@ -1006,4 +1155,6 @@  struct rte_ml_dev_ops cn10k_ml_ops = {
 	/* I/O ops */
 	.io_input_size_get = cn10k_ml_io_input_size_get,
 	.io_output_size_get = cn10k_ml_io_output_size_get,
+	.io_quantize = cn10k_ml_io_quantize,
+	.io_dequantize = cn10k_ml_io_dequantize,
 };