[v2,22/37] ml/cnxk: enable quantization and dequantization

Message ID 20221208201806.21893-23-syalavarthi@marvell.com (mailing list archive)
State Superseded, archived
Delegated to: Thomas Monjalon
Headers
Series Implementation of ML CNXK driver |

Checks

Context Check Description
ci/checkpatch success coding style OK

Commit Message

Srikanth Yalavarthi Dec. 8, 2022, 8:17 p.m. UTC
  Implemented driver functions to quantize / dequantize input
and output data. Support is enabled for multiple batches.
Quantization / dequantization use the type conversion functions
defined in ML common code.

Signed-off-by: Srikanth Yalavarthi <syalavarthi@marvell.com>
---
 drivers/ml/cnxk/cn10k_ml_ops.c | 150 +++++++++++++++++++++++++++++++++
 1 file changed, 150 insertions(+)
  

Patch

diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c
index c96f17ebd8..9868d2a598 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.c
+++ b/drivers/ml/cnxk/cn10k_ml_ops.c
@@ -5,6 +5,8 @@ 
 #include <rte_mldev.h>
 #include <rte_mldev_pmd.h>
 
+#include <ml_utils.h>
+
 #include "cn10k_ml_dev.h"
 #include "cn10k_ml_model.h"
 #include "cn10k_ml_ops.h"
@@ -987,6 +989,152 @@  cn10k_ml_io_output_size_get(struct rte_ml_dev *dev, int16_t model_id, uint32_t n
 	return 0;
 }
 
