working on code
authorJacob Lifshay <programmerjake@gmail.com>
Mon, 7 Nov 2022 10:54:55 +0000 (02:54 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Mon, 7 Nov 2022 10:54:55 +0000 (02:54 -0800)
src/bigint_presentation_code/_tests/test_toom_cook.py
src/bigint_presentation_code/compiler_ir2.py
src/bigint_presentation_code/register_allocator2.py
src/bigint_presentation_code/toom_cook.py
src/bigint_presentation_code/util.py

index 6fff570f0a485d49fc39d723da5f25a88a8271cb..994c9510ea797c016c6a1acdbc0c0b333b4cb858 100644 (file)
@@ -1,31 +1,38 @@
 import unittest
 
-from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BYTES, SSAGPR, VL, FixedGPRRangeType, Fn,
-                                                  GlobalMem, GPRRange,
-                                                  GPRRangeType, OpCopy,
-                                                  OpFuncArg, OpInputMem,
-                                                  OpSetVLImm, OpStore, PreRASimState, SSAGPRRange, XERBit,
-                                                  generate_assembly)
-from bigint_presentation_code.register_allocator import allocate_registers
+from bigint_presentation_code.compiler_ir2 import (GPR_SIZE_IN_BYTES, Fn,
+                                                   GenAsmState, OpKind,
+                                                   PreRASimState)
+from bigint_presentation_code.register_allocator2 import allocate_registers
 from bigint_presentation_code.toom_cook import ToomCookInstance, simple_mul
-from bigint_presentation_code.util import FMap
 
 
 class SimpleMul192x192:
     def __init__(self):
+        super().__init__()
         self.fn = fn = Fn()
-        self.mem_in = mem = OpInputMem(fn).out
-        self.dest_ptr_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3))).out
-        self.lhs_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(4, 3))).out
-        self.rhs_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(7, 3))).out
-        dest_ptr = OpCopy(fn, self.dest_ptr_in, GPRRangeType()).dest
-        vl = OpSetVLImm(fn, 3).out
-        lhs = OpCopy(fn, self.lhs_in, GPRRangeType(3), vl=vl).dest
-        rhs = OpCopy(fn, self.rhs_in, GPRRangeType(3), vl=vl).dest
-        retval = simple_mul(fn, lhs, rhs)
-        vl = OpSetVLImm(fn, 6).out
-        self.mem_out = OpStore(fn, RS=retval, RA=dest_ptr, offset=0,
-                               mem_in=mem, vl=vl).mem_out
+        self.dest_offset = 0
+        self.lhs_offset = 48 + self.dest_offset
+        self.rhs_offset = 24 + self.lhs_offset
+        self.ptr_in = fn.append_new_op(kind=OpKind.FuncArgR3,
+                                       name="ptr_in").outputs[0]
+        setvl3 = fn.append_new_op(kind=OpKind.SetVLI, immediates=[3],
+                                  maxvl=3, name="setvl3")
+        load_lhs = fn.append_new_op(
+            kind=OpKind.SvLd, immediates=[self.lhs_offset],
+            input_vals=[self.ptr_in, setvl3.outputs[0]],
+            name="load_lhs", maxvl=3)
+        load_rhs = fn.append_new_op(
+            kind=OpKind.SvLd, immediates=[self.rhs_offset],
+            input_vals=[self.ptr_in, setvl3.outputs[0]],
+            name="load_rhs", maxvl=3)
+        retval = simple_mul(fn, load_lhs.outputs[0], load_rhs.outputs[0])
+        setvl6 = fn.append_new_op(kind=OpKind.SetVLI, immediates=[6],
+                                  maxvl=6, name="setvl6")
+        fn.append_new_op(
+            kind=OpKind.SvStd,
+            input_vals=[retval, self.ptr_in, setvl6.outputs[0]],
+            immediates=[self.dest_offset], maxvl=6, name="store_dest")
 
 
 class TestToomCook(unittest.TestCase):
@@ -207,95 +214,248 @@ class TestToomCook(unittest.TestCase):
         # == int.from_bytes(b"arbitrary 192x192->384-bit multiplication test",
         #                   'little')
         code = SimpleMul192x192()
-        dest_ptr = 0x100
-        state = PreRASimState(
-            gprs={}, VLs={}, CAs={}, global_mems={code.mem_in: FMap()},
-            stack_slots={}, fixed_gprs={
-                code.dest_ptr_in: (dest_ptr,),
-                code.lhs_in: (0x821a2342132c5b57, 0x4c6b5f2b19e1a53e,
-                              0x000191acb262e15b),
-                code.rhs_in: (0x208a49071aeec507, 0xcf1f597598194ae6,
-                              0x4a37c0567bcbab53)
-            })
+        ptr_in = 0x100
+        dest_ptr = ptr_in + code.dest_offset
+        lhs_ptr = ptr_in + code.lhs_offset
+        rhs_ptr = ptr_in + code.rhs_offset
+        state = PreRASimState(ssa_vals={code.ptr_in: (ptr_in,)}, memory={})
+        state.store(lhs_ptr, 0x821a2342132c5b57)
+        state.store(lhs_ptr + 8, 0x4c6b5f2b19e1a53e)
+        state.store(lhs_ptr + 16, 0x000191acb262e15b)
+        state.store(rhs_ptr, 0x208a49071aeec507)
+        state.store(rhs_ptr + 8, 0xcf1f597598194ae6)
+        state.store(rhs_ptr + 16, 0x4a37c0567bcbab53)
         code.fn.pre_ra_sim(state)
         expected_bytes = b"arbitrary 192x192->384-bit multiplication test"
         OUT_BYTE_COUNT = 6 * GPR_SIZE_IN_BYTES
         expected_bytes = expected_bytes.ljust(OUT_BYTE_COUNT, b'\0')
