spirv: Propagate alignments to deref chains via casts
[mesa.git] / src / compiler / spirv / vtn_variables.c
index 2cde9ac3545ad3871bd2a8012b7300115802c58d..820a4c8c97ddea1d1fc155838354e50f071620ce 100644 (file)
 #include "nir_deref.h"
 #include <vulkan/vulkan_core.h>
 
+static struct vtn_pointer*
+vtn_align_pointer(struct vtn_builder *b, struct vtn_pointer *ptr,
+                  unsigned alignment)
+{
+   if (alignment == 0)
+      return ptr;
+
+   if (!util_is_power_of_two_nonzero(alignment)) {
+      vtn_warn("Provided alignment is not a power of two");
+      alignment = 1 << (ffs(alignment) - 1);
+   }
+
+   /* If this pointer doesn't have a deref, bail.  This either means we're
+    * using the old offset+alignment pointers which don't support carrying
+    * alignment information or we're a pointer that is below the block
+    * boundary in our access chain in which case alignment is meaningless.
+    */
+   if (ptr->deref == NULL)
+      return ptr;
+
+   /* Ignore alignment information on logical pointers.  This way, we don't
+    * trip up drivers with unnecessary casts.
+    */
+   nir_address_format addr_format = vtn_mode_to_address_format(b, ptr->mode);
+   if (addr_format == nir_address_format_logical)
+      return ptr;
+
+   struct vtn_pointer *copy = ralloc(b, struct vtn_pointer);
+   *copy = *ptr;
+   copy->deref = nir_alignment_deref_cast(&b->nb, ptr->deref, alignment, 0);
+
+   return copy;
+}
+
 static void
 ptr_decoration_cb(struct vtn_builder *b, struct vtn_value *val, int member,
                   const struct vtn_decoration *dec, void *void_ptr)
@@ -46,21 +80,48 @@ ptr_decoration_cb(struct vtn_builder *b, struct vtn_value *val, int member,
    }
 }
 
+struct access_align {
+   enum gl_access_qualifier access;
+   uint32_t alignment;
+};
+
+static void
+access_align_cb(struct vtn_builder *b, struct vtn_value *val, int member,
+                const struct vtn_decoration *dec, void *void_ptr)
+{
+   struct access_align *aa = void_ptr;
+
+   switch (dec->decoration) {
+   case SpvDecorationAlignment:
+      aa->alignment = dec->operands[0];
+      break;
+
+   case SpvDecorationNonUniformEXT:
+      aa->access |= ACCESS_NON_UNIFORM;
+      break;
+
+   default:
+      break;
+   }
+}
+
 static struct vtn_pointer*
 vtn_decorate_pointer(struct vtn_builder *b, struct vtn_value *val,
                      struct vtn_pointer *ptr)
 {
-   struct vtn_pointer dummy = { .access = 0 };
-   vtn_foreach_decoration(b, val, ptr_decoration_cb, &dummy);
+   struct access_align aa = { 0, };
+   vtn_foreach_decoration(b, val, access_align_cb, &aa);
+
+   ptr = vtn_align_pointer(b, ptr, aa.alignment);
 
    /* If we're adding access flags, make a copy of the pointer.  We could
     * probably just OR them in without doing so but this prevents us from
     * leaking them any further than actually specified in the SPIR-V.
     */
-   if (dummy.access & ~ptr->access) {
+   if (aa.access & ~ptr->access) {
       struct vtn_pointer *copy = ralloc(b, struct vtn_pointer);
       *copy = *ptr;
-      copy->access |= dummy.access;
+      copy->access |= aa.access;
       return copy;
    }
 
@@ -2654,6 +2715,8 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
          src_alignment = dest_alignment;
          src_access = dest_access;
       }
+      src = vtn_align_pointer(b, src, src_alignment);
+      dest = vtn_align_pointer(b, dest, dest_alignment);
 
       vtn_emit_make_visible_barrier(b, src_access, src_scope, src->mode);
 
@@ -2676,6 +2739,7 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
       SpvMemoryAccessMask access;
       SpvScope scope;
       vtn_get_mem_operands(b, w, count, &idx, &access, &alignment, NULL, &scope);
+      src = vtn_align_pointer(b, src, alignment);
 
       vtn_emit_make_visible_barrier(b, access, scope, src->mode);
 
@@ -2717,6 +2781,7 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
       SpvMemoryAccessMask access;
       SpvScope scope;
       vtn_get_mem_operands(b, w, count, &idx, &access, &alignment, &scope, NULL);
+      dest = vtn_align_pointer(b, dest, alignment);
 
       struct vtn_ssa_value *src = vtn_ssa_value(b, w[2]);
       vtn_variable_store(b, src, dest, spv_access_to_gl_access(access));