nir/opt_deref: Remove restrictive alignment information from casts
[mesa.git] / src / compiler / nir / nir_deref.c
index 5b9f97b6620063cdcdc53a24b702a4dcca4b0778..0d83266c9cdf4626bb6a161e009408435c631476 100644 (file)
@@ -848,6 +848,76 @@ is_deref_ptr_as_array(nir_instr *instr)
           nir_instr_as_deref(instr)->deref_type == nir_deref_type_ptr_as_array;
 }
 
+static bool
+opt_remove_restricting_cast_alignments(nir_deref_instr *cast)
+{
+   assert(cast->deref_type == nir_deref_type_cast);
+   if (cast->cast.align_mul == 0)
+      return false;
+
+   nir_deref_instr *parent = nir_src_as_deref(cast->parent);
+   if (parent == NULL)
+      return false;
+
+   /* Don't use any default alignment for this check.  We don't want to fall
+    * back to type alignment too early in case we find out later that we're
+    * somehow a child of a packed struct.
+    */
+   uint32_t parent_mul, parent_offset;
+   if (!nir_get_explicit_deref_align(parent, false /* default_to_type_align */,
+                                     &parent_mul, &parent_offset))
+      return false;
+
+   /* If this cast increases the alignment, we want to keep it.
+    *
+    * There is a possibility that the larger alignment provided by this cast
+    * somehow disagrees with the smaller alignment further up the deref chain.
+    * In that case, we choose to favor the alignment closer to the actual
+    * memory operation which, in this case, is the cast and not its parent so
+    * keeping the cast alignment is the right thing to do.
+    */
+   if (parent_mul < cast->cast.align_mul)
+      return false;
+
+   /* If we've gotten here, we have a parent deref with an align_mul at least
+    * as large as ours so we can potentially throw away the alignment
+    * information on this deref.  There are two cases to consider here:
+    *
+    *  1. We can chase the deref all the way back to the variable.  In this
+    *     case, we have "perfect" knowledge, modulo indirect array derefs.
+    *     Unless we've done something wrong in our indirect/wildcard stride
+    *     calculations, our knowledge from the deref walk is better than the
+    *     client's.
+    *
+    *  2. We can't chase it all the way back to the variable.  In this case,
+    *     because our call to nir_get_explicit_deref_align(parent, ...) above
+    *     above passes default_to_type_align=false, the only way we can even
+    *     get here is if something further up the deref chain has a cast with
+    *     an alignment which can only happen if we get an alignment from the
+    *     client (most likely a decoration in the SPIR-V).  If the client has
+    *     provided us with two conflicting alignments in the deref chain,
+    *     that's their fault and we can do whatever we want.
+    *
+    * In either case, we should be without our rights, at this point, to throw
+    * away the alignment information on this deref.  However, to be "nice" to
+    * weird clients, we do one more check.  It really shouldn't happen but
+    * it's possible that the parent's alignment offset disagrees with the
+    * cast's alignment offset.  In this case, we consider the cast as
+    * providing more information (or at least more valid information) and keep
+    * it even if the align_mul from the parent is larger.
+    */
+   assert(cast->cast.align_mul <= parent_mul);
+   if (parent_offset % cast->cast.align_mul != cast->cast.align_offset)
+      return false;
+
+   /* If we got here, the parent has better alignment information than the
+    * child and we can get rid of the child alignment information.
+    */
+   cast->cast.align_mul = 0;
+   cast->cast.align_offset = 0;
+   return true;
+}
+
 /**
  * Remove casts that just wrap other casts.
  */
@@ -945,7 +1015,9 @@ opt_replace_struct_wrapper_cast(nir_builder *b, nir_deref_instr *cast)
 static bool
 opt_deref_cast(nir_builder *b, nir_deref_instr *cast)
 {
-   bool progress;
+   bool progress = false;
+
+   progress |= opt_remove_restricting_cast_alignments(cast);
 
    if (opt_replace_struct_wrapper_cast(b, cast))
       return true;
@@ -953,7 +1025,7 @@ opt_deref_cast(nir_builder *b, nir_deref_instr *cast)
    if (opt_remove_sampler_cast(cast))
       return true;
 
-   progress = opt_remove_cast_cast(cast);
+   progress |= opt_remove_cast_cast(cast);
    if (!is_trivial_deref_cast(cast))
       return progress;