nir/algebraic: trivially enable existing 32-bit patterns for all bit sizes
[mesa.git] / src / compiler / nir / nir_opt_load_store_vectorize.c
index b0d3a7d0d90192ce11cfde94e13bc816d073b092..147b88c35941f77ba47258942e62a99f1d213bc2 100644 (file)
@@ -80,6 +80,8 @@ case nir_intrinsic_##op: {\
    STORE(0, deref, -1, -1, 0, 1)
    LOAD(nir_var_mem_shared, shared, -1, 0, -1)
    STORE(nir_var_mem_shared, shared, -1, 1, -1, 0)
+   LOAD(nir_var_mem_global, global, -1, 0, -1)
+   STORE(nir_var_mem_global, global, -1, 1, -1, 0)
    ATOMIC(nir_var_mem_ssbo, ssbo, add, 0, 1, -1, 2)
    ATOMIC(nir_var_mem_ssbo, ssbo, imin, 0, 1, -1, 2)
    ATOMIC(nir_var_mem_ssbo, ssbo, umin, 0, 1, -1, 2)
@@ -122,6 +124,20 @@ case nir_intrinsic_##op: {\
    ATOMIC(nir_var_mem_shared, shared, fmin, -1, 0, -1, 1)
    ATOMIC(nir_var_mem_shared, shared, fmax, -1, 0, -1, 1)
    ATOMIC(nir_var_mem_shared, shared, fcomp_swap, -1, 0, -1, 1)
+   ATOMIC(nir_var_mem_global, global, add, -1, 0, -1, 1)
+   ATOMIC(nir_var_mem_global, global, imin, -1, 0, -1, 1)
+   ATOMIC(nir_var_mem_global, global, umin, -1, 0, -1, 1)
+   ATOMIC(nir_var_mem_global, global, imax, -1, 0, -1, 1)
+   ATOMIC(nir_var_mem_global, global, umax, -1, 0, -1, 1)
+   ATOMIC(nir_var_mem_global, global, and, -1, 0, -1, 1)
+   ATOMIC(nir_var_mem_global, global, or, -1, 0, -1, 1)
+   ATOMIC(nir_var_mem_global, global, xor, -1, 0, -1, 1)
+   ATOMIC(nir_var_mem_global, global, exchange, -1, 0, -1, 1)
+   ATOMIC(nir_var_mem_global, global, comp_swap, -1, 0, -1, 1)
+   ATOMIC(nir_var_mem_global, global, fadd, -1, 0, -1, 1)
+   ATOMIC(nir_var_mem_global, global, fmin, -1, 0, -1, 1)
+   ATOMIC(nir_var_mem_global, global, fmax, -1, 0, -1, 1)
+   ATOMIC(nir_var_mem_global, global, fcomp_swap, -1, 0, -1, 1)
    default:
       break;
 #undef ATOMIC
@@ -157,7 +173,8 @@ struct entry {
       uint64_t offset; /* sign-extended */
       int64_t offset_signed;
    };
-   uint32_t best_align;
+   uint32_t align_mul;
+   uint32_t align_offset;
 
    nir_instr *instr;
    nir_intrinsic_instr *intrin;
