From 089a2ef4c11af2fa9bf6f7a181d7145d1e61ced9 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Tue, 15 Nov 2022 00:11:21 -0800 Subject: [PATCH] working on code --- .../_tests/test_toom_cook.py | 679 +++++++++++++++++- src/bigint_presentation_code/compiler_ir.py | 2 +- .../register_allocator.py | 9 + src/bigint_presentation_code/toom_cook.py | 110 ++- src/bigint_presentation_code/util.py | 11 +- 5 files changed, 746 insertions(+), 65 deletions(-) diff --git a/src/bigint_presentation_code/_tests/test_toom_cook.py b/src/bigint_presentation_code/_tests/test_toom_cook.py index 735d7e0..2032cbc 100644 --- a/src/bigint_presentation_code/_tests/test_toom_cook.py +++ b/src/bigint_presentation_code/_tests/test_toom_cook.py @@ -990,8 +990,7 @@ class TestToomCook(unittest.TestCase): 'sv.std *4, 0(3)' ]) - def test_toom_2_mul_256x256_asm(self): - self.skipTest("WIP") # FIXME: finish + def toom_2_mul_256x256(self): TOOM_2 = ToomCookInstance.make_toom_2() instances = TOOM_2, TOOM_2 @@ -999,12 +998,686 @@ class TestToomCook(unittest.TestCase): # type: (Fn, SSAVal, SSAVal) -> SSAVal return toom_cook_mul(fn=fn, lhs=lhs, lhs_signed=False, rhs=rhs, rhs_signed=False, instances=instances) - code = Mul(mul=mul, lhs_size_in_words=3, rhs_size_in_words=3) + return Mul(mul=mul, lhs_size_in_words=3, rhs_size_in_words=3) + + def test_toom_2_mul_256x256_pre_ra_sim(self): + self.skipTest("WIP") # FIXME: finish + # maybe use something that multiplies to: + # int.from_bytes( + # b'256x256-bit bigint mul using TOOM-2 -- Karatsuba Multiplication!', + # 'little') + # as the multiplication test... + # known factors (used yafu-1.34): + # P1 = 2 + # P1 = 7 + # P3 = 197 + # P7 = 1319057 + # ***co-factor*** + # C144 = 4812983706140089583461601472550901888754775658675461119771495\ + # 11062521614062442465071845504357495554525178667728633744424201288485\ + # 594266060663587 + + def test_toom_2_mul_256x256_asm(self): + self.skipTest("WIP") # FIXME: finish + code = self.toom_2_mul_256x256() fn = code.fn assigned_registers = allocate_registers(fn) gen_asm_state = GenAsmState(assigned_registers) fn.gen_asm(gen_asm_state) self.assertEqual(gen_asm_state.output, [ + 'or 23, 3, 3', + 'setvl 0, 0, 3, 0, 1, 1', + 'or 6, 23, 23', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.ld *3, 48(6)', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *7, *3, *3', + 'setvl 0, 0, 3, 0, 1, 1', + 'or 6, 23, 23', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.ld *3, 72(6)', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *16, *3, *3', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *3, *7, *7', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *6, *3, *3', + 'or 5, 6, 6', + 'or 4, 7, 7', + 'or 3, 8, 8', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 3, 5, 5', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 10, 3, 3', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 3, 4, 4', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 25, 3, 3', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 3, 10, 10', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 4, 3, 3', + 'addi 3, 0, 0', + 'or 5, 3, 3', + 'setvl 0, 0, 2, 0, 1, 1', + 'or 3, 4, 4', + 'or 4, 5, 5', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *6, *3, *3', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 3, 25, 25', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 4, 3, 3', + 'addi 3, 0, 0', + 'or 5, 3, 3', + 'setvl 0, 0, 2, 0, 1, 1', + 'or 3, 4, 4', + 'or 4, 5, 5', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'addic 0, 0, 0', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or/mrr *7, *6, *6', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *5, *3, *3', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.adde *3, *7, *5', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *14, *3, *3', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *3, *16, *16', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *6, *3, *3', + 'or 5, 6, 6', + 'or 4, 7, 7', + 'or 3, 8, 8', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 3, 5, 5', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 9, 3, 3', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 3, 4, 4', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 24, 3, 3', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 3, 9, 9', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 4, 3, 3', + 'addi 3, 0, 0', + 'or 5, 3, 3', + 'setvl 0, 0, 2, 0, 1, 1', + 'or 3, 4, 4', + 'or 4, 5, 5', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *6, *3, *3', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 3, 24, 24', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 4, 3, 3', + 'addi 3, 0, 0', + 'or 5, 3, 3', + 'setvl 0, 0, 2, 0, 1, 1', + 'or 3, 4, 4', + 'or 4, 5, 5', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'addic 0, 0, 0', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or/mrr *7, *6, *6', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *5, *3, *3', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.adde *3, *7, *5', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *11, *3, *3', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 3, 9, 9', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 5, 3, 3', + 'addi 3, 0, 0', + 'or 4, 3, 3', + 'setvl 0, 0, 1, 0, 1, 1', + 'addi 3, 0, 0', + 'or 6, 10, 10', + 'setvl 0, 0, 1, 0, 1, 1', + 'sv.maddedu *3, *6, 5, 4', + 'or 5, 4, 4', + 'setvl 0, 0, 1, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'or 4, 5, 5', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *21, *3, *3', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *3, *14, *14', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or/mrr *4, *3, *3', + 'or 3, 4, 4', + 'or 4, 5, 5', + 'setvl 0, 0, 1, 0, 1, 1', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 10, 3, 3', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 3, 4, 4', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 18, 3, 3', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 3, 10, 10', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 4, 3, 3', + 'addi 3, 0, 0', + 'or 5, 3, 3', + 'setvl 0, 0, 2, 0, 1, 1', + 'or 3, 4, 4', + 'or 4, 5, 5', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *6, *3, *3', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 3, 18, 18', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 4, 3, 3', + 'addi 3, 0, 0', + 'or 5, 3, 3', + 'setvl 0, 0, 2, 0, 1, 1', + 'or 3, 4, 4', + 'or 4, 5, 5', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'addic 0, 0, 0', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or/mrr *7, *6, *6', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *5, *3, *3', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.adde *3, *7, *5', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *16, *3, *3', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *3, *11, *11', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or/mrr *4, *3, *3', + 'or 3, 4, 4', + 'or 4, 5, 5', + 'setvl 0, 0, 1, 0, 1, 1', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 9, 3, 3', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 3, 4, 4', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 15, 3, 3', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 3, 9, 9', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 4, 3, 3', + 'addi 3, 0, 0', + 'or 5, 3, 3', + 'setvl 0, 0, 2, 0, 1, 1', + 'or 3, 4, 4', + 'or 4, 5, 5', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *6, *3, *3', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 3, 15, 15', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 4, 3, 3', + 'addi 3, 0, 0', + 'or 5, 3, 3', + 'setvl 0, 0, 2, 0, 1, 1', + 'or 3, 4, 4', + 'or 4, 5, 5', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'addic 0, 0, 0', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or/mrr *7, *6, *6', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *5, *3, *3', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.adde *3, *7, *5', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *7, *3, *3', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 3, 9, 9', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 5, 3, 3', + 'addi 3, 0, 0', + 'or 4, 3, 3', + 'setvl 0, 0, 1, 0, 1, 1', + 'addi 3, 0, 0', + 'or 6, 10, 10', + 'setvl 0, 0, 1, 0, 1, 1', + 'sv.maddedu *3, *6, 5, 4', + 'or 5, 4, 4', + 'setvl 0, 0, 1, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'or 4, 5, 5', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *19, *3, *3', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *3, *7, *7', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *5, *3, *3', + 'or 4, 5, 5', + 'or 11, 6, 6', + 'addi 3, 0, 0', + 'or 10, 3, 3', + 'setvl 0, 0, 2, 0, 1, 1', + 'addi 3, 0, 0', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *8, *16, *16', + 'or 6, 4, 4', + 'or 5, 10, 10', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.maddedu *3, *8, 6, 5', + 'setvl 0, 0, 2, 0, 1, 1', + 'or 14, 5, 5', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'or 12, 3, 3', + 'or 7, 4, 4', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *8, *16, *16', + 'or 6, 11, 11', + 'or 5, 10, 10', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.maddedu *3, *8, 6, 5', + 'setvl 0, 0, 2, 0, 1, 1', + 'or 11, 5, 5', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'or 10, 3, 3', + 'or 9, 4, 4', + 'addi 3, 0, 0', + 'or 6, 3, 3', + 'setvl 0, 0, 3, 0, 1, 1', + 'or 3, 7, 7', + 'or 4, 14, 14', + 'or 5, 6, 6', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *6, *3, *3', + 'or 3, 10, 10', + 'or 4, 9, 9', + 'or 5, 11, 11', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'addic 0, 0, 0', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *9, *6, *6', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *6, *3, *3', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.adde *3, *9, *6', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'or 9, 3, 3', + 'or 8, 4, 4', + 'or 7, 5, 5', + 'setvl 0, 0, 4, 0, 1, 1', + 'or 3, 12, 12', + 'or 4, 9, 9', + 'or 5, 8, 8', + 'or 6, 7, 7', + 'setvl 0, 0, 4, 0, 1, 1', + 'setvl 0, 0, 4, 0, 1, 1', + 'sv.or *7, *3, *3', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 3, 15, 15', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 5, 3, 3', + 'addi 3, 0, 0', + 'or 4, 3, 3', + 'setvl 0, 0, 1, 0, 1, 1', + 'addi 3, 0, 0', + 'or 6, 18, 18', + 'setvl 0, 0, 1, 0, 1, 1', + 'sv.maddedu *3, *6, 5, 4', + 'or 5, 4, 4', + 'setvl 0, 0, 1, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'or 4, 5, 5', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *17, *3, *3', + 'setvl 0, 0, 4, 0, 1, 1', + 'setvl 0, 0, 4, 0, 1, 1', + 'sv.or *3, *7, *7', + 'setvl 0, 0, 4, 0, 1, 1', + 'sv.or *8, *3, *3', + 'or 4, 8, 8', + 'or 7, 9, 9', + 'or 6, 10, 10', + 'or 3, 11, 11', + 'setvl 0, 0, 3, 0, 1, 1', + 'or 3, 4, 4', + 'or 4, 7, 7', + 'or 5, 6, 6', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *14, *3, *3', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *3, *19, *19', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *5, *3, *3', + 'or 4, 5, 5', + 'or 7, 6, 6', + 'addi 3, 0, 0', + 'or 6, 3, 3', + 'setvl 0, 0, 3, 0, 1, 1', + 'or 3, 4, 4', + 'or 4, 7, 7', + 'or 5, 6, 6', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'subfc 0, 0, 0', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *9, *3, *3', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *6, *14, *14', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.subfe *3, *9, *6', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *14, *3, *3', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *3, *17, *17', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *5, *3, *3', + 'or 4, 5, 5', + 'or 7, 6, 6', + 'addi 3, 0, 0', + 'or 6, 3, 3', + 'setvl 0, 0, 3, 0, 1, 1', + 'or 3, 4, 4', + 'or 4, 7, 7', + 'or 5, 6, 6', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'subfc 0, 0, 0', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *9, *3, *3', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *6, *14, *14', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.subfe *3, *9, *6', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *8, *3, *3', + 'addi 3, 0, 0', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *3, *19, *19', + 'setvl 0, 0, 2, 0, 1, 1', + 'or 12, 3, 3', + 'or 7, 4, 4', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *3, *8, *8', + 'setvl 0, 0, 3, 0, 1, 1', + 'or 11, 3, 3', + 'or 10, 4, 4', + 'or 9, 5, 5', + 'addi 3, 0, 0', + 'or 6, 3, 3', + 'setvl 0, 0, 3, 0, 1, 1', + 'or 3, 7, 7', + 'or 4, 6, 6', + 'or 5, 6, 6', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *6, *3, *3', + 'or 3, 11, 11', + 'or 4, 10, 10', + 'or 5, 9, 9', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'addic 0, 0, 0', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *9, *6, *6', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *6, *3, *3', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.adde *3, *9, *6', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'or 9, 3, 3', + 'or 6, 4, 4', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *3, *17, *17', + 'setvl 0, 0, 2, 0, 1, 1', + 'or 8, 3, 3', + 'or 7, 4, 4', + 'setvl 0, 0, 2, 0, 1, 1', + 'or 3, 6, 6', + 'or 4, 5, 5', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *5, *3, *3', + 'or 3, 8, 8', + 'or 4, 7, 7', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'addic 0, 0, 0', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *7, *5, *5', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *5, *3, *3', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.adde *3, *7, *5', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'or 8, 3, 3', + 'or 7, 4, 4', + 'setvl 0, 0, 4, 0, 1, 1', + 'or 3, 12, 12', + 'or 4, 9, 9', + 'or 5, 8, 8', + 'or 6, 7, 7', + 'setvl 0, 0, 4, 0, 1, 1', + 'setvl 0, 0, 4, 0, 1, 1', + 'sv.or *7, *3, *3', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 3, 24, 24', + 'setvl 0, 0, 1, 0, 1, 1', + 'or 5, 3, 3', + 'addi 3, 0, 0', + 'or 4, 3, 3', + 'setvl 0, 0, 1, 0, 1, 1', + 'addi 3, 0, 0', + 'or 6, 25, 25', + 'setvl 0, 0, 1, 0, 1, 1', + 'sv.maddedu *3, *6, 5, 4', + 'or 5, 4, 4', + 'setvl 0, 0, 1, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'or 4, 5, 5', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *19, *3, *3', + 'setvl 0, 0, 4, 0, 1, 1', + 'setvl 0, 0, 4, 0, 1, 1', + 'sv.or *3, *7, *7', + 'setvl 0, 0, 4, 0, 1, 1', + 'sv.or *8, *3, *3', + 'or 4, 8, 8', + 'or 7, 9, 9', + 'or 6, 10, 10', + 'or 3, 11, 11', + 'setvl 0, 0, 3, 0, 1, 1', + 'or 3, 4, 4', + 'or 4, 7, 7', + 'or 5, 6, 6', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *14, *3, *3', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *3, *21, *21', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *5, *3, *3', + 'or 4, 5, 5', + 'or 7, 6, 6', + 'addi 3, 0, 0', + 'or 6, 3, 3', + 'setvl 0, 0, 3, 0, 1, 1', + 'or 3, 4, 4', + 'or 4, 7, 7', + 'or 5, 6, 6', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'subfc 0, 0, 0', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *9, *3, *3', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *6, *14, *14', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.subfe *3, *9, *6', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *14, *3, *3', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *3, *19, *19', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *5, *3, *3', + 'or 4, 5, 5', + 'or 7, 6, 6', + 'addi 3, 0, 0', + 'or 6, 3, 3', + 'setvl 0, 0, 3, 0, 1, 1', + 'or 3, 4, 4', + 'or 4, 7, 7', + 'or 5, 6, 6', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'subfc 0, 0, 0', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *9, *3, *3', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *6, *14, *14', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.subfe *3, *9, *6', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *7, *3, *3', + 'addi 3, 0, 0', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *3, *21, *21', + 'setvl 0, 0, 2, 0, 1, 1', + 'or 18, 3, 3', + 'or 6, 4, 4', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *3, *7, *7', + 'setvl 0, 0, 3, 0, 1, 1', + 'or 15, 3, 3', + 'or 14, 4, 4', + 'or 12, 5, 5', + 'addi 3, 0, 0', + 'or 7, 3, 3', + 'or 3, 12, 12', + 'sradi 3, 3, 63', + 'or 11, 3, 3', + 'setvl 0, 0, 4, 0, 1, 1', + 'or 3, 6, 6', + 'or 4, 7, 7', + 'or 5, 7, 7', + 'or 6, 7, 7', + 'setvl 0, 0, 4, 0, 1, 1', + 'setvl 0, 0, 4, 0, 1, 1', + 'sv.or *7, *3, *3', + 'or 3, 15, 15', + 'or 4, 14, 14', + 'or 5, 12, 12', + 'or 6, 11, 11', + 'setvl 0, 0, 4, 0, 1, 1', + 'setvl 0, 0, 4, 0, 1, 1', + 'addic 0, 0, 0', + 'setvl 0, 0, 4, 0, 1, 1', + 'sv.or *14, *7, *7', + 'setvl 0, 0, 4, 0, 1, 1', + 'sv.or *7, *3, *3', + 'setvl 0, 0, 4, 0, 1, 1', + 'sv.adde *3, *14, *7', + 'setvl 0, 0, 4, 0, 1, 1', + 'setvl 0, 0, 4, 0, 1, 1', + 'setvl 0, 0, 4, 0, 1, 1', + 'sv.or *8, *3, *3', + 'or 14, 8, 8', + 'or 5, 9, 9', + 'or 7, 10, 10', + 'or 6, 11, 11', + 'setvl 0, 0, 2, 0, 1, 1', + 'setvl 0, 0, 2, 0, 1, 1', + 'sv.or *3, *19, *19', + 'setvl 0, 0, 2, 0, 1, 1', + 'or 11, 3, 3', + 'or 10, 4, 4', + 'addi 3, 0, 0', + 'or 9, 3, 3', + 'setvl 0, 0, 3, 0, 1, 1', + 'or 3, 5, 5', + 'or 4, 7, 7', + 'or 5, 6, 6', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *6, *3, *3', + 'or 3, 11, 11', + 'or 4, 10, 10', + 'or 5, 9, 9', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'addic 0, 0, 0', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *9, *6, *6', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *6, *3, *3', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.adde *3, *9, *6', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 3, 0, 1, 1', + 'or 12, 3, 3', + 'or 11, 4, 4', + 'or 10, 5, 5', + 'addi 3, 0, 0', + 'or 9, 3, 3', + 'setvl 0, 0, 6, 0, 1, 1', + 'or 3, 18, 18', + 'or 4, 14, 14', + 'or 5, 12, 12', + 'or 6, 11, 11', + 'or 7, 10, 10', + 'or 8, 9, 9', + 'setvl 0, 0, 6, 0, 1, 1', + 'setvl 0, 0, 6, 0, 1, 1', + 'setvl 0, 0, 6, 0, 1, 1', + 'setvl 0, 0, 6, 0, 1, 1', + 'sv.or/mrr *4, *3, *3', + 'or 3, 23, 23', + 'setvl 0, 0, 6, 0, 1, 1', + 'sv.std *4, 0(3)' ]) diff --git a/src/bigint_presentation_code/compiler_ir.py b/src/bigint_presentation_code/compiler_ir.py index 9c7d729..61a5ece 100644 --- a/src/bigint_presentation_code/compiler_ir.py +++ b/src/bigint_presentation_code/compiler_ir.py @@ -1294,7 +1294,7 @@ class OpKind(Enum): def __sradi_gen_asm(op, state): # type: (Op, GenAsmState) -> None RA = state.sgpr(op.outputs[0]) - RS = state.sgpr(op.input_vals[1]) + RS = state.sgpr(op.input_vals[0]) imm = op.immediates[0] state.writeln(f"sradi {RA}, {RS}, {imm}") SRADI = GenericOpProperties( diff --git a/src/bigint_presentation_code/register_allocator.py b/src/bigint_presentation_code/register_allocator.py index b3f682e..f41d18e 100644 --- a/src/bigint_presentation_code/register_allocator.py +++ b/src/bigint_presentation_code/register_allocator.py @@ -500,6 +500,8 @@ def allocate_registers(fn, debug_out=None): nodes_remaining = OSet(interference_graph.nodes.values()) + local_colorability_score_cache = {} # type: dict[IGNode, int] + def local_colorability_score(node): # type: (IGNode) -> int """ returns a positive integer if node is locally colorable, returns @@ -508,10 +510,14 @@ def allocate_registers(fn, debug_out=None): """ if node not in nodes_remaining: raise ValueError() + retval = local_colorability_score_cache.get(node, None) + if retval is not None: + return retval retval = len(node.loc_set) for neighbor in node.edges: if neighbor in nodes_remaining: retval -= node.loc_set.max_conflicts_with(neighbor.loc_set) + local_colorability_score_cache[node] = retval return retval # TODO: implement copy-merging @@ -533,6 +539,9 @@ def allocate_registers(fn, debug_out=None): break node_stack.append(best_node) nodes_remaining.remove(best_node) + local_colorability_score_cache.pop(best_node, None) + for neighbor in best_node.edges: + local_colorability_score_cache.pop(neighbor, None) if debug_out is not None: print(f"After deciding node allocation order:\n" diff --git a/src/bigint_presentation_code/toom_cook.py b/src/bigint_presentation_code/toom_cook.py index c261d82..75891f2 100644 --- a/src/bigint_presentation_code/toom_cook.py +++ b/src/bigint_presentation_code/toom_cook.py @@ -5,7 +5,7 @@ import math from abc import abstractmethod from enum import Enum from fractions import Fraction -from typing import Iterable, Mapping, Union +from typing import Iterable, Mapping, Tuple, Union from cached_property import cached_property from nmutil.plain_data import plain_data @@ -376,8 +376,16 @@ class EvalOpAdd(EvalOp): fn=state.fn, ssa_val=rhs.output, dest_size=output_value_range.output_size, src_signed=rhs.is_signed, name="add_rhs_cast") - - raise NotImplementedError # FIXME: finish + setvl = state.fn.append_new_op( + OpKind.SetVLI, immediates=[output_value_range.output_size], + name="setvl", maxvl=output_value_range.output_size) + clear_ca = state.fn.append_new_op(OpKind.ClearCA, name="clear_ca") + add = state.fn.append_new_op( + OpKind.SvAddE, input_vals=[ + lhs_output, rhs_output, clear_ca.outputs[0], setvl.outputs[0]], + maxvl=output_value_range.output_size, name="add") + return EvalOpGenIrOutput( + output=add.outputs[0], value_range=output_value_range) @plain_data(frozen=True, unsafe_hash=True) @@ -391,7 +399,26 @@ class EvalOpSub(EvalOp): def make_output(self, state, output_value_range): # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput - raise NotImplementedError # FIXME: finish + lhs = state.get_output(self.lhs) + lhs_output = cast_to_size( + fn=state.fn, ssa_val=lhs.output, + dest_size=output_value_range.output_size, src_signed=lhs.is_signed, + name="add_lhs_cast") + rhs = state.get_output(self.rhs) + rhs_output = cast_to_size( + fn=state.fn, ssa_val=rhs.output, + dest_size=output_value_range.output_size, src_signed=rhs.is_signed, + name="add_rhs_cast") + setvl = state.fn.append_new_op( + OpKind.SetVLI, immediates=[output_value_range.output_size], + name="setvl", maxvl=output_value_range.output_size) + set_ca = state.fn.append_new_op(OpKind.SetCA, name="set_ca") + sub = state.fn.append_new_op( + OpKind.SvSubFE, input_vals=[ + rhs_output, lhs_output, set_ca.outputs[0], setvl.outputs[0]], + maxvl=output_value_range.output_size, name="sub") + return EvalOpGenIrOutput( + output=sub.outputs[0], value_range=output_value_range) @plain_data(frozen=True, unsafe_hash=True) @@ -747,9 +774,11 @@ def sum_partial_products(fn, partial_products, retval_size, name): return retval_concat.outputs[0] -def simple_mul(fn, lhs, lhs_signed, rhs, rhs_signed, name): - # type: (Fn, SSAVal, bool, SSAVal, bool, str) -> SSAVal +def simple_mul(fn, lhs, lhs_signed, rhs, rhs_signed, name, retval_size=None): + # type: (Fn, SSAVal, bool, SSAVal, bool, str, int | None) -> SSAVal """ simple O(n^2) big-int multiply """ + if retval_size is None: + retval_size = lhs.ty.reg_len + rhs.ty.reg_len if lhs.ty.reg_len < rhs.ty.reg_len: lhs, rhs = rhs, lhs lhs_signed, rhs_signed = rhs_signed, lhs_signed @@ -820,7 +849,7 @@ def simple_mul(fn, lhs, lhs_signed, rhs, rhs_signed, name): shift_in_words=rhs.ty.reg_len, is_signed=False, subtract=True) return sum_partial_products( fn=fn, partial_products=partial_products(), - retval_size=lhs.ty.reg_len + rhs.ty.reg_len, name=name) + retval_size=retval_size, name=name) def cast_to_size(fn, ssa_val, src_signed, dest_size, name): @@ -908,9 +937,14 @@ def split_into_exact_sized_parts(fn, ssa_val, part_count, part_size, name): return retval +__TCIs = Tuple[ToomCookInstance, ...] + + def toom_cook_mul(fn, lhs, lhs_signed, rhs, rhs_signed, instances, - start_instance_index=0): - # type: (Fn, SSAVal, bool, SSAVal, bool, tuple[ToomCookInstance, ...], int) -> SSAVal + retval_size=None, start_instance_index=0): + # type: (Fn, SSAVal, bool, SSAVal, bool, __TCIs, None | int, int) -> SSAVal + if retval_size is None: + retval_size = lhs.ty.reg_len + rhs.ty.reg_len if start_instance_index < 0: raise ValueError("start_instance_index must be non-negative") instance = None @@ -968,15 +1002,10 @@ def toom_cook_mul(fn, lhs, lhs_signed, rhs, rhs_signed, instances, prod_eval_state = EvalOpGenIrState(fn=fn, inputs=prod_inputs) prod_parts = [ prod_eval_state.get_output(i) for i in instance.prod_eval_ops] - retval_size = lhs.ty.reg_len + rhs.ty.reg_len - spread_retval = [] # type: list[SSAVal] - retval_signed = False # type: bool - # FIXME: replace loop with call to sum_partial_products - for part, prod_part in enumerate(prod_parts): - shift = part * part_size - maxvl = 1 + max(len(spread_retval) - shift, - prod_part.output.ty.reg_len) - if part == 0: + + def partial_products(): + # type: () -> Iterable[PartialProduct] + for part, prod_part in enumerate(prod_parts): part_maxvl = prod_part.output.ty.reg_len part_setvl = fn.append_new_op( OpKind.SetVLI, immediates=[part_maxvl], @@ -985,43 +1014,8 @@ def toom_cook_mul(fn, lhs, lhs_signed, rhs, rhs_signed, instances, OpKind.Spread, input_vals=[prod_part.output, part_setvl.outputs[0]], name=f"prod_{part}_spread", maxvl=part_maxvl) - spread_retval[:] = spread_part.outputs - else: - cast_retval_spread = cast_to_size_spread( - fn=fn, ssa_vals=spread_retval[shift:], - src_signed=retval_signed, dest_size=maxvl, - name=f"prod_{part}_retval_cast") - cast_prod = cast_to_size( - fn=fn, ssa_val=prod_part.output, - src_signed=prod_part.is_signed, dest_size=maxvl, - name=f"prod_{part}_cast") - part_setvl = fn.append_new_op( - OpKind.SetVLI, immediates=[maxvl], - name=f"prod_{part}_setvl", maxvl=maxvl) - cast_retval = fn.append_new_op( - kind=OpKind.Concat, - input_vals=[*cast_retval_spread, part_setvl.outputs[0]], - name=f"prod_{part}_concat", maxvl=maxvl) - clear_ca = fn.append_new_op(kind=OpKind.ClearCA, - name=f"prod_{part}_clear_ca") - add = fn.append_new_op( - kind=OpKind.SvAddE, input_vals=[ - cast_prod, cast_retval.outputs[0], - clear_ca.outputs[0], part_setvl.outputs[0]], - maxvl=maxvl, name=f"prod_{part}_add") - spread = fn.append_new_op( - kind=OpKind.Spread, - input_vals=[add.outputs[0], part_setvl.outputs[0]], - name=f"prod_{part}_spread", maxvl=maxvl) - spread_retval[shift:] = spread.outputs - retval_signed |= prod_part.is_signed - while len(spread_retval) > retval_size: - spread_retval.pop() - assert len(spread_retval) == retval_size, "logic error" - retval_setvl = fn.append_new_op( - OpKind.SetVLI, immediates=[retval_size], name=f"prod_setvl", - maxvl=retval_size) - retval_concat = fn.append_new_op( - OpKind.Concat, input_vals=[*spread_retval, retval_setvl.outputs[0]], - name="prod_concat", maxvl=retval_size) - return retval_concat.outputs[0] + yield PartialProduct( + spread_part.outputs, shift_in_words=part * part_size, + is_signed=prod_part.is_signed, subtract=False) + return sum_partial_products(fn=fn, partial_products=partial_products(), + retval_size=retval_size, name="prod") diff --git a/src/bigint_presentation_code/util.py b/src/bigint_presentation_code/util.py index 03eaeff..423714a 100644 --- a/src/bigint_presentation_code/util.py +++ b/src/bigint_presentation_code/util.py @@ -26,13 +26,15 @@ class InternedMeta(ABCMeta): # type: (*Any, **Any) -> None super().__init__(*args, **kwargs) self.__INTERN_TABLE = {} # type: dict[Any, Any] + self._InternedMeta__interned = False def __intern(self, value): # type: (_T) -> _T + if value._InternedMeta__interned: # type: ignore + return value value = self.__INTERN_TABLE.setdefault(value, value) - if value.__dict__.get("_InternedMeta__interned", False): + if value._InternedMeta__interned: # type: ignore return value - value.__dict__["_InternedMeta__interned"] = True hash_v = hash(value) value.__dict__["__hash__"] = lambda: hash_v old_eq = value.__eq__ @@ -40,9 +42,12 @@ class InternedMeta(ABCMeta): def __eq__(__o): # type: (_T) -> bool if value.__class__ is __o.__class__: - return value is __o + if (value._InternedMeta__interned and # type: ignore + __o._InternedMeta__interned): # type: ignore + return value is __o return old_eq(__o) value.__dict__["__eq__"] = __eq__ + value.__dict__["_InternedMeta__interned"] = True return value def __call__(self, *args, **kwargs): -- 2.30.2