1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "shader_module_gles.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 
21 #include <base/containers/array_view.h>
22 #include <base/containers/fixed_string.h>
23 #include <base/containers/iterator.h>
24 #include <base/containers/string.h>
25 #include <base/containers/string_view.h>
26 #include <base/containers/type_traits.h>
27 #include <base/math/vector.h>
28 #include <render/device/pipeline_layout_desc.h>
29 #include <render/namespace.h>
30 
31 #include "device/gpu_program_util.h"
32 #include "device/shader_manager.h"
33 #include "gles/spirv_cross_helpers_gles.h"
34 #include "util/log.h"
35 
36 using namespace BASE_NS;
37 
38 RENDER_BEGIN_NAMESPACE()
39 namespace {
40 template<typename SetType>
Collect(const uint32_t set,const DescriptorSetLayoutBinding & binding,SetType & sets)41 void Collect(const uint32_t set, const DescriptorSetLayoutBinding& binding, SetType& sets)
42 {
43     const auto name = "s" + to_string(set) + "_b" + to_string(binding.binding);
44     sets.push_back({ static_cast<uint8_t>(set), static_cast<uint8_t>(binding.binding),
45         static_cast<uint8_t>(binding.descriptorCount), string { name } });
46 }
47 
CollectRes(const PipelineLayout & pipeline,ShaderModulePlatformDataGLES & plat_)48 void CollectRes(const PipelineLayout& pipeline, ShaderModulePlatformDataGLES& plat_)
49 {
50     struct Bind {
51         uint8_t set;
52         uint8_t bind;
53     };
54     vector<Bind> samplers;
55     vector<Bind> images;
56     for (const auto& set : pipeline.descriptorSetLayouts) {
57         if (set.set != PipelineLayoutConstants::INVALID_INDEX) {
58             for (const auto& binding : set.bindings) {
59                 switch (binding.descriptorType) {
60                     case DescriptorType::CORE_DESCRIPTOR_TYPE_SAMPLER:
61                         samplers.push_back({ static_cast<uint8_t>(set.set), static_cast<uint8_t>(binding.binding) });
62                         break;
63                     case DescriptorType::CORE_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
64                         Collect(set.set, binding, plat_.cbSets);
65                         break;
66                     case DescriptorType::CORE_DESCRIPTOR_TYPE_SAMPLED_IMAGE:
67                         images.push_back({ static_cast<uint8_t>(set.set), static_cast<uint8_t>(binding.binding) });
68                         break;
69                     case DescriptorType::CORE_DESCRIPTOR_TYPE_STORAGE_IMAGE:
70                         Collect(set.set, binding, plat_.ciSets);
71                         break;
72                     case DescriptorType::CORE_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER:
73                         break;
74                     case DescriptorType::CORE_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER:
75                         break;
76                     case DescriptorType::CORE_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
77                         Collect(set.set, binding, plat_.ubSets);
78                         break;
79                     case DescriptorType::CORE_DESCRIPTOR_TYPE_STORAGE_BUFFER:
80                         Collect(set.set, binding, plat_.sbSets);
81                         break;
82                     case DescriptorType::CORE_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC:
83                         break;
84                     case DescriptorType::CORE_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC:
85                         break;
86                     case DescriptorType::CORE_DESCRIPTOR_TYPE_INPUT_ATTACHMENT:
87                         Collect(set.set, binding, plat_.siSets);
88                         break;
89                     case DescriptorType::CORE_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE:
90                         break;
91                     case DescriptorType::CORE_DESCRIPTOR_TYPE_MAX_ENUM:
92                         break;
93                 }
94             }
95         }
96     }
97     for (const auto& sBinding : samplers) {
98         for (const auto& iBinding : images) {
99             const auto name = "s" + to_string(iBinding.set) + "_b" + to_string(iBinding.bind) + "_s" +
100                               to_string(sBinding.set) + "_b" + to_string(sBinding.bind);
101             plat_.combSets.push_back({ sBinding.set, sBinding.bind, iBinding.set, iBinding.bind, string { name } });
102         }
103     }
104 }
105 
CreateSpecInfos(array_view<const ShaderSpecialization::Constant> constants,vector<Gles::SpecConstantInfo> & outSpecInfo)106 void CreateSpecInfos(
107     array_view<const ShaderSpecialization::Constant> constants, vector<Gles::SpecConstantInfo>& outSpecInfo)
108 {
109     static_assert(static_cast<uint32_t>(Gles::SpecConstantInfo::Types::BOOL) ==
110                   static_cast<uint32_t>(ShaderSpecialization::Constant::Type::BOOL));
111     for (const auto& constant : constants) {
112         Gles::SpecConstantInfo info { static_cast<Gles::SpecConstantInfo::Types>(constant.type), constant.id, 1U, 1U,
113             {} };
114         outSpecInfo.push_back(info);
115     }
116 }
117 
SortSets(PipelineLayout & pipelineLayout)118 void SortSets(PipelineLayout& pipelineLayout)
119 {
120     pipelineLayout.descriptorSetCount = 0;
121     for (uint32_t idx = 0; idx < PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT; ++idx) {
122         DescriptorSetLayout& currSet = pipelineLayout.descriptorSetLayouts[idx];
123         if (currSet.set != PipelineLayoutConstants::INVALID_INDEX) {
124             pipelineLayout.descriptorSetCount++;
125             std::sort(currSet.bindings.begin(), currSet.bindings.end(),
126                 [](auto const& lhs, auto const& rhs) { return (lhs.binding < rhs.binding); });
127         }
128     }
129 }
130 } // namespace
131 struct Reader {
132     const uint8_t* ptr;
GetUint8Reader133     uint8_t GetUint8()
134     {
135         return *ptr++;
136     }
137 
GetUint16Reader138     uint16_t GetUint16()
139     {
140         const uint16_t value = static_cast<uint16_t>(*ptr | (*(ptr + 1) << 8));
141         ptr += sizeof(uint16_t);
142         return value;
143     }
GetUint32Reader144     uint32_t GetUint32()
145     {
146         const uint32_t value =
147             static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8) | ((*(ptr + 2)) << 16) | ((*(ptr + 3)) << 24));
148         ptr += sizeof(uint32_t);
149         return value;
150     }
GetStringViewReader151     string_view GetStringView()
152     {
153         string_view value;
154         const uint16_t len = GetUint16();
155         value = string_view(static_cast<const char*>(static_cast<const void*>(ptr)), len);
156         ptr += len;
157         return value;
158     }
159 };
160 template<typename ShaderBase>
ProcessShaderModule(ShaderBase & me,const ShaderModuleCreateInfo & createInfo)161 void ProcessShaderModule(ShaderBase& me, const ShaderModuleCreateInfo& createInfo)
162 {
163     me.pipelineLayout_ = createInfo.reflectionData.GetPipelineLayout();
164     if (me.shaderStageFlags_ & CORE_SHADER_STAGE_VERTEX_BIT) {
165         me.vertexInputAttributeDescriptions_ = createInfo.reflectionData.GetInputDescriptions();
166         me.vertexInputBindingDescriptions_.reserve(me.vertexInputAttributeDescriptions_.size());
167         for (const auto& attrib : me.vertexInputAttributeDescriptions_) {
168             VertexInputDeclaration::VertexInputBindingDescription bindingDesc;
169             bindingDesc.binding = attrib.binding;
170             bindingDesc.stride = GpuProgramUtil::FormatByteSize(attrib.format);
171             bindingDesc.vertexInputRate = VertexInputRate::CORE_VERTEX_INPUT_RATE_VERTEX;
172             me.vertexInputBindingDescriptions_.push_back(bindingDesc);
173         }
174         me.vidv_.bindingDescriptions = { me.vertexInputBindingDescriptions_.data(),
175             me.vertexInputBindingDescriptions_.size() };
176         me.vidv_.attributeDescriptions = { me.vertexInputAttributeDescriptions_.data(),
177             me.vertexInputAttributeDescriptions_.size() };
178     }
179 
180     if (me.shaderStageFlags_ & CORE_SHADER_STAGE_COMPUTE_BIT) {
181         const Math::UVec3 tgs = createInfo.reflectionData.GetLocalSize();
182         me.stg_.x = tgs.x;
183         me.stg_.y = tgs.y;
184         me.stg_.z = tgs.z;
185     }
186     if (auto* ptr = createInfo.reflectionData.GetPushConstants(); ptr) {
187         Reader read { ptr };
188         const auto constants = read.GetUint8();
189         for (uint8_t i = 0U; i < constants; ++i) {
190             Gles::PushConstantReflection refl;
191             refl.type = read.GetUint32();
192             refl.offset = read.GetUint16();
193             refl.size = read.GetUint16();
194             refl.arraySize = read.GetUint16();
195             refl.arrayStride = read.GetUint16();
196             refl.matrixStride = read.GetUint16();
197             refl.name = "CORE_PC_0";
198             refl.name += read.GetStringView();
199             refl.stage = me.shaderStageFlags_;
200             me.plat_.infos.push_back(move(refl));
201         }
202     }
203 
204     me.constants_ = createInfo.reflectionData.GetSpecializationConstants();
205     me.sscv_.constants = { me.constants_.data(), me.constants_.size() };
206     CollectRes(me.pipelineLayout_, me.plat_);
207     CreateSpecInfos(me.constants_, me.specInfo_);
208     // sort bindings inside sets (and count them)
209     SortSets(me.pipelineLayout_);
210 
211     me.source_.assign(
212         static_cast<const char*>(static_cast<const void*>(createInfo.spvData.data())), createInfo.spvData.size());
213 }
214 
215 template<typename ShaderBase>
SpecializeShaderModule(const ShaderBase & base,const ShaderSpecializationConstantDataView & specData)216 string SpecializeShaderModule(const ShaderBase& base, const ShaderSpecializationConstantDataView& specData)
217 {
218     return Gles::Specialize(base.shaderStageFlags_, base.source_, base.constants_, specData);
219 }
220 
ShaderModuleGLES(Device & device,const ShaderModuleCreateInfo & createInfo)221 ShaderModuleGLES::ShaderModuleGLES(Device& device, const ShaderModuleCreateInfo& createInfo)
222     : device_(device), shaderStageFlags_(createInfo.shaderStageFlags)
223 {
224     if (createInfo.reflectionData.IsValid() &&
225         (shaderStageFlags_ &
226             (CORE_SHADER_STAGE_VERTEX_BIT | CORE_SHADER_STAGE_FRAGMENT_BIT | CORE_SHADER_STAGE_COMPUTE_BIT))) {
227         ProcessShaderModule(*this, createInfo);
228     } else {
229         PLUGIN_LOG_E("invalid shader stages or invalid reflection data for shader module, invalid shader module");
230     }
231 }
232 
233 ShaderModuleGLES::~ShaderModuleGLES() = default;
234 
GetShaderStageFlags() const235 ShaderStageFlags ShaderModuleGLES::GetShaderStageFlags() const
236 {
237     return shaderStageFlags_;
238 }
239 
GetGLSL(const ShaderSpecializationConstantDataView & specData) const240 string ShaderModuleGLES::GetGLSL(const ShaderSpecializationConstantDataView& specData) const
241 {
242     return SpecializeShaderModule(*this, specData);
243 }
244 
GetPlatformData() const245 const ShaderModulePlatformData& ShaderModuleGLES::GetPlatformData() const
246 {
247     return plat_;
248 }
249 
GetPipelineLayout() const250 const PipelineLayout& ShaderModuleGLES::GetPipelineLayout() const
251 {
252     return pipelineLayout_;
253 }
254 
GetSpecilization() const255 ShaderSpecializationConstantView ShaderModuleGLES::GetSpecilization() const
256 {
257     return sscv_;
258 }
259 
GetVertexInputDeclaration() const260 VertexInputDeclarationView ShaderModuleGLES::GetVertexInputDeclaration() const
261 {
262     return vidv_;
263 }
264 
GetThreadGroupSize() const265 ShaderThreadGroup ShaderModuleGLES::GetThreadGroupSize() const
266 {
267     return stg_;
268 }
269 RENDER_END_NAMESPACE()
270