From patchwork Tue Feb 7 16:07:03 2023 Content-Type: text/plain; charset="utf-8" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit X-Patchwork-Submitter: Srikanth Yalavarthi X-Patchwork-Id: 123342 X-Patchwork-Delegate: thomas@monjalon.net Return-Path: X-Original-To: patchwork@inbox.dpdk.org Delivered-To: patchwork@inbox.dpdk.org Received: from mails.dpdk.org (mails.dpdk.org [217.70.189.124]) by inbox.dpdk.org (Postfix) with ESMTP id 2749E41C30; Tue, 7 Feb 2023 17:10:45 +0100 (CET) Received: from mails.dpdk.org (localhost [127.0.0.1]) by mails.dpdk.org (Postfix) with ESMTP id 38B3C42F84; Tue, 7 Feb 2023 17:07:56 +0100 (CET) Received: from mx0b-0016f401.pphosted.com (mx0b-0016f401.pphosted.com [67.231.156.173]) by mails.dpdk.org (Postfix) with ESMTP id C10E742D41 for ; Tue, 7 Feb 2023 17:07:33 +0100 (CET) Received: from pps.filterd (m0045851.ppops.net [127.0.0.1]) by mx0b-0016f401.pphosted.com (8.17.1.19/8.17.1.19) with ESMTP id 317BL2vn005847 for ; Tue, 7 Feb 2023 08:07:33 -0800 DKIM-Signature: v=1; a=rsa-sha256; c=relaxed/relaxed; d=marvell.com; h=from : to : cc : subject : date : message-id : in-reply-to : references : mime-version : content-type; s=pfpt0220; bh=C4bPUhYr92S/nF2vE8AGx3IoRio0mSEO7kPfyqW87ew=; b=SdKQ93b9hh4TJeB6wxEYTtRFIgIaw81a8IcGwbWkSSGZuPicPBwty6Rzj0JlhFxspV0Y w0h71EXWcnNptZwoxhOBVKnGR/ClH/c3EDNa5+pWFwSlMbwZWqe8c1stfuKcFJeUn65t 4rknkR5tZdNm0JZBYzTue7HE3KK9HBeLlnbfPe8AxLrEKYJpxYAOdFnXwSYLT58RbpqV 30WhMX2pOTRAWISdLoGU+2TV7+vLoMbrZFA/V1IYG+poacwa+a8CzRpPBEcAUn9kTDC6 4ZNRBHYvemoM+3LAzK26kwpjlvMEKYnczDUOEfyPNX4scjH9FmcrBJuyuIq1Hn/JYgu9 Ww== Received: from dc5-exch01.marvell.com ([199.233.59.181]) by mx0b-0016f401.pphosted.com (PPS) with ESMTPS id 3nhqrtmsnd-11 (version=TLSv1.2 cipher=ECDHE-RSA-AES256-SHA384 bits=256 verify=NOT) for ; Tue, 07 Feb 2023 08:07:32 -0800 Received: from DC5-EXCH01.marvell.com (10.69.176.38) by DC5-EXCH01.marvell.com (10.69.176.38) with Microsoft SMTP Server (TLS) id 15.0.1497.42; Tue, 7 Feb 2023 08:07:28 -0800 Received: from maili.marvell.com (10.69.176.80) by DC5-EXCH01.marvell.com (10.69.176.38) with Microsoft SMTP Server id 15.0.1497.42 via Frontend Transport; Tue, 7 Feb 2023 08:07:28 -0800 Received: from ml-host-33.caveonetworks.com (unknown [10.110.143.233]) by maili.marvell.com (Postfix) with ESMTP id 5DBD63F7088; Tue, 7 Feb 2023 08:07:28 -0800 (PST) From: Srikanth Yalavarthi To: Srikanth Yalavarthi CC: , , , , , Subject: [PATCH v5 23/39] ml/cnxk: enable quantization and dequantization Date: Tue, 7 Feb 2023 08:07:03 -0800 Message-ID: <20230207160719.1307-24-syalavarthi@marvell.com> X-Mailer: git-send-email 2.17.1 In-Reply-To: <20230207160719.1307-1-syalavarthi@marvell.com> References: <20221208200220.20267-1-syalavarthi@marvell.com> <20230207160719.1307-1-syalavarthi@marvell.com> MIME-Version: 1.0 X-Proofpoint-GUID: eAak-1V55Rjns1x7k7cnyouUsePujCBV X-Proofpoint-ORIG-GUID: eAak-1V55Rjns1x7k7cnyouUsePujCBV X-Proofpoint-Virus-Version: vendor=baseguard engine=ICAP:2.0.219,Aquarius:18.0.930,Hydra:6.0.562,FMLib:17.11.122.1 definitions=2023-02-07_07,2023-02-06_03,2022-06-22_01 X-BeenThere: dev@dpdk.org X-Mailman-Version: 2.1.29 Precedence: list List-Id: DPDK patches and discussions List-Unsubscribe: , List-Archive: List-Post: List-Help: List-Subscribe: , Errors-To: dev-bounces@dpdk.org Implemented driver functions to quantize / dequantize input and output data. Support is enabled for multiple batches. Quantization / dequantization use the type conversion functions defined in ML common code. Signed-off-by: Srikanth Yalavarthi --- drivers/ml/cnxk/cn10k_ml_ops.c | 151 +++++++++++++++++++++++++++++++++ 1 file changed, 151 insertions(+) diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c index b5c89bee40..231c9b340b 100644 --- a/drivers/ml/cnxk/cn10k_ml_ops.c +++ b/drivers/ml/cnxk/cn10k_ml_ops.c @@ -5,6 +5,8 @@ #include #include +#include + #include "cn10k_ml_dev.h" #include "cn10k_ml_model.h" #include "cn10k_ml_ops.h" @@ -983,6 +985,153 @@ cn10k_ml_io_output_size_get(struct rte_ml_dev *dev, uint16_t model_id, uint32_t return 0; } +static int +cn10k_ml_io_quantize(struct rte_ml_dev *dev, uint16_t model_id, uint16_t nb_batches, void *dbuffer, + void *qbuffer) +{ + struct cn10k_ml_model *model; + uint8_t *lcl_dbuffer; + uint8_t *lcl_qbuffer; + uint32_t batch_id; + uint32_t i; + int ret; + + model = dev->data->models[model_id]; + + if (model == NULL) { + plt_err("Invalid model_id = %u", model_id); + return -EINVAL; + } + + lcl_dbuffer = dbuffer; + lcl_qbuffer = qbuffer; + batch_id = 0; + +next_batch: + for (i = 0; i < model->metadata.model.num_input; i++) { + if (model->metadata.input[i].input_type == + model->metadata.input[i].model_input_type) { + rte_memcpy(lcl_qbuffer, lcl_dbuffer, model->addr.input[i].sz_d); + } else { + switch (model->metadata.input[i].model_input_type) { + case RTE_ML_IO_TYPE_INT8: + ret = rte_ml_io_float32_to_int8(model->metadata.input[i].qscale, + model->addr.input[i].nb_elements, + lcl_dbuffer, lcl_qbuffer); + break; + case RTE_ML_IO_TYPE_UINT8: + ret = rte_ml_io_float32_to_uint8(model->metadata.input[i].qscale, + model->addr.input[i].nb_elements, + lcl_dbuffer, lcl_qbuffer); + break; + case RTE_ML_IO_TYPE_INT16: + ret = rte_ml_io_float32_to_int16(model->metadata.input[i].qscale, + model->addr.input[i].nb_elements, + lcl_dbuffer, lcl_qbuffer); + break; + case RTE_ML_IO_TYPE_UINT16: + ret = rte_ml_io_float32_to_uint16(model->metadata.input[i].qscale, + model->addr.input[i].nb_elements, + lcl_dbuffer, lcl_qbuffer); + break; + case RTE_ML_IO_TYPE_FP16: + ret = rte_ml_io_float32_to_float16(model->addr.input[i].nb_elements, + lcl_dbuffer, lcl_qbuffer); + break; + default: + plt_err("Unsupported model_input_type[%u] : %u", i, + model->metadata.input[i].model_input_type); + ret = -ENOTSUP; + } + if (ret < 0) + return ret; + } + + lcl_dbuffer += model->addr.input[i].sz_d; + lcl_qbuffer += model->addr.input[i].sz_q; + } + + batch_id++; + if (batch_id < PLT_DIV_CEIL(nb_batches, model->batch_size)) + goto next_batch; + + return 0; +} + +static int +cn10k_ml_io_dequantize(struct rte_ml_dev *dev, uint16_t model_id, uint16_t nb_batches, + void *qbuffer, void *dbuffer) +{ + struct cn10k_ml_model *model; + uint8_t *lcl_qbuffer; + uint8_t *lcl_dbuffer; + uint32_t batch_id; + uint32_t i; + int ret; + + model = dev->data->models[model_id]; + + if (model == NULL) { + plt_err("Invalid model_id = %u", model_id); + return -EINVAL; + } + + lcl_dbuffer = dbuffer; + lcl_qbuffer = qbuffer; + batch_id = 0; + +next_batch: + for (i = 0; i < model->metadata.model.num_output; i++) { + if (model->metadata.output[i].output_type == + model->metadata.output[i].model_output_type) { + rte_memcpy(lcl_dbuffer, lcl_qbuffer, model->addr.output[i].sz_q); + } else { + switch (model->metadata.output[i].model_output_type) { + case RTE_ML_IO_TYPE_INT8: + ret = rte_ml_io_int8_to_float32(model->metadata.output[i].dscale, + model->addr.output[i].nb_elements, + lcl_qbuffer, lcl_dbuffer); + break; + case RTE_ML_IO_TYPE_UINT8: + ret = rte_ml_io_uint8_to_float32(model->metadata.output[i].dscale, + model->addr.output[i].nb_elements, + lcl_qbuffer, lcl_dbuffer); + break; + case RTE_ML_IO_TYPE_INT16: + ret = rte_ml_io_int16_to_float32(model->metadata.output[i].dscale, + model->addr.output[i].nb_elements, + lcl_qbuffer, lcl_dbuffer); + break; + case RTE_ML_IO_TYPE_UINT16: + ret = rte_ml_io_uint16_to_float32(model->metadata.output[i].dscale, + model->addr.output[i].nb_elements, + lcl_qbuffer, lcl_dbuffer); + break; + case RTE_ML_IO_TYPE_FP16: + ret = rte_ml_io_float16_to_float32( + model->addr.output[i].nb_elements, lcl_qbuffer, + lcl_dbuffer); + break; + default: + plt_err("Unsupported model_output_type[%u] : %u", i, + model->metadata.output[i].model_output_type); + ret = -ENOTSUP; + } + if (ret < 0) + return ret; + } + + lcl_qbuffer += model->addr.output[i].sz_q; + lcl_dbuffer += model->addr.output[i].sz_d; + } + + batch_id++; + if (batch_id < PLT_DIV_CEIL(nb_batches, model->batch_size)) + goto next_batch; + + return 0; +} + struct rte_ml_dev_ops cn10k_ml_ops = { /* Device control ops */ .dev_info_get = cn10k_ml_dev_info_get, @@ -1006,4 +1155,6 @@ struct rte_ml_dev_ops cn10k_ml_ops = { /* I/O ops */ .io_input_size_get = cn10k_ml_io_input_size_get, .io_output_size_get = cn10k_ml_io_output_size_get, + .io_quantize = cn10k_ml_io_quantize, + .io_dequantize = cn10k_ml_io_dequantize, };