zink/spirv: handle reading registers
authorErik Faye-Lund <erik.faye-lund@collabora.com>
Thu, 21 Mar 2019 11:14:53 +0000 (12:14 +0100)
committerErik Faye-Lund <erik.faye-lund@collabora.com>
Mon, 28 Oct 2019 08:51:43 +0000 (08:51 +0000)
Acked-by: Jordan Justen <jordan.l.justen@intel.com>
src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c

index 6b2c75fa96ff6da2c9009aa152934e2b4adf709b..c18a57cd5a5ad275b9485d69cad734b9ffe78672 100644 (file)
@@ -48,6 +48,8 @@ struct ntv_context {
 
    SpvId *defs;
    size_t num_defs;
+
+   struct hash_table *vars;
 };
 
 static SpvId
@@ -371,18 +373,50 @@ emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
 }
 
 static SpvId
-get_src_uint_ssa(struct ntv_context *ctx, nir_ssa_def *ssa)
+get_src_uint_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
 {
    assert(ssa->index < ctx->num_defs);
    assert(ctx->defs[ssa->index] != 0);
    return ctx->defs[ssa->index];
 }
 
+static SpvId
+get_var_from_reg(struct ntv_context *ctx, nir_register *reg)
+{
+   struct hash_entry *he = _mesa_hash_table_search(ctx->vars, reg);
+   if (!he) {
+      SpvId type = get_uvec_type(ctx, reg->bit_size, reg->num_components);
+      SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
+                                                      SpvStorageClassFunction,
+                                                      type);
+
+      SpvId var = spirv_builder_emit_var(&ctx->builder, pointer_type,
+                                         SpvStorageClassFunction);
+
+      he = _mesa_hash_table_insert(ctx->vars, reg, (void *)(intptr_t)var);
+   }
+   return (SpvId)(intptr_t)he->data;
+}
+
+static SpvId
+get_src_uint_reg(struct ntv_context *ctx, const nir_reg_src *reg)
+{
+   assert(reg->reg);
+   assert(!reg->indirect);
+   assert(!reg->base_offset);
+
+   SpvId var = get_var_from_reg(ctx, reg->reg);
+   SpvId type = get_uvec_type(ctx, reg->reg->bit_size, reg->reg->num_components);
+   return spirv_builder_emit_load(&ctx->builder, type, var);
+}
+
 static SpvId
 get_src_uint(struct ntv_context *ctx, nir_src *src)
 {
-   assert(src->is_ssa);
-   return get_src_uint_ssa(ctx, src->ssa);
+   if (src->is_ssa)
+      return get_src_uint_ssa(ctx, src->ssa);
+   else
+      return get_src_uint_reg(ctx, &src->reg);
 }
 
 static SpvId
@@ -504,11 +538,21 @@ bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
    return emit_unop(ctx, SpvOpBitcast, type, value);
 }
 
+static void
+store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result)
+{
+   SpvId var = get_var_from_reg(ctx, reg->reg);
+   assert(var);
+   spirv_builder_emit_store(&ctx->builder, var, result);
+}
+
 static void
 store_dest_uint(struct ntv_context *ctx, nir_dest *dest, SpvId result)
 {
-   assert(dest->is_ssa);
-   store_ssa_def_uint(ctx, &dest->ssa, result);
+   if (dest->is_ssa)
+      store_ssa_def_uint(ctx, &dest->ssa, result);
+   else
+      store_reg_def(ctx, &dest->reg, result);
 }
 
 static void
@@ -1241,6 +1285,11 @@ nir_to_spirv(struct nir_shader *s)
       goto fail;
    ctx.num_defs = entry->ssa_alloc;
 
+   ctx.vars = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
+                                            _mesa_key_pointer_equal);
+   if (!ctx.vars)
+      goto fail;
+
    emit_cf_list(&ctx, &entry->body);
    free(ctx.defs);