1 /*
2  * Copyright (C) 2021 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *    http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "crypto_hash_to_point.h"
17 #include <openssl/bn.h>
18 #include <openssl/evp.h>
19 #include <openssl/rand.h>
20 #include "hal_error.h"
21 #include "hc_log.h"
22 #include "hc_types.h"
23 #include "hks_type.h"
24 
25 #define KEY_BYTES_CURVE25519                 32
26 
27 struct CurveConstPara {
28     BIGNUM *p;
29     BIGNUM *one;
30     BIGNUM *d;
31     BIGNUM *k;
32     BIGNUM *capitalA;
33     BIGNUM *minusA;
34     BIGNUM *u;
35     BIGNUM *q;
36 };
37 
38 /* RFC 8032, the prime of Curve25519, p = 2^255-19 */
39 static const uint8_t g_curveParamP[KEY_BYTES_CURVE25519] = {
40     0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
41     0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xed
42 };
43 
44 /* RFC 8032, one = 1 */
45 static const uint8_t g_curveParamOne[KEY_BYTES_CURVE25519] = {
46     0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
47     0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01
48 };
49 
50 /* RFC 8032, A non-zero element in the finite field GF(p), not equal to 1 */
51 static const uint8_t g_curveParamD[KEY_BYTES_CURVE25519] = {
52     0xa3, 0x78, 0x59, 0x13, 0xca, 0x4d, 0xeb, 0x75, 0xab, 0xd8, 0x41, 0x41, 0x4d, 0x0a, 0x70, 0x00,
53     0x98, 0xe8, 0x79, 0x77, 0x79, 0x40, 0xc7, 0x8c, 0x73, 0xfe, 0x6f, 0x2b, 0xee, 0x6c, 0x03, 0x52
54 };
55 
56 /* RFC 8032, k = (p - 1) / 2 */
57 static const uint8_t g_curveParamK[KEY_BYTES_CURVE25519] = {
58     0xf6, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
59     0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x3f
60 };
61 
62 /* RFC 8032, A = 486662 */
63 static const uint8_t g_curveParamCapitalA[KEY_BYTES_CURVE25519] = {
64     0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
65     0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x07, 0x6d, 0x06
66 };
67 
68 /* RFC 8032, -A = -486662 */
69 static const uint8_t g_curveParamMinusA[KEY_BYTES_CURVE25519] = {
70     0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
71     0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xf8, 0x92, 0xe7
72 };
73 
74 /* RFC 8032, u = 2 */
75 static const uint8_t g_curveParamU[KEY_BYTES_CURVE25519] = {
76     0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
77     0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02
78 };
79 
80 /* RFC 8032, q = endian_swap(k) */
81 static const uint8_t g_curveParamQ[KEY_BYTES_CURVE25519] = {
82     0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
83     0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xf6
84 };
85 
HcBnFree(BIGNUM * p)86 static void HcBnFree(BIGNUM *p)
87 {
88     if (p == NULL) {
89         return;
90     }
91     BN_free(p);
92 }
93 
HcBnCTXFree(BN_CTX * ctx)94 static void HcBnCTXFree(BN_CTX *ctx)
95 {
96     if (ctx == NULL) {
97         return;
98     }
99     BN_CTX_free(ctx);
100 }
101 
CurveFreeConstPara(struct CurveConstPara * para)102 static void CurveFreeConstPara(struct CurveConstPara *para)
103 {
104     HcBnFree(para->p);
105     HcBnFree(para->one);
106     HcBnFree(para->d);
107     HcBnFree(para->k);
108     HcBnFree(para->capitalA);
109     HcBnFree(para->minusA);
110     HcBnFree(para->u);
111     HcBnFree(para->q);
112 }
113 
CurveInitConstPara(struct CurveConstPara * para)114 static int32_t CurveInitConstPara(struct CurveConstPara *para)
115 {
116     do {
117         para->p = BN_new();
118         if (para->p == NULL) {
119             break;
120         }
121         para->one = BN_new();
122         if (para->one == NULL) {
123             break;
124         }
125         para->d = BN_new();
126         if (para->d == NULL) {
127             break;
128         }
129         para->k = BN_new();
130         if (para->k == NULL) {
131             break;
132         }
133         para->capitalA = BN_new();
134         if (para->capitalA == NULL) {
135             break;
136         }
137         para->minusA = BN_new();
138         if (para->minusA == NULL) {
139             break;
140         }
141         para->u = BN_new();
142         if (para->u == NULL) {
143             break;
144         }
145         para->q = BN_new();
146         if (para->q == NULL) {
147             break;
148         }
149         return HAL_SUCCESS;
150     } while (0);
151 
152     CurveFreeConstPara(para);
153 
154     return HAL_FAILED;
155 }
156 
157 /* b := -A / (1 + u * a ^ 2) */
CurveHashToPointCalcB(const struct HksBlob * hash,const struct CurveConstPara * curvePara,BIGNUM * b,BN_CTX * ctx)158 static int32_t CurveHashToPointCalcB(const struct HksBlob *hash,
159     const struct CurveConstPara *curvePara, BIGNUM *b, BN_CTX *ctx)
160 {
161     BIGNUM *swap = BN_new();
162     if (swap == NULL) {
163         return HAL_FAILED;
164     }
165 
166     int32_t ret = HAL_FAILED;
167     do {
168         if (BN_bin2bn(hash->data, hash->size, swap) == NULL) {
169             break;
170         }
171         if (BN_mul(b, swap, swap, ctx) <= 0) {
172             break;
173         }
174         if (BN_mod(b, b, curvePara->p, ctx) <= 0) {
175             break;
176         }
177         if (BN_mul(swap, b, curvePara->u, ctx) <= 0) {
178             break;
179         }
180         if (BN_mod(swap, swap, curvePara->p, ctx) <= 0) {
181             break;
182         }
183         if (BN_add(b, swap, curvePara->one) <= 0) {
184             break;
185         }
186         if (BN_mod(b, b, curvePara->p, ctx) <= 0) {
187             break;
188         }
189         if (BN_mod_inverse(swap, b, curvePara->p, ctx) <= 0) {
190             break;
191         }
192         if (BN_mul(b, swap, curvePara->minusA, ctx) <= 0) {
193             break;
194         }
195         if (BN_mod(b, b, curvePara->p, ctx) <= 0) {
196             break;
197         }
198         ret = HAL_SUCCESS;
199     } while (0);
200     HcBnFree(swap);
201     return ret;
202 }
203 
CurveHashToPointCalcA(const BIGNUM * b,const struct CurveConstPara * curvePara,BIGNUM * a,BN_CTX * ctx)204 static int32_t CurveHashToPointCalcA(const BIGNUM *b,
205     const struct CurveConstPara *curvePara, BIGNUM *a, BN_CTX *ctx)
206 {
207     BIGNUM *swap = BN_new();
208     if (swap == NULL) {
209         return HAL_FAILED;
210     }
211 
212     BIGNUM *result = BN_new();
213     if (result == NULL) {
214         HcBnFree(swap);
215         return HAL_FAILED;
216     }
217 
218     int32_t ret = HAL_FAILED;
219     do {
220         if (BN_mul(result, b, b, ctx) <= 0) {
221             break;
222         }
223         if (BN_mod(result, result, curvePara->p, ctx) <= 0) {
224             break;
225         }
226         if (BN_mul(swap, result, b, ctx) <= 0) {
227             break;
228         }
229         if (BN_mod(swap, swap, curvePara->p, ctx) <= 0) {
230             break;
231         }
232         if (BN_mul(a, result, curvePara->capitalA, ctx) <= 0) {
233             break;
234         }
235         if (BN_mod(a, a, curvePara->p, ctx) <= 0) {
236             break;
237         }
238         if (BN_add(result, swap, a) <= 0) {
239             break;
240         }
241         if (BN_mod(result, result, curvePara->p, ctx) <= 0) {
242             break;
243         }
244         if (BN_add(a, result, b) <= 0) {
245             break;
246         }
247         if (BN_mod(a, a, curvePara->p, ctx) <= 0) {
248             break;
249         }
250         ret = HAL_SUCCESS;
251     } while (0);
252 
253     HcBnFree(swap);
254     HcBnFree(result);
255     return ret;
256 }
257 
CurveHashToPointCalcC(const BIGNUM * a,BIGNUM * b,const struct CurveConstPara * curvePara,BIGNUM * c,BN_CTX * ctx)258 static int32_t CurveHashToPointCalcC(const BIGNUM *a, BIGNUM *b,
259     const struct CurveConstPara *curvePara, BIGNUM *c, BN_CTX *ctx)
260 {
261     BIGNUM *result = BN_new();
262     if (result == NULL) {
263         return HAL_FAILED;
264     }
265 
266     int32_t ret = HAL_FAILED;
267     do {
268         /* If a is a quadratic residue modulo p, c := b and high_y := 1 Otherwise c := -b - A and high_y := 0 */
269         if (BN_sub(c, curvePara->p, b) <= 0) {
270             break;
271         }
272         if (BN_mod(c, c, curvePara->p, ctx) <= 0) {
273             break;
274         }
275         if (BN_add(c, c, curvePara->minusA) <= 0) {
276             break;
277         }
278         if (BN_mod(c, c, curvePara->p, ctx) <= 0) {
279             break;
280         }
281         /* Sliding-window exponentiation: result = a^q mod p */
282         if (BN_mod_exp(result, a, curvePara->q, curvePara->p, ctx) <= 0) {
283             break;
284         }
285         if (BN_cmp(curvePara->q, result) > 0) {
286             BN_swap(b, c);
287         }
288         ret = HAL_SUCCESS;
289     } while (0);
290 
291     HcBnFree(result);
292     return ret;
293 }
294 
CurveSetConstPara(struct CurveConstPara * para)295 static int32_t CurveSetConstPara(struct CurveConstPara *para)
296 {
297     int32_t ret = HAL_FAILED;
298     do {
299         if (BN_bin2bn(g_curveParamP, KEY_BYTES_CURVE25519, para->p) == NULL) {
300             break;
301         }
302         if (BN_bin2bn(g_curveParamOne, KEY_BYTES_CURVE25519, para->one) == NULL) {
303             break;
304         }
305         if (BN_bin2bn(g_curveParamD, KEY_BYTES_CURVE25519, para->d) == NULL) {
306             break;
307         }
308         if (BN_bin2bn(g_curveParamK, KEY_BYTES_CURVE25519, para->k) == NULL) {
309             break;
310         }
311         if (BN_bin2bn(g_curveParamCapitalA, KEY_BYTES_CURVE25519, para->capitalA) == NULL) {
312             break;
313         }
314         if (BN_bin2bn(g_curveParamMinusA, KEY_BYTES_CURVE25519, para->minusA) == NULL) {
315             break;
316         }
317         if (BN_bin2bn(g_curveParamU, KEY_BYTES_CURVE25519, para->u) == NULL) {
318             break;
319         }
320         if (BN_bin2bn(g_curveParamQ, KEY_BYTES_CURVE25519, para->q) == NULL) {
321             break;
322         }
323         ret = HAL_SUCCESS;
324     } while (0);
325 
326     return ret;
327 }
328 
CurveHashToPoint(const struct HksBlob * hash,struct HksBlob * point)329 static int32_t CurveHashToPoint(const struct HksBlob *hash, struct HksBlob *point)
330 {
331     struct CurveConstPara curvePara;
332     (void)memset_s(&curvePara, sizeof(curvePara), 0, sizeof(curvePara));
333     int32_t ret = CurveInitConstPara(&curvePara);
334     if (ret != HAL_SUCCESS) {
335         return HAL_ERR_BAD_ALLOC;
336     }
337     BIGNUM *a = BN_new();
338     BIGNUM *b = BN_new();
339     BIGNUM *c = BN_new();
340     BN_CTX *ctx = BN_CTX_new();
341     do {
342         if (a == NULL || b == NULL || c == NULL || ctx == NULL) {
343             ret = HAL_ERR_BAD_ALLOC;
344             break;
345         }
346         ret = CurveSetConstPara(&curvePara);
347         if (ret != HAL_SUCCESS) {
348             break;
349         }
350         ret = CurveHashToPointCalcB(hash, &curvePara, b, ctx);
351         if (ret != HAL_SUCCESS) {
352             break;
353         }
354         ret = CurveHashToPointCalcA(b, &curvePara, a, ctx);
355         if (ret != HAL_SUCCESS) {
356             break;
357         }
358         ret = CurveHashToPointCalcC(a, b, &curvePara, c, ctx);
359         if (ret != HAL_SUCCESS) {
360             break;
361         }
362         if (BN_bn2binpad(c, point->data, point->size) <= 0) {
363             ret = HAL_FAILED;
364             break;
365         }
366         ret = HAL_SUCCESS;
367     } while (0);
368     CurveFreeConstPara(&curvePara);
369     HcBnFree(a);
370     HcBnFree(b);
371     HcBnFree(c);
372     HcBnCTXFree(ctx);
373     return ret;
374 }
375 
EndianSwap(struct HksBlob * data)376 static int32_t EndianSwap(struct HksBlob *data)
377 {
378     uint32_t end = data->size - 1;
379     const uint32_t start = 0;
380 
381     /* count the middle index of array */
382     uint32_t cnt = data->size / 2; // 2 used to calculate half of the data size
383 
384     for (uint32_t i = 0; i < cnt; i++) {
385         uint8_t tmp;
386         tmp = data->data[start + i];
387         data->data[start + i] = data->data[end - i];
388         data->data[end - i] = tmp;
389     }
390     return HAL_SUCCESS;
391 }
392 
OpensslHashToPoint(const struct HksBlob * hash,struct HksBlob * point)393 int32_t OpensslHashToPoint(const struct HksBlob *hash, struct HksBlob *point)
394 {
395     int32_t ret = HAL_FAILED;
396     uint8_t *copyData = HcMalloc(hash->size, 0);
397     if (copyData == NULL) {
398         LOGE("malloc size %u failed", hash->size);
399         return HKS_ERROR_MALLOC_FAIL;
400     }
401     struct HksBlob hashCopy = { hash->size, copyData};
402 
403     do {
404         if (memcpy_s(hashCopy.data, hashCopy.size, hash->data, hash->size) != EOK) {
405             break;
406         }
407 
408         hashCopy.data[hashCopy.size - 1] &= 0x3f; /* RFC 8032 */
409         (void)EndianSwap(&hashCopy);
410         ret = CurveHashToPoint(&hashCopy, point);
411         if (ret != HAL_SUCCESS) {
412             LOGE("curve hash to point failed");
413             break;
414         }
415         (void)EndianSwap(point);
416     } while (0);
417     HcFree(hashCopy.data);
418     return ret;
419 }