aco: Implement tessellation control shader input/output.
authorTimur Kristóf <timur.kristof@gmail.com>
Fri, 21 Feb 2020 16:46:15 +0000 (17:46 +0100)
committerMarge Bot <eric+marge@anholt.net>
Wed, 11 Mar 2020 08:34:10 +0000 (08:34 +0000)
Tessellation control shaders can have per-vertex inputs,
and both per-vertex and per-patch outputs. TCS can not only store,
but also load their outputs.

The TCS outputs are stored in RING_HS_TESS_OFFCHIP in VMEM, which
is where the TES reads them from. Additionally, the are also stored
in LDS to make sure they can be loaded fast when read by the TCS.

Tessellation factors are always just stored in LDS.
At the end of the shader, the first shader invocation reads these
from LDS and writes them to RING_HS_TESS_FACTOR in VMEM, and
additionally to RING_HS_TESS_OFFCHIP when they are read by
the Tessellation Evaluation Shader.

This implementation matches the memory layouts used by radv_nir_to_llvm.

Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/3964>

src/amd/compiler/aco_instruction_selection.cpp
src/amd/compiler/aco_instruction_selection_setup.cpp

index ba798eb9eacc0388051f2354412c099a5c003955..ed95655993fbae335019ac01e75163a91c9e0859 100644 (file)
@@ -3156,6 +3156,122 @@ std::pair<Temp, unsigned> get_intrinsic_io_basic_offset(isel_context *ctx, nir_i
    return get_intrinsic_io_basic_offset(ctx, instr, stride, stride);
 }
 
