1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "render_node_compute_generic.h"
17 
18 #include <base/math/mathf.h>
19 #include <base/math/vector.h>
20 #include <render/datastore/intf_render_data_store_manager.h>
21 #include <render/datastore/intf_render_data_store_pod.h>
22 #include <render/device/intf_gpu_resource_manager.h>
23 #include <render/device/intf_shader_manager.h>
24 #include <render/namespace.h>
25 #include <render/nodecontext/intf_node_context_descriptor_set_manager.h>
26 #include <render/nodecontext/intf_node_context_pso_manager.h>
27 #include <render/nodecontext/intf_pipeline_descriptor_set_binder.h>
28 #include <render/nodecontext/intf_render_command_list.h>
29 #include <render/nodecontext/intf_render_node_context_manager.h>
30 #include <render/nodecontext/intf_render_node_parser_util.h>
31 #include <render/nodecontext/intf_render_node_util.h>
32 
33 #include "util/log.h"
34 
35 using namespace BASE_NS;
36 
37 RENDER_BEGIN_NAMESPACE()
38 namespace {
39 struct DispatchResources {
40     RenderHandle buffer {};
41     RenderHandle image {};
42 };
43 
GetDispatchResources(const RenderNodeHandles::InputResources & ir)44 DispatchResources GetDispatchResources(const RenderNodeHandles::InputResources& ir)
45 {
46     DispatchResources dr;
47     if (!ir.customInputBuffers.empty()) {
48         dr.buffer = ir.customInputBuffers[0].handle;
49     }
50     if (!ir.customInputImages.empty()) {
51         dr.image = ir.customInputImages[0].handle;
52     }
53     return dr;
54 }
55 } // namespace
56 
InitNode(IRenderNodeContextManager & renderNodeContextMgr)57 void RenderNodeComputeGeneric::InitNode(IRenderNodeContextManager& renderNodeContextMgr)
58 {
59     renderNodeContextMgr_ = &renderNodeContextMgr;
60     ParseRenderNodeInputs();
61 
62     useDataStoreShaderSpecialization_ = !jsonInputs_.renderDataStoreSpecialization.dataStoreName.empty();
63 
64     auto& shaderMgr = renderNodeContextMgr.GetShaderManager();
65     const auto& renderNodeUtil = renderNodeContextMgr.GetRenderNodeUtil();
66     if (RenderHandleUtil::GetHandleType(shader_) != RenderHandleType::COMPUTE_SHADER_STATE_OBJECT) {
67         PLUGIN_LOG_E("RenderNodeComputeGeneric needs a valid compute shader handle");
68     }
69     pipelineLayout_ = renderNodeContextMgr.GetRenderNodeUtil().CreatePipelineLayout(shader_);
70     threadGroupSize_ = shaderMgr.GetReflectionThreadGroupSize(shader_);
71 
72     if (dispatchResources_.customInputBuffers.empty() && dispatchResources_.customInputImages.empty()) {
73         PLUGIN_LOG_W("RenderNodeComputeGeneric: dispatchResources (GPU buffer or GPU image) needed");
74     }
75 
76     if (useDataStoreShaderSpecialization_) {
77         const ShaderSpecializationConstantView sscv = shaderMgr.GetReflectionSpecialization(shader_);
78         shaderSpecializationData_.constants.resize(sscv.constants.size());
79         shaderSpecializationData_.data.resize(sscv.constants.size());
80         for (size_t idx = 0; idx < shaderSpecializationData_.constants.size(); ++idx) {
81             shaderSpecializationData_.constants[idx] = sscv.constants[idx];
82             shaderSpecializationData_.data[idx] = ~0u;
83         }
84         useDataStoreShaderSpecialization_ = !sscv.constants.empty();
85     }
86     psoHandle_ = renderNodeContextMgr.GetPsoManager().GetComputePsoHandle(shader_, pipelineLayout_, {});
87 
88     {
89         const DescriptorCounts dc = renderNodeUtil.GetDescriptorCounts(pipelineLayout_);
90         renderNodeContextMgr.GetDescriptorSetManager().ResetAndReserve(dc);
91     }
92 
93     pipelineDescriptorSetBinder_ = renderNodeUtil.CreatePipelineDescriptorSetBinder(pipelineLayout_);
94     renderNodeUtil.BindResourcesToBinder(inputResources_, *pipelineDescriptorSetBinder_);
95 
96     useDataStorePushConstant_ = (pipelineLayout_.pushConstant.byteSize > 0) &&
97                                 (!jsonInputs_.renderDataStore.dataStoreName.empty()) &&
98                                 (!jsonInputs_.renderDataStore.configurationName.empty());
99 }
100 
PreExecuteFrame()101 void RenderNodeComputeGeneric::PreExecuteFrame()
102 {
103     // re-create needed gpu resources
104 }
105 
ExecuteFrame(IRenderCommandList & cmdList)106 void RenderNodeComputeGeneric::ExecuteFrame(IRenderCommandList& cmdList)
107 {
108     if (!RenderHandleUtil::IsValid(shader_)) {
109         return; // invalid shader
110     }
111 
112     const auto& renderNodeUtil = renderNodeContextMgr_->GetRenderNodeUtil();
113     if (jsonInputs_.hasChangeableResourceHandles) {
114         inputResources_ = renderNodeUtil.CreateInputResources(jsonInputs_.resources);
115         renderNodeUtil.BindResourcesToBinder(inputResources_, *pipelineDescriptorSetBinder_);
116     }
117     if (jsonInputs_.hasChangeableDispatchHandles) {
118         dispatchResources_ = renderNodeUtil.CreateInputResources(jsonInputs_.dispatchResources);
119     }
120     const DispatchResources dr = GetDispatchResources(dispatchResources_);
121     if ((!RenderHandleUtil::IsValid(dr.buffer)) && (!RenderHandleUtil::IsValid(dr.image))) {
122 #if (RENDER_VALIDATION_ENABLED == 1)
123         PLUGIN_LOG_ONCE_W(renderNodeContextMgr_->GetName() + "_no_dr",
124             "RENDER_VALIDATION: RN: %s, no valid dispatch resource", renderNodeContextMgr_->GetName().data());
125 #endif
126         return; // no way to evaluate dispatch size
127     }
128     const uint32_t firstSetIndex = pipelineDescriptorSetBinder_->GetFirstSet();
129     {
130         const auto setIndices = pipelineDescriptorSetBinder_->GetSetIndices();
131         for (auto refIndex : setIndices) {
132             const auto descHandle = pipelineDescriptorSetBinder_->GetDescriptorSetHandle(refIndex);
133             const auto bindings = pipelineDescriptorSetBinder_->GetDescriptorSetLayoutBindingResources(refIndex);
134             cmdList.UpdateDescriptorSet(descHandle, bindings);
135         }
136 #if (RENDER_VALIDATION_ENABLED == 1)
137         if (!pipelineDescriptorSetBinder_->GetPipelineDescriptorSetLayoutBindingValidity()) {
138             PLUGIN_LOG_ONCE_E(renderNodeContextMgr_->GetName() + "_bindings_missing",
139                 "RENDER_VALIDATION: RenderNodeComputeGeneric: bindings missing (RN: %s)",
140                 renderNodeContextMgr_->GetName().data());
141         }
142 #endif
143     }
144 
145     const RenderHandle psoHandle = GetPsoHandle(*renderNodeContextMgr_);
146     cmdList.BindPipeline(psoHandle);
147 
148     // bind all sets
149     {
150         const auto descHandles = pipelineDescriptorSetBinder_->GetDescriptorSetHandles();
151         cmdList.BindDescriptorSets(firstSetIndex, descHandles);
152     }
153 
154     // push constants
155     if (useDataStorePushConstant_) {
156         const auto& renderDataStoreMgr = renderNodeContextMgr_->GetRenderDataStoreManager();
157         const auto dataStore = static_cast<IRenderDataStorePod const*>(
158             renderDataStoreMgr.GetRenderDataStore(jsonInputs_.renderDataStore.dataStoreName.c_str()));
159         if (dataStore) {
160             const auto dataView = dataStore->Get(jsonInputs_.renderDataStore.configurationName);
161             if (!dataView.empty()) {
162                 cmdList.PushConstant(pipelineLayout_.pushConstant, dataView.data());
163             }
164         }
165     }
166 
167     if (RenderHandleUtil::IsValid(dr.buffer)) {
168         cmdList.DispatchIndirect(dr.buffer, 0);
169     } else if (RenderHandleUtil::IsValid(dr.image)) {
170         const IRenderNodeGpuResourceManager& gpuResourceMgr = renderNodeContextMgr_->GetGpuResourceManager();
171         const GpuImageDesc desc = gpuResourceMgr.GetImageDescriptor(dr.image);
172         const Math::UVec3 targetSize = { desc.width, desc.height, desc.depth };
173         cmdList.Dispatch((targetSize.x + threadGroupSize_.x - 1u) / threadGroupSize_.x,
174             (targetSize.y + threadGroupSize_.y - 1u) / threadGroupSize_.y,
175             (targetSize.z + threadGroupSize_.z - 1u) / threadGroupSize_.z);
176     }
177 }
178 
GetPsoHandle(IRenderNodeContextManager & renderNodeContextMgr)179 RenderHandle RenderNodeComputeGeneric::GetPsoHandle(IRenderNodeContextManager& renderNodeContextMgr)
180 {
181     if (useDataStoreShaderSpecialization_) {
182         const auto& renderDataStoreMgr = renderNodeContextMgr.GetRenderDataStoreManager();
183         const auto dataStore = static_cast<IRenderDataStorePod const*>(
184             renderDataStoreMgr.GetRenderDataStore(jsonInputs_.renderDataStoreSpecialization.dataStoreName.c_str()));
185         if (dataStore) {
186             const auto dataView = dataStore->Get(jsonInputs_.renderDataStoreSpecialization.configurationName);
187             if (dataView.data() && (dataView.size_bytes() == sizeof(ShaderSpecializationRenderPod))) {
188                 const auto* spec = reinterpret_cast<const ShaderSpecializationRenderPod*>(dataView.data());
189                 bool valuesChanged = false;
190                 const auto specializationCount = Math::min(
191                     ShaderSpecializationRenderPod::MAX_SPECIALIZATION_CONSTANT_COUNT,
192                     Math::min((uint32_t)shaderSpecializationData_.constants.size(), spec->specializationConstantCount));
193                 const auto constantsView = array_view(shaderSpecializationData_.constants.data(), specializationCount);
194                 for (const auto& ref : constantsView) {
195                     const uint32_t constantId = ref.offset / sizeof(uint32_t);
196                     const uint32_t specId = ref.id;
197                     if (specId < ShaderSpecializationRenderPod::MAX_SPECIALIZATION_CONSTANT_COUNT) {
198                         if (shaderSpecializationData_.data[constantId] != spec->specializationFlags[specId].value) {
199                             shaderSpecializationData_.data[constantId] = spec->specializationFlags[specId].value;
200                             valuesChanged = true;
201                         }
202                     }
203                 }
204                 if (valuesChanged) {
205                     const ShaderSpecializationConstantDataView specialization {
206                         constantsView,
207                         { shaderSpecializationData_.data.data(), specializationCount },
208                     };
209                     psoHandle_ = renderNodeContextMgr.GetPsoManager().GetComputePsoHandle(
210                         shader_, pipelineLayout_, specialization);
211                 }
212             } else {
213 #if (RENDER_VALIDATION_ENABLED == 1)
214                 const string logName = "RenderNodeComputeGeneric_ShaderSpecialization" +
215                                        string(jsonInputs_.renderDataStoreSpecialization.configurationName);
216                 PLUGIN_LOG_ONCE_E(logName.c_str(),
217                     "RENDER_VALIDATION: RenderNodeComputeGeneric shader specilization render data store size mismatch, "
218                     "name: %s, size:%u, podsize%u",
219                     jsonInputs_.renderDataStoreSpecialization.configurationName.c_str(),
220                     static_cast<uint32_t>(sizeof(ShaderSpecializationRenderPod)),
221                     static_cast<uint32_t>(dataView.size_bytes()));
222 #endif
223             }
224         }
225     }
226     return psoHandle_;
227 }
228 
ParseRenderNodeInputs()229 void RenderNodeComputeGeneric::ParseRenderNodeInputs()
230 {
231     const IRenderNodeParserUtil& parserUtil = renderNodeContextMgr_->GetRenderNodeParserUtil();
232     const auto jsonVal = renderNodeContextMgr_->GetNodeJson();
233     jsonInputs_.resources = parserUtil.GetInputResources(jsonVal, "resources");
234     jsonInputs_.dispatchResources = parserUtil.GetInputResources(jsonVal, "dispatchResources");
235     jsonInputs_.renderDataStore = parserUtil.GetRenderDataStore(jsonVal, "renderDataStore");
236     jsonInputs_.renderDataStoreSpecialization =
237         parserUtil.GetRenderDataStore(jsonVal, "renderDataStoreShaderSpecialization");
238 
239     const auto shaderName = parserUtil.GetStringValue(jsonVal, "shader");
240     const IRenderNodeShaderManager& shaderMgr = renderNodeContextMgr_->GetShaderManager();
241     shader_ = shaderMgr.GetShaderHandle(shaderName);
242 
243     const auto& renderNodeUtil = renderNodeContextMgr_->GetRenderNodeUtil();
244     inputResources_ = renderNodeUtil.CreateInputResources(jsonInputs_.resources);
245     dispatchResources_ = renderNodeUtil.CreateInputResources(jsonInputs_.dispatchResources);
246     jsonInputs_.hasChangeableResourceHandles = renderNodeUtil.HasChangeableResources(jsonInputs_.resources);
247     jsonInputs_.hasChangeableDispatchHandles = renderNodeUtil.HasChangeableResources(jsonInputs_.dispatchResources);
248 }
249 
250 // for plugin / factory interface
Create()251 IRenderNode* RenderNodeComputeGeneric::Create()
252 {
253     return new RenderNodeComputeGeneric();
254 }
255 
Destroy(IRenderNode * instance)256 void RenderNodeComputeGeneric::Destroy(IRenderNode* instance)
257 {
258     delete static_cast<RenderNodeComputeGeneric*>(instance);
259 }
260 RENDER_END_NAMESPACE()
261