3 # Copyright (C) 2014 Intel Corporation
5 # Permission is hereby granted, free of charge, to any person obtaining a
6 # copy of this software and associated documentation files (the "Software"),
7 # to deal in the Software without restriction, including without limitation
8 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
9 # and/or sell copies of the Software, and to permit persons to whom the
10 # Software is furnished to do so, subject to the following conditions:
12 # The above copyright notice and this permission notice (including the next
13 # paragraph) shall be included in all copies or substantial portions of the
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
19 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
25 # Jason Ekstrand (jason@jlekstrand.net)
27 from __future__
import print_function
36 from nir_opcodes
import opcodes
38 _type_re
= re
.compile(r
"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
40 def type_bits(type_str
):
41 m
= _type_re
.match(type_str
)
42 assert m
.group('type')
44 if m
.group('bits') is None:
47 return int(m
.group('bits'))
49 # Represents a set of variables, each with a unique id
53 self
.ids
= itertools
.count()
54 self
.immutable
= False;
56 def __getitem__(self
, name
):
57 if name
not in self
.names
:
58 assert not self
.immutable
, "Unknown replacement variable: " + name
59 self
.names
[name
] = self
.ids
.next()
61 return self
.names
[name
]
68 def create(val
, name_base
, varset
):
69 if isinstance(val
, tuple):
70 return Expression(val
, name_base
, varset
)
71 elif isinstance(val
, Expression
):
73 elif isinstance(val
, (str, unicode)):
74 return Variable(val
, name_base
, varset
)
75 elif isinstance(val
, (bool, int, long, float)):
76 return Constant(val
, name_base
)
78 __template
= mako
.template
.Template("""
79 #include "compiler/nir/nir_search_helpers.h"
80 static const ${val.c_type} ${val.name} = {
81 { ${val.type_enum}, ${val.bit_size} },
82 % if isinstance(val, Constant):
83 ${val.type()}, { ${hex(val)} /* ${val.value} */ },
84 % elif isinstance(val, Variable):
85 ${val.index}, /* ${val.var_name} */
86 ${'true' if val.is_constant else 'false'},
87 ${val.type() or 'nir_type_invalid' },
88 ${val.cond if val.cond else 'NULL'},
89 % elif isinstance(val, Expression):
90 ${'true' if val.inexact else 'false'},
92 { ${', '.join(src.c_ptr for src in val.sources)} },
93 ${val.cond if val.cond else 'NULL'},
97 def __init__(self
, name
, type_str
):
99 self
.type_str
= type_str
103 return "nir_search_value_" + self
.type_str
107 return "nir_search_" + self
.type_str
111 return "&{0}.value".format(self
.name
)
114 return self
.__template
.render(val
=self
,
117 Expression
=Expression
)
119 _constant_re
= re
.compile(r
"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
121 class Constant(Value
):
122 def __init__(self
, val
, name
):
123 Value
.__init
__(self
, name
, "constant")
125 if isinstance(val
, (str)):
126 m
= _constant_re
.match(val
)
127 self
.value
= ast
.literal_eval(m
.group('value'))
128 self
.bit_size
= int(m
.group('bits')) if m
.group('bits') else 0
133 if isinstance(self
.value
, bool):
134 assert self
.bit_size
== 0 or self
.bit_size
== 32
138 if isinstance(self
.value
, (bool)):
139 return 'NIR_TRUE' if self
.value
else 'NIR_FALSE'
140 if isinstance(self
.value
, (int, long)):
141 return hex(self
.value
)
142 elif isinstance(self
.value
, float):
143 return hex(struct
.unpack('Q', struct
.pack('d', self
.value
))[0])
148 if isinstance(self
.value
, (bool)):
149 return "nir_type_bool32"
150 elif isinstance(self
.value
, (int, long)):
151 return "nir_type_int"
152 elif isinstance(self
.value
, float):
153 return "nir_type_float"
155 _var_name_re
= re
.compile(r
"(?P<const>#)?(?P<name>\w+)"
156 r
"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
157 r
"(?P<cond>\([^\)]+\))?")
159 class Variable(Value
):
160 def __init__(self
, val
, name
, varset
):
161 Value
.__init
__(self
, name
, "variable")
163 m
= _var_name_re
.match(val
)
164 assert m
and m
.group('name') is not None
166 self
.var_name
= m
.group('name')
167 self
.is_constant
= m
.group('const') is not None
168 self
.cond
= m
.group('cond')
169 self
.required_type
= m
.group('type')
170 self
.bit_size
= int(m
.group('bits')) if m
.group('bits') else 0
172 if self
.required_type
== 'bool':
173 assert self
.bit_size
== 0 or self
.bit_size
== 32
176 if self
.required_type
is not None:
177 assert self
.required_type
in ('float', 'bool', 'int', 'uint')
179 self
.index
= varset
[self
.var_name
]
182 if self
.required_type
== 'bool':
183 return "nir_type_bool32"
184 elif self
.required_type
in ('int', 'uint'):
185 return "nir_type_int"
186 elif self
.required_type
== 'float':
187 return "nir_type_float"
189 _opcode_re
= re
.compile(r
"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
190 r
"(?P<cond>\([^\)]+\))?")
192 class Expression(Value
):
193 def __init__(self
, expr
, name_base
, varset
):
194 Value
.__init
__(self
, name_base
, "expression")
195 assert isinstance(expr
, tuple)
197 m
= _opcode_re
.match(expr
[0])
198 assert m
and m
.group('opcode') is not None
200 self
.opcode
= m
.group('opcode')
201 self
.bit_size
= int(m
.group('bits')) if m
.group('bits') else 0
202 self
.inexact
= m
.group('inexact') is not None
203 self
.cond
= m
.group('cond')
204 self
.sources
= [ Value
.create(src
, "{0}_{1}".format(name_base
, i
), varset
)
205 for (i
, src
) in enumerate(expr
[1:]) ]
208 srcs
= "\n".join(src
.render() for src
in self
.sources
)
209 return srcs
+ super(Expression
, self
).render()
211 class IntEquivalenceRelation(object):
212 """A class representing an equivalence relation on integers.
214 Each integer has a canonical form which is the maximum integer to which it
215 is equivalent. Two integers are equivalent precisely when they have the
218 The convention of maximum is explicitly chosen to make using it in
219 BitSizeValidator easier because it means that an actual bit_size (if any)
220 will always be the canonical form.
225 def get_canonical(self
, x
):
226 """Get the canonical integer corresponding to x."""
228 return self
.get_canonical(self
._remap
[x
])
232 def add_equiv(self
, a
, b
):
233 """Add an equivalence and return the canonical form."""
234 c
= max(self
.get_canonical(a
), self
.get_canonical(b
))
245 class BitSizeValidator(object):
246 """A class for validating bit sizes of expressions.
248 NIR supports multiple bit-sizes on expressions in order to handle things
249 such as fp64. The source and destination of every ALU operation is
250 assigned a type and that type may or may not specify a bit size. Sources
251 and destinations whose type does not specify a bit size are considered
252 "unsized" and automatically take on the bit size of the corresponding
253 register or SSA value. NIR has two simple rules for bit sizes that are
254 validated by nir_validator:
256 1) A given SSA def or register has a single bit size that is respected by
257 everything that reads from it or writes to it.
259 2) The bit sizes of all unsized inputs/outputs on any given ALU
260 instruction must match. They need not match the sized inputs or
261 outputs but they must match each other.
263 In order to keep nir_algebraic relatively simple and easy-to-use,
264 nir_search supports a type of bit-size inference based on the two rules
265 above. This is similar to type inference in many common programming
266 languages. If, for instance, you are constructing an add operation and you
267 know the second source is 16-bit, then you know that the other source and
268 the destination must also be 16-bit. There are, however, cases where this
269 inference can be ambiguous or contradictory. Consider, for instance, the
270 following transformation:
272 (('usub_borrow', a, b), ('b2i', ('ult', a, b)))
274 This transformation can potentially cause a problem because usub_borrow is
275 well-defined for any bit-size of integer. However, b2i always generates a
276 32-bit result so it could end up replacing a 64-bit expression with one
277 that takes two 64-bit values and produces a 32-bit value. As another
278 example, consider this expression:
280 (('bcsel', a, b, 0), ('iand', a, b))
282 In this case, in the search expression a must be 32-bit but b can
283 potentially have any bit size. If we had a 64-bit b value, we would end up
284 trying to and a 32-bit value with a 64-bit value which would be invalid
286 This class solves that problem by providing a validation layer that proves
287 that a given search-and-replace operation is 100% well-defined before we
288 generate any code. This ensures that bugs are caught at compile time
289 rather than at run time.
291 The basic operation of the validator is very similar to the bitsize_tree in
292 nir_search only a little more subtle. Instead of simply tracking bit
293 sizes, it tracks "bit classes" where each class is represented by an
294 integer. A value of 0 means we don't know anything yet, positive values
295 are actual bit-sizes, and negative values are used to track equivalence
296 classes of sizes that must be the same but have yet to receive an actual
297 size. The first stage uses the bitsize_tree algorithm to assign bit
298 classes to each variable. If it ever comes across an inconsistency, it
299 assert-fails. Then the second stage uses that information to prove that
300 the resulting expression can always validly be constructed.
303 def __init__(self
, varset
):
304 self
._num
_classes
= 0
305 self
._var
_classes
= [0] * len(varset
.names
)
306 self
._class
_relation
= IntEquivalenceRelation()
308 def validate(self
, search
, replace
):
309 dst_class
= self
._propagate
_bit
_size
_up
(search
)
311 dst_class
= self
._new
_class
()
312 self
._propagate
_bit
_class
_down
(search
, dst_class
)
314 validate_dst_class
= self
._validate
_bit
_class
_up
(replace
)
315 assert validate_dst_class
== 0 or validate_dst_class
== dst_class
316 self
._validate
_bit
_class
_down
(replace
, dst_class
)
318 def _new_class(self
):
319 self
._num
_classes
+= 1
320 return -self
._num
_classes
322 def _set_var_bit_class(self
, var_id
, bit_class
):
323 assert bit_class
!= 0
324 var_class
= self
._var
_classes
[var_id
]
326 self
._var
_classes
[var_id
] = bit_class
328 canon_class
= self
._class
_relation
.get_canonical(var_class
)
329 assert canon_class
< 0 or canon_class
== bit_class
330 var_class
= self
._class
_relation
.add_equiv(var_class
, bit_class
)
331 self
._var
_classes
[var_id
] = var_class
333 def _get_var_bit_class(self
, var_id
):
334 return self
._class
_relation
.get_canonical(self
._var
_classes
[var_id
])
336 def _propagate_bit_size_up(self
, val
):
337 if isinstance(val
, (Constant
, Variable
)):
340 elif isinstance(val
, Expression
):
341 nir_op
= opcodes
[val
.opcode
]
343 for i
in range(nir_op
.num_inputs
):
344 src_bits
= self
._propagate
_bit
_size
_up
(val
.sources
[i
])
348 src_type_bits
= type_bits(nir_op
.input_types
[i
])
349 if src_type_bits
!= 0:
350 assert src_bits
== src_type_bits
352 assert val
.common_size
== 0 or src_bits
== val
.common_size
353 val
.common_size
= src_bits
355 dst_type_bits
= type_bits(nir_op
.output_type
)
356 if dst_type_bits
!= 0:
357 assert val
.bit_size
== 0 or val
.bit_size
== dst_type_bits
360 if val
.common_size
!= 0:
361 assert val
.bit_size
== 0 or val
.bit_size
== val
.common_size
363 val
.common_size
= val
.bit_size
364 return val
.common_size
366 def _propagate_bit_class_down(self
, val
, bit_class
):
367 if isinstance(val
, Constant
):
368 assert val
.bit_size
== 0 or val
.bit_size
== bit_class
370 elif isinstance(val
, Variable
):
371 assert val
.bit_size
== 0 or val
.bit_size
== bit_class
372 self
._set
_var
_bit
_class
(val
.index
, bit_class
)
374 elif isinstance(val
, Expression
):
375 nir_op
= opcodes
[val
.opcode
]
376 dst_type_bits
= type_bits(nir_op
.output_type
)
377 if dst_type_bits
!= 0:
378 assert bit_class
== 0 or bit_class
== dst_type_bits
380 assert val
.common_size
== 0 or val
.common_size
== bit_class
381 val
.common_size
= bit_class
384 common_class
= val
.common_size
385 elif nir_op
.num_inputs
:
386 # If we got here then we have no idea what the actual size is.
387 # Instead, we use a generic class
388 common_class
= self
._new
_class
()
390 for i
in range(nir_op
.num_inputs
):
391 src_type_bits
= type_bits(nir_op
.input_types
[i
])
392 if src_type_bits
!= 0:
393 self
._propagate
_bit
_class
_down
(val
.sources
[i
], src_type_bits
)
395 self
._propagate
_bit
_class
_down
(val
.sources
[i
], common_class
)
397 def _validate_bit_class_up(self
, val
):
398 if isinstance(val
, Constant
):
401 elif isinstance(val
, Variable
):
402 var_class
= self
._get
_var
_bit
_class
(val
.index
)
403 # By the time we get to validation, every variable should have a class
404 assert var_class
!= 0
406 # If we have an explicit size provided by the user, the variable
407 # *must* exactly match the search. It cannot be implicitly sized
408 # because otherwise we could end up with a conflict at runtime.
409 assert val
.bit_size
== 0 or val
.bit_size
== var_class
413 elif isinstance(val
, Expression
):
414 nir_op
= opcodes
[val
.opcode
]
416 for i
in range(nir_op
.num_inputs
):
417 src_class
= self
._validate
_bit
_class
_up
(val
.sources
[i
])
421 src_type_bits
= type_bits(nir_op
.input_types
[i
])
422 if src_type_bits
!= 0:
423 assert src_class
== src_type_bits
425 assert val
.common_class
== 0 or src_class
== val
.common_class
426 val
.common_class
= src_class
428 dst_type_bits
= type_bits(nir_op
.output_type
)
429 if dst_type_bits
!= 0:
430 assert val
.bit_size
== 0 or val
.bit_size
== dst_type_bits
433 if val
.common_class
!= 0:
434 assert val
.bit_size
== 0 or val
.bit_size
== val
.common_class
436 val
.common_class
= val
.bit_size
437 return val
.common_class
439 def _validate_bit_class_down(self
, val
, bit_class
):
440 # At this point, everything *must* have a bit class. Otherwise, we have
441 # a value we don't know how to define.
442 assert bit_class
!= 0
444 if isinstance(val
, Constant
):
445 assert val
.bit_size
== 0 or val
.bit_size
== bit_class
447 elif isinstance(val
, Variable
):
448 assert val
.bit_size
== 0 or val
.bit_size
== bit_class
450 elif isinstance(val
, Expression
):
451 nir_op
= opcodes
[val
.opcode
]
452 dst_type_bits
= type_bits(nir_op
.output_type
)
453 if dst_type_bits
!= 0:
454 assert bit_class
== dst_type_bits
456 assert val
.common_class
== 0 or val
.common_class
== bit_class
457 val
.common_class
= bit_class
459 for i
in range(nir_op
.num_inputs
):
460 src_type_bits
= type_bits(nir_op
.input_types
[i
])
461 if src_type_bits
!= 0:
462 self
._validate
_bit
_class
_down
(val
.sources
[i
], src_type_bits
)
464 self
._validate
_bit
_class
_down
(val
.sources
[i
], val
.common_class
)
466 _optimization_ids
= itertools
.count()
468 condition_list
= ['true']
470 class SearchAndReplace(object):
471 def __init__(self
, transform
):
472 self
.id = _optimization_ids
.next()
474 search
= transform
[0]
475 replace
= transform
[1]
476 if len(transform
) > 2:
477 self
.condition
= transform
[2]
479 self
.condition
= 'true'
481 if self
.condition
not in condition_list
:
482 condition_list
.append(self
.condition
)
483 self
.condition_index
= condition_list
.index(self
.condition
)
486 if isinstance(search
, Expression
):
489 self
.search
= Expression(search
, "search{0}".format(self
.id), varset
)
493 if isinstance(replace
, Value
):
494 self
.replace
= replace
496 self
.replace
= Value
.create(replace
, "replace{0}".format(self
.id), varset
)
498 BitSizeValidator(varset
).validate(self
.search
, self
.replace
)
500 _algebraic_pass_template
= mako
.template
.Template("""
502 #include "nir_search.h"
504 #ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
505 #define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
508 const nir_search_expression *search;
509 const nir_search_value *replace;
510 unsigned condition_offset;
515 % for (opcode, xform_list) in xform_dict.iteritems():
516 % for xform in xform_list:
517 ${xform.search.render()}
518 ${xform.replace.render()}
521 static const struct transform ${pass_name}_${opcode}_xforms[] = {
522 % for xform in xform_list:
523 { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} },
529 ${pass_name}_block(nir_block *block, const bool *condition_flags,
532 bool progress = false;
534 nir_foreach_instr_reverse_safe(instr, block) {
535 if (instr->type != nir_instr_type_alu)
538 nir_alu_instr *alu = nir_instr_as_alu(instr);
539 if (!alu->dest.dest.is_ssa)
543 % for opcode in xform_dict.keys():
544 case nir_op_${opcode}:
545 for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
546 const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
547 if (condition_flags[xform->condition_offset] &&
548 nir_replace_instr(alu, xform->search, xform->replace,
565 ${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
567 void *mem_ctx = ralloc_parent(impl);
568 bool progress = false;
570 nir_foreach_block_reverse(block, impl) {
571 progress |= ${pass_name}_block(block, condition_flags, mem_ctx);
575 nir_metadata_preserve(impl, nir_metadata_block_index |
576 nir_metadata_dominance);
583 ${pass_name}(nir_shader *shader)
585 bool progress = false;
586 bool condition_flags[${len(condition_list)}];
587 const nir_shader_compiler_options *options = shader->options;
590 % for index, condition in enumerate(condition_list):
591 condition_flags[${index}] = ${condition};
594 nir_foreach_function(function, shader) {
596 progress |= ${pass_name}_impl(function->impl, condition_flags);
603 class AlgebraicPass(object):
604 def __init__(self
, pass_name
, transforms
):
606 self
.pass_name
= pass_name
610 for xform
in transforms
:
611 if not isinstance(xform
, SearchAndReplace
):
613 xform
= SearchAndReplace(xform
)
615 print("Failed to parse transformation:", file=sys
.stderr
)
616 print(" " + str(xform
), file=sys
.stderr
)
617 traceback
.print_exc(file=sys
.stderr
)
618 print('', file=sys
.stderr
)
622 if xform
.search
.opcode
not in self
.xform_dict
:
623 self
.xform_dict
[xform
.search
.opcode
] = []
625 self
.xform_dict
[xform
.search
.opcode
].append(xform
)
631 return _algebraic_pass_template
.render(pass_name
=self
.pass_name
,
632 xform_dict
=self
.xform_dict
,
633 condition_list
=condition_list
)