1 /*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "gpu_program_vk.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <vulkan/vulkan_core.h>
21
22 #include <base/containers/array_view.h>
23 #include <base/containers/vector.h>
24 #include <render/device/gpu_resource_desc.h>
25 #include <render/device/pipeline_layout_desc.h>
26 #include <render/namespace.h>
27
28 #include "device/device.h"
29 #include "device/gpu_program_util.h"
30 #include "device/shader_module.h"
31 #include "util/log.h"
32 #include "vulkan/device_vk.h"
33 #include "vulkan/shader_module_vk.h"
34 #include "vulkan/validate_vk.h"
35
36 using namespace BASE_NS;
37
RENDER_BEGIN_NAMESPACE()38 RENDER_BEGIN_NAMESPACE()
39 GpuShaderProgramVk::GpuShaderProgramVk(Device& device, const GpuShaderProgramCreateData& createData)
40 : GpuShaderProgram()
41 {
42 PLUGIN_ASSERT(createData.vertShaderModule);
43 PLUGIN_ASSERT(createData.fragShaderModule);
44
45 // combine vertex and fragment shader data
46 if (createData.vertShaderModule && createData.fragShaderModule) {
47 vertShaderModule_ = static_cast<ShaderModuleVk*>(createData.vertShaderModule);
48 fragShaderModule_ = static_cast<ShaderModuleVk*>(createData.fragShaderModule);
49 auto& pipelineLayout = reflection_.pipelineLayout;
50
51 { // vert
52 const ShaderModuleVk& mod = *vertShaderModule_;
53 plat_.vert = ((const ShaderModulePlatformDataVk&)mod.GetPlatformData()).shaderModule;
54 pipelineLayout = mod.GetPipelineLayout();
55 const auto& sscv = mod.GetSpecilization();
56 // has sort inside
57 GpuProgramUtil::CombineSpecializationConstants(sscv.constants, constants_);
58
59 // not owned, directly reflected from vertex shader module
60 const auto& vidv = mod.GetVertexInputDeclaration();
61 reflection_.vertexInputDeclarationView.bindingDescriptions = vidv.bindingDescriptions;
62 reflection_.vertexInputDeclarationView.attributeDescriptions = vidv.attributeDescriptions;
63 }
64 { // frag
65 const ShaderModuleVk& mod = *fragShaderModule_;
66 plat_.frag = ((const ShaderModulePlatformDataVk&)mod.GetPlatformData()).shaderModule;
67
68 const auto& sscv = mod.GetSpecilization();
69 // has sort inside
70 GpuProgramUtil::CombineSpecializationConstants(sscv.constants, constants_);
71
72 const auto& reflPl = mod.GetPipelineLayout();
73 // has sort inside
74 GpuProgramUtil::CombinePipelineLayouts({ &reflPl, 1u }, pipelineLayout);
75 }
76
77 reflection_.shaderSpecializationConstantView.constants =
78 array_view<ShaderSpecialization::Constant const>(constants_.data(), constants_.size());
79 }
80 }
81
GetPlatformData() const82 const GpuShaderProgramPlatformDataVk& GpuShaderProgramVk::GetPlatformData() const
83 {
84 return plat_;
85 }
86
GetReflection() const87 const ShaderReflection& GpuShaderProgramVk::GetReflection() const
88 {
89 return reflection_;
90 }
91
GpuComputeProgramVk(Device & device,const GpuComputeProgramCreateData & createData)92 GpuComputeProgramVk::GpuComputeProgramVk(Device& device, const GpuComputeProgramCreateData& createData)
93 : GpuComputeProgram()
94 {
95 PLUGIN_ASSERT(createData.compShaderModule);
96
97 if (createData.compShaderModule) {
98 shaderModule_ = static_cast<ShaderModuleVk*>(createData.compShaderModule);
99 {
100 const ShaderModuleVk& mod = *shaderModule_;
101 plat_.comp = ((const ShaderModulePlatformDataVk&)mod.GetPlatformData()).shaderModule;
102 // copy needed data
103 reflection_.pipelineLayout = mod.GetPipelineLayout();
104 const auto& tgs = mod.GetThreadGroupSize();
105 reflection_.threadGroupSizeX = Math::max(1u, tgs.x);
106 reflection_.threadGroupSizeY = Math::max(1u, tgs.y);
107 reflection_.threadGroupSizeZ = Math::max(1u, tgs.z);
108 const auto& sscv = mod.GetSpecilization();
109 constants_ =
110 vector<ShaderSpecialization::Constant>(sscv.constants.cbegin().ptr(), sscv.constants.cend().ptr());
111 }
112
113 reflection_.shaderSpecializationConstantView.constants =
114 array_view<ShaderSpecialization::Constant const>(constants_.data(), constants_.size());
115 }
116 }
117
GetPlatformData() const118 const GpuComputeProgramPlatformDataVk& GpuComputeProgramVk::GetPlatformData() const
119 {
120 return plat_;
121 }
122
GetReflection() const123 const ComputeShaderReflection& GpuComputeProgramVk::GetReflection() const
124 {
125 return reflection_;
126 }
127 RENDER_END_NAMESPACE()
128