self._class_relation = IntEquivalenceRelation()
def validate(self, search, replace):
- dst_class = self._propagate_bit_size_up(search)
- if dst_class == 0:
- dst_class = self._new_class()
- self._propagate_bit_class_down(search, dst_class)
+ search_dst_class = self._propagate_bit_size_up(search)
+ if search_dst_class == 0:
+ search_dst_class = self._new_class()
+ self._propagate_bit_class_down(search, search_dst_class)
- validate_dst_class = self._validate_bit_class_up(replace)
- assert validate_dst_class == 0 or validate_dst_class == dst_class
- self._validate_bit_class_down(replace, dst_class)
+ replace_dst_class = self._validate_bit_class_up(replace)
+ assert replace_dst_class == 0 or replace_dst_class == search_dst_class
+ self._validate_bit_class_down(replace, search_dst_class)
def _new_class(self):
self._num_classes += 1
return -self._num_classes
- def _set_var_bit_class(self, var_id, bit_class):
+ def _set_var_bit_class(self, var, bit_class):
assert bit_class != 0
- var_class = self._var_classes[var_id]
+ var_class = self._var_classes[var.index]
if var_class == 0:
- self._var_classes[var_id] = bit_class
+ self._var_classes[var.index] = bit_class
else:
canon_class = self._class_relation.get_canonical(var_class)
assert canon_class < 0 or canon_class == bit_class
var_class = self._class_relation.add_equiv(var_class, bit_class)
- self._var_classes[var_id] = var_class
+ self._var_classes[var.index] = var_class
- def _get_var_bit_class(self, var_id):
- return self._class_relation.get_canonical(self._var_classes[var_id])
+ def _get_var_bit_class(self, var):
+ return self._class_relation.get_canonical(self._var_classes[var.index])
def _propagate_bit_size_up(self, val):
if isinstance(val, (Constant, Variable)):
elif isinstance(val, Variable):
assert val.bit_size == 0 or val.bit_size == bit_class
- self._set_var_bit_class(val.index, bit_class)
+ self._set_var_bit_class(val, bit_class)
elif isinstance(val, Expression):
nir_op = opcodes[val.opcode]
return val.bit_size
elif isinstance(val, Variable):
- var_class = self._get_var_bit_class(val.index)
+ var_class = self._get_var_bit_class(val)
# By the time we get to validation, every variable should have a class
assert var_class != 0