+ if self.opcode in conv_opcode_types:
+ assert self._bit_size is None, \
+ 'Expression cannot use an unsized conversion opcode with ' \
+ 'an explicit size; that\'s silly.'
+
+ self.__index_comm_exprs(0)
+
+ def equivalent(self, other):
+ """Check that two variables are equivalent.
+
+ This is check is much weaker than equality. One generally cannot be
+ used in place of the other. Using this implementation for the __eq__
+ will break BitSizeValidator.
+
+ This implementation does not check for equivalence due to commutativity,
+ but it could.
+
+ """
+ if not isinstance(other, type(self)):
+ return False
+
+ if len(self.sources) != len(other.sources):
+ return False
+
+ if self.opcode != other.opcode:
+ return False
+
+ return all(s.equivalent(o) for s, o in zip(self.sources, other.sources))
+
+ def __index_comm_exprs(self, base_idx):
+ """Recursively count and index commutative expressions
+ """
+ self.comm_exprs = 0
+
+ # A note about the explicit "len(self.sources)" check. The list of
+ # sources comes from user input, and that input might be bad. Check
+ # that the expected second source exists before accessing it. Without
+ # this check, a unit test that does "('iadd', 'a')" will crash.
+ if self.opcode not in conv_opcode_types and \
+ "2src_commutative" in opcodes[self.opcode].algebraic_properties and \
+ len(self.sources) >= 2 and \
+ not self.sources[0].equivalent(self.sources[1]):
+ self.comm_expr_idx = base_idx
+ self.comm_exprs += 1
+ else:
+ self.comm_expr_idx = -1
+
+ for s in self.sources:
+ if isinstance(s, Expression):
+ s.__index_comm_exprs(base_idx + self.comm_exprs)
+ self.comm_exprs += s.comm_exprs
+
+ return self.comm_exprs
+
+ def c_opcode(self):
+ return get_c_opcode(self.opcode)
+
+ def render(self, cache):
+ srcs = "\n".join(src.render(cache) for src in self.sources)
+ return srcs + super(Expression, self).render(cache)
+
+class BitSizeValidator(object):
+ """A class for validating bit sizes of expressions.
+
+ NIR supports multiple bit-sizes on expressions in order to handle things
+ such as fp64. The source and destination of every ALU operation is
+ assigned a type and that type may or may not specify a bit size. Sources
+ and destinations whose type does not specify a bit size are considered
+ "unsized" and automatically take on the bit size of the corresponding
+ register or SSA value. NIR has two simple rules for bit sizes that are
+ validated by nir_validator:
+
+ 1) A given SSA def or register has a single bit size that is respected by
+ everything that reads from it or writes to it.
+
+ 2) The bit sizes of all unsized inputs/outputs on any given ALU
+ instruction must match. They need not match the sized inputs or
+ outputs but they must match each other.
+
+ In order to keep nir_algebraic relatively simple and easy-to-use,
+ nir_search supports a type of bit-size inference based on the two rules
+ above. This is similar to type inference in many common programming
+ languages. If, for instance, you are constructing an add operation and you
+ know the second source is 16-bit, then you know that the other source and
+ the destination must also be 16-bit. There are, however, cases where this
+ inference can be ambiguous or contradictory. Consider, for instance, the
+ following transformation:
+
+ (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
+
+ This transformation can potentially cause a problem because usub_borrow is
+ well-defined for any bit-size of integer. However, b2i always generates a
+ 32-bit result so it could end up replacing a 64-bit expression with one
+ that takes two 64-bit values and produces a 32-bit value. As another
+ example, consider this expression:
+
+ (('bcsel', a, b, 0), ('iand', a, b))
+
+ In this case, in the search expression a must be 32-bit but b can
+ potentially have any bit size. If we had a 64-bit b value, we would end up
+ trying to and a 32-bit value with a 64-bit value which would be invalid
+
+ This class solves that problem by providing a validation layer that proves
+ that a given search-and-replace operation is 100% well-defined before we
+ generate any code. This ensures that bugs are caught at compile time
+ rather than at run time.
+
+ Each value maintains a "bit-size class", which is either an actual bit size
+ or an equivalence class with other values that must have the same bit size.
+ The validator works by combining bit-size classes with each other according
+ to the NIR rules outlined above, checking that there are no inconsistencies.
+ When doing this for the replacement expression, we make sure to never change
+ the equivalence class of any of the search values. We could make the example
+ transforms above work by doing some extra run-time checking of the search
+ expression, but we make the user specify those constraints themselves, to
+ avoid any surprises. Since the replacement bitsizes can only be connected to
+ the source bitsize via variables (variables must have the same bitsize in
+ the source and replacment expressions) or the roots of the expression (the
+ replacement expression must produce the same bit size as the search
+ expression), we prevent merging a variable with anything when processing the
+ replacement expression, or specializing the search bitsize
+ with anything. The former prevents
+
+ (('bcsel', a, b, 0), ('iand', a, b))
+
+ from being allowed, since we'd have to merge the bitsizes for a and b due to
+ the 'iand', while the latter prevents
+
+ (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
+
+ from being allowed, since the search expression has the bit size of a and b,
+ which can't be specialized to 32 which is the bitsize of the replace
+ expression. It also prevents something like:
+
+ (('b2i', ('i2b', a)), ('ineq', a, 0))
+
+ since the bitsize of 'b2i', which can be anything, can't be specialized to
+ the bitsize of a.
+
+ After doing all this, we check that every subexpression of the replacement
+ was assigned a constant bitsize, the bitsize of a variable, or the bitsize
+ of the search expresssion, since those are the things that are known when
+ constructing the replacement expresssion. Finally, we record the bitsize
+ needed in nir_search_value so that we know what to do when building the
+ replacement expression.
+ """
+
+ def __init__(self, varset):
+ self._var_classes = [None] * len(varset.names)
+
+ def compare_bitsizes(self, a, b):
+ """Determines which bitsize class is a specialization of the other, or
+ whether neither is. When we merge two different bitsizes, the
+ less-specialized bitsize always points to the more-specialized one, so
+ that calling get_bit_size() always gets you the most specialized bitsize.
+ The specialization partial order is given by:
+ - Physical bitsizes are always the most specialized, and a different
+ bitsize can never specialize another.
+ - In the search expression, variables can always be specialized to each
+ other and to physical bitsizes. In the replace expression, we disallow
+ this to avoid adding extra constraints to the search expression that
+ the user didn't specify.
+ - Expressions and constants without a bitsize can always be specialized to
+ each other and variables, but not the other way around.
+
+ We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b,
+ and None if they are not comparable (neither a <= b nor b <= a).
+ """
+ if isinstance(a, int):
+ if isinstance(b, int):
+ return 0 if a == b else None
+ elif isinstance(b, Variable):
+ return -1 if self.is_search else None
+ else:
+ return -1
+ elif isinstance(a, Variable):
+ if isinstance(b, int):
+ return 1 if self.is_search else None
+ elif isinstance(b, Variable):
+ return 0 if self.is_search or a.index == b.index else None
+ else:
+ return -1
+ else:
+ if isinstance(b, int):
+ return 1
+ elif isinstance(b, Variable):
+ return 1
+ else:
+ return 0
+
+ def unify_bit_size(self, a, b, error_msg):
+ """Record that a must have the same bit-size as b. If both
+ have been assigned conflicting physical bit-sizes, call "error_msg" with
+ the bit-sizes of self and other to get a message and raise an error.
+ In the replace expression, disallow merging variables with other
+ variables and physical bit-sizes as well.
+ """
+ a_bit_size = a.get_bit_size()
+ b_bit_size = b if isinstance(b, int) else b.get_bit_size()
+
+ cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size)
+
+ assert cmp_result is not None, \
+ error_msg(a_bit_size, b_bit_size)
+
+ if cmp_result < 0:
+ b_bit_size.set_bit_size(a)
+ elif not isinstance(a_bit_size, int):
+ a_bit_size.set_bit_size(b)
+
+ def merge_variables(self, val):
+ """Perform the first part of type inference by merging all the different
+ uses of the same variable. We always do this as if we're in the search
+ expression, even if we're actually not, since otherwise we'd get errors
+ if the search expression specified some constraint but the replace
+ expression didn't, because we'd be merging a variable and a constant.
+ """
+ if isinstance(val, Variable):
+ if self._var_classes[val.index] is None:
+ self._var_classes[val.index] = val
+ else:
+ other = self._var_classes[val.index]
+ self.unify_bit_size(other, val,
+ lambda other_bit_size, bit_size:
+ 'Variable {} has conflicting bit size requirements: ' \
+ 'it must have bit size {} and {}'.format(
+ val.var_name, other_bit_size, bit_size))
+ elif isinstance(val, Expression):
+ for src in val.sources:
+ self.merge_variables(src)
+
+ def validate_value(self, val):
+ """Validate the an expression by performing classic Hindley-Milner
+ type inference on bitsizes. This will detect if there are any conflicting
+ requirements, and unify variables so that we know which variables must
+ have the same bitsize. If we're operating on the replace expression, we
+ will refuse to merge different variables together or merge a variable
+ with a constant, in order to prevent surprises due to rules unexpectedly
+ not matching at runtime.
+ """
+ if not isinstance(val, Expression):
+ return
+
+ # Generic conversion ops are special in that they have a single unsized
+ # source and an unsized destination and the two don't have to match.
+ # This means there's no validation or unioning to do here besides the
+ # len(val.sources) check.
+ if val.opcode in conv_opcode_types:
+ assert len(val.sources) == 1, \
+ "Expression {} has {} sources, expected 1".format(
+ val, len(val.sources))
+ self.validate_value(val.sources[0])
+ return
+
+ nir_op = opcodes[val.opcode]
+ assert len(val.sources) == nir_op.num_inputs, \
+ "Expression {} has {} sources, expected {}".format(
+ val, len(val.sources), nir_op.num_inputs)
+
+ for src in val.sources:
+ self.validate_value(src)
+
+ dst_type_bits = type_bits(nir_op.output_type)
+
+ # First, unify all the sources. That way, an error coming up because two
+ # sources have an incompatible bit-size won't produce an error message
+ # involving the destination.
+ first_unsized_src = None
+ for src_type, src in zip(nir_op.input_types, val.sources):
+ src_type_bits = type_bits(src_type)
+ if src_type_bits == 0:
+ if first_unsized_src is None:
+ first_unsized_src = src
+ continue
+
+ if self.is_search:
+ self.unify_bit_size(first_unsized_src, src,
+ lambda first_unsized_src_bit_size, src_bit_size:
+ 'Source {} of {} must have bit size {}, while source {} ' \
+ 'must have incompatible bit size {}'.format(
+ first_unsized_src, val, first_unsized_src_bit_size,
+ src, src_bit_size))
+ else:
+ self.unify_bit_size(first_unsized_src, src,
+ lambda first_unsized_src_bit_size, src_bit_size:
+ 'Sources {} (bit size of {}) and {} (bit size of {}) ' \
+ 'of {} may not have the same bit size when building the ' \
+ 'replacement expression.'.format(
+ first_unsized_src, first_unsized_src_bit_size, src,
+ src_bit_size, val))
+ else:
+ if self.is_search:
+ self.unify_bit_size(src, src_type_bits,
+ lambda src_bit_size, unused:
+ '{} must have {} bits, but as a source of nir_op_{} '\
+ 'it must have {} bits'.format(
+ src, src_bit_size, nir_op.name, src_type_bits))
+ else:
+ self.unify_bit_size(src, src_type_bits,
+ lambda src_bit_size, unused:
+ '{} has the bit size of {}, but as a source of ' \
+ 'nir_op_{} it must have {} bits, which may not be the ' \
+ 'same'.format(
+ src, src_bit_size, nir_op.name, src_type_bits))
+
+ if dst_type_bits == 0:
+ if first_unsized_src is not None:
+ if self.is_search:
+ self.unify_bit_size(val, first_unsized_src,
+ lambda val_bit_size, src_bit_size:
+ '{} must have the bit size of {}, while its source {} ' \
+ 'must have incompatible bit size {}'.format(
+ val, val_bit_size, first_unsized_src, src_bit_size))
+ else:
+ self.unify_bit_size(val, first_unsized_src,
+ lambda val_bit_size, src_bit_size:
+ '{} must have {} bits, but its source {} ' \
+ '(bit size of {}) may not have that bit size ' \
+ 'when building the replacement.'.format(
+ val, val_bit_size, first_unsized_src, src_bit_size))
+ else:
+ self.unify_bit_size(val, dst_type_bits,
+ lambda dst_bit_size, unused:
+ '{} must have {} bits, but as a destination of nir_op_{} ' \
+ 'it must have {} bits'.format(
+ val, dst_bit_size, nir_op.name, dst_type_bits))
+
+ def validate_replace(self, val, search):
+ bit_size = val.get_bit_size()
+ assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \
+ bit_size == search.get_bit_size(), \
+ 'Ambiguous bit size for replacement value {}: ' \
+ 'it cannot be deduced from a variable, a fixed bit size ' \
+ 'somewhere, or the search expression.'.format(val)
+
+ if isinstance(val, Expression):
+ for src in val.sources:
+ self.validate_replace(src, search)
+
+ def validate(self, search, replace):
+ self.is_search = True
+ self.merge_variables(search)
+ self.merge_variables(replace)
+ self.validate_value(search)
+
+ self.is_search = False
+ self.validate_value(replace)
+
+ # Check that search is always more specialized than replace. Note that
+ # we're doing this in replace mode, disallowing merging variables.
+ search_bit_size = search.get_bit_size()
+ replace_bit_size = replace.get_bit_size()
+ cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size)
+
+ assert cmp_result is not None and cmp_result <= 0, \
+ 'The search expression bit size {} and replace expression ' \
+ 'bit size {} may not be the same'.format(
+ search_bit_size, replace_bit_size)
+
+ replace.set_bit_size(search)
+
+ self.validate_replace(replace, search)