/* * Copyright (C) 2023-2024 Huawei Device Co., Ltd. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "mdns_client_fuzzer.h" #include #include #include #include "i_mdns_event.h" #include "iservice_registry.h" #include "message_parcel.h" #include "net_manager_ext_constants.h" #include "refbase.h" #include "system_ability_definition.h" #include "netmgr_ext_log_wrapper.h" #include "mdns_protocol_impl.h" #define private public #include "mdns_client.h" #include "mdns_service.h" #undef private namespace OHOS { namespace NetManagerStandard { class IRegistrationCallbackTest : public IRemoteStub { public: void HandleRegister(const MDnsServiceInfo &serviceInfo, int32_t retCode) override {} void HandleUnRegister(const MDnsServiceInfo &serviceInfo, int32_t retCode) override {} void HandleRegisterResult(const MDnsServiceInfo &serviceInfo, int32_t retCode) override {} }; class IDiscoveryCallbackTest : public IRemoteStub { public: void HandleStartDiscover(const MDnsServiceInfo &serviceInfo, int32_t retCode) override {} void HandleStopDiscover(const MDnsServiceInfo &serviceInfo, int32_t retCode) override {} void HandleServiceFound(const MDnsServiceInfo &serviceInfo, int32_t retCode) override {} void HandleServiceLost(const MDnsServiceInfo &serviceInfo, int32_t retCode) override {} }; class IResolveCallbackTest : public IRemoteStub { public: void HandleResolveResult(const MDnsServiceInfo &serviceInfo, int32_t retCode) override {} }; static const uint8_t *g_baseFuzzData = nullptr; static size_t g_baseFuzzSize = 0; static size_t g_baseFuzzPos; static bool g_isInited = false; static constexpr size_t STR_LEN = 10; bool InitGlobalData(const uint8_t *data, size_t size) { if (data == nullptr || size == 0) { return false; } g_baseFuzzData = data; g_baseFuzzSize = size; g_baseFuzzPos = 0; return true; } template T GetData() { T object{}; size_t objectSize = sizeof(object); if (g_baseFuzzData == nullptr || objectSize > g_baseFuzzSize - g_baseFuzzPos) { return object; } errno_t ret = memcpy_s(&object, objectSize, g_baseFuzzData + g_baseFuzzPos, objectSize); if (ret != EOK) { return {}; } g_baseFuzzPos += objectSize; return object; } std::string GetStringFromData(int strlen) { char cstr[strlen]; cstr[strlen - 1] = '\0'; for (int i = 0; i < strlen - 1; i++) { cstr[i] = GetData(); } std::string str(cstr); return str; } bool WriteInterfaceToken(MessageParcel &data) { return data.WriteInterfaceToken(IMDnsService::GetDescriptor()); } __attribute__((no_sanitize("cfi"))) bool GetMessageParcel(const uint8_t *data, size_t size, MessageParcel &dataParcel) { if (!InitGlobalData(data, size)) { return false; } if (!WriteInterfaceToken(dataParcel)) { return false; } sptr info = new (std::nothrow) MDnsServiceInfo(); info->name = GetStringFromData(STR_LEN); info->type = GetStringFromData(STR_LEN); info->family = GetData(); info->addr = GetStringFromData(STR_LEN); info->port = GetData(); std::string str = GetStringFromData(STR_LEN); info->txtRecord = std::vector(str.begin(), str.end()); if (!MDnsServiceInfo::Marshalling(dataParcel, info)) { return false; } return true; } void Init() { if (!g_isInited) { DelayedSingleton::GetInstance()->Init(); g_isInited = true; } } __attribute__((no_sanitize("cfi"))) int32_t OnRemoteRequest(uint32_t code, MessageParcel &data) { if (!g_isInited) { Init(); } MessageParcel reply; MessageOption option; return DelayedSingleton::GetInstance()->OnRemoteRequest(code, data, reply, option); } void RegisterServiceFuzzTest(const uint8_t *data, size_t size) { NETMGR_EXT_LOG_D("RegisterServiceFuzzTest enter"); MessageParcel dataParcel; if (!GetMessageParcel(data, size, dataParcel)) { return; } sptr callback = new (std::nothrow) IRegistrationCallbackTest(); if (callback == nullptr) { return; } dataParcel.WriteRemoteObject(callback->AsObject().GetRefPtr()); OnRemoteRequest(static_cast(MdnsServiceInterfaceCode::CMD_REGISTER), dataParcel); } void UnRegisterServiceFuzzTest(const uint8_t *data, size_t size) { NETMGR_EXT_LOG_D("UnRegisterServiceFuzzTest enter"); MessageParcel dataParcel; if (!GetMessageParcel(data, size, dataParcel)) { return; } sptr callback = new (std::nothrow) IRegistrationCallbackTest(); if (callback == nullptr) { return; } dataParcel.WriteRemoteObject(callback->AsObject().GetRefPtr()); OnRemoteRequest(static_cast(MdnsServiceInterfaceCode::CMD_STOP_REGISTER), dataParcel); } void StartDiscoverServiceFuzzTest(const uint8_t *data, size_t size) { NETMGR_EXT_LOG_D("StartDiscoverServiceFuzzTest enter"); MessageParcel dataParcel; if (!GetMessageParcel(data, size, dataParcel)) { return; } sptr callback = new (std::nothrow) IDiscoveryCallbackTest(); if (callback == nullptr) { return; } dataParcel.WriteRemoteObject(callback->AsObject().GetRefPtr()); OnRemoteRequest(static_cast(MdnsServiceInterfaceCode::CMD_DISCOVER), dataParcel); } void StopDiscoverServiceFuzzTest(const uint8_t *data, size_t size) { NETMGR_EXT_LOG_D("StopDiscoverServiceFuzzTest enter"); MessageParcel dataParcel; if (!GetMessageParcel(data, size, dataParcel)) { return; } sptr callback = new (std::nothrow) IDiscoveryCallbackTest(); if (callback == nullptr) { return; } dataParcel.WriteRemoteObject(callback->AsObject().GetRefPtr()); OnRemoteRequest(static_cast(MdnsServiceInterfaceCode::CMD_STOP_DISCOVER), dataParcel); } void ResolveServiceFuzzTest(const uint8_t *data, size_t size) { NETMGR_EXT_LOG_D("ResolveServiceFuzzTest enter"); MessageParcel dataParcel; if (!GetMessageParcel(data, size, dataParcel)) { return; } sptr callback = new (std::nothrow) IResolveCallbackTest(); if (callback == nullptr) { return; } dataParcel.WriteRemoteObject(callback->AsObject().GetRefPtr()); OnRemoteRequest(static_cast(MdnsServiceInterfaceCode::CMD_RESOLVE), dataParcel); } void MdnsRegisterServiceFuzzTest(const uint8_t *data, size_t size) { if (data == nullptr || size == 0) { return; } MDnsServiceInfo serviceInfo; std::string name(reinterpret_cast(data), size); serviceInfo.name = name; serviceInfo.port = static_cast(size % STR_LEN); sptr callback = new (std::nothrow) IRegistrationCallbackTest(); if (callback == nullptr) { return; } DelayedSingleton::GetInstance()->RegisterService(serviceInfo, callback); DelayedSingleton::GetInstance()->UnRegisterService(callback); } void MdnsStartDiscoverServiceFuzzTest(const uint8_t *data, size_t size) { if (data == nullptr || size == 0) { return; } sptr callback = new (std::nothrow) IDiscoveryCallbackTest(); if (callback == nullptr) { return; } std::string serviceType(reinterpret_cast(data), size); DelayedSingleton::GetInstance()->StartDiscoverService(serviceType, callback); DelayedSingleton::GetInstance()->StopDiscoverService(callback); } void MdnsResolveServiceFuzzTest(const uint8_t *data, size_t size) { if (data == nullptr || size == 0) { return; } sptr callback = new (std::nothrow) IResolveCallbackTest(); if (callback == nullptr) { return; } MDnsServiceInfo serviceInfo; std::string name(reinterpret_cast(data), size); serviceInfo.port = static_cast(size % STR_LEN); serviceInfo.name = name; DelayedSingleton::GetInstance()->ResolveService(serviceInfo, callback); DelayedSingleton::GetInstance()->RestartResume(); } void ReceivePacketTest(const uint8_t *data, size_t size) { if (data == nullptr || size == 0) { return; } std::string str = GetStringFromData(STR_LEN); std::vector copy = std::vector(str.begin(), str.end()); MDnsPayloadParser parser; MDnsMessage msg = parser.FromBytes(copy); } } // namespace NetManagerStandard } // namespace OHOS /* Fuzzer entry point */ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { /* Run your code on data */ OHOS::NetManagerStandard::RegisterServiceFuzzTest(data, size); OHOS::NetManagerStandard::StartDiscoverServiceFuzzTest(data, size); OHOS::NetManagerStandard::StopDiscoverServiceFuzzTest(data, size); OHOS::NetManagerStandard::ResolveServiceFuzzTest(data, size); OHOS::NetManagerStandard::UnRegisterServiceFuzzTest(data, size); OHOS::NetManagerStandard::MdnsRegisterServiceFuzzTest(data, size); OHOS::NetManagerStandard::MdnsStartDiscoverServiceFuzzTest(data, size); OHOS::NetManagerStandard::MdnsResolveServiceFuzzTest(data, size); OHOS::NetManagerStandard::ReceivePacketTest(data, size); return 0; }