1 /*
2  * Copyright (C) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "spirv_cross_helpers_gles.h"
17 
18 #include <cmath>
19 #include <glcorearb.h>
20 
21 namespace Gles {
22 namespace {
23 static const spirv_cross::SPIRConstant invalid {};
24 
FindConstant(const std::vector<PushConstantReflection> & reflections,const PushConstantReflection & reflection)25 int32_t FindConstant(const std::vector<PushConstantReflection>& reflections, const PushConstantReflection& reflection)
26 {
27     for (size_t i = 0; i < reflections.size(); i++) {
28         if (reflection.name == reflections[i].name) {
29             // Check that it's actually same and not a conflict!.
30             if (reflection.type != reflections[i].type) {
31                 return -2;
32             }
33             if (reflection.offset != reflections[i].offset) {
34                 return -2;
35             }
36             if (reflection.size != reflections[i].size) {
37                 return -2;
38             }
39             if (reflection.arraySize != reflections[i].arraySize) {
40                 return -2;
41             }
42             if (reflection.arrayStride != reflections[i].arrayStride) {
43                 return -2;
44             }
45             if (reflection.matrixStride != reflections[i].matrixStride) {
46                 return -2;
47             }
48             return (int32_t)i;
49         }
50     }
51     return -1;
52 }
53 
ConstByName(const spirv_cross::CompilerGLSL & compiler,const char * name)54 const spirv_cross::SPIRConstant& ConstByName(const spirv_cross::CompilerGLSL& compiler, const char* name)
55 {
56     const auto& specInfo = ((CoreCompiler&)compiler).GetConstants();
57     for (auto& c : specInfo) {
58         const auto& opName = compiler.get_name(c.self);
59         if (opName == name) {
60             auto& constant = compiler.get_constant(c.self);
61             return constant;
62         }
63     }
64     // is default invalid?
65     return invalid;
66 }
67 
SpecConstByName(const spirv_cross::CompilerGLSL & compiler,const char * name)68 const spirv_cross::SPIRConstant& SpecConstByName(const spirv_cross::CompilerGLSL& compiler, const char* name)
69 {
70     const auto& specInfo = compiler.get_specialization_constants();
71     for (const auto& c : specInfo) {
72         const auto& opName = compiler.get_name(c.id);
73         if (opName == name) {
74             auto& constant = compiler.get_constant(c.id);
75             return constant;
76         }
77     }
78     // is default invalid?
79     return invalid;
80 }
81 } // namespace
82 
83 // inherit from CompilerGLSL to have better access
CoreCompiler(const uint32_t * ir,size_t wordCount)84 CoreCompiler::CoreCompiler(const uint32_t* ir, size_t wordCount) : CompilerGLSL(ir, wordCount) {}
85 
GetConstants() const86 const std::vector<spirv_cross::SPIRConstant> CoreCompiler::GetConstants() const
87 {
88     std::vector<spirv_cross::SPIRConstant> consts;
89     ir.for_each_typed_id<spirv_cross::SPIRConstant>(
90         [&consts](uint32_t, const spirv_cross::SPIRConstant& c) { consts.push_back(c); });
91     return consts;
92 }
93 
GetIr() const94 const spirv_cross::ParsedIR& CoreCompiler::GetIr() const
95 {
96     return ir;
97 }
98 
ReflectPushConstants(spirv_cross::Compiler & compiler,const spirv_cross::ShaderResources & resources,std::vector<PushConstantReflection> & reflections,ShaderStageFlags stage)99 void ReflectPushConstants(spirv_cross::Compiler& compiler, const spirv_cross::ShaderResources& resources,
100     std::vector<PushConstantReflection>& reflections, ShaderStageFlags stage)
101 {
102     char ids[64];
103     int id = 0;
104     // There can be only one push_constant_buffer, but since spirv-cross has prepared for this to be relaxed, we will
105     // too.
106     std::string name = "CORE_PC_00";
107     for (auto& remap : resources.push_constant_buffers) {
108         const auto& blockType = compiler.get_type(remap.base_type_id);
109         (void)(blockType);
110         sprintf(ids, "%d", id);
111         name.resize(8);
112         name.append(ids);
113         compiler.set_name(remap.id, name);
114         assert((blockType.basetype == spirv_cross::SPIRType::Struct) && "Push constant is not a struct!");
115         ProcessStruct(std::string_view(name.data(), name.size()), 0, compiler, remap.base_type_id, reflections, stage);
116         id++;
117     }
118 }
119 
120 // Converts specialization constant to normal constant, (to reduce unnecessary clutter in glsl)
ConvertSpecConstToConstant(spirv_cross::CompilerGLSL & compiler,const char * name)121 void ConvertSpecConstToConstant(spirv_cross::CompilerGLSL& compiler, const char* name)
122 {
123     const auto& c = SpecConstByName(compiler, name);
124     if (c.self == invalid.self) {
125         return;
126     }
127     compiler.unset_decoration(c.self, spv::Decoration::DecorationSpecId);
128 }
129 
130 // Converts constant declaration to uniform. (actually only works on spec constants)
ConvertConstantToUniform(const spirv_cross::CompilerGLSL & compiler,std::string & source,const char * name)131 void ConvertConstantToUniform(const spirv_cross::CompilerGLSL& compiler, std::string& source, const char* name)
132 {
133     std::string tmp;
134     tmp.reserve(strlen(name) + 16);
135     const auto& constant = ConstByName(compiler, name);
136     if (constant.self == invalid.self) {
137         return;
138     }
139     const auto& type = compiler.get_type(constant.constant_type);
140     if (type.basetype == spirv_cross::SPIRType::Boolean) {
141         tmp += "const bool ";
142     } else if (type.basetype == spirv_cross::SPIRType::UInt) {
143         tmp += "const uint ";
144     } else if (type.basetype == spirv_cross::SPIRType::Int) {
145         tmp += "const int ";
146     } else if (type.basetype == spirv_cross::SPIRType::Float) {
147         tmp += "const float ";
148     } else {
149         assert(false && "Unhandled specialization constant type");
150     }
151     // We expect spirv_cross to generate them with certain pattern..
152     tmp += name;
153     tmp += " =";
154     const auto p = source.find(tmp);
155     if (p != std::string::npos) {
156         // found it, change it. (changes const to uniform)
157         auto bi = source.begin() + (int64_t)p;
158         auto ei = bi + 6;
159         source.replace(bi, ei, "uniform ");
160 
161         // remove the initializer..
162         const auto p2 = source.find('=', p);
163         const auto p3 = source.find(';', p);
164         if ((p2 != std::string::npos) && (p3 != std::string::npos)) {
165             if (p2 < p3) {
166                 // should be correct (tm)
167                 bi = source.begin() + (int64_t)p2;
168                 ei = source.begin() + (int64_t)p3;
169                 source.erase(bi, ei);
170             }
171         }
172     }
173 }
174 
SetSpecMacro(spirv_cross::CompilerGLSL & compiler,const char * name,uint32_t value)175 void SetSpecMacro(spirv_cross::CompilerGLSL& compiler, const char* name, uint32_t value)
176 {
177     const auto& vc = SpecConstByName(compiler, name);
178     if (vc.self != invalid.self) {
179         const uint32_t constantId = compiler.get_decoration(vc.self, spv::Decoration::DecorationSpecId);
180         char buf[1024];
181         sprintf(buf, "#define SPIRV_CROSS_CONSTANT_ID_%u %du", constantId, value);
182         compiler.add_header_line(buf);
183     }
184 }
185 
ProcessStruct(std::string_view baseName,size_t baseOffset,const spirv_cross::Compiler & compiler,uint32_t structTypeId,std::vector<PushConstantReflection> & reflections,ShaderStageFlags stage)186 void ProcessStruct(std::string_view baseName, size_t baseOffset, const spirv_cross::Compiler& compiler,
187     uint32_t structTypeId, std::vector<PushConstantReflection>& reflections, ShaderStageFlags stage)
188 {
189     const auto& structType = compiler.get_type(structTypeId);
190     reflections.reserve(reflections.size() + structType.member_types.size());
191     for (uint32_t bi = 0; bi < structType.member_types.size(); bi++) {
192         const uint32_t memberTypeId = structType.member_types[bi];
193         const auto& memberType = compiler.get_type(memberTypeId);
194         const auto& name = compiler.get_member_name(structTypeId, bi);
195 
196         PushConstantReflection t;
197         t.stage = stage;
198         t.name = baseName;
199         t.name += '.';
200         t.name += name;
201         t.arrayStride = 0;
202         t.matrixStride = 0;
203         // Get member offset within this struct.
204         t.offset = baseOffset + compiler.type_struct_member_offset(structType, bi);
205         t.size = compiler.get_declared_struct_member_size(structType, bi);
206         t.arraySize = 0;
207         if (!memberType.array.empty()) {
208             // Get array stride, e.g. float4 foo[]; Will have array stride of 16 bytes.
209             t.arrayStride = compiler.type_struct_member_array_stride(structType, bi);
210             t.arraySize = memberType.array[0]; // We don't support arrays of arrays. just use the size of first.
211         }
212 
213         if (memberType.columns > 1) {
214             // Get bytes stride between columns (if column major), for float4x4 -> 16 bytes.
215             t.matrixStride = compiler.type_struct_member_matrix_stride(structType, bi);
216         }
217 
218         switch (memberType.basetype) {
219             case spirv_cross::SPIRType::Struct:
220                 ProcessStruct(t.name, t.offset, compiler, memberTypeId, reflections, stage);
221                 continue;
222                 break;
223             case spirv_cross::SPIRType::UInt: {
224                 constexpr GLenum type[5][5] = { { 0, 0, 0, 0, 0 }, { 0, GL_UNSIGNED_INT, 0, 0, 0 },
225                     { 0, GL_UNSIGNED_INT_VEC2, 0, 0, 0 }, { 0, GL_UNSIGNED_INT_VEC3, 0, 0, 0 },
226                     { 0, GL_UNSIGNED_INT_VEC4, 0, 0, 0 } };
227                 t.type = type[memberType.vecsize][memberType.columns];
228                 break;
229             }
230             case spirv_cross::SPIRType::Float: {
231                 constexpr GLenum type[5][5] = {
232                     { 0, 0, 0, 0, 0 },
233                     { 0, GL_FLOAT, 0, 0, 0 },
234                     { 0, GL_FLOAT_VEC2, GL_FLOAT_MAT2, GL_FLOAT_MAT3x2, GL_FLOAT_MAT4x2 },
235                     { 0, GL_FLOAT_VEC3, GL_FLOAT_MAT2x3, GL_FLOAT_MAT3, GL_FLOAT_MAT4x3 },
236                     { 0, GL_FLOAT_VEC4, GL_FLOAT_MAT2x4, GL_FLOAT_MAT3x4, GL_FLOAT_MAT4 },
237                 };
238                 t.type = type[memberType.vecsize][memberType.columns];
239                 break;
240             }
241             case spirv_cross::SPIRType::Unknown:
242             case spirv_cross::SPIRType::Void:
243             case spirv_cross::SPIRType::Boolean:
244             case spirv_cross::SPIRType::Char:
245             case spirv_cross::SPIRType::SByte:
246             case spirv_cross::SPIRType::UByte:
247             case spirv_cross::SPIRType::Short:
248             case spirv_cross::SPIRType::UShort:
249             case spirv_cross::SPIRType::Int:
250             case spirv_cross::SPIRType::Int64:
251             case spirv_cross::SPIRType::UInt64:
252             case spirv_cross::SPIRType::AtomicCounter:
253             case spirv_cross::SPIRType::Half:
254             case spirv_cross::SPIRType::Double:
255             case spirv_cross::SPIRType::Image:
256             case spirv_cross::SPIRType::SampledImage:
257             case spirv_cross::SPIRType::Sampler:
258             case spirv_cross::SPIRType::AccelerationStructure:
259             case spirv_cross::SPIRType::ControlPointArray:
260             default:
261                 t.type = 0;
262                 break;
263         }
264         assert((t.type != 0) && "Unhandled Type!");
265         const int32_t res = FindConstant(reflections, t);
266         assert((res >= -1) && "Push constant conflict.");
267         if (res == -1) {
268             reflections.push_back(t);
269         }
270     }
271 }
272 
273 #ifdef PLUGIN_UNUSED_SPRIV_CROSS_HELPERS
DefineForSpec(const std::vector<SpecConstantInfo> & reflectionInfo,uint32_t spcid,uintptr_t offset,std::string & result)274 bool DefineForSpec(
275     const std::vector<SpecConstantInfo>& reflectionInfo, uint32_t spcid, uintptr_t offset, std::string& result)
276 {
277     // "#define SPIRV_CROSS_CONSTANT_ID_4294967295 4294967295\n" //worst case for bool
278     // "#define SPIRV_CROSS_CONSTANT_ID_4294967295 4294967295\n" //worst case for uint32
279     // "#define SPIRV_CROSS_CONSTANT_ID_4294967295 -2147483648\n"//worst case for int32
280     // and floats can be REALLY long..
281     char buf[1024];
282     bool ok = false;
283     for (const auto& c : reflectionInfo) {
284         if (c.constantId == spcid) {
285             if (c.name == "CORE_BACKEND_TYPE") {
286                 ok = true; // backend type can't change anymore..
287                 continue;
288             }
289             const auto& type = c.constantType;
290             [[maybe_unused]] const size_t size = c.vectorSize * c.columns * sizeof(uint32_t);
291             assert((size == sizeof(uint32_t)) && "Specialization constant size is not 4!");
292             //  The constant_id can only be applied to a scalar *int*, a scalar *float* or a scalar *bool*.
293             //  https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_gl_spirv.txt
294             if ((type == SpecConstantInfo::Types::BOOL) || (type == SpecConstantInfo::Types::UINT32)) {
295                 const uint32_t value = *reinterpret_cast<uint32_t*>(offset);
296                 const int len = sprintf(buf, "%u %uu\n", c.constantId, value);
297                 ok = len > 0;
298             } else if (type == SpecConstantInfo::Types::INT32) {
299                 const int32_t value = *reinterpret_cast<int32_t*>(offset);
300                 const int len = sprintf(buf, "%u %d\n", c.constantId, value);
301                 ok = len > 0;
302             } else if (type == SpecConstantInfo::Types::FLOAT) {
303                 const float value = *reinterpret_cast<float_t*>(offset);
304                 // NOTE: resulting constant might not be the same. due to float -> string -> float conversions.
305                 const int len = sprintf(buf, "%u %f\n", c.constantId, value);
306                 ok = len > 0;
307             } else {
308                 assert(false && "Unhandled specialization constant type");
309             }
310             if (ok) {
311                 result.append("#define SPIRV_CROSS_CONSTANT_ID_");
312                 result.append(buf);
313             }
314             break;
315         }
316     }
317     return ok;
318 }
319 
InsertDefines(std::string_view shaderIn,std::string_view Defines)320 std::string InsertDefines(std::string_view shaderIn, std::string_view Defines)
321 {
322     std::string shaderOut;
323     // Create defines..
324     if (!shaderIn.empty()) {
325         const size_t voff = shaderIn.find_first_of('\n');
326         shaderOut.reserve(shaderIn.length() + Defines.length());
327         shaderOut.append(shaderIn.substr(0, voff + 1));
328         shaderOut.append(Defines);
329         shaderOut.append(shaderIn.substr(voff + 1));
330     } else {
331         shaderOut = Defines;
332     }
333     return shaderOut;
334 }
335 
Specialize(ShaderStageFlags mask,std::string_view shaderTemplate,const std::vector<SpecConstantInfo> & info,const ShaderSpecializationConstantDataView & data)336 std::string Specialize(ShaderStageFlags mask, std::string_view shaderTemplate,
337     const std::vector<SpecConstantInfo>& info, const ShaderSpecializationConstantDataView& data)
338 {
339     if (shaderTemplate.empty()) {
340         return {};
341     }
342     bool ok = false;
343     for (const auto& spc : data.constants) {
344         if (spc.shaderStage & mask) {
345             ok = true;
346             break;
347         }
348     }
349     if (!ok) {
350         // nothing to specialize
351         return std::string(shaderTemplate);
352     }
353     // Create defines..
354     const uintptr_t base = (uintptr_t)data.data.data();
355     std::string defines;
356     defines.reserve(256);
357     for (const auto& spc : data.constants) {
358         if (spc.shaderStage & mask) {
359             const uintptr_t offset = base + spc.offset;
360             DefineForSpec(info, spc.id, offset, defines);
361         }
362     }
363     // inject defines to shader source.
364     return InsertDefines(shaderTemplate, defines);
365 }
366 
CreateSpecInfos(const spirv_cross::Compiler & compiler,std::vector<SpecConstantInfo> & outSpecInfo)367 void CreateSpecInfos(const spirv_cross::Compiler& compiler, std::vector<SpecConstantInfo>& outSpecInfo)
368 {
369     const auto& specInfo = compiler.get_specialization_constants();
370     for (const auto& c : specInfo) {
371         SpecConstantInfo t;
372         t.constantId = c.constant_id;
373         const spirv_cross::SPIRConstant& constant = compiler.get_constant(c.id);
374         const auto& name = compiler.get_name(c.id);
375         const auto type = compiler.get_type(constant.constant_type);
376         if (type.basetype == spirv_cross::SPIRType::Boolean) {
377             t.constantType = SpecConstantInfo::Types::BOOL;
378         } else if (type.basetype == spirv_cross::SPIRType::UInt) {
379             t.constantType = SpecConstantInfo::Types::UINT32;
380         } else if (type.basetype == spirv_cross::SPIRType::Int) {
381             t.constantType = SpecConstantInfo::Types::INT32;
382         } else if (type.basetype == spirv_cross::SPIRType::Float) {
383             t.constantType = SpecConstantInfo::Types::FLOAT;
384         } else {
385             assert(false && "Unhandled specialization constant type");
386         }
387         t.vectorSize = constant.vector_size();
388         t.columns = constant.columns();
389         t.name.assign(name.data(), name.size());
390         outSpecInfo.push_back(std::move(t));
391     }
392 }
393 
ConstId(spirv_cross::CompilerGLSL & compiler,const char * name)394 uint32_t ConstId(spirv_cross::CompilerGLSL& compiler, const char* name)
395 {
396     const auto& c = ConstByName(compiler, name);
397     return c.self;
398 }
399 
SpecConstId(spirv_cross::CompilerGLSL & compiler,const char * name)400 uint32_t SpecConstId(spirv_cross::CompilerGLSL& compiler, const char* name)
401 {
402     const auto& c = SpecConstByName(compiler, name);
403     return c.self;
404 }
405 #endif
406 
407 } // namespace Gles
408