nir: Add some generic helpers for writing lowering passes
authorJason Ekstrand <jason@jlekstrand.net>
Thu, 11 Jul 2019 18:00:42 +0000 (13:00 -0500)
committerJason Ekstrand <jason@jlekstrand.net>
Tue, 16 Jul 2019 16:05:16 +0000 (16:05 +0000)
Reviewed-by: Eric Anholt <eric@anholt.net>
src/compiler/nir/nir.c
src/compiler/nir/nir.h

index 9c252a49137bbc4a577eb837d4459d036eb1d492..3e52d0b3cac3ade30a9e1cac7cbc9ee9c7aac44b 100644 (file)
@@ -26,6 +26,7 @@
  */
 
 #include "nir.h"
+#include "nir_builder.h"
 #include "nir_control_flow_private.h"
 #include "util/half_float.h"
 #include <limits.h>
@@ -1786,6 +1787,152 @@ nir_index_instrs(nir_function_impl *impl)
    return index;
 }
 
+static nir_instr *
+cursor_next_instr(nir_cursor cursor)
+{
+   switch (cursor.option) {
+   case nir_cursor_before_block:
+      for (nir_block *block = cursor.block; block;
+           block = nir_block_cf_tree_next(block)) {
+         nir_instr *instr = nir_block_first_instr(block);
+         if (instr)
+            return instr;
+      }
+      return NULL;
+
+   case nir_cursor_after_block:
+      cursor.block = nir_block_cf_tree_next(cursor.block);
+      if (cursor.block == NULL)
+         return NULL;
+
+      cursor.option = nir_cursor_before_block;
+      return cursor_next_instr(cursor);
+
+   case nir_cursor_before_instr:
+      return cursor.instr;
+
+   case nir_cursor_after_instr:
+      if (nir_instr_next(cursor.instr))
+         return nir_instr_next(cursor.instr);
+
+      cursor.option = nir_cursor_after_block;
+      cursor.block = cursor.instr->block;
+      return cursor_next_instr(cursor);
+   }
+
+   unreachable("Inavlid cursor option");
+}
+
+static bool
+dest_is_ssa(nir_dest *dest, void *_state)
+{
+   return dest->is_ssa;
+}
+
+bool
+nir_function_impl_lower_instructions(nir_function_impl *impl,
+                                     nir_instr_filter_cb filter,
+                                     nir_lower_instr_cb lower,
+                                     void *cb_data)
+{
+   nir_builder b;
+   nir_builder_init(&b, impl);
+
+   nir_metadata preserved = nir_metadata_block_index ||
+                            nir_metadata_dominance;
+
+   bool progress = false;
+   nir_cursor iter = nir_before_cf_list(&impl->body);
+   nir_instr *instr;
+   while ((instr = cursor_next_instr(iter)) != NULL) {
+      if (filter && !filter(instr, cb_data)) {
+         iter = nir_after_instr(instr);
+         continue;
+      }
+
+      assert(nir_foreach_dest(instr, dest_is_ssa, NULL));
+      nir_ssa_def *old_def = nir_instr_ssa_def(instr);
+      if (old_def == NULL) {
+         iter = nir_after_instr(instr);
+         continue;
+      }
+
+      /* We're about to ask the callback to generate a replacement for instr.
+       * Save off the uses from instr's SSA def so we know what uses to
+       * rewrite later.  If we use nir_ssa_def_rewrite_uses, it fails in the
+       * case where the generated replacement code uses the result of instr
+       * itself.  If we use nir_ssa_def_rewrite_uses_after (which is the
+       * normal solution to this problem), it doesn't work well if control-
+       * flow is inserted as part of the replacement, doesn't handle cases
+       * where the replacement is something consumed by instr, and suffers
+       * from performance issues.  This is the only way to 100% guarantee
+       * that we rewrite the correct set efficiently.
+       */
+      struct list_head old_uses, old_if_uses;
+      list_replace(&old_def->uses, &old_uses);
+      list_inithead(&old_def->uses);
+      list_replace(&old_def->if_uses, &old_if_uses);
+      list_inithead(&old_def->if_uses);
+
+      b.cursor = nir_after_instr(instr);
+      nir_ssa_def *new_def = lower(&b, instr, cb_data);
+      if (new_def) {
+         assert(old_def != NULL);
+         if (new_def->parent_instr->block != instr->block)
+            preserved = nir_metadata_none;
+
+         nir_src new_src = nir_src_for_ssa(new_def);
+         list_for_each_entry_safe(nir_src, use_src, &old_uses, use_link)
+            nir_instr_rewrite_src(use_src->parent_instr, use_src, new_src);
+
+         list_for_each_entry_safe(nir_src, use_src, &old_if_uses, use_link)
+            nir_if_rewrite_condition(use_src->parent_if, new_src);
+
+         if (list_empty(&old_def->uses) && list_empty(&old_def->if_uses)) {
+            iter = nir_instr_remove(instr);
+         } else {
+            iter = nir_after_instr(instr);
+         }
+         progress = true;
+      } else {
+         /* We didn't end up lowering after all.  Put the uses back */
+         if (old_def) {
+            list_replace(&old_uses, &old_def->uses);
+            list_replace(&old_if_uses, &old_def->if_uses);
+         }
+         iter = nir_after_instr(instr);
+      }
+   }
+
+   if (progress) {
+      nir_metadata_preserve(impl, preserved);
+   } else {
+#ifndef NDEBUG
+      impl->valid_metadata &= ~nir_metadata_not_properly_reset;
+#endif
+   }
+
+   return progress;
+}
+
+bool
+nir_shader_lower_instructions(nir_shader *shader,
+                              nir_instr_filter_cb filter,
+                              nir_lower_instr_cb lower,
+                              void *cb_data)
+{
+   bool progress = false;
+
+   nir_foreach_function(function, shader) {
+      if (function->impl &&
+          nir_function_impl_lower_instructions(function->impl,
+                                               filter, lower, cb_data))
+         progress = true;
+   }
+
+   return progress;
+}
+
 nir_intrinsic_op
 nir_intrinsic_from_system_value(gl_system_value val)
 {
index 3112151526648440b7ef1880402cf6fe74ef54c8..6805d067a80b8f78ce7fc3120f31fc295239a04d 100644 (file)
@@ -3277,6 +3277,51 @@ static inline bool should_print_nir(void) { return false; }
 
 #define NIR_SKIP(name) should_skip_nir(#name)
 
+/** An instruction filtering callback
+ *
+ * Returns true if the instruction should be processed and false otherwise.
+ */
+typedef bool (*nir_instr_filter_cb)(const nir_instr *, const void *);
+
+/** A simple instruction lowering callback
+ *
+ * Many instruction lowering passes can be written as a simple function which
+ * takes an instruction as its input and returns a sequence of instructions
+ * that implement the consumed instruction.  This function type represents
+ * such a lowering function.  When called, a function with this prototype
+ * should either return NULL indicating that no lowering needs to be done or
+ * emit a sequence of instructions using the provided builder (whose cursor
+ * will already be placed after the instruction to be lowered) and return the
+ * resulting nir_ssa_def.
+ */
+typedef nir_ssa_def *(*nir_lower_instr_cb)(struct nir_builder *,
+                                           nir_instr *, void *);
+
+/** Iterate over all the instructions in a nir_function_impl and lower them
+ *  using the provided callbacks
+ *
+ * This function implements the guts of a standard lowering pass for you.  It
+ * iterates over all of the instructions in a nir_function_impl and calls the
+ * filter callback on each one.  If the filter callback returns true, it then
+ * calls the lowering call back on the instruction.  (Splitting it this way
+ * allows us to avoid some save/restore work for instructions we know won't be
+ * lowered.)  If the instruction is dead after the lowering is complete, it
+ * will be removed.  If new instructions are added, the lowering callback will
+ * also be called on them in case multiple lowerings are required.
+ *
+ * The metadata for the nir_function_impl will also be updated.  If any blocks
+ * are added (they cannot be removed), dominance and block indices will be
+ * invalidated.
+ */
+bool nir_function_impl_lower_instructions(nir_function_impl *impl,
+                                          nir_instr_filter_cb filter,
+                                          nir_lower_instr_cb lower,
+                                          void *cb_data);
+bool nir_shader_lower_instructions(nir_shader *shader,
+                                   nir_instr_filter_cb filter,
+                                   nir_lower_instr_cb lower,
+                                   void *cb_data);
+
 void nir_calc_dominance_impl(nir_function_impl *impl);
 void nir_calc_dominance(nir_shader *shader);