+Temp get_tess_rel_patch_id(isel_context *ctx)
+{
+   Builder bld(ctx->program, ctx->block);
+
+   switch (ctx->shader->info.stage) {
+   case MESA_SHADER_TESS_CTRL:
+      return bld.vop2(aco_opcode::v_and_b32, bld.def(v1), Operand(0xffu),
+                      get_arg(ctx, ctx->args->ac.tcs_rel_ids));
+   case MESA_SHADER_TESS_EVAL:
+      return get_arg(ctx, ctx->args->tes_rel_patch_id);
+   default:
+      unreachable("Unsupported stage in get_tess_rel_patch_id");
+   }
+}
+
+std::pair<Temp, unsigned> get_tcs_per_vertex_input_lds_offset(isel_context *ctx, nir_intrinsic_instr *instr)
+{
+   assert(ctx->shader->info.stage == MESA_SHADER_TESS_CTRL);
+   Builder bld(ctx->program, ctx->block);
+
+   uint32_t tcs_in_patch_stride = ctx->args->options->key.tcs.input_vertices * ctx->tcs_num_inputs * 4;
+   uint32_t tcs_in_vertex_stride = ctx->tcs_num_inputs * 4;
+
+   std::pair<Temp, unsigned> offs = get_intrinsic_io_basic_offset(ctx, instr);
+
+   nir_src *vertex_index_src = nir_get_io_vertex_index_src(instr);
+   offs = offset_add_from_nir(ctx, offs, vertex_index_src, tcs_in_vertex_stride);
+
+   Temp rel_patch_id = get_tess_rel_patch_id(ctx);
+   Temp tcs_in_current_patch_offset = bld.v_mul24_imm(bld.def(v1), rel_patch_id, tcs_in_patch_stride);
+   offs = offset_add(ctx, offs, std::make_pair(tcs_in_current_patch_offset, 0));
+
+   return offset_mul(ctx, offs, 4u);
+}
+
+std::pair<Temp, unsigned> get_tcs_output_lds_offset(isel_context *ctx, nir_intrinsic_instr *instr = nullptr, bool per_vertex = false)
+{
+   assert(ctx->shader->info.stage == MESA_SHADER_TESS_CTRL);
+   Builder bld(ctx->program, ctx->block);
+
+   uint32_t input_patch_size = ctx->args->options->key.tcs.input_vertices * ctx->tcs_num_inputs * 16;
+   uint32_t num_tcs_outputs = util_last_bit64(ctx->args->shader_info->tcs.outputs_written);
+   uint32_t num_tcs_patch_outputs = util_last_bit64(ctx->args->shader_info->tcs.patch_outputs_written);
+   uint32_t output_vertex_size = num_tcs_outputs * 16;
+   uint32_t pervertex_output_patch_size = ctx->shader->info.tess.tcs_vertices_out * output_vertex_size;
+   uint32_t output_patch_stride = pervertex_output_patch_size + num_tcs_patch_outputs * 16;
+
+   std::pair<Temp, unsigned> offs = instr
+                                    ? get_intrinsic_io_basic_offset(ctx, instr, 4u)
+                                    : std::make_pair(Temp(), 0u);
+
+   Temp rel_patch_id = get_tess_rel_patch_id(ctx);
+   Temp patch_off = bld.v_mul24_imm(bld.def(v1), rel_patch_id, output_patch_stride);
+
+   if (per_vertex) {
+      assert(instr);
+
+      nir_src *vertex_index_src = nir_get_io_vertex_index_src(instr);
+      offs = offset_add_from_nir(ctx, offs, vertex_index_src, output_vertex_size);
+
+      uint32_t output_patch0_offset = (input_patch_size * ctx->tcs_num_patches);
+      offs = offset_add(ctx, offs, std::make_pair(patch_off, output_patch0_offset));
+   } else {
+      uint32_t output_patch0_patch_data_offset = (input_patch_size * ctx->tcs_num_patches + pervertex_output_patch_size);
+      offs = offset_add(ctx, offs, std::make_pair(patch_off, output_patch0_patch_data_offset));
+   }
+
+   return offs;
+}
+
+std::pair<Temp, unsigned> get_tcs_per_vertex_output_vmem_offset(isel_context *ctx, nir_intrinsic_instr *instr)
+{
+   Builder bld(ctx->program, ctx->block);
+
+   unsigned vertices_per_patch = ctx->shader->info.tess.tcs_vertices_out;
+   unsigned attr_stride = vertices_per_patch * ctx->tcs_num_patches;
+
+   std::pair<Temp, unsigned> offs = get_intrinsic_io_basic_offset(ctx, instr, attr_stride * 4u, 4u);
+
+   Temp rel_patch_id = get_tess_rel_patch_id(ctx);
+   Temp patch_off = bld.v_mul24_imm(bld.def(v1), rel_patch_id, vertices_per_patch * 16u);
+   offs = offset_add(ctx, offs, std::make_pair(patch_off, 0u));
+
+   nir_src *vertex_index_src = nir_get_io_vertex_index_src(instr);
+   offs = offset_add_from_nir(ctx, offs, vertex_index_src, 16u);
+
+   return offs;
+}
+
+std::pair<Temp, unsigned> get_tcs_per_patch_output_vmem_offset(isel_context *ctx, nir_intrinsic_instr *instr = nullptr, unsigned const_base_offset = 0u)
+{
+   Builder bld(ctx->program, ctx->block);
+
+   unsigned num_tcs_outputs = ctx->shader->info.stage == MESA_SHADER_TESS_CTRL
+                              ? util_last_bit64(ctx->args->shader_info->tcs.outputs_written)
+                              : ctx->args->options->key.tes.tcs_num_outputs;
+
+   unsigned output_vertex_size = num_tcs_outputs * 16;
+   unsigned per_vertex_output_patch_size = ctx->shader->info.tess.tcs_vertices_out * output_vertex_size;
+   unsigned per_patch_data_offset = per_vertex_output_patch_size * ctx->tcs_num_patches;
+   unsigned attr_stride = ctx->tcs_num_patches;
+
+   std::pair<Temp, unsigned> offs = instr
+                                    ? get_intrinsic_io_basic_offset(ctx, instr, attr_stride * 4u, 4u)
+                                    : std::make_pair(Temp(), 0u);
+
+   if (const_base_offset)
+      offs.second += const_base_offset * attr_stride;
+
+   Temp rel_patch_id = get_tess_rel_patch_id(ctx);
+   Temp patch_off = bld.v_mul_imm(bld.def(v1), rel_patch_id, 16u);
+   offs = offset_add(ctx, offs, std::make_pair(patch_off, per_patch_data_offset));
+
+   return offs;
+}
+
 void visit_store_ls_or_es_output(isel_context *ctx, nir_intrinsic_instr *instr)
 {
    Builder bld(ctx->program, ctx->block);
@@ -3192,6 +3308,54 @@ void visit_store_ls_or_es_output(isel_context *ctx, nir_intrinsic_instr *instr)
    }
 }
 
