op1 = fn.append_new_op(OpKind.SetVLI, immediates=[MAXVL], name="vl")
vl = op1.outputs[0]
op2 = fn.append_new_op(
- OpKind.SvLd, inputs=[arg, vl], immediates=[0], maxvl=MAXVL,
+ OpKind.SvLd, input_vals=[arg, vl], immediates=[0], maxvl=MAXVL,
name="ld")
a = op2.outputs[0]
- op3 = fn.append_new_op(
- OpKind.SvLI, inputs=[vl], immediates=[0], maxvl=MAXVL, name="li")
+ op3 = fn.append_new_op(OpKind.SvLI, input_vals=[vl], immediates=[0],
+ maxvl=MAXVL, name="li")
b = op3.outputs[0]
op4 = fn.append_new_op(OpKind.SetCA, name="ca")
ca = op4.outputs[0]
op5 = fn.append_new_op(
- OpKind.SvAddE, inputs=[a, b, ca, vl], maxvl=MAXVL, name="add")
+ OpKind.SvAddE, input_vals=[a, b, ca, vl], maxvl=MAXVL, name="add")
s = op5.outputs[0]
- fn.append_new_op(
- OpKind.SvStd, inputs=[s, arg, vl], immediates=[0], maxvl=MAXVL,
- name="st")
+ fn.append_new_op(OpKind.SvStd, input_vals=[s, arg, vl],
+ immediates=[0], maxvl=MAXVL, name="st")
return fn, arg
def test_repr(self):
fn, _arg = self.make_add_fn()
self.assertEqual([repr(i) for i in fn.ops], [
"Op(kind=OpKind.FuncArgR3, "
- "inputs=[], "
+ "input_vals=[], "
+ "input_uses=(), "
"immediates=[], "
"outputs=(<arg.outputs[0]: <I64>>,), name='arg')",
"Op(kind=OpKind.SetVLI, "
- "inputs=[], "
+ "input_vals=[], "
+ "input_uses=(), "
"immediates=[32], "
"outputs=(<vl.outputs[0]: <VL_MAXVL>>,), name='vl')",
"Op(kind=OpKind.SvLd, "
- "inputs=[<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>], "
+ "input_vals=[<arg.outputs[0]: <I64>>, "
+ "<vl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<ld.input_uses[0]: <I64>>, "
+ "<ld.input_uses[1]: <VL_MAXVL>>), "
"immediates=[0], "
"outputs=(<ld.outputs[0]: <I64*32>>,), name='ld')",
"Op(kind=OpKind.SvLI, "
- "inputs=[<vl.outputs[0]: <VL_MAXVL>>], "
+ "input_vals=[<vl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<li.input_uses[0]: <VL_MAXVL>>,), "
"immediates=[0], "
"outputs=(<li.outputs[0]: <I64*32>>,), name='li')",
"Op(kind=OpKind.SetCA, "
- "inputs=[], "
+ "input_vals=[], "
+ "input_uses=(), "
"immediates=[], "
"outputs=(<ca.outputs[0]: <CA>>,), name='ca')",
"Op(kind=OpKind.SvAddE, "
- "inputs=[<ld.outputs[0]: <I64*32>>, <li.outputs[0]: <I64*32>>, "
- "<ca.outputs[0]: <CA>>, <vl.outputs[0]: <VL_MAXVL>>], "
+ "input_vals=[<ld.outputs[0]: <I64*32>>, "
+ "<li.outputs[0]: <I64*32>>, <ca.outputs[0]: <CA>>, "
+ "<vl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<add.input_uses[0]: <I64*32>>, "
+ "<add.input_uses[1]: <I64*32>>, <add.input_uses[2]: <CA>>, "
+ "<add.input_uses[3]: <VL_MAXVL>>), "
"immediates=[], "
"outputs=(<add.outputs[0]: <I64*32>>, <add.outputs[1]: <CA>>), "
"name='add')",
"Op(kind=OpKind.SvStd, "
- "inputs=[<add.outputs[0]: <I64*32>>, <arg.outputs[0]: <I64>>, "
+ "input_vals=[<add.outputs[0]: <I64*32>>, <arg.outputs[0]: <I64>>, "
"<vl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<st.input_uses[0]: <I64*32>>, "
+ "<st.input_uses[1]: <I64>>, <st.input_uses[2]: <VL_MAXVL>>), "
"immediates=[0], "
"outputs=(), name='st')",
])
fn.pre_ra_insert_copies()
self.assertEqual([repr(i) for i in fn.ops], [
"Op(kind=OpKind.FuncArgR3, "
- "inputs=[], "
+ "input_vals=[], "
+ "input_uses=(), "
"immediates=[], "
"outputs=(<arg.outputs[0]: <I64>>,), name='arg')",
"Op(kind=OpKind.CopyFromReg, "
- "inputs=[<arg.outputs[0]: <I64>>], "
+ "input_vals=[<arg.outputs[0]: <I64>>], "
+ "input_uses=(<arg.out0.copy.input_uses[0]: <I64>>,), "
"immediates=[], "
- "outputs=(<2.outputs[0]: <I64>>,), name='2')",
+ "outputs=(<arg.out0.copy.outputs[0]: <I64>>,), "
+ "name='arg.out0.copy')",
"Op(kind=OpKind.SetVLI, "
- "inputs=[], "
+ "input_vals=[], "
+ "input_uses=(), "
"immediates=[32], "
"outputs=(<vl.outputs[0]: <VL_MAXVL>>,), name='vl')",
"Op(kind=OpKind.CopyToReg, "
- "inputs=[<2.outputs[0]: <I64>>], "
+ "input_vals=[<arg.out0.copy.outputs[0]: <I64>>], "
+ "input_uses=(<ld.inp0.copy.input_uses[0]: <I64>>,), "
"immediates=[], "
- "outputs=(<3.outputs[0]: <I64>>,), name='3')",
+ "outputs=(<ld.inp0.copy.outputs[0]: <I64>>,), name='ld.inp0.copy')",
"Op(kind=OpKind.SvLd, "
- "inputs=[<3.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>], "
+ "input_vals=[<ld.inp0.copy.outputs[0]: <I64>>, "
+ "<vl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<ld.input_uses[0]: <I64>>, "
+ "<ld.input_uses[1]: <VL_MAXVL>>), "
"immediates=[0], "
"outputs=(<ld.outputs[0]: <I64*32>>,), name='ld')",
"Op(kind=OpKind.SetVLI, "
- "inputs=[], "
+ "input_vals=[], "
+ "input_uses=(), "
"immediates=[32], "
- "outputs=(<4.outputs[0]: <VL_MAXVL>>,), name='4')",
+ "outputs=(<ld.out0.setvl.outputs[0]: <VL_MAXVL>>,), "
+ "name='ld.out0.setvl')",
"Op(kind=OpKind.VecCopyFromReg, "
- "inputs=[<ld.outputs[0]: <I64*32>>, <4.outputs[0]: <VL_MAXVL>>], "
+ "input_vals=[<ld.outputs[0]: <I64*32>>, "
+ "<ld.out0.setvl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<ld.out0.copy.input_uses[0]: <I64*32>>, "
+ "<ld.out0.copy.input_uses[1]: <VL_MAXVL>>), "
"immediates=[], "
- "outputs=(<5.outputs[0]: <I64*32>>,), name='5')",
+ "outputs=(<ld.out0.copy.outputs[0]: <I64*32>>,), "
+ "name='ld.out0.copy')",
"Op(kind=OpKind.SvLI, "
- "inputs=[<vl.outputs[0]: <VL_MAXVL>>], "
+ "input_vals=[<vl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<li.input_uses[0]: <VL_MAXVL>>,), "
"immediates=[0], "
"outputs=(<li.outputs[0]: <I64*32>>,), name='li')",
"Op(kind=OpKind.SetVLI, "
- "inputs=[], "
+ "input_vals=[], "
+ "input_uses=(), "
"immediates=[32], "
- "outputs=(<6.outputs[0]: <VL_MAXVL>>,), name='6')",
+ "outputs=(<li.out0.setvl.outputs[0]: <VL_MAXVL>>,), "
+ "name='li.out0.setvl')",
"Op(kind=OpKind.VecCopyFromReg, "
- "inputs=[<li.outputs[0]: <I64*32>>, <6.outputs[0]: <VL_MAXVL>>], "
+ "input_vals=[<li.outputs[0]: <I64*32>>, "
+ "<li.out0.setvl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<li.out0.copy.input_uses[0]: <I64*32>>, "
+ "<li.out0.copy.input_uses[1]: <VL_MAXVL>>), "
"immediates=[], "
- "outputs=(<7.outputs[0]: <I64*32>>,), name='7')",
+ "outputs=(<li.out0.copy.outputs[0]: <I64*32>>,), "
+ "name='li.out0.copy')",
"Op(kind=OpKind.SetCA, "
- "inputs=[], "
+ "input_vals=[], "
+ "input_uses=(), "
"immediates=[], "
"outputs=(<ca.outputs[0]: <CA>>,), name='ca')",
"Op(kind=OpKind.SetVLI, "
- "inputs=[], "
+ "input_vals=[], "
+ "input_uses=(), "
"immediates=[32], "
- "outputs=(<8.outputs[0]: <VL_MAXVL>>,), name='8')",
+ "outputs=(<add.inp0.setvl.outputs[0]: <VL_MAXVL>>,), "
+ "name='add.inp0.setvl')",
"Op(kind=OpKind.VecCopyToReg, "
- "inputs=[<5.outputs[0]: <I64*32>>, <8.outputs[0]: <VL_MAXVL>>], "
+ "input_vals=[<ld.out0.copy.outputs[0]: <I64*32>>, "
+ "<add.inp0.setvl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<add.inp0.copy.input_uses[0]: <I64*32>>, "
+ "<add.inp0.copy.input_uses[1]: <VL_MAXVL>>), "
"immediates=[], "
- "outputs=(<9.outputs[0]: <I64*32>>,), name='9')",
+ "outputs=(<add.inp0.copy.outputs[0]: <I64*32>>,), "
+ "name='add.inp0.copy')",
"Op(kind=OpKind.SetVLI, "
- "inputs=[], "
+ "input_vals=[], "
+ "input_uses=(), "
"immediates=[32], "
- "outputs=(<10.outputs[0]: <VL_MAXVL>>,), name='10')",
+ "outputs=(<add.inp1.setvl.outputs[0]: <VL_MAXVL>>,), "
+ "name='add.inp1.setvl')",
"Op(kind=OpKind.VecCopyToReg, "
- "inputs=[<7.outputs[0]: <I64*32>>, <10.outputs[0]: <VL_MAXVL>>], "
+ "input_vals=[<li.out0.copy.outputs[0]: <I64*32>>, "
+ "<add.inp1.setvl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<add.inp1.copy.input_uses[0]: <I64*32>>, "
+ "<add.inp1.copy.input_uses[1]: <VL_MAXVL>>), "
"immediates=[], "
- "outputs=(<11.outputs[0]: <I64*32>>,), name='11')",
+ "outputs=(<add.inp1.copy.outputs[0]: <I64*32>>,), "
+ "name='add.inp1.copy')",
"Op(kind=OpKind.SvAddE, "
- "inputs=[<9.outputs[0]: <I64*32>>, <11.outputs[0]: <I64*32>>, "
- "<ca.outputs[0]: <CA>>, <vl.outputs[0]: <VL_MAXVL>>], "
+ "input_vals=[<add.inp0.copy.outputs[0]: <I64*32>>, "
+ "<add.inp1.copy.outputs[0]: <I64*32>>, <ca.outputs[0]: <CA>>, "
+ "<vl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<add.input_uses[0]: <I64*32>>, "
+ "<add.input_uses[1]: <I64*32>>, <add.input_uses[2]: <CA>>, "
+ "<add.input_uses[3]: <VL_MAXVL>>), "
"immediates=[], "
"outputs=(<add.outputs[0]: <I64*32>>, <add.outputs[1]: <CA>>), "
"name='add')",
"Op(kind=OpKind.SetVLI, "
- "inputs=[], "
+ "input_vals=[], "
+ "input_uses=(), "
"immediates=[32], "
- "outputs=(<12.outputs[0]: <VL_MAXVL>>,), name='12')",
+ "outputs=(<add.out0.setvl.outputs[0]: <VL_MAXVL>>,), "
+ "name='add.out0.setvl')",
"Op(kind=OpKind.VecCopyFromReg, "
- "inputs=[<add.outputs[0]: <I64*32>>, "
- "<12.outputs[0]: <VL_MAXVL>>], "
+ "input_vals=[<add.outputs[0]: <I64*32>>, "
+ "<add.out0.setvl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<add.out0.copy.input_uses[0]: <I64*32>>, "
+ "<add.out0.copy.input_uses[1]: <VL_MAXVL>>), "
"immediates=[], "
- "outputs=(<13.outputs[0]: <I64*32>>,), name='13')",
+ "outputs=(<add.out0.copy.outputs[0]: <I64*32>>,), "
+ "name='add.out0.copy')",
"Op(kind=OpKind.SetVLI, "
- "inputs=[], "
+ "input_vals=[], "
+ "input_uses=(), "
"immediates=[32], "
- "outputs=(<14.outputs[0]: <VL_MAXVL>>,), name='14')",
+ "outputs=(<st.inp0.setvl.outputs[0]: <VL_MAXVL>>,), "
+ "name='st.inp0.setvl')",
"Op(kind=OpKind.VecCopyToReg, "
- "inputs=[<13.outputs[0]: <I64*32>>, <14.outputs[0]: <VL_MAXVL>>], "
+ "input_vals=[<add.out0.copy.outputs[0]: <I64*32>>, "
+ "<st.inp0.setvl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<st.inp0.copy.input_uses[0]: <I64*32>>, "
+ "<st.inp0.copy.input_uses[1]: <VL_MAXVL>>), "
"immediates=[], "
- "outputs=(<15.outputs[0]: <I64*32>>,), name='15')",
+ "outputs=(<st.inp0.copy.outputs[0]: <I64*32>>,), "
+ "name='st.inp0.copy')",
"Op(kind=OpKind.CopyToReg, "
- "inputs=[<2.outputs[0]: <I64>>], "
+ "input_vals=[<arg.out0.copy.outputs[0]: <I64>>], "
+ "input_uses=(<st.inp1.copy.input_uses[0]: <I64>>,), "
"immediates=[], "
- "outputs=(<16.outputs[0]: <I64>>,), name='16')",
+ "outputs=(<st.inp1.copy.outputs[0]: <I64>>,), "
+ "name='st.inp1.copy')",
"Op(kind=OpKind.SvStd, "
- "inputs=[<15.outputs[0]: <I64*32>>, <16.outputs[0]: <I64>>, "
- "<vl.outputs[0]: <VL_MAXVL>>], "
+ "input_vals=[<st.inp0.copy.outputs[0]: <I64*32>>, "
+ "<st.inp1.copy.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<st.input_uses[0]: <I64*32>>, "
+ "<st.input_uses[1]: <I64>>, <st.input_uses[2]: <VL_MAXVL>>), "
"immediates=[0], "
"outputs=(), name='st')",
])
raise ValueError("can't add Op to wrong Fn")
self.ops.append(op)
- def append_new_op(self, kind, inputs=(), immediates=(), name="", maxvl=1):
+ def append_new_op(self, kind, input_vals=(), immediates=(), name="",
+ maxvl=1):
# type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op
retval = Op(fn=self, properties=kind.instantiate(maxvl=maxvl),
- inputs=inputs, immediates=immediates, name=name)
+ input_vals=input_vals, immediates=immediates, name=name)
self.append_op(retval)
return retval
copied_outputs = {} # type: dict[SSAVal, SSAVal]
self.ops.clear()
for op in orig_ops:
- for i in range(len(op.inputs)):
- inp = copied_outputs[op.inputs[i]]
+ for i in range(len(op.input_vals)):
+ inp = copied_outputs[op.input_vals[i]]
if inp.ty.base_ty is BaseTy.I64:
maxvl = inp.ty.reg_len
if inp.ty.reg_len != 1:
- setvl = self.append_new_op(OpKind.SetVLI,
- immediates=[maxvl])
+ setvl = self.append_new_op(
+ OpKind.SetVLI, immediates=[maxvl],
+ name=f"{op.name}.inp{i}.setvl")
vl = setvl.outputs[0]
- mv = self.append_new_op(OpKind.VecCopyToReg,
- inputs=[inp, vl], maxvl=maxvl)
+ mv = self.append_new_op(
+ OpKind.VecCopyToReg, input_vals=[inp, vl],
+ maxvl=maxvl, name=f"{op.name}.inp{i}.copy")
else:
- mv = self.append_new_op(OpKind.CopyToReg, inputs=[inp])
- op.inputs[i] = mv.outputs[0]
+ mv = self.append_new_op(
+ OpKind.CopyToReg, input_vals=[inp],
+ name=f"{op.name}.inp{i}.copy")
+ op.input_vals[i] = mv.outputs[0]
elif inp.ty.base_ty is BaseTy.CA \
or inp.ty.base_ty is BaseTy.VL_MAXVL:
# all copies would be no-ops, so we don't need to copy
- op.inputs[i] = inp
+ op.input_vals[i] = inp
else:
assert_never(inp.ty.base_ty)
self.ops.append(op)
- for out in op.outputs:
+ for i, out in enumerate(op.outputs):
if out.ty.base_ty is BaseTy.I64:
maxvl = out.ty.reg_len
if out.ty.reg_len != 1:
- setvl = self.append_new_op(OpKind.SetVLI,
- immediates=[maxvl])
+ setvl = self.append_new_op(
+ OpKind.SetVLI, immediates=[maxvl],
+ name=f"{op.name}.out{i}.setvl")
vl = setvl.outputs[0]
- mv = self.append_new_op(OpKind.VecCopyFromReg,
- inputs=[out, vl], maxvl=maxvl)
+ mv = self.append_new_op(
+ OpKind.VecCopyFromReg, input_vals=[out, vl],
+ maxvl=maxvl, name=f"{op.name}.out{i}.copy")
else:
- mv = self.append_new_op(OpKind.CopyFromReg,
- inputs=[out])
+ mv = self.append_new_op(
+ OpKind.CopyFromReg, input_vals=[out],
+ name=f"{op.name}.out{i}.copy")
copied_outputs[out] = mv.outputs[0]
elif out.ty.base_ty is BaseTy.CA \
or out.ty.base_ty is BaseTy.VL_MAXVL:
self.fn = fn
retval = {} # type: dict[SSAVal, OSet[SSAUse]]
for op in fn.ops:
- for idx, inp in enumerate(op.inputs):
+ for idx, inp in enumerate(op.input_vals):
retval[inp].add(SSAUse(op, idx))
for out in op.outputs:
retval[out] = OSet()
def instantiate(self, maxvl):
# type: (int) -> Iterable[OperandDesc]
+ # assumes all spread operands have ty.reg_len = 1
rep_count = 1
if self.spread:
rep_count = maxvl
def ty(self):
""" Ty after any spread is applied """
if self.spread_index is not None:
+ # assumes all spread operands have ty.reg_len = 1
return Ty(base_ty=self.ty_before_spread.base_ty, reg_len=1)
return self.ty_before_spread
+ @property
+ def reg_offset_in_unspread(self):
+ """ the number of reg-sized slots in the unspread Loc before self's Loc
+
+ e.g. if the unspread Loc containing self is:
+ `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
+ and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
+ then reg_offset_into_unspread == 2 == 10 - 8
+ """
+ if self.spread_index is None:
+ return 0
+ return self.spread_index * self.ty.reg_len
+
OD_BASE_SGPR = GenericOperandDesc(
ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
@staticmethod
def __svadde_pre_ra_sim(op, state):
# type: (Op, PreRASimState) -> None
- RA = state.ssa_vals[op.inputs[0]]
- RB = state.ssa_vals[op.inputs[1]]
- carry, = state.ssa_vals[op.inputs[2]]
- VL, = state.ssa_vals[op.inputs[3]]
+ RA = state.ssa_vals[op.input_vals[0]]
+ RB = state.ssa_vals[op.input_vals[1]]
+ carry, = state.ssa_vals[op.input_vals[2]]
+ VL, = state.ssa_vals[op.input_vals[3]]
RT = [] # type: list[int]
for i in range(VL):
v = RA[i] + RB[i] + carry
@staticmethod
def __svsubfe_pre_ra_sim(op, state):
# type: (Op, PreRASimState) -> None
- RA = state.ssa_vals[op.inputs[0]]
- RB = state.ssa_vals[op.inputs[1]]
- carry, = state.ssa_vals[op.inputs[2]]
- VL, = state.ssa_vals[op.inputs[3]]
+ RA = state.ssa_vals[op.input_vals[0]]
+ RB = state.ssa_vals[op.input_vals[1]]
+ carry, = state.ssa_vals[op.input_vals[2]]
+ VL, = state.ssa_vals[op.input_vals[3]]
RT = [] # type: list[int]
for i in range(VL):
v = (~RA[i] & GPR_VALUE_MASK) + RB[i] + carry
@staticmethod
def __svmaddedu_pre_ra_sim(op, state):
# type: (Op, PreRASimState) -> None
- RA = state.ssa_vals[op.inputs[0]]
- RB, = state.ssa_vals[op.inputs[1]]
- carry, = state.ssa_vals[op.inputs[2]]
- VL, = state.ssa_vals[op.inputs[3]]
+ RA = state.ssa_vals[op.input_vals[0]]
+ RB, = state.ssa_vals[op.input_vals[1]]
+ carry, = state.ssa_vals[op.input_vals[2]]
+ VL, = state.ssa_vals[op.input_vals[3]]
RT = [] # type: list[int]
for i in range(VL):
v = RA[i] * RB + carry
@staticmethod
def __svli_pre_ra_sim(op, state):
# type: (Op, PreRASimState) -> None
- VL, = state.ssa_vals[op.inputs[0]]
+ VL, = state.ssa_vals[op.input_vals[0]]
imm = op.immediates[0] & GPR_VALUE_MASK
state.ssa_vals[op.outputs[0]] = (imm,) * VL
SvLI = GenericOpProperties(
@staticmethod
def __veccopytoreg_pre_ra_sim(op, state):
# type: (Op, PreRASimState) -> None
- state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.inputs[0]]
+ state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
VecCopyToReg = GenericOpProperties(
demo_asm="sv.mv dest, src",
inputs=[GenericOperandDesc(
@staticmethod
def __veccopyfromreg_pre_ra_sim(op, state):
# type: (Op, PreRASimState) -> None
- state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.inputs[0]]
+ state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
VecCopyFromReg = GenericOpProperties(
demo_asm="sv.mv dest, src",
inputs=[OD_EXTRA3_VGPR, OD_VL],
@staticmethod
def __copytoreg_pre_ra_sim(op, state):
# type: (Op, PreRASimState) -> None
- state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.inputs[0]]
+ state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
CopyToReg = GenericOpProperties(
demo_asm="mv dest, src",
inputs=[GenericOperandDesc(
@staticmethod
def __copyfromreg_pre_ra_sim(op, state):
# type: (Op, PreRASimState) -> None
- state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.inputs[0]]
+ state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
CopyFromReg = GenericOpProperties(
demo_asm="mv dest, src",
inputs=[GenericOperandDesc(
def __concat_pre_ra_sim(op, state):
# type: (Op, PreRASimState) -> None
state.ssa_vals[op.outputs[0]] = tuple(
- state.ssa_vals[i][0] for i in op.inputs[:-1])
+ state.ssa_vals[i][0] for i in op.input_vals[:-1])
Concat = GenericOpProperties(
demo_asm="sv.mv dest, src",
inputs=[GenericOperandDesc(
@staticmethod
def __spread_pre_ra_sim(op, state):
# type: (Op, PreRASimState) -> None
- for idx, inp in enumerate(state.ssa_vals[op.inputs[0]]):
+ for idx, inp in enumerate(state.ssa_vals[op.input_vals[0]]):
state.ssa_vals[op.outputs[idx]] = inp,
Spread = GenericOpProperties(
demo_asm="sv.mv dest, src",
@staticmethod
def __svld_pre_ra_sim(op, state):
# type: (Op, PreRASimState) -> None
- RA, = state.ssa_vals[op.inputs[0]]
- VL, = state.ssa_vals[op.inputs[1]]
+ RA, = state.ssa_vals[op.input_vals[0]]
+ VL, = state.ssa_vals[op.input_vals[1]]
addr = RA + op.immediates[0]
RT = [] # type: list[int]
for i in range(VL):
@staticmethod
def __ld_pre_ra_sim(op, state):
# type: (Op, PreRASimState) -> None
- RA, = state.ssa_vals[op.inputs[0]]
+ RA, = state.ssa_vals[op.input_vals[0]]
addr = RA + op.immediates[0]
v = state.load(addr)
state.ssa_vals[op.outputs[0]] = v & GPR_VALUE_MASK,
@staticmethod
def __svstd_pre_ra_sim(op, state):
# type: (Op, PreRASimState) -> None
- RS = state.ssa_vals[op.inputs[0]]
- RA, = state.ssa_vals[op.inputs[1]]
- VL, = state.ssa_vals[op.inputs[2]]
+ RS = state.ssa_vals[op.input_vals[0]]
+ RA, = state.ssa_vals[op.input_vals[1]]
+ VL, = state.ssa_vals[op.input_vals[2]]
addr = RA + op.immediates[0]
for i in range(VL):
state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i])
@staticmethod
def __std_pre_ra_sim(op, state):
# type: (Op, PreRASimState) -> None
- RS, = state.ssa_vals[op.inputs[0]]
- RA, = state.ssa_vals[op.inputs[1]]
+ RS, = state.ssa_vals[op.input_vals[0]]
+ RA, = state.ssa_vals[op.input_vals[1]]
addr = RA + op.immediates[0]
state.store(addr, value=RS)
Std = GenericOpProperties(
@plain_data(frozen=True, unsafe_hash=True, repr=False)
class SSAValOrUse(metaclass=ABCMeta):
- __slots__ = "op",
+ __slots__ = "op", "operand_idx"
- def __init__(self, op):
- # type: (Op) -> None
+ def __init__(self, op, operand_idx):
+ # type: (Op, int) -> None
self.op = op
+ if operand_idx < 0 or operand_idx >= len(self.descriptor_array):
+ raise ValueError("invalid operand_idx")
+ self.operand_idx = operand_idx
@abstractmethod
def __repr__(self):
@property
@abstractmethod
+ def descriptor_array(self):
+ # type: () -> tuple[OperandDesc, ...]
+ ...
+
+ @property
def defining_descriptor(self):
# type: () -> OperandDesc
- ...
+ return self.descriptor_array[self.operand_idx]
@cached_property
def ty(self):
# type: () -> BaseTy
return self.ty_before_spread.base_ty
+ @property
+ def reg_offset_in_unspread(self):
+ """ the number of reg-sized slots in the unspread Loc before self's Loc
+
+ e.g. if the unspread Loc containing self is:
+ `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
+ and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
+ then reg_offset_into_unspread == 2 == 10 - 8
+ """
+ return self.defining_descriptor.reg_offset_in_unspread
+
+ @property
+ def unspread_start_idx(self):
+ # type: () -> int
+ return self.operand_idx - (self.defining_descriptor.spread_index or 0)
+
+ @property
+ def unspread_start(self):
+ # type: () -> Self
+ return self.__class__(op=self.op, operand_idx=self.unspread_start_idx)
+
@plain_data(frozen=True, unsafe_hash=True, repr=False)
@final
class SSAVal(SSAValOrUse):
- __slots__ = "output_idx",
-
- def __init__(self, op, output_idx):
- # type: (Op, int) -> None
- super().__init__(op)
- if output_idx < 0 or output_idx >= len(op.properties.outputs):
- raise ValueError("invalid output_idx")
- self.output_idx = output_idx
+ __slots__ = ()
def __repr__(self):
# type: () -> str
- return f"<{self.op.name}.outputs[{self.output_idx}]: {self.ty}>"
+ return f"<{self.op.name}.outputs[{self.operand_idx}]: {self.ty}>"
@cached_property
def def_loc_set_before_spread(self):
return self.defining_descriptor.loc_set_before_spread
@cached_property
- def defining_descriptor(self):
- # type: () -> OperandDesc
- return self.op.properties.outputs[self.output_idx]
+ def descriptor_array(self):
+ # type: () -> tuple[OperandDesc, ...]
+ return self.op.properties.outputs
+
+ @cached_property
+ def tied_input(self):
+ # type: () -> None | SSAUse
+ if self.defining_descriptor.tied_input_index is None:
+ return None
+ return SSAUse(op=self.op,
+ operand_idx=self.defining_descriptor.tied_input_index)
@plain_data(frozen=True, unsafe_hash=True, repr=False)
@final
class SSAUse(SSAValOrUse):
- __slots__ = "input_idx",
-
- def __init__(self, op, input_idx):
- # type: (Op, int) -> None
- super().__init__(op)
- self.input_idx = input_idx
- if input_idx < 0 or input_idx >= len(op.inputs):
- raise ValueError("input_idx out of range")
+ __slots__ = ()
@cached_property
def use_loc_set_before_spread(self):
return self.defining_descriptor.loc_set_before_spread
@cached_property
- def defining_descriptor(self):
- # type: () -> OperandDesc
- return self.op.properties.inputs[self.input_idx]
+ def descriptor_array(self):
+ # type: () -> tuple[OperandDesc, ...]
+ return self.op.properties.inputs
def __repr__(self):
# type: () -> str
- return f"<{self.op.name}.inputs[{self.input_idx}]: {self.ty}>"
+ return f"<{self.op.name}.input_uses[{self.operand_idx}]: {self.ty}>"
+
+ @property
+ def ssa_val(self):
+ # type: () -> SSAVal
+ return self.op.input_vals[self.operand_idx]
+
+ @ssa_val.setter
+ def ssa_val(self, ssa_val):
+ # type: (SSAVal) -> None
+ self.op.input_vals[self.operand_idx] = ssa_val
_T = TypeVar("_T")
@final
-class OpInputs(OpInputSeq[SSAVal, OperandDesc]):
+class OpInputVals(OpInputSeq[SSAVal, OperandDesc]):
def _get_descriptors(self):
# type: () -> tuple[OperandDesc, ...]
return self.op.properties.inputs
@plain_data(frozen=True, eq=False, repr=False)
@final
class Op:
- __slots__ = "fn", "properties", "inputs", "immediates", "outputs", "name"
+ __slots__ = ("fn", "properties", "input_vals", "input_uses", "immediates",
+ "outputs", "name")
- def __init__(self, fn, properties, inputs, immediates, name=""):
+ def __init__(self, fn, properties, input_vals, immediates, name=""):
# type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
self.fn = fn
self.properties = properties
- self.inputs = OpInputs(inputs, op=self)
+ self.input_vals = OpInputVals(input_vals, op=self)
+ inputs_len = len(self.properties.inputs)
+ self.input_uses = tuple(SSAUse(self, i) for i in range(inputs_len))
self.immediates = OpImmediates(immediates, op=self)
outputs_len = len(self.properties.outputs)
self.outputs = tuple(SSAVal(self, i) for i in range(outputs_len))
def pre_ra_sim(self, state):
# type: (PreRASimState) -> None
- for inp in self.inputs:
+ for inp in self.input_vals:
if inp not in state.ssa_vals:
raise ValueError(f"SSAVal {inp} not yet assigned when "
f"running {self}")
"""
from itertools import combinations
-from functools import reduce
-from typing import Generic, Iterable, Mapping
-from cached_property import cached_property
-import operator
+from typing import Any, Generic, Iterable, Iterator, Mapping, MutableSet
+from cached_property import cached_property
from nmutil.plain_data import plain_data
-from bigint_presentation_code.compiler_ir2 import (
- Op, LocSet, Ty, SSAVal, BaseTy, Loc, FnWithUses)
-from bigint_presentation_code.type_util import final, Self
-from bigint_presentation_code.util import OFSet, OSet, FMap
+from bigint_presentation_code.compiler_ir2 import (BaseTy, FnWithUses, Loc,
+ LocSet, Op, SSAVal, Ty)
+from bigint_presentation_code.type_util import final
+from bigint_presentation_code.util import FMap, OFSet, OSet
@plain_data(unsafe_hash=True, order=True, frozen=True)
pass
-@plain_data(frozen=True, unsafe_hash=True)
+@plain_data(frozen=True)
@final
class MergedSSAVal:
"""a set of `SSAVal`s along with their offsets, all register allocated as
* `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)`
* `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)`
"""
- __slots__ = "fn_with_uses", "ssa_val_offsets", "base_ty", "loc_set"
+ __slots__ = "fn_with_uses", "ssa_val_offsets", "first_ssa_val", "loc_set"
def __init__(self, fn_with_uses, ssa_val_offsets):
# type: (FnWithUses, Mapping[SSAVal, int] | SSAVal) -> None
if isinstance(ssa_val_offsets, SSAVal):
ssa_val_offsets = {ssa_val_offsets: 0}
self.ssa_val_offsets = FMap(ssa_val_offsets) # type: FMap[SSAVal, int]
- base_ty = None
- for ssa_val in self.ssa_val_offsets.keys():
- base_ty = ssa_val.base_ty
+ first_ssa_val = None
+ for ssa_val in self.ssa_vals:
+ first_ssa_val = ssa_val
break
- if base_ty is None:
+ if first_ssa_val is None:
raise BadMergedSSAVal("MergedSSAVal can't be empty")
- self.base_ty = base_ty # type: BaseTy
+ self.first_ssa_val = first_ssa_val # type: SSAVal
# self.ty checks for mismatched base_ty
reg_len = self.ty.reg_len
loc_set = None # type: None | LocSet
assert loc_set.ty == self.ty, "logic error somewhere"
self.loc_set = loc_set # type: LocSet
+ @cached_property
+ def __hash(self):
+ # type: () -> int
+ return hash((self.fn_with_uses, self.ssa_val_offsets))
+
+ def __hash__(self):
+ # type: () -> int
+ return self.__hash
+
@cached_property
def offset(self):
# type: () -> int
return min(self.ssa_val_offsets_before_spread.values())
+ @property
+ def base_ty(self):
+ # type: () -> BaseTy
+ return self.first_ssa_val.base_ty
+
+ @cached_property
+ def ssa_vals(self):
+ # type: () -> OFSet[SSAVal]
+ return OFSet(self.ssa_val_offsets.keys())
+
@cached_property
def ty(self):
# type: () -> Ty
# type: () -> FMap[SSAVal, int]
retval = {} # type: dict[SSAVal, int]
for ssa_val, offset in self.ssa_val_offsets.items():
- offset_before_spread = offset
- spread_index = ssa_val.defining_descriptor.spread_index
- if spread_index is not None:
- assert ssa_val.ty.reg_len == 1, (
- "this function assumes spreading always converts a vector "
- "to a contiguous sequence of scalars, if that's changed "
- "in the future, then this function needs to be adjusted")
- offset_before_spread -= spread_index
- retval[ssa_val] = offset_before_spread
+ retval[ssa_val] = (
+ offset - ssa_val.defining_descriptor.reg_offset_in_unspread)
return FMap(retval)
def offset_by(self, amount):
# type: () -> MergedSSAVal
return self.offset_by(-self.offset)
- def with_offset_to_match(self, target):
- # type: (MergedSSAVal) -> MergedSSAVal
+ def with_offset_to_match(self, target, additional_offset=0):
+ # type: (MergedSSAVal | SSAVal, int) -> MergedSSAVal
+ if isinstance(target, MergedSSAVal):
+ ssa_val_offsets = target.ssa_val_offsets
+ else:
+ ssa_val_offsets = {target: 0}
for ssa_val, offset in self.ssa_val_offsets.items():
- if ssa_val in target.ssa_val_offsets:
- return self.offset_by(target.ssa_val_offsets[ssa_val] - offset)
+ if ssa_val in ssa_val_offsets:
+ return self.offset_by(
+ ssa_val_offsets[ssa_val] + additional_offset - offset)
raise ValueError("can't change offset to match unrelated MergedSSAVal")
+ def merged(self, *others):
+ # type: (*MergedSSAVal) -> MergedSSAVal
+ retval = dict(self.ssa_val_offsets)
+ for other in others:
+ if other.fn_with_uses != self.fn_with_uses:
+ raise ValueError("fn_with_uses mismatch")
+ for ssa_val, offset in other.ssa_val_offsets.items():
+ if ssa_val in retval and retval[ssa_val] != offset:
+ raise BadMergedSSAVal(f"offset mismatch for {ssa_val}: "
+ f"{retval[ssa_val]} != {offset}")
+ retval[ssa_val] = offset
+ return MergedSSAVal(fn_with_uses=self.fn_with_uses,
+ ssa_val_offsets=retval)
+
+
+@final
+class MergedSSAValsMap(Mapping[SSAVal, MergedSSAVal]):
+ def __init__(self):
+ # type: (...) -> None
+ self.__merge_map = {} # type: dict[SSAVal, MergedSSAVal]
+ self.__values_set = MergedSSAValsSet(
+ _private_merge_map=self.__merge_map,
+ _private_values_set=OSet())
+
+ def __getitem__(self, __key):
+ # type: (SSAVal) -> MergedSSAVal
+ return self.__merge_map[__key]
+
+ def __iter__(self):
+ # type: () -> Iterator[SSAVal]
+ return iter(self.__merge_map)
+
+ def __len__(self):
+ # type: () -> int
+ return len(self.__merge_map)
+
+ @property
+ def values_set(self):
+ # type: () -> MergedSSAValsSet
+ return self.__values_set
+
+ def __repr__(self):
+ # type: () -> str
+ s = ",\n".join(repr(v) for v in self.__values_set)
+ return f"MergedSSAValsMap({{{s}}})"
+
@final
-class MergedSSAVals(OFSet[MergedSSAVal]):
- def __init__(self, merged_ssa_vals=()):
- # type: (Iterable[MergedSSAVal]) -> None
- super().__init__(merged_ssa_vals)
- merge_map = {} # type: dict[SSAVal, MergedSSAVal]
- for merged_ssa_val in self:
- for ssa_val in merged_ssa_val.ssa_val_offsets.keys():
- if ssa_val in merge_map:
+class MergedSSAValsSet(MutableSet[MergedSSAVal]):
+ def __init__(self, *,
+ _private_merge_map, # type: dict[SSAVal, MergedSSAVal]
+ _private_values_set, # type: OSet[MergedSSAVal]
+ ):
+ # type: (...) -> None
+ self.__merge_map = _private_merge_map
+ self.__values_set = _private_values_set
+
+ @classmethod
+ def _from_iterable(cls, it):
+ # type: (Iterable[MergedSSAVal]) -> OSet[MergedSSAVal]
+ return OSet(it)
+
+ def __contains__(self, value):
+ # type: (MergedSSAVal | Any) -> bool
+ return value in self.__values_set
+
+ def __iter__(self):
+ # type: () -> Iterator[MergedSSAVal]
+ return iter(self.__values_set)
+
+ def __len__(self):
+ # type: () -> int
+ return len(self.__values_set)
+
+ def add(self, value):
+ # type: (MergedSSAVal) -> None
+ if value in self:
+ return
+ added = 0 # type: int | None
+ try:
+ for ssa_val in value.ssa_vals:
+ if ssa_val in self.__merge_map:
raise ValueError(
f"overlapping `MergedSSAVal`s: {ssa_val} is in both "
- f"{merged_ssa_val} and {merge_map[ssa_val]}")
- merge_map[ssa_val] = merged_ssa_val
- self.__merge_map = FMap(merge_map)
+ f"{value} and {self.__merge_map[ssa_val]}")
+ self.__merge_map[ssa_val] = value
+ added += 1
+ self.__values_set.add(value)
+ added = None
+ finally:
+ if added is not None:
+ # remove partially added stuff
+ for idx, ssa_val in enumerate(value.ssa_vals):
+ if idx >= added:
+ break
+ del self.__merge_map[ssa_val]
+
+ def discard(self, value):
+ # type: (MergedSSAVal) -> None
+ if value not in self:
+ return
+ self.__values_set.discard(value)
+ for ssa_val in value.ssa_val_offsets.keys():
+ del self.__merge_map[ssa_val]
- @cached_property
- def merge_map(self):
- # type: () -> FMap[SSAVal, MergedSSAVal]
- return self.__merge_map
+ def __repr__(self):
+ # type: () -> str
+ s = ",\n".join(repr(v) for v in self.__values_set)
+ return f"MergedSSAValsSet({{{s}}})"
-# FIXME: work on code from here
+
+@plain_data(frozen=True)
+@final
+class MergedSSAVals:
+ __slots__ = "fn_with_uses", "merge_map", "merged_ssa_vals"
+
+ def __init__(self, fn_with_uses, merged_ssa_vals):
+ # type: (FnWithUses, Iterable[MergedSSAVal]) -> None
+ self.fn_with_uses = fn_with_uses
+ self.merge_map = MergedSSAValsMap()
+ self.merged_ssa_vals = self.merge_map.values_set
+ for i in merged_ssa_vals:
+ self.merged_ssa_vals.add(i)
+
+ def merge(self, ssa_val1, ssa_val2, additional_offset=0):
+ # type: (SSAVal, SSAVal, int) -> MergedSSAVal
+ merged1 = self.merge_map[ssa_val1]
+ merged2 = self.merge_map[ssa_val2]
+ merged = merged1.with_offset_to_match(ssa_val1)
+ merged = merged.merged(merged2.with_offset_to_match(
+ ssa_val2, additional_offset=additional_offset))
+ self.merged_ssa_vals.remove(merged1)
+ self.merged_ssa_vals.remove(merged2)
+ self.merged_ssa_vals.add(merged)
+ return merged
@staticmethod
def minimally_merged(fn_with_uses):
# type: (FnWithUses) -> MergedSSAVals
- merge_map = {} # type: dict[SSAVal, MergedSSAVal]
+ retval = MergedSSAVals(fn_with_uses=fn_with_uses, merged_ssa_vals=())
for op in fn_with_uses.fn.ops:
- for fn
- for val in (*op.inputs().values(), *op.outputs().values()):
- if val not in merged_sets:
- merged_sets[val] = MergedRegSet(val)
- for e in op.get_equality_constraints():
- lhs_set = MergedRegSet.from_equality_constraint(e.lhs)
- rhs_set = MergedRegSet.from_equality_constraint(e.rhs)
- items = [] # type: list[tuple[SSAVal, int]]
- for i in e.lhs:
- s = merged_sets[i].with_offset_to_match(lhs_set)
- items.extend(s.items())
- for i in e.rhs:
- s = merged_sets[i].with_offset_to_match(rhs_set)
- items.extend(s.items())
- full_set = MergedRegSet(items)
- for val in full_set.keys():
- merged_sets[val] = full_set
-
- self.__map = {k: v.normalized() for k, v in merged_sets.items()}
+ for inp in op.input_uses:
+ if inp.unspread_start != inp:
+ retval.merge(inp.unspread_start.ssa_val, inp.ssa_val,
+ additional_offset=inp.reg_offset_in_unspread)
+ for out in op.outputs:
+ if out.unspread_start != out:
+ retval.merge(out.unspread_start, out,
+ additional_offset=out.reg_offset_in_unspread)
+ if out.tied_input is not None:
+ retval.merge(out.tied_input.ssa_val, out)
+ return retval
+
+
+# FIXME: work on code from here
@final
-class LiveIntervals(Mapping[MergedRegSet[_RegType], LiveInterval]):
- def __init__(self, ops):
+class LiveIntervals(Mapping[MergedSSAVal, LiveInterval]):
+ def __init__(self, merged_ssa_vals):
# type: (list[Op]) -> None
self.__merged_reg_sets = MergedRegSets(ops)
live_intervals = {} # type: dict[MergedRegSet[_RegType], LiveInterval]