From 868456fbf7418f318ea965c2ce151781dbe42e67 Mon Sep 17 00:00:00 2001 From: Jason Ekstrand Date: Wed, 28 Jun 2017 16:31:06 -0700 Subject: [PATCH] nir/spirv: Implement OpPtrAccessChain for buffers Reviewed-by: Iago Toral Quiroga --- src/compiler/spirv/spirv_to_nir.c | 4 +++- src/compiler/spirv/vtn_private.h | 11 ++++++++--- src/compiler/spirv/vtn_variables.c | 23 +++++++++++++++++++++++ 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index 89ebc5f674c..7038bd97ced 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -600,7 +600,8 @@ type_decoration_cb(struct vtn_builder *b, switch (dec->decoration) { case SpvDecorationArrayStride: assert(type->base_type == vtn_base_type_matrix || - type->base_type == vtn_base_type_array); + type->base_type == vtn_base_type_array || + type->base_type == vtn_base_type_pointer); type->stride = dec->literals[0]; break; case SpvDecorationBlock: @@ -3067,6 +3068,7 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode, case SpvOpCopyMemory: case SpvOpCopyMemorySized: case SpvOpAccessChain: + case SpvOpPtrAccessChain: case SpvOpInBoundsAccessChain: case SpvOpArrayLength: vtn_handle_variables(b, opcode, w, count); diff --git a/src/compiler/spirv/vtn_private.h b/src/compiler/spirv/vtn_private.h index 7cb503568fe..2f96c0904ac 100644 --- a/src/compiler/spirv/vtn_private.h +++ b/src/compiler/spirv/vtn_private.h @@ -220,15 +220,15 @@ struct vtn_type { /* Specifies the length of complex types. */ unsigned length; + /* for arrays, matrices and pointers, the array stride */ + unsigned stride; + union { /* Members for scalar, vector, and array-like types */ struct { /* for arrays, the vtn_type for the elements of the array */ struct vtn_type *array_element; - /* for arrays and matrices, the array stride */ - unsigned stride; - /* for matrices, whether the matrix is stored row-major */ bool row_major:1; @@ -308,6 +308,11 @@ struct vtn_access_link { struct vtn_access_chain { uint32_t length; + /** Whether or not to treat the base pointer as an array. This is only + * true if this access chain came from an OpPtrAccessChain. + */ + bool ptr_as_array; + /** Struct elements and array offsets. * * This is an array of 1 so that it can conveniently be created on the diff --git a/src/compiler/spirv/vtn_variables.c b/src/compiler/spirv/vtn_variables.c index 4f21fdd4cac..a9ba39247c8 100644 --- a/src/compiler/spirv/vtn_variables.c +++ b/src/compiler/spirv/vtn_variables.c @@ -67,6 +67,12 @@ vtn_access_chain_pointer_dereference(struct vtn_builder *b, vtn_access_chain_extend(b, base->chain, deref_chain->length); struct vtn_type *type = base->type; + /* OpPtrAccessChain is only allowed on things which support variable + * pointers. For everything else, the client is expected to just pass us + * the right access chain. + */ + assert(!deref_chain->ptr_as_array); + unsigned start = base->chain ? base->chain->length : 0; for (unsigned i = 0; i < deref_chain->length; i++) { chain->link[start + i] = deref_chain->link[i]; @@ -135,6 +141,21 @@ vtn_ssa_offset_pointer_dereference(struct vtn_builder *b, struct vtn_type *type = base->type; unsigned idx = 0; + if (deref_chain->ptr_as_array) { + /* We need ptr_type for the stride */ + assert(base->ptr_type); + /* This must be a pointer to an actual element somewhere */ + assert(block_index && offset); + /* We need at least one element in the chain */ + assert(deref_chain->length >= 1); + + nir_ssa_def *elem_offset = + vtn_access_link_as_ssa(b, deref_chain->link[idx], + base->ptr_type->stride); + offset = nir_iadd(&b->nb, offset, elem_offset); + idx++; + } + if (!block_index) { assert(base->var); if (glsl_type_is_array(type->type)) { @@ -1699,8 +1720,10 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode, } case SpvOpAccessChain: + case SpvOpPtrAccessChain: case SpvOpInBoundsAccessChain: { struct vtn_access_chain *chain = vtn_access_chain_create(b, count - 4); + chain->ptr_as_array = (opcode == SpvOpPtrAccessChain); unsigned idx = 0; for (int i = 4; i < count; i++) { -- 2.30.2