-        mem_out = state.global_mems[code.mem_out]
         out_bytes = bytes(
-            mem_out.get(dest_ptr + i, 0) for i in range(OUT_BYTE_COUNT))
+            state.load_byte(dest_ptr + i) for i in range(OUT_BYTE_COUNT))
         self.assertEqual(out_bytes, expected_bytes)
 
     def test_simple_mul_192x192_ops(self):
         code = SimpleMul192x192()
         fn = code.fn
         self.assertEqual([repr(v) for v in fn.ops], [
-            'OpInputMem(#0, <#0.out: GlobalMemType()>)',
-            'OpFuncArg(#1, <#1.out: <fixed(<r3>)>>)',
-            'OpFuncArg(#2, <#2.out: <fixed(<r4..len=3>)>>)',
-            'OpFuncArg(#3, <#3.out: <fixed(<r7..len=3>)>>)',
-            'OpCopy(#4, <#4.dest: <gpr_ty[1]>>, src=<#1.out: <fixed(<r3>)>>, '
-            'vl=None)',
-            'OpSetVLImm(#5, <#5.out: KnownVLType(length=3)>)',
-            'OpCopy(#6, <#6.dest: <gpr_ty[3]>>, '
-            'src=<#2.out: <fixed(<r4..len=3>)>>, '
-            'vl=<#5.out: KnownVLType(length=3)>)',
-            'OpCopy(#7, <#7.dest: <gpr_ty[3]>>, '
-            'src=<#3.out: <fixed(<r7..len=3>)>>, '
-            'vl=<#5.out: KnownVLType(length=3)>)',
-            'OpSplit(#8, results=(<#8.results[0]: <gpr_ty[1]>>, '
-            '<#8.results[1]: <gpr_ty[1]>>, <#8.results[2]: <gpr_ty[1]>>), '
-            'src=<#7.dest: <gpr_ty[3]>>)',
-            'OpSetVLImm(#9, <#9.out: KnownVLType(length=3)>)',
-            'OpLI(#10, <#10.out: <gpr_ty[1]>>, value=0, vl=None)',
-            'OpBigIntMulDiv(#11, <#11.RT: <gpr_ty[3]>>, '
-            'RA=<#6.dest: <gpr_ty[3]>>, RB=<#8.results[0]: <gpr_ty[1]>>, '
-            'RC=<#10.out: <gpr_ty[1]>>, <#11.RS: <gpr_ty[1]>>, is_div=False, '
-            'vl=<#9.out: KnownVLType(length=3)>)',
-            'OpConcat(#12, <#12.dest: <gpr_ty[4]>>, sources=('
-            '<#11.RT: <gpr_ty[3]>>, <#11.RS: <gpr_ty[1]>>))',
-            'OpBigIntMulDiv(#13, <#13.RT: <gpr_ty[3]>>, '
-            'RA=<#6.dest: <gpr_ty[3]>>, RB=<#8.results[1]: <gpr_ty[1]>>, '
-            'RC=<#10.out: <gpr_ty[1]>>, <#13.RS: <gpr_ty[1]>>, is_div=False, '
-            'vl=<#9.out: KnownVLType(length=3)>)',
-            'OpSplit(#14, results=(<#14.results[0]: <gpr_ty[1]>>, '
-            '<#14.results[1]: <gpr_ty[3]>>), src=<#12.dest: <gpr_ty[4]>>)',
-            'OpSetCA(#15, <#15.out: CAType()>, value=False)',
-            'OpBigIntAddSub(#16, <#16.out: <gpr_ty[3]>>, '
-            'lhs=<#13.RT: <gpr_ty[3]>>, rhs=<#14.results[1]: <gpr_ty[3]>>, '
-            'CA_in=<#15.out: CAType()>, <#16.CA_out: CAType()>, is_sub=False, '
-            'vl=<#9.out: KnownVLType(length=3)>)',
-            'OpBigIntAddSub(#17, <#17.out: <gpr_ty[1]>>, '
-            'lhs=<#13.RS: <gpr_ty[1]>>, rhs=<#10.out: <gpr_ty[1]>>, '
-            'CA_in=<#16.CA_out: CAType()>, <#17.CA_out: CAType()>, '
-            'is_sub=False, vl=None)',
-            'OpConcat(#18, <#18.dest: <gpr_ty[5]>>, sources=('
-            '<#14.results[0]: <gpr_ty[1]>>, <#16.out: <gpr_ty[3]>>, '
-            '<#17.out: <gpr_ty[1]>>))',
-            'OpBigIntMulDiv(#19, <#19.RT: <gpr_ty[3]>>, '
-            'RA=<#6.dest: <gpr_ty[3]>>, RB=<#8.results[2]: <gpr_ty[1]>>, '
-            'RC=<#10.out: <gpr_ty[1]>>, <#19.RS: <gpr_ty[1]>>, is_div=False, '
-            'vl=<#9.out: KnownVLType(length=3)>)',
-            'OpSplit(#20, results=(<#20.results[0]: <gpr_ty[2]>>, '
-            '<#20.results[1]: <gpr_ty[3]>>), src=<#18.dest: <gpr_ty[5]>>)',
-            'OpSetCA(#21, <#21.out: CAType()>, value=False)',
-            'OpBigIntAddSub(#22, <#22.out: <gpr_ty[3]>>, '
-            'lhs=<#19.RT: <gpr_ty[3]>>, rhs=<#20.results[1]: <gpr_ty[3]>>, '
-            'CA_in=<#21.out: CAType()>, <#22.CA_out: CAType()>, is_sub=False, '
-            'vl=<#9.out: KnownVLType(length=3)>)',
-            'OpBigIntAddSub(#23, <#23.out: <gpr_ty[1]>>, '
-            'lhs=<#19.RS: <gpr_ty[1]>>, rhs=<#10.out: <gpr_ty[1]>>, '
-            'CA_in=<#22.CA_out: CAType()>, <#23.CA_out: CAType()>, '
-            'is_sub=False, vl=None)',
-            'OpConcat(#24, <#24.dest: <gpr_ty[6]>>, sources=('
-            '<#20.results[0]: <gpr_ty[2]>>, <#22.out: <gpr_ty[3]>>, '
-            '<#23.out: <gpr_ty[1]>>))',
-            'OpSetVLImm(#25, <#25.out: KnownVLType(length=6)>)',
-            'OpStore(#26, RS=<#24.dest: <gpr_ty[6]>>, '
-            'RA=<#4.dest: <gpr_ty[1]>>, offset=0, '
-            'mem_in=<#0.out: GlobalMemType()>, '
-            '<#26.mem_out: GlobalMemType()>, '
-            'vl=<#25.out: KnownVLType(length=6)>)'
+            "Op(kind=OpKind.FuncArgR3, "
+            "input_vals=[], "
+            "input_uses=(), immediates=[], "
+            "outputs=(<ptr_in.outputs[0]: <I64>>,), "
+            "name='ptr_in')",
+            "Op(kind=OpKind.SetVLI, "
+            "input_vals=[], "
+            "input_uses=(), immediates=[3], "
+            "outputs=(<setvl3.outputs[0]: <VL_MAXVL>>,), "
+            "name='setvl3')",
+            "Op(kind=OpKind.SvLd, "
+            "input_vals=[<ptr_in.outputs[0]: <I64>>, "
+            "<setvl3.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<load_lhs.input_uses[0]: <I64>>, "
+            "<load_lhs.input_uses[1]: <VL_MAXVL>>), immediates=[48], "
+            "outputs=(<load_lhs.outputs[0]: <I64*3>>,), "
+            "name='load_lhs')",
+            "Op(kind=OpKind.SvLd, "
+            "input_vals=[<ptr_in.outputs[0]: <I64>>, "
+            "<setvl3.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<load_rhs.input_uses[0]: <I64>>, "
+            "<load_rhs.input_uses[1]: <VL_MAXVL>>), immediates=[72], "
+            "outputs=(<load_rhs.outputs[0]: <I64*3>>,), "
+            "name='load_rhs')",
+            "Op(kind=OpKind.SetVLI, "
+            "input_vals=[], "
+            "input_uses=(), immediates=[3], "
+            "outputs=(<rhs_setvl.outputs[0]: <VL_MAXVL>>,), "
+            "name='rhs_setvl')",
+            "Op(kind=OpKind.Spread, "
+            "input_vals=[<load_rhs.outputs[0]: <I64*3>>, "
+            "<rhs_setvl.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<rhs_spread.input_uses[0]: <I64*3>>, "
+            "<rhs_spread.input_uses[1]: <VL_MAXVL>>), immediates=[], "
+            "outputs=(<rhs_spread.outputs[0]: <I64>>, "
+            "<rhs_spread.outputs[1]: <I64>>, "
+            "<rhs_spread.outputs[2]: <I64>>), "
+            "name='rhs_spread')",
+            "Op(kind=OpKind.SetVLI, "
+            "input_vals=[], "
+            "input_uses=(), immediates=[3], "
+            "outputs=(<lhs_setvl.outputs[0]: <VL_MAXVL>>,), "
+            "name='lhs_setvl')",
+            "Op(kind=OpKind.LI, "
+            "input_vals=[], "
+            "input_uses=(), immediates=[0], "
+            "outputs=(<zero.outputs[0]: <I64>>,), "
+            "name='zero')",
+            "Op(kind=OpKind.SvMAddEDU, "
+            "input_vals=[<load_lhs.outputs[0]: <I64*3>>, "
+            "<rhs_spread.outputs[0]: <I64>>, "
+            "<zero.outputs[0]: <I64>>, "
+            "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<mul0.input_uses[0]: <I64*3>>, "
+            "<mul0.input_uses[1]: <I64>>, "
+            "<mul0.input_uses[2]: <I64>>, "
+            "<mul0.input_uses[3]: <VL_MAXVL>>), immediates=[], "
+            "outputs=(<mul0.outputs[0]: <I64*3>>, "
+            "<mul0.outputs[1]: <I64>>), "
+            "name='mul0')",
+            "Op(kind=OpKind.Spread, "
+            "input_vals=[<mul0.outputs[0]: <I64*3>>, "
+            "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<mul0_rt_spread.input_uses[0]: <I64*3>>, "
+            "<mul0_rt_spread.input_uses[1]: <VL_MAXVL>>), immediates=[], "
+            "outputs=(<mul0_rt_spread.outputs[0]: <I64>>, "
+            "<mul0_rt_spread.outputs[1]: <I64>>, "
+            "<mul0_rt_spread.outputs[2]: <I64>>), "
+            "name='mul0_rt_spread')",
+            "Op(kind=OpKind.SvMAddEDU, "
+            "input_vals=[<load_lhs.outputs[0]: <I64*3>>, "
+            "<rhs_spread.outputs[1]: <I64>>, "
+            "<zero.outputs[0]: <I64>>, "
+            "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<mul1.input_uses[0]: <I64*3>>, "
+            "<mul1.input_uses[1]: <I64>>, "
+            "<mul1.input_uses[2]: <I64>>, "
+            "<mul1.input_uses[3]: <VL_MAXVL>>), immediates=[], "
+            "outputs=(<mul1.outputs[0]: <I64*3>>, "
+            "<mul1.outputs[1]: <I64>>), "
+            "name='mul1')",
+            "Op(kind=OpKind.Concat, "
+            "input_vals=[<mul0_rt_spread.outputs[1]: <I64>>, "
+            "<mul0_rt_spread.outputs[2]: <I64>>, "
+            "<mul0.outputs[1]: <I64>>, "
+            "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<add1_rb_concat.input_uses[0]: <I64>>, "
+            "<add1_rb_concat.input_uses[1]: <I64>>, "
+            "<add1_rb_concat.input_uses[2]: <I64>>, "
+            "<add1_rb_concat.input_uses[3]: <VL_MAXVL>>), immediates=[], "
+            "outputs=(<add1_rb_concat.outputs[0]: <I64*3>>,), "
+            "name='add1_rb_concat')",
+            "Op(kind=OpKind.ClearCA, "
+            "input_vals=[], "
+            "input_uses=(), immediates=[], "
+            "outputs=(<clear_ca1.outputs[0]: <CA>>,), "
+            "name='clear_ca1')",
+            "Op(kind=OpKind.SvAddE, "
+            "input_vals=[<mul1.outputs[0]: <I64*3>>, "
+            "<add1_rb_concat.outputs[0]: <I64*3>>, "
+            "<clear_ca1.outputs[0]: <CA>>, "
+            "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<add1.input_uses[0]: <I64*3>>, "
+            "<add1.input_uses[1]: <I64*3>>, "
+            "<add1.input_uses[2]: <CA>>, "
+            "<add1.input_uses[3]: <VL_MAXVL>>), immediates=[], "
+            "outputs=(<add1.outputs[0]: <I64*3>>, "
+            "<add1.outputs[1]: <CA>>), "
+            "name='add1')",
+            "Op(kind=OpKind.Spread, "
+            "input_vals=[<add1.outputs[0]: <I64*3>>, "
+            "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<add1_rt_spread.input_uses[0]: <I64*3>>, "
+            "<add1_rt_spread.input_uses[1]: <VL_MAXVL>>), immediates=[], "
+            "outputs=(<add1_rt_spread.outputs[0]: <I64>>, "
+            "<add1_rt_spread.outputs[1]: <I64>>, "
+            "<add1_rt_spread.outputs[2]: <I64>>), "
+            "name='add1_rt_spread')",
+            "Op(kind=OpKind.AddZE, "
+            "input_vals=[<mul1.outputs[1]: <I64>>, "
+            "<add1.outputs[1]: <CA>>], "
+            "input_uses=(<add_hi1.input_uses[0]: <I64>>, "
+            "<add_hi1.input_uses[1]: <CA>>), immediates=[], "
+            "outputs=(<add_hi1.outputs[0]: <I64>>, "
+            "<add_hi1.outputs[1]: <CA>>), "
+            "name='add_hi1')",
+            "Op(kind=OpKind.SvMAddEDU, "
+            "input_vals=[<load_lhs.outputs[0]: <I64*3>>, "
+            "<rhs_spread.outputs[2]: <I64>>, "
+            "<zero.outputs[0]: <I64>>, "
+            "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<mul2.input_uses[0]: <I64*3>>, "
+            "<mul2.input_uses[1]: <I64>>, "
+            "<mul2.input_uses[2]: <I64>>, "
+            "<mul2.input_uses[3]: <VL_MAXVL>>), immediates=[], "
+            "outputs=(<mul2.outputs[0]: <I64*3>>, "
+            "<mul2.outputs[1]: <I64>>), "
+            "name='mul2')",
+            "Op(kind=OpKind.Concat, "
+            "input_vals=[<add1_rt_spread.outputs[1]: <I64>>, "
+            "<add1_rt_spread.outputs[2]: <I64>>, "
+            "<add_hi1.outputs[0]: <I64>>, "
+            "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<add2_rb_concat.input_uses[0]: <I64>>, "
+            "<add2_rb_concat.input_uses[1]: <I64>>, "
+            "<add2_rb_concat.input_uses[2]: <I64>>, "
+            "<add2_rb_concat.input_uses[3]: <VL_MAXVL>>), immediates=[], "
+            "outputs=(<add2_rb_concat.outputs[0]: <I64*3>>,), "
+            "name='add2_rb_concat')",
+            "Op(kind=OpKind.ClearCA, "
+            "input_vals=[], "
+            "input_uses=(), immediates=[], "
+            "outputs=(<clear_ca2.outputs[0]: <CA>>,), "
+            "name='clear_ca2')",
+            "Op(kind=OpKind.SvAddE, "
+            "input_vals=[<mul2.outputs[0]: <I64*3>>, "
+            "<add2_rb_concat.outputs[0]: <I64*3>>, "
+            "<clear_ca2.outputs[0]: <CA>>, "
+            "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<add2.input_uses[0]: <I64*3>>, "
+            "<add2.input_uses[1]: <I64*3>>, "
+            "<add2.input_uses[2]: <CA>>, "
+            "<add2.input_uses[3]: <VL_MAXVL>>), immediates=[], "
+            "outputs=(<add2.outputs[0]: <I64*3>>, "
+            "<add2.outputs[1]: <CA>>), "
+            "name='add2')",
+            "Op(kind=OpKind.Spread, "
+            "input_vals=[<add2.outputs[0]: <I64*3>>, "
+            "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<add2_rt_spread.input_uses[0]: <I64*3>>, "
+            "<add2_rt_spread.input_uses[1]: <VL_MAXVL>>), immediates=[], "
+            "outputs=(<add2_rt_spread.outputs[0]: <I64>>, "
+            "<add2_rt_spread.outputs[1]: <I64>>, "
+            "<add2_rt_spread.outputs[2]: <I64>>), "
+            "name='add2_rt_spread')",
+            "Op(kind=OpKind.AddZE, "
+            "input_vals=[<mul2.outputs[1]: <I64>>, "
+            "<add2.outputs[1]: <CA>>], "
+            "input_uses=(<add_hi2.input_uses[0]: <I64>>, "
+            "<add_hi2.input_uses[1]: <CA>>), immediates=[], "
+            "outputs=(<add_hi2.outputs[0]: <I64>>, "
+            "<add_hi2.outputs[1]: <CA>>), "
+            "name='add_hi2')",
+            "Op(kind=OpKind.SetVLI, "
+            "input_vals=[], "
+            "input_uses=(), immediates=[6], "
+            "outputs=(<retval_setvl.outputs[0]: <VL_MAXVL>>,), "
+            "name='retval_setvl')",
+            "Op(kind=OpKind.Concat, "
+            "input_vals=[<mul0_rt_spread.outputs[0]: <I64>>, "
+            "<add1_rt_spread.outputs[0]: <I64>>, "
+            "<add2_rt_spread.outputs[0]: <I64>>, "
+            "<add2_rt_spread.outputs[1]: <I64>>, "
+            "<add2_rt_spread.outputs[2]: <I64>>, "
+            "<add_hi2.outputs[0]: <I64>>, "
+            "<retval_setvl.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<concat_retval.input_uses[0]: <I64>>, "
+            "<concat_retval.input_uses[1]: <I64>>, "
+            "<concat_retval.input_uses[2]: <I64>>, "
+            "<concat_retval.input_uses[3]: <I64>>, "
+            "<concat_retval.input_uses[4]: <I64>>, "
+            "<concat_retval.input_uses[5]: <I64>>, "
+            "<concat_retval.input_uses[6]: <VL_MAXVL>>), immediates=[], "
+            "outputs=(<concat_retval.outputs[0]: <I64*6>>,), "
+            "name='concat_retval')",
+            "Op(kind=OpKind.SetVLI, "
+            "input_vals=[], "
+            "input_uses=(), immediates=[6], "
+            "outputs=(<setvl6.outputs[0]: <VL_MAXVL>>,), "
+            "name='setvl6')",
+            "Op(kind=OpKind.SvStd, "
+            "input_vals=[<concat_retval.outputs[0]: <I64*6>>, "
+            "<ptr_in.outputs[0]: <I64>>, "
+            "<setvl6.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<store_dest.input_uses[0]: <I64*6>>, "
+            "<store_dest.input_uses[1]: <I64>>, "
+            "<store_dest.input_uses[2]: <VL_MAXVL>>), immediates=[0], "
+            "outputs=(), "
+            "name='store_dest')",
         ])
 
     # FIXME: register allocator currently allocates wrong registers
