1  /*
2   * Copyright (C) 2023 Huawei Device Co., Ltd.
3   * Licensed under the Apache License, Version 2.0 (the "License");
4   * you may not use this file except in compliance with the License.
5   * You may obtain a copy of the License at
6   *
7   *     http://www.apache.org/licenses/LICENSE-2.0
8   *
9   * Unless required by applicable law or agreed to in writing, software
10   * distributed under the License is distributed on an "AS IS" BASIS,
11   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12   * See the License for the specific language governing permissions and
13   * limitations under the License.
14   */
15  
16  #include "default_limits.h"
17  #include <glslang/Public/ShaderLang.h>
18  #include <SPIRV/GlslangToSpv.h>
19  #include <SPIRV/SpvTools.h>
20  #include <spirv-tools/optimizer.hpp>
21  
22  #include "spirv_cross.hpp"
23  
24  // #include "preprocess/preprocess.h"
25  #include <algorithm>
26  #include <chrono>
27  #include <filesystem>
28  #include <fstream>
29  #include <iostream>
30  #include <memory>
31  #include <numeric>
32  #include <optional>
33  #include <sstream>
34  #include <string>
35  #include <thread>
36  
37  #include "io/dev/FileMonitor.h"
38  #include "lume/Log.h"
39  #include "shader_type.h"
40  #include "spirv_cross_helpers_gles.h"
41  
42  using namespace std::chrono_literals;
43  
44  // Enumerations from Engine which should match: Format, DescriptorType, ShaderStageFlagBits
45  /** Format */
46  enum class Format {
47      /** Undefined */
48      UNDEFINED = 0,
49      /** R4G4 UNORM PACK8 */
50      R4G4_UNORM_PACK8 = 1,
51      /** R4G4B4A4 UNORM PACK16 */
52      R4G4B4A4_UNORM_PACK16 = 2,
53      /** B4G4R4A4 UNORM PACK16 */
54      B4G4R4A4_UNORM_PACK16 = 3,
55      /** R5G6B5 UNORM PACK16 */
56      R5G6B5_UNORM_PACK16 = 4,
57      /** B5G6R5 UNORM PACK16 */
58      B5G6R5_UNORM_PACK16 = 5,
59      /** R5G5B5A1 UNORM PACK16 */
60      R5G5B5A1_UNORM_PACK16 = 6,
61      /** B5G5R5A1 UNORM PACK16 */
62      B5G5R5A1_UNORM_PACK16 = 7,
63      /** A1R5G5B5 UNORM PACK16 */
64      A1R5G5B5_UNORM_PACK16 = 8,
65      /** R8 UNORM */
66      R8_UNORM = 9,
67      /** R8 SNORM */
68      R8_SNORM = 10,
69      /** R8 USCALED */
70      R8_USCALED = 11,
71      /** R8 SSCALED */
72      R8_SSCALED = 12,
73      /** R8 UINT */
74      R8_UINT = 13,
75      /** R8 SINT */
76      R8_SINT = 14,
77      /** R8 SRGB */
78      R8_SRGB = 15,
79      /** R8G8 UNORM */
80      R8G8_UNORM = 16,
81      /** R8G8 SNORM */
82      R8G8_SNORM = 17,
83      /** R8G8 USCALED */
84      R8G8_USCALED = 18,
85      /** R8G8 SSCALED */
86      R8G8_SSCALED = 19,
87      /** R8G8 UINT */
88      R8G8_UINT = 20,
89      /** R8G8 SINT */
90      R8G8_SINT = 21,
91      /** R8G8 SRGB */
92      R8G8_SRGB = 22,
93      /** R8G8B8 UNORM */
94      R8G8B8_UNORM = 23,
95      /** R8G8B8 SNORM */
96      R8G8B8_SNORM = 24,
97      /** R8G8B8 USCALED */
98      R8G8B8_USCALED = 25,
99      /** R8G8B8 SSCALED */
100      R8G8B8_SSCALED = 26,
101      /** R8G8B8 UINT */
102      R8G8B8_UINT = 27,
103      /** R8G8B8 SINT */
104      R8G8B8_SINT = 28,
105      /** R8G8B8 SRGB */
106      R8G8B8_SRGB = 29,
107      /** B8G8R8 UNORM */
108      B8G8R8_UNORM = 30,
109      /** B8G8R8 SNORM */
110      B8G8R8_SNORM = 31,
111      /** B8G8R8 UINT */
112      B8G8R8_UINT = 34,
113      /** B8G8R8 SINT */
114      B8G8R8_SINT = 35,
115      /** B8G8R8 SRGB */
116      B8G8R8_SRGB = 36,
117      /** R8G8B8A8 UNORM */
118      R8G8B8A8_UNORM = 37,
119      /** R8G8B8A8 SNORM */
120      R8G8B8A8_SNORM = 38,
121      /** R8G8B8A8 USCALED */
122      R8G8B8A8_USCALED = 39,
123      /** R8G8B8A8 SSCALED */
124      R8G8B8A8_SSCALED = 40,
125      /** R8G8B8A8 UINT */
126      R8G8B8A8_UINT = 41,
127      /** R8G8B8A8 SINT */
128      R8G8B8A8_SINT = 42,
129      /** R8G8B8A8 SRGB */
130      R8G8B8A8_SRGB = 43,
131      /** B8G8R8A8 UNORM */
132      B8G8R8A8_UNORM = 44,
133      /** B8G8R8A8 SNORM */
134      B8G8R8A8_SNORM = 45,
135      /** B8G8R8A8 UINT */
136      B8G8R8A8_UINT = 48,
137      /** B8G8R8A8 SINT */
138      B8G8R8A8_SINT = 49,
139      /** FORMAT B8G8R8A8 SRGB */
140      B8G8R8A8_SRGB = 50,
141      /** A8B8G8R8 UNORM PACK32 */
142      A8B8G8R8_UNORM_PACK32 = 51,
143      /** A8B8G8R8 SNORM PACK32 */
144      A8B8G8R8_SNORM_PACK32 = 52,
145      /** A8B8G8R8 USCALED PACK32 */
146      A8B8G8R8_USCALED_PACK32 = 53,
147      /** A8B8G8R8 SSCALED PACK32 */
148      A8B8G8R8_SSCALED_PACK32 = 54,
149      /** A8B8G8R8 UINT PACK32 */
150      A8B8G8R8_UINT_PACK32 = 55,
151      /** A8B8G8R8 SINT PACK32 */
152      A8B8G8R8_SINT_PACK32 = 56,
153      /** A8B8G8R8 SRGB PACK32 */
154      A8B8G8R8_SRGB_PACK32 = 57,
155      /** A2R10G10B10 UNORM PACK32 */
156      A2R10G10B10_UNORM_PACK32 = 58,
157      /** A2R10G10B10 UINT PACK32 */
158      A2R10G10B10_UINT_PACK32 = 62,
159      /** A2R10G10B10 SINT PACK32 */
160      A2R10G10B10_SINT_PACK32 = 63,
161      /** A2B10G10R10 UNORM PACK32 */
162      A2B10G10R10_UNORM_PACK32 = 64,
163      /** A2B10G10R10 SNORM PACK32 */
164      A2B10G10R10_SNORM_PACK32 = 65,
165      /** A2B10G10R10 USCALED PACK32 */
166      A2B10G10R10_USCALED_PACK32 = 66,
167      /** A2B10G10R10 SSCALED PACK32 */
168      A2B10G10R10_SSCALED_PACK32 = 67,
169      /** A2B10G10R10 UINT PACK32 */
170      A2B10G10R10_UINT_PACK32 = 68,
171      /** A2B10G10R10 SINT PACK32 */
172      A2B10G10R10_SINT_PACK32 = 69,
173      /** R16 UNORM */
174      R16_UNORM = 70,
175      /** R16 SNORM */
176      R16_SNORM = 71,
177      /** R16 USCALED */
178      R16_USCALED = 72,
179      /** R16 SSCALED */
180      R16_SSCALED = 73,
181      /** R16 UINT */
182      R16_UINT = 74,
183      /** R16 SINT */
184      R16_SINT = 75,
185      /** R16 SFLOAT */
186      R16_SFLOAT = 76,
187      /** R16G16 UNORM */
188      R16G16_UNORM = 77,
189      /** R16G16 SNORM */
190      R16G16_SNORM = 78,
191      /** R16G16 USCALED */
192      R16G16_USCALED = 79,
193      /** R16G16 SSCALED */
194      R16G16_SSCALED = 80,
195      /** R16G16 UINT */
196      R16G16_UINT = 81,
197      /** R16G16 SINT */
198      R16G16_SINT = 82,
199      /** R16G16 SFLOAT */
200      R16G16_SFLOAT = 83,
201      /** R16G16B16 UNORM */
202      R16G16B16_UNORM = 84,
203      /** R16G16B16 SNORM */
204      R16G16B16_SNORM = 85,
205      /** R16G16B16 USCALED */
206      R16G16B16_USCALED = 86,
207      /** R16G16B16 SSCALED */
208      R16G16B16_SSCALED = 87,
209      /** R16G16B16 UINT */
210      R16G16B16_UINT = 88,
211      /** R16G16B16 SINT */
212      R16G16B16_SINT = 89,
213      /** R16G16B16 SFLOAT */
214      R16G16B16_SFLOAT = 90,
215      /** R16G16B16A16 UNORM */
216      R16G16B16A16_UNORM = 91,
217      /** R16G16B16A16 SNORM */
218      R16G16B16A16_SNORM = 92,
219      /** R16G16B16A16 USCALED */
220      R16G16B16A16_USCALED = 93,
221      /** R16G16B16A16 SSCALED */
222      R16G16B16A16_SSCALED = 94,
223      /** R16G16B16A16 UINT */
224      R16G16B16A16_UINT = 95,
225      /** R16G16B16A16 SINT */
226      R16G16B16A16_SINT = 96,
227      /** R16G16B16A16 SFLOAT */
228      R16G16B16A16_SFLOAT = 97,
229      /** R32 UINT */
230      R32_UINT = 98,
231      /** R32 SINT */
232      R32_SINT = 99,
233      /** R32 SFLOAT */
234      R32_SFLOAT = 100,
235      /** R32G32 UINT */
236      R32G32_UINT = 101,
237      /** R32G32 SINT */
238      R32G32_SINT = 102,
239      /** R32G32 SFLOAT */
240      R32G32_SFLOAT = 103,
241      /** R32G32B32 UINT */
242      R32G32B32_UINT = 104,
243      /** R32G32B32 SINT */
244      R32G32B32_SINT = 105,
245      /** R32G32B32 SFLOAT */
246      R32G32B32_SFLOAT = 106,
247      /** R32G32B32A32 UINT */
248      R32G32B32A32_UINT = 107,
249      /** R32G32B32A32 SINT */
250      R32G32B32A32_SINT = 108,
251      /** R32G32B32A32 SFLOAT */
252      R32G32B32A32_SFLOAT = 109,
253      /** B10G11R11 UFLOAT PACK32 */
254      B10G11R11_UFLOAT_PACK32 = 122,
255      /** E5B9G9R9 UFLOAT PACK32 */
256      E5B9G9R9_UFLOAT_PACK32 = 123,
257      /** D16 UNORM */
258      D16_UNORM = 124,
259      /** X8 D24 UNORM PACK32 */
260      X8_D24_UNORM_PACK32 = 125,
261      /** D32 SFLOAT */
262      D32_SFLOAT = 126,
263      /** S8 UINT */
264      S8_UINT = 127,
265      /** D24 UNORM S8 UINT */
266      D24_UNORM_S8_UINT = 129,
267      /** BC1 RGB UNORM BLOCK */
268      BC1_RGB_UNORM_BLOCK = 131,
269      /** BC1 RGB SRGB BLOCK */
270      BC1_RGB_SRGB_BLOCK = 132,
271      /** BC1 RGBA UNORM BLOCK */
272      BC1_RGBA_UNORM_BLOCK = 133,
273      /** BC1 RGBA SRGB BLOCK */
274      BC1_RGBA_SRGB_BLOCK = 134,
275      /** BC2 UNORM BLOCK */
276      BC2_UNORM_BLOCK = 135,
277      /** BC2 SRGB BLOCK */
278      BC2_SRGB_BLOCK = 136,
279      /** BC3 UNORM BLOCK */
280      BC3_UNORM_BLOCK = 137,
281      /** BC3 SRGB BLOCK */
282      BC3_SRGB_BLOCK = 138,
283      /** BC4 UNORM BLOCK */
284      BC4_UNORM_BLOCK = 139,
285      /** BC4 SNORM BLOCK */
286      BC4_SNORM_BLOCK = 140,
287      /** BC5 UNORM BLOCK */
288      BC5_UNORM_BLOCK = 141,
289      /** BC5 SNORM BLOCK */
290      BC5_SNORM_BLOCK = 142,
291      /** BC6H UFLOAT BLOCK */
292      BC6H_UFLOAT_BLOCK = 143,
293      /** BC6H SFLOAT BLOCK */
294      BC6H_SFLOAT_BLOCK = 144,
295      /** BC7 UNORM BLOCK */
296      BC7_UNORM_BLOCK = 145,
297      /** BC7 SRGB BLOCK */
298      BC7_SRGB_BLOCK = 146,
299      /** ETC2 R8G8B8 UNORM BLOCK */
300      ETC2_R8G8B8_UNORM_BLOCK = 147,
301      /** ETC2 R8G8B8 SRGB BLOCK */
302      ETC2_R8G8B8_SRGB_BLOCK = 148,
303      /** ETC2 R8G8B8A1 UNORM BLOCK */
304      ETC2_R8G8B8A1_UNORM_BLOCK = 149,
305      /** ETC2 R8G8B8A1 SRGB BLOCK */
306      ETC2_R8G8B8A1_SRGB_BLOCK = 150,
307      /** ETC2 R8G8B8A8 UNORM BLOCK */
308      ETC2_R8G8B8A8_UNORM_BLOCK = 151,
309      /** ETC2 R8G8B8A8 SRGB BLOCK */
310      ETC2_R8G8B8A8_SRGB_BLOCK = 152,
311      /** EAC R11 UNORM BLOCK */
312      EAC_R11_UNORM_BLOCK = 153,
313      /** EAC R11 SNORM BLOCK */
314      EAC_R11_SNORM_BLOCK = 154,
315      /** EAC R11G11 UNORM BLOCK */
316      EAC_R11G11_UNORM_BLOCK = 155,
317      /** EAC R11G11 SNORM BLOCK */
318      EAC_R11G11_SNORM_BLOCK = 156,
319      /** ASTC 4x4 UNORM BLOCK */
320      ASTC_4x4_UNORM_BLOCK = 157,
321      /** ASTC 4x4 SRGB BLOCK */
322      ASTC_4x4_SRGB_BLOCK = 158,
323      /** ASTC 5x4 UNORM BLOCK */
324      ASTC_5x4_UNORM_BLOCK = 159,
325      /** ASTC 5x4 SRGB BLOCK */
326      ASTC_5x4_SRGB_BLOCK = 160,
327      /** ASTC 5x5 UNORM BLOCK */
328      ASTC_5x5_UNORM_BLOCK = 161,
329      /** ASTC 5x5 SRGB BLOCK */
330      ASTC_5x5_SRGB_BLOCK = 162,
331      /** ASTC 6x5 UNORM BLOCK */
332      ASTC_6x5_UNORM_BLOCK = 163,
333      /** ASTC 6x5 SRGB BLOCK */
334      ASTC_6x5_SRGB_BLOCK = 164,
335      /** ASTC 6x6 UNORM BLOCK */
336      ASTC_6x6_UNORM_BLOCK = 165,
337      /** ASTC 6x6 SRGB BLOCK */
338      ASTC_6x6_SRGB_BLOCK = 166,
339      /** ASTC 8x5 UNORM BLOCK */
340      ASTC_8x5_UNORM_BLOCK = 167,
341      /** ASTC 8x5 SRGB BLOCK */
342      ASTC_8x5_SRGB_BLOCK = 168,
343      /** ASTC 8x6 UNORM BLOCK */
344      ASTC_8x6_UNORM_BLOCK = 169,
345      /** ASTC 8x6 SRGB BLOCK */
346      ASTC_8x6_SRGB_BLOCK = 170,
347      /** ASTC 8x8 UNORM BLOCK */
348      ASTC_8x8_UNORM_BLOCK = 171,
349      /** ASTC 8x8 SRGB BLOCK */
350      ASTC_8x8_SRGB_BLOCK = 172,
351      /** ASTC 10x5 UNORM BLOCK */
352      ASTC_10x5_UNORM_BLOCK = 173,
353      /** ASTC 10x5 SRGB BLOCK */
354      ASTC_10x5_SRGB_BLOCK = 174,
355      /** ASTC 10x6 UNORM BLOCK */
356      ASTC_10x6_UNORM_BLOCK = 175,
357      /** ASTC 10x6 SRGB BLOCK */
358      ASTC_10x6_SRGB_BLOCK = 176,
359      /** ASTC 10x8 UNORM BLOCK */
360      ASTC_10x8_UNORM_BLOCK = 177,
361      /** ASTC 10x8 SRGB BLOCK */
362      ASTC_10x8_SRGB_BLOCK = 178,
363      /** ASTC 10x10 UNORM BLOCK */
364      ASTC_10x10_UNORM_BLOCK = 179,
365      /** ASTC 10x10 SRGB BLOCK */
366      ASTC_10x10_SRGB_BLOCK = 180,
367      /** ASTC 12x10 UNORM BLOCK */
368      ASTC_12x10_UNORM_BLOCK = 181,
369      /** ASTC 12x10 SRGB BLOCK */
370      ASTC_12x10_SRGB_BLOCK = 182,
371      /** ASTC 12x12 UNORM BLOCK */
372      ASTC_12x12_UNORM_BLOCK = 183,
373      /** ASTC 12x12 SRGB BLOCK */
374      ASTC_12x12_SRGB_BLOCK = 184,
375      /** G8B8G8R8 422 UNORM */
376      G8B8G8R8_422_UNORM = 1000156000,
377      /** B8G8R8G8 422 UNORM */
378      B8G8R8G8_422_UNORM = 1000156001,
379      /** G8 B8 R8 3PLANE 420 UNORM */
380      G8_B8_R8_3PLANE_420_UNORM = 1000156002,
381      /** G8 B8R8 2PLANE 420 UNORM */
382      G8_B8R8_2PLANE_420_UNORM = 1000156003,
383      /** G8 B8 R8 3PLANE 422 UNORM */
384      G8_B8_R8_3PLANE_422_UNORM = 1000156004,
385      /** G8 B8R8 2PLANE 422 UNORM */
386      G8_B8R8_2PLANE_422_UNORM = 1000156005,
387      /** Max enumeration */
388      MAX_ENUM = 0x7FFFFFFF
389  };
390  
391  enum class DescriptorType {
392      /** Sampler */
393      SAMPLER = 0,
394      /** Combined image sampler */
395      COMBINED_IMAGE_SAMPLER = 1,
396      /** Sampled image */
397      SAMPLED_IMAGE = 2,
398      /** Storage image */
399      STORAGE_IMAGE = 3,
400      /** Uniform texel buffer */
401      UNIFORM_TEXEL_BUFFER = 4,
402      /** Storage texel buffer */
403      STORAGE_TEXEL_BUFFER = 5,
404      /** Uniform buffer */
405      UNIFORM_BUFFER = 6,
406      /** Storage buffer */
407      STORAGE_BUFFER = 7,
408      /** Dynamic uniform buffer */
409      UNIFORM_BUFFER_DYNAMIC = 8,
410      /** Dynamic storage buffer */
411      STORAGE_BUFFER_DYNAMIC = 9,
412      /** Input attachment */
413      INPUT_ATTACHMENT = 10,
414      /** Acceleration structure */
415      ACCELERATION_STRUCTURE = 1000150000,
416      /** Max enumeration */
417      MAX_ENUM = 0x7FFFFFFF
418  };
419  
420  /** Vertex input rate */
421  enum class VertexInputRate {
422      /** Vertex */
423      VERTEX = 0,
424      /** Instance */
425      INSTANCE = 1,
426      /** Max enumeration */
427      MAX_ENUM = 0x7FFFFFFF
428  };
429  
430  /** Pipeline layout constants */
431  struct PipelineLayoutConstants {
432      /** Max descriptor set count */
433      static constexpr uint32_t MAX_DESCRIPTOR_SET_COUNT { 4u };
434      /** Max dynamic descriptor offset count */
435      static constexpr uint32_t MAX_DYNAMIC_DESCRIPTOR_OFFSET_COUNT { 16u };
436      /** Invalid index */
437      static constexpr uint32_t INVALID_INDEX { ~0u };
438      /** Max push constant byte size */
439      static constexpr uint32_t MAX_PUSH_CONSTANT_BYTE_SIZE { 128u };
440  };
441  
442  /** Descriptor set layout binding */
443  struct DescriptorSetLayoutBinding {
444      /** Binding */
445      uint32_t binding { PipelineLayoutConstants::INVALID_INDEX };
446      /** Descriptor type */
447      DescriptorType descriptorType { DescriptorType::MAX_ENUM };
448      /** Descriptor count */
449      uint32_t descriptorCount { 0 };
450      /** Stage flags */
451      ShaderStageFlags shaderStageFlags;
452  };
453  
454  /** Descriptor set layout */
455  struct DescriptorSetLayout {
456      /** Set */
457      uint32_t set { PipelineLayoutConstants::INVALID_INDEX };
458      /** Bindings */
459      std::vector<DescriptorSetLayoutBinding> bindings;
460  };
461  
462  /** Push constant */
463  struct PushConstant {
464      /** Shader stage flags */
465      ShaderStageFlags shaderStageFlags;
466      /** Byte size */
467      uint32_t byteSize { 0 };
468  };
469  
470  /** Pipeline layout */
471  struct PipelineLayout {
472      /** Push constant */
473      PushConstant pushConstant;
474      /** Descriptor set count */
475      uint32_t descriptorSetCount { 0 };
476      /** Descriptor sets */
477      DescriptorSetLayout descriptorSetLayouts[PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT] {};
478  };
479  
480  constexpr const uint32_t RESERVED_CONSTANT_ID_INDEX { 256 };
481  
482  /** Vertex input declaration */
483  struct VertexInputDeclaration {
484      /** Vertex input binding description */
485      struct VertexInputBindingDescription {
486          /** Binding */
487          uint32_t binding { ~0u };
488          /** Stride */
489          uint32_t stride { 0u };
490          /** Vertex input rate */
491          VertexInputRate vertexInputRate { VertexInputRate::MAX_ENUM };
492      };
493  
494      /** Vertex input attribute description */
495      struct VertexInputAttributeDescription {
496          /** Location */
497          uint32_t location { ~0u };
498          /** Binding */
499          uint32_t binding { ~0u };
500          /** Format */
501          Format format { Format::UNDEFINED };
502          /** Offset */
503          uint32_t offset { 0u };
504      };
505  };
506  
507  struct VertexAttributeInfo {
508      uint32_t byteSize { 0 };
509      VertexInputDeclaration::VertexInputAttributeDescription description;
510  };
511  
512  struct UVec3 {
513      uint32_t x;
514      uint32_t y;
515      uint32_t z;
516  };
517  
518  struct ShaderReflectionData {
519      array_view<const uint8_t> reflectionData;
520  
521      bool IsValid() const;
522      ShaderStageFlags GetStageFlags() const;
523      PipelineLayout GetPipelineLayout() const;
524      std::vector<ShaderSpecializationConstant> GetSpecializationConstants() const;
525      std::vector<VertexInputDeclaration::VertexInputAttributeDescription> GetInputDescriptions() const;
526      UVec3 GetLocalSize() const;
527  };
528  
529  struct ShaderModuleCreateInfo {
530      ShaderStageFlags shaderStageFlags;
531      array_view<const uint8_t> spvData;
532      ShaderReflectionData reflectionData;
533  };
534  
535  struct CompilationSettings {
536      ShaderEnv shaderEnv;
537      std::vector<std::filesystem::path> shaderIncludePaths;
538      std::optional<spvtools::Optimizer> optimizer;
539      std::filesystem::path& shaderSourcePath;
540      std::filesystem::path& compiledShaderDestinationPath;
541  };
542  
543  constexpr uint8_t REFLECTION_TAG[] = { 'r', 'f', 'l', 0 };
544  struct ReflectionHeader {
545      uint8_t tag[sizeof(REFLECTION_TAG)];
546      uint16_t type;
547      uint16_t offsetPushConstants;
548      uint16_t offsetSpecializationConstants;
549      uint16_t offsetDescriptorSets;
550      uint16_t offsetInputs;
551      uint16_t offsetLocalSize;
552  };
553  
554  class scope {
555  private:
556      std::function<void()> init;
557      std::function<void()> deinit;
558  
559  public:
scope(const std::function<void ()> && initializer,const std::function<void ()> && deinitalizer)560      scope(const std::function<void()>&& initializer, const std::function<void()>&& deinitalizer)
561          : init(initializer), deinit(deinitalizer)
562      {
563          init();
564      }
565  
~scope()566      ~scope()
567      {
568          deinit();
569      }
570  };
571  
IsValid() const572  bool ShaderReflectionData::IsValid() const
573  {
574      if (reflectionData.size() < sizeof(ReflectionHeader)) {
575          return false;
576      }
577      const ReflectionHeader& header = *reinterpret_cast<const ReflectionHeader*>(reflectionData.data());
578      return memcmp(header.tag, REFLECTION_TAG, sizeof(REFLECTION_TAG)) == 0;
579  }
580  
GetStageFlags() const581  ShaderStageFlags ShaderReflectionData::GetStageFlags() const
582  {
583      ShaderStageFlags flags;
584      const ReflectionHeader& header = *reinterpret_cast<const ReflectionHeader*>(reflectionData.data());
585      flags = static_cast<ShaderStageFlagBits>(header.type);
586      return flags;
587  }
588  
GetPipelineLayout() const589  PipelineLayout ShaderReflectionData::GetPipelineLayout() const
590  {
591      PipelineLayout pipelineLayout;
592      const ReflectionHeader& header = *reinterpret_cast<const ReflectionHeader*>(reflectionData.data());
593      if (header.offsetPushConstants && header.offsetPushConstants < reflectionData.size()) {
594          auto ptr = reflectionData.data() + header.offsetPushConstants;
595          const auto constants = *ptr;
596          if (constants) {
597              pipelineLayout.pushConstant.shaderStageFlags = static_cast<ShaderStageFlagBits>(header.type);
598              pipelineLayout.pushConstant.byteSize = static_cast<uint32_t>(*(ptr + 1) | (*(ptr + 2) << 8));
599          }
600      }
601      if (header.offsetDescriptorSets && header.offsetDescriptorSets < reflectionData.size()) {
602          auto ptr = reflectionData.data() + header.offsetDescriptorSets;
603          pipelineLayout.descriptorSetCount = static_cast<uint32_t>(*(ptr) | (*(ptr + 1) << 8));
604          ptr += 2;
605          for (auto i = 0u; i < pipelineLayout.descriptorSetCount; ++i) {
606              // write to correct set location
607              const uint32_t set = static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8));
608              assert(set < PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT);
609              auto& layout = pipelineLayout.descriptorSetLayouts[set];
610              layout.set = set;
611              ptr += 2;
612              const auto bindings = static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8));
613              ptr += 2;
614              for (auto j = 0u; j < bindings; ++j) {
615                  DescriptorSetLayoutBinding binding;
616                  binding.binding = static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8));
617                  ptr += 2;
618                  binding.descriptorType = static_cast<DescriptorType>(*ptr | (*(ptr + 1) << 8));
619                  ptr += 2;
620                  binding.descriptorCount = static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8));
621                  ptr += 2;
622                  binding.shaderStageFlags = static_cast<ShaderStageFlagBits>(header.type);
623                  layout.bindings.push_back(binding);
624              }
625          }
626      }
627      return pipelineLayout;
628  }
629  
GetSpecializationConstants() const630  std::vector<ShaderSpecializationConstant> ShaderReflectionData::GetSpecializationConstants() const
631  {
632      std::vector<ShaderSpecializationConstant> constants;
633      const ReflectionHeader& header = *reinterpret_cast<const ReflectionHeader*>(reflectionData.data());
634      if (header.offsetSpecializationConstants && header.offsetSpecializationConstants < reflectionData.size()) {
635          auto ptr = reflectionData.data() + header.offsetSpecializationConstants;
636          const auto size = *ptr | *(ptr + 1) << 8 | *(ptr + 2) << 16 | *(ptr + 3) << 24;
637          ptr += 4;
638          for (auto i = 0; i < size; ++i) {
639              ShaderSpecializationConstant constant;
640              constant.shaderStage = static_cast<ShaderStageFlagBits>(header.type);
641              constant.id = static_cast<uint32_t>(*ptr | *(ptr + 1) << 8 | *(ptr + 2) << 16 | *(ptr + 3) << 24);
642              ptr += 4;
643              constant.type = static_cast<ShaderSpecializationConstant::Type>(
644                  *ptr | *(ptr + 1) << 8 | *(ptr + 2) << 16 | *(ptr + 3) << 24);
645              ptr += 4;
646              constant.offset = 0;
647              constants.push_back(constant);
648          }
649      }
650      return constants;
651  }
652  
GetInputDescriptions() const653  std::vector<VertexInputDeclaration::VertexInputAttributeDescription> ShaderReflectionData::GetInputDescriptions() const
654  {
655      std::vector<VertexInputDeclaration::VertexInputAttributeDescription> inputs;
656      const ReflectionHeader& header = *reinterpret_cast<const ReflectionHeader*>(reflectionData.data());
657      if (header.offsetInputs && header.offsetInputs < reflectionData.size()) {
658          auto ptr = reflectionData.data() + header.offsetInputs;
659          const auto size = *(ptr) | (*(ptr + 1) << 8);
660          ptr += 2;
661          for (auto i = 0; i < size; ++i) {
662              VertexInputDeclaration::VertexInputAttributeDescription desc;
663              desc.location = static_cast<uint32_t>(*(ptr) | (*(ptr + 1) << 8));
664              ptr += 2;
665              desc.binding = desc.location;
666              desc.format = static_cast<Format>(*(ptr) | (*(ptr + 1) << 8));
667              ptr += 2;
668              desc.offset = 0;
669              inputs.push_back(desc);
670          }
671      }
672      return inputs;
673  }
674  
GetLocalSize() const675  UVec3 ShaderReflectionData::GetLocalSize() const
676  {
677      UVec3 sizes;
678      const ReflectionHeader& header = *reinterpret_cast<const ReflectionHeader*>(reflectionData.data());
679      if (header.offsetLocalSize && header.offsetLocalSize < reflectionData.size()) {
680          auto ptr = reflectionData.data() + header.offsetLocalSize;
681          sizes.x = static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8) | (*(ptr + 2)) << 16 | (*(ptr + 3)) << 24);
682          ptr += 4;
683          sizes.y = static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8) | (*(ptr + 2)) << 16 | (*(ptr + 3)) << 24);
684          ptr += 4;
685          sizes.z = static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8) | (*(ptr + 2)) << 16 | (*(ptr + 3)) << 24);
686          ptr += 4;
687      }
688      return sizes;
689  }
690  
readFileToString(std::string_view aFilename)691  std::string readFileToString(std::string_view aFilename)
692  {
693      std::stringstream ss;
694      std::ifstream file;
695  
696      file.exceptions(std::ifstream::failbit | std::ifstream::badbit);
697      try {
698          file.open(aFilename.data(), std::ios::in);
699  
700          if (!file.fail()) {
701              ss << file.rdbuf();
702              return ss.str();
703          }
704      } catch (std::exception const& ex) {
705          LUME_LOG_E("Error reading file: '%s': %s", aFilename.data(), ex.what());
706          return {};
707      }
708      return {};
709  }
710  
711  class FileIncluder : public glslang::TShader::Includer {
712  public:
713      const CompilationSettings& settings;
FileIncluder(const CompilationSettings & compilationSettings)714      FileIncluder(const CompilationSettings& compilationSettings) : settings(compilationSettings) {}
715  
716  private:
include(const char * headerName,const char * includerName,size_t inclusionDepth,bool relative)717      virtual IncludeResult* include(
718          const char* headerName, const char* includerName, size_t inclusionDepth, bool relative)
719      {
720          std::filesystem::path path;
721          bool found = false;
722          if (relative == true) {
723              path.append(settings.shaderSourcePath.c_str());
724              path.append(includerName);
725              path = path.parent_path();
726              path.append(headerName);
727              found = std::filesystem::exists(path);
728          }
729  
730          for (int i = 0; i < settings.shaderIncludePaths.size() && found == false; ++i) {
731              path.assign(settings.shaderIncludePaths[i]);
732              path.append(headerName);
733              found = std::filesystem::exists(path);
734          }
735  
736          if (found == true) {
737              auto str = path.string();
738  
739              std::ifstream file(path);
740              file.seekg(0, file.end);
741              std::streampos length = file.tellg();
742              file.seekg(0, file.beg);
743  
744              char* memory = new (std::nothrow) char[length + std::streampos(1)];
745              if (memory == 0) {
746                  return nullptr;
747              }
748  
749              char* last = std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), memory);
750              IncludeResult* result = new (std::nothrow) IncludeResult(str, memory, std::distance(memory, last), 0);
751              if (result == 0) {
752                  delete memory;
753                  return nullptr;
754              }
755  
756              return result;
757          }
758  
759          return nullptr;
760      }
761  
includeSystem(const char * headerName,const char * includerName,size_t inclusionDepth)762      virtual IncludeResult* includeSystem(const char* headerName, const char* includerName, size_t inclusionDepth)
763      {
764          return include(headerName, includerName, inclusionDepth, false);
765      }
766  
includeLocal(const char * headerName,const char * includerName,size_t inclusionDepth)767      virtual IncludeResult* includeLocal(const char* headerName, const char* includerName, size_t inclusionDepth)
768      {
769          return include(headerName, includerName, inclusionDepth, true);
770      }
771  
releaseInclude(IncludeResult * include)772      virtual void releaseInclude(IncludeResult* include)
773      {
774          delete include;
775      }
776  };
777  
ToSpirVVersion(glslang::EShTargetClientVersion env_version)778  glslang::EShTargetLanguageVersion ToSpirVVersion(glslang::EShTargetClientVersion env_version)
779  {
780      if (env_version == glslang::EShTargetVulkan_1_0) {
781          return glslang::EShTargetSpv_1_0;
782      } else if (env_version == glslang::EShTargetVulkan_1_1) {
783          return glslang::EShTargetSpv_1_3;
784      } else if (env_version == glslang::EShTargetVulkan_1_2) {
785          return glslang::EShTargetSpv_1_5;
786  #if GLSLANG_VERSION >= GLSLANG_VERSION_12_2_0
787      } else if (env_version == glslang::EShTargetVulkan_1_3) {
788          return glslang::EShTargetSpv_1_6;
789  #endif
790      } else {
791          return glslang::EShTargetSpv_1_0;
792      }
793  }
794  
preProcessShader(std::string_view aSource,ShaderKind aKind,std::string_view aSourceName,const CompilationSettings & settings)795  std::string preProcessShader(
796      std::string_view aSource, ShaderKind aKind, std::string_view aSourceName, const CompilationSettings& settings)
797  {
798      glslang::EShTargetLanguageVersion languageVersion;
799      glslang::EShTargetClientVersion version;
800      EShLanguage stage;
801      switch (aKind) {
802          case ShaderKind::VERTEX:
803              stage = EShLanguage::EShLangVertex;
804              break;
805          case ShaderKind::FRAGMENT:
806              stage = EShLanguage::EShLangFragment;
807              break;
808          case ShaderKind::COMPUTE:
809              stage = EShLanguage::EShLangCompute;
810              break;
811          default:
812              LUME_LOG_E("Spirv preprocessing compilation failed '%s'", "ShaderKind not recognized");
813              return {};
814      }
815  
816      switch (settings.shaderEnv) {
817          case ShaderEnv::version_vulkan_1_0:
818              version = glslang::EShTargetClientVersion::EShTargetVulkan_1_0;
819              break;
820          case ShaderEnv::version_vulkan_1_1:
821              version = glslang::EShTargetClientVersion::EShTargetVulkan_1_1;
822              break;
823          case ShaderEnv::version_vulkan_1_2:
824              version = glslang::EShTargetClientVersion::EShTargetVulkan_1_2;
825              break;
826  #if GLSLANG_VERSION >= GLSLANG_VERSION_12_2_0
827          case ShaderEnv::version_vulkan_1_3:
828              version = glslang::EShTargetClientVersion::EShTargetVulkan_1_3;
829              break;
830  #endif
831          default:
832              LUME_LOG_E("Spirv preprocessing compilation failed '%s'", "ShaderEnv not recognized");
833              return {};
834      }
835  
836      languageVersion = ToSpirVVersion(version);
837  
838      FileIncluder includer(settings);
839      glslang::TShader shader(stage);
840      const char* shader_strings = aSource.data();
841      const int shader_lengths = static_cast<int>(aSource.size());
842      const char* string_names = aSourceName.data();
843      std::string_view preamble = "#extension GL_GOOGLE_include_directive : enable\n";
844      shader.setStringsWithLengthsAndNames(&shader_strings, &shader_lengths, &string_names, 1);
845      shader.setPreamble(preamble.data());
846      shader.setEntryPoint("main");
847      shader.setAutoMapBindings(false);
848      shader.setAutoMapLocations(false);
849      shader.setShiftImageBinding(0);
850      shader.setShiftSamplerBinding(0);
851      shader.setShiftTextureBinding(0);
852      shader.setShiftUboBinding(0);
853      shader.setShiftSsboBinding(0);
854      shader.setShiftUavBinding(0);
855      shader.setEnvClient(glslang::EShClient::EShClientVulkan, version);
856      shader.setEnvTarget(glslang::EShTargetLanguage::EShTargetSpv, languageVersion);
857      shader.setInvertY(false);
858      shader.setNanMinMaxClamp(false);
859  
860      std::string output;
861      const EShMessages rules =
862          static_cast<EShMessages>(EShMsgOnlyPreprocessor | EShMsgSpvRules | EShMsgVulkanRules | EShMsgCascadingErrors);
863      if (!shader.preprocess(
864              &kGLSLangDefaultTResource, 110, EProfile::ENoProfile, false, false, rules, &output, includer)) {
865          LUME_LOG_E("Spirv preprocessing compilation failed '%s':\n%s", aSourceName.data(), shader.getInfoLog());
866          LUME_LOG_E("Spirv preprocessing compilation failed '%s':\n%s", aSourceName.data(), shader.getInfoDebugLog());
867  
868          output = { output.begin() + preamble.size(), output.end() };
869          return {};
870      }
871  
872      output = { output.begin() + preamble.size(), output.end() };
873      return output;
874  }
875  
compileShaderToSpirvBinary(std::string_view aSource,ShaderKind aKind,std::string_view aSourceName,const CompilationSettings & settings)876  std::vector<uint32_t> compileShaderToSpirvBinary(
877      std::string_view aSource, ShaderKind aKind, std::string_view aSourceName, const CompilationSettings& settings)
878  {
879      glslang::EShTargetLanguageVersion languageVersion;
880      glslang::EShTargetClientVersion version;
881      EShLanguage stage;
882      switch (aKind) {
883          case ShaderKind::VERTEX:
884              stage = EShLanguage::EShLangVertex;
885              break;
886          case ShaderKind::FRAGMENT:
887              stage = EShLanguage::EShLangFragment;
888              break;
889          case ShaderKind::COMPUTE:
890              stage = EShLanguage::EShLangCompute;
891              break;
892          default:
893              LUME_LOG_E("Spirv binary compilation failed '%s'", "ShaderKind not recognized");
894              return {};
895      }
896  
897      switch (settings.shaderEnv) {
898          case ShaderEnv::version_vulkan_1_0:
899              version = glslang::EShTargetClientVersion::EShTargetVulkan_1_0;
900              break;
901          case ShaderEnv::version_vulkan_1_1:
902              version = glslang::EShTargetClientVersion::EShTargetVulkan_1_1;
903              break;
904          case ShaderEnv::version_vulkan_1_2:
905              version = glslang::EShTargetClientVersion::EShTargetVulkan_1_2;
906              break;
907  #if GLSLANG_VERSION >= GLSLANG_VERSION_12_2_0
908          case ShaderEnv::version_vulkan_1_3:
909              version = glslang::EShTargetClientVersion::EShTargetVulkan_1_3;
910              break;
911  #endif
912          default:
913              LUME_LOG_E("Spirv binary compilation failed '%s'", "ShaderEnv not recognized");
914              return {};
915      }
916  
917      languageVersion = ToSpirVVersion(version);
918  
919      glslang::TShader shader(stage);
920      const char* shader_strings = aSource.data();
921      const int shader_lengths = static_cast<int>(aSource.size());
922      const char* string_names = aSourceName.data();
923      shader.setStringsWithLengthsAndNames(&shader_strings, &shader_lengths, &string_names, 1);
924      shader.setPreamble("#extension GL_GOOGLE_include_directive : enable\n");
925      shader.setEntryPoint("main");
926      shader.setAutoMapBindings(false);
927      shader.setAutoMapLocations(false);
928      shader.setShiftImageBinding(0);
929      shader.setShiftSamplerBinding(0);
930      shader.setShiftTextureBinding(0);
931      shader.setShiftUboBinding(0);
932      shader.setShiftSsboBinding(0);
933      shader.setShiftUavBinding(0);
934      shader.setEnvClient(glslang::EShClient::EShClientVulkan, version);
935      shader.setEnvTarget(glslang::EShTargetLanguage::EShTargetSpv, languageVersion);
936      shader.setInvertY(false);
937      shader.setNanMinMaxClamp(false);
938  
939      const EShMessages rules = static_cast<EShMessages>(EShMsgSpvRules | EShMsgVulkanRules | EShMsgCascadingErrors);
940      if (!shader.parse(&kGLSLangDefaultTResource, 110, EProfile::ENoProfile, false, false, rules)) {
941          LUME_LOG_E("Spirv binary compilation failed '%s':\n%s", aSourceName.data(), shader.getInfoLog());
942          LUME_LOG_E("Spirv binary compilation failed '%s':\n%s", aSourceName.data(), shader.getInfoDebugLog());
943          return {};
944      }
945  
946      glslang::TProgram program;
947      program.addShader(&shader);
948      if (!program.link(EShMsgDefault) || !program.mapIO()) {
949          LUME_LOG_E("Spirv binary compilation failed '%s':\n%s", aSourceName.data(), program.getInfoLog());
950          LUME_LOG_E("Spirv binary compilation failed '%s':\n%s", aSourceName.data(), program.getInfoDebugLog());
951          return {};
952      }
953  
954      std::vector<unsigned int> spirv;
955      glslang::SpvOptions spv_options;
956      spv_options.generateDebugInfo = false;
957      spv_options.disableOptimizer = true;
958      spv_options.optimizeSize = false;
959      spv::SpvBuildLogger logger;
960      glslang::TIntermediate* intermediate = program.getIntermediate(stage);
961      glslang::GlslangToSpv(*intermediate, spirv, &logger, &spv_options);
962  
963      const uint32_t shadercGeneratorWord = 13; // From SPIR-V XML Registry
964      const uint32_t generatorWordIndex = 2;    // SPIR-V 2.3: Physical layout
965      assert(spirv.size() > generatorWordIndex);
966      spirv[generatorWordIndex] = (spirv[generatorWordIndex] & 0xffff) | (shadercGeneratorWord << 16u);
967      return spirv;
968  }
969  
processResource(const spirv_cross::Compiler & compiler,const spirv_cross::Resource & resource,ShaderStageFlags shaderStateFlags,DescriptorType type,DescriptorSetLayout * layouts)970  void processResource(const spirv_cross::Compiler& compiler, const spirv_cross::Resource& resource,
971      ShaderStageFlags shaderStateFlags, DescriptorType type, DescriptorSetLayout* layouts)
972  {
973      const uint32_t set = compiler.get_decoration(resource.id, spv::DecorationDescriptorSet);
974  
975      assert(set < PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT);
976      if (set >= PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT) {
977          return;
978      }
979      DescriptorSetLayout& layout = layouts[set];
980      layout.set = set;
981  
982      // Collect bindings.
983      const uint32_t bindingIndex = compiler.get_decoration(resource.id, spv::DecorationBinding);
984      auto& bindings = layout.bindings;
985      if (auto pos = std::find_if(bindings.begin(), bindings.end(),
986              [bindingIndex](const DescriptorSetLayoutBinding& binding) { return binding.binding == bindingIndex; });
987          pos == bindings.end()) {
988          const spirv_cross::SPIRType& spirType = compiler.get_type(resource.type_id);
989  
990          DescriptorSetLayoutBinding binding;
991          binding.binding = bindingIndex;
992          binding.descriptorType = type;
993          binding.descriptorCount = spirType.array.empty() ? 1 : spirType.array[0];
994          binding.shaderStageFlags = shaderStateFlags;
995  
996          bindings.emplace_back(binding);
997      } else {
998          pos->shaderStageFlags |= shaderStateFlags;
999      }
1000  }
1001  
reflectDescriptorSets(const spirv_cross::Compiler & compiler,const spirv_cross::ShaderResources & resources,ShaderStageFlags shaderStateFlags,DescriptorSetLayout * layouts)1002  void reflectDescriptorSets(const spirv_cross::Compiler& compiler, const spirv_cross::ShaderResources& resources,
1003      ShaderStageFlags shaderStateFlags, DescriptorSetLayout* layouts)
1004  {
1005      for (const auto& ref : resources.sampled_images) {
1006          processResource(compiler, ref, shaderStateFlags, DescriptorType::COMBINED_IMAGE_SAMPLER, layouts);
1007      }
1008  
1009      for (const auto& ref : resources.separate_samplers) {
1010          processResource(compiler, ref, shaderStateFlags, DescriptorType::SAMPLER, layouts);
1011      }
1012  
1013      for (const auto& ref : resources.separate_images) {
1014          processResource(compiler, ref, shaderStateFlags, DescriptorType::SAMPLED_IMAGE, layouts);
1015      }
1016  
1017      for (const auto& ref : resources.storage_images) {
1018          processResource(compiler, ref, shaderStateFlags, DescriptorType::STORAGE_IMAGE, layouts);
1019      }
1020  
1021      for (const auto& ref : resources.uniform_buffers) {
1022          processResource(compiler, ref, shaderStateFlags, DescriptorType::UNIFORM_BUFFER, layouts);
1023      }
1024  
1025      for (const auto& ref : resources.storage_buffers) {
1026          processResource(compiler, ref, shaderStateFlags, DescriptorType::STORAGE_BUFFER, layouts);
1027      }
1028  
1029      for (const auto& ref : resources.subpass_inputs) {
1030          processResource(compiler, ref, shaderStateFlags, DescriptorType::INPUT_ATTACHMENT, layouts);
1031      }
1032  
1033      for (const auto& ref : resources.acceleration_structures) {
1034          processResource(compiler, ref, shaderStateFlags, DescriptorType::ACCELERATION_STRUCTURE, layouts);
1035      }
1036  
1037      std::sort(layouts, layouts + PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT,
1038          [](const DescriptorSetLayout& lhs, const DescriptorSetLayout& rhs) { return (lhs.set < rhs.set); });
1039  
1040      std::for_each(
1041          layouts, layouts + PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT, [](DescriptorSetLayout& layout) {
1042              std::sort(layout.bindings.begin(), layout.bindings.end(),
1043                  [](const DescriptorSetLayoutBinding& lhs, const DescriptorSetLayoutBinding& rhs) {
1044                      return (lhs.binding < rhs.binding);
1045                  });
1046          });
1047  }
1048  
reflectPushContants(const spirv_cross::Compiler & compiler,const spirv_cross::ShaderResources & resources,ShaderStageFlags shaderStateFlags,PushConstant & pushConstant)1049  void reflectPushContants(const spirv_cross::Compiler& compiler, const spirv_cross::ShaderResources& resources,
1050      ShaderStageFlags shaderStateFlags, PushConstant& pushConstant)
1051  {
1052      // NOTE: support for only one push constant
1053      if (resources.push_constant_buffers.size() > 0) {
1054          pushConstant.shaderStageFlags |= shaderStateFlags;
1055  
1056          const auto ranges = compiler.get_active_buffer_ranges(resources.push_constant_buffers[0].id);
1057          const uint32_t byteSize = std::accumulate(
1058              ranges.begin(), ranges.end(), 0u, [](uint32_t byteSize, const spirv_cross::BufferRange& range) {
1059                  return byteSize + static_cast<uint32_t>(range.range);
1060              });
1061          pushConstant.byteSize = std::max(pushConstant.byteSize, byteSize);
1062      }
1063  }
1064  
reflectSpecializationConstants(const spirv_cross::Compiler & compiler,ShaderStageFlags shaderStateFlags)1065  std::vector<ShaderSpecializationConstant> reflectSpecializationConstants(
1066      const spirv_cross::Compiler& compiler, ShaderStageFlags shaderStateFlags)
1067  {
1068      std::vector<ShaderSpecializationConstant> specializationConstants;
1069      uint32_t offset = 0;
1070      for (auto const& constant : compiler.get_specialization_constants()) {
1071          if (constant.constant_id < RESERVED_CONSTANT_ID_INDEX) {
1072              const spirv_cross::SPIRConstant& spirvConstant = compiler.get_constant(constant.id);
1073              const auto type = compiler.get_type(spirvConstant.constant_type);
1074              ShaderSpecializationConstant::Type constantType = ShaderSpecializationConstant::Type::INVALID;
1075              if (type.basetype == spirv_cross::SPIRType::Boolean) {
1076                  constantType = ShaderSpecializationConstant::Type::BOOL;
1077              } else if (type.basetype == spirv_cross::SPIRType::UInt) {
1078                  constantType = ShaderSpecializationConstant::Type::UINT32;
1079              } else if (type.basetype == spirv_cross::SPIRType::Int) {
1080                  constantType = ShaderSpecializationConstant::Type::INT32;
1081              } else if (type.basetype == spirv_cross::SPIRType::Float) {
1082                  constantType = ShaderSpecializationConstant::Type::FLOAT;
1083              } else {
1084                  assert(false && "Unhandled specialization constant type");
1085              }
1086              const uint32_t size = spirvConstant.vector_size() * spirvConstant.columns() * sizeof(uint32_t);
1087              specializationConstants.push_back(
1088                  ShaderSpecializationConstant { shaderStateFlags, constant.constant_id, constantType, offset });
1089              offset += size;
1090          }
1091      }
1092      // sorted based on offset due to offset mapping with shader combinations
1093      // NOTE: id and name indexing
1094      std::sort(specializationConstants.begin(), specializationConstants.end(),
1095          [](const auto& lhs, const auto& rhs) { return (lhs.offset < rhs.offset); });
1096  
1097      return specializationConstants;
1098  }
1099  
convertToVertexInputFormat(const spirv_cross::SPIRType & type)1100  Format convertToVertexInputFormat(const spirv_cross::SPIRType& type)
1101  {
1102      using BaseType = spirv_cross::SPIRType::BaseType;
1103  
1104      // ivecn: a vector of signed integers
1105      if (type.basetype == BaseType::Int) {
1106          switch (type.vecsize) {
1107              case 1:
1108                  return Format::R32_SINT;
1109              case 2:
1110                  return Format::R32G32_SINT;
1111              case 3:
1112                  return Format::R32G32B32_SINT;
1113              case 4:
1114                  return Format::R32G32B32A32_SINT;
1115          }
1116      }
1117  
1118      // uvecn: a vector of unsigned integers
1119      if (type.basetype == BaseType::UInt) {
1120          switch (type.vecsize) {
1121              case 1:
1122                  return Format::R32_UINT;
1123              case 2:
1124                  return Format::R32G32_UINT;
1125              case 3:
1126                  return Format::R32G32B32_UINT;
1127              case 4:
1128                  return Format::R32G32B32A32_UINT;
1129          }
1130      }
1131  
1132      // halfn: a vector of half-precision floating-point numbers
1133      if (type.basetype == BaseType::Half) {
1134          switch (type.vecsize) {
1135              case 1:
1136                  return Format::R16_SFLOAT;
1137              case 2:
1138                  return Format::R16G16_SFLOAT;
1139              case 3:
1140                  return Format::R16G16B16_SFLOAT;
1141              case 4:
1142                  return Format::R16G16B16A16_SFLOAT;
1143          }
1144      }
1145  
1146      // vecn: a vector of single-precision floating-point numbers
1147      if (type.basetype == BaseType::Float) {
1148          switch (type.vecsize) {
1149              case 1:
1150                  return Format::R32_SFLOAT;
1151              case 2:
1152                  return Format::R32G32_SFLOAT;
1153              case 3:
1154                  return Format::R32G32B32_SFLOAT;
1155              case 4:
1156                  return Format::R32G32B32A32_SFLOAT;
1157          }
1158      }
1159  
1160      return Format::UNDEFINED;
1161  }
1162  
reflectVertexInputs(const spirv_cross::Compiler & compiler,const spirv_cross::ShaderResources & resources,ShaderStageFlags,std::vector<VertexInputDeclaration::VertexInputAttributeDescription> & vertexInputAttributes)1163  void reflectVertexInputs(const spirv_cross::Compiler& compiler, const spirv_cross::ShaderResources& resources,
1164      ShaderStageFlags /* shaderStateFlags */,
1165      std::vector<VertexInputDeclaration::VertexInputAttributeDescription>& vertexInputAttributes)
1166  {
1167      std::vector<VertexAttributeInfo> vertexAttributeInfos;
1168  
1169      // Vertex input attributes.
1170      for (auto& attr : resources.stage_inputs) {
1171          const spirv_cross::SPIRType attributeType = compiler.get_type(attr.type_id);
1172  
1173          VertexAttributeInfo info;
1174  
1175          // For now, assume that every vertex attribute comes from it's own binding which equals the location.
1176          info.description.location = compiler.get_decoration(attr.id, spv::DecorationLocation);
1177          info.description.binding = info.description.location;
1178          info.description.format = convertToVertexInputFormat(attributeType);
1179          info.description.offset = 0;
1180  
1181          info.byteSize = attributeType.vecsize * (attributeType.width / 8);
1182  
1183          vertexAttributeInfos.emplace_back(std::move(info));
1184      }
1185  
1186      // Sort input attributes by binding and location.
1187      std::sort(std::begin(vertexAttributeInfos), std::end(vertexAttributeInfos),
1188          [](const VertexAttributeInfo& aA, const VertexAttributeInfo& aB) {
1189              if (aA.description.binding < aB.description.binding) {
1190                  return true;
1191              }
1192  
1193              return aA.description.location < aB.description.location;
1194          });
1195  
1196      // Create final attributes.
1197      if (!vertexAttributeInfos.empty()) {
1198          for (auto& info : vertexAttributeInfos) {
1199              vertexInputAttributes.push_back(info.description);
1200          }
1201      }
1202  }
1203  
1204  template<typename T>
push(std::vector<uint8_t> & buffer,T data)1205  void push(std::vector<uint8_t>& buffer, T data)
1206  {
1207      buffer.push_back(data & 0xff);
1208      if constexpr (sizeof(T) > 1) {
1209          buffer.push_back((data >> 8) & 0xff);
1210      }
1211      if constexpr (sizeof(T) > 2) {
1212          buffer.push_back((data >> 16) & 0xff);
1213      }
1214      if constexpr (sizeof(T) > 3) {
1215          buffer.push_back((data >> 24) & 0xff);
1216      }
1217  }
1218  
reflectSpvBinary(const std::vector<uint32_t> & aBinary,ShaderKind aKind)1219  std::vector<uint8_t> reflectSpvBinary(const std::vector<uint32_t>& aBinary, ShaderKind aKind)
1220  {
1221      const spirv_cross::Compiler compiler(aBinary);
1222  
1223      const auto shaderStateFlags = ShaderStageFlags(aKind);
1224  
1225      const spirv_cross::ShaderResources resources = compiler.get_shader_resources();
1226  
1227      PipelineLayout pipelineLayout;
1228      reflectDescriptorSets(compiler, resources, shaderStateFlags, pipelineLayout.descriptorSetLayouts);
1229      pipelineLayout.descriptorSetCount =
1230          static_cast<uint32_t>(std::count_if(std::begin(pipelineLayout.descriptorSetLayouts),
1231              std::end(pipelineLayout.descriptorSetLayouts), [](const DescriptorSetLayout& layout) {
1232                  return layout.set < PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT;
1233              }));
1234      reflectPushContants(compiler, resources, shaderStateFlags, pipelineLayout.pushConstant);
1235  
1236      // some additional information mainly for GL
1237      std::vector<Gles::PushConstantReflection> pushConstantReflection;
1238      for (auto& remap : resources.push_constant_buffers) {
1239          const auto& blockType = compiler.get_type(remap.base_type_id);
1240          auto name = compiler.get_name(remap.id);
1241          (void)(blockType);
1242          assert((blockType.basetype == spirv_cross::SPIRType::Struct) && "Push constant is not a struct!");
1243          Gles::ProcessStruct(std::string_view(name.data(), name.size()), 0, compiler, remap.base_type_id,
1244              pushConstantReflection, shaderStateFlags);
1245      }
1246  
1247      auto specializationConstants = reflectSpecializationConstants(compiler, shaderStateFlags);
1248  
1249      // NOTE: this is done for all although the name is 'Vertex'InputAttributes
1250      std::vector<VertexInputDeclaration::VertexInputAttributeDescription> vertexInputAttributes;
1251      reflectVertexInputs(compiler, resources, shaderStateFlags, vertexInputAttributes);
1252  
1253      std::vector<uint8_t> reflection;
1254      reflection.reserve(512u);
1255      static constexpr uint8_t tag[] = { 'r', 'f', 'l', 0 }; // last one is version
1256      uint16_t type = 0;
1257      uint16_t offsetPushConstants = 0;
1258      uint16_t offsetSpecializationConstants = 0;
1259      uint16_t offsetDescriptorSets = 0;
1260      uint16_t offsetInputs = 0;
1261      uint16_t offsetLocalSize = 0;
1262      // tag
1263      {
1264          reflection.insert(reflection.end(), std::begin(tag), std::end(tag));
1265      }
1266      // shader type
1267      {
1268          push(reflection, static_cast<uint16_t>(shaderStateFlags.flags));
1269      }
1270      // offsets
1271      {
1272          reflection.resize(reflection.size() + sizeof(uint16_t) * 5);
1273      }
1274      // push constants
1275      {
1276          offsetPushConstants = static_cast<uint16_t>(reflection.size());
1277          if (pipelineLayout.pushConstant.byteSize) {
1278              reflection.push_back(1);
1279              push(reflection, static_cast<uint16_t>(pipelineLayout.pushConstant.byteSize));
1280  
1281              push(reflection, static_cast<uint8_t>(pushConstantReflection.size()));
1282              for (const auto& refl : pushConstantReflection) {
1283                  push(reflection, refl.type);
1284                  push(reflection, static_cast<uint16_t>(refl.offset));
1285                  push(reflection, static_cast<uint16_t>(refl.size));
1286                  push(reflection, static_cast<uint16_t>(refl.arraySize));
1287                  push(reflection, static_cast<uint16_t>(refl.arrayStride));
1288                  push(reflection, static_cast<uint16_t>(refl.matrixStride));
1289                  push(reflection, static_cast<uint16_t>(refl.name.size()));
1290                  reflection.insert(reflection.end(), std::begin(refl.name), std::end(refl.name));
1291              }
1292          } else {
1293              reflection.push_back(0);
1294          }
1295      }
1296      // specialization constants
1297      {
1298          offsetSpecializationConstants = static_cast<uint16_t>(reflection.size());
1299          {
1300              const auto size = static_cast<uint32_t>(specializationConstants.size());
1301              push(reflection, static_cast<uint32_t>(specializationConstants.size()));
1302          }
1303          for (auto const& constant : specializationConstants) {
1304              push(reflection, static_cast<uint32_t>(constant.id));
1305              push(reflection, static_cast<uint32_t>(constant.type));
1306          }
1307      }
1308      // descriptor sets
1309      {
1310          offsetDescriptorSets = static_cast<uint16_t>(reflection.size());
1311          {
1312              push(reflection, static_cast<uint16_t>(pipelineLayout.descriptorSetCount));
1313          }
1314          auto begin = std::begin(pipelineLayout.descriptorSetLayouts);
1315          auto end = begin;
1316          std::advance(end, pipelineLayout.descriptorSetCount);
1317          std::for_each(begin, end, [&reflection](const DescriptorSetLayout& layout) {
1318              push(reflection, static_cast<uint16_t>(layout.set));
1319              push(reflection, static_cast<uint16_t>(layout.bindings.size()));
1320              for (const auto& binding : layout.bindings) {
1321                  push(reflection, static_cast<uint16_t>(binding.binding));
1322                  push(reflection, static_cast<uint16_t>(binding.descriptorType));
1323                  push(reflection, static_cast<uint16_t>(binding.descriptorCount));
1324              }
1325          });
1326      }
1327      // inputs
1328      {
1329          offsetInputs = static_cast<uint16_t>(reflection.size());
1330          const auto size = static_cast<uint16_t>(vertexInputAttributes.size());
1331          push(reflection, size);
1332          for (const auto& input : vertexInputAttributes) {
1333              push(reflection, static_cast<uint16_t>(input.location));
1334              push(reflection, static_cast<uint16_t>(input.format));
1335          }
1336      }
1337      // local size
1338      if (shaderStateFlags & ShaderStageFlagBits::COMPUTE_BIT) {
1339          offsetLocalSize = static_cast<uint16_t>(reflection.size());
1340          uint32_t size = compiler.get_execution_mode_argument(spv::ExecutionMode::ExecutionModeLocalSize, 0);
1341          push(reflection, size);
1342  
1343          size = compiler.get_execution_mode_argument(spv::ExecutionMode::ExecutionModeLocalSize, 1);
1344          push(reflection, size);
1345  
1346          size = compiler.get_execution_mode_argument(spv::ExecutionMode::ExecutionModeLocalSize, 2);
1347          push(reflection, size);
1348      }
1349      // update offsets to real values
1350      {
1351          auto ptr = reflection.data() + (sizeof(tag) + sizeof(type));
1352          *ptr++ = offsetPushConstants & 0xff;
1353          *ptr++ = (offsetPushConstants >> 8) & 0xff;
1354          *ptr++ = offsetSpecializationConstants & 0xff;
1355          *ptr++ = (offsetSpecializationConstants >> 8) & 0xff;
1356          *ptr++ = offsetDescriptorSets & 0xff;
1357          *ptr++ = (offsetDescriptorSets >> 8) & 0xff;
1358          *ptr++ = offsetInputs & 0xff;
1359          *ptr++ = (offsetInputs >> 8) & 0xff;
1360          *ptr++ = offsetLocalSize & 0xff;
1361          *ptr++ = (offsetLocalSize >> 8) & 0xff;
1362      }
1363  
1364      return reflection;
1365  }
1366  
1367  struct Binding {
1368      uint8_t set;
1369      uint8_t bind;
1370  };
1371  
get_binding(Gles::CoreCompiler & compiler,spirv_cross::ID id)1372  Binding get_binding(Gles::CoreCompiler& compiler, spirv_cross::ID id)
1373  {
1374      const uint32_t dset = compiler.get_decoration(id, spv::Decoration::DecorationDescriptorSet);
1375      const uint32_t dbind = compiler.get_decoration(id, spv::Decoration::DecorationBinding);
1376      assert(dset < Gles::ResourceLimits::MAX_SETS);
1377      assert(dbind < Gles::ResourceLimits::MAX_BIND_IN_SET);
1378      const uint8_t set = static_cast<uint8_t>(dset);
1379      const uint8_t bind = static_cast<uint8_t>(dbind);
1380      return { set, bind };
1381  }
1382  
SortSets(PipelineLayout & pipelineLayout)1383  void SortSets(PipelineLayout& pipelineLayout)
1384  {
1385      pipelineLayout.descriptorSetCount = 0;
1386      for (uint32_t idx = 0; idx < PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT; ++idx) {
1387          DescriptorSetLayout& currSet = pipelineLayout.descriptorSetLayouts[idx];
1388          if (currSet.set != PipelineLayoutConstants::INVALID_INDEX) {
1389              pipelineLayout.descriptorSetCount++;
1390              std::sort(currSet.bindings.begin(), currSet.bindings.end(),
1391                  [](auto const& lhs, auto const& rhs) { return (lhs.binding < rhs.binding); });
1392          }
1393      }
1394  }
1395  
Collect(Gles::CoreCompiler & compiler,const spirv_cross::SmallVector<spirv_cross::Resource> & resources,const uint32_t forceBinding=0)1396  void Collect(Gles::CoreCompiler& compiler, const spirv_cross::SmallVector<spirv_cross::Resource>& resources,
1397      const uint32_t forceBinding = 0)
1398  {
1399      std::string name;
1400  
1401      for (const auto& remap : resources) {
1402          const auto binding = get_binding(compiler, remap.id);
1403  
1404          name.resize(name.capacity() - 1);
1405          const auto nameLen = sprintf(name.data(), "s%u_b%u", binding.set, binding.bind);
1406          name.resize(nameLen);
1407  
1408          // if name is empty it's a block and we need to rename the base_type_id i.e.
1409          // uniform <base_type_id> { vec4 foo; } <id>;
1410          if (auto origname = compiler.get_name(remap.id); origname.empty()) {
1411              compiler.set_name(remap.base_type_id, name);
1412              name.insert(name.begin(), '_');
1413              compiler.set_name(remap.id, name);
1414          } else {
1415              // uniform <id> vec4 foo;
1416              compiler.set_name(remap.id, name);
1417          }
1418  
1419          compiler.unset_decoration(remap.id, spv::DecorationDescriptorSet);
1420          compiler.unset_decoration(remap.id, spv::DecorationBinding);
1421          if (forceBinding > 0) {
1422              compiler.set_decoration(
1423                  remap.id, spv::DecorationBinding, forceBinding - 1); // will be over-written later. (special handling)
1424          }
1425      }
1426  }
1427  
1428  struct ShaderModulePlatformDataGLES {
1429      std::vector<Gles::PushConstantReflection> infos;
1430  };
1431  
CollectRes(Gles::CoreCompiler & compiler,const spirv_cross::ShaderResources & res,ShaderModulePlatformDataGLES & plat_)1432  void CollectRes(
1433      Gles::CoreCompiler& compiler, const spirv_cross::ShaderResources& res, ShaderModulePlatformDataGLES& plat_)
1434  {
1435      // collect names for later linkage
1436      static constexpr uint32_t defaultBinding = 11;
1437      Collect(compiler, res.storage_buffers, defaultBinding + 1);
1438      Collect(compiler, res.storage_images, defaultBinding + 1);
1439      Collect(compiler, res.uniform_buffers, 0); // 0 == remove binding decorations (let's the compiler decide)
1440      Collect(compiler, res.subpass_inputs, 0);  // 0 == remove binding decorations (let's the compiler decide)
1441  
1442      // handle the real sampled images.
1443      Collect(compiler, res.sampled_images, 0); // 0 == remove binding decorations (let's the compiler decide)
1444  
1445      // and now the "generated ones" (separate image/sampler)
1446      std::string imageName;
1447      std::string samplerName;
1448      std::string temp;
1449      for (auto& remap : compiler.get_combined_image_samplers()) {
1450          const auto imageBinding = get_binding(compiler, remap.image_id);
1451          {
1452              imageName.resize(imageName.capacity() - 1);
1453              const auto nameLen = sprintf(imageName.data(), "s%u_b%u", imageBinding.set, imageBinding.bind);
1454              imageName.resize(nameLen);
1455          }
1456          const auto samplerBinding = get_binding(compiler, remap.sampler_id);
1457          {
1458              samplerName.resize(samplerName.capacity() - 1);
1459              const auto nameLen = sprintf(samplerName.data(), "s%u_b%u", samplerBinding.set, samplerBinding.bind);
1460              samplerName.resize(nameLen);
1461          }
1462  
1463          temp.reserve(imageName.size() + samplerName.size() + 1);
1464          temp.clear();
1465          temp.append(imageName);
1466          temp.append(1, '_');
1467          temp.append(samplerName);
1468          compiler.set_name(remap.combined_id, temp);
1469      }
1470  }
1471  
1472  /** Device backend type */
1473  enum class DeviceBackendType {
1474      /** Vulkan backend */
1475      VULKAN,
1476      /** GLES backend */
1477      OPENGLES,
1478      /** OpenGL backend */
1479      OPENGL
1480  };
1481  
SetupSpirvCross(ShaderStageFlags stage,Gles::CoreCompiler * compiler,DeviceBackendType backend,bool ovrEnabled)1482  void SetupSpirvCross(ShaderStageFlags stage, Gles::CoreCompiler* compiler, DeviceBackendType backend, bool ovrEnabled)
1483  {
1484      spirv_cross::CompilerGLSL::Options options;
1485  
1486      if (backend == DeviceBackendType::OPENGLES) {
1487          options.version = 320;
1488          options.es = true;
1489          options.fragment.default_float_precision = spirv_cross::CompilerGLSL::Options::Precision::Highp;
1490          options.fragment.default_int_precision = spirv_cross::CompilerGLSL::Options::Precision::Highp;
1491      }
1492  
1493      if (backend == DeviceBackendType::OPENGL) {
1494          options.version = 450;
1495          options.es = false;
1496      }
1497  
1498  #if defined(CORE_USE_SEPARATE_SHADER_OBJECTS) && (CORE_USE_SEPARATE_SHADER_OBJECTS == 1)
1499      if (stage & (CORE_SHADER_STAGE_VERTEX_BIT | CORE_SHADER_STAGE_FRAGMENT_BIT)) {
1500          options.separate_shader_objects = true;
1501      }
1502  #endif
1503  
1504      options.ovr_multiview_view_count = ovrEnabled ? 1 : 0;
1505  
1506      compiler->set_common_options(options);
1507  }
1508  
1509  struct Shader {
1510      ShaderStageFlags shaderStageFlags_;
1511      DeviceBackendType backend_;
1512      ShaderModulePlatformDataGLES plat_;
1513      bool ovrEnabled;
1514  
1515      std::string source_;
1516  };
1517  
ProcessShaderModule(Shader & me,const ShaderModuleCreateInfo & createInfo)1518  void ProcessShaderModule(Shader& me, const ShaderModuleCreateInfo& createInfo)
1519  {
1520      // perform reflection.
1521      auto compiler = Gles::CoreCompiler(reinterpret_cast<const uint32_t*>(createInfo.spvData.data()),
1522          static_cast<uint32_t>(createInfo.spvData.size() / sizeof(uint32_t)));
1523      // Set some options.
1524      SetupSpirvCross(me.shaderStageFlags_, &compiler, me.backend_, me.ovrEnabled);
1525  
1526      // first step in converting CORE_FLIP_NDC to regular uniform. (specializationconstant -> constant) this makes the
1527      // compiled glsl more readable, and simpler to post process later.
1528      Gles::ConvertSpecConstToConstant(compiler, "CORE_FLIP_NDC");
1529      // const auto& res = compiler.get_shader_resources();
1530  
1531      auto active = compiler.get_active_interface_variables();
1532      const auto& res = compiler.get_shader_resources(active);
1533      compiler.set_enabled_interface_variables(std::move(active));
1534  
1535      Gles::ReflectPushConstants(compiler, res, me.plat_.infos, me.shaderStageFlags_);
1536      compiler.build_combined_image_samplers();
1537      CollectRes(compiler, res, me.plat_);
1538  
1539      // set "CORE_BACKEND_TYPE" specialization to 1.
1540      Gles::SetSpecMacro(compiler, "CORE_BACKEND_TYPE", 1U);
1541  
1542      me.source_ = compiler.compile();
1543      Gles::ConvertConstantToUniform(compiler, me.source_, "CORE_FLIP_NDC");
1544  }
1545  
1546  template<typename T>
writeToFile(const array_view<T> & data,std::filesystem::path aDestinationFile)1547  bool writeToFile(const array_view<T>& data, std::filesystem::path aDestinationFile)
1548  {
1549      std::ofstream outputStream(aDestinationFile, std::ios::out | std::ios::binary);
1550      if (outputStream.is_open()) {
1551          outputStream.write(reinterpret_cast<const char*>(data.data()), data.size() * sizeof(T));
1552          outputStream.close();
1553          return true;
1554      } else {
1555          LUME_LOG_E("Could not write file: '%s'", aDestinationFile.string().c_str());
1556          return false;
1557      }
1558  }
1559  
runAllCompilationStages(std::string_view inputFilename,CompilationSettings & settings)1560  bool runAllCompilationStages(std::string_view inputFilename, CompilationSettings& settings)
1561  {
1562      try {
1563          // std::string inputFilename = aFile;
1564          const std::filesystem::path relativeInputFilename =
1565              std::filesystem::relative(inputFilename, settings.shaderSourcePath);
1566          const std::string relativeFilename = relativeInputFilename.string();
1567          const std::string extension = std::filesystem::path(inputFilename).extension().string();
1568          std::filesystem::path outputFilename = settings.compiledShaderDestinationPath / relativeInputFilename;
1569  
1570          // Make sure the output dir hierarchy exists.
1571          std::filesystem::create_directories(outputFilename.parent_path());
1572  
1573          // Just copying json files to the destination dir.
1574          if (extension == ".json") {
1575              if (!std::filesystem::exists(outputFilename) ||
1576                  !std::filesystem::equivalent(inputFilename, outputFilename)) {
1577                  LUME_LOG_I("  %s", relativeFilename.c_str());
1578                  std::filesystem::copy(inputFilename, outputFilename, std::filesystem::copy_options::overwrite_existing);
1579              }
1580              return true;
1581          } else {
1582              LUME_LOG_I("  %s", relativeFilename.c_str());
1583              outputFilename += ".spv";
1584  
1585              LUME_LOG_V("    input: '%s'", inputFilename.data());
1586              LUME_LOG_V("      dst: '%s'", settings.compiledShaderDestinationPath.string().c_str());
1587              LUME_LOG_V(" relative: '%s'", relativeFilename.c_str());
1588              LUME_LOG_V("   output: '%s'", outputFilename.string().c_str());
1589  
1590              if (std::string shaderSource = readFileToString(inputFilename); !shaderSource.empty()) {
1591                  ShaderKind shaderKind;
1592                  if (extension == ".vert") {
1593                      shaderKind = ShaderKind::VERTEX;
1594                  } else if (extension == ".frag") {
1595                      shaderKind = ShaderKind::FRAGMENT;
1596                  } else if (extension == ".comp") {
1597                      shaderKind = ShaderKind::COMPUTE;
1598                  } else {
1599                      return false;
1600                  }
1601  
1602                  if (std::string preProcessedShader =
1603                          preProcessShader(shaderSource, shaderKind, relativeFilename, settings);
1604                      !preProcessedShader.empty()) {
1605                      if (true) {
1606                          auto reflectionFile = outputFilename;
1607                          reflectionFile += ".pre";
1608                          if (!writeToFile(
1609                                  array_view(preProcessedShader.data(), preProcessedShader.size()), reflectionFile)) {
1610                              LUME_LOG_E("Failed to save reflection %s", reflectionFile.string().data());
1611                          }
1612                      }
1613  
1614                      if (std::vector<uint32_t> spvBinary =
1615                              compileShaderToSpirvBinary(preProcessedShader, shaderKind, relativeFilename, settings);
1616                          !spvBinary.empty()) {
1617                          const auto reflection = reflectSpvBinary(spvBinary, shaderKind);
1618                          if (reflection.empty()) {
1619                              LUME_LOG_E("Failed to reflect %s", inputFilename.data());
1620                          } else {
1621                              auto reflectionFile = outputFilename;
1622                              reflectionFile += ".lsb";
1623                              if (!writeToFile(array_view(reflection.data(), reflection.size()), reflectionFile)) {
1624                                  LUME_LOG_E("Failed to save reflection %s", reflectionFile.string().data());
1625                              }
1626                          }
1627  
1628                          if (settings.optimizer) {
1629                              // spirv-opt resets the passes everytime so then need to be setup
1630                              settings.optimizer->RegisterPerformancePasses();
1631                              if (!settings.optimizer->Run(spvBinary.data(), spvBinary.size(), &spvBinary)) {
1632                                  LUME_LOG_E("Failed to optimize %s", inputFilename.data());
1633                              }
1634                          }
1635  
1636                          if (writeToFile(array_view(spvBinary.data(), spvBinary.size()), outputFilename)) {
1637                              LUME_LOG_D("  -> %s", outputFilename.string().c_str());
1638                              auto glFile = outputFilename;
1639                              glFile += ".gl";
1640                              try {
1641                                  bool multiviewEnabled = false;
1642                                  if (shaderKind == ShaderKind::VERTEX) {
1643                                      static constexpr const std::string_view multiview = "GL_EXT_multiview";
1644                                      for (auto pos = shaderSource.find(multiview); pos != std::string::npos;
1645                                           pos = shaderSource.find(multiview, pos + multiview.size())) {
1646                                          if ((shaderSource.rfind("#extension", pos) != std::string::npos) &&
1647                                              (shaderSource.find("enabled", pos + multiview.size()) !=
1648                                                  std::string::npos)) {
1649                                              multiviewEnabled = true;
1650                                              break;
1651                                          }
1652                                      }
1653                                  }
1654                                  Shader shader;
1655                                  shader.shaderStageFlags_ =
1656                                      shaderKind == ShaderKind::VERTEX
1657                                          ? ShaderStageFlagBits::VERTEX_BIT
1658                                          : (shaderKind == ShaderKind::FRAGMENT ? ShaderStageFlagBits::FRAGMENT_BIT
1659                                                                                : ShaderStageFlagBits::COMPUTE_BIT);
1660  
1661                                  shader.backend_ = DeviceBackendType::OPENGL;
1662                                  shader.ovrEnabled = multiviewEnabled;
1663                                  ShaderModuleCreateInfo info;
1664                                  info.shaderStageFlags =
1665                                      shaderKind == ShaderKind::VERTEX
1666                                          ? ShaderStageFlagBits::VERTEX_BIT
1667                                          : (shaderKind == ShaderKind::FRAGMENT ? ShaderStageFlagBits::FRAGMENT_BIT
1668                                                                                : ShaderStageFlagBits::COMPUTE_BIT);
1669                                  info.spvData =
1670                                      array_view(static_cast<const uint8_t*>(static_cast<const void*>(spvBinary.data())),
1671                                          spvBinary.size() * sizeof(decltype(spvBinary)::value_type));
1672                                  info.reflectionData.reflectionData =
1673                                      array_view(static_cast<const uint8_t*>(static_cast<const void*>(reflection.data())),
1674                                          reflection.size() * sizeof(decltype(reflection)::value_type));
1675                                  ProcessShaderModule(shader, info);
1676                                  writeToFile(array_view(static_cast<const uint8_t*>(
1677                                                             static_cast<const void*>(shader.source_.data())),
1678                                                  shader.source_.size()),
1679                                      glFile);
1680                              } catch (std::exception const& e) {
1681                                  LUME_LOG_E("Failed to generate GL shader for '%s': %s", inputFilename.data(), e.what());
1682                              }
1683  
1684                              auto glesFile = glFile;
1685                              glesFile += "es";
1686                              try {
1687                                  bool multiviewEnabled = false;
1688                                  if (shaderKind == ShaderKind::VERTEX) {
1689                                      static constexpr const std::string_view multiview = "GL_EXT_multiview";
1690                                      for (auto pos = shaderSource.find(multiview); pos != std::string::npos;
1691                                           pos = shaderSource.find(multiview, pos + multiview.size())) {
1692                                          if ((shaderSource.rfind("#extension", pos) != std::string::npos) &&
1693                                              (shaderSource.find("enabled", pos + multiview.size()) !=
1694                                                  std::string::npos)) {
1695                                              multiviewEnabled = true;
1696                                              break;
1697                                          }
1698                                      }
1699                                  }
1700                                  Shader shader;
1701                                  shader.shaderStageFlags_ =
1702                                      shaderKind == ShaderKind::VERTEX
1703                                          ? ShaderStageFlagBits::VERTEX_BIT
1704                                          : (shaderKind == ShaderKind::FRAGMENT ? ShaderStageFlagBits::FRAGMENT_BIT
1705                                                                                : ShaderStageFlagBits::COMPUTE_BIT);
1706  
1707                                  shader.backend_ = DeviceBackendType::OPENGLES;
1708                                  shader.ovrEnabled = multiviewEnabled;
1709                                  ShaderModuleCreateInfo info;
1710                                  info.shaderStageFlags =
1711                                      shaderKind == ShaderKind::VERTEX
1712                                          ? ShaderStageFlagBits::VERTEX_BIT
1713                                          : (shaderKind == ShaderKind::FRAGMENT ? ShaderStageFlagBits::FRAGMENT_BIT
1714                                                                                : ShaderStageFlagBits::COMPUTE_BIT);
1715                                  info.spvData =
1716                                      array_view(static_cast<const uint8_t*>(static_cast<const void*>(spvBinary.data())),
1717                                          spvBinary.size() * sizeof(decltype(spvBinary)::value_type));
1718                                  info.reflectionData.reflectionData =
1719                                      array_view(static_cast<const uint8_t*>(static_cast<const void*>(reflection.data())),
1720                                          reflection.size() * sizeof(decltype(reflection)::value_type));
1721                                  ProcessShaderModule(shader, info);
1722                                  writeToFile(array_view(static_cast<const uint8_t*>(
1723                                                             static_cast<const void*>(shader.source_.data())),
1724                                                  shader.source_.size()),
1725                                      glesFile);
1726                              } catch (std::exception const& e) {
1727                                  LUME_LOG_E(
1728                                      "Failed to generate GLES shader for '%s': %s", inputFilename.data(), e.what());
1729                              }
1730  
1731                              return true;
1732                          }
1733                      }
1734                  }
1735              }
1736          }
1737      } catch (std::exception const& e) {
1738          LUME_LOG_E("Processing file failed '%s': %s", inputFilename.data(), e.what());
1739      }
1740      return false;
1741  }
1742  
show_usage()1743  void show_usage()
1744  {
1745      std::cout << "LumeShaderCompiler - Supported shader types: vertex (.vert), fragment (.frag), compute (.comp)\n\n"
1746                   "How to use: \n"
1747                   "LumeShaderCompiler.exe --source <source path> (sets destination path to same as source)\n"
1748                   "LumeShaderCompiler.exe --source <source path> --destination <destination path>\n"
1749                   "LumeShaderCompiler.exe --monitor (monitors changes in the source files)\n";
1750  }
1751  
filterByExtension(const std::vector<std::string> & aFilenames,const std::vector<std::string_view> & aIncludeExtensions)1752  std::vector<std::string> filterByExtension(
1753      const std::vector<std::string>& aFilenames, const std::vector<std::string_view>& aIncludeExtensions)
1754  {
1755      std::vector<std::string> filtered;
1756      for (auto const& file : aFilenames) {
1757          std::string lowercaseFileExt = std::filesystem::path(file).extension().string();
1758          std::transform(lowercaseFileExt.begin(), lowercaseFileExt.end(), lowercaseFileExt.begin(), tolower);
1759  
1760          for (auto const& extension : aIncludeExtensions) {
1761              if (lowercaseFileExt == extension) {
1762                  filtered.push_back(file);
1763                  break;
1764              }
1765          }
1766      }
1767  
1768      return filtered;
1769  }
1770  
main(int argc,char * argv[])1771  int main(int argc, char* argv[])
1772  {
1773      if (argc == 1) {
1774          show_usage();
1775          return 0;
1776      }
1777  
1778      std::filesystem::path const currentFolder = std::filesystem::current_path();
1779      std::filesystem::path shaderSourcesPath = currentFolder;
1780      std::filesystem::path compiledShaderDestinationPath;
1781      std::vector<std::filesystem::path> shaderIncludePaths;
1782      std::filesystem::path sourceFile;
1783  
1784      bool monitorChanges = false;
1785      bool optimizeSpirv = false;
1786      ShaderEnv envVersion = ShaderEnv::version_vulkan_1_0;
1787      for (int i = 1; i < argc; ++i) {
1788          const auto arg = std::string_view(argv[i]);
1789          if (arg == "--help") {
1790              show_usage();
1791              return 0;
1792          } else if (arg == "--sourceFile") {
1793              if (i + 1 < argc) {
1794                  sourceFile = argv[++i];
1795                  sourceFile.make_preferred();
1796                  shaderSourcesPath = sourceFile;
1797                  shaderSourcesPath.remove_filename();
1798                  if (compiledShaderDestinationPath.empty()) {
1799                      compiledShaderDestinationPath = shaderSourcesPath;
1800                  }
1801              } else {
1802                  LUME_LOG_E("--sourceFile option requires one argument.\n");
1803                  return 1;
1804              }
1805          } else if (arg == "--source") {
1806              if (i + 1 < argc) {
1807                  shaderSourcesPath = argv[++i];
1808                  shaderSourcesPath.make_preferred();
1809                  if (compiledShaderDestinationPath.empty()) {
1810                      compiledShaderDestinationPath = shaderSourcesPath;
1811                  }
1812              } else {
1813                  LUME_LOG_E("--source option requires one argument.");
1814                  return 1;
1815              }
1816          } else if (arg == "--destination") {
1817              if (i + 1 < argc) {
1818                  compiledShaderDestinationPath = argv[++i];
1819                  compiledShaderDestinationPath.make_preferred();
1820              } else {
1821                  LUME_LOG_E("--destination option requires one argument.");
1822                  return 1;
1823              }
1824          } else if (arg == "--include") {
1825              if (i + 1 < argc) {
1826                  shaderIncludePaths.emplace_back(argv[++i]).make_preferred();
1827              } else {
1828                  LUME_LOG_E("--include option requires one argument.");
1829                  return 1;
1830              }
1831  
1832          } else if (arg == "--monitor") {
1833              monitorChanges = true;
1834          } else if (arg == "--optimize") {
1835              optimizeSpirv = true;
1836          } else if (arg == "--vulkan") {
1837              if (i + 1 < argc) {
1838                  const auto version = std::string_view(argv[++i]);
1839                  if (version == "1.0") {
1840                      envVersion = ShaderEnv::version_vulkan_1_0;
1841                  } else if (version == "1.1") {
1842                      envVersion = ShaderEnv::version_vulkan_1_1;
1843                  } else if (version == "1.2") {
1844                      envVersion = ShaderEnv::version_vulkan_1_2;
1845  #ifdef GLSLANG_VERSION >= GLSLANG_VERSION_12_2_0
1846                  } else if (version == "1.3") {
1847                      envVersion = ShaderEnv::version_vulkan_1_3;
1848  #endif
1849                  } else {
1850                      LUME_LOG_E("Invalid argument for option --vulkan.");
1851                      return 1;
1852                  }
1853              } else {
1854                  LUME_LOG_E("--vulkan option requires one argument.");
1855                  return 1;
1856              }
1857          }
1858      }
1859  
1860      if (compiledShaderDestinationPath.empty()) {
1861          compiledShaderDestinationPath = currentFolder;
1862      }
1863  
1864      ige::FileMonitor fileMonitor;
1865  
1866      if (!std::filesystem::exists(shaderSourcesPath)) {
1867          LUME_LOG_E("Source path does not exist: '%s'", shaderSourcesPath.string().c_str());
1868          return 1;
1869      }
1870  
1871      // Make sure the destination dir exists.
1872      std::filesystem::create_directories(compiledShaderDestinationPath);
1873  
1874      if (!std::filesystem::exists(compiledShaderDestinationPath)) {
1875          LUME_LOG_E("Destination path does not exist: '%s'", compiledShaderDestinationPath.string().c_str());
1876          return 1;
1877      }
1878  
1879      fileMonitor.AddPath(shaderSourcesPath.string());
1880      std::vector<std::string> fileList = [&]() {
1881          std::vector<std::string> list;
1882          if (!sourceFile.empty()) {
1883              list.push_back(sourceFile.u8string());
1884          } else {
1885              list = fileMonitor.getMonitoredFiles();
1886          }
1887          return list;
1888      }();
1889  
1890      const std::vector<std::string_view> supportedFileTypes = { ".vert", ".frag", ".comp", ".json" };
1891      fileList = filterByExtension(fileList, supportedFileTypes);
1892  
1893      LUME_LOG_I("     Source path: '%s'", std::filesystem::absolute(shaderSourcesPath).string().c_str());
1894      for (auto const& path : shaderIncludePaths) {
1895          LUME_LOG_I("    Include path: '%s'", std::filesystem::absolute(path).string().c_str());
1896      }
1897      LUME_LOG_I("Destination path: '%s'", std::filesystem::absolute(compiledShaderDestinationPath).string().c_str());
1898      LUME_LOG_I("");
1899      LUME_LOG_I("Processing:");
1900  
1901      int errorCount = 0;
1902      scope scope([]() { glslang::InitializeProcess(); }, []() { glslang::FinalizeProcess(); });
1903  
1904      std::vector<std::filesystem::path> searchPath;
1905      searchPath.reserve(searchPath.size() + 1 + shaderIncludePaths.size());
1906      searchPath.emplace_back(shaderSourcesPath.string());
1907      for (auto const& path : shaderIncludePaths) {
1908          searchPath.emplace_back(path.string());
1909      }
1910  
1911      auto settings =
1912          CompilationSettings { envVersion, searchPath, {}, shaderSourcesPath, compiledShaderDestinationPath };
1913  
1914      if (optimizeSpirv) {
1915          spv_target_env targetEnv = spv_target_env::SPV_ENV_VULKAN_1_0;
1916          switch (envVersion) {
1917              case ShaderEnv::version_vulkan_1_0:
1918                  targetEnv = spv_target_env::SPV_ENV_VULKAN_1_0;
1919                  break;
1920              case ShaderEnv::version_vulkan_1_1:
1921                  targetEnv = spv_target_env::SPV_ENV_VULKAN_1_1;
1922                  break;
1923              case ShaderEnv::version_vulkan_1_2:
1924                  targetEnv = spv_target_env::SPV_ENV_VULKAN_1_2;
1925                  break;
1926              case ShaderEnv::version_vulkan_1_3:
1927                  targetEnv = spv_target_env::SPV_ENV_VULKAN_1_3;
1928                  break;
1929              default:
1930                  break;
1931          }
1932          settings.optimizer.emplace(targetEnv);
1933      }
1934  
1935      // Startup compilation.
1936      for (auto const& file : fileList) {
1937          std::string relativeFilename = std::filesystem::relative(file, shaderSourcesPath).string();
1938          LUME_LOG_D("Tracked source file: '%s'", relativeFilename.c_str());
1939          if (!runAllCompilationStages(file, settings)) {
1940              errorCount++;
1941          }
1942      }
1943  
1944      if (errorCount == 0) {
1945          LUME_LOG_I("Success.");
1946      } else {
1947          LUME_LOG_E("Failed: %d", errorCount);
1948      }
1949  
1950      if (monitorChanges) {
1951          LUME_LOG_I("Monitoring file changes.");
1952      }
1953  
1954      // Main loop.
1955      while (monitorChanges) {
1956          std::vector<std::string> addedFiles, removedFiles, modifiedFiles;
1957          fileMonitor.scanModifications(addedFiles, removedFiles, modifiedFiles);
1958          modifiedFiles = filterByExtension(modifiedFiles, supportedFileTypes);
1959  
1960          if (sourceFile.empty()) {
1961              addedFiles = filterByExtension(addedFiles, supportedFileTypes);
1962              removedFiles = filterByExtension(removedFiles, supportedFileTypes);
1963  
1964              if (!addedFiles.empty()) {
1965                  LUME_LOG_I("Files added:");
1966                  for (auto const& addedFile : addedFiles) {
1967                      runAllCompilationStages(addedFile, settings);
1968                  }
1969              }
1970  
1971              if (!removedFiles.empty()) {
1972                  LUME_LOG_I("Files removed:");
1973                  for (auto const& removedFile : removedFiles) {
1974                      std::string relativeFilename = std::filesystem::relative(removedFile, shaderSourcesPath).string();
1975                      LUME_LOG_I("  %s", relativeFilename.c_str());
1976                  }
1977              }
1978  
1979              if (!modifiedFiles.empty()) {
1980                  LUME_LOG_I("Files modified:");
1981                  for (auto const& modifiedFile : modifiedFiles) {
1982                      runAllCompilationStages(modifiedFile, settings);
1983                  }
1984              }
1985          } else if (!modifiedFiles.empty()) {
1986              if (auto pos = std::find_if(modifiedFiles.cbegin(), modifiedFiles.cend(),
1987                      [&sourceFile](const std::string& modified) { return modified == sourceFile; });
1988                  pos != modifiedFiles.cend()) {
1989                  runAllCompilationStages(*pos, settings);
1990              }
1991          }
1992  
1993          std::this_thread::sleep_for(std::chrono::seconds(1));
1994      }
1995  
1996      return errorCount;
1997  }