zink: do not lower io
[mesa.git] / src / gallium / drivers / zink / nir_to_spirv / nir_to_spirv.c
index 20c72fd214734d01fbeefa3eae5e3803ef4713df..8534ed1300a4e180035fbb8250bf69bf4201f78e 100644 (file)
@@ -27,6 +27,7 @@
 #include "nir.h"
 #include "pipe/p_state.h"
 #include "util/u_memory.h"
+#include "util/hash_table.h"
 
 struct ntv_context {
    struct spirv_builder builder;
@@ -34,10 +35,7 @@ struct ntv_context {
    SpvId GLSL_std_450;
 
    gl_shader_stage stage;
-   SpvId inputs[PIPE_MAX_SHADER_INPUTS][4];
-   SpvId input_types[PIPE_MAX_SHADER_INPUTS][4];
-   SpvId outputs[PIPE_MAX_SHADER_OUTPUTS][4];
-   SpvId output_types[PIPE_MAX_SHADER_OUTPUTS][4];
+   int var_location;
 
    SpvId ubos[128];
    size_t num_ubos;
@@ -49,7 +47,10 @@ struct ntv_context {
    SpvId *defs;
    size_t num_defs;
 
-   struct hash_table *vars;
+   SpvId *regs;
+   size_t num_regs;
+
+   struct hash_table *vars; /* nir_variable -> SpvId */
 
    const SpvId *block_ids;
    size_t num_blocks;
@@ -181,10 +182,10 @@ get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
 static void
 emit_input(struct ntv_context *ctx, struct nir_variable *var)
 {
-   SpvId vec_type = get_glsl_type(ctx, var->type);
+   SpvId var_type = get_glsl_type(ctx, var->type);
    SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
                                                    SpvStorageClassInput,
-                                                   vec_type);
+                                                   var_type);
    SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
                                          SpvStorageClassInput);
 
@@ -192,19 +193,24 @@ emit_input(struct ntv_context *ctx, struct nir_variable *var)
       spirv_builder_emit_name(&ctx->builder, var_id, var->name);
 
    if (ctx->stage == MESA_SHADER_FRAGMENT) {
-      switch (var->data.location) {
-      case VARYING_SLOT_POS:
-         spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragCoord);
-         break;
-
-      case VARYING_SLOT_PNTC:
-         spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointCoord);
-         break;
-
-      default:
+      if (var->data.location >= VARYING_SLOT_VAR0 ||
+          (var->data.location >= VARYING_SLOT_COL0 &&
+           var->data.location <= VARYING_SLOT_TEX7)) {
          spirv_builder_emit_location(&ctx->builder, var_id,
-                                     var->data.driver_location);
-         break;
+                                     ctx->var_location++);
+      } else {
+         switch (var->data.location) {
+         case VARYING_SLOT_POS:
+            spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragCoord);
+            break;
+
+         case VARYING_SLOT_PNTC:
+            spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointCoord);
+            break;
+
+         default:
+            unreachable("unknown varying slot");
+         }
       }
    } else {
       spirv_builder_emit_location(&ctx->builder, var_id,
@@ -218,11 +224,7 @@ emit_input(struct ntv_context *ctx, struct nir_variable *var)
    if (var->data.interpolation == INTERP_MODE_FLAT)
       spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
 
-   assert(var->data.driver_location < PIPE_MAX_SHADER_INPUTS);
-   assert(var->data.location_frac < 4);
-   assert(ctx->inputs[var->data.driver_location][var->data.location_frac] == 0);
-   ctx->inputs[var->data.driver_location][var->data.location_frac] = var_id;
-   ctx->input_types[var->data.driver_location][var->data.location_frac] = vec_type;
+   _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
 
    assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
    ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
@@ -231,10 +233,10 @@ emit_input(struct ntv_context *ctx, struct nir_variable *var)
 static void
 emit_output(struct ntv_context *ctx, struct nir_variable *var)
 {
-   SpvId vec_type = get_glsl_type(ctx, var->type);
+   SpvId var_type = get_glsl_type(ctx, var->type);
    SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
                                                    SpvStorageClassOutput,
-                                                   vec_type);
+                                                   var_type);
    SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
                                          SpvStorageClassOutput);
    if (var->name)
