nir: Use a single list for all shader variables
[mesa.git] / src / gallium / drivers / zink / nir_to_spirv / nir_to_spirv.c
index 40a0b267a258228aa5db430a9395e6500a803935..68a5d00a3506c3233e773cdea859e9f1586b432e 100644 (file)
@@ -36,15 +36,15 @@ static unsigned slot_pack_map[] = {
    [VARYING_SLOT_COL0] = 0, /* input/output */
    [VARYING_SLOT_COL1] = 1, /* input/output */
    [VARYING_SLOT_FOGC] = 2, /* input/output */
-   /* TEX0-7 are translated to VAR0-7 by nir, so we don't need to reserve */
-   [VARYING_SLOT_TEX0] = UINT_MAX, /* input/output */
-   [VARYING_SLOT_TEX1] = UINT_MAX,
-   [VARYING_SLOT_TEX2] = UINT_MAX,
-   [VARYING_SLOT_TEX3] = UINT_MAX,
-   [VARYING_SLOT_TEX4] = UINT_MAX,
-   [VARYING_SLOT_TEX5] = UINT_MAX,
-   [VARYING_SLOT_TEX6] = UINT_MAX,
-   [VARYING_SLOT_TEX7] = UINT_MAX,
+   /* TEX0-7 are deprecated, so we put them at the end of the range and hope nobody uses them all */
+   [VARYING_SLOT_TEX0] = VARYING_SLOT_VAR0 - 1, /* input/output */
+   [VARYING_SLOT_TEX1] = VARYING_SLOT_VAR0 - 2,
+   [VARYING_SLOT_TEX2] = VARYING_SLOT_VAR0 - 3,
+   [VARYING_SLOT_TEX3] = VARYING_SLOT_VAR0 - 4,
+   [VARYING_SLOT_TEX4] = VARYING_SLOT_VAR0 - 5,
+   [VARYING_SLOT_TEX5] = VARYING_SLOT_VAR0 - 6,
+   [VARYING_SLOT_TEX6] = VARYING_SLOT_VAR0 - 7,
+   [VARYING_SLOT_TEX7] = VARYING_SLOT_VAR0 - 8,
 
    /* PointSize is builtin */
    [VARYING_SLOT_PSIZ] = UINT_MAX,
@@ -90,6 +90,8 @@ static unsigned slot_pack_map[] = {
 #define NTV_MIN_RESERVED_SLOTS 11
 
 struct ntv_context {
+   void *mem_ctx;
+
    struct spirv_builder builder;
 
    SpvId GLSL_std_450;
@@ -122,6 +124,9 @@ struct ntv_context {
    SpvId loop_break, loop_cont;
 
    SpvId front_face_var, instance_id_var, vertex_id_var;
+#ifndef NDEBUG
+   bool seen_texcoord[8]; //whether we've seen a VARYING_SLOT_TEX[n] this pass
+#endif
 };
 
 static SpvId
@@ -284,6 +289,28 @@ get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
    unreachable("we shouldn't get here, I think...");
 }
 
+static inline unsigned
+handle_slot(struct ntv_context *ctx, unsigned slot)
+{
+   unsigned orig = slot;
+   if (slot < VARYING_SLOT_VAR0) {
+#ifndef NDEBUG
+      if (slot >= VARYING_SLOT_TEX0 && slot <= VARYING_SLOT_TEX7)
+         ctx->seen_texcoord[slot - VARYING_SLOT_TEX0] = true;
+#endif
+      slot = slot_pack_map[slot];
+      if (slot == UINT_MAX)
+         debug_printf("unhandled varying slot: %s\n", gl_varying_slot_name(orig));
+   } else {
+      slot -= VARYING_SLOT_VAR0 - NTV_MIN_RESERVED_SLOTS;
+      assert(slot <= VARYING_SLOT_VAR0 - 8 ||
+             !ctx->seen_texcoord[VARYING_SLOT_VAR0 - slot - 1]);
+
+   }
+   assert(slot < VARYING_SLOT_VAR0);
+   return slot;
+}
+
 #define HANDLE_EMIT_BUILTIN(SLOT, BUILTIN) \
       case VARYING_SLOT_##SLOT: \
          spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltIn##BUILTIN); \
@@ -316,13 +343,7 @@ emit_input(struct ntv_context *ctx, struct nir_variable *var)
       HANDLE_EMIT_BUILTIN(FACE, FrontFacing);
 
       default:
-         if (slot < VARYING_SLOT_VAR0) {
-            slot = slot_pack_map[slot];
-            if (slot == UINT_MAX)
-               debug_printf("unhandled varying slot: %s\n", gl_varying_slot_name(var->data.location));
-         } else
-            slot -= VARYING_SLOT_VAR0 - NTV_MIN_RESERVED_SLOTS;
-         assert(slot < VARYING_SLOT_VAR0);
+         slot = handle_slot(ctx, slot);
          spirv_builder_emit_location(&ctx->builder, var_id, slot);
       }
    } else {
@@ -378,31 +399,21 @@ emit_output(struct ntv_context *ctx, struct nir_variable *var)
          break;
 
       default:
-         if (slot < VARYING_SLOT_VAR0) {
-            slot = slot_pack_map[slot];
-            if (slot == UINT_MAX)
-               debug_printf("unhandled varying slot: %s\n", gl_varying_slot_name(var->data.location));
-         } else
-            slot -= VARYING_SLOT_VAR0 - NTV_MIN_RESERVED_SLOTS;
-         assert(slot < VARYING_SLOT_VAR0);
+         slot = handle_slot(ctx, slot);
          spirv_builder_emit_location(&ctx->builder, var_id, slot);
-         /* non-builtins get location incremented by VARYING_SLOT_VAR0 in vtn, so
-          * use driver_location for non-builtins with defined slots to avoid overlap
-          */
       }
       ctx->outputs[var->data.location] = var_id;
       ctx->so_output_gl_types[var->data.location] = var->type;
       ctx->so_output_types[var->data.location] = var_type;
    } else if (ctx->stage == MESA_SHADER_FRAGMENT) {
-      if (var->data.location >= FRAG_RESULT_DATA0)
+      if (var->data.location >= FRAG_RESULT_DATA0) {
          spirv_builder_emit_location(&ctx->builder, var_id,
                                      var->data.location - FRAG_RESULT_DATA0);
-      else {
+         spirv_builder_emit_index(&ctx->builder, var_id, var->data.index);
+      } else {
          switch (var->data.location) {
          case FRAG_RESULT_COLOR:
-            spirv_builder_emit_location(&ctx->builder, var_id, 0);
-            spirv_builder_emit_index(&ctx->builder, var_id, var->data.index);
-            break;
+            unreachable("gl_FragColor should be lowered by now");
 
          case FRAG_RESULT_DEPTH:
             spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragDepth);
@@ -411,6 +422,7 @@ emit_output(struct ntv_context *ctx, struct nir_variable *var)
          default:
             spirv_builder_emit_location(&ctx->builder, var_id,
                                         var->data.driver_location);
+            spirv_builder_emit_index(&ctx->builder, var_id, var->data.index);
          }
       }
    }
