2 # Copyright (C) 2014 Intel Corporation
4 # Permission is hereby granted, free of charge, to any person obtaining a
5 # copy of this software and associated documentation files (the "Software"),
6 # to deal in the Software without restriction, including without limitation
7 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 # and/or sell copies of the Software, and to permit persons to whom the
9 # Software is furnished to do so, subject to the following conditions:
11 # The above copyright notice and this permission notice (including the next
12 # paragraph) shall be included in all copies or substantial portions of the
15 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
24 # Jason Ekstrand (jason@jlekstrand.net)
26 from __future__
import print_function
28 from collections
import OrderedDict
36 from nir_opcodes
import opcodes
38 _type_re
= re
.compile(r
"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
40 def type_bits(type_str
):
41 m
= _type_re
.match(type_str
)
42 assert m
.group('type')
44 if m
.group('bits') is None:
47 return int(m
.group('bits'))
49 # Represents a set of variables, each with a unique id
53 self
.ids
= itertools
.count()
54 self
.immutable
= False;
56 def __getitem__(self
, name
):
57 if name
not in self
.names
:
58 assert not self
.immutable
, "Unknown replacement variable: " + name
59 self
.names
[name
] = next(self
.ids
)
61 return self
.names
[name
]
68 def create(val
, name_base
, varset
):
69 if isinstance(val
, tuple):
70 return Expression(val
, name_base
, varset
)
71 elif isinstance(val
, Expression
):
73 elif isinstance(val
, (str, unicode)):
74 return Variable(val
, name_base
, varset
)
75 elif isinstance(val
, (bool, int, long, float)):
76 return Constant(val
, name_base
)
78 __template
= mako
.template
.Template("""
79 static const ${val.c_type} ${val.name} = {
80 { ${val.type_enum}, ${val.bit_size} },
81 % if isinstance(val, Constant):
82 ${val.type()}, { ${val.hex()} /* ${val.value} */ },
83 % elif isinstance(val, Variable):
84 ${val.index}, /* ${val.var_name} */
85 ${'true' if val.is_constant else 'false'},
86 ${val.type() or 'nir_type_invalid' },
87 ${val.cond if val.cond else 'NULL'},
88 % elif isinstance(val, Expression):
89 ${'true' if val.inexact else 'false'},
91 { ${', '.join(src.c_ptr for src in val.sources)} },
92 ${val.cond if val.cond else 'NULL'},
96 def __init__(self
, name
, type_str
):
98 self
.type_str
= type_str
102 return "nir_search_value_" + self
.type_str
106 return "nir_search_" + self
.type_str
110 return "&{0}.value".format(self
.name
)
113 return self
.__template
.render(val
=self
,
116 Expression
=Expression
)
118 _constant_re
= re
.compile(r
"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
120 class Constant(Value
):
121 def __init__(self
, val
, name
):
122 Value
.__init
__(self
, name
, "constant")
124 if isinstance(val
, (str)):
125 m
= _constant_re
.match(val
)
126 self
.value
= ast
.literal_eval(m
.group('value'))
127 self
.bit_size
= int(m
.group('bits')) if m
.group('bits') else 0
132 if isinstance(self
.value
, bool):
133 assert self
.bit_size
== 0 or self
.bit_size
== 32
137 if isinstance(self
.value
, (bool)):
138 return 'NIR_TRUE' if self
.value
else 'NIR_FALSE'
139 if isinstance(self
.value
, (int, long)):
140 return hex(self
.value
)
141 elif isinstance(self
.value
, float):
142 i
= struct
.unpack('Q', struct
.pack('d', self
.value
))[0]
145 # On Python 2 this 'L' suffix is automatically added, but not on Python 3
146 # Adding it explicitly makes the generated file identical, regardless
147 # of the Python version running this script.
148 if h
[-1] != 'L' and i
> sys
.maxsize
:
156 if isinstance(self
.value
, (bool)):
157 return "nir_type_bool32"
158 elif isinstance(self
.value
, (int, long)):
159 return "nir_type_int"
160 elif isinstance(self
.value
, float):
161 return "nir_type_float"
163 _var_name_re
= re
.compile(r
"(?P<const>#)?(?P<name>\w+)"
164 r
"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
165 r
"(?P<cond>\([^\)]+\))?")
167 class Variable(Value
):
168 def __init__(self
, val
, name
, varset
):
169 Value
.__init
__(self
, name
, "variable")
171 m
= _var_name_re
.match(val
)
172 assert m
and m
.group('name') is not None
174 self
.var_name
= m
.group('name')
175 self
.is_constant
= m
.group('const') is not None
176 self
.cond
= m
.group('cond')
177 self
.required_type
= m
.group('type')
178 self
.bit_size
= int(m
.group('bits')) if m
.group('bits') else 0
180 if self
.required_type
== 'bool':
181 assert self
.bit_size
== 0 or self
.bit_size
== 32
184 if self
.required_type
is not None:
185 assert self
.required_type
in ('float', 'bool', 'int', 'uint')
187 self
.index
= varset
[self
.var_name
]
190 if self
.required_type
== 'bool':
191 return "nir_type_bool32"
192 elif self
.required_type
in ('int', 'uint'):
193 return "nir_type_int"
194 elif self
.required_type
== 'float':
195 return "nir_type_float"
197 _opcode_re
= re
.compile(r
"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
198 r
"(?P<cond>\([^\)]+\))?")
200 class Expression(Value
):
201 def __init__(self
, expr
, name_base
, varset
):
202 Value
.__init
__(self
, name_base
, "expression")
203 assert isinstance(expr
, tuple)
205 m
= _opcode_re
.match(expr
[0])
206 assert m
and m
.group('opcode') is not None
208 self
.opcode
= m
.group('opcode')
209 self
.bit_size
= int(m
.group('bits')) if m
.group('bits') else 0
210 self
.inexact
= m
.group('inexact') is not None
211 self
.cond
= m
.group('cond')
212 self
.sources
= [ Value
.create(src
, "{0}_{1}".format(name_base
, i
), varset
)
213 for (i
, src
) in enumerate(expr
[1:]) ]
216 srcs
= "\n".join(src
.render() for src
in self
.sources
)
217 return srcs
+ super(Expression
, self
).render()
219 class IntEquivalenceRelation(object):
220 """A class representing an equivalence relation on integers.
222 Each integer has a canonical form which is the maximum integer to which it
223 is equivalent. Two integers are equivalent precisely when they have the
226 The convention of maximum is explicitly chosen to make using it in
227 BitSizeValidator easier because it means that an actual bit_size (if any)
228 will always be the canonical form.
233 def get_canonical(self
, x
):
234 """Get the canonical integer corresponding to x."""
236 return self
.get_canonical(self
._remap
[x
])
240 def add_equiv(self
, a
, b
):
241 """Add an equivalence and return the canonical form."""
242 c
= max(self
.get_canonical(a
), self
.get_canonical(b
))
253 class BitSizeValidator(object):
254 """A class for validating bit sizes of expressions.
256 NIR supports multiple bit-sizes on expressions in order to handle things
257 such as fp64. The source and destination of every ALU operation is
258 assigned a type and that type may or may not specify a bit size. Sources
259 and destinations whose type does not specify a bit size are considered
260 "unsized" and automatically take on the bit size of the corresponding
261 register or SSA value. NIR has two simple rules for bit sizes that are
262 validated by nir_validator:
264 1) A given SSA def or register has a single bit size that is respected by
265 everything that reads from it or writes to it.
267 2) The bit sizes of all unsized inputs/outputs on any given ALU
268 instruction must match. They need not match the sized inputs or
269 outputs but they must match each other.
271 In order to keep nir_algebraic relatively simple and easy-to-use,
272 nir_search supports a type of bit-size inference based on the two rules
273 above. This is similar to type inference in many common programming
274 languages. If, for instance, you are constructing an add operation and you
275 know the second source is 16-bit, then you know that the other source and
276 the destination must also be 16-bit. There are, however, cases where this
277 inference can be ambiguous or contradictory. Consider, for instance, the
278 following transformation:
280 (('usub_borrow', a, b), ('b2i', ('ult', a, b)))
282 This transformation can potentially cause a problem because usub_borrow is
283 well-defined for any bit-size of integer. However, b2i always generates a
284 32-bit result so it could end up replacing a 64-bit expression with one
285 that takes two 64-bit values and produces a 32-bit value. As another
286 example, consider this expression:
288 (('bcsel', a, b, 0), ('iand', a, b))
290 In this case, in the search expression a must be 32-bit but b can
291 potentially have any bit size. If we had a 64-bit b value, we would end up
292 trying to and a 32-bit value with a 64-bit value which would be invalid
294 This class solves that problem by providing a validation layer that proves
295 that a given search-and-replace operation is 100% well-defined before we
296 generate any code. This ensures that bugs are caught at compile time
297 rather than at run time.
299 The basic operation of the validator is very similar to the bitsize_tree in
300 nir_search only a little more subtle. Instead of simply tracking bit
301 sizes, it tracks "bit classes" where each class is represented by an
302 integer. A value of 0 means we don't know anything yet, positive values
303 are actual bit-sizes, and negative values are used to track equivalence
304 classes of sizes that must be the same but have yet to receive an actual
305 size. The first stage uses the bitsize_tree algorithm to assign bit
306 classes to each variable. If it ever comes across an inconsistency, it
307 assert-fails. Then the second stage uses that information to prove that
308 the resulting expression can always validly be constructed.
311 def __init__(self
, varset
):
312 self
._num
_classes
= 0
313 self
._var
_classes
= [0] * len(varset
.names
)
314 self
._class
_relation
= IntEquivalenceRelation()
316 def validate(self
, search
, replace
):
317 dst_class
= self
._propagate
_bit
_size
_up
(search
)
319 dst_class
= self
._new
_class
()
320 self
._propagate
_bit
_class
_down
(search
, dst_class
)
322 validate_dst_class
= self
._validate
_bit
_class
_up
(replace
)
323 assert validate_dst_class
== 0 or validate_dst_class
== dst_class
324 self
._validate
_bit
_class
_down
(replace
, dst_class
)
326 def _new_class(self
):
327 self
._num
_classes
+= 1
328 return -self
._num
_classes
330 def _set_var_bit_class(self
, var_id
, bit_class
):
331 assert bit_class
!= 0
332 var_class
= self
._var
_classes
[var_id
]
334 self
._var
_classes
[var_id
] = bit_class
336 canon_class
= self
._class
_relation
.get_canonical(var_class
)
337 assert canon_class
< 0 or canon_class
== bit_class
338 var_class
= self
._class
_relation
.add_equiv(var_class
, bit_class
)
339 self
._var
_classes
[var_id
] = var_class
341 def _get_var_bit_class(self
, var_id
):
342 return self
._class
_relation
.get_canonical(self
._var
_classes
[var_id
])
344 def _propagate_bit_size_up(self
, val
):
345 if isinstance(val
, (Constant
, Variable
)):
348 elif isinstance(val
, Expression
):
349 nir_op
= opcodes
[val
.opcode
]
351 for i
in range(nir_op
.num_inputs
):
352 src_bits
= self
._propagate
_bit
_size
_up
(val
.sources
[i
])
356 src_type_bits
= type_bits(nir_op
.input_types
[i
])
357 if src_type_bits
!= 0:
358 assert src_bits
== src_type_bits
360 assert val
.common_size
== 0 or src_bits
== val
.common_size
361 val
.common_size
= src_bits
363 dst_type_bits
= type_bits(nir_op
.output_type
)
364 if dst_type_bits
!= 0:
365 assert val
.bit_size
== 0 or val
.bit_size
== dst_type_bits
368 if val
.common_size
!= 0:
369 assert val
.bit_size
== 0 or val
.bit_size
== val
.common_size
371 val
.common_size
= val
.bit_size
372 return val
.common_size
374 def _propagate_bit_class_down(self
, val
, bit_class
):
375 if isinstance(val
, Constant
):
376 assert val
.bit_size
== 0 or val
.bit_size
== bit_class
378 elif isinstance(val
, Variable
):
379 assert val
.bit_size
== 0 or val
.bit_size
== bit_class
380 self
._set
_var
_bit
_class
(val
.index
, bit_class
)
382 elif isinstance(val
, Expression
):
383 nir_op
= opcodes
[val
.opcode
]
384 dst_type_bits
= type_bits(nir_op
.output_type
)
385 if dst_type_bits
!= 0:
386 assert bit_class
== 0 or bit_class
== dst_type_bits
388 assert val
.common_size
== 0 or val
.common_size
== bit_class
389 val
.common_size
= bit_class
392 common_class
= val
.common_size
393 elif nir_op
.num_inputs
:
394 # If we got here then we have no idea what the actual size is.
395 # Instead, we use a generic class
396 common_class
= self
._new
_class
()
398 for i
in range(nir_op
.num_inputs
):
399 src_type_bits
= type_bits(nir_op
.input_types
[i
])
400 if src_type_bits
!= 0:
401 self
._propagate
_bit
_class
_down
(val
.sources
[i
], src_type_bits
)
403 self
._propagate
_bit
_class
_down
(val
.sources
[i
], common_class
)
405 def _validate_bit_class_up(self
, val
):
406 if isinstance(val
, Constant
):
409 elif isinstance(val
, Variable
):
410 var_class
= self
._get
_var
_bit
_class
(val
.index
)
411 # By the time we get to validation, every variable should have a class
412 assert var_class
!= 0
414 # If we have an explicit size provided by the user, the variable
415 # *must* exactly match the search. It cannot be implicitly sized
416 # because otherwise we could end up with a conflict at runtime.
417 assert val
.bit_size
== 0 or val
.bit_size
== var_class
421 elif isinstance(val
, Expression
):
422 nir_op
= opcodes
[val
.opcode
]
424 for i
in range(nir_op
.num_inputs
):
425 src_class
= self
._validate
_bit
_class
_up
(val
.sources
[i
])
429 src_type_bits
= type_bits(nir_op
.input_types
[i
])
430 if src_type_bits
!= 0:
431 assert src_class
== src_type_bits
433 assert val
.common_class
== 0 or src_class
== val
.common_class
434 val
.common_class
= src_class
436 dst_type_bits
= type_bits(nir_op
.output_type
)
437 if dst_type_bits
!= 0:
438 assert val
.bit_size
== 0 or val
.bit_size
== dst_type_bits
441 if val
.common_class
!= 0:
442 assert val
.bit_size
== 0 or val
.bit_size
== val
.common_class
444 val
.common_class
= val
.bit_size
445 return val
.common_class
447 def _validate_bit_class_down(self
, val
, bit_class
):
448 # At this point, everything *must* have a bit class. Otherwise, we have
449 # a value we don't know how to define.
450 assert bit_class
!= 0
452 if isinstance(val
, Constant
):
453 assert val
.bit_size
== 0 or val
.bit_size
== bit_class
455 elif isinstance(val
, Variable
):
456 assert val
.bit_size
== 0 or val
.bit_size
== bit_class
458 elif isinstance(val
, Expression
):
459 nir_op
= opcodes
[val
.opcode
]
460 dst_type_bits
= type_bits(nir_op
.output_type
)
461 if dst_type_bits
!= 0:
462 assert bit_class
== dst_type_bits
464 assert val
.common_class
== 0 or val
.common_class
== bit_class
465 val
.common_class
= bit_class
467 for i
in range(nir_op
.num_inputs
):
468 src_type_bits
= type_bits(nir_op
.input_types
[i
])
469 if src_type_bits
!= 0:
470 self
._validate
_bit
_class
_down
(val
.sources
[i
], src_type_bits
)
472 self
._validate
_bit
_class
_down
(val
.sources
[i
], val
.common_class
)
474 _optimization_ids
= itertools
.count()
476 condition_list
= ['true']
478 class SearchAndReplace(object):
479 def __init__(self
, transform
):
480 self
.id = next(_optimization_ids
)
482 search
= transform
[0]
483 replace
= transform
[1]
484 if len(transform
) > 2:
485 self
.condition
= transform
[2]
487 self
.condition
= 'true'
489 if self
.condition
not in condition_list
:
490 condition_list
.append(self
.condition
)
491 self
.condition_index
= condition_list
.index(self
.condition
)
494 if isinstance(search
, Expression
):
497 self
.search
= Expression(search
, "search{0}".format(self
.id), varset
)
501 if isinstance(replace
, Value
):
502 self
.replace
= replace
504 self
.replace
= Value
.create(replace
, "replace{0}".format(self
.id), varset
)
506 BitSizeValidator(varset
).validate(self
.search
, self
.replace
)
508 _algebraic_pass_template
= mako
.template
.Template("""
510 #include "nir_search.h"
511 #include "nir_search_helpers.h"
513 #ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
514 #define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
517 const nir_search_expression *search;
518 const nir_search_value *replace;
519 unsigned condition_offset;
524 % for (opcode, xform_list) in xform_dict.items():
525 % for xform in xform_list:
526 ${xform.search.render()}
527 ${xform.replace.render()}
530 static const struct transform ${pass_name}_${opcode}_xforms[] = {
531 % for xform in xform_list:
532 { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} },
538 ${pass_name}_block(nir_block *block, const bool *condition_flags,
541 bool progress = false;
543 nir_foreach_instr_reverse_safe(instr, block) {
544 if (instr->type != nir_instr_type_alu)
547 nir_alu_instr *alu = nir_instr_as_alu(instr);
548 if (!alu->dest.dest.is_ssa)
552 % for opcode in xform_dict.keys():
553 case nir_op_${opcode}:
554 for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
555 const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
556 if (condition_flags[xform->condition_offset] &&
557 nir_replace_instr(alu, xform->search, xform->replace,
574 ${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
576 void *mem_ctx = ralloc_parent(impl);
577 bool progress = false;
579 nir_foreach_block_reverse(block, impl) {
580 progress |= ${pass_name}_block(block, condition_flags, mem_ctx);
584 nir_metadata_preserve(impl, nir_metadata_block_index |
585 nir_metadata_dominance);
592 ${pass_name}(nir_shader *shader)
594 bool progress = false;
595 bool condition_flags[${len(condition_list)}];
596 const nir_shader_compiler_options *options = shader->options;
599 % for index, condition in enumerate(condition_list):
600 condition_flags[${index}] = ${condition};
603 nir_foreach_function(function, shader) {
605 progress |= ${pass_name}_impl(function->impl, condition_flags);
612 class AlgebraicPass(object):
613 def __init__(self
, pass_name
, transforms
):
614 self
.xform_dict
= OrderedDict()
615 self
.pass_name
= pass_name
619 for xform
in transforms
:
620 if not isinstance(xform
, SearchAndReplace
):
622 xform
= SearchAndReplace(xform
)
624 print("Failed to parse transformation:", file=sys
.stderr
)
625 print(" " + str(xform
), file=sys
.stderr
)
626 traceback
.print_exc(file=sys
.stderr
)
627 print('', file=sys
.stderr
)
631 if xform
.search
.opcode
not in self
.xform_dict
:
632 self
.xform_dict
[xform
.search
.opcode
] = []
634 self
.xform_dict
[xform
.search
.opcode
].append(xform
)
640 return _algebraic_pass_template
.render(pass_name
=self
.pass_name
,
641 xform_dict
=self
.xform_dict
,
642 condition_list
=condition_list
)