From 00af1128a9d555a09e18eb3fd8ce1829d94509b7 Mon Sep 17 00:00:00 2001 From: Jason Ekstrand Date: Fri, 29 May 2020 14:40:12 -0500 Subject: [PATCH] spirv/subgroups: Refactor to use vtn_push_ssa Reviewed-by: Caio Marcelo de Oliveira Filho Part-of: --- src/compiler/spirv/vtn_subgroup.c | 78 +++++++++++++++++-------------- 1 file changed, 42 insertions(+), 36 deletions(-) diff --git a/src/compiler/spirv/vtn_subgroup.c b/src/compiler/spirv/vtn_subgroup.c index aa8ddff5654..8e4c3f2ba92 100644 --- a/src/compiler/spirv/vtn_subgroup.c +++ b/src/compiler/spirv/vtn_subgroup.c @@ -23,10 +23,9 @@ #include "vtn_private.h" -static void +static struct vtn_ssa_value * vtn_build_subgroup_instr(struct vtn_builder *b, nir_intrinsic_op nir_op, - struct vtn_ssa_value *dst, struct vtn_ssa_value *src0, nir_ssa_def *index, unsigned const_idx0, @@ -39,14 +38,16 @@ vtn_build_subgroup_instr(struct vtn_builder *b, if (index && index->bit_size != 32) index = nir_u2u32(&b->nb, index); + struct vtn_ssa_value *dst = vtn_create_ssa_value(b, src0->type); + vtn_assert(dst->type == src0->type); if (!glsl_type_is_vector_or_scalar(dst->type)) { for (unsigned i = 0; i < glsl_get_length(dst->type); i++) { - vtn_build_subgroup_instr(b, nir_op, dst->elems[i], - src0->elems[i], index, - const_idx0, const_idx1); + dst->elems[0] = + vtn_build_subgroup_instr(b, nir_op, src0->elems[i], index, + const_idx0, const_idx1); } - return; + return dst; } nir_intrinsic_instr *intrin = @@ -65,33 +66,33 @@ vtn_build_subgroup_instr(struct vtn_builder *b, nir_builder_instr_insert(&b->nb, &intrin->instr); dst->def = &intrin->dest.ssa; + + return dst; } void vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, const uint32_t *w, unsigned count) { - struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa); - - val->ssa = vtn_create_ssa_value(b, val->type->type); + struct vtn_type *dest_type = vtn_get_type(b, w[1]); switch (opcode) { case SpvOpGroupNonUniformElect: { - vtn_fail_if(val->type->type != glsl_bool_type(), + vtn_fail_if(dest_type->type != glsl_bool_type(), "OpGroupNonUniformElect must return a Bool"); nir_intrinsic_instr *elect = nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_elect); nir_ssa_dest_init_for_type(&elect->instr, &elect->dest, - val->type->type, NULL); + dest_type->type, NULL); nir_builder_instr_insert(&b->nb, &elect->instr); - val->ssa->def = &elect->dest.ssa; + vtn_push_nir_ssa(b, w[2], &elect->dest.ssa); break; } case SpvOpGroupNonUniformBallot: case SpvOpSubgroupBallotKHR: { bool has_scope = (opcode != SpvOpSubgroupBallotKHR); - vtn_fail_if(val->type->type != glsl_vector_type(GLSL_TYPE_UINT, 4), + vtn_fail_if(dest_type->type != glsl_vector_type(GLSL_TYPE_UINT, 4), "OpGroupNonUniformBallot must return a uvec4"); nir_intrinsic_instr *ballot = nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot); @@ -99,7 +100,7 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL); ballot->num_components = 4; nir_builder_instr_insert(&b->nb, &ballot->instr); - val->ssa->def = &ballot->dest.ssa; + vtn_push_nir_ssa(b, w[2], &ballot->dest.ssa); break; } @@ -116,10 +117,10 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb)); nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest, - val->type->type, NULL); + dest_type->type, NULL); nir_builder_instr_insert(&b->nb, &intrin->instr); - val->ssa->def = &intrin->dest.ssa; + vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa); break; } @@ -171,19 +172,20 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, intrin->src[1] = nir_src_for_ssa(src1); nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest, - val->type->type, NULL); + dest_type->type, NULL); nir_builder_instr_insert(&b->nb, &intrin->instr); - val->ssa->def = &intrin->dest.ssa; + vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa); break; } case SpvOpGroupNonUniformBroadcastFirst: case SpvOpSubgroupFirstInvocationKHR: { bool has_scope = (opcode != SpvOpSubgroupFirstInvocationKHR); - vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation, - val->ssa, vtn_ssa_value(b, w[3 + has_scope]), - NULL, 0, 0); + vtn_push_ssa_value(b, w[2], + vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation, + vtn_ssa_value(b, w[3 + has_scope]), + NULL, 0, 0)); break; } @@ -191,9 +193,10 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, case SpvOpGroupBroadcast: case SpvOpSubgroupReadInvocationKHR: { bool has_scope = (opcode != SpvOpSubgroupReadInvocationKHR); - vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation, - val->ssa, vtn_ssa_value(b, w[3 + has_scope]), - vtn_get_nir_ssa(b, w[4 + has_scope]), 0, 0); + vtn_push_ssa_value(b, w[2], + vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation, + vtn_ssa_value(b, w[3 + has_scope]), + vtn_get_nir_ssa(b, w[4 + has_scope]), 0, 0)); break; } @@ -205,7 +208,7 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, case SpvOpSubgroupAllKHR: case SpvOpSubgroupAnyKHR: case SpvOpSubgroupAllEqualKHR: { - vtn_fail_if(val->type->type != glsl_bool_type(), + vtn_fail_if(dest_type->type != glsl_bool_type(), "OpGroupNonUniform(All|Any|AllEqual) must return a bool"); nir_intrinsic_op op; switch (opcode) { @@ -262,10 +265,10 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, intrin->num_components = src0->num_components; intrin->src[0] = nir_src_for_ssa(src0); nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest, - val->type->type, NULL); + dest_type->type, NULL); nir_builder_instr_insert(&b->nb, &intrin->instr); - val->ssa->def = &intrin->dest.ssa; + vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa); break; } @@ -290,15 +293,17 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, default: unreachable("Invalid opcode"); } - vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]), - vtn_get_nir_ssa(b, w[5]), 0, 0); + vtn_push_ssa_value(b, w[2], + vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]), + vtn_get_nir_ssa(b, w[5]), 0, 0)); break; } case SpvOpGroupNonUniformQuadBroadcast: - vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast, - val->ssa, vtn_ssa_value(b, w[4]), - vtn_get_nir_ssa(b, w[5]), 0, 0); + vtn_push_ssa_value(b, w[2], + vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast, + vtn_ssa_value(b, w[4]), + vtn_get_nir_ssa(b, w[5]), 0, 0)); break; case SpvOpGroupNonUniformQuadSwap: { @@ -317,8 +322,8 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, default: vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap"); } - vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]), - NULL, 0, 0); + vtn_push_ssa_value(b, w[2], + vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]), NULL, 0, 0)); break; } @@ -439,8 +444,9 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, unreachable("Invalid group operation"); } - vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[5]), - NULL, reduction_op, cluster_size); + vtn_push_ssa_value(b, w[2], + vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[5]), NULL, + reduction_op, cluster_size)); break; } -- 2.30.2