From patchwork Thu Dec 8 19:35:29 2022 Content-Type: text/plain; charset="utf-8" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit X-Patchwork-Submitter: Srikanth Yalavarthi X-Patchwork-Id: 120596 X-Patchwork-Delegate: thomas@monjalon.net Return-Path: X-Original-To: patchwork@inbox.dpdk.org Delivered-To: patchwork@inbox.dpdk.org Received: from mails.dpdk.org (mails.dpdk.org [217.70.189.124]) by inbox.dpdk.org (Postfix) with ESMTP id EFF21A0032; Thu, 8 Dec 2022 20:35:43 +0100 (CET) Received: from mails.dpdk.org (localhost [127.0.0.1]) by mails.dpdk.org (Postfix) with ESMTP id 061CE42D20; Thu, 8 Dec 2022 20:35:40 +0100 (CET) Received: from mx0b-0016f401.pphosted.com (mx0a-0016f401.pphosted.com [67.231.148.174]) by mails.dpdk.org (Postfix) with ESMTP id 3BDCB4003F for ; Thu, 8 Dec 2022 20:35:38 +0100 (CET) Received: from pps.filterd (m0045849.ppops.net [127.0.0.1]) by mx0a-0016f401.pphosted.com (8.17.1.19/8.17.1.19) with ESMTP id 2B8J8KK7001363; Thu, 8 Dec 2022 11:35:37 -0800 DKIM-Signature: v=1; a=rsa-sha256; c=relaxed/relaxed; d=marvell.com; h=from : to : cc : subject : date : message-id : in-reply-to : references : mime-version : content-type; s=pfpt0220; bh=1y3WPsisNh4t07H03WKBH2dZlrXQVhGoWJ6xFz51WU4=; b=aLkPxIUmrDlsxGgE82+bTXEzHgcOAhsgHm9j4yOSez3oXjEk0tEVK919z75rI3zaRA36 5JhihEwZJDiaz6/PE2ztqEwOi30TyMM9zevl50WP07MOFd2SnnklFKm7iNbLU/VPyYQe Z53PYaFFyG0LYXSu1gBmrwZxqzqwsOX2W9oC2ddaJ49GzHkwIxOXQflXuf3ah1LxCuFP 17MH6TUqjsr1wu6ytxF8cInaezIkOUioi6CWWICyrWZRgO7DglwhgUvd3bmyW3VvOPKN TbxkCllQFZOi6PPoVjE9X75TNkz8g6cdfcadRB+X3o1Z6iAnrREiJAJzI6x6vZRCSJXJ AQ== Received: from dc5-exch01.marvell.com ([199.233.59.181]) by mx0a-0016f401.pphosted.com (PPS) with ESMTPS id 3mb22svkjj-2 (version=TLSv1.2 cipher=ECDHE-RSA-AES256-SHA384 bits=256 verify=NOT); Thu, 08 Dec 2022 11:35:37 -0800 Received: from DC5-EXCH02.marvell.com (10.69.176.39) by DC5-EXCH01.marvell.com (10.69.176.38) with Microsoft SMTP Server (TLS) id 15.0.1497.2; Thu, 8 Dec 2022 11:35:35 -0800 Received: from maili.marvell.com (10.69.176.80) by DC5-EXCH02.marvell.com (10.69.176.39) with Microsoft SMTP Server id 15.0.1497.18 via Frontend Transport; Thu, 8 Dec 2022 11:35:35 -0800 Received: from ml-host-33.caveonetworks.com (unknown [10.110.143.233]) by maili.marvell.com (Postfix) with ESMTP id 30C553F7066; Thu, 8 Dec 2022 11:35:35 -0800 (PST) From: Srikanth Yalavarthi To: Thomas Monjalon , Srikanth Yalavarthi CC: , , , Subject: [PATCH v1 1/4] common/ml: add initial files for ML common code Date: Thu, 8 Dec 2022 11:35:29 -0800 Message-ID: <20221208193532.16718-2-syalavarthi@marvell.com> X-Mailer: git-send-email 2.17.1 In-Reply-To: <20221208193532.16718-1-syalavarthi@marvell.com> References: <20221208193532.16718-1-syalavarthi@marvell.com> MIME-Version: 1.0 X-Proofpoint-ORIG-GUID: 4jt1QvDG_IT20M5RHyn4EG0sjFSUXwf0 X-Proofpoint-GUID: 4jt1QvDG_IT20M5RHyn4EG0sjFSUXwf0 X-Proofpoint-Virus-Version: vendor=baseguard engine=ICAP:2.0.205,Aquarius:18.0.923,Hydra:6.0.545,FMLib:17.11.122.1 definitions=2022-12-08_11,2022-12-08_01,2022-06-22_01 X-BeenThere: dev@dpdk.org X-Mailman-Version: 2.1.29 Precedence: list List-Id: DPDK patches and discussions List-Unsubscribe: , List-Archive: List-Post: List-Help: List-Subscribe: , Errors-To: dev-bounces@dpdk.org Added initial files for common ML driver code. Implemented ML type to size conversion, type to string and format to string conversion utility functions. Signed-off-by: Srikanth Yalavarthi --- Depends-on: series-26046 ("app/mldev: implement test framework for mldev") MAINTAINERS | 8 +++ drivers/common/meson.build | 1 + drivers/common/ml/meson.build | 20 +++++++ drivers/common/ml/ml_utils.c | 110 ++++++++++++++++++++++++++++++++++ drivers/common/ml/ml_utils.h | 50 ++++++++++++++++ drivers/common/ml/version.map | 9 +++ 6 files changed, 198 insertions(+) create mode 100644 drivers/common/ml/meson.build create mode 100644 drivers/common/ml/ml_utils.c create mode 100644 drivers/common/ml/ml_utils.h create mode 100644 drivers/common/ml/version.map -- 2.17.1 diff --git a/MAINTAINERS b/MAINTAINERS index 5fa276fafa..6412209bff 100644 --- a/MAINTAINERS +++ b/MAINTAINERS @@ -1431,6 +1431,14 @@ F: drivers/raw/dpaa2_cmdif/ F: doc/guides/rawdevs/dpaa2_cmdif.rst +ML Device Drivers +------------------------ + +ML common code +M: Srikanth Yalavarthi +F: drivers/common/ml/ + + Packet processing ----------------- diff --git a/drivers/common/meson.build b/drivers/common/meson.build index b63d899d50..0878dde0a0 100644 --- a/drivers/common/meson.build +++ b/drivers/common/meson.build @@ -9,4 +9,5 @@ drivers = [ 'idpf', 'mvep', 'octeontx', + 'ml', ] diff --git a/drivers/common/ml/meson.build b/drivers/common/ml/meson.build new file mode 100644 index 0000000000..2749ab6c2e --- /dev/null +++ b/drivers/common/ml/meson.build @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2022 Marvell. + +if not is_linux or not dpdk_conf.get('RTE_ARCH_64') + build = false + reason = 'only supported on 64-bit Linux' + subdir_done() +endif + +headers = files( + 'ml_utils.h', +) + +sources = files( + 'ml_utils.c', +) + +deps += ['mldev'] + +pmd_supports_disable_iova_as_pa = true diff --git a/drivers/common/ml/ml_utils.c b/drivers/common/ml/ml_utils.c new file mode 100644 index 0000000000..45c1f76a54 --- /dev/null +++ b/drivers/common/ml/ml_utils.c @@ -0,0 +1,110 @@ +/* SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2022 Marvell. + */ + +#include + +#include "ml_utils.h" + +int +ml_io_type_size_get(enum rte_ml_io_type type) +{ + switch (type) { + case RTE_ML_IO_TYPE_UNKNOWN: + return -EINVAL; + case RTE_ML_IO_TYPE_INT8: + return sizeof(int8_t); + case RTE_ML_IO_TYPE_UINT8: + return sizeof(uint8_t); + case RTE_ML_IO_TYPE_INT16: + return sizeof(int16_t); + case RTE_ML_IO_TYPE_UINT16: + return sizeof(uint16_t); + case RTE_ML_IO_TYPE_INT32: + return sizeof(int32_t); + case RTE_ML_IO_TYPE_UINT32: + return sizeof(uint32_t); + case RTE_ML_IO_TYPE_FP8: + return sizeof(uint8_t); + case RTE_ML_IO_TYPE_FP16: + return sizeof(uint8_t) * 2; + case RTE_ML_IO_TYPE_FP32: + return sizeof(uint8_t) * 4; + case RTE_ML_IO_TYPE_BFLOAT16: + return sizeof(uint8_t) * 2; + default: + return -EINVAL; + } +} + +void +ml_io_type_to_str(enum rte_ml_io_type type, char *str, int len) +{ + switch (type) { + case RTE_ML_IO_TYPE_UNKNOWN: + rte_strlcpy(str, "unknown", len); + break; + case RTE_ML_IO_TYPE_INT8: + rte_strlcpy(str, "int8", len); + break; + case RTE_ML_IO_TYPE_UINT8: + rte_strlcpy(str, "uint8", len); + break; + case RTE_ML_IO_TYPE_INT16: + rte_strlcpy(str, "int16", len); + break; + case RTE_ML_IO_TYPE_UINT16: + rte_strlcpy(str, "uint16", len); + break; + case RTE_ML_IO_TYPE_INT32: + rte_strlcpy(str, "int32", len); + break; + case RTE_ML_IO_TYPE_UINT32: + rte_strlcpy(str, "uint32", len); + break; + case RTE_ML_IO_TYPE_FP8: + rte_strlcpy(str, "float8", len); + break; + case RTE_ML_IO_TYPE_FP16: + rte_strlcpy(str, "float16", len); + break; + case RTE_ML_IO_TYPE_FP32: + rte_strlcpy(str, "float32", len); + break; + case RTE_ML_IO_TYPE_BFLOAT16: + rte_strlcpy(str, "bfloat16", len); + break; + default: + rte_strlcpy(str, "invalid", len); + } +} + +void +ml_io_format_to_str(enum rte_ml_io_format format, char *str, int len) +{ + switch (format) { + case RTE_ML_IO_FORMAT_NCHW: + rte_strlcpy(str, "NCHW", len); + break; + case RTE_ML_IO_FORMAT_NHWC: + rte_strlcpy(str, "NHWC", len); + break; + case RTE_ML_IO_FORMAT_CHWN: + rte_strlcpy(str, "CHWN", len); + break; + case RTE_ML_IO_FORMAT_3D: + rte_strlcpy(str, "3D", len); + break; + case RTE_ML_IO_FORMAT_2D: + rte_strlcpy(str, "Matrix", len); + break; + case RTE_ML_IO_FORMAT_1D: + rte_strlcpy(str, "Vector", len); + break; + case RTE_ML_IO_FORMAT_SCALAR: + rte_strlcpy(str, "Scalar", len); + break; + default: + rte_strlcpy(str, "invalid", len); + } +} diff --git a/drivers/common/ml/ml_utils.h b/drivers/common/ml/ml_utils.h new file mode 100644 index 0000000000..b6adb98e04 --- /dev/null +++ b/drivers/common/ml/ml_utils.h @@ -0,0 +1,50 @@ +/* SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2022 Marvell. + */ + +#ifndef _ML_UTILS_H_ +#define _ML_UTILS_H_ + +#include +#include + +/** + * Get the size an ML IO type in bytes. + * + * @param[in] type + * Enumeration of ML IO data type. + * + * @return + * - > 0, Size of the data type in bytes. + * - < 0, Error code on failure. + */ +__rte_internal +int ml_io_type_size_get(enum rte_ml_io_type type); + +/** + * Get the name of an ML IO type. + * + * @param[in] type + * Enumeration of ML IO data type. + * @param[in] str + * Address of character array. + * @param[in] len + * Length of character array. + */ +__rte_internal +void ml_io_type_to_str(enum rte_ml_io_type type, char *str, int len); + +/** + * Get the name of an ML IO format. + * + * @param[in] type + * Enumeration of ML IO format. + * @param[in] str + * Address of character array. + * @param[in] len + * Length of character array. + */ +__rte_internal +void ml_io_format_to_str(enum rte_ml_io_format format, char *str, int len); + +#endif /*_ML_UTILS_H_ */ diff --git a/drivers/common/ml/version.map b/drivers/common/ml/version.map new file mode 100644 index 0000000000..7e33755f2f --- /dev/null +++ b/drivers/common/ml/version.map @@ -0,0 +1,9 @@ +INTERNAL { + global: + + ml_io_type_size_get; + ml_io_type_to_str; + ml_io_format_to_str; + + local: *; +}; From patchwork Thu Dec 8 19:35:30 2022 Content-Type: text/plain; charset="utf-8" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit X-Patchwork-Submitter: Srikanth Yalavarthi X-Patchwork-Id: 120597 X-Patchwork-Delegate: thomas@monjalon.net Return-Path: X-Original-To: patchwork@inbox.dpdk.org Delivered-To: patchwork@inbox.dpdk.org Received: from mails.dpdk.org (mails.dpdk.org [217.70.189.124]) by inbox.dpdk.org (Postfix) with ESMTP id 28A6DA0032; Thu, 8 Dec 2022 20:35:49 +0100 (CET) Received: from mails.dpdk.org (localhost [127.0.0.1]) by mails.dpdk.org (Postfix) with ESMTP id DAC9D42D27; Thu, 8 Dec 2022 20:35:40 +0100 (CET) Received: from mx0b-0016f401.pphosted.com (mx0a-0016f401.pphosted.com [67.231.148.174]) by mails.dpdk.org (Postfix) with ESMTP id 8FBEA410D2 for ; Thu, 8 Dec 2022 20:35:38 +0100 (CET) Received: from pps.filterd (m0045849.ppops.net [127.0.0.1]) by mx0a-0016f401.pphosted.com (8.17.1.19/8.17.1.19) with ESMTP id 2B8J8KK8001363 for ; Thu, 8 Dec 2022 11:35:37 -0800 DKIM-Signature: v=1; a=rsa-sha256; c=relaxed/relaxed; d=marvell.com; h=from : to : cc : subject : date : message-id : in-reply-to : references : mime-version : content-type; s=pfpt0220; bh=YjhSfWFRzYTCB8hV5Vng6uczzkRf82PsVucbQYWLrQY=; b=h8ECJBuEGg7e0c8vtY9AKrXUwC38cvEe9D1O+z89GoJLPOIze/BM4YnWh1FW7cvbxDps oj/E0AzVIWJVQWcuHLvjY9gN5GukWsIbR1ijJ+lp/R+/genLWXSc29m82swaaYmDmQUm RSnNh4bRp7CTcD7qBOTf8KXMmEIoIj81dNic57JUzofRdS3xAzPqdbK0T5Jjh7CoWAXQ uIjHN8mpcHZ6NhYf/1q3bGD0xgn50DrDtRNKXQnXFRrDNU5Y29vYAr9DAlAVQE6thza0 Zwf4XQVvS1LqV5t+n2MQde/7c01MrfpDN/wjvrhtdktvvBOymwNLo4YI5IX4egNAP1HQ Pw== Received: from dc5-exch01.marvell.com ([199.233.59.181]) by mx0a-0016f401.pphosted.com (PPS) with ESMTPS id 3mb22svkjj-3 (version=TLSv1.2 cipher=ECDHE-RSA-AES256-SHA384 bits=256 verify=NOT) for ; Thu, 08 Dec 2022 11:35:37 -0800 Received: from DC5-EXCH02.marvell.com (10.69.176.39) by DC5-EXCH01.marvell.com (10.69.176.38) with Microsoft SMTP Server (TLS) id 15.0.1497.2; Thu, 8 Dec 2022 11:35:35 -0800 Received: from maili.marvell.com (10.69.176.80) by DC5-EXCH02.marvell.com (10.69.176.39) with Microsoft SMTP Server id 15.0.1497.18 via Frontend Transport; Thu, 8 Dec 2022 11:35:35 -0800 Received: from ml-host-33.caveonetworks.com (unknown [10.110.143.233]) by maili.marvell.com (Postfix) with ESMTP id 82E273F7072; Thu, 8 Dec 2022 11:35:35 -0800 (PST) From: Srikanth Yalavarthi To: Srikanth Yalavarthi CC: , , , Subject: [PATCH v1 2/4] common/ml: add data type conversion routines Date: Thu, 8 Dec 2022 11:35:30 -0800 Message-ID: <20221208193532.16718-3-syalavarthi@marvell.com> X-Mailer: git-send-email 2.17.1 In-Reply-To: <20221208193532.16718-1-syalavarthi@marvell.com> References: <20221208193532.16718-1-syalavarthi@marvell.com> MIME-Version: 1.0 X-Proofpoint-ORIG-GUID: xk5HmOpxlL8hOfy7cG-_eVrW8WQiSGZX X-Proofpoint-GUID: xk5HmOpxlL8hOfy7cG-_eVrW8WQiSGZX X-Proofpoint-Virus-Version: vendor=baseguard engine=ICAP:2.0.205,Aquarius:18.0.923,Hydra:6.0.545,FMLib:17.11.122.1 definitions=2022-12-08_11,2022-12-08_01,2022-06-22_01 X-BeenThere: dev@dpdk.org X-Mailman-Version: 2.1.29 Precedence: list List-Id: DPDK patches and discussions List-Unsubscribe: , List-Archive: List-Post: List-Help: List-Subscribe: , Errors-To: dev-bounces@dpdk.org Type conversion routines transform data from higher to lower precision data types or vice-versa. These conversion functions can be used by the ML driver implementations for quantization and de-quantization. Added driver routines for type conversion. These driver routines invoke the architecture specific functions. Signed-off-by: Srikanth Yalavarthi --- drivers/common/ml/ml_utils.c | 132 +++++++++++++++++++ drivers/common/ml/ml_utils.h | 233 ++++++++++++++++++++++++++++++++++ drivers/common/ml/version.map | 16 +++ 3 files changed, 381 insertions(+) diff --git a/drivers/common/ml/ml_utils.c b/drivers/common/ml/ml_utils.c index 45c1f76a54..553e906172 100644 --- a/drivers/common/ml/ml_utils.c +++ b/drivers/common/ml/ml_utils.c @@ -2,6 +2,10 @@ * Copyright (c) 2022 Marvell. */ +#include +#include + +#include #include #include "ml_utils.h" @@ -108,3 +112,131 @@ ml_io_format_to_str(enum rte_ml_io_format format, char *str, int len) rte_strlcpy(str, "invalid", len); } } + +int +ml_float32_to_int8(float scale, uint64_t nb_elements, void *input, void *output) +{ + RTE_SET_USED(scale); + RTE_SET_USED(nb_elements); + RTE_SET_USED(input); + RTE_SET_USED(output); + + return -ENOTSUP; +} + +int +ml_int8_to_float32(float scale, uint64_t nb_elements, void *input, void *output) +{ + RTE_SET_USED(scale); + RTE_SET_USED(nb_elements); + RTE_SET_USED(input); + RTE_SET_USED(output); + + return -ENOTSUP; +} + +int +ml_float32_to_uint8(float scale, uint64_t nb_elements, void *input, void *output) +{ + RTE_SET_USED(scale); + RTE_SET_USED(nb_elements); + RTE_SET_USED(input); + RTE_SET_USED(output); + + return -ENOTSUP; +} + +int +ml_uint8_to_float32(float scale, uint64_t nb_elements, void *input, void *output) +{ + RTE_SET_USED(scale); + RTE_SET_USED(nb_elements); + RTE_SET_USED(input); + RTE_SET_USED(output); + + return -ENOTSUP; +} + +int +ml_float32_to_int16(float scale, uint64_t nb_elements, void *input, void *output) +{ + RTE_SET_USED(scale); + RTE_SET_USED(nb_elements); + RTE_SET_USED(input); + RTE_SET_USED(output); + + return -ENOTSUP; +} + +int +ml_int16_to_float32(float scale, uint64_t nb_elements, void *input, void *output) +{ + RTE_SET_USED(scale); + RTE_SET_USED(nb_elements); + RTE_SET_USED(input); + RTE_SET_USED(output); + + return -ENOTSUP; +} + +int +ml_float32_to_uint16(float scale, uint64_t nb_elements, void *input, void *output) +{ + RTE_SET_USED(scale); + RTE_SET_USED(nb_elements); + RTE_SET_USED(input); + RTE_SET_USED(output); + + return -ENOTSUP; +} + +int +ml_uint16_to_float32(float scale, uint64_t nb_elements, void *input, void *output) +{ + RTE_SET_USED(scale); + RTE_SET_USED(nb_elements); + RTE_SET_USED(input); + RTE_SET_USED(output); + + return -ENOTSUP; +} + +int +ml_float32_to_float16(uint64_t nb_elements, void *input, void *output) +{ + RTE_SET_USED(nb_elements); + RTE_SET_USED(input); + RTE_SET_USED(output); + + return -ENOTSUP; +} + +int +ml_float16_to_float32(uint64_t nb_elements, void *input, void *output) +{ + RTE_SET_USED(nb_elements); + RTE_SET_USED(input); + RTE_SET_USED(output); + + return -ENOTSUP; +} + +int +ml_float32_to_bfloat16(uint64_t nb_elements, void *input, void *output) +{ + RTE_SET_USED(nb_elements); + RTE_SET_USED(input); + RTE_SET_USED(output); + + return -ENOTSUP; +} + +int +ml_bfloat16_to_float32(uint64_t nb_elements, void *input, void *output) +{ + RTE_SET_USED(nb_elements); + RTE_SET_USED(input); + RTE_SET_USED(output); + + return -ENOTSUP; +} diff --git a/drivers/common/ml/ml_utils.h b/drivers/common/ml/ml_utils.h index b6adb98e04..9726c6e3b5 100644 --- a/drivers/common/ml/ml_utils.h +++ b/drivers/common/ml/ml_utils.h @@ -47,4 +47,237 @@ void ml_io_type_to_str(enum rte_ml_io_type type, char *str, int len); __rte_internal void ml_io_format_to_str(enum rte_ml_io_format format, char *str, int len); +/** + * Convert a buffer containing numbers in single precision floating format (float32) to signed 8-bit + * integer format (INT8). + * + * @param[in] scale + * Scale factor for conversion. + * @param[in] nb_elements + * Number of elements in the buffer. + * @param[in] input + * Input buffer containing float32 numbers. Size of buffer is equal to (nb_elements * 4) bytes. + * @param[out] output + * Output buffer to store INT8 numbers. Size of buffer is equal to (nb_elements * 1) bytes. + * + * @return + * - 0, Success. + * - < 0, Error code on failure. + */ +__rte_internal +int ml_float32_to_int8(float scale, uint64_t nb_elements, void *input, void *output); + +/** + * Convert a buffer containing numbers in signed 8-bit integer format (INT8) to single precision + * floating format (float32). + * + * @param[in] scale + * Scale factor for conversion. + * @param[in] nb_elements + * Number of elements in the buffer. + * @param[in] input + * Input buffer containing INT8 numbers. Size of buffer is equal to (nb_elements * 1) bytes. + * @param[out] output + * Output buffer to store float32 numbers. Size of buffer is equal to (nb_elements * 4) bytes. + * + * @return + * - 0, Success. + * - < 0, Error code on failure. + */ +__rte_internal +int ml_int8_to_float32(float scale, uint64_t nb_elements, void *input, void *output); + +/** + * Convert a buffer containing numbers in single precision floating format (float32) to unsigned + * 8-bit integer format (UINT8). + * + * @param[in] scale + * Scale factor for conversion. + * @param[in] nb_elements + * Number of elements in the buffer. + * @param[in] input + * Input buffer containing float32 numbers. Size of buffer is equal to (nb_elements * 4) bytes. + * @param[out] output + * Output buffer to store UINT8 numbers. Size of buffer is equal to (nb_elements * 1) bytes. + * + * @return + * - 0, Success. + * - < 0, Error code on failure. + */ +__rte_internal +int ml_float32_to_uint8(float scale, uint64_t nb_elements, void *input, void *output); + +/** + * Convert a buffer containing numbers in unsigned 8-bit integer format (UINT8) to single precision + * floating format (float32). + * + * @param[in] scale + * Scale factor for conversion. + * @param[in] nb_elements + * Number of elements in the buffer. + * @param[in] input + * Input buffer containing UINT8 numbers. Size of buffer is equal to (nb_elements * 1) bytes. + * @param[out] output + * Output buffer to store float32 numbers. Size of buffer is equal to (nb_elements * 4) bytes. + * + * @return + * - 0, Success. + * - < 0, Error code on failure. + */ +__rte_internal +int ml_uint8_to_float32(float scale, uint64_t nb_elements, void *input, void *output); + +/** + * Convert a buffer containing numbers in single precision floating format (float32) to signed + * 16-bit integer format (INT16). + * + * @param[in] scale + * Scale factor for conversion. + * @param[in] nb_elements + * Number of elements in the buffer. + * @param[in] input + * Input buffer containing float32 numbers. Size of buffer is equal to (nb_elements * 4) bytes. + * @param[out] output + * Output buffer to store INT16 numbers. Size of buffer is equal to (nb_elements * 2) bytes. + * + * @return + * - 0, Success. + * - < 0, Error code on failure. + */ +__rte_internal +int ml_float32_to_int16(float scale, uint64_t nb_elements, void *input, void *output); + +/** + * Convert a buffer containing numbers in signed 16-bit integer format (INT16) to single precision + * floating format (float32). + * + * @param[in] scale + * Scale factor for conversion. + * @param[in] nb_elements + * Number of elements in the buffer. + * @param[in] input + * Input buffer containing INT16 numbers. Size of buffer is equal to (nb_elements * 2) bytes. + * @param[out] output + * Output buffer to store float32 numbers. Size of buffer is equal to (nb_elements * 4) bytes. + * + * @return + * - 0, Success. + * - < 0, Error code on failure. + */ +__rte_internal +int ml_int16_to_float32(float scale, uint64_t nb_elements, void *input, void *output); + +/** + * Convert a buffer containing numbers in single precision floating format (float32) to unsigned + * 16-bit integer format (UINT16). + * + * @param[in] scale + * Scale factor for conversion. + * @param[in] nb_elements + * Number of elements in the buffer. + * @param[in] input + * Input buffer containing float32 numbers. Size of buffer is equal to (nb_elements * 4) bytes. + * @param[out] output + * Output buffer to store UINT16 numbers. Size of buffer is equal to (nb_elements * 2) bytes. + * + * @return + * - 0, Success. + * - < 0, Error code on failure. + */ +__rte_internal +int ml_float32_to_uint16(float scale, uint64_t nb_elements, void *input, void *output); + +/** + * Convert a buffer containing numbers in unsigned 16-bit integer format (UINT16) to single + * precision floating format (float32). + * + * @param[in] scale + * Scale factor for conversion. + * @param[in] nb_elements + * Number of elements in the buffer. + * @param[in] input + * Input buffer containing UINT16 numbers. Size of buffer is equal to (nb_elements * 2) bytes. + * @param[out] output + * Output buffer to store float32 numbers. Size of buffer is equal to (nb_elements * 4) bytes. + * + * @return + * - 0, Success. + * - < 0, Error code on failure. + */ +__rte_internal +int ml_uint16_to_float32(float scale, uint64_t nb_elements, void *input, void *output); + +/** + * Convert a buffer containing numbers in single precision floating format (float32) to half + * precision floating point format (FP16). + * + * @param[in] nb_elements + * Number of elements in the buffer. + * @param[in] input + * Input buffer containing float32 numbers. Size of buffer is equal to (nb_elements *4) bytes. + * @param[out] output + * Output buffer to store float16 numbers. Size of buffer is equal to (nb_elements * 2) bytes. + * + * @return + * - 0, Success. + * - < 0, Error code on failure. + */ +__rte_internal +int ml_float32_to_float16(uint64_t nb_elements, void *input, void *output); + +/** + * Convert a buffer containing numbers in half precision floating format (FP16) to single precision + * floating point format (float32). + * + * @param[in] nb_elements + * Number of elements in the buffer. + * @param[in] input + * Input buffer containing float16 numbers. Size of buffer is equal to (nb_elements * 2) bytes. + * @param[out] output + * Output buffer to store float32 numbers. Size of buffer is equal to (nb_elements * 4) bytes. + * + * @return + * - 0, Success. + * - < 0, Error code on failure. + */ +__rte_internal +int ml_float16_to_float32(uint64_t nb_elements, void *input, void *output); + +/** + * Convert a buffer containing numbers in single precision floating format (float32) to brain + * floating point format (bfloat16). + * + * @param[in] nb_elements + * Number of elements in the buffer. + * @param[in] input + * Input buffer containing float32 numbers. Size of buffer is equal to (nb_elements *4) bytes. + * @param[out] output + * Output buffer to store bfloat16 numbers. Size of buffer is equal to (nb_elements * 2) bytes. + * + * @return + * - 0, Success. + * - < 0, Error code on failure. + */ +__rte_internal +int ml_float32_to_bfloat16(uint64_t nb_elements, void *input, void *output); + +/** + * Convert a buffer containing numbers in brain floating point format (bfloat16) to single precision + * floating point format (float32). + * + * @param[in] nb_elements + * Number of elements in the buffer. + * @param[in] input + * Input buffer containing bfloat16 numbers. Size of buffer is equal to (nb_elements * 2) + * bytes. + * @param[out] output + * Output buffer to store float32 numbers. Size of buffer is equal to (nb_elements * 4) bytes. + * + * @return + * - 0, Success. + * - < 0, Error code on failure. + */ +__rte_internal +int ml_bfloat16_to_float32(uint64_t nb_elements, void *input, void *output); + #endif /*_ML_UTILS_H_ */ diff --git a/drivers/common/ml/version.map b/drivers/common/ml/version.map index 7e33755f2f..35f270f637 100644 --- a/drivers/common/ml/version.map +++ b/drivers/common/ml/version.map @@ -5,5 +5,21 @@ INTERNAL { ml_io_type_to_str; ml_io_format_to_str; + ml_float32_to_int8; + ml_int8_to_float32; + ml_float32_to_uint8; + ml_uint8_to_float32; + + ml_float32_to_int16; + ml_int16_to_float32; + ml_float32_to_uint16; + ml_uint16_to_float32; + + ml_float32_to_float16; + ml_float16_to_float32; + + ml_float32_to_bfloat16; + ml_bfloat16_to_float32; + local: *; }; From patchwork Thu Dec 8 19:35:31 2022 Content-Type: text/plain; charset="utf-8" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit X-Patchwork-Submitter: Srikanth Yalavarthi X-Patchwork-Id: 120598 X-Patchwork-Delegate: thomas@monjalon.net Return-Path: X-Original-To: patchwork@inbox.dpdk.org Delivered-To: patchwork@inbox.dpdk.org Received: from mails.dpdk.org (mails.dpdk.org [217.70.189.124]) by inbox.dpdk.org (Postfix) with ESMTP id 5BF36A0032; Thu, 8 Dec 2022 20:35:54 +0100 (CET) Received: from mails.dpdk.org (localhost [127.0.0.1]) by mails.dpdk.org (Postfix) with ESMTP id BED6D410D2; Thu, 8 Dec 2022 20:35:41 +0100 (CET) Received: from mx0b-0016f401.pphosted.com (mx0a-0016f401.pphosted.com [67.231.148.174]) by mails.dpdk.org (Postfix) with ESMTP id 1C8824003F for ; Thu, 8 Dec 2022 20:35:38 +0100 (CET) Received: from pps.filterd (m0045849.ppops.net [127.0.0.1]) by mx0a-0016f401.pphosted.com (8.17.1.19/8.17.1.19) with ESMTP id 2B8J8KK9001363 for ; Thu, 8 Dec 2022 11:35:38 -0800 DKIM-Signature: v=1; a=rsa-sha256; c=relaxed/relaxed; d=marvell.com; h=from : to : cc : subject : date : message-id : in-reply-to : references : mime-version : content-type; s=pfpt0220; bh=60xifUUqMC6VVtpQj+UHSzLU0skaXKs8sOwtPquPFd0=; b=F1QyXI0WOdErazUxVLm3yNlr5QST4LZK/y2EagggTkyo/Am2u2dg/nXc12Q4eqSMpaLf Cr2xjWpoTc5uZUQQO43vBaewBYwqXZoyocs5lhtoH4jHN6i9GuNsm5NxtLTgHMkS4HLB fhWhjWjjdGwwQFJL+nYbi602YNVVpUXiD6aohxF5zRrACWrDCEh9ib07vubQHembrLBm ZnmSGzDdNEgYOweGW6QeSXvFJuvMjRfhfNtL0BQgp/hQxi0YVpQl/38zZJKNqCmcidL0 kel78sIn+zql04DHrPaabthZUT6zD3k70gSZIIdHZB4Qnq+HjwZpMAW/oi11Tj8w1zPg rw== Received: from dc5-exch01.marvell.com ([199.233.59.181]) by mx0a-0016f401.pphosted.com (PPS) with ESMTPS id 3mb22svkjj-4 (version=TLSv1.2 cipher=ECDHE-RSA-AES256-SHA384 bits=256 verify=NOT) for ; Thu, 08 Dec 2022 11:35:38 -0800 Received: from DC5-EXCH02.marvell.com (10.69.176.39) by DC5-EXCH01.marvell.com (10.69.176.38) with Microsoft SMTP Server (TLS) id 15.0.1497.2; Thu, 8 Dec 2022 11:35:36 -0800 Received: from maili.marvell.com (10.69.176.80) by DC5-EXCH02.marvell.com (10.69.176.39) with Microsoft SMTP Server id 15.0.1497.18 via Frontend Transport; Thu, 8 Dec 2022 11:35:36 -0800 Received: from ml-host-33.caveonetworks.com (unknown [10.110.143.233]) by maili.marvell.com (Postfix) with ESMTP id D03253F7091; Thu, 8 Dec 2022 11:35:35 -0800 (PST) From: Srikanth Yalavarthi To: Srikanth Yalavarthi CC: , , , Subject: [PATCH v1 3/4] common/ml: add generic type conversion functions Date: Thu, 8 Dec 2022 11:35:31 -0800 Message-ID: <20221208193532.16718-4-syalavarthi@marvell.com> X-Mailer: git-send-email 2.17.1 In-Reply-To: <20221208193532.16718-1-syalavarthi@marvell.com> References: <20221208193532.16718-1-syalavarthi@marvell.com> MIME-Version: 1.0 X-Proofpoint-ORIG-GUID: x3yzA5EhmAyE3wFWliL2GH8L2cuU6VnL X-Proofpoint-GUID: x3yzA5EhmAyE3wFWliL2GH8L2cuU6VnL X-Proofpoint-Virus-Version: vendor=baseguard engine=ICAP:2.0.205,Aquarius:18.0.923,Hydra:6.0.545,FMLib:17.11.122.1 definitions=2022-12-08_11,2022-12-08_01,2022-06-22_01 X-BeenThere: dev@dpdk.org X-Mailman-Version: 2.1.29 Precedence: list List-Id: DPDK patches and discussions List-Unsubscribe: , List-Archive: List-Post: List-Help: List-Subscribe: , Errors-To: dev-bounces@dpdk.org Added generic implementations to support conversion of data types. Support is enabled to handle int8, uint8, int16, uint16, float16, float32 and bfloat16 types. Signed-off-by: Srikanth Yalavarthi --- drivers/common/ml/meson.build | 2 + drivers/common/ml/ml_utils.c | 86 +--- drivers/common/ml/ml_utils_generic.c | 716 +++++++++++++++++++++++++++ drivers/common/ml/ml_utils_generic.h | 23 + 4 files changed, 758 insertions(+), 69 deletions(-) create mode 100644 drivers/common/ml/ml_utils_generic.c create mode 100644 drivers/common/ml/ml_utils_generic.h diff --git a/drivers/common/ml/meson.build b/drivers/common/ml/meson.build index 2749ab6c2e..84ae84ee4e 100644 --- a/drivers/common/ml/meson.build +++ b/drivers/common/ml/meson.build @@ -9,10 +9,12 @@ endif headers = files( 'ml_utils.h', + 'ml_utils_generic.h', ) sources = files( 'ml_utils.c', + 'ml_utils_generic.c', ) deps += ['mldev'] diff --git a/drivers/common/ml/ml_utils.c b/drivers/common/ml/ml_utils.c index 553e906172..e2edef0904 100644 --- a/drivers/common/ml/ml_utils.c +++ b/drivers/common/ml/ml_utils.c @@ -5,10 +5,14 @@ #include #include -#include #include #include "ml_utils.h" +#include "ml_utils_generic.h" + +#if defined(__ARM_NEON__) +#include "ml_utils_neon.h" +#endif int ml_io_type_size_get(enum rte_ml_io_type type) @@ -116,127 +120,71 @@ ml_io_format_to_str(enum rte_ml_io_format format, char *str, int len) int ml_float32_to_int8(float scale, uint64_t nb_elements, void *input, void *output) { - RTE_SET_USED(scale); - RTE_SET_USED(nb_elements); - RTE_SET_USED(input); - RTE_SET_USED(output); - - return -ENOTSUP; + return ml_float32_to_int8_generic(scale, nb_elements, input, output); } int ml_int8_to_float32(float scale, uint64_t nb_elements, void *input, void *output) { - RTE_SET_USED(scale); - RTE_SET_USED(nb_elements); - RTE_SET_USED(input); - RTE_SET_USED(output); - - return -ENOTSUP; + return ml_int8_to_float32_generic(scale, nb_elements, input, output); } int ml_float32_to_uint8(float scale, uint64_t nb_elements, void *input, void *output) { - RTE_SET_USED(scale); - RTE_SET_USED(nb_elements); - RTE_SET_USED(input); - RTE_SET_USED(output); - - return -ENOTSUP; + return ml_float32_to_uint8_generic(scale, nb_elements, input, output); } int ml_uint8_to_float32(float scale, uint64_t nb_elements, void *input, void *output) { - RTE_SET_USED(scale); - RTE_SET_USED(nb_elements); - RTE_SET_USED(input); - RTE_SET_USED(output); - - return -ENOTSUP; + return ml_uint8_to_float32_generic(scale, nb_elements, input, output); } int ml_float32_to_int16(float scale, uint64_t nb_elements, void *input, void *output) { - RTE_SET_USED(scale); - RTE_SET_USED(nb_elements); - RTE_SET_USED(input); - RTE_SET_USED(output); - - return -ENOTSUP; + return ml_float32_to_int16_generic(scale, nb_elements, input, output); } int ml_int16_to_float32(float scale, uint64_t nb_elements, void *input, void *output) { - RTE_SET_USED(scale); - RTE_SET_USED(nb_elements); - RTE_SET_USED(input); - RTE_SET_USED(output); - - return -ENOTSUP; + return ml_int16_to_float32_generic(scale, nb_elements, input, output); } int ml_float32_to_uint16(float scale, uint64_t nb_elements, void *input, void *output) { - RTE_SET_USED(scale); - RTE_SET_USED(nb_elements); - RTE_SET_USED(input); - RTE_SET_USED(output); - - return -ENOTSUP; + return ml_float32_to_uint16_generic(scale, nb_elements, input, output); } int ml_uint16_to_float32(float scale, uint64_t nb_elements, void *input, void *output) { - RTE_SET_USED(scale); - RTE_SET_USED(nb_elements); - RTE_SET_USED(input); - RTE_SET_USED(output); - - return -ENOTSUP; + return ml_uint16_to_float32_generic(scale, nb_elements, input, output); } int ml_float32_to_float16(uint64_t nb_elements, void *input, void *output) { - RTE_SET_USED(nb_elements); - RTE_SET_USED(input); - RTE_SET_USED(output); - - return -ENOTSUP; + return ml_float32_to_float16_generic(nb_elements, input, output); } int ml_float16_to_float32(uint64_t nb_elements, void *input, void *output) { - RTE_SET_USED(nb_elements); - RTE_SET_USED(input); - RTE_SET_USED(output); - - return -ENOTSUP; + return ml_float16_to_float32_generic(nb_elements, input, output); } int ml_float32_to_bfloat16(uint64_t nb_elements, void *input, void *output) { - RTE_SET_USED(nb_elements); - RTE_SET_USED(input); - RTE_SET_USED(output); - - return -ENOTSUP; + return ml_float32_to_bfloat16_generic(nb_elements, input, output); } int ml_bfloat16_to_float32(uint64_t nb_elements, void *input, void *output) { - RTE_SET_USED(nb_elements); - RTE_SET_USED(input); - RTE_SET_USED(output); - - return -ENOTSUP; + return ml_bfloat16_to_float32_generic(nb_elements, input, output); } diff --git a/drivers/common/ml/ml_utils_generic.c b/drivers/common/ml/ml_utils_generic.c new file mode 100644 index 0000000000..ab67a2ac7f --- /dev/null +++ b/drivers/common/ml/ml_utils_generic.c @@ -0,0 +1,716 @@ +/* SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2022 Marvell. + */ + +#include +#include +#include + +#include "ml_utils.h" +#include "ml_utils_generic.h" + +#ifndef BIT +#define BIT(nr) (1UL << (nr)) +#endif + +#ifndef BITS_PER_LONG +#define BITS_PER_LONG (__SIZEOF_LONG__ * 8) +#endif + +#ifndef GENMASK_U32 +#define GENMASK_U32(h, l) (((~0UL) << (l)) & (~0UL >> (BITS_PER_LONG - 1 - (h)))) +#endif + +/* float32: bit index of MSB & LSB of sign, exponent and mantissa */ +#define FP32_LSB_M 0 +#define FP32_MSB_M 22 +#define FP32_LSB_E 23 +#define FP32_MSB_E 30 +#define FP32_LSB_S 31 +#define FP32_MSB_S 31 + +/* float32: bitmask for sign, exponent and mantissa */ +#define FP32_MASK_S GENMASK_U32(FP32_MSB_S, FP32_LSB_S) +#define FP32_MASK_E GENMASK_U32(FP32_MSB_E, FP32_LSB_E) +#define FP32_MASK_M GENMASK_U32(FP32_MSB_M, FP32_LSB_M) + +/* float16: bit index of MSB & LSB of sign, exponent and mantissa */ +#define FP16_LSB_M 0 +#define FP16_MSB_M 9 +#define FP16_LSB_E 10 +#define FP16_MSB_E 14 +#define FP16_LSB_S 15 +#define FP16_MSB_S 15 + +/* float16: bitmask for sign, exponent and mantissa */ +#define FP16_MASK_S GENMASK_U32(FP16_MSB_S, FP16_LSB_S) +#define FP16_MASK_E GENMASK_U32(FP16_MSB_E, FP16_LSB_E) +#define FP16_MASK_M GENMASK_U32(FP16_MSB_M, FP16_LSB_M) + +/* BFLOAT16: bit index of MSB & LSB of sign, exponent and mantissa */ +#define BF16_LSB_M 0 +#define BF16_MSB_M 6 +#define BF16_LSB_E 7 +#define BF16_MSB_E 14 +#define BF16_LSB_S 15 +#define BF16_MSB_S 15 + +/* BFLOAT16: bitmask for sign, exponent and mantissa */ +#define BF16_MASK_S GENMASK_U32(BF16_MSB_S, BF16_LSB_S) +#define BF16_MASK_E GENMASK_U32(BF16_MSB_E, BF16_LSB_E) +#define BF16_MASK_M GENMASK_U32(BF16_MSB_M, BF16_LSB_M) + +/* Exponent bias */ +#define FP32_BIAS_E 127 +#define FP16_BIAS_E 15 +#define BF16_BIAS_E 127 + +#define FP32_PACK(sign, exponent, mantissa) \ + (((sign) << FP32_LSB_S) | ((exponent) << FP32_LSB_E) | (mantissa)) + +#define FP16_PACK(sign, exponent, mantissa) \ + (((sign) << FP16_LSB_S) | ((exponent) << FP16_LSB_E) | (mantissa)) + +#define BF16_PACK(sign, exponent, mantissa) \ + (((sign) << BF16_LSB_S) | ((exponent) << BF16_LSB_E) | (mantissa)) + +/* Represent float32 as float and uint32_t */ +union float32 { + float f; + uint32_t u; +}; + +int +ml_float32_to_int8_generic(float scale, uint64_t nb_elements, void *input, void *output) +{ + float *input_buffer; + int8_t *output_buffer; + uint64_t i; + int i32; + + if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) + return -EINVAL; + + input_buffer = (float *)input; + output_buffer = (int8_t *)output; + + for (i = 0; i < nb_elements; i++) { + i32 = (int32_t)round((*input_buffer) * scale); + + if (i32 < INT8_MIN) + i32 = INT8_MIN; + + if (i32 > INT8_MAX) + i32 = INT8_MAX; + + *output_buffer = (int8_t)i32; + + input_buffer++; + output_buffer++; + } + + return 0; +} + +int +ml_int8_to_float32_generic(float scale, uint64_t nb_elements, void *input, void *output) +{ + int8_t *input_buffer; + float *output_buffer; + uint64_t i; + + if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) + return -EINVAL; + + input_buffer = (int8_t *)input; + output_buffer = (float *)output; + + for (i = 0; i < nb_elements; i++) { + *output_buffer = scale * (float)(*input_buffer); + + input_buffer++; + output_buffer++; + } + + return 0; +} + +int +ml_float32_to_uint8_generic(float scale, uint64_t nb_elements, void *input, void *output) +{ + float *input_buffer; + uint8_t *output_buffer; + int32_t i32; + uint64_t i; + + if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) + return -EINVAL; + + input_buffer = (float *)input; + output_buffer = (uint8_t *)output; + + for (i = 0; i < nb_elements; i++) { + i32 = (int32_t)round((*input_buffer) * scale); + + if (i32 < 0) + i32 = 0; + + if (i32 > UINT8_MAX) + i32 = UINT8_MAX; + + *output_buffer = (uint8_t)i32; + + input_buffer++; + output_buffer++; + } + + return 0; +} + +int +ml_uint8_to_float32_generic(float scale, uint64_t nb_elements, void *input, void *output) +{ + uint8_t *input_buffer; + float *output_buffer; + uint64_t i; + + if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) + return -EINVAL; + + input_buffer = (uint8_t *)input; + output_buffer = (float *)output; + + for (i = 0; i < nb_elements; i++) { + *output_buffer = scale * (float)(*input_buffer); + + input_buffer++; + output_buffer++; + } + + return 0; +} + +int +ml_float32_to_int16_generic(float scale, uint64_t nb_elements, void *input, void *output) +{ + float *input_buffer; + int16_t *output_buffer; + int32_t i32; + uint64_t i; + + if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) + return -EINVAL; + + input_buffer = (float *)input; + output_buffer = (int16_t *)output; + + for (i = 0; i < nb_elements; i++) { + i32 = (int32_t)round((*input_buffer) * scale); + + if (i32 < INT16_MIN) + i32 = INT16_MIN; + + if (i32 > INT16_MAX) + i32 = INT16_MAX; + + *output_buffer = (int16_t)i32; + + input_buffer++; + output_buffer++; + } + + return 0; +} + +int +ml_int16_to_float32_generic(float scale, uint64_t nb_elements, void *input, void *output) +{ + int16_t *input_buffer; + float *output_buffer; + uint64_t i; + + if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) + return -EINVAL; + + input_buffer = (int16_t *)input; + output_buffer = (float *)output; + + for (i = 0; i < nb_elements; i++) { + *output_buffer = scale * (float)(*input_buffer); + + input_buffer++; + output_buffer++; + } + + return 0; +} + +int +ml_float32_to_uint16_generic(float scale, uint64_t nb_elements, void *input, void *output) +{ + float *input_buffer; + uint16_t *output_buffer; + int32_t i32; + uint64_t i; + + if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) + return -EINVAL; + + input_buffer = (float *)input; + output_buffer = (uint16_t *)output; + + for (i = 0; i < nb_elements; i++) { + i32 = (int32_t)round((*input_buffer) * scale); + + if (i32 < 0) + i32 = 0; + + if (i32 > UINT16_MAX) + i32 = UINT16_MAX; + + *output_buffer = (uint16_t)i32; + + input_buffer++; + output_buffer++; + } + + return 0; +} + +int +ml_uint16_to_float32_generic(float scale, uint64_t nb_elements, void *input, void *output) +{ + uint16_t *input_buffer; + float *output_buffer; + uint64_t i; + + if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) + return -EINVAL; + + input_buffer = (uint16_t *)input; + output_buffer = (float *)output; + + for (i = 0; i < nb_elements; i++) { + *output_buffer = scale * (float)(*input_buffer); + + input_buffer++; + output_buffer++; + } + + return 0; +} + +/* Convert a single precision floating point number (float32) into a half precision + * floating point number (float16) using round to nearest rounding mode. + */ +static uint16_t +__float32_to_float16_generic_rtn(float x) +{ + union float32 f32; /* float32 input */ + uint32_t f32_s; /* float32 sign */ + uint32_t f32_e; /* float32 exponent */ + uint32_t f32_m; /* float32 mantissa */ + uint16_t f16_s; /* float16 sign */ + uint16_t f16_e; /* float16 exponent */ + uint16_t f16_m; /* float16 mantissa */ + uint32_t tbits; /* number of truncated bits */ + uint32_t tmsb; /* MSB position of truncated bits */ + uint32_t m_32; /* temporary float32 mantissa */ + uint16_t m_16; /* temporary float16 mantissa */ + uint16_t u16; /* float16 output */ + int be_16; /* float16 biased exponent, signed */ + + f32.f = x; + f32_s = (f32.u & FP32_MASK_S) >> FP32_LSB_S; + f32_e = (f32.u & FP32_MASK_E) >> FP32_LSB_E; + f32_m = (f32.u & FP32_MASK_M) >> FP32_LSB_M; + + f16_s = f32_s; + f16_e = 0; + f16_m = 0; + + switch (f32_e) { + case (0): /* float32: zero or subnormal number */ + f16_e = 0; + if (f32_m == 0) /* zero */ + f16_m = 0; + else /* subnormal number, convert to zero */ + f16_m = 0; + break; + case (FP32_MASK_E >> FP32_LSB_E): /* float32: infinity or nan */ + f16_e = FP16_MASK_E >> FP16_LSB_E; + if (f32_m == 0) { /* infinity */ + f16_m = 0; + } else { /* nan, propagate mantissa and set MSB of mantissa to 1 */ + f16_m = f32_m >> (FP32_MSB_M - FP16_MSB_M); + f16_m |= BIT(FP16_MSB_M); + } + break; + default: /* float32: normal number */ + /* compute biased exponent for float16 */ + be_16 = (int)f32_e - FP32_BIAS_E + FP16_BIAS_E; + + /* overflow, be_16 = [31-INF], set to infinity */ + if (be_16 >= (int)(FP16_MASK_E >> FP16_LSB_E)) { + f16_e = FP16_MASK_E >> FP16_LSB_E; + f16_m = 0; + } else if ((be_16 >= 1) && (be_16 < (int)(FP16_MASK_E >> FP16_LSB_E))) { + /* normal float16, be_16 = [1:30]*/ + f16_e = be_16; + m_16 = f32_m >> (FP32_LSB_E - FP16_LSB_E); + tmsb = FP32_MSB_M - FP16_MSB_M - 1; + if ((f32_m & GENMASK_U32(tmsb, 0)) > BIT(tmsb)) { + /* round: non-zero truncated bits except MSB */ + m_16++; + + /* overflow into exponent */ + if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1) + f16_e++; + } else if ((f32_m & GENMASK_U32(tmsb, 0)) == BIT(tmsb)) { + /* round: MSB of truncated bits and LSB of m_16 is set */ + if ((m_16 & 0x1) == 0x1) { + m_16++; + + /* overflow into exponent */ + if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1) + f16_e++; + } + } + f16_m = m_16 & FP16_MASK_M; + } else if ((be_16 >= -(int)(FP16_MSB_M)) && (be_16 < 1)) { + /* underflow: zero / subnormal, be_16 = [-9:0] */ + f16_e = 0; + + /* add implicit leading zero */ + m_32 = f32_m | BIT(FP32_LSB_E); + tbits = FP32_LSB_E - FP16_LSB_E - be_16 + 1; + m_16 = m_32 >> tbits; + + /* if non-leading truncated bits are set */ + if ((f32_m & GENMASK_U32(tbits - 1, 0)) > BIT(tbits - 1)) { + m_16++; + + /* overflow into exponent */ + if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1) + f16_e++; + } else if ((f32_m & GENMASK_U32(tbits - 1, 0)) == BIT(tbits - 1)) { + /* if leading truncated bit is set */ + if ((m_16 & 0x1) == 0x1) { + m_16++; + + /* overflow into exponent */ + if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1) + f16_e++; + } + } + f16_m = m_16 & FP16_MASK_M; + } else if (be_16 == -(int)(FP16_MSB_M + 1)) { + /* underflow: zero, be_16 = [-10] */ + f16_e = 0; + if (f32_m != 0) + f16_m = 1; + else + f16_m = 0; + } else { + /* underflow: zero, be_16 = [-INF:-11] */ + f16_e = 0; + f16_m = 0; + } + + break; + } + + u16 = FP16_PACK(f16_s, f16_e, f16_m); + + return u16; +} + +int +ml_float32_to_float16_generic(uint64_t nb_elements, void *input, void *output) +{ + float *input_buffer; + uint16_t *output_buffer; + uint64_t i; + + if ((nb_elements == 0) || (input == NULL) || (output == NULL)) + return -EINVAL; + + input_buffer = (float *)input; + output_buffer = (uint16_t *)output; + + for (i = 0; i < nb_elements; i++) { + *output_buffer = __float32_to_float16_generic_rtn(*input_buffer); + + input_buffer = input_buffer + 1; + output_buffer = output_buffer + 1; + } + + return 0; +} + +/* Convert a half precision floating point number (float16) into a single precision + * floating point number (float32). + */ +static float +__float16_to_float32_generic_rtx(uint16_t f16) +{ + union float32 f32; /* float32 output */ + uint16_t f16_s; /* float16 sign */ + uint16_t f16_e; /* float16 exponent */ + uint16_t f16_m; /* float16 mantissa */ + uint32_t f32_s; /* float32 sign */ + uint32_t f32_e; /* float32 exponent */ + uint32_t f32_m; /* float32 mantissa*/ + uint8_t shift; /* number of bits to be shifted */ + uint32_t clz; /* count of leading zeroes */ + int e_16; /* float16 exponent unbiased */ + + f16_s = (f16 & FP16_MASK_S) >> FP16_LSB_S; + f16_e = (f16 & FP16_MASK_E) >> FP16_LSB_E; + f16_m = (f16 & FP16_MASK_M) >> FP16_LSB_M; + + f32_s = f16_s; + switch (f16_e) { + case (FP16_MASK_E >> FP16_LSB_E): /* float16: infinity or nan */ + f32_e = FP32_MASK_E >> FP32_LSB_E; + if (f16_m == 0x0) { /* infinity */ + f32_m = f16_m; + } else { /* nan, propagate mantissa, set MSB of mantissa to 1 */ + f32_m = f16_m; + shift = FP32_MSB_M - FP16_MSB_M; + f32_m = (f32_m << shift) & FP32_MASK_M; + f32_m |= BIT(FP32_MSB_M); + } + break; + case 0: /* float16: zero or sub-normal */ + f32_m = f16_m; + if (f16_m == 0) { /* zero signed */ + f32_e = 0; + } else { /* subnormal numbers */ + clz = __builtin_clz((uint32_t)f16_m) - sizeof(uint32_t) * 8 + FP16_LSB_E; + e_16 = (int)f16_e - clz; + f32_e = FP32_BIAS_E + e_16 - FP16_BIAS_E; + + shift = clz + (FP32_MSB_M - FP16_MSB_M) + 1; + f32_m = (f32_m << shift) & FP32_MASK_M; + } + break; + default: /* normal numbers */ + f32_m = f16_m; + e_16 = (int)f16_e; + f32_e = FP32_BIAS_E + e_16 - FP16_BIAS_E; + + shift = (FP32_MSB_M - FP16_MSB_M); + f32_m = (f32_m << shift) & FP32_MASK_M; + } + + f32.u = FP32_PACK(f32_s, f32_e, f32_m); + + return f32.f; +} + +int +ml_float16_to_float32_generic(uint64_t nb_elements, void *input, void *output) +{ + uint16_t *input_buffer; + float *output_buffer; + uint64_t i; + + if ((nb_elements == 0) || (input == NULL) || (output == NULL)) + return -EINVAL; + + input_buffer = (uint16_t *)input; + output_buffer = (float *)output; + + for (i = 0; i < nb_elements; i++) { + *output_buffer = __float16_to_float32_generic_rtx(*input_buffer); + + input_buffer = input_buffer + 1; + output_buffer = output_buffer + 1; + } + + return 0; +} + +/* Convert a single precision floating point number (float32) into a + * brain float number (bfloat16) using round to nearest rounding mode. + */ +static uint16_t +__float32_to_bfloat16_generic_rtn(float x) +{ + union float32 f32; /* float32 input */ + uint32_t f32_s; /* float32 sign */ + uint32_t f32_e; /* float32 exponent */ + uint32_t f32_m; /* float32 mantissa */ + uint16_t b16_s; /* float16 sign */ + uint16_t b16_e; /* float16 exponent */ + uint16_t b16_m; /* float16 mantissa */ + uint32_t tbits; /* number of truncated bits */ + uint16_t u16; /* float16 output */ + + f32.f = x; + f32_s = (f32.u & FP32_MASK_S) >> FP32_LSB_S; + f32_e = (f32.u & FP32_MASK_E) >> FP32_LSB_E; + f32_m = (f32.u & FP32_MASK_M) >> FP32_LSB_M; + + b16_s = f32_s; + b16_e = 0; + b16_m = 0; + + switch (f32_e) { + case (0): /* float32: zero or subnormal number */ + b16_e = 0; + if (f32_m == 0) /* zero */ + b16_m = 0; + else /* subnormal float32 number, normal bfloat16 */ + goto bf16_normal; + break; + case (FP32_MASK_E >> FP32_LSB_E): /* float32: infinity or nan */ + b16_e = BF16_MASK_E >> BF16_LSB_E; + if (f32_m == 0) { /* infinity */ + b16_m = 0; + } else { /* nan, propagate mantissa and set MSB of mantissa to 1 */ + b16_m = f32_m >> (FP32_MSB_M - BF16_MSB_M); + b16_m |= BIT(BF16_MSB_M); + } + break; + default: /* float32: normal number, normal bfloat16 */ + goto bf16_normal; + } + + goto bf16_pack; + +bf16_normal: + b16_e = f32_e; + tbits = FP32_MSB_M - BF16_MSB_M; + b16_m = f32_m >> tbits; + + /* if non-leading truncated bits are set */ + if ((f32_m & GENMASK_U32(tbits - 1, 0)) > BIT(tbits - 1)) { + b16_m++; + + /* if overflow into exponent */ + if (((b16_m & BF16_MASK_E) >> BF16_LSB_E) == 0x1) + b16_e++; + } else if ((f32_m & GENMASK_U32(tbits - 1, 0)) == BIT(tbits - 1)) { + /* if only leading truncated bit is set */ + if ((b16_m & 0x1) == 0x1) { + b16_m++; + + /* if overflow into exponent */ + if (((b16_m & BF16_MASK_E) >> BF16_LSB_E) == 0x1) + b16_e++; + } + } + b16_m = b16_m & BF16_MASK_M; + +bf16_pack: + u16 = BF16_PACK(b16_s, b16_e, b16_m); + + return u16; +} + +int +ml_float32_to_bfloat16_generic(uint64_t nb_elements, void *input, void *output) +{ + float *input_buffer; + uint16_t *output_buffer; + uint64_t i; + + if ((nb_elements == 0) || (input == NULL) || (output == NULL)) + return -EINVAL; + + input_buffer = (float *)input; + output_buffer = (uint16_t *)output; + + for (i = 0; i < nb_elements; i++) { + *output_buffer = __float32_to_bfloat16_generic_rtn(*input_buffer); + + input_buffer = input_buffer + 1; + output_buffer = output_buffer + 1; + } + + return 0; +} + +/* Convert a brain float number (bfloat16) into a + * single precision floating point number (float32). + */ +static float +__bfloat16_to_float32_generic_rtx(uint16_t f16) +{ + union float32 f32; /* float32 output */ + uint16_t b16_s; /* float16 sign */ + uint16_t b16_e; /* float16 exponent */ + uint16_t b16_m; /* float16 mantissa */ + uint32_t f32_s; /* float32 sign */ + uint32_t f32_e; /* float32 exponent */ + uint32_t f32_m; /* float32 mantissa*/ + uint8_t shift; /* number of bits to be shifted */ + + b16_s = (f16 & BF16_MASK_S) >> BF16_LSB_S; + b16_e = (f16 & BF16_MASK_E) >> BF16_LSB_E; + b16_m = (f16 & BF16_MASK_M) >> BF16_LSB_M; + + f32_s = b16_s; + switch (b16_e) { + case (BF16_MASK_E >> BF16_LSB_E): /* bfloat16: infinity or nan */ + f32_e = FP32_MASK_E >> FP32_LSB_E; + if (b16_m == 0x0) { /* infinity */ + f32_m = 0; + } else { /* nan, propagate mantissa, set MSB of mantissa to 1 */ + f32_m = b16_m; + shift = FP32_MSB_M - BF16_MSB_M; + f32_m = (f32_m << shift) & FP32_MASK_M; + f32_m |= BIT(FP32_MSB_M); + } + break; + case 0: /* bfloat16: zero or subnormal */ + f32_m = b16_m; + if (b16_m == 0) { /* zero signed */ + f32_e = 0; + } else { /* subnormal numbers */ + goto fp32_normal; + } + break; + default: /* bfloat16: normal number */ + goto fp32_normal; + } + + goto fp32_pack; + +fp32_normal: + f32_m = b16_m; + f32_e = FP32_BIAS_E + b16_e - BF16_BIAS_E; + + shift = (FP32_MSB_M - BF16_MSB_M); + f32_m = (f32_m << shift) & FP32_MASK_M; + +fp32_pack: + f32.u = FP32_PACK(f32_s, f32_e, f32_m); + + return f32.f; +} + +int +ml_bfloat16_to_float32_generic(uint64_t nb_elements, void *input, void *output) +{ + uint16_t *input_buffer; + float *output_buffer; + uint64_t i; + + if ((nb_elements == 0) || (input == NULL) || (output == NULL)) + return -EINVAL; + + input_buffer = (uint16_t *)input; + output_buffer = (float *)output; + + for (i = 0; i < nb_elements; i++) { + *output_buffer = __bfloat16_to_float32_generic_rtx(*input_buffer); + + input_buffer = input_buffer + 1; + output_buffer = output_buffer + 1; + } + + return 0; +} diff --git a/drivers/common/ml/ml_utils_generic.h b/drivers/common/ml/ml_utils_generic.h new file mode 100644 index 0000000000..9d47d8466e --- /dev/null +++ b/drivers/common/ml/ml_utils_generic.h @@ -0,0 +1,23 @@ +/* SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2022 Marvell. + */ + +#ifndef _ML_UTILS_GENERIC_H_ +#define _ML_UTILS_GENERIC_H_ + +#include + +int ml_float32_to_int8_generic(float scale, uint64_t nb_elements, void *input, void *output); +int ml_int8_to_float32_generic(float scale, uint64_t nb_elements, void *input, void *output); +int ml_float32_to_uint8_generic(float scale, uint64_t nb_elements, void *input, void *output); +int ml_uint8_to_float32_generic(float scale, uint64_t nb_elements, void *input, void *output); +int ml_float32_to_int16_generic(float scale, uint64_t nb_elements, void *input, void *output); +int ml_int16_to_float32_generic(float scale, uint64_t nb_elements, void *input, void *output); +int ml_float32_to_uint16_generic(float scale, uint64_t nb_elements, void *input, void *output); +int ml_uint16_to_float32_generic(float scale, uint64_t nb_elements, void *input, void *output); +int ml_float32_to_float16_generic(uint64_t nb_elements, void *input, void *output); +int ml_float16_to_float32_generic(uint64_t nb_elements, void *input, void *output); +int ml_float32_to_bfloat16_generic(uint64_t nb_elements, void *input, void *output); +int ml_bfloat16_to_float32_generic(uint64_t nb_elements, void *input, void *output); + +#endif /*_ML_UTILS_GENERIC_H_ */ From patchwork Thu Dec 8 19:35:32 2022 Content-Type: text/plain; charset="utf-8" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit X-Patchwork-Submitter: Srikanth Yalavarthi X-Patchwork-Id: 120599 X-Patchwork-Delegate: thomas@monjalon.net Return-Path: X-Original-To: patchwork@inbox.dpdk.org Delivered-To: patchwork@inbox.dpdk.org Received: from mails.dpdk.org (mails.dpdk.org [217.70.189.124]) by inbox.dpdk.org (Postfix) with ESMTP id 6D5EDA0032; Thu, 8 Dec 2022 20:36:01 +0100 (CET) Received: from mails.dpdk.org (localhost [127.0.0.1]) by mails.dpdk.org (Postfix) with ESMTP id 0B1FF42D18; Thu, 8 Dec 2022 20:35:44 +0100 (CET) Received: from mx0b-0016f401.pphosted.com (mx0a-0016f401.pphosted.com [67.231.148.174]) by mails.dpdk.org (Postfix) with ESMTP id 0394A42D30 for ; Thu, 8 Dec 2022 20:35:41 +0100 (CET) Received: from pps.filterd (m0045849.ppops.net [127.0.0.1]) by mx0a-0016f401.pphosted.com (8.17.1.19/8.17.1.19) with ESMTP id 2B8J8KCP001352; Thu, 8 Dec 2022 11:35:38 -0800 DKIM-Signature: v=1; a=rsa-sha256; c=relaxed/relaxed; d=marvell.com; h=from : to : cc : subject : date : message-id : in-reply-to : references : mime-version : content-type; s=pfpt0220; bh=KH6ecjgraHiq5NtgnhM+/G6TWws3ixdRMI0yw8LXPfM=; b=TDNlZ5zEhbozH84Z2vyzXw59uRcNaKoUq1fXZvd0evu7Dcpjgn3bx9QCYoe6JH5ltjDK ve+hmdUhEA7A2LSFazeduLsNIgNSC4AggSOt+UG3f8MFoSoYemEJFJ43am+ROvIBiHd7 3sdVL4reC92e1wmb31mBwTpGnx++Uf1UEvoemuQnNrra839K+z1HL0XR28ZWL4scVV4q Y/1YbnK7Qpe5sZ+/0blxaia+hiEBtzcGFFi6xczEpN1v6UqJNW6TUl08o+sVomJFS3MW vXT8qSZA8EpYhoTmzNDpfPtzmMaMZHxq9Bn2Wsv8MzGiXgqCiDuxB3aHNOueAOi8h/O/ /Q== Received: from dc5-exch02.marvell.com ([199.233.59.182]) by mx0a-0016f401.pphosted.com (PPS) with ESMTPS id 3mb22svkjm-1 (version=TLSv1.2 cipher=ECDHE-RSA-AES256-SHA384 bits=256 verify=NOT); Thu, 08 Dec 2022 11:35:38 -0800 Received: from DC5-EXCH02.marvell.com (10.69.176.39) by DC5-EXCH02.marvell.com (10.69.176.39) with Microsoft SMTP Server (TLS) id 15.0.1497.18; Thu, 8 Dec 2022 11:35:36 -0800 Received: from maili.marvell.com (10.69.176.80) by DC5-EXCH02.marvell.com (10.69.176.39) with Microsoft SMTP Server id 15.0.1497.18 via Frontend Transport; Thu, 8 Dec 2022 11:35:36 -0800 Received: from ml-host-33.caveonetworks.com (unknown [10.110.143.233]) by maili.marvell.com (Postfix) with ESMTP id 2AE9B3F706F; Thu, 8 Dec 2022 11:35:36 -0800 (PST) From: Srikanth Yalavarthi To: Srikanth Yalavarthi , Ruifeng Wang CC: , , , Subject: [PATCH v1 4/4] common/ml: add Arm NEON type conversion routines Date: Thu, 8 Dec 2022 11:35:32 -0800 Message-ID: <20221208193532.16718-5-syalavarthi@marvell.com> X-Mailer: git-send-email 2.17.1 In-Reply-To: <20221208193532.16718-1-syalavarthi@marvell.com> References: <20221208193532.16718-1-syalavarthi@marvell.com> MIME-Version: 1.0 X-Proofpoint-ORIG-GUID: 3iE-h1Ss98HQxxLtoQdhZ4kjJawZjPii X-Proofpoint-GUID: 3iE-h1Ss98HQxxLtoQdhZ4kjJawZjPii X-Proofpoint-Virus-Version: vendor=baseguard engine=ICAP:2.0.205,Aquarius:18.0.923,Hydra:6.0.545,FMLib:17.11.122.1 definitions=2022-12-08_11,2022-12-08_01,2022-06-22_01 X-BeenThere: dev@dpdk.org X-Mailman-Version: 2.1.29 Precedence: list List-Id: DPDK patches and discussions List-Unsubscribe: , List-Archive: List-Post: List-Help: List-Subscribe: , Errors-To: dev-bounces@dpdk.org Added ARM NEON intrinsic based implementations to support conversion of data types. Support is enabled to handle int8, uint8, int16, uint16, float16, float32 and bfloat16 types. Signed-off-by: Srikanth Yalavarthi --- drivers/common/ml/meson.build | 5 + drivers/common/ml/ml_utils.c | 48 ++ drivers/common/ml/ml_utils_neon.c | 950 ++++++++++++++++++++++++++++++ drivers/common/ml/ml_utils_neon.h | 23 + 4 files changed, 1026 insertions(+) create mode 100644 drivers/common/ml/ml_utils_neon.c create mode 100644 drivers/common/ml/ml_utils_neon.h diff --git a/drivers/common/ml/meson.build b/drivers/common/ml/meson.build index 84ae84ee4e..f7ce19b4b4 100644 --- a/drivers/common/ml/meson.build +++ b/drivers/common/ml/meson.build @@ -17,6 +17,11 @@ sources = files( 'ml_utils_generic.c', ) +if arch_subdir == 'arm' + headers += files('ml_utils_neon.h') + sources += files('ml_utils_neon.c') +endif + deps += ['mldev'] pmd_supports_disable_iova_as_pa = true diff --git a/drivers/common/ml/ml_utils.c b/drivers/common/ml/ml_utils.c index e2edef0904..3edcf09fde 100644 --- a/drivers/common/ml/ml_utils.c +++ b/drivers/common/ml/ml_utils.c @@ -120,71 +120,119 @@ ml_io_format_to_str(enum rte_ml_io_format format, char *str, int len) int ml_float32_to_int8(float scale, uint64_t nb_elements, void *input, void *output) { +#if defined(__ARM_NEON__) + return ml_float32_to_int8_neon(scale, nb_elements, input, output); +#else return ml_float32_to_int8_generic(scale, nb_elements, input, output); +#endif } int ml_int8_to_float32(float scale, uint64_t nb_elements, void *input, void *output) { +#if defined(__ARM_NEON__) + return ml_int8_to_float32_neon(scale, nb_elements, input, output); +#else return ml_int8_to_float32_generic(scale, nb_elements, input, output); +#endif } int ml_float32_to_uint8(float scale, uint64_t nb_elements, void *input, void *output) { +#if defined(__ARM_NEON__) + return ml_float32_to_uint8_neon(scale, nb_elements, input, output); +#else return ml_float32_to_uint8_generic(scale, nb_elements, input, output); +#endif } int ml_uint8_to_float32(float scale, uint64_t nb_elements, void *input, void *output) { +#if defined(__ARM_NEON__) + return ml_uint8_to_float32_neon(scale, nb_elements, input, output); +#else return ml_uint8_to_float32_generic(scale, nb_elements, input, output); +#endif } int ml_float32_to_int16(float scale, uint64_t nb_elements, void *input, void *output) { +#if defined(__ARM_NEON__) + return ml_float32_to_int16_neon(scale, nb_elements, input, output); +#else return ml_float32_to_int16_generic(scale, nb_elements, input, output); +#endif } int ml_int16_to_float32(float scale, uint64_t nb_elements, void *input, void *output) { +#if defined(__ARM_NEON__) + return ml_int16_to_float32_neon(scale, nb_elements, input, output); +#else return ml_int16_to_float32_generic(scale, nb_elements, input, output); +#endif } int ml_float32_to_uint16(float scale, uint64_t nb_elements, void *input, void *output) { +#if defined(__ARM_NEON__) + return ml_float32_to_uint16_neon(scale, nb_elements, input, output); +#else return ml_float32_to_uint16_generic(scale, nb_elements, input, output); +#endif } int ml_uint16_to_float32(float scale, uint64_t nb_elements, void *input, void *output) { +#if defined(__ARM_NEON__) + return ml_uint16_to_float32_neon(scale, nb_elements, input, output); +#else return ml_uint16_to_float32_generic(scale, nb_elements, input, output); +#endif } int ml_float32_to_float16(uint64_t nb_elements, void *input, void *output) { +#if defined(__ARM_NEON__) + return ml_float32_to_float16_neon(scale, nb_elements, input, output); +#else return ml_float32_to_float16_generic(nb_elements, input, output); +#endif } int ml_float16_to_float32(uint64_t nb_elements, void *input, void *output) { +#if defined(__ARM_NEON__) + return ml_float16_to_float32_neon(scale, nb_elements, input, output); +#else return ml_float16_to_float32_generic(nb_elements, input, output); +#endif } int ml_float32_to_bfloat16(uint64_t nb_elements, void *input, void *output) { +#if defined(__ARM_FEATURE_BF16) + return ml_float32_to_bfloat16_neon(scale, nb_elements, input, output); +#else return ml_float32_to_bfloat16_generic(nb_elements, input, output); +#endif } int ml_bfloat16_to_float32(uint64_t nb_elements, void *input, void *output) { +#if defined(__ARM_FEATURE_BF16) + return ml_bfloat16_to_float32_neon(scale, nb_elements, input, output); +#else return ml_bfloat16_to_float32_generic(nb_elements, input, output); +#endif } diff --git a/drivers/common/ml/ml_utils_neon.c b/drivers/common/ml/ml_utils_neon.c new file mode 100644 index 0000000000..b660de07ec --- /dev/null +++ b/drivers/common/ml/ml_utils_neon.c @@ -0,0 +1,950 @@ +/* SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2022 Marvell. + */ + +#include +#include +#include + +#include +#include + +#include "ml_utils.h" +#include "ml_utils_neon.h" + +#include + +static void +__float32_to_int8_neon_s8x8(float scale, float *input, int8_t *output) +{ + int16x4_t s16x4_l; + int16x4_t s16x4_h; + float32x4_t f32x4; + int16x8_t s16x8; + int32x4_t s32x4; + int32x4_t vmin; + int32x4_t vmax; + int8x8_t s8x8; + + /* set constants */ + vmin = vdupq_n_s32(INT8_MIN); + vmax = vdupq_n_s32(INT8_MAX); + + /* load 4 float32 elements, scale, convert, update ranges and narrow to int16. + * Use round to nearest with ties away rounding mode. + */ + f32x4 = vld1q_f32(input); + f32x4 = vmulq_n_f32(f32x4, scale); + s32x4 = vcvtaq_s32_f32(f32x4); + s32x4 = vminq_s32(s32x4, vmax); + s32x4 = vmaxq_s32(s32x4, vmin); + s16x4_l = vmovn_s32(s32x4); + + /* load next 4 float32 elements, scale, convert, update ranges and narrow to int16. + * Use round to nearest with ties away rounding mode. + */ + f32x4 = vld1q_f32(input + 4); + f32x4 = vmulq_n_f32(f32x4, scale); + s32x4 = vcvtaq_s32_f32(f32x4); + s32x4 = vminq_s32(s32x4, vmax); + s32x4 = vmaxq_s32(s32x4, vmin); + s16x4_h = vmovn_s32(s32x4); + + /* combine lower and higher int16x4_t to int16x8_t */ + s16x8 = vcombine_s16(s16x4_l, s16x4_h); + + /* narrow to int8_t */ + s8x8 = vmovn_s16(s16x8); + + /* store 8 elements */ + vst1_s8(output, s8x8); +} + +static void +__float32_to_int8_neon_s8x1(float scale, float *input, int8_t *output) +{ + float32x2_t f32x2; + int32x2_t s32x2; + int32x2_t vmin; + int32x2_t vmax; + int8x8_t s8x8; + + /* set constants */ + vmin = vdup_n_s32(INT8_MIN); + vmax = vdup_n_s32(INT8_MAX); + + /* load element to 2 lanes */ + f32x2 = vld1_dup_f32(input); + + /* scale */ + f32x2 = vmul_n_f32(f32x2, scale); + + /* convert with use round to nearest with ties away rounding mode */ + s32x2 = vcvta_s32_f32(f32x2); + + /* update range [INT8_MIN:INT8_MAX] */ + s32x2 = vmin_s32(s32x2, vmax); + s32x2 = vmax_s32(s32x2, vmin); + + /* convert to int8_t */ + s8x8 = vreinterpret_s8_s32(s32x2); + + /* store lane 0 / 1 element */ + vst1_lane_s8(output, s8x8, 0); +} + +int +ml_float32_to_int8_neon(float scale, uint64_t nb_elements, void *input, void *output) +{ + float *input_buffer; + int8_t *output_buffer; + uint32_t batch_size; + uint64_t i; + + if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) + return -EINVAL; + + input_buffer = (float *)input; + output_buffer = (int8_t *)output; + batch_size = 2 * sizeof(float) / sizeof(int8_t); + + /* convert batch_size elements in each iteration */ + for (i = 0; i < (nb_elements / batch_size); i++) { + __float32_to_int8_neon_s8x8(scale, input_buffer, output_buffer); + input_buffer += batch_size; + output_buffer += batch_size; + } + + /* convert leftover elements */ + i = i * batch_size; + for (; i < nb_elements; i++) { + __float32_to_int8_neon_s8x1(scale, input_buffer, output_buffer); + input_buffer++; + output_buffer++; + } + + return 0; +} + +static void +__int8_to_float32_neon_f32x8(float scale, int8_t *input, float *output) +{ + float32x4_t f32x4; + int16x8_t s16x8; + int16x4_t s16x4; + int32x4_t s32x4; + int8x8_t s8x8; + + /* load 8 x int8_t elements */ + s8x8 = vld1_s8(input); + + /* widen int8_t to int16_t */ + s16x8 = vmovl_s8(s8x8); + + /* convert lower 4 elements: widen to int32_t, convert to float, scale and store */ + s16x4 = vget_low_s16(s16x8); + s32x4 = vmovl_s16(s16x4); + f32x4 = vcvtq_f32_s32(s32x4); + f32x4 = vmulq_n_f32(f32x4, scale); + vst1q_f32(output, f32x4); + + /* convert higher 4 elements: widen to int32_t, convert to float, scale and store */ + s16x4 = vget_high_s16(s16x8); + s32x4 = vmovl_s16(s16x4); + f32x4 = vcvtq_f32_s32(s32x4); + f32x4 = vmulq_n_f32(f32x4, scale); + vst1q_f32(output + 4, f32x4); +} + +static void +__int8_to_float32_neon_f32x1(float scale, int8_t *input, float *output) +{ + *output = scale * vcvts_f32_s32((int32_t)*input); +} + +int +ml_int8_to_float32_neon(float scale, uint64_t nb_elements, void *input, void *output) +{ + int8_t *input_buffer; + float *output_buffer; + uint32_t vlen; + uint64_t i; + + if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) + return -EINVAL; + + input_buffer = (int8_t *)input; + output_buffer = (float *)output; + vlen = 2 * sizeof(float) / sizeof(int8_t); + + /* convert vlen elements in each iteration */ + for (i = 0; i < (nb_elements / vlen); i++) { + __int8_to_float32_neon_f32x8(scale, input_buffer, output_buffer); + input_buffer += vlen; + output_buffer += vlen; + } + + /* convert leftover elements */ + i = i * vlen; + for (; i < nb_elements; i++) { + __int8_to_float32_neon_f32x1(scale, input_buffer, output_buffer); + input_buffer++; + output_buffer++; + } + + return 0; +} + +static void +__float32_to_uint8_neon_u8x8(float scale, float *input, uint8_t *output) +{ + uint16x4_t u16x4_l; + uint16x4_t u16x4_h; + float32x4_t f32x4; + uint32x4_t u32x4; + uint16x8_t u16x8; + uint32x4_t vmax; + uint8x8_t u8x8; + + /* set constants */ + vmax = vdupq_n_u32(UINT8_MAX); + + /* load 4 float elements, scale, convert, update range and narrow to uint16_t. + * use round to nearest with ties away rounding mode. + */ + f32x4 = vld1q_f32(input); + f32x4 = vmulq_n_f32(f32x4, scale); + u32x4 = vcvtaq_u32_f32(f32x4); + u32x4 = vminq_u32(u32x4, vmax); + u16x4_l = vmovn_u32(u32x4); + + /* load next 4 float elements, scale, convert, update range and narrow to uint16_t + * use round to nearest with ties away rounding mode. + */ + f32x4 = vld1q_f32(input + 4); + f32x4 = vmulq_n_f32(f32x4, scale); + u32x4 = vcvtaq_u32_f32(f32x4); + u32x4 = vminq_u32(u32x4, vmax); + u16x4_h = vmovn_u32(u32x4); + + /* combine lower and higher uint16x4_t */ + u16x8 = vcombine_u16(u16x4_l, u16x4_h); + + /* narrow to uint8x8_t */ + u8x8 = vmovn_u16(u16x8); + + /* store 8 elements */ + vst1_u8(output, u8x8); +} + +static void +__float32_to_uint8_neon_u8x1(float scale, float *input, uint8_t *output) +{ + float32x2_t f32x2; + uint32x2_t u32x2; + uint32x2_t vmax; + uint8x8_t u8x8; + + /* set constants */ + vmax = vdup_n_u32(UINT8_MAX); + + /* load element to 2 lanes */ + f32x2 = vld1_dup_f32(input); + + /* scale */ + f32x2 = vmul_n_f32(f32x2, scale); + + /* convert to uin32_t using round to nearest with ties away rounding mode */ + u32x2 = vcvta_u32_f32(f32x2); + + /* update range [0:UINT8_MAX] */ + u32x2 = vmin_u32(u32x2, vmax); + + /* convert to uint8x8_t */ + u8x8 = vreinterpret_u8_u32(u32x2); + + /* store lane 0 / 1 element */ + vst1_lane_u8(output, u8x8, 0); +} + +int +ml_float32_to_uint8_neon(float scale, uint64_t nb_elements, void *input, void *output) +{ + float *input_buffer; + uint8_t *output_buffer; + uint32_t vlen; + uint64_t i; + + if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) + return -EINVAL; + + input_buffer = (float *)input; + output_buffer = (uint8_t *)output; + vlen = 2 * sizeof(float) / sizeof(uint8_t); + + /* convert vlen elements in each iteration */ + for (i = 0; i < (nb_elements / vlen); i++) { + __float32_to_uint8_neon_u8x8(scale, input_buffer, output_buffer); + input_buffer += vlen; + output_buffer += vlen; + } + + /* convert leftover elements */ + i = i * vlen; + for (; i < nb_elements; i++) { + __float32_to_uint8_neon_u8x1(scale, input_buffer, output_buffer); + input_buffer++; + output_buffer++; + } + + return 0; +} + +static void +__uint8_to_float32_neon_f32x8(float scale, uint8_t *input, float *output) +{ + float32x4_t f32x4; + uint16x8_t u16x8; + uint16x4_t u16x4; + uint32x4_t u32x4; + uint8x8_t u8x8; + + /* load 8 x uint8_t elements */ + u8x8 = vld1_u8(input); + + /* widen uint8_t to uint16_t */ + u16x8 = vmovl_u8(u8x8); + + /* convert lower 4 elements: widen to uint32_t, convert to float, scale and store */ + u16x4 = vget_low_u16(u16x8); + u32x4 = vmovl_u16(u16x4); + f32x4 = vcvtq_f32_u32(u32x4); + f32x4 = vmulq_n_f32(f32x4, scale); + vst1q_f32(output, f32x4); + + /* convert higher 4 elements: widen to uint32_t, convert to float, scale and store */ + u16x4 = vget_high_u16(u16x8); + u32x4 = vmovl_u16(u16x4); + f32x4 = vcvtq_f32_u32(u32x4); + f32x4 = vmulq_n_f32(f32x4, scale); + vst1q_f32(output + 4, f32x4); +} + +static void +__uint8_to_float32_neon_f32x1(float scale, uint8_t *input, float *output) +{ + *output = scale * vcvts_f32_u32((uint32_t)*input); +} + +int +ml_uint8_to_float32_neon(float scale, uint64_t nb_elements, void *input, void *output) +{ + uint8_t *input_buffer; + float *output_buffer; + uint64_t vlen; + uint64_t i; + + if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) + return -EINVAL; + + input_buffer = (uint8_t *)input; + output_buffer = (float *)output; + vlen = 2 * sizeof(float) / sizeof(uint8_t); + + /* convert vlen elements in each iteration */ + for (i = 0; i < (nb_elements / vlen); i++) { + __uint8_to_float32_neon_f32x8(scale, input_buffer, output_buffer); + input_buffer += vlen; + output_buffer += vlen; + } + + /* convert leftover elements */ + i = i * vlen; + for (; i < nb_elements; i++) { + __uint8_to_float32_neon_f32x1(scale, input_buffer, output_buffer); + input_buffer++; + output_buffer++; + } + + return 0; +} + +static void +__float32_to_int16_neon_s16x4(float scale, float *input, int16_t *output) +{ + float32x4_t f32x4; + int16x4_t s16x4; + int32x4_t s32x4; + int32x4_t vmin; + int32x4_t vmax; + + /* set constants */ + vmin = vdupq_n_s32(INT16_MIN); + vmax = vdupq_n_s32(INT16_MAX); + + /* load 4 x float elements */ + f32x4 = vld1q_f32(input); + + /* scale */ + f32x4 = vmulq_n_f32(f32x4, scale); + + /* convert to int32x4_t using round to nearest with ties away rounding mode */ + s32x4 = vcvtaq_s32_f32(f32x4); + + /* update range [INT16_MIN:INT16_MAX] */ + s32x4 = vminq_s32(s32x4, vmax); + s32x4 = vmaxq_s32(s32x4, vmin); + + /* narrow to int16x4_t */ + s16x4 = vmovn_s32(s32x4); + + /* store 4 elements */ + vst1_s16(output, s16x4); +} + +static void +__float32_to_int16_neon_s16x1(float scale, float *input, int16_t *output) +{ + float32x2_t f32x2; + int32x2_t s32x2; + int16x4_t s16x4; + int32x2_t vmin; + int32x2_t vmax; + + /* set constants */ + vmin = vdup_n_s32(INT16_MIN); + vmax = vdup_n_s32(INT16_MAX); + + /* load element to 2 lanes */ + f32x2 = vld1_dup_f32(input); + + /* scale */ + f32x2 = vmul_n_f32(f32x2, scale); + + /* convert using round to nearest with ties to away rounding mode */ + s32x2 = vcvta_s32_f32(f32x2); + + /* update range [INT16_MIN:INT16_MAX] */ + s32x2 = vmin_s32(s32x2, vmax); + s32x2 = vmax_s32(s32x2, vmin); + + /* convert to int16x4_t */ + s16x4 = vreinterpret_s16_s32(s32x2); + + /* store lane 0 / 1 element */ + vst1_lane_s16(output, s16x4, 0); +} + +int +ml_float32_to_int16_neon(float scale, uint64_t nb_elements, void *input, void *output) +{ + float *input_buffer; + int16_t *output_buffer; + uint32_t vlen; + uint64_t i; + + if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) + return -EINVAL; + + input_buffer = (float *)input; + output_buffer = (int16_t *)output; + vlen = 2 * sizeof(float) / sizeof(int16_t); + + /* convert vlen elements in each iteration */ + for (i = 0; i < (nb_elements / vlen); i++) { + __float32_to_int16_neon_s16x4(scale, input_buffer, output_buffer); + input_buffer += vlen; + output_buffer += vlen; + } + + /* convert leftover elements */ + i = i * vlen; + for (; i < nb_elements; i++) { + __float32_to_int16_neon_s16x1(scale, input_buffer, output_buffer); + input_buffer++; + output_buffer++; + } + + return 0; +} + +static void +__int16_to_float32_neon_f32x4(float scale, int16_t *input, float *output) +{ + float32x4_t f32x4; + int16x4_t s16x4; + int32x4_t s32x4; + + /* load 4 x int16_t elements */ + s16x4 = vld1_s16(input); + + /* widen int16_t to int32_t */ + s32x4 = vmovl_s16(s16x4); + + /* convert uint32_t to float */ + f32x4 = vcvtq_f32_s32(s32x4); + + /* scale */ + f32x4 = vmulq_n_f32(f32x4, scale); + + /* store float32x4_t */ + vst1q_f32(output, f32x4); +} + +static void +__int16_to_float32_neon_f32x1(float scale, int16_t *input, float *output) +{ + *output = scale * vcvts_f32_s32((int32_t)*input); +} + +int +ml_int16_to_float32_neon(float scale, uint64_t nb_elements, void *input, void *output) +{ + int16_t *input_buffer; + float *output_buffer; + uint32_t vlen; + uint64_t i; + + if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) + return -EINVAL; + + input_buffer = (int16_t *)input; + output_buffer = (float *)output; + vlen = 2 * sizeof(float) / sizeof(int16_t); + + /* convert vlen elements in each iteration */ + for (i = 0; i < (nb_elements / vlen); i++) { + __int16_to_float32_neon_f32x4(scale, input_buffer, output_buffer); + input_buffer += vlen; + output_buffer += vlen; + } + + /* convert leftover elements */ + i = i * vlen; + for (; i < nb_elements; i++) { + __int16_to_float32_neon_f32x1(scale, input_buffer, output_buffer); + input_buffer++; + output_buffer++; + } + + return 0; +} + +static void +__float32_to_uint16_neon_u16x4(float scale, float *input, uint16_t *output) +{ + float32x4_t f32x4; + uint16x4_t u16x4; + uint32x4_t u32x4; + uint32x4_t vmax; + + /* set constants */ + vmax = vdupq_n_u32(UINT16_MAX); + + /* load 4 float elements */ + f32x4 = vld1q_f32(input); + + /* scale */ + f32x4 = vmulq_n_f32(f32x4, scale); + + /* convert using round to nearest with ties to away rounding mode */ + u32x4 = vcvtaq_u32_f32(f32x4); + + /* update range [0:UINT16_MAX] */ + u32x4 = vminq_u32(u32x4, vmax); + + /* narrow */ + u16x4 = vmovn_u32(u32x4); + + /* store 4 elements */ + vst1_u16(output, u16x4); +} + +static void +__float32_to_uint16_neon_u16x1(float scale, float *input, uint16_t *output) +{ + float32x2_t f32x2; + uint16x4_t u16x4; + int32x2_t s32x2; + int32x2_t vmax; + + /* set constants */ + vmax = vdup_n_s32(UINT16_MAX); + + /* load element to 2 lanes */ + f32x2 = vld1_dup_f32(input); + + /* scale */ + f32x2 = vmul_n_f32(f32x2, scale); + + /* convert using round to nearest with ties to away rounding mode */ + s32x2 = vcvta_s32_f32(f32x2); + + /* update range [0:UINT16_MAX] */ + s32x2 = vmin_s32(s32x2, vmax); + + /* convert to uint16x4_t */ + u16x4 = vreinterpret_u16_s32(s32x2); + + /* store lane 0 / 1 element */ + vst1_lane_u16(output, u16x4, 0); +} + +int +ml_float32_to_uint16_neon(float scale, uint64_t nb_elements, void *input, void *output) +{ + float *input_buffer; + uint16_t *output_buffer; + uint64_t vlen; + uint64_t i; + + if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) + return -EINVAL; + + input_buffer = (float *)input; + output_buffer = (uint16_t *)output; + vlen = 2 * sizeof(float) / sizeof(uint16_t); + + /* convert vlen elements in each iteration */ + for (i = 0; i < (nb_elements / vlen); i++) { + __float32_to_uint16_neon_u16x4(scale, input_buffer, output_buffer); + input_buffer += vlen; + output_buffer += vlen; + } + + /* convert leftover elements */ + i = i * vlen; + for (; i < nb_elements; i++) { + __float32_to_uint16_neon_u16x1(scale, input_buffer, output_buffer); + input_buffer++; + output_buffer++; + } + + return 0; +} + +static void +__uint16_to_float32_neon_f32x4(float scale, uint16_t *input, float *output) +{ + float32x4_t f32x4; + uint16x4_t u16x4; + uint32x4_t u32x4; + + /* load 4 x uint16_t elements */ + u16x4 = vld1_u16(input); + + /* widen uint16_t to uint32_t */ + u32x4 = vmovl_u16(u16x4); + + /* convert uint32_t to float */ + f32x4 = vcvtq_f32_u32(u32x4); + + /* scale */ + f32x4 = vmulq_n_f32(f32x4, scale); + + /* store float32x4_t */ + vst1q_f32(output, f32x4); +} + +static void +__uint16_to_float32_neon_f32x1(float scale, uint16_t *input, float *output) +{ + *output = scale * vcvts_f32_u32((uint32_t)*input); +} + +int +ml_uint16_to_float32_neon(float scale, uint64_t nb_elements, void *input, void *output) +{ + uint16_t *input_buffer; + float *output_buffer; + uint32_t vlen; + uint64_t i; + + if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL)) + return -EINVAL; + + input_buffer = (uint16_t *)input; + output_buffer = (float *)output; + vlen = 2 * sizeof(float) / sizeof(uint16_t); + + /* convert vlen elements in each iteration */ + for (i = 0; i < (nb_elements / vlen); i++) { + __uint16_to_float32_neon_f32x4(scale, input_buffer, output_buffer); + input_buffer += vlen; + output_buffer += vlen; + } + + /* convert leftover elements */ + i = i * vlen; + for (; i < nb_elements; i++) { + __uint16_to_float32_neon_f32x1(scale, input_buffer, output_buffer); + input_buffer++; + output_buffer++; + } + + return 0; +} + +static void +__float32_to_float16_neon_f16x4(float32_t *input, float16_t *output) +{ + float32x4_t f32x4; + float16x4_t f16x4; + + /* load 4 x float32_t elements */ + f32x4 = vld1q_f32(input); + + /* convert to float16x4_t */ + f16x4 = vcvt_f16_f32(f32x4); + + /* store float16x4_t */ + vst1_f16(output, f16x4); +} + +static void +__float32_to_float16_neon_f16x1(float32_t *input, float16_t *output) +{ + float32x4_t f32x4; + float16x4_t f16x4; + + /* load element to 4 lanes */ + f32x4 = vld1q_dup_f32(input); + + /* convert float32_t to float16_t */ + f16x4 = vcvt_f16_f32(f32x4); + + /* store lane 0 / 1 element */ + vst1_lane_f16(output, f16x4, 0); +} + +int +ml_float32_to_float16_neon(uint64_t nb_elements, void *input, void *output) +{ + float32_t *input_buffer; + float16_t *output_buffer; + uint32_t vlen; + uint64_t i; + + if ((nb_elements == 0) || (input == NULL) || (output == NULL)) + return -EINVAL; + + input_buffer = (float32_t *)input; + output_buffer = (float16_t *)output; + vlen = 2 * sizeof(float32_t) / sizeof(float16_t); + + /* convert vlen elements in each iteration */ + for (i = 0; i < (nb_elements / vlen); i++) { + __float32_to_float16_neon_f16x4(input_buffer, output_buffer); + input_buffer += vlen; + output_buffer += vlen; + } + + /* convert leftover elements */ + i = i * vlen; + for (; i < nb_elements; i++) { + __float32_to_float16_neon_f16x1(input_buffer, output_buffer); + input_buffer++; + output_buffer++; + } + + return 0; +} + +static void +__float16_to_float32_neon_f32x4(float16_t *input, float32_t *output) +{ + float16x4_t f16x4; + float32x4_t f32x4; + + /* load 4 x float16_t elements */ + f16x4 = vld1_f16(input); + + /* convert float16x4_t to float32x4_t */ + f32x4 = vcvt_f32_f16(f16x4); + + /* store float32x4_t */ + vst1q_f32(output, f32x4); +} + +static void +__float16_to_float32_neon_f32x1(float16_t *input, float32_t *output) +{ + float16x4_t f16x4; + float32x4_t f32x4; + + /* load element to 4 lanes */ + f16x4 = vld1_dup_f16(input); + + /* convert float16_t to float32_t */ + f32x4 = vcvt_f32_f16(f16x4); + + /* store 1 element */ + vst1q_lane_f32(output, f32x4, 0); +} + +int +ml_float16_to_float32_neon(uint64_t nb_elements, void *input, void *output) +{ + float16_t *input_buffer; + float32_t *output_buffer; + uint32_t vlen; + uint64_t i; + + if ((nb_elements == 0) || (input == NULL) || (output == NULL)) + return -EINVAL; + + input_buffer = (float16_t *)input; + output_buffer = (float32_t *)output; + vlen = 2 * sizeof(float32_t) / sizeof(float16_t); + + /* convert vlen elements in each iteration */ + for (i = 0; i < (nb_elements / vlen); i++) { + __float16_to_float32_neon_f32x4(input_buffer, output_buffer); + input_buffer += vlen; + output_buffer += vlen; + } + + /* convert leftover elements */ + i = i * vlen; + for (; i < nb_elements; i++) { + __float16_to_float32_neon_f32x1(input_buffer, output_buffer); + input_buffer++; + output_buffer++; + } + + return 0; +} + +#ifdef __ARM_FEATURE_BF16 + +static void +__float32_to_bfloat16_neon_f16x4(float32_t *input, bfloat16_t *output) +{ + float32x4_t f32x4; + bfloat16x4_t bf16x4; + + /* load 4 x float32_t elements */ + f32x4 = vld1q_f32(input); + + /* convert float32x4_t to bfloat16x4_t */ + bf16x4 = vcvt_bf16_f32(f32x4); + + /* store bfloat16x4_t */ + vst1_bf16(output, bf16x4); +} + +static void +__float32_to_bfloat16_neon_f16x1(float32_t *input, bfloat16_t *output) +{ + float32x4_t f32x4; + bfloat16x4_t bf16x4; + + /* load element to 4 lanes */ + f32x4 = vld1q_dup_f32(input); + + /* convert float32_t to bfloat16_t */ + bf16x4 = vcvt_bf16_f32(f32x4); + + /* store lane 0 / 1 element */ + vst1_lane_bf16(output, bf16x4, 0); +} + +int +ml_float32_to_bfloat16_neon(uint64_t nb_elements, void *input, void *output) +{ + float32_t *input_buffer; + bfloat16_t *output_buffer; + uint32_t vlen; + uint64_t i; + + if ((nb_elements == 0) || (input == NULL) || (output == NULL)) + return -EINVAL; + + input_buffer = (float32_t *)input; + output_buffer = (bfloat16_t *)output; + vlen = 2 * sizeof(float32_t) / sizeof(bfloat16_t); + + /* convert vlen elements in each iteration */ + for (i = 0; i < (nb_elements / vlen); i++) { + __float32_to_bfloat16_neon_f16x4(input_buffer, output_buffer); + input_buffer += vlen; + output_buffer += vlen; + } + + /* convert leftover elements */ + i = i * vlen; + for (; i < nb_elements; i++) { + __float32_to_bfloat16_neon_f16x1(input_buffer, output_buffer); + input_buffer++; + output_buffer++; + } + + return 0; +} + +static void +__bfloat16_to_float32_neon_f32x4(bfloat16_t *input, float32_t *output) +{ + bfloat16x4_t bf16x4; + float32x4_t f32x4; + + /* load 4 x bfloat16_t elements */ + bf16x4 = vld1_bf16(input); + + /* convert bfloat16x4_t to float32x4_t */ + f32x4 = vcvt_f32_bf16(bf16x4); + + /* store float32x4_t */ + vst1q_f32(output, f32x4); +} + +static void +__bfloat16_to_float32_neon_f32x1(bfloat16_t *input, float32_t *output) +{ + bfloat16x4_t bf16x4; + float32x4_t f32x4; + + /* load element to 4 lanes */ + bf16x4 = vld1_dup_bf16(input); + + /* convert bfloat16_t to float32_t */ + f32x4 = vcvt_f32_bf16(bf16x4); + + /* store lane 0 / 1 element */ + vst1q_lane_f32(output, f32x4, 0); +} + +int +ml_bfloat16_to_float32_neon(uint64_t nb_elements, void *input, void *output) +{ + bfloat16_t *input_buffer; + float32_t *output_buffer; + uint32_t vlen; + uint64_t i; + + if ((nb_elements == 0) || (input == NULL) || (output == NULL)) + return -EINVAL; + + input_buffer = (bfloat16_t *)input; + output_buffer = (float32_t *)output; + vlen = 2 * sizeof(float32_t) / sizeof(bfloat16_t); + + /* convert vlen elements in each iteration */ + for (i = 0; i < (nb_elements / vlen); i++) { + __bfloat16_to_float32_neon_f32x4(input_buffer, output_buffer); + input_buffer += vlen; + output_buffer += vlen; + } + + /* convert leftover elements */ + i = i * vlen; + for (; i < nb_elements; i++) { + __bfloat16_to_float32_neon_f32x1(input_buffer, output_buffer); + input_buffer++; + output_buffer++; + } + + return 0; +} + +#endif /* __ARM_FEATURE_BF16 */ diff --git a/drivers/common/ml/ml_utils_neon.h b/drivers/common/ml/ml_utils_neon.h new file mode 100644 index 0000000000..d912049779 --- /dev/null +++ b/drivers/common/ml/ml_utils_neon.h @@ -0,0 +1,23 @@ +/* SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2022 Marvell. + */ + +#ifndef _ML_UTILS_NEON_H_ +#define _ML_UTILS_NEON_H_ + +#include + +int ml_float32_to_int8_neon(float scale, uint64_t nb_elements, void *input, void *output); +int ml_int8_to_float32_neon(float scale, uint64_t nb_elements, void *input, void *output); +int ml_float32_to_uint8_neon(float scale, uint64_t nb_elements, void *input, void *output); +int ml_uint8_to_float32_neon(float scale, uint64_t nb_elements, void *input, void *output); +int ml_float32_to_int16_neon(float scale, uint64_t nb_elements, void *input, void *output); +int ml_int16_to_float32_neon(float scale, uint64_t nb_elements, void *input, void *output); +int ml_float32_to_uint16_neon(float scale, uint64_t nb_elements, void *input, void *output); +int ml_uint16_to_float32_neon(float scale, uint64_t nb_elements, void *input, void *output); +int ml_float32_to_float16_neon(uint64_t nb_elements, void *input, void *output); +int ml_float16_to_float32_neon(uint64_t nb_elements, void *input, void *output); +int ml_float32_to_bfloat16_neon(uint64_t nb_elements, void *input, void *output); +int ml_bfloat16_to_float32_neon(uint64_t nb_elements, void *input, void *output); + +#endif /*_ML_UTILS_NEON_H_ */