nir: Add a lowering pass to split 64bit phis
[mesa.git] / src / compiler / nir / nir_lower_to_source_mods.c
index 6c4e1f0d3f349e838c5470c6716a9066ec91483a..ab9581dcc3d98f46b4d94cb6a1a8e3481728ab80 100644 (file)
  * easier to not have them when we're doing optimizations.
  */
 
+static void
+alu_src_consume_abs(nir_alu_src *src)
+{
+   src->abs = true;
+}
+
+static void
+alu_src_consume_negate(nir_alu_src *src)
+{
+   /* If abs is set on the source, the negate goes away */
+   if (!src->abs)
+      src->negate = !src->negate;
+}
+
 static bool
-nir_lower_to_source_mods_block(nir_block *block, void *state)
+nir_lower_to_source_mods_block(nir_block *block,
+                               nir_lower_to_source_mods_flags options)
 {
-   nir_foreach_instr(block, instr) {
+   bool progress = false;
+
+   nir_foreach_instr(instr, block) {
       if (instr->type != nir_instr_type_alu)
          continue;
 
       nir_alu_instr *alu = nir_instr_as_alu(instr);
 
+      bool lower_abs = (nir_op_infos[alu->op].num_inputs < 3) ||
+            (options & nir_lower_triop_abs);
+
       for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) {
          if (!alu->src[i].src.is_ssa)
             continue;
@@ -54,13 +74,17 @@ nir_lower_to_source_mods_block(nir_block *block, void *state)
          if (parent->dest.saturate)
             continue;
 
-         switch (nir_op_infos[alu->op].input_types[i]) {
+         switch (nir_alu_type_get_base_type(nir_op_infos[alu->op].input_types[i])) {
          case nir_type_float:
-            if (parent->op != nir_op_fmov)
+            if (!(options & nir_lower_float_source_mods))
+               continue;
+            if (parent->op != nir_op_fabs && parent->op != nir_op_fneg)
                continue;
             break;
          case nir_type_int:
-            if (parent->op != nir_op_imov)
+            if (!(options & nir_lower_int_source_mods))
+               continue;
+            if (parent->op != nir_op_iabs && parent->op != nir_op_ineg)
                continue;
             break;
          default:
@@ -74,13 +98,23 @@ nir_lower_to_source_mods_block(nir_block *block, void *state)
          if (!parent->src[0].src.is_ssa)
             continue;
 
+         if (!lower_abs && (parent->op == nir_op_fabs ||
+                            parent->op == nir_op_iabs))
+            continue;
+
          nir_instr_rewrite_src(instr, &alu->src[i].src, parent->src[0].src);
-         if (alu->src[i].abs) {
-            /* abs trumps both neg and abs, do nothing */
-         } else {
-            alu->src[i].negate = (alu->src[i].negate != parent->src[0].negate);
-            alu->src[i].abs |= parent->src[0].abs;
-         }
+
+         /* Apply any modifiers that come from the parent opcode */
+         if (parent->op == nir_op_fneg || parent->op == nir_op_ineg)
+            alu_src_consume_negate(&alu->src[i]);
+         if (parent->op == nir_op_fabs || parent->op == nir_op_iabs)
+            alu_src_consume_abs(&alu->src[i]);
+
+         /* Apply modifiers from the parent source */
+         if (parent->src[0].negate)
+            alu_src_consume_negate(&alu->src[i]);
+         if (parent->src[0].abs)
+            alu_src_consume_abs(&alu->src[i]);
 
          for (int j = 0; j < 4; ++j) {
             if (!nir_alu_instr_channel_used(alu, i, j))
@@ -88,36 +122,11 @@ nir_lower_to_source_mods_block(nir_block *block, void *state)
             alu->src[i].swizzle[j] = parent->src[0].swizzle[alu->src[i].swizzle[j]];
          }
 
-         if (list_empty(&parent->dest.dest.ssa.uses) &&
-             list_empty(&parent->dest.dest.ssa.if_uses))
+         if (list_is_empty(&parent->dest.dest.ssa.uses) &&
+             list_is_empty(&parent->dest.dest.ssa.if_uses))
             nir_instr_remove(&parent->instr);
-      }
 
-      switch (alu->op) {
-      case nir_op_fsat:
-         alu->op = nir_op_fmov;
-         alu->dest.saturate = true;
-         break;
-      case nir_op_ineg:
-         alu->op = nir_op_imov;
-         alu->src[0].negate = !alu->src[0].negate;
-         break;
-      case nir_op_fneg:
-         alu->op = nir_op_fmov;
-         alu->src[0].negate = !alu->src[0].negate;
-         break;
-      case nir_op_iabs:
-         alu->op = nir_op_imov;
-         alu->src[0].abs = true;
-         alu->src[0].negate = false;
-         break;
-      case nir_op_fabs:
-         alu->op = nir_op_fmov;
-         alu->src[0].abs = true;
-         alu->src[0].negate = false;
-         break;
-      default:
-         break;
+         progress = true;
       }
 
       /* We've covered sources.  Now we're going to try and saturate the
@@ -128,14 +137,18 @@ nir_lower_to_source_mods_block(nir_block *block, void *state)
          continue;
 
       /* We can only saturate float destinations */
-      if (nir_op_infos[alu->op].output_type != nir_type_float)
+      if (nir_alu_type_get_base_type(nir_op_infos[alu->op].output_type) !=
+          nir_type_float)
+         continue;
+
+      if (!(options & nir_lower_float_source_mods))
          continue;
 
-      if (!list_empty(&alu->dest.dest.ssa.if_uses))
+      if (!list_is_empty(&alu->dest.dest.ssa.if_uses))
          continue;
 
       bool all_children_are_sat = true;
-      nir_foreach_use(&alu->dest.dest.ssa, child_src) {
+      nir_foreach_use(child_src, &alu->dest.dest.ssa) {
          assert(child_src->is_ssa);
          nir_instr *child = child_src->parent_instr;
          if (child->type != nir_instr_type_alu) {
@@ -149,8 +162,7 @@ nir_lower_to_source_mods_block(nir_block *block, void *state)
             continue;
          }
 
-         if (child_alu->op != nir_op_fsat &&
-             !(child_alu->op == nir_op_fmov && child_alu->dest.saturate)) {
+         if (child_alu->op != nir_op_fsat) {
             all_children_are_sat = false;
             continue;
          }
@@ -160,14 +172,13 @@ nir_lower_to_source_mods_block(nir_block *block, void *state)
          continue;
 
       alu->dest.saturate = true;
+      progress = true;
 
-      nir_foreach_use(&alu->dest.dest.ssa, child_src) {
+      nir_foreach_use(child_src, &alu->dest.dest.ssa) {
          assert(child_src->is_ssa);
-         nir_instr *child = child_src->parent_instr;
-         assert(child->type == nir_instr_type_alu);
-         nir_alu_instr *child_alu = nir_instr_as_alu(child);
+         nir_alu_instr *child_alu = nir_instr_as_alu(child_src->parent_instr);
 
-         child_alu->op = nir_op_fmov;
+         child_alu->op = nir_op_mov;
          child_alu->dest.saturate = false;
          /* We could propagate the dest of our instruction to the
           * destinations of the uses here.  However, one quick round of
@@ -177,20 +188,37 @@ nir_lower_to_source_mods_block(nir_block *block, void *state)
       }
    }
 
-   return true;
+   return progress;
 }
 
-static void
-nir_lower_to_source_mods_impl(nir_function_impl *impl)
+static bool
+nir_lower_to_source_mods_impl(nir_function_impl *impl,
+                              nir_lower_to_source_mods_flags options)
 {
-   nir_foreach_block(impl, nir_lower_to_source_mods_block, NULL);
+   bool progress = false;
+
+   nir_foreach_block(block, impl) {
+      progress |= nir_lower_to_source_mods_block(block, options);
+   }
+
+   if (progress)
+      nir_metadata_preserve(impl, nir_metadata_block_index |
+                                  nir_metadata_dominance);
+
+   return progress;
 }
 
-void
-nir_lower_to_source_mods(nir_shader *shader)
+bool
+nir_lower_to_source_mods(nir_shader *shader,
+                         nir_lower_to_source_mods_flags options)
 {
-   nir_foreach_function(shader, function) {
-      if (function->impl)
-         nir_lower_to_source_mods_impl(function->impl);
+   bool progress = false;
+
+   nir_foreach_function(function, shader) {
+      if (function->impl) {
+         progress |= nir_lower_to_source_mods_impl(function->impl, options);
+      }
    }
+
+   return progress;
 }