+void visit_store_tcs_output(isel_context *ctx, nir_intrinsic_instr *instr, bool per_vertex)
+{
+   assert(ctx->stage == tess_control_hs || ctx->stage == vertex_tess_control_hs);
+   assert(ctx->shader->info.stage == MESA_SHADER_TESS_CTRL);
+
+   Builder bld(ctx->program, ctx->block);
+
+   Temp store_val = get_ssa_temp(ctx, instr->src[0].ssa);
+   unsigned elem_size_bytes = instr->src[0].ssa->bit_size / 8;
+   unsigned write_mask = nir_intrinsic_write_mask(instr);
+
+   /* TODO: Only write to VMEM if the output is per-vertex or it's per-patch non tess factor */
+   bool write_to_vmem = true;
+   /* TODO: Only write to LDS if the output is read by the shader, or it's per-patch tess factor */
+   bool write_to_lds = true;
+
+   if (write_to_vmem) {
+      std::pair<Temp, unsigned> vmem_offs = per_vertex
+                                            ? get_tcs_per_vertex_output_vmem_offset(ctx, instr)
+                                            : get_tcs_per_patch_output_vmem_offset(ctx, instr);
+
+      Temp hs_ring_tess_offchip = bld.smem(aco_opcode::s_load_dwordx4, bld.def(s4), ctx->program->private_segment_buffer, Operand(RING_HS_TESS_OFFCHIP * 16u));
+      Temp oc_lds = get_arg(ctx, ctx->args->oc_lds);
+      store_vmem_mubuf(ctx, store_val, hs_ring_tess_offchip, vmem_offs.first, oc_lds, vmem_offs.second, elem_size_bytes, write_mask, false, false);
+   }
+
+   if (write_to_lds) {
+      std::pair<Temp, unsigned> lds_offs = get_tcs_output_lds_offset(ctx, instr, per_vertex);
+      unsigned lds_align = calculate_lds_alignment(ctx, lds_offs.second);
+      store_lds(ctx, elem_size_bytes, store_val, write_mask, lds_offs.first, lds_offs.second, lds_align);
+   }
+}
+
+void visit_load_tcs_output(isel_context *ctx, nir_intrinsic_instr *instr, bool per_vertex)
+{
+   assert(ctx->stage == tess_control_hs || ctx->stage == vertex_tess_control_hs);
+   assert(ctx->shader->info.stage == MESA_SHADER_TESS_CTRL);
+
+   Builder bld(ctx->program, ctx->block);
+
+   Temp dst = get_ssa_temp(ctx, &instr->dest.ssa);
+   std::pair<Temp, unsigned> lds_offs = get_tcs_output_lds_offset(ctx, instr, per_vertex);
+   unsigned lds_align = calculate_lds_alignment(ctx, lds_offs.second);
+   unsigned elem_size_bytes = instr->src[0].ssa->bit_size / 8;
+
+   load_lds(ctx, elem_size_bytes, dst, lds_offs.first, lds_offs.second, lds_align);
+}
+
 void visit_store_output(isel_context *ctx, nir_intrinsic_instr *instr)
 {
    if (ctx->stage == vertex_vs ||
@@ -3223,11 +3387,18 @@ void visit_store_output(isel_context *ctx, nir_intrinsic_instr *instr)
    } else if (ctx->stage == vertex_es ||
               (ctx->stage == vertex_geometry_gs && ctx->shader->info.stage == MESA_SHADER_VERTEX)) {
       visit_store_ls_or_es_output(ctx, instr);
+   } else if (ctx->shader->info.stage == MESA_SHADER_TESS_CTRL) {
+      visit_store_tcs_output(ctx, instr, false);
    } else {
       unreachable("Shader stage not implemented");
    }
 }
 
+void visit_load_output(isel_context *ctx, nir_intrinsic_instr *instr)
+{
+   visit_load_tcs_output(ctx, instr, false);
+}
+
 void emit_interp_instr(isel_context *ctx, unsigned idx, unsigned component, Temp src, Temp dst, Temp prim_mask)
 {
    Temp coord1 = emit_extract_vector(ctx, src, 0, v1);
@@ -3713,17 +3884,46 @@ void visit_load_gs_per_vertex_input(isel_context *ctx, nir_intrinsic_instr *inst
    }
 }
 
+void visit_load_tcs_per_vertex_input(isel_context *ctx, nir_intrinsic_instr *instr)
+{
+   assert(ctx->shader->info.stage == MESA_SHADER_TESS_CTRL);
+
+   Builder bld(ctx->program, ctx->block);
+   Temp dst = get_ssa_temp(ctx, &instr->dest.ssa);
+   std::pair<Temp, unsigned> offs = get_tcs_per_vertex_input_lds_offset(ctx, instr);
+   unsigned elem_size_bytes = instr->dest.ssa.bit_size / 8;
+   unsigned lds_align = calculate_lds_alignment(ctx, offs.second);
+
+   load_lds(ctx, elem_size_bytes, dst, offs.first, offs.second, lds_align);
+}
+
 void visit_load_per_vertex_input(isel_context *ctx, nir_intrinsic_instr *instr)
 {
    switch (ctx->shader->info.stage) {
    case MESA_SHADER_GEOMETRY:
       visit_load_gs_per_vertex_input(ctx, instr);
       break;
+   case MESA_SHADER_TESS_CTRL:
+      visit_load_tcs_per_vertex_input(ctx, instr);
+      break;
    default:
       unreachable("Unimplemented shader stage");
    }
 }
 
