From: Jason Ekstrand Date: Tue, 22 Jan 2019 00:20:46 +0000 (-0600) Subject: spirv: Implement OpConvertPtrToU and OpConvertUToPtr X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=fb282a68bc46a1e28eaedb2670be241401f2b9da;p=mesa.git spirv: Implement OpConvertPtrToU and OpConvertUToPtr This only implements the actual opcodes and does not implement support for using them with specialization constants. Reviewed-by: Bas Nieuwenhuizen --- diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index 8c85eac5875..022a90eff7e 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -4031,6 +4031,8 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode, case SpvOpPtrAccessChain: case SpvOpInBoundsAccessChain: case SpvOpArrayLength: + case SpvOpConvertPtrToU: + case SpvOpConvertUToPtr: vtn_handle_variables(b, opcode, w, count); break; @@ -4187,8 +4189,6 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode, case SpvOpSConvert: case SpvOpFConvert: case SpvOpQuantizeToF16: - case SpvOpConvertPtrToU: - case SpvOpConvertUToPtr: case SpvOpPtrCastToGeneric: case SpvOpGenericCastToPtr: case SpvOpBitcast: diff --git a/src/compiler/spirv/vtn_variables.c b/src/compiler/spirv/vtn_variables.c index 17f067133dd..4f7e2a15af9 100644 --- a/src/compiler/spirv/vtn_variables.c +++ b/src/compiler/spirv/vtn_variables.c @@ -2150,6 +2150,44 @@ vtn_assert_types_equal(struct vtn_builder *b, SpvOp opcode, glsl_get_type_name(src_type->type)); } +static nir_ssa_def * +nir_shrink_zero_pad_vec(nir_builder *b, nir_ssa_def *val, + unsigned num_components) +{ + if (val->num_components == num_components) + return val; + + nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS]; + for (unsigned i = 0; i < num_components; i++) { + if (i < val->num_components) + comps[i] = nir_channel(b, val, i); + else + comps[i] = nir_imm_intN_t(b, 0, val->bit_size); + } + return nir_vec(b, comps, num_components); +} + +static nir_ssa_def * +nir_sloppy_bitcast(nir_builder *b, nir_ssa_def *val, + const struct glsl_type *type) +{ + const unsigned num_components = glsl_get_vector_elements(type); + const unsigned bit_size = glsl_get_bit_size(type); + + /* First, zero-pad to ensure that the value is big enough that when we + * bit-cast it, we don't loose anything. + */ + if (val->bit_size < bit_size) { + const unsigned src_num_components_needed = + vtn_align_u32(val->num_components, bit_size / val->bit_size); + val = nir_shrink_zero_pad_vec(b, val, src_num_components_needed); + } + + val = nir_bitcast_vector(b, val, bit_size); + + return nir_shrink_zero_pad_vec(b, val, num_components); +} + void vtn_handle_variables(struct vtn_builder *b, SpvOp opcode, const uint32_t *w, unsigned count) @@ -2352,6 +2390,41 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode, break; } + case SpvOpConvertPtrToU: { + struct vtn_value *u_val = vtn_push_value(b, w[2], vtn_value_type_ssa); + + vtn_fail_if(u_val->type->base_type != vtn_base_type_vector && + u_val->type->base_type != vtn_base_type_scalar, + "OpConvertPtrToU can only be used to cast to a vector or " + "scalar type"); + + /* The pointer will be converted to an SSA value automatically */ + nir_ssa_def *ptr_ssa = vtn_ssa_value(b, w[3])->def; + + u_val->ssa = vtn_create_ssa_value(b, u_val->type->type); + u_val->ssa->def = nir_sloppy_bitcast(&b->nb, ptr_ssa, u_val->type->type); + break; + } + + case SpvOpConvertUToPtr: { + struct vtn_value *ptr_val = + vtn_push_value(b, w[2], vtn_value_type_pointer); + struct vtn_value *u_val = vtn_value(b, w[3], vtn_value_type_ssa); + + vtn_fail_if(ptr_val->type->type == NULL, + "OpConvertUToPtr can only be used on physical pointers"); + + vtn_fail_if(u_val->type->base_type != vtn_base_type_vector && + u_val->type->base_type != vtn_base_type_scalar, + "OpConvertUToPtr can only be used to cast from a vector or " + "scalar type"); + + nir_ssa_def *ptr_ssa = nir_sloppy_bitcast(&b->nb, u_val->ssa->def, + ptr_val->type->type); + ptr_val->pointer = vtn_pointer_from_ssa(b, ptr_ssa, ptr_val->type); + break; + } + case SpvOpCopyMemorySized: default: vtn_fail("Unhandled opcode");