@@ -242,18 +244,24 @@ emit_output(struct ntv_context *ctx, struct nir_variable *var)
 
 
    if (ctx->stage == MESA_SHADER_VERTEX) {
-      switch (var->data.location) {
-      case VARYING_SLOT_POS:
-         spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPosition);
-         break;
-
-      case VARYING_SLOT_PSIZ:
-         spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointSize);
-         break;
-
-      default:
+      if (var->data.location >= VARYING_SLOT_VAR0 ||
+          (var->data.location >= VARYING_SLOT_COL0 &&
+           var->data.location <= VARYING_SLOT_TEX7)) {
          spirv_builder_emit_location(&ctx->builder, var_id,
-                                     var->data.driver_location - 1);
+                                     ctx->var_location++);
+      } else {
+         switch (var->data.location) {
+         case VARYING_SLOT_POS:
+            spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPosition);
+            break;
+
+         case VARYING_SLOT_PSIZ:
+            spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInPointSize);
+            break;
+
+         default:
+            unreachable("unknown varying slot");
+         }
       }
    } else if (ctx->stage == MESA_SHADER_FRAGMENT) {
       switch (var->data.location) {
@@ -271,11 +279,7 @@ emit_output(struct ntv_context *ctx, struct nir_variable *var)
       spirv_builder_emit_component(&ctx->builder, var_id,
                                    var->data.location_frac);
 
-   assert(var->data.driver_location < PIPE_MAX_SHADER_INPUTS);
-   assert(var->data.location_frac < 4);
-   assert(ctx->outputs[var->data.driver_location][var->data.location_frac] == 0);
-   ctx->outputs[var->data.driver_location][var->data.location_frac] = var_id;
-   ctx->output_types[var->data.driver_location][var->data.location_frac] = vec_type;
+   _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
 
    assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
    ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
@@ -395,19 +399,9 @@ get_src_uint_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
 static SpvId
 get_var_from_reg(struct ntv_context *ctx, nir_register *reg)
 {
-   struct hash_entry *he = _mesa_hash_table_search(ctx->vars, reg);
-   if (!he) {
-      SpvId type = get_uvec_type(ctx, reg->bit_size, reg->num_components);
-      SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
-                                                      SpvStorageClassFunction,
-                                                      type);
-
-      SpvId var = spirv_builder_emit_var(&ctx->builder, pointer_type,
-                                         SpvStorageClassFunction);
-
-      he = _mesa_hash_table_insert(ctx->vars, reg, (void *)(intptr_t)var);
-   }
-   return (SpvId)(intptr_t)he->data;
+   assert(reg->index < ctx->num_regs);
+   assert(ctx->regs[reg->index] != 0);
+   return ctx->regs[reg->index];
 }
 
 static SpvId
@@ -926,31 +920,6 @@ emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
    store_ssa_def_uint(ctx, &load_const->def, constant);
 }
 
-static void
-emit_load_input(struct ntv_context *ctx, nir_intrinsic_instr *intr)
-{
-   nir_const_value *const_offset = nir_src_as_const_value(intr->src[0]);
-   if (const_offset) {
-      int driver_location = (int)nir_intrinsic_base(intr) + const_offset->u32;
-      assert(driver_location < PIPE_MAX_SHADER_INPUTS);
-      int location_frac = nir_intrinsic_component(intr);
-      assert(location_frac < 4);
-
-      SpvId ptr = ctx->inputs[driver_location][location_frac];
-      SpvId type = ctx->input_types[driver_location][location_frac];
-      assert(ptr && type);
-
-      SpvId result = spirv_builder_emit_load(&ctx->builder, type, ptr);
-
-      unsigned num_components = nir_dest_num_components(intr->dest);
-      unsigned bit_size = nir_dest_bit_size(intr->dest);
-      result = bitcast_to_uvec(ctx, result, bit_size, num_components);
-
-      store_dest_uint(ctx, &intr->dest, result);
-   } else
-      unreachable("input-addressing not yet supported");
-}
-
 static void
 emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
 {
@@ -1002,27 +971,6 @@ emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
       unreachable("uniform-addressing not yet supported");
 }
 
-static void
-emit_store_output(struct ntv_context *ctx, nir_intrinsic_instr *intr)
-{
-   nir_const_value *const_offset = nir_src_as_const_value(intr->src[1]);
-   if (const_offset) {
-      int driver_location = (int)nir_intrinsic_base(intr) + const_offset->u32;
-      assert(driver_location < PIPE_MAX_SHADER_OUTPUTS);
-      int location_frac = nir_intrinsic_component(intr);
-      assert(location_frac < 4);
-
-      SpvId ptr = ctx->outputs[driver_location][location_frac];
-      assert(ptr > 0);
-
-      SpvId src = get_src_uint(ctx, &intr->src[0]);
-      SpvId spirv_type = ctx->output_types[driver_location][location_frac];
-      SpvId result = emit_unop(ctx, SpvOpBitcast, spirv_type, src);
-      spirv_builder_emit_store(&ctx->builder, ptr, result);
-   } else
-      unreachable("output-addressing not yet supported");
-}
-
 static void
 emit_discard(struct ntv_context *ctx, nir_intrinsic_instr *intr)
 {
@@ -1033,26 +981,59 @@ emit_discard(struct ntv_context *ctx, nir_intrinsic_instr *intr)
    spirv_builder_label(&ctx->builder, spirv_builder_new_id(&ctx->builder));
 }
 