@@ -303,73 +463,21 @@ class TestToomCook(unittest.TestCase):
     def test_simple_mul_192x192_reg_alloc(self):
         code = SimpleMul192x192()
         fn = code.fn
-        assigned_registers = allocate_registers(fn.ops)
+        assigned_registers = allocate_registers(fn)
         self.assertEqual(assigned_registers, {
-            fn.ops[13].RS: GPRRange(9),  # type: ignore
-            fn.ops[14].results[0]: GPRRange(6),  # type: ignore
-            fn.ops[14].results[1]: GPRRange(7, length=3),  # type: ignore
-            fn.ops[15].out: XERBit.CA,  # type: ignore
-            fn.ops[16].out: GPRRange(7, length=3),  # type: ignore
-            fn.ops[16].CA_out: XERBit.CA,  # type: ignore
-            fn.ops[17].out: GPRRange(10),  # type: ignore
-            fn.ops[17].CA_out: XERBit.CA,  # type: ignore
-            fn.ops[18].dest: GPRRange(6, length=5),  # type: ignore
-            fn.ops[19].RT: GPRRange(3, length=3),  # type: ignore
-            fn.ops[19].RS: GPRRange(9),  # type: ignore
-            fn.ops[20].results[0]: GPRRange(6, length=2),  # type: ignore
-            fn.ops[20].results[1]: GPRRange(8, length=3),  # type: ignore
-            fn.ops[21].out: XERBit.CA,  # type: ignore
-            fn.ops[22].out: GPRRange(8, length=3),  # type: ignore
-            fn.ops[22].CA_out: XERBit.CA,  # type: ignore
-            fn.ops[23].out: GPRRange(11),  # type: ignore
-            fn.ops[23].CA_out: XERBit.CA,  # type: ignore
-            fn.ops[24].dest: GPRRange(6, length=6),  # type: ignore
-            fn.ops[25].out: VL.VL_MAXVL,  # type: ignore
-            fn.ops[26].mem_out: GlobalMem.GlobalMem,  # type: ignore
-            fn.ops[0].out: GlobalMem.GlobalMem,  # type: ignore
-            fn.ops[1].out: GPRRange(3),  # type: ignore
-            fn.ops[2].out: GPRRange(4, length=3),  # type: ignore
-            fn.ops[3].out: GPRRange(7, length=3),  # type: ignore
-            fn.ops[4].dest: GPRRange(12),  # type: ignore
-            fn.ops[5].out: VL.VL_MAXVL,  # type: ignore
-            fn.ops[6].dest: GPRRange(17, length=3),  # type: ignore
-            fn.ops[7].dest: GPRRange(14, length=3),  # type: ignore
-            fn.ops[8].results[0]: GPRRange(14),  # type: ignore
-            fn.ops[8].results[1]: GPRRange(15),  # type: ignore
-            fn.ops[8].results[2]: GPRRange(16),  # type: ignore
-            fn.ops[9].out: VL.VL_MAXVL,  # type: ignore
-            fn.ops[10].out: GPRRange(9),  # type: ignore
-            fn.ops[11].RT: GPRRange(6, length=3),  # type: ignore
-            fn.ops[11].RS: GPRRange(9),  # type: ignore
-            fn.ops[12].dest: GPRRange(6, length=4),  # type: ignore
-            fn.ops[13].RT: GPRRange(3, length=3)  # type: ignore
         })
         self.fail("register allocator currently allocates wrong registers")
 
     # FIXME: register allocator currently allocates wrong registers
     @unittest.expectedFailure
     def test_simple_mul_192x192_asm(self):