+static int
+cn10k_ml_io_quantize(struct rte_ml_dev *dev, int16_t model_id, uint16_t nb_batches, void *dbuffer,
+		     void *qbuffer)
+{
+	struct cn10k_ml_model *model;
+	uint8_t *lcl_dbuffer;
+	uint8_t *lcl_qbuffer;
+	uint32_t batch_id;
+	uint32_t i;
+	int ret;
+
+	model = dev->data->models[model_id];
+
+	if (model == NULL) {
+		plt_err("Invalid model_id = %d", model_id);
+		return -EINVAL;
+	}
+
+	lcl_dbuffer = dbuffer;
+	lcl_qbuffer = qbuffer;
+	batch_id = 0;
+
+next_batch:
+	for (i = 0; i < model->metadata.model.num_input; i++) {
+		if (model->metadata.input[i].input_type ==
+		    model->metadata.input[i].model_input_type) {
+			memcpy(lcl_qbuffer, lcl_dbuffer, model->addr.input[i].sz_d);
+		} else {
+			switch (model->metadata.input[i].model_input_type) {
+			case RTE_ML_IO_TYPE_INT8:
+				ret = ml_float32_to_int8(model->metadata.input[i].qscale,
+							 model->addr.input[i].nb_elements,
+							 lcl_dbuffer, lcl_qbuffer);
+				break;
+			case RTE_ML_IO_TYPE_UINT8:
+				ret = ml_float32_to_uint8(model->metadata.input[i].qscale,
+							  model->addr.input[i].nb_elements,
+							  lcl_dbuffer, lcl_qbuffer);
+				break;
+			case RTE_ML_IO_TYPE_INT16:
+				ret = ml_float32_to_int16(model->metadata.input[i].qscale,
+							  model->addr.input[i].nb_elements,
+							  lcl_dbuffer, lcl_qbuffer);
+				break;
+			case RTE_ML_IO_TYPE_UINT16:
+				ret = ml_float32_to_uint16(model->metadata.input[i].qscale,
+							   model->addr.input[i].nb_elements,
+							   lcl_dbuffer, lcl_qbuffer);
+				break;
+			case RTE_ML_IO_TYPE_FP16:
+				ret = ml_float32_to_float16(model->addr.input[i].nb_elements,
+							    lcl_dbuffer, lcl_qbuffer);
+				break;
+			default:
+				plt_err("Unsupported model_input_type[%u] : %u", i,
+					model->metadata.input[i].model_input_type);
+				ret = -ENOTSUP;
+			}
+			if (ret < 0)
+				return ret;
+		}
+
+		lcl_dbuffer += model->addr.input[i].sz_d;
+		lcl_qbuffer += model->addr.input[i].sz_q;
+	}
+
+	batch_id++;
+	if (batch_id < PLT_DIV_CEIL(nb_batches, model->batch_size))
+		goto next_batch;
+
+	return 0;
+}
+
+static int
+cn10k_ml_io_dequantize(struct rte_ml_dev *dev, int16_t model_id, uint16_t nb_batches, void *qbuffer,
+		       void *dbuffer)
+{
+	struct cn10k_ml_model *model;
+	uint8_t *lcl_qbuffer;
+	uint8_t *lcl_dbuffer;
+	uint32_t batch_id;
+	uint32_t i;
+	int ret;
+
+	model = dev->data->models[model_id];
+
+	if (model == NULL) {
+		plt_err("Invalid model_id = %d", model_id);
+		return -EINVAL;
+	}
+
+	lcl_dbuffer = dbuffer;
+	lcl_qbuffer = qbuffer;
+	batch_id = 0;
+
+next_batch:
+	for (i = 0; i < model->metadata.model.num_output; i++) {
+		if (model->metadata.output[i].output_type ==
+		    model->metadata.output[i].model_output_type) {
+			memcpy(lcl_dbuffer, lcl_qbuffer, model->addr.output[i].sz_q);
+		} else {
+			switch (model->metadata.output[i].model_output_type) {
+			case RTE_ML_IO_TYPE_INT8:
+				ret = ml_int8_to_float32(model->metadata.output[i].dscale,
+							 model->addr.output[i].nb_elements,
+							 lcl_qbuffer, lcl_dbuffer);
+				break;
+			case RTE_ML_IO_TYPE_UINT8:
+				ret = ml_uint8_to_float32(model->metadata.output[i].dscale,
+							  model->addr.output[i].nb_elements,
+							  lcl_qbuffer, lcl_dbuffer);
+				break;
+			case RTE_ML_IO_TYPE_INT16:
+				ret = ml_int16_to_float32(model->metadata.output[i].dscale,
+							  model->addr.output[i].nb_elements,
+							  lcl_qbuffer, lcl_dbuffer);
+				break;
+			case RTE_ML_IO_TYPE_UINT16:
+				ret = ml_uint16_to_float32(model->metadata.output[i].dscale,
+							   model->addr.output[i].nb_elements,
+							   lcl_qbuffer, lcl_dbuffer);
+				break;
+			case RTE_ML_IO_TYPE_FP16:
+				ret = ml_float16_to_float32(model->addr.output[i].nb_elements,
+							    lcl_qbuffer, lcl_dbuffer);
+				break;
+			default:
+				plt_err("Unsupported model_output_type[%u] : %u", i,
+					model->metadata.output[i].model_output_type);
+				ret = -ENOTSUP;
+			}
+			if (ret < 0)
+				return ret;
+		}
+
+		lcl_qbuffer += model->addr.output[i].sz_q;
+		lcl_dbuffer += model->addr.output[i].sz_d;
+	}
+
+	batch_id++;
+	if (batch_id < PLT_DIV_CEIL(nb_batches, model->batch_size))
+		goto next_batch;
+
+	return 0;
+}
+
 struct rte_ml_dev_ops cn10k_ml_ops = {
 	/* Device control ops */
 	.dev_info_get = cn10k_ml_dev_info_get,
@@ -1010,4 +1158,6 @@  struct rte_ml_dev_ops cn10k_ml_ops = {
 	/* I/O ops */
 	.io_input_size_get = cn10k_ml_io_input_size_get,
 	.io_output_size_get = cn10k_ml_io_output_size_get,
+	.io_quantize = cn10k_ml_io_quantize,
+	.io_dequantize = cn10k_ml_io_dequantize,
 };