From d8eb6f2499c66e26d7312f59dc052b9f416cc486 Mon Sep 17 00:00:00 2001 From: Jason Ekstrand Date: Wed, 27 May 2020 17:49:47 -0500 Subject: [PATCH] spirv: Add a vtn_push_nir_ssa helper This makes it easy to write a simple NIR SSA value Reviewed-by: Caio Marcelo de Oliveira Filho Part-of: --- src/compiler/spirv/spirv_to_nir.c | 45 +++++++++++++++--------------- src/compiler/spirv/vtn_alu.c | 6 ++-- src/compiler/spirv/vtn_glsl450.c | 21 ++++---------- src/compiler/spirv/vtn_opencl.c | 8 ++---- src/compiler/spirv/vtn_private.h | 3 ++ src/compiler/spirv/vtn_variables.c | 4 +-- 6 files changed, 37 insertions(+), 50 deletions(-) diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index 308102f4bf0..3940c2347f8 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -302,6 +302,21 @@ vtn_ssa_value(struct vtn_builder *b, uint32_t value_id) } } +struct vtn_value * +vtn_push_nir_ssa(struct vtn_builder *b, uint32_t value_id, nir_ssa_def *def) +{ + /* Types for all SPIR-V SSA values are set as part of a pre-pass so the + * type will be valid by the time we get here. + */ + struct vtn_type *type = vtn_get_value_type(b, value_id); + vtn_fail_if(def->num_components != glsl_get_vector_elements(type->type) || + def->bit_size != glsl_get_bit_size(type->type), + "Mismatch between NIR and SPIR-V type."); + struct vtn_ssa_value *ssa = vtn_create_ssa_value(b, type->type); + ssa->def = def; + return vtn_push_ssa(b, value_id, type, ssa); +} + static char * vtn_string_literal(struct vtn_builder *b, const uint32_t *words, unsigned word_count, unsigned *words_used) @@ -2672,11 +2687,9 @@ vtn_handle_texture(struct vtn_builder *b, SpvOp opcode, } } - struct vtn_ssa_value *ssa = vtn_create_ssa_value(b, ret_type->type); - ssa->def = &instr->dest.ssa; - vtn_push_ssa(b, w[2], ret_type, ssa); - nir_builder_instr_insert(&b->nb, &instr->instr); + + vtn_push_nir_ssa(b, w[2], &instr->dest.ssa); } static void @@ -3026,9 +3039,7 @@ vtn_handle_image(struct vtn_builder *b, SpvOp opcode, if (nir_intrinsic_dest_components(intrin) != dest_components) result = nir_channels(&b->nb, result, (1 << dest_components) - 1); - struct vtn_value *val = - vtn_push_ssa(b, w[2], type, vtn_create_ssa_value(b, type->type)); - val->ssa->def = result; + vtn_push_nir_ssa(b, w[2], result); } else { nir_builder_instr_insert(&b->nb, &intrin->instr); } @@ -3318,10 +3329,7 @@ vtn_handle_atomics(struct vtn_builder *b, SpvOp opcode, glsl_get_vector_elements(type->type), glsl_get_bit_size(type->type), NULL); - struct vtn_ssa_value *ssa = rzalloc(b, struct vtn_ssa_value); - ssa->def = &atomic->dest.ssa; - ssa->type = type->type; - vtn_push_ssa(b, w[2], type, ssa); + vtn_push_nir_ssa(b, w[2], &atomic->dest.ssa); } nir_builder_instr_insert(&b->nb, &atomic->instr); @@ -4777,9 +4785,7 @@ vtn_handle_ptr(struct vtn_builder *b, SpvOp opcode, unreachable("Invalid ptr operation"); } - struct vtn_ssa_value *ssa_value = vtn_create_ssa_value(b, type); - ssa_value->def = def; - vtn_push_ssa(b, w[2], vtn_type, ssa_value); + vtn_push_nir_ssa(b, w[2], def); } static bool @@ -5134,11 +5140,7 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode, nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 1, NULL); nir_builder_instr_insert(&b->nb, &intrin->instr); - struct vtn_type *res_type = vtn_get_type(b, w[1]); - struct vtn_ssa_value *val = vtn_create_ssa_value(b, res_type->type); - val->def = &intrin->dest.ssa; - - vtn_push_ssa(b, w[2], res_type, val); + vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa); break; } @@ -5179,10 +5181,7 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode, result = nir_pack_64_2x32(&b->nb, &intrin->dest.ssa); } - struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa); - val->type = type; - val->ssa = vtn_create_ssa_value(b, dest_type); - val->ssa->def = result; + vtn_push_nir_ssa(b, w[2], result); break; } diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c index 60e88144ceb..035bd857f19 100644 --- a/src/compiler/spirv/vtn_alu.c +++ b/src/compiler/spirv/vtn_alu.c @@ -699,7 +699,6 @@ vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count) struct vtn_type *type = vtn_get_type(b, w[1]); struct vtn_ssa_value *vtn_src = vtn_ssa_value(b, w[3]); struct nir_ssa_def *src = vtn_src->def; - struct vtn_ssa_value *val = vtn_create_ssa_value(b, type->type); vtn_assert(glsl_type_is_vector_or_scalar(vtn_src->type)); @@ -707,6 +706,7 @@ vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count) glsl_get_vector_elements(type->type) * glsl_get_bit_size(type->type), "Source and destination of OpBitcast must have the same " "total number of bits"); - val->def = nir_bitcast_vector(&b->nb, src, glsl_get_bit_size(type->type)); - vtn_push_ssa(b, w[2], type, val); + nir_ssa_def *val = + nir_bitcast_vector(&b->nb, src, glsl_get_bit_size(type->type)); + vtn_push_nir_ssa(b, w[2], val); } diff --git a/src/compiler/spirv/vtn_glsl450.c b/src/compiler/spirv/vtn_glsl450.c index 061ffd09282..933cf1d407e 100644 --- a/src/compiler/spirv/vtn_glsl450.c +++ b/src/compiler/spirv/vtn_glsl450.c @@ -557,10 +557,6 @@ static void handle_glsl450_interpolation(struct vtn_builder *b, enum GLSLstd450 opcode, const uint32_t *w, unsigned count) { - const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type; - struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa); - val->ssa = vtn_create_ssa_value(b, dest_type); - nir_intrinsic_op op; switch (opcode) { case GLSLstd450InterpolateAtCentroid: @@ -615,13 +611,11 @@ handle_glsl450_interpolation(struct vtn_builder *b, enum GLSLstd450 opcode, nir_builder_instr_insert(&b->nb, &intrin->instr); - if (vec_array_deref) { - assert(vec_deref); - val->ssa->def = nir_vector_extract(&b->nb, &intrin->dest.ssa, - vec_deref->arr.index.ssa); - } else { - val->ssa->def = &intrin->dest.ssa; - } + nir_ssa_def *def = &intrin->dest.ssa; + if (vec_array_deref) + def = nir_vector_extract(&b->nb, def, vec_deref->arr.index.ssa); + + vtn_push_nir_ssa(b, w[2], def); } bool @@ -630,10 +624,7 @@ vtn_handle_glsl450_instruction(struct vtn_builder *b, SpvOp ext_opcode, { switch ((enum GLSLstd450)ext_opcode) { case GLSLstd450Determinant: { - struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa); - val->ssa = rzalloc(b, struct vtn_ssa_value); - val->ssa->type = vtn_get_type(b, w[1])->type; - val->ssa->def = build_mat_det(b, vtn_ssa_value(b, w[5])); + vtn_push_nir_ssa(b, w[2], build_mat_det(b, vtn_ssa_value(b, w[5]))); break; } diff --git a/src/compiler/spirv/vtn_opencl.c b/src/compiler/spirv/vtn_opencl.c index 57d39ee9e64..ba3d00a7ce4 100644 --- a/src/compiler/spirv/vtn_opencl.c +++ b/src/compiler/spirv/vtn_opencl.c @@ -50,9 +50,7 @@ handle_instr(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, nir_ssa_def *result = handler(b, opcode, num_srcs, srcs, dest_type); if (result) { - struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa); - val->ssa = vtn_create_ssa_value(b, dest_type); - val->ssa->def = result; + vtn_push_nir_ssa(b, w[2], result); } else { vtn_assert(dest_type == glsl_void_type()); } @@ -256,9 +254,7 @@ _handle_v_load_store(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, } } if (load) { - struct vtn_ssa_value *ssa = vtn_create_ssa_value(b, dest_type); - ssa->def = nir_vec(&b->nb, ncomps, components); - vtn_push_ssa(b, w[2], type, ssa); + vtn_push_nir_ssa(b, w[2], nir_vec(&b->nb, ncomps, components)); } } diff --git a/src/compiler/spirv/vtn_private.h b/src/compiler/spirv/vtn_private.h index 9d9e7883481..fd1ce33c6be 100644 --- a/src/compiler/spirv/vtn_private.h +++ b/src/compiler/spirv/vtn_private.h @@ -785,6 +785,9 @@ vtn_get_type(struct vtn_builder *b, uint32_t value_id) struct vtn_ssa_value *vtn_ssa_value(struct vtn_builder *b, uint32_t value_id); +struct vtn_value *vtn_push_nir_ssa(struct vtn_builder *b, uint32_t value_id, + nir_ssa_def *def); + struct vtn_value *vtn_push_pointer(struct vtn_builder *b, uint32_t value_id, struct vtn_pointer *ptr); diff --git a/src/compiler/spirv/vtn_variables.c b/src/compiler/spirv/vtn_variables.c index fb3eef7de71..6bd6bd5873e 100644 --- a/src/compiler/spirv/vtn_variables.c +++ b/src/compiler/spirv/vtn_variables.c @@ -2735,9 +2735,7 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode, nir_imm_int(&b->nb, 0u)), nir_imm_int(&b->nb, stride)); - struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa); - val->ssa = vtn_create_ssa_value(b, glsl_uint_type()); - val->ssa->def = array_length; + vtn_push_nir_ssa(b, w[2], array_length); break; } -- 2.30.2