[Impeller] Fix 1-d grid computation for compute (#42516)

Note that the 2d grid case is still incorrect. Consider: the grid size should be the number of compute units required, but the threadgroup size is a minimum number of compute units.

If I need to process a 50x50 image, I should be able to set a grid size of 50x50. Since the minimum threadgroup size is probably bigger (say 1024), this should turn into one dispatch of size (1, 1, 1). However with the current implementation, we will make a dispatch of (50, 50, 1), which essentially squares the amount of work - doing one thread group per unit of compute.

The correct implementation for 2d compute should take the mod of each grid dimension with the threadgroup size in that dimension. I did not fix this case as we do not have a use for 2d compute yet.
diff --git a/ci/licenses_golden/licenses_flutter b/ci/licenses_golden/licenses_flutter
index 11b6dc4..e2eb8ee 100644
--- a/ci/licenses_golden/licenses_flutter
+++ b/ci/licenses_golden/licenses_flutter
@@ -1603,6 +1603,7 @@
 ORIGIN: ../../../flutter/impeller/renderer/stroke.comp + ../../../flutter/LICENSE
 ORIGIN: ../../../flutter/impeller/renderer/surface.cc + ../../../flutter/LICENSE
 ORIGIN: ../../../flutter/impeller/renderer/surface.h + ../../../flutter/LICENSE
+ORIGIN: ../../../flutter/impeller/renderer/threadgroup_sizing_test.comp + ../../../flutter/LICENSE
 ORIGIN: ../../../flutter/impeller/renderer/vertex_buffer_builder.cc + ../../../flutter/LICENSE
 ORIGIN: ../../../flutter/impeller/renderer/vertex_buffer_builder.h + ../../../flutter/LICENSE
 ORIGIN: ../../../flutter/impeller/renderer/vertex_descriptor.cc + ../../../flutter/LICENSE
@@ -4276,6 +4277,7 @@
 FILE: ../../../flutter/impeller/renderer/stroke.comp
 FILE: ../../../flutter/impeller/renderer/surface.cc
 FILE: ../../../flutter/impeller/renderer/surface.h
+FILE: ../../../flutter/impeller/renderer/threadgroup_sizing_test.comp
 FILE: ../../../flutter/impeller/renderer/vertex_buffer_builder.cc
 FILE: ../../../flutter/impeller/renderer/vertex_buffer_builder.h
 FILE: ../../../flutter/impeller/renderer/vertex_descriptor.cc
diff --git a/impeller/renderer/BUILD.gn b/impeller/renderer/BUILD.gn
index ec516cd..b43c71a 100644
--- a/impeller/renderer/BUILD.gn
+++ b/impeller/renderer/BUILD.gn
@@ -23,6 +23,7 @@
       "stroke.comp",
       "path_polyline.comp",
       "prefix_sum_test.comp",
+      "threadgroup_sizing_test.comp",
     ]
   }
 
diff --git a/impeller/renderer/backend/metal/compute_pass_mtl.mm b/impeller/renderer/backend/metal/compute_pass_mtl.mm
index 864b7d7..5cdf4a5 100644
--- a/impeller/renderer/backend/metal/compute_pass_mtl.mm
+++ b/impeller/renderer/backend/metal/compute_pass_mtl.mm
@@ -258,8 +258,10 @@
 
     // Special case for linear processing.
     if (height == 1) {
-      int64_t threadGroups =
-          std::max(width / maxTotalThreadsPerThreadgroup, 1LL);
+      int64_t threadGroups = std::max(
+          static_cast<int64_t>(
+              std::ceil(width * 1.0 / maxTotalThreadsPerThreadgroup * 1.0)),
+          1LL);
       [encoder dispatchThreadgroups:MTLSizeMake(threadGroups, 1, 1)
               threadsPerThreadgroup:MTLSizeMake(maxTotalThreadsPerThreadgroup,
                                                 1, 1)];
diff --git a/impeller/renderer/backend/vulkan/compute_pass_vk.cc b/impeller/renderer/backend/vulkan/compute_pass_vk.cc
index bb1651c..cb1074f 100644
--- a/impeller/renderer/backend/vulkan/compute_pass_vk.cc
+++ b/impeller/renderer/backend/vulkan/compute_pass_vk.cc
@@ -252,14 +252,22 @@
       int64_t width = grid_size.width;
       int64_t height = grid_size.height;
 
-      while (width > max_wg_size[0]) {
-        width = std::max(static_cast<int64_t>(1), width / 2);
+      // Special case for linear processing.
+      if (height == 1) {
+        int64_t minimum = 1;
+        int64_t threadGroups = std::max(
+            static_cast<int64_t>(std::ceil(width * 1.0 / max_wg_size[0] * 1.0)),
+            minimum);
+        cmd_buffer.dispatch(threadGroups, 1, 1);
+      } else {
+        while (width > max_wg_size[0]) {
+          width = std::max(static_cast<int64_t>(1), width / 2);
+        }
+        while (height > max_wg_size[1]) {
+          height = std::max(static_cast<int64_t>(1), height / 2);
+        }
+        cmd_buffer.dispatch(width, height, 1);
       }
-      while (height > max_wg_size[1]) {
-        height = std::max(static_cast<int64_t>(1), height / 2);
-      }
-
-      cmd_buffer.dispatch(width, height, 1);
     }
   }
 