+void visit_load_per_vertex_output(isel_context *ctx, nir_intrinsic_instr *instr)
+{
+   visit_load_tcs_output(ctx, instr, true);
+}
+
+void visit_store_per_vertex_output(isel_context *ctx, nir_intrinsic_instr *instr)
+{
+   assert(ctx->stage == tess_control_hs || ctx->stage == vertex_tess_control_hs);
+   assert(ctx->shader->info.stage == MESA_SHADER_TESS_CTRL);
+
+   visit_store_tcs_output(ctx, instr, true);
+}
+
 void visit_load_tess_coord(isel_context *ctx, nir_intrinsic_instr *instr)
 {
    assert(ctx->shader->info.stage == MESA_SHADER_TESS_EVAL);
@@ -6409,9 +6609,18 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
    case nir_intrinsic_load_input_vertex:
       visit_load_input(ctx, instr);
       break;
+   case nir_intrinsic_load_output:
+      visit_load_output(ctx, instr);
+      break;
    case nir_intrinsic_load_per_vertex_input:
       visit_load_per_vertex_input(ctx, instr);
       break;
+   case nir_intrinsic_load_per_vertex_output:
+      visit_load_per_vertex_output(ctx, instr);
+      break;
+   case nir_intrinsic_store_per_vertex_output:
+      visit_store_per_vertex_output(ctx, instr);
+      break;
    case nir_intrinsic_load_ubo:
       visit_load_ubo(ctx, instr);
       break;
@@ -8922,6 +9131,101 @@ static void create_fs_exports(isel_context *ctx)
    }
 }
 
