aco: add emit_load helper
authorRhys Perry <pendingchaos02@gmail.com>
Thu, 16 Apr 2020 18:51:02 +0000 (19:51 +0100)
committerMarge Bot <eric+marge@anholt.net>
Fri, 24 Apr 2020 18:52:54 +0000 (18:52 +0000)
This helper is used for recombining split loads, passing the result to
p_as_uniform, aligning the offset down and shifting it right if needed and
handling large constant offsets.

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/4639>

src/amd/compiler/aco_instruction_selection.cpp

index c86d69baae9ee0ca5466821d676612bb07af6c33..de28601744b72f5c8cb085b5ff61353fe89e221b 100644 (file)
@@ -3051,6 +3051,291 @@ uint32_t widen_mask(uint32_t mask, unsigned multiplier)
    return new_mask;
 }
 
+void byte_align_vector(isel_context *ctx, Temp vec, Operand offset, Temp dst)
+{
+   Builder bld(ctx->program, ctx->block);
+   if (offset.isTemp()) {
+      Temp tmp[3] = {vec, vec, vec};
+
+      if (vec.size() == 3) {
+         tmp[0] = bld.tmp(v1), tmp[1] = bld.tmp(v1), tmp[2] = bld.tmp(v1);
+         bld.pseudo(aco_opcode::p_split_vector, Definition(tmp[0]), Definition(tmp[1]), Definition(tmp[2]), vec);
+      } else if (vec.size() == 2) {
+         tmp[0] = bld.tmp(v1), tmp[1] = bld.tmp(v1), tmp[2] = tmp[1];
+         bld.pseudo(aco_opcode::p_split_vector, Definition(tmp[0]), Definition(tmp[1]), vec);
+      }
+      for (unsigned i = 0; i < dst.size(); i++)
+         tmp[i] = bld.vop3(aco_opcode::v_alignbyte_b32, bld.def(v1), tmp[i + 1], tmp[i], offset);
+
+      vec = tmp[0];
+      if (dst.size() == 2)
+         vec = bld.pseudo(aco_opcode::p_create_vector, bld.def(v2), tmp[0], tmp[1]);
+
+      offset = Operand(0u);
+   }
+
+   if (vec.bytes() == dst.bytes() && offset.constantValue() == 0)
+      bld.copy(Definition(dst), vec);
+   else
+      trim_subdword_vector(ctx, vec, dst, vec.bytes(), ((1 << dst.bytes()) - 1) << offset.constantValue());
+}
+
+struct LoadEmitInfo {
+   Operand offset;
+   Temp dst;
+   unsigned num_components;
+   unsigned component_size;
+   Temp resource = Temp(0, s1);
+   unsigned component_stride = 0;
+   unsigned const_offset = 0;
+   unsigned align_mul = 0;
+   unsigned align_offset = 0;
+
+   bool glc = false;
+   unsigned swizzle_component_size = 0;
+   barrier_interaction barrier = barrier_none;
+   bool can_reorder = true;
+   Temp soffset = Temp(0, s1);
+};
+
+using LoadCallback = Temp(*)(
+   Builder& bld, const LoadEmitInfo* info, Temp offset, unsigned bytes_needed,
+   unsigned align, unsigned const_offset, Temp dst_hint);
+
+template <LoadCallback callback, bool byte_align_loads, bool supports_8bit_16bit_loads, unsigned max_const_offset_plus_one>
+void emit_load(isel_context *ctx, Builder& bld, const LoadEmitInfo *info)
+{
+   unsigned load_size = info->num_components * info->component_size;
+   unsigned component_size = info->component_size;
+
+   unsigned num_vals = 0;
+   Temp vals[info->dst.bytes()];
+
+   unsigned const_offset = info->const_offset;
+
+   unsigned align_mul = info->align_mul ? info->align_mul : component_size;
+   unsigned align_offset = (info->align_offset + const_offset) % align_mul;
+
+   unsigned bytes_read = 0;
+   while (bytes_read < load_size) {
+      unsigned bytes_needed = load_size - bytes_read;
+
+      /* add buffer for unaligned loads */
+      int byte_align = align_mul % 4 == 0 ? align_offset % 4 : -1;
+
+      if (byte_align) {
+         if ((bytes_needed > 2 || !supports_8bit_16bit_loads) && byte_align_loads) {
+            if (info->component_stride) {
+               assert(supports_8bit_16bit_loads && "unimplemented");
+               bytes_needed = 2;
+               byte_align = 0;
+            } else {
+               bytes_needed += byte_align == -1 ? 4 - info->align_mul : byte_align;
+               bytes_needed = align(bytes_needed, 4);
+            }
+         } else {
+            byte_align = 0;
+         }
+      }
+
+      if (info->swizzle_component_size)
+         bytes_needed = MIN2(bytes_needed, info->swizzle_component_size);
+      if (info->component_stride)
+         bytes_needed = MIN2(bytes_needed, info->component_size);
+
+      bool need_to_align_offset = byte_align && (align_mul % 4 || align_offset % 4);
+
+      /* reduce constant offset */
+      Operand offset = info->offset;
+      unsigned reduced_const_offset = const_offset;
+      bool remove_const_offset_completely = need_to_align_offset;
+      if (const_offset && (remove_const_offset_completely || const_offset >= max_const_offset_plus_one)) {
+         unsigned to_add = const_offset;
+         if (remove_const_offset_completely) {
+            reduced_const_offset = 0;
+         } else {
+            to_add = const_offset / max_const_offset_plus_one * max_const_offset_plus_one;
+            reduced_const_offset %= max_const_offset_plus_one;
+         }
+         Temp offset_tmp = offset.isTemp() ? offset.getTemp() : Temp();
+         if (offset.isConstant()) {
+            offset = Operand(offset.constantValue() + to_add);
+         } else if (offset_tmp.regClass() == s1) {
+            offset = bld.sop2(aco_opcode::s_add_i32, bld.def(s1), bld.def(s1, scc),
+                              offset_tmp, Operand(to_add));
+         } else if (offset_tmp.regClass() == v1) {
+            offset = bld.vadd32(bld.def(v1), offset_tmp, Operand(to_add));
+         } else {
+            Temp lo = bld.tmp(offset_tmp.type(), 1);
+            Temp hi = bld.tmp(offset_tmp.type(), 1);
+            bld.pseudo(aco_opcode::p_split_vector, Definition(lo), Definition(hi), offset_tmp);
+
+            if (offset_tmp.regClass() == s2) {
+               Temp carry = bld.tmp(s1);
+               lo = bld.sop2(aco_opcode::s_add_u32, bld.def(s1), bld.scc(Definition(carry)), lo, Operand(to_add));
+               hi = bld.sop2(aco_opcode::s_add_u32, bld.def(s1), bld.def(s1, scc), hi, carry);
+               offset = bld.pseudo(aco_opcode::p_create_vector, bld.def(s2), lo, hi);
+            } else {
+               Temp new_lo = bld.tmp(v1);
+               Temp carry = bld.vadd32(Definition(new_lo), lo, Operand(to_add), true).def(1).getTemp();
+               hi = bld.vadd32(bld.def(v1), hi, Operand(0u), false, carry);
+               offset = bld.pseudo(aco_opcode::p_create_vector, bld.def(v2), new_lo, hi);
+            }
+         }
+      }
+
+      /* align offset down if needed */
+      Operand aligned_offset = offset;
+      if (need_to_align_offset) {
+         Temp offset_tmp = offset.isTemp() ? offset.getTemp() : Temp();
+         if (offset.isConstant()) {
+            aligned_offset = Operand(offset.constantValue() & 0xfffffffcu);
+         } else if (offset_tmp.regClass() == s1) {
+            aligned_offset = bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc), Operand(0xfffffffcu), offset_tmp);
+         } else if (offset_tmp.regClass() == s2) {
+            aligned_offset = bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), Operand((uint64_t)0xfffffffffffffffcllu), offset_tmp);
+         } else if (offset_tmp.regClass() == v1) {
+            aligned_offset = bld.vop2(aco_opcode::v_and_b32, bld.def(v1), Operand(0xfffffffcu), offset_tmp);
+         } else if (offset_tmp.regClass() == v2) {
+            Temp hi = bld.tmp(v1), lo = bld.tmp(v1);
+            bld.pseudo(aco_opcode::p_split_vector, Definition(lo), Definition(hi), offset_tmp);
+            lo = bld.vop2(aco_opcode::v_and_b32, bld.def(v1), Operand(0xfffffffcu), lo);
+            aligned_offset = bld.pseudo(aco_opcode::p_create_vector, bld.def(v2), lo, hi);
+         }
+      }
+      Temp aligned_offset_tmp = aligned_offset.isTemp() ? aligned_offset.getTemp() :
+                                bld.copy(bld.def(s1), aligned_offset);
+
+      unsigned align = align_offset ? 1 << (ffs(align_offset) - 1) : align_mul;
+      Temp val = callback(bld, info, aligned_offset_tmp, bytes_needed, align,
+                          reduced_const_offset, byte_align ? Temp() : info->dst);
+
+      /* shift result right if needed */
+      if (byte_align) {
+         Operand align((uint32_t)byte_align);
+         if (byte_align == -1) {
+            if (offset.isConstant())
+               align = Operand(offset.constantValue() % 4u);
+            else if (offset.size() == 2)
+               align = Operand(emit_extract_vector(ctx, offset.getTemp(), 0, RegClass(offset.getTemp().type(), 1)));
+            else
+               align = offset;
+         }
+
+         if (align.isTemp() || align.constantValue()) {
+            assert(val.bytes() >= load_size && "unimplemented");
+            Temp new_val = bld.tmp(RegClass::get(val.type(), load_size));
+            if (val.type() == RegType::sgpr)
+               byte_align_scalar(ctx, val, align, new_val);
+            else
+               byte_align_vector(ctx, val, align, new_val);
+            val = new_val;
+         }
+      }
+
+      /* add result to list and advance */
+      if (info->component_stride) {
+         assert(val.bytes() == info->component_size && "unimplemented");
+         const_offset += info->component_stride;
+         align_offset = (align_offset + info->component_stride) % align_mul;
+      } else {
+         const_offset += val.bytes();
+         align_offset = (align_offset + val.bytes()) % align_mul;
+      }
+      bytes_read += val.bytes();
+      vals[num_vals++] = val;
+   }
+
+   /* the callback wrote directly to dst */
+   if (vals[0] == info->dst) {
+      assert(num_vals == 1);
+      emit_split_vector(ctx, info->dst, info->num_components);
+      return;
+   }
+
+   /* create array of components */
+   unsigned components_split = 0;
+   std::array<Temp, NIR_MAX_VEC_COMPONENTS> allocated_vec;
+   bool has_vgprs = false;
+   for (unsigned i = 0; i < num_vals;) {
+      Temp tmp[num_vals];
+      unsigned num_tmps = 0;
+      unsigned tmp_size = 0;
+      RegType reg_type = RegType::sgpr;
+      while ((!tmp_size || (tmp_size % component_size)) && i < num_vals) {
+         if (vals[i].type() == RegType::vgpr)
+            reg_type = RegType::vgpr;
+         tmp_size += vals[i].bytes();
+         tmp[num_tmps++] = vals[i++];
+      }
+      if (num_tmps > 1) {
+         aco_ptr<Pseudo_instruction> vec{create_instruction<Pseudo_instruction>(
+            aco_opcode::p_create_vector, Format::PSEUDO, num_tmps, 1)};
+         for (unsigned i = 0; i < num_vals; i++)
+            vec->operands[i] = Operand(tmp[i]);
+         tmp[0] = bld.tmp(RegClass::get(reg_type, tmp_size));
+         vec->definitions[0] = Definition(tmp[0]);
+         bld.insert(std::move(vec));
+      }
+
+      if (tmp[0].bytes() % component_size) {
+         /* trim tmp[0] */
+         assert(i == num_vals);
+         RegClass new_rc = RegClass::get(reg_type, tmp[0].bytes() / component_size * component_size);
+         tmp[0] = bld.pseudo(aco_opcode::p_extract_vector, bld.def(new_rc), tmp[0], Operand(0u));
+      }
+
+      RegClass elem_rc = RegClass::get(reg_type, component_size);
+
+      unsigned start = components_split;
+
+      if (tmp_size == elem_rc.bytes()) {
+         allocated_vec[components_split++] = tmp[0];
+      } else {
+         assert(tmp_size % elem_rc.bytes() == 0);
+         aco_ptr<Pseudo_instruction> split{create_instruction<Pseudo_instruction>(
+            aco_opcode::p_split_vector, Format::PSEUDO, 1, tmp_size / elem_rc.bytes())};
+         for (unsigned i = 0; i < split->definitions.size(); i++) {
+            Temp component = bld.tmp(elem_rc);
+            allocated_vec[components_split++] = component;
+            split->definitions[i] = Definition(component);
+         }
+         split->operands[0] = Operand(tmp[0]);
+         bld.insert(std::move(split));
+      }
+
+      /* try to p_as_uniform early so we can create more optimizable code and
+       * also update allocated_vec */
+      for (unsigned j = start; j < components_split; j++) {
+         if (allocated_vec[j].bytes() % 4 == 0 && info->dst.type() == RegType::sgpr)
+            allocated_vec[j] = bld.as_uniform(allocated_vec[j]);
+         has_vgprs |= allocated_vec[j].type() == RegType::vgpr;
+      }
+   }
+
+   /* concatenate components and p_as_uniform() result if needed */
+   if (info->dst.type() == RegType::vgpr || !has_vgprs)
+      ctx->allocated_vec.emplace(info->dst.id(), allocated_vec);
+
+   int padding_bytes = MAX2((int)info->dst.bytes() - int(allocated_vec[0].bytes() * info->num_components), 0);
+
+   aco_ptr<Pseudo_instruction> vec{create_instruction<Pseudo_instruction>(
+      aco_opcode::p_create_vector, Format::PSEUDO, info->num_components + !!padding_bytes, 1)};
+   for (unsigned i = 0; i < info->num_components; i++)
+      vec->operands[i] = Operand(allocated_vec[i]);
+   if (padding_bytes)
+      vec->operands[info->num_components] = Operand(RegClass::get(RegType::vgpr, padding_bytes));
+   if (info->dst.type() == RegType::sgpr && has_vgprs) {
+      Temp tmp = bld.tmp(RegType::vgpr, info->dst.size());
+      vec->definitions[0] = Definition(tmp);
+      bld.insert(std::move(vec));
+      bld.pseudo(aco_opcode::p_as_uniform, Definition(info->dst), tmp);
+   } else {
+      vec->definitions[0] = Definition(info->dst);
+      bld.insert(std::move(vec));
+   }
+}
+
 Operand load_lds_size_m0(isel_context *ctx)
 {
    /* TODO: m0 does not need to be initialized on GFX9+ */