intel/nir: Allow splitting a single load into up to 32 loads
[mesa.git] / src / intel / compiler / brw_nir_lower_mem_access_bit_sizes.c
index 0396f5ffcc04e2e13adf2f4d9eb0ea1a61d714a5..c26ea0bb7783c3633d22cdd1b322dc64e26ab9ce 100644 (file)
@@ -74,19 +74,23 @@ dup_mem_intrinsic(nir_builder *b, nir_intrinsic_instr *intrin,
 }
 
 static bool
-lower_mem_load_bit_size(nir_builder *b, nir_intrinsic_instr *intrin)
+lower_mem_load_bit_size(nir_builder *b, nir_intrinsic_instr *intrin,
+                        const struct gen_device_info *devinfo)
 {
-   assert(intrin->dest.is_ssa);
-   if (intrin->dest.ssa.bit_size == 32)
-      return false;
+   const bool needs_scalar =
+      intrin->intrinsic == nir_intrinsic_load_scratch;
 
+   assert(intrin->dest.is_ssa);
    const unsigned bit_size = intrin->dest.ssa.bit_size;
    const unsigned num_components = intrin->dest.ssa.num_components;
    const unsigned bytes_read = num_components * (bit_size / 8);
    const unsigned align = nir_intrinsic_align(intrin);
 
-   nir_ssa_def *result[4] = { NULL, };
+   if (bit_size == 32 && align >= 32 &&
+       (!needs_scalar || intrin->num_components == 1))
+      return false;
 
+   nir_ssa_def *result;
    nir_src *offset_src = nir_get_io_offset_src(intrin);
    if (bit_size < 32 && nir_src_is_const(*offset_src)) {
       /* The offset is constant so we can use a 32-bit load and just shift it
@@ -102,21 +106,14 @@ lower_mem_load_bit_size(nir_builder *b, nir_intrinsic_instr *intrin)
 
       nir_ssa_def *load = dup_mem_intrinsic(b, intrin, NULL, -load_offset,
                                             load_comps32, 32, 4);
-      nir_ssa_def *unpacked[3];
-      for (unsigned i = 0; i < load_comps32; i++)
-         unpacked[i] = nir_unpack_bits(b, nir_channel(b, load, i), bit_size);
-
-      assert(load_offset % (bit_size / 8) == 0);
-      const unsigned divisor = 32 / bit_size;
-
-      for (unsigned i = 0; i < num_components; i++) {
-         unsigned load_i = i + load_offset / (bit_size / 8);
-         result[i] = nir_channel(b, unpacked[load_i / divisor],
-                                    load_i % divisor);
-      }
+      result = nir_extract_bits(b, &load, 1, load_offset * 8,
+                                num_components, bit_size);
    } else {
-      /* Otherwise, we have to break it into smaller loads */
-      unsigned res_idx = 0;
+      /* Otherwise, we have to break it into smaller loads.  We could end up
+       * with as many as 32 loads if we're loading a u64vec16 from scratch.
+       */
+      nir_ssa_def *loads[32];
+      unsigned num_loads = 0;
       int load_offset = 0;
       while (load_offset < bytes_read) {
          const unsigned bytes_left = bytes_read - load_offset;
@@ -128,34 +125,35 @@ lower_mem_load_bit_size(nir_builder *b, nir_intrinsic_instr *intrin)
          } else {
             assert(load_offset % 4 == 0);
             load_bit_size = 32;
-            load_comps = DIV_ROUND_UP(MIN2(bytes_left, 16), 4);
+            load_comps = needs_scalar ? 1 :
+                         DIV_ROUND_UP(MIN2(bytes_left, 16), 4);
          }
 
-         nir_ssa_def *load = dup_mem_intrinsic(b, intrin, NULL, load_offset,
-                                               load_comps, load_bit_size,
-                                               align);
-
-         nir_ssa_def *unpacked = nir_bitcast_vector(b, load, bit_size);
-         for (unsigned i = 0; i < unpacked->num_components; i++) {
-            if (res_idx < num_components)
-               result[res_idx++] = nir_channel(b, unpacked, i);
-         }
+         loads[num_loads++] = dup_mem_intrinsic(b, intrin, NULL, load_offset,
+                                                load_comps, load_bit_size,
+                                                align);
 
          load_offset += load_comps * (load_bit_size / 8);
       }
+      assert(num_loads <= ARRAY_SIZE(loads));
+      result = nir_extract_bits(b, loads, num_loads, 0,
+                                num_components, bit_size);
    }
 