+static void write_tcs_tess_factors(isel_context *ctx)
+{
+   unsigned outer_comps;
+   unsigned inner_comps;
+
+   switch (ctx->args->options->key.tcs.primitive_mode) {
+   case GL_ISOLINES:
+      outer_comps = 2;
+      inner_comps = 0;
+      break;
+   case GL_TRIANGLES:
+      outer_comps = 3;
+      inner_comps = 1;
+      break;
+   case GL_QUADS:
+      outer_comps = 4;
+      inner_comps = 2;
+      break;
+   default:
+      return;
+   }
+
+   const unsigned tess_index_inner = shader_io_get_unique_index(VARYING_SLOT_TESS_LEVEL_INNER);
+   const unsigned tess_index_outer = shader_io_get_unique_index(VARYING_SLOT_TESS_LEVEL_OUTER);
+
+   Builder bld(ctx->program, ctx->block);
+
+   bld.barrier(aco_opcode::p_memory_barrier_shared);
+   unsigned workgroup_size = ctx->tcs_num_patches * ctx->shader->info.tess.tcs_vertices_out;
+   if (unlikely(ctx->program->chip_class != GFX6 && workgroup_size > ctx->program->wave_size))
+      bld.sopp(aco_opcode::s_barrier);
+
+   Temp tcs_rel_ids = get_arg(ctx, ctx->args->ac.tcs_rel_ids);
+   Temp invocation_id = bld.vop3(aco_opcode::v_bfe_u32, bld.def(v1), tcs_rel_ids, Operand(8u), Operand(5u));
+
+   Temp invocation_id_is_zero = bld.vopc(aco_opcode::v_cmp_eq_u32, bld.hint_vcc(bld.def(bld.lm)), Operand(0u), invocation_id);
+   if_context ic_invocation_id_is_zero;
+   begin_divergent_if_then(ctx, &ic_invocation_id_is_zero, invocation_id_is_zero);
+   bld.reset(ctx->block);
+
+   Temp hs_ring_tess_factor = bld.smem(aco_opcode::s_load_dwordx4, bld.def(s4), ctx->program->private_segment_buffer, Operand(RING_HS_TESS_FACTOR * 16u));
+
+   std::pair<Temp, unsigned> lds_base = get_tcs_output_lds_offset(ctx);
+   unsigned stride = inner_comps + outer_comps;
+   Temp inner[4];
+   Temp outer[4];
+   Temp out[6];
+   assert(inner_comps <= (sizeof(inner) / sizeof(Temp)));
+   assert(outer_comps <= (sizeof(outer) / sizeof(Temp)));
+   assert(stride <= (sizeof(out) / sizeof(Temp)));
+
+   if (ctx->args->options->key.tcs.primitive_mode == GL_ISOLINES) {
+      // LINES reversal
+      outer[0] = out[1] = load_lds(ctx, 4, bld.tmp(v1), lds_base.first, lds_base.second + tess_index_outer * 16 + 0 * 4, 4);
+      outer[1] = out[0] = load_lds(ctx, 4, bld.tmp(v1), lds_base.first, lds_base.second + tess_index_outer * 16 + 1 * 4, 4);
+   } else {
+      for (unsigned i = 0; i < outer_comps; ++i)
+         outer[i] = out[i] = load_lds(ctx, 4, bld.tmp(v1), lds_base.first, lds_base.second + tess_index_outer * 16 + i * 4, 4);
+
+      for (unsigned i = 0; i < inner_comps; ++i)
+         inner[i] = out[outer_comps + i] = load_lds(ctx, 4, bld.tmp(v1), lds_base.first, lds_base.second + tess_index_inner * 16 + i * 4, 4);
+   }
+
+   Temp rel_patch_id = get_tess_rel_patch_id(ctx);
+   Temp tf_base = get_arg(ctx, ctx->args->tess_factor_offset);
+   Temp byte_offset = bld.v_mul_imm(bld.def(v1), rel_patch_id, stride * 4u);
+   unsigned tf_const_offset = 0;
+
+   if (ctx->program->chip_class <= GFX8) {
+      Temp rel_patch_id_is_zero = bld.vopc(aco_opcode::v_cmp_eq_u32, bld.hint_vcc(bld.def(bld.lm)), Operand(0u), rel_patch_id);
+      if_context ic_rel_patch_id_is_zero;
+      begin_divergent_if_then(ctx, &ic_rel_patch_id_is_zero, rel_patch_id_is_zero);
+      bld.reset(ctx->block);
+
+      /* Store the dynamic HS control word. */
+      Temp control_word = bld.copy(bld.def(v1), Operand(0x80000000u));
+      bld.mubuf(aco_opcode::buffer_store_dword,
+                /* SRSRC */ hs_ring_tess_factor, /* VADDR */ Operand(v1), /* SOFFSET */ tf_base, /* VDATA */ control_word,
+                /* immediate OFFSET */ 0, /* OFFEN */ false, /* idxen*/ false, /* addr64 */ false,
+                /* disable_wqm */ false, /* glc */ true);
+      tf_const_offset += 4;
+
+      begin_divergent_if_else(ctx, &ic_rel_patch_id_is_zero);
+      end_divergent_if(ctx, &ic_rel_patch_id_is_zero);
+      bld.reset(ctx->block);
+   }
+
+   assert(stride == 2 || stride == 4 || stride == 6);
+   Temp tf_vec = create_vec_from_array(ctx, out, stride, RegType::vgpr);
+   store_vmem_mubuf(ctx, tf_vec, hs_ring_tess_factor, byte_offset, tf_base, tf_const_offset, 4, (1 << stride) - 1, true, false);
+
+   begin_divergent_if_else(ctx, &ic_invocation_id_is_zero);
+   end_divergent_if(ctx, &ic_invocation_id_is_zero);
+}
+
 static void emit_stream_output(isel_context *ctx,
                                Temp const *so_buffers,
                                Temp const *so_write_offset,
@@ -9246,6 +9550,8 @@ void select_program(Program *program,
          Builder bld(ctx.program, ctx.block);
          bld.barrier(aco_opcode::p_memory_barrier_gs_data);
          bld.sopp(aco_opcode::s_sendmsg, bld.m0(ctx.gs_wave_id), -1, sendmsg_gs_done(false, false, 0));
+      } else if (nir->info.stage == MESA_SHADER_TESS_CTRL) {
+         write_tcs_tess_factors(&ctx);
       }
 
       if (ctx.stage == fragment_fs)
index 5d6fd1b9f531c9919e3beea62e06a34fd8a4a864..46e61e4a77e5aa9aec236831ea5e889ac028f863 100644 (file)
@@ -311,8 +311,10 @@ void init_context(isel_context *ctx, nir_shader *shader)
                   case nir_intrinsic_load_sample_id:
                   case nir_intrinsic_load_sample_mask_in:
                   case nir_intrinsic_load_input:
+                  case nir_intrinsic_load_output:
                   case nir_intrinsic_load_input_vertex:
                   case nir_intrinsic_load_per_vertex_input:
+                  case nir_intrinsic_load_per_vertex_output:
                   case nir_intrinsic_load_vertex_id:
                   case nir_intrinsic_load_vertex_id_zero_base:
                   case nir_intrinsic_load_barycentric_sample: