blob: 43d1ce34f759071595ad9a9e211b1e48b6750d68 [file] [log] [blame]
// Copyright 2013 The Flutter Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "impeller/renderer/backend/vulkan/compute_pass_vk.h"
#include "flutter/fml/trace_event.h"
#include "impeller/renderer/backend/vulkan/binding_helpers_vk.h"
#include "impeller/renderer/backend/vulkan/command_buffer_vk.h"
#include "impeller/renderer/backend/vulkan/compute_pipeline_vk.h"
#include "impeller/renderer/backend/vulkan/texture_vk.h"
namespace impeller {
ComputePassVK::ComputePassVK(std::weak_ptr<const Context> context,
std::weak_ptr<CommandBufferVK> command_buffer)
: ComputePass(std::move(context)),
command_buffer_(std::move(command_buffer)) {
is_valid_ = true;
}
ComputePassVK::~ComputePassVK() = default;
bool ComputePassVK::IsValid() const {
return is_valid_;
}
void ComputePassVK::OnSetLabel(const std::string& label) {
if (label.empty()) {
return;
}
label_ = label;
}
static bool UpdateBindingLayouts(const Bindings& bindings,
const vk::CommandBuffer& buffer) {
BarrierVK barrier;
barrier.cmd_buffer = buffer;
barrier.src_access = vk::AccessFlagBits::eTransferWrite;
barrier.src_stage = vk::PipelineStageFlagBits::eTransfer;
barrier.dst_access = vk::AccessFlagBits::eShaderRead;
barrier.dst_stage = vk::PipelineStageFlagBits::eComputeShader;
barrier.new_layout = vk::ImageLayout::eShaderReadOnlyOptimal;
for (const auto& [_, data] : bindings.sampled_images) {
if (!TextureVK::Cast(*data.texture.resource).SetLayout(barrier)) {
return false;
}
}
return true;
}
static bool UpdateBindingLayouts(const ComputeCommand& command,
const vk::CommandBuffer& buffer) {
return UpdateBindingLayouts(command.bindings, buffer);
}
static bool UpdateBindingLayouts(const std::vector<ComputeCommand>& commands,
const vk::CommandBuffer& buffer) {
for (const auto& command : commands) {
if (!UpdateBindingLayouts(command, buffer)) {
return false;
}
}
return true;
}
bool ComputePassVK::OnEncodeCommands(const Context& context,
const ISize& grid_size,
const ISize& thread_group_size) const {
TRACE_EVENT0("impeller", "ComputePassVK::EncodeCommands");
if (!IsValid()) {
return false;
}
FML_DCHECK(!grid_size.IsEmpty() && !thread_group_size.IsEmpty());
const auto& vk_context = ContextVK::Cast(context);
auto command_buffer = command_buffer_.lock();
if (!command_buffer) {
VALIDATION_LOG << "Command buffer died before commands could be encoded.";
return false;
}
auto encoder = command_buffer->GetEncoder();
if (!encoder) {
return false;
}
fml::ScopedCleanupClosure pop_marker(
[&encoder]() { encoder->PopDebugGroup(); });
if (!label_.empty()) {
encoder->PushDebugGroup(label_.c_str());
} else {
pop_marker.Release();
}
auto cmd_buffer = encoder->GetCommandBuffer();
if (!UpdateBindingLayouts(commands_, cmd_buffer)) {
VALIDATION_LOG << "Could not update binding layouts for compute pass.";
return false;
}
auto desc_sets_result =
AllocateAndBindDescriptorSets(vk_context, encoder, commands_);
if (!desc_sets_result.ok()) {
return false;
}
auto desc_sets = desc_sets_result.value();
TRACE_EVENT0("impeller", "EncodeComputePassCommands");
size_t desc_index = 0;
for (const auto& command : commands_) {
const auto& pipeline_vk = ComputePipelineVK::Cast(*command.pipeline);
cmd_buffer.bindPipeline(vk::PipelineBindPoint::eCompute,
pipeline_vk.GetPipeline());
cmd_buffer.bindDescriptorSets(
vk::PipelineBindPoint::eCompute, // bind point
pipeline_vk.GetPipelineLayout(), // layout
0, // first set
{vk::DescriptorSet{desc_sets[desc_index]}}, // sets
nullptr // offsets
);
// TOOD(dnfield): This should be moved to caps. But for now keeping this
// in parallel with Metal.
auto device_properties = vk_context.GetPhysicalDevice().getProperties();
auto max_wg_size = device_properties.limits.maxComputeWorkGroupSize;
int64_t width = grid_size.width;
int64_t height = grid_size.height;
// Special case for linear processing.
if (height == 1) {
int64_t minimum = 1;
int64_t threadGroups = std::max(
static_cast<int64_t>(std::ceil(width * 1.0 / max_wg_size[0] * 1.0)),
minimum);
cmd_buffer.dispatch(threadGroups, 1, 1);
} else {
while (width > max_wg_size[0]) {
width = std::max(static_cast<int64_t>(1), width / 2);
}
while (height > max_wg_size[1]) {
height = std::max(static_cast<int64_t>(1), height / 2);
}
cmd_buffer.dispatch(width, height, 1);
}
desc_index += 1;
}
return true;
}
} // namespace impeller