1 /*
2  * Copyright (c) 2023-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "cert_utils.h"
17 
18 #include <cstring>
19 #include <memory>
20 #include <openssl/rand.h>
21 #include <securec.h>
22 #include <string>
23 #include <vector>
24 
25 #include "byte_buffer.h"
26 #include "errcode.h"
27 #include "huks_param_set.h"
28 #include "log.h"
29 
30 namespace OHOS {
31 namespace Security {
32 namespace CodeSign {
33 static const uint32_t CERT_DATA_SIZE = 8192;
34 static const uint32_t CHALLENGE_LEN = 32;
35 
CastToUint8Ptr(uint32_t * ptr)36 static inline uint8_t *CastToUint8Ptr(uint32_t *ptr)
37 {
38     return reinterpret_cast<uint8_t *>(ptr);
39 }
40 
ConstructDataToCertChain(struct HksCertChain ** certChain,int certsCount)41 bool ConstructDataToCertChain(struct HksCertChain **certChain, int certsCount)
42 {
43     *certChain = static_cast<struct HksCertChain *>(malloc(sizeof(struct HksCertChain)));
44     if (*certChain == nullptr) {
45         LOG_ERROR("malloc fail");
46         return false;
47     }
48     (*certChain)->certsCount = CERT_COUNT;
49 
50     (*certChain)->certs = static_cast<struct HksBlob *>(malloc(sizeof(struct HksBlob) *
51         ((*certChain)->certsCount)));
52     if ((*certChain)->certs == nullptr) {
53         free(*certChain);
54         *certChain = nullptr;
55         return false;
56     }
57     for (uint32_t i = 0; i < (*certChain)->certsCount; i++) {
58         (*certChain)->certs[i].size = CERT_DATA_SIZE;
59         (*certChain)->certs[i].data = static_cast<uint8_t *>(malloc((*certChain)->certs[i].size));
60         if ((*certChain)->certs[i].data == nullptr) {
61             LOG_ERROR("malloc fail");
62             FreeCertChain(certChain, i);
63             return false;
64         }
65     }
66     return true;
67 }
68 
FreeCertChain(struct HksCertChain ** certChain,const uint32_t pos)69 void FreeCertChain(struct HksCertChain **certChain, const uint32_t pos)
70 {
71     if (*certChain == nullptr) {
72         return;
73     }
74     if ((*certChain)->certs == nullptr) {
75         free(*certChain);
76         *certChain = nullptr;
77         return;
78     }
79     for (uint32_t j = 0; j < pos; j++) {
80         if ((*certChain)->certs[j].data != nullptr) {
81             free((*certChain)->certs[j].data);
82             (*certChain)->certs[j].data = nullptr;
83         }
84     }
85     free((*certChain)->certs);
86     (*certChain)->certs = nullptr;
87     free(*certChain);
88     *certChain = nullptr;
89 }
90 
FormattedCertChain(const HksCertChain * certChain,ByteBuffer & buffer)91 bool FormattedCertChain(const HksCertChain *certChain, ByteBuffer &buffer)
92 {
93     uint32_t certsCount = certChain->certsCount;
94     uint32_t totalLen = sizeof(uint32_t);
95     for (uint32_t i = 0; i < certsCount; i++) {
96         totalLen += sizeof(uint32_t) + certChain->certs[i].size;
97     }
98 
99     buffer.Resize(totalLen);
100     if (!buffer.PutData(0, CastToUint8Ptr(&certsCount), sizeof(uint32_t))) {
101         return false;
102     }
103     uint32_t pos = sizeof(uint32_t);
104     for (uint32_t i = 0; i < certsCount; i++) {
105         if (!buffer.PutData(pos, CastToUint8Ptr(&certChain->certs[i].size), sizeof(uint32_t))) {
106             return false;
107         }
108         pos += sizeof(uint32_t);
109         if (!buffer.PutData(pos, certChain->certs[i].data, certChain->certs[i].size)) {
110             return false;
111         }
112         pos += certChain->certs[i].size;
113     }
114     return true;
115 }
116 
CheckSizeAndAssign(uint8_t * & bufferPtr,uint32_t & restSize,uint32_t & retSize)117 static inline bool CheckSizeAndAssign(uint8_t *&bufferPtr, uint32_t &restSize, uint32_t &retSize)
118 {
119     if (restSize < sizeof(uint32_t)) {
120         return false;
121     }
122     retSize = *reinterpret_cast<uint32_t *>(bufferPtr);
123     bufferPtr += sizeof(uint32_t);
124     restSize -= sizeof(uint32_t);
125     return true;
126 }
127 
CheckSizeAndCopy(uint8_t * & bufferPtr,uint32_t & restSize,const uint32_t size,ByteBuffer & ret)128 static inline bool CheckSizeAndCopy(uint8_t *&bufferPtr, uint32_t &restSize, const uint32_t size,
129     ByteBuffer &ret)
130 {
131     if (restSize < size) {
132         return false;
133     }
134     if (!ret.CopyFrom(bufferPtr, size)) {
135         return false;
136     }
137     bufferPtr += size;
138     restSize -= size;
139     return true;
140 }
141 
GetCertChainFormBuffer(const ByteBuffer & certChainBuffer,ByteBuffer & signCert,ByteBuffer & issuer,std::vector<ByteBuffer> & chain)142 bool GetCertChainFormBuffer(const ByteBuffer &certChainBuffer,
143     ByteBuffer &signCert, ByteBuffer &issuer, std::vector<ByteBuffer> &chain)
144 {
145     uint8_t *rawPtr = certChainBuffer.GetBuffer();
146     if (rawPtr == nullptr || certChainBuffer.GetSize() < sizeof(uint32_t)) {
147         LOG_ERROR("empty cert chain buffer.");
148         return false;
149     }
150     uint32_t certsCount = *reinterpret_cast<uint32_t *>(rawPtr);
151     rawPtr += sizeof(uint32_t);
152     if (certsCount == 0) {
153         return false;
154     }
155 
156     uint32_t certSize;
157     bool ret = true;
158     uint32_t restSize = certChainBuffer.GetSize() - sizeof(uint32_t);
159     for (uint32_t i = 0; i < certsCount - 1; i++) {
160         if (!CheckSizeAndAssign(rawPtr, restSize, certSize)) {
161             return false;
162         }
163         if (i == 0) {
164             ret = CheckSizeAndCopy(rawPtr, restSize, certSize, signCert);
165         } else if (i == 1) {
166             ret = CheckSizeAndCopy(rawPtr, restSize, certSize, issuer);
167         } else {
168             ByteBuffer cert;
169             ret = CheckSizeAndCopy(rawPtr, restSize, certSize, cert);
170             chain.emplace_back(cert);
171         }
172         if (!ret) {
173             LOG_ERROR("failed at index = %{public}u", i);
174             break;
175         }
176     }
177     return ret;
178 }
179 
GetRandomChallenge()180 std::unique_ptr<ByteBuffer> GetRandomChallenge()
181 {
182     std::unique_ptr<ByteBuffer> challenge = std::make_unique<ByteBuffer>(CHALLENGE_LEN);
183     if (challenge == nullptr) {
184         return nullptr;
185     }
186     RAND_bytes(challenge->GetBuffer(), CHALLENGE_LEN);
187     return challenge;
188 }
189 
CheckChallengeSize(uint32_t size)190 bool CheckChallengeSize(uint32_t size)
191 {
192     if (size > CHALLENGE_LEN) {
193         return false;
194     }
195     return true;
196 }
197 }
198 }
199 }