1 /*
2  * Copyright (c) 2022-2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 #include <stdlib.h>
16 #include "hvb_crypto.h"
17 #include "hvb_util.h"
18 #include "hvb_rsa.h"
19 
20 enum {
21     RESULT_OK = 0,
22     ERROR_MEMORY_EMPTY,
23     ERROR_MEMORY_NO_ENOUGH,
24     ERROR_WORDLEN_ZERO,
25 };
26 
27 #ifndef __WORDSIZE
28 #if defined(__LP64__)
29 #define __WORDSIZE 64
30 #elif defined(__LP32__)
31 #define __WORDSIZE 32
32 #else
33 #error "not support word size "
34 #endif
35 #endif
36 
37 #define WORD_BYTE_SIZE sizeof(unsigned long)
38 #define WORD_BIT_SIZE (WORD_BYTE_SIZE * 8)
39 #define WORD_BIT_MASK (((1UL << WORD_BIT_SIZE) - 1))
40 #define byte2bit(byte) ((byte) << 3)
41 #define SWORD_BIT_SIZE (WORD_BIT_SIZE / 2)
42 #define SWORD_BIT_MASK ((1UL << SWORD_BIT_SIZE) - 1)
43 
lin_clear(struct long_int_num * p_a)44 static void lin_clear(struct long_int_num *p_a)
45 {
46     hvb_memset(p_a->data_mem, 0, p_a->mem_size);
47 }
48 
lin_copy(struct long_int_num * p_src,struct long_int_num * p_dst)49 static int lin_copy(struct long_int_num *p_src, struct long_int_num *p_dst)
50 {
51     if (p_src->valid_word_len > p_dst->mem_size) {
52         return ERROR_MEMORY_NO_ENOUGH;
53     }
54 
55     hvb_memcpy(p_dst->p_uint, p_src->p_uint, p_src->valid_word_len * WORD_BYTE_SIZE);
56 
57     p_dst->valid_word_len = p_src->valid_word_len;
58 
59     return RESULT_OK;
60 }
61 
lin_compare(struct long_int_num * p_a,struct long_int_num * p_b)62 static int lin_compare(struct long_int_num *p_a, struct long_int_num *p_b)
63 {
64     int i;
65 
66     if (p_a->valid_word_len != p_b->valid_word_len) {
67         return p_a->valid_word_len - p_b->valid_word_len;
68     }
69 
70     for (i = p_a->valid_word_len - 1; i >= 0; --i) {
71         if (p_a->p_uint[i] != p_b->p_uint[i]) {
72             if (p_a->p_uint[i] > p_b->p_uint[i]) {
73                 return 1;
74             }
75             return -1;
76         }
77     }
78     return 0;
79 }
80 
lin_calloc(struct long_int_num * p_long_int,uint32_t word_len)81 static int lin_calloc(struct long_int_num *p_long_int, uint32_t word_len)
82 {
83     unsigned long *p_data = NULL;
84 
85     if (word_len == 0) {
86         return ERROR_WORDLEN_ZERO;
87     }
88     p_data = hvb_malloc(word_len * WORD_BYTE_SIZE);
89     if (p_data == NULL) {
90         return ERROR_MEMORY_EMPTY;
91     }
92 
93     hvb_memset(p_data, 0, word_len * WORD_BYTE_SIZE);
94 
95     p_long_int->data_mem = p_data;
96     p_long_int->mem_size = word_len * WORD_BYTE_SIZE;
97     p_long_int->p_uint = p_data;
98     p_long_int->valid_word_len = 0;
99 
100     return RESULT_OK;
101 }
102 
lin_create(uint32_t word_len)103 struct long_int_num *lin_create(uint32_t word_len)
104 {
105     struct long_int_num *p_res = NULL;
106 
107     p_res = hvb_malloc(sizeof(struct long_int_num));
108     if (p_res == NULL) {
109         return NULL;
110     }
111 
112     if (lin_calloc(p_res, word_len) > 0) {
113         hvb_free(p_res);
114         return NULL;
115     }
116     p_res->valid_word_len = 0;
117     return p_res;
118 }
119 
lin_free(struct long_int_num * p_long_int)120 void lin_free(struct long_int_num *p_long_int)
121 {
122     if (!p_long_int) {
123         return;
124     }
125     if (p_long_int->p_uint != NULL) {
126         hvb_free(p_long_int->data_mem);
127         p_long_int->p_uint = NULL;
128     }
129     hvb_free(p_long_int);
130 
131     return;
132 }
133 
bn_get_valid_len(const uint8_t * pd,uint32_t size)134 uint32_t bn_get_valid_len(const uint8_t *pd, uint32_t size)
135 {
136     uint32_t i = 0;
137     uint32_t valid_len = size;
138 
139     if(!pd)
140         return 0;
141 
142     while(valid_len > 0 && pd[i++] == 0)
143         valid_len--;
144 
145     return valid_len;
146 }
147 
lin_update_valid_len(struct long_int_num * p_a)148 void lin_update_valid_len(struct long_int_num *p_a)
149 {
150     unsigned long *p_data = NULL;
151     uint32_t i;
152 
153     if (!p_a) {
154         return;
155     }
156 
157     p_data = p_a->p_uint + p_a->valid_word_len - 1;
158     for (i = 0; i < p_a->valid_word_len; ++i) {
159         if (*p_data != 0) {
160             break;
161         }
162         --p_data;
163     }
164     p_a->valid_word_len -= i;
165 }
166 
lin_mul_word(unsigned long a,unsigned long b,unsigned long * res_hi,unsigned long * res_low)167 static void lin_mul_word(unsigned long a, unsigned long b, unsigned long *res_hi, unsigned long *res_low)
168 {
169 #if defined(__aarch64__)
170     unsigned long hi = 0;
171     *res_low = a * b;
172     __asm__ volatile ("umulh %0, %1, %2" : "+r"(hi) : "r"(a), "r"(b) :);
173     *res_hi = hi;
174 #else
175 
176 #if defined(__uint128_t)
177     #if __WORDSIZE == 32
178     unsigned long long aa;
179 #elif __WORDSIZE == 64
180     __uint128_t aa, bb;
181 #else
182     #error "not support word size "
183 #endif
184     aa = a;
185     bb = b;
186     aa = aa * bb;
187     *res_hi = aa >> WORD_BIT_SIZE;
188     *res_low = aa & WORD_BIT_MASK;
189 #else
190     unsigned long a_h, a_l;
191     unsigned long b_h, b_l;
192     unsigned long res_h, res_l;
193     unsigned long c, t;
194     a_h = a >> SWORD_BIT_SIZE;
195     a_l = a & SWORD_BIT_MASK;
196     b_h = b >> SWORD_BIT_SIZE;
197     b_l = b & SWORD_BIT_MASK;
198 
199     res_h = a_h * b_h;
200     res_l = a_l * b_l;
201 
202     c = a_h * b_l;
203     res_h += c >> SWORD_BIT_SIZE;
204     t = res_l;
205     res_l += c << SWORD_BIT_SIZE;
206     res_h += t > res_l;
207 
208     c = a_l * b_h;
209     res_h += c >> SWORD_BIT_SIZE;
210     t = res_l;
211     res_l += c << SWORD_BIT_SIZE;
212     res_h += t > res_l;
213     *res_hi  = res_h;
214     *res_low = res_l;
215 #endif
216 #endif
217 }
218 
lin_sub(struct long_int_num * p_a,struct long_int_num * p_b)219 static void lin_sub(struct long_int_num *p_a, struct long_int_num *p_b)
220 {
221     uint32_t i;
222     unsigned long c;
223     unsigned long t;
224 
225     c = 0;
226     for (i = 0; i < p_b->valid_word_len; ++i) {
227         t = p_a->p_uint[i] < c;
228         p_a->p_uint[i] = p_a->p_uint[i] - c;
229 
230         c = (p_a->p_uint[i] < p_b->p_uint[i]) + t;
231         p_a->p_uint[i] = p_a->p_uint[i] - p_b->p_uint[i];
232     }
233     for (; i < p_a->valid_word_len && c; ++i) {
234         t = p_a->p_uint[i] < c;
235         p_a->p_uint[i] = p_a->p_uint[i] - c;
236         c = t;
237     }
238     lin_update_valid_len(p_a);
239 }
240 
241 #define dword_add_word(a, b, r)		       \
242     do {				       \
243         r##_l = a##_l + (b);	       \
244         r##_h = a##_h + (r##_l < (b)); \
245     } while (0)
246 
montgomery_mul_add(struct long_int_num * p_a,unsigned long b,struct long_int_num * p_n,unsigned long n_n0_i,struct long_int_num * p_res)247 static void montgomery_mul_add(struct long_int_num *p_a, unsigned long b, struct long_int_num *p_n,
248                                unsigned long n_n0_i, struct long_int_num *p_res)
249 {
250     unsigned long x_h, x_l;
251     unsigned long d0;
252     unsigned long y_h, y_l;
253     unsigned long t_h, t_l;
254     unsigned long *p_ad = p_a->p_uint;
255     unsigned long *p_nd = p_n->p_uint;
256     unsigned long *p_rd = p_res->p_uint;
257     uint32_t i;
258 
259     while (p_a->valid_word_len > p_n->valid_word_len){
260         lin_sub(p_a, p_n);
261     }
262 
263     lin_mul_word(p_a->p_uint[0], b, &x_h, &x_l);
264 
265     dword_add_word(x, p_rd[0], x);
266 
267     d0 = x_l * n_n0_i;
268 
269     lin_mul_word(d0, p_nd[0], &y_h, &y_l);
270     dword_add_word(y, x_l, y);
271 
272     for (i = 1; i < p_a->valid_word_len; ++i) {
273         lin_mul_word(p_ad[i], b, &t_h, &t_l);
274         dword_add_word(t, p_rd[i], t);
275         dword_add_word(t, x_h, x);
276 
277         lin_mul_word(d0, p_nd[i], &t_h, &t_l);
278         dword_add_word(t, x_l, t);
279         dword_add_word(t, y_h, y);
280 
281         p_rd[i - 1] = y_l;
282     }
283 
284     p_rd[i - 1] = x_h + y_h;
285 
286     p_res->valid_word_len = p_n->valid_word_len;
287     if (p_rd[i - 1] < x_h) {
288         lin_sub(p_res, p_n);
289     }
290 }
291 
montgomery_mod_mul(struct long_int_num * p_a,struct long_int_num * p_b,struct long_int_num * p_n,unsigned long n_n0_i,struct long_int_num * p_res)292 static void montgomery_mod_mul(struct long_int_num *p_a, struct long_int_num *p_b, struct long_int_num *p_n,
293                                unsigned long n_n0_i, struct long_int_num *p_res)
294 {
295     uint32_t i;
296 
297     lin_clear(p_res);
298 
299     for (i = 0; i < p_b->valid_word_len; ++i) {
300         montgomery_mul_add(p_a, p_b->p_uint[i], p_n, n_n0_i, p_res);
301     }
302 }
303 
montgomery_mod_exp(struct long_int_num * p_m,struct long_int_num * p_n,unsigned long n_n0_i,struct long_int_num * p_rr,uint32_t exp)304 struct long_int_num *montgomery_mod_exp(struct long_int_num *p_m, struct long_int_num *p_n, unsigned long n_n0_i,
305                                         struct long_int_num *p_rr, uint32_t exp)
306 {
307     struct long_int_num *p_res = NULL;
308     struct long_int_num *p_mr = NULL;
309     struct long_int_num *p_square = NULL;
310     int i;
311     if ((exp & 1UL) == 0) {
312         goto fail_final;
313     }
314 
315     p_mr = lin_create(p_n->valid_word_len);
316     if (p_mr == NULL) {
317         goto fail_final;
318     }
319 
320     p_square = lin_create(p_n->valid_word_len);
321     if (p_square == NULL) {
322         goto fail_final;
323     }
324 
325     p_res = lin_create(p_n->valid_word_len);
326     if (p_res == NULL) {
327         goto fail_final;
328     }
329 
330     montgomery_mod_mul(p_m, p_rr, p_n, n_n0_i, p_mr);
331     i = byte2bit(sizeof(exp)) - 1;
332     for (; i >= 0; --i) {
333         if (exp & (1UL << i)) {
334             break;
335         }
336     }
337 
338     lin_copy(p_mr, p_res);
339 
340     for (--i; i > 0; --i) {
341         montgomery_mod_mul(p_res, p_res, p_n, n_n0_i, p_square);
342         if (exp & (1UL << i)) {
343             montgomery_mod_mul(p_mr, p_square, p_n, n_n0_i, p_res);
344         } else {
345             lin_copy(p_square, p_res);
346         }
347     }
348     montgomery_mod_mul(p_res, p_res, p_n, n_n0_i, p_square);
349     montgomery_mod_mul(p_m, p_square, p_n, n_n0_i, p_res);
350 
351     if (lin_compare(p_res, p_n) >= 0) {
352         lin_sub(p_res, p_n);
353     }
354 
355 fail_final:
356     lin_free(p_mr);
357     lin_free(p_square);
358 
359     return p_res;
360 }
361 
lin_get_bitlen(struct long_int_num * p_a)362 uint32_t lin_get_bitlen(struct long_int_num *p_a)
363 {
364     int i;
365     int bit_len;
366     unsigned long *p_data = NULL;
367     unsigned long value;
368 
369     if (!p_a) {
370         return 0;
371     }
372     p_data = p_a->p_uint;
373     for (i = p_a->valid_word_len - 1; i >= 0; --i) {
374         if (p_data[i] != 0) {
375             break;
376         }
377     }
378 
379     bit_len = (i + 1) * WORD_BIT_SIZE;
380 
381     if (bit_len == 0) {
382         return 0;
383     }
384 
385     for (value = p_data[i]; ((signed long)value) > 0; value = value << 1) {
386         --bit_len;
387     }
388 
389     return bit_len;
390 }
391