@@ -171,6 +188,7 @@ struct entry {
 struct vectorize_ctx {
    nir_variable_mode modes;
    nir_should_vectorize_mem_func callback;
+   nir_variable_mode robust_modes;
    struct list_head entries[nir_num_variable_modes];
    struct hash_table *loads[nir_num_variable_modes];
    struct hash_table *stores[nir_num_variable_modes];
@@ -182,20 +200,19 @@ static uint32_t hash_entry_key(const void *key_)
     * the order of the hash table walk is deterministic */
    struct entry_key *key = (struct entry_key*)key_;
 
-   uint32_t hash = _mesa_fnv32_1a_offset_bias;
+   uint32_t hash = 0;
    if (key->resource)
-      hash = _mesa_fnv32_1a_accumulate(hash, key->resource->index);
+      hash = XXH32(&key->resource->index, sizeof(key->resource->index), hash);
    if (key->var) {
-      hash = _mesa_fnv32_1a_accumulate(hash, key->var->index);
+      hash = XXH32(&key->var->index, sizeof(key->var->index), hash);
       unsigned mode = key->var->data.mode;
-      hash = _mesa_fnv32_1a_accumulate(hash, mode);
+      hash = XXH32(&mode, sizeof(mode), hash);
    }
 
    for (unsigned i = 0; i < key->offset_def_count; i++)
-      hash = _mesa_fnv32_1a_accumulate(hash, key->offset_defs[i]->index);
+      hash = XXH32(&key->offset_defs[i]->index, sizeof(key->offset_defs[i]->index), hash);
 
-   hash = _mesa_fnv32_1a_accumulate_block(
-      hash, key->offset_defs_mul, key->offset_def_count * sizeof(uint64_t));
+   hash = XXH32(key->offset_defs_mul, key->offset_def_count * sizeof(uint64_t), hash);
 
    return hash;
 }
@@ -315,17 +332,6 @@ type_scalar_size_bytes(const struct glsl_type *type)
    return glsl_type_is_boolean(type) ? 4u : glsl_get_bit_size(type) / 8u;
 }
 
