1 /*
2  * Copyright (C) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "default_limits.h"
17 #include <glslang/Public/ShaderLang.h>
18 #include <SPIRV/GlslangToSpv.h>
19 #include <SPIRV/SpvTools.h>
20 #include <spirv-tools/optimizer.hpp>
21 
22 #include "spirv_cross.hpp"
23 
24 // #include "preprocess/preprocess.h"
25 #include <algorithm>
26 #include <chrono>
27 #include <filesystem>
28 #include <fstream>
29 #include <iostream>
30 #include <memory>
31 #include <numeric>
32 #include <optional>
33 #include <sstream>
34 #include <string>
35 #include <thread>
36 
37 #include "io/dev/FileMonitor.h"
38 #include "lume/Log.h"
39 #include "shader_type.h"
40 #include "spirv_cross_helpers_gles.h"
41 
42 using namespace std::chrono_literals;
43 
44 // Enumerations from Engine which should match: Format, DescriptorType, ShaderStageFlagBits
45 /** Format */
46 enum class Format {
47     /** Undefined */
48     UNDEFINED = 0,
49     /** R4G4 UNORM PACK8 */
50     R4G4_UNORM_PACK8 = 1,
51     /** R4G4B4A4 UNORM PACK16 */
52     R4G4B4A4_UNORM_PACK16 = 2,
53     /** B4G4R4A4 UNORM PACK16 */
54     B4G4R4A4_UNORM_PACK16 = 3,
55     /** R5G6B5 UNORM PACK16 */
56     R5G6B5_UNORM_PACK16 = 4,
57     /** B5G6R5 UNORM PACK16 */
58     B5G6R5_UNORM_PACK16 = 5,
59     /** R5G5B5A1 UNORM PACK16 */
60     R5G5B5A1_UNORM_PACK16 = 6,
61     /** B5G5R5A1 UNORM PACK16 */
62     B5G5R5A1_UNORM_PACK16 = 7,
63     /** A1R5G5B5 UNORM PACK16 */
64     A1R5G5B5_UNORM_PACK16 = 8,
65     /** R8 UNORM */
66     R8_UNORM = 9,
67     /** R8 SNORM */
68     R8_SNORM = 10,
69     /** R8 USCALED */
70     R8_USCALED = 11,
71     /** R8 SSCALED */
72     R8_SSCALED = 12,
73     /** R8 UINT */
74     R8_UINT = 13,
75     /** R8 SINT */
76     R8_SINT = 14,
77     /** R8 SRGB */
78     R8_SRGB = 15,
79     /** R8G8 UNORM */
80     R8G8_UNORM = 16,
81     /** R8G8 SNORM */
82     R8G8_SNORM = 17,
83     /** R8G8 USCALED */
84     R8G8_USCALED = 18,
85     /** R8G8 SSCALED */
86     R8G8_SSCALED = 19,
87     /** R8G8 UINT */
88     R8G8_UINT = 20,
89     /** R8G8 SINT */
90     R8G8_SINT = 21,
91     /** R8G8 SRGB */
92     R8G8_SRGB = 22,
93     /** R8G8B8 UNORM */
94     R8G8B8_UNORM = 23,
95     /** R8G8B8 SNORM */
96     R8G8B8_SNORM = 24,
97     /** R8G8B8 USCALED */
98     R8G8B8_USCALED = 25,
99     /** R8G8B8 SSCALED */
100     R8G8B8_SSCALED = 26,
101     /** R8G8B8 UINT */
102     R8G8B8_UINT = 27,
103     /** R8G8B8 SINT */
104     R8G8B8_SINT = 28,
105     /** R8G8B8 SRGB */
106     R8G8B8_SRGB = 29,
107     /** B8G8R8 UNORM */
108     B8G8R8_UNORM = 30,
109     /** B8G8R8 SNORM */
110     B8G8R8_SNORM = 31,
111     /** B8G8R8 UINT */
112     B8G8R8_UINT = 34,
113     /** B8G8R8 SINT */
114     B8G8R8_SINT = 35,
115     /** B8G8R8 SRGB */
116     B8G8R8_SRGB = 36,
117     /** R8G8B8A8 UNORM */
118     R8G8B8A8_UNORM = 37,
119     /** R8G8B8A8 SNORM */
120     R8G8B8A8_SNORM = 38,
121     /** R8G8B8A8 USCALED */
122     R8G8B8A8_USCALED = 39,
123     /** R8G8B8A8 SSCALED */
124     R8G8B8A8_SSCALED = 40,
125     /** R8G8B8A8 UINT */
126     R8G8B8A8_UINT = 41,
127     /** R8G8B8A8 SINT */
128     R8G8B8A8_SINT = 42,
129     /** R8G8B8A8 SRGB */
130     R8G8B8A8_SRGB = 43,
131     /** B8G8R8A8 UNORM */
132     B8G8R8A8_UNORM = 44,
133     /** B8G8R8A8 SNORM */
134     B8G8R8A8_SNORM = 45,
135     /** B8G8R8A8 UINT */
136     B8G8R8A8_UINT = 48,
137     /** B8G8R8A8 SINT */
138     B8G8R8A8_SINT = 49,
139     /** FORMAT B8G8R8A8 SRGB */
140     B8G8R8A8_SRGB = 50,
141     /** A8B8G8R8 UNORM PACK32 */
142     A8B8G8R8_UNORM_PACK32 = 51,
143     /** A8B8G8R8 SNORM PACK32 */
144     A8B8G8R8_SNORM_PACK32 = 52,
145     /** A8B8G8R8 USCALED PACK32 */
146     A8B8G8R8_USCALED_PACK32 = 53,
147     /** A8B8G8R8 SSCALED PACK32 */
148     A8B8G8R8_SSCALED_PACK32 = 54,
149     /** A8B8G8R8 UINT PACK32 */
150     A8B8G8R8_UINT_PACK32 = 55,
151     /** A8B8G8R8 SINT PACK32 */
152     A8B8G8R8_SINT_PACK32 = 56,
153     /** A8B8G8R8 SRGB PACK32 */
154     A8B8G8R8_SRGB_PACK32 = 57,
155     /** A2R10G10B10 UNORM PACK32 */
156     A2R10G10B10_UNORM_PACK32 = 58,
157     /** A2R10G10B10 UINT PACK32 */
158     A2R10G10B10_UINT_PACK32 = 62,
159     /** A2R10G10B10 SINT PACK32 */
160     A2R10G10B10_SINT_PACK32 = 63,
161     /** A2B10G10R10 UNORM PACK32 */
162     A2B10G10R10_UNORM_PACK32 = 64,
163     /** A2B10G10R10 SNORM PACK32 */
164     A2B10G10R10_SNORM_PACK32 = 65,
165     /** A2B10G10R10 USCALED PACK32 */
166     A2B10G10R10_USCALED_PACK32 = 66,
167     /** A2B10G10R10 SSCALED PACK32 */
168     A2B10G10R10_SSCALED_PACK32 = 67,
169     /** A2B10G10R10 UINT PACK32 */
170     A2B10G10R10_UINT_PACK32 = 68,
171     /** A2B10G10R10 SINT PACK32 */
172     A2B10G10R10_SINT_PACK32 = 69,
173     /** R16 UNORM */
174     R16_UNORM = 70,
175     /** R16 SNORM */
176     R16_SNORM = 71,
177     /** R16 USCALED */
178     R16_USCALED = 72,
179     /** R16 SSCALED */
180     R16_SSCALED = 73,
181     /** R16 UINT */
182     R16_UINT = 74,
183     /** R16 SINT */
184     R16_SINT = 75,
185     /** R16 SFLOAT */
186     R16_SFLOAT = 76,
187     /** R16G16 UNORM */
188     R16G16_UNORM = 77,
189     /** R16G16 SNORM */
190     R16G16_SNORM = 78,
191     /** R16G16 USCALED */
192     R16G16_USCALED = 79,
193     /** R16G16 SSCALED */
194     R16G16_SSCALED = 80,
195     /** R16G16 UINT */
196     R16G16_UINT = 81,
197     /** R16G16 SINT */
198     R16G16_SINT = 82,
199     /** R16G16 SFLOAT */
200     R16G16_SFLOAT = 83,
201     /** R16G16B16 UNORM */
202     R16G16B16_UNORM = 84,
203     /** R16G16B16 SNORM */
204     R16G16B16_SNORM = 85,
205     /** R16G16B16 USCALED */
206     R16G16B16_USCALED = 86,
207     /** R16G16B16 SSCALED */
208     R16G16B16_SSCALED = 87,
209     /** R16G16B16 UINT */
210     R16G16B16_UINT = 88,
211     /** R16G16B16 SINT */
212     R16G16B16_SINT = 89,
213     /** R16G16B16 SFLOAT */
214     R16G16B16_SFLOAT = 90,
215     /** R16G16B16A16 UNORM */
216     R16G16B16A16_UNORM = 91,
217     /** R16G16B16A16 SNORM */
218     R16G16B16A16_SNORM = 92,
219     /** R16G16B16A16 USCALED */
220     R16G16B16A16_USCALED = 93,
221     /** R16G16B16A16 SSCALED */
222     R16G16B16A16_SSCALED = 94,
223     /** R16G16B16A16 UINT */
224     R16G16B16A16_UINT = 95,
225     /** R16G16B16A16 SINT */
226     R16G16B16A16_SINT = 96,
227     /** R16G16B16A16 SFLOAT */
228     R16G16B16A16_SFLOAT = 97,
229     /** R32 UINT */
230     R32_UINT = 98,
231     /** R32 SINT */
232     R32_SINT = 99,
233     /** R32 SFLOAT */
234     R32_SFLOAT = 100,
235     /** R32G32 UINT */
236     R32G32_UINT = 101,
237     /** R32G32 SINT */
238     R32G32_SINT = 102,
239     /** R32G32 SFLOAT */
240     R32G32_SFLOAT = 103,
241     /** R32G32B32 UINT */
242     R32G32B32_UINT = 104,
243     /** R32G32B32 SINT */
244     R32G32B32_SINT = 105,
245     /** R32G32B32 SFLOAT */
246     R32G32B32_SFLOAT = 106,
247     /** R32G32B32A32 UINT */
248     R32G32B32A32_UINT = 107,
249     /** R32G32B32A32 SINT */
250     R32G32B32A32_SINT = 108,
251     /** R32G32B32A32 SFLOAT */
252     R32G32B32A32_SFLOAT = 109,
253     /** B10G11R11 UFLOAT PACK32 */
254     B10G11R11_UFLOAT_PACK32 = 122,
255     /** E5B9G9R9 UFLOAT PACK32 */
256     E5B9G9R9_UFLOAT_PACK32 = 123,
257     /** D16 UNORM */
258     D16_UNORM = 124,
259     /** X8 D24 UNORM PACK32 */
260     X8_D24_UNORM_PACK32 = 125,
261     /** D32 SFLOAT */
262     D32_SFLOAT = 126,
263     /** S8 UINT */
264     S8_UINT = 127,
265     /** D24 UNORM S8 UINT */
266     D24_UNORM_S8_UINT = 129,
267     /** BC1 RGB UNORM BLOCK */
268     BC1_RGB_UNORM_BLOCK = 131,
269     /** BC1 RGB SRGB BLOCK */
270     BC1_RGB_SRGB_BLOCK = 132,
271     /** BC1 RGBA UNORM BLOCK */
272     BC1_RGBA_UNORM_BLOCK = 133,
273     /** BC1 RGBA SRGB BLOCK */
274     BC1_RGBA_SRGB_BLOCK = 134,
275     /** BC2 UNORM BLOCK */
276     BC2_UNORM_BLOCK = 135,
277     /** BC2 SRGB BLOCK */
278     BC2_SRGB_BLOCK = 136,
279     /** BC3 UNORM BLOCK */
280     BC3_UNORM_BLOCK = 137,
281     /** BC3 SRGB BLOCK */
282     BC3_SRGB_BLOCK = 138,
283     /** BC4 UNORM BLOCK */
284     BC4_UNORM_BLOCK = 139,
285     /** BC4 SNORM BLOCK */
286     BC4_SNORM_BLOCK = 140,
287     /** BC5 UNORM BLOCK */
288     BC5_UNORM_BLOCK = 141,
289     /** BC5 SNORM BLOCK */
290     BC5_SNORM_BLOCK = 142,
291     /** BC6H UFLOAT BLOCK */
292     BC6H_UFLOAT_BLOCK = 143,
293     /** BC6H SFLOAT BLOCK */
294     BC6H_SFLOAT_BLOCK = 144,
295     /** BC7 UNORM BLOCK */
296     BC7_UNORM_BLOCK = 145,
297     /** BC7 SRGB BLOCK */
298     BC7_SRGB_BLOCK = 146,
299     /** ETC2 R8G8B8 UNORM BLOCK */
300     ETC2_R8G8B8_UNORM_BLOCK = 147,
301     /** ETC2 R8G8B8 SRGB BLOCK */
302     ETC2_R8G8B8_SRGB_BLOCK = 148,
303     /** ETC2 R8G8B8A1 UNORM BLOCK */
304     ETC2_R8G8B8A1_UNORM_BLOCK = 149,
305     /** ETC2 R8G8B8A1 SRGB BLOCK */
306     ETC2_R8G8B8A1_SRGB_BLOCK = 150,
307     /** ETC2 R8G8B8A8 UNORM BLOCK */
308     ETC2_R8G8B8A8_UNORM_BLOCK = 151,
309     /** ETC2 R8G8B8A8 SRGB BLOCK */
310     ETC2_R8G8B8A8_SRGB_BLOCK = 152,
311     /** EAC R11 UNORM BLOCK */
312     EAC_R11_UNORM_BLOCK = 153,
313     /** EAC R11 SNORM BLOCK */
314     EAC_R11_SNORM_BLOCK = 154,
315     /** EAC R11G11 UNORM BLOCK */
316     EAC_R11G11_UNORM_BLOCK = 155,
317     /** EAC R11G11 SNORM BLOCK */
318     EAC_R11G11_SNORM_BLOCK = 156,
319     /** ASTC 4x4 UNORM BLOCK */
320     ASTC_4x4_UNORM_BLOCK = 157,
321     /** ASTC 4x4 SRGB BLOCK */
322     ASTC_4x4_SRGB_BLOCK = 158,
323     /** ASTC 5x4 UNORM BLOCK */
324     ASTC_5x4_UNORM_BLOCK = 159,
325     /** ASTC 5x4 SRGB BLOCK */
326     ASTC_5x4_SRGB_BLOCK = 160,
327     /** ASTC 5x5 UNORM BLOCK */
328     ASTC_5x5_UNORM_BLOCK = 161,
329     /** ASTC 5x5 SRGB BLOCK */
330     ASTC_5x5_SRGB_BLOCK = 162,
331     /** ASTC 6x5 UNORM BLOCK */
332     ASTC_6x5_UNORM_BLOCK = 163,
333     /** ASTC 6x5 SRGB BLOCK */
334     ASTC_6x5_SRGB_BLOCK = 164,
335     /** ASTC 6x6 UNORM BLOCK */
336     ASTC_6x6_UNORM_BLOCK = 165,
337     /** ASTC 6x6 SRGB BLOCK */
338     ASTC_6x6_SRGB_BLOCK = 166,
339     /** ASTC 8x5 UNORM BLOCK */
340     ASTC_8x5_UNORM_BLOCK = 167,
341     /** ASTC 8x5 SRGB BLOCK */
342     ASTC_8x5_SRGB_BLOCK = 168,
343     /** ASTC 8x6 UNORM BLOCK */
344     ASTC_8x6_UNORM_BLOCK = 169,
345     /** ASTC 8x6 SRGB BLOCK */
346     ASTC_8x6_SRGB_BLOCK = 170,
347     /** ASTC 8x8 UNORM BLOCK */
348     ASTC_8x8_UNORM_BLOCK = 171,
349     /** ASTC 8x8 SRGB BLOCK */
350     ASTC_8x8_SRGB_BLOCK = 172,
351     /** ASTC 10x5 UNORM BLOCK */
352     ASTC_10x5_UNORM_BLOCK = 173,
353     /** ASTC 10x5 SRGB BLOCK */
354     ASTC_10x5_SRGB_BLOCK = 174,
355     /** ASTC 10x6 UNORM BLOCK */
356     ASTC_10x6_UNORM_BLOCK = 175,
357     /** ASTC 10x6 SRGB BLOCK */
358     ASTC_10x6_SRGB_BLOCK = 176,
359     /** ASTC 10x8 UNORM BLOCK */
360     ASTC_10x8_UNORM_BLOCK = 177,
361     /** ASTC 10x8 SRGB BLOCK */
362     ASTC_10x8_SRGB_BLOCK = 178,
363     /** ASTC 10x10 UNORM BLOCK */
364     ASTC_10x10_UNORM_BLOCK = 179,
365     /** ASTC 10x10 SRGB BLOCK */
366     ASTC_10x10_SRGB_BLOCK = 180,
367     /** ASTC 12x10 UNORM BLOCK */
368     ASTC_12x10_UNORM_BLOCK = 181,
369     /** ASTC 12x10 SRGB BLOCK */
370     ASTC_12x10_SRGB_BLOCK = 182,
371     /** ASTC 12x12 UNORM BLOCK */
372     ASTC_12x12_UNORM_BLOCK = 183,
373     /** ASTC 12x12 SRGB BLOCK */
374     ASTC_12x12_SRGB_BLOCK = 184,
375     /** G8B8G8R8 422 UNORM */
376     G8B8G8R8_422_UNORM = 1000156000,
377     /** B8G8R8G8 422 UNORM */
378     B8G8R8G8_422_UNORM = 1000156001,
379     /** G8 B8 R8 3PLANE 420 UNORM */
380     G8_B8_R8_3PLANE_420_UNORM = 1000156002,
381     /** G8 B8R8 2PLANE 420 UNORM */
382     G8_B8R8_2PLANE_420_UNORM = 1000156003,
383     /** G8 B8 R8 3PLANE 422 UNORM */
384     G8_B8_R8_3PLANE_422_UNORM = 1000156004,
385     /** G8 B8R8 2PLANE 422 UNORM */
386     G8_B8R8_2PLANE_422_UNORM = 1000156005,
387     /** Max enumeration */
388     MAX_ENUM = 0x7FFFFFFF
389 };
390 
391 enum class DescriptorType {
392     /** Sampler */
393     SAMPLER = 0,
394     /** Combined image sampler */
395     COMBINED_IMAGE_SAMPLER = 1,
396     /** Sampled image */
397     SAMPLED_IMAGE = 2,
398     /** Storage image */
399     STORAGE_IMAGE = 3,
400     /** Uniform texel buffer */
401     UNIFORM_TEXEL_BUFFER = 4,
402     /** Storage texel buffer */
403     STORAGE_TEXEL_BUFFER = 5,
404     /** Uniform buffer */
405     UNIFORM_BUFFER = 6,
406     /** Storage buffer */
407     STORAGE_BUFFER = 7,
408     /** Dynamic uniform buffer */
409     UNIFORM_BUFFER_DYNAMIC = 8,
410     /** Dynamic storage buffer */
411     STORAGE_BUFFER_DYNAMIC = 9,
412     /** Input attachment */
413     INPUT_ATTACHMENT = 10,
414     /** Acceleration structure */
415     ACCELERATION_STRUCTURE = 1000150000,
416     /** Max enumeration */
417     MAX_ENUM = 0x7FFFFFFF
418 };
419 
420 /** Vertex input rate */
421 enum class VertexInputRate {
422     /** Vertex */
423     VERTEX = 0,
424     /** Instance */
425     INSTANCE = 1,
426     /** Max enumeration */
427     MAX_ENUM = 0x7FFFFFFF
428 };
429 
430 /** Pipeline layout constants */
431 struct PipelineLayoutConstants {
432     /** Max descriptor set count */
433     static constexpr uint32_t MAX_DESCRIPTOR_SET_COUNT { 4u };
434     /** Max dynamic descriptor offset count */
435     static constexpr uint32_t MAX_DYNAMIC_DESCRIPTOR_OFFSET_COUNT { 16u };
436     /** Invalid index */
437     static constexpr uint32_t INVALID_INDEX { ~0u };
438     /** Max push constant byte size */
439     static constexpr uint32_t MAX_PUSH_CONSTANT_BYTE_SIZE { 128u };
440 };
441 
442 /** Descriptor set layout binding */
443 struct DescriptorSetLayoutBinding {
444     /** Binding */
445     uint32_t binding { PipelineLayoutConstants::INVALID_INDEX };
446     /** Descriptor type */
447     DescriptorType descriptorType { DescriptorType::MAX_ENUM };
448     /** Descriptor count */
449     uint32_t descriptorCount { 0 };
450     /** Stage flags */
451     ShaderStageFlags shaderStageFlags;
452 };
453 
454 /** Descriptor set layout */
455 struct DescriptorSetLayout {
456     /** Set */
457     uint32_t set { PipelineLayoutConstants::INVALID_INDEX };
458     /** Bindings */
459     std::vector<DescriptorSetLayoutBinding> bindings;
460 };
461 
462 /** Push constant */
463 struct PushConstant {
464     /** Shader stage flags */
465     ShaderStageFlags shaderStageFlags;
466     /** Byte size */
467     uint32_t byteSize { 0 };
468 };
469 
470 /** Pipeline layout */
471 struct PipelineLayout {
472     /** Push constant */
473     PushConstant pushConstant;
474     /** Descriptor set count */
475     uint32_t descriptorSetCount { 0 };
476     /** Descriptor sets */
477     DescriptorSetLayout descriptorSetLayouts[PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT] {};
478 };
479 
480 constexpr const uint32_t RESERVED_CONSTANT_ID_INDEX { 256 };
481 
482 /** Vertex input declaration */
483 struct VertexInputDeclaration {
484     /** Vertex input binding description */
485     struct VertexInputBindingDescription {
486         /** Binding */
487         uint32_t binding { ~0u };
488         /** Stride */
489         uint32_t stride { 0u };
490         /** Vertex input rate */
491         VertexInputRate vertexInputRate { VertexInputRate::MAX_ENUM };
492     };
493 
494     /** Vertex input attribute description */
495     struct VertexInputAttributeDescription {
496         /** Location */
497         uint32_t location { ~0u };
498         /** Binding */
499         uint32_t binding { ~0u };
500         /** Format */
501         Format format { Format::UNDEFINED };
502         /** Offset */
503         uint32_t offset { 0u };
504     };
505 };
506 
507 struct VertexAttributeInfo {
508     uint32_t byteSize { 0 };
509     VertexInputDeclaration::VertexInputAttributeDescription description;
510 };
511 
512 struct UVec3 {
513     uint32_t x;
514     uint32_t y;
515     uint32_t z;
516 };
517 
518 struct ShaderReflectionData {
519     array_view<const uint8_t> reflectionData;
520 
521     bool IsValid() const;
522     ShaderStageFlags GetStageFlags() const;
523     PipelineLayout GetPipelineLayout() const;
524     std::vector<ShaderSpecializationConstant> GetSpecializationConstants() const;
525     std::vector<VertexInputDeclaration::VertexInputAttributeDescription> GetInputDescriptions() const;
526     UVec3 GetLocalSize() const;
527 };
528 
529 struct ShaderModuleCreateInfo {
530     ShaderStageFlags shaderStageFlags;
531     array_view<const uint8_t> spvData;
532     ShaderReflectionData reflectionData;
533 };
534 
535 struct CompilationSettings {
536     ShaderEnv shaderEnv;
537     std::vector<std::filesystem::path> shaderIncludePaths;
538     std::optional<spvtools::Optimizer> optimizer;
539     std::filesystem::path& shaderSourcePath;
540     std::filesystem::path& compiledShaderDestinationPath;
541 };
542 
543 constexpr uint8_t REFLECTION_TAG[] = { 'r', 'f', 'l', 0 };
544 struct ReflectionHeader {
545     uint8_t tag[sizeof(REFLECTION_TAG)];
546     uint16_t type;
547     uint16_t offsetPushConstants;
548     uint16_t offsetSpecializationConstants;
549     uint16_t offsetDescriptorSets;
550     uint16_t offsetInputs;
551     uint16_t offsetLocalSize;
552 };
553 
554 class scope {
555 private:
556     std::function<void()> init;
557     std::function<void()> deinit;
558 
559 public:
scope(const std::function<void ()> && initializer,const std::function<void ()> && deinitalizer)560     scope(const std::function<void()>&& initializer, const std::function<void()>&& deinitalizer)
561         : init(initializer), deinit(deinitalizer)
562     {
563         init();
564     }
565 
~scope()566     ~scope()
567     {
568         deinit();
569     }
570 };
571 
IsValid() const572 bool ShaderReflectionData::IsValid() const
573 {
574     if (reflectionData.size() < sizeof(ReflectionHeader)) {
575         return false;
576     }
577     const ReflectionHeader& header = *reinterpret_cast<const ReflectionHeader*>(reflectionData.data());
578     return memcmp(header.tag, REFLECTION_TAG, sizeof(REFLECTION_TAG)) == 0;
579 }
580 
GetStageFlags() const581 ShaderStageFlags ShaderReflectionData::GetStageFlags() const
582 {
583     ShaderStageFlags flags;
584     const ReflectionHeader& header = *reinterpret_cast<const ReflectionHeader*>(reflectionData.data());
585     flags = static_cast<ShaderStageFlagBits>(header.type);
586     return flags;
587 }
588 
GetPipelineLayout() const589 PipelineLayout ShaderReflectionData::GetPipelineLayout() const
590 {
591     PipelineLayout pipelineLayout;
592     const ReflectionHeader& header = *reinterpret_cast<const ReflectionHeader*>(reflectionData.data());
593     if (header.offsetPushConstants && header.offsetPushConstants < reflectionData.size()) {
594         auto ptr = reflectionData.data() + header.offsetPushConstants;
595         const auto constants = *ptr;
596         if (constants) {
597             pipelineLayout.pushConstant.shaderStageFlags = static_cast<ShaderStageFlagBits>(header.type);
598             pipelineLayout.pushConstant.byteSize = static_cast<uint32_t>(*(ptr + 1) | (*(ptr + 2) << 8));
599         }
600     }
601     if (header.offsetDescriptorSets && header.offsetDescriptorSets < reflectionData.size()) {
602         auto ptr = reflectionData.data() + header.offsetDescriptorSets;
603         pipelineLayout.descriptorSetCount = static_cast<uint32_t>(*(ptr) | (*(ptr + 1) << 8));
604         ptr += 2;
605         for (auto i = 0u; i < pipelineLayout.descriptorSetCount; ++i) {
606             // write to correct set location
607             const uint32_t set = static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8));
608             assert(set < PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT);
609             auto& layout = pipelineLayout.descriptorSetLayouts[set];
610             layout.set = set;
611             ptr += 2;
612             const auto bindings = static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8));
613             ptr += 2;
614             for (auto j = 0u; j < bindings; ++j) {
615                 DescriptorSetLayoutBinding binding;
616                 binding.binding = static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8));
617                 ptr += 2;
618                 binding.descriptorType = static_cast<DescriptorType>(*ptr | (*(ptr + 1) << 8));
619                 ptr += 2;
620                 binding.descriptorCount = static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8));
621                 ptr += 2;
622                 binding.shaderStageFlags = static_cast<ShaderStageFlagBits>(header.type);
623                 layout.bindings.push_back(binding);
624             }
625         }
626     }
627     return pipelineLayout;
628 }
629 
GetSpecializationConstants() const630 std::vector<ShaderSpecializationConstant> ShaderReflectionData::GetSpecializationConstants() const
631 {
632     std::vector<ShaderSpecializationConstant> constants;
633     const ReflectionHeader& header = *reinterpret_cast<const ReflectionHeader*>(reflectionData.data());
634     if (header.offsetSpecializationConstants && header.offsetSpecializationConstants < reflectionData.size()) {
635         auto ptr = reflectionData.data() + header.offsetSpecializationConstants;
636         const auto size = *ptr | *(ptr + 1) << 8 | *(ptr + 2) << 16 | *(ptr + 3) << 24;
637         ptr += 4;
638         for (auto i = 0; i < size; ++i) {
639             ShaderSpecializationConstant constant;
640             constant.shaderStage = static_cast<ShaderStageFlagBits>(header.type);
641             constant.id = static_cast<uint32_t>(*ptr | *(ptr + 1) << 8 | *(ptr + 2) << 16 | *(ptr + 3) << 24);
642             ptr += 4;
643             constant.type = static_cast<ShaderSpecializationConstant::Type>(
644                 *ptr | *(ptr + 1) << 8 | *(ptr + 2) << 16 | *(ptr + 3) << 24);
645             ptr += 4;
646             constant.offset = 0;
647             constants.push_back(constant);
648         }
649     }
650     return constants;
651 }
652 
GetInputDescriptions() const653 std::vector<VertexInputDeclaration::VertexInputAttributeDescription> ShaderReflectionData::GetInputDescriptions() const
654 {
655     std::vector<VertexInputDeclaration::VertexInputAttributeDescription> inputs;
656     const ReflectionHeader& header = *reinterpret_cast<const ReflectionHeader*>(reflectionData.data());
657     if (header.offsetInputs && header.offsetInputs < reflectionData.size()) {
658         auto ptr = reflectionData.data() + header.offsetInputs;
659         const auto size = *(ptr) | (*(ptr + 1) << 8);
660         ptr += 2;
661         for (auto i = 0; i < size; ++i) {
662             VertexInputDeclaration::VertexInputAttributeDescription desc;
663             desc.location = static_cast<uint32_t>(*(ptr) | (*(ptr + 1) << 8));
664             ptr += 2;
665             desc.binding = desc.location;
666             desc.format = static_cast<Format>(*(ptr) | (*(ptr + 1) << 8));
667             ptr += 2;
668             desc.offset = 0;
669             inputs.push_back(desc);
670         }
671     }
672     return inputs;
673 }
674 
GetLocalSize() const675 UVec3 ShaderReflectionData::GetLocalSize() const
676 {
677     UVec3 sizes;
678     const ReflectionHeader& header = *reinterpret_cast<const ReflectionHeader*>(reflectionData.data());
679     if (header.offsetLocalSize && header.offsetLocalSize < reflectionData.size()) {
680         auto ptr = reflectionData.data() + header.offsetLocalSize;
681         sizes.x = static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8) | (*(ptr + 2)) << 16 | (*(ptr + 3)) << 24);
682         ptr += 4;
683         sizes.y = static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8) | (*(ptr + 2)) << 16 | (*(ptr + 3)) << 24);
684         ptr += 4;
685         sizes.z = static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8) | (*(ptr + 2)) << 16 | (*(ptr + 3)) << 24);
686         ptr += 4;
687     }
688     return sizes;
689 }
690 
readFileToString(std::string_view aFilename)691 std::string readFileToString(std::string_view aFilename)
692 {
693     std::stringstream ss;
694     std::ifstream file;
695 
696     file.exceptions(std::ifstream::failbit | std::ifstream::badbit);
697     try {
698         file.open(aFilename.data(), std::ios::in);
699 
700         if (!file.fail()) {
701             ss << file.rdbuf();
702             return ss.str();
703         }
704     } catch (std::exception const& ex) {
705         LUME_LOG_E("Error reading file: '%s': %s", aFilename.data(), ex.what());
706         return {};
707     }
708     return {};
709 }
710 
711 class FileIncluder : public glslang::TShader::Includer {
712 public:
713     const CompilationSettings& settings;
FileIncluder(const CompilationSettings & compilationSettings)714     FileIncluder(const CompilationSettings& compilationSettings) : settings(compilationSettings) {}
715 
716 private:
include(const char * headerName,const char * includerName,size_t inclusionDepth,bool relative)717     virtual IncludeResult* include(
718         const char* headerName, const char* includerName, size_t inclusionDepth, bool relative)
719     {
720         std::filesystem::path path;
721         bool found = false;
722         if (relative == true) {
723             path.append(settings.shaderSourcePath.c_str());
724             path.append(includerName);
725             path = path.parent_path();
726             path.append(headerName);
727             found = std::filesystem::exists(path);
728         }
729 
730         for (int i = 0; i < settings.shaderIncludePaths.size() && found == false; ++i) {
731             path.assign(settings.shaderIncludePaths[i]);
732             path.append(headerName);
733             found = std::filesystem::exists(path);
734         }
735 
736         if (found == true) {
737             auto str = path.string();
738 
739             std::ifstream file(path);
740             file.seekg(0, file.end);
741             std::streampos length = file.tellg();
742             file.seekg(0, file.beg);
743 
744             char* memory = new (std::nothrow) char[length + std::streampos(1)];
745             if (memory == 0) {
746                 return nullptr;
747             }
748 
749             char* last = std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), memory);
750             IncludeResult* result = new (std::nothrow) IncludeResult(str, memory, std::distance(memory, last), 0);
751             if (result == 0) {
752                 delete memory;
753                 return nullptr;
754             }
755 
756             return result;
757         }
758 
759         return nullptr;
760     }
761 
includeSystem(const char * headerName,const char * includerName,size_t inclusionDepth)762     virtual IncludeResult* includeSystem(const char* headerName, const char* includerName, size_t inclusionDepth)
763     {
764         return include(headerName, includerName, inclusionDepth, false);
765     }
766 
includeLocal(const char * headerName,const char * includerName,size_t inclusionDepth)767     virtual IncludeResult* includeLocal(const char* headerName, const char* includerName, size_t inclusionDepth)
768     {
769         return include(headerName, includerName, inclusionDepth, true);
770     }
771 
releaseInclude(IncludeResult * include)772     virtual void releaseInclude(IncludeResult* include)
773     {
774         delete include;
775     }
776 };
777 
ToSpirVVersion(glslang::EShTargetClientVersion env_version)778 glslang::EShTargetLanguageVersion ToSpirVVersion(glslang::EShTargetClientVersion env_version)
779 {
780     if (env_version == glslang::EShTargetVulkan_1_0) {
781         return glslang::EShTargetSpv_1_0;
782     } else if (env_version == glslang::EShTargetVulkan_1_1) {
783         return glslang::EShTargetSpv_1_3;
784     } else if (env_version == glslang::EShTargetVulkan_1_2) {
785         return glslang::EShTargetSpv_1_5;
786 #if GLSLANG_VERSION >= GLSLANG_VERSION_12_2_0
787     } else if (env_version == glslang::EShTargetVulkan_1_3) {
788         return glslang::EShTargetSpv_1_6;
789 #endif
790     } else {
791         return glslang::EShTargetSpv_1_0;
792     }
793 }
794 
preProcessShader(std::string_view aSource,ShaderKind aKind,std::string_view aSourceName,const CompilationSettings & settings)795 std::string preProcessShader(
796     std::string_view aSource, ShaderKind aKind, std::string_view aSourceName, const CompilationSettings& settings)
797 {
798     glslang::EShTargetLanguageVersion languageVersion;
799     glslang::EShTargetClientVersion version;
800     EShLanguage stage;
801     switch (aKind) {
802         case ShaderKind::VERTEX:
803             stage = EShLanguage::EShLangVertex;
804             break;
805         case ShaderKind::FRAGMENT:
806             stage = EShLanguage::EShLangFragment;
807             break;
808         case ShaderKind::COMPUTE:
809             stage = EShLanguage::EShLangCompute;
810             break;
811         default:
812             LUME_LOG_E("Spirv preprocessing compilation failed '%s'", "ShaderKind not recognized");
813             return {};
814     }
815 
816     switch (settings.shaderEnv) {
817         case ShaderEnv::version_vulkan_1_0:
818             version = glslang::EShTargetClientVersion::EShTargetVulkan_1_0;
819             break;
820         case ShaderEnv::version_vulkan_1_1:
821             version = glslang::EShTargetClientVersion::EShTargetVulkan_1_1;
822             break;
823         case ShaderEnv::version_vulkan_1_2:
824             version = glslang::EShTargetClientVersion::EShTargetVulkan_1_2;
825             break;
826 #if GLSLANG_VERSION >= GLSLANG_VERSION_12_2_0
827         case ShaderEnv::version_vulkan_1_3:
828             version = glslang::EShTargetClientVersion::EShTargetVulkan_1_3;
829             break;
830 #endif
831         default:
832             LUME_LOG_E("Spirv preprocessing compilation failed '%s'", "ShaderEnv not recognized");
833             return {};
834     }
835 
836     languageVersion = ToSpirVVersion(version);
837 
838     FileIncluder includer(settings);
839     glslang::TShader shader(stage);
840     const char* shader_strings = aSource.data();
841     const int shader_lengths = static_cast<int>(aSource.size());
842     const char* string_names = aSourceName.data();
843     std::string_view preamble = "#extension GL_GOOGLE_include_directive : enable\n";
844     shader.setStringsWithLengthsAndNames(&shader_strings, &shader_lengths, &string_names, 1);
845     shader.setPreamble(preamble.data());
846     shader.setEntryPoint("main");
847     shader.setAutoMapBindings(false);
848     shader.setAutoMapLocations(false);
849     shader.setShiftImageBinding(0);
850     shader.setShiftSamplerBinding(0);
851     shader.setShiftTextureBinding(0);
852     shader.setShiftUboBinding(0);
853     shader.setShiftSsboBinding(0);
854     shader.setShiftUavBinding(0);
855     shader.setEnvClient(glslang::EShClient::EShClientVulkan, version);
856     shader.setEnvTarget(glslang::EShTargetLanguage::EShTargetSpv, languageVersion);
857     shader.setInvertY(false);
858     shader.setNanMinMaxClamp(false);
859 
860     std::string output;
861     const EShMessages rules =
862         static_cast<EShMessages>(EShMsgOnlyPreprocessor | EShMsgSpvRules | EShMsgVulkanRules | EShMsgCascadingErrors);
863     if (!shader.preprocess(
864             &kGLSLangDefaultTResource, 110, EProfile::ENoProfile, false, false, rules, &output, includer)) {
865         LUME_LOG_E("Spirv preprocessing compilation failed '%s':\n%s", aSourceName.data(), shader.getInfoLog());
866         LUME_LOG_E("Spirv preprocessing compilation failed '%s':\n%s", aSourceName.data(), shader.getInfoDebugLog());
867 
868         output = { output.begin() + preamble.size(), output.end() };
869         return {};
870     }
871 
872     output = { output.begin() + preamble.size(), output.end() };
873     return output;
874 }
875 
compileShaderToSpirvBinary(std::string_view aSource,ShaderKind aKind,std::string_view aSourceName,const CompilationSettings & settings)876 std::vector<uint32_t> compileShaderToSpirvBinary(
877     std::string_view aSource, ShaderKind aKind, std::string_view aSourceName, const CompilationSettings& settings)
878 {
879     glslang::EShTargetLanguageVersion languageVersion;
880     glslang::EShTargetClientVersion version;
881     EShLanguage stage;
882     switch (aKind) {
883         case ShaderKind::VERTEX:
884             stage = EShLanguage::EShLangVertex;
885             break;
886         case ShaderKind::FRAGMENT:
887             stage = EShLanguage::EShLangFragment;
888             break;
889         case ShaderKind::COMPUTE:
890             stage = EShLanguage::EShLangCompute;
891             break;
892         default:
893             LUME_LOG_E("Spirv binary compilation failed '%s'", "ShaderKind not recognized");
894             return {};
895     }
896 
897     switch (settings.shaderEnv) {
898         case ShaderEnv::version_vulkan_1_0:
899             version = glslang::EShTargetClientVersion::EShTargetVulkan_1_0;
900             break;
901         case ShaderEnv::version_vulkan_1_1:
902             version = glslang::EShTargetClientVersion::EShTargetVulkan_1_1;
903             break;
904         case ShaderEnv::version_vulkan_1_2:
905             version = glslang::EShTargetClientVersion::EShTargetVulkan_1_2;
906             break;
907 #if GLSLANG_VERSION >= GLSLANG_VERSION_12_2_0
908         case ShaderEnv::version_vulkan_1_3:
909             version = glslang::EShTargetClientVersion::EShTargetVulkan_1_3;
910             break;
911 #endif
912         default:
913             LUME_LOG_E("Spirv binary compilation failed '%s'", "ShaderEnv not recognized");
914             return {};
915     }
916 
917     languageVersion = ToSpirVVersion(version);
918 
919     glslang::TShader shader(stage);
920     const char* shader_strings = aSource.data();
921     const int shader_lengths = static_cast<int>(aSource.size());
922     const char* string_names = aSourceName.data();
923     shader.setStringsWithLengthsAndNames(&shader_strings, &shader_lengths, &string_names, 1);
924     shader.setPreamble("#extension GL_GOOGLE_include_directive : enable\n");
925     shader.setEntryPoint("main");
926     shader.setAutoMapBindings(false);
927     shader.setAutoMapLocations(false);
928     shader.setShiftImageBinding(0);
929     shader.setShiftSamplerBinding(0);
930     shader.setShiftTextureBinding(0);
931     shader.setShiftUboBinding(0);
932     shader.setShiftSsboBinding(0);
933     shader.setShiftUavBinding(0);
934     shader.setEnvClient(glslang::EShClient::EShClientVulkan, version);
935     shader.setEnvTarget(glslang::EShTargetLanguage::EShTargetSpv, languageVersion);
936     shader.setInvertY(false);
937     shader.setNanMinMaxClamp(false);
938 
939     const EShMessages rules = static_cast<EShMessages>(EShMsgSpvRules | EShMsgVulkanRules | EShMsgCascadingErrors);
940     if (!shader.parse(&kGLSLangDefaultTResource, 110, EProfile::ENoProfile, false, false, rules)) {
941         LUME_LOG_E("Spirv binary compilation failed '%s':\n%s", aSourceName.data(), shader.getInfoLog());
942         LUME_LOG_E("Spirv binary compilation failed '%s':\n%s", aSourceName.data(), shader.getInfoDebugLog());
943         return {};
944     }
945 
946     glslang::TProgram program;
947     program.addShader(&shader);
948     if (!program.link(EShMsgDefault) || !program.mapIO()) {
949         LUME_LOG_E("Spirv binary compilation failed '%s':\n%s", aSourceName.data(), program.getInfoLog());
950         LUME_LOG_E("Spirv binary compilation failed '%s':\n%s", aSourceName.data(), program.getInfoDebugLog());
951         return {};
952     }
953 
954     std::vector<unsigned int> spirv;
955     glslang::SpvOptions spv_options;
956     spv_options.generateDebugInfo = false;
957     spv_options.disableOptimizer = true;
958     spv_options.optimizeSize = false;
959     spv::SpvBuildLogger logger;
960     glslang::TIntermediate* intermediate = program.getIntermediate(stage);
961     glslang::GlslangToSpv(*intermediate, spirv, &logger, &spv_options);
962 
963     const uint32_t shadercGeneratorWord = 13; // From SPIR-V XML Registry
964     const uint32_t generatorWordIndex = 2;    // SPIR-V 2.3: Physical layout
965     assert(spirv.size() > generatorWordIndex);
966     spirv[generatorWordIndex] = (spirv[generatorWordIndex] & 0xffff) | (shadercGeneratorWord << 16u);
967     return spirv;
968 }
969 
processResource(const spirv_cross::Compiler & compiler,const spirv_cross::Resource & resource,ShaderStageFlags shaderStateFlags,DescriptorType type,DescriptorSetLayout * layouts)970 void processResource(const spirv_cross::Compiler& compiler, const spirv_cross::Resource& resource,
971     ShaderStageFlags shaderStateFlags, DescriptorType type, DescriptorSetLayout* layouts)
972 {
973     const uint32_t set = compiler.get_decoration(resource.id, spv::DecorationDescriptorSet);
974 
975     assert(set < PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT);
976     if (set >= PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT) {
977         return;
978     }
979     DescriptorSetLayout& layout = layouts[set];
980     layout.set = set;
981 
982     // Collect bindings.
983     const uint32_t bindingIndex = compiler.get_decoration(resource.id, spv::DecorationBinding);
984     auto& bindings = layout.bindings;
985     if (auto pos = std::find_if(bindings.begin(), bindings.end(),
986             [bindingIndex](const DescriptorSetLayoutBinding& binding) { return binding.binding == bindingIndex; });
987         pos == bindings.end()) {
988         const spirv_cross::SPIRType& spirType = compiler.get_type(resource.type_id);
989 
990         DescriptorSetLayoutBinding binding;
991         binding.binding = bindingIndex;
992         binding.descriptorType = type;
993         binding.descriptorCount = spirType.array.empty() ? 1 : spirType.array[0];
994         binding.shaderStageFlags = shaderStateFlags;
995 
996         bindings.emplace_back(binding);
997     } else {
998         pos->shaderStageFlags |= shaderStateFlags;
999     }
1000 }
1001 
reflectDescriptorSets(const spirv_cross::Compiler & compiler,const spirv_cross::ShaderResources & resources,ShaderStageFlags shaderStateFlags,DescriptorSetLayout * layouts)1002 void reflectDescriptorSets(const spirv_cross::Compiler& compiler, const spirv_cross::ShaderResources& resources,
1003     ShaderStageFlags shaderStateFlags, DescriptorSetLayout* layouts)
1004 {
1005     for (const auto& ref : resources.sampled_images) {
1006         processResource(compiler, ref, shaderStateFlags, DescriptorType::COMBINED_IMAGE_SAMPLER, layouts);
1007     }
1008 
1009     for (const auto& ref : resources.separate_samplers) {
1010         processResource(compiler, ref, shaderStateFlags, DescriptorType::SAMPLER, layouts);
1011     }
1012 
1013     for (const auto& ref : resources.separate_images) {
1014         processResource(compiler, ref, shaderStateFlags, DescriptorType::SAMPLED_IMAGE, layouts);
1015     }
1016 
1017     for (const auto& ref : resources.storage_images) {
1018         processResource(compiler, ref, shaderStateFlags, DescriptorType::STORAGE_IMAGE, layouts);
1019     }
1020 
1021     for (const auto& ref : resources.uniform_buffers) {
1022         processResource(compiler, ref, shaderStateFlags, DescriptorType::UNIFORM_BUFFER, layouts);
1023     }
1024 
1025     for (const auto& ref : resources.storage_buffers) {
1026         processResource(compiler, ref, shaderStateFlags, DescriptorType::STORAGE_BUFFER, layouts);
1027     }
1028 
1029     for (const auto& ref : resources.subpass_inputs) {
1030         processResource(compiler, ref, shaderStateFlags, DescriptorType::INPUT_ATTACHMENT, layouts);
1031     }
1032 
1033     for (const auto& ref : resources.acceleration_structures) {
1034         processResource(compiler, ref, shaderStateFlags, DescriptorType::ACCELERATION_STRUCTURE, layouts);
1035     }
1036 
1037     std::sort(layouts, layouts + PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT,
1038         [](const DescriptorSetLayout& lhs, const DescriptorSetLayout& rhs) { return (lhs.set < rhs.set); });
1039 
1040     std::for_each(
1041         layouts, layouts + PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT, [](DescriptorSetLayout& layout) {
1042             std::sort(layout.bindings.begin(), layout.bindings.end(),
1043                 [](const DescriptorSetLayoutBinding& lhs, const DescriptorSetLayoutBinding& rhs) {
1044                     return (lhs.binding < rhs.binding);
1045                 });
1046         });
1047 }
1048 
reflectPushContants(const spirv_cross::Compiler & compiler,const spirv_cross::ShaderResources & resources,ShaderStageFlags shaderStateFlags,PushConstant & pushConstant)1049 void reflectPushContants(const spirv_cross::Compiler& compiler, const spirv_cross::ShaderResources& resources,
1050     ShaderStageFlags shaderStateFlags, PushConstant& pushConstant)
1051 {
1052     // NOTE: support for only one push constant
1053     if (resources.push_constant_buffers.size() > 0) {
1054         pushConstant.shaderStageFlags |= shaderStateFlags;
1055 
1056         const auto ranges = compiler.get_active_buffer_ranges(resources.push_constant_buffers[0].id);
1057         const uint32_t byteSize = std::accumulate(
1058             ranges.begin(), ranges.end(), 0u, [](uint32_t byteSize, const spirv_cross::BufferRange& range) {
1059                 return byteSize + static_cast<uint32_t>(range.range);
1060             });
1061         pushConstant.byteSize = std::max(pushConstant.byteSize, byteSize);
1062     }
1063 }
1064 
reflectSpecializationConstants(const spirv_cross::Compiler & compiler,ShaderStageFlags shaderStateFlags)1065 std::vector<ShaderSpecializationConstant> reflectSpecializationConstants(
1066     const spirv_cross::Compiler& compiler, ShaderStageFlags shaderStateFlags)
1067 {
1068     std::vector<ShaderSpecializationConstant> specializationConstants;
1069     uint32_t offset = 0;
1070     for (auto const& constant : compiler.get_specialization_constants()) {
1071         if (constant.constant_id < RESERVED_CONSTANT_ID_INDEX) {
1072             const spirv_cross::SPIRConstant& spirvConstant = compiler.get_constant(constant.id);
1073             const auto type = compiler.get_type(spirvConstant.constant_type);
1074             ShaderSpecializationConstant::Type constantType = ShaderSpecializationConstant::Type::INVALID;
1075             if (type.basetype == spirv_cross::SPIRType::Boolean) {
1076                 constantType = ShaderSpecializationConstant::Type::BOOL;
1077             } else if (type.basetype == spirv_cross::SPIRType::UInt) {
1078                 constantType = ShaderSpecializationConstant::Type::UINT32;
1079             } else if (type.basetype == spirv_cross::SPIRType::Int) {
1080                 constantType = ShaderSpecializationConstant::Type::INT32;
1081             } else if (type.basetype == spirv_cross::SPIRType::Float) {
1082                 constantType = ShaderSpecializationConstant::Type::FLOAT;
1083             } else {
1084                 assert(false && "Unhandled specialization constant type");
1085             }
1086             const uint32_t size = spirvConstant.vector_size() * spirvConstant.columns() * sizeof(uint32_t);
1087             specializationConstants.push_back(
1088                 ShaderSpecializationConstant { shaderStateFlags, constant.constant_id, constantType, offset });
1089             offset += size;
1090         }
1091     }
1092     // sorted based on offset due to offset mapping with shader combinations
1093     // NOTE: id and name indexing
1094     std::sort(specializationConstants.begin(), specializationConstants.end(),
1095         [](const auto& lhs, const auto& rhs) { return (lhs.offset < rhs.offset); });
1096 
1097     return specializationConstants;
1098 }
1099 
convertToVertexInputFormat(const spirv_cross::SPIRType & type)1100 Format convertToVertexInputFormat(const spirv_cross::SPIRType& type)
1101 {
1102     using BaseType = spirv_cross::SPIRType::BaseType;
1103 
1104     // ivecn: a vector of signed integers
1105     if (type.basetype == BaseType::Int) {
1106         switch (type.vecsize) {
1107             case 1:
1108                 return Format::R32_SINT;
1109             case 2:
1110                 return Format::R32G32_SINT;
1111             case 3:
1112                 return Format::R32G32B32_SINT;
1113             case 4:
1114                 return Format::R32G32B32A32_SINT;
1115         }
1116     }
1117 
1118     // uvecn: a vector of unsigned integers
1119     if (type.basetype == BaseType::UInt) {
1120         switch (type.vecsize) {
1121             case 1:
1122                 return Format::R32_UINT;
1123             case 2:
1124                 return Format::R32G32_UINT;
1125             case 3:
1126                 return Format::R32G32B32_UINT;
1127             case 4:
1128                 return Format::R32G32B32A32_UINT;
1129         }
1130     }
1131 
1132     // halfn: a vector of half-precision floating-point numbers
1133     if (type.basetype == BaseType::Half) {
1134         switch (type.vecsize) {
1135             case 1:
1136                 return Format::R16_SFLOAT;
1137             case 2:
1138                 return Format::R16G16_SFLOAT;
1139             case 3:
1140                 return Format::R16G16B16_SFLOAT;
1141             case 4:
1142                 return Format::R16G16B16A16_SFLOAT;
1143         }
1144     }
1145 
1146     // vecn: a vector of single-precision floating-point numbers
1147     if (type.basetype == BaseType::Float) {
1148         switch (type.vecsize) {
1149             case 1:
1150                 return Format::R32_SFLOAT;
1151             case 2:
1152                 return Format::R32G32_SFLOAT;
1153             case 3:
1154                 return Format::R32G32B32_SFLOAT;
1155             case 4:
1156                 return Format::R32G32B32A32_SFLOAT;
1157         }
1158     }
1159 
1160     return Format::UNDEFINED;
1161 }
1162 
reflectVertexInputs(const spirv_cross::Compiler & compiler,const spirv_cross::ShaderResources & resources,ShaderStageFlags,std::vector<VertexInputDeclaration::VertexInputAttributeDescription> & vertexInputAttributes)1163 void reflectVertexInputs(const spirv_cross::Compiler& compiler, const spirv_cross::ShaderResources& resources,
1164     ShaderStageFlags /* shaderStateFlags */,
1165     std::vector<VertexInputDeclaration::VertexInputAttributeDescription>& vertexInputAttributes)
1166 {
1167     std::vector<VertexAttributeInfo> vertexAttributeInfos;
1168 
1169     // Vertex input attributes.
1170     for (auto& attr : resources.stage_inputs) {
1171         const spirv_cross::SPIRType attributeType = compiler.get_type(attr.type_id);
1172 
1173         VertexAttributeInfo info;
1174 
1175         // For now, assume that every vertex attribute comes from it's own binding which equals the location.
1176         info.description.location = compiler.get_decoration(attr.id, spv::DecorationLocation);
1177         info.description.binding = info.description.location;
1178         info.description.format = convertToVertexInputFormat(attributeType);
1179         info.description.offset = 0;
1180 
1181         info.byteSize = attributeType.vecsize * (attributeType.width / 8);
1182 
1183         vertexAttributeInfos.emplace_back(std::move(info));
1184     }
1185 
1186     // Sort input attributes by binding and location.
1187     std::sort(std::begin(vertexAttributeInfos), std::end(vertexAttributeInfos),
1188         [](const VertexAttributeInfo& aA, const VertexAttributeInfo& aB) {
1189             if (aA.description.binding < aB.description.binding) {
1190                 return true;
1191             }
1192 
1193             return aA.description.location < aB.description.location;
1194         });
1195 
1196     // Create final attributes.
1197     if (!vertexAttributeInfos.empty()) {
1198         for (auto& info : vertexAttributeInfos) {
1199             vertexInputAttributes.push_back(info.description);
1200         }
1201     }
1202 }
1203 
1204 template<typename T>
push(std::vector<uint8_t> & buffer,T data)1205 void push(std::vector<uint8_t>& buffer, T data)
1206 {
1207     buffer.push_back(data & 0xff);
1208     if constexpr (sizeof(T) > 1) {
1209         buffer.push_back((data >> 8) & 0xff);
1210     }
1211     if constexpr (sizeof(T) > 2) {
1212         buffer.push_back((data >> 16) & 0xff);
1213     }
1214     if constexpr (sizeof(T) > 3) {
1215         buffer.push_back((data >> 24) & 0xff);
1216     }
1217 }
1218 
reflectSpvBinary(const std::vector<uint32_t> & aBinary,ShaderKind aKind)1219 std::vector<uint8_t> reflectSpvBinary(const std::vector<uint32_t>& aBinary, ShaderKind aKind)
1220 {
1221     const spirv_cross::Compiler compiler(aBinary);
1222 
1223     const auto shaderStateFlags = ShaderStageFlags(aKind);
1224 
1225     const spirv_cross::ShaderResources resources = compiler.get_shader_resources();
1226 
1227     PipelineLayout pipelineLayout;
1228     reflectDescriptorSets(compiler, resources, shaderStateFlags, pipelineLayout.descriptorSetLayouts);
1229     pipelineLayout.descriptorSetCount =
1230         static_cast<uint32_t>(std::count_if(std::begin(pipelineLayout.descriptorSetLayouts),
1231             std::end(pipelineLayout.descriptorSetLayouts), [](const DescriptorSetLayout& layout) {
1232                 return layout.set < PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT;
1233             }));
1234     reflectPushContants(compiler, resources, shaderStateFlags, pipelineLayout.pushConstant);
1235 
1236     // some additional information mainly for GL
1237     std::vector<Gles::PushConstantReflection> pushConstantReflection;
1238     for (auto& remap : resources.push_constant_buffers) {
1239         const auto& blockType = compiler.get_type(remap.base_type_id);
1240         auto name = compiler.get_name(remap.id);
1241         (void)(blockType);
1242         assert((blockType.basetype == spirv_cross::SPIRType::Struct) && "Push constant is not a struct!");
1243         Gles::ProcessStruct(std::string_view(name.data(), name.size()), 0, compiler, remap.base_type_id,
1244             pushConstantReflection, shaderStateFlags);
1245     }
1246 
1247     auto specializationConstants = reflectSpecializationConstants(compiler, shaderStateFlags);
1248 
1249     // NOTE: this is done for all although the name is 'Vertex'InputAttributes
1250     std::vector<VertexInputDeclaration::VertexInputAttributeDescription> vertexInputAttributes;
1251     reflectVertexInputs(compiler, resources, shaderStateFlags, vertexInputAttributes);
1252 
1253     std::vector<uint8_t> reflection;
1254     reflection.reserve(512u);
1255     static constexpr uint8_t tag[] = { 'r', 'f', 'l', 0 }; // last one is version
1256     uint16_t type = 0;
1257     uint16_t offsetPushConstants = 0;
1258     uint16_t offsetSpecializationConstants = 0;
1259     uint16_t offsetDescriptorSets = 0;
1260     uint16_t offsetInputs = 0;
1261     uint16_t offsetLocalSize = 0;
1262     // tag
1263     {
1264         reflection.insert(reflection.end(), std::begin(tag), std::end(tag));
1265     }
1266     // shader type
1267     {
1268         push(reflection, static_cast<uint16_t>(shaderStateFlags.flags));
1269     }
1270     // offsets
1271     {
1272         reflection.resize(reflection.size() + sizeof(uint16_t) * 5);
1273     }
1274     // push constants
1275     {
1276         offsetPushConstants = static_cast<uint16_t>(reflection.size());
1277         if (pipelineLayout.pushConstant.byteSize) {
1278             reflection.push_back(1);
1279             push(reflection, static_cast<uint16_t>(pipelineLayout.pushConstant.byteSize));
1280 
1281             push(reflection, static_cast<uint8_t>(pushConstantReflection.size()));
1282             for (const auto& refl : pushConstantReflection) {
1283                 push(reflection, refl.type);
1284                 push(reflection, static_cast<uint16_t>(refl.offset));
1285                 push(reflection, static_cast<uint16_t>(refl.size));
1286                 push(reflection, static_cast<uint16_t>(refl.arraySize));
1287                 push(reflection, static_cast<uint16_t>(refl.arrayStride));
1288                 push(reflection, static_cast<uint16_t>(refl.matrixStride));
1289                 push(reflection, static_cast<uint16_t>(refl.name.size()));
1290                 reflection.insert(reflection.end(), std::begin(refl.name), std::end(refl.name));
1291             }
1292         } else {
1293             reflection.push_back(0);
1294         }
1295     }
1296     // specialization constants
1297     {
1298         offsetSpecializationConstants = static_cast<uint16_t>(reflection.size());
1299         {
1300             const auto size = static_cast<uint32_t>(specializationConstants.size());
1301             push(reflection, static_cast<uint32_t>(specializationConstants.size()));
1302         }
1303         for (auto const& constant : specializationConstants) {
1304             push(reflection, static_cast<uint32_t>(constant.id));
1305             push(reflection, static_cast<uint32_t>(constant.type));
1306         }
1307     }
1308     // descriptor sets
1309     {
1310         offsetDescriptorSets = static_cast<uint16_t>(reflection.size());
1311         {
1312             push(reflection, static_cast<uint16_t>(pipelineLayout.descriptorSetCount));
1313         }
1314         auto begin = std::begin(pipelineLayout.descriptorSetLayouts);
1315         auto end = begin;
1316         std::advance(end, pipelineLayout.descriptorSetCount);
1317         std::for_each(begin, end, [&reflection](const DescriptorSetLayout& layout) {
1318             push(reflection, static_cast<uint16_t>(layout.set));
1319             push(reflection, static_cast<uint16_t>(layout.bindings.size()));
1320             for (const auto& binding : layout.bindings) {
1321                 push(reflection, static_cast<uint16_t>(binding.binding));
1322                 push(reflection, static_cast<uint16_t>(binding.descriptorType));
1323                 push(reflection, static_cast<uint16_t>(binding.descriptorCount));
1324             }
1325         });
1326     }
1327     // inputs
1328     {
1329         offsetInputs = static_cast<uint16_t>(reflection.size());
1330         const auto size = static_cast<uint16_t>(vertexInputAttributes.size());
1331         push(reflection, size);
1332         for (const auto& input : vertexInputAttributes) {
1333             push(reflection, static_cast<uint16_t>(input.location));
1334             push(reflection, static_cast<uint16_t>(input.format));
1335         }
1336     }
1337     // local size
1338     if (shaderStateFlags & ShaderStageFlagBits::COMPUTE_BIT) {
1339         offsetLocalSize = static_cast<uint16_t>(reflection.size());
1340         uint32_t size = compiler.get_execution_mode_argument(spv::ExecutionMode::ExecutionModeLocalSize, 0);
1341         push(reflection, size);
1342 
1343         size = compiler.get_execution_mode_argument(spv::ExecutionMode::ExecutionModeLocalSize, 1);
1344         push(reflection, size);
1345 
1346         size = compiler.get_execution_mode_argument(spv::ExecutionMode::ExecutionModeLocalSize, 2);
1347         push(reflection, size);
1348     }
1349     // update offsets to real values
1350     {
1351         auto ptr = reflection.data() + (sizeof(tag) + sizeof(type));
1352         *ptr++ = offsetPushConstants & 0xff;
1353         *ptr++ = (offsetPushConstants >> 8) & 0xff;
1354         *ptr++ = offsetSpecializationConstants & 0xff;
1355         *ptr++ = (offsetSpecializationConstants >> 8) & 0xff;
1356         *ptr++ = offsetDescriptorSets & 0xff;
1357         *ptr++ = (offsetDescriptorSets >> 8) & 0xff;
1358         *ptr++ = offsetInputs & 0xff;
1359         *ptr++ = (offsetInputs >> 8) & 0xff;
1360         *ptr++ = offsetLocalSize & 0xff;
1361         *ptr++ = (offsetLocalSize >> 8) & 0xff;
1362     }
1363 
1364     return reflection;
1365 }
1366 
1367 struct Binding {
1368     uint8_t set;
1369     uint8_t bind;
1370 };
1371 
get_binding(Gles::CoreCompiler & compiler,spirv_cross::ID id)1372 Binding get_binding(Gles::CoreCompiler& compiler, spirv_cross::ID id)
1373 {
1374     const uint32_t dset = compiler.get_decoration(id, spv::Decoration::DecorationDescriptorSet);
1375     const uint32_t dbind = compiler.get_decoration(id, spv::Decoration::DecorationBinding);
1376     assert(dset < Gles::ResourceLimits::MAX_SETS);
1377     assert(dbind < Gles::ResourceLimits::MAX_BIND_IN_SET);
1378     const uint8_t set = static_cast<uint8_t>(dset);
1379     const uint8_t bind = static_cast<uint8_t>(dbind);
1380     return { set, bind };
1381 }
1382 
SortSets(PipelineLayout & pipelineLayout)1383 void SortSets(PipelineLayout& pipelineLayout)
1384 {
1385     pipelineLayout.descriptorSetCount = 0;
1386     for (uint32_t idx = 0; idx < PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT; ++idx) {
1387         DescriptorSetLayout& currSet = pipelineLayout.descriptorSetLayouts[idx];
1388         if (currSet.set != PipelineLayoutConstants::INVALID_INDEX) {
1389             pipelineLayout.descriptorSetCount++;
1390             std::sort(currSet.bindings.begin(), currSet.bindings.end(),
1391                 [](auto const& lhs, auto const& rhs) { return (lhs.binding < rhs.binding); });
1392         }
1393     }
1394 }
1395 
Collect(Gles::CoreCompiler & compiler,const spirv_cross::SmallVector<spirv_cross::Resource> & resources,const uint32_t forceBinding=0)1396 void Collect(Gles::CoreCompiler& compiler, const spirv_cross::SmallVector<spirv_cross::Resource>& resources,
1397     const uint32_t forceBinding = 0)
1398 {
1399     std::string name;
1400 
1401     for (const auto& remap : resources) {
1402         const auto binding = get_binding(compiler, remap.id);
1403 
1404         name.resize(name.capacity() - 1);
1405         const auto nameLen = sprintf(name.data(), "s%u_b%u", binding.set, binding.bind);
1406         name.resize(nameLen);
1407 
1408         // if name is empty it's a block and we need to rename the base_type_id i.e.
1409         // uniform <base_type_id> { vec4 foo; } <id>;
1410         if (auto origname = compiler.get_name(remap.id); origname.empty()) {
1411             compiler.set_name(remap.base_type_id, name);
1412             name.insert(name.begin(), '_');
1413             compiler.set_name(remap.id, name);
1414         } else {
1415             // uniform <id> vec4 foo;
1416             compiler.set_name(remap.id, name);
1417         }
1418 
1419         compiler.unset_decoration(remap.id, spv::DecorationDescriptorSet);
1420         compiler.unset_decoration(remap.id, spv::DecorationBinding);
1421         if (forceBinding > 0) {
1422             compiler.set_decoration(
1423                 remap.id, spv::DecorationBinding, forceBinding - 1); // will be over-written later. (special handling)
1424         }
1425     }
1426 }
1427 
1428 struct ShaderModulePlatformDataGLES {
1429     std::vector<Gles::PushConstantReflection> infos;
1430 };
1431 
CollectRes(Gles::CoreCompiler & compiler,const spirv_cross::ShaderResources & res,ShaderModulePlatformDataGLES & plat_)1432 void CollectRes(
1433     Gles::CoreCompiler& compiler, const spirv_cross::ShaderResources& res, ShaderModulePlatformDataGLES& plat_)
1434 {
1435     // collect names for later linkage
1436     static constexpr uint32_t defaultBinding = 11;
1437     Collect(compiler, res.storage_buffers, defaultBinding + 1);
1438     Collect(compiler, res.storage_images, defaultBinding + 1);
1439     Collect(compiler, res.uniform_buffers, 0); // 0 == remove binding decorations (let's the compiler decide)
1440     Collect(compiler, res.subpass_inputs, 0);  // 0 == remove binding decorations (let's the compiler decide)
1441 
1442     // handle the real sampled images.
1443     Collect(compiler, res.sampled_images, 0); // 0 == remove binding decorations (let's the compiler decide)
1444 
1445     // and now the "generated ones" (separate image/sampler)
1446     std::string imageName;
1447     std::string samplerName;
1448     std::string temp;
1449     for (auto& remap : compiler.get_combined_image_samplers()) {
1450         const auto imageBinding = get_binding(compiler, remap.image_id);
1451         {
1452             imageName.resize(imageName.capacity() - 1);
1453             const auto nameLen = sprintf(imageName.data(), "s%u_b%u", imageBinding.set, imageBinding.bind);
1454             imageName.resize(nameLen);
1455         }
1456         const auto samplerBinding = get_binding(compiler, remap.sampler_id);
1457         {
1458             samplerName.resize(samplerName.capacity() - 1);
1459             const auto nameLen = sprintf(samplerName.data(), "s%u_b%u", samplerBinding.set, samplerBinding.bind);
1460             samplerName.resize(nameLen);
1461         }
1462 
1463         temp.reserve(imageName.size() + samplerName.size() + 1);
1464         temp.clear();
1465         temp.append(imageName);
1466         temp.append(1, '_');
1467         temp.append(samplerName);
1468         compiler.set_name(remap.combined_id, temp);
1469     }
1470 }
1471 
1472 /** Device backend type */
1473 enum class DeviceBackendType {
1474     /** Vulkan backend */
1475     VULKAN,
1476     /** GLES backend */
1477     OPENGLES,
1478     /** OpenGL backend */
1479     OPENGL
1480 };
1481 
SetupSpirvCross(ShaderStageFlags stage,Gles::CoreCompiler * compiler,DeviceBackendType backend,bool ovrEnabled)1482 void SetupSpirvCross(ShaderStageFlags stage, Gles::CoreCompiler* compiler, DeviceBackendType backend, bool ovrEnabled)
1483 {
1484     spirv_cross::CompilerGLSL::Options options;
1485 
1486     if (backend == DeviceBackendType::OPENGLES) {
1487         options.version = 320;
1488         options.es = true;
1489         options.fragment.default_float_precision = spirv_cross::CompilerGLSL::Options::Precision::Highp;
1490         options.fragment.default_int_precision = spirv_cross::CompilerGLSL::Options::Precision::Highp;
1491     }
1492 
1493     if (backend == DeviceBackendType::OPENGL) {
1494         options.version = 450;
1495         options.es = false;
1496     }
1497 
1498 #if defined(CORE_USE_SEPARATE_SHADER_OBJECTS) && (CORE_USE_SEPARATE_SHADER_OBJECTS == 1)
1499     if (stage & (CORE_SHADER_STAGE_VERTEX_BIT | CORE_SHADER_STAGE_FRAGMENT_BIT)) {
1500         options.separate_shader_objects = true;
1501     }
1502 #endif
1503 
1504     options.ovr_multiview_view_count = ovrEnabled ? 1 : 0;
1505 
1506     compiler->set_common_options(options);
1507 }
1508 
1509 struct Shader {
1510     ShaderStageFlags shaderStageFlags_;
1511     DeviceBackendType backend_;
1512     ShaderModulePlatformDataGLES plat_;
1513     bool ovrEnabled;
1514 
1515     std::string source_;
1516 };
1517 
ProcessShaderModule(Shader & me,const ShaderModuleCreateInfo & createInfo)1518 void ProcessShaderModule(Shader& me, const ShaderModuleCreateInfo& createInfo)
1519 {
1520     // perform reflection.
1521     auto compiler = Gles::CoreCompiler(reinterpret_cast<const uint32_t*>(createInfo.spvData.data()),
1522         static_cast<uint32_t>(createInfo.spvData.size() / sizeof(uint32_t)));
1523     // Set some options.
1524     SetupSpirvCross(me.shaderStageFlags_, &compiler, me.backend_, me.ovrEnabled);
1525 
1526     // first step in converting CORE_FLIP_NDC to regular uniform. (specializationconstant -> constant) this makes the
1527     // compiled glsl more readable, and simpler to post process later.
1528     Gles::ConvertSpecConstToConstant(compiler, "CORE_FLIP_NDC");
1529     // const auto& res = compiler.get_shader_resources();
1530 
1531     auto active = compiler.get_active_interface_variables();
1532     const auto& res = compiler.get_shader_resources(active);
1533     compiler.set_enabled_interface_variables(std::move(active));
1534 
1535     Gles::ReflectPushConstants(compiler, res, me.plat_.infos, me.shaderStageFlags_);
1536     compiler.build_combined_image_samplers();
1537     CollectRes(compiler, res, me.plat_);
1538 
1539     // set "CORE_BACKEND_TYPE" specialization to 1.
1540     Gles::SetSpecMacro(compiler, "CORE_BACKEND_TYPE", 1U);
1541 
1542     me.source_ = compiler.compile();
1543     Gles::ConvertConstantToUniform(compiler, me.source_, "CORE_FLIP_NDC");
1544 }
1545 
1546 template<typename T>
writeToFile(const array_view<T> & data,std::filesystem::path aDestinationFile)1547 bool writeToFile(const array_view<T>& data, std::filesystem::path aDestinationFile)
1548 {
1549     std::ofstream outputStream(aDestinationFile, std::ios::out | std::ios::binary);
1550     if (outputStream.is_open()) {
1551         outputStream.write(reinterpret_cast<const char*>(data.data()), data.size() * sizeof(T));
1552         outputStream.close();
1553         return true;
1554     } else {
1555         LUME_LOG_E("Could not write file: '%s'", aDestinationFile.string().c_str());
1556         return false;
1557     }
1558 }
1559 
runAllCompilationStages(std::string_view inputFilename,CompilationSettings & settings)1560 bool runAllCompilationStages(std::string_view inputFilename, CompilationSettings& settings)
1561 {
1562     try {
1563         // std::string inputFilename = aFile;
1564         const std::filesystem::path relativeInputFilename =
1565             std::filesystem::relative(inputFilename, settings.shaderSourcePath);
1566         const std::string relativeFilename = relativeInputFilename.string();
1567         const std::string extension = std::filesystem::path(inputFilename).extension().string();
1568         std::filesystem::path outputFilename = settings.compiledShaderDestinationPath / relativeInputFilename;
1569 
1570         // Make sure the output dir hierarchy exists.
1571         std::filesystem::create_directories(outputFilename.parent_path());
1572 
1573         // Just copying json files to the destination dir.
1574         if (extension == ".json") {
1575             if (!std::filesystem::exists(outputFilename) ||
1576                 !std::filesystem::equivalent(inputFilename, outputFilename)) {
1577                 LUME_LOG_I("  %s", relativeFilename.c_str());
1578                 std::filesystem::copy(inputFilename, outputFilename, std::filesystem::copy_options::overwrite_existing);
1579             }
1580             return true;
1581         } else {
1582             LUME_LOG_I("  %s", relativeFilename.c_str());
1583             outputFilename += ".spv";
1584 
1585             LUME_LOG_V("    input: '%s'", inputFilename.data());
1586             LUME_LOG_V("      dst: '%s'", settings.compiledShaderDestinationPath.string().c_str());
1587             LUME_LOG_V(" relative: '%s'", relativeFilename.c_str());
1588             LUME_LOG_V("   output: '%s'", outputFilename.string().c_str());
1589 
1590             if (std::string shaderSource = readFileToString(inputFilename); !shaderSource.empty()) {
1591                 ShaderKind shaderKind;
1592                 if (extension == ".vert") {
1593                     shaderKind = ShaderKind::VERTEX;
1594                 } else if (extension == ".frag") {
1595                     shaderKind = ShaderKind::FRAGMENT;
1596                 } else if (extension == ".comp") {
1597                     shaderKind = ShaderKind::COMPUTE;
1598                 } else {
1599                     return false;
1600                 }
1601 
1602                 if (std::string preProcessedShader =
1603                         preProcessShader(shaderSource, shaderKind, relativeFilename, settings);
1604                     !preProcessedShader.empty()) {
1605                     if (true) {
1606                         auto reflectionFile = outputFilename;
1607                         reflectionFile += ".pre";
1608                         if (!writeToFile(
1609                                 array_view(preProcessedShader.data(), preProcessedShader.size()), reflectionFile)) {
1610                             LUME_LOG_E("Failed to save reflection %s", reflectionFile.string().data());
1611                         }
1612                     }
1613 
1614                     if (std::vector<uint32_t> spvBinary =
1615                             compileShaderToSpirvBinary(preProcessedShader, shaderKind, relativeFilename, settings);
1616                         !spvBinary.empty()) {
1617                         const auto reflection = reflectSpvBinary(spvBinary, shaderKind);
1618                         if (reflection.empty()) {
1619                             LUME_LOG_E("Failed to reflect %s", inputFilename.data());
1620                         } else {
1621                             auto reflectionFile = outputFilename;
1622                             reflectionFile += ".lsb";
1623                             if (!writeToFile(array_view(reflection.data(), reflection.size()), reflectionFile)) {
1624                                 LUME_LOG_E("Failed to save reflection %s", reflectionFile.string().data());
1625                             }
1626                         }
1627 
1628                         if (settings.optimizer) {
1629                             // spirv-opt resets the passes everytime so then need to be setup
1630                             settings.optimizer->RegisterPerformancePasses();
1631                             if (!settings.optimizer->Run(spvBinary.data(), spvBinary.size(), &spvBinary)) {
1632                                 LUME_LOG_E("Failed to optimize %s", inputFilename.data());
1633                             }
1634                         }
1635 
1636                         if (writeToFile(array_view(spvBinary.data(), spvBinary.size()), outputFilename)) {
1637                             LUME_LOG_D("  -> %s", outputFilename.string().c_str());
1638                             auto glFile = outputFilename;
1639                             glFile += ".gl";
1640                             try {
1641                                 bool multiviewEnabled = false;
1642                                 if (shaderKind == ShaderKind::VERTEX) {
1643                                     static constexpr const std::string_view multiview = "GL_EXT_multiview";
1644                                     for (auto pos = shaderSource.find(multiview); pos != std::string::npos;
1645                                          pos = shaderSource.find(multiview, pos + multiview.size())) {
1646                                         if ((shaderSource.rfind("#extension", pos) != std::string::npos) &&
1647                                             (shaderSource.find("enabled", pos + multiview.size()) !=
1648                                                 std::string::npos)) {
1649                                             multiviewEnabled = true;
1650                                             break;
1651                                         }
1652                                     }
1653                                 }
1654                                 Shader shader;
1655                                 shader.shaderStageFlags_ =
1656                                     shaderKind == ShaderKind::VERTEX
1657                                         ? ShaderStageFlagBits::VERTEX_BIT
1658                                         : (shaderKind == ShaderKind::FRAGMENT ? ShaderStageFlagBits::FRAGMENT_BIT
1659                                                                               : ShaderStageFlagBits::COMPUTE_BIT);
1660 
1661                                 shader.backend_ = DeviceBackendType::OPENGL;
1662                                 shader.ovrEnabled = multiviewEnabled;
1663                                 ShaderModuleCreateInfo info;
1664                                 info.shaderStageFlags =
1665                                     shaderKind == ShaderKind::VERTEX
1666                                         ? ShaderStageFlagBits::VERTEX_BIT
1667                                         : (shaderKind == ShaderKind::FRAGMENT ? ShaderStageFlagBits::FRAGMENT_BIT
1668                                                                               : ShaderStageFlagBits::COMPUTE_BIT);
1669                                 info.spvData =
1670                                     array_view(static_cast<const uint8_t*>(static_cast<const void*>(spvBinary.data())),
1671                                         spvBinary.size() * sizeof(decltype(spvBinary)::value_type));
1672                                 info.reflectionData.reflectionData =
1673                                     array_view(static_cast<const uint8_t*>(static_cast<const void*>(reflection.data())),
1674                                         reflection.size() * sizeof(decltype(reflection)::value_type));
1675                                 ProcessShaderModule(shader, info);
1676                                 writeToFile(array_view(static_cast<const uint8_t*>(
1677                                                            static_cast<const void*>(shader.source_.data())),
1678                                                 shader.source_.size()),
1679                                     glFile);
1680                             } catch (std::exception const& e) {
1681                                 LUME_LOG_E("Failed to generate GL shader for '%s': %s", inputFilename.data(), e.what());
1682                             }
1683 
1684                             auto glesFile = glFile;
1685                             glesFile += "es";
1686                             try {
1687                                 bool multiviewEnabled = false;
1688                                 if (shaderKind == ShaderKind::VERTEX) {
1689                                     static constexpr const std::string_view multiview = "GL_EXT_multiview";
1690                                     for (auto pos = shaderSource.find(multiview); pos != std::string::npos;
1691                                          pos = shaderSource.find(multiview, pos + multiview.size())) {
1692                                         if ((shaderSource.rfind("#extension", pos) != std::string::npos) &&
1693                                             (shaderSource.find("enabled", pos + multiview.size()) !=
1694                                                 std::string::npos)) {
1695                                             multiviewEnabled = true;
1696                                             break;
1697                                         }
1698                                     }
1699                                 }
1700                                 Shader shader;
1701                                 shader.shaderStageFlags_ =
1702                                     shaderKind == ShaderKind::VERTEX
1703                                         ? ShaderStageFlagBits::VERTEX_BIT
1704                                         : (shaderKind == ShaderKind::FRAGMENT ? ShaderStageFlagBits::FRAGMENT_BIT
1705                                                                               : ShaderStageFlagBits::COMPUTE_BIT);
1706 
1707                                 shader.backend_ = DeviceBackendType::OPENGLES;
1708                                 shader.ovrEnabled = multiviewEnabled;
1709                                 ShaderModuleCreateInfo info;
1710                                 info.shaderStageFlags =
1711                                     shaderKind == ShaderKind::VERTEX
1712                                         ? ShaderStageFlagBits::VERTEX_BIT
1713                                         : (shaderKind == ShaderKind::FRAGMENT ? ShaderStageFlagBits::FRAGMENT_BIT
1714                                                                               : ShaderStageFlagBits::COMPUTE_BIT);
1715                                 info.spvData =
1716                                     array_view(static_cast<const uint8_t*>(static_cast<const void*>(spvBinary.data())),
1717                                         spvBinary.size() * sizeof(decltype(spvBinary)::value_type));
1718                                 info.reflectionData.reflectionData =
1719                                     array_view(static_cast<const uint8_t*>(static_cast<const void*>(reflection.data())),
1720                                         reflection.size() * sizeof(decltype(reflection)::value_type));
1721                                 ProcessShaderModule(shader, info);
1722                                 writeToFile(array_view(static_cast<const uint8_t*>(
1723                                                            static_cast<const void*>(shader.source_.data())),
1724                                                 shader.source_.size()),
1725                                     glesFile);
1726                             } catch (std::exception const& e) {
1727                                 LUME_LOG_E(
1728                                     "Failed to generate GLES shader for '%s': %s", inputFilename.data(), e.what());
1729                             }
1730 
1731                             return true;
1732                         }
1733                     }
1734                 }
1735             }
1736         }
1737     } catch (std::exception const& e) {
1738         LUME_LOG_E("Processing file failed '%s': %s", inputFilename.data(), e.what());
1739     }
1740     return false;
1741 }
1742 
show_usage()1743 void show_usage()
1744 {
1745     std::cout << "LumeShaderCompiler - Supported shader types: vertex (.vert), fragment (.frag), compute (.comp)\n\n"
1746                  "How to use: \n"
1747                  "LumeShaderCompiler.exe --source <source path> (sets destination path to same as source)\n"
1748                  "LumeShaderCompiler.exe --source <source path> --destination <destination path>\n"
1749                  "LumeShaderCompiler.exe --monitor (monitors changes in the source files)\n";
1750 }
1751 
filterByExtension(const std::vector<std::string> & aFilenames,const std::vector<std::string_view> & aIncludeExtensions)1752 std::vector<std::string> filterByExtension(
1753     const std::vector<std::string>& aFilenames, const std::vector<std::string_view>& aIncludeExtensions)
1754 {
1755     std::vector<std::string> filtered;
1756     for (auto const& file : aFilenames) {
1757         std::string lowercaseFileExt = std::filesystem::path(file).extension().string();
1758         std::transform(lowercaseFileExt.begin(), lowercaseFileExt.end(), lowercaseFileExt.begin(), tolower);
1759 
1760         for (auto const& extension : aIncludeExtensions) {
1761             if (lowercaseFileExt == extension) {
1762                 filtered.push_back(file);
1763                 break;
1764             }
1765         }
1766     }
1767 
1768     return filtered;
1769 }
1770 
main(int argc,char * argv[])1771 int main(int argc, char* argv[])
1772 {
1773     if (argc == 1) {
1774         show_usage();
1775         return 0;
1776     }
1777 
1778     std::filesystem::path const currentFolder = std::filesystem::current_path();
1779     std::filesystem::path shaderSourcesPath = currentFolder;
1780     std::filesystem::path compiledShaderDestinationPath;
1781     std::vector<std::filesystem::path> shaderIncludePaths;
1782     std::filesystem::path sourceFile;
1783 
1784     bool monitorChanges = false;
1785     bool optimizeSpirv = false;
1786     ShaderEnv envVersion = ShaderEnv::version_vulkan_1_0;
1787     for (int i = 1; i < argc; ++i) {
1788         const auto arg = std::string_view(argv[i]);
1789         if (arg == "--help") {
1790             show_usage();
1791             return 0;
1792         } else if (arg == "--sourceFile") {
1793             if (i + 1 < argc) {
1794                 sourceFile = argv[++i];
1795                 sourceFile.make_preferred();
1796                 shaderSourcesPath = sourceFile;
1797                 shaderSourcesPath.remove_filename();
1798                 if (compiledShaderDestinationPath.empty()) {
1799                     compiledShaderDestinationPath = shaderSourcesPath;
1800                 }
1801             } else {
1802                 LUME_LOG_E("--sourceFile option requires one argument.\n");
1803                 return 1;
1804             }
1805         } else if (arg == "--source") {
1806             if (i + 1 < argc) {
1807                 shaderSourcesPath = argv[++i];
1808                 shaderSourcesPath.make_preferred();
1809                 if (compiledShaderDestinationPath.empty()) {
1810                     compiledShaderDestinationPath = shaderSourcesPath;
1811                 }
1812             } else {
1813                 LUME_LOG_E("--source option requires one argument.");
1814                 return 1;
1815             }
1816         } else if (arg == "--destination") {
1817             if (i + 1 < argc) {
1818                 compiledShaderDestinationPath = argv[++i];
1819                 compiledShaderDestinationPath.make_preferred();
1820             } else {
1821                 LUME_LOG_E("--destination option requires one argument.");
1822                 return 1;
1823             }
1824         } else if (arg == "--include") {
1825             if (i + 1 < argc) {
1826                 shaderIncludePaths.emplace_back(argv[++i]).make_preferred();
1827             } else {
1828                 LUME_LOG_E("--include option requires one argument.");
1829                 return 1;
1830             }
1831 
1832         } else if (arg == "--monitor") {
1833             monitorChanges = true;
1834         } else if (arg == "--optimize") {
1835             optimizeSpirv = true;
1836         } else if (arg == "--vulkan") {
1837             if (i + 1 < argc) {
1838                 const auto version = std::string_view(argv[++i]);
1839                 if (version == "1.0") {
1840                     envVersion = ShaderEnv::version_vulkan_1_0;
1841                 } else if (version == "1.1") {
1842                     envVersion = ShaderEnv::version_vulkan_1_1;
1843                 } else if (version == "1.2") {
1844                     envVersion = ShaderEnv::version_vulkan_1_2;
1845 #ifdef GLSLANG_VERSION >= GLSLANG_VERSION_12_2_0
1846                 } else if (version == "1.3") {
1847                     envVersion = ShaderEnv::version_vulkan_1_3;
1848 #endif
1849                 } else {
1850                     LUME_LOG_E("Invalid argument for option --vulkan.");
1851                     return 1;
1852                 }
1853             } else {
1854                 LUME_LOG_E("--vulkan option requires one argument.");
1855                 return 1;
1856             }
1857         }
1858     }
1859 
1860     if (compiledShaderDestinationPath.empty()) {
1861         compiledShaderDestinationPath = currentFolder;
1862     }
1863 
1864     ige::FileMonitor fileMonitor;
1865 
1866     if (!std::filesystem::exists(shaderSourcesPath)) {
1867         LUME_LOG_E("Source path does not exist: '%s'", shaderSourcesPath.string().c_str());
1868         return 1;
1869     }
1870 
1871     // Make sure the destination dir exists.
1872     std::filesystem::create_directories(compiledShaderDestinationPath);
1873 
1874     if (!std::filesystem::exists(compiledShaderDestinationPath)) {
1875         LUME_LOG_E("Destination path does not exist: '%s'", compiledShaderDestinationPath.string().c_str());
1876         return 1;
1877     }
1878 
1879     fileMonitor.AddPath(shaderSourcesPath.string());
1880     std::vector<std::string> fileList = [&]() {
1881         std::vector<std::string> list;
1882         if (!sourceFile.empty()) {
1883             list.push_back(sourceFile.u8string());
1884         } else {
1885             list = fileMonitor.getMonitoredFiles();
1886         }
1887         return list;
1888     }();
1889 
1890     const std::vector<std::string_view> supportedFileTypes = { ".vert", ".frag", ".comp", ".json" };
1891     fileList = filterByExtension(fileList, supportedFileTypes);
1892 
1893     LUME_LOG_I("     Source path: '%s'", std::filesystem::absolute(shaderSourcesPath).string().c_str());
1894     for (auto const& path : shaderIncludePaths) {
1895         LUME_LOG_I("    Include path: '%s'", std::filesystem::absolute(path).string().c_str());
1896     }
1897     LUME_LOG_I("Destination path: '%s'", std::filesystem::absolute(compiledShaderDestinationPath).string().c_str());
1898     LUME_LOG_I("");
1899     LUME_LOG_I("Processing:");
1900 
1901     int errorCount = 0;
1902     scope scope([]() { glslang::InitializeProcess(); }, []() { glslang::FinalizeProcess(); });
1903 
1904     std::vector<std::filesystem::path> searchPath;
1905     searchPath.reserve(searchPath.size() + 1 + shaderIncludePaths.size());
1906     searchPath.emplace_back(shaderSourcesPath.string());
1907     for (auto const& path : shaderIncludePaths) {
1908         searchPath.emplace_back(path.string());
1909     }
1910 
1911     auto settings =
1912         CompilationSettings { envVersion, searchPath, {}, shaderSourcesPath, compiledShaderDestinationPath };
1913 
1914     if (optimizeSpirv) {
1915         spv_target_env targetEnv = spv_target_env::SPV_ENV_VULKAN_1_0;
1916         switch (envVersion) {
1917             case ShaderEnv::version_vulkan_1_0:
1918                 targetEnv = spv_target_env::SPV_ENV_VULKAN_1_0;
1919                 break;
1920             case ShaderEnv::version_vulkan_1_1:
1921                 targetEnv = spv_target_env::SPV_ENV_VULKAN_1_1;
1922                 break;
1923             case ShaderEnv::version_vulkan_1_2:
1924                 targetEnv = spv_target_env::SPV_ENV_VULKAN_1_2;
1925                 break;
1926             case ShaderEnv::version_vulkan_1_3:
1927                 targetEnv = spv_target_env::SPV_ENV_VULKAN_1_3;
1928                 break;
1929             default:
1930                 break;
1931         }
1932         settings.optimizer.emplace(targetEnv);
1933     }
1934 
1935     // Startup compilation.
1936     for (auto const& file : fileList) {
1937         std::string relativeFilename = std::filesystem::relative(file, shaderSourcesPath).string();
1938         LUME_LOG_D("Tracked source file: '%s'", relativeFilename.c_str());
1939         if (!runAllCompilationStages(file, settings)) {
1940             errorCount++;
1941         }
1942     }
1943 
1944     if (errorCount == 0) {
1945         LUME_LOG_I("Success.");
1946     } else {
1947         LUME_LOG_E("Failed: %d", errorCount);
1948     }
1949 
1950     if (monitorChanges) {
1951         LUME_LOG_I("Monitoring file changes.");
1952     }
1953 
1954     // Main loop.
1955     while (monitorChanges) {
1956         std::vector<std::string> addedFiles, removedFiles, modifiedFiles;
1957         fileMonitor.scanModifications(addedFiles, removedFiles, modifiedFiles);
1958         modifiedFiles = filterByExtension(modifiedFiles, supportedFileTypes);
1959 
1960         if (sourceFile.empty()) {
1961             addedFiles = filterByExtension(addedFiles, supportedFileTypes);
1962             removedFiles = filterByExtension(removedFiles, supportedFileTypes);
1963 
1964             if (!addedFiles.empty()) {
1965                 LUME_LOG_I("Files added:");
1966                 for (auto const& addedFile : addedFiles) {
1967                     runAllCompilationStages(addedFile, settings);
1968                 }
1969             }
1970 
1971             if (!removedFiles.empty()) {
1972                 LUME_LOG_I("Files removed:");
1973                 for (auto const& removedFile : removedFiles) {
1974                     std::string relativeFilename = std::filesystem::relative(removedFile, shaderSourcesPath).string();
1975                     LUME_LOG_I("  %s", relativeFilename.c_str());
1976                 }
1977             }
1978 
1979             if (!modifiedFiles.empty()) {
1980                 LUME_LOG_I("Files modified:");
1981                 for (auto const& modifiedFile : modifiedFiles) {
1982                     runAllCompilationStages(modifiedFile, settings);
1983                 }
1984             }
1985         } else if (!modifiedFiles.empty()) {
1986             if (auto pos = std::find_if(modifiedFiles.cbegin(), modifiedFiles.cend(),
1987                     [&sourceFile](const std::string& modified) { return modified == sourceFile; });
1988                 pos != modifiedFiles.cend()) {
1989                 runAllCompilationStages(*pos, settings);
1990             }
1991         }
1992 
1993         std::this_thread::sleep_for(std::chrono::seconds(1));
1994     }
1995 
1996     return errorCount;
1997 }