From abfe674c54bee6f8fdcae411b07db89c10b9d530 Mon Sep 17 00:00:00 2001 From: Jason Ekstrand Date: Fri, 14 Dec 2018 11:06:07 -0600 Subject: [PATCH] spirv: Handle arbitrary bit sizes for deref array indices MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit We already had code in link_as_ssa to handle bit sizes; we just need to use it. While we're at it we clean up link_as_ssa a bit and add an explicit bit_size parameter in preparation for a day when we have derefs that aren't 32 bit. Cc: mesa-stable@lists.freedesktop.org Reviewed-by: Alejandro Piñeiro Reviewed-by: Caio Marcelo de Oliveira Filho --- src/compiler/spirv/vtn_private.h | 2 +- src/compiler/spirv/vtn_variables.c | 74 +++++++++++++++++------------- 2 files changed, 42 insertions(+), 34 deletions(-) diff --git a/src/compiler/spirv/vtn_private.h b/src/compiler/spirv/vtn_private.h index defcbb8e69d..35739255510 100644 --- a/src/compiler/spirv/vtn_private.h +++ b/src/compiler/spirv/vtn_private.h @@ -390,7 +390,7 @@ enum vtn_access_mode { struct vtn_access_link { enum vtn_access_mode mode; - uint32_t id; + int64_t id; }; struct vtn_access_chain { diff --git a/src/compiler/spirv/vtn_variables.c b/src/compiler/spirv/vtn_variables.c index d50e445778e..70bec69a052 100644 --- a/src/compiler/spirv/vtn_variables.c +++ b/src/compiler/spirv/vtn_variables.c @@ -65,6 +65,23 @@ vtn_pointer_is_external_block(struct vtn_builder *b, b->options->lower_workgroup_access_to_offsets); } +static nir_ssa_def * +vtn_access_link_as_ssa(struct vtn_builder *b, struct vtn_access_link link, + unsigned stride, unsigned bit_size) +{ + vtn_assert(stride > 0); + if (link.mode == vtn_access_mode_literal) { + return nir_imm_intN_t(&b->nb, link.id * stride, bit_size); + } else { + nir_ssa_def *ssa = vtn_ssa_value(b, link.id)->def; + if (ssa->bit_size != bit_size) + ssa = nir_i2i(&b->nb, ssa, bit_size); + if (stride != 1) + ssa = nir_imul_imm(&b->nb, ssa, stride); + return ssa; + } +} + /* Dereference the given base pointer by the access chain */ static struct vtn_pointer * vtn_nir_deref_pointer_dereference(struct vtn_builder *b, @@ -95,13 +112,8 @@ vtn_nir_deref_pointer_dereference(struct vtn_builder *b, tail = nir_build_deref_struct(&b->nb, tail, idx); type = type->members[idx]; } else { - nir_ssa_def *index; - if (deref_chain->link[i].mode == vtn_access_mode_literal) { - index = nir_imm_int(&b->nb, deref_chain->link[i].id); - } else { - vtn_assert(deref_chain->link[i].mode == vtn_access_mode_id); - index = vtn_ssa_value(b, deref_chain->link[i].id)->def; - } + nir_ssa_def *index = vtn_access_link_as_ssa(b, deref_chain->link[i], 1, + tail->dest.ssa.bit_size); tail = nir_build_deref_array(&b->nb, tail, index); type = type->array_element; } @@ -119,26 +131,6 @@ vtn_nir_deref_pointer_dereference(struct vtn_builder *b, return ptr; } -static nir_ssa_def * -vtn_access_link_as_ssa(struct vtn_builder *b, struct vtn_access_link link, - unsigned stride) -{ - vtn_assert(stride > 0); - if (link.mode == vtn_access_mode_literal) { - return nir_imm_int(&b->nb, link.id * stride); - } else if (stride == 1) { - nir_ssa_def *ssa = vtn_ssa_value(b, link.id)->def; - if (ssa->bit_size != 32) - ssa = nir_i2i32(&b->nb, ssa); - return ssa; - } else { - nir_ssa_def *src0 = vtn_ssa_value(b, link.id)->def; - if (src0->bit_size != 32) - src0 = nir_i2i32(&b->nb, src0); - return nir_imul_imm(&b->nb, src0, stride); - } -} - static nir_ssa_def * vtn_variable_resource_index(struct vtn_builder *b, struct vtn_variable *var, nir_ssa_def *desc_array_index) @@ -196,7 +188,7 @@ vtn_ssa_offset_pointer_dereference(struct vtn_builder *b, if (glsl_type_is_array(type->type)) { if (deref_chain->length >= 1) { desc_arr_idx = - vtn_access_link_as_ssa(b, deref_chain->link[0], 1); + vtn_access_link_as_ssa(b, deref_chain->link[0], 1, 32); idx++; /* This consumes a level of type */ type = type->array_element; @@ -212,7 +204,7 @@ vtn_ssa_offset_pointer_dereference(struct vtn_builder *b, } else if (deref_chain->ptr_as_array) { /* You can't have a zero-length OpPtrAccessChain */ vtn_assert(deref_chain->length >= 1); - desc_arr_idx = vtn_access_link_as_ssa(b, deref_chain->link[0], 1); + desc_arr_idx = vtn_access_link_as_ssa(b, deref_chain->link[0], 1, 32); } else { /* We have a regular non-array SSBO. */ desc_arr_idx = NULL; @@ -244,7 +236,7 @@ vtn_ssa_offset_pointer_dereference(struct vtn_builder *b, */ vtn_assert(deref_chain->length >= 1); nir_ssa_def *offset_index = - vtn_access_link_as_ssa(b, deref_chain->link[0], 1); + vtn_access_link_as_ssa(b, deref_chain->link[0], 1, 32); idx++; block_index = vtn_resource_reindex(b, block_index, offset_index); @@ -298,7 +290,7 @@ vtn_ssa_offset_pointer_dereference(struct vtn_builder *b, nir_ssa_def *elem_offset = vtn_access_link_as_ssa(b, deref_chain->link[idx], - base->ptr_type->stride); + base->ptr_type->stride, offset->bit_size); offset = nir_iadd(&b->nb, offset, elem_offset); idx++; } @@ -319,7 +311,8 @@ vtn_ssa_offset_pointer_dereference(struct vtn_builder *b, case GLSL_TYPE_BOOL: case GLSL_TYPE_ARRAY: { nir_ssa_def *elem_offset = - vtn_access_link_as_ssa(b, deref_chain->link[idx], type->stride); + vtn_access_link_as_ssa(b, deref_chain->link[idx], + type->stride, offset->bit_size); offset = nir_iadd(&b->nb, offset, elem_offset); type = type->array_element; access |= type->access; @@ -1911,7 +1904,22 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode, struct vtn_value *link_val = vtn_untyped_value(b, w[i]); if (link_val->value_type == vtn_value_type_constant) { chain->link[idx].mode = vtn_access_mode_literal; - chain->link[idx].id = link_val->constant->values[0].u32[0]; + switch (glsl_get_bit_size(link_val->type->type)) { + case 8: + chain->link[idx].id = link_val->constant->values[0].i8[0]; + break; + case 16: + chain->link[idx].id = link_val->constant->values[0].i16[0]; + break; + case 32: + chain->link[idx].id = link_val->constant->values[0].i32[0]; + break; + case 64: + chain->link[idx].id = link_val->constant->values[0].i64[0]; + break; + default: + vtn_fail("Invalid bit size"); + } } else { chain->link[idx].mode = vtn_access_mode_id; chain->link[idx].id = w[i]; -- 2.30.2