nir: Handle all array stride cases in nir_deref_instr_array_stride
authorJason Ekstrand <jason@jlekstrand.net>
Thu, 27 Aug 2020 16:59:54 +0000 (11:59 -0500)
committerMarge Bot <eric+marge@anholt.net>
Thu, 3 Sep 2020 18:02:50 +0000 (18:02 +0000)
This renames it to drop the ptr_as and makes it handle all of the stride
cases.  There's a bit of a tricky bit in here around Booleans but we
currently use 32-bit for those always.

Reviewed-by: Jesse Natalie <jenatali@microsoft.com>
Reviewed-by: Boris Brezillon <boris.brezillon@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6472>

src/amd/llvm/ac_nir_to_llvm.c
src/compiler/nir/nir.h
src/compiler/nir/nir_deref.c
src/compiler/nir/nir_lower_io.c
src/compiler/nir/nir_opt_load_store_vectorize.c
src/compiler/nir/nir_repair_ssa.c
src/compiler/nir/nir_to_lcssa.c

index 13f12d350dcb80e5a37f6629216e058eeb066275..4f47c4d323da28b42dc91c8a7dc1fb1a5d4fc4b7 100644 (file)
@@ -5167,7 +5167,7 @@ static void visit_deref(struct ac_nir_context *ctx,
                break;
        case nir_deref_type_ptr_as_array:
                if (instr->mode == nir_var_mem_global) {
-                       unsigned stride = nir_deref_instr_ptr_as_array_stride(instr);
+                       unsigned stride = nir_deref_instr_array_stride(instr);
 
                        LLVMValueRef index = get_src(ctx, instr->arr.index);
                        if (LLVMTypeOf(index) != ctx->ac.i64)
index 16733fd5e5067a9264f5dcf292fb2f340c5efed1..63cd06e7883c89cd0abd5668ec5b24b314e9e1ba 100644 (file)
@@ -1501,7 +1501,7 @@ bool nir_deref_instr_has_complex_use(nir_deref_instr *instr);
 
 bool nir_deref_instr_remove_if_unused(nir_deref_instr *instr);
 
-unsigned nir_deref_instr_ptr_as_array_stride(nir_deref_instr *instr);
+unsigned nir_deref_instr_array_stride(nir_deref_instr *instr);
 
 typedef struct {
    nir_instr instr;
index 4191b7b75992102cd0b25b978b8edf12da5c53c2..ad0380aa28b026d45d9f2df4789dd80ac9361f57 100644 (file)
@@ -231,14 +231,32 @@ nir_deref_instr_has_complex_use(nir_deref_instr *deref)
    return false;
 }
 
+static unsigned
+type_scalar_size_bytes(const struct glsl_type *type)
+{
+   assert(glsl_type_is_vector_or_scalar(type) ||
+          glsl_type_is_matrix(type));
+   return glsl_type_is_boolean(type) ? 4 : glsl_get_bit_size(type) / 8;
+}
+
 unsigned
-nir_deref_instr_ptr_as_array_stride(nir_deref_instr *deref)
+nir_deref_instr_array_stride(nir_deref_instr *deref)
 {
    switch (deref->deref_type) {
    case nir_deref_type_array:
-      return glsl_get_explicit_stride(nir_deref_instr_parent(deref)->type);
+   case nir_deref_type_array_wildcard: {
+      const struct glsl_type *arr_type = nir_deref_instr_parent(deref)->type;
+      unsigned stride = glsl_get_explicit_stride(arr_type);
+
+      if ((glsl_type_is_matrix(arr_type) &&
+           glsl_matrix_type_is_row_major(arr_type)) ||
+          (glsl_type_is_vector(arr_type) && stride == 0))
+         stride = type_scalar_size_bytes(arr_type);
+
+      return stride;
+   }
    case nir_deref_type_ptr_as_array:
-      return nir_deref_instr_ptr_as_array_stride(nir_deref_instr_parent(deref));
+      return nir_deref_instr_array_stride(nir_deref_instr_parent(deref));
    case nir_deref_type_cast:
       return deref->cast.ptr_stride;
    default:
@@ -817,7 +835,7 @@ is_trivial_array_deref_cast(nir_deref_instr *cast)
              glsl_get_explicit_stride(nir_deref_instr_parent(parent)->type);
    } else if (parent->deref_type == nir_deref_type_ptr_as_array) {
       return cast->cast.ptr_stride ==
-             nir_deref_instr_ptr_as_array_stride(parent);
+             nir_deref_instr_array_stride(parent);
    } else {
       return false;
    }
index 651c7460cd7b68b7a4fd59baab767ba4aa5176ea..28eb899114cfbd97d3db5dad49a33fb4e9188dde 100644 (file)
@@ -1250,14 +1250,7 @@ nir_explicit_io_address_from_deref(nir_builder *b, nir_deref_instr *deref,
       return build_addr_for_var(b, deref->var, addr_format);
 
    case nir_deref_type_array: {
-      nir_deref_instr *parent = nir_deref_instr_parent(deref);
-
-      unsigned stride = glsl_get_explicit_stride(parent->type);
-      if ((glsl_type_is_matrix(parent->type) &&
-           glsl_matrix_type_is_row_major(parent->type)) ||
-          (glsl_type_is_vector(parent->type) && stride == 0))
-         stride = type_scalar_size_bytes(parent->type);
-
+      unsigned stride = nir_deref_instr_array_stride(deref);
       assert(stride > 0);
 
       nir_ssa_def *index = nir_ssa_for_src(b, deref->arr.index, 1);
@@ -1269,7 +1262,7 @@ nir_explicit_io_address_from_deref(nir_builder *b, nir_deref_instr *deref,
    case nir_deref_type_ptr_as_array: {
       nir_ssa_def *index = nir_ssa_for_src(b, deref->arr.index, 1);
       index = nir_i2i(b, index, addr_get_offset_bit_size(base_addr, addr_format));
-      unsigned stride = nir_deref_instr_ptr_as_array_stride(deref);
+      unsigned stride = nir_deref_instr_array_stride(deref);
       return build_addr_iadd(b, base_addr, addr_format,
                                 nir_amul_imm(b, index, stride));
    }
index 0bb5a95bbe01764a0b134e7cdc481d0487f01be9..224323532ace331fc0797c7bca71e0c00b3491b4 100644 (file)
@@ -331,17 +331,6 @@ type_scalar_size_bytes(const struct glsl_type *type)
    return glsl_type_is_boolean(type) ? 4u : glsl_get_bit_size(type) / 8u;
 }
 
-static int
-get_array_stride(const struct glsl_type *type)
-{
-   unsigned explicit_stride = glsl_get_explicit_stride(type);
-   if ((glsl_type_is_matrix(type) &&
-        glsl_matrix_type_is_row_major(type)) ||
-       (glsl_type_is_vector(type) && explicit_stride == 0))
-      return type_scalar_size_bytes(type);
-   return explicit_stride;
-}
-
 static uint64_t
 mask_sign_extend(uint64_t val, unsigned bit_size)
 {
@@ -413,11 +402,7 @@ create_entry_key_from_deref(void *mem_ctx,
       case nir_deref_type_ptr_as_array: {
          assert(parent);
          nir_ssa_def *index = deref->arr.index.ssa;
-         uint32_t stride;
-         if (deref->deref_type == nir_deref_type_ptr_as_array)
-            stride = nir_deref_instr_ptr_as_array_stride(deref);
-         else
-            stride = get_array_stride(parent->type);
+         uint32_t stride = nir_deref_instr_array_stride(deref);
 
          nir_ssa_def *base = index;
          uint64_t offset = 0, base_mul = 1;
@@ -741,8 +726,8 @@ static nir_deref_instr *subtract_deref(nir_builder *b, nir_deref_instr *deref, i
    /* avoid adding another deref to the path */
    if (deref->deref_type == nir_deref_type_ptr_as_array &&
        nir_src_is_const(deref->arr.index) &&
-       offset % nir_deref_instr_ptr_as_array_stride(deref) == 0) {
-      unsigned stride = nir_deref_instr_ptr_as_array_stride(deref);
+       offset % nir_deref_instr_array_stride(deref) == 0) {
+      unsigned stride = nir_deref_instr_array_stride(deref);
       nir_ssa_def *index = nir_imm_intN_t(b, nir_src_as_int(deref->arr.index) - offset / stride,
                                           deref->dest.ssa.bit_size);
       return nir_build_deref_ptr_as_array(b, nir_deref_instr_parent(deref), index);
index 8eceadcebb08261ff53b0203ebec2f1d6a9cf29f..ce5465ebf7763892c30ff7c93b38d63bb84c1098 100644 (file)
@@ -128,7 +128,7 @@ repair_ssa_def(nir_ssa_def *def, void *void_state)
          cast->mode = deref->mode;
          cast->type = deref->type;
          cast->parent = nir_src_for_ssa(block_def);
-         cast->cast.ptr_stride = nir_deref_instr_ptr_as_array_stride(deref);
+         cast->cast.ptr_stride = nir_deref_instr_array_stride(deref);
 
          nir_ssa_dest_init(&cast->instr, &cast->dest,
                            def->num_components, def->bit_size, NULL);
index 327de85d36d7e45fa2726974de3d8b09576f5267..de2c7b600bd70197d57759108d38225361db566f 100644 (file)
@@ -253,7 +253,7 @@ convert_loop_exit_for_ssa(nir_ssa_def *def, void *void_state)
       cast->mode = instr->mode;
       cast->type = instr->type;
       cast->parent = nir_src_for_ssa(&phi->dest.ssa);
-      cast->cast.ptr_stride = nir_deref_instr_ptr_as_array_stride(instr);
+      cast->cast.ptr_stride = nir_deref_instr_array_stride(instr);
 
       nir_ssa_dest_init(&cast->instr, &cast->dest,
                         phi->dest.ssa.num_components,