@@ -1835,45 +1835,23 @@ cn10k_ml_model_stop(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model)
}
int
-cn10k_ml_model_info_get(struct rte_ml_dev *dev, uint16_t model_id,
- struct rte_ml_model_info *model_info)
+cn10k_ml_model_params_update(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model,
+ void *buffer)
{
- struct cnxk_ml_model *model;
-
- model = dev->data->models[model_id];
-
- if (model == NULL) {
- plt_err("Invalid model_id = %u", model_id);
- return -EINVAL;
- }
-
- rte_memcpy(model_info, model->info, sizeof(struct rte_ml_model_info));
- model_info->input_info = ((struct rte_ml_model_info *)model->info)->input_info;
- model_info->output_info = ((struct rte_ml_model_info *)model->info)->output_info;
-
- return 0;
-}
-
-int
-cn10k_ml_model_params_update(struct rte_ml_dev *dev, uint16_t model_id, void *buffer)
-{
- struct cnxk_ml_model *model;
-
- model = dev->data->models[model_id];
+ struct cnxk_ml_layer *layer;
- if (model == NULL) {
- plt_err("Invalid model_id = %u", model_id);
- return -EINVAL;
- }
+ RTE_SET_USED(cnxk_mldev);
if (model->state == ML_CNXK_MODEL_STATE_UNKNOWN)
return -1;
else if (model->state != ML_CNXK_MODEL_STATE_LOADED)
return -EBUSY;
+ layer = &model->layer[0];
+
/* Update model weights & bias */
- rte_memcpy(model->layer[0].glow.addr.wb_load_addr, buffer,
- model->layer[0].glow.metadata.weights_bias.file_size);
+ rte_memcpy(layer->glow.addr.wb_load_addr, buffer,
+ layer->glow.metadata.weights_bias.file_size);
return 0;
}
@@ -317,9 +317,8 @@ int cn10k_ml_model_load(struct cnxk_ml_dev *cnxk_mldev, struct rte_ml_model_para
int cn10k_ml_model_unload(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model);
int cn10k_ml_model_start(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model);
int cn10k_ml_model_stop(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model);
-int cn10k_ml_model_info_get(struct rte_ml_dev *dev, uint16_t model_id,
- struct rte_ml_model_info *model_info);
-int cn10k_ml_model_params_update(struct rte_ml_dev *dev, uint16_t model_id, void *buffer);
+int cn10k_ml_model_params_update(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model *model,
+ void *buffer);
/* I/O ops */
int cn10k_ml_io_quantize(struct rte_ml_dev *dev, uint16_t model_id,
@@ -664,6 +664,50 @@ cnxk_ml_model_stop(struct rte_ml_dev *dev, uint16_t model_id)
return cn10k_ml_model_stop(cnxk_mldev, model);
}
+static int
+cnxk_ml_model_info_get(struct rte_ml_dev *dev, uint16_t model_id,
+ struct rte_ml_model_info *model_info)
+{
+ struct rte_ml_model_info *info;
+ struct cnxk_ml_model *model;
+
+ if ((dev == NULL) || (model_info == NULL))
+ return -EINVAL;
+
+ model = dev->data->models[model_id];
+ if (model == NULL) {
+ plt_err("Invalid model_id = %u", model_id);
+ return -EINVAL;
+ }
+
+ info = (struct rte_ml_model_info *)model->info;
+ rte_memcpy(model_info, info, sizeof(struct rte_ml_model_info));
+ model_info->input_info = info->input_info;
+ model_info->output_info = info->output_info;
+
+ return 0;
+}
+
+static int
+cnxk_ml_model_params_update(struct rte_ml_dev *dev, uint16_t model_id, void *buffer)
+{
+ struct cnxk_ml_dev *cnxk_mldev;
+ struct cnxk_ml_model *model;
+
+ if ((dev == NULL) || (buffer == NULL))
+ return -EINVAL;
+
+ cnxk_mldev = dev->data->dev_private;
+
+ model = dev->data->models[model_id];
+ if (model == NULL) {
+ plt_err("Invalid model_id = %u", model_id);
+ return -EINVAL;
+ }
+
+ return cn10k_ml_model_params_update(cnxk_mldev, model, buffer);
+}
+
struct rte_ml_dev_ops cnxk_ml_ops = {
/* Device control ops */
.dev_info_get = cnxk_ml_dev_info_get,
@@ -691,8 +735,8 @@ struct rte_ml_dev_ops cnxk_ml_ops = {
.model_unload = cnxk_ml_model_unload,
.model_start = cnxk_ml_model_start,
.model_stop = cnxk_ml_model_stop,
- .model_info_get = cn10k_ml_model_info_get,
- .model_params_update = cn10k_ml_model_params_update,
+ .model_info_get = cnxk_ml_model_info_get,
+ .model_params_update = cnxk_ml_model_params_update,
/* I/O ops */
.io_quantize = cn10k_ml_io_quantize,