1 /*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "shader_module_vk.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <vulkan/vulkan_core.h>
21
22 #include <render/device/pipeline_layout_desc.h>
23 #include <render/namespace.h>
24
25 #include "device/device.h"
26 #include "device/gpu_program_util.h"
27 #include "device/shader_manager.h"
28 #include "util/log.h"
29 #include "vulkan/device_vk.h"
30 #include "vulkan/validate_vk.h"
31
32 using namespace BASE_NS;
33
34 RENDER_BEGIN_NAMESPACE()
35 namespace {
CreateShaderModule(const VkDevice device,array_view<const uint8_t> data)36 VkShaderModule CreateShaderModule(const VkDevice device, array_view<const uint8_t> data)
37 {
38 PLUGIN_ASSERT(data.size() > 0);
39 VkShaderModule shaderModule { VK_NULL_HANDLE };
40
41 constexpr VkShaderModuleCreateFlags shaderModuleCreateFlags { 0 };
42 const VkShaderModuleCreateInfo shaderModuleCreateInfo {
43 VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO, // sType
44 nullptr, // pNext
45 shaderModuleCreateFlags, // flags
46 static_cast<uint32_t>(data.size()), // codeSize
47 reinterpret_cast<const uint32_t*>(data.data()) // pCode
48 };
49
50 VALIDATE_VK_RESULT(vkCreateShaderModule(device, // device
51 &shaderModuleCreateInfo, // pCreateInfo
52 nullptr, // pAllocator
53 &shaderModule)); // pShaderModule
54
55 return shaderModule;
56 }
57 } // namespace
58
ShaderModuleVk(Device & device,const ShaderModuleCreateInfo & createInfo)59 ShaderModuleVk::ShaderModuleVk(Device& device, const ShaderModuleCreateInfo& createInfo)
60 : device_(device), shaderStageFlags_(createInfo.shaderStageFlags)
61 {
62 PLUGIN_ASSERT(createInfo.spvData.size() > 0);
63 PLUGIN_ASSERT(createInfo.shaderStageFlags & (ShaderStageFlagBits::CORE_SHADER_STAGE_VERTEX_BIT |
64 ShaderStageFlagBits::CORE_SHADER_STAGE_FRAGMENT_BIT |
65 ShaderStageFlagBits::CORE_SHADER_STAGE_COMPUTE_BIT));
66
67 bool valid = false;
68 if (createInfo.reflectionData.IsValid()) {
69 valid = true;
70 pipelineLayout_ = createInfo.reflectionData.GetPipelineLayout();
71
72 constants_ = createInfo.reflectionData.GetSpecializationConstants();
73 sscv_.constants = constants_;
74
75 if (shaderStageFlags_ == ShaderStageFlagBits::CORE_SHADER_STAGE_VERTEX_BIT) {
76 vertexInputAttributeDescriptions_ = createInfo.reflectionData.GetInputDescriptions();
77 for (const auto& attrib : vertexInputAttributeDescriptions_) {
78 VertexInputDeclaration::VertexInputBindingDescription bindingDesc;
79 bindingDesc.binding = attrib.binding;
80 bindingDesc.stride = GpuProgramUtil::FormatByteSize(attrib.format);
81 bindingDesc.vertexInputRate = VertexInputRate::CORE_VERTEX_INPUT_RATE_VERTEX;
82 vertexInputBindingDescriptions_.push_back(bindingDesc);
83 }
84 vidv_.bindingDescriptions = vertexInputBindingDescriptions_;
85 vidv_.attributeDescriptions = vertexInputAttributeDescriptions_;
86 } else if (shaderStageFlags_ == ShaderStageFlagBits::CORE_SHADER_STAGE_FRAGMENT_BIT) {
87 } else if (shaderStageFlags_ == ShaderStageFlagBits::CORE_SHADER_STAGE_COMPUTE_BIT) {
88 const Math::UVec3 tgs = createInfo.reflectionData.GetLocalSize();
89 stg_.x = tgs[0u];
90 stg_.y = tgs[1u];
91 stg_.z = tgs[2u];
92 } else {
93 PLUGIN_LOG_E("invalid shader stage flags for module creation");
94 valid = false;
95 }
96 }
97
98 // NOTE: sorting not needed?
99
100 if (valid) {
101 const VkDevice vkDevice = ((const DevicePlatformDataVk&)device_.GetPlatformData()).device;
102 plat_.shaderModule = CreateShaderModule(vkDevice, createInfo.spvData);
103 } else {
104 PLUGIN_LOG_E("invalid vulkan shader module");
105 }
106 }
107
~ShaderModuleVk()108 ShaderModuleVk::~ShaderModuleVk()
109 {
110 const VkDevice device = ((const DevicePlatformDataVk&)device_.GetPlatformData()).device;
111 if (plat_.shaderModule != VK_NULL_HANDLE) {
112 vkDestroyShaderModule(device, // device
113 plat_.shaderModule, // shaderModule
114 nullptr); // pAllocator
115 }
116 }
117
GetShaderStageFlags() const118 ShaderStageFlags ShaderModuleVk::GetShaderStageFlags() const
119 {
120 return shaderStageFlags_;
121 }
122
GetPlatformData() const123 const ShaderModulePlatformData& ShaderModuleVk::GetPlatformData() const
124 {
125 return plat_;
126 }
127
GetPipelineLayout() const128 const PipelineLayout& ShaderModuleVk::GetPipelineLayout() const
129 {
130 return pipelineLayout_;
131 }
132
GetSpecilization() const133 ShaderSpecializationConstantView ShaderModuleVk::GetSpecilization() const
134 {
135 return sscv_;
136 }
137
GetVertexInputDeclaration() const138 VertexInputDeclarationView ShaderModuleVk::GetVertexInputDeclaration() const
139 {
140 return vidv_;
141 }
142
GetThreadGroupSize() const143 ShaderThreadGroup ShaderModuleVk::GetThreadGroupSize() const
144 {
145 return stg_;
146 }
147 RENDER_END_NAMESPACE()
148