Search code examples

WebGPU prefix-sum: issue with bind-group ping-pong

I'm implementing a simple Hillis and Steele prefix-sum algorithm in WebGPU.

It deliberately doesn't do anything fancy. I'm working with small arrays so that I can dispatch as many work-groups as there are entries (just for simplicity). For each iteration of the prefix-sum algorithm, I'm ping-pong'ing between two bind-groups, so that the output of one iteration ends up being the input to the next.

export async function prefix_sum(input: number[]) {
  const adapter = await navigator.gpu?.requestAdapter();
  const device = await adapter?.requestDevice();
  if (!device) throw new Error(`No support for WebGPU`);

   * Shader

  const shader = /*wgsl*/ `
        @group(0) @binding(0) var<storage, read> input: array<f32>;
        @group(0) @binding(1) var<storage, read_write> output: array<f32>;
        @group(0) @binding(2) var<uniform> iteration: u32;

        @compute @workgroup_size(1) fn main(
            @builtin(workgroup_id) workgroup_id : vec3<u32>,
            @builtin(global_invocation_id) global_invocation_id: vec3<u32>,
            @builtin(local_invocation_id) local_invocation_id: vec3<u32>
        ) {
            let i = iteration;
            let j = global_invocation_id.x;

            let p = u32(pow(2.0, f32(i)));
            if (j < p) {
                output[j] = input[j];
            } else {
                output[j] = input[j] + input[j - p];

   * Data

  const l = input.length;
  const iterations = Math.floor(Math.log2(l));

  const iterationsData = new Uint32Array([iterations]);
  const iterationBuffer = device.createBuffer({
    label: `prefix sum iterations count buffer`,
    size: iterationsData.byteLength,
    usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
  device.queue.writeBuffer(iterationBuffer, 0, iterationsData);

  const inputData = new Float32Array(input);
  const inputBuffer = device.createBuffer({
    label: `prefix sum input buffer`,
    size: inputData.byteLength,
    usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC,
  device.queue.writeBuffer(inputBuffer, 0, inputData);

  const outputBuffer = device.createBuffer({
    label: `prefix sum output buffer`,
    size: inputData.byteLength,
    usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC,

  const readBuffer = device.createBuffer({
    label: 'prefix sum read buffer',
    size: inputData.byteLength,
    usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,

  const module = device.createShaderModule({
    code: shader,

  const pipeline = device.createComputePipeline({
    layout: 'auto',
    compute: {

  const bindGroup1 = device.createBindGroup({
    label: `prefix sum bind group 1`,
    layout: pipeline.getBindGroupLayout(0),
    entries: [
      { binding: 0, resource: { buffer: inputBuffer } },
      { binding: 1, resource: { buffer: outputBuffer } },
      { binding: 2, resource: { buffer: iterationBuffer } },

  const bindGroup2 = device.createBindGroup({
    label: `prefix sum bind group 2`,
    layout: pipeline.getBindGroupLayout(0),
    entries: [
      { binding: 0, resource: { buffer: outputBuffer } },
      { binding: 1, resource: { buffer: inputBuffer } },
      { binding: 2, resource: { buffer: iterationBuffer } },

   * Execute

  const encoder = device.createCommandEncoder({ label: `prefix sum encoder` });

  for (let iteration = 0; iteration <= iterations; iteration++) {
    device.queue.writeBuffer(iterationBuffer, 0, new Uint32Array([iteration]));
    const pass = encoder.beginComputePass({ label: `prefix sum pass iteration ${iteration}` });
    if (iteration % 2 === 0) {
      pass.setBindGroup(0, bindGroup1);
    } else {
      pass.setBindGroup(0, bindGroup2);

   * Output

  const lastOutputBuffer = iterations % 2 === 0 ? inputBuffer : outputBuffer;
  encoder.copyBufferToBuffer(lastOutputBuffer, 0, readBuffer, 0, readBuffer.size);
  const commandBuffer = encoder.finish();

  await readBuffer.mapAsync(GPUMapMode.READ);
  const result = new Float32Array(readBuffer.getMappedRange());
  //   readBuffer.unmap();

  return structuredClone(result);

I'm observing some weird behaviour, though.

For the input [1, 2, 3, 4, 5, 6, 7, 8], I'm getting:

  • 0th iteration:
    • input: [1, 2, 3, 4, 5, 6, 7, 8] as expected,
    • output: [1, 3, 5, 7, 9, 11, 13, 15] as expected
  • 1st iteration:
    • input: [1, 2, 4, 6, 8, 10, 12, 14]
      • should be identical to the output from the 0th iteration
      • instead is as if pass (with the uniform iteration = 1) had been applied with the input-buffer from the 0th iteration as input and the input-buffer to the 1st iteration as output
    • output: [1, 2, 5, 8, 12, 16, 20, 24] ... as expected given the false input
  • 2nd iteration:
    • input: [1, 2, 3, 4, 7, 10, 13, 16]
      • should be identical to the output from the 1st iteration
      • instead is as if pass (with the uniform iteration = 2) had been applied with the input-buffer from the 0th iteration as input and the input-buffer to the 2nd iteration as output
    • output: [1, 2, 3, 4, 8, 12, 16, 20] ... again as expected given the false input

I tried a few things that I thought might have caused my issue.

  • Hypothesis 1: you can't have multiple passes in an encoder
    • so I moved to one encoder per iteration - but that didn't work, either, and yielded the same results.
    • also, I know from here that multiple passes per encoder are allowed, and so is bind-group ping-pong.
  • Hypothesis 2: must not share in/out buffer between passes
    • so I've created separate buffers for each pass and copied the data from the last iteration's output to current iteration's input. Still doesn't work, though, and yields the same results.


  • I don't know if this is the only issue but one issue I see is you need a iterationBuffer per iteration.

    As it is, this code

      for (let iteration = 0; iteration < iterations; iteration++) {
        device.queue.writeBuffer(iterationBuffer, 0, new Uint32Array([iteration]));
        const pass = encoder.beginComputePass({ label: `prefix sum pass iteration ${iteration}` });
        if (iteration % 2 === 0) {
          pass.setBindGroup(0, bindGroup1);
        } else {
          pass.setBindGroup(0, bindGroup2);

    Is effectively doing this

      const interationBuffer = [];
      for (let iteration = 0; iteration < iterations; iteration++) {
        iterationBuffer[0] = iteration
        encode commands
      execute (submit) commands

    At the execute commands point above, iterationBuffer only contains the last value you wrote into it. If you want a different value for iteration each time you call dispatchWorkgroup in the same submit you need a different buffer to hold each value.