From 562353e1f1246bfe0f70315083b51d26d60d994b Mon Sep 17 00:00:00 2001 From: Rhys Perry Date: Thu, 16 Apr 2020 19:07:06 +0100 Subject: [PATCH] aco: add helpers for splitting stores MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit split_store_data() splits a vector and p_as_uniforms it if needed. scan_write_mask()/advance_write_mask() are similar to u_bit_scan_consecutive_range(), but makes it easier to only clear part of the range and will also give ranges for zero'd bits. split_buffer_store() is a helper for splitting VMEM/SMEM stores. Signed-off-by: Rhys Perry Reviewed-by: Daniel Schürmann Part-of: --- .../compiler/aco_instruction_selection.cpp | 155 ++++++++++++++++++ 1 file changed, 155 insertions(+) diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index ca376e5052b..411744e1c22 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -3698,6 +3698,108 @@ void ds_write_helper(isel_context *ctx, Operand m, Temp address, Temp data, unsi } } +void split_store_data(isel_context *ctx, RegType dst_type, unsigned count, Temp *dst, unsigned *offsets, Temp src) +{ + if (!count) + return; + + Builder bld(ctx->program, ctx->block); + + ASSERTED bool is_subdword = false; + for (unsigned i = 0; i < count; i++) + is_subdword |= offsets[i] % 4; + is_subdword |= (src.bytes() - offsets[count - 1]) % 4; + assert(!is_subdword || dst_type == RegType::vgpr); + + /* count == 1 fast path */ + if (count == 1) { + if (dst_type == RegType::sgpr) + dst[0] = bld.as_uniform(src); + else + dst[0] = as_vgpr(ctx, src); + return; + } + + for (unsigned i = 0; i < count - 1; i++) + dst[i] = bld.tmp(RegClass::get(dst_type, offsets[i + 1] - offsets[i])); + dst[count - 1] = bld.tmp(RegClass::get(dst_type, src.bytes() - offsets[count - 1])); + + if (is_subdword && src.type() == RegType::sgpr) { + src = as_vgpr(ctx, src); + } else { + /* use allocated_vec if possible */ + auto it = ctx->allocated_vec.find(src.id()); + if (it != ctx->allocated_vec.end()) { + unsigned total_size = 0; + for (unsigned i = 0; it->second[i].bytes() && (i < NIR_MAX_VEC_COMPONENTS); i++) + total_size += it->second[i].bytes(); + if (total_size != src.bytes()) + goto split; + + unsigned elem_size = it->second[0].bytes(); + + for (unsigned i = 0; i < count; i++) { + if (offsets[i] % elem_size || dst[i].bytes() % elem_size) + goto split; + } + + for (unsigned i = 0; i < count; i++) { + unsigned start_idx = offsets[i] / elem_size; + unsigned op_count = dst[i].bytes() / elem_size; + if (op_count == 1) { + if (dst_type == RegType::sgpr) + dst[i] = bld.as_uniform(it->second[start_idx]); + else + dst[i] = as_vgpr(ctx, it->second[start_idx]); + continue; + } + + aco_ptr vec{create_instruction(aco_opcode::p_create_vector, Format::PSEUDO, op_count, 1)}; + for (unsigned j = 0; j < op_count; j++) { + Temp tmp = it->second[start_idx + j]; + if (dst_type == RegType::sgpr) + tmp = bld.as_uniform(tmp); + vec->operands[j] = Operand(tmp); + } + vec->definitions[0] = Definition(dst[i]); + bld.insert(std::move(vec)); + } + return; + } + } + + if (dst_type == RegType::sgpr) + src = bld.as_uniform(src); + + split: + /* just split it */ + aco_ptr split{create_instruction(aco_opcode::p_split_vector, Format::PSEUDO, 1, count)}; + split->operands[0] = Operand(src); + for (unsigned i = 0; i < count; i++) + split->definitions[i] = Definition(dst[i]); + bld.insert(std::move(split)); +} + +bool scan_write_mask(uint32_t mask, uint32_t todo_mask, + int *start, int *count) +{ + unsigned start_elem = ffs(todo_mask) - 1; + bool skip = !(mask & (1 << start_elem)); + if (skip) + mask = ~mask & todo_mask; + + mask &= todo_mask; + + u_bit_scan_consecutive_range(&mask, start, count); + + return !skip; +} + +void advance_write_mask(uint32_t *todo_mask, int start, int count) +{ + *todo_mask &= ~u_bit_consecutive(0, count) << start; +} + void store_lds(isel_context *ctx, unsigned elem_size_bytes, Temp data, uint32_t wrmask, Temp address, unsigned base_offset, unsigned align) { @@ -3755,6 +3857,59 @@ unsigned calculate_lds_alignment(isel_context *ctx, unsigned const_offset) } +void split_buffer_store(isel_context *ctx, nir_intrinsic_instr *instr, bool smem, RegType dst_type, + Temp data, unsigned writemask, int swizzle_element_size, + unsigned *write_count, Temp *write_datas, unsigned *offsets) +{ + unsigned write_count_with_skips = 0; + bool skips[16]; + + /* determine how to split the data */ + unsigned todo = u_bit_consecutive(0, data.bytes()); + while (todo) { + int offset, bytes; + skips[write_count_with_skips] = !scan_write_mask(writemask, todo, &offset, &bytes); + offsets[write_count_with_skips] = offset; + if (skips[write_count_with_skips]) { + advance_write_mask(&todo, offset, bytes); + write_count_with_skips++; + continue; + } + + /* only supported sizes are 1, 2, 4, 8, 12 and 16 bytes and can't be + * larger than swizzle_element_size */ + bytes = MIN2(bytes, swizzle_element_size); + if (bytes % 4) + bytes = bytes > 4 ? bytes & ~0x3 : MIN2(bytes, 2); + + /* SMEM and GFX6 VMEM can't emit 12-byte stores */ + if ((ctx->program->chip_class == GFX6 || smem) && bytes == 12) + bytes = 8; + + /* dword or larger stores have to be dword-aligned */ + unsigned align_mul = instr ? nir_intrinsic_align_mul(instr) : 4; + unsigned align_offset = instr ? nir_intrinsic_align_mul(instr) : 0; + bool dword_aligned = (align_offset + offset) % 4 == 0 && align_mul % 4 == 0; + if (bytes >= 4 && !dword_aligned) + bytes = MIN2(bytes, 2); + + advance_write_mask(&todo, offset, bytes); + write_count_with_skips++; + } + + /* actually split data */ + split_store_data(ctx, dst_type, write_count_with_skips, write_datas, offsets, data); + + /* remove skips */ + for (unsigned i = 0; i < write_count_with_skips; i++) { + if (skips[i]) + continue; + write_datas[*write_count] = write_datas[i]; + offsets[*write_count] = offsets[i]; + (*write_count)++; + } +} + Temp create_vec_from_array(isel_context *ctx, Temp arr[], unsigned cnt, RegType reg_type, unsigned elem_size_bytes, unsigned split_cnt = 0u, Temp dst = Temp()) { -- 2.30.2