From df779d330220e88b03491f32ce66d34392739b5b Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Mon, 7 Nov 2022 02:54:55 -0800 Subject: [PATCH] working on code --- .../_tests/test_toom_cook.py | 422 +++++++++++------- src/bigint_presentation_code/compiler_ir2.py | 62 ++- .../register_allocator2.py | 40 +- src/bigint_presentation_code/toom_cook.py | 79 +++- src/bigint_presentation_code/util.py | 47 +- 5 files changed, 399 insertions(+), 251 deletions(-) diff --git a/src/bigint_presentation_code/_tests/test_toom_cook.py b/src/bigint_presentation_code/_tests/test_toom_cook.py index 6fff570..994c951 100644 --- a/src/bigint_presentation_code/_tests/test_toom_cook.py +++ b/src/bigint_presentation_code/_tests/test_toom_cook.py @@ -1,31 +1,38 @@ import unittest -from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BYTES, SSAGPR, VL, FixedGPRRangeType, Fn, - GlobalMem, GPRRange, - GPRRangeType, OpCopy, - OpFuncArg, OpInputMem, - OpSetVLImm, OpStore, PreRASimState, SSAGPRRange, XERBit, - generate_assembly) -from bigint_presentation_code.register_allocator import allocate_registers +from bigint_presentation_code.compiler_ir2 import (GPR_SIZE_IN_BYTES, Fn, + GenAsmState, OpKind, + PreRASimState) +from bigint_presentation_code.register_allocator2 import allocate_registers from bigint_presentation_code.toom_cook import ToomCookInstance, simple_mul -from bigint_presentation_code.util import FMap class SimpleMul192x192: def __init__(self): + super().__init__() self.fn = fn = Fn() - self.mem_in = mem = OpInputMem(fn).out - self.dest_ptr_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3))).out - self.lhs_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(4, 3))).out - self.rhs_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(7, 3))).out - dest_ptr = OpCopy(fn, self.dest_ptr_in, GPRRangeType()).dest - vl = OpSetVLImm(fn, 3).out - lhs = OpCopy(fn, self.lhs_in, GPRRangeType(3), vl=vl).dest - rhs = OpCopy(fn, self.rhs_in, GPRRangeType(3), vl=vl).dest - retval = simple_mul(fn, lhs, rhs) - vl = OpSetVLImm(fn, 6).out - self.mem_out = OpStore(fn, RS=retval, RA=dest_ptr, offset=0, - mem_in=mem, vl=vl).mem_out + self.dest_offset = 0 + self.lhs_offset = 48 + self.dest_offset + self.rhs_offset = 24 + self.lhs_offset + self.ptr_in = fn.append_new_op(kind=OpKind.FuncArgR3, + name="ptr_in").outputs[0] + setvl3 = fn.append_new_op(kind=OpKind.SetVLI, immediates=[3], + maxvl=3, name="setvl3") + load_lhs = fn.append_new_op( + kind=OpKind.SvLd, immediates=[self.lhs_offset], + input_vals=[self.ptr_in, setvl3.outputs[0]], + name="load_lhs", maxvl=3) + load_rhs = fn.append_new_op( + kind=OpKind.SvLd, immediates=[self.rhs_offset], + input_vals=[self.ptr_in, setvl3.outputs[0]], + name="load_rhs", maxvl=3) + retval = simple_mul(fn, load_lhs.outputs[0], load_rhs.outputs[0]) + setvl6 = fn.append_new_op(kind=OpKind.SetVLI, immediates=[6], + maxvl=6, name="setvl6") + fn.append_new_op( + kind=OpKind.SvStd, + input_vals=[retval, self.ptr_in, setvl6.outputs[0]], + immediates=[self.dest_offset], maxvl=6, name="store_dest") class TestToomCook(unittest.TestCase): @@ -207,95 +214,248 @@ class TestToomCook(unittest.TestCase): # == int.from_bytes(b"arbitrary 192x192->384-bit multiplication test", # 'little') code = SimpleMul192x192() - dest_ptr = 0x100 - state = PreRASimState( - gprs={}, VLs={}, CAs={}, global_mems={code.mem_in: FMap()}, - stack_slots={}, fixed_gprs={ - code.dest_ptr_in: (dest_ptr,), - code.lhs_in: (0x821a2342132c5b57, 0x4c6b5f2b19e1a53e, - 0x000191acb262e15b), - code.rhs_in: (0x208a49071aeec507, 0xcf1f597598194ae6, - 0x4a37c0567bcbab53) - }) + ptr_in = 0x100 + dest_ptr = ptr_in + code.dest_offset + lhs_ptr = ptr_in + code.lhs_offset + rhs_ptr = ptr_in + code.rhs_offset + state = PreRASimState(ssa_vals={code.ptr_in: (ptr_in,)}, memory={}) + state.store(lhs_ptr, 0x821a2342132c5b57) + state.store(lhs_ptr + 8, 0x4c6b5f2b19e1a53e) + state.store(lhs_ptr + 16, 0x000191acb262e15b) + state.store(rhs_ptr, 0x208a49071aeec507) + state.store(rhs_ptr + 8, 0xcf1f597598194ae6) + state.store(rhs_ptr + 16, 0x4a37c0567bcbab53) code.fn.pre_ra_sim(state) expected_bytes = b"arbitrary 192x192->384-bit multiplication test" OUT_BYTE_COUNT = 6 * GPR_SIZE_IN_BYTES expected_bytes = expected_bytes.ljust(OUT_BYTE_COUNT, b'\0') - mem_out = state.global_mems[code.mem_out] out_bytes = bytes( - mem_out.get(dest_ptr + i, 0) for i in range(OUT_BYTE_COUNT)) + state.load_byte(dest_ptr + i) for i in range(OUT_BYTE_COUNT)) self.assertEqual(out_bytes, expected_bytes) def test_simple_mul_192x192_ops(self): code = SimpleMul192x192() fn = code.fn self.assertEqual([repr(v) for v in fn.ops], [ - 'OpInputMem(#0, <#0.out: GlobalMemType()>)', - 'OpFuncArg(#1, <#1.out: )>>)', - 'OpFuncArg(#2, <#2.out: )>>)', - 'OpFuncArg(#3, <#3.out: )>>)', - 'OpCopy(#4, <#4.dest: >, src=<#1.out: )>>, ' - 'vl=None)', - 'OpSetVLImm(#5, <#5.out: KnownVLType(length=3)>)', - 'OpCopy(#6, <#6.dest: >, ' - 'src=<#2.out: )>>, ' - 'vl=<#5.out: KnownVLType(length=3)>)', - 'OpCopy(#7, <#7.dest: >, ' - 'src=<#3.out: )>>, ' - 'vl=<#5.out: KnownVLType(length=3)>)', - 'OpSplit(#8, results=(<#8.results[0]: >, ' - '<#8.results[1]: >, <#8.results[2]: >), ' - 'src=<#7.dest: >)', - 'OpSetVLImm(#9, <#9.out: KnownVLType(length=3)>)', - 'OpLI(#10, <#10.out: >, value=0, vl=None)', - 'OpBigIntMulDiv(#11, <#11.RT: >, ' - 'RA=<#6.dest: >, RB=<#8.results[0]: >, ' - 'RC=<#10.out: >, <#11.RS: >, is_div=False, ' - 'vl=<#9.out: KnownVLType(length=3)>)', - 'OpConcat(#12, <#12.dest: >, sources=(' - '<#11.RT: >, <#11.RS: >))', - 'OpBigIntMulDiv(#13, <#13.RT: >, ' - 'RA=<#6.dest: >, RB=<#8.results[1]: >, ' - 'RC=<#10.out: >, <#13.RS: >, is_div=False, ' - 'vl=<#9.out: KnownVLType(length=3)>)', - 'OpSplit(#14, results=(<#14.results[0]: >, ' - '<#14.results[1]: >), src=<#12.dest: >)', - 'OpSetCA(#15, <#15.out: CAType()>, value=False)', - 'OpBigIntAddSub(#16, <#16.out: >, ' - 'lhs=<#13.RT: >, rhs=<#14.results[1]: >, ' - 'CA_in=<#15.out: CAType()>, <#16.CA_out: CAType()>, is_sub=False, ' - 'vl=<#9.out: KnownVLType(length=3)>)', - 'OpBigIntAddSub(#17, <#17.out: >, ' - 'lhs=<#13.RS: >, rhs=<#10.out: >, ' - 'CA_in=<#16.CA_out: CAType()>, <#17.CA_out: CAType()>, ' - 'is_sub=False, vl=None)', - 'OpConcat(#18, <#18.dest: >, sources=(' - '<#14.results[0]: >, <#16.out: >, ' - '<#17.out: >))', - 'OpBigIntMulDiv(#19, <#19.RT: >, ' - 'RA=<#6.dest: >, RB=<#8.results[2]: >, ' - 'RC=<#10.out: >, <#19.RS: >, is_div=False, ' - 'vl=<#9.out: KnownVLType(length=3)>)', - 'OpSplit(#20, results=(<#20.results[0]: >, ' - '<#20.results[1]: >), src=<#18.dest: >)', - 'OpSetCA(#21, <#21.out: CAType()>, value=False)', - 'OpBigIntAddSub(#22, <#22.out: >, ' - 'lhs=<#19.RT: >, rhs=<#20.results[1]: >, ' - 'CA_in=<#21.out: CAType()>, <#22.CA_out: CAType()>, is_sub=False, ' - 'vl=<#9.out: KnownVLType(length=3)>)', - 'OpBigIntAddSub(#23, <#23.out: >, ' - 'lhs=<#19.RS: >, rhs=<#10.out: >, ' - 'CA_in=<#22.CA_out: CAType()>, <#23.CA_out: CAType()>, ' - 'is_sub=False, vl=None)', - 'OpConcat(#24, <#24.dest: >, sources=(' - '<#20.results[0]: >, <#22.out: >, ' - '<#23.out: >))', - 'OpSetVLImm(#25, <#25.out: KnownVLType(length=6)>)', - 'OpStore(#26, RS=<#24.dest: >, ' - 'RA=<#4.dest: >, offset=0, ' - 'mem_in=<#0.out: GlobalMemType()>, ' - '<#26.mem_out: GlobalMemType()>, ' - 'vl=<#25.out: KnownVLType(length=6)>)' + "Op(kind=OpKind.FuncArgR3, " + "input_vals=[], " + "input_uses=(), immediates=[], " + "outputs=(>,), " + "name='ptr_in')", + "Op(kind=OpKind.SetVLI, " + "input_vals=[], " + "input_uses=(), immediates=[3], " + "outputs=(>,), " + "name='setvl3')", + "Op(kind=OpKind.SvLd, " + "input_vals=[>, " + ">], " + "input_uses=(>, " + ">), immediates=[48], " + "outputs=(>,), " + "name='load_lhs')", + "Op(kind=OpKind.SvLd, " + "input_vals=[>, " + ">], " + "input_uses=(>, " + ">), immediates=[72], " + "outputs=(>,), " + "name='load_rhs')", + "Op(kind=OpKind.SetVLI, " + "input_vals=[], " + "input_uses=(), immediates=[3], " + "outputs=(>,), " + "name='rhs_setvl')", + "Op(kind=OpKind.Spread, " + "input_vals=[>, " + ">], " + "input_uses=(>, " + ">), immediates=[], " + "outputs=(>, " + ">, " + ">), " + "name='rhs_spread')", + "Op(kind=OpKind.SetVLI, " + "input_vals=[], " + "input_uses=(), immediates=[3], " + "outputs=(>,), " + "name='lhs_setvl')", + "Op(kind=OpKind.LI, " + "input_vals=[], " + "input_uses=(), immediates=[0], " + "outputs=(>,), " + "name='zero')", + "Op(kind=OpKind.SvMAddEDU, " + "input_vals=[>, " + ">, " + ">, " + ">], " + "input_uses=(>, " + ">, " + ">, " + ">), immediates=[], " + "outputs=(>, " + ">), " + "name='mul0')", + "Op(kind=OpKind.Spread, " + "input_vals=[>, " + ">], " + "input_uses=(>, " + ">), immediates=[], " + "outputs=(>, " + ">, " + ">), " + "name='mul0_rt_spread')", + "Op(kind=OpKind.SvMAddEDU, " + "input_vals=[>, " + ">, " + ">, " + ">], " + "input_uses=(>, " + ">, " + ">, " + ">), immediates=[], " + "outputs=(>, " + ">), " + "name='mul1')", + "Op(kind=OpKind.Concat, " + "input_vals=[>, " + ">, " + ">, " + ">], " + "input_uses=(>, " + ">, " + ">, " + ">), immediates=[], " + "outputs=(>,), " + "name='add1_rb_concat')", + "Op(kind=OpKind.ClearCA, " + "input_vals=[], " + "input_uses=(), immediates=[], " + "outputs=(>,), " + "name='clear_ca1')", + "Op(kind=OpKind.SvAddE, " + "input_vals=[>, " + ">, " + ">, " + ">], " + "input_uses=(>, " + ">, " + ">, " + ">), immediates=[], " + "outputs=(>, " + ">), " + "name='add1')", + "Op(kind=OpKind.Spread, " + "input_vals=[>, " + ">], " + "input_uses=(>, " + ">), immediates=[], " + "outputs=(>, " + ">, " + ">), " + "name='add1_rt_spread')", + "Op(kind=OpKind.AddZE, " + "input_vals=[>, " + ">], " + "input_uses=(>, " + ">), immediates=[], " + "outputs=(>, " + ">), " + "name='add_hi1')", + "Op(kind=OpKind.SvMAddEDU, " + "input_vals=[>, " + ">, " + ">, " + ">], " + "input_uses=(>, " + ">, " + ">, " + ">), immediates=[], " + "outputs=(>, " + ">), " + "name='mul2')", + "Op(kind=OpKind.Concat, " + "input_vals=[>, " + ">, " + ">, " + ">], " + "input_uses=(>, " + ">, " + ">, " + ">), immediates=[], " + "outputs=(>,), " + "name='add2_rb_concat')", + "Op(kind=OpKind.ClearCA, " + "input_vals=[], " + "input_uses=(), immediates=[], " + "outputs=(>,), " + "name='clear_ca2')", + "Op(kind=OpKind.SvAddE, " + "input_vals=[>, " + ">, " + ">, " + ">], " + "input_uses=(>, " + ">, " + ">, " + ">), immediates=[], " + "outputs=(>, " + ">), " + "name='add2')", + "Op(kind=OpKind.Spread, " + "input_vals=[>, " + ">], " + "input_uses=(>, " + ">), immediates=[], " + "outputs=(>, " + ">, " + ">), " + "name='add2_rt_spread')", + "Op(kind=OpKind.AddZE, " + "input_vals=[>, " + ">], " + "input_uses=(>, " + ">), immediates=[], " + "outputs=(>, " + ">), " + "name='add_hi2')", + "Op(kind=OpKind.SetVLI, " + "input_vals=[], " + "input_uses=(), immediates=[6], " + "outputs=(>,), " + "name='retval_setvl')", + "Op(kind=OpKind.Concat, " + "input_vals=[>, " + ">, " + ">, " + ">, " + ">, " + ">, " + ">], " + "input_uses=(>, " + ">, " + ">, " + ">, " + ">, " + ">, " + ">), immediates=[], " + "outputs=(>,), " + "name='concat_retval')", + "Op(kind=OpKind.SetVLI, " + "input_vals=[], " + "input_uses=(), immediates=[6], " + "outputs=(>,), " + "name='setvl6')", + "Op(kind=OpKind.SvStd, " + "input_vals=[>, " + ">, " + ">], " + "input_uses=(>, " + ">, " + ">), immediates=[0], " + "outputs=(), " + "name='store_dest')", ]) # FIXME: register allocator currently allocates wrong registers @@ -303,73 +463,21 @@ class TestToomCook(unittest.TestCase): def test_simple_mul_192x192_reg_alloc(self): code = SimpleMul192x192() fn = code.fn - assigned_registers = allocate_registers(fn.ops) + assigned_registers = allocate_registers(fn) self.assertEqual(assigned_registers, { - fn.ops[13].RS: GPRRange(9), # type: ignore - fn.ops[14].results[0]: GPRRange(6), # type: ignore - fn.ops[14].results[1]: GPRRange(7, length=3), # type: ignore - fn.ops[15].out: XERBit.CA, # type: ignore - fn.ops[16].out: GPRRange(7, length=3), # type: ignore - fn.ops[16].CA_out: XERBit.CA, # type: ignore - fn.ops[17].out: GPRRange(10), # type: ignore - fn.ops[17].CA_out: XERBit.CA, # type: ignore - fn.ops[18].dest: GPRRange(6, length=5), # type: ignore - fn.ops[19].RT: GPRRange(3, length=3), # type: ignore - fn.ops[19].RS: GPRRange(9), # type: ignore - fn.ops[20].results[0]: GPRRange(6, length=2), # type: ignore - fn.ops[20].results[1]: GPRRange(8, length=3), # type: ignore - fn.ops[21].out: XERBit.CA, # type: ignore - fn.ops[22].out: GPRRange(8, length=3), # type: ignore - fn.ops[22].CA_out: XERBit.CA, # type: ignore - fn.ops[23].out: GPRRange(11), # type: ignore - fn.ops[23].CA_out: XERBit.CA, # type: ignore - fn.ops[24].dest: GPRRange(6, length=6), # type: ignore - fn.ops[25].out: VL.VL_MAXVL, # type: ignore - fn.ops[26].mem_out: GlobalMem.GlobalMem, # type: ignore - fn.ops[0].out: GlobalMem.GlobalMem, # type: ignore - fn.ops[1].out: GPRRange(3), # type: ignore - fn.ops[2].out: GPRRange(4, length=3), # type: ignore - fn.ops[3].out: GPRRange(7, length=3), # type: ignore - fn.ops[4].dest: GPRRange(12), # type: ignore - fn.ops[5].out: VL.VL_MAXVL, # type: ignore - fn.ops[6].dest: GPRRange(17, length=3), # type: ignore - fn.ops[7].dest: GPRRange(14, length=3), # type: ignore - fn.ops[8].results[0]: GPRRange(14), # type: ignore - fn.ops[8].results[1]: GPRRange(15), # type: ignore - fn.ops[8].results[2]: GPRRange(16), # type: ignore - fn.ops[9].out: VL.VL_MAXVL, # type: ignore - fn.ops[10].out: GPRRange(9), # type: ignore - fn.ops[11].RT: GPRRange(6, length=3), # type: ignore - fn.ops[11].RS: GPRRange(9), # type: ignore - fn.ops[12].dest: GPRRange(6, length=4), # type: ignore - fn.ops[13].RT: GPRRange(3, length=3) # type: ignore }) self.fail("register allocator currently allocates wrong registers") # FIXME: register allocator currently allocates wrong registers @unittest.expectedFailure def test_simple_mul_192x192_asm(self): + self.skipTest("WIP") code = SimpleMul192x192() - asm = generate_assembly(code.fn.ops) - self.assertEqual(asm, [ - 'or 12, 3, 3', - 'setvl 0, 0, 3, 0, 1, 1', - 'sv.or *17, *4, *4', - 'sv.or *14, *7, *7', - 'setvl 0, 0, 3, 0, 1, 1', - 'addi 9, 0, 0', - 'sv.maddedu *6, *17, 14, 9', - 'sv.maddedu *3, *17, 15, 9', - 'addic 0, 0, 0', - 'sv.adde *7, *3, *7', - 'adde 10, 9, 9', - 'sv.maddedu *3, *17, 16, 9', - 'addic 0, 0, 0', - 'sv.adde *8, *3, *8', - 'adde 11, 9, 9', - 'setvl 0, 0, 6, 0, 1, 1', - 'sv.std *6, 0(12)', - 'bclr 20, 0, 0' + fn = code.fn + assigned_registers = allocate_registers(fn) + gen_asm_state = GenAsmState(assigned_registers) + fn.gen_asm(gen_asm_state) + self.assertEqual(gen_asm_state.output, [ ]) self.fail("register allocator currently allocates wrong registers") diff --git a/src/bigint_presentation_code/compiler_ir2.py b/src/bigint_presentation_code/compiler_ir2.py index b256a07..bd3d38c 100644 --- a/src/bigint_presentation_code/compiler_ir2.py +++ b/src/bigint_presentation_code/compiler_ir2.py @@ -12,7 +12,7 @@ from nmutil.plain_data import fields, plain_data from bigint_presentation_code.type_util import (Literal, Self, assert_never, final) -from bigint_presentation_code.util import BitSet, FBitSet, FMap, OFSet, OSet +from bigint_presentation_code.util import BitSet, FBitSet, FMap, InternedMeta, OFSet, OSet @final @@ -180,7 +180,7 @@ assert OpStage.Early < OpStage.Late, "early must be less than late" @plain_data(frozen=True, unsafe_hash=True, repr=False) @final @total_ordering -class ProgramPoint: +class ProgramPoint(metaclass=InternedMeta): __slots__ = "op_index", "stage" def __init__(self, op_index, stage): @@ -225,7 +225,7 @@ class ProgramPoint: @plain_data(frozen=True, unsafe_hash=True, repr=False) @final -class ProgramRange(Sequence[ProgramPoint]): +class ProgramRange(Sequence[ProgramPoint], metaclass=InternedMeta): __slots__ = "start", "stop" def __init__(self, start, stop): @@ -389,7 +389,7 @@ class BaseTy(Enum): @plain_data(frozen=True, unsafe_hash=True, repr=False) @final -class Ty: +class Ty(metaclass=InternedMeta): __slots__ = "base_ty", "reg_len" @staticmethod @@ -529,7 +529,7 @@ class LocSubKind(Enum): @plain_data(frozen=True, unsafe_hash=True) @final -class GenericTy: +class GenericTy(metaclass=InternedMeta): __slots__ = "base_ty", "is_vec" def __init__(self, base_ty, is_vec): @@ -557,7 +557,7 @@ class GenericTy: @plain_data(frozen=True, unsafe_hash=True) @final -class Loc: +class Loc(metaclass=InternedMeta): __slots__ = "kind", "start", "reg_len" @staticmethod @@ -643,7 +643,7 @@ SPECIAL_GPRS = ( @plain_data(frozen=True, eq=False) @final -class LocSet(AbstractSet[Loc]): +class LocSet(AbstractSet[Loc], metaclass=InternedMeta): __slots__ = "starts", "ty" def __init__(self, __locs=()): @@ -746,19 +746,8 @@ class LocSet(AbstractSet[Loc]): def __len__(self): return self.__len - __HASHES = {} # type: dict[tuple[Ty | None, FMap[LocKind, FBitSet]], int] - - @cached_property - def __hash(self): - # cache hashes to avoid slow LocSet iteration - key = self.ty, self.starts - retval = self.__HASHES.get(key, None) - if retval is None: - self.__HASHES[key] = retval = super(LocSet, self)._hash() - return retval - def __hash__(self): - return self.__hash + return super()._hash() def __eq__(self, __other): # type: (LocSet | Any) -> bool @@ -780,7 +769,7 @@ class LocSet(AbstractSet[Loc]): @plain_data(frozen=True, unsafe_hash=True) @final -class GenericOperandDesc: +class GenericOperandDesc(metaclass=InternedMeta): """generic Op operand descriptor""" __slots__ = ("ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread", "write_stage") @@ -882,7 +871,7 @@ class GenericOperandDesc: @plain_data(frozen=True, unsafe_hash=True) @final -class OperandDesc: +class OperandDesc(metaclass=InternedMeta): """Op operand descriptor""" __slots__ = ("loc_set_before_spread", "tied_input_index", "spread_index", "write_stage") @@ -955,7 +944,7 @@ OD_VL = GenericOperandDesc( @plain_data(frozen=True, unsafe_hash=True) @final -class GenericOpProperties: +class GenericOpProperties(metaclass=InternedMeta): __slots__ = ("demo_asm", "inputs", "outputs", "immediates", "is_copy", "is_load_immediate", "has_side_effects") @@ -1007,7 +996,7 @@ class GenericOpProperties: @plain_data(frozen=True, unsafe_hash=True) @final -class OpProperties: +class OpProperties(metaclass=InternedMeta): __slots__ = "kind", "inputs", "outputs", "maxvl" def __init__(self, kind, maxvl): @@ -1159,6 +1148,31 @@ class OpKind(Enum): _PRE_RA_SIMS[SvAddE] = lambda: OpKind.__svadde_pre_ra_sim _GEN_ASMS[SvAddE] = lambda: OpKind.__svadde_gen_asm + @staticmethod + def __addze_pre_ra_sim(op, state): + # type: (Op, PreRASimState) -> None + RA, = state.ssa_vals[op.input_vals[0]] + carry, = state.ssa_vals[op.input_vals[1]] + v = RA + carry + RT = v & GPR_VALUE_MASK + carry = (v >> GPR_SIZE_IN_BITS) != 0 + state.ssa_vals[op.outputs[0]] = RT, + state.ssa_vals[op.outputs[1]] = carry, + + @staticmethod + def __addze_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + RT = state.vgpr(op.outputs[0]) + RA = state.vgpr(op.input_vals[0]) + state.writeln(f"addze {RT}, {RA}") + AddZE = GenericOpProperties( + demo_asm="addze RT, RA", + inputs=[OD_BASE_SGPR, OD_CA], + outputs=[OD_BASE_SGPR, OD_CA.tied_to_input(1)], + ) + _PRE_RA_SIMS[AddZE] = lambda: OpKind.__addze_pre_ra_sim + _GEN_ASMS[AddZE] = lambda: OpKind.__addze_gen_asm + @staticmethod def __svsubfe_pre_ra_sim(op, state): # type: (Op, PreRASimState) -> None @@ -1611,7 +1625,7 @@ class OpKind(Enum): @plain_data(frozen=True, unsafe_hash=True, repr=False) -class SSAValOrUse(metaclass=ABCMeta): +class SSAValOrUse(metaclass=InternedMeta): __slots__ = "op", "operand_idx" def __init__(self, op, operand_idx): diff --git a/src/bigint_presentation_code/register_allocator2.py b/src/bigint_presentation_code/register_allocator2.py index aefa291..198eebc 100644 --- a/src/bigint_presentation_code/register_allocator2.py +++ b/src/bigint_presentation_code/register_allocator2.py @@ -15,43 +15,7 @@ from bigint_presentation_code.compiler_ir2 import (BaseTy, Fn, FnAnalysis, Loc, LocSet, ProgramRange, SSAVal, Ty) from bigint_presentation_code.type_util import final -from bigint_presentation_code.util import FMap, OFSet, OSet - - -@plain_data(unsafe_hash=True, order=True, frozen=True) -class LiveInterval: - __slots__ = "first_write", "last_use" - - def __init__(self, first_write, last_use=None): - # type: (int, int | None) -> None - super().__init__() - if last_use is None: - last_use = first_write - if last_use < first_write: - raise ValueError("uses must be after first_write") - if first_write < 0 or last_use < 0: - raise ValueError("indexes must be nonnegative") - self.first_write = first_write - self.last_use = last_use - - def overlaps(self, other): - # type: (LiveInterval) -> bool - if self.first_write == other.first_write: - return True - return self.last_use > other.first_write \ - and other.last_use > self.first_write - - def __add__(self, use): - # type: (int) -> LiveInterval - last_use = max(self.last_use, use) - return LiveInterval(first_write=self.first_write, last_use=last_use) - - @property - def live_after_op_range(self): - """the range of op indexes where self is live immediately after the - Op at each index - """ - return range(self.first_write, self.last_use) +from bigint_presentation_code.util import FMap, InternedMeta, OFSet, OSet class BadMergedSSAVal(ValueError): @@ -60,7 +24,7 @@ class BadMergedSSAVal(ValueError): @plain_data(frozen=True, repr=False) @final -class MergedSSAVal: +class MergedSSAVal(metaclass=InternedMeta): """a set of `SSAVal`s along with their offsets, all register allocated as a single unit. diff --git a/src/bigint_presentation_code/toom_cook.py b/src/bigint_presentation_code/toom_cook.py index 246f654..aa18967 100644 --- a/src/bigint_presentation_code/toom_cook.py +++ b/src/bigint_presentation_code/toom_cook.py @@ -8,10 +8,7 @@ from typing import Any, Generic, Iterable, Mapping, TypeVar, Union from nmutil.plain_data import plain_data -from bigint_presentation_code.compiler_ir import (Fn, OpBigIntAddSub, - OpBigIntMulDiv, OpConcat, - OpLI, OpSetCA, OpSetVLImm, - OpSplit, SSAGPRRange) +from bigint_presentation_code.compiler_ir2 import (Fn, OpKind, SSAVal) from bigint_presentation_code.matrix import Matrix from bigint_presentation_code.type_util import Literal, final @@ -190,6 +187,7 @@ class EvalOp(Generic[_EvalOpLHS, _EvalOpRHS]): def __init__(self, lhs, rhs): # type: (_EvalOpLHS, _EvalOpRHS) -> None + super().__init__() self.lhs = lhs self.rhs = rhs self.poly = self._make_poly() @@ -442,32 +440,65 @@ class ToomCookInstance: def simple_mul(fn, lhs, rhs): - # type: (Fn, SSAGPRRange, SSAGPRRange) -> SSAGPRRange + # type: (Fn, SSAVal, SSAVal) -> SSAVal """ simple O(n^2) big-int unsigned multiply """ - if lhs.ty.length < rhs.ty.length: + if lhs.ty.reg_len < rhs.ty.reg_len: lhs, rhs = rhs, lhs # split rhs into elements - rhs_words = OpSplit(fn, rhs, range(1, rhs.ty.length)).results - retval = None - vl = OpSetVLImm(fn, lhs.ty.length).out - zero = OpLI(fn, 0).out + rhs_setvl = fn.append_new_op(kind=OpKind.SetVLI, + immediates=[rhs.ty.reg_len], name="rhs_setvl") + rhs_spread = fn.append_new_op( + kind=OpKind.Spread, input_vals=[rhs, rhs_setvl.outputs[0]], + maxvl=rhs.ty.reg_len, name="rhs_spread") + rhs_words = rhs_spread.outputs + spread_retval = None # type: tuple[SSAVal, ...] | None + maxvl = lhs.ty.reg_len + lhs_setvl = fn.append_new_op(kind=OpKind.SetVLI, + immediates=[lhs.ty.reg_len], name="lhs_setvl") + vl = lhs_setvl.outputs[0] + zero_op = fn.append_new_op(kind=OpKind.LI, immediates=[0], name="zero") + zero = zero_op.outputs[0] for shift, rhs_word in enumerate(rhs_words): - mul = OpBigIntMulDiv(fn, RA=lhs, RB=rhs_word, RC=zero, - is_div=False, vl=vl) - if retval is None: - retval = OpConcat(fn, [mul.RT, mul.RS]).dest + mul = fn.append_new_op(kind=OpKind.SvMAddEDU, + input_vals=[lhs, rhs_word, zero, vl], + maxvl=maxvl, name=f"mul{shift}") + if spread_retval is None: + mul_rt_spread = fn.append_new_op( + kind=OpKind.Spread, input_vals=[mul.outputs[0], vl], + name=f"mul{shift}_rt_spread", maxvl=maxvl) + spread_retval = (*mul_rt_spread.outputs, mul.outputs[1]) else: - first_part, last_part = OpSplit(fn, retval, [shift]).results - add = OpBigIntAddSub( - fn, lhs=mul.RT, rhs=last_part, CA_in=OpSetCA(fn, False).out, - is_sub=False, vl=vl) - add_hi = OpBigIntAddSub(fn, lhs=mul.RS, rhs=zero, CA_in=add.CA_out, - is_sub=False) - retval = OpConcat(fn, [first_part, add.out, add_hi.out]).dest - assert retval is not None - return retval + first_part = spread_retval[:shift] # type: tuple[SSAVal, ...] + last_part = spread_retval[shift:] + + add_rb_concat = fn.append_new_op( + kind=OpKind.Concat, input_vals=[*last_part, vl], + name=f"add{shift}_rb_concat", maxvl=maxvl) + clear_ca = fn.append_new_op(kind=OpKind.ClearCA, + name=f"clear_ca{shift}") + add = fn.append_new_op( + kind=OpKind.SvAddE, input_vals=[ + mul.outputs[0], add_rb_concat.outputs[0], + clear_ca.outputs[0], vl], + maxvl=maxvl, name=f"add{shift}") + add_rt_spread = fn.append_new_op( + kind=OpKind.Spread, input_vals=[add.outputs[0], vl], + name=f"add{shift}_rt_spread", maxvl=maxvl) + add_hi = fn.append_new_op( + kind=OpKind.AddZE, input_vals=[mul.outputs[1], add.outputs[1]], + name=f"add_hi{shift}") + spread_retval = ( + *first_part, *add_rt_spread.outputs, add_hi.outputs[0]) + assert spread_retval is not None + lhs_setvl = fn.append_new_op( + kind=OpKind.SetVLI, immediates=[len(spread_retval)], + name="retval_setvl") + concat_retval = fn.append_new_op( + kind=OpKind.Concat, input_vals=[*spread_retval, lhs_setvl.outputs[0]], + name="concat_retval", maxvl=len(spread_retval)) + return concat_retval.outputs[0] def toom_cook_mul(fn, lhs, rhs, instances): - # type: (Fn, SSAGPRRange, SSAGPRRange, list[ToomCookInstance]) -> SSAGPRRange + # type: (Fn, SSAVal, SSAVal, list[ToomCookInstance]) -> SSAVal raise NotImplementedError diff --git a/src/bigint_presentation_code/util.py b/src/bigint_presentation_code/util.py index 757f267..4b7399f 100644 --- a/src/bigint_presentation_code/util.py +++ b/src/bigint_presentation_code/util.py @@ -1,4 +1,5 @@ -from abc import abstractmethod +from abc import ABCMeta, abstractmethod +from functools import lru_cache from typing import (AbstractSet, Any, Iterable, Iterator, Mapping, MutableSet, TypeVar, overload) @@ -17,12 +18,42 @@ __all__ = [ "OSet", "top_set_bit_index", "trailing_zero_count", + "InternedMeta", ] -class OFSet(AbstractSet[_T_co]): +class InternedMeta(ABCMeta): + def __init__(self, *args, **kwargs): + # type: (*Any, **Any) -> None + super().__init__(*args, **kwargs) + self.__INTERN_TABLE = {} # type: dict[Any, Any] + + def __intern(self, value): + # type: (_T) -> _T + value = self.__INTERN_TABLE.setdefault(value, value) + if value.__dict__.get("_InternedMeta__interned", False): + return value + value.__dict__["_InternedMeta__interned"] = True + hash_v = hash(value) + value.__dict__["__hash__"] = lambda: hash_v + old_eq = value.__eq__ + + def __eq__(__o): + # type: (_T) -> bool + if value.__class__ is __o.__class__: + return value is __o + return old_eq(__o) + value.__dict__["__eq__"] = __eq__ + return value + + def __call__(self, *args, **kwargs): + # type: (*Any, **Any) -> Any + return self.__intern(super().__call__(*args, **kwargs)) + + +class OFSet(AbstractSet[_T_co], metaclass=InternedMeta): """ ordered frozen set """ - __slots__ = "__items", + __slots__ = "__items", "__dict__", "__weakref__" def __init__(self, items=()): # type: (Iterable[_T_co]) -> None @@ -54,7 +85,7 @@ class OFSet(AbstractSet[_T_co]): class OSet(MutableSet[_T]): """ ordered mutable set """ - __slots__ = "__items", + __slots__ = "__items", "__dict__" def __init__(self, items=()): # type: (Iterable[_T]) -> None @@ -88,9 +119,9 @@ class OSet(MutableSet[_T]): return f"OSet({list(self)})" -class FMap(Mapping[_T, _T_co]): +class FMap(Mapping[_T, _T_co], metaclass=InternedMeta): """ordered frozen hashable mapping""" - __slots__ = "__items", "__hash" + __slots__ = "__items", "__hash", "__dict__", "__weakref__" @overload def __init__(self, items): @@ -167,7 +198,7 @@ except AttributeError: class BaseBitSet(AbstractSet[int]): - __slots__ = "__bits", + __slots__ = "__bits", "__dict__", "__weakref__" @classmethod @abstractmethod @@ -402,7 +433,7 @@ class BitSet(BaseBitSet, MutableSet[int]): return super().__isub__(it) -class FBitSet(BaseBitSet): +class FBitSet(BaseBitSet, metaclass=InternedMeta): """Frozen Bit Set""" @final -- 2.30.2