1 /*
2  * Copyright (c) 2020-2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "cipher.h"
17 
18 #include <stdbool.h>
19 #include <stdint.h>
20 #include <stdlib.h>
21 #include <string.h>
22 
23 #include "base64.h"
24 #include "cipher_log.h"
25 #include "ctr_drbg.h"
26 #include "entropy.h"
27 #include "md.h"
28 #include "pk.h"
29 #include "rsa.h"
30 #include "securec.h"
31 
32 #define RSA_KEY_BYTE   66
33 #define NUM_FOUR       4
34 #define NUM_THREE      3
35 #define MBEDTLS_RSA_PUBLIC	0 /**< Request private key operation. */
36 #define MBEDTLS_RSA_PRIVATE	1 /**< Request public key operation. */
37 
RsaMallocPrivateKey(const unsigned char * key,size_t * keyLen)38 static char *RsaMallocPrivateKey(const unsigned char *key, size_t *keyLen)
39 {
40     int32_t ret;
41     const char start[] = "-----BEGIN RSA PRIVATE KEY-----\n";
42     const char end[] = "\n-----END RSA PRIVATE KEY-----\n";
43     size_t startLen = strlen(start);
44     size_t endLen = strlen(end);
45     size_t keyFinalLen = *keyLen + startLen + endLen + 1;
46 
47     char *privateKey = malloc(keyFinalLen);
48     if (privateKey == NULL) {
49         return NULL;
50     }
51 
52     (void)memset_s(privateKey, keyFinalLen, 0, keyFinalLen);
53     ret = memcpy_s(privateKey, keyFinalLen, start, startLen);
54     if (ret != EOK) {
55         CIPHER_LOG_E("memcpy failed.");
56         free(privateKey);
57         return NULL;
58     }
59 
60     ret = memcpy_s(privateKey + startLen, keyFinalLen - startLen, key, *keyLen);
61     if (ret != EOK) {
62         CIPHER_LOG_E("memcpy failed.");
63         free(privateKey);
64         return NULL;
65     }
66 
67     ret = memcpy_s(privateKey + startLen + *keyLen, keyFinalLen - startLen - *keyLen, end, endLen);
68     if (ret != EOK) {
69         CIPHER_LOG_E("memcpy failed.");
70         (void)memset_s(privateKey, keyFinalLen, 0, keyFinalLen);
71         free(privateKey);
72         return NULL;
73     }
74 
75     *keyLen = keyFinalLen;
76     return privateKey;
77 }
78 
RsaMallocPublicKey(const unsigned char * key,size_t * keyLen)79 static char *RsaMallocPublicKey(const unsigned char *key, size_t *keyLen)
80 {
81     int32_t ret;
82     const char start[] = "-----BEGIN PUBLIC KEY-----\n";
83     const char end[] = "\n-----END PUBLIC KEY-----\n";
84     size_t startLen = strlen(start);
85     size_t endLen = strlen(end);
86     size_t keyFinalLen = *keyLen + startLen + endLen + 1;
87 
88     char *pubKey = malloc(keyFinalLen);
89     if (pubKey == NULL) {
90         return NULL;
91     }
92 
93     (void)memset_s(pubKey, keyFinalLen, 0, keyFinalLen);
94     ret = memcpy_s(pubKey, keyFinalLen, start, startLen);
95     if (ret != EOK) {
96         CIPHER_LOG_E("memcpy failed.");
97         free(pubKey);
98         return NULL;
99     }
100 
101     ret = memcpy_s(pubKey + startLen, keyFinalLen - startLen, key, *keyLen);
102     if (ret != EOK) {
103         CIPHER_LOG_E("memcpy failed.");
104         free(pubKey);
105         return NULL;
106     }
107 
108     ret = memcpy_s(pubKey + startLen + *keyLen, keyFinalLen - startLen - *keyLen, end, endLen);
109     if (ret != EOK) {
110         CIPHER_LOG_E("memcpy failed.");
111         (void)memset_s(pubKey, keyFinalLen, 0, keyFinalLen);
112         free(pubKey);
113         return NULL;
114     }
115 
116     *keyLen = keyFinalLen;
117     return pubKey;
118 }
119 
RsaInit(mbedtls_ctr_drbg_context * ctrDrbg,mbedtls_entropy_context * entropy)120 static void RsaInit(mbedtls_ctr_drbg_context *ctrDrbg, mbedtls_entropy_context *entropy)
121 {
122     mbedtls_ctr_drbg_init(ctrDrbg);
123     mbedtls_entropy_init(entropy);
124     (void)mbedtls_ctr_drbg_seed(ctrDrbg, mbedtls_entropy_func, entropy, NULL, 0);
125 }
126 
RsaLoadPrivateKey(mbedtls_pk_context * pk,const unsigned char * key,size_t keyLen)127 static int32_t RsaLoadPrivateKey(mbedtls_pk_context *pk, const unsigned char *key, size_t keyLen)
128 {
129     int32_t ret;
130     size_t finalKeyLen = keyLen;
131     mbedtls_rsa_context *rsa = NULL;
132     char *finalKey = RsaMallocPrivateKey(key, &finalKeyLen);
133     if (finalKey == NULL) {
134         CIPHER_LOG_E("malloc private key error, final Key Length:%zu.", finalKeyLen);
135         return ERROR_CODE_GENERAL;
136     }
137 
138     mbedtls_ctr_drbg_context ctrDrbg;
139     mbedtls_entropy_context entropy;
140     RsaInit(&ctrDrbg, &entropy);
141 
142     do {
143         ret = mbedtls_pk_parse_key(pk, (const unsigned char *)finalKey, finalKeyLen, NULL, 0,
144             mbedtls_ctr_drbg_random, &ctrDrbg);
145         if (ret != 0) {
146             CIPHER_LOG_E("parse private key error, ret:%d.", ret);
147             break;
148         }
149 
150         rsa = mbedtls_pk_rsa(*pk);
151         if (rsa == NULL) {
152             CIPHER_LOG_E("rsa error");
153             break;
154         }
155 
156         if (mbedtls_rsa_check_privkey(rsa) != 0) {
157             CIPHER_LOG_E("check private key failed.");
158             break;
159         }
160 
161         /* set padding as OAEPWITHSHA256 */
162         mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V21, MBEDTLS_MD_SHA256);
163         (void)memset_s(finalKey, finalKeyLen, 0, finalKeyLen);
164         free(finalKey);
165         mbedtls_ctr_drbg_free(&ctrDrbg);
166         mbedtls_entropy_free(&entropy);
167         return ERROR_SUCCESS;
168     } while (0);
169 
170     (void)memset_s(finalKey, finalKeyLen, 0, finalKeyLen);
171     free(finalKey);
172     mbedtls_ctr_drbg_free(&ctrDrbg);
173     mbedtls_entropy_free(&entropy);
174     return ERROR_CODE_GENERAL;
175 }
176 
RsaLoadPublicKey(mbedtls_pk_context * pk,const unsigned char * key,size_t keyLen)177 static int32_t RsaLoadPublicKey(mbedtls_pk_context *pk, const unsigned char *key, size_t keyLen)
178 {
179     int32_t ret;
180     size_t finalKeyLen = keyLen;
181     mbedtls_rsa_context *rsa = NULL;
182     char* finalKey = RsaMallocPublicKey(key, &finalKeyLen);
183     if (finalKey == NULL) {
184         CIPHER_LOG_E("malloc public key error, final Key Length:%zu.", finalKeyLen);
185         return ERROR_CODE_GENERAL;
186     }
187 
188     do {
189         ret = mbedtls_pk_parse_public_key(pk, (const unsigned char *)finalKey, finalKeyLen);
190         if (ret != 0) {
191             CIPHER_LOG_E("parse public key error, ret:%d.", ret);
192             break;
193         }
194 
195         rsa = mbedtls_pk_rsa(*pk);
196         if (rsa == NULL) {
197             CIPHER_LOG_E("pk rsa error");
198             break;
199         }
200 
201         if (mbedtls_rsa_check_pubkey(rsa)) {
202             CIPHER_LOG_E("check public key failed.");
203             break;
204         }
205         /* set padding as OAEPWITHSHA256 */
206         mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V21, MBEDTLS_MD_SHA256);
207         (void)memset_s(finalKey, finalKeyLen, 0, finalKeyLen);
208         free(finalKey);
209         return ERROR_SUCCESS;
210     } while (0);
211 
212     (void)memset_s(finalKey, finalKeyLen, 0, finalKeyLen);
213     free(finalKey);
214     return ERROR_CODE_GENERAL;
215 }
216 
RsaDeinit(mbedtls_ctr_drbg_context * ctrDrbg,mbedtls_entropy_context * entropy)217 static void RsaDeinit(mbedtls_ctr_drbg_context *ctrDrbg, mbedtls_entropy_context *entropy)
218 {
219     mbedtls_ctr_drbg_free(ctrDrbg);
220     mbedtls_entropy_free(entropy);
221 }
222 
RsaEncryptBase64Encode(int32_t cipherTotalLen,char * cipherText,int32_t cipherTextLen)223 static int32_t RsaEncryptBase64Encode(int32_t cipherTotalLen, char *cipherText, int32_t cipherTextLen)
224 {
225     if (cipherTotalLen <= 0) {
226         return ERROR_CODE_GENERAL;
227     }
228 
229     char *tempBuf = malloc(cipherTotalLen);
230     if (tempBuf == NULL) {
231         CIPHER_LOG_E("RsaEncrypt Base64Encode malloc fail.");
232         return ERROR_CODE_GENERAL;
233     }
234 
235     int32_t ret = memcpy_s(tempBuf, cipherTotalLen, cipherText, cipherTotalLen);
236     if (ret != EOK) {
237         CIPHER_LOG_E("memcpy fail.");
238         free(tempBuf);
239         return ERROR_CODE_GENERAL;
240     }
241 
242     (void)memset_s(cipherText, cipherTextLen, 0, cipherTextLen);
243     size_t dataLen = 0;
244     ret = mbedtls_base64_encode(NULL, 0, &dataLen, (const unsigned char *)tempBuf, cipherTotalLen);
245     if (ret != MBEDTLS_ERR_BASE64_BUFFER_TOO_SMALL) {
246         CIPHER_LOG_E("base64_encode fail.");
247         free(tempBuf);
248         return ERROR_CODE_GENERAL;
249     }
250 
251     ret = mbedtls_base64_encode((unsigned char *)cipherText, cipherTextLen, &dataLen,
252         (const unsigned char *)tempBuf, cipherTotalLen);
253     if (ret != 0) {
254         CIPHER_LOG_E("base64_encode fail.");
255         free(tempBuf);
256         return ERROR_CODE_GENERAL;
257     }
258     free(tempBuf);
259     return ERROR_SUCCESS;
260 }
261 
RsaEncryptMultipleBlock(mbedtls_rsa_context * rsa,const char * plainText,char * cipherText,int32_t cipherTextLen)262 static int32_t RsaEncryptMultipleBlock(mbedtls_rsa_context *rsa, const char *plainText,
263     char *cipherText, int32_t cipherTextLen)
264 {
265     mbedtls_ctr_drbg_context ctrDrbg;
266     mbedtls_entropy_context entropy;
267     int32_t rsaLen = mbedtls_rsa_get_len(rsa);
268     int32_t rsaContentLen = rsaLen - RSA_KEY_BYTE;
269     if ((rsaContentLen <= 0) || (rsaLen <= 0)) {
270         CIPHER_LOG_E("rsa content len:%d, rsaLen:%d.", rsaContentLen, rsaLen);
271         return ERROR_CODE_GENERAL;
272     }
273     int32_t count = strlen((const char *)(uintptr_t)plainText) / rsaContentLen;
274     int32_t remain = strlen((const char *)(uintptr_t)plainText) % rsaContentLen;
275     int32_t cipherTotalLen = 0;
276     unsigned char *buf = (unsigned char *)malloc(rsaLen);
277     if (buf == NULL) {
278         return ERROR_CODE_GENERAL;
279     }
280     int32_t ret = ERROR_CODE_GENERAL;
281     do {
282         RsaInit(&ctrDrbg, &entropy);
283         bool isBreak = false;
284         for (int32_t i = 0; i < count; i++) {
285             (void)memset_s(buf, rsaLen, 0, rsaLen);
286             if (mbedtls_rsa_pkcs1_encrypt(rsa, mbedtls_ctr_drbg_random, &ctrDrbg,
287                 rsaContentLen, (const unsigned char *)(plainText + i * rsaContentLen), buf)) {
288                 isBreak = true;
289                 break;
290             }
291             if (memcpy_s(cipherText + i * rsaLen, cipherTextLen - i * rsaLen, buf, rsaLen)) {
292                 isBreak = true;
293                 break;
294             }
295             cipherTotalLen += rsaLen;
296         }
297         if (isBreak) {
298             break;
299         }
300         if (remain > 0) {
301             (void)memset_s(buf, rsaLen, 0, rsaLen);
302             if (mbedtls_rsa_pkcs1_encrypt(rsa, mbedtls_ctr_drbg_random, &ctrDrbg,
303                 remain, (const unsigned char *)(plainText + count * rsaContentLen), buf)) {
304                 break;
305             }
306             if (memcpy_s(cipherText + count * rsaLen, cipherTextLen - count * rsaLen, buf, rsaLen)) {
307                 break;
308             }
309             cipherTotalLen += rsaLen;
310         }
311         if (RsaEncryptBase64Encode(cipherTotalLen, cipherText, cipherTextLen)) {
312             break;
313         }
314         ret = ERROR_SUCCESS;
315     } while (0);
316 
317     free(buf);
318     RsaDeinit(&ctrDrbg, &entropy);
319     return ret;
320 }
321 
RsaEncrypt(RsaKeyData * key,const RsaData * plain,RsaData * cipher)322 static int32_t RsaEncrypt(RsaKeyData *key, const RsaData *plain, RsaData *cipher)
323 {
324     if ((key->trans != NULL) && (strcmp(key->trans, "RSA/None/OAEPWithSHA256AndMGF1Padding"))) {
325         return ERROR_CODE_GENERAL;
326     }
327 
328     mbedtls_pk_context pk;
329     mbedtls_pk_init(&pk);
330     if (RsaLoadPublicKey(&pk, (const unsigned char *)key->key, key->keyLen) != 0) {
331         mbedtls_pk_free(&pk);
332         return ERROR_CODE_GENERAL;
333     }
334 
335     mbedtls_rsa_context *rsa = mbedtls_pk_rsa(pk);
336     if (rsa == NULL) {
337         mbedtls_pk_free(&pk);
338         return ERROR_CODE_GENERAL;
339     }
340 
341     size_t rsaLen = mbedtls_rsa_get_len(rsa);
342     size_t rsaContentLen = rsaLen - RSA_KEY_BYTE;
343     if (rsaContentLen <= 0) {
344         mbedtls_pk_free(&pk);
345         return ERROR_CODE_GENERAL;
346     }
347 
348     size_t count = plain->length / rsaContentLen;
349     size_t remain = plain->length % rsaContentLen;
350     if (cipher->data == NULL) {
351         cipher->length = rsaLen * count + (remain ? rsaLen : 0);
352         cipher->length = (cipher->length / NUM_THREE + 1) * NUM_FOUR + 1;
353         mbedtls_pk_free(&pk);
354         return ERROR_SUCCESS;
355     }
356 
357     if (RsaEncryptMultipleBlock(rsa, plain->data, cipher->data, cipher->length) != 0) {
358         CIPHER_LOG_E("Rsa encrypt block error.");
359         mbedtls_pk_free(&pk);
360         return ERROR_CODE_GENERAL;
361     }
362 
363     mbedtls_pk_free(&pk);
364     return ERROR_SUCCESS;
365 }
366 
CheckParamAndMallocBuf(size_t rsaLen,const RsaData * cipher,unsigned char ** buf,unsigned char ** tembuf)367 static int32_t CheckParamAndMallocBuf(size_t rsaLen, const RsaData *cipher, unsigned char **buf, unsigned char **tembuf)
368 {
369     if ((rsaLen == 0) || (cipher->length == 0)) {
370         return ERROR_CODE_GENERAL;
371     }
372     *buf = (unsigned char*)malloc(rsaLen);
373     if (*buf == NULL) {
374         return ERROR_CODE_GENERAL;
375     }
376     *tembuf = (unsigned char*)malloc(cipher->length);
377     if (*tembuf == NULL) {
378         free(*buf);
379         *buf = NULL;
380         return ERROR_CODE_GENERAL;
381     }
382     return ERROR_SUCCESS;
383 }
384 
RsaPkcs1Decrypt(mbedtls_rsa_context * rsa,size_t rsaLen,RsaData * cipher,RsaData * plain)385 static int32_t RsaPkcs1Decrypt(mbedtls_rsa_context *rsa, size_t rsaLen, RsaData *cipher, RsaData *plain)
386 {
387     size_t plainLen = 0;
388     size_t totalPlainLen = 0;
389 
390     unsigned char *buf = NULL;
391     unsigned char *tembuf = NULL;
392 
393     int32_t ret = CheckParamAndMallocBuf(rsaLen, cipher, &buf, &tembuf);
394     if (ret != ERROR_SUCCESS) {
395         return ret;
396     }
397 
398     (void)memset_s(tembuf, cipher->length, 0, cipher->length);
399     mbedtls_ctr_drbg_context ctrDrbg;
400     mbedtls_entropy_context entropy;
401     RsaInit(&ctrDrbg, &entropy);
402     size_t dataLen;
403 
404     do {
405         if (mbedtls_base64_decode(tembuf, cipher->length, &dataLen, (const unsigned char *)cipher->data,
406             cipher->length)) {
407             break;
408         }
409         size_t count = dataLen / rsaLen;
410         bool isBreak = false;
411         for (size_t i = 0; i < count; i++) {
412             (void)memset_s(buf, rsaLen, 0, rsaLen);
413             if (mbedtls_rsa_pkcs1_decrypt(rsa, mbedtls_ctr_drbg_random, &ctrDrbg,
414                 &plainLen, tembuf + i * rsaLen, buf, rsaLen)) {
415                 isBreak = true;
416                 break;
417             }
418             if (memcpy_s(plain->data + totalPlainLen, plain->length - totalPlainLen, buf, plainLen)) {
419                 isBreak = true;
420                 break;
421             }
422             totalPlainLen += plainLen;
423         }
424         if (isBreak) {
425             break;
426         }
427         plain->length = totalPlainLen;
428         RsaDeinit(&ctrDrbg, &entropy);
429         free(tembuf);
430         free(buf);
431         return ERROR_SUCCESS;
432     } while (0);
433 
434     RsaDeinit(&ctrDrbg, &entropy);
435     free(tembuf);
436     free(buf);
437     return ERROR_CODE_GENERAL;
438 }
439 
RsaDecrypt(RsaKeyData * key,RsaData * cipher,RsaData * plain)440 static int32_t RsaDecrypt(RsaKeyData *key, RsaData *cipher, RsaData *plain)
441 {
442     if ((key->trans != NULL) && (strcmp(key->trans, "RSA/None/OAEPWithSHA256AndMGF1Padding"))) {
443         return ERROR_CODE_GENERAL;
444     }
445 
446     if (plain->data == NULL) {
447         plain->length = cipher->length;
448         return ERROR_SUCCESS;
449     }
450 
451     mbedtls_pk_context pk;
452     mbedtls_pk_init(&pk);
453     if (RsaLoadPrivateKey(&pk, (const unsigned char *)key->key, key->keyLen) != 0) {
454         mbedtls_pk_free(&pk);
455         return ERROR_CODE_GENERAL;
456     }
457 
458     mbedtls_rsa_context *rsa = mbedtls_pk_rsa(pk);
459     size_t rsaLen = mbedtls_rsa_get_len(rsa);
460     int32_t ret = RsaPkcs1Decrypt(rsa, rsaLen, cipher, plain);
461     if (ret != ERROR_SUCCESS) {
462         CIPHER_LOG_E("Rsa pkcs1 decrypt failed.");
463         mbedtls_pk_free(&pk);
464         return ERROR_CODE_GENERAL;
465     }
466 
467     mbedtls_pk_free(&pk);
468     return ERROR_SUCCESS;
469 }
470 
RsaCrypt(RsaKeyData * key,RsaData * inData,RsaData * outData)471 int32_t RsaCrypt(RsaKeyData *key, RsaData *inData, RsaData *outData)
472 {
473     if (key == NULL || inData == NULL || outData == NULL) {
474         return ERROR_CODE_GENERAL;
475     }
476 
477     if ((key->action == NULL) || (key->key == NULL) || (inData->data == NULL)) {
478         return ERROR_CODE_GENERAL;
479     }
480 
481     if (!strcmp(key->action, "encrypt")) {
482         return RsaEncrypt(key, inData, outData);
483     } else if (!strcmp(key->action, "decrypt")) {
484         return RsaDecrypt(key, inData, outData);
485     } else {
486         return ERROR_CODE_GENERAL;
487     }
488 }
489