nir: Copy "patch" flag from ir_variable to nir_variable.
[mesa.git] / src / glsl / nir / nir_algebraic.py
index 72703beea8a2eb8693a9bb8393e1d17a17e8aaf4..bbf4f08ef92856531c343fa67a692c34a06cf2cb 100644 (file)
@@ -28,19 +28,25 @@ import itertools
 import struct
 import sys
 import mako.template
+import re
 
 # Represents a set of variables, each with a unique id
 class VarSet(object):
    def __init__(self):
       self.names = {}
       self.ids = itertools.count()
+      self.immutable = False;
 
    def __getitem__(self, name):
       if name not in self.names:
+         assert not self.immutable, "Unknown replacement variable: " + name
          self.names[name] = self.ids.next()
 
       return self.names[name]
 
+   def lock(self):
+      self.immutable = True
+
 class Value(object):
    @staticmethod
    def create(val, name_base, varset):
@@ -60,6 +66,8 @@ static const ${val.c_type} ${val.name} = {
    { ${hex(val)} /* ${val.value} */ },
 % elif isinstance(val, Variable):
    ${val.index}, /* ${val.var_name} */
+   ${'true' if val.is_constant else 'false'},
+   nir_type_${ val.required_type or 'invalid' },
 % elif isinstance(val, Expression):
    nir_op_${val.opcode},
    { ${', '.join(src.c_ptr for src in val.sources)} },
@@ -106,12 +114,23 @@ class Constant(Value):
       else:
          assert False
 
+_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)(?:@(?P<type>\w+))?")
+
 class Variable(Value):
    def __init__(self, val, name, varset):
       Value.__init__(self, name, "variable")
-      self.var_name = val
-      self.index = varset[val]
-      self.name = name
+
+      m = _var_name_re.match(val)
+      assert m and m.group('name') is not None
+
+      self.var_name = m.group('name')
+      self.is_constant = m.group('const') is not None
+      self.required_type = m.group('type')
+
+      if self.required_type is not None:
+         assert self.required_type in ('float', 'bool', 'int', 'unsigned')
+
+      self.index = varset[self.var_name]
 
 class Expression(Value):
    def __init__(self, expr, name_base, varset):
@@ -128,16 +147,31 @@ class Expression(Value):
 
 _optimization_ids = itertools.count()
 
+condition_list = ['true']
+
 class SearchAndReplace(object):
-   def __init__(self, search, replace):
+   def __init__(self, transform):
       self.id = _optimization_ids.next()
 
+      search = transform[0]
+      replace = transform[1]
+      if len(transform) > 2:
+         self.condition = transform[2]
+      else:
+         self.condition = 'true'
+
+      if self.condition not in condition_list:
+         condition_list.append(self.condition)
+      self.condition_index = condition_list.index(self.condition)
+
       varset = VarSet()
       if isinstance(search, Expression):
          self.search = search
       else:
          self.search = Expression(search, "search{0}".format(self.id), varset)
 
+      varset.lock()
+
       if isinstance(replace, Value):
          self.replace = replace
       else:
@@ -147,32 +181,36 @@ _algebraic_pass_template = mako.template.Template("""
 #include "nir.h"
 #include "nir_search.h"
 
+#ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
+#define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
+
 struct transform {
    const nir_search_expression *search;
    const nir_search_value *replace;
+   unsigned condition_offset;
+};
+
+struct opt_state {
+   void *mem_ctx;
+   bool progress;
+   const bool *condition_flags;
 };
 
+#endif
+
 % for (opcode, xform_list) in xform_dict.iteritems():
 % for xform in xform_list:
    ${xform.search.render()}
    ${xform.replace.render()}
 % endfor
 
-static const struct {
-   const nir_search_expression *search;
-   const nir_search_value *replace;
-} ${pass_name}_${opcode}_xforms[] = {
+static const struct transform ${pass_name}_${opcode}_xforms[] = {
 % for xform in xform_list:
-   { &${xform.search.name}, ${xform.replace.c_ptr} },
+   { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} },
 % endfor
 };
 % endfor
 
-struct opt_state {
-   void *mem_ctx;
-   bool progress;
-};
-
 static bool
 ${pass_name}_block(nir_block *block, void *void_state)
 {
@@ -190,10 +228,13 @@ ${pass_name}_block(nir_block *block, void *void_state)
       % for opcode in xform_dict.keys():
       case nir_op_${opcode}:
          for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
-            if (nir_replace_instr(alu, ${pass_name}_${opcode}_xforms[i].search,
-                                  ${pass_name}_${opcode}_xforms[i].replace,
-                                  state->mem_ctx))
+            const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
+            if (state->condition_flags[xform->condition_offset] &&
+                nir_replace_instr(alu, xform->search, xform->replace,
+                                  state->mem_ctx)) {
                state->progress = true;
+               break;
+            }
          }
          break;
       % endfor
@@ -206,12 +247,13 @@ ${pass_name}_block(nir_block *block, void *void_state)
 }
 
 static bool
-${pass_name}_impl(nir_function_impl *impl)
+${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
 {
    struct opt_state state;
 
    state.mem_ctx = ralloc_parent(impl);
    state.progress = false;
+   state.condition_flags = condition_flags;
 
    nir_foreach_block(impl, ${pass_name}_block, &state);
 
@@ -222,14 +264,21 @@ ${pass_name}_impl(nir_function_impl *impl)
    return state.progress;
 }
 
+
 bool
 ${pass_name}(nir_shader *shader)
 {
    bool progress = false;
+   bool condition_flags[${len(condition_list)}];
+   const nir_shader_compiler_options *options = shader->options;
+
+   % for index, condition in enumerate(condition_list):
+   condition_flags[${index}] = ${condition};
+   % endfor
 
    nir_foreach_overload(shader, overload) {
       if (overload->impl)
-         progress |= ${pass_name}_impl(overload->impl);
+         progress |= ${pass_name}_impl(overload->impl, condition_flags);
    }
 
    return progress;
@@ -243,7 +292,7 @@ class AlgebraicPass(object):
 
       for xform in transforms:
          if not isinstance(xform, SearchAndReplace):
-            xform = SearchAndReplace(*xform)
+            xform = SearchAndReplace(xform)
 
          if xform.search.opcode not in self.xform_dict:
             self.xform_dict[xform.search.opcode] = []
@@ -252,4 +301,5 @@ class AlgebraicPass(object):
 
    def render(self):
       return _algebraic_pass_template.render(pass_name=self.pass_name,
-                                             xform_dict=self.xform_dict)
+                                             xform_dict=self.xform_dict,
+                                             condition_list=condition_list)