1 /*
2 * Copyright (C) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "spirv_cross_helpers_gles.h"
17
18 #include <cmath>
19 #include <glcorearb.h>
20
21 namespace Gles {
22 namespace {
23 static const spirv_cross::SPIRConstant invalid {};
24
FindConstant(const std::vector<PushConstantReflection> & reflections,const PushConstantReflection & reflection)25 int32_t FindConstant(const std::vector<PushConstantReflection>& reflections, const PushConstantReflection& reflection)
26 {
27 for (size_t i = 0; i < reflections.size(); i++) {
28 if (reflection.name == reflections[i].name) {
29 // Check that it's actually same and not a conflict!.
30 if (reflection.type != reflections[i].type) {
31 return -2;
32 }
33 if (reflection.offset != reflections[i].offset) {
34 return -2;
35 }
36 if (reflection.size != reflections[i].size) {
37 return -2;
38 }
39 if (reflection.arraySize != reflections[i].arraySize) {
40 return -2;
41 }
42 if (reflection.arrayStride != reflections[i].arrayStride) {
43 return -2;
44 }
45 if (reflection.matrixStride != reflections[i].matrixStride) {
46 return -2;
47 }
48 return (int32_t)i;
49 }
50 }
51 return -1;
52 }
53
ConstByName(const spirv_cross::CompilerGLSL & compiler,const char * name)54 const spirv_cross::SPIRConstant& ConstByName(const spirv_cross::CompilerGLSL& compiler, const char* name)
55 {
56 const auto& specInfo = ((CoreCompiler&)compiler).GetConstants();
57 for (auto& c : specInfo) {
58 const auto& opName = compiler.get_name(c.self);
59 if (opName == name) {
60 auto& constant = compiler.get_constant(c.self);
61 return constant;
62 }
63 }
64 // is default invalid?
65 return invalid;
66 }
67
SpecConstByName(const spirv_cross::CompilerGLSL & compiler,const char * name)68 const spirv_cross::SPIRConstant& SpecConstByName(const spirv_cross::CompilerGLSL& compiler, const char* name)
69 {
70 const auto& specInfo = compiler.get_specialization_constants();
71 for (const auto& c : specInfo) {
72 const auto& opName = compiler.get_name(c.id);
73 if (opName == name) {
74 auto& constant = compiler.get_constant(c.id);
75 return constant;
76 }
77 }
78 // is default invalid?
79 return invalid;
80 }
81 } // namespace
82
83 // inherit from CompilerGLSL to have better access
CoreCompiler(const uint32_t * ir,size_t wordCount)84 CoreCompiler::CoreCompiler(const uint32_t* ir, size_t wordCount) : CompilerGLSL(ir, wordCount) {}
85
GetConstants() const86 const std::vector<spirv_cross::SPIRConstant> CoreCompiler::GetConstants() const
87 {
88 std::vector<spirv_cross::SPIRConstant> consts;
89 ir.for_each_typed_id<spirv_cross::SPIRConstant>(
90 [&consts](uint32_t, const spirv_cross::SPIRConstant& c) { consts.push_back(c); });
91 return consts;
92 }
93
GetIr() const94 const spirv_cross::ParsedIR& CoreCompiler::GetIr() const
95 {
96 return ir;
97 }
98
ReflectPushConstants(spirv_cross::Compiler & compiler,const spirv_cross::ShaderResources & resources,std::vector<PushConstantReflection> & reflections,ShaderStageFlags stage)99 void ReflectPushConstants(spirv_cross::Compiler& compiler, const spirv_cross::ShaderResources& resources,
100 std::vector<PushConstantReflection>& reflections, ShaderStageFlags stage)
101 {
102 char ids[64];
103 int id = 0;
104 // There can be only one push_constant_buffer, but since spirv-cross has prepared for this to be relaxed, we will
105 // too.
106 std::string name = "CORE_PC_00";
107 for (auto& remap : resources.push_constant_buffers) {
108 const auto& blockType = compiler.get_type(remap.base_type_id);
109 (void)(blockType);
110 sprintf(ids, "%d", id);
111 name.resize(8);
112 name.append(ids);
113 compiler.set_name(remap.id, name);
114 assert((blockType.basetype == spirv_cross::SPIRType::Struct) && "Push constant is not a struct!");
115 ProcessStruct(std::string_view(name.data(), name.size()), 0, compiler, remap.base_type_id, reflections, stage);
116 id++;
117 }
118 }
119
120 // Converts specialization constant to normal constant, (to reduce unnecessary clutter in glsl)
ConvertSpecConstToConstant(spirv_cross::CompilerGLSL & compiler,const char * name)121 void ConvertSpecConstToConstant(spirv_cross::CompilerGLSL& compiler, const char* name)
122 {
123 const auto& c = SpecConstByName(compiler, name);
124 if (c.self == invalid.self) {
125 return;
126 }
127 compiler.unset_decoration(c.self, spv::Decoration::DecorationSpecId);
128 }
129
130 // Converts constant declaration to uniform. (actually only works on spec constants)
ConvertConstantToUniform(const spirv_cross::CompilerGLSL & compiler,std::string & source,const char * name)131 void ConvertConstantToUniform(const spirv_cross::CompilerGLSL& compiler, std::string& source, const char* name)
132 {
133 std::string tmp;
134 tmp.reserve(strlen(name) + 16);
135 const auto& constant = ConstByName(compiler, name);
136 if (constant.self == invalid.self) {
137 return;
138 }
139 const auto& type = compiler.get_type(constant.constant_type);
140 if (type.basetype == spirv_cross::SPIRType::Boolean) {
141 tmp += "const bool ";
142 } else if (type.basetype == spirv_cross::SPIRType::UInt) {
143 tmp += "const uint ";
144 } else if (type.basetype == spirv_cross::SPIRType::Int) {
145 tmp += "const int ";
146 } else if (type.basetype == spirv_cross::SPIRType::Float) {
147 tmp += "const float ";
148 } else {
149 assert(false && "Unhandled specialization constant type");
150 }
151 // We expect spirv_cross to generate them with certain pattern..
152 tmp += name;
153 tmp += " =";
154 const auto p = source.find(tmp);
155 if (p != std::string::npos) {
156 // found it, change it. (changes const to uniform)
157 auto bi = source.begin() + (int64_t)p;
158 auto ei = bi + 6;
159 source.replace(bi, ei, "uniform ");
160
161 // remove the initializer..
162 const auto p2 = source.find('=', p);
163 const auto p3 = source.find(';', p);
164 if ((p2 != std::string::npos) && (p3 != std::string::npos)) {
165 if (p2 < p3) {
166 // should be correct (tm)
167 bi = source.begin() + (int64_t)p2;
168 ei = source.begin() + (int64_t)p3;
169 source.erase(bi, ei);
170 }
171 }
172 }
173 }
174
SetSpecMacro(spirv_cross::CompilerGLSL & compiler,const char * name,uint32_t value)175 void SetSpecMacro(spirv_cross::CompilerGLSL& compiler, const char* name, uint32_t value)
176 {
177 const auto& vc = SpecConstByName(compiler, name);
178 if (vc.self != invalid.self) {
179 const uint32_t constantId = compiler.get_decoration(vc.self, spv::Decoration::DecorationSpecId);
180 char buf[1024];
181 sprintf(buf, "#define SPIRV_CROSS_CONSTANT_ID_%u %du", constantId, value);
182 compiler.add_header_line(buf);
183 }
184 }
185
ProcessStruct(std::string_view baseName,size_t baseOffset,const spirv_cross::Compiler & compiler,uint32_t structTypeId,std::vector<PushConstantReflection> & reflections,ShaderStageFlags stage)186 void ProcessStruct(std::string_view baseName, size_t baseOffset, const spirv_cross::Compiler& compiler,
187 uint32_t structTypeId, std::vector<PushConstantReflection>& reflections, ShaderStageFlags stage)
188 {
189 const auto& structType = compiler.get_type(structTypeId);
190 reflections.reserve(reflections.size() + structType.member_types.size());
191 for (uint32_t bi = 0; bi < structType.member_types.size(); bi++) {
192 const uint32_t memberTypeId = structType.member_types[bi];
193 const auto& memberType = compiler.get_type(memberTypeId);
194 const auto& name = compiler.get_member_name(structTypeId, bi);
195
196 PushConstantReflection t;
197 t.stage = stage;
198 t.name = baseName;
199 t.name += '.';
200 t.name += name;
201 t.arrayStride = 0;
202 t.matrixStride = 0;
203 // Get member offset within this struct.
204 t.offset = baseOffset + compiler.type_struct_member_offset(structType, bi);
205 t.size = compiler.get_declared_struct_member_size(structType, bi);
206 t.arraySize = 0;
207 if (!memberType.array.empty()) {
208 // Get array stride, e.g. float4 foo[]; Will have array stride of 16 bytes.
209 t.arrayStride = compiler.type_struct_member_array_stride(structType, bi);
210 t.arraySize = memberType.array[0]; // We don't support arrays of arrays. just use the size of first.
211 }
212
213 if (memberType.columns > 1) {
214 // Get bytes stride between columns (if column major), for float4x4 -> 16 bytes.
215 t.matrixStride = compiler.type_struct_member_matrix_stride(structType, bi);
216 }
217
218 switch (memberType.basetype) {
219 case spirv_cross::SPIRType::Struct:
220 ProcessStruct(t.name, t.offset, compiler, memberTypeId, reflections, stage);
221 continue;
222 break;
223 case spirv_cross::SPIRType::UInt: {
224 constexpr GLenum type[5][5] = { { 0, 0, 0, 0, 0 }, { 0, GL_UNSIGNED_INT, 0, 0, 0 },
225 { 0, GL_UNSIGNED_INT_VEC2, 0, 0, 0 }, { 0, GL_UNSIGNED_INT_VEC3, 0, 0, 0 },
226 { 0, GL_UNSIGNED_INT_VEC4, 0, 0, 0 } };
227 t.type = type[memberType.vecsize][memberType.columns];
228 break;
229 }
230 case spirv_cross::SPIRType::Float: {
231 constexpr GLenum type[5][5] = {
232 { 0, 0, 0, 0, 0 },
233 { 0, GL_FLOAT, 0, 0, 0 },
234 { 0, GL_FLOAT_VEC2, GL_FLOAT_MAT2, GL_FLOAT_MAT3x2, GL_FLOAT_MAT4x2 },
235 { 0, GL_FLOAT_VEC3, GL_FLOAT_MAT2x3, GL_FLOAT_MAT3, GL_FLOAT_MAT4x3 },
236 { 0, GL_FLOAT_VEC4, GL_FLOAT_MAT2x4, GL_FLOAT_MAT3x4, GL_FLOAT_MAT4 },
237 };
238 t.type = type[memberType.vecsize][memberType.columns];
239 break;
240 }
241 case spirv_cross::SPIRType::Unknown:
242 case spirv_cross::SPIRType::Void:
243 case spirv_cross::SPIRType::Boolean:
244 case spirv_cross::SPIRType::Char:
245 case spirv_cross::SPIRType::SByte:
246 case spirv_cross::SPIRType::UByte:
247 case spirv_cross::SPIRType::Short:
248 case spirv_cross::SPIRType::UShort:
249 case spirv_cross::SPIRType::Int:
250 case spirv_cross::SPIRType::Int64:
251 case spirv_cross::SPIRType::UInt64:
252 case spirv_cross::SPIRType::AtomicCounter:
253 case spirv_cross::SPIRType::Half:
254 case spirv_cross::SPIRType::Double:
255 case spirv_cross::SPIRType::Image:
256 case spirv_cross::SPIRType::SampledImage:
257 case spirv_cross::SPIRType::Sampler:
258 case spirv_cross::SPIRType::AccelerationStructure:
259 case spirv_cross::SPIRType::ControlPointArray:
260 default:
261 t.type = 0;
262 break;
263 }
264 assert((t.type != 0) && "Unhandled Type!");
265 const int32_t res = FindConstant(reflections, t);
266 assert((res >= -1) && "Push constant conflict.");
267 if (res == -1) {
268 reflections.push_back(t);
269 }
270 }
271 }
272
273 #ifdef PLUGIN_UNUSED_SPRIV_CROSS_HELPERS
DefineForSpec(const std::vector<SpecConstantInfo> & reflectionInfo,uint32_t spcid,uintptr_t offset,std::string & result)274 bool DefineForSpec(
275 const std::vector<SpecConstantInfo>& reflectionInfo, uint32_t spcid, uintptr_t offset, std::string& result)
276 {
277 // "#define SPIRV_CROSS_CONSTANT_ID_4294967295 4294967295\n" //worst case for bool
278 // "#define SPIRV_CROSS_CONSTANT_ID_4294967295 4294967295\n" //worst case for uint32
279 // "#define SPIRV_CROSS_CONSTANT_ID_4294967295 -2147483648\n"//worst case for int32
280 // and floats can be REALLY long..
281 char buf[1024];
282 bool ok = false;
283 for (const auto& c : reflectionInfo) {
284 if (c.constantId == spcid) {
285 if (c.name == "CORE_BACKEND_TYPE") {
286 ok = true; // backend type can't change anymore..
287 continue;
288 }
289 const auto& type = c.constantType;
290 [[maybe_unused]] const size_t size = c.vectorSize * c.columns * sizeof(uint32_t);
291 assert((size == sizeof(uint32_t)) && "Specialization constant size is not 4!");
292 // The constant_id can only be applied to a scalar *int*, a scalar *float* or a scalar *bool*.
293 // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_gl_spirv.txt
294 if ((type == SpecConstantInfo::Types::BOOL) || (type == SpecConstantInfo::Types::UINT32)) {
295 const uint32_t value = *reinterpret_cast<uint32_t*>(offset);
296 const int len = sprintf(buf, "%u %uu\n", c.constantId, value);
297 ok = len > 0;
298 } else if (type == SpecConstantInfo::Types::INT32) {
299 const int32_t value = *reinterpret_cast<int32_t*>(offset);
300 const int len = sprintf(buf, "%u %d\n", c.constantId, value);
301 ok = len > 0;
302 } else if (type == SpecConstantInfo::Types::FLOAT) {
303 const float value = *reinterpret_cast<float_t*>(offset);
304 // NOTE: resulting constant might not be the same. due to float -> string -> float conversions.
305 const int len = sprintf(buf, "%u %f\n", c.constantId, value);
306 ok = len > 0;
307 } else {
308 assert(false && "Unhandled specialization constant type");
309 }
310 if (ok) {
311 result.append("#define SPIRV_CROSS_CONSTANT_ID_");
312 result.append(buf);
313 }
314 break;
315 }
316 }
317 return ok;
318 }
319
InsertDefines(std::string_view shaderIn,std::string_view Defines)320 std::string InsertDefines(std::string_view shaderIn, std::string_view Defines)
321 {
322 std::string shaderOut;
323 // Create defines..
324 if (!shaderIn.empty()) {
325 const size_t voff = shaderIn.find_first_of('\n');
326 shaderOut.reserve(shaderIn.length() + Defines.length());
327 shaderOut.append(shaderIn.substr(0, voff + 1));
328 shaderOut.append(Defines);
329 shaderOut.append(shaderIn.substr(voff + 1));
330 } else {
331 shaderOut = Defines;
332 }
333 return shaderOut;
334 }
335
Specialize(ShaderStageFlags mask,std::string_view shaderTemplate,const std::vector<SpecConstantInfo> & info,const ShaderSpecializationConstantDataView & data)336 std::string Specialize(ShaderStageFlags mask, std::string_view shaderTemplate,
337 const std::vector<SpecConstantInfo>& info, const ShaderSpecializationConstantDataView& data)
338 {
339 if (shaderTemplate.empty()) {
340 return {};
341 }
342 bool ok = false;
343 for (const auto& spc : data.constants) {
344 if (spc.shaderStage & mask) {
345 ok = true;
346 break;
347 }
348 }
349 if (!ok) {
350 // nothing to specialize
351 return std::string(shaderTemplate);
352 }
353 // Create defines..
354 const uintptr_t base = (uintptr_t)data.data.data();
355 std::string defines;
356 defines.reserve(256);
357 for (const auto& spc : data.constants) {
358 if (spc.shaderStage & mask) {
359 const uintptr_t offset = base + spc.offset;
360 DefineForSpec(info, spc.id, offset, defines);
361 }
362 }
363 // inject defines to shader source.
364 return InsertDefines(shaderTemplate, defines);
365 }
366
CreateSpecInfos(const spirv_cross::Compiler & compiler,std::vector<SpecConstantInfo> & outSpecInfo)367 void CreateSpecInfos(const spirv_cross::Compiler& compiler, std::vector<SpecConstantInfo>& outSpecInfo)
368 {
369 const auto& specInfo = compiler.get_specialization_constants();
370 for (const auto& c : specInfo) {
371 SpecConstantInfo t;
372 t.constantId = c.constant_id;
373 const spirv_cross::SPIRConstant& constant = compiler.get_constant(c.id);
374 const auto& name = compiler.get_name(c.id);
375 const auto type = compiler.get_type(constant.constant_type);
376 if (type.basetype == spirv_cross::SPIRType::Boolean) {
377 t.constantType = SpecConstantInfo::Types::BOOL;
378 } else if (type.basetype == spirv_cross::SPIRType::UInt) {
379 t.constantType = SpecConstantInfo::Types::UINT32;
380 } else if (type.basetype == spirv_cross::SPIRType::Int) {
381 t.constantType = SpecConstantInfo::Types::INT32;
382 } else if (type.basetype == spirv_cross::SPIRType::Float) {
383 t.constantType = SpecConstantInfo::Types::FLOAT;
384 } else {
385 assert(false && "Unhandled specialization constant type");
386 }
387 t.vectorSize = constant.vector_size();
388 t.columns = constant.columns();
389 t.name.assign(name.data(), name.size());
390 outSpecInfo.push_back(std::move(t));
391 }
392 }
393
ConstId(spirv_cross::CompilerGLSL & compiler,const char * name)394 uint32_t ConstId(spirv_cross::CompilerGLSL& compiler, const char* name)
395 {
396 const auto& c = ConstByName(compiler, name);
397 return c.self;
398 }
399
SpecConstId(spirv_cross::CompilerGLSL & compiler,const char * name)400 uint32_t SpecConstId(spirv_cross::CompilerGLSL& compiler, const char* name)
401 {
402 const auto& c = SpecConstByName(compiler, name);
403 return c.self;
404 }
405 #endif
406
407 } // namespace Gles
408