1# 使用MindSpore Lite进行端侧训练 (C/C++)
2
3## 场景介绍
4
5MindSpore Lite是一款AI引擎,它提供了面向不同硬件设备AI模型推理的功能,目前已经在图像分类、目标识别、人脸识别、文字识别等应用中广泛使用,同时支持在端侧设备上进行部署训练,让模型在实际业务场景中自适应用户的行为。
6
7本文介绍使用MindSpore Lite端侧AI引擎进行模型训练的通用开发流程。
8
9
10## 接口说明
11此处给出使用MindSpore Lite进行模型训练相关的部分接口,具体请见下方表格
12
13| 接口名称        | 描述        |
14| ------------------ | ----------------- |
15|OH_AI_ContextHandle OH_AI_ContextCreate()|创建一个上下文的对象。注意:此接口需跟OH_AI_ContextDestroy配套使用。|
16|OH_AI_DeviceInfoHandle OH_AI_DeviceInfoCreate(OH_AI_DeviceType device_type)|创建一个运行时设备信息对象。|
17|void OH_AI_ContextDestroy(OH_AI_ContextHandle *context)|释放上下文对象。|
18|void OH_AI_ContextAddDeviceInfo(OH_AI_ContextHandle context, OH_AI_DeviceInfoHandle device_info)|添加运行时设备信息。|
19|OH_AI_TrainCfgHandle OH_AI_TrainCfgCreate()|创建训练配置对象指针。|
20|void OH_AI_TrainCfgDestroy(OH_AI_TrainCfgHandle *train_cfg)|销毁训练配置对象指针。|
21|OH_AI_ModelHandle OH_AI_ModelCreate()|创建一个模型对象。|
22|OH_AI_Status OH_AI_TrainModelBuildFromFile(OH_AI_ModelHandle model, const char *model_path, OH_AI_ModelType model_type, const OH_AI_ContextHandle model_context, const OH_AI_TrainCfgHandle train_cfg)|通过模型文件加载并编译MindSpore训练模型。|
23|OH_AI_Status OH_AI_RunStep(OH_AI_ModelHandle model, const OH_AI_KernelCallBack before, const OH_AI_KernelCallBack after)|单步训练模型。|
24|OH_AI_Status OH_AI_ModelSetTrainMode(OH_AI_ModelHandle model, bool train)|设置训练模式。|
25|OH_AI_Status OH_AI_ExportModel(OH_AI_ModelHandle model, OH_AI_ModelType model_type, const char *model_file, OH_AI_QuantizationType quantization_type, bool export_inference_only, char **output_tensor_name, size_t num)|导出训练后的ms模型。|
26|void OH_AI_ModelDestroy(OH_AI_ModelHandle *model)|释放一个模型对象。|
27
28
29## 开发步骤
30使用MindSpore Lite进行模型训练的开发流程如下图所示。
31
32**图 1** 使用MindSpore Lite进行模型训练的开发流程
33![how-to-use-train](figures/train_sequence_unify_api.png)
34
35进入主要流程之前需要先引用相关的头文件,并编写函数生成随机的输入,具体如下:
36
37```c
38#include <stdlib.h>
39#include <stdio.h>
40#include <string.h>
41#include "mindspore/model.h"
42
43int GenerateInputDataWithRandom(OH_AI_TensorHandleArray inputs) {
44  for (size_t i = 0; i < inputs.handle_num; ++i) {
45    float *input_data = (float *)OH_AI_TensorGetMutableData(inputs.handle_list[i]);
46    if (input_data == NULL) {
47      printf("OH_AI_TensorGetMutableData failed.\n");
48      return  OH_AI_STATUS_LITE_ERROR;
49    }
50    int64_t num = OH_AI_TensorGetElementNum(inputs.handle_list[i]);
51    const int divisor = 10;
52    for (size_t j = 0; j < num; j++) {
53      input_data[j] = (float)(rand() % divisor) / divisor;  // 0--0.9f
54    }
55  }
56  return OH_AI_STATUS_SUCCESS;
57}
58```
59
60然后进入主要的开发步骤,包括模型的准备、读取、编译、训练、模型导出和释放,具体开发过程及细节请见下文的开发步骤及示例。
61
621. 模型准备。
63
64    准备的模型格式为`.ms`,本文以[lenet_train.ms](https://gitee.com/openharmony-sig/compatibility/blob/master/test_suite/resource/master/standard%20system/acts/resource/ai/mindspore/lenet_train/lenet_train.ms)为例(此模型是提前准备的`ms`模型)。如果开发者需要使用自己准备的模型,可以按如下步骤操作:
65
66    - 首先基于MindSpore架构使用Python创建网络模型,并导出为`.mindir`文件,详细指南参考[这里](https://www.mindspore.cn/tutorials/zh-CN/r2.1/beginner/quick_start.html)67    - 然后将`.mindir`模型文件转换成`.ms`文件,转换操作步骤可以参考[训练模型转换](https://www.mindspore.cn/lite/docs/zh-CN/r2.1/use/converter_train.html),`.ms`文件可以导入端侧设备并基于MindSpore端侧框架进行训练。
68
692. 创建上下文,设置设备类型、训练配置等参数。
70
71    ```c
72    // Create and init context, add CPU device info
73    OH_AI_ContextHandle context = OH_AI_ContextCreate();
74    if (context == NULL) {
75        printf("OH_AI_ContextCreate failed.\n");
76        return OH_AI_STATUS_LITE_ERROR;
77    }
78
79    OH_AI_DeviceInfoHandle cpu_device_info = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_CPU);
80    if (cpu_device_info == NULL) {
81        printf("OH_AI_DeviceInfoCreate failed.\n");
82        OH_AI_ContextDestroy(&context);
83        return OH_AI_STATUS_LITE_ERROR;
84    }
85    OH_AI_ContextAddDeviceInfo(context, cpu_device_info);
86
87    // Create trainCfg
88    OH_AI_TrainCfgHandle trainCfg = OH_AI_TrainCfgCreate();
89    if (trainCfg == NULL) {
90        printf("OH_AI_TrainCfgCreate failed.\n");
91        OH_AI_ContextDestroy(&context);
92        return OH_AI_STATUS_LITE_ERROR;
93    }
94    ```
95
963. 创建、加载与编译模型。
97
98    调用OH_AI_TrainModelBuildFromFile加载并编译模型。
99
100    ```c
101    // Create model
102    OH_AI_ModelHandle model = OH_AI_ModelCreate();
103    if (model == NULL) {
104        printf("OH_AI_ModelCreate failed.\n");
105        OH_AI_TrainCfgDestroy(&trainCfg);
106        OH_AI_ContextDestroy(&context);
107        return OH_AI_STATUS_LITE_ERROR;
108    }
109
110    // Build model
111    int ret = OH_AI_TrainModelBuildFromFile(model, model_file, OH_AI_MODELTYPE_MINDIR, context, trainCfg);
112    if (ret != OH_AI_STATUS_SUCCESS) {
113        printf("OH_AI_TrainModelBuildFromFile failed, ret: %d.\n", ret);
114        OH_AI_ModelDestroy(&model);
115        OH_AI_ContextDestroy(&context);
116        return ret;
117    }
118    ```
119
1204. 输入数据。
121
122    模型执行之前需要向输入的张量中填充数据。本例使用随机的数据对模型进行填充。
123
124    ```c
125    // Get Inputs
126    OH_AI_TensorHandleArray inputs = OH_AI_ModelGetInputs(model);
127    if (inputs.handle_list == NULL) {
128        printf("OH_AI_ModelGetInputs failed, ret: %d.\n", ret);
129        OH_AI_ModelDestroy(&model);
130        OH_AI_ContextDestroy(&context);
131        return ret;
132    }
133
134    // Generate random data as input data.
135    ret = GenerateInputDataWithRandom(inputs);
136    if (ret != OH_AI_STATUS_SUCCESS) {
137        printf("GenerateInputDataWithRandom failed, ret: %d.\n", ret);
138        OH_AI_ModelDestroy(&model);
139        OH_AI_ContextDestroy(&context);
140        return ret;
141    }
142    ```
143
1445. 执行训练。
145
146    使用OH_AI_ModelSetTrainMode接口设置训练模式,使用OH_AI_RunStep接口进行模型训练。
147
148    ```c
149    // Set Traim Mode
150    ret = OH_AI_ModelSetTrainMode(model, true);
151    if (ret != OH_AI_STATUS_SUCCESS) {
152        printf("OH_AI_ModelSetTrainMode failed, ret: %d.\n", ret);
153        OH_AI_ModelDestroy(&model);
154        OH_AI_ContextDestroy(&context);
155        return ret;
156    }
157
158    // Model Train Step
159    ret = OH_AI_RunStep(model, NULL, NULL);
160    if (ret != OH_AI_STATUS_SUCCESS) {
161        printf("OH_AI_RunStep failed, ret: %d.\n", ret);
162        OH_AI_ModelDestroy(&model);
163        OH_AI_ContextDestroy(&context);
164        return ret;
165    }
166    printf("Train Step Success.\n");
167    ```
168
1696. 导出训练后模型。
170
171    使用OH_AI_ExportModel接口导出训练后模型。
172
173    ```c
174    // Export Train Model
175    ret = OH_AI_ExportModel(model, OH_AI_MODELTYPE_MINDIR, export_train_model, OH_AI_NO_QUANT, false, NULL, 0);
176    if (ret != OH_AI_STATUS_SUCCESS) {
177        printf("OH_AI_ExportModel train failed, ret: %d.\n", ret);
178        OH_AI_ModelDestroy(&model);
179        OH_AI_ContextDestroy(&context);
180        return ret;
181    }
182    printf("Export Train Model Success.\n");
183
184    // Export Inference Model
185    ret = OH_AI_ExportModel(model, OH_AI_MODELTYPE_MINDIR, export_infer_model, OH_AI_NO_QUANT, true, NULL, 0);
186    if (ret != OH_AI_STATUS_SUCCESS) {
187        printf("OH_AI_ExportModel inference failed, ret: %d.\n", ret);
188        OH_AI_ModelDestroy(&model);
189        OH_AI_ContextDestroy(&context);
190        return ret;
191    }
192    printf("Export Inference Model Success.\n");
193    ```
194
1957. 释放模型。
196
197    不再使用MindSpore Lite推理框架时,需要释放已经创建的模型。
198
199    ```c
200    // Delete model and context.
201    OH_AI_ModelDestroy(&model);
202    OH_AI_ContextDestroy(&context);
203    ```
204
205
206## 调测验证
207
2081. 编写CMakeLists.txt209    ```c
210    cmake_minimum_required(VERSION 3.14)
211    project(TrainDemo)
212
213    add_executable(train_demo main.c)
214
215    target_link_libraries(
216            train_demo
217            mindspore_lite_ndk
218    )
219    ```
220
221   - 使用ohos-sdk交叉编译,需要对CMake设置native工具链路径,即:`-DCMAKE_TOOLCHAIN_FILE="/xxx/native/build/cmake/ohos.toolchain.camke"`。
222
223   - 编译命令如下,其中OHOS_NDK需要设置为native工具链路径:
224      ```shell
225        mkdir -p build
226
227        cd ./build || exit
228        OHOS_NDK=""
229        cmake -G "Unix Makefiles" \
230              -S ../ \
231              -DCMAKE_TOOLCHAIN_FILE="$OHOS_NDK/build/cmake/ohos.toolchain.cmake" \
232              -DOHOS_ARCH=arm64-v8a \
233              -DCMAKE_BUILD_TYPE=Release
234
235        make
236      ```
237
2382. 运行编译的可执行程序。
239
240    - 使用hdc连接设备,并将train_demo和lenet_train.ms推送到设备中的相同目录。
241    - 使用hdc shell进入设备,并进入train_demo所在的目录执行如下命令,即可得到结果。
242
243    ```shell
244    ./train_demo ./lenet_train.ms export_train_model export_infer_model
245    ```
246
247    得到如下输出:
248
249    ```shell
250    Train Step Success.
251    Export Train Model Success.
252    Export Inference Model Success.
253    Tensor name: Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/BiasAdd-op121, tensor size is 80, elements num: 20.
254    output data is:
255    0.000265 0.000231 0.000254 0.000269 0.000238 0.000228
256    ```
257
258    在train_demo所在目录可以看到导出的两个模型文件:export_train_model.msexport_infer_model.ms259
260
261## 完整示例
262
263```c
264#include <stdlib.h>
265#include <stdio.h>
266#include <string.h>
267#include "mindspore/model.h"
268
269int GenerateInputDataWithRandom(OH_AI_TensorHandleArray inputs) {
270  for (size_t i = 0; i < inputs.handle_num; ++i) {
271    float *input_data = (float *)OH_AI_TensorGetMutableData(inputs.handle_list[i]);
272    if (input_data == NULL) {
273      printf("OH_AI_TensorGetMutableData failed.\n");
274      return  OH_AI_STATUS_LITE_ERROR;
275    }
276    int64_t num = OH_AI_TensorGetElementNum(inputs.handle_list[i]);
277    const int divisor = 10;
278    for (size_t j = 0; j < num; j++) {
279      input_data[j] = (float)(rand() % divisor) / divisor;  // 0--0.9f
280    }
281  }
282  return OH_AI_STATUS_SUCCESS;
283}
284
285int ModelPredict(char* model_file) {
286  // Create and init context, add CPU device info
287  OH_AI_ContextHandle context = OH_AI_ContextCreate();
288  if (context == NULL) {
289    printf("OH_AI_ContextCreate failed.\n");
290    return OH_AI_STATUS_LITE_ERROR;
291  }
292
293  OH_AI_DeviceInfoHandle cpu_device_info = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_CPU);
294  if (cpu_device_info == NULL) {
295    printf("OH_AI_DeviceInfoCreate failed.\n");
296    OH_AI_ContextDestroy(&context);
297    return OH_AI_STATUS_LITE_ERROR;
298  }
299  OH_AI_ContextAddDeviceInfo(context, cpu_device_info);
300
301  // Create model
302  OH_AI_ModelHandle model = OH_AI_ModelCreate();
303  if (model == NULL) {
304    printf("OH_AI_ModelCreate failed.\n");
305    OH_AI_ContextDestroy(&context);
306    return OH_AI_STATUS_LITE_ERROR;
307  }
308
309  // Build model
310  int ret = OH_AI_ModelBuildFromFile(model, model_file, OH_AI_MODELTYPE_MINDIR, context);
311  if (ret != OH_AI_STATUS_SUCCESS) {
312    printf("OH_AI_ModelBuildFromFile failed, ret: %d.\n", ret);
313    OH_AI_ModelDestroy(&model);
314    OH_AI_ContextDestroy(&context);
315    return ret;
316  }
317
318  // Get Inputs
319  OH_AI_TensorHandleArray inputs = OH_AI_ModelGetInputs(model);
320  if (inputs.handle_list == NULL) {
321    printf("OH_AI_ModelGetInputs failed, ret: %d.\n", ret);
322    OH_AI_ModelDestroy(&model);
323    OH_AI_ContextDestroy(&context);
324    return ret;
325  }
326
327  // Generate random data as input data.
328  ret = GenerateInputDataWithRandom(inputs);
329  if (ret != OH_AI_STATUS_SUCCESS) {
330    printf("GenerateInputDataWithRandom failed, ret: %d.\n", ret);
331    OH_AI_ModelDestroy(&model);
332    OH_AI_ContextDestroy(&context);
333    return ret;
334  }
335
336  // Model Predict
337  OH_AI_TensorHandleArray outputs;
338  ret = OH_AI_ModelPredict(model, inputs, &outputs, NULL, NULL);
339  if (ret != OH_AI_STATUS_SUCCESS) {
340    printf("MSModelPredict failed, ret: %d.\n", ret);
341    OH_AI_ModelDestroy(&model);
342    OH_AI_ContextDestroy(&context);
343    return ret;
344  }
345
346  // Print Output Tensor Data.
347  for (size_t i = 0; i < outputs.handle_num; ++i) {
348    OH_AI_TensorHandle tensor = outputs.handle_list[i];
349    int64_t element_num = OH_AI_TensorGetElementNum(tensor);
350    printf("Tensor name: %s, tensor size is %ld ,elements num: %ld.\n", OH_AI_TensorGetName(tensor),
351           OH_AI_TensorGetDataSize(tensor), element_num);
352    const float *data = (const float *)OH_AI_TensorGetData(tensor);
353    printf("output data is:\n");
354    const int max_print_num = 50;
355    for (int j = 0; j < element_num && j <= max_print_num; ++j) {
356      printf("%f ", data[j]);
357    }
358    printf("\n");
359  }
360
361  OH_AI_ModelDestroy(&model);
362  OH_AI_ContextDestroy(&context);
363  return OH_AI_STATUS_SUCCESS;
364}
365
366int TrainDemo(int argc, const char **argv) {
367  if (argc < 4) {
368    printf("Model file must be provided.\n");
369    printf("Export Train Model path must be provided.\n");
370    printf("Export Inference Model path must be provided.\n");
371    return OH_AI_STATUS_LITE_ERROR;
372  }
373  const char *model_file = argv[1];
374  const char *export_train_model = argv[2];
375  const char *export_infer_model = argv[3];
376
377  // Create and init context, add CPU device info
378  OH_AI_ContextHandle context = OH_AI_ContextCreate();
379  if (context == NULL) {
380    printf("OH_AI_ContextCreate failed.\n");
381    return OH_AI_STATUS_LITE_ERROR;
382  }
383
384  OH_AI_DeviceInfoHandle cpu_device_info = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_CPU);
385  if (cpu_device_info == NULL) {
386    printf("OH_AI_DeviceInfoCreate failed.\n");
387    OH_AI_ContextDestroy(&context);
388    return OH_AI_STATUS_LITE_ERROR;
389  }
390  OH_AI_ContextAddDeviceInfo(context, cpu_device_info);
391
392  // Create trainCfg
393  OH_AI_TrainCfgHandle trainCfg = OH_AI_TrainCfgCreate();
394  if (trainCfg == NULL) {
395    printf("OH_AI_TrainCfgCreate failed.\n");
396    OH_AI_ContextDestroy(&context);
397    return OH_AI_STATUS_LITE_ERROR;
398  }
399
400  // Create model
401  OH_AI_ModelHandle model = OH_AI_ModelCreate();
402  if (model == NULL) {
403    printf("OH_AI_ModelCreate failed.\n");
404    OH_AI_TrainCfgDestroy(&trainCfg);
405    OH_AI_ContextDestroy(&context);
406    return OH_AI_STATUS_LITE_ERROR;
407  }
408
409  // Build model
410  int ret = OH_AI_TrainModelBuildFromFile(model, model_file, OH_AI_MODELTYPE_MINDIR, context, trainCfg);
411  if (ret != OH_AI_STATUS_SUCCESS) {
412    printf("OH_AI_TrainModelBuildFromFile failed, ret: %d.\n", ret);
413    OH_AI_ModelDestroy(&model);
414    OH_AI_ContextDestroy(&context);
415    return ret;
416  }
417
418  // Get Inputs
419  OH_AI_TensorHandleArray inputs = OH_AI_ModelGetInputs(model);
420  if (inputs.handle_list == NULL) {
421    printf("OH_AI_ModelGetInputs failed, ret: %d.\n", ret);
422    OH_AI_ModelDestroy(&model);
423    OH_AI_ContextDestroy(&context);
424    return ret;
425  }
426
427  // Generate random data as input data.
428  ret = GenerateInputDataWithRandom(inputs);
429  if (ret != OH_AI_STATUS_SUCCESS) {
430    printf("GenerateInputDataWithRandom failed, ret: %d.\n", ret);
431    OH_AI_ModelDestroy(&model);
432    OH_AI_ContextDestroy(&context);
433    return ret;
434  }
435
436  // Set Traim Mode
437  ret = OH_AI_ModelSetTrainMode(model, true);
438  if (ret != OH_AI_STATUS_SUCCESS) {
439    printf("OH_AI_ModelSetTrainMode failed, ret: %d.\n", ret);
440    OH_AI_ModelDestroy(&model);
441	OH_AI_ContextDestroy(&context);
442    return ret;
443  }
444
445  // Model Train Step
446  ret = OH_AI_RunStep(model, NULL, NULL);
447  if (ret != OH_AI_STATUS_SUCCESS) {
448    printf("OH_AI_RunStep failed, ret: %d.\n", ret);
449    OH_AI_ModelDestroy(&model);
450    OH_AI_ContextDestroy(&context);
451    return ret;
452  }
453  printf("Train Step Success.\n");
454
455  // Export Train Model
456  ret = OH_AI_ExportModel(model, OH_AI_MODELTYPE_MINDIR, export_train_model, OH_AI_NO_QUANT, false, NULL, 0);
457  if (ret != OH_AI_STATUS_SUCCESS) {
458    printf("OH_AI_ExportModel train failed, ret: %d.\n", ret);
459    OH_AI_ModelDestroy(&model);
460    OH_AI_ContextDestroy(&context);
461    return ret;
462  }
463  printf("Export Train Model Success.\n");
464
465  // Export Inference Model
466  ret = OH_AI_ExportModel(model, OH_AI_MODELTYPE_MINDIR, export_infer_model, OH_AI_NO_QUANT, true, NULL, 0);
467  if (ret != OH_AI_STATUS_SUCCESS) {
468    printf("OH_AI_ExportModel inference failed, ret: %d.\n", ret);
469    OH_AI_ModelDestroy(&model);
470    OH_AI_ContextDestroy(&context);
471    return ret;
472  }
473  printf("Export Inference Model Success.\n");
474
475  // Delete model and context.
476  OH_AI_ModelDestroy(&model);
477  OH_AI_ContextDestroy(&context);
478
479  // Use The Exported Model to predict
480  ret = ModelPredict(strcat(export_infer_model, ".ms"));
481  if (ret != OH_AI_STATUS_SUCCESS) {
482    printf("Exported Model to predict failed, ret: %d.\n", ret);
483    return ret;
484  }
485  return OH_AI_STATUS_SUCCESS;
486}
487
488int main(int argc, const char **argv) { return TrainDemo(argc, argv); }
489
490```