1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "pipeline_state_object_vk.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <cstdint>
21 #include <vulkan/vulkan_core.h>
22 
23 #include <base/util/formats.h>
24 #include <render/device/pipeline_layout_desc.h>
25 #include <render/device/pipeline_state_desc.h>
26 #include <render/namespace.h>
27 
28 #include "device/gpu_program.h"
29 #include "device/gpu_program_util.h"
30 #include "device/gpu_resource_handle_util.h"
31 #include "util/log.h"
32 #include "vulkan/create_functions_vk.h"
33 #include "vulkan/device_vk.h"
34 #include "vulkan/gpu_program_vk.h"
35 #include "vulkan/pipeline_create_functions_vk.h"
36 #include "vulkan/validate_vk.h"
37 
38 using namespace BASE_NS;
39 
40 RENDER_BEGIN_NAMESPACE()
41 namespace {
42 constexpr uint32_t MAX_DYNAMIC_STATE_COUNT { 10u };
43 
44 constexpr VkDescriptorSetLayoutCreateInfo EMPTY_LAYOUT_INFO {
45     VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO, // sType
46     nullptr,                                             // pNext
47     0U,                                                  // flags
48     0U,                                                  // bindingCount
49     nullptr,                                             // pBindings
50 };
51 
GetVertexInputs(const VertexInputDeclarationView & vertexInputDeclaration,vector<VkVertexInputBindingDescription> & vertexInputBindingDescriptions,vector<VkVertexInputAttributeDescription> & vertexInputAttributeDescriptions)52 void GetVertexInputs(const VertexInputDeclarationView& vertexInputDeclaration,
53     vector<VkVertexInputBindingDescription>& vertexInputBindingDescriptions,
54     vector<VkVertexInputAttributeDescription>& vertexInputAttributeDescriptions)
55 {
56     vertexInputBindingDescriptions.resize(vertexInputDeclaration.bindingDescriptions.size());
57     vertexInputAttributeDescriptions.resize(vertexInputDeclaration.attributeDescriptions.size());
58 
59     for (size_t idx = 0; idx < vertexInputBindingDescriptions.size(); ++idx) {
60         const auto& bindingRef = vertexInputDeclaration.bindingDescriptions[idx];
61 
62         const VkVertexInputRate vertexInputRate = (VkVertexInputRate)bindingRef.vertexInputRate;
63         vertexInputBindingDescriptions[idx] = {
64             bindingRef.binding, // binding
65             bindingRef.stride,  // stride
66             vertexInputRate,    // inputRate
67         };
68     }
69 
70     for (size_t idx = 0; idx < vertexInputAttributeDescriptions.size(); ++idx) {
71         const auto& attributeRef = vertexInputDeclaration.attributeDescriptions[idx];
72         const VkFormat vertexInputFormat = (VkFormat)attributeRef.format;
73         vertexInputAttributeDescriptions[idx] = {
74             attributeRef.location, // location
75             attributeRef.binding,  // binding
76             vertexInputFormat,     // format
77             attributeRef.offset,   // offset
78         };
79     }
80 }
81 
82 struct DescriptorSetFillData {
83     uint32_t descriptorSetCount { 0 };
84     uint32_t pushConstantRangeCount { 0u };
85     VkPushConstantRange pushConstantRanges[PipelineLayoutConstants::MAX_PUSH_CONSTANT_RANGE_COUNT];
86     VkDescriptorSetLayout descriptorSetLayouts[PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT] { VK_NULL_HANDLE,
87         VK_NULL_HANDLE, VK_NULL_HANDLE, VK_NULL_HANDLE };
88     // the layout can be coming for special descriptor sets (with e.g. platform formats and immutable samplers)
89     bool descriptorSetLayoutOwnership[PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT] { true, true, true, true };
90 };
91 
GetDescriptorSetFillData(const PipelineLayout & pipelineLayout,const LowLevelPipelineLayoutData & pipelineLayoutData,const VkDevice device,const VkShaderStageFlags neededShaderStageFlags,DescriptorSetFillData & ds)92 void GetDescriptorSetFillData(const PipelineLayout& pipelineLayout,
93     const LowLevelPipelineLayoutData& pipelineLayoutData, const VkDevice device,
94     const VkShaderStageFlags neededShaderStageFlags, DescriptorSetFillData& ds)
95 {
96     // NOTE: support for only one push constant
97     ds.pushConstantRangeCount = (pipelineLayout.pushConstant.byteSize > 0) ? 1u : 0u;
98     const LowLevelPipelineLayoutDataVk& pipelineLayoutDataVk =
99         static_cast<const LowLevelPipelineLayoutDataVk&>(pipelineLayoutData);
100     // uses the same temp array for all bindings in all sets
101     VkDescriptorSetLayoutBinding descriptorSetLayoutBindings[PipelineLayoutConstants::MAX_DESCRIPTOR_SET_BINDING_COUNT];
102     for (uint32_t operationIdx = 0; operationIdx < PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT; ++operationIdx) {
103         const auto& descRef = pipelineLayout.descriptorSetLayouts[operationIdx];
104         if ((ds.descriptorSetCount >= pipelineLayout.descriptorSetCount) &&
105             (descRef.set == PipelineLayoutConstants::INVALID_INDEX)) {
106             continue;
107         }
108         ds.descriptorSetCount++;
109         const uint32_t setIdx = operationIdx;
110         const auto& descSetLayoutData = pipelineLayoutDataVk.descriptorSetLayouts[setIdx];
111         // NOTE: we are currently only doing handling of special (immutable sampler needing) layouts
112         // with the descriptor set layout coming from the real descriptor set
113         if (descSetLayoutData.flags & LowLevelDescriptorSetVk::DESCRIPTOR_SET_LAYOUT_IMMUTABLE_SAMPLER_BIT) {
114             PLUGIN_ASSERT(descSetLayoutData.descriptorSetLayout);
115             ds.descriptorSetLayouts[setIdx] = descSetLayoutData.descriptorSetLayout;
116             ds.descriptorSetLayoutOwnership[setIdx] = false; // not owned, cannot be destroyed
117         } else {
118             constexpr VkDescriptorSetLayoutCreateFlags descriptorSetLayoutCreateFlags { 0 };
119             if (descRef.set == PipelineLayoutConstants::INVALID_INDEX) {
120                 // provide empty layout for empty set
121                 VALIDATE_VK_RESULT(vkCreateDescriptorSetLayout(device, // device
122                     &EMPTY_LAYOUT_INFO,                                // pCreateInfo
123                     nullptr,                                           // pAllocator
124                     &ds.descriptorSetLayouts[setIdx]));                // pSetLayout
125             } else {
126                 const uint32_t bindingCount = static_cast<uint32_t>(descRef.bindings.size());
127                 PLUGIN_ASSERT(bindingCount <= PipelineLayoutConstants::MAX_DESCRIPTOR_SET_BINDING_COUNT);
128                 for (uint32_t bindingOpIdx = 0; bindingOpIdx < bindingCount; ++bindingOpIdx) {
129                     const auto& bindingRef = descRef.bindings[bindingOpIdx];
130                     const VkShaderStageFlags shaderStageFlags = (VkShaderStageFlags)bindingRef.shaderStageFlags;
131                     const uint32_t bindingIdx = bindingRef.binding;
132                     const VkDescriptorType descriptorType = (VkDescriptorType)bindingRef.descriptorType;
133                     const uint32_t descriptorCount = bindingRef.descriptorCount;
134 
135                     PLUGIN_ASSERT((shaderStageFlags & neededShaderStageFlags) > 0);
136                     descriptorSetLayoutBindings[bindingOpIdx] = {
137                         bindingIdx,       // binding
138                         descriptorType,   // descriptorType
139                         descriptorCount,  // descriptorCount
140                         shaderStageFlags, // stageFlags
141                         nullptr,          // pImmutableSamplers
142                     };
143                 }
144 
145                 const VkDescriptorSetLayoutCreateInfo descriptorSetLayoutCreateInfo {
146                     VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO, // sType
147                     nullptr,                                             // pNext
148                     descriptorSetLayoutCreateFlags,                      // flags
149                     bindingCount,                                        // bindingCount
150                     descriptorSetLayoutBindings,                         // pBindings
151                 };
152 
153                 VALIDATE_VK_RESULT(vkCreateDescriptorSetLayout(device, // device
154                     &descriptorSetLayoutCreateInfo,                    // pCreateInfo
155                     nullptr,                                           // pAllocator
156                     &ds.descriptorSetLayouts[setIdx]));                // pSetLayout
157             }
158         }
159     }
160 
161     if (ds.pushConstantRangeCount == 1) {
162         const VkShaderStageFlags shaderStageFlags = (VkShaderStageFlags)pipelineLayout.pushConstant.shaderStageFlags;
163         PLUGIN_ASSERT((shaderStageFlags & neededShaderStageFlags) > 0);
164         const uint32_t bytesize = pipelineLayout.pushConstant.byteSize;
165         ds.pushConstantRanges[0] = {
166             shaderStageFlags, // stageFlags
167             0,                // offset
168             bytesize,         // size
169         };
170     }
171 }
172 } // namespace
173 
GraphicsPipelineStateObjectVk(Device & device,const GpuShaderProgram & gpuShaderProgram,const GraphicsState & graphicsState,const PipelineLayout & pipelineLayout,const VertexInputDeclarationView & vertexInputDeclaration,const ShaderSpecializationConstantDataView & specializationConstants,const array_view<const DynamicStateEnum> dynamicStates,const RenderPassDesc & renderPassDesc,const array_view<const RenderPassSubpassDesc> & renderPassSubpassDescs,const uint32_t subpassIndex,const LowLevelRenderPassData & renderPassData,const LowLevelPipelineLayoutData & pipelineLayoutData)174 GraphicsPipelineStateObjectVk::GraphicsPipelineStateObjectVk(Device& device, const GpuShaderProgram& gpuShaderProgram,
175     const GraphicsState& graphicsState, const PipelineLayout& pipelineLayout,
176     const VertexInputDeclarationView& vertexInputDeclaration,
177     const ShaderSpecializationConstantDataView& specializationConstants,
178     const array_view<const DynamicStateEnum> dynamicStates, const RenderPassDesc& renderPassDesc,
179     const array_view<const RenderPassSubpassDesc>& renderPassSubpassDescs, const uint32_t subpassIndex,
180     const LowLevelRenderPassData& renderPassData, const LowLevelPipelineLayoutData& pipelineLayoutData)
181     : GraphicsPipelineStateObject(), device_(device)
182 {
183     PLUGIN_ASSERT(!renderPassSubpassDescs.empty());
184 
185     const LowLevelRenderPassDataVk& lowLevelRenderPassDataVk = (const LowLevelRenderPassDataVk&)renderPassData;
186 
187     const DeviceVk& deviceVk = (const DeviceVk&)device_;
188     const DevicePlatformDataVk& devicePlatVk = (const DevicePlatformDataVk&)deviceVk.GetPlatformData();
189     const VkDevice vkDevice = devicePlatVk.device;
190 
191     const GpuShaderProgramVk& program = static_cast<const GpuShaderProgramVk&>(gpuShaderProgram);
192     const GpuShaderProgramPlatformDataVk& platData = program.GetPlatformData();
193 
194     vector<VkVertexInputBindingDescription> vertexInputBindingDescriptions;
195     vector<VkVertexInputAttributeDescription> vertexInputAttributeDescriptions;
196     GetVertexInputs(vertexInputDeclaration, vertexInputBindingDescriptions, vertexInputAttributeDescriptions);
197 
198     constexpr VkPipelineVertexInputStateCreateFlags pipelineVertexInputStateCreateFlags { 0 };
199     const VkPipelineVertexInputStateCreateInfo pipelineVertexInputStateCreateInfo {
200         VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO, // sType
201         nullptr,                                                   // pNext
202         pipelineVertexInputStateCreateFlags,                       // flags
203         static_cast<uint32_t>(vertexInputBindingDescriptions.size()),           // vertexBindingDescriptionCount
204         vertexInputBindingDescriptions.data(),                     // pVertexBindingDescriptions
205         static_cast<uint32_t>(vertexInputAttributeDescriptions.size()),         // vertexAttributeDescriptionCount
206         vertexInputAttributeDescriptions.data(),                   // pVertexAttributeDescriptions
207     };
208 
209     const GraphicsState::InputAssembly& inputAssembly = graphicsState.inputAssembly;
210 
211     constexpr VkPipelineInputAssemblyStateCreateFlags pipelineInputAssemblyStateCreateFlags { 0 };
212     const VkPipelineInputAssemblyStateCreateInfo pipelineInputAssemblyStateCreateInfo {
213         VK_STRUCTURE_TYPE_PIPELINE_INPUT_ASSEMBLY_STATE_CREATE_INFO, // sType
214         nullptr,                                                     // pNext
215         pipelineInputAssemblyStateCreateFlags,                       // flags
216         (VkPrimitiveTopology)inputAssembly.primitiveTopology,        // topology
217         (VkBool32)inputAssembly.enablePrimitiveRestart,              // primitiveRestartEnable
218     };
219 
220     const GraphicsState::RasterizationState& rasterizationState = graphicsState.rasterizationState;
221 
222     constexpr VkPipelineRasterizationStateCreateFlags pipelineRasterizationStateCreateFlags { 0 };
223     const VkPipelineRasterizationStateCreateInfo pipelineRasterizationStateCreateInfo {
224         VK_STRUCTURE_TYPE_PIPELINE_RASTERIZATION_STATE_CREATE_INFO, // sType
225         nullptr,                                                    // Next
226         pipelineRasterizationStateCreateFlags,                      // flags
227         (VkBool32)rasterizationState.enableDepthClamp,              // depthClampEnable
228         (VkBool32)rasterizationState.enableRasterizerDiscard,       // rasterizerDiscardEnable
229         (VkPolygonMode)rasterizationState.polygonMode,              // polygonMode
230         (VkCullModeFlags)rasterizationState.cullModeFlags,          // cullMode
231         (VkFrontFace)rasterizationState.frontFace,                  // frontFace
232         (VkBool32)rasterizationState.enableDepthBias,               // depthBiasEnable
233         rasterizationState.depthBiasConstantFactor,                 // depthBiasConstantFactor
234         rasterizationState.depthBiasClamp,                          // depthBiasClamp
235         rasterizationState.depthBiasSlopeFactor,                    // depthBiasSlopeFactor
236         rasterizationState.lineWidth,                               // lineWidth
237     };
238 
239     const GraphicsState::DepthStencilState& depthStencilState = graphicsState.depthStencilState;
240 
241     const GraphicsState::StencilOpState& frontStencilOpState = depthStencilState.frontStencilOpState;
242     const VkStencilOpState frontStencilOpStateVk {
243         (VkStencilOp)frontStencilOpState.failOp,      // failOp
244         (VkStencilOp)frontStencilOpState.passOp,      // passOp
245         (VkStencilOp)frontStencilOpState.depthFailOp, // depthFailOp
246         (VkCompareOp)frontStencilOpState.compareOp,   // compareOp
247         frontStencilOpState.compareMask,              // compareMask
248         frontStencilOpState.writeMask,                // writeMask
249         frontStencilOpState.reference,                // reference
250     };
251     const GraphicsState::StencilOpState& backStencilOpState = depthStencilState.backStencilOpState;
252     const VkStencilOpState backStencilOpStateVk {
253         (VkStencilOp)backStencilOpState.failOp,      // failOp
254         (VkStencilOp)backStencilOpState.passOp,      // passOp
255         (VkStencilOp)backStencilOpState.depthFailOp, // depthFailOp
256         (VkCompareOp)backStencilOpState.compareOp,   // compareOp
257         backStencilOpState.compareMask,              // compareMask
258         backStencilOpState.writeMask,                // writeMask
259         backStencilOpState.reference,                // reference
260     };
261 
262     constexpr VkPipelineDepthStencilStateCreateFlags pipelineDepthStencilStateCreateFlags { 0 };
263     const VkPipelineDepthStencilStateCreateInfo pipelineDepthStencilStateCreateInfo {
264         VK_STRUCTURE_TYPE_PIPELINE_DEPTH_STENCIL_STATE_CREATE_INFO, // sType
265         nullptr,                                                    // pNext
266         pipelineDepthStencilStateCreateFlags,                       // flags
267         (VkBool32)depthStencilState.enableDepthTest,                // depthTestEnable
268         (VkBool32)depthStencilState.enableDepthWrite,               // depthWriteEnable
269         (VkCompareOp)depthStencilState.depthCompareOp,              // depthCompareOp
270         (VkBool32)depthStencilState.enableDepthBoundsTest,          // depthBoundsTestEnable
271         (VkBool32)depthStencilState.enableStencilTest,              // stencilTestEnable
272         frontStencilOpStateVk,                                      // front
273         backStencilOpStateVk,                                       // back
274         depthStencilState.minDepthBounds,                           // minDepthBounds
275         depthStencilState.maxDepthBounds,                           // maxDepthBounds
276     };
277 
278     const GraphicsState::ColorBlendState& colorBlendState = graphicsState.colorBlendState;
279 
280     VkPipelineColorBlendAttachmentState
281         pipelineColorBlendAttachmentStates[PipelineStateConstants::MAX_COLOR_ATTACHMENT_COUNT];
282     const uint32_t colAttachmentCount = renderPassSubpassDescs[subpassIndex].colorAttachmentCount;
283     PLUGIN_ASSERT(colAttachmentCount <= PipelineStateConstants::MAX_COLOR_ATTACHMENT_COUNT);
284     for (size_t idx = 0; idx < colAttachmentCount; ++idx) {
285         const GraphicsState::ColorBlendState::Attachment& attachmentBlendState = colorBlendState.colorAttachments[idx];
286 
287         pipelineColorBlendAttachmentStates[idx] = {
288             (VkBool32)attachmentBlendState.enableBlend,                 // blendEnable
289             (VkBlendFactor)attachmentBlendState.srcColorBlendFactor,    // srcColorBlendFactor
290             (VkBlendFactor)attachmentBlendState.dstColorBlendFactor,    // dstColorBlendFactor
291             (VkBlendOp)attachmentBlendState.colorBlendOp,               // colorBlendOp
292             (VkBlendFactor)attachmentBlendState.srcAlphaBlendFactor,    // srcAlphaBlendFactor
293             (VkBlendFactor)attachmentBlendState.dstAlphaBlendFactor,    // dstAlphaBlendFactor
294             (VkBlendOp)attachmentBlendState.alphaBlendOp,               // alphaBlendOp
295             (VkColorComponentFlags)attachmentBlendState.colorWriteMask, // colorWriteMask
296         };
297     }
298 
299     constexpr VkPipelineColorBlendStateCreateFlags pipelineColorBlendStateCreateFlags { 0 };
300     const float* bc = colorBlendState.colorBlendConstants;
301     const VkPipelineColorBlendStateCreateInfo pipelineColorBlendStateCreateInfo {
302         VK_STRUCTURE_TYPE_PIPELINE_COLOR_BLEND_STATE_CREATE_INFO, // sType
303         nullptr,                                                  // pNext
304         pipelineColorBlendStateCreateFlags,                       // flags
305         (VkBool32)colorBlendState.enableLogicOp,                  // logicOpEnable
306         (VkLogicOp)colorBlendState.logicOp,                       // logicOp
307         colAttachmentCount,                                       // attachmentCount
308         pipelineColorBlendAttachmentStates,                       // pAttachments
309         { bc[0u], bc[1u], bc[2u], bc[3u] },                       // blendConstants[4]
310     };
311 
312     VkPipelineDynamicStateCreateInfo pipelineDynamicStateCreateInfo {};
313     pipelineDynamicStateCreateInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_DYNAMIC_STATE_CREATE_INFO;
314 
315     VkDynamicState vkDynamicStates[MAX_DYNAMIC_STATE_COUNT];
316     uint32_t dynamicStateCount = Math::min(MAX_DYNAMIC_STATE_COUNT, static_cast<uint32_t>(dynamicStates.size()));
317     if (dynamicStateCount > 0) {
318         for (uint32_t idx = 0; idx < dynamicStateCount; ++idx) {
319             vkDynamicStates[idx] = (VkDynamicState)dynamicStates[idx];
320         }
321 
322         constexpr VkPipelineDynamicStateCreateFlags pipelineDynamicStateCreateFlags { 0 };
323         pipelineDynamicStateCreateInfo = {
324             VK_STRUCTURE_TYPE_PIPELINE_DYNAMIC_STATE_CREATE_INFO, // sType
325             nullptr,                                              // pNext
326             pipelineDynamicStateCreateFlags,                      // flags
327             dynamicStateCount,                                    // dynamicStateCount
328             vkDynamicStates,                                      // pDynamicStates
329         };
330     }
331 
332     constexpr uint32_t maxViewportCount { 1 };
333     constexpr uint32_t maxScissorCount { 1 };
334     VkPipelineViewportStateCreateInfo viewportStateCreateInfo {
335         VK_STRUCTURE_TYPE_PIPELINE_VIEWPORT_STATE_CREATE_INFO, // sType
336         nullptr,                                               // pNext
337         0,                                                     // flags
338         maxViewportCount,                                      // viewportCount
339         &lowLevelRenderPassDataVk.viewport,                    // pViewports
340         maxScissorCount,                                       // scissorCount
341         &lowLevelRenderPassDataVk.scissor,                     // pScissors
342     };
343 
344     // reserve max
345     vector<VkSpecializationMapEntry> vertexStageSpecializations;
346     vector<VkSpecializationMapEntry> fragmentStageSpecializations;
347     vertexStageSpecializations.reserve(specializationConstants.constants.size());
348     fragmentStageSpecializations.reserve(specializationConstants.constants.size());
349 
350     uint32_t vertexDataSize = 0;
351     uint32_t fragmentDataSize = 0;
352 
353     for (auto const& constant : specializationConstants.constants) {
354         const auto constantSize = GpuProgramUtil::SpecializationByteSize(constant.type);
355         const VkSpecializationMapEntry entry {
356             static_cast<uint32_t>(constant.id), // constantID
357             constant.offset,                    // offset
358             constantSize                        // entry.size
359         };
360         if (constant.shaderStage & CORE_SHADER_STAGE_VERTEX_BIT) {
361             vertexStageSpecializations.push_back(entry);
362             vertexDataSize = std::max(vertexDataSize, constant.offset + constantSize);
363         }
364         if (constant.shaderStage & CORE_SHADER_STAGE_FRAGMENT_BIT) {
365             fragmentStageSpecializations.push_back(entry);
366             fragmentDataSize = std::max(fragmentDataSize, constant.offset + constantSize);
367         }
368     }
369 
370     const VkSpecializationInfo vertexSpecializationInfo {
371         static_cast<uint32_t>(vertexStageSpecializations.size()), // mapEntryCount
372         vertexStageSpecializations.data(),                        // pMapEntries
373         vertexDataSize,                                           // dataSize
374         specializationConstants.data.data()                       // pData
375     };
376 
377     const VkSpecializationInfo fragmentSpecializationInfo {
378         static_cast<uint32_t>(fragmentStageSpecializations.size()), // mapEntryCount
379         fragmentStageSpecializations.data(),                        // pMapEntries
380         fragmentDataSize,                                           // dataSize
381         specializationConstants.data.data()                         // pData
382     };
383 
384     constexpr uint32_t stageCount { 2 };
385     const VkPipelineShaderStageCreateInfo pipelineShaderStageCreateInfos[stageCount] {
386         {
387             VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, // sTypeoldHandle
388             nullptr,                                             // pNextoldHandle
389             0,                                                   // flags
390             VkShaderStageFlagBits::VK_SHADER_STAGE_VERTEX_BIT,   // stage
391             platData.vert,                                       // module
392             "main",                                              // pName
393             &vertexSpecializationInfo,                           // pSpecializationInfo
394         },
395         {
396             VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, // sType
397             nullptr,                                             // pNext
398             0,                                                   // flags
399             VkShaderStageFlagBits::VK_SHADER_STAGE_FRAGMENT_BIT, // stage
400             platData.frag,                                       // module
401             "main",                                              // pName
402             &fragmentSpecializationInfo                          // pSpecializationInfo
403         },
404     };
405 
406     // NOTE: support for only one push constant
407     DescriptorSetFillData ds;
408     GetDescriptorSetFillData(pipelineLayout, pipelineLayoutData, vkDevice,
409         VkShaderStageFlagBits::VK_SHADER_STAGE_VERTEX_BIT | VkShaderStageFlagBits::VK_SHADER_STAGE_FRAGMENT_BIT, ds);
410 
411     const VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo {
412         VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO, // sType
413         nullptr,                                       // pNext
414         0,                                             // flags
415         ds.descriptorSetCount,                         // setLayoutCount
416         ds.descriptorSetLayouts,                       // pSetLayouts
417         ds.pushConstantRangeCount,                     // pushConstantRangeCount
418         ds.pushConstantRanges,                         // pPushConstantRanges
419     };
420 
421     VALIDATE_VK_RESULT(vkCreatePipelineLayout(vkDevice, // device
422         &pipelineLayoutCreateInfo,                      // pCreateInfo,
423         nullptr,                                        // pAllocator
424         &plat_.pipelineLayout));                        // pPipelineLayout
425 
426     constexpr VkPipelineMultisampleStateCreateFlags pipelineMultisampleStateCreateFlags { 0 };
427 
428     VkSampleCountFlagBits sampleCountFlagBits { VK_SAMPLE_COUNT_1_BIT };
429     if (renderPassSubpassDescs[subpassIndex].colorAttachmentCount > 0) {
430         const auto& ref = lowLevelRenderPassDataVk.renderPassCompatibilityDesc.attachments[0];
431         sampleCountFlagBits = (ref.sampleCountFlags == 0) ? VkSampleCountFlagBits::VK_SAMPLE_COUNT_1_BIT
432                                                           : (VkSampleCountFlagBits)ref.sampleCountFlags;
433     }
434 
435     VkBool32 sampleShadingEnable = VK_FALSE;
436     float minSampleShading = 0.0f;
437 
438     const bool msaaEnabled =
439         (VkBool32)((sampleCountFlagBits != VkSampleCountFlagBits::VK_SAMPLE_COUNT_1_BIT) &&
440                    (sampleCountFlagBits != VkSampleCountFlagBits::VK_SAMPLE_COUNT_FLAG_BITS_MAX_ENUM))
441             ? true
442             : false;
443     if (msaaEnabled) {
444         if (devicePlatVk.enabledPhysicalDeviceFeatures.sampleRateShading) {
445             sampleShadingEnable = VK_TRUE;
446             minSampleShading = deviceVk.GetFeatureConfigurations().minSampleShading;
447         }
448     }
449 
450     // NOTE: alpha to coverage
451     constexpr VkBool32 alphaToCoverageEnable { false };
452     constexpr VkBool32 alphaToOneEnable { false };
453 
454     const VkPipelineMultisampleStateCreateInfo pipelineMultisampleStateCreateInfo {
455         VK_STRUCTURE_TYPE_PIPELINE_MULTISAMPLE_STATE_CREATE_INFO, // sType
456         nullptr,                                                  // pNext
457         pipelineMultisampleStateCreateFlags,                      // flags
458         sampleCountFlagBits,                                      // rasterizationSamples
459         sampleShadingEnable,                                      // sampleShadingEnable
460         minSampleShading,                                         // minSampleShading
461         nullptr,                                                  // pSampleMask
462         alphaToCoverageEnable,                                    // alphaToCoverageEnable
463         alphaToOneEnable,                                         // alphaToOneEnable
464     };
465 
466     // needs nullptr if no dynamic states
467     VkPipelineDynamicStateCreateInfo* ptrPipelineDynamicStateCreateInfo = nullptr;
468     if (dynamicStateCount > 0) {
469         ptrPipelineDynamicStateCreateInfo = &pipelineDynamicStateCreateInfo;
470     }
471 
472     constexpr VkPipelineCreateFlags pipelineCreateFlags { 0 };
473     const VkRenderPass renderPass = lowLevelRenderPassDataVk.renderPassCompatibility;
474     const VkGraphicsPipelineCreateInfo graphicsPipelineCreateInfo {
475         VK_STRUCTURE_TYPE_GRAPHICS_PIPELINE_CREATE_INFO, // sType
476         nullptr,                                         // pNext
477         pipelineCreateFlags,                             // flags
478         stageCount,                                      // stageCount
479         pipelineShaderStageCreateInfos,                  // pStages
480         &pipelineVertexInputStateCreateInfo,             // pVertexInputState
481         &pipelineInputAssemblyStateCreateInfo,           // pInputAssemblyState
482         nullptr,                                         // pTessellationState
483         &viewportStateCreateInfo,                        // pViewportState
484         &pipelineRasterizationStateCreateInfo,           // pRasterizationState
485         &pipelineMultisampleStateCreateInfo,             // pMultisampleState
486         &pipelineDepthStencilStateCreateInfo,            // pDepthStencilState
487         &pipelineColorBlendStateCreateInfo,              // pColorBlendState
488         ptrPipelineDynamicStateCreateInfo,               // pDynamicState
489         plat_.pipelineLayout,                            // layout
490         renderPass,                                      // renderPass
491         subpassIndex,                                    // subpass
492         VK_NULL_HANDLE,                                  // basePipelineHandle
493         0,                                               // basePipelineIndex
494     };
495 
496     VALIDATE_VK_RESULT(vkCreateGraphicsPipelines(vkDevice, // device
497         devicePlatVk.pipelineCache,                        // pipelineCache
498         1,                                                 // createInfoCount
499         &graphicsPipelineCreateInfo,                       // pCreateInfos
500         nullptr,                                           // pAllocator
501         &plat_.pipeline));                                 // pPipelines
502 
503     // NOTE: direct destruction here
504     for (uint32_t idx = 0; idx < PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT; ++idx) {
505         const auto& descRef = ds.descriptorSetLayouts[idx];
506         if (descRef && ds.descriptorSetLayoutOwnership[idx]) {
507             vkDestroyDescriptorSetLayout(vkDevice, // device
508                 descRef,                           // descriptorSetLayout
509                 nullptr);                          // pAllocator
510         }
511     }
512 }
513 
~GraphicsPipelineStateObjectVk()514 GraphicsPipelineStateObjectVk::~GraphicsPipelineStateObjectVk()
515 {
516     const VkDevice device = ((const DevicePlatformDataVk&)device_.GetPlatformData()).device;
517     if (plat_.pipeline) {
518         vkDestroyPipeline(device, // device
519             plat_.pipeline,       // pipeline
520             nullptr);             // pAllocator
521     }
522     if (plat_.pipelineLayout) {
523         vkDestroyPipelineLayout(device, // device
524             plat_.pipelineLayout,       // pipelineLayout
525             nullptr);                   // pAllocator
526     }
527 }
528 
GetPlatformData() const529 const PipelineStateObjectPlatformDataVk& GraphicsPipelineStateObjectVk::GetPlatformData() const
530 {
531     return plat_;
532 }
533 
ComputePipelineStateObjectVk(Device & device,const GpuComputeProgram & gpuComputeProgram,const PipelineLayout & pipelineLayout,const ShaderSpecializationConstantDataView & specializationConstants,const LowLevelPipelineLayoutData & pipelineLayoutData)534 ComputePipelineStateObjectVk::ComputePipelineStateObjectVk(Device& device, const GpuComputeProgram& gpuComputeProgram,
535     const PipelineLayout& pipelineLayout, const ShaderSpecializationConstantDataView& specializationConstants,
536     const LowLevelPipelineLayoutData& pipelineLayoutData)
537     : ComputePipelineStateObject(), device_(device)
538 {
539     const DeviceVk& deviceVk = (const DeviceVk&)device_;
540     const DevicePlatformDataVk& devicePlatVk = (const DevicePlatformDataVk&)deviceVk.GetPlatformData();
541     const VkDevice vkDevice = devicePlatVk.device;
542 
543     const GpuComputeProgramVk& program = static_cast<const GpuComputeProgramVk&>(gpuComputeProgram);
544     const auto& platData = program.GetPlatformData();
545     const VkShaderModule shaderModule = platData.comp;
546 
547     vector<VkSpecializationMapEntry> computeStateSpecializations;
548     computeStateSpecializations.reserve(specializationConstants.constants.size());
549     uint32_t computeDataSize = 0;
550     for (auto const& constant : specializationConstants.constants) {
551         const auto constantSize = GpuProgramUtil::SpecializationByteSize(constant.type);
552         const VkSpecializationMapEntry entry {
553             static_cast<uint32_t>(constant.id), // constantID
554             constant.offset,                    // offset
555             constantSize                        // entry.size
556         };
557         if (constant.shaderStage & CORE_SHADER_STAGE_COMPUTE_BIT) {
558             computeStateSpecializations.push_back(entry);
559             computeDataSize = std::max(computeDataSize, constant.offset + constantSize);
560         }
561     }
562 
563     const VkSpecializationInfo computeSpecializationInfo {
564         static_cast<uint32_t>(computeStateSpecializations.size()), // mapEntryCount
565         computeStateSpecializations.data(),                        // pMapEntries
566         computeDataSize,                                           // dataSize
567         specializationConstants.data.data()                        // pData
568     };
569 
570     const VkPipelineShaderStageCreateInfo pipelineShaderStageCreateInfo {
571         VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, // sType
572         nullptr,                                             // pNext
573         0,                                                   // flags
574         VkShaderStageFlagBits::VK_SHADER_STAGE_COMPUTE_BIT,  // stage
575         shaderModule,                                        // module
576         "main",                                              // pName
577         &computeSpecializationInfo,                          // pSpecializationInfo
578     };
579 
580     // NOTE: support for only one push constant
581     DescriptorSetFillData ds;
582     GetDescriptorSetFillData(
583         pipelineLayout, pipelineLayoutData, vkDevice, VkShaderStageFlagBits::VK_SHADER_STAGE_COMPUTE_BIT, ds);
584 
585     const VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo {
586         VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO, // sType
587         nullptr,                                       // pNext
588         0,                                             // flags
589         ds.descriptorSetCount,                         // setLayoutCount
590         ds.descriptorSetLayouts,                       // pSetLayouts
591         ds.pushConstantRangeCount,                     // pushConstantRangeCount
592         ds.pushConstantRanges,                         // pPushConstantRanges
593     };
594 
595     VALIDATE_VK_RESULT(vkCreatePipelineLayout(vkDevice, // device
596         &pipelineLayoutCreateInfo,                      // pCreateInfo,
597         nullptr,                                        // pAllocator
598         &plat_.pipelineLayout));                        // pPipelineLayout
599 
600     constexpr VkPipelineCreateFlags pipelineCreateFlags { 0 };
601     const VkComputePipelineCreateInfo computePipelineCreateInfo {
602         VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO, // sType
603         nullptr,                                        // pNext
604         pipelineCreateFlags,                            // flags
605         pipelineShaderStageCreateInfo,                  // stage
606         plat_.pipelineLayout,                           // layout
607         VK_NULL_HANDLE,                                 // basePipelineHandle
608         0,                                              // basePipelineIndex
609     };
610 
611     VALIDATE_VK_RESULT(vkCreateComputePipelines(vkDevice, // device
612         devicePlatVk.pipelineCache,                       // pipelineCache
613         1,                                                // createInfoCount
614         &computePipelineCreateInfo,                       // pCreateInfos
615         nullptr,                                          // pAllocator
616         &plat_.pipeline));                                // pPipelines
617 
618     // NOTE: direct destruction here
619     for (uint32_t idx = 0; idx < PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT; ++idx) {
620         const auto& descRef = ds.descriptorSetLayouts[idx];
621         if (descRef && ds.descriptorSetLayoutOwnership[idx]) {
622             vkDestroyDescriptorSetLayout(vkDevice, // device
623                 descRef,                           // descriptorSetLayout
624                 nullptr);                          // pAllocator
625         }
626     }
627 }
628 
~ComputePipelineStateObjectVk()629 ComputePipelineStateObjectVk::~ComputePipelineStateObjectVk()
630 {
631     const VkDevice device = ((const DevicePlatformDataVk&)device_.GetPlatformData()).device;
632     vkDestroyPipeline(device,       // device
633         plat_.pipeline,             // pipeline
634         nullptr);                   // pAllocator
635     vkDestroyPipelineLayout(device, // device
636         plat_.pipelineLayout,       // pipelineLayout
637         nullptr);                   // pAllocator
638 }
639 
GetPlatformData() const640 const PipelineStateObjectPlatformDataVk& ComputePipelineStateObjectVk::GetPlatformData() const
641 {
642     return plat_;
643 }
644 RENDER_END_NAMESPACE()
645