2 # Copyright (C) 2014 Intel Corporation
4 # Permission is hereby granted, free of charge, to any person obtaining a
5 # copy of this software and associated documentation files (the "Software"),
6 # to deal in the Software without restriction, including without limitation
7 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 # and/or sell copies of the Software, and to permit persons to whom the
9 # Software is furnished to do so, subject to the following conditions:
11 # The above copyright notice and this permission notice (including the next
12 # paragraph) shall be included in all copies or substantial portions of the
15 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
24 # Jason Ekstrand (jason@jlekstrand.net)
26 from __future__
import print_function
28 from collections
import OrderedDict
36 from nir_opcodes
import opcodes
38 if sys
.version_info
< (3, 0):
39 integer_types
= (int, long)
43 integer_types
= (int, )
46 _type_re
= re
.compile(r
"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
48 def type_bits(type_str
):
49 m
= _type_re
.match(type_str
)
50 assert m
.group('type')
52 if m
.group('bits') is None:
55 return int(m
.group('bits'))
57 # Represents a set of variables, each with a unique id
61 self
.ids
= itertools
.count()
62 self
.immutable
= False;
64 def __getitem__(self
, name
):
65 if name
not in self
.names
:
66 assert not self
.immutable
, "Unknown replacement variable: " + name
67 self
.names
[name
] = next(self
.ids
)
69 return self
.names
[name
]
76 def create(val
, name_base
, varset
):
77 if isinstance(val
, bytes
):
78 val
= val
.decode('utf-8')
80 if isinstance(val
, tuple):
81 return Expression(val
, name_base
, varset
)
82 elif isinstance(val
, Expression
):
84 elif isinstance(val
, string_type
):
85 return Variable(val
, name_base
, varset
)
86 elif isinstance(val
, (bool, float) + integer_types
):
87 return Constant(val
, name_base
)
89 __template
= mako
.template
.Template("""
90 static const ${val.c_type} ${val.name} = {
91 { ${val.type_enum}, ${val.bit_size} },
92 % if isinstance(val, Constant):
93 ${val.type()}, { ${val.hex()} /* ${val.value} */ },
94 % elif isinstance(val, Variable):
95 ${val.index}, /* ${val.var_name} */
96 ${'true' if val.is_constant else 'false'},
97 ${val.type() or 'nir_type_invalid' },
98 ${val.cond if val.cond else 'NULL'},
99 % elif isinstance(val, Expression):
100 ${'true' if val.inexact else 'false'},
101 nir_op_${val.opcode},
102 { ${', '.join(src.c_ptr for src in val.sources)} },
103 ${val.cond if val.cond else 'NULL'},
107 def __init__(self
, name
, type_str
):
109 self
.type_str
= type_str
113 return "nir_search_value_" + self
.type_str
117 return "nir_search_" + self
.type_str
121 return "&{0}.value".format(self
.name
)
124 return self
.__template
.render(val
=self
,
127 Expression
=Expression
)
129 _constant_re
= re
.compile(r
"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
131 class Constant(Value
):
132 def __init__(self
, val
, name
):
133 Value
.__init
__(self
, name
, "constant")
135 if isinstance(val
, (str)):
136 m
= _constant_re
.match(val
)
137 self
.value
= ast
.literal_eval(m
.group('value'))
138 self
.bit_size
= int(m
.group('bits')) if m
.group('bits') else 0
143 if isinstance(self
.value
, bool):
144 assert self
.bit_size
== 0 or self
.bit_size
== 32
148 if isinstance(self
.value
, (bool)):
149 return 'NIR_TRUE' if self
.value
else 'NIR_FALSE'
150 if isinstance(self
.value
, integer_types
):
151 return hex(self
.value
)
152 elif isinstance(self
.value
, float):
153 i
= struct
.unpack('Q', struct
.pack('d', self
.value
))[0]
156 # On Python 2 this 'L' suffix is automatically added, but not on Python 3
157 # Adding it explicitly makes the generated file identical, regardless
158 # of the Python version running this script.
159 if h
[-1] != 'L' and i
> sys
.maxsize
:
167 if isinstance(self
.value
, (bool)):
168 return "nir_type_bool32"
169 elif isinstance(self
.value
, integer_types
):
170 return "nir_type_int"
171 elif isinstance(self
.value
, float):
172 return "nir_type_float"
174 _var_name_re
= re
.compile(r
"(?P<const>#)?(?P<name>\w+)"
175 r
"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
176 r
"(?P<cond>\([^\)]+\))?")
178 class Variable(Value
):
179 def __init__(self
, val
, name
, varset
):
180 Value
.__init
__(self
, name
, "variable")
182 m
= _var_name_re
.match(val
)
183 assert m
and m
.group('name') is not None
185 self
.var_name
= m
.group('name')
186 self
.is_constant
= m
.group('const') is not None
187 self
.cond
= m
.group('cond')
188 self
.required_type
= m
.group('type')
189 self
.bit_size
= int(m
.group('bits')) if m
.group('bits') else 0
191 if self
.required_type
== 'bool':
192 assert self
.bit_size
== 0 or self
.bit_size
== 32
195 if self
.required_type
is not None:
196 assert self
.required_type
in ('float', 'bool', 'int', 'uint')
198 self
.index
= varset
[self
.var_name
]
201 if self
.required_type
== 'bool':
202 return "nir_type_bool32"
203 elif self
.required_type
in ('int', 'uint'):
204 return "nir_type_int"
205 elif self
.required_type
== 'float':
206 return "nir_type_float"
208 _opcode_re
= re
.compile(r
"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
209 r
"(?P<cond>\([^\)]+\))?")
211 class Expression(Value
):
212 def __init__(self
, expr
, name_base
, varset
):
213 Value
.__init
__(self
, name_base
, "expression")
214 assert isinstance(expr
, tuple)
216 m
= _opcode_re
.match(expr
[0])
217 assert m
and m
.group('opcode') is not None
219 self
.opcode
= m
.group('opcode')
220 self
.bit_size
= int(m
.group('bits')) if m
.group('bits') else 0
221 self
.inexact
= m
.group('inexact') is not None
222 self
.cond
= m
.group('cond')
223 self
.sources
= [ Value
.create(src
, "{0}_{1}".format(name_base
, i
), varset
)
224 for (i
, src
) in enumerate(expr
[1:]) ]
227 srcs
= "\n".join(src
.render() for src
in self
.sources
)
228 return srcs
+ super(Expression
, self
).render()
230 class IntEquivalenceRelation(object):
231 """A class representing an equivalence relation on integers.
233 Each integer has a canonical form which is the maximum integer to which it
234 is equivalent. Two integers are equivalent precisely when they have the
237 The convention of maximum is explicitly chosen to make using it in
238 BitSizeValidator easier because it means that an actual bit_size (if any)
239 will always be the canonical form.
244 def get_canonical(self
, x
):
245 """Get the canonical integer corresponding to x."""
247 return self
.get_canonical(self
._remap
[x
])
251 def add_equiv(self
, a
, b
):
252 """Add an equivalence and return the canonical form."""
253 c
= max(self
.get_canonical(a
), self
.get_canonical(b
))
264 class BitSizeValidator(object):
265 """A class for validating bit sizes of expressions.
267 NIR supports multiple bit-sizes on expressions in order to handle things
268 such as fp64. The source and destination of every ALU operation is
269 assigned a type and that type may or may not specify a bit size. Sources
270 and destinations whose type does not specify a bit size are considered
271 "unsized" and automatically take on the bit size of the corresponding
272 register or SSA value. NIR has two simple rules for bit sizes that are
273 validated by nir_validator:
275 1) A given SSA def or register has a single bit size that is respected by
276 everything that reads from it or writes to it.
278 2) The bit sizes of all unsized inputs/outputs on any given ALU
279 instruction must match. They need not match the sized inputs or
280 outputs but they must match each other.
282 In order to keep nir_algebraic relatively simple and easy-to-use,
283 nir_search supports a type of bit-size inference based on the two rules
284 above. This is similar to type inference in many common programming
285 languages. If, for instance, you are constructing an add operation and you
286 know the second source is 16-bit, then you know that the other source and
287 the destination must also be 16-bit. There are, however, cases where this
288 inference can be ambiguous or contradictory. Consider, for instance, the
289 following transformation:
291 (('usub_borrow', a, b), ('b2i', ('ult', a, b)))
293 This transformation can potentially cause a problem because usub_borrow is
294 well-defined for any bit-size of integer. However, b2i always generates a
295 32-bit result so it could end up replacing a 64-bit expression with one
296 that takes two 64-bit values and produces a 32-bit value. As another
297 example, consider this expression:
299 (('bcsel', a, b, 0), ('iand', a, b))
301 In this case, in the search expression a must be 32-bit but b can
302 potentially have any bit size. If we had a 64-bit b value, we would end up
303 trying to and a 32-bit value with a 64-bit value which would be invalid
305 This class solves that problem by providing a validation layer that proves
306 that a given search-and-replace operation is 100% well-defined before we
307 generate any code. This ensures that bugs are caught at compile time
308 rather than at run time.
310 The basic operation of the validator is very similar to the bitsize_tree in
311 nir_search only a little more subtle. Instead of simply tracking bit
312 sizes, it tracks "bit classes" where each class is represented by an
313 integer. A value of 0 means we don't know anything yet, positive values
314 are actual bit-sizes, and negative values are used to track equivalence
315 classes of sizes that must be the same but have yet to receive an actual
316 size. The first stage uses the bitsize_tree algorithm to assign bit
317 classes to each variable. If it ever comes across an inconsistency, it
318 assert-fails. Then the second stage uses that information to prove that
319 the resulting expression can always validly be constructed.
322 def __init__(self
, varset
):
323 self
._num
_classes
= 0
324 self
._var
_classes
= [0] * len(varset
.names
)
325 self
._class
_relation
= IntEquivalenceRelation()
327 def validate(self
, search
, replace
):
328 dst_class
= self
._propagate
_bit
_size
_up
(search
)
330 dst_class
= self
._new
_class
()
331 self
._propagate
_bit
_class
_down
(search
, dst_class
)
333 validate_dst_class
= self
._validate
_bit
_class
_up
(replace
)
334 assert validate_dst_class
== 0 or validate_dst_class
== dst_class
335 self
._validate
_bit
_class
_down
(replace
, dst_class
)
337 def _new_class(self
):
338 self
._num
_classes
+= 1
339 return -self
._num
_classes
341 def _set_var_bit_class(self
, var_id
, bit_class
):
342 assert bit_class
!= 0
343 var_class
= self
._var
_classes
[var_id
]
345 self
._var
_classes
[var_id
] = bit_class
347 canon_class
= self
._class
_relation
.get_canonical(var_class
)
348 assert canon_class
< 0 or canon_class
== bit_class
349 var_class
= self
._class
_relation
.add_equiv(var_class
, bit_class
)
350 self
._var
_classes
[var_id
] = var_class
352 def _get_var_bit_class(self
, var_id
):
353 return self
._class
_relation
.get_canonical(self
._var
_classes
[var_id
])
355 def _propagate_bit_size_up(self
, val
):
356 if isinstance(val
, (Constant
, Variable
)):
359 elif isinstance(val
, Expression
):
360 nir_op
= opcodes
[val
.opcode
]
362 for i
in range(nir_op
.num_inputs
):
363 src_bits
= self
._propagate
_bit
_size
_up
(val
.sources
[i
])
367 src_type_bits
= type_bits(nir_op
.input_types
[i
])
368 if src_type_bits
!= 0:
369 assert src_bits
== src_type_bits
371 assert val
.common_size
== 0 or src_bits
== val
.common_size
372 val
.common_size
= src_bits
374 dst_type_bits
= type_bits(nir_op
.output_type
)
375 if dst_type_bits
!= 0:
376 assert val
.bit_size
== 0 or val
.bit_size
== dst_type_bits
379 if val
.common_size
!= 0:
380 assert val
.bit_size
== 0 or val
.bit_size
== val
.common_size
382 val
.common_size
= val
.bit_size
383 return val
.common_size
385 def _propagate_bit_class_down(self
, val
, bit_class
):
386 if isinstance(val
, Constant
):
387 assert val
.bit_size
== 0 or val
.bit_size
== bit_class
389 elif isinstance(val
, Variable
):
390 assert val
.bit_size
== 0 or val
.bit_size
== bit_class
391 self
._set
_var
_bit
_class
(val
.index
, bit_class
)
393 elif isinstance(val
, Expression
):
394 nir_op
= opcodes
[val
.opcode
]
395 dst_type_bits
= type_bits(nir_op
.output_type
)
396 if dst_type_bits
!= 0:
397 assert bit_class
== 0 or bit_class
== dst_type_bits
399 assert val
.common_size
== 0 or val
.common_size
== bit_class
400 val
.common_size
= bit_class
403 common_class
= val
.common_size
404 elif nir_op
.num_inputs
:
405 # If we got here then we have no idea what the actual size is.
406 # Instead, we use a generic class
407 common_class
= self
._new
_class
()
409 for i
in range(nir_op
.num_inputs
):
410 src_type_bits
= type_bits(nir_op
.input_types
[i
])
411 if src_type_bits
!= 0:
412 self
._propagate
_bit
_class
_down
(val
.sources
[i
], src_type_bits
)
414 self
._propagate
_bit
_class
_down
(val
.sources
[i
], common_class
)
416 def _validate_bit_class_up(self
, val
):
417 if isinstance(val
, Constant
):
420 elif isinstance(val
, Variable
):
421 var_class
= self
._get
_var
_bit
_class
(val
.index
)
422 # By the time we get to validation, every variable should have a class
423 assert var_class
!= 0
425 # If we have an explicit size provided by the user, the variable
426 # *must* exactly match the search. It cannot be implicitly sized
427 # because otherwise we could end up with a conflict at runtime.
428 assert val
.bit_size
== 0 or val
.bit_size
== var_class
432 elif isinstance(val
, Expression
):
433 nir_op
= opcodes
[val
.opcode
]
435 for i
in range(nir_op
.num_inputs
):
436 src_class
= self
._validate
_bit
_class
_up
(val
.sources
[i
])
440 src_type_bits
= type_bits(nir_op
.input_types
[i
])
441 if src_type_bits
!= 0:
442 assert src_class
== src_type_bits
444 assert val
.common_class
== 0 or src_class
== val
.common_class
445 val
.common_class
= src_class
447 dst_type_bits
= type_bits(nir_op
.output_type
)
448 if dst_type_bits
!= 0:
449 assert val
.bit_size
== 0 or val
.bit_size
== dst_type_bits
452 if val
.common_class
!= 0:
453 assert val
.bit_size
== 0 or val
.bit_size
== val
.common_class
455 val
.common_class
= val
.bit_size
456 return val
.common_class
458 def _validate_bit_class_down(self
, val
, bit_class
):
459 # At this point, everything *must* have a bit class. Otherwise, we have
460 # a value we don't know how to define.
461 assert bit_class
!= 0
463 if isinstance(val
, Constant
):
464 assert val
.bit_size
== 0 or val
.bit_size
== bit_class
466 elif isinstance(val
, Variable
):
467 assert val
.bit_size
== 0 or val
.bit_size
== bit_class
469 elif isinstance(val
, Expression
):
470 nir_op
= opcodes
[val
.opcode
]
471 dst_type_bits
= type_bits(nir_op
.output_type
)
472 if dst_type_bits
!= 0:
473 assert bit_class
== dst_type_bits
475 assert val
.common_class
== 0 or val
.common_class
== bit_class
476 val
.common_class
= bit_class
478 for i
in range(nir_op
.num_inputs
):
479 src_type_bits
= type_bits(nir_op
.input_types
[i
])
480 if src_type_bits
!= 0:
481 self
._validate
_bit
_class
_down
(val
.sources
[i
], src_type_bits
)
483 self
._validate
_bit
_class
_down
(val
.sources
[i
], val
.common_class
)
485 _optimization_ids
= itertools
.count()
487 condition_list
= ['true']
489 class SearchAndReplace(object):
490 def __init__(self
, transform
):
491 self
.id = next(_optimization_ids
)
493 search
= transform
[0]
494 replace
= transform
[1]
495 if len(transform
) > 2:
496 self
.condition
= transform
[2]
498 self
.condition
= 'true'
500 if self
.condition
not in condition_list
:
501 condition_list
.append(self
.condition
)
502 self
.condition_index
= condition_list
.index(self
.condition
)
505 if isinstance(search
, Expression
):
508 self
.search
= Expression(search
, "search{0}".format(self
.id), varset
)
512 if isinstance(replace
, Value
):
513 self
.replace
= replace
515 self
.replace
= Value
.create(replace
, "replace{0}".format(self
.id), varset
)
517 BitSizeValidator(varset
).validate(self
.search
, self
.replace
)
519 _algebraic_pass_template
= mako
.template
.Template("""
521 #include "nir_search.h"
522 #include "nir_search_helpers.h"
524 #ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
525 #define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
528 const nir_search_expression *search;
529 const nir_search_value *replace;
530 unsigned condition_offset;
535 % for (opcode, xform_list) in xform_dict.items():
536 % for xform in xform_list:
537 ${xform.search.render()}
538 ${xform.replace.render()}
541 static const struct transform ${pass_name}_${opcode}_xforms[] = {
542 % for xform in xform_list:
543 { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} },
549 ${pass_name}_block(nir_block *block, const bool *condition_flags,
552 bool progress = false;
554 nir_foreach_instr_reverse_safe(instr, block) {
555 if (instr->type != nir_instr_type_alu)
558 nir_alu_instr *alu = nir_instr_as_alu(instr);
559 if (!alu->dest.dest.is_ssa)
563 % for opcode in xform_dict.keys():
564 case nir_op_${opcode}:
565 for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
566 const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
567 if (condition_flags[xform->condition_offset] &&
568 nir_replace_instr(alu, xform->search, xform->replace,
585 ${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
587 void *mem_ctx = ralloc_parent(impl);
588 bool progress = false;
590 nir_foreach_block_reverse(block, impl) {
591 progress |= ${pass_name}_block(block, condition_flags, mem_ctx);
595 nir_metadata_preserve(impl, nir_metadata_block_index |
596 nir_metadata_dominance);
603 ${pass_name}(nir_shader *shader)
605 bool progress = false;
606 bool condition_flags[${len(condition_list)}];
607 const nir_shader_compiler_options *options = shader->options;
610 % for index, condition in enumerate(condition_list):
611 condition_flags[${index}] = ${condition};
614 nir_foreach_function(function, shader) {
616 progress |= ${pass_name}_impl(function->impl, condition_flags);
623 class AlgebraicPass(object):
624 def __init__(self
, pass_name
, transforms
):
625 self
.xform_dict
= OrderedDict()
626 self
.pass_name
= pass_name
630 for xform
in transforms
:
631 if not isinstance(xform
, SearchAndReplace
):
633 xform
= SearchAndReplace(xform
)
635 print("Failed to parse transformation:", file=sys
.stderr
)
636 print(" " + str(xform
), file=sys
.stderr
)
637 traceback
.print_exc(file=sys
.stderr
)
638 print('', file=sys
.stderr
)
642 if xform
.search
.opcode
not in self
.xform_dict
:
643 self
.xform_dict
[xform
.search
.opcode
] = []
645 self
.xform_dict
[xform
.search
.opcode
].append(xform
)
651 return _algebraic_pass_template
.render(pass_name
=self
.pass_name
,
652 xform_dict
=self
.xform_dict
,
653 condition_list
=condition_list
)