glsl: make loop unrolling more like the nir unrolling path
authorTimothy Arceri <tarceri@itsqueeze.com>
Mon, 4 Sep 2017 03:11:49 +0000 (13:11 +1000)
committerTimothy Arceri <tarceri@itsqueeze.com>
Mon, 9 Oct 2017 23:05:37 +0000 (10:05 +1100)
The old code assumed that loop terminators will always be at
the start of the loop, resulting in otherwise unrollable
loops not being unrolled at all. For example the current
code would unroll:

  int j = 0;
  do {
     if (j > 5)
        break;

     ... do stuff ...

     j++;
  } while (j < 4);

But would fail to unroll the following as no iteration limit was
calculated because it failed to find the terminator:

  int j = 0;
  do {
     ... do stuff ...

     j++;
  } while (j < 4);

Also we would fail to unroll the following as we ended up
calculating the iteration limit as 6 rather than 4. The unroll
code then assumed we had 3 terminators rather the 2 as it
wasn't able to determine that "if (j > 5)" was redundant.

  int j = 0;
  do {
     if (j > 5)
        break;

     ... do stuff ...

     if (bool(i))
        break;

     j++;
  } while (j < 4);

This patch changes this pass to be more like the NIR unrolling pass.
With this change we handle loop terminators correctly and also
handle cases where the terminators have instructions in their
branches other than a break.

V2:
- fixed regression where loops with a break in else were never
  unrolled in v1.
- fixed confusing/wrong naming of bools in complex unrolling.

Reviewed-by: Nicolai Hähnle <nicolai.haehnle@amd.com>
Tested-by: Dieter Nützel <Dieter@nuetzel-hh.de>
src/compiler/glsl/loop_analysis.cpp
src/compiler/glsl/loop_analysis.h
src/compiler/glsl/loop_unroll.cpp

index 78279844dc84609247e39eb7c452e44f414967d6..2979e09433fbd8d4b0c0370625c6dc792fabb79d 100644 (file)
@@ -25,7 +25,7 @@
 #include "loop_analysis.h"
 #include "ir_hierarchical_visitor.h"
 
-static bool is_loop_terminator(ir_if *ir);
+static void try_add_loop_terminator(loop_variable_state *ls, ir_if *ir);
 
 static bool all_expression_operands_are_loop_constant(ir_rvalue *,
                                                      hash_table *);
