self.assertEqual(
repr(fn_analysis.all_program_points),
"<range:ops[0]:Early..ops[7]:Early>")
+ self.assertEqual(repr(fn_analysis.copies), "FMap({})")
+ self.assertEqual(
+ repr(fn_analysis.const_ssa_vals),
+ "FMap({"
+ "<vl.outputs[0]: <VL_MAXVL>>: (32,), "
+ "<li.outputs[0]: <I64*32>>: ("
+ "0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, "
+ "0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), "
+ "<ca.outputs[0]: <CA>>: (1,)})"
+ )
+ self.assertEqual(
+ repr(fn_analysis.const_ssa_val_sub_regs),
+ "FMap({"
+ "<vl.outputs[0]: <VL_MAXVL>>[0]: 32, "
+ "<li.outputs[0]: <I64*32>>[0]: 0, "
+ "<li.outputs[0]: <I64*32>>[1]: 0, "
+ "<li.outputs[0]: <I64*32>>[2]: 0, "
+ "<li.outputs[0]: <I64*32>>[3]: 0, "
+ "<li.outputs[0]: <I64*32>>[4]: 0, "
+ "<li.outputs[0]: <I64*32>>[5]: 0, "
+ "<li.outputs[0]: <I64*32>>[6]: 0, "
+ "<li.outputs[0]: <I64*32>>[7]: 0, "
+ "<li.outputs[0]: <I64*32>>[8]: 0, "
+ "<li.outputs[0]: <I64*32>>[9]: 0, "
+ "<li.outputs[0]: <I64*32>>[10]: 0, "
+ "<li.outputs[0]: <I64*32>>[11]: 0, "
+ "<li.outputs[0]: <I64*32>>[12]: 0, "
+ "<li.outputs[0]: <I64*32>>[13]: 0, "
+ "<li.outputs[0]: <I64*32>>[14]: 0, "
+ "<li.outputs[0]: <I64*32>>[15]: 0, "
+ "<li.outputs[0]: <I64*32>>[16]: 0, "
+ "<li.outputs[0]: <I64*32>>[17]: 0, "
+ "<li.outputs[0]: <I64*32>>[18]: 0, "
+ "<li.outputs[0]: <I64*32>>[19]: 0, "
+ "<li.outputs[0]: <I64*32>>[20]: 0, "
+ "<li.outputs[0]: <I64*32>>[21]: 0, "
+ "<li.outputs[0]: <I64*32>>[22]: 0, "
+ "<li.outputs[0]: <I64*32>>[23]: 0, "
+ "<li.outputs[0]: <I64*32>>[24]: 0, "
+ "<li.outputs[0]: <I64*32>>[25]: 0, "
+ "<li.outputs[0]: <I64*32>>[26]: 0, "
+ "<li.outputs[0]: <I64*32>>[27]: 0, "
+ "<li.outputs[0]: <I64*32>>[28]: 0, "
+ "<li.outputs[0]: <I64*32>>[29]: 0, "
+ "<li.outputs[0]: <I64*32>>[30]: 0, "
+ "<li.outputs[0]: <I64*32>>[31]: 0, "
+ "<ca.outputs[0]: <CA>>[0]: 1})"
+ )
def test_repr(self):
fn, _arg = self.make_add_fn()
" <spread.outputs[1]: <I64>>, <spread.outputs[0]: <I64>>,\n"
" <vl.outputs[0]: <VL_MAXVL>>)"
)
+ fn_analysis = FnAnalysis(fn)
+ self.assertEqual(
+ repr(fn_analysis.copies),
+ "FMap({"
+ "<spread.outputs[0]: <I64>>[0]: <li.outputs[0]: <I64*4>>[0], "
+ "<spread.outputs[1]: <I64>>[0]: <li.outputs[0]: <I64*4>>[1], "
+ "<spread.outputs[2]: <I64>>[0]: <li.outputs[0]: <I64*4>>[2], "
+ "<spread.outputs[3]: <I64>>[0]: <li.outputs[0]: <I64*4>>[3], "
+ "<concat.outputs[0]: <I64*4>>[0]: <li.outputs[0]: <I64*4>>[3], "
+ "<concat.outputs[0]: <I64*4>>[1]: <li.outputs[0]: <I64*4>>[2], "
+ "<concat.outputs[0]: <I64*4>>[2]: <li.outputs[0]: <I64*4>>[1], "
+ "<concat.outputs[0]: <I64*4>>[3]: <li.outputs[0]: <I64*4>>[0]})"
+ )
+ self.assertEqual(
+ repr(fn_analysis.const_ssa_val_sub_regs),
+ "FMap({"
+ "<vl.outputs[0]: <VL_MAXVL>>[0]: 4, "
+ "<li.outputs[0]: <I64*4>>[0]: 0, "
+ "<li.outputs[0]: <I64*4>>[1]: 0, "
+ "<li.outputs[0]: <I64*4>>[2]: 0, "
+ "<li.outputs[0]: <I64*4>>[3]: 0, "
+ "<spread.outputs[0]: <I64>>[0]: 0, "
+ "<spread.outputs[1]: <I64>>[0]: 0, "
+ "<spread.outputs[2]: <I64>>[0]: 0, "
+ "<spread.outputs[3]: <I64>>[0]: 0, "
+ "<concat.outputs[0]: <I64*4>>[0]: 0, "
+ "<concat.outputs[0]: <I64*4>>[1]: 0, "
+ "<concat.outputs[0]: <I64*4>>[2]: 0, "
+ "<concat.outputs[0]: <I64*4>>[3]: 0})"
+ )
if __name__ == "__main__":
return f"<range:{start}..{stop}>"
-@plain_data(frozen=True, unsafe_hash=True)
+@plain_data(frozen=True, unsafe_hash=True, repr=False)
@final
class SSAValSubReg(metaclass=InternedMeta):
__slots__ = "ssa_val", "reg_idx"
self.ssa_val = ssa_val
self.reg_idx = reg_idx
+ def __repr__(self):
+ # type: () -> str
+ return f"{self.ssa_val}[{self.reg_idx}]"
+
@plain_data(frozen=True, eq=False, repr=False)
@final
retval[SSAValSubReg(ssa_val, reg_idx)] = v
return FMap(retval)
- def are_always_equal(self, a, b):
+ def is_always_equal(self, a, b):
# type: (SSAValSubReg, SSAValSubReg) -> bool
"""check if a and b are known to be always equal to each other.
This means they can be allocated to the same location if other
@abstractmethod
def __setitem__(self, ssa_val, value):
- # type: (SSAVal, tuple[int, ...]) -> None
+ # type: (SSAVal, Iterable[int]) -> None
...
raise KeyError("SSAVal has no value set", ssa_val)
def __setitem__(self, ssa_val, value):
- # type: (SSAVal, tuple[int, ...]) -> None
+ # type: (SSAVal, Iterable[int]) -> None
+ value = tuple(map(int, value))
if len(value) != ssa_val.ty.reg_len:
raise ValueError("value has wrong len")
self.ssa_vals[ssa_val] = value
return tuple(retval)
def __setitem__(self, ssa_val, value):
- # type: (SSAVal, tuple[int, ...]) -> None
+ # type: (SSAVal, Iterable[int]) -> None
+ value = tuple(map(int, value))
if len(value) != ssa_val.ty.reg_len:
raise ValueError("value has wrong len")
loc = self.ssa_val_to_loc_map[ssa_val]
from nmutil.plain_data import plain_data
from bigint_presentation_code.compiler_ir import (BaseTy, Fn, FnAnalysis, Loc,
- LocSet, ProgramRange, SSAVal,
- Ty)
+ LocSet, Op, ProgramRange,
+ SSAVal, SSAValSubReg, Ty)
from bigint_presentation_code.type_util import final
from bigint_presentation_code.util import FMap, InternedMeta, OFSet, OSet
spread arguments are one of the things that can force two values to
illegally overlap.
"""
- # pick an arbitrary Loc, any Loc will do
- loc = self.first_loc
- ops = sorted(OSet(i.op for i in self.ssa_vals),
- key=self.fn_analysis.op_indexes.__getitem__)
- vals = {} # type: dict[Loc, tuple[SSAVal, int]]
+ ops = OSet() # type: Iterable[Op]
+ for ssa_val in self.ssa_vals:
+ ops.add(ssa_val.op)
+ for use in self.fn_analysis.uses[ssa_val]:
+ ops.add(use.op)
+ ops = sorted(ops, key=self.fn_analysis.op_indexes.__getitem__)
+ vals = {} # type: dict[int, SSAValSubReg]
for op in ops:
for inp in op.input_vals:
- pass
- # FIXME: finish checking using FnAnalysis.are_always_equal
- # also check that two different outputs of the same
- # instruction aren't merged
+ try:
+ ssa_val_offset = self.ssa_val_offsets[inp]
+ except KeyError:
+ continue
+ for orig_reg in inp.ssa_val_sub_regs:
+ reg_offset = ssa_val_offset + orig_reg.reg_idx
+ replaced_reg = vals[reg_offset]
+ if not self.fn_analysis.is_always_equal(
+ orig_reg, replaced_reg):
+ raise BadMergedSSAVal(
+ f"attempting to merge values that aren't known to "
+ f"be always equal: {orig_reg} != {replaced_reg}")
+ output_offsets = dict.fromkeys(range(
+ self.offset, self.offset + self.ty.reg_len))
+ for out in op.outputs:
+ try:
+ ssa_val_offset = self.ssa_val_offsets[out]
+ except KeyError:
+ continue
+ for reg in out.ssa_val_sub_regs:
+ reg_offset = ssa_val_offset + reg.reg_idx
+ try:
+ del output_offsets[reg_offset]
+ except KeyError:
+ raise BadMergedSSAVal("attempted to merge two outputs "
+ "of the same instruction")
+ vals[reg_offset] = reg
@cached_property
def __hash(self):