1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include <climits>
17 #include <cstring>
18 #include <fcntl.h>
19 #include <gtest/gtest.h>
20 #include <iostream>
21 #include <sys/mman.h>
22 #include <sys/stat.h>
23 #include <unistd.h>
24 #include "hash_data_verifier.h"
25 #include "log.h"
26 #include "pkg_stream.h"
27 #include "pkg_utils.h"
28 #include "scope_guard.h"
29 #include "script_instruction.h"
30 #include "script_manager.h"
31 #include "script/script_unittest.h"
32 #include "script_utils.h"
33 #include "unittest_comm.h"
34 #include "utils.h"
35 
36 using namespace std;
37 using namespace Hpackage;
38 using namespace Uscript;
39 using namespace Updater;
40 using namespace testing::ext;
41 
42 namespace {
43 constexpr int32_t SCRIPT_TEST_PRIORITY_NUM = 3;
44 constexpr int32_t SCRIPT_TEST_LAST_PRIORITY = 2;
45 
46 class TestPkgManager : public TestScriptPkgManager {
47 public:
GetFileInfo(const std::string & fileId)48     const FileInfo *GetFileInfo(const std::string &fileId) override
49     {
50         static FileInfo fileInfo {};
51         static std::vector<std::string> testFileNames = {
52             "loadScript.us",
53             "registerCmd.us",
54             "test_function.us",
55             "test_if.us",
56             "test_logic.us",
57             "test_math.us",
58             "test_native.us",
59             "testscript.us",
60             "Verse-script.us",
61             "test_script.us"
62         };
63         if (fileId == "hash_signed_data") {
64             fileInfo.unpackedSize = GetFileSize(TEST_PATH_FROM + fileId);
65             return &fileInfo;
66         }
67         if (std::find(testFileNames.begin(), testFileNames.end(), fileId) != testFileNames.end()) {
68             return &fileInfo;
69         }
70         return nullptr;
71     }
CreatePkgStream(StreamPtr & stream,const std::string & fileName,const PkgBuffer & buffer)72     int32_t CreatePkgStream(StreamPtr &stream, const std::string &fileName, const PkgBuffer &buffer) override
73     {
74         stream = new MemoryMapStream(this, fileName, buffer, PkgStream::PkgStreamType_Buffer);
75         return PKG_SUCCESS;
76     }
ExtractFile(const std::string & fileId,StreamPtr output)77     int32_t ExtractFile(const std::string &fileId, StreamPtr output) override
78     {
79         if (fileId != "hash_signed_data") {
80             return PKG_SUCCESS;
81         }
82         if (output == nullptr) {
83             return PKG_INVALID_STREAM;
84         }
85         auto stream = static_cast<MemoryMapStream *>(output);
86         auto fd = open((TEST_PATH_FROM + fileId).c_str(), O_RDWR);
87         if (fd == -1) {
88             PKG_LOGE("file %s not existed", (TEST_PATH_FROM + fileId).c_str());
89             return PKG_INVALID_FILE;
90         }
91         ON_SCOPE_EXIT(close) {
92             close(fd);
93         };
94         std::string content {};
95         if (!Utils::ReadFileToString(fd, content)) {
96             PKG_LOGE("read file to string failed");
97             return PKG_INVALID_FILE;
98         }
99         PkgBuffer buffer {};
100         stream->GetBuffer(buffer);
101         if (content.size() + 1 != buffer.length) {
102             PKG_LOGE("content size is not valid, %u != %u", content.size(), buffer.data.size());
103             return PKG_INVALID_FILE;
104         }
105         std::copy(content.begin(), content.end(), buffer.buffer);
106         return PKG_SUCCESS;
107     }
CreatePkgStream(StreamPtr & stream,const std::string & fileName,size_t size,int32_t type)108     int32_t CreatePkgStream(StreamPtr &stream, const std::string &fileName,
109          size_t size, int32_t type) override
110     {
111         FILE *file = nullptr;
112         std::string fileNameReal = fileName;
113         auto pos = fileName.rfind('/');
114         if (pos != std::string::npos) {
115             fileNameReal = fileName.substr(pos + 1);
116         }
117         char realPath[PATH_MAX + 1] = {};
118         if (realpath((TEST_PATH_FROM + fileNameReal).c_str(), realPath) == nullptr) {
119             LOG(ERROR) << (TEST_PATH_FROM + fileNameReal) << " realpath failed";
120             return PKG_INVALID_FILE;
121         }
122         file = fopen(realPath, "rb");
123         if (file == nullptr) {
124             PKG_LOGE("Fail to open file %s ", fileNameReal.c_str());
125             return PKG_INVALID_FILE;
126         }
127         stream = new FileStream(this, fileNameReal, file, PkgStream::PkgStreamType_Read);
128         return USCRIPT_SUCCESS;
129     }
ClosePkgStream(StreamPtr & stream)130     void ClosePkgStream(StreamPtr &stream) override
131     {
132         delete stream;
133     }
134 };
135 
136 
137 class TestScriptInstructionSparseImageWrite : public Uscript::UScriptInstruction {
138 public:
TestScriptInstructionSparseImageWrite()139     TestScriptInstructionSparseImageWrite() {}
~TestScriptInstructionSparseImageWrite()140     virtual ~TestScriptInstructionSparseImageWrite() {}
Execute(Uscript::UScriptEnv & env,Uscript::UScriptContext & context)141     int32_t Execute(Uscript::UScriptEnv &env, Uscript::UScriptContext &context) override
142     {
143         // 从参数中获取分区信息
144         std::string partitionName;
145         int32_t ret = context.GetParam(0, partitionName);
146         if (ret != USCRIPT_SUCCESS) {
147             LOG(ERROR) << "Error to get param";
148             return ret;
149         }
150         LOG(INFO) << "UScriptInstructionSparseImageWrite::Execute " << partitionName;
151         if (env.GetPkgManager() == nullptr) {
152             LOG(ERROR) << "Error to get pkg manager";
153             return USCRIPT_ERROR_EXECUTE;
154         }
155         return ret;
156     }
157 };
158 
159 class TestScriptInstructionFactory : public UScriptInstructionFactory {
160 public:
CreateInstructionInstance(UScriptInstructionPtr & instr,const std::string & name)161     virtual int32_t CreateInstructionInstance(UScriptInstructionPtr& instr, const std::string& name)
162     {
163         if (name == "sparse_image_write") {
164             instr = new TestScriptInstructionSparseImageWrite();
165         }
166         return USCRIPT_SUCCESS;
167     }
DestoryInstructionInstance(UScriptInstructionPtr & instr)168     virtual void DestoryInstructionInstance(UScriptInstructionPtr& instr)
169     {
170         delete instr;
171         instr = nullptr;
172     }
TestScriptInstructionFactory()173     TestScriptInstructionFactory() {}
~TestScriptInstructionFactory()174     virtual ~TestScriptInstructionFactory() {}
175 };
176 
177 class UTestScriptEnv : public UScriptEnv {
178 public:
UTestScriptEnv(Hpackage::PkgManager::PkgManagerPtr pkgManager)179     explicit UTestScriptEnv(Hpackage::PkgManager::PkgManagerPtr pkgManager) : UScriptEnv(pkgManager) {}
~UTestScriptEnv()180     ~UTestScriptEnv()
181     {
182         if (factory_ != nullptr) {
183             delete factory_;
184             factory_ = nullptr;
185         }
186     }
187 
PostMessage(const std::string & cmd,std::string content)188     virtual void PostMessage(const std::string &cmd, std::string content) {}
189 
GetInstructionFactory()190     virtual UScriptInstructionFactoryPtr GetInstructionFactory()
191     {
192         if (factory_ == nullptr) {
193             factory_ = new TestScriptInstructionFactory();
194         }
195         return factory_;
196     }
197 
GetInstructionNames() const198     virtual const std::vector<std::string> GetInstructionNames() const
199     {
200         static std::vector<std::string> updaterCmds = {"sparse_image_write"};
201         return updaterCmds;
202     }
203 
IsRetry() const204     virtual bool IsRetry() const
205     {
206         return isRetry;
207     }
208 
GetPostmsgFunc()209     virtual PostMessageFunction GetPostmsgFunc()
210     {
211         return nullptr;
212     }
213     UScriptInstructionFactory *factory_ = nullptr;
214 private:
215     bool isRetry = false;
216 };
217 
218 class UScriptTest : public ::testing::Test {
219 public:
UScriptTest()220     UScriptTest() {}
221 
~UScriptTest()222     ~UScriptTest()
223     {
224         ScriptManager::ReleaseScriptManager();
225     }
226 
TestUscriptExecute()227     int TestUscriptExecute()
228     {
229         int32_t ret {};
230         TestPkgManager packageManager;
231         UTestScriptEnv *env = new UTestScriptEnv(&packageManager);
232         HashDataVerifier verifier {&packageManager};
233         verifier.LoadHashDataAndPkcs7(TEST_PATH_FROM + "updater_fake_pkg.zip");
234         ScriptManager *manager = ScriptManager::GetScriptManager(env, &verifier);
235         if (manager == nullptr) {
236             USCRIPT_LOGI("create manager fail ret:%d", ret);
237             delete env;
238             return USCRIPT_INVALID_SCRIPT;
239         }
240         int32_t priority = SCRIPT_TEST_PRIORITY_NUM;
241         ret = manager->ExecuteScript(priority);
242         EXPECT_EQ(ret, USCRIPT_ABOART);
243         USCRIPT_LOGI("ExecuteScript ret:%d", ret);
244         priority = 0;
245         ret = manager->ExecuteScript(priority);
246         priority = 1;
247         ret = manager->ExecuteScript(priority);
248         priority = SCRIPT_TEST_LAST_PRIORITY;
249         ret = manager->ExecuteScript(priority);
250         delete env;
251         ScriptManager::ReleaseScriptManager();
252         return ret;
253     }
254 
255 protected:
SetUp()256     void SetUp() {}
TearDown()257     void TearDown() {}
TestBody()258     void TestBody() {}
259 
260 private:
261     std::vector<std::string> testFileNames_ = {
262         "loadScript.us",
263         "registerCmd.us",
264         "test_function.us",
265         "test_if.us",
266         "test_logic.us",
267         "test_math.us",
268         "test_native.us",
269         "testscript.us",
270         "Verse-script.us",
271         "test_script.us"
272     };
273 };
274 
275 HWTEST_F(UScriptTest, TestUscriptExecute, TestSize.Level1)
276 {
277     UScriptTest test;
278     EXPECT_EQ(0, test.TestUscriptExecute());
279 }
280 }
281