-static int
-get_array_stride(const struct glsl_type *type)
-{
-   unsigned explicit_stride = glsl_get_explicit_stride(type);
-   if ((glsl_type_is_matrix(type) &&
-        glsl_matrix_type_is_row_major(type)) ||
-       (glsl_type_is_vector(type) && explicit_stride == 0))
-      return type_scalar_size_bytes(type);
-   return explicit_stride;
-}
-
 static uint64_t
 mask_sign_extend(uint64_t val, unsigned bit_size)
 {
@@ -397,11 +403,7 @@ create_entry_key_from_deref(void *mem_ctx,
       case nir_deref_type_ptr_as_array: {
          assert(parent);
          nir_ssa_def *index = deref->arr.index.ssa;
-         uint32_t stride;
-         if (deref->deref_type == nir_deref_type_ptr_as_array)
-            stride = nir_deref_instr_ptr_as_array_stride(deref);
-         else
-            stride = get_array_stride(parent->type);
+         uint32_t stride = nir_deref_instr_array_stride(deref);
 
          nir_ssa_def *base = index;
          uint64_t offset = 0, base_mul = 1;
@@ -515,6 +517,46 @@ get_variable_mode(struct entry *entry)
    return entry->deref->mode;
 }
 
+static unsigned
+mode_to_index(nir_variable_mode mode)
+{
+   assert(util_bitcount(mode) == 1);
+
+   /* Globals and SSBOs should be tracked together */
+   if (mode == nir_var_mem_global)
+      mode = nir_var_mem_ssbo;
+
+   return ffs(mode) - 1;
+}
+
+static nir_variable_mode
+aliasing_modes(nir_variable_mode modes)
+{
+   /* Global and SSBO can alias */
+   if (modes & (nir_var_mem_ssbo | nir_var_mem_global))
+      modes |= nir_var_mem_ssbo | nir_var_mem_global;
+   return modes;
+}
+
+static void
+calc_alignment(struct entry *entry)
+{
+   uint32_t align_mul = 31;
+   for (unsigned i = 0; i < entry->key->offset_def_count; i++) {
+      if (entry->key->offset_defs_mul[i])
+         align_mul = MIN2(align_mul, ffsll(entry->key->offset_defs_mul[i]));
+   }
+
+   entry->align_mul = 1u << (align_mul - 1);
+   bool has_align = nir_intrinsic_infos[entry->intrin->intrinsic].index_map[NIR_INTRINSIC_ALIGN_MUL];
+   if (!has_align || entry->align_mul >= nir_intrinsic_align_mul(entry->intrin)) {
+      entry->align_offset = entry->offset % entry->align_mul;
+   } else {
+      entry->align_mul = nir_intrinsic_align_mul(entry->intrin);
+      entry->align_offset = nir_intrinsic_align_offset(entry->intrin);
+   }
+}
+
 static struct entry *
 create_entry(struct vectorize_ctx *ctx,
              const struct intrinsic_info *info,
@@ -524,7 +566,6 @@ create_entry(struct vectorize_ctx *ctx,
    entry->intrin = intrin;
    entry->instr = &intrin->instr;
    entry->info = info;
-   entry->best_align = UINT32_MAX;
    entry->is_store = entry->info->value_src >= 0;
 
    if (entry->info->deref_src >= 0) {
@@ -537,7 +578,7 @@ create_entry(struct vectorize_ctx *ctx,
       nir_ssa_def *base = entry->info->base_src >= 0 ?
                           intrin->src[entry->info->base_src].ssa : NULL;
       uint64_t offset = 0;
-      if (nir_intrinsic_infos[intrin->intrinsic].index_map[NIR_INTRINSIC_BASE])
+      if (nir_intrinsic_has_base(intrin))
          offset += nir_intrinsic_base(intrin);
       entry->key = create_entry_key_from_offset(entry, base, 1, &offset);
       entry->offset = offset;
@@ -549,7 +590,7 @@ create_entry(struct vectorize_ctx *ctx,
    if (entry->info->resource_src >= 0)
       entry->key->resource = intrin->src[entry->info->resource_src].ssa;
 
-   if (nir_intrinsic_infos[intrin->intrinsic].index_map[NIR_INTRINSIC_ACCESS])
+   if (nir_intrinsic_has_access(intrin))
       entry->access = nir_intrinsic_access(intrin);
    else if (entry->key->var)
       entry->access = entry->key->var->data.access;
@@ -561,6 +602,8 @@ create_entry(struct vectorize_ctx *ctx,
    if (get_variable_mode(entry) & restrict_modes)
       entry->access |= ACCESS_RESTRICT;
 
+   calc_alignment(entry);
+
    return entry;
 }
 
@@ -601,40 +644,6 @@ writemask_representable(unsigned write_mask, unsigned old_bit_size, unsigned new
    return true;
 }
 
-static uint64_t
-gcd(uint64_t a, uint64_t b)
-{
-   while (b) {
-      uint64_t old_b = b;
-      b = a % b;
-      a = old_b;
-   }
-   return a;
-}
-
-static uint32_t
-get_best_align(struct entry *entry)
-{
-   if (entry->best_align != UINT32_MAX)
-      return entry->best_align;
-
-   uint64_t best_align = entry->offset;
-   for (unsigned i = 0; i < entry->key->offset_def_count; i++) {
-      if (!best_align)
-         best_align = entry->key->offset_defs_mul[i];
-      else if (entry->key->offset_defs_mul[i])
-         best_align = gcd(best_align, entry->key->offset_defs_mul[i]);
-   }
-
-   if (nir_intrinsic_infos[entry->intrin->intrinsic].index_map[NIR_INTRINSIC_ALIGN_MUL])
-      best_align = MAX2(best_align, nir_intrinsic_align(entry->intrin));
-
-   /* ensure the result is a power of two that fits in a int32_t */
-   entry->best_align = gcd(best_align, 1u << 30);
-
-   return entry->best_align;
-}
-
 /* Return true if "new_bit_size" is a usable bit size for a vectorized load/store
  * of "low" and "high". */
 static bool
@@ -658,7 +667,8 @@ new_bitsize_acceptable(struct vectorize_ctx *ctx, unsigned new_bit_size,
    if (new_bit_size / common_bit_size > NIR_MAX_VEC_COMPONENTS)
       return false;
 
-   if (!ctx->callback(get_best_align(low), new_bit_size, new_num_components,
+   uint32_t align = low->align_offset ? 1 << (ffs(low->align_offset) - 1) : low->align_mul;
+   if (!ctx->callback(align, new_bit_size, new_num_components,
                       high_offset, low->intrin, high->intrin))
       return false;
 
@@ -704,8 +714,8 @@ static nir_deref_instr *subtract_deref(nir_builder *b, nir_deref_instr *deref, i
    /* avoid adding another deref to the path */
    if (deref->deref_type == nir_deref_type_ptr_as_array &&
        nir_src_is_const(deref->arr.index) &&
-       offset % nir_deref_instr_ptr_as_array_stride(deref) == 0) {
-      unsigned stride = nir_deref_instr_ptr_as_array_stride(deref);
+       offset % nir_deref_instr_array_stride(deref) == 0) {
+      unsigned stride = nir_deref_instr_array_stride(deref);
       nir_ssa_def *index = nir_imm_intN_t(b, nir_src_as_int(deref->arr.index) - offset / stride,
                                           deref->dest.ssa.bit_size);
       return nir_build_deref_ptr_as_array(b, nir_deref_instr_parent(deref), index);
@@ -727,20 +737,6 @@ static nir_deref_instr *subtract_deref(nir_builder *b, nir_deref_instr *deref, i
       b, deref, nir_imm_intN_t(b, -offset, deref->dest.ssa.bit_size));
 }
 
-static bool update_align(struct entry *entry)
-{
-   bool has_align_index =
-      nir_intrinsic_infos[entry->intrin->intrinsic].index_map[NIR_INTRINSIC_ALIGN_MUL];
-   if (has_align_index) {
-      unsigned align = get_best_align(entry);
-      if (align != nir_intrinsic_align(entry->intrin)) {
-         nir_intrinsic_set_align(entry->intrin, align, 0);
-         return true;
-      }
-   }
-   return false;
-}
-
 static void
 vectorize_loads(nir_builder *b, struct vectorize_ctx *ctx,
                 struct entry *low, struct entry *high,
@@ -793,7 +789,7 @@ vectorize_loads(nir_builder *b, struct vectorize_ctx *ctx,
       b->cursor = nir_before_instr(first->instr);
 
       nir_ssa_def *new_base = first->intrin->src[info->base_src].ssa;
-      new_base = nir_iadd(b, new_base, nir_imm_int(b, -(high_start / 8u)));
+      new_base = nir_iadd_imm(b, new_base, -(int)(high_start / 8u));
 
       nir_instr_rewrite_src(first->instr, &first->intrin->src[info->base_src],
                             nir_src_for_ssa(new_base));
@@ -813,17 +809,14 @@ vectorize_loads(nir_builder *b, struct vectorize_ctx *ctx,
    }
 
    /* update base/align */
-   bool has_base_index =
-      nir_intrinsic_infos[first->intrin->intrinsic].index_map[NIR_INTRINSIC_BASE];
-
-   if (first != low && has_base_index)
+   if (first != low && nir_intrinsic_has_base(first->intrin))
       nir_intrinsic_set_base(first->intrin, nir_intrinsic_base(low->intrin));
 
    first->key = low->key;
    first->offset = low->offset;
-   first->best_align = get_best_align(low);
 
-   update_align(first);
+   first->align_mul = low->align_mul;
+   first->align_offset = low->align_offset;
 
    nir_instr_remove(second->instr);
 }
@@ -898,17 +891,14 @@ vectorize_stores(nir_builder *b, struct vectorize_ctx *ctx,
    }
 
    /* update base/align */
-   bool has_base_index =
-      nir_intrinsic_infos[second->intrin->intrinsic].index_map[NIR_INTRINSIC_BASE];
-
-   if (second != low && has_base_index)
+   if (second != low && nir_intrinsic_has_base(second->intrin))
       nir_intrinsic_set_base(second->intrin, nir_intrinsic_base(low->intrin));
 
    second->key = low->key;
    second->offset = low->offset;
-   second->best_align = get_best_align(low);
 
-   update_align(second);
+   second->align_mul = low->align_mul;
+   second->align_offset = low->align_offset;
 
    list_del(&first->head);
    nir_instr_remove(first->instr);
@@ -952,7 +942,8 @@ compare_entries(struct entry *a, struct entry *b)
 static bool
 may_alias(struct entry *a, struct entry *b)
 {
-   assert(get_variable_mode(a) == get_variable_mode(b));
+   assert(mode_to_index(get_variable_mode(a)) ==
+          mode_to_index(get_variable_mode(b)));
 
    /* if the resources/variables are definitively different and both have
     * ACCESS_RESTRICT, we can assume they do not alias. */
@@ -989,7 +980,7 @@ check_for_aliasing(struct vectorize_ctx *ctx, struct entry *first, struct entry
                nir_var_mem_push_const | nir_var_mem_ubo))
       return false;
 
-   unsigned mode_index = ffs(mode) - 1;
+   unsigned mode_index = mode_to_index(mode);
    if (first->is_store) {
       /* find first entry that aliases "first" */
       list_for_each_entry_from(struct entry, next, first, &ctx->entries[mode_index], head) {
@@ -1015,11 +1006,28 @@ check_for_aliasing(struct vectorize_ctx *ctx, struct entry *first, struct entry
    return false;
 }
 
+static bool
+check_for_robustness(struct vectorize_ctx *ctx, struct entry *low)
+{
+   nir_variable_mode mode = get_variable_mode(low);
+   if (mode & ctx->robust_modes) {
+      unsigned low_bit_size = get_bit_size(low);
+      unsigned low_size = low->intrin->num_components * low_bit_size;
+
+      /* don't attempt to vectorize accesses if the offset can overflow. */
+      /* TODO: handle indirect accesses. */
+      return low->offset_signed < 0 && low->offset_signed + low_size >= 0;
+   }
+
+   return false;
+}
+
 static bool
 is_strided_vector(const struct glsl_type *type)
 {
    if (glsl_type_is_vector(type)) {
-      return glsl_get_explicit_stride(type) !=
+      unsigned explicit_stride = glsl_get_explicit_stride(type);
+      return explicit_stride != 0 && explicit_stride !=
              type_scalar_size_bytes(glsl_get_array_element(type));
    } else {
       return false;
@@ -1031,9 +1039,16 @@ try_vectorize(nir_function_impl *impl, struct vectorize_ctx *ctx,
               struct entry *low, struct entry *high,
               struct entry *first, struct entry *second)
 {
+   if (!(get_variable_mode(first) & ctx->modes) ||
+       !(get_variable_mode(second) & ctx->modes))
+      return false;
+
    if (check_for_aliasing(ctx, first, second))
       return false;
 
+   if (check_for_robustness(ctx, low))
+      return false;
+
    /* we can only vectorize non-volatile loads/stores of the same type and with
     * the same access */
    if (first->info != second->info || first->access != second->access ||
@@ -1091,6 +1106,18 @@ try_vectorize(nir_function_impl *impl, struct vectorize_ctx *ctx,
    return true;
 }
 
+static bool
+update_align(struct entry *entry)
+{
+   if (nir_intrinsic_has_align_mul(entry->intrin) &&
+       (entry->align_mul != nir_intrinsic_align_mul(entry->intrin) ||
+        entry->align_offset != nir_intrinsic_align_offset(entry->intrin))) {
+      nir_intrinsic_set_align(entry->intrin, entry->align_mul, entry->align_offset);
+      return true;
+   }
+   return false;
+}
+
 static bool
 vectorize_entries(struct vectorize_ctx *ctx, nir_function_impl *impl, struct hash_table *ht)
 {
@@ -1113,10 +1140,8 @@ vectorize_entries(struct vectorize_ctx *ctx, nir_function_impl *impl, struct has
          struct entry *high = *util_dynarray_element(arr, struct entry *, i + 1);
 
          uint64_t diff = high->offset_signed - low->offset_signed;
-         if (diff > get_bit_size(low) / 8u * low->intrin->num_components) {
-            progress |= update_align(low);
+         if (diff > get_bit_size(low) / 8u * low->intrin->num_components)
             continue;
-         }
 
          struct entry *first = low->index < high->index ? low : high;
          struct entry *second = low->index < high->index ? high : low;
@@ -1125,13 +1150,13 @@ vectorize_entries(struct vectorize_ctx *ctx, nir_function_impl *impl, struct has
             *util_dynarray_element(arr, struct entry *, i) = NULL;
             *util_dynarray_element(arr, struct entry *, i + 1) = low->is_store ? second : first;
             progress = true;
-         } else {
-            progress |= update_align(low);
          }
       }
 
-      struct entry *last = *util_dynarray_element(arr, struct entry *, i);
-      progress |= update_align(last);
+      util_dynarray_foreach(arr, struct entry *, elem) {
+         if (*elem)
+            progress |= update_align(*elem);
+      }
    }
 
    _mesa_hash_table_clear(ht, delete_entry_dynarray);
@@ -1163,8 +1188,13 @@ handle_barrier(struct vectorize_ctx *ctx, bool *progress, nir_function_impl *imp
       case nir_intrinsic_memory_barrier_shared:
          modes = nir_var_mem_shared;
          break;
-      case nir_intrinsic_scoped_memory_barrier:
-         modes = nir_intrinsic_memory_modes(intrin);
+      case nir_intrinsic_scoped_barrier:
+         if (nir_intrinsic_memory_scope(intrin) == NIR_SCOPE_NONE)
+            break;
+
+         modes = nir_intrinsic_memory_modes(intrin) & (nir_var_mem_ssbo |
+                                                       nir_var_mem_shared |
+                                                       nir_var_mem_global);
          acquire = nir_intrinsic_memory_semantics(intrin) & NIR_MEMORY_ACQUIRE;
          release = nir_intrinsic_memory_semantics(intrin) & NIR_MEMORY_RELEASE;
          switch (nir_intrinsic_memory_scope(intrin)) {
@@ -1188,6 +1218,13 @@ handle_barrier(struct vectorize_ctx *ctx, bool *progress, nir_function_impl *imp
 
    while (modes) {
       unsigned mode_index = u_bit_scan(&modes);
+      if ((1 << mode_index) == nir_var_mem_global) {
+         /* Global should be rolled in with SSBO */
+         assert(list_is_empty(&ctx->entries[mode_index]));
+         assert(ctx->loads[mode_index] == NULL);
+         assert(ctx->stores[mode_index] == NULL);
+         continue;
+      }
 
       if (acquire)
          *progress |= vectorize_entries(ctx, impl, ctx->loads[mode_index]);
@@ -1230,9 +1267,9 @@ process_block(nir_function_impl *impl, struct vectorize_ctx *ctx, nir_block *blo
       nir_variable_mode mode = info->mode;
       if (!mode)
          mode = nir_src_as_deref(intrin->src[info->deref_src])->mode;
-      if (!(mode & ctx->modes))
+      if (!(mode & aliasing_modes(ctx->modes)))
          continue;
-      unsigned mode_index = ffs(mode) - 1;
+      unsigned mode_index = mode_to_index(mode);
 
       /* create entry */
       struct entry *entry = create_entry(ctx, info, intrin);
@@ -1277,20 +1314,22 @@ process_block(nir_function_impl *impl, struct vectorize_ctx *ctx, nir_block *blo
 
 bool
 nir_opt_load_store_vectorize(nir_shader *shader, nir_variable_mode modes,
-                             nir_should_vectorize_mem_func callback)
+                             nir_should_vectorize_mem_func callback,
+                             nir_variable_mode robust_modes)
 {
    bool progress = false;
 
    struct vectorize_ctx *ctx = rzalloc(NULL, struct vectorize_ctx);
    ctx->modes = modes;
    ctx->callback = callback;
+   ctx->robust_modes = robust_modes;
 
-   nir_index_vars(shader, NULL, modes);
+   nir_shader_index_vars(shader, modes);
 
    nir_foreach_function(function, shader) {
       if (function->impl) {
          if (modes & nir_var_function_temp)
-            nir_index_vars(shader, function->impl, nir_var_function_temp);
+            nir_function_impl_index_vars(function->impl);
 
          nir_foreach_block(block, function->impl)
             progress |= process_block(function->impl, ctx, block);