+        self.skipTest("WIP")
         code = SimpleMul192x192()
-        asm = generate_assembly(code.fn.ops)
-        self.assertEqual(asm, [
-            'or 12, 3, 3',
-            'setvl 0, 0, 3, 0, 1, 1',
-            'sv.or *17, *4, *4',
-            'sv.or *14, *7, *7',
-            'setvl 0, 0, 3, 0, 1, 1',
-            'addi 9, 0, 0',
-            'sv.maddedu *6, *17, 14, 9',
-            'sv.maddedu *3, *17, 15, 9',
-            'addic 0, 0, 0',
-            'sv.adde *7, *3, *7',
-            'adde 10, 9, 9',
-            'sv.maddedu *3, *17, 16, 9',
-            'addic 0, 0, 0',
-            'sv.adde *8, *3, *8',
-            'adde 11, 9, 9',
-            'setvl 0, 0, 6, 0, 1, 1',
-            'sv.std *6, 0(12)',
-            'bclr 20, 0, 0'
+        fn = code.fn
+        assigned_registers = allocate_registers(fn)
+        gen_asm_state = GenAsmState(assigned_registers)
+        fn.gen_asm(gen_asm_state)
+        self.assertEqual(gen_asm_state.output, [
         ])
         self.fail("register allocator currently allocates wrong registers")
 
index b256a07b7a58868bdcce9d203cdf96389fcd6dc6..bd3d38c77f1a7d648387a7d8e16abafe93b5009f 100644 (file)
@@ -12,7 +12,7 @@ from nmutil.plain_data import fields, plain_data
 
 from bigint_presentation_code.type_util import (Literal, Self, assert_never,
                                                 final)
-from bigint_presentation_code.util import BitSet, FBitSet, FMap, OFSet, OSet
+from bigint_presentation_code.util import BitSet, FBitSet, FMap, InternedMeta, OFSet, OSet
 
 
 @final
@@ -180,7 +180,7 @@ assert OpStage.Early < OpStage.Late, "early must be less than late"
 @plain_data(frozen=True, unsafe_hash=True, repr=False)
 @final
 @total_ordering
-class ProgramPoint:
+class ProgramPoint(metaclass=InternedMeta):
     __slots__ = "op_index", "stage"
 
     def __init__(self, op_index, stage):
@@ -225,7 +225,7 @@ class ProgramPoint:
 
 @plain_data(frozen=True, unsafe_hash=True, repr=False)
 @final
-class ProgramRange(Sequence[ProgramPoint]):
+class ProgramRange(Sequence[ProgramPoint], metaclass=InternedMeta):
     __slots__ = "start", "stop"
 
     def __init__(self, start, stop):
@@ -389,7 +389,7 @@ class BaseTy(Enum):
 
 @plain_data(frozen=True, unsafe_hash=True, repr=False)
 @final
-class Ty:
+class Ty(metaclass=InternedMeta):
     __slots__ = "base_ty", "reg_len"
 
     @staticmethod
@@ -529,7 +529,7 @@ class LocSubKind(Enum):
 
 @plain_data(frozen=True, unsafe_hash=True)
 @final
-class GenericTy:
+class GenericTy(metaclass=InternedMeta):
     __slots__ = "base_ty", "is_vec"
 
     def __init__(self, base_ty, is_vec):
@@ -557,7 +557,7 @@ class GenericTy:
 
 @plain_data(frozen=True, unsafe_hash=True)
 @final
-class Loc:
+class Loc(metaclass=InternedMeta):
     __slots__ = "kind", "start", "reg_len"
 
     @staticmethod
@@ -643,7 +643,7 @@ SPECIAL_GPRS = (
 
 @plain_data(frozen=True, eq=False)
 @final
-class LocSet(AbstractSet[Loc]):
+class LocSet(AbstractSet[Loc], metaclass=InternedMeta):
     __slots__ = "starts", "ty"
 
     def __init__(self, __locs=()):
@@ -746,19 +746,8 @@ class LocSet(AbstractSet[Loc]):
     def __len__(self):
         return self.__len
 
-    __HASHES = {}  # type: dict[tuple[Ty | None, FMap[LocKind, FBitSet]], int]
-
-    @cached_property
-    def __hash(self):
-        # cache hashes to avoid slow LocSet iteration
-        key = self.ty, self.starts
-        retval = self.__HASHES.get(key, None)
-        if retval is None:
-            self.__HASHES[key] = retval = super(LocSet, self)._hash()
-        return retval
-
     def __hash__(self):
-        return self.__hash
+        return super()._hash()
 
     def __eq__(self, __other):
         # type: (LocSet | Any) -> bool
@@ -780,7 +769,7 @@ class LocSet(AbstractSet[Loc]):
 
 @plain_data(frozen=True, unsafe_hash=True)
 @final
-class GenericOperandDesc:
+class GenericOperandDesc(metaclass=InternedMeta):
     """generic Op operand descriptor"""
     __slots__ = ("ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread",
                  "write_stage")
@@ -882,7 +871,7 @@ class GenericOperandDesc:
 
 @plain_data(frozen=True, unsafe_hash=True)
 @final
-class OperandDesc:
+class OperandDesc(metaclass=InternedMeta):
     """Op operand descriptor"""
     __slots__ = ("loc_set_before_spread", "tied_input_index", "spread_index",
                  "write_stage")
@@ -955,7 +944,7 @@ OD_VL = GenericOperandDesc(
 
 @plain_data(frozen=True, unsafe_hash=True)
 @final
-class GenericOpProperties:
+class GenericOpProperties(metaclass=InternedMeta):
     __slots__ = ("demo_asm", "inputs", "outputs", "immediates",
                  "is_copy", "is_load_immediate", "has_side_effects")
 
@@ -1007,7 +996,7 @@ class GenericOpProperties:
 
 @plain_data(frozen=True, unsafe_hash=True)
 @final
-class OpProperties:
+class OpProperties(metaclass=InternedMeta):
     __slots__ = "kind", "inputs", "outputs", "maxvl"
 
     def __init__(self, kind, maxvl):
@@ -1159,6 +1148,31 @@ class OpKind(Enum):
     _PRE_RA_SIMS[SvAddE] = lambda: OpKind.__svadde_pre_ra_sim
     _GEN_ASMS[SvAddE] = lambda: OpKind.__svadde_gen_asm
 
+    @staticmethod
+    def __addze_pre_ra_sim(op, state):
+        # type: (Op, PreRASimState) -> None
+        RA, = state.ssa_vals[op.input_vals[0]]
+        carry, = state.ssa_vals[op.input_vals[1]]
+        v = RA + carry
+        RT = v & GPR_VALUE_MASK
+        carry = (v >> GPR_SIZE_IN_BITS) != 0
+        state.ssa_vals[op.outputs[0]] = RT,
+        state.ssa_vals[op.outputs[1]] = carry,
+
+    @staticmethod
+    def __addze_gen_asm(op, state):
+        # type: (Op, GenAsmState) -> None
+        RT = state.vgpr(op.outputs[0])
+        RA = state.vgpr(op.input_vals[0])
+        state.writeln(f"addze {RT}, {RA}")
+    AddZE = GenericOpProperties(
+        demo_asm="addze RT, RA",
+        inputs=[OD_BASE_SGPR, OD_CA],
+        outputs=[OD_BASE_SGPR, OD_CA.tied_to_input(1)],
+    )
+    _PRE_RA_SIMS[AddZE] = lambda: OpKind.__addze_pre_ra_sim
+    _GEN_ASMS[AddZE] = lambda: OpKind.__addze_gen_asm
+
     @staticmethod
     def __svsubfe_pre_ra_sim(op, state):
         # type: (Op, PreRASimState) -> None
@@ -1611,7 +1625,7 @@ class OpKind(Enum):
 
 
 @plain_data(frozen=True, unsafe_hash=True, repr=False)
-class SSAValOrUse(metaclass=ABCMeta):
+class SSAValOrUse(metaclass=InternedMeta):
     __slots__ = "op", "operand_idx"
 
     def __init__(self, op, operand_idx):
index aefa2914f2cc58ad6a55994f04fbd113cf6758d8..198eebcc9593f13050232e216b58e6699ad2e9c5 100644 (file)
@@ -15,43 +15,7 @@ from bigint_presentation_code.compiler_ir2 import (BaseTy, Fn, FnAnalysis, Loc,
                                                    LocSet, ProgramRange,
                                                    SSAVal, Ty)
 from bigint_presentation_code.type_util import final
-from bigint_presentation_code.util import FMap, OFSet, OSet
-
-
-@plain_data(unsafe_hash=True, order=True, frozen=True)
-class LiveInterval:
-    __slots__ = "first_write", "last_use"
-
-    def __init__(self, first_write, last_use=None):
-        # type: (int, int | None) -> None
-        super().__init__()
-        if last_use is None:
-            last_use = first_write
-        if last_use < first_write:
-            raise ValueError("uses must be after first_write")
-        if first_write < 0 or last_use < 0:
-            raise ValueError("indexes must be nonnegative")
-        self.first_write = first_write
-        self.last_use = last_use
-
-    def overlaps(self, other):
-        # type: (LiveInterval) -> bool
-        if self.first_write == other.first_write:
-            return True
-        return self.last_use > other.first_write \
-            and other.last_use > self.first_write
-
-    def __add__(self, use):
-        # type: (int) -> LiveInterval
-        last_use = max(self.last_use, use)
-        return LiveInterval(first_write=self.first_write, last_use=last_use)
-
-    @property
-    def live_after_op_range(self):
-        """the range of op indexes where self is live immediately after the
-        Op at each index
-        """
-        return range(self.first_write, self.last_use)
+from bigint_presentation_code.util import FMap, InternedMeta, OFSet, OSet
 
 
 class BadMergedSSAVal(ValueError):
@@ -60,7 +24,7 @@ class BadMergedSSAVal(ValueError):
 
 @plain_data(frozen=True, repr=False)
 @final
-class MergedSSAVal:
+class MergedSSAVal(metaclass=InternedMeta):
     """a set of `SSAVal`s along with their offsets, all register allocated as
     a single unit.
 
index 246f65454a861be15fbfb00954f9f9a37dac4899..aa18967e3443353851439002f2ed7458ab7a37c9 100644 (file)
@@ -8,10 +8,7 @@ from typing import Any, Generic, Iterable, Mapping, TypeVar, Union
 
 from nmutil.plain_data import plain_data
 
-from bigint_presentation_code.compiler_ir import (Fn, OpBigIntAddSub,
-                                                  OpBigIntMulDiv, OpConcat,
-                                                  OpLI, OpSetCA, OpSetVLImm,
-                                                  OpSplit, SSAGPRRange)
+from bigint_presentation_code.compiler_ir2 import (Fn, OpKind, SSAVal)
 from bigint_presentation_code.matrix import Matrix
 from bigint_presentation_code.type_util import Literal, final
 
@@ -190,6 +187,7 @@ class EvalOp(Generic[_EvalOpLHS, _EvalOpRHS]):
 
     def __init__(self, lhs, rhs):
         # type: (_EvalOpLHS, _EvalOpRHS) -> None
+        super().__init__()
         self.lhs = lhs
         self.rhs = rhs
         self.poly = self._make_poly()
@@ -442,32 +440,65 @@ class ToomCookInstance:
 
 
 def simple_mul(fn, lhs, rhs):
-    # type: (Fn, SSAGPRRange, SSAGPRRange) -> SSAGPRRange
+    # type: (Fn, SSAVal, SSAVal) -> SSAVal
     """ simple O(n^2) big-int unsigned multiply """
-    if lhs.ty.length < rhs.ty.length:
+    if lhs.ty.reg_len < rhs.ty.reg_len:
         lhs, rhs = rhs, lhs
     # split rhs into elements
-    rhs_words = OpSplit(fn, rhs, range(1, rhs.ty.length)).results
-    retval = None
-    vl = OpSetVLImm(fn, lhs.ty.length).out
-    zero = OpLI(fn, 0).out
+    rhs_setvl = fn.append_new_op(kind=OpKind.SetVLI,
+                                 immediates=[rhs.ty.reg_len], name="rhs_setvl")
+    rhs_spread = fn.append_new_op(
+        kind=OpKind.Spread, input_vals=[rhs, rhs_setvl.outputs[0]],
+        maxvl=rhs.ty.reg_len, name="rhs_spread")
+    rhs_words = rhs_spread.outputs
+    spread_retval = None  # type: tuple[SSAVal, ...] | None
+    maxvl = lhs.ty.reg_len
+    lhs_setvl = fn.append_new_op(kind=OpKind.SetVLI,
+                                 immediates=[lhs.ty.reg_len], name="lhs_setvl")
+    vl = lhs_setvl.outputs[0]
+    zero_op = fn.append_new_op(kind=OpKind.LI, immediates=[0], name="zero")
+    zero = zero_op.outputs[0]
     for shift, rhs_word in enumerate(rhs_words):
-        mul = OpBigIntMulDiv(fn, RA=lhs, RB=rhs_word, RC=zero,
-                             is_div=False, vl=vl)
-        if retval is None:
-            retval = OpConcat(fn, [mul.RT, mul.RS]).dest
+        mul = fn.append_new_op(kind=OpKind.SvMAddEDU,
+                               input_vals=[lhs, rhs_word, zero, vl],
+                               maxvl=maxvl, name=f"mul{shift}")
+        if spread_retval is None:
+            mul_rt_spread = fn.append_new_op(
+                kind=OpKind.Spread, input_vals=[mul.outputs[0], vl],
+                name=f"mul{shift}_rt_spread", maxvl=maxvl)
+            spread_retval = (*mul_rt_spread.outputs, mul.outputs[1])
         else:
-            first_part, last_part = OpSplit(fn, retval, [shift]).results
-            add = OpBigIntAddSub(
-                fn, lhs=mul.RT, rhs=last_part, CA_in=OpSetCA(fn, False).out,
-                is_sub=False, vl=vl)
-            add_hi = OpBigIntAddSub(fn, lhs=mul.RS, rhs=zero, CA_in=add.CA_out,
-                                    is_sub=False)
-            retval = OpConcat(fn, [first_part, add.out, add_hi.out]).dest
-    assert retval is not None
-    return retval
+            first_part = spread_retval[:shift]  # type: tuple[SSAVal, ...]
+            last_part = spread_retval[shift:]
+
+            add_rb_concat = fn.append_new_op(
+                kind=OpKind.Concat, input_vals=[*last_part, vl],
+                name=f"add{shift}_rb_concat", maxvl=maxvl)
+            clear_ca = fn.append_new_op(kind=OpKind.ClearCA,
+                                        name=f"clear_ca{shift}")
+            add = fn.append_new_op(
+                kind=OpKind.SvAddE, input_vals=[
+                    mul.outputs[0], add_rb_concat.outputs[0],
+                    clear_ca.outputs[0], vl],
+                maxvl=maxvl, name=f"add{shift}")
+            add_rt_spread = fn.append_new_op(
+                kind=OpKind.Spread, input_vals=[add.outputs[0], vl],
+                name=f"add{shift}_rt_spread", maxvl=maxvl)
+            add_hi = fn.append_new_op(
+                kind=OpKind.AddZE, input_vals=[mul.outputs[1], add.outputs[1]],
+                name=f"add_hi{shift}")
+            spread_retval = (
+                *first_part, *add_rt_spread.outputs, add_hi.outputs[0])
+    assert spread_retval is not None
+    lhs_setvl = fn.append_new_op(
+        kind=OpKind.SetVLI, immediates=[len(spread_retval)],
+        name="retval_setvl")
+    concat_retval = fn.append_new_op(
+        kind=OpKind.Concat, input_vals=[*spread_retval, lhs_setvl.outputs[0]],
+        name="concat_retval", maxvl=len(spread_retval))
+    return concat_retval.outputs[0]
 
 
 def toom_cook_mul(fn, lhs, rhs, instances):
-    # type: (Fn, SSAGPRRange, SSAGPRRange, list[ToomCookInstance]) -> SSAGPRRange
+    # type: (Fn, SSAVal, SSAVal, list[ToomCookInstance]) -> SSAVal
     raise NotImplementedError
index 757f267ecdfabcc923980e3d863666b5c55d0305..4b7399fc83400ac7ea94bb5517251b5de7d5ce6f 100644 (file)
@@ -1,4 +1,5 @@
-from abc import abstractmethod
+from abc import ABCMeta, abstractmethod
+from functools import lru_cache
 from typing import (AbstractSet, Any, Iterable, Iterator, Mapping, MutableSet,
                     TypeVar, overload)
 
@@ -17,12 +18,42 @@ __all__ = [
     "OSet",
     "top_set_bit_index",
     "trailing_zero_count",
+    "InternedMeta",
 ]
 
 
-class OFSet(AbstractSet[_T_co]):
+class InternedMeta(ABCMeta):
+    def __init__(self, *args, **kwargs):
+        # type: (*Any, **Any) -> None
+        super().__init__(*args, **kwargs)
+        self.__INTERN_TABLE = {}  # type: dict[Any, Any]
+
+    def __intern(self, value):
+        # type: (_T) -> _T
+        value = self.__INTERN_TABLE.setdefault(value, value)
+        if value.__dict__.get("_InternedMeta__interned", False):
+            return value
+        value.__dict__["_InternedMeta__interned"] = True
+        hash_v = hash(value)
+        value.__dict__["__hash__"] = lambda: hash_v
+        old_eq = value.__eq__
+
+        def __eq__(__o):
+            # type: (_T) -> bool
+            if value.__class__ is __o.__class__:
+                return value is __o
+            return old_eq(__o)
+        value.__dict__["__eq__"] = __eq__
+        return value
+
+    def __call__(self, *args, **kwargs):
+        # type: (*Any, **Any) -> Any
+        return self.__intern(super().__call__(*args, **kwargs))
+
+
+class OFSet(AbstractSet[_T_co], metaclass=InternedMeta):
     """ ordered frozen set """
-    __slots__ = "__items",
+    __slots__ = "__items", "__dict__", "__weakref__"
 
     def __init__(self, items=()):
         # type: (Iterable[_T_co]) -> None
@@ -54,7 +85,7 @@ class OFSet(AbstractSet[_T_co]):
 
 class OSet(MutableSet[_T]):
     """ ordered mutable set """
-    __slots__ = "__items",
+    __slots__ = "__items", "__dict__"
 
     def __init__(self, items=()):
         # type: (Iterable[_T]) -> None
@@ -88,9 +119,9 @@ class OSet(MutableSet[_T]):
         return f"OSet({list(self)})"
 
 
-class FMap(Mapping[_T, _T_co]):
+class FMap(Mapping[_T, _T_co], metaclass=InternedMeta):
     """ordered frozen hashable mapping"""
-    __slots__ = "__items", "__hash"
+    __slots__ = "__items", "__hash", "__dict__", "__weakref__"
 
     @overload
     def __init__(self, items):
@@ -167,7 +198,7 @@ except AttributeError:
 
 
 class BaseBitSet(AbstractSet[int]):
-    __slots__ = "__bits",
+    __slots__ = "__bits", "__dict__", "__weakref__"
 
     @classmethod
     @abstractmethod
@@ -402,7 +433,7 @@ class BitSet(BaseBitSet, MutableSet[int]):
         return super().__isub__(it)
 
 
-class FBitSet(BaseBitSet):
+class FBitSet(BaseBitSet, metaclass=InternedMeta):
     """Frozen Bit Set"""
 
     @final