from functools import lru_cache, total_ordering
from io import StringIO
from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator,
- Mapping, Sequence, TypeVar, overload)
+ Mapping, Sequence, TypeVar, Union, overload)
from weakref import WeakValueDictionary as _WeakVDict
from cached_property import cached_property
for op in self.ops:
op.pre_ra_sim(state)
+ def gen_asm(self, state):
+ # type: (GenAsmState) -> None
+ for op in self.ops:
+ op.gen_asm(state)
+
def pre_ra_insert_copies(self):
# type: () -> None
orig_ops = list(self.ops)
copied_outputs = {} # type: dict[SSAVal, SSAVal]
+ setvli_outputs = {} # type: dict[SSAVal, Op]
self.ops.clear()
for op in orig_ops:
for i in range(len(op.input_vals)):
op.input_vals[i] = mv.outputs[0]
elif inp.ty.base_ty is BaseTy.CA \
or inp.ty.base_ty is BaseTy.VL_MAXVL:
- # all copies would be no-ops, so we don't need to copy
+ # all copies would be no-ops, so we don't need to copy,
+ # though we do need to rematerialize SetVLI ops right
+ # before the ops VL
+ if inp in setvli_outputs:
+ setvl = self.append_new_op(
+ OpKind.SetVLI,
+ immediates=setvli_outputs[inp].immediates,
+ name=f"{op.name}.inp{i}.setvl")
+ inp = setvl.outputs[0]
op.input_vals[i] = inp
else:
assert_never(inp.ty.base_ty)
self.ops.append(op)
for i, out in enumerate(op.outputs):
+ if op.kind is OpKind.SetVLI:
+ setvli_outputs[out] = op
if out.ty.base_ty is BaseTy.I64:
maxvl = out.ty.reg_len
if out.ty.reg_len != 1:
return f"<range:{start}..{stop}>"
-@plain_data(frozen=True, eq=False)
+@plain_data(frozen=True, eq=False, repr=False)
@final
class FnAnalysis:
__slots__ = ("fn", "uses", "op_indexes", "live_ranges", "live_at",
# type: () -> int
return hash(self.fn)
+ def __repr__(self):
+ # type: () -> str
+ return "<FnAnalysis>"
+
@unique
@final
def loc_count(self):
# type: () -> int
if self is LocKind.StackI64:
- return 1024
+ return 512
if self is LocKind.GPR or self is LocKind.CA \
or self is LocKind.VL_MAXVL:
return self.base_ty.max_reg_len
# type: (Ty, int) -> Loc
if subloc_ty.base_ty != self.kind.base_ty:
raise ValueError("BaseTy mismatch")
- start = self.start + offset
- if offset < 0 or start + subloc_ty.reg_len > self.reg_len:
+ if offset < 0 or offset + subloc_ty.reg_len > self.reg_len:
raise ValueError("invalid sub-Loc: offset and/or "
"subloc_ty.reg_len out of range")
- return Loc(kind=self.kind, start=start, reg_len=subloc_ty.reg_len)
+ return Loc(kind=self.kind,
+ start=self.start + offset, reg_len=subloc_ty.reg_len)
SPECIAL_GPRS = (
def __len__(self):
return self.__len
+ __HASHES = {} # type: dict[tuple[Ty | None, FMap[LocKind, FBitSet]], int]
+
@cached_property
def __hash(self):
- return super()._hash()
+ # cache hashes to avoid slow LocSet iteration
+ key = self.ty, self.starts
+ retval = self.__HASHES.get(key, None)
+ if retval is None:
+ self.__HASHES[key] = retval = super(LocSet, self)._hash()
+ return retval
def __hash__(self):
return self.__hash
+ def __eq__(self, __other):
+ # type: (LocSet | Any) -> bool
+ if isinstance(__other, LocSet):
+ return self.ty == __other.ty and self.starts == __other.starts
+ return super().__eq__(__other)
+
@lru_cache(maxsize=None, typed=True)
def max_conflicts_with(self, other):
# type: (LocSet | Loc) -> int
raise ValueError("loc_set_before_spread must not be empty")
self.loc_set_before_spread = loc_set_before_spread
self.tied_input_index = tied_input_index
- if self.tied_input_index is not None and self.spread_index is not None:
+ if self.tied_input_index is not None and spread_index is not None:
raise ValueError("operand can't be both spread and tied")
self.spread_index = spread_index
self.write_stage = write_stage
SvAddE = GenericOpProperties(
demo_asm="sv.adde *RT, *RA, *RB",
inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
- outputs=[OD_EXTRA3_VGPR, OD_CA],
+ outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
)
_PRE_RA_SIMS[SvAddE] = lambda: OpKind.__svadde_pre_ra_sim
_GEN_ASMS[SvAddE] = lambda: OpKind.__svadde_gen_asm
SvSubFE = GenericOpProperties(
demo_asm="sv.subfe *RT, *RA, *RB",
inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
- outputs=[OD_EXTRA3_VGPR, OD_CA],
+ outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
)
_PRE_RA_SIMS[SvSubFE] = lambda: OpKind.__svsubfe_pre_ra_sim
_GEN_ASMS[SvSubFE] = lambda: OpKind.__svsubfe_gen_asm
state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
@staticmethod
- def __veccopytoreg_gen_asm(op, state):
- # type: (Op, GenAsmState) -> None
- src_loc = state.loc(op.input_vals[0], (LocKind.GPR, LocKind.StackI64))
- dest_loc = state.loc(op.outputs[0], LocKind.GPR)
- RT = state.vgpr(dest_loc)
+ def __copy_to_from_reg_gen_asm(src_loc, dest_loc, is_vec, state):
+ # type: (Loc, Loc, bool, GenAsmState) -> None
+ sv = "sv." if is_vec else ""
+ rev = ""
+ if src_loc.conflicts(dest_loc) and src_loc.start < dest_loc.start:
+ rev = "/mrr"
if src_loc == dest_loc:
return # no-op
- assert src_loc.kind in (LocKind.GPR, LocKind.StackI64), \
- "checked by loc()"
+ if src_loc.kind not in (LocKind.GPR, LocKind.StackI64):
+ raise ValueError(f"invalid src_loc.kind: {src_loc.kind}")
+ if dest_loc.kind not in (LocKind.GPR, LocKind.StackI64):
+ raise ValueError(f"invalid dest_loc.kind: {dest_loc.kind}")
if src_loc.kind is LocKind.StackI64:
+ if dest_loc.kind is LocKind.StackI64:
+ raise ValueError(
+ f"can't copy from stack to stack: {src_loc} {dest_loc}")
+ elif dest_loc.kind is not LocKind.GPR:
+ assert_never(dest_loc.kind)
src = state.stack(src_loc)
- state.writeln(f"sv.ld {RT}, {src}")
- return
- elif src_loc.kind is not LocKind.GPR:
+ dest = state.gpr(dest_loc, is_vec=is_vec)
+ state.writeln(f"{sv}ld {dest}, {src}")
+ elif dest_loc.kind is LocKind.StackI64:
+ if src_loc.kind is not LocKind.GPR:
+ assert_never(src_loc.kind)
+ src = state.gpr(src_loc, is_vec=is_vec)
+ dest = state.stack(dest_loc)
+ state.writeln(f"{sv}std {src}, {dest}")
+ elif src_loc.kind is LocKind.GPR:
+ if dest_loc.kind is not LocKind.GPR:
+ assert_never(dest_loc.kind)
+ src = state.gpr(src_loc, is_vec=is_vec)
+ dest = state.gpr(dest_loc, is_vec=is_vec)
+ state.writeln(f"{sv}or{rev} {dest}, {src}, {src}")
+ else:
assert_never(src_loc.kind)
- rev = ""
- if src_loc.conflicts(dest_loc) and src_loc.start < dest_loc.start:
- rev = "/mrr"
- src = state.vgpr(src_loc)
- state.writeln(f"sv.or{rev} {RT}, {src}, {src}")
+
+ @staticmethod
+ def __veccopytoreg_gen_asm(op, state):
+ # type: (Op, GenAsmState) -> None
+ OpKind.__copy_to_from_reg_gen_asm(
+ src_loc=state.loc(
+ op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
+ dest_loc=state.loc(op.outputs[0], LocKind.GPR),
+ is_vec=True, state=state)
VecCopyToReg = GenericOpProperties(
demo_asm="sv.mv dest, src",
# type: (Op, PreRASimState) -> None
state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
- # FIXME: change to correct __*_gen_asm function
@staticmethod
- def __clearca_gen_asm(op, state):
+ def __veccopyfromreg_gen_asm(op, state):
# type: (Op, GenAsmState) -> None
- state.writeln("addic 0, 0, 0")
+ OpKind.__copy_to_from_reg_gen_asm(
+ src_loc=state.loc(op.input_vals[0], LocKind.GPR),
+ dest_loc=state.loc(
+ op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
+ is_vec=True, state=state)
VecCopyFromReg = GenericOpProperties(
demo_asm="sv.mv dest, src",
inputs=[OD_EXTRA3_VGPR, OD_VL],
# type: (Op, PreRASimState) -> None
state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
- # FIXME: change to correct __*_gen_asm function
@staticmethod
- def __clearca_gen_asm(op, state):
+ def __copytoreg_gen_asm(op, state):
# type: (Op, GenAsmState) -> None
- state.writeln("addic 0, 0, 0")
+ OpKind.__copy_to_from_reg_gen_asm(
+ src_loc=state.loc(
+ op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
+ dest_loc=state.loc(op.outputs[0], LocKind.GPR),
+ is_vec=False, state=state)
CopyToReg = GenericOpProperties(
demo_asm="mv dest, src",
inputs=[GenericOperandDesc(
# type: (Op, PreRASimState) -> None
state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
- # FIXME: change to correct __*_gen_asm function
@staticmethod
- def __clearca_gen_asm(op, state):
+ def __copyfromreg_gen_asm(op, state):
# type: (Op, GenAsmState) -> None
- state.writeln("addic 0, 0, 0")
+ OpKind.__copy_to_from_reg_gen_asm(
+ src_loc=state.loc(op.input_vals[0], LocKind.GPR),
+ dest_loc=state.loc(
+ op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
+ is_vec=False, state=state)
CopyFromReg = GenericOpProperties(
demo_asm="mv dest, src",
inputs=[GenericOperandDesc(
state.ssa_vals[op.outputs[0]] = tuple(
state.ssa_vals[i][0] for i in op.input_vals[:-1])
- # FIXME: change to correct __*_gen_asm function
@staticmethod
- def __clearca_gen_asm(op, state):
+ def __concat_gen_asm(op, state):
# type: (Op, GenAsmState) -> None
- state.writeln("addic 0, 0, 0")
+ OpKind.__copy_to_from_reg_gen_asm(
+ src_loc=state.loc(op.input_vals[0:-1], LocKind.GPR),
+ dest_loc=state.loc(op.outputs[0], LocKind.GPR),
+ is_vec=True, state=state)
Concat = GenericOpProperties(
demo_asm="sv.mv dest, src",
inputs=[GenericOperandDesc(
for idx, inp in enumerate(state.ssa_vals[op.input_vals[0]]):
state.ssa_vals[op.outputs[idx]] = inp,
- # FIXME: change to correct __*_gen_asm function
@staticmethod
- def __clearca_gen_asm(op, state):
+ def __spread_gen_asm(op, state):
# type: (Op, GenAsmState) -> None
- state.writeln("addic 0, 0, 0")
+ OpKind.__copy_to_from_reg_gen_asm(
+ src_loc=state.loc(op.input_vals[0], LocKind.GPR),
+ dest_loc=state.loc(op.outputs, LocKind.GPR),
+ is_vec=True, state=state)
Spread = GenericOpProperties(
demo_asm="sv.mv dest, src",
inputs=[OD_EXTRA3_VGPR, OD_VL],
RT.append(v & GPR_VALUE_MASK)
state.ssa_vals[op.outputs[0]] = tuple(RT)
- # FIXME: change to correct __*_gen_asm function
@staticmethod
- def __clearca_gen_asm(op, state):
+ def __svld_gen_asm(op, state):
# type: (Op, GenAsmState) -> None
- state.writeln("addic 0, 0, 0")
+ RA = state.sgpr(op.input_vals[0])
+ RT = state.vgpr(op.outputs[0])
+ imm = op.immediates[0]
+ state.writeln(f"sv.ld {RT}, {imm}({RA})")
SvLd = GenericOpProperties(
demo_asm="sv.ld *RT, imm(RA)",
inputs=[OD_EXTRA3_SGPR, OD_VL],
v = state.load(addr)
state.ssa_vals[op.outputs[0]] = v & GPR_VALUE_MASK,
- # FIXME: change to correct __*_gen_asm function
@staticmethod
- def __clearca_gen_asm(op, state):
+ def __ld_gen_asm(op, state):
# type: (Op, GenAsmState) -> None
- state.writeln("addic 0, 0, 0")
+ RA = state.sgpr(op.input_vals[0])
+ RT = state.sgpr(op.outputs[0])
+ imm = op.immediates[0]
+ state.writeln(f"ld {RT}, {imm}({RA})")
Ld = GenericOpProperties(
demo_asm="ld RT, imm(RA)",
inputs=[OD_BASE_SGPR],
for i in range(VL):
state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i])
- # FIXME: change to correct __*_gen_asm function
@staticmethod
- def __clearca_gen_asm(op, state):
+ def __svstd_gen_asm(op, state):
# type: (Op, GenAsmState) -> None
- state.writeln("addic 0, 0, 0")
+ RS = state.vgpr(op.input_vals[0])
+ RA = state.sgpr(op.input_vals[1])
+ imm = op.immediates[0]
+ state.writeln(f"sv.std {RS}, {imm}({RA})")
SvStd = GenericOpProperties(
demo_asm="sv.std *RS, imm(RA)",
inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
addr = RA + op.immediates[0]
state.store(addr, value=RS)
- # FIXME: change to correct __*_gen_asm function
@staticmethod
- def __clearca_gen_asm(op, state):
+ def __std_gen_asm(op, state):
# type: (Op, GenAsmState) -> None
- state.writeln("addic 0, 0, 0")
+ RS = state.sgpr(op.input_vals[0])
+ RA = state.sgpr(op.input_vals[1])
+ imm = op.immediates[0]
+ state.writeln(f"std {RS}, {imm}({RA})")
Std = GenericOpProperties(
- demo_asm="std RT, imm(RA)",
+ demo_asm="std RS, imm(RA)",
inputs=[OD_BASE_SGPR, OD_BASE_SGPR],
outputs=[],
immediates=[IMM_S16],
f"expected {out.ty.reg_len} found "
f"{len(state.ssa_vals[out])}: {state.ssa_vals[out]!r}")
+ def gen_asm(self, state):
+ # type: (GenAsmState) -> None
+ all_loc_kinds = tuple(LocKind)
+ for inp in self.input_vals:
+ state.loc(inp, expected_kinds=all_loc_kinds)
+ for out in self.outputs:
+ state.loc(out, expected_kinds=all_loc_kinds)
+ self.kind.gen_asm(self, state)
+
GPR_SIZE_IN_BYTES = 8
BITS_IN_BYTE = 8
class GenAsmState:
__slots__ = "allocated_locs", "output"
- def __init__(self, allocated_locs, output):
+ def __init__(self, allocated_locs, output=None):
# type: (Mapping[SSAVal, Loc], StringIO | list[str] | None) -> None
super().__init__()
self.allocated_locs = FMap(allocated_locs)
output = []
self.output = output
- def loc(self, ssa_val_or_loc, expected_kinds):
- # type: (SSAVal | Loc, LocKind | tuple[LocKind, ...]) -> Loc
- if isinstance(ssa_val_or_loc, SSAVal):
- retval = self.allocated_locs[ssa_val_or_loc]
- else:
- retval = ssa_val_or_loc
+ __SSA_VAL_OR_LOCS = Union[SSAVal, Loc, Sequence["SSAVal | Loc"]]
+
+ def loc(self, ssa_val_or_locs, expected_kinds):
+ # type: (__SSA_VAL_OR_LOCS, LocKind | tuple[LocKind, ...]) -> Loc
+ if isinstance(ssa_val_or_locs, (SSAVal, Loc)):
+ ssa_val_or_locs = [ssa_val_or_locs]
+ locs = [] # type: list[Loc]
+ for i in ssa_val_or_locs:
+ if isinstance(i, SSAVal):
+ locs.append(self.allocated_locs[i])
+ else:
+ locs.append(i)
+ if len(locs) == 0:
+ raise ValueError("invalid Loc sequence: must not be empty")
+ retval = locs[0].try_concat(*locs[1:])
+ if retval is None:
+ raise ValueError("invalid Loc sequence: try_concat failed")
if isinstance(expected_kinds, LocKind):
expected_kinds = expected_kinds,
if retval.kind not in expected_kinds:
if len(expected_kinds) == 1:
expected_kinds = expected_kinds[0]
- raise ValueError(f"LocKind mismatch: {ssa_val_or_loc}: found "
+ raise ValueError(f"LocKind mismatch: {ssa_val_or_locs}: found "
f"{retval.kind} expected {expected_kinds}")
return retval
- def gpr(self, ssa_val_or_loc, is_vec):
- # type: (SSAVal | Loc, bool) -> str
- loc = self.loc(ssa_val_or_loc, LocKind.GPR)
+ def gpr(self, ssa_val_or_locs, is_vec):
+ # type: (__SSA_VAL_OR_LOCS, bool) -> str
+ loc = self.loc(ssa_val_or_locs, LocKind.GPR)
vec_str = "*" if is_vec else ""
return vec_str + str(loc.start)
- def sgpr(self, ssa_val_or_loc):
- # type: (SSAVal | Loc) -> str
- return self.gpr(ssa_val_or_loc, is_vec=False)
+ def sgpr(self, ssa_val_or_locs):
+ # type: (__SSA_VAL_OR_LOCS) -> str
+ return self.gpr(ssa_val_or_locs, is_vec=False)
- def vgpr(self, ssa_val_or_loc):
- # type: (SSAVal | Loc) -> str
- return self.gpr(ssa_val_or_loc, is_vec=True)
+ def vgpr(self, ssa_val_or_locs):
+ # type: (__SSA_VAL_OR_LOCS) -> str
+ return self.gpr(ssa_val_or_locs, is_vec=True)
- def stack(self, ssa_val_or_loc):
- # type: (SSAVal | Loc) -> str
- loc = self.loc(ssa_val_or_loc, LocKind.StackI64)
+ def stack(self, ssa_val_or_locs):
+ # type: (__SSA_VAL_OR_LOCS) -> str
+ loc = self.loc(ssa_val_or_locs, LocKind.StackI64)
return f"{loc.start}(1)"
def writeln(self, *line_segments):