glsl: Improve the local dead code optimization to eliminate unused channels.
[mesa.git] / src / glsl / opt_dead_code_local.cpp
index a81a38fff0fb097077a85a847870a3a7b511f32a..4af78a72cc387c1d328bb0d3d6b18e9f5d23f65a 100644 (file)
@@ -43,16 +43,20 @@ static bool debug = false;
 class assignment_entry : public exec_node
 {
 public:
-   assignment_entry(ir_variable *lhs, ir_instruction *ir)
+   assignment_entry(ir_variable *lhs, ir_assignment *ir)
    {
       assert(lhs);
       assert(ir);
       this->lhs = lhs;
       this->ir = ir;
+      this->available = ir->write_mask;
    }
 
    ir_variable *lhs;
-   ir_instruction *ir;
+   ir_assignment *ir;
+
+   /* bitmask of xyzw channels written that haven't been used so far. */
+   int available;
 };
 
 class kill_for_derefs_visitor : public ir_hierarchical_visitor {
@@ -62,23 +66,52 @@ public:
       this->assignments = assignments;
    }
 
-   virtual ir_visitor_status visit(ir_dereference_variable *ir)
+   void kill_channels(ir_variable *const var, int used)
    {
-      ir_variable *const var = ir->variable_referenced();
-
       foreach_iter(exec_list_iterator, iter, *this->assignments) {
         assignment_entry *entry = (assignment_entry *)iter.get();
 
         if (entry->lhs == var) {
-           if (debug)
-              printf("kill %s\n", entry->lhs->name);
-           entry->remove();
+           if (var->type->is_scalar() || var->type->is_vector()) {
+              if (debug)
+                 printf("kill %s (0x%01x - 0x%01x)\n", entry->lhs->name,
+                        entry->available, used);
+              entry->available &= ~used;
+              if (!entry->available)
+                 entry->remove();
+           } else {
+              if (debug)
+                 printf("kill %s\n", entry->lhs->name);
+              entry->remove();
+           }
         }
       }
+   }
+
+   virtual ir_visitor_status visit(ir_dereference_variable *ir)
+   {
+      kill_channels(ir->var, ~0);
 
       return visit_continue;
    }
 
+   virtual ir_visitor_status visit(ir_swizzle *ir)
+   {
+      ir_dereference_variable *deref = ir->val->as_dereference_variable();
+      if (!deref)
+        return visit_continue;
+
+      int used = 0;
+      used |= 1 << ir->mask.x;
+      used |= 1 << ir->mask.y;
+      used |= 1 << ir->mask.z;
+      used |= 1 << ir->mask.w;
+
+      kill_channels(deref->var, used);
+
+      return visit_continue_with_parent;
+   }
+
 private:
    exec_list *assignments;
 };
@@ -130,21 +163,91 @@ process_assignment(void *ctx, ir_assignment *ir, exec_list *assignments)
    assert(var);
 
    /* Now, check if we did a whole-variable assignment. */
-   if (!ir->condition && (ir->whole_variable_written() != NULL)) {
-      /* We did a whole-variable assignment.  So, any instruction in
-       * the assignment list with the same LHS is dead.
-       */
-      if (debug)
-        printf("looking for %s to remove\n", var->name);
-      foreach_iter(exec_list_iterator, iter, *assignments) {
-        assignment_entry *entry = (assignment_entry *)iter.get();
+   if (!ir->condition) {
+      ir_dereference_variable *deref_var = ir->lhs->as_dereference_variable();
 
-        if (entry->lhs == var) {
-           if (debug)
-              printf("removing %s\n", var->name);
-           entry->ir->remove();
-           entry->remove();
-           progress = true;
+      /* If it's a vector type, we can do per-channel elimination of
+       * use of the RHS.
+       */
+      if (deref_var && (deref_var->var->type->is_scalar() ||
+                       deref_var->var->type->is_vector())) {
+
+        if (debug)
+           printf("looking for %s.0x%01x to remove\n", var->name,
+                  ir->write_mask);
+
+        foreach_iter(exec_list_iterator, iter, *assignments) {
+           assignment_entry *entry = (assignment_entry *)iter.get();
+
+           if (entry->lhs != var)
+              continue;
+
+           int remove = entry->available & ir->write_mask;
+           if (debug) {
+              printf("%s 0x%01x - 0x%01x = 0x%01x\n",
+                     var->name,
+                     entry->ir->write_mask,
+                     remove, entry->ir->write_mask & ~remove);
+           }
+           if (remove) {
+              progress = true;
+
+              if (debug) {
+                 printf("rewriting:\n  ");
+                 entry->ir->print();
+                 printf("\n");
+              }
+
+              entry->ir->write_mask &= ~remove;
+              entry->available &= ~remove;
+              if (entry->ir->write_mask == 0) {
+                 /* Delete the dead assignment. */
+                 entry->ir->remove();
+                 entry->remove();
+              } else {
+                 void *mem_ctx = ralloc_parent(entry->ir);
+                 /* Reswizzle the RHS arguments according to the new
+                  * write_mask.
+                  */
+                 unsigned components[4];
+                 unsigned channels = 0;
+                 unsigned next = 0;
+
+                 for (int i = 0; i < 4; i++) {
+                    if ((entry->ir->write_mask | remove) & (1 << i)) {
+                       if (!(remove & (1 << i)))
+                          components[channels++] = next;
+                       next++;
+                    }
+                 }
+
+                 entry->ir->rhs = new(mem_ctx) ir_swizzle(entry->ir->rhs,
+                                                          components,
+                                                          channels);
+                 if (debug) {
+                    printf("to:\n  ");
+                    entry->ir->print();
+                    printf("\n");
+                 }
+              }
+           }
+        }
+      } else if (ir->whole_variable_written() != NULL) {
+        /* We did a whole-variable assignment.  So, any instruction in
+         * the assignment list with the same LHS is dead.
+         */
+        if (debug)
+           printf("looking for %s to remove\n", var->name);
+        foreach_iter(exec_list_iterator, iter, *assignments) {
+           assignment_entry *entry = (assignment_entry *)iter.get();
+
+           if (entry->lhs == var) {
+              if (debug)
+                 printf("removing %s\n", var->name);
+              entry->ir->remove();
+              entry->remove();
+              progress = true;
+           }
         }
       }
    }
@@ -160,7 +263,7 @@ process_assignment(void *ctx, ir_assignment *ir, exec_list *assignments)
       foreach_iter(exec_list_iterator, iter, *assignments) {
         assignment_entry *entry = (assignment_entry *)iter.get();
 
-        printf("    %s\n", entry->lhs->name);
+        printf("    %s (0x%01x)\n", entry->lhs->name, entry->available);
       }
    }