From c9f33dda41a61e0ee817e907ea5eb66c0fddcab1 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Sun, 23 Oct 2022 00:25:43 -0700 Subject: [PATCH] 192x192->384-bit O(n^2) mul works in SSA form, reg-alloc gives incorrect results though --- src/bigint_presentation_code/compiler_ir.py | 278 ++++++++++++++++-- .../test_register_allocator.py | 69 +++-- .../test_toom_cook.py | 214 +++++++++++++- src/bigint_presentation_code/toom_cook.py | 35 ++- 4 files changed, 535 insertions(+), 61 deletions(-) diff --git a/src/bigint_presentation_code/compiler_ir.py b/src/bigint_presentation_code/compiler_ir.py index 517e542..17f5f3f 100644 --- a/src/bigint_presentation_code/compiler_ir.py +++ b/src/bigint_presentation_code/compiler_ir.py @@ -12,7 +12,7 @@ from typing import Any, Generic, Iterable, Sequence, Type, TypeVar, cast from nmutil.plain_data import fields, plain_data -from bigint_presentation_code.util import OFSet, OSet, final +from bigint_presentation_code.util import FMap, OFSet, OSet, final class ABCEnumMeta(EnumMeta, ABCMeta): @@ -41,7 +41,7 @@ class RegLoc(metaclass=ABCMeta): GPR_COUNT = 128 -@plain_data(frozen=True, unsafe_hash=True) +@plain_data(frozen=True, unsafe_hash=True, repr=False) @final class GPRRange(RegLoc, Sequence["GPRRange"]): __slots__ = "start", "length" @@ -114,6 +114,11 @@ class GPRRange(RegLoc, Sequence["GPRRange"]): raise ValueError(f"sub-register offset is out of range: {offset}") return GPRRange(self.start + offset, subreg_type.length) + def __repr__(self): + if self.length == 1: + return f"" + return f"" + SPECIAL_GPRS = GPRRange(0), GPRRange(1), GPRRange(2), GPRRange(13) @@ -191,7 +196,7 @@ _RegType = TypeVar("_RegType", bound=RegType) _RegLoc = TypeVar("_RegLoc", bound=RegLoc) -@plain_data(frozen=True, eq=False) +@plain_data(frozen=True, eq=False, repr=False) @final class GPRRangeType(RegType): __slots__ = "length", @@ -230,12 +235,15 @@ class GPRRangeType(RegType): def __hash__(self): return hash(self.length) + def __repr__(self): + return f"" + GPRType = GPRRangeType """a length=1 GPRRangeType""" -@plain_data(frozen=True, unsafe_hash=True) +@plain_data(frozen=True, unsafe_hash=True, repr=False) @final class FixedGPRRangeType(RegType): __slots__ = "reg", @@ -254,6 +262,9 @@ class FixedGPRRangeType(RegType): # type: () -> int return self.reg.length + def __repr__(self): + return f"" + @plain_data(frozen=True, unsafe_hash=True) @final @@ -411,20 +422,8 @@ class SSAVal(Generic[_RegType]): def __hash__(self): return hash((id(self.op), self.arg_name)) - def __repr__(self, long=False): - if not long: - return f"<#{self.op.id}.{self.arg_name}>" - fields_list = [] - for name in fields(self): - v = getattr(self, name, None) - if v is not None: - if name == "op": - v = v.__repr__(just_id=True) - else: - v = repr(v) - fields_list.append(f"{name}={v}") - fields_str = ", ".join(fields_list) - return f"SSAVal({fields_str})" + def __repr__(self): + return f"<#{self.op.id}.{self.arg_name}: {self.ty}>" SSAGPRRange = SSAVal[GPRRangeType] @@ -459,6 +458,11 @@ class Fn: ops = ", ".join(op.__repr__(just_id=True) for op in self.ops) return f"" + def pre_ra_sim(self, state): + # type: (PreRASimState) -> None + for op in self.ops: + op.pre_ra_sim(state) + class _NotSet: """ helper for __repr__ for when fields aren't set """ @@ -641,6 +645,36 @@ class AsmContext: return False +GPR_SIZE_IN_BYTES = 8 +GPR_SIZE_IN_BITS = GPR_SIZE_IN_BYTES * 8 +GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1 + + +@plain_data(frozen=True) +@final +class PreRASimState: + __slots__ = ("gprs", "VLs", "CAs", + "global_mems", "stack_slots", + "fixed_gprs") + + def __init__( + self, + gprs, # type: dict[SSAGPRRange, tuple[int, ...]] + VLs, # type: dict[SSAKnownVL, int] + CAs, # type: dict[SSAVal[CAType], bool] + global_mems, # type: dict[SSAVal[GlobalMemType], FMap[int, int]] + stack_slots, # type: dict[SSAVal[StackSlotType], tuple[int, ...]] + fixed_gprs, # type: dict[SSAVal[FixedGPRRangeType], tuple[int, ...]] + ): + # type: (...) -> None + self.gprs = gprs + self.VLs = VLs + self.CAs = CAs + self.global_mems = global_mems + self.stack_slots = stack_slots + self.fixed_gprs = fixed_gprs + + @plain_data(unsafe_hash=True, frozen=True, repr=False) class Op(metaclass=ABCMeta): __slots__ = "id", "fn" @@ -681,15 +715,14 @@ class Op(metaclass=ABCMeta): pass if not just_id: for name in fields(self): + if name in ("id", "fn"): + continue v = getattr(self, name, _NOT_SET) - if ((outputs is None or name in outputs) - and isinstance(v, SSAVal)): - v = v.__repr__(long=True) - elif isinstance(v, Fn): - v = v.__repr__(short=True) + if (outputs is not None and name in outputs + and outputs[name] is v): + fields_list.append(repr(v)) else: - v = repr(v) - fields_list.append(f"{name}={v}") + fields_list.append(f"{name}={v!r}") fields_str = ', '.join(fields_list) return f"{self.__class__.__name__}({fields_str})" @@ -699,6 +732,12 @@ class Op(metaclass=ABCMeta): """get the lines of assembly for this Op""" ... + @abstractmethod + def pre_ra_sim(self, state): + # type: (PreRASimState) -> None + """simulate op before register allocation""" + ... + @plain_data(unsafe_hash=True, frozen=True, repr=False) @final @@ -732,6 +771,11 @@ class OpLoadFromStackSlot(Op): return [f"sv.ld {dest}, {src.start_byte}(1)"] return [f"ld {dest}, {src.start_byte}(1)"] + def pre_ra_sim(self, state): + # type: (PreRASimState) -> None + """simulate op before register allocation""" + state.gprs[self.dest] = state.stack_slots[self.src] + @plain_data(unsafe_hash=True, frozen=True, repr=False) @final @@ -765,6 +809,11 @@ class OpStoreToStackSlot(Op): return [f"sv.std {src}, {dest.start_byte}(1)"] return [f"std {src}, {dest.start_byte}(1)"] + def pre_ra_sim(self, state): + # type: (PreRASimState) -> None + """simulate op before register allocation""" + state.stack_slots[self.dest] = state.gprs[self.src] + _RegSrcType = TypeVar("_RegSrcType", bound=RegType) @@ -805,6 +854,8 @@ class OpCopy(Op, Generic[_RegSrcType, _RegType]): elif src.ty != dest_ty: raise ValueError(f"incompatible source and destination " f"types: {src.ty} and {dest_ty}") + elif isinstance(src.ty, StackSlotType): + raise ValueError("can't use OpCopy on stack slots") elif isinstance(src.ty, (GPRRangeType, FixedGPRRangeType)): length = src.ty.length else: @@ -834,6 +885,37 @@ class OpCopy(Op, Generic[_RegSrcType, _RegType]): return [f"or {dest_s}, {src_s}, {src_s}"] raise NotImplementedError + def pre_ra_sim(self, state): + # type: (PreRASimState) -> None + if (isinstance(self.src.ty, (GPRRangeType, FixedGPRRangeType)) and + isinstance(self.dest.ty, (GPRRangeType, FixedGPRRangeType))): + if isinstance(self.src.ty, GPRRangeType): + v = state.gprs[self.src] # type: ignore + else: + v = state.fixed_gprs[self.src] # type: ignore + if isinstance(self.dest.ty, GPRRangeType): + state.gprs[self.dest] = v # type: ignore + else: + state.fixed_gprs[self.dest] = v # type: ignore + elif (isinstance(self.src.ty, FixedGPRRangeType) and + isinstance(self.dest.ty, GPRRangeType)): + state.gprs[self.dest] = state.fixed_gprs[self.src] # type: ignore + elif (isinstance(self.src.ty, GPRRangeType) and + isinstance(self.dest.ty, FixedGPRRangeType)): + state.fixed_gprs[self.dest] = state.gprs[self.src] # type: ignore + elif (isinstance(self.src.ty, CAType) and + self.src.ty == self.dest.ty): + state.CAs[self.dest] = state.CAs[self.src] # type: ignore + elif (isinstance(self.src.ty, KnownVLType) and + self.src.ty == self.dest.ty): + state.VLs[self.dest] = state.VLs[self.src] # type: ignore + elif (isinstance(self.src.ty, GlobalMemType) and + self.src.ty == self.dest.ty): + v = state.global_mems[self.src] # type: ignore + state.global_mems[self.dest] = v # type: ignore + else: + raise NotImplementedError + @plain_data(unsafe_hash=True, frozen=True, repr=False) @final @@ -864,6 +946,13 @@ class OpConcat(Op): # type: (AsmContext) -> list[str] return [] + def pre_ra_sim(self, state): + # type: (PreRASimState) -> None + v = [] + for src in self.sources: + v.extend(state.gprs[src]) + state.gprs[self.dest] = tuple(v) + @plain_data(unsafe_hash=True, frozen=True, repr=False) @final @@ -892,7 +981,7 @@ class OpSplit(Op): ranges.append(GPRRangeType(src.ty.length - last)) self.src = src self.results = tuple( - SSAVal(self, f"results{i}", r) for i, r in enumerate(ranges)) + SSAVal(self, f"results[{i}]", r) for i, r in enumerate(ranges)) def get_equality_constraints(self): # type: () -> Iterable[EqualityConstraint] @@ -902,6 +991,13 @@ class OpSplit(Op): # type: (AsmContext) -> list[str] return [] + def pre_ra_sim(self, state): + # type: (PreRASimState) -> None + rest = state.gprs[self.src] + for dest in reversed(self.results): + state.gprs[dest] = rest[-dest.ty.length:] + rest = rest[:-dest.ty.length] + @plain_data(unsafe_hash=True, frozen=True, repr=False) @final @@ -956,6 +1052,19 @@ class OpBigIntAddSub(Op): return [f"sv.{mnemonic} {out}, {RA}, {RB}"] return [f"{mnemonic} {out}, {RA}, {RB}"] + def pre_ra_sim(self, state): + # type: (PreRASimState) -> None + carry = state.CAs[self.CA_in] + out = [] # type: list[int] + for l, r in zip(state.gprs[self.lhs], state.gprs[self.rhs]): + if self.is_sub: + r = r ^ GPR_VALUE_MASK + s = l + r + carry + carry = s != (s & GPR_VALUE_MASK) + out.append(s & GPR_VALUE_MASK) + state.CAs[self.CA_out] = carry + state.gprs[self.out] = tuple(out) + @plain_data(unsafe_hash=True, frozen=True, repr=False) @final @@ -1013,6 +1122,29 @@ class OpBigIntMulDiv(Op): mnemonic = "divmod2du/mrr" return [f"sv.{mnemonic} {RT}, {RA}, {RB}, {RC}"] + def pre_ra_sim(self, state): + # type: (PreRASimState) -> None + carry = state.gprs[self.RC][0] + RA = state.gprs[self.RA] + RB = state.gprs[self.RB][0] + RT = [0] * self.RT.ty.length + if self.is_div: + for i in reversed(range(self.RT.ty.length)): + if carry < RB and RB != 0: + div, mod = divmod((carry << 64) | RA[i], RB) + RT[i] = div & GPR_VALUE_MASK + carry = mod & GPR_VALUE_MASK + else: + RT[i] = GPR_VALUE_MASK + carry = 0 + else: + for i in range(self.RT.ty.length): + v = RA[i] * RB + carry + carry = v >> 64 + RT[i] = v & GPR_VALUE_MASK + state.gprs[self.RS] = carry, + state.gprs[self.RT] = tuple(RT) + @final @unique @@ -1101,6 +1233,27 @@ class OpBigIntShift(Op): RB = ctx.sgpr(self.sh) return [f"sv.dsrd {RT}, {RA}, {RB}, 1"] + def pre_ra_sim(self, state): + # type: (PreRASimState) -> None + out = [0] * self.out.ty.length + carry = state.gprs[self.carry_in][0] + sh = state.gprs[self.sh][0] % 64 + if self.kind is ShiftKind.Sl: + inp = carry, *state.gprs[self.inp] + for i in reversed(range(self.out.ty.length)): + v = inp[i] | (inp[i + 1] << 64) + v <<= sh + out[i] = (v >> 64) & GPR_VALUE_MASK + else: + assert self.kind is ShiftKind.Sr or self.kind is ShiftKind.Sra + inp = *state.gprs[self.inp], carry + for i in range(self.out.ty.length): + v = inp[i] | (inp[i + 1] << 64) + v >>= sh + out[i] = v & GPR_VALUE_MASK + # state.gprs[self._out_padding] is intentionally not written + state.gprs[self.out] = tuple(out) + @plain_data(unsafe_hash=True, frozen=True, repr=False) @final @@ -1150,6 +1303,25 @@ class OpShiftImm(Op): return [f"sv.{mnemonic} {out}, {inp}, {args}"] return [f"{mnemonic} {out}, {inp}, {args}"] + def pre_ra_sim(self, state): + # type: (PreRASimState) -> None + inp = state.gprs[self.inp][0] + if self.kind is ShiftKind.Sl: + assert self.ca_out is None + out = inp << self.sh + elif self.kind is ShiftKind.Sr: + assert self.ca_out is None + out = inp >> self.sh + else: + assert self.kind is ShiftKind.Sra + assert self.ca_out is not None + if inp & (1 << 63): # sign extend + inp -= 1 << 64 + out = inp >> self.sh + ca = inp < 0 and (out << self.sh) != inp + state.CAs[self.ca_out] = ca + state.gprs[self.out] = out, + @plain_data(unsafe_hash=True, frozen=True, repr=False) @final @@ -1189,6 +1361,11 @@ class OpLI(Op): return [f"sv.addi {out}, 0, {self.value}"] return [f"addi {out}, 0, {self.value}"] + def pre_ra_sim(self, state): + # type: (PreRASimState) -> None + value = self.value & GPR_VALUE_MASK + state.gprs[self.out] = (value,) * self.out.ty.length + @plain_data(unsafe_hash=True, frozen=True, repr=False) @final @@ -1215,6 +1392,10 @@ class OpSetCA(Op): return ["subfic 0, 0, -1"] return ["addic 0, 0, 0"] + def pre_ra_sim(self, state): + # type: (PreRASimState) -> None + state.CAs[self.out] = self.value + @plain_data(unsafe_hash=True, frozen=True, repr=False) @final @@ -1265,6 +1446,22 @@ class OpLoad(Op): return [f"sv.ld {RT}, {self.offset}({RA})"] return [f"ld {RT}, {self.offset}({RA})"] + def pre_ra_sim(self, state): + # type: (PreRASimState) -> None + addr = state.gprs[self.RA][0] + addr += self.offset + RT = [0] * self.RT.ty.length + mem = state.global_mems[self.mem] + for i in range(self.RT.ty.length): + cur_addr = (addr + i * GPR_SIZE_IN_BYTES) & GPR_VALUE_MASK + if cur_addr % GPR_SIZE_IN_BYTES != 0: + raise ValueError(f"can't load from unaligned address: " + f"{cur_addr:#x}") + for j in range(GPR_SIZE_IN_BYTES): + byte_val = mem.get(cur_addr + j, 0) & 0xFF + RT[i] |= byte_val << (j * 8) + state.gprs[self.RT] = tuple(RT) + @plain_data(unsafe_hash=True, frozen=True, repr=False) @final @@ -1308,6 +1505,21 @@ class OpStore(Op): return [f"sv.std {RS}, {self.offset}({RA})"] return [f"std {RS}, {self.offset}({RA})"] + def pre_ra_sim(self, state): + # type: (PreRASimState) -> None + mem = dict(state.global_mems[self.mem_in]) + addr = state.gprs[self.RA][0] + addr += self.offset + RS = state.gprs[self.RS] + for i in range(self.RS.ty.length): + cur_addr = (addr + i * GPR_SIZE_IN_BYTES) & GPR_VALUE_MASK + if cur_addr % GPR_SIZE_IN_BYTES != 0: + raise ValueError(f"can't store to unaligned address: " + f"{cur_addr:#x}") + for j in range(GPR_SIZE_IN_BYTES): + mem[cur_addr + j] = (RS[i] >> (j * 8)) & 0xFF + state.global_mems[self.mem_out] = FMap(mem) + @plain_data(unsafe_hash=True, frozen=True, repr=False) @final @@ -1331,6 +1543,11 @@ class OpFuncArg(Op): # type: (AsmContext) -> list[str] return [] + def pre_ra_sim(self, state): + # type: (PreRASimState) -> None + if self.out not in state.fixed_gprs: + state.fixed_gprs[self.out] = (0,) * self.out.ty.length + @plain_data(unsafe_hash=True, frozen=True, repr=False) @final @@ -1354,6 +1571,11 @@ class OpInputMem(Op): # type: (AsmContext) -> list[str] return [] + def pre_ra_sim(self, state): + # type: (PreRASimState) -> None + if self.out not in state.global_mems: + state.global_mems[self.out] = FMap() + @plain_data(unsafe_hash=True, frozen=True, repr=False) @final @@ -1377,6 +1599,10 @@ class OpSetVLImm(Op): # type: (AsmContext) -> list[str] return [f"setvl 0, 0, {self.out.ty.length}, 0, 1, 1"] + def pre_ra_sim(self, state): + # type: (PreRASimState) -> None + state.VLs[self.out] = self.out.ty.length + def op_set_to_list(ops): # type: (Iterable[Op]) -> list[Op] diff --git a/src/bigint_presentation_code/test_register_allocator.py b/src/bigint_presentation_code/test_register_allocator.py index 6558264..1eff254 100644 --- a/src/bigint_presentation_code/test_register_allocator.py +++ b/src/bigint_presentation_code/test_register_allocator.py @@ -56,43 +56,62 @@ class TestRegisterAllocator(unittest.TestCase): repr(reg_assignments), "AllocationFailed(" "node=IGNode(#0, merged_reg_set=MergedRegSet([" - "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]), " + "(<#4.dest: >, 0), " + "(<#1.out: >, 0), " + "(<#3.out: >, 52)]), " "edges={}, reg=None), " - "live_intervals=LiveIntervals(" - "live_intervals={" - "MergedRegSet([(<#0.out>, 0)]): " + "live_intervals=LiveIntervals(live_intervals={" + "MergedRegSet([(<#0.out: KnownVLType(length=52)>, 0)]): " "LiveInterval(first_write=0, last_use=1), " - "MergedRegSet([(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]): " + "MergedRegSet([(<#4.dest: >, 0), " + "(<#1.out: >, 0), " + "(<#3.out: >, 52)]): " "LiveInterval(first_write=1, last_use=4), " - "MergedRegSet([(<#2.out>, 0)]): " + "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)]): " "LiveInterval(first_write=2, last_use=3)}, " "merged_reg_sets=MergedRegSets(data={" - "<#0.out>: MergedRegSet([(<#0.out>, 0)]), " - "<#1.out>: MergedRegSet([" - "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]), " - "<#2.out>: MergedRegSet([(<#2.out>, 0)]), " - "<#3.out>: MergedRegSet([" - "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]), " - "<#4.dest>: MergedRegSet([" - "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)])}), " + "<#0.out: KnownVLType(length=52)>: " + "MergedRegSet([(<#0.out: KnownVLType(length=52)>, 0)]), " + "<#1.out: >: MergedRegSet([" + "(<#4.dest: >, 0), " + "(<#1.out: >, 0), " + "(<#3.out: >, 52)]), " + "<#2.out: KnownVLType(length=64)>: " + "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)]), " + "<#3.out: >: MergedRegSet([" + "(<#4.dest: >, 0), " + "(<#1.out: >, 0), " + "(<#3.out: >, 52)]), " + "<#4.dest: >: MergedRegSet([" + "(<#4.dest: >, 0), " + "(<#1.out: >, 0), " + "(<#3.out: >, 52)])}), " "reg_sets_live_after={" - "0: OFSet([MergedRegSet([(<#0.out>, 0)])]), " + "0: OFSet([MergedRegSet([" + "(<#0.out: KnownVLType(length=52)>, 0)])]), " "1: OFSet([MergedRegSet([" - "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)])]), " + "(<#4.dest: >, 0), " + "(<#1.out: >, 0), " + "(<#3.out: >, 52)])]), " "2: OFSet([MergedRegSet([" - "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]), " - "MergedRegSet([(<#2.out>, 0)])]), " + "(<#4.dest: >, 0), " + "(<#1.out: >, 0), " + "(<#3.out: >, 52)]), " + "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)])]), " "3: OFSet([MergedRegSet([" - "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)])]), " + "(<#4.dest: >, 0), " + "(<#1.out: >, 0), " + "(<#3.out: >, 52)])]), " "4: OFSet()}), " "interference_graph=InterferenceGraph(nodes={" - "...: IGNode(#0, merged_reg_set=MergedRegSet([(<#0.out>, 0)]), " - "edges={}, reg=None), " + "...: IGNode(#0, merged_reg_set=MergedRegSet([" + "(<#0.out: KnownVLType(length=52)>, 0)]), edges={}, reg=None), " "...: IGNode(#1, merged_reg_set=MergedRegSet([" - "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]), " - "edges={}, reg=None), " - "...: IGNode(#2, merged_reg_set=MergedRegSet([(<#2.out>, 0)]), " - "edges={}, reg=None)}))" + "(<#4.dest: >, 0), " + "(<#1.out: >, 0), " + "(<#3.out: >, 52)]), edges={}, reg=None), " + "...: IGNode(#2, merged_reg_set=MergedRegSet([" + "(<#2.out: KnownVLType(length=64)>, 0)]), edges={}, reg=None)}))" ) def test_try_alloc_bigint_inc(self): diff --git a/src/bigint_presentation_code/test_toom_cook.py b/src/bigint_presentation_code/test_toom_cook.py index 656c8d7..6fff570 100644 --- a/src/bigint_presentation_code/test_toom_cook.py +++ b/src/bigint_presentation_code/test_toom_cook.py @@ -1,10 +1,37 @@ import unittest -from bigint_presentation_code.toom_cook import ToomCookInstance +from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BYTES, SSAGPR, VL, FixedGPRRangeType, Fn, + GlobalMem, GPRRange, + GPRRangeType, OpCopy, + OpFuncArg, OpInputMem, + OpSetVLImm, OpStore, PreRASimState, SSAGPRRange, XERBit, + generate_assembly) +from bigint_presentation_code.register_allocator import allocate_registers +from bigint_presentation_code.toom_cook import ToomCookInstance, simple_mul +from bigint_presentation_code.util import FMap + + +class SimpleMul192x192: + def __init__(self): + self.fn = fn = Fn() + self.mem_in = mem = OpInputMem(fn).out + self.dest_ptr_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3))).out + self.lhs_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(4, 3))).out + self.rhs_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(7, 3))).out + dest_ptr = OpCopy(fn, self.dest_ptr_in, GPRRangeType()).dest + vl = OpSetVLImm(fn, 3).out + lhs = OpCopy(fn, self.lhs_in, GPRRangeType(3), vl=vl).dest + rhs = OpCopy(fn, self.rhs_in, GPRRangeType(3), vl=vl).dest + retval = simple_mul(fn, lhs, rhs) + vl = OpSetVLImm(fn, 6).out + self.mem_out = OpStore(fn, RS=retval, RA=dest_ptr, offset=0, + mem_in=mem, vl=vl).mem_out class TestToomCook(unittest.TestCase): - def test_toom_2(self): + maxDiff = None + + def test_toom_2_repr(self): TOOM_2 = ToomCookInstance.make_toom_2() # print(repr(repr(TOOM_2))) self.assertEqual( @@ -42,7 +69,7 @@ class TestToomCook(unittest.TestCase): "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)}))))" ) - def test_toom_2_5(self): + def test_toom_2_5_repr(self): TOOM_2_5 = ToomCookInstance.make_toom_2_5() # print(repr(repr(TOOM_2_5))) self.assertEqual( @@ -107,9 +134,9 @@ class TestToomCook(unittest.TestCase): "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)}))))" ) - def test_reversed_toom_2_5(self): + def test_reversed_toom_2_5_repr(self): TOOM_2_5 = ToomCookInstance.make_toom_2_5().reversed() - print(repr(repr(TOOM_2_5))) + # print(repr(repr(TOOM_2_5))) self.assertEqual( repr(TOOM_2_5), "ToomCookInstance(lhs_part_count=2, rhs_part_count=3, " @@ -169,6 +196,183 @@ class TestToomCook(unittest.TestCase): "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)}))))" ) + def test_simple_mul_192x192_pre_ra_sim(self): + # test multiplying: + # 0x000191acb262e15b_4c6b5f2b19e1a53e_821a2342132c5b57 + # * 0x4a37c0567bcbab53_cf1f597598194ae6_208a49071aeec507 + # == + # int("0x00074736574206e_6f69746163696c70" + # "_69746c756d207469_622d3438333e2d32" + # "_3931783239312079_7261727469627261", base=0) + # == int.from_bytes(b"arbitrary 192x192->384-bit multiplication test", + # 'little') + code = SimpleMul192x192() + dest_ptr = 0x100 + state = PreRASimState( + gprs={}, VLs={}, CAs={}, global_mems={code.mem_in: FMap()}, + stack_slots={}, fixed_gprs={ + code.dest_ptr_in: (dest_ptr,), + code.lhs_in: (0x821a2342132c5b57, 0x4c6b5f2b19e1a53e, + 0x000191acb262e15b), + code.rhs_in: (0x208a49071aeec507, 0xcf1f597598194ae6, + 0x4a37c0567bcbab53) + }) + code.fn.pre_ra_sim(state) + expected_bytes = b"arbitrary 192x192->384-bit multiplication test" + OUT_BYTE_COUNT = 6 * GPR_SIZE_IN_BYTES + expected_bytes = expected_bytes.ljust(OUT_BYTE_COUNT, b'\0') + mem_out = state.global_mems[code.mem_out] + out_bytes = bytes( + mem_out.get(dest_ptr + i, 0) for i in range(OUT_BYTE_COUNT)) + self.assertEqual(out_bytes, expected_bytes) + + def test_simple_mul_192x192_ops(self): + code = SimpleMul192x192() + fn = code.fn + self.assertEqual([repr(v) for v in fn.ops], [ + 'OpInputMem(#0, <#0.out: GlobalMemType()>)', + 'OpFuncArg(#1, <#1.out: )>>)', + 'OpFuncArg(#2, <#2.out: )>>)', + 'OpFuncArg(#3, <#3.out: )>>)', + 'OpCopy(#4, <#4.dest: >, src=<#1.out: )>>, ' + 'vl=None)', + 'OpSetVLImm(#5, <#5.out: KnownVLType(length=3)>)', + 'OpCopy(#6, <#6.dest: >, ' + 'src=<#2.out: )>>, ' + 'vl=<#5.out: KnownVLType(length=3)>)', + 'OpCopy(#7, <#7.dest: >, ' + 'src=<#3.out: )>>, ' + 'vl=<#5.out: KnownVLType(length=3)>)', + 'OpSplit(#8, results=(<#8.results[0]: >, ' + '<#8.results[1]: >, <#8.results[2]: >), ' + 'src=<#7.dest: >)', + 'OpSetVLImm(#9, <#9.out: KnownVLType(length=3)>)', + 'OpLI(#10, <#10.out: >, value=0, vl=None)', + 'OpBigIntMulDiv(#11, <#11.RT: >, ' + 'RA=<#6.dest: >, RB=<#8.results[0]: >, ' + 'RC=<#10.out: >, <#11.RS: >, is_div=False, ' + 'vl=<#9.out: KnownVLType(length=3)>)', + 'OpConcat(#12, <#12.dest: >, sources=(' + '<#11.RT: >, <#11.RS: >))', + 'OpBigIntMulDiv(#13, <#13.RT: >, ' + 'RA=<#6.dest: >, RB=<#8.results[1]: >, ' + 'RC=<#10.out: >, <#13.RS: >, is_div=False, ' + 'vl=<#9.out: KnownVLType(length=3)>)', + 'OpSplit(#14, results=(<#14.results[0]: >, ' + '<#14.results[1]: >), src=<#12.dest: >)', + 'OpSetCA(#15, <#15.out: CAType()>, value=False)', + 'OpBigIntAddSub(#16, <#16.out: >, ' + 'lhs=<#13.RT: >, rhs=<#14.results[1]: >, ' + 'CA_in=<#15.out: CAType()>, <#16.CA_out: CAType()>, is_sub=False, ' + 'vl=<#9.out: KnownVLType(length=3)>)', + 'OpBigIntAddSub(#17, <#17.out: >, ' + 'lhs=<#13.RS: >, rhs=<#10.out: >, ' + 'CA_in=<#16.CA_out: CAType()>, <#17.CA_out: CAType()>, ' + 'is_sub=False, vl=None)', + 'OpConcat(#18, <#18.dest: >, sources=(' + '<#14.results[0]: >, <#16.out: >, ' + '<#17.out: >))', + 'OpBigIntMulDiv(#19, <#19.RT: >, ' + 'RA=<#6.dest: >, RB=<#8.results[2]: >, ' + 'RC=<#10.out: >, <#19.RS: >, is_div=False, ' + 'vl=<#9.out: KnownVLType(length=3)>)', + 'OpSplit(#20, results=(<#20.results[0]: >, ' + '<#20.results[1]: >), src=<#18.dest: >)', + 'OpSetCA(#21, <#21.out: CAType()>, value=False)', + 'OpBigIntAddSub(#22, <#22.out: >, ' + 'lhs=<#19.RT: >, rhs=<#20.results[1]: >, ' + 'CA_in=<#21.out: CAType()>, <#22.CA_out: CAType()>, is_sub=False, ' + 'vl=<#9.out: KnownVLType(length=3)>)', + 'OpBigIntAddSub(#23, <#23.out: >, ' + 'lhs=<#19.RS: >, rhs=<#10.out: >, ' + 'CA_in=<#22.CA_out: CAType()>, <#23.CA_out: CAType()>, ' + 'is_sub=False, vl=None)', + 'OpConcat(#24, <#24.dest: >, sources=(' + '<#20.results[0]: >, <#22.out: >, ' + '<#23.out: >))', + 'OpSetVLImm(#25, <#25.out: KnownVLType(length=6)>)', + 'OpStore(#26, RS=<#24.dest: >, ' + 'RA=<#4.dest: >, offset=0, ' + 'mem_in=<#0.out: GlobalMemType()>, ' + '<#26.mem_out: GlobalMemType()>, ' + 'vl=<#25.out: KnownVLType(length=6)>)' + ]) + + # FIXME: register allocator currently allocates wrong registers + @unittest.expectedFailure + def test_simple_mul_192x192_reg_alloc(self): + code = SimpleMul192x192() + fn = code.fn + assigned_registers = allocate_registers(fn.ops) + self.assertEqual(assigned_registers, { + fn.ops[13].RS: GPRRange(9), # type: ignore + fn.ops[14].results[0]: GPRRange(6), # type: ignore + fn.ops[14].results[1]: GPRRange(7, length=3), # type: ignore + fn.ops[15].out: XERBit.CA, # type: ignore + fn.ops[16].out: GPRRange(7, length=3), # type: ignore + fn.ops[16].CA_out: XERBit.CA, # type: ignore + fn.ops[17].out: GPRRange(10), # type: ignore + fn.ops[17].CA_out: XERBit.CA, # type: ignore + fn.ops[18].dest: GPRRange(6, length=5), # type: ignore + fn.ops[19].RT: GPRRange(3, length=3), # type: ignore + fn.ops[19].RS: GPRRange(9), # type: ignore + fn.ops[20].results[0]: GPRRange(6, length=2), # type: ignore + fn.ops[20].results[1]: GPRRange(8, length=3), # type: ignore + fn.ops[21].out: XERBit.CA, # type: ignore + fn.ops[22].out: GPRRange(8, length=3), # type: ignore + fn.ops[22].CA_out: XERBit.CA, # type: ignore + fn.ops[23].out: GPRRange(11), # type: ignore + fn.ops[23].CA_out: XERBit.CA, # type: ignore + fn.ops[24].dest: GPRRange(6, length=6), # type: ignore + fn.ops[25].out: VL.VL_MAXVL, # type: ignore + fn.ops[26].mem_out: GlobalMem.GlobalMem, # type: ignore + fn.ops[0].out: GlobalMem.GlobalMem, # type: ignore + fn.ops[1].out: GPRRange(3), # type: ignore + fn.ops[2].out: GPRRange(4, length=3), # type: ignore + fn.ops[3].out: GPRRange(7, length=3), # type: ignore + fn.ops[4].dest: GPRRange(12), # type: ignore + fn.ops[5].out: VL.VL_MAXVL, # type: ignore + fn.ops[6].dest: GPRRange(17, length=3), # type: ignore + fn.ops[7].dest: GPRRange(14, length=3), # type: ignore + fn.ops[8].results[0]: GPRRange(14), # type: ignore + fn.ops[8].results[1]: GPRRange(15), # type: ignore + fn.ops[8].results[2]: GPRRange(16), # type: ignore + fn.ops[9].out: VL.VL_MAXVL, # type: ignore + fn.ops[10].out: GPRRange(9), # type: ignore + fn.ops[11].RT: GPRRange(6, length=3), # type: ignore + fn.ops[11].RS: GPRRange(9), # type: ignore + fn.ops[12].dest: GPRRange(6, length=4), # type: ignore + fn.ops[13].RT: GPRRange(3, length=3) # type: ignore + }) + self.fail("register allocator currently allocates wrong registers") + + # FIXME: register allocator currently allocates wrong registers + @unittest.expectedFailure + def test_simple_mul_192x192_asm(self): + code = SimpleMul192x192() + asm = generate_assembly(code.fn.ops) + self.assertEqual(asm, [ + 'or 12, 3, 3', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *17, *4, *4', + 'sv.or *14, *7, *7', + 'setvl 0, 0, 3, 0, 1, 1', + 'addi 9, 0, 0', + 'sv.maddedu *6, *17, 14, 9', + 'sv.maddedu *3, *17, 15, 9', + 'addic 0, 0, 0', + 'sv.adde *7, *3, *7', + 'adde 10, 9, 9', + 'sv.maddedu *3, *17, 16, 9', + 'addic 0, 0, 0', + 'sv.adde *8, *3, *8', + 'adde 11, 9, 9', + 'setvl 0, 0, 6, 0, 1, 1', + 'sv.std *6, 0(12)', + 'bclr 20, 0, 0' + ]) + self.fail("register allocator currently allocates wrong registers") + if __name__ == "__main__": unittest.main() diff --git a/src/bigint_presentation_code/toom_cook.py b/src/bigint_presentation_code/toom_cook.py index 9786624..9e3ec74 100644 --- a/src/bigint_presentation_code/toom_cook.py +++ b/src/bigint_presentation_code/toom_cook.py @@ -8,7 +8,7 @@ from typing import Any, Generic, Iterable, Mapping, Sequence, TypeVar, Union from nmutil.plain_data import plain_data -from bigint_presentation_code.compiler_ir import Fn, Op +from bigint_presentation_code.compiler_ir import Fn, Op, OpBigIntAddSub, OpBigIntMulDiv, OpConcat, OpLI, OpSetCA, OpSetVLImm, OpSplit, SSAGPRRange from bigint_presentation_code.matrix import Matrix from bigint_presentation_code.util import Literal, OSet, final @@ -438,8 +438,33 @@ class ToomCookInstance: # TODO: add make_toom_3 -def toom_cook_mul(fn, word_count, instances): - # type: (Fn, int, Sequence[ToomCookInstance]) -> OSet[Op] - retval = OSet() # type: OSet[Op] - raise NotImplementedError +def simple_mul(fn, lhs, rhs): + # type: (Fn, SSAGPRRange, SSAGPRRange) -> SSAGPRRange + """ simple O(n^2) big-int unsigned multiply """ + if lhs.ty.length < rhs.ty.length: + lhs, rhs = rhs, lhs + # split rhs into elements + rhs_words = OpSplit(fn, rhs, range(1, rhs.ty.length)).results + retval = None + vl = OpSetVLImm(fn, lhs.ty.length).out + zero = OpLI(fn, 0).out + for shift, rhs_word in enumerate(rhs_words): + mul = OpBigIntMulDiv(fn, RA=lhs, RB=rhs_word, RC=zero, + is_div=False, vl=vl) + if retval is None: + retval = OpConcat(fn, [mul.RT, mul.RS]).dest + else: + first_part, last_part = OpSplit(fn, retval, [shift]).results + add = OpBigIntAddSub( + fn, lhs=mul.RT, rhs=last_part, CA_in=OpSetCA(fn, False).out, + is_sub=False, vl=vl) + add_hi = OpBigIntAddSub(fn, lhs=mul.RS, rhs=zero, CA_in=add.CA_out, + is_sub=False) + retval = OpConcat(fn, [first_part, add.out, add_hi.out]).dest + assert retval is not None return retval + + +def toom_cook_mul(fn, lhs, rhs, instances): + # type: (Fn, SSAGPRRange, SSAGPRRange, list[ToomCookInstance]) -> SSAGPRRange + raise NotImplementedError -- 2.30.2