1 /*
2  * Copyright (C) 2021 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *    http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "das_version_util.h"
17 #include "hc_log.h"
18 #include "hc_types.h"
19 #include "string_util.h"
20 
21 #define BIND_PRIORITY_LEN 5
22 #define AUTH_PRIORITY_LEN 5
23 
24 typedef struct PriorityMapT {
25     uint32_t alg;
26     ProtocolType type;
27 } PriorityMap;
28 
29 VersionStruct g_defaultVersion = { 1, 0, 0 };
30 PriorityMap g_bindPriorityList[BIND_PRIORITY_LEN] = {
31     { EC_PAKE_V2, PAKE_V2 },
32     { DL_PAKE_V2, PAKE_V2 },
33     { EC_PAKE_V1, PAKE_V1 },
34     { DL_PAKE_V1, PAKE_V1 },
35     { ISO_ALG, ISO }
36 };
37 PriorityMap g_authPriorityList[AUTH_PRIORITY_LEN] = {
38     { PSK_SPEKE | EC_PAKE_V2, PAKE_V2 },
39     { PSK_SPEKE | EC_PAKE_V1, PAKE_V1 },
40     { ISO_ALG, ISO }
41 };
42 
GetSlice(char * str,char delim,int * nextIdx)43 static const char *GetSlice(char *str, char delim, int *nextIdx)
44 {
45     uint32_t len = HcStrlen(str);
46     for (uint32_t i = 0; i < len; i++) {
47         if (str[i] == delim) {
48             *nextIdx = *nextIdx + i + 1;
49             str[i] = '\0';
50             return str;
51         }
52     }
53     return str;
54 }
55 
StringToVersion(const char * verStr,VersionStruct * version)56 int32_t StringToVersion(const char* verStr, VersionStruct* version)
57 {
58     CHECK_PTR_RETURN_ERROR_CODE(version, "version");
59     CHECK_PTR_RETURN_ERROR_CODE(verStr, "verStr");
60 
61     const char *subVer = NULL;
62     int nextIdx = 0;
63 
64     uint32_t len = HcStrlen(verStr);
65     char *verStrTmp = (char *)HcMalloc(len + 1, 0);
66     if (verStrTmp == NULL) {
67         LOGE("Malloc for verStrTmp failed.");
68         return HC_ERR_ALLOC_MEMORY;
69     }
70     if (memcpy_s(verStrTmp, len + 1, verStr, len) != EOK) {
71         LOGE("Memcpy for verStrTmp failed.");
72         HcFree(verStrTmp);
73         return HC_ERR_MEMORY_COPY;
74     }
75 
76     subVer = GetSlice(verStrTmp, '.', &nextIdx);
77     if (subVer == NULL) {
78         goto CLEAN_UP;
79     }
80     version->first = (uint32_t)strtoul(subVer, NULL, DEC);
81 
82     subVer = GetSlice(verStrTmp + nextIdx, '.', &nextIdx);
83     if (subVer == NULL) {
84         goto CLEAN_UP;
85     }
86     version->second = (uint32_t)strtoul(subVer, NULL, DEC);
87 
88     subVer = GetSlice(verStrTmp + nextIdx, '.', &nextIdx);
89     if (subVer == NULL) {
90         goto CLEAN_UP;
91     }
92     version->third = (uint32_t)strtoul(subVer, NULL, DEC);
93 
94     HcFree(verStrTmp);
95     return HC_SUCCESS;
96 CLEAN_UP:
97     LOGE("GetSlice failed.");
98     HcFree(verStrTmp);
99     return HC_ERROR;
100 }
101 
VersionToString(const VersionStruct * version,char * verStr,uint32_t len)102 int32_t VersionToString(const VersionStruct *version, char *verStr, uint32_t len)
103 {
104     CHECK_PTR_RETURN_ERROR_CODE(version, "version");
105     CHECK_PTR_RETURN_ERROR_CODE(verStr, "verStr");
106 
107     char tmpStr[TMP_VERSION_STR_LEN] = { 0 };
108     if (sprintf_s(tmpStr, TMP_VERSION_STR_LEN, "%u.%u.%u", version->first, version->second, version->third) <= 0) {
109         LOGE("Convert version struct to string failed.");
110         return HC_ERR_CONVERT_FAILED;
111     }
112     uint32_t tmpStrLen = HcStrlen(tmpStr);
113     if (len < tmpStrLen + 1) {
114         LOGE("The length of verStr is too short, len: %u.", len);
115         return HC_ERR_INVALID_LEN;
116     }
117 
118     if (memcpy_s(verStr, len, tmpStr, tmpStrLen + 1) != 0) {
119         LOGE("Memcpy for verStr failed.");
120         return HC_ERR_MEMORY_COPY;
121     }
122 
123     return HC_SUCCESS;
124 }
125 
AddSingleVersionToJson(CJson * jsonObj,const VersionStruct * version)126 int32_t AddSingleVersionToJson(CJson *jsonObj, const VersionStruct *version)
127 {
128     CHECK_PTR_RETURN_ERROR_CODE(jsonObj, "jsonObj");
129     CHECK_PTR_RETURN_ERROR_CODE(version, "version");
130 
131     char versionStr[TMP_VERSION_STR_LEN] = { 0 };
132     int32_t ret = VersionToString(version, versionStr, TMP_VERSION_STR_LEN);
133     if (ret != HC_SUCCESS) {
134         LOGE("VersionToString failed, res: %x.", ret);
135         return ret;
136     }
137 
138     CJson *sendToPeer = GetObjFromJson(jsonObj, FIELD_SEND_TO_PEER);
139     if (sendToPeer == NULL) {
140         LOGD("There is not sendToPeer in json.");
141         return HC_SUCCESS;
142     }
143     if (AddStringToJson(sendToPeer, FIELD_GROUP_AND_MODULE_VERSION, versionStr) != HC_SUCCESS) {
144         LOGE("Add group and module version to sendToPeer failed.");
145         return HC_ERR_JSON_ADD;
146     }
147     return HC_SUCCESS;
148 }
149 
GetSingleVersionFromJson(const CJson * jsonObj,VersionStruct * version)150 int32_t GetSingleVersionFromJson(const CJson* jsonObj, VersionStruct *version)
151 {
152     CHECK_PTR_RETURN_ERROR_CODE(jsonObj, "jsonObj");
153     CHECK_PTR_RETURN_ERROR_CODE(version, "version");
154 
155     const char *versionStr = GetStringFromJson(jsonObj, FIELD_GROUP_AND_MODULE_VERSION);
156     if (versionStr == NULL) {
157         LOGE("Get group and module version from json failed.");
158         return HC_ERR_JSON_GET;
159     }
160 
161     int32_t ret = StringToVersion(versionStr, version);
162     if (ret != HC_SUCCESS) {
163         LOGE("StringToVersion failed, res: %x.", ret);
164         return ret;
165     }
166     return HC_SUCCESS;
167 }
168 
InitGroupAndModuleVersion(VersionStruct * version)169 void InitGroupAndModuleVersion(VersionStruct *version)
170 {
171     if (version == NULL) {
172         LOGE("Version is null.");
173         return;
174     }
175     version->first = MAJOR_VERSION_NO;
176     version->second = 0;
177     version->third = 0;
178 }
179 
GetVersionFromJson(const CJson * jsonObj,VersionStruct * minVer,VersionStruct * maxVer)180 int32_t GetVersionFromJson(const CJson* jsonObj, VersionStruct *minVer, VersionStruct *maxVer)
181 {
182     CHECK_PTR_RETURN_ERROR_CODE(jsonObj, "jsonObj");
183     CHECK_PTR_RETURN_ERROR_CODE(minVer, "minVer");
184     CHECK_PTR_RETURN_ERROR_CODE(maxVer, "maxVer");
185 
186     const char *minStr = GetStringFromJson(jsonObj, FIELD_MIN_VERSION);
187     CHECK_PTR_RETURN_ERROR_CODE(minStr, "minStr");
188     const char *maxStr = GetStringFromJson(jsonObj, FIELD_CURRENT_VERSION);
189     CHECK_PTR_RETURN_ERROR_CODE(maxStr, "maxStr");
190 
191     int32_t minRet = StringToVersion(minStr, minVer);
192     int32_t maxRet = StringToVersion(maxStr, maxVer);
193     if (minRet != HC_SUCCESS || maxRet != HC_SUCCESS) {
194         LOGE("Convert version string to struct failed.");
195         return HC_ERROR;
196     }
197     return HC_SUCCESS;
198 }
199 
AddVersionToJson(CJson * jsonObj,const VersionStruct * minVer,const VersionStruct * maxVer)200 int32_t AddVersionToJson(CJson *jsonObj, const VersionStruct *minVer, const VersionStruct *maxVer)
201 {
202     CHECK_PTR_RETURN_ERROR_CODE(jsonObj, "jsonObj");
203     CHECK_PTR_RETURN_ERROR_CODE(minVer, "minVer");
204     CHECK_PTR_RETURN_ERROR_CODE(maxVer, "maxVer");
205 
206     char minStr[TMP_VERSION_STR_LEN] = { 0 };
207     int32_t minRet = VersionToString(minVer, minStr, TMP_VERSION_STR_LEN);
208     char maxStr[TMP_VERSION_STR_LEN] = { 0 };
209     int32_t maxRet = VersionToString(maxVer, maxStr, TMP_VERSION_STR_LEN);
210     if (minRet != HC_SUCCESS || maxRet != HC_SUCCESS) {
211         return HC_ERROR;
212     }
213     CJson* version = CreateJson();
214     if (version == NULL) {
215         LOGE("CreateJson for version failed.");
216         return HC_ERR_JSON_CREATE;
217     }
218     if (AddStringToJson(version, FIELD_MIN_VERSION, minStr) != HC_SUCCESS) {
219         LOGE("Add min version to json failed.");
220         FreeJson(version);
221         return HC_ERR_JSON_ADD;
222     }
223     if (AddStringToJson(version, FIELD_CURRENT_VERSION, maxStr) != HC_SUCCESS) {
224         LOGE("Add max version to json failed.");
225         FreeJson(version);
226         return HC_ERR_JSON_ADD;
227     }
228     if (AddObjToJson(jsonObj, FIELD_VERSION, version) != HC_SUCCESS) {
229         LOGE("Add version object to json failed.");
230         FreeJson(version);
231         return HC_ERR_JSON_ADD;
232     }
233     FreeJson(version);
234     return HC_SUCCESS;
235 }
236 
IsVersionEqual(VersionStruct * src,VersionStruct * des)237 bool IsVersionEqual(VersionStruct *src, VersionStruct *des)
238 {
239     if ((src->first == des->first) && (src->second == des->second) && (src->third == des->third)) {
240         return true;
241     }
242     return false;
243 }
244 
NegotiateVersion(VersionStruct * minVersionPeer,VersionStruct * curVersionPeer,VersionStruct * curVersionSelf)245 int32_t NegotiateVersion(VersionStruct *minVersionPeer, VersionStruct *curVersionPeer,
246     VersionStruct *curVersionSelf)
247 {
248     (void)minVersionPeer;
249     if (IsVersionEqual(curVersionPeer, &g_defaultVersion)) {
250         curVersionSelf->first = g_defaultVersion.first;
251         curVersionSelf->second = g_defaultVersion.second;
252         curVersionSelf->third = g_defaultVersion.third;
253         return HC_SUCCESS;
254     }
255     curVersionSelf->third = curVersionSelf->third & curVersionPeer->third;
256     if (curVersionSelf->third == 0) {
257         LOGE("Unsupported version!");
258         return HC_ERR_UNSUPPORTED_VERSION;
259     }
260     return HC_SUCCESS;
261 }
262 
GetBindPrototolType(VersionStruct * curVersion)263 static ProtocolType GetBindPrototolType(VersionStruct *curVersion)
264 {
265     if (IsVersionEqual(curVersion, &g_defaultVersion)) {
266         return PAKE_V1;
267     }
268     for (int i = 0; i < BIND_PRIORITY_LEN; i++) {
269         if ((curVersion->third & g_bindPriorityList[i].alg) == g_bindPriorityList[i].alg) {
270             return g_bindPriorityList[i].type;
271         }
272     }
273     return PROTOCOL_TYPE_NONE;
274 }
275 
GetAuthPrototolType(VersionStruct * curVersion)276 static ProtocolType GetAuthPrototolType(VersionStruct *curVersion)
277 {
278     if (IsVersionEqual(curVersion, &g_defaultVersion)) {
279         LOGE("Not support STS.");
280         return PROTOCOL_TYPE_NONE;
281     }
282     for (int i = 0; i < AUTH_PRIORITY_LEN; i++) {
283         if ((curVersion->third & g_authPriorityList[i].alg) == g_authPriorityList[i].alg) {
284             return g_authPriorityList[i].type;
285         }
286     }
287     return PROTOCOL_TYPE_NONE;
288 }
289 
GetPrototolType(VersionStruct * curVersion,OperationCode opCode)290 ProtocolType GetPrototolType(VersionStruct *curVersion, OperationCode opCode)
291 {
292     switch (opCode) {
293         case OP_BIND:
294         case AUTH_KEY_AGREEMENT:
295             return GetBindPrototolType(curVersion);
296         case AUTHENTICATE:
297         case OP_UNBIND:
298             return GetAuthPrototolType(curVersion);
299         default:
300             LOGE("Unsupported opCode: %d.", opCode);
301     }
302     return PROTOCOL_TYPE_NONE;
303 }
304 
GetSupportedPakeAlg(VersionStruct * curVersion,ProtocolType protocolType)305 PakeAlgType GetSupportedPakeAlg(VersionStruct *curVersion, ProtocolType protocolType)
306 {
307     PakeAlgType pakeAlgType = PAKE_ALG_NONE;
308     if (protocolType == PAKE_V2) {
309         pakeAlgType = ((curVersion->third & EC_PAKE_V2) >> ALG_OFFSET_FOR_PAKE_V2) |
310             ((curVersion->third & DL_PAKE_V2) >> ALG_OFFSET_FOR_PAKE_V2);
311     } else if (protocolType == PAKE_V1) {
312         pakeAlgType = ((curVersion->third & EC_PAKE_V1) >> ALG_OFFSET_FOR_PAKE_V1) |
313             ((curVersion->third & DL_PAKE_V1) >> ALG_OFFSET_FOR_PAKE_V1);
314     } else {
315         LOGE("Invalid protocolType: %d.", protocolType);
316     }
317     return pakeAlgType;
318 }
319 
IsSupportedPsk(VersionStruct * curVersion)320 bool IsSupportedPsk(VersionStruct *curVersion)
321 {
322     return ((curVersion->third & PSK_SPEKE) != 0);
323 }