From patchwork Tue Dec 20 19:26:41 2022 Content-Type: text/plain; charset="utf-8" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit X-Patchwork-Submitter: Srikanth Yalavarthi X-Patchwork-Id: 121133 X-Patchwork-Delegate: thomas@monjalon.net Return-Path: X-Original-To: patchwork@inbox.dpdk.org Delivered-To: patchwork@inbox.dpdk.org Received: from mails.dpdk.org (mails.dpdk.org [217.70.189.124]) by inbox.dpdk.org (Postfix) with ESMTP id CCD36A0545; Tue, 20 Dec 2022 20:30:30 +0100 (CET) Received: from mails.dpdk.org (localhost [127.0.0.1]) by mails.dpdk.org (Postfix) with ESMTP id 1228C42DAC; Tue, 20 Dec 2022 20:27:27 +0100 (CET) Received: from mx0b-0016f401.pphosted.com (mx0a-0016f401.pphosted.com [67.231.148.174]) by mails.dpdk.org (Postfix) with ESMTP id 931FA42D33 for ; Tue, 20 Dec 2022 20:27:03 +0100 (CET) Received: from pps.filterd (m0045849.ppops.net [127.0.0.1]) by mx0a-0016f401.pphosted.com (8.17.1.19/8.17.1.19) with ESMTP id 2BKHPQ13015491 for ; Tue, 20 Dec 2022 11:27:02 -0800 DKIM-Signature: v=1; a=rsa-sha256; c=relaxed/relaxed; d=marvell.com; h=from : to : cc : subject : date : message-id : in-reply-to : references : mime-version : content-type; s=pfpt0220; bh=H8qUp1n4cKIYJTX/ibpgWYX4F43dXfd9QfFTF5ex4tI=; b=AMQWCWJelaJRLt9E3sL5w49+o2rtXzxL5dih9hehOJGZH0bdg+vT37r79Z8xNY3NRwCv YfTLWEMfcH3IdMUMPyaOYw4YijRfwYayN84mHwrz+ygIYeTDmMaR/Ii4D08lvKE8bSgT ryBhe9PammVY4uj4jqt3Q/LIItqvKTqjuVDl6qdK4X8gFanDeF+jykYyKkblSS547ONe 7sVKvUX2UTf5dyr8bGVcUMCQC00FikdFavne342gdm0XmbLL6FKZlp5oULK39XrtnrSL MmQZPCxkTjMnJUJ+uaDsO27eFJJYlH+rWI5B36Q9T6OHkQBCtZp5AlZhuizQ5CG1fiFH sg== Received: from dc5-exch02.marvell.com ([199.233.59.182]) by mx0a-0016f401.pphosted.com (PPS) with ESMTPS id 3mkapj2tq2-8 (version=TLSv1.2 cipher=ECDHE-RSA-AES256-SHA384 bits=256 verify=NOT) for ; Tue, 20 Dec 2022 11:27:02 -0800 Received: from DC5-EXCH02.marvell.com (10.69.176.39) by DC5-EXCH02.marvell.com (10.69.176.39) with Microsoft SMTP Server (TLS) id 15.0.1497.42; Tue, 20 Dec 2022 11:27:01 -0800 Received: from maili.marvell.com (10.69.176.80) by DC5-EXCH02.marvell.com (10.69.176.39) with Microsoft SMTP Server id 15.0.1497.42 via Frontend Transport; Tue, 20 Dec 2022 11:27:01 -0800 Received: from ml-host-33.caveonetworks.com (unknown [10.110.143.233]) by maili.marvell.com (Postfix) with ESMTP id EEFE83F706F; Tue, 20 Dec 2022 11:27:00 -0800 (PST) From: Srikanth Yalavarthi To: Srikanth Yalavarthi CC: , , , Subject: [PATCH v3 34/38] ml/cnxk: add support to enable model data caching Date: Tue, 20 Dec 2022 11:26:41 -0800 Message-ID: <20221220192645.14042-35-syalavarthi@marvell.com> X-Mailer: git-send-email 2.17.1 In-Reply-To: <20221220192645.14042-1-syalavarthi@marvell.com> References: <20221208201806.21893-1-syalavarthi@marvell.com> <20221220192645.14042-1-syalavarthi@marvell.com> MIME-Version: 1.0 X-Proofpoint-ORIG-GUID: muqFe8TQjPBwJnFP9oYqJgcqkn7PslfM X-Proofpoint-GUID: muqFe8TQjPBwJnFP9oYqJgcqkn7PslfM X-Proofpoint-Virus-Version: vendor=baseguard engine=ICAP:2.0.205,Aquarius:18.0.923,Hydra:6.0.545,FMLib:17.11.122.1 definitions=2022-12-20_06,2022-12-20_01,2022-06-22_01 X-BeenThere: dev@dpdk.org X-Mailman-Version: 2.1.29 Precedence: list List-Id: DPDK patches and discussions List-Unsubscribe: , List-Archive: List-Post: List-Help: List-Subscribe: , Errors-To: dev-bounces@dpdk.org Added device argument 'cache_model_data' to enable model data caching. An inference request would be executed with dummy data in synchronous mode during model start stage. This run would cache the model weights and bias in the memory and result in improved inference throughput. cache_model_data = 1, enable (default) cache_model_data = 0, disable Signed-off-by: Srikanth Yalavarthi --- drivers/ml/cnxk/cn10k_ml_dev.c | 33 ++++++++++++++++++++-- drivers/ml/cnxk/cn10k_ml_dev.h | 3 ++ drivers/ml/cnxk/cn10k_ml_ops.c | 50 ++++++++++++++++++++++++++++++++++ 3 files changed, 84 insertions(+), 2 deletions(-) diff --git a/drivers/ml/cnxk/cn10k_ml_dev.c b/drivers/ml/cnxk/cn10k_ml_dev.c index ac6592891b..948708a420 100644 --- a/drivers/ml/cnxk/cn10k_ml_dev.c +++ b/drivers/ml/cnxk/cn10k_ml_dev.c @@ -20,10 +20,12 @@ #define CN10K_ML_FW_PATH "fw_path" #define CN10K_ML_FW_ENABLE_DPE_WARNINGS "enable_dpe_warnings" #define CN10K_ML_FW_REPORT_DPE_WARNINGS "report_dpe_warnings" +#define CN10K_ML_DEV_CACHE_MODEL_DATA "cache_model_data" #define CN10K_ML_FW_PATH_DEFAULT "/lib/firmware/mlip-fw.bin" #define CN10K_ML_FW_ENABLE_DPE_WARNINGS_DEFAULT 1 #define CN10K_ML_FW_REPORT_DPE_WARNINGS_DEFAULT 0 +#define CN10K_ML_DEV_CACHE_MODEL_DATA_DEFAULT 1 /* ML firmware macros */ #define FW_MEMZONE_NAME "ml_cn10k_fw_mz" @@ -38,7 +40,8 @@ #define FW_REPORT_DPE_WARNING_BITMASK BIT(1) static const char *const valid_args[] = {CN10K_ML_FW_PATH, CN10K_ML_FW_ENABLE_DPE_WARNINGS, - CN10K_ML_FW_REPORT_DPE_WARNINGS, NULL}; + CN10K_ML_FW_REPORT_DPE_WARNINGS, + CN10K_ML_DEV_CACHE_MODEL_DATA, NULL}; /* Dummy operations for ML device */ struct rte_ml_dev_ops ml_dev_dummy_ops = {0}; @@ -76,6 +79,7 @@ cn10k_mldev_parse_devargs(struct rte_devargs *devargs, struct cn10k_ml_dev *mlde { bool enable_dpe_warnings_set = false; bool report_dpe_warnings_set = false; + bool cache_model_data_set = false; struct rte_kvargs *kvlist = NULL; bool fw_path_set = false; char *fw_path = NULL; @@ -124,6 +128,18 @@ cn10k_mldev_parse_devargs(struct rte_devargs *devargs, struct cn10k_ml_dev *mlde report_dpe_warnings_set = true; } + if (rte_kvargs_count(kvlist, CN10K_ML_DEV_CACHE_MODEL_DATA) == 1) { + ret = rte_kvargs_process(kvlist, CN10K_ML_DEV_CACHE_MODEL_DATA, &parse_integer_arg, + &mldev->cache_model_data); + if (ret < 0) { + plt_err("Error processing arguments, key = %s\n", + CN10K_ML_DEV_CACHE_MODEL_DATA); + ret = -EINVAL; + goto exit; + } + cache_model_data_set = true; + } + check_args: if (!fw_path_set) mldev->fw.path = CN10K_ML_FW_PATH_DEFAULT; @@ -155,6 +171,18 @@ cn10k_mldev_parse_devargs(struct rte_devargs *devargs, struct cn10k_ml_dev *mlde } plt_info("ML: %s = %d", CN10K_ML_FW_REPORT_DPE_WARNINGS, mldev->fw.report_dpe_warnings); + if (!cache_model_data_set) { + mldev->cache_model_data = CN10K_ML_DEV_CACHE_MODEL_DATA_DEFAULT; + } else { + if ((mldev->cache_model_data < 0) || (mldev->cache_model_data > 1)) { + plt_err("Invalid argument, %s = %d\n", CN10K_ML_DEV_CACHE_MODEL_DATA, + mldev->cache_model_data); + ret = -EINVAL; + goto exit; + } + } + plt_info("ML: %s = %d", CN10K_ML_DEV_CACHE_MODEL_DATA, mldev->cache_model_data); + exit: if (kvlist) rte_kvargs_free(kvlist); @@ -694,4 +722,5 @@ RTE_PMD_REGISTER_KMOD_DEP(MLDEV_NAME_CN10K_PMD, "vfio-pci"); RTE_PMD_REGISTER_PARAM_STRING(MLDEV_NAME_CN10K_PMD, CN10K_ML_FW_PATH "=" CN10K_ML_FW_ENABLE_DPE_WARNINGS - "=<0|1>" CN10K_ML_FW_REPORT_DPE_WARNINGS "=<0|1>"); + "=<0|1>" CN10K_ML_FW_REPORT_DPE_WARNINGS + "=<0|1>" CN10K_ML_DEV_CACHE_MODEL_DATA "=<0|1>"); diff --git a/drivers/ml/cnxk/cn10k_ml_dev.h b/drivers/ml/cnxk/cn10k_ml_dev.h index 9ba56ffba6..718edadde7 100644 --- a/drivers/ml/cnxk/cn10k_ml_dev.h +++ b/drivers/ml/cnxk/cn10k_ml_dev.h @@ -381,6 +381,9 @@ struct cn10k_ml_dev { /* xstats status */ bool xstats_enabled; + + /* Enable / disable model data caching */ + int cache_model_data; }; uint64_t cn10k_ml_fw_flags_get(struct cn10k_ml_fw *fw); diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c index cdd9ae9c69..1e6d366c59 100644 --- a/drivers/ml/cnxk/cn10k_ml_ops.c +++ b/drivers/ml/cnxk/cn10k_ml_ops.c @@ -488,6 +488,49 @@ cn10k_ml_model_xstat_reset(struct rte_ml_dev *dev, uint16_t model_id, } } +static int +cn10k_ml_cache_model_data(struct rte_ml_dev *dev, int16_t model_id) +{ + struct cn10k_ml_model *model; + struct rte_ml_op op; + + char str[RTE_MEMZONE_NAMESIZE]; + const struct plt_memzone *mz; + uint64_t isize = 0; + uint64_t osize = 0; + int ret = 0; + + model = dev->data->models[model_id]; + + /* Create input and output buffers. */ + rte_ml_io_input_size_get(dev->data->dev_id, model_id, model->batch_size, &isize, NULL); + rte_ml_io_output_size_get(dev->data->dev_id, model_id, model->batch_size, &osize, NULL); + + snprintf(str, RTE_MEMZONE_NAMESIZE, "%s_%d", "ml_dummy_io", model_id); + mz = plt_memzone_reserve_aligned(str, isize + osize, 0, ML_CN10K_ALIGN_SIZE); + if (mz == NULL) + return -ENOMEM; + memset(mz->addr, 0, isize + osize); + + op.model_id = model_id; + op.nb_batches = model->batch_size; + op.mempool = NULL; + + op.input.addr = mz->addr; + op.input.length = isize; + op.input.next = NULL; + + op.output.addr = PLT_PTR_ADD(op.input.addr, isize); + op.output.length = osize; + op.output.next = NULL; + + memset(model->req, 0, sizeof(struct cn10k_ml_req)); + ret = cn10k_ml_inference_sync(dev, &op); + plt_memzone_free(mz); + + return ret; +} + static int cn10k_ml_dev_info_get(struct rte_ml_dev *dev, struct rte_ml_dev_info *dev_info) { @@ -1467,6 +1510,13 @@ cn10k_ml_model_start(struct rte_ml_dev *dev, int16_t model_id) } } + if (ret < 0) { /* Call unload to update model and FW state, ignore error */ + rte_ml_model_stop(dev->data->dev_id, model_id); + } else { + if (mldev->cache_model_data && roc_model_is_cn10ka()) + ret = cn10k_ml_cache_model_data(dev, model_id); + } + return ret; }