@@ -5,6 +5,8 @@
#include <rte_mldev.h>
#include <rte_mldev_pmd.h>
+#include <mldev_utils.h>
+
#include "cn10k_ml_dev.h"
#include "cn10k_ml_model.h"
#include "cn10k_ml_ops.h"
@@ -983,6 +985,153 @@ cn10k_ml_io_output_size_get(struct rte_ml_dev *dev, uint16_t model_id, uint32_t
return 0;
}
+static int
+cn10k_ml_io_quantize(struct rte_ml_dev *dev, uint16_t model_id, uint16_t nb_batches, void *dbuffer,
+ void *qbuffer)
+{
+ struct cn10k_ml_model *model;
+ uint8_t *lcl_dbuffer;
+ uint8_t *lcl_qbuffer;
+ uint32_t batch_id;
+ uint32_t i;
+ int ret;
+
+ model = dev->data->models[model_id];
+
+ if (model == NULL) {
+ plt_err("Invalid model_id = %u", model_id);
+ return -EINVAL;
+ }
+
+ lcl_dbuffer = dbuffer;
+ lcl_qbuffer = qbuffer;
+ batch_id = 0;
+
+next_batch:
+ for (i = 0; i < model->metadata.model.num_input; i++) {
+ if (model->metadata.input[i].input_type ==
+ model->metadata.input[i].model_input_type) {
+ rte_memcpy(lcl_qbuffer, lcl_dbuffer, model->addr.input[i].sz_d);
+ } else {
+ switch (model->metadata.input[i].model_input_type) {
+ case RTE_ML_IO_TYPE_INT8:
+ ret = rte_ml_io_float32_to_int8(model->metadata.input[i].qscale,
+ model->addr.input[i].nb_elements,
+ lcl_dbuffer, lcl_qbuffer);
+ break;
+ case RTE_ML_IO_TYPE_UINT8:
+ ret = rte_ml_io_float32_to_uint8(model->metadata.input[i].qscale,
+ model->addr.input[i].nb_elements,
+ lcl_dbuffer, lcl_qbuffer);
+ break;
+ case RTE_ML_IO_TYPE_INT16:
+ ret = rte_ml_io_float32_to_int16(model->metadata.input[i].qscale,
+ model->addr.input[i].nb_elements,
+ lcl_dbuffer, lcl_qbuffer);
+ break;
+ case RTE_ML_IO_TYPE_UINT16:
+ ret = rte_ml_io_float32_to_uint16(model->metadata.input[i].qscale,
+ model->addr.input[i].nb_elements,
+ lcl_dbuffer, lcl_qbuffer);
+ break;
+ case RTE_ML_IO_TYPE_FP16:
+ ret = rte_ml_io_float32_to_float16(model->addr.input[i].nb_elements,
+ lcl_dbuffer, lcl_qbuffer);
+ break;
+ default:
+ plt_err("Unsupported model_input_type[%u] : %u", i,
+ model->metadata.input[i].model_input_type);
+ ret = -ENOTSUP;
+ }
+ if (ret < 0)
+ return ret;
+ }
+
+ lcl_dbuffer += model->addr.input[i].sz_d;
+ lcl_qbuffer += model->addr.input[i].sz_q;
+ }
+
+ batch_id++;
+ if (batch_id < PLT_DIV_CEIL(nb_batches, model->batch_size))
+ goto next_batch;
+
+ return 0;
+}
+
+static int
+cn10k_ml_io_dequantize(struct rte_ml_dev *dev, uint16_t model_id, uint16_t nb_batches,
+ void *qbuffer, void *dbuffer)
+{
+ struct cn10k_ml_model *model;
+ uint8_t *lcl_qbuffer;
+ uint8_t *lcl_dbuffer;
+ uint32_t batch_id;
+ uint32_t i;
+ int ret;
+
+ model = dev->data->models[model_id];
+
+ if (model == NULL) {
+ plt_err("Invalid model_id = %u", model_id);
+ return -EINVAL;
+ }
+
+ lcl_dbuffer = dbuffer;
+ lcl_qbuffer = qbuffer;
+ batch_id = 0;
+
+next_batch:
+ for (i = 0; i < model->metadata.model.num_output; i++) {
+ if (model->metadata.output[i].output_type ==
+ model->metadata.output[i].model_output_type) {
+ rte_memcpy(lcl_dbuffer, lcl_qbuffer, model->addr.output[i].sz_q);
+ } else {
+ switch (model->metadata.output[i].model_output_type) {
+ case RTE_ML_IO_TYPE_INT8:
+ ret = rte_ml_io_int8_to_float32(model->metadata.output[i].dscale,
+ model->addr.output[i].nb_elements,
+ lcl_qbuffer, lcl_dbuffer);
+ break;
+ case RTE_ML_IO_TYPE_UINT8:
+ ret = rte_ml_io_uint8_to_float32(model->metadata.output[i].dscale,
+ model->addr.output[i].nb_elements,
+ lcl_qbuffer, lcl_dbuffer);
+ break;
+ case RTE_ML_IO_TYPE_INT16:
+ ret = rte_ml_io_int16_to_float32(model->metadata.output[i].dscale,
+ model->addr.output[i].nb_elements,
+ lcl_qbuffer, lcl_dbuffer);
+ break;
+ case RTE_ML_IO_TYPE_UINT16:
+ ret = rte_ml_io_uint16_to_float32(model->metadata.output[i].dscale,
+ model->addr.output[i].nb_elements,
+ lcl_qbuffer, lcl_dbuffer);
+ break;
+ case RTE_ML_IO_TYPE_FP16:
+ ret = rte_ml_io_float16_to_float32(
+ model->addr.output[i].nb_elements, lcl_qbuffer,
+ lcl_dbuffer);
+ break;
+ default:
+ plt_err("Unsupported model_output_type[%u] : %u", i,
+ model->metadata.output[i].model_output_type);
+ ret = -ENOTSUP;
+ }
+ if (ret < 0)
+ return ret;
+ }
+
+ lcl_qbuffer += model->addr.output[i].sz_q;
+ lcl_dbuffer += model->addr.output[i].sz_d;
+ }
+
+ batch_id++;
+ if (batch_id < PLT_DIV_CEIL(nb_batches, model->batch_size))
+ goto next_batch;
+
+ return 0;
+}
+
struct rte_ml_dev_ops cn10k_ml_ops = {
/* Device control ops */
.dev_info_get = cn10k_ml_dev_info_get,
@@ -1006,4 +1155,6 @@ struct rte_ml_dev_ops cn10k_ml_ops = {
/* I/O ops */
.io_input_size_get = cn10k_ml_io_input_size_get,
.io_output_size_get = cn10k_ml_io_output_size_get,
+ .io_quantize = cn10k_ml_io_quantize,
+ .io_dequantize = cn10k_ml_io_dequantize,
};