1 /*
2  * Copyright (c) 2020 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *    http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "sts_client.h"
17 
18 #if !(defined(_CUT_STS_) || defined(_CUT_STS_CLIENT_))
19 #include <string.h>
20 #include "securec.h"
21 #include "huks_adapter.h"
22 #include "log.h"
23 #include "commonutil.h"
24 #include "distribution.h"
25 
26 static int32_t build_start_request_data(void *handle, void *data);
27 static int32_t parse_start_response_data(void *handle, void *data);
28 static int32_t build_end_request_data(void *handle, void *data);
29 static int32_t parse_end_response_data(void *handle, void *data);
30 
31 static void generate_output_key(struct sts_client *sts_client);
32 
build_sts_client(const hc_handle hichain_handle,uint32_t key_length,const struct hc_auth_id * client,const struct hc_auth_id * server)33 struct sts_client *build_sts_client(const hc_handle hichain_handle, uint32_t key_length,
34     const struct hc_auth_id *client, const struct hc_auth_id *server)
35 {
36     struct sts_client *sts_client = (struct sts_client *)MALLOC(sizeof(struct sts_client));
37     if (sts_client == NULL) {
38         LOGE("Build sts client object failed");
39         return NULL;
40     }
41     (void)memset_s(sts_client, sizeof(*sts_client), 0, sizeof(*sts_client));
42 
43     struct client_virtual_func_group funcs = { build_start_request_data, parse_start_response_data,
44                                                build_end_request_data, parse_end_response_data };
45 
46     init_client(&sts_client->client_info, &funcs);
47     DBG_OUT("Build sts client object %u success", sts_client_sn(sts_client));
48 
49     sts_client->self_id = *client;
50     sts_client->peer_id = *server;
51     sts_client->identity = &((struct hichain *)hichain_handle)->identity;
52     sts_client->key_length = key_length;
53     sts_client->hichain_handle = hichain_handle;
54 
55     return sts_client;
56 }
57 
destroy_sts_client(struct sts_client * handle)58 void destroy_sts_client(struct sts_client *handle)
59 {
60     if (handle == NULL) {
61         DBG_OUT("Destroy sts client object failed");
62         return;
63     }
64     (void)memset_s(&handle->self_private_key, sizeof(struct stsk), 0, sizeof(struct stsk));
65     (void)memset_s(&handle->session_key, sizeof(struct sts_session_key), 0, sizeof(struct sts_session_key));
66     (void)memset_s(&handle->service_key, sizeof(struct hc_session_key), 0, sizeof(struct hc_session_key));
67     FREE(handle);
68     DBG_OUT("FREE sts client object success");
69 }
70 
build_start_request_data(void * handle,void * data)71 static int32_t build_start_request_data(void *handle, void *data)
72 {
73     struct sts_client *sts_client = (struct sts_client *)handle;
74     struct sts_start_request_data *send_data = (struct sts_start_request_data *)data;
75     struct st_key_pair key_pair;
76     int32_t ret = generate_st_key_pair(&key_pair);
77     if (ret != HC_OK) {
78         LOGE("Object %u generate_st_key_pair failed, error code is %d", sts_client_sn(sts_client), ret);
79         return HC_INPUT_ERROR;
80     }
81 
82     sts_client->self_private_key = key_pair.st_private_key;
83     sts_client->self_public_key = key_pair.st_public_key;
84 
85     struct random_value random_value = generate_random(CHALLENGE_BUFF_LENGTH);
86     if (memcpy_s(sts_client->my_challenge.challenge, sizeof(sts_client->my_challenge.challenge),
87         random_value.random_value, random_value.length) != EOK) {
88         return memory_copy_error(__func__, __LINE__);
89     }
90     sts_client->my_challenge.length = random_value.length;
91 
92     send_data->peer_version.first = 1;
93     send_data->peer_version.second = 0;
94     send_data->peer_version.third = 0;
95     send_data->peer_support_version.first = 1;
96     send_data->peer_support_version.second = 0;
97     send_data->peer_support_version.third = 0;
98     send_data->operation_code = sts_client->operation_code;
99     send_data->epk = sts_client->self_public_key;
100     send_data->challenge = sts_client->my_challenge;
101     send_data->package_name = sts_client->identity->package_name;
102     send_data->service_type = sts_client->identity->service_type;
103     send_data->self_auth_id = sts_client->self_id;
104     send_data->peer_user_type = ((struct hichain *)sts_client->hichain_handle)->type == HC_CENTRE ?
105         HC_USER_TYPE_CONTROLLER : HC_USER_TYPE_ACCESSORY;
106     send_data->key_length = sts_client->key_length;
107 
108     return HC_OK;
109 }
110 
parse_start_response_data(void * handle,void * data)111 static int32_t parse_start_response_data(void *handle, void *data)
112 {
113     struct sts_client *sts_client = (struct sts_client *)handle;
114     struct sts_start_response_data *receive = (struct sts_start_response_data *)data;
115 
116     sts_client->salt = receive->salt;
117     sts_client->peer_public_key = receive->epk;
118     sts_client->peer_challenge = receive->challenge;
119     sts_client->peer_auth_data = receive->auth_data;
120     sts_client->peer_user_type = receive->peer_user_type;
121 
122     struct sts_shared_secret shared_secret;
123 
124     int32_t ret = compute_sts_shared_secret(&sts_client->self_private_key,
125                                             &sts_client->peer_public_key, &shared_secret);
126     if (ret != HC_OK) {
127         LOGE("Object %u compute_shared_secret failed, error code is %d", sts_client_sn(sts_client), ret);
128         return ret;
129     }
130 
131     ret = compute_hkdf((struct var_buffer *)&shared_secret, &sts_client->salt, HICHAIN_AUTH_INFO,
132                        STS_SESSION_KEY_LENGTH, (struct var_buffer *)&sts_client->session_key);
133     (void)memset_s(&shared_secret, sizeof(struct sts_shared_secret), 0, sizeof(struct sts_shared_secret));
134     if (ret != HC_OK) {
135         LOGE("Object %u compute_hkdf failed, error code is %d", sts_client_sn(sts_client), ret);
136         return HC_STS_OBJECT_ERROR;
137     }
138 
139     return HC_OK;
140 }
141 
generate_sign_message(void * handle,struct uint8_buff * message)142 static int32_t generate_sign_message(void *handle, struct uint8_buff *message)
143 {
144     DBG_OUT("Called generate sign message");
145     check_ptr_return_val(handle, HC_INPUT_ERROR);
146     check_ptr_return_val(message, HC_INPUT_ERROR);
147     struct sts_client *sts_client = (struct sts_client *)handle;
148 
149     uint32_t len = sts_client->peer_public_key.length + sts_client->peer_id.length +
150               sts_client->self_public_key.length + sts_client->self_id.length;
151     uint8_t *info = (uint8_t *)MALLOC(len);
152     if (info == NULL) {
153         LOGE("Malloc info failed");
154         return HC_MALLOC_FAILED;
155     }
156 
157     uint32_t pos = 0;
158     (void)memcpy_s(info + pos, len - pos, sts_client->peer_public_key.stpk, sts_client->peer_public_key.length);
159     pos += sts_client->peer_public_key.length;
160     (void)memcpy_s(info + pos, len - pos, sts_client->peer_id.auth_id, sts_client->peer_id.length);
161     pos += sts_client->peer_id.length;
162     (void)memcpy_s(info + pos, len - pos, sts_client->self_public_key.stpk, sts_client->self_public_key.length);
163     pos += sts_client->self_public_key.length;
164     (void)memcpy_s(info + pos, len - pos, sts_client->self_id.auth_id, sts_client->self_id.length);
165 
166     message->val = info;
167     message->length = len;
168     message->size = len;
169     return HC_OK;
170 }
171 
verify_response_data(void * handle,const struct uint8_buff * message,struct signature * signature)172 static int32_t verify_response_data(void *handle, const struct uint8_buff *message, struct signature *signature)
173 {
174     DBG_OUT("Called verify request data");
175     check_ptr_return_val(handle, HC_INPUT_ERROR);
176     check_ptr_return_val(message, HC_INPUT_ERROR);
177     check_ptr_return_val(signature, HC_INPUT_ERROR);
178     struct sts_client *sts_client = (struct sts_client *)handle;
179 
180     struct hichain *hichain_handle = sts_client->hichain_handle;
181     enum huks_key_alias_type alias_type;
182 
183     if (hichain_handle->type == HC_CENTRE) {
184         if (sts_client->peer_user_type == HC_USER_TYPE_CONTROLLER) { /* center(as phone identity) -> phone */
185             alias_type = KEY_ALIAS_LT_KEY_PAIR;
186         } else { /* center -> accessory */
187             alias_type = KEY_ALIAS_ACCESSOR_PK;
188         }
189     } else { /* accessory -> center/phone */
190         alias_type = KEY_ALIAS_CONTROLLER_PK;
191     }
192 
193     struct service_id service_id = generate_service_id(sts_client->identity);
194     if (service_id.length == 0) {
195         LOGE("Generate service id failed");
196         return HC_GEN_SERVICE_ID_FAILED;
197     }
198     struct hc_key_alias key_alias = generate_key_alias(&service_id, &sts_client->peer_id, alias_type);
199     if (key_alias.length == 0) {
200         LOGE("Generate key alias failed");
201         return HC_GEN_ALIAS_FAILED;
202     }
203 
204     int32_t ret = verify(&key_alias, sts_client->peer_user_type, message, signature);
205     if (ret != HC_OK) {
206         LOGE("Object %u verify failed, error code is %d", sts_client_sn(sts_client), ret);
207         return HC_VERIFY_PROOF_FAILED;
208     }
209     return HC_OK;
210 }
211 
generate_sts_request_sign(void * handle,struct signature * signature)212 static int32_t generate_sts_request_sign(void *handle, struct signature *signature)
213 {
214     struct sts_client *sts_client = (struct sts_client *)handle;
215 
216     uint32_t len = sts_client->self_public_key.length + sts_client->self_id.length +
217                   sts_client->peer_public_key.length + sts_client->peer_id.length;
218     uint8_t *info = (uint8_t *)MALLOC(len);
219     if (info == NULL) {
220         LOGE("Malloc info failed");
221         return HC_MALLOC_FAILED;
222     }
223 
224     uint32_t pos = 0;
225     (void)memcpy_s(info + pos, len - pos, sts_client->self_public_key.stpk, sts_client->self_public_key.length);
226     pos += sts_client->self_public_key.length;
227     (void)memcpy_s(info + pos, len - pos, sts_client->self_id.auth_id, sts_client->self_id.length);
228     pos += sts_client->self_id.length;
229     (void)memcpy_s(info + pos, len - pos, sts_client->peer_public_key.stpk, sts_client->peer_public_key.length);
230     pos += sts_client->peer_public_key.length;
231     (void)memcpy_s(info + pos, len - pos, sts_client->peer_id.auth_id, sts_client->peer_id.length);
232 
233     struct service_id service_id = generate_service_id(sts_client->identity);
234     if (service_id.length == 0) {
235         LOGE("Generate service id failed");
236         FREE(info);
237         return HC_GEN_SERVICE_ID_FAILED;
238     }
239 #if (defined(_SUPPORT_SEC_CLONE_) || defined(_SUPPORT_SEC_CLONE_SERVER_))
240     struct hc_key_alias key_alias = generate_key_alias(&service_id, &sts_client->self_id, KEY_ALIAS_LT_KEY_PAIR);
241 #else
242     struct hc_key_alias key_alias = generate_key_alias(&service_id, &sts_client->self_id, KEY_ALIAS_ACCESSOR_PK);
243 #endif
244     if (key_alias.length == 0) {
245         LOGE("Generate key alias failed");
246         FREE(info);
247         return HC_GEN_ALIAS_FAILED;
248     }
249     struct uint8_buff sign_message = { info, len, len };
250     int32_t ret = sign(&key_alias, &sign_message, signature);
251     if (ret != HC_OK) {
252         LOGE("Object %u sign failed, error code is %d", sts_client_sn(sts_client), ret);
253         FREE(info);
254         return HC_SIGN_EXCHANGE_FAILED;
255     }
256 
257     FREE(info);
258     return ret;
259 }
260 
init_auth_data(struct uint8_buff * auth_data)261 static int32_t init_auth_data(struct uint8_buff *auth_data)
262 {
263     auth_data->size = HC_AUTH_DATA_BUFF_LEN;
264     auth_data->val = (uint8_t *)MALLOC(auth_data->size);
265     if (auth_data->val == NULL) {
266         LOGE("Malloc failed");
267         return HC_MALLOC_FAILED;
268     }
269     auth_data->length = 0;
270     (void)memset_s(auth_data->val, auth_data->size, 0, auth_data->size);
271     return HC_OK;
272 }
273 
init_signature(void * handle,struct signature * signature)274 static int32_t init_signature(void *handle, struct signature *signature)
275 {
276     struct sts_client *sts_client = (struct sts_client *)handle;
277     struct aes_aad aes_aad;
278 
279     if (memcpy_s(aes_aad.aad, sizeof(aes_aad.aad), sts_client->my_challenge.challenge,
280                  sts_client->my_challenge.length) != EOK) {
281         return memory_copy_error(__func__, __LINE__);
282     }
283 
284     aes_aad.length = sts_client->my_challenge.length;
285 
286     struct uint8_buff out_plain = { 0, 0, 0 };
287 
288     out_plain.val = (uint8_t *)MALLOC(sts_client->peer_auth_data.length);
289     if (out_plain.val == NULL) {
290         LOGE("Malloc peer_auth_data failed");
291         return HC_MALLOC_FAILED;
292     }
293     (void)memset_s(out_plain.val, sts_client->peer_auth_data.length, 0, sts_client->peer_auth_data.length);
294     out_plain.size = sts_client->peer_auth_data.length;
295 
296     struct uint8_buff auth_data = {sts_client->peer_auth_data.auth_data, sts_client->peer_auth_data.length,
297                                    sts_client->peer_auth_data.length};
298 
299     int32_t ret = aes_gcm_decrypt((struct var_buffer *)&sts_client->session_key, &auth_data, &aes_aad, &out_plain);
300     if (ret != HC_OK) {
301         FREE(out_plain.val);
302         LOGE("Object %u aes_gcm_decrypt failed, error code is %d", sts_client_sn(sts_client), ret);
303         return HC_DECRYPT_FAILED;
304     }
305 
306     if (memcpy_s(signature->signature, sizeof(signature->signature), out_plain.val, out_plain.length) != EOK) {
307         FREE(out_plain.val);
308         return memory_copy_error(__func__, __LINE__);
309     }
310     signature->length = out_plain.length;
311     FREE(out_plain.val);
312     return HC_OK;
313 }
314 
verify_data(void * handle)315 static int32_t verify_data(void *handle)
316 {
317     struct signature signature = { 0, {0} };
318     int32_t ret = init_signature(handle, &signature);
319     if (ret != HC_OK) {
320         return ret;
321     }
322 
323     struct uint8_buff message;
324     (void)memset_s(&message, sizeof(message), 0, sizeof(message));
325     ret = generate_sign_message(handle, &message);
326     if (ret != HC_OK) {
327         return ret;
328     }
329 
330     ret = verify_response_data(handle, &message, &signature);
331     FREE(message.val);
332     message.val = NULL;
333     if (ret != HC_OK) {
334         return ret;
335     }
336     return HC_OK;
337 }
338 
build_end_request_data(void * handle,void * data)339 static int32_t build_end_request_data(void *handle, void *data)
340 {
341     struct sts_client *sts_client = (struct sts_client *)handle;
342     struct sts_end_request_data *send = (struct sts_end_request_data *)data;
343 
344     int32_t ret = verify_data(handle);
345     if (ret != HC_OK) {
346         return ret;
347     }
348 
349     struct signature request_sign = { 0, {0} };
350     ret = generate_sts_request_sign(handle, &request_sign);
351     if (ret != HC_OK) {
352         return ret;
353     }
354 
355     struct uint8_buff out_auth_data;
356     ret = init_auth_data(&out_auth_data);
357     if (ret != HC_OK) {
358         return ret;
359     }
360 
361     struct aes_aad aes_aad;
362     if (memcpy_s(aes_aad.aad, sizeof(aes_aad.aad), sts_client->peer_challenge.challenge,
363         sts_client->peer_challenge.length) != EOK) {
364         FREE(out_auth_data.val);
365         return memory_copy_error(__func__, __LINE__);
366     }
367     aes_aad.length = sts_client->peer_challenge.length;
368 
369     struct uint8_buff plain = {request_sign.signature, request_sign.length, request_sign.length};
370     ret = aes_gcm_encrypt((struct var_buffer *)&sts_client->session_key, &plain, &aes_aad, &out_auth_data);
371     if (ret != HC_OK) {
372         FREE(out_auth_data.val);
373         LOGE("Object %u aes_gcm_encrypt failed, error code is %d", sts_client_sn(sts_client), ret);
374         return HC_ENCRYPT_FAILED;
375     }
376 
377     if (memcpy_s(send->auth_data.auth_data, sizeof(send->auth_data.auth_data),
378         out_auth_data.val, out_auth_data.length) != EOK) {
379         FREE(out_auth_data.val);
380         return memory_copy_error(__func__, __LINE__);
381     }
382     send->auth_data.length = out_auth_data.length;
383     FREE(out_auth_data.val);
384 
385     return HC_OK;
386 }
387 
parse_end_response_data(void * handle,void * data)388 static int32_t parse_end_response_data(void *handle, void *data)
389 {
390     struct sts_client *sts_client = (struct sts_client *)handle;
391     struct sts_end_response_data *receive = (struct sts_end_response_data *)data;
392 
393     struct uint8_buff auth_ret;
394     (void)memset_s(&auth_ret, sizeof(auth_ret), 0, sizeof(auth_ret));
395     auth_ret.val = (uint8_t *)MALLOC(receive->auth_return.length);
396     if (auth_ret.val == NULL) {
397         LOGE("Malloc auth_ret.val failed");
398         return HC_MALLOC_FAILED;
399     }
400     (void)memset_s(auth_ret.val, receive->auth_return.length, 0, receive->auth_return.length);
401     auth_ret.size = receive->auth_return.length;
402 
403     struct aes_aad aes_aad;
404     if (memcpy_s(aes_aad.aad, sizeof(aes_aad.aad), sts_client->my_challenge.challenge,
405         sts_client->my_challenge.length) != EOK) {
406         FREE(auth_ret.val);
407         return memory_copy_error(__func__, __LINE__);
408     }
409     aes_aad.length = sts_client->peer_challenge.length;
410     struct uint8_buff cipher = { receive->auth_return.auth_return,
411                                  receive->auth_return.length, receive->auth_return.length };
412 
413     int32_t ret = aes_gcm_decrypt((struct var_buffer *)&sts_client->session_key, &cipher, &aes_aad, &auth_ret);
414     if (ret != HC_OK) {
415         FREE(auth_ret.val);
416         LOGE("Object %u aes_gcm_encrypt failed, error code is %d", sts_client_sn(sts_client), ret);
417         return false;
418     }
419     FREE(auth_ret.val);
420     generate_output_key(sts_client);
421 
422     return ret;
423 }
424 
send_sts_start_request(struct sts_client * sts_client,struct message * send)425 int32_t send_sts_start_request(struct sts_client *sts_client, struct message *send)
426 {
427     check_ptr_return_val(sts_client, HC_INPUT_ERROR);
428     check_ptr_return_val(send, HC_INPUT_ERROR);
429     struct sts_start_request_data *send_data =
430         (struct sts_start_request_data *)MALLOC(sizeof(struct sts_start_request_data));
431     if (send_data == NULL) {
432         LOGE("Malloc struct STS_START_REQUEST_DATA failed");
433         return HC_MALLOC_FAILED;
434     }
435     (void)memset_s(send_data, sizeof(*send_data), 0, sizeof(*send_data));
436 
437     int32_t ret = send_start_request(sts_client, send_data);
438     if (ret != HC_OK) {
439         LOGE("Called send_start_request failed, error code is %d", ret);
440         FREE(send_data);
441         send->msg_code = INFORM_MESSAGE;
442     } else {
443         DBG_OUT("Called send_start_request success");
444         send->msg_code = AUTH_START_REQUEST;
445         send->payload = send_data;
446     }
447 
448     return ret;
449 }
450 
send_sts_end_request(struct sts_client * sts_client,const struct message * receive,struct message * send)451 int32_t send_sts_end_request(struct sts_client *sts_client, const struct message *receive, struct message *send)
452 {
453     DBG_OUT("Receive data send_sts_start_response");
454     check_ptr_return_val(sts_client, HC_INPUT_ERROR);
455     check_ptr_return_val(receive, HC_INPUT_ERROR);
456     check_ptr_return_val(send, HC_INPUT_ERROR);
457     struct sts_start_response_data *receive_data = (struct sts_start_response_data *)receive->payload;
458 
459     struct sts_end_request_data *send_data =
460         (struct sts_end_request_data *)MALLOC(sizeof(struct sts_end_request_data));
461     if (send_data == NULL) {
462         LOGE("Malloc struct STS_END_REQUEST_DATA failed");
463         return HC_MALLOC_FAILED;
464     }
465     (void)memset_s(send_data, sizeof(*send_data), 0, sizeof(*send_data));
466 
467     int32_t ret = send_end_request(sts_client, receive_data, send_data);
468     if (ret != HC_OK) {
469         LOGE("Called send_end_request failed, error code is %d", ret);
470         FREE(send_data);
471         send->msg_code = INFORM_MESSAGE;
472     } else {
473         DBG_OUT("Called send_end_request success");
474         send->msg_code = AUTH_ACK_REQUEST;
475         send->payload = send_data;
476     }
477 
478     return ret;
479 }
480 
receive_sts_end_response(struct sts_client * sts_client,struct message * receive)481 int32_t receive_sts_end_response(struct sts_client *sts_client, struct message *receive)
482 {
483     DBG_OUT("Receive sts end response data");
484     check_ptr_return_val(sts_client, HC_INPUT_ERROR);
485     check_ptr_return_val(receive, HC_INPUT_ERROR);
486     struct sts_end_response_data *receive_data = (struct sts_end_response_data *)receive->payload;
487     int32_t ret = receive_end_response(sts_client, receive_data);
488     if (ret != HC_OK) {
489         LOGE("Called receive_end_response failed, error code is %d", ret);
490         receive->msg_code = INFORM_MESSAGE;
491     } else {
492         DBG_OUT("Called receive_end_response success");
493         receive->msg_code = AUTH_ACK_RESPONSE;
494         receive->payload = receive_data;
495     }
496 
497     return ret;
498 }
499 
generate_output_key(struct sts_client * sts_client)500 static void generate_output_key(struct sts_client *sts_client)
501 {
502     DBG_OUT("Start sts_client generate output key");
503     int32_t ret = compute_hkdf((struct var_buffer *)&sts_client->session_key,
504                                &sts_client->salt,
505                                HICHAIN_RETURN_KEY,
506                                sts_client->key_length,
507                                (struct var_buffer *)&sts_client->service_key);
508     if (ret != HC_OK) {
509         LOGE("Object %u generate output key failed, error code is %d", sts_client_sn(sts_client), ret);
510         return;
511     } else {
512         DBG_OUT("Sts client generate output key success");
513         return;
514     }
515 }
516 
517 #else /* _CUT_XXX_ */
build_sts_client(const hc_handle hichain_handle,uint32_t key_length,const struct hc_auth_id * client,const struct hc_auth_id * server)518 struct sts_client *build_sts_client(const hc_handle hichain_handle, uint32_t key_length,
519     const struct hc_auth_id *client, const struct hc_auth_id *server)
520 {
521     (void)hichain_handle;
522     (void)key_length;
523     (void)client;
524     (void)server;
525     return NULL;
526 }
527 
destroy_sts_client(struct sts_client * handle)528 void destroy_sts_client(struct sts_client *handle)
529 {
530     (void)handle;
531     return;
532 }
533 
534 #endif /* _CUT_XXX_ */
535 
536