X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=src%2Fbigint_presentation_code%2Fcompiler_ir2.py;h=b256a07b7a58868bdcce9d203cdf96389fcd6dc6;hb=e2ad5044997e39b98f990d7409b5dfccd22027bc;hp=9f911969c7f96235d472f7fb7b15cac83c9491a9;hpb=9ff4c2196738457c74e83201a7c087cf04098e5a;p=bigint-presentation-code.git diff --git a/src/bigint_presentation_code/compiler_ir2.py b/src/bigint_presentation_code/compiler_ir2.py index 9f91196..b256a07 100644 --- a/src/bigint_presentation_code/compiler_ir2.py +++ b/src/bigint_presentation_code/compiler_ir2.py @@ -4,7 +4,7 @@ from enum import Enum, unique from functools import lru_cache, total_ordering from io import StringIO from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator, - Mapping, Sequence, TypeVar, overload) + Mapping, Sequence, TypeVar, Union, overload) from weakref import WeakValueDictionary as _WeakVDict from cached_property import cached_property @@ -59,10 +59,16 @@ class Fn: for op in self.ops: op.pre_ra_sim(state) + def gen_asm(self, state): + # type: (GenAsmState) -> None + for op in self.ops: + op.gen_asm(state) + def pre_ra_insert_copies(self): # type: () -> None orig_ops = list(self.ops) copied_outputs = {} # type: dict[SSAVal, SSAVal] + setvli_outputs = {} # type: dict[SSAVal, Op] self.ops.clear() for op in orig_ops: for i in range(len(op.input_vals)): @@ -84,12 +90,22 @@ class Fn: op.input_vals[i] = mv.outputs[0] elif inp.ty.base_ty is BaseTy.CA \ or inp.ty.base_ty is BaseTy.VL_MAXVL: - # all copies would be no-ops, so we don't need to copy + # all copies would be no-ops, so we don't need to copy, + # though we do need to rematerialize SetVLI ops right + # before the ops VL + if inp in setvli_outputs: + setvl = self.append_new_op( + OpKind.SetVLI, + immediates=setvli_outputs[inp].immediates, + name=f"{op.name}.inp{i}.setvl") + inp = setvl.outputs[0] op.input_vals[i] = inp else: assert_never(inp.ty.base_ty) self.ops.append(op) for i, out in enumerate(op.outputs): + if op.kind is OpKind.SetVLI: + setvli_outputs[out] = op if out.ty.base_ty is BaseTy.I64: maxvl = out.ty.reg_len if out.ty.reg_len != 1: @@ -263,7 +279,7 @@ class ProgramRange(Sequence[ProgramPoint]): return f"" -@plain_data(frozen=True, eq=False) +@plain_data(frozen=True, eq=False, repr=False) @final class FnAnalysis: __slots__ = ("fn", "uses", "op_indexes", "live_ranges", "live_at", @@ -335,6 +351,10 @@ class FnAnalysis: # type: () -> int return hash(self.fn) + def __repr__(self): + # type: () -> str + return "" + @unique @final @@ -425,7 +445,7 @@ class LocKind(Enum): def loc_count(self): # type: () -> int if self is LocKind.StackI64: - return 1024 + return 512 if self is LocKind.GPR or self is LocKind.CA \ or self is LocKind.VL_MAXVL: return self.base_ty.max_reg_len @@ -606,11 +626,11 @@ class Loc: # type: (Ty, int) -> Loc if subloc_ty.base_ty != self.kind.base_ty: raise ValueError("BaseTy mismatch") - start = self.start + offset - if offset < 0 or start + subloc_ty.reg_len > self.reg_len: + if offset < 0 or offset + subloc_ty.reg_len > self.reg_len: raise ValueError("invalid sub-Loc: offset and/or " "subloc_ty.reg_len out of range") - return Loc(kind=self.kind, start=start, reg_len=subloc_ty.reg_len) + return Loc(kind=self.kind, + start=self.start + offset, reg_len=subloc_ty.reg_len) SPECIAL_GPRS = ( @@ -726,13 +746,26 @@ class LocSet(AbstractSet[Loc]): def __len__(self): return self.__len + __HASHES = {} # type: dict[tuple[Ty | None, FMap[LocKind, FBitSet]], int] + @cached_property def __hash(self): - return super()._hash() + # cache hashes to avoid slow LocSet iteration + key = self.ty, self.starts + retval = self.__HASHES.get(key, None) + if retval is None: + self.__HASHES[key] = retval = super(LocSet, self)._hash() + return retval def __hash__(self): return self.__hash + def __eq__(self, __other): + # type: (LocSet | Any) -> bool + if isinstance(__other, LocSet): + return self.ty == __other.ty and self.starts == __other.starts + return super().__eq__(__other) + @lru_cache(maxsize=None, typed=True) def max_conflicts_with(self, other): # type: (LocSet | Loc) -> int @@ -861,7 +894,7 @@ class OperandDesc: raise ValueError("loc_set_before_spread must not be empty") self.loc_set_before_spread = loc_set_before_spread self.tied_input_index = tied_input_index - if self.tied_input_index is not None and self.spread_index is not None: + if self.tied_input_index is not None and spread_index is not None: raise ValueError("operand can't be both spread and tied") self.spread_index = spread_index self.write_stage = write_stage @@ -1121,7 +1154,7 @@ class OpKind(Enum): SvAddE = GenericOpProperties( demo_asm="sv.adde *RT, *RA, *RB", inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL], - outputs=[OD_EXTRA3_VGPR, OD_CA], + outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)], ) _PRE_RA_SIMS[SvAddE] = lambda: OpKind.__svadde_pre_ra_sim _GEN_ASMS[SvAddE] = lambda: OpKind.__svadde_gen_asm @@ -1151,7 +1184,7 @@ class OpKind(Enum): SvSubFE = GenericOpProperties( demo_asm="sv.subfe *RT, *RA, *RB", inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL], - outputs=[OD_EXTRA3_VGPR, OD_CA], + outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)], ) _PRE_RA_SIMS[SvSubFE] = lambda: OpKind.__svsubfe_pre_ra_sim _GEN_ASMS[SvSubFE] = lambda: OpKind.__svsubfe_gen_asm @@ -1258,26 +1291,50 @@ class OpKind(Enum): state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]] @staticmethod - def __veccopytoreg_gen_asm(op, state): - # type: (Op, GenAsmState) -> None - src_loc = state.loc(op.input_vals[0], (LocKind.GPR, LocKind.StackI64)) - dest_loc = state.loc(op.outputs[0], LocKind.GPR) - RT = state.vgpr(dest_loc) + def __copy_to_from_reg_gen_asm(src_loc, dest_loc, is_vec, state): + # type: (Loc, Loc, bool, GenAsmState) -> None + sv = "sv." if is_vec else "" + rev = "" + if src_loc.conflicts(dest_loc) and src_loc.start < dest_loc.start: + rev = "/mrr" if src_loc == dest_loc: return # no-op - assert src_loc.kind in (LocKind.GPR, LocKind.StackI64), \ - "checked by loc()" + if src_loc.kind not in (LocKind.GPR, LocKind.StackI64): + raise ValueError(f"invalid src_loc.kind: {src_loc.kind}") + if dest_loc.kind not in (LocKind.GPR, LocKind.StackI64): + raise ValueError(f"invalid dest_loc.kind: {dest_loc.kind}") if src_loc.kind is LocKind.StackI64: + if dest_loc.kind is LocKind.StackI64: + raise ValueError( + f"can't copy from stack to stack: {src_loc} {dest_loc}") + elif dest_loc.kind is not LocKind.GPR: + assert_never(dest_loc.kind) src = state.stack(src_loc) - state.writeln(f"sv.ld {RT}, {src}") - return - elif src_loc.kind is not LocKind.GPR: + dest = state.gpr(dest_loc, is_vec=is_vec) + state.writeln(f"{sv}ld {dest}, {src}") + elif dest_loc.kind is LocKind.StackI64: + if src_loc.kind is not LocKind.GPR: + assert_never(src_loc.kind) + src = state.gpr(src_loc, is_vec=is_vec) + dest = state.stack(dest_loc) + state.writeln(f"{sv}std {src}, {dest}") + elif src_loc.kind is LocKind.GPR: + if dest_loc.kind is not LocKind.GPR: + assert_never(dest_loc.kind) + src = state.gpr(src_loc, is_vec=is_vec) + dest = state.gpr(dest_loc, is_vec=is_vec) + state.writeln(f"{sv}or{rev} {dest}, {src}, {src}") + else: assert_never(src_loc.kind) - rev = "" - if src_loc.conflicts(dest_loc) and src_loc.start < dest_loc.start: - rev = "/mrr" - src = state.vgpr(src_loc) - state.writeln(f"sv.or{rev} {RT}, {src}, {src}") + + @staticmethod + def __veccopytoreg_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + OpKind.__copy_to_from_reg_gen_asm( + src_loc=state.loc( + op.input_vals[0], (LocKind.GPR, LocKind.StackI64)), + dest_loc=state.loc(op.outputs[0], LocKind.GPR), + is_vec=True, state=state) VecCopyToReg = GenericOpProperties( demo_asm="sv.mv dest, src", @@ -1296,11 +1353,14 @@ class OpKind(Enum): # type: (Op, PreRASimState) -> None state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]] - # FIXME: change to correct __*_gen_asm function @staticmethod - def __clearca_gen_asm(op, state): + def __veccopyfromreg_gen_asm(op, state): # type: (Op, GenAsmState) -> None - state.writeln("addic 0, 0, 0") + OpKind.__copy_to_from_reg_gen_asm( + src_loc=state.loc(op.input_vals[0], LocKind.GPR), + dest_loc=state.loc( + op.outputs[0], (LocKind.GPR, LocKind.StackI64)), + is_vec=True, state=state) VecCopyFromReg = GenericOpProperties( demo_asm="sv.mv dest, src", inputs=[OD_EXTRA3_VGPR, OD_VL], @@ -1319,11 +1379,14 @@ class OpKind(Enum): # type: (Op, PreRASimState) -> None state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]] - # FIXME: change to correct __*_gen_asm function @staticmethod - def __clearca_gen_asm(op, state): + def __copytoreg_gen_asm(op, state): # type: (Op, GenAsmState) -> None - state.writeln("addic 0, 0, 0") + OpKind.__copy_to_from_reg_gen_asm( + src_loc=state.loc( + op.input_vals[0], (LocKind.GPR, LocKind.StackI64)), + dest_loc=state.loc(op.outputs[0], LocKind.GPR), + is_vec=False, state=state) CopyToReg = GenericOpProperties( demo_asm="mv dest, src", inputs=[GenericOperandDesc( @@ -1346,11 +1409,14 @@ class OpKind(Enum): # type: (Op, PreRASimState) -> None state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]] - # FIXME: change to correct __*_gen_asm function @staticmethod - def __clearca_gen_asm(op, state): + def __copyfromreg_gen_asm(op, state): # type: (Op, GenAsmState) -> None - state.writeln("addic 0, 0, 0") + OpKind.__copy_to_from_reg_gen_asm( + src_loc=state.loc(op.input_vals[0], LocKind.GPR), + dest_loc=state.loc( + op.outputs[0], (LocKind.GPR, LocKind.StackI64)), + is_vec=False, state=state) CopyFromReg = GenericOpProperties( demo_asm="mv dest, src", inputs=[GenericOperandDesc( @@ -1374,11 +1440,13 @@ class OpKind(Enum): state.ssa_vals[op.outputs[0]] = tuple( state.ssa_vals[i][0] for i in op.input_vals[:-1]) - # FIXME: change to correct __*_gen_asm function @staticmethod - def __clearca_gen_asm(op, state): + def __concat_gen_asm(op, state): # type: (Op, GenAsmState) -> None - state.writeln("addic 0, 0, 0") + OpKind.__copy_to_from_reg_gen_asm( + src_loc=state.loc(op.input_vals[0:-1], LocKind.GPR), + dest_loc=state.loc(op.outputs[0], LocKind.GPR), + is_vec=True, state=state) Concat = GenericOpProperties( demo_asm="sv.mv dest, src", inputs=[GenericOperandDesc( @@ -1398,11 +1466,13 @@ class OpKind(Enum): for idx, inp in enumerate(state.ssa_vals[op.input_vals[0]]): state.ssa_vals[op.outputs[idx]] = inp, - # FIXME: change to correct __*_gen_asm function @staticmethod - def __clearca_gen_asm(op, state): + def __spread_gen_asm(op, state): # type: (Op, GenAsmState) -> None - state.writeln("addic 0, 0, 0") + OpKind.__copy_to_from_reg_gen_asm( + src_loc=state.loc(op.input_vals[0], LocKind.GPR), + dest_loc=state.loc(op.outputs, LocKind.GPR), + is_vec=True, state=state) Spread = GenericOpProperties( demo_asm="sv.mv dest, src", inputs=[OD_EXTRA3_VGPR, OD_VL], @@ -1429,11 +1499,13 @@ class OpKind(Enum): RT.append(v & GPR_VALUE_MASK) state.ssa_vals[op.outputs[0]] = tuple(RT) - # FIXME: change to correct __*_gen_asm function @staticmethod - def __clearca_gen_asm(op, state): + def __svld_gen_asm(op, state): # type: (Op, GenAsmState) -> None - state.writeln("addic 0, 0, 0") + RA = state.sgpr(op.input_vals[0]) + RT = state.vgpr(op.outputs[0]) + imm = op.immediates[0] + state.writeln(f"sv.ld {RT}, {imm}({RA})") SvLd = GenericOpProperties( demo_asm="sv.ld *RT, imm(RA)", inputs=[OD_EXTRA3_SGPR, OD_VL], @@ -1451,11 +1523,13 @@ class OpKind(Enum): v = state.load(addr) state.ssa_vals[op.outputs[0]] = v & GPR_VALUE_MASK, - # FIXME: change to correct __*_gen_asm function @staticmethod - def __clearca_gen_asm(op, state): + def __ld_gen_asm(op, state): # type: (Op, GenAsmState) -> None - state.writeln("addic 0, 0, 0") + RA = state.sgpr(op.input_vals[0]) + RT = state.sgpr(op.outputs[0]) + imm = op.immediates[0] + state.writeln(f"ld {RT}, {imm}({RA})") Ld = GenericOpProperties( demo_asm="ld RT, imm(RA)", inputs=[OD_BASE_SGPR], @@ -1475,11 +1549,13 @@ class OpKind(Enum): for i in range(VL): state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i]) - # FIXME: change to correct __*_gen_asm function @staticmethod - def __clearca_gen_asm(op, state): + def __svstd_gen_asm(op, state): # type: (Op, GenAsmState) -> None - state.writeln("addic 0, 0, 0") + RS = state.vgpr(op.input_vals[0]) + RA = state.sgpr(op.input_vals[1]) + imm = op.immediates[0] + state.writeln(f"sv.std {RS}, {imm}({RA})") SvStd = GenericOpProperties( demo_asm="sv.std *RS, imm(RA)", inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL], @@ -1498,13 +1574,15 @@ class OpKind(Enum): addr = RA + op.immediates[0] state.store(addr, value=RS) - # FIXME: change to correct __*_gen_asm function @staticmethod - def __clearca_gen_asm(op, state): + def __std_gen_asm(op, state): # type: (Op, GenAsmState) -> None - state.writeln("addic 0, 0, 0") + RS = state.sgpr(op.input_vals[0]) + RA = state.sgpr(op.input_vals[1]) + imm = op.immediates[0] + state.writeln(f"std {RS}, {imm}({RA})") Std = GenericOpProperties( - demo_asm="std RT, imm(RA)", + demo_asm="std RS, imm(RA)", inputs=[OD_BASE_SGPR, OD_BASE_SGPR], outputs=[], immediates=[IMM_S16], @@ -1877,6 +1955,15 @@ class Op: f"expected {out.ty.reg_len} found " f"{len(state.ssa_vals[out])}: {state.ssa_vals[out]!r}") + def gen_asm(self, state): + # type: (GenAsmState) -> None + all_loc_kinds = tuple(LocKind) + for inp in self.input_vals: + state.loc(inp, expected_kinds=all_loc_kinds) + for out in self.outputs: + state.loc(out, expected_kinds=all_loc_kinds) + self.kind.gen_asm(self, state) + GPR_SIZE_IN_BYTES = 8 BITS_IN_BYTE = 8 @@ -1994,7 +2081,7 @@ class PreRASimState: class GenAsmState: __slots__ = "allocated_locs", "output" - def __init__(self, allocated_locs, output): + def __init__(self, allocated_locs, output=None): # type: (Mapping[SSAVal, Loc], StringIO | list[str] | None) -> None super().__init__() self.allocated_locs = FMap(allocated_locs) @@ -2006,38 +2093,49 @@ class GenAsmState: output = [] self.output = output - def loc(self, ssa_val_or_loc, expected_kinds): - # type: (SSAVal | Loc, LocKind | tuple[LocKind, ...]) -> Loc - if isinstance(ssa_val_or_loc, SSAVal): - retval = self.allocated_locs[ssa_val_or_loc] - else: - retval = ssa_val_or_loc + __SSA_VAL_OR_LOCS = Union[SSAVal, Loc, Sequence["SSAVal | Loc"]] + + def loc(self, ssa_val_or_locs, expected_kinds): + # type: (__SSA_VAL_OR_LOCS, LocKind | tuple[LocKind, ...]) -> Loc + if isinstance(ssa_val_or_locs, (SSAVal, Loc)): + ssa_val_or_locs = [ssa_val_or_locs] + locs = [] # type: list[Loc] + for i in ssa_val_or_locs: + if isinstance(i, SSAVal): + locs.append(self.allocated_locs[i]) + else: + locs.append(i) + if len(locs) == 0: + raise ValueError("invalid Loc sequence: must not be empty") + retval = locs[0].try_concat(*locs[1:]) + if retval is None: + raise ValueError("invalid Loc sequence: try_concat failed") if isinstance(expected_kinds, LocKind): expected_kinds = expected_kinds, if retval.kind not in expected_kinds: if len(expected_kinds) == 1: expected_kinds = expected_kinds[0] - raise ValueError(f"LocKind mismatch: {ssa_val_or_loc}: found " + raise ValueError(f"LocKind mismatch: {ssa_val_or_locs}: found " f"{retval.kind} expected {expected_kinds}") return retval - def gpr(self, ssa_val_or_loc, is_vec): - # type: (SSAVal | Loc, bool) -> str - loc = self.loc(ssa_val_or_loc, LocKind.GPR) + def gpr(self, ssa_val_or_locs, is_vec): + # type: (__SSA_VAL_OR_LOCS, bool) -> str + loc = self.loc(ssa_val_or_locs, LocKind.GPR) vec_str = "*" if is_vec else "" return vec_str + str(loc.start) - def sgpr(self, ssa_val_or_loc): - # type: (SSAVal | Loc) -> str - return self.gpr(ssa_val_or_loc, is_vec=False) + def sgpr(self, ssa_val_or_locs): + # type: (__SSA_VAL_OR_LOCS) -> str + return self.gpr(ssa_val_or_locs, is_vec=False) - def vgpr(self, ssa_val_or_loc): - # type: (SSAVal | Loc) -> str - return self.gpr(ssa_val_or_loc, is_vec=True) + def vgpr(self, ssa_val_or_locs): + # type: (__SSA_VAL_OR_LOCS) -> str + return self.gpr(ssa_val_or_locs, is_vec=True) - def stack(self, ssa_val_or_loc): - # type: (SSAVal | Loc) -> str - loc = self.loc(ssa_val_or_loc, LocKind.StackI64) + def stack(self, ssa_val_or_locs): + # type: (__SSA_VAL_OR_LOCS) -> str + loc = self.loc(ssa_val_or_locs, LocKind.StackI64) return f"{loc.start}(1)" def writeln(self, *line_segments):