From f5deed138a0b4765438135367248f1d8f0649975 Mon Sep 17 00:00:00 2001 From: Jason Ekstrand Date: Thu, 9 Apr 2020 17:09:10 -0500 Subject: [PATCH] spirv,nir: Move the SPIR-V vector insert code to NIR This also makes spirv_to_nir a bit simpler because the new nir_vector_insert helper automatically handles a constant component selector like nir_vector_extract does. Reviewed-by: Caio Marcelo de Oliveira Filho Part-of: --- src/compiler/nir/nir_builder.h | 55 ++++++++++++++++++++++++++++++ src/compiler/spirv/spirv_to_nir.c | 46 +++---------------------- src/compiler/spirv/vtn_private.h | 5 --- src/compiler/spirv/vtn_variables.c | 8 ++--- 4 files changed, 61 insertions(+), 53 deletions(-) diff --git a/src/compiler/nir/nir_builder.h b/src/compiler/nir/nir_builder.h index 481ea6382bf..52fcf9e2250 100644 --- a/src/compiler/nir/nir_builder.h +++ b/src/compiler/nir/nir_builder.h @@ -601,6 +601,61 @@ nir_vector_extract(nir_builder *b, nir_ssa_def *vec, nir_ssa_def *c) } } +/** Replaces the component of `vec` specified by `c` with `scalar` */ +static inline nir_ssa_def * +nir_vector_insert_imm(nir_builder *b, nir_ssa_def *vec, + nir_ssa_def *scalar, unsigned c) +{ + assert(scalar->num_components == 1); + assert(c < vec->num_components); + + nir_op vec_op = nir_op_vec(vec->num_components); + nir_alu_instr *vec_instr = nir_alu_instr_create(b->shader, vec_op); + + for (unsigned i = 0; i < vec->num_components; i++) { + if (i == c) { + vec_instr->src[i].src = nir_src_for_ssa(scalar); + vec_instr->src[i].swizzle[0] = 0; + } else { + vec_instr->src[i].src = nir_src_for_ssa(vec); + vec_instr->src[i].swizzle[0] = i; + } + } + + return nir_builder_alu_instr_finish_and_insert(b, vec_instr); +} + +/** Replaces the component of `vec` specified by `c` with `scalar` */ +static inline nir_ssa_def * +nir_vector_insert(nir_builder *b, nir_ssa_def *vec, nir_ssa_def *scalar, + nir_ssa_def *c) +{ + assert(scalar->num_components == 1); + assert(c->num_components == 1); + + nir_src c_src = nir_src_for_ssa(c); + if (nir_src_is_const(c_src)) { + uint64_t c_const = nir_src_as_uint(c_src); + if (c_const < vec->num_components) + return nir_vector_insert_imm(b, vec, scalar, c_const); + else + return vec; + } else { + nir_const_value per_comp_idx_const[NIR_MAX_VEC_COMPONENTS]; + for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) + per_comp_idx_const[i] = nir_const_value_for_int(i, c->bit_size); + nir_ssa_def *per_comp_idx = + nir_build_imm(b, vec->num_components, + c->bit_size, per_comp_idx_const); + + /* nir_builder will automatically splat out scalars to vectors so an + * insert is as simple as "if I'm the channel, replace me with the + * scalar." + */ + return nir_bcsel(b, nir_ieq(b, c, per_comp_idx), scalar, vec); + } +} + static inline nir_ssa_def * nir_i2i(nir_builder *build, nir_ssa_def *x, unsigned dest_bit_size) { diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index 2cc8f2570c7..3cac23433f2 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -3311,44 +3311,6 @@ vtn_ssa_transpose(struct vtn_builder *b, struct vtn_ssa_value *src) return dest; } -nir_ssa_def * -vtn_vector_insert(struct vtn_builder *b, nir_ssa_def *src, nir_ssa_def *insert, - unsigned index) -{ - nir_alu_instr *vec = create_vec(b, src->num_components, - src->bit_size); - - for (unsigned i = 0; i < src->num_components; i++) { - if (i == index) { - vec->src[i].src = nir_src_for_ssa(insert); - } else { - vec->src[i].src = nir_src_for_ssa(src); - vec->src[i].swizzle[0] = i; - } - } - - nir_builder_instr_insert(&b->nb, &vec->instr); - - return &vec->dest.dest.ssa; -} - -nir_ssa_def * -vtn_vector_insert_dynamic(struct vtn_builder *b, nir_ssa_def *src, - nir_ssa_def *insert, nir_ssa_def *index) -{ - nir_const_value per_comp_idx_const[NIR_MAX_VEC_COMPONENTS]; - for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) - per_comp_idx_const[i] = nir_const_value_for_int(i, index->bit_size); - nir_ssa_def *per_comp_idx = - nir_build_imm(&b->nb, src->num_components, - index->bit_size, per_comp_idx_const); - - /* nir_builder will automatically splat out scalars to vectors so an insert - * is as simple as "if I'm the channel, replace me with the scalar." - */ - return nir_bcsel(&b->nb, nir_ieq(&b->nb, index, per_comp_idx), insert, src); -} - static nir_ssa_def * vtn_vector_shuffle(struct vtn_builder *b, unsigned num_components, nir_ssa_def *src0, nir_ssa_def *src1, @@ -3462,7 +3424,7 @@ vtn_composite_insert(struct vtn_builder *b, struct vtn_ssa_value *src, * the index to insert the scalar into the vector. */ - cur->def = vtn_vector_insert(b, cur->def, insert->def, indices[i]); + cur->def = nir_vector_insert_imm(&b->nb, cur->def, insert->def, indices[i]); } else { vtn_fail_if(indices[i] >= glsl_get_length(cur->type), "All indices in an OpCompositeInsert must be in-bounds"); @@ -3516,9 +3478,9 @@ vtn_handle_composite(struct vtn_builder *b, SpvOp opcode, break; case SpvOpVectorInsertDynamic: - ssa->def = vtn_vector_insert_dynamic(b, vtn_ssa_value(b, w[3])->def, - vtn_ssa_value(b, w[4])->def, - vtn_ssa_value(b, w[5])->def); + ssa->def = nir_vector_insert(&b->nb, vtn_ssa_value(b, w[3])->def, + vtn_ssa_value(b, w[4])->def, + vtn_ssa_value(b, w[5])->def); break; case SpvOpVectorShuffle: diff --git a/src/compiler/spirv/vtn_private.h b/src/compiler/spirv/vtn_private.h index 3624ac3fa7e..f4e6201febe 100644 --- a/src/compiler/spirv/vtn_private.h +++ b/src/compiler/spirv/vtn_private.h @@ -797,11 +797,6 @@ struct vtn_ssa_value *vtn_create_ssa_value(struct vtn_builder *b, struct vtn_ssa_value *vtn_ssa_transpose(struct vtn_builder *b, struct vtn_ssa_value *src); -nir_ssa_def *vtn_vector_insert(struct vtn_builder *b, nir_ssa_def *src, - nir_ssa_def *insert, unsigned index); -nir_ssa_def *vtn_vector_insert_dynamic(struct vtn_builder *b, nir_ssa_def *src, - nir_ssa_def *insert, nir_ssa_def *index); - nir_deref_instr *vtn_nir_deref(struct vtn_builder *b, uint32_t id); struct vtn_pointer *vtn_pointer_for_variable(struct vtn_builder *b, diff --git a/src/compiler/spirv/vtn_variables.c b/src/compiler/spirv/vtn_variables.c index 8bb00a5dd40..9dc2c755ca5 100644 --- a/src/compiler/spirv/vtn_variables.c +++ b/src/compiler/spirv/vtn_variables.c @@ -751,12 +751,8 @@ vtn_local_store(struct vtn_builder *b, struct vtn_ssa_value *src, struct vtn_ssa_value *val = vtn_create_ssa_value(b, dest_tail->type); _vtn_local_load_store(b, true, dest_tail, val, access); - if (nir_src_is_const(dest->arr.index)) - val->def = vtn_vector_insert(b, val->def, src->def, - nir_src_as_uint(dest->arr.index)); - else - val->def = vtn_vector_insert_dynamic(b, val->def, src->def, - dest->arr.index.ssa); + val->def = nir_vector_insert(&b->nb, val->def, src->def, + dest->arr.index.ssa); _vtn_local_load_store(b, false, dest_tail, val, access); } else { _vtn_local_load_store(b, false, dest_tail, src, access); -- 2.30.2