@@ -620,6 +632,17 @@ emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
    }
 }
 
+static SpvId
+get_vec_from_bit_size(struct ntv_context *ctx, uint32_t bit_size, uint32_t num_components)
+{
+   if (bit_size == 1)
+      return get_bvec_type(ctx, num_components);
+   if (bit_size == 32)
+      return get_uvec_type(ctx, bit_size, num_components);
+   unreachable("unhandled register bit size");
+   return 0;
+}
+
 static SpvId
 get_src_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
 {
@@ -644,7 +667,7 @@ get_src_reg(struct ntv_context *ctx, const nir_reg_src *reg)
    assert(!reg->base_offset);
 
    SpvId var = get_var_from_reg(ctx, reg->reg);
-   SpvId type = get_uvec_type(ctx, reg->reg->bit_size, reg->reg->num_components);
+   SpvId type = get_vec_from_bit_size(ctx, reg->reg->bit_size, reg->reg->num_components);
    return spirv_builder_emit_load(&ctx->builder, type, var);
 }
 
@@ -895,6 +918,8 @@ emit_so_info(struct ntv_context *ctx, unsigned max_output_location,
       if (max_output_location >= VARYING_SLOT_VAR0)
          location = max_output_location - VARYING_SLOT_VAR0 + 1 + i;
       assert(location < VARYING_SLOT_VAR0);
+      assert(location <= VARYING_SLOT_VAR0 - 8 ||
+             !ctx->seen_texcoord[VARYING_SLOT_VAR0 - location - 1]);
       spirv_builder_emit_location(&ctx->builder, var_id, location);
 
       /* note: gl_ClipDistance[4] can the 0-indexed member of VARYING_SLOT_CLIP_DIST1 here,
@@ -903,7 +928,7 @@ emit_so_info(struct ntv_context *ctx, unsigned max_output_location,
       if (so_output.start_component)
          spirv_builder_emit_component(&ctx->builder, var_id, so_output.start_component);
 
-      uint32_t *key = ralloc_size(NULL, sizeof(uint32_t));
+      uint32_t *key = ralloc_size(ctx->mem_ctx, sizeof(uint32_t));
       *key = (uint32_t)so_output.register_index << 2 | so_output.start_component;
       _mesa_hash_table_insert(ctx->so_outputs, key, (void *)(intptr_t)var_id);
 
@@ -1239,6 +1264,7 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
    BUILTIN_UNOP(nir_op_ftrunc, GLSLstd450Trunc)
    BUILTIN_UNOP(nir_op_fround_even, GLSLstd450RoundEven)
    BUILTIN_UNOP(nir_op_fsign, GLSLstd450FSign)
+   BUILTIN_UNOP(nir_op_isign, GLSLstd450SSign)
    BUILTIN_UNOP(nir_op_fsin, GLSLstd450Sin)
    BUILTIN_UNOP(nir_op_fcos, GLSLstd450Cos)
 #undef BUILTIN_UNOP
@@ -1290,7 +1316,7 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
    BINOP(nir_op_flt, SpvOpFOrdLessThan)
    BINOP(nir_op_fge, SpvOpFOrdGreaterThanEqual)
    BINOP(nir_op_feq, SpvOpFOrdEqual)
-   BINOP(nir_op_fne, SpvOpFOrdNotEqual)
+   BINOP(nir_op_fne, SpvOpFUnordNotEqual)
    BINOP(nir_op_ishl, SpvOpShiftLeftLogical)
    BINOP(nir_op_ishr, SpvOpShiftRightArithmetic)
    BINOP(nir_op_ushr, SpvOpShiftRightLogical)
@@ -1488,19 +1514,17 @@ emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
    SpvId constant;
    if (num_components > 1) {
       SpvId components[num_components];
-      SpvId type;
+      SpvId type = get_vec_from_bit_size(ctx, bit_size, num_components);
       if (bit_size == 1) {
          for (int i = 0; i < num_components; i++)
             components[i] = spirv_builder_const_bool(&ctx->builder,
                                                      load_const->value[i].b);
 
-         type = get_bvec_type(ctx, num_components);
       } else {
          for (int i = 0; i < num_components; i++)
             components[i] = emit_uint_const(ctx, bit_size,
                                             load_const->value[i].u32);
 
-         type = get_uvec_type(ctx, bit_size, num_components);
       }
       constant = spirv_builder_const_composite(&ctx->builder, type,
                                                components, num_components);
@@ -1585,9 +1609,8 @@ emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
 {
    SpvId ptr = get_src(ctx, intr->src);
 
-   nir_variable *var = nir_intrinsic_get_var(intr, 0);
    SpvId result = spirv_builder_emit_load(&ctx->builder,
-                                          get_glsl_type(ctx, var->type),
+                                          get_glsl_type(ctx, nir_src_as_deref(intr->src[0])->type),
                                           ptr);
    unsigned num_components = nir_dest_num_components(intr->dest);
    unsigned bit_size = nir_dest_bit_size(intr->dest);
@@ -1601,8 +1624,7 @@ emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
    SpvId ptr = get_src(ctx, &intr->src[0]);
    SpvId src = get_src(ctx, &intr->src[1]);
 
-   nir_variable *var = nir_intrinsic_get_var(intr, 0);
-   SpvId type = get_glsl_type(ctx, glsl_without_array(var->type));
+   SpvId type = get_glsl_type(ctx, nir_src_as_deref(intr->src[0])->type);
    SpvId result = emit_bitcast(ctx, type, src);
    spirv_builder_emit_store(&ctx->builder, ptr, result);
 }
@@ -1740,6 +1762,42 @@ get_src_int(struct ntv_context *ctx, nir_src *src)
    return bitcast_to_ivec(ctx, def, bit_size, num_components);
 }
 
+static inline bool
+tex_instr_is_lod_allowed(nir_tex_instr *tex)
+{
+   /* This can only be used with an OpTypeImage that has a Dim operand of 1D, 2D, 3D, or Cube
+    * - SPIR-V: 3.14. Image Operands
+    */
+
+   return (tex->sampler_dim == GLSL_SAMPLER_DIM_1D ||
+           tex->sampler_dim == GLSL_SAMPLER_DIM_2D ||
+           tex->sampler_dim == GLSL_SAMPLER_DIM_3D ||
+           tex->sampler_dim == GLSL_SAMPLER_DIM_CUBE);
+}
+
+static SpvId
+pad_coord_vector(struct ntv_context *ctx, SpvId orig, unsigned old_size, unsigned new_size)
+{
+    SpvId int_type = spirv_builder_type_int(&ctx->builder, 32);
+    SpvId type = get_ivec_type(ctx, 32, new_size);
+    SpvId constituents[NIR_MAX_VEC_COMPONENTS] = {0};
+    SpvId zero = emit_int_const(ctx, 32, 0);
+    assert(new_size < NIR_MAX_VEC_COMPONENTS);
+
+    if (old_size == 1)
+       constituents[0] = orig;
+    else {
+       for (unsigned i = 0; i < old_size; i++)
+          constituents[i] = spirv_builder_emit_vector_extract(&ctx->builder, int_type, orig, i);
+    }
+
+    for (unsigned i = old_size; i < new_size; i++)
+       constituents[i] = zero;
+
+    return spirv_builder_emit_composite_construct(&ctx->builder, type,
+                                                  constituents, new_size);
+}
+
 static void
 emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
 {
@@ -1754,7 +1812,7 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
 
    SpvId coord = 0, proj = 0, bias = 0, lod = 0, dref = 0, dx = 0, dy = 0,
          offset = 0, sample = 0;
-   unsigned coord_components = 0;
+   unsigned coord_components = 0, coord_bitsize = 0, offset_components = 0;
    for (unsigned i = 0; i < tex->num_srcs; i++) {
       switch (tex->src[i].src_type) {
       case nir_tex_src_coord:
@@ -1764,6 +1822,7 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
          else
             coord = get_src_float(ctx, &tex->src[i].src);
          coord_components = nir_src_num_components(tex->src[i].src);
+         coord_bitsize = nir_src_bit_size(tex->src[i].src);
          break;
 
       case nir_tex_src_projector:
@@ -1774,6 +1833,7 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
 
       case nir_tex_src_offset:
          offset = get_src_int(ctx, &tex->src[i].src);
+         offset_components = nir_src_num_components(tex->src[i].src);
          break;
 
       case nir_tex_src_bias:
@@ -1835,6 +1895,8 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
 
    SpvId dest_type = get_dest_type(ctx, &tex->dest, tex->dest_type);
 
+   if (!tex_instr_is_lod_allowed(tex))
+      lod = 0;
    if (tex->op == nir_texop_txs) {
       SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
       SpvId result = spirv_builder_emit_image_query_size(&ctx->builder,
@@ -1875,6 +1937,19 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
    if (tex->op == nir_texop_txf ||
        tex->op == nir_texop_txf_ms) {
       SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
+      if (offset) {
+         /* SPIRV requires matched length vectors for OpIAdd, so if a shader
+          * uses vecs of differing sizes we need to make a new vec padded with zeroes
+          * to mimic how GLSL does this implicitly
+          */
+         if (offset_components > coord_components)
+            coord = pad_coord_vector(ctx, coord, coord_components, offset_components);
+         else if (coord_components > offset_components)
+            offset = pad_coord_vector(ctx, offset, offset_components, coord_components);
+         coord = emit_binop(ctx, SpvOpIAdd,
+                            get_ivec_type(ctx, coord_bitsize, coord_components),
+                            coord, offset);
+      }
       result = spirv_builder_emit_image_fetch(&ctx->builder, dest_type,
                                               image, coord, lod, sample);
    } else {
@@ -2118,7 +2193,9 @@ emit_loop(struct ntv_context *ctx, nir_loop *loop)
    ctx->loop_break = save_break;
    ctx->loop_cont = save_cont;
 
-   branch(ctx, cont_id);
+   /* loop->body may have already ended our block */
+   if (ctx->block_started)
+      branch(ctx, cont_id);
    start_block(ctx, cont_id);
    branch(ctx, header_id);
 
@@ -2155,6 +2232,8 @@ nir_to_spirv(struct nir_shader *s, const struct pipe_stream_output_info *so_info
    struct spirv_shader *ret = NULL;
 
    struct ntv_context ctx = {};
+   ctx.mem_ctx = ralloc_context(NULL);
+   ctx.builder.mem_ctx = ctx.mem_ctx;
 
    switch (s->info.stage) {
    case MESA_SHADER_VERTEX:
@@ -2220,21 +2299,23 @@ nir_to_spirv(struct nir_shader *s, const struct pipe_stream_output_info *so_info
    SpvId entry_point = spirv_builder_new_id(&ctx.builder);
    spirv_builder_emit_name(&ctx.builder, entry_point, "main");
 
-   ctx.vars = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
+   ctx.vars = _mesa_hash_table_create(ctx.mem_ctx, _mesa_hash_pointer,
                                       _mesa_key_pointer_equal);
 
-   ctx.so_outputs = _mesa_hash_table_create(NULL, _mesa_hash_u32,
+   ctx.so_outputs = _mesa_hash_table_create(ctx.mem_ctx, _mesa_hash_u32,
                                             _mesa_key_u32_equal);
 
-   nir_foreach_variable(var, &s->inputs)
+   nir_foreach_shader_in_variable(var, s)
       emit_input(&ctx, var);
 
-   nir_foreach_variable(var, &s->outputs)
+   nir_foreach_shader_out_variable(var, s)
       emit_output(&ctx, var);
 
    if (so_info)
       emit_so_info(&ctx, util_last_bit64(s->info.outputs_written), so_info, local_so_info);
-   nir_foreach_variable(var, &s->uniforms)
+   nir_foreach_variable_with_modes(var, s, nir_var_uniform |
+                                           nir_var_mem_ubo |
+                                           nir_var_mem_ssbo)
       emit_uniform(&ctx, var);
 
    if (s->info.stage == MESA_SHADER_FRAGMENT) {
@@ -2258,18 +2339,21 @@ nir_to_spirv(struct nir_shader *s, const struct pipe_stream_output_info *so_info
    nir_function_impl *entry = nir_shader_get_entrypoint(s);
    nir_metadata_require(entry, nir_metadata_block_index);
 
-   ctx.defs = (SpvId *)malloc(sizeof(SpvId) * entry->ssa_alloc);
+   ctx.defs = ralloc_array_size(ctx.mem_ctx,
+                                sizeof(SpvId), entry->ssa_alloc);
    if (!ctx.defs)
       goto fail;
    ctx.num_defs = entry->ssa_alloc;
 
    nir_index_local_regs(entry);
-   ctx.regs = malloc(sizeof(SpvId) * entry->reg_alloc);
+   ctx.regs = ralloc_array_size(ctx.mem_ctx,
+                                sizeof(SpvId), entry->reg_alloc);
    if (!ctx.regs)
       goto fail;
    ctx.num_regs = entry->reg_alloc;
 
-   SpvId *block_ids = (SpvId *)malloc(sizeof(SpvId) * entry->num_blocks);
+   SpvId *block_ids = ralloc_array_size(ctx.mem_ctx,
+                                        sizeof(SpvId), entry->num_blocks);
    if (!block_ids)
       goto fail;
 
@@ -2282,7 +2366,7 @@ nir_to_spirv(struct nir_shader *s, const struct pipe_stream_output_info *so_info
    /* emit a block only for the variable declarations */
    start_block(&ctx, spirv_builder_new_id(&ctx.builder));
    foreach_list_typed(nir_register, reg, node, &entry->registers) {
-      SpvId type = get_uvec_type(&ctx, reg->bit_size, reg->num_components);
+      SpvId type = get_vec_from_bit_size(&ctx, reg->bit_size, reg->num_components);
       SpvId pointer_type = spirv_builder_type_pointer(&ctx.builder,
                                                       SpvStorageClassFunction,
                                                       type);
@@ -2294,8 +2378,6 @@ nir_to_spirv(struct nir_shader *s, const struct pipe_stream_output_info *so_info
 
    emit_cf_list(&ctx, &entry->body);
 
-   free(ctx.defs);
-
    if (so_info)
       emit_so_outputs(&ctx, so_info, local_so_info);
 
@@ -2319,19 +2401,16 @@ nir_to_spirv(struct nir_shader *s, const struct pipe_stream_output_info *so_info
    ret->num_words = spirv_builder_get_words(&ctx.builder, ret->words, num_words);
    assert(ret->num_words == num_words);
 
+   ralloc_free(ctx.mem_ctx);
+
    return ret;
 
 fail:
+   ralloc_free(ctx.mem_ctx);
 
    if (ret)
       spirv_shader_delete(ret);
 
-   if (ctx.vars)
-      _mesa_hash_table_destroy(ctx.vars, NULL);
-
-   if (ctx.so_outputs)
-      _mesa_hash_table_destroy(ctx.so_outputs, NULL);
-
    return NULL;
 }