3 # Copyright (C) 2014 Intel Corporation
5 # Permission is hereby granted, free of charge, to any person obtaining a
6 # copy of this software and associated documentation files (the "Software"),
7 # to deal in the Software without restriction, including without limitation
8 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
9 # and/or sell copies of the Software, and to permit persons to whom the
10 # Software is furnished to do so, subject to the following conditions:
12 # The above copyright notice and this permission notice (including the next
13 # paragraph) shall be included in all copies or substantial portions of the
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
19 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
25 # Jason Ekstrand (jason@jlekstrand.net)
27 from __future__
import print_function
36 from nir_opcodes
import opcodes
38 _type_re
= re
.compile(r
"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
40 def type_bits(type_str
):
41 m
= _type_re
.match(type_str
)
42 assert m
.group('type')
44 if m
.group('bits') is None:
47 return int(m
.group('bits'))
49 # Represents a set of variables, each with a unique id
53 self
.ids
= itertools
.count()
54 self
.immutable
= False;
56 def __getitem__(self
, name
):
57 if name
not in self
.names
:
58 assert not self
.immutable
, "Unknown replacement variable: " + name
59 self
.names
[name
] = self
.ids
.next()
61 return self
.names
[name
]
68 def create(val
, name_base
, varset
):
69 if isinstance(val
, tuple):
70 return Expression(val
, name_base
, varset
)
71 elif isinstance(val
, Expression
):
73 elif isinstance(val
, (str, unicode)):
74 return Variable(val
, name_base
, varset
)
75 elif isinstance(val
, (bool, int, long, float)):
76 return Constant(val
, name_base
)
78 __template
= mako
.template
.Template("""
79 static const ${val.c_type} ${val.name} = {
80 { ${val.type_enum}, ${val.bit_size} },
81 % if isinstance(val, Constant):
82 ${val.type()}, { ${hex(val)} /* ${val.value} */ },
83 % elif isinstance(val, Variable):
84 ${val.index}, /* ${val.var_name} */
85 ${'true' if val.is_constant else 'false'},
86 ${val.type() or 'nir_type_invalid' },
87 % elif isinstance(val, Expression):
88 ${'true' if val.inexact else 'false'},
90 { ${', '.join(src.c_ptr for src in val.sources)} },
94 def __init__(self
, name
, type_str
):
96 self
.type_str
= type_str
100 return "nir_search_value_" + self
.type_str
104 return "nir_search_" + self
.type_str
108 return "&{0}.value".format(self
.name
)
111 return self
.__template
.render(val
=self
,
114 Expression
=Expression
)
116 _constant_re
= re
.compile(r
"(?P<value>[^@]+)(?:@(?P<bits>\d+))?")
118 class Constant(Value
):
119 def __init__(self
, val
, name
):
120 Value
.__init
__(self
, name
, "constant")
122 if isinstance(val
, (str)):
123 m
= _constant_re
.match(val
)
124 self
.value
= ast
.literal_eval(m
.group('value'))
125 self
.bit_size
= int(m
.group('bits')) if m
.group('bits') else 0
130 if isinstance(self
.value
, bool):
131 assert self
.bit_size
== 0 or self
.bit_size
== 32
135 if isinstance(self
.value
, (bool)):
136 return 'NIR_TRUE' if self
.value
else 'NIR_FALSE'
137 if isinstance(self
.value
, (int, long)):
138 return hex(self
.value
)
139 elif isinstance(self
.value
, float):
140 return hex(struct
.unpack('Q', struct
.pack('d', self
.value
))[0])
145 if isinstance(self
.value
, (bool)):
146 return "nir_type_bool32"
147 elif isinstance(self
.value
, (int, long)):
148 return "nir_type_int"
149 elif isinstance(self
.value
, float):
150 return "nir_type_float"
152 _var_name_re
= re
.compile(r
"(?P<const>#)?(?P<name>\w+)"
153 r
"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?")
155 class Variable(Value
):
156 def __init__(self
, val
, name
, varset
):
157 Value
.__init
__(self
, name
, "variable")
159 m
= _var_name_re
.match(val
)
160 assert m
and m
.group('name') is not None
162 self
.var_name
= m
.group('name')
163 self
.is_constant
= m
.group('const') is not None
164 self
.required_type
= m
.group('type')
165 self
.bit_size
= int(m
.group('bits')) if m
.group('bits') else 0
167 if self
.required_type
== 'bool':
168 assert self
.bit_size
== 0 or self
.bit_size
== 32
171 if self
.required_type
is not None:
172 assert self
.required_type
in ('float', 'bool', 'int', 'uint')
174 self
.index
= varset
[self
.var_name
]
177 if self
.required_type
== 'bool':
178 return "nir_type_bool32"
179 elif self
.required_type
in ('int', 'uint'):
180 return "nir_type_int"
181 elif self
.required_type
== 'float':
182 return "nir_type_float"
184 _opcode_re
= re
.compile(r
"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?")
186 class Expression(Value
):
187 def __init__(self
, expr
, name_base
, varset
):
188 Value
.__init
__(self
, name_base
, "expression")
189 assert isinstance(expr
, tuple)
191 m
= _opcode_re
.match(expr
[0])
192 assert m
and m
.group('opcode') is not None
194 self
.opcode
= m
.group('opcode')
195 self
.bit_size
= int(m
.group('bits')) if m
.group('bits') else 0
196 self
.inexact
= m
.group('inexact') is not None
197 self
.sources
= [ Value
.create(src
, "{0}_{1}".format(name_base
, i
), varset
)
198 for (i
, src
) in enumerate(expr
[1:]) ]
201 srcs
= "\n".join(src
.render() for src
in self
.sources
)
202 return srcs
+ super(Expression
, self
).render()
204 class IntEquivalenceRelation(object):
205 """A class representing an equivalence relation on integers.
207 Each integer has a canonical form which is the maximum integer to which it
208 is equivalent. Two integers are equivalent precisely when they have the
211 The convention of maximum is explicitly chosen to make using it in
212 BitSizeValidator easier because it means that an actual bit_size (if any)
213 will always be the canonical form.
218 def get_canonical(self
, x
):
219 """Get the canonical integer corresponding to x."""
221 return self
.get_canonical(self
._remap
[x
])
225 def add_equiv(self
, a
, b
):
226 """Add an equivalence and return the canonical form."""
227 c
= max(self
.get_canonical(a
), self
.get_canonical(b
))
238 class BitSizeValidator(object):
239 """A class for validating bit sizes of expressions.
241 NIR supports multiple bit-sizes on expressions in order to handle things
242 such as fp64. The source and destination of every ALU operation is
243 assigned a type and that type may or may not specify a bit size. Sources
244 and destinations whose type does not specify a bit size are considered
245 "unsized" and automatically take on the bit size of the corresponding
246 register or SSA value. NIR has two simple rules for bit sizes that are
247 validated by nir_validator:
249 1) A given SSA def or register has a single bit size that is respected by
250 everything that reads from it or writes to it.
252 2) The bit sizes of all unsized inputs/outputs on any given ALU
253 instruction must match. They need not match the sized inputs or
254 outputs but they must match each other.
256 In order to keep nir_algebraic relatively simple and easy-to-use,
257 nir_search supports a type of bit-size inference based on the two rules
258 above. This is similar to type inference in many common programming
259 languages. If, for instance, you are constructing an add operation and you
260 know the second source is 16-bit, then you know that the other source and
261 the destination must also be 16-bit. There are, however, cases where this
262 inference can be ambiguous or contradictory. Consider, for instance, the
263 following transformation:
265 (('usub_borrow', a, b), ('b2i', ('ult', a, b)))
267 This transformation can potentially cause a problem because usub_borrow is
268 well-defined for any bit-size of integer. However, b2i always generates a
269 32-bit result so it could end up replacing a 64-bit expression with one
270 that takes two 64-bit values and produces a 32-bit value. As another
271 example, consider this expression:
273 (('bcsel', a, b, 0), ('iand', a, b))
275 In this case, in the search expression a must be 32-bit but b can
276 potentially have any bit size. If we had a 64-bit b value, we would end up
277 trying to and a 32-bit value with a 64-bit value which would be invalid
279 This class solves that problem by providing a validation layer that proves
280 that a given search-and-replace operation is 100% well-defined before we
281 generate any code. This ensures that bugs are caught at compile time
282 rather than at run time.
284 The basic operation of the validator is very similar to the bitsize_tree in
285 nir_search only a little more subtle. Instead of simply tracking bit
286 sizes, it tracks "bit classes" where each class is represented by an
287 integer. A value of 0 means we don't know anything yet, positive values
288 are actual bit-sizes, and negative values are used to track equivalence
289 classes of sizes that must be the same but have yet to receive an actual
290 size. The first stage uses the bitsize_tree algorithm to assign bit
291 classes to each variable. If it ever comes across an inconsistency, it
292 assert-fails. Then the second stage uses that information to prove that
293 the resulting expression can always validly be constructed.
296 def __init__(self
, varset
):
297 self
._num
_classes
= 0
298 self
._var
_classes
= [0] * len(varset
.names
)
299 self
._class
_relation
= IntEquivalenceRelation()
301 def validate(self
, search
, replace
):
302 dst_class
= self
._propagate
_bit
_size
_up
(search
)
304 dst_class
= self
._new
_class
()
305 self
._propagate
_bit
_class
_down
(search
, dst_class
)
307 validate_dst_class
= self
._validate
_bit
_class
_up
(replace
)
308 assert validate_dst_class
== 0 or validate_dst_class
== dst_class
309 self
._validate
_bit
_class
_down
(replace
, dst_class
)
311 def _new_class(self
):
312 self
._num
_classes
+= 1
313 return -self
._num
_classes
315 def _set_var_bit_class(self
, var_id
, bit_class
):
316 assert bit_class
!= 0
317 var_class
= self
._var
_classes
[var_id
]
319 self
._var
_classes
[var_id
] = bit_class
321 canon_class
= self
._class
_relation
.get_canonical(var_class
)
322 assert canon_class
< 0 or canon_class
== bit_class
323 var_class
= self
._class
_relation
.add_equiv(var_class
, bit_class
)
324 self
._var
_classes
[var_id
] = var_class
326 def _get_var_bit_class(self
, var_id
):
327 return self
._class
_relation
.get_canonical(self
._var
_classes
[var_id
])
329 def _propagate_bit_size_up(self
, val
):
330 if isinstance(val
, (Constant
, Variable
)):
333 elif isinstance(val
, Expression
):
334 nir_op
= opcodes
[val
.opcode
]
336 for i
in range(nir_op
.num_inputs
):
337 src_bits
= self
._propagate
_bit
_size
_up
(val
.sources
[i
])
341 src_type_bits
= type_bits(nir_op
.input_types
[i
])
342 if src_type_bits
!= 0:
343 assert src_bits
== src_type_bits
345 assert val
.common_size
== 0 or src_bits
== val
.common_size
346 val
.common_size
= src_bits
348 dst_type_bits
= type_bits(nir_op
.output_type
)
349 if dst_type_bits
!= 0:
350 assert val
.bit_size
== 0 or val
.bit_size
== dst_type_bits
353 if val
.common_size
!= 0:
354 assert val
.bit_size
== 0 or val
.bit_size
== val
.common_size
356 val
.common_size
= val
.bit_size
357 return val
.common_size
359 def _propagate_bit_class_down(self
, val
, bit_class
):
360 if isinstance(val
, Constant
):
361 assert val
.bit_size
== 0 or val
.bit_size
== bit_class
363 elif isinstance(val
, Variable
):
364 assert val
.bit_size
== 0 or val
.bit_size
== bit_class
365 self
._set
_var
_bit
_class
(val
.index
, bit_class
)
367 elif isinstance(val
, Expression
):
368 nir_op
= opcodes
[val
.opcode
]
369 dst_type_bits
= type_bits(nir_op
.output_type
)
370 if dst_type_bits
!= 0:
371 assert bit_class
== 0 or bit_class
== dst_type_bits
373 assert val
.common_size
== 0 or val
.common_size
== bit_class
374 val
.common_size
= bit_class
377 common_class
= val
.common_size
378 elif nir_op
.num_inputs
:
379 # If we got here then we have no idea what the actual size is.
380 # Instead, we use a generic class
381 common_class
= self
._new
_class
()
383 for i
in range(nir_op
.num_inputs
):
384 src_type_bits
= type_bits(nir_op
.input_types
[i
])
385 if src_type_bits
!= 0:
386 self
._propagate
_bit
_class
_down
(val
.sources
[i
], src_type_bits
)
388 self
._propagate
_bit
_class
_down
(val
.sources
[i
], common_class
)
390 def _validate_bit_class_up(self
, val
):
391 if isinstance(val
, Constant
):
394 elif isinstance(val
, Variable
):
395 var_class
= self
._get
_var
_bit
_class
(val
.index
)
396 # By the time we get to validation, every variable should have a class
397 assert var_class
!= 0
399 # If we have an explicit size provided by the user, the variable
400 # *must* exactly match the search. It cannot be implicitly sized
401 # because otherwise we could end up with a conflict at runtime.
402 assert val
.bit_size
== 0 or val
.bit_size
== var_class
406 elif isinstance(val
, Expression
):
407 nir_op
= opcodes
[val
.opcode
]
409 for i
in range(nir_op
.num_inputs
):
410 src_class
= self
._validate
_bit
_class
_up
(val
.sources
[i
])
414 src_type_bits
= type_bits(nir_op
.input_types
[i
])
415 if src_type_bits
!= 0:
416 assert src_class
== src_type_bits
418 assert val
.common_class
== 0 or src_class
== val
.common_class
419 val
.common_class
= src_class
421 dst_type_bits
= type_bits(nir_op
.output_type
)
422 if dst_type_bits
!= 0:
423 assert val
.bit_size
== 0 or val
.bit_size
== dst_type_bits
426 if val
.common_class
!= 0:
427 assert val
.bit_size
== 0 or val
.bit_size
== val
.common_class
429 val
.common_class
= val
.bit_size
430 return val
.common_class
432 def _validate_bit_class_down(self
, val
, bit_class
):
433 # At this point, everything *must* have a bit class. Otherwise, we have
434 # a value we don't know how to define.
435 assert bit_class
!= 0
437 if isinstance(val
, Constant
):
438 assert val
.bit_size
== 0 or val
.bit_size
== bit_class
440 elif isinstance(val
, Variable
):
441 assert val
.bit_size
== 0 or val
.bit_size
== bit_class
443 elif isinstance(val
, Expression
):
444 nir_op
= opcodes
[val
.opcode
]
445 dst_type_bits
= type_bits(nir_op
.output_type
)
446 if dst_type_bits
!= 0:
447 assert bit_class
== dst_type_bits
449 assert val
.common_class
== 0 or val
.common_class
== bit_class
450 val
.common_class
= bit_class
452 for i
in range(nir_op
.num_inputs
):
453 src_type_bits
= type_bits(nir_op
.input_types
[i
])
454 if src_type_bits
!= 0:
455 self
._validate
_bit
_class
_down
(val
.sources
[i
], src_type_bits
)
457 self
._validate
_bit
_class
_down
(val
.sources
[i
], val
.common_class
)
459 _optimization_ids
= itertools
.count()
461 condition_list
= ['true']
463 class SearchAndReplace(object):
464 def __init__(self
, transform
):
465 self
.id = _optimization_ids
.next()
467 search
= transform
[0]
468 replace
= transform
[1]
469 if len(transform
) > 2:
470 self
.condition
= transform
[2]
472 self
.condition
= 'true'
474 if self
.condition
not in condition_list
:
475 condition_list
.append(self
.condition
)
476 self
.condition_index
= condition_list
.index(self
.condition
)
479 if isinstance(search
, Expression
):
482 self
.search
= Expression(search
, "search{0}".format(self
.id), varset
)
486 if isinstance(replace
, Value
):
487 self
.replace
= replace
489 self
.replace
= Value
.create(replace
, "replace{0}".format(self
.id), varset
)
491 BitSizeValidator(varset
).validate(self
.search
, self
.replace
)
493 _algebraic_pass_template
= mako
.template
.Template("""
495 #include "nir_search.h"
497 #ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
498 #define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
501 const nir_search_expression *search;
502 const nir_search_value *replace;
503 unsigned condition_offset;
509 const bool *condition_flags;
514 % for (opcode, xform_list) in xform_dict.iteritems():
515 % for xform in xform_list:
516 ${xform.search.render()}
517 ${xform.replace.render()}
520 static const struct transform ${pass_name}_${opcode}_xforms[] = {
521 % for xform in xform_list:
522 { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} },
528 ${pass_name}_block(nir_block *block, void *void_state)
530 struct opt_state *state = void_state;
532 nir_foreach_instr_reverse_safe(block, instr) {
533 if (instr->type != nir_instr_type_alu)
536 nir_alu_instr *alu = nir_instr_as_alu(instr);
537 if (!alu->dest.dest.is_ssa)
541 % for opcode in xform_dict.keys():
542 case nir_op_${opcode}:
543 for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
544 const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
545 if (state->condition_flags[xform->condition_offset] &&
546 nir_replace_instr(alu, xform->search, xform->replace,
548 state->progress = true;
563 ${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
565 struct opt_state state;
567 state.mem_ctx = ralloc_parent(impl);
568 state.progress = false;
569 state.condition_flags = condition_flags;
571 nir_foreach_block_reverse_call(impl, ${pass_name}_block, &state);
574 nir_metadata_preserve(impl, nir_metadata_block_index |
575 nir_metadata_dominance);
577 return state.progress;
582 ${pass_name}(nir_shader *shader)
584 bool progress = false;
585 bool condition_flags[${len(condition_list)}];
586 const nir_shader_compiler_options *options = shader->options;
589 % for index, condition in enumerate(condition_list):
590 condition_flags[${index}] = ${condition};
593 nir_foreach_function(shader, function) {
595 progress |= ${pass_name}_impl(function->impl, condition_flags);
602 class AlgebraicPass(object):
603 def __init__(self
, pass_name
, transforms
):
605 self
.pass_name
= pass_name
609 for xform
in transforms
:
610 if not isinstance(xform
, SearchAndReplace
):
612 xform
= SearchAndReplace(xform
)
614 print("Failed to parse transformation:", file=sys
.stderr
)
615 print(" " + str(xform
), file=sys
.stderr
)
616 traceback
.print_exc(file=sys
.stderr
)
617 print('', file=sys
.stderr
)
621 if xform
.search
.opcode
not in self
.xform_dict
:
622 self
.xform_dict
[xform
.search
.opcode
] = []
624 self
.xform_dict
[xform
.search
.opcode
].append(xform
)
630 return _algebraic_pass_template
.render(pass_name
=self
.pass_name
,
631 xform_dict
=self
.xform_dict
,
632 condition_list
=condition_list
)