-   nir_ssa_def *vec_result = nir_vec(b, result, num_components);
    nir_ssa_def_rewrite_uses(&intrin->dest.ssa,
-                            nir_src_for_ssa(vec_result));
+                            nir_src_for_ssa(result));
    nir_instr_remove(&intrin->instr);
 
    return true;
 }
 
 static bool
-lower_mem_store_bit_size(nir_builder *b, nir_intrinsic_instr *intrin)
+lower_mem_store_bit_size(nir_builder *b, nir_intrinsic_instr *intrin,
+                         const struct gen_device_info *devinfo)
 {
+   const bool needs_scalar =
+      intrin->intrinsic == nir_intrinsic_store_scratch;
+
    assert(intrin->src[0].is_ssa);
    nir_ssa_def *value = intrin->src[0].ssa;
 
@@ -171,7 +169,9 @@ lower_mem_store_bit_size(nir_builder *b, nir_intrinsic_instr *intrin)
    assert(writemask < (1 << num_components));
 
    if ((value->bit_size <= 32 && num_components == 1) ||
-       (value->bit_size == 32 && writemask == (1 << num_components) - 1))
+       (value->bit_size == 32 && align >= 32 &&
+        writemask == (1 << num_components) - 1 &&
+        !needs_scalar))
       return false;
 
    nir_src *offset_src = nir_get_io_offset_src(intrin);
@@ -179,20 +179,23 @@ lower_mem_store_bit_size(nir_builder *b, nir_intrinsic_instr *intrin)
    const unsigned const_offset =
       offset_is_const ? nir_src_as_uint(*offset_src) : 0;
 
-   assert(num_components * (bit_size / 8) <= 32);
-   uint32_t byte_mask = 0;
+   const unsigned byte_size = bit_size / 8;
+   assert(byte_size <= sizeof(uint64_t));
+
+   BITSET_DECLARE(mask, NIR_MAX_VEC_COMPONENTS * sizeof(uint64_t));
+   BITSET_ZERO(mask);
+
    for (unsigned i = 0; i < num_components; i++) {
-      if (writemask & (1 << i))
-         byte_mask |= ((1 << (bit_size / 8)) - 1) << i * (bit_size / 8);
+      if (writemask & (1u << i))
+         BITSET_SET_RANGE(mask, i * byte_size, ((i + 1) * byte_size) - 1);
    }
 