+static void
+emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
+{
+   nir_variable *var = nir_intrinsic_get_var(intr, 0);
+   struct hash_entry *he = _mesa_hash_table_search(ctx->vars, var);
+   assert(he);
+   SpvId ptr = (SpvId)(intptr_t)he->data;
+
+   // SpvId ptr = get_src_uint(ctx, intr->src); /* uint is a bit of a lie here; it's really just a pointer */
+   SpvId result = spirv_builder_emit_load(&ctx->builder,
+                                          get_glsl_type(ctx, var->type),
+                                          ptr);
+   unsigned num_components = nir_dest_num_components(intr->dest);
+   unsigned bit_size = nir_dest_bit_size(intr->dest);
+   result = bitcast_to_uvec(ctx, result, bit_size, num_components);
+   store_dest_uint(ctx, &intr->dest, result);
+}
+
+static void
+emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
+{
+   nir_variable *var = nir_intrinsic_get_var(intr, 0);
+   struct hash_entry *he = _mesa_hash_table_search(ctx->vars, var);
+   assert(he);
+   SpvId ptr = (SpvId)(intptr_t)he->data;
+   // SpvId ptr = get_src_uint(ctx, &intr->src[0]); /* uint is a bit of a lie here; it's really just a pointer */
+
+   SpvId src = get_src_uint(ctx, &intr->src[1]);
+   SpvId result = emit_unop(ctx, SpvOpBitcast, get_glsl_type(ctx, var->type),
+                            src);
+   spirv_builder_emit_store(&ctx->builder, ptr, result);
+}
+
 static void
 emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
 {
    switch (intr->intrinsic) {
-   case nir_intrinsic_load_input:
-      emit_load_input(ctx, intr);
-      break;
-
    case nir_intrinsic_load_ubo:
       emit_load_ubo(ctx, intr);
       break;
 
-   case nir_intrinsic_store_output:
-      emit_store_output(ctx, intr);
-      break;
-
    case nir_intrinsic_discard:
       emit_discard(ctx, intr);
       break;
 
+   case nir_intrinsic_load_deref:
+      emit_load_deref(ctx, intr);
+      break;
+
+   case nir_intrinsic_store_deref:
+      emit_store_deref(ctx, intr);
+      break;
+
    default:
       fprintf(stderr, "emit_intrinsic: not implemented (%s)\n",
               nir_intrinsic_infos[intr->intrinsic].name);
@@ -1063,7 +1044,7 @@ emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
 static void
 emit_undef(struct ntv_context *ctx, nir_ssa_undef_instr *undef)
 {
-   SpvId type = get_fvec_type(ctx, undef->def.bit_size,
+   SpvId type = get_uvec_type(ctx, undef->def.bit_size,
                               undef->def.num_components);
 
    store_ssa_def_uint(ctx, &undef->def,
@@ -1086,8 +1067,8 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
    assert(nir_alu_type_get_base_type(tex->dest_type) == nir_type_float);
    assert(tex->texture_index == tex->sampler_index);
 
-   bool has_proj = false;
-   SpvId coord = 0, proj;
+   bool has_proj = false, has_lod = false;
+   SpvId coord = 0, proj, lod;
    unsigned coord_components;
    for (unsigned i = 0; i < tex->num_srcs; i++) {
       switch (tex->src[i].src_type) {
@@ -1102,12 +1083,23 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
          assert(nir_src_num_components(tex->src[i].src) == 1);
          break;
 
+      case nir_tex_src_lod:
+         has_lod = true;
+         lod = get_src_float(ctx, &tex->src[i].src);
+         assert(nir_src_num_components(tex->src[i].src) == 1);
+         break;
+
       default:
          fprintf(stderr, "texture source: %d\n", tex->src[i].src_type);
          unreachable("unknown texture source");
       }
    }
 
+   if (!has_lod && ctx->stage != MESA_SHADER_FRAGMENT) {
+      has_lod = true;
+      lod = spirv_builder_const_float(&ctx->builder, 32, 0);
+   }
+
    bool is_ms;
    SpvDim dimension = type_to_dim(tex->sampler_dim, &is_ms);
    SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
@@ -1141,14 +1133,29 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
                                                             constituents,
                                                             coord_components);
 
-      result = spirv_builder_emit_image_sample_proj_implicit_lod(&ctx->builder,
-                                                                 dest_type,
-                                                                 load,
-                                                                 merged);
-   } else
-      result = spirv_builder_emit_image_sample_implicit_lod(&ctx->builder,
-                                                            dest_type, load,
-                                                            coord);
+      if (has_lod)
+         result = spirv_builder_emit_image_sample_proj_explicit_lod(&ctx->builder,
+                                                                    dest_type,
+                                                                    load,
+                                                                    merged,
+                                                                    lod);
+      else
+         result = spirv_builder_emit_image_sample_proj_implicit_lod(&ctx->builder,
+                                                                    dest_type,
+                                                                    load,
+                                                                    merged);
+   } else {
+      if (has_lod)
+         result = spirv_builder_emit_image_sample_explicit_lod(&ctx->builder,
+                                                               dest_type,
+                                                               load,
+                                                               coord, lod);
+      else
+         result = spirv_builder_emit_image_sample_implicit_lod(&ctx->builder,
+                                                               dest_type,
+                                                               load,
+                                                               coord);
+   }
    spirv_builder_emit_decoration(&ctx->builder, result,
                                  SpvDecorationRelaxedPrecision);
 
@@ -1204,6 +1211,40 @@ emit_jump(struct ntv_context *ctx, nir_jump_instr *jump)
    }
 }
 
+static void
+emit_deref(struct ntv_context *ctx, nir_deref_instr *deref)
+{
+   assert(deref->deref_type == nir_deref_type_var);
+
+   SpvStorageClass storage_class;
+   switch (deref->var->data.mode) {
+   case nir_var_shader_in:
+      storage_class = SpvStorageClassInput;
+      break;
+
+   case nir_var_shader_out:
+      storage_class = SpvStorageClassOutput;
+      break;
+
+   default:
+      unreachable("Unsupported nir_variable_mode\n");
+   }
+
+   struct hash_entry *he = _mesa_hash_table_search(ctx->vars, deref->var);
+   assert(he);
+
+   SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
+                                               storage_class,
+                                               get_glsl_type(ctx, deref->type));
+
+   SpvId result = spirv_builder_emit_access_chain(&ctx->builder,
+                                                  ptr_type,
+                                                  (SpvId)(intptr_t)he->data,
+                                                  NULL, 0);
+   /* uint is a bit of a lie here, it's really just an opaque type */
+   store_dest_uint(ctx, &deref->dest, result);
+}
+
 static void
 emit_block(struct ntv_context *ctx, struct nir_block *block)
 {
@@ -1238,7 +1279,8 @@ emit_block(struct ntv_context *ctx, struct nir_block *block)
          unreachable("nir_instr_type_parallel_copy not supported");
          break;
       case nir_instr_type_deref:
-         unreachable("nir_instr_type_deref not supported");
+         /* these are handled in emit_{load,store}_deref */
+         /* emit_deref(ctx, nir_instr_as_deref(instr)); */
          break;
       }
    }
@@ -1408,6 +1450,9 @@ nir_to_spirv(struct nir_shader *s)
    SpvId entry_point = spirv_builder_new_id(&ctx.builder);
    spirv_builder_emit_name(&ctx.builder, entry_point, "main");
 
+   ctx.vars = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
+                                      _mesa_key_pointer_equal);
+
    nir_foreach_variable(var, &s->inputs)
       emit_input(&ctx, var);
 
@@ -1437,10 +1482,11 @@ nir_to_spirv(struct nir_shader *s)
       goto fail;
    ctx.num_defs = entry->ssa_alloc;
 
-   ctx.vars = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
-                                            _mesa_key_pointer_equal);
-   if (!ctx.vars)
+   nir_index_local_regs(entry);
+   ctx.regs = malloc(sizeof(SpvId) * entry->reg_alloc);
+   if (!ctx.regs)
       goto fail;
+   ctx.num_regs = entry->reg_alloc;
 
    SpvId *block_ids = (SpvId *)malloc(sizeof(SpvId) * entry->num_blocks);
    if (!block_ids)
@@ -1452,6 +1498,19 @@ nir_to_spirv(struct nir_shader *s)
    ctx.block_ids = block_ids;
    ctx.num_blocks = entry->num_blocks;
 
+   /* emit a block only for the variable declarations */
+   start_block(&ctx, spirv_builder_new_id(&ctx.builder));
+   foreach_list_typed(nir_register, reg, node, &entry->registers) {
+      SpvId type = get_uvec_type(&ctx, reg->bit_size, reg->num_components);
+      SpvId pointer_type = spirv_builder_type_pointer(&ctx.builder,
+                                                      SpvStorageClassFunction,
+                                                      type);
+      SpvId var = spirv_builder_emit_var(&ctx.builder, pointer_type,
+                                         SpvStorageClassFunction);
+
+      ctx.regs[reg->index] = var;
+   }
+
    emit_cf_list(&ctx, &entry->body);
 
    free(ctx.defs);
@@ -1479,6 +1538,9 @@ fail:
    if (ret)
       spirv_shader_delete(ret);
 
+   if (ctx.vars)
+      _mesa_hash_table_destroy(ctx.vars, NULL);
+
    return NULL;
 }