From 32aea77cfef7c4304752f9b0fa385b4588d2a26d Mon Sep 17 00:00:00 2001 From: Erik Faye-Lund Date: Thu, 21 Mar 2019 12:14:53 +0100 Subject: [PATCH] zink/spirv: handle reading registers Acked-by: Jordan Justen --- .../drivers/zink/nir_to_spirv/nir_to_spirv.c | 59 +++++++++++++++++-- 1 file changed, 54 insertions(+), 5 deletions(-) diff --git a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c index 6b2c75fa96f..c18a57cd5a5 100644 --- a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c +++ b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c @@ -48,6 +48,8 @@ struct ntv_context { SpvId *defs; size_t num_defs; + + struct hash_table *vars; }; static SpvId @@ -371,18 +373,50 @@ emit_uniform(struct ntv_context *ctx, struct nir_variable *var) } static SpvId -get_src_uint_ssa(struct ntv_context *ctx, nir_ssa_def *ssa) +get_src_uint_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa) { assert(ssa->index < ctx->num_defs); assert(ctx->defs[ssa->index] != 0); return ctx->defs[ssa->index]; } +static SpvId +get_var_from_reg(struct ntv_context *ctx, nir_register *reg) +{ + struct hash_entry *he = _mesa_hash_table_search(ctx->vars, reg); + if (!he) { + SpvId type = get_uvec_type(ctx, reg->bit_size, reg->num_components); + SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder, + SpvStorageClassFunction, + type); + + SpvId var = spirv_builder_emit_var(&ctx->builder, pointer_type, + SpvStorageClassFunction); + + he = _mesa_hash_table_insert(ctx->vars, reg, (void *)(intptr_t)var); + } + return (SpvId)(intptr_t)he->data; +} + +static SpvId +get_src_uint_reg(struct ntv_context *ctx, const nir_reg_src *reg) +{ + assert(reg->reg); + assert(!reg->indirect); + assert(!reg->base_offset); + + SpvId var = get_var_from_reg(ctx, reg->reg); + SpvId type = get_uvec_type(ctx, reg->reg->bit_size, reg->reg->num_components); + return spirv_builder_emit_load(&ctx->builder, type, var); +} + static SpvId get_src_uint(struct ntv_context *ctx, nir_src *src) { - assert(src->is_ssa); - return get_src_uint_ssa(ctx, src->ssa); + if (src->is_ssa) + return get_src_uint_ssa(ctx, src->ssa); + else + return get_src_uint_reg(ctx, &src->reg); } static SpvId @@ -504,11 +538,21 @@ bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size, return emit_unop(ctx, SpvOpBitcast, type, value); } +static void +store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result) +{ + SpvId var = get_var_from_reg(ctx, reg->reg); + assert(var); + spirv_builder_emit_store(&ctx->builder, var, result); +} + static void store_dest_uint(struct ntv_context *ctx, nir_dest *dest, SpvId result) { - assert(dest->is_ssa); - store_ssa_def_uint(ctx, &dest->ssa, result); + if (dest->is_ssa) + store_ssa_def_uint(ctx, &dest->ssa, result); + else + store_reg_def(ctx, &dest->reg, result); } static void @@ -1241,6 +1285,11 @@ nir_to_spirv(struct nir_shader *s) goto fail; ctx.num_defs = entry->ssa_alloc; + ctx.vars = _mesa_hash_table_create(NULL, _mesa_hash_pointer, + _mesa_key_pointer_equal); + if (!ctx.vars) + goto fail; + emit_cf_list(&ctx, &entry->body); free(ctx.defs); -- 2.30.2