From: Jason Ekstrand Date: Wed, 27 May 2020 23:28:18 +0000 (-0500) Subject: spirv: Add a vtn_get_nir_ssa helper X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=37ab3234805bc3fa34c12065bde0bcf37fdbdd89;p=mesa.git spirv: Add a vtn_get_nir_ssa helper Reviewed-by: Caio Marcelo de Oliveira Filho Part-of: --- diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index 3940c2347f8..e0a92bcf76a 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -302,6 +302,15 @@ vtn_ssa_value(struct vtn_builder *b, uint32_t value_id) } } +nir_ssa_def * +vtn_get_nir_ssa(struct vtn_builder *b, uint32_t value_id) +{ + struct vtn_ssa_value *ssa = vtn_ssa_value(b, value_id); + vtn_fail_if(!glsl_type_is_vector_or_scalar(ssa->type), + "Expected a vector or scalar type"); + return ssa->def; +} + struct vtn_value * vtn_push_nir_ssa(struct vtn_builder *b, uint32_t value_id, nir_ssa_def *def) { @@ -2219,7 +2228,7 @@ static nir_tex_src vtn_tex_src(struct vtn_builder *b, unsigned index, nir_tex_src_type type) { nir_tex_src src; - src.src = nir_src_for_ssa(vtn_ssa_value(b, index)->def); + src.src = nir_src_for_ssa(vtn_get_nir_ssa(b, index)); src.src_type = type; return src; } @@ -2468,7 +2477,7 @@ vtn_handle_texture(struct vtn_builder *b, SpvOp opcode, if (is_array && texop != nir_texop_lod) coord_components++; - coord = vtn_ssa_value(b, w[idx++])->def; + coord = vtn_get_nir_ssa(b, w[idx++]); p->src = nir_src_for_ssa(nir_channels(&b->nb, coord, (1 << coord_components) - 1)); p->src_type = nir_tex_src_coord; @@ -2707,13 +2716,13 @@ fill_common_atomic_sources(struct vtn_builder *b, SpvOp opcode, case SpvOpAtomicISub: src[0] = - nir_src_for_ssa(nir_ineg(&b->nb, vtn_ssa_value(b, w[6])->def)); + nir_src_for_ssa(nir_ineg(&b->nb, vtn_get_nir_ssa(b, w[6]))); break; case SpvOpAtomicCompareExchange: case SpvOpAtomicCompareExchangeWeak: - src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[8])->def); - src[1] = nir_src_for_ssa(vtn_ssa_value(b, w[7])->def); + src[0] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[8])); + src[1] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[7])); break; case SpvOpAtomicExchange: @@ -2726,7 +2735,7 @@ fill_common_atomic_sources(struct vtn_builder *b, SpvOp opcode, case SpvOpAtomicOr: case SpvOpAtomicXor: case SpvOpAtomicFAddEXT: - src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[6])->def); + src[0] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[6])); break; default: @@ -2737,15 +2746,14 @@ fill_common_atomic_sources(struct vtn_builder *b, SpvOp opcode, static nir_ssa_def * get_image_coord(struct vtn_builder *b, uint32_t value) { - struct vtn_ssa_value *coord = vtn_ssa_value(b, value); + nir_ssa_def *coord = vtn_get_nir_ssa(b, value); /* The image_load_store intrinsics assume a 4-dim coordinate */ - unsigned dim = glsl_get_vector_elements(coord->type); unsigned swizzle[4]; for (unsigned i = 0; i < 4; i++) - swizzle[i] = MIN2(i, dim - 1); + swizzle[i] = MIN2(i, coord->num_components - 1); - return nir_swizzle(&b->nb, coord->def, swizzle, 4); + return nir_swizzle(&b->nb, coord, swizzle, 4); } static nir_ssa_def * @@ -2772,7 +2780,7 @@ vtn_handle_image(struct vtn_builder *b, SpvOp opcode, val->image->image = vtn_value(b, w[3], vtn_value_type_pointer)->pointer; val->image->coord = get_image_coord(b, w[4]); - val->image->sample = vtn_ssa_value(b, w[5])->def; + val->image->sample = vtn_get_nir_ssa(b, w[5]); val->image->lod = nir_imm_int(&b->nb, 0); return; } @@ -2831,7 +2839,7 @@ vtn_handle_image(struct vtn_builder *b, SpvOp opcode, if (operands & SpvImageOperandsSampleMask) { uint32_t arg = image_operand_arg(b, w, count, 5, SpvImageOperandsSampleMask); - image.sample = vtn_ssa_value(b, w[arg])->def; + image.sample = vtn_get_nir_ssa(b, w[arg]); } else { image.sample = nir_ssa_undef(&b->nb, 1, 32); } @@ -2848,7 +2856,7 @@ vtn_handle_image(struct vtn_builder *b, SpvOp opcode, if (operands & SpvImageOperandsLodMask) { uint32_t arg = image_operand_arg(b, w, count, 5, SpvImageOperandsLodMask); - image.lod = vtn_ssa_value(b, w[arg])->def; + image.lod = vtn_get_nir_ssa(b, w[arg]); } else { image.lod = nir_imm_int(&b->nb, 0); } @@ -2871,7 +2879,7 @@ vtn_handle_image(struct vtn_builder *b, SpvOp opcode, if (operands & SpvImageOperandsSampleMask) { uint32_t arg = image_operand_arg(b, w, count, 4, SpvImageOperandsSampleMask); - image.sample = vtn_ssa_value(b, w[arg])->def; + image.sample = vtn_get_nir_ssa(b, w[arg]); } else { image.sample = nir_ssa_undef(&b->nb, 1, 32); } @@ -2888,7 +2896,7 @@ vtn_handle_image(struct vtn_builder *b, SpvOp opcode, if (operands & SpvImageOperandsLodMask) { uint32_t arg = image_operand_arg(b, w, count, 4, SpvImageOperandsLodMask); - image.lod = vtn_ssa_value(b, w[arg])->def; + image.lod = vtn_get_nir_ssa(b, w[arg]); } else { image.lod = nir_imm_int(&b->nb, 0); } @@ -2977,7 +2985,7 @@ vtn_handle_image(struct vtn_builder *b, SpvOp opcode, case SpvOpAtomicStore: case SpvOpImageWrite: { const uint32_t value_id = opcode == SpvOpAtomicStore ? w[4] : w[3]; - nir_ssa_def *value = vtn_ssa_value(b, value_id)->def; + nir_ssa_def *value = vtn_get_nir_ssa(b, value_id); /* nir_intrinsic_image_deref_store always takes a vec4 value */ assert(op == nir_intrinsic_image_deref_store); intrin->num_components = 4; @@ -3239,7 +3247,7 @@ vtn_handle_atomics(struct vtn_builder *b, SpvOp opcode, atomic->num_components = glsl_get_vector_elements(ptr->type->type); nir_intrinsic_set_write_mask(atomic, (1 << atomic->num_components) - 1); nir_intrinsic_set_align(atomic, 4, 0); - atomic->src[src++] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def); + atomic->src[src++] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[4])); if (ptr->mode == vtn_variable_mode_ssbo) atomic->src[src++] = nir_src_for_ssa(index); atomic->src[src++] = nir_src_for_ssa(offset); @@ -3284,7 +3292,7 @@ vtn_handle_atomics(struct vtn_builder *b, SpvOp opcode, case SpvOpAtomicStore: atomic->num_components = glsl_get_vector_elements(deref_type); nir_intrinsic_set_write_mask(atomic, (1 << atomic->num_components) - 1); - atomic->src[1] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def); + atomic->src[1] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[4])); break; case SpvOpAtomicExchange: @@ -3542,20 +3550,20 @@ vtn_handle_composite(struct vtn_builder *b, SpvOp opcode, switch (opcode) { case SpvOpVectorExtractDynamic: - ssa->def = nir_vector_extract(&b->nb, vtn_ssa_value(b, w[3])->def, - vtn_ssa_value(b, w[4])->def); + ssa->def = nir_vector_extract(&b->nb, vtn_get_nir_ssa(b, w[3]), + vtn_get_nir_ssa(b, w[4])); break; case SpvOpVectorInsertDynamic: - ssa->def = nir_vector_insert(&b->nb, vtn_ssa_value(b, w[3])->def, - vtn_ssa_value(b, w[4])->def, - vtn_ssa_value(b, w[5])->def); + ssa->def = nir_vector_insert(&b->nb, vtn_get_nir_ssa(b, w[3]), + vtn_get_nir_ssa(b, w[4]), + vtn_get_nir_ssa(b, w[5])); break; case SpvOpVectorShuffle: ssa->def = vtn_vector_shuffle(b, glsl_get_vector_elements(type->type), - vtn_ssa_value(b, w[3])->def, - vtn_ssa_value(b, w[4])->def, + vtn_get_nir_ssa(b, w[3]), + vtn_get_nir_ssa(b, w[4]), w + 5); break; @@ -3565,7 +3573,7 @@ vtn_handle_composite(struct vtn_builder *b, SpvOp opcode, if (glsl_type_is_vector_or_scalar(type->type)) { nir_ssa_def *srcs[NIR_MAX_VEC_COMPONENTS]; for (unsigned i = 0; i < elems; i++) - srcs[i] = vtn_ssa_value(b, w[3 + i])->def; + srcs[i] = vtn_get_nir_ssa(b, w[3 + i]); ssa->def = vtn_vector_construct(b, glsl_get_vector_elements(type->type), elems, srcs); @@ -4762,8 +4770,8 @@ vtn_handle_ptr(struct vtn_builder *b, SpvOp opcode, &elem_size, &elem_align); def = nir_build_addr_isub(&b->nb, - vtn_ssa_value(b, w[3])->def, - vtn_ssa_value(b, w[4])->def, + vtn_get_nir_ssa(b, w[3]), + vtn_get_nir_ssa(b, w[4]), addr_format); def = nir_idiv(&b->nb, def, nir_imm_intN_t(&b->nb, elem_size, def->bit_size)); def = nir_i2i(&b->nb, def, glsl_get_bit_size(type)); @@ -4773,8 +4781,8 @@ vtn_handle_ptr(struct vtn_builder *b, SpvOp opcode, case SpvOpPtrEqual: case SpvOpPtrNotEqual: { def = nir_build_addr_ieq(&b->nb, - vtn_ssa_value(b, w[3])->def, - vtn_ssa_value(b, w[4])->def, + vtn_get_nir_ssa(b, w[3]), + vtn_get_nir_ssa(b, w[4]), addr_format); if (opcode == SpvOpPtrNotEqual) def = nir_inot(&b->nb, def); diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c index 035bd857f19..e15d67d688d 100644 --- a/src/compiler/spirv/vtn_alu.c +++ b/src/compiler/spirv/vtn_alu.c @@ -697,10 +697,7 @@ vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count) */ struct vtn_type *type = vtn_get_type(b, w[1]); - struct vtn_ssa_value *vtn_src = vtn_ssa_value(b, w[3]); - struct nir_ssa_def *src = vtn_src->def; - - vtn_assert(glsl_type_is_vector_or_scalar(vtn_src->type)); + struct nir_ssa_def *src = vtn_get_nir_ssa(b, w[3]); vtn_fail_if(src->num_components * src->bit_size != glsl_get_vector_elements(type->type) * glsl_get_bit_size(type->type), diff --git a/src/compiler/spirv/vtn_amd.c b/src/compiler/spirv/vtn_amd.c index 32420fccae0..4ba8193b532 100644 --- a/src/compiler/spirv/vtn_amd.c +++ b/src/compiler/spirv/vtn_amd.c @@ -33,10 +33,10 @@ vtn_handle_amd_gcn_shader_instruction(struct vtn_builder *b, SpvOp ext_opcode, nir_ssa_def *def; switch ((enum GcnShaderAMD)ext_opcode) { case CubeFaceIndexAMD: - def = nir_cube_face_index(&b->nb, vtn_ssa_value(b, w[5])->def); + def = nir_cube_face_index(&b->nb, vtn_get_nir_ssa(b, w[5])); break; case CubeFaceCoordAMD: - def = nir_cube_face_coord(&b->nb, vtn_ssa_value(b, w[5])->def); + def = nir_cube_face_coord(&b->nb, vtn_get_nir_ssa(b, w[5])); break; case TimeAMD: { nir_intrinsic_instr *intrin = nir_intrinsic_instr_create(b->nb.shader, @@ -90,7 +90,7 @@ vtn_handle_amd_shader_ballot_instruction(struct vtn_builder *b, SpvOp ext_opcode intrin->num_components = intrin->dest.ssa.num_components; for (unsigned i = 0; i < num_args; i++) - intrin->src[i] = nir_src_for_ssa(vtn_ssa_value(b, w[i + 5])->def); + intrin->src[i] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[i + 5])); if (intrin->intrinsic == nir_intrinsic_quad_swizzle_amd) { struct vtn_value *val = vtn_value(b, w[6], vtn_value_type_constant); @@ -124,7 +124,7 @@ vtn_handle_amd_shader_trinary_minmax_instruction(struct vtn_builder *b, SpvOp ex assert(num_inputs == 3); nir_ssa_def *src[3] = { NULL, }; for (unsigned i = 0; i < num_inputs; i++) - src[i] = vtn_ssa_value(b, w[i + 5])->def; + src[i] = vtn_get_nir_ssa(b, w[i + 5]); nir_ssa_def *def; switch ((enum ShaderTrinaryMinMaxAMD)ext_opcode) { @@ -198,7 +198,7 @@ vtn_handle_amd_shader_explicit_vertex_parameter_instruction(struct vtn_builder * deref = nir_deref_instr_parent(deref); } intrin->src[0] = nir_src_for_ssa(&deref->dest.ssa); - intrin->src[1] = nir_src_for_ssa(vtn_ssa_value(b, w[6])->def); + intrin->src[1] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[6])); intrin->num_components = glsl_get_vector_elements(deref->type); nir_ssa_dest_init(&intrin->instr, &intrin->dest, diff --git a/src/compiler/spirv/vtn_cfg.c b/src/compiler/spirv/vtn_cfg.c index 23a0983e4f4..e97362a9dca 100644 --- a/src/compiler/spirv/vtn_cfg.c +++ b/src/compiler/spirv/vtn_cfg.c @@ -1176,7 +1176,7 @@ vtn_emit_cf_list(struct vtn_builder *b, struct list_head *cf_list, bool sw_break = false; nir_if *nif = - nir_push_if(&b->nb, vtn_ssa_value(b, vtn_if->condition)->def); + nir_push_if(&b->nb, vtn_get_nir_ssa(b, vtn_if->condition)); nif->control = vtn_selection_control(b, vtn_if); @@ -1263,7 +1263,7 @@ vtn_emit_cf_list(struct vtn_builder *b, struct list_head *cf_list, nir_local_variable_create(b->nb.impl, glsl_bool_type(), "fall"); nir_store_var(&b->nb, fall_var, nir_imm_false(&b->nb), 1); - nir_ssa_def *sel = vtn_ssa_value(b, vtn_switch->selector)->def; + nir_ssa_def *sel = vtn_get_nir_ssa(b, vtn_switch->selector); /* Now we can walk the list of cases and actually emit code */ vtn_foreach_cf_node(case_node, &vtn_switch->cases) { diff --git a/src/compiler/spirv/vtn_glsl450.c b/src/compiler/spirv/vtn_glsl450.c index 933cf1d407e..a6334149353 100644 --- a/src/compiler/spirv/vtn_glsl450.c +++ b/src/compiler/spirv/vtn_glsl450.c @@ -322,7 +322,7 @@ handle_glsl450_alu(struct vtn_builder *b, enum GLSLstd450 entrypoint, if (vtn_untyped_value(b, w[i + 5])->value_type == vtn_value_type_pointer) continue; - src[i] = vtn_ssa_value(b, w[i + 5])->def; + src[i] = vtn_get_nir_ssa(b, w[i + 5]); } switch (entrypoint) { @@ -598,7 +598,7 @@ handle_glsl450_interpolation(struct vtn_builder *b, enum GLSLstd450 opcode, break; case GLSLstd450InterpolateAtSample: case GLSLstd450InterpolateAtOffset: - intrin->src[1] = nir_src_for_ssa(vtn_ssa_value(b, w[6])->def); + intrin->src[1] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[6])); break; default: vtn_fail("Invalid opcode"); diff --git a/src/compiler/spirv/vtn_opencl.c b/src/compiler/spirv/vtn_opencl.c index ba3d00a7ce4..6018184fe21 100644 --- a/src/compiler/spirv/vtn_opencl.c +++ b/src/compiler/spirv/vtn_opencl.c @@ -45,7 +45,7 @@ handle_instr(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, nir_ssa_def *srcs[3] = { NULL }; vtn_assert(num_srcs <= ARRAY_SIZE(srcs)); for (unsigned i = 0; i < num_srcs; i++) { - srcs[i] = vtn_ssa_value(b, w[i + 5])->def; + srcs[i] = vtn_get_nir_ssa(b, w[i + 5]); } nir_ssa_def *result = handler(b, opcode, num_srcs, srcs, dest_type); @@ -230,7 +230,7 @@ _handle_v_load_store(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, const struct glsl_type *dest_type = type->type; unsigned components = glsl_get_vector_elements(dest_type); - nir_ssa_def *offset = vtn_ssa_value(b, w[5 + a])->def; + nir_ssa_def *offset = vtn_get_nir_ssa(b, w[5 + a]); struct vtn_value *p = vtn_value(b, w[6 + a], vtn_value_type_pointer); struct vtn_ssa_value *comps[NIR_MAX_VEC_COMPONENTS]; diff --git a/src/compiler/spirv/vtn_private.h b/src/compiler/spirv/vtn_private.h index fd1ce33c6be..1e90316aec2 100644 --- a/src/compiler/spirv/vtn_private.h +++ b/src/compiler/spirv/vtn_private.h @@ -785,6 +785,7 @@ vtn_get_type(struct vtn_builder *b, uint32_t value_id) struct vtn_ssa_value *vtn_ssa_value(struct vtn_builder *b, uint32_t value_id); +nir_ssa_def *vtn_get_nir_ssa(struct vtn_builder *b, uint32_t value_id); struct vtn_value *vtn_push_nir_ssa(struct vtn_builder *b, uint32_t value_id, nir_ssa_def *def); diff --git a/src/compiler/spirv/vtn_subgroup.c b/src/compiler/spirv/vtn_subgroup.c index e8c3db79b9d..9ebc7c14f09 100644 --- a/src/compiler/spirv/vtn_subgroup.c +++ b/src/compiler/spirv/vtn_subgroup.c @@ -94,7 +94,7 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, "OpGroupNonUniformBallot must return a uvec4"); nir_intrinsic_instr *ballot = nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot); - ballot->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[3])->def); + ballot->src[0] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[3])); nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL); ballot->num_components = 4; nir_builder_instr_insert(&b->nb, &ballot->instr); @@ -111,7 +111,7 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot_bitfield_extract); - intrin->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def); + intrin->src[0] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[4])); intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb)); nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest, @@ -131,8 +131,8 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, switch (opcode) { case SpvOpGroupNonUniformBallotBitExtract: op = nir_intrinsic_ballot_bitfield_extract; - src0 = vtn_ssa_value(b, w[4])->def; - src1 = vtn_ssa_value(b, w[5])->def; + src0 = vtn_get_nir_ssa(b, w[4]); + src1 = vtn_get_nir_ssa(b, w[5]); break; case SpvOpGroupNonUniformBallotBitCount: switch ((SpvGroupOperation)w[4]) { @@ -148,15 +148,15 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, default: unreachable("Invalid group operation"); } - src0 = vtn_ssa_value(b, w[5])->def; + src0 = vtn_get_nir_ssa(b, w[5]); break; case SpvOpGroupNonUniformBallotFindLSB: op = nir_intrinsic_ballot_find_lsb; - src0 = vtn_ssa_value(b, w[4])->def; + src0 = vtn_get_nir_ssa(b, w[4]); break; case SpvOpGroupNonUniformBallotFindMSB: op = nir_intrinsic_ballot_find_msb; - src0 = vtn_ssa_value(b, w[4])->def; + src0 = vtn_get_nir_ssa(b, w[4]); break; default: unreachable("Unhandled opcode"); @@ -188,7 +188,7 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, case SpvOpSubgroupReadInvocationKHR: vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation, val->ssa, vtn_ssa_value(b, w[3]), - vtn_ssa_value(b, w[4])->def, 0, 0); + vtn_get_nir_ssa(b, w[4]), 0, 0); break; case SpvOpGroupNonUniformAll: @@ -246,9 +246,9 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, if (opcode == SpvOpGroupNonUniformAll || opcode == SpvOpGroupAll || opcode == SpvOpGroupNonUniformAny || opcode == SpvOpGroupAny || opcode == SpvOpGroupNonUniformAllEqual) { - src0 = vtn_ssa_value(b, w[4])->def; + src0 = vtn_get_nir_ssa(b, w[4]); } else { - src0 = vtn_ssa_value(b, w[3])->def; + src0 = vtn_get_nir_ssa(b, w[3]); } nir_intrinsic_instr *intrin = nir_intrinsic_instr_create(b->nb.shader, op); @@ -285,14 +285,14 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, unreachable("Invalid opcode"); } vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]), - vtn_ssa_value(b, w[5])->def, 0, 0); + vtn_get_nir_ssa(b, w[5]), 0, 0); break; } case SpvOpGroupNonUniformQuadBroadcast: vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast, val->ssa, vtn_ssa_value(b, w[4]), - vtn_ssa_value(b, w[5])->def, 0, 0); + vtn_get_nir_ssa(b, w[5]), 0, 0); break; case SpvOpGroupNonUniformQuadSwap: {