aco: Use common argument handling
[mesa.git] / src / amd / compiler / aco_instruction_selection_setup.cpp
index fbab89417cd3e7b8833ecffda48c0eeeb167159e..16b53725408bea87f7405302bde0e2d4f18d63c2 100644 (file)
@@ -28,6 +28,7 @@
 #include "nir.h"
 #include "vulkan/radv_shader.h"
 #include "vulkan/radv_descriptor_set.h"
+#include "vulkan/radv_shader_args.h"
 #include "sid.h"
 #include "ac_exp_param.h"
 #include "ac_shader_util.h"
 
 namespace aco {
 
-enum fs_input {
-   persp_sample_p1,
-   persp_sample_p2,
-   persp_center_p1,
-   persp_center_p2,
-   persp_centroid_p1,
-   persp_centroid_p2,
-   persp_pull_model,
-   linear_sample_p1,
-   linear_sample_p2,
-   linear_center_p1,
-   linear_center_p2,
-   linear_centroid_p1,
-   linear_centroid_p2,
-   line_stipple,
-   frag_pos_0,
-   frag_pos_1,
-   frag_pos_2,
-   frag_pos_3,
-   front_face,
-   ancillary,
-   sample_coverage,
-   fixed_pt,
-   max_inputs,
-};
-
 struct vs_output_state {
    uint8_t mask[VARYING_SLOT_VAR31 + 1];
    Temp outputs[VARYING_SLOT_VAR31 + 1][4];
@@ -71,6 +46,7 @@ struct vs_output_state {
 
 struct isel_context {
    const struct radv_nir_compiler_options *options;
+   struct radv_shader_args *args;
    Program *program;
    nir_shader *shader;
    uint32_t constant_data_offset;
@@ -95,51 +71,30 @@ struct isel_context {
       bool exec_potentially_empty = false;
    } cf_info;
 
+   Temp arg_temps[AC_MAX_ARGS];
+
    /* inputs common for merged stages */
    Temp merged_wave_info = Temp(0, s1);
 
    /* FS inputs */
-   bool fs_vgpr_args[fs_input::max_inputs];
-   Temp fs_inputs[fs_input::max_inputs];
-   Temp prim_mask = Temp(0, s1);
-   Temp descriptor_sets[MAX_SETS];
-   Temp push_constants = Temp(0, s1);
-   Temp inline_push_consts[MAX_INLINE_PUSH_CONSTS];
-   unsigned num_inline_push_consts = 0;
-   unsigned base_inline_push_consts = 0;
+   Temp persp_centroid, linear_centroid;
 
    /* VS inputs */
-   Temp vertex_buffers = Temp(0, s1);
-   Temp base_vertex = Temp(0, s1);
-   Temp start_instance = Temp(0, s1);
-   Temp draw_id = Temp(0, s1);
-   Temp view_index = Temp(0, s1);
-   Temp es2gs_offset = Temp(0, s1);
-   Temp vertex_id = Temp(0, v1);
-   Temp rel_auto_id = Temp(0, v1);
-   Temp instance_id = Temp(0, v1);
-   Temp vs_prim_id = Temp(0, v1);
    bool needs_instance_id;
 
-   /* CS inputs */
-   Temp num_workgroups = Temp(0, s3);
-   Temp workgroup_ids[3] = {Temp(0, s1), Temp(0, s1), Temp(0, s1)};
-   Temp tg_size = Temp(0, s1);
-   Temp local_invocation_ids = Temp(0, v3);
-
    /* VS output information */
    unsigned num_clip_distances;
    unsigned num_cull_distances;
    vs_output_state vs_output;
-
-   /* Streamout */
-   Temp streamout_buffers = Temp(0, s1);
-   Temp streamout_write_idx = Temp(0, s1);
-   Temp streamout_config = Temp(0, s1);
-   Temp streamout_offset[4] = {Temp(0, s1), Temp(0, s1), Temp(0, s1), Temp(0, s1)};
 };
 
-fs_input get_interp_input(nir_intrinsic_op intrin, enum glsl_interp_mode interp)
+Temp get_arg(isel_context *ctx, struct ac_arg arg)
+{
+   assert(arg.used);
+   return ctx->arg_temps[arg.arg_index];
+}
+
+unsigned get_interp_input(nir_intrinsic_op intrin, enum glsl_interp_mode interp)
 {
    switch (interp) {
    case INTERP_MODE_SMOOTH:
@@ -147,24 +102,24 @@ fs_input get_interp_input(nir_intrinsic_op intrin, enum glsl_interp_mode interp)
       if (intrin == nir_intrinsic_load_barycentric_pixel ||
           intrin == nir_intrinsic_load_barycentric_at_sample ||
           intrin == nir_intrinsic_load_barycentric_at_offset)
-         return fs_input::persp_center_p1;
+         return S_0286CC_PERSP_CENTER_ENA(1);
       else if (intrin == nir_intrinsic_load_barycentric_centroid)
-         return fs_input::persp_centroid_p1;
+         return S_0286CC_PERSP_CENTROID_ENA(1);
       else if (intrin == nir_intrinsic_load_barycentric_sample)
-         return fs_input::persp_sample_p1;
+         return S_0286CC_PERSP_SAMPLE_ENA(1);
       break;
    case INTERP_MODE_NOPERSPECTIVE:
       if (intrin == nir_intrinsic_load_barycentric_pixel)
-         return fs_input::linear_center_p1;
+         return S_0286CC_LINEAR_CENTER_ENA(1);
       else if (intrin == nir_intrinsic_load_barycentric_centroid)
-         return fs_input::linear_centroid_p1;
+         return S_0286CC_LINEAR_CENTROID_ENA(1);
       else if (intrin == nir_intrinsic_load_barycentric_sample)
-         return fs_input::linear_sample_p1;
+         return S_0286CC_LINEAR_SAMPLE_ENA(1);
       break;
    default:
       break;
    }
-   return fs_input::max_inputs;
+   return 0;
 }
 
 void init_context(isel_context *ctx, nir_shader *shader)
@@ -175,7 +130,8 @@ void init_context(isel_context *ctx, nir_shader *shader)
    ctx->divergent_vals = nir_divergence_analysis(shader, nir_divergence_view_index_uniform);
 
    std::unique_ptr<Temp[]> allocated{new Temp[impl->ssa_alloc]()};
-   memset(&ctx->fs_vgpr_args, false, sizeof(ctx->fs_vgpr_args));
+
+   unsigned spi_ps_inputs = 0;
 
    bool done = false;
    while (!done) {
@@ -457,28 +413,28 @@ void init_context(isel_context *ctx, nir_shader *shader)
                   case nir_intrinsic_load_barycentric_at_sample:
                   case nir_intrinsic_load_barycentric_at_offset: {
                      glsl_interp_mode mode = (glsl_interp_mode)nir_intrinsic_interp_mode(intrinsic);
-                     ctx->fs_vgpr_args[get_interp_input(intrinsic->intrinsic, mode)] = true;
+                     spi_ps_inputs |= get_interp_input(intrinsic->intrinsic, mode);
                      break;
                   }
                   case nir_intrinsic_load_front_face:
-                     ctx->fs_vgpr_args[fs_input::front_face] = true;
+                     spi_ps_inputs |= S_0286CC_FRONT_FACE_ENA(1);
                      break;
                   case nir_intrinsic_load_frag_coord:
                   case nir_intrinsic_load_sample_pos: {
                      uint8_t mask = nir_ssa_def_components_read(&intrinsic->dest.ssa);
                      for (unsigned i = 0; i < 4; i++) {
                         if (mask & (1 << i))
-                           ctx->fs_vgpr_args[fs_input::frag_pos_0 + i] = true;
+                           spi_ps_inputs |= S_0286CC_POS_X_FLOAT_ENA(1) << i;
 
                      }
                      break;
                   }
                   case nir_intrinsic_load_sample_id:
-                     ctx->fs_vgpr_args[fs_input::ancillary] = true;
+                     spi_ps_inputs |= S_0286CC_ANCILLARY_ENA(1);
                      break;
                   case nir_intrinsic_load_sample_mask_in:
-                     ctx->fs_vgpr_args[fs_input::ancillary] = true;
-                     ctx->fs_vgpr_args[fs_input::sample_coverage] = true;
+                     spi_ps_inputs |= S_0286CC_ANCILLARY_ENA(1);
+                     spi_ps_inputs |= S_0286CC_SAMPLE_COVERAGE_ENA(1);
                      break;
                   default:
                      break;
@@ -555,479 +511,81 @@ void init_context(isel_context *ctx, nir_shader *shader)
       }
    }
 
-   for (unsigned i = 0; i < impl->ssa_alloc; i++)
-      allocated[i] = Temp(ctx->program->allocateId(), allocated[i].regClass());
-
-   ctx->allocated.reset(allocated.release());
-}
-
-struct user_sgpr_info {
-   uint8_t num_sgpr;
-   uint8_t remaining_sgprs;
-   uint8_t user_sgpr_idx;
-   bool need_ring_offsets;
-   bool indirect_all_descriptor_sets;
-};
-
-static void allocate_inline_push_consts(isel_context *ctx,
-                                        user_sgpr_info& user_sgpr_info)
-{
-   uint8_t remaining_sgprs = user_sgpr_info.remaining_sgprs;
-
-   /* Only supported if shaders use push constants. */
-   if (ctx->program->info->min_push_constant_used == UINT8_MAX)
-      return;
-
-   /* Only supported if shaders don't have indirect push constants. */
-   if (ctx->program->info->has_indirect_push_constants)
-      return;
-
-   /* Only supported for 32-bit push constants. */
-   //TODO: it's possible that some day, the load/store vectorization could make this inaccurate
-   if (!ctx->program->info->has_only_32bit_push_constants)
-      return;
-
-   uint8_t num_push_consts =
-      (ctx->program->info->max_push_constant_used -
-       ctx->program->info->min_push_constant_used) / 4;
-
-   /* Check if the number of user SGPRs is large enough. */
-   if (num_push_consts < remaining_sgprs) {
-      ctx->program->info->num_inline_push_consts = num_push_consts;
-   } else {
-      ctx->program->info->num_inline_push_consts = remaining_sgprs;
-   }
-
-   /* Clamp to the maximum number of allowed inlined push constants. */
-   if (ctx->program->info->num_inline_push_consts > MAX_INLINE_PUSH_CONSTS)
-      ctx->program->info->num_inline_push_consts = MAX_INLINE_PUSH_CONSTS;
-
-   if (ctx->program->info->num_inline_push_consts == num_push_consts &&
-       !ctx->program->info->loads_dynamic_offsets) {
-      /* Disable the default push constants path if all constants are
-       * inlined and if shaders don't use dynamic descriptors.
-       */
-      ctx->program->info->loads_push_constants = false;
-      user_sgpr_info.num_sgpr--;
-      user_sgpr_info.remaining_sgprs++;
-   }
-
-   ctx->program->info->base_inline_push_consts =
-      ctx->program->info->min_push_constant_used / 4;
-
-   user_sgpr_info.num_sgpr += ctx->program->info->num_inline_push_consts;
-   user_sgpr_info.remaining_sgprs -= ctx->program->info->num_inline_push_consts;
-}
-
-static void allocate_user_sgprs(isel_context *ctx,
-                                bool needs_view_index, user_sgpr_info& user_sgpr_info)
-{
-   memset(&user_sgpr_info, 0, sizeof(struct user_sgpr_info));
-   uint32_t user_sgpr_count = 0;
-
-   /* until we sort out scratch/global buffers always assign ring offsets for gs/vs/es */
-   if (ctx->stage != fragment_fs &&
-       ctx->stage != compute_cs
-       /*|| ctx->is_gs_copy_shader */)
-      user_sgpr_info.need_ring_offsets = true;
-
-   if (ctx->stage == fragment_fs &&
-       ctx->program->info->ps.needs_sample_positions)
-      user_sgpr_info.need_ring_offsets = true;
-
-   /* 2 user sgprs will nearly always be allocated for scratch/rings */
-   user_sgpr_count += 2;
-
-   switch (ctx->stage) {
-   case vertex_vs:
-   /* if (!ctx->is_gs_copy_shader) */ {
-         if (ctx->program->info->vs.has_vertex_buffers)
-            user_sgpr_count++;
-         user_sgpr_count += ctx->program->info->vs.needs_draw_id ? 3 : 2;
-      }
-      break;
-   case fragment_fs:
-      //user_sgpr_count += ctx->program->info->ps.needs_sample_positions;
-      break;
-   case compute_cs:
-      if (ctx->program->info->cs.uses_grid_size)
-         user_sgpr_count += 3;
-      break;
-   default:
-      unreachable("Shader stage not implemented");
-   }
-
-   if (needs_view_index)
-      user_sgpr_count++;
-
-   if (ctx->program->info->loads_push_constants)
-      user_sgpr_count += 1; /* we use 32bit pointers */
-
-   if (ctx->program->info->so.num_outputs)
-      user_sgpr_count += 1; /* we use 32bit pointers */
-
-   uint32_t available_sgprs = ctx->options->chip_class >= GFX9 && !(ctx->stage & hw_cs) ? 32 : 16;
-   uint32_t remaining_sgprs = available_sgprs - user_sgpr_count;
-   uint32_t num_desc_set = util_bitcount(ctx->program->info->desc_set_used_mask);
-
-   if (available_sgprs < user_sgpr_count + num_desc_set) {
-      user_sgpr_info.indirect_all_descriptor_sets = true;
-      user_sgpr_info.num_sgpr = user_sgpr_count + 1;
-      user_sgpr_info.remaining_sgprs = remaining_sgprs - 1;
-   } else {
-      user_sgpr_info.num_sgpr = user_sgpr_count + num_desc_set;
-      user_sgpr_info.remaining_sgprs = remaining_sgprs - num_desc_set;
-   }
-
-   allocate_inline_push_consts(ctx, user_sgpr_info);
-}
-
-#define MAX_ARGS 64
-struct arg_info {
-   RegClass types[MAX_ARGS];
-   Temp *assign[MAX_ARGS];
-   PhysReg reg[MAX_ARGS];
-   unsigned array_params_mask;
-   uint8_t count;
-   uint8_t sgpr_count;
-   uint8_t num_sgprs_used;
-   uint8_t num_vgprs_used;
-};
-
-static void
-add_arg(arg_info *info, RegClass rc, Temp *param_ptr, unsigned reg)
-{
-   assert(info->count < MAX_ARGS);
-
-   info->assign[info->count] = param_ptr;
-   info->types[info->count] = rc;
-
-   if (rc.type() == RegType::sgpr) {
-      info->num_sgprs_used += rc.size();
-      info->sgpr_count++;
-      info->reg[info->count] = PhysReg{reg};
-   } else {
-      assert(rc.type() == RegType::vgpr);
-      info->num_vgprs_used += rc.size();
-      info->reg[info->count] = PhysReg{reg + 256};
-   }
-   info->count++;
-}
-
-static void
-set_loc(struct radv_userdata_info *ud_info, uint8_t *sgpr_idx, uint8_t num_sgprs)
-{
-   ud_info->sgpr_idx = *sgpr_idx;
-   ud_info->num_sgprs = num_sgprs;
-   *sgpr_idx += num_sgprs;
-}
-
-static void
-set_loc_shader(isel_context *ctx, int idx, uint8_t *sgpr_idx,
-               uint8_t num_sgprs)
-{
-   struct radv_userdata_info *ud_info = &ctx->program->info->user_sgprs_locs.shader_data[idx];
-   assert(ud_info);
-
-   set_loc(ud_info, sgpr_idx, num_sgprs);
-}
-
-static void
-set_loc_shader_ptr(isel_context *ctx, int idx, uint8_t *sgpr_idx)
-{
-   bool use_32bit_pointers = idx != AC_UD_SCRATCH_RING_OFFSETS;
-
-   set_loc_shader(ctx, idx, sgpr_idx, use_32bit_pointers ? 1 : 2);
-}
-
-static void
-set_loc_desc(isel_context *ctx, int idx,  uint8_t *sgpr_idx)
-{
-   struct radv_userdata_locations *locs = &ctx->program->info->user_sgprs_locs;
-   struct radv_userdata_info *ud_info = &locs->descriptor_sets[idx];
-   assert(ud_info);
-
-   set_loc(ud_info, sgpr_idx, 1);
-   locs->descriptor_sets_enabled |= 1 << idx;
-}
-
-static void
-declare_global_input_sgprs(isel_context *ctx,
-                           /* bool has_previous_stage, gl_shader_stage previous_stage, */
-                           user_sgpr_info *user_sgpr_info,
-                           struct arg_info *args,
-                           Temp *desc_sets)
-{
-   /* 1 for each descriptor set */
-   if (!user_sgpr_info->indirect_all_descriptor_sets) {
-      uint32_t mask = ctx->program->info->desc_set_used_mask;
-      while (mask) {
-         int i = u_bit_scan(&mask);
-         add_arg(args, s1, &desc_sets[i], user_sgpr_info->user_sgpr_idx);
-         set_loc_desc(ctx, i, &user_sgpr_info->user_sgpr_idx);
-      }
-      /* NIR->LLVM might have set this to true if RADV_DEBUG=compiletime */
-      ctx->program->info->need_indirect_descriptor_sets = false;
-   } else {
-      add_arg(args, s1, desc_sets, user_sgpr_info->user_sgpr_idx);
-      set_loc_shader_ptr(ctx, AC_UD_INDIRECT_DESCRIPTOR_SETS, &user_sgpr_info->user_sgpr_idx);
-      ctx->program->info->need_indirect_descriptor_sets = true;
-   }
-
-   if (ctx->program->info->loads_push_constants) {
-      /* 1 for push constants and dynamic descriptors */
-      add_arg(args, s1, &ctx->push_constants, user_sgpr_info->user_sgpr_idx);
-      set_loc_shader_ptr(ctx, AC_UD_PUSH_CONSTANTS, &user_sgpr_info->user_sgpr_idx);
-   }
-
-   if (ctx->program->info->num_inline_push_consts) {
-      unsigned count = ctx->program->info->num_inline_push_consts;
-      for (unsigned i = 0; i < count; i++)
-         add_arg(args, s1, &ctx->inline_push_consts[i], user_sgpr_info->user_sgpr_idx + i);
-      set_loc_shader(ctx, AC_UD_INLINE_PUSH_CONSTANTS, &user_sgpr_info->user_sgpr_idx, count);
-
-      ctx->num_inline_push_consts = ctx->program->info->num_inline_push_consts;
-      ctx->base_inline_push_consts = ctx->program->info->base_inline_push_consts;
-   }
-
-   if (ctx->program->info->so.num_outputs) {
-      add_arg(args, s1, &ctx->streamout_buffers, user_sgpr_info->user_sgpr_idx);
-      set_loc_shader_ptr(ctx, AC_UD_STREAMOUT_BUFFERS, &user_sgpr_info->user_sgpr_idx);
-   }
-}
-
-static void
-declare_vs_input_vgprs(isel_context *ctx, struct arg_info *args)
-{
-   unsigned vgpr_idx = 0;
-   add_arg(args, v1, &ctx->vertex_id, vgpr_idx++);
-   if (ctx->options->chip_class >= GFX10) {
-      add_arg(args, v1, NULL, vgpr_idx++); /* unused */
-      add_arg(args, v1, &ctx->vs_prim_id, vgpr_idx++);
-      add_arg(args, v1, &ctx->instance_id, vgpr_idx++);
-   } else {
-      if (ctx->options->key.vs.out.as_ls) {
-         add_arg(args, v1, &ctx->rel_auto_id, vgpr_idx++);
-         add_arg(args, v1, &ctx->instance_id, vgpr_idx++);
-      } else {
-         add_arg(args, v1, &ctx->instance_id, vgpr_idx++);
-         add_arg(args, v1, &ctx->vs_prim_id, vgpr_idx++);
-      }
-      add_arg(args, v1, NULL, vgpr_idx); /* unused */
+   if (G_0286CC_POS_W_FLOAT_ENA(spi_ps_inputs)) {
+      /* If POS_W_FLOAT (11) is enabled, at least one of PERSP_* must be enabled too */
+      spi_ps_inputs |= S_0286CC_PERSP_CENTER_ENA(1);
    }
-}
-
-static void
-declare_streamout_sgprs(isel_context *ctx, struct arg_info *args, unsigned *idx)
-{
-   /* Streamout SGPRs. */
-   if (ctx->program->info->so.num_outputs) {
-      assert(ctx->stage & hw_vs);
-
-      if (ctx->stage != tess_eval_vs) {
-         add_arg(args, s1, &ctx->streamout_config, (*idx)++);
-      } else {
-         args->assign[args->count - 1] = &ctx->streamout_config;
-         args->types[args->count - 1] = s1;
-      }
 
-      add_arg(args, s1, &ctx->streamout_write_idx, (*idx)++);
+   if (!(spi_ps_inputs & 0x7F)) {
+      /* At least one of PERSP_* (0xF) or LINEAR_* (0x70) must be enabled */
+      spi_ps_inputs |= S_0286CC_PERSP_CENTER_ENA(1);
    }
 
-   /* A streamout buffer offset is loaded if the stride is non-zero. */
-   for (unsigned i = 0; i < 4; i++) {
-      if (!ctx->program->info->so.strides[i])
-         continue;
+   ctx->program->config->spi_ps_input_ena = spi_ps_inputs;
+   ctx->program->config->spi_ps_input_addr = spi_ps_inputs;
 
-      add_arg(args, s1, &ctx->streamout_offset[i], (*idx)++);
-   }
-}
-
-static bool needs_view_index_sgpr(isel_context *ctx)
-{
-   switch (ctx->stage) {
-   case vertex_vs:
-      return ctx->program->info->needs_multiview_view_index || ctx->options->key.has_multiview_view_index;
-   case tess_eval_vs:
-      return ctx->program->info->needs_multiview_view_index && ctx->options->key.has_multiview_view_index;
-   case vertex_ls:
-   case vertex_es:
-   case vertex_tess_control_hs:
-   case vertex_geometry_gs:
-   case tess_control_hs:
-   case tess_eval_es:
-   case tess_eval_geometry_gs:
-   case geometry_gs:
-      return ctx->program->info->needs_multiview_view_index;
-   default:
-      return false;
-   }
-}
-
-static inline bool
-add_fs_arg(isel_context *ctx, arg_info *args, unsigned &vgpr_idx, fs_input input, unsigned value, bool enable_next = false, RegClass rc = v1)
-{
-   if (!ctx->fs_vgpr_args[input])
-      return false;
-
-   add_arg(args, rc, &ctx->fs_inputs[input], vgpr_idx);
-   vgpr_idx += rc.size();
-
-   if (enable_next) {
-      add_arg(args, rc, &ctx->fs_inputs[input + 1], vgpr_idx);
-      vgpr_idx += rc.size();
-   }
+   for (unsigned i = 0; i < impl->ssa_alloc; i++)
+      allocated[i] = Temp(ctx->program->allocateId(), allocated[i].regClass());
 
-   ctx->program->config->spi_ps_input_addr |= value;
-   ctx->program->config->spi_ps_input_ena |= value;
-   return true;
+   ctx->allocated.reset(allocated.release());
 }
 
 Pseudo_instruction *add_startpgm(struct isel_context *ctx)
 {
-   user_sgpr_info user_sgpr_info;
-   bool needs_view_index = needs_view_index_sgpr(ctx);
-   allocate_user_sgprs(ctx, needs_view_index, user_sgpr_info);
-   arg_info args = {};
-
-   /* this needs to be in sgprs 0 and 1 */
-   add_arg(&args, s2, &ctx->program->private_segment_buffer, 0);
-   set_loc_shader_ptr(ctx, AC_UD_SCRATCH_RING_OFFSETS, &user_sgpr_info.user_sgpr_idx);
-
-   unsigned vgpr_idx = 0;
-   switch (ctx->stage) {
-   case vertex_vs: {
-      declare_global_input_sgprs(ctx, &user_sgpr_info, &args, ctx->descriptor_sets);
-      if (ctx->program->info->vs.has_vertex_buffers) {
-         add_arg(&args, s1, &ctx->vertex_buffers, user_sgpr_info.user_sgpr_idx);
-         set_loc_shader_ptr(ctx, AC_UD_VS_VERTEX_BUFFERS, &user_sgpr_info.user_sgpr_idx);
-      }
-      add_arg(&args, s1, &ctx->base_vertex, user_sgpr_info.user_sgpr_idx);
-      add_arg(&args, s1, &ctx->start_instance, user_sgpr_info.user_sgpr_idx + 1);
-      if (ctx->program->info->vs.needs_draw_id) {
-         add_arg(&args, s1, &ctx->draw_id, user_sgpr_info.user_sgpr_idx + 2);
-         set_loc_shader(ctx, AC_UD_VS_BASE_VERTEX_START_INSTANCE, &user_sgpr_info.user_sgpr_idx, 3);
-      } else
-         set_loc_shader(ctx, AC_UD_VS_BASE_VERTEX_START_INSTANCE, &user_sgpr_info.user_sgpr_idx, 2);
-
-      if (needs_view_index) {
-         add_arg(&args, s1, &ctx->view_index, user_sgpr_info.user_sgpr_idx);
-         set_loc_shader(ctx, AC_UD_VIEW_INDEX, &user_sgpr_info.user_sgpr_idx, 1);
-      }
-
-      assert(user_sgpr_info.user_sgpr_idx == user_sgpr_info.num_sgpr);
-      unsigned idx = user_sgpr_info.user_sgpr_idx;
-      if (ctx->options->key.vs.out.as_es)
-         add_arg(&args, s1, &ctx->es2gs_offset, idx++);
-      else
-         declare_streamout_sgprs(ctx, &args, &idx);
-
-      add_arg(&args, s1, &ctx->program->scratch_offset, idx++);
-
-      declare_vs_input_vgprs(ctx, &args);
-      break;
-   }
-   case fragment_fs: {
-      declare_global_input_sgprs(ctx, &user_sgpr_info, &args, ctx->descriptor_sets);
-
-      assert(user_sgpr_info.user_sgpr_idx == user_sgpr_info.num_sgpr);
-      add_arg(&args, s1, &ctx->prim_mask, user_sgpr_info.user_sgpr_idx);
-
-      add_arg(&args, s1, &ctx->program->scratch_offset, user_sgpr_info.user_sgpr_idx + 1);
-
-      ctx->program->config->spi_ps_input_addr = 0;
-      ctx->program->config->spi_ps_input_ena = 0;
-
-      bool has_interp_mode = false;
-
-      has_interp_mode |= add_fs_arg(ctx, &args, vgpr_idx, fs_input::persp_sample_p1, S_0286CC_PERSP_SAMPLE_ENA(1), true);
-      has_interp_mode |= add_fs_arg(ctx, &args, vgpr_idx, fs_input::persp_center_p1, S_0286CC_PERSP_CENTER_ENA(1), true);
-      has_interp_mode |= add_fs_arg(ctx, &args, vgpr_idx, fs_input::persp_centroid_p1, S_0286CC_PERSP_CENTROID_ENA(1), true);
-      has_interp_mode |= add_fs_arg(ctx, &args, vgpr_idx, fs_input::persp_pull_model, S_0286CC_PERSP_PULL_MODEL_ENA(1), false, v3);
-
-      if (!has_interp_mode && ctx->fs_vgpr_args[fs_input::frag_pos_3]) {
-         /* If POS_W_FLOAT (11) is enabled, at least one of PERSP_* must be enabled too */
-         ctx->fs_vgpr_args[fs_input::persp_center_p1] = true;
-         has_interp_mode = add_fs_arg(ctx, &args, vgpr_idx, fs_input::persp_center_p1, S_0286CC_PERSP_CENTER_ENA(1), true);
-      }
-
-      has_interp_mode |= add_fs_arg(ctx, &args, vgpr_idx, fs_input::linear_sample_p1, S_0286CC_LINEAR_SAMPLE_ENA(1), true);
-      has_interp_mode |= add_fs_arg(ctx, &args, vgpr_idx, fs_input::linear_center_p1, S_0286CC_LINEAR_CENTER_ENA(1), true);
-      has_interp_mode |= add_fs_arg(ctx, &args, vgpr_idx, fs_input::linear_centroid_p1, S_0286CC_LINEAR_CENTROID_ENA(1), true);
-      has_interp_mode |= add_fs_arg(ctx, &args, vgpr_idx, fs_input::line_stipple, S_0286CC_LINE_STIPPLE_TEX_ENA(1));
-
-      if (!has_interp_mode) {
-         /* At least one of PERSP_* (0xF) or LINEAR_* (0x70) must be enabled */
-         ctx->fs_vgpr_args[fs_input::persp_center_p1] = true;
-         has_interp_mode = add_fs_arg(ctx, &args, vgpr_idx, fs_input::persp_center_p1, S_0286CC_PERSP_CENTER_ENA(1), true);
-      }
-
-      add_fs_arg(ctx, &args, vgpr_idx, fs_input::frag_pos_0, S_0286CC_POS_X_FLOAT_ENA(1));
-      add_fs_arg(ctx, &args, vgpr_idx, fs_input::frag_pos_1, S_0286CC_POS_Y_FLOAT_ENA(1));
-      add_fs_arg(ctx, &args, vgpr_idx, fs_input::frag_pos_2, S_0286CC_POS_Z_FLOAT_ENA(1));
-      add_fs_arg(ctx, &args, vgpr_idx, fs_input::frag_pos_3, S_0286CC_POS_W_FLOAT_ENA(1));
-
-      add_fs_arg(ctx, &args, vgpr_idx, fs_input::front_face, S_0286CC_FRONT_FACE_ENA(1));
-      add_fs_arg(ctx, &args, vgpr_idx, fs_input::ancillary, S_0286CC_ANCILLARY_ENA(1));
-      add_fs_arg(ctx, &args, vgpr_idx, fs_input::sample_coverage, S_0286CC_SAMPLE_COVERAGE_ENA(1));
-      add_fs_arg(ctx, &args, vgpr_idx, fs_input::fixed_pt, S_0286CC_POS_FIXED_PT_ENA(1));
-
-      ASSERTED bool unset_interp_mode = !(ctx->program->config->spi_ps_input_addr & 0x7F) ||
-                                        (G_0286CC_POS_W_FLOAT_ENA(ctx->program->config->spi_ps_input_addr)
-                                        && !(ctx->program->config->spi_ps_input_addr & 0xF));
-
-      assert(has_interp_mode);
-      assert(!unset_interp_mode);
-      break;
-   }
-   case compute_cs: {
-      declare_global_input_sgprs(ctx, &user_sgpr_info, &args, ctx->descriptor_sets);
+   unsigned arg_count = ctx->args->ac.arg_count;
+   if (ctx->stage == fragment_fs) {
+      /* LLVM optimizes away unused FS inputs and computes spi_ps_input_addr
+       * itself and then communicates the results back via the ELF binary.
+       * Mirror what LLVM does by re-mapping the VGPR arguments here.
+       *
+       * TODO: If we made the FS input scanning code into a separate pass that
+       * could run before argument setup, then this wouldn't be necessary
+       * anymore.
+       */
+      struct ac_shader_args *args = &ctx->args->ac;
+      arg_count = 0;
+      for (unsigned i = 0, vgpr_arg = 0, vgpr_reg = 0; i < args->arg_count; i++) {
+         if (args->args[i].file != AC_ARG_VGPR) {
+            arg_count++;
+            continue;
+         }
 
-      if (ctx->program->info->cs.uses_grid_size) {
-         add_arg(&args, s3, &ctx->num_workgroups, user_sgpr_info.user_sgpr_idx);
-         set_loc_shader(ctx, AC_UD_CS_GRID_SIZE, &user_sgpr_info.user_sgpr_idx, 3);
-      }
-      assert(user_sgpr_info.user_sgpr_idx == user_sgpr_info.num_sgpr);
-      unsigned idx = user_sgpr_info.user_sgpr_idx;
-      for (unsigned i = 0; i < 3; i++) {
-         if (ctx->program->info->cs.uses_block_id[i])
-            add_arg(&args, s1, &ctx->workgroup_ids[i], idx++);
+         if (!(ctx->program->config->spi_ps_input_addr & (1 << vgpr_arg))) {
+            args->args[i].skip = true;
+         } else {
+            args->args[i].offset = vgpr_reg;
+            vgpr_reg += args->args[i].size;
+            arg_count++;
+         }
+         vgpr_arg++;
       }
-
-      if (ctx->program->info->cs.uses_local_invocation_idx)
-         add_arg(&args, s1, &ctx->tg_size, idx++);
-      add_arg(&args, s1, &ctx->program->scratch_offset, idx++);
-
-      add_arg(&args, v3, &ctx->local_invocation_ids, vgpr_idx++);
-      break;
    }
-   default:
-      unreachable("Shader stage not implemented");
-   }
-
-   ctx->program->info->num_input_vgprs = 0;
-   ctx->program->info->num_input_sgprs = args.num_sgprs_used;
-   ctx->program->info->num_user_sgprs = user_sgpr_info.num_sgpr;
-   ctx->program->info->num_input_vgprs = args.num_vgprs_used;
 
-   if (ctx->stage == fragment_fs) {
-      /* Verify that we have a correct assumption about input VGPR count */
-      ASSERTED unsigned input_vgpr_cnt = ac_get_fs_input_vgpr_cnt(ctx->program->config, nullptr, nullptr);
-      assert(input_vgpr_cnt == ctx->program->info->num_input_vgprs);
-   }
+   aco_ptr<Pseudo_instruction> startpgm{create_instruction<Pseudo_instruction>(aco_opcode::p_startpgm, Format::PSEUDO, 0, arg_count + 1)};
+   for (unsigned i = 0, arg = 0; i < ctx->args->ac.arg_count; i++) {
+      if (ctx->args->ac.args[i].skip)
+         continue;
 
-   aco_ptr<Pseudo_instruction> startpgm{create_instruction<Pseudo_instruction>(aco_opcode::p_startpgm, Format::PSEUDO, 0, args.count + 1)};
-   for (unsigned i = 0; i < args.count; i++) {
-      if (args.assign[i]) {
-         *args.assign[i] = Temp{ctx->program->allocateId(), args.types[i]};
-         startpgm->definitions[i] = Definition(*args.assign[i]);
-         startpgm->definitions[i].setFixed(args.reg[i]);
-      }
-   }
-   startpgm->definitions[args.count] = Definition{ctx->program->allocateId(), exec, s2};
+      enum ac_arg_regfile file = ctx->args->ac.args[i].file;
+      unsigned size = ctx->args->ac.args[i].size;
+      unsigned reg = ctx->args->ac.args[i].offset;
+      RegClass type = RegClass(file == AC_ARG_SGPR ? RegType::sgpr : RegType::vgpr, size);
+      Temp dst = Temp{ctx->program->allocateId(), type};
+      ctx->arg_temps[i] = dst;
+      startpgm->definitions[arg] = Definition(dst);
+      startpgm->definitions[arg].setFixed(PhysReg{file == AC_ARG_SGPR ? reg : reg + 256});
+      arg++;
+   }
+   startpgm->definitions[arg_count] = Definition{ctx->program->allocateId(), exec, s2};
    Pseudo_instruction *instr = startpgm.get();
    ctx->block->instructions.push_back(std::move(startpgm));
 
+   /* Stash these in the program so that they can be accessed later when
+    * handling spilling.
+    */
+   ctx->program->private_segment_buffer = get_arg(ctx, ctx->args->ring_offsets);
+   ctx->program->scratch_offset = get_arg(ctx, ctx->args->scratch_offset);
+
    return instr;
 }
 
@@ -1168,8 +726,7 @@ setup_isel_context(Program* program,
                    unsigned shader_count,
                    struct nir_shader *const *shaders,
                    ac_shader_config* config,
-                   radv_shader_info *info,
-                   const radv_nir_compiler_options *options)
+                   struct radv_shader_args *args)
 {
    program->stage = 0;
    for (unsigned i = 0; i < shader_count; i++) {
@@ -1206,23 +763,23 @@ setup_isel_context(Program* program,
       unreachable("Shader stage not implemented");
 
    program->config = config;
-   program->info = info;
-   program->chip_class = options->chip_class;
-   program->family = options->family;
-   program->wave_size = info->wave_size;
+   program->info = args->shader_info;
+   program->chip_class = args->options->chip_class;
+   program->family = args->options->family;
+   program->wave_size = args->shader_info->wave_size;
 
-   program->lds_alloc_granule = options->chip_class >= GFX7 ? 512 : 256;
-   program->lds_limit = options->chip_class >= GFX7 ? 65536 : 32768;
+   program->lds_alloc_granule = args->options->chip_class >= GFX7 ? 512 : 256;
+   program->lds_limit = args->options->chip_class >= GFX7 ? 65536 : 32768;
    program->vgpr_limit = 256;
 
-   if (options->chip_class >= GFX10) {
+   if (args->options->chip_class >= GFX10) {
       program->physical_sgprs = 2560; /* doesn't matter as long as it's at least 128 * 20 */
       program->sgpr_alloc_granule = 127;
       program->sgpr_limit = 106;
    } else if (program->chip_class >= GFX8) {
       program->physical_sgprs = 800;
       program->sgpr_alloc_granule = 15;
-      if (options->family == CHIP_TONGA || options->family == CHIP_ICELAND)
+      if (args->options->family == CHIP_TONGA || args->options->family == CHIP_ICELAND)
          program->sgpr_limit = 94; /* workaround hardware bug */
       else
          program->sgpr_limit = 102;
@@ -1234,28 +791,12 @@ setup_isel_context(Program* program,
    /* TODO: we don't have to allocate VCC if we don't need it */
    program->needs_vcc = true;
 
-   for (unsigned i = 0; i < MAX_SETS; ++i)
-      program->info->user_sgprs_locs.descriptor_sets[i].sgpr_idx = -1;
-   for (unsigned i = 0; i < AC_UD_MAX_UD; ++i)
-      program->info->user_sgprs_locs.shader_data[i].sgpr_idx = -1;
-
    isel_context ctx = {};
    ctx.program = program;
-   ctx.options = options;
+   ctx.args = args;
+   ctx.options = args->options;
    ctx.stage = program->stage;
 
-   for (unsigned i = 0; i < fs_input::max_inputs; ++i)
-      ctx.fs_inputs[i] = Temp(0, v1);
-   ctx.fs_inputs[fs_input::persp_pull_model] = Temp(0, v3);
-   for (unsigned i = 0; i < MAX_SETS; ++i)
-      ctx.descriptor_sets[i] = Temp(0, s1);
-   for (unsigned i = 0; i < MAX_INLINE_PUSH_CONSTS; ++i)
-      ctx.inline_push_consts[i] = Temp(0, s1);
-   for (unsigned i = 0; i <= VARYING_SLOT_VAR31; ++i) {
-      for (unsigned j = 0; j < 4; ++j)
-         ctx.vs_output.outputs[i][j] = Temp(0, v1);
-   }
-
    for (unsigned i = 0; i < shader_count; i++) {
       nir_shader *nir = shaders[i];
 
@@ -1339,7 +880,7 @@ setup_isel_context(Program* program,
       nir_function_impl *func = nir_shader_get_entrypoint(nir);
       nir_index_ssa_defs(func);
 
-      if (options->dump_preoptir) {
+      if (args->options->dump_preoptir) {
          fprintf(stderr, "NIR shader before instruction selection:\n");
          nir_print_shader(nir, stderr);
       }