diff --git a/impeller/renderer/backend/vulkan/context_vk.h b/impeller/renderer/backend/vulkan/context_vk.h
index 53f5175..4b691c8 100644
--- a/impeller/renderer/backend/vulkan/context_vk.h
+++ b/impeller/renderer/backend/vulkan/context_vk.h
@@ -136,6 +136,11 @@
   struct DeviceHolderImpl : public DeviceHolder {
     // |DeviceHolder|
     const vk::Device& GetDevice() const override { return device.get(); }
+    // |DeviceHolder|
+    const vk::PhysicalDevice& GetPhysicalDevice() const override {
+      return physical_device;
+    }
+
     vk::UniqueInstance instance;
     vk::PhysicalDevice physical_device;
     vk::UniqueDevice device;
diff --git a/impeller/renderer/backend/vulkan/device_holder.h b/impeller/renderer/backend/vulkan/device_holder.h
index cb9fdee..9086666 100644
--- a/impeller/renderer/backend/vulkan/device_holder.h
+++ b/impeller/renderer/backend/vulkan/device_holder.h
@@ -12,6 +12,7 @@
  public:
   virtual ~DeviceHolder() = default;
   virtual const vk::Device& GetDevice() const = 0;
+  virtual const vk::PhysicalDevice& GetPhysicalDevice() const = 0;
 };
 
 }  // namespace impeller
diff --git a/impeller/renderer/backend/vulkan/pipeline_library_vk.cc b/impeller/renderer/backend/vulkan/pipeline_library_vk.cc
index c2e26c7..201f88e 100644
--- a/impeller/renderer/backend/vulkan/pipeline_library_vk.cc
+++ b/impeller/renderer/backend/vulkan/pipeline_library_vk.cc
@@ -357,16 +357,34 @@
     return nullptr;
   }
 
-  vk::PipelineShaderStageCreateInfo info;
-  info.setStage(vk::ShaderStageFlagBits::eCompute);
-  info.setPName("main");
-  info.setModule(ShaderFunctionVK::Cast(entrypoint.get())->GetModule());
-  pipeline_info.setStage(info);
-
   std::shared_ptr<DeviceHolder> strong_device = device_holder_.lock();
   if (!strong_device) {
     return nullptr;
   }
+  auto device_properties = strong_device->GetPhysicalDevice().getProperties();
+  auto max_wg_size = device_properties.limits.maxComputeWorkGroupSize;
+
+  // Give all compute shaders a specialization constant entry for the
+  // workgroup/threadgroup size.
+  vk::SpecializationMapEntry specialization_map_entry[1];
+
+  uint32_t workgroup_size_x = max_wg_size[0];
+  specialization_map_entry[0].constantID = 0;
+  specialization_map_entry[0].offset = 0;
+  specialization_map_entry[0].size = sizeof(uint32_t);
+
+  vk::SpecializationInfo specialization_info;
+  specialization_info.mapEntryCount = 1;
+  specialization_info.pMapEntries = &specialization_map_entry[0];
+  specialization_info.dataSize = sizeof(uint32_t);
+  specialization_info.pData = &workgroup_size_x;
+
+  vk::PipelineShaderStageCreateInfo info;
+  info.setStage(vk::ShaderStageFlagBits::eCompute);
+  info.setPName("main");
+  info.setModule(ShaderFunctionVK::Cast(entrypoint.get())->GetModule());
+  info.setPSpecializationInfo(&specialization_info);
+  pipeline_info.setStage(info);
 
   //----------------------------------------------------------------------------
   /// Pipeline Layout a.k.a the descriptor sets and uniforms.
diff --git a/impeller/renderer/compute_pipeline_builder.h b/impeller/renderer/compute_pipeline_builder.h
index f4fe1b8..7a7d34c 100644
--- a/impeller/renderer/compute_pipeline_builder.h
+++ b/impeller/renderer/compute_pipeline_builder.h
@@ -45,9 +45,8 @@
     ComputePipelineDescriptor desc;
     if (InitializePipelineDescriptorDefaults(context, desc)) {
       return {std::move(desc)};
-    } else {
-      return std::nullopt;
     }
+    return std::nullopt;
   }
 
   [[nodiscard]] static bool InitializePipelineDescriptorDefaults(
diff --git a/impeller/renderer/compute_unittests.cc b/impeller/renderer/compute_unittests.cc
index 2ebb92c..67517a4 100644
--- a/impeller/renderer/compute_unittests.cc
+++ b/impeller/renderer/compute_unittests.cc
@@ -19,6 +19,7 @@
 #include "impeller/renderer/compute_pipeline_builder.h"
 #include "impeller/renderer/pipeline_library.h"
 #include "impeller/renderer/prefix_sum_test.comp.h"
+#include "impeller/renderer/threadgroup_sizing_test.comp.h"
 
 namespace impeller {
 namespace testing {
@@ -176,6 +177,59 @@
   latch.Wait();
 }
 
+TEST_P(ComputeTest, 1DThreadgroupSizingIsCorrect) {
+  using CS = ThreadgroupSizingTestComputeShader;
+  auto context = GetContext();
+  ASSERT_TRUE(context);
+  ASSERT_TRUE(context->GetCapabilities()->SupportsCompute());
+
+  using SamplePipelineBuilder = ComputePipelineBuilder<CS>;
+  auto pipeline_desc =
+      SamplePipelineBuilder::MakeDefaultPipelineDescriptor(*context);
+  ASSERT_TRUE(pipeline_desc.has_value());
+  auto compute_pipeline =
+      context->GetPipelineLibrary()->GetPipeline(pipeline_desc).Get();
+  ASSERT_TRUE(compute_pipeline);
+
+  auto cmd_buffer = context->CreateCommandBuffer();
+  auto pass = cmd_buffer->CreateComputePass();
+  ASSERT_TRUE(pass && pass->IsValid());
+
+  static constexpr size_t kCount = 2048;
+
+  pass->SetGridSize(ISize(kCount, 1));
+  pass->SetThreadGroupSize(ISize(kCount, 1));
+
+  ComputeCommand cmd;
+  cmd.label = "Compute";
+  cmd.pipeline = compute_pipeline;
+
+  auto output_buffer = CreateHostVisibleDeviceBuffer<CS::OutputData<kCount>>(
+      context, "Output Buffer");
+
+  CS::BindOutputData(cmd, output_buffer->AsBufferView());
+
+  ASSERT_TRUE(pass->AddCommand(std::move(cmd)));
+  ASSERT_TRUE(pass->EncodeCommands());
+
+  fml::AutoResetWaitableEvent latch;
+  ASSERT_TRUE(cmd_buffer->SubmitCommands(
+      [&latch, output_buffer](CommandBuffer::Status status) {
+        EXPECT_EQ(status, CommandBuffer::Status::kCompleted);
+
+        auto view = output_buffer->AsBufferView();
+        EXPECT_EQ(view.range.length, sizeof(CS::OutputData<kCount>));
+
+        CS::OutputData<kCount>* output =
+            reinterpret_cast<CS::OutputData<kCount>*>(view.contents);
+        EXPECT_TRUE(output);
+        EXPECT_EQ(output->data[kCount - 1], kCount - 1);
+        latch.Signal();
+      }));
+
+  latch.Wait();
+}
+
 TEST_P(ComputeTest, CanComputePrefixSumLargeInteractive) {
   using CS = PrefixSumTestComputeShader;
 
diff --git a/impeller/renderer/prefix_sum_test.comp b/impeller/renderer/prefix_sum_test.comp
index 0f8bff2..7cc940f 100644
--- a/impeller/renderer/prefix_sum_test.comp
+++ b/impeller/renderer/prefix_sum_test.comp
@@ -2,9 +2,7 @@
 // Use of this source code is governed by a BSD-style license that can be
 // found in the LICENSE file.
 
-// TODO(dnfield): This should not need to be so small,
-// https://github.com/flutter/flutter/issues/119357
-layout(local_size_x = 256, local_size_y = 1) in;
+layout(local_size_x_id = 0) in;
 layout(std430) buffer;
 
 #include <impeller/prefix_sum.glsl>
diff --git a/impeller/renderer/threadgroup_sizing_test.comp b/impeller/renderer/threadgroup_sizing_test.comp
new file mode 100644
index 0000000..3d2e02a
--- /dev/null
+++ b/impeller/renderer/threadgroup_sizing_test.comp
@@ -0,0 +1,18 @@
+// Copyright 2013 The Flutter Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+// Size is passed in via specialization constant.
+layout(local_size_x_id = 0) in;
+
+layout(std430) buffer;
+
+layout(binding = 1) writeonly buffer OutputData {
+  uint data[];
+}
+output_data;
+
+void main() {
+  uint ident = gl_GlobalInvocationID.x;
+  output_data.data[ident] = ident;
+}
diff --git a/impeller/tools/malioc.json b/impeller/tools/malioc.json
index 68afc86..e99ada0 100644
--- a/impeller/tools/malioc.json
+++ b/impeller/tools/malioc.json
@@ -13569,9 +13569,9 @@
               "load_store"
             ],
             "longest_path_cycles": [
-              2.65625,
+              2.450000047683716,
               0.0,
-              2.65625,
+              2.450000047683716,
               1.0,
               72.0,
               0.0
@@ -13589,9 +13589,9 @@
               "arith_cvt"
             ],
             "shortest_path_cycles": [
-              0.9375,
+              0.762499988079071,
               0.0,
-              0.9375,
+              0.762499988079071,
               0.0,
               0.0,
               0.0
@@ -13600,9 +13600,9 @@
               "load_store"
             ],
             "total_cycles": [
-              2.65625,
+              2.46875,
               0.0,
-              2.65625,
+              2.46875,
               1.0,
               72.0,
               0.0
@@ -13612,7 +13612,7 @@
           "stack_spill_bytes": 0,
           "thread_occupancy": 100,
           "uniform_registers_used": 8,
-          "work_registers_used": 17
+          "work_registers_used": 18
         }
       }
     }
@@ -13680,6 +13680,68 @@
       }
     }
   },
+  "flutter/impeller/renderer/threadgroup_sizing_test.comp.vkspv": {
+    "Mali-G78": {
+      "core": "Mali-G78",
+      "filename": "flutter/impeller/renderer/threadgroup_sizing_test.comp.vkspv",
+      "has_uniform_computation": true,
+      "type": "Compute",
+      "variants": {
+        "Main": {
+          "fp16_arithmetic": null,
+          "has_stack_spilling": false,
+          "performance": {
+            "longest_path_bound_pipelines": [
+              "load_store"
+            ],
+            "longest_path_cycles": [
+              0.03125,
+              0.0,
+              0.03125,
+              0.0,
+              1.0,
+              0.0
+            ],
+            "pipelines": [
+              "arith_total",
+              "arith_fma",
+              "arith_cvt",
+              "arith_sfu",
+              "load_store",
+              "texture"
+            ],
+            "shortest_path_bound_pipelines": [
+              "load_store"
+            ],
+            "shortest_path_cycles": [
+              0.03125,
+              0.0,
+              0.03125,
+              0.0,
+              1.0,
+              0.0
+            ],
+            "total_bound_pipelines": [
+              "load_store"
+            ],
+            "total_cycles": [
+              0.03125,
+              0.0,
+              0.03125,
+              0.0,
+              1.0,
+              0.0
+            ]
+          },
+          "shared_storage_used": 0,
+          "stack_spill_bytes": 0,
+          "thread_occupancy": 100,
+          "uniform_registers_used": 2,
+          "work_registers_used": 4
+        }
+      }
+    }
+  },
   "flutter/impeller/scene/shaders/gles/skinned.vert.gles": {
     "Mali-G78": {
       "core": "Mali-G78",