@@ -87,7 +87,7 @@ find_initial_value(ir_loop *loop, ir_variable *var)
 
 static int
 calculate_iterations(ir_rvalue *from, ir_rvalue *to, ir_rvalue *increment,
-                     enum ir_expression_operation op)
+                     enum ir_expression_operation op, bool continue_from_then)
 {
    if (from == NULL || to == NULL || increment == NULL)
       return -1;
@@ -154,8 +154,10 @@ calculate_iterations(ir_rvalue *from, ir_rvalue *to, ir_rvalue *increment,
       ir_expression *const add =
          new(mem_ctx) ir_expression(ir_binop_add, mul->type, mul, from);
 
-      ir_expression *const cmp =
+      ir_expression *cmp =
          new(mem_ctx) ir_expression(op, glsl_type::bool_type, add, to);
+      if (continue_from_then)
+         cmp = new(mem_ctx) ir_expression(ir_unop_logic_not, cmp);
 
       ir_constant *const cmp_result = cmp->constant_expression_value(mem_ctx);
 
@@ -306,12 +308,14 @@ loop_variable_state::insert(ir_variable *var)
 
 
 loop_terminator *
-loop_variable_state::insert(ir_if *if_stmt)
+loop_variable_state::insert(ir_if *if_stmt, bool continue_from_then)
 {
    void *mem_ctx = ralloc_parent(this);
    loop_terminator *t = new(mem_ctx) loop_terminator();
 
    t->ir = if_stmt;
+   t->continue_from_then = continue_from_then;
+
    this->terminators.push_tail(t);
 
    return t;
@@ -468,10 +472,8 @@ loop_analysis::visit_leave(ir_loop *ir)
 
       ir_if *if_stmt = ((ir_instruction *) node)->as_if();
 
-      if ((if_stmt != NULL) && is_loop_terminator(if_stmt))
-        ls->insert(if_stmt);
-      else
-        break;
+      if (if_stmt != NULL)
+         try_add_loop_terminator(ls, if_stmt);
    }
 
 
@@ -614,7 +616,7 @@ loop_analysis::visit_leave(ir_loop *ir)
          loop_variable *lv = ls->get(var);
          if (lv != NULL && lv->is_induction_var()) {
             t->iterations = calculate_iterations(init, limit, lv->increment,
-                                                 cmp);
+                                                 cmp, t->continue_from_then);
 
             if (incremented_before_terminator(ir, var, t->ir)) {
                t->iterations--;
@@ -781,31 +783,26 @@ get_basic_induction_increment(ir_assignment *ir, hash_table *var_hash)
 
 
 /**
- * Detect whether an if-statement is a loop terminating condition
+ * Detect whether an if-statement is a loop terminating condition, if so
+ * add it to the list of loop terminators.
  *
  * Detects if-statements of the form
  *
- *  (if (expression bool ...) (break))
+ *  (if (expression bool ...) (...then_instrs...break))
+ *
+ *     or
+ *
+ *  (if (expression bool ...) ... (...else_instrs...break))
  */
-bool
-is_loop_terminator(ir_if *ir)
+void
+try_add_loop_terminator(loop_variable_state *ls, ir_if *ir)
 {
-   if (!ir->else_instructions.is_empty())
-      return false;
-
-   ir_instruction *const inst =
-      (ir_instruction *) ir->then_instructions.get_head();
-   if (inst == NULL)
-      return false;
-
-   if (inst->ir_type != ir_type_loop_jump)
-      return false;
-
-   ir_loop_jump *const jump = (ir_loop_jump *) inst;
-   if (jump->mode != ir_loop_jump::jump_break)
-      return false;
+   ir_instruction *inst = (ir_instruction *) ir->then_instructions.get_tail();
+   ir_instruction *else_inst =
+      (ir_instruction *) ir->else_instructions.get_tail();
 
-   return true;
+   if (is_break(inst) || is_break(else_inst))
+      ls->insert(ir, is_break(else_inst));
 }
 
 
index 99b6bf756389db3b3400cedd6ecbf589cbf924f6..4e1100184611f5dbec7399b01d0db4e78609d904 100644 (file)
@@ -55,7 +55,7 @@ public:
    class loop_variable *get(const ir_variable *);
    class loop_variable *insert(ir_variable *);
    class loop_variable *get_or_insert(ir_variable *, bool in_assignee);
-   class loop_terminator *insert(ir_if *);
+   class loop_terminator *insert(ir_if *, bool continue_from_then);
 
 
    /**
@@ -210,6 +210,9 @@ public:
     * terminate the loop (if that is a fixed value).  Otherwise -1.
     */
    int iterations;
+
+   /* Does the if continue from the then branch or the else branch */
+   bool continue_from_then;
 };
 
 
index 358cbf10af474dce563aed466607b20b7c8ea96d..6e06a30fb9150f95a3f8042357da1c15e9bcb33a 100644 (file)
@@ -42,7 +42,9 @@ public:
    virtual ir_visitor_status visit_leave(ir_loop *ir);
    void simple_unroll(ir_loop *ir, int iterations);
    void complex_unroll(ir_loop *ir, int iterations,
-                       bool continue_from_then_branch);
+                       bool continue_from_then_branch,
+                       bool limiting_term_first,
+                       bool lt_continue_from_then_branch);
    void splice_post_if_instructions(ir_if *ir_if, exec_list *splice_dest);
 
    loop_state *state;
@@ -176,6 +178,51 @@ void
 loop_unroll_visitor::simple_unroll(ir_loop *ir, int iterations)
 {
    void *const mem_ctx = ralloc_parent(ir);
+   loop_variable_state *const ls = this->state->get(ir);
+
+   ir_instruction *first_ir =
+      (ir_instruction *) ir->body_instructions.get_head();
+
+   if (!first_ir) {
+      /* The loop is empty remove it and return */
+      ir->remove();
+      return;
+   }
+
+   ir_if *limit_if = NULL;
+   bool exit_branch_has_instructions = false;
+   if (ls->limiting_terminator) {
+      limit_if = ls->limiting_terminator->ir;
+      ir_instruction *ir_if_last = (ir_instruction *)
+         limit_if->then_instructions.get_tail();
+
+      if (is_break(ir_if_last)) {
+         if (ir_if_last != limit_if->then_instructions.get_head())
+            exit_branch_has_instructions = true;
+
+         splice_post_if_instructions(limit_if, &limit_if->else_instructions);
+         ir_if_last->remove();
+      } else {
+         ir_if_last = (ir_instruction *)
+            limit_if->else_instructions.get_tail();
+         assert(is_break(ir_if_last));
+
+         if (ir_if_last != limit_if->else_instructions.get_head())
+            exit_branch_has_instructions = true;
+
+         splice_post_if_instructions(limit_if, &limit_if->then_instructions);
+         ir_if_last->remove();
+      }
+   }
+
+   /* Because 'iterations' is the number of times we pass over the *entire*
+    * loop body before hitting the first break, we need to bump the number of
+    * iterations if the limiting terminator is not the first instruction in
+    * the loop, or it the exit branch contains instructions. This ensures we
+    * execute any instructions before the terminator or in its exit branch.
+    */
+   if (limit_if != first_ir->as_if() || exit_branch_has_instructions)
+      iterations++;
 
    for (int i = 0; i < iterations; i++) {
       exec_list copy_list;
@@ -227,11 +274,22 @@ loop_unroll_visitor::simple_unroll(ir_loop *ir, int iterations)
  */
 void
 loop_unroll_visitor::complex_unroll(ir_loop *ir, int iterations,
-                                    bool continue_from_then_branch)
+                                    bool second_term_then_continue,
+                                    bool extra_iteration_required,
+                                    bool first_term_then_continue)
 {
    void *const mem_ctx = ralloc_parent(ir);
    ir_instruction *ir_to_replace = ir;
 
+   /* Because 'iterations' is the number of times we pass over the *entire*
+    * loop body before hitting the first break, we need to bump the number of
+    * iterations if the limiting terminator is not the first instruction in
+    * the loop, or it the exit branch contains instructions. This ensures we
+    * execute any instructions before the terminator or in its exit branch.
+    */
+   if (extra_iteration_required)
+      iterations++;
+
    for (int i = 0; i < iterations; i++) {
       exec_list copy_list;
 
@@ -241,6 +299,10 @@ loop_unroll_visitor::complex_unroll(ir_loop *ir, int iterations,
       ir_if *ir_if = ((ir_instruction *) copy_list.get_tail())->as_if();
       assert(ir_if != NULL);
 
+      exec_list *const first_list = first_term_then_continue
+         ? &ir_if->then_instructions : &ir_if->else_instructions;
+      ir_if = ((ir_instruction *) first_list->get_tail())->as_if();
+
       ir_to_replace->insert_before(&copy_list);
       ir_to_replace->remove();
 
@@ -248,10 +310,10 @@ loop_unroll_visitor::complex_unroll(ir_loop *ir, int iterations,
       ir_to_replace =
          new(mem_ctx) ir_loop_jump(ir_loop_jump::jump_continue);
 
-      exec_list *const list = (continue_from_then_branch)
+      exec_list *const second_term_continue_list = second_term_then_continue
          ? &ir_if->then_instructions : &ir_if->else_instructions;
 
-      list->push_tail(ir_to_replace);
+      second_term_continue_list->push_tail(ir_to_replace);
    }
 
    ir_to_replace->remove();
@@ -293,6 +355,21 @@ loop_unroll_visitor::splice_post_if_instructions(ir_if *ir_if,
    }
 }
 
+static bool
+exit_branch_has_instructions(ir_if *term_if, bool lt_then_continue)
+{
+   if (lt_then_continue) {
+      if (term_if->else_instructions.get_head() ==
+          term_if->else_instructions.get_tail())
+         return false;
+   } else {
+      if (term_if->then_instructions.get_head() ==
+          term_if->then_instructions.get_tail())
+         return false;
+   }
+
+   return true;
+}
 
 ir_visitor_status
 loop_unroll_visitor::visit_leave(ir_loop *ir)
@@ -417,7 +494,6 @@ loop_unroll_visitor::visit_leave(ir_loop *ir)
       return visit_continue;
 
    if (predicted_num_loop_jumps == 0) {
-      ls->limiting_terminator->ir->remove();
       simple_unroll(ir, iterations);
       return visit_continue;
    }
@@ -432,51 +508,71 @@ loop_unroll_visitor::visit_leave(ir_loop *ir)
        */
       last_ir->remove();
 
-      ls->limiting_terminator->ir->remove();
       simple_unroll(ir, 1);
       return visit_continue;
    }
 
-   /* recognize loops in the form produced by ir_lower_jumps */
-   foreach_in_list(ir_instruction, cur_ir, &ir->body_instructions) {
-      /* Skip the limiting terminator, since it will go away when we
-       * unroll.
-       */
-      if (cur_ir == ls->limiting_terminator->ir)
-         continue;
+   /* Complex unrolling can only handle two terminators. One with an unknown
+    * iteration count and one with a known iteration count. We have already
+    * made sure we have a known iteration count above and removed any
+    * unreachable terminators with a known count. Here we make sure there
+    * isn't any additional unknown terminators, or any other jumps nested
+    * inside futher ifs.
+    */
+   if (ls->num_loop_jumps != 2)
+      return visit_continue;
 
-      ir_if *ir_if = cur_ir->as_if();
-      if (ir_if != NULL) {
-         /* Determine which if-statement branch, if any, ends with a
-          * break.  The branch that did *not* have the break will get a
-          * temporary continue inserted in each iteration of the loop
-          * unroll.
-          *
-          * Note that since ls->num_loop_jumps is <= 1, it is impossible
-          * for both branches to end with a break.
-          */
-         ir_instruction *ir_if_last =
-            (ir_instruction *) ir_if->then_instructions.get_tail();
+   ir_instruction *first_ir =
+      (ir_instruction *) ir->body_instructions.get_head();
 
+   unsigned term_count = 0;
+   bool first_term_then_continue = false;
+   foreach_in_list(loop_terminator, t, &ls->terminators) {
+      assert(term_count < 2);
+
+      ir_if *ir_if = t->ir->as_if();
+      assert(ir_if != NULL);
+
+      ir_instruction *ir_if_last =
+         (ir_instruction *) ir_if->then_instructions.get_tail();
+
+      if (is_break(ir_if_last)) {
+         splice_post_if_instructions(ir_if, &ir_if->else_instructions);
+         ir_if_last->remove();
+         if (term_count == 1) {
+            bool ebi =
+               exit_branch_has_instructions(ls->limiting_terminator->ir,
+                                            first_term_then_continue);
+            complex_unroll(ir, iterations, false,
+                           first_ir->as_if() != ls->limiting_terminator->ir ||
+                           ebi,
+                           first_term_then_continue);
+            return visit_continue;
+         }
+      } else {
+         ir_if_last =
+            (ir_instruction *) ir_if->else_instructions.get_tail();
+
+         assert(is_break(ir_if_last));
          if (is_break(ir_if_last)) {
-            ls->limiting_terminator->ir->remove();
-            splice_post_if_instructions(ir_if, &ir_if->else_instructions);
+            splice_post_if_instructions(ir_if, &ir_if->then_instructions);
             ir_if_last->remove();
-            complex_unroll(ir, iterations, false);
-            return visit_continue;
-         } else {
-            ir_if_last =
-               (ir_instruction *) ir_if->else_instructions.get_tail();
-
-            if (is_break(ir_if_last)) {
-               ls->limiting_terminator->ir->remove();
-               splice_post_if_instructions(ir_if, &ir_if->then_instructions);
-               ir_if_last->remove();
-               complex_unroll(ir, iterations, true);
+            if (term_count == 1) {
+               bool ebi =
+                  exit_branch_has_instructions(ls->limiting_terminator->ir,
+                                               first_term_then_continue);
+               complex_unroll(ir, iterations, true,
+                              first_ir->as_if() != ls->limiting_terminator->ir ||
+                              ebi,
+                              first_term_then_continue);
                return visit_continue;
+            } else {
+               first_term_then_continue = true;
             }
          }
       }
+
+      term_count++;
    }
 
    /* Did not find the break statement.  It must be in a complex if-nesting,