-   while (byte_mask) {
-      const int start = ffs(byte_mask) - 1;
-      assert(start % (bit_size / 8) == 0);
+   while (BITSET_FFS(mask) != 0) {
+      const int start = BITSET_FFS(mask) - 1;
 
       int end;
       for (end = start + 1; end < bytes_written; end++) {
-         if (!(byte_mask & (1 << end)))
+         if (!(BITSET_TEST(mask, end)))
             break;
       }
       /* The size of the current contiguous chunk in bytes */
@@ -206,7 +209,7 @@ lower_mem_store_bit_size(nir_builder *b, nir_intrinsic_instr *intrin)
       if (chunk_bytes >= 4 && is_dword_aligned) {
          store_align = MAX2(align, 4);
          store_bit_size = 32;
-         store_comps = MIN2(chunk_bytes, 16) / 4;
+         store_comps = needs_scalar ? 1 : MIN2(chunk_bytes, 16) / 4;
       } else {
          store_align = align;
          store_comps = 1;
@@ -215,24 +218,15 @@ lower_mem_store_bit_size(nir_builder *b, nir_intrinsic_instr *intrin)
          if (store_bit_size == 24)
             store_bit_size = 16;
       }
-
       const unsigned store_bytes = store_comps * (store_bit_size / 8);
-      assert(store_bytes % (bit_size / 8) == 0);
-      const unsigned store_first_src_comp = start / (bit_size / 8);
-      const unsigned store_src_comps = store_bytes / (bit_size / 8);
-      assert(store_first_src_comp + store_src_comps <= num_components);
-
-      unsigned src_swiz[4] = { 0, };
-      for (unsigned i = 0; i < store_src_comps; i++)
-         src_swiz[i] = store_first_src_comp + i;
-      nir_ssa_def *store_value =
-         nir_swizzle(b, value, src_swiz, store_src_comps, false);
-      nir_ssa_def *packed = nir_bitcast_vector(b, store_value, store_bit_size);
+
+      nir_ssa_def *packed = nir_extract_bits(b, &value, 1, start * 8,
+                                             store_comps, store_bit_size);
 
       dup_mem_intrinsic(b, intrin, packed, start,
                         store_comps, store_bit_size, store_align);
 
-      byte_mask &= ~(((1u << store_bytes) - 1) << start);
+      BITSET_CLEAR_RANGE(mask, start, (start + store_bytes - 1));
    }
 
    nir_instr_remove(&intrin->instr);
@@ -241,7 +235,8 @@ lower_mem_store_bit_size(nir_builder *b, nir_intrinsic_instr *intrin)
 }
 
 static bool
-lower_mem_access_bit_sizes_impl(nir_function_impl *impl)
+lower_mem_access_bit_sizes_impl(nir_function_impl *impl,
+                                const struct gen_device_info *devinfo)
 {
    bool progress = false;
 
@@ -260,14 +255,16 @@ lower_mem_access_bit_sizes_impl(nir_function_impl *impl)
          case nir_intrinsic_load_global:
          case nir_intrinsic_load_ssbo:
          case nir_intrinsic_load_shared:
-            if (lower_mem_load_bit_size(&b, intrin))
+         case nir_intrinsic_load_scratch:
+            if (lower_mem_load_bit_size(&b, intrin, devinfo))
                progress = true;
             break;
 
          case nir_intrinsic_store_global:
          case nir_intrinsic_store_ssbo:
          case nir_intrinsic_store_shared:
-            if (lower_mem_store_bit_size(&b, intrin))
+         case nir_intrinsic_store_scratch:
+            if (lower_mem_store_bit_size(&b, intrin, devinfo))
                progress = true;
             break;
 
@@ -280,6 +277,8 @@ lower_mem_access_bit_sizes_impl(nir_function_impl *impl)
    if (progress) {
       nir_metadata_preserve(impl, nir_metadata_block_index |
                                   nir_metadata_dominance);
+   } else {
+      nir_metadata_preserve(impl, nir_metadata_all);
    }
 
    return progress;
@@ -300,14 +299,21 @@ lower_mem_access_bit_sizes_impl(nir_function_impl *impl)
  * all nir load/store intrinsics into a series of either 8 or 32-bit
  * load/store intrinsics with a number of components that we can directly
  * handle in hardware and with a trivial write-mask.
+ *
+ * For scratch access, additional consideration has to be made due to the way
+ * that we swizzle the memory addresses to achieve decent cache locality.  In
+ * particular, even though untyped surface read/write messages exist and work,
+ * we can't use them to load multiple components in a single SEND.  For more
+ * detail on the scratch swizzle, see fs_visitor::swizzle_nir_scratch_addr.
  */
 bool
-brw_nir_lower_mem_access_bit_sizes(nir_shader *shader)
+brw_nir_lower_mem_access_bit_sizes(nir_shader *shader,
+                                   const struct gen_device_info *devinfo)
 {
    bool progress = false;
 
    nir_foreach_function(func, shader) {
-      if (func->impl && lower_mem_access_bit_sizes_impl(func->impl))
+      if (func->impl && lower_mem_access_bit_sizes_impl(func->impl, devinfo))
          progress = true;
    }