spirv: Propagate alignments to deref chains via casts
authorJason Ekstrand <jason@jlekstrand.net>
Thu, 27 Aug 2020 23:34:50 +0000 (18:34 -0500)
committerMarge Bot <eric+marge@anholt.net>
Thu, 3 Sep 2020 18:02:50 +0000 (18:02 +0000)
This commit propagates the alignment information provided either through
the Alignment decoration on pointers or via the alignment mem operands
to OpLoad, OpStore, and OpCopyMemory to the NIR deref chain.  It does so
by wrapping the deref in a cast.  NIR should be able to clean up most
unnecessary casts only leaving us with the useful alignment information.

Reviewed-by: Jesse Natalie <jenatali@microsoft.com>
Reviewed-by: Boris Brezillon <boris.brezillon@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6472>

src/compiler/nir/nir_builder.h
src/compiler/spirv/vtn_variables.c

index a33a33a870dae6005338b2ddf8a08ab370a3e1dd..f9ac4c830a375c9cfae221fd8381a4190f595ba0 100644 (file)
@@ -1210,6 +1210,29 @@ nir_build_deref_cast(nir_builder *build, nir_ssa_def *parent,
    return deref;
 }
 
+static inline nir_deref_instr *
+nir_alignment_deref_cast(nir_builder *build, nir_deref_instr *parent,
+                         uint32_t align_mul, uint32_t align_offset)
+{
+   nir_deref_instr *deref =
+      nir_deref_instr_create(build->shader, nir_deref_type_cast);
+
+   deref->mode = parent->mode;
+   deref->type = parent->type;
+   deref->parent = nir_src_for_ssa(&parent->dest.ssa);
+   deref->cast.ptr_stride = nir_deref_instr_array_stride(deref);
+   deref->cast.align_mul = align_mul;
+   deref->cast.align_offset = align_offset;
+
+   nir_ssa_dest_init(&deref->instr, &deref->dest,
+                     parent->dest.ssa.num_components,
+                     parent->dest.ssa.bit_size, NULL);
+
+   nir_builder_instr_insert(build, &deref->instr);
+
+   return deref;
+}
+
 /** Returns a deref that follows another but starting from the given parent
  *
  * The new deref will be the same type and take the same array or struct index
index 2cde9ac3545ad3871bd2a8012b7300115802c58d..820a4c8c97ddea1d1fc155838354e50f071620ce 100644 (file)
 #include "nir_deref.h"
 #include <vulkan/vulkan_core.h>
 
+static struct vtn_pointer*
+vtn_align_pointer(struct vtn_builder *b, struct vtn_pointer *ptr,
+                  unsigned alignment)
+{
+   if (alignment == 0)
+      return ptr;
+
+   if (!util_is_power_of_two_nonzero(alignment)) {
+      vtn_warn("Provided alignment is not a power of two");
+      alignment = 1 << (ffs(alignment) - 1);
+   }
+
+   /* If this pointer doesn't have a deref, bail.  This either means we're
+    * using the old offset+alignment pointers which don't support carrying
+    * alignment information or we're a pointer that is below the block
+    * boundary in our access chain in which case alignment is meaningless.
+    */
+   if (ptr->deref == NULL)
+      return ptr;
+
+   /* Ignore alignment information on logical pointers.  This way, we don't
+    * trip up drivers with unnecessary casts.
+    */
+   nir_address_format addr_format = vtn_mode_to_address_format(b, ptr->mode);
+   if (addr_format == nir_address_format_logical)
+      return ptr;
+
+   struct vtn_pointer *copy = ralloc(b, struct vtn_pointer);
+   *copy = *ptr;
+   copy->deref = nir_alignment_deref_cast(&b->nb, ptr->deref, alignment, 0);
+
+   return copy;
+}
+
 static void
 ptr_decoration_cb(struct vtn_builder *b, struct vtn_value *val, int member,
                   const struct vtn_decoration *dec, void *void_ptr)
@@ -46,21 +80,48 @@ ptr_decoration_cb(struct vtn_builder *b, struct vtn_value *val, int member,
    }
 }
 
+struct access_align {
+   enum gl_access_qualifier access;
+   uint32_t alignment;
+};
+
+static void
+access_align_cb(struct vtn_builder *b, struct vtn_value *val, int member,
+                const struct vtn_decoration *dec, void *void_ptr)
+{
+   struct access_align *aa = void_ptr;
+
+   switch (dec->decoration) {
+   case SpvDecorationAlignment:
+      aa->alignment = dec->operands[0];
+      break;
+
+   case SpvDecorationNonUniformEXT:
+      aa->access |= ACCESS_NON_UNIFORM;
+      break;
+
+   default:
+      break;
+   }
+}
+
 static struct vtn_pointer*
 vtn_decorate_pointer(struct vtn_builder *b, struct vtn_value *val,
                      struct vtn_pointer *ptr)
 {
-   struct vtn_pointer dummy = { .access = 0 };
-   vtn_foreach_decoration(b, val, ptr_decoration_cb, &dummy);
+   struct access_align aa = { 0, };
+   vtn_foreach_decoration(b, val, access_align_cb, &aa);
+
+   ptr = vtn_align_pointer(b, ptr, aa.alignment);
 
    /* If we're adding access flags, make a copy of the pointer.  We could
     * probably just OR them in without doing so but this prevents us from
     * leaking them any further than actually specified in the SPIR-V.
     */
-   if (dummy.access & ~ptr->access) {
+   if (aa.access & ~ptr->access) {
       struct vtn_pointer *copy = ralloc(b, struct vtn_pointer);
       *copy = *ptr;
-      copy->access |= dummy.access;
+      copy->access |= aa.access;
       return copy;
    }
 
@@ -2654,6 +2715,8 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
          src_alignment = dest_alignment;
          src_access = dest_access;
       }
+      src = vtn_align_pointer(b, src, src_alignment);
+      dest = vtn_align_pointer(b, dest, dest_alignment);
 
       vtn_emit_make_visible_barrier(b, src_access, src_scope, src->mode);
 
@@ -2676,6 +2739,7 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
       SpvMemoryAccessMask access;
       SpvScope scope;
       vtn_get_mem_operands(b, w, count, &idx, &access, &alignment, NULL, &scope);
+      src = vtn_align_pointer(b, src, alignment);
 
       vtn_emit_make_visible_barrier(b, access, scope, src->mode);
 
@@ -2717,6 +2781,7 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
       SpvMemoryAccessMask access;
       SpvScope scope;
       vtn_get_mem_operands(b, w, count, &idx, &access, &alignment, &scope, NULL);
+      dest = vtn_align_pointer(b, dest, alignment);
 
       struct vtn_ssa_value *src = vtn_ssa_value(b, w[2]);
       vtn_variable_store(b, src, dest, spv_access_to_gl_access(access));