1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "gpu_program_vk.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <vulkan/vulkan_core.h>
21 
22 #include <base/containers/array_view.h>
23 #include <base/containers/vector.h>
24 #include <render/device/gpu_resource_desc.h>
25 #include <render/device/pipeline_layout_desc.h>
26 #include <render/namespace.h>
27 
28 #include "device/device.h"
29 #include "device/gpu_program_util.h"
30 #include "device/shader_module.h"
31 #include "util/log.h"
32 #include "vulkan/device_vk.h"
33 #include "vulkan/shader_module_vk.h"
34 #include "vulkan/validate_vk.h"
35 
36 using namespace BASE_NS;
37 
RENDER_BEGIN_NAMESPACE()38 RENDER_BEGIN_NAMESPACE()
39 GpuShaderProgramVk::GpuShaderProgramVk(Device& device, const GpuShaderProgramCreateData& createData)
40     : GpuShaderProgram()
41 {
42     PLUGIN_ASSERT(createData.vertShaderModule);
43     PLUGIN_ASSERT(createData.fragShaderModule);
44 
45     // combine vertex and fragment shader data
46     if (createData.vertShaderModule && createData.fragShaderModule) {
47         vertShaderModule_ = static_cast<ShaderModuleVk*>(createData.vertShaderModule);
48         fragShaderModule_ = static_cast<ShaderModuleVk*>(createData.fragShaderModule);
49         auto& pipelineLayout = reflection_.pipelineLayout;
50 
51         { // vert
52             const ShaderModuleVk& mod = *vertShaderModule_;
53             plat_.vert = ((const ShaderModulePlatformDataVk&)mod.GetPlatformData()).shaderModule;
54             pipelineLayout = mod.GetPipelineLayout();
55             const auto& sscv = mod.GetSpecilization();
56             // has sort inside
57             GpuProgramUtil::CombineSpecializationConstants(sscv.constants, constants_);
58 
59             // not owned, directly reflected from vertex shader module
60             const auto& vidv = mod.GetVertexInputDeclaration();
61             reflection_.vertexInputDeclarationView.bindingDescriptions = vidv.bindingDescriptions;
62             reflection_.vertexInputDeclarationView.attributeDescriptions = vidv.attributeDescriptions;
63         }
64         { // frag
65             const ShaderModuleVk& mod = *fragShaderModule_;
66             plat_.frag = ((const ShaderModulePlatformDataVk&)mod.GetPlatformData()).shaderModule;
67 
68             const auto& sscv = mod.GetSpecilization();
69             // has sort inside
70             GpuProgramUtil::CombineSpecializationConstants(sscv.constants, constants_);
71 
72             const auto& reflPl = mod.GetPipelineLayout();
73             // has sort inside
74             GpuProgramUtil::CombinePipelineLayouts({ &reflPl, 1u }, pipelineLayout);
75         }
76 
77         reflection_.shaderSpecializationConstantView.constants =
78             array_view<ShaderSpecialization::Constant const>(constants_.data(), constants_.size());
79     }
80 }
81 
GetPlatformData() const82 const GpuShaderProgramPlatformDataVk& GpuShaderProgramVk::GetPlatformData() const
83 {
84     return plat_;
85 }
86 
GetReflection() const87 const ShaderReflection& GpuShaderProgramVk::GetReflection() const
88 {
89     return reflection_;
90 }
91 
GpuComputeProgramVk(Device & device,const GpuComputeProgramCreateData & createData)92 GpuComputeProgramVk::GpuComputeProgramVk(Device& device, const GpuComputeProgramCreateData& createData)
93     : GpuComputeProgram()
94 {
95     PLUGIN_ASSERT(createData.compShaderModule);
96 
97     if (createData.compShaderModule) {
98         shaderModule_ = static_cast<ShaderModuleVk*>(createData.compShaderModule);
99         {
100             const ShaderModuleVk& mod = *shaderModule_;
101             plat_.comp = ((const ShaderModulePlatformDataVk&)mod.GetPlatformData()).shaderModule;
102             // copy needed data
103             reflection_.pipelineLayout = mod.GetPipelineLayout();
104             const auto& tgs = mod.GetThreadGroupSize();
105             reflection_.threadGroupSizeX = Math::max(1u, tgs.x);
106             reflection_.threadGroupSizeY = Math::max(1u, tgs.y);
107             reflection_.threadGroupSizeZ = Math::max(1u, tgs.z);
108             const auto& sscv = mod.GetSpecilization();
109             constants_ =
110                 vector<ShaderSpecialization::Constant>(sscv.constants.cbegin().ptr(), sscv.constants.cend().ptr());
111         }
112 
113         reflection_.shaderSpecializationConstantView.constants =
114             array_view<ShaderSpecialization::Constant const>(constants_.data(), constants_.size());
115     }
116 }
117 
GetPlatformData() const118 const GpuComputeProgramPlatformDataVk& GpuComputeProgramVk::GetPlatformData() const
119 {
120     return plat_;
121 }
122 
GetReflection() const123 const ComputeShaderReflection& GpuComputeProgramVk::GetReflection() const
124 {
125     return reflection_;
126 }
127 RENDER_END_NAMESPACE()
128