/* * Copyright (C) 2023 Huawei Device Co., Ltd. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "mdns_packet_parser.h" #include "netmgr_ext_log_wrapper.h" #include <cstring> namespace OHOS { namespace NetManagerStandard { namespace { constexpr size_t MDNS_STR_INITIAL_SIZE = 16; constexpr uint8_t DNS_STR_PTR_U8_MASK = 0xc0; constexpr uint16_t DNS_STR_PTR_U16_MASK = 0xc000; constexpr uint16_t DNS_STR_PTR_LENGTH = 0x3f; constexpr uint8_t DNS_STR_EOL = '\0'; template <class T> void WriteRawData(const T &data, MDnsPayload &payload) { const uint8_t *begin = reinterpret_cast<const uint8_t *>(&data); payload.insert(payload.end(), begin, begin + sizeof(T)); } template <class T> void WriteRawData(const T &data, uint8_t *ptr) { const uint8_t *begin = reinterpret_cast<const uint8_t *>(&data); for (size_t i = 0; i < sizeof(T); ++i) { ptr[i] = *begin++; } } template <class T> const uint8_t *ReadRawData(const uint8_t *raw, T &data) { data = *reinterpret_cast<const T *>(raw); return raw + sizeof(T); } const uint8_t *ReadNUint16(const uint8_t *raw, uint16_t &data) { const uint8_t *tmp = ReadRawData(raw, data); data = ntohs(data); return tmp; } const uint8_t *ReadNUint32(const uint8_t *raw, uint32_t &data) { const uint8_t *tmp = ReadRawData(raw, data); data = ntohl(data); return tmp; } std::string UnDotted(const std::string &name) { return EndsWith(name, MDNS_DOMAIN_SPLITER_STR) ? name.substr(0, name.size() - 1) : name; } } // namespace MDnsMessage MDnsPayloadParser::FromBytes(const MDnsPayload &payload) { MDnsMessage msg; errorFlags_ = PARSE_OK; pos_ = Parse(payload.data(), payload, msg); return msg; } MDnsPayload MDnsPayloadParser::ToBytes(const MDnsMessage &msg) { MDnsPayload payload; MDnsPayload *cachedPayload = &payload; std::map<std::string, uint16_t> strCacheMap; Serialize(msg, payload, cachedPayload, strCacheMap); return payload; } const uint8_t *MDnsPayloadParser::Parse(const uint8_t *begin, const MDnsPayload &payload, MDnsMessage &msg) { begin = ParseHeader(begin, payload, msg.header); if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) { return begin; } for (int i = 0; i < msg.header.qdcount; ++i) { begin = ParseQuestion(begin, payload, msg.questions); if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) { return begin; } } for (int i = 0; i < msg.header.ancount; ++i) { begin = ParseRR(begin, payload, msg.answers); if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) { return begin; } } for (int i = 0; i < msg.header.nscount; ++i) { begin = ParseRR(begin, payload, msg.authorities); if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) { return begin; } } for (int i = 0; i < msg.header.arcount; ++i) { begin = ParseRR(begin, payload, msg.additional); if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) { return begin; } } return begin; } const uint8_t *MDnsPayloadParser::ParseHeader(const uint8_t *begin, const MDnsPayload &payload, DNSProto::Header &header) { const uint8_t *end = payload.data() + payload.size(); if (end - begin < static_cast<int>(sizeof(DNSProto::Header))) { errorFlags_ |= PARSE_ERROR_BAD_SIZE; return begin; } begin = ReadNUint16(begin, header.id); begin = ReadNUint16(begin, header.flags); begin = ReadNUint16(begin, header.qdcount); begin = ReadNUint16(begin, header.ancount); begin = ReadNUint16(begin, header.nscount); begin = ReadNUint16(begin, header.arcount); return begin; } const uint8_t *MDnsPayloadParser::ParseQuestion(const uint8_t *begin, const MDnsPayload &payload, std::vector<DNSProto::Question> &questions) { questions.emplace_back(); begin = ParseDnsString(begin, payload, questions.back().name); questions.back().name = UnDotted(questions.back().name); if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) { questions.pop_back(); return begin; } const uint8_t *end = payload.data() + payload.size(); if (static_cast<ssize_t>(end - begin) < static_cast<ssize_t>(sizeof(uint16_t) + sizeof(uint16_t))) { errorFlags_ |= PARSE_ERROR_BAD_SIZE; questions.pop_back(); return begin; } begin = ReadNUint16(begin, questions.back().qtype); begin = ReadNUint16(begin, questions.back().qclass); return begin; } const uint8_t *MDnsPayloadParser::ParseRR(const uint8_t *begin, const MDnsPayload &payload, std::vector<DNSProto::ResourceRecord> &answers) { answers.emplace_back(); begin = ParseDnsString(begin, payload, answers.back().name); answers.back().name = UnDotted(answers.back().name); if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) { answers.pop_back(); return begin; } const uint8_t *end = payload.data() + payload.size(); if (static_cast<ssize_t>(end - begin) < static_cast<ssize_t>(sizeof(uint16_t) + sizeof(uint16_t) + sizeof(uint32_t) + sizeof(uint16_t))) { errorFlags_ |= PARSE_ERROR_BAD_SIZE; answers.pop_back(); return begin; } begin = ReadNUint16(begin, answers.back().rtype); begin = ReadNUint16(begin, answers.back().rclass); begin = ReadNUint32(begin, answers.back().ttl); begin = ReadNUint16(begin, answers.back().length); return ParseRData(begin, payload, answers.back().rtype, answers.back().length, answers.back().rdata); } const uint8_t *MDnsPayloadParser::ParseRData(const uint8_t *begin, const MDnsPayload &payload, int type, int length, std::any &data) { switch (type) { case DNSProto::RRTYPE_A: { const uint8_t *end = payload.data() + payload.size(); if (static_cast<size_t>(end - begin) < sizeof(in_addr) || length != sizeof(in_addr)) { errorFlags_ |= PARSE_ERROR_BAD_SIZE; return begin; } in_addr addr; begin = ReadRawData(begin, addr); data = addr; return begin; } case DNSProto::RRTYPE_AAAA: { const uint8_t *end = payload.data() + payload.size(); if ((static_cast<ssize_t>(end - begin) < static_cast<ssize_t>(sizeof(in6_addr))) || (length != sizeof(in6_addr))) { errorFlags_ |= PARSE_ERROR_BAD_SIZE; return begin; } in6_addr addr; begin = ReadRawData(begin, addr); data = addr; return begin; } case DNSProto::RRTYPE_PTR: { std::string str; begin = ParseDnsString(begin, payload, str); str = UnDotted(str); if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) { return begin; } data = str; return begin; } case DNSProto::RRTYPE_SRV: { return ParseSrv(begin, payload, data); } case DNSProto::RRTYPE_TXT: { return ParseTxt(begin, payload, length, data); } default: { errorFlags_ |= PARSE_WARNING_BAD_RRTYPE; return begin + length; } } } const uint8_t *MDnsPayloadParser::ParseSrv(const uint8_t *begin, const MDnsPayload &payload, std::any &data) { const uint8_t *end = payload.data() + payload.size(); if (static_cast<ssize_t>(end - begin) < static_cast<ssize_t>(sizeof(uint16_t) + sizeof(uint16_t) + sizeof(uint16_t))) { errorFlags_ |= PARSE_ERROR_BAD_SIZE; return begin; } DNSProto::RDataSrv srv; begin = ReadNUint16(begin, srv.priority); begin = ReadNUint16(begin, srv.weight); begin = ReadNUint16(begin, srv.port); begin = ParseDnsString(begin, payload, srv.name); srv.name = UnDotted(srv.name); if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) { return begin; } data = srv; return begin; } const uint8_t *MDnsPayloadParser::ParseTxt(const uint8_t *begin, const MDnsPayload &payload, int length, std::any &data) { const uint8_t *end = payload.data() + payload.size(); if (end - begin < length) { errorFlags_ |= PARSE_ERROR_BAD_SIZE; return begin; } data = TxtRecordEncoded(begin, begin + length); return begin + length; } const uint8_t *MDnsPayloadParser::ParseDnsString(const uint8_t *begin, const MDnsPayload &payload, std::string &str) { const uint8_t *end = payload.data() + payload.size(); const uint8_t *p = begin; str.reserve(MDNS_STR_INITIAL_SIZE); while (p && p < end) { if (*p == 0) { return p + 1; } if (*p <= MDNS_MAX_DOMAIN_LABEL && p + *p < end) { str.append(reinterpret_cast<const char *>(p) + 1, *p); str.push_back(MDNS_DOMAIN_SPLITER); p += (*p + 1); } else if ((*p & DNS_STR_PTR_U8_MASK) == DNS_STR_PTR_U8_MASK) { if (end - p < static_cast<int>(sizeof(uint16_t))) { errorFlags_ |= PARSE_ERROR_BAD_SIZE; return begin; } uint16_t offset; const uint8_t *tmp = ReadNUint16(p, offset); offset = offset & ~DNS_STR_PTR_U16_MASK; const uint8_t *next = payload.data() + (offset & ~DNS_STR_PTR_U16_MASK); if (next >= end || next >= begin) { errorFlags_ |= PARSE_ERROR_BAD_STRPTR; return begin; } ParseDnsString(next, payload, str); if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) { return begin; } return tmp; } else { errorFlags_ |= PARSE_ERROR_BAD_STR; return p; } } return p; } void MDnsPayloadParser::Serialize(const MDnsMessage &msg, MDnsPayload &payload, MDnsPayload *cachedPayload, std::map<std::string, uint16_t> &strCacheMap) { payload.reserve(sizeof(DNSProto::Message)); DNSProto::Header header = msg.header; header.qdcount = msg.questions.size(); header.ancount = msg.answers.size(); header.nscount = msg.authorities.size(); header.arcount = msg.additional.size(); SerializeHeader(header, msg, payload); for (uint16_t i = 0; i < header.qdcount; ++i) { SerializeQuestion(msg.questions[i], payload, cachedPayload, strCacheMap); } for (uint16_t i = 0; i < header.ancount; ++i) { SerializeRR(msg.answers[i], payload, cachedPayload, strCacheMap); } for (uint16_t i = 0; i < header.nscount; ++i) { SerializeRR(msg.authorities[i], payload, cachedPayload, strCacheMap); } for (uint16_t i = 0; i < header.arcount; ++i) { SerializeRR(msg.additional[i], payload, cachedPayload, strCacheMap); } } void MDnsPayloadParser::SerializeHeader(const DNSProto::Header &header, const MDnsMessage &msg, MDnsPayload &payload) { WriteRawData(htons(header.id), payload); WriteRawData(htons(header.flags), payload); WriteRawData(htons(header.qdcount), payload); WriteRawData(htons(header.ancount), payload); WriteRawData(htons(header.nscount), payload); WriteRawData(htons(header.arcount), payload); } void MDnsPayloadParser::SerializeQuestion(const DNSProto::Question &question, MDnsPayload &payload, MDnsPayload *cachedPayload, std::map<std::string, uint16_t> &strCacheMap) { SerializeDnsString(question.name, payload, cachedPayload, strCacheMap); WriteRawData(htons(question.qtype), payload); WriteRawData(htons(question.qclass), payload); } void MDnsPayloadParser::SerializeRR(const DNSProto::ResourceRecord &rr, MDnsPayload &payload, MDnsPayload *cachedPayload, std::map<std::string, uint16_t> &strCacheMap) { SerializeDnsString(rr.name, payload, cachedPayload, strCacheMap); WriteRawData(htons(rr.rtype), payload); WriteRawData(htons(rr.rclass), payload); WriteRawData(htonl(rr.ttl), payload); size_t lenStart = payload.size(); WriteRawData(htons(rr.length), payload); SerializeRData(rr.rdata, payload, cachedPayload, strCacheMap); uint16_t len = payload.size() - lenStart - sizeof(uint16_t); WriteRawData(htons(len), payload.data() + lenStart); } void MDnsPayloadParser::SerializeRData(const std::any &rdata, MDnsPayload &payload, MDnsPayload *cachedPayload, std::map<std::string, uint16_t> &strCacheMap) { if (std::any_cast<const in_addr>(&rdata)) { WriteRawData(*std::any_cast<const in_addr>(&rdata), payload); } else if (std::any_cast<const in6_addr>(&rdata)) { WriteRawData(*std::any_cast<const in6_addr>(&rdata), payload); } else if (std::any_cast<const std::string>(&rdata)) { SerializeDnsString(*std::any_cast<const std::string>(&rdata), payload, cachedPayload, strCacheMap); } else if (std::any_cast<const DNSProto::RDataSrv>(&rdata)) { const DNSProto::RDataSrv *srv = std::any_cast<const DNSProto::RDataSrv>(&rdata); WriteRawData(htons(srv->priority), payload); WriteRawData(htons(srv->weight), payload); WriteRawData(htons(srv->port), payload); SerializeDnsString(srv->name, payload, cachedPayload, strCacheMap); } else if (std::any_cast<TxtRecordEncoded>(&rdata)) { const auto *txt = std::any_cast<TxtRecordEncoded>(&rdata); payload.insert(payload.end(), txt->begin(), txt->end()); } } void MDnsPayloadParser::SerializeDnsString(const std::string &str, MDnsPayload &payload, MDnsPayload *cachedPayload, std::map<std::string, uint16_t> &strCacheMap) { size_t pos = 0; while (pos < str.size()) { if ((cachedPayload == &payload) && (strCacheMap.find(str.substr(pos)) != strCacheMap.end())) { return WriteRawData(htons(strCacheMap[str.substr(pos)]), payload); } size_t nextDot = str.find(MDNS_DOMAIN_SPLITER, pos); if (nextDot == std::string::npos) { nextDot = str.size(); } uint8_t segLen = (nextDot - pos) & DNS_STR_PTR_LENGTH; uint16_t strptr = payload.size(); WriteRawData(segLen, payload); for (int i = 0; i < segLen; ++i) { WriteRawData(str[pos + i], payload); } strCacheMap[str.substr(pos)] = strptr | DNS_STR_PTR_U16_MASK; pos = nextDot + 1; } WriteRawData(DNS_STR_EOL, payload); } uint32_t MDnsPayloadParser::GetError() const { return errorFlags_ & PARSE_ERROR; } } // namespace NetManagerStandard } // namespace OHOS