From e4d5c09e1030a24887ea33bff5d39b0151584fc7 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Wed, 9 Nov 2022 23:59:00 -0800 Subject: [PATCH] simple_mul works with signed/unsigned mul; also made ir repr easier to read --- .../_tests/test_compiler_ir.py | 454 +++----- .../_tests/test_toom_cook.py | 1027 +++++++++-------- src/bigint_presentation_code/compiler_ir.py | 118 +- src/bigint_presentation_code/toom_cook.py | 120 +- 4 files changed, 903 insertions(+), 816 deletions(-) diff --git a/src/bigint_presentation_code/_tests/test_compiler_ir.py b/src/bigint_presentation_code/_tests/test_compiler_ir.py index 9763a07..904f66d 100644 --- a/src/bigint_presentation_code/_tests/test_compiler_ir.py +++ b/src/bigint_presentation_code/_tests/test_compiler_ir.py @@ -94,36 +94,28 @@ class TestCompilerIR(unittest.TestCase): self.assertEqual( repr(fn_analysis.op_indexes), "FMap({" - "Op(kind=OpKind.FuncArgR3, input_vals=[], input_uses=(), " - "immediates=[], outputs=(>,), " - "name='arg'): 0, " - "Op(kind=OpKind.SetVLI, input_vals=[], input_uses=(), " - "immediates=[32], outputs=(>,), " - "name='vl'): 1, " - "Op(kind=OpKind.SvLd, input_vals=[" - ">, >], " - "input_uses=(>, " - ">), immediates=[0], " - "outputs=(>,), name='ld'): 2, " - "Op(kind=OpKind.SvLI, input_vals=[>], " - "input_uses=(>,), immediates=[0], " - "outputs=(>,), name='li'): 3, " - "Op(kind=OpKind.SetCA, input_vals=[], input_uses=(), " - "immediates=[], outputs=(>,), name='ca'): 4, " - "Op(kind=OpKind.SvAddE, input_vals=[" - ">, >, " - ">, >], " - "input_uses=(>, " - ">, >, " - ">), immediates=[], outputs=(" - ">, >), " - "name='add'): 5, " - "Op(kind=OpKind.SvStd, input_vals=[" - ">, >, " - ">], " - "input_uses=(>, " - ">, >), " - "immediates=[0], outputs=(), name='st'): 6})" + "arg:\n" + " (<...outputs[0]: >) <= FuncArgR3: 0, " + "vl:\n" + " (<...outputs[0]: >) <= SetVLI(0x20): 1, " + "ld:\n" + " (<...outputs[0]: >) <= SvLd(\n" + " >, >, 0x0)" + ": 2, " + "li:\n" + " (<...outputs[0]: >) <= SvLI(\n" + " >, 0x0): 3, " + "ca:\n" + " (<...outputs[0]: >) <= SetCA: 4, " + "add:\n" + " (<...outputs[0]: >, <...outputs[1]: >\n" + " ) <= SvAddE(>,\n" + " >, >,\n" + " >): 5, " + "st:\n" + " SvStd(>, >,\n" + " >, 0x0): 6" + "})" ) self.assertEqual( repr(fn_analysis.live_ranges), @@ -207,52 +199,53 @@ class TestCompilerIR(unittest.TestCase): def test_repr(self): fn, _arg = self.make_add_fn() - self.assertEqual([repr(i) for i in fn.ops], [ - "Op(kind=OpKind.FuncArgR3, " - "input_vals=[], " - "input_uses=(), " - "immediates=[], " - "outputs=(>,), name='arg')", - "Op(kind=OpKind.SetVLI, " - "input_vals=[], " - "input_uses=(), " - "immediates=[32], " - "outputs=(>,), name='vl')", - "Op(kind=OpKind.SvLd, " - "input_vals=[>, " - ">], " - "input_uses=(>, " - ">), " - "immediates=[0], " - "outputs=(>,), name='ld')", - "Op(kind=OpKind.SvLI, " - "input_vals=[>], " - "input_uses=(>,), " - "immediates=[0], " - "outputs=(>,), name='li')", - "Op(kind=OpKind.SetCA, " - "input_vals=[], " - "input_uses=(), " - "immediates=[], " - "outputs=(>,), name='ca')", - "Op(kind=OpKind.SvAddE, " - "input_vals=[>, " - ">, >, " - ">], " - "input_uses=(>, " - ">, >, " - ">), " - "immediates=[], " - "outputs=(>, >), " - "name='add')", - "Op(kind=OpKind.SvStd, " - "input_vals=[>, >, " - ">], " - "input_uses=(>, " - ">, >), " - "immediates=[0], " - "outputs=(), name='st')", - ]) + self.assertEqual( + fn.ops_to_str(), + "arg:\n" + " (<...outputs[0]: >) <= FuncArgR3\n" + "vl:\n" + " (<...outputs[0]: >) <= SetVLI(0x20)\n" + "ld:\n" + " (<...outputs[0]: >) <= SvLd(\n" + " >, >, 0x0)\n" + "li:\n" + " (<...outputs[0]: >) <= SvLI(\n" + " >, 0x0)\n" + "ca:\n" + " (<...outputs[0]: >) <= SetCA\n" + "add:\n" + " (<...outputs[0]: >, <...outputs[1]: >\n" + " ) <= SvAddE(>,\n" + " >, >,\n" + " >)\n" + "st:\n" + " SvStd(>, >,\n" + " >, 0x0)" + ) + self.assertEqual( + fn.ops_to_str(as_python_literal=True), r""" + "arg:\n" + " (<...outputs[0]: >) <= FuncArgR3\n" + "vl:\n" + " (<...outputs[0]: >) <= SetVLI(0x20)\n" + "ld:\n" + " (<...outputs[0]: >) <= SvLd(\n" + " >, >, 0x0)\n" + "li:\n" + " (<...outputs[0]: >) <= SvLI(\n" + " >, 0x0)\n" + "ca:\n" + " (<...outputs[0]: >) <= SetCA\n" + "add:\n" + " (<...outputs[0]: >, <...outputs[1]: >\n" + " ) <= SvAddE(>,\n" + " >, >,\n" + " >)\n" + "st:\n" + " SvStd(>, >,\n" + " >, 0x0)" +"""[1:-1] + ) self.assertEqual([repr(op.properties) for op in fn.ops], [ "OpProperties(kind=OpKind.FuncArgR3, " "inputs=(), " @@ -350,178 +343,85 @@ class TestCompilerIR(unittest.TestCase): def test_pre_ra_insert_copies(self): fn, _arg = self.make_add_fn() fn.pre_ra_insert_copies() - self.assertEqual([repr(i) for i in fn.ops], [ - "Op(kind=OpKind.FuncArgR3, " - "input_vals=[], " - "input_uses=(), " - "immediates=[], " - "outputs=(>,), name='arg')", - "Op(kind=OpKind.CopyFromReg, " - "input_vals=[>], " - "input_uses=(>,), " - "immediates=[], " - "outputs=(>,), " - "name='arg.out0.copy')", - "Op(kind=OpKind.SetVLI, " - "input_vals=[], " - "input_uses=(), " - "immediates=[32], " - "outputs=(>,), name='vl')", - "Op(kind=OpKind.CopyToReg, " - "input_vals=[>], " - "input_uses=(>,), " - "immediates=[], " - "outputs=(>,), name='ld.inp0.copy')", - "Op(kind=OpKind.SetVLI, " - "input_vals=[], " - "input_uses=(), " - "immediates=[32], " - "outputs=(>,), " - "name='ld.inp1.setvl')", - "Op(kind=OpKind.SvLd, " - "input_vals=[>, " - ">], " - "input_uses=(>, " - ">), " - "immediates=[0], " - "outputs=(>,), name='ld')", - "Op(kind=OpKind.SetVLI, " - "input_vals=[], " - "input_uses=(), " - "immediates=[32], " - "outputs=(>,), " - "name='ld.out0.setvl')", - "Op(kind=OpKind.VecCopyFromReg, " - "input_vals=[>, " - ">], " - "input_uses=(>, " - ">), " - "immediates=[], " - "outputs=(>,), " - "name='ld.out0.copy')", - "Op(kind=OpKind.SetVLI, " - "input_vals=[], " - "input_uses=(), " - "immediates=[32], " - "outputs=(>,), " - "name='li.inp0.setvl')", - "Op(kind=OpKind.SvLI, " - "input_vals=[>], " - "input_uses=(>,), " - "immediates=[0], " - "outputs=(>,), name='li')", - "Op(kind=OpKind.SetVLI, " - "input_vals=[], " - "input_uses=(), " - "immediates=[32], " - "outputs=(>,), " - "name='li.out0.setvl')", - "Op(kind=OpKind.VecCopyFromReg, " - "input_vals=[>, " - ">], " - "input_uses=(>, " - ">), " - "immediates=[], " - "outputs=(>,), " - "name='li.out0.copy')", - "Op(kind=OpKind.SetCA, " - "input_vals=[], " - "input_uses=(), " - "immediates=[], " - "outputs=(>,), name='ca')", - "Op(kind=OpKind.SetVLI, " - "input_vals=[], " - "input_uses=(), " - "immediates=[32], " - "outputs=(>,), " - "name='add.inp0.setvl')", - "Op(kind=OpKind.VecCopyToReg, " - "input_vals=[>, " - ">], " - "input_uses=(>, " - ">), " - "immediates=[], " - "outputs=(>,), " - "name='add.inp0.copy')", - "Op(kind=OpKind.SetVLI, " - "input_vals=[], " - "input_uses=(), " - "immediates=[32], " - "outputs=(>,), " - "name='add.inp1.setvl')", - "Op(kind=OpKind.VecCopyToReg, " - "input_vals=[>, " - ">], " - "input_uses=(>, " - ">), " - "immediates=[], " - "outputs=(>,), " - "name='add.inp1.copy')", - "Op(kind=OpKind.SetVLI, " - "input_vals=[], " - "input_uses=(), " - "immediates=[32], " - "outputs=(>,), " - "name='add.inp3.setvl')", - "Op(kind=OpKind.SvAddE, " - "input_vals=[>, " - ">, >, " - ">], " - "input_uses=(>, " - ">, >, " - ">), " - "immediates=[], " - "outputs=(>, >), " - "name='add')", - "Op(kind=OpKind.SetVLI, " - "input_vals=[], " - "input_uses=(), " - "immediates=[32], " - "outputs=(>,), " - "name='add.out0.setvl')", - "Op(kind=OpKind.VecCopyFromReg, " - "input_vals=[>, " - ">], " - "input_uses=(>, " - ">), " - "immediates=[], " - "outputs=(>,), " - "name='add.out0.copy')", - "Op(kind=OpKind.SetVLI, " - "input_vals=[], " - "input_uses=(), " - "immediates=[32], " - "outputs=(>,), " - "name='st.inp0.setvl')", - "Op(kind=OpKind.VecCopyToReg, " - "input_vals=[>, " - ">], " - "input_uses=(>, " - ">), " - "immediates=[], " - "outputs=(>,), " - "name='st.inp0.copy')", - "Op(kind=OpKind.CopyToReg, " - "input_vals=[>], " - "input_uses=(>,), " - "immediates=[], " - "outputs=(>,), " - "name='st.inp1.copy')", - "Op(kind=OpKind.SetVLI, " - "input_vals=[], " - "input_uses=(), " - "immediates=[32], " - "outputs=(>,), " - "name='st.inp2.setvl')", - "Op(kind=OpKind.SvStd, " - "input_vals=[>, " - ">, " - ">], " - "input_uses=(>, " - ">, >), " - "immediates=[0], " - "outputs=(), name='st')", - ]) + self.assertEqual( + fn.ops_to_str(), + "arg:\n" + " (<...outputs[0]: >) <= FuncArgR3\n" + "arg.out0.copy:\n" + " (<...outputs[0]: >) <= CopyFromReg(\n" + " >)\n" + "vl:\n" + " (<...outputs[0]: >) <= SetVLI(0x20)\n" + "ld.inp0.copy:\n" + " (<...outputs[0]: >) <= CopyToReg(\n" + " >)\n" + "ld.inp1.setvl:\n" + " (<...outputs[0]: >) <= SetVLI(0x20)\n" + "ld:\n" + " (<...outputs[0]: >) <= SvLd(\n" + " >,\n" + " >, 0x0)\n" + "ld.out0.setvl:\n" + " (<...outputs[0]: >) <= SetVLI(0x20)\n" + "ld.out0.copy:\n" + " (<...outputs[0]: >) <= VecCopyFromReg(\n" + " >,\n" + " >)\n" + "li.inp0.setvl:\n" + " (<...outputs[0]: >) <= SetVLI(0x20)\n" + "li:\n" + " (<...outputs[0]: >) <= SvLI(\n" + " >, 0x0)\n" + "li.out0.setvl:\n" + " (<...outputs[0]: >) <= SetVLI(0x20)\n" + "li.out0.copy:\n" + " (<...outputs[0]: >) <= VecCopyFromReg(\n" + " >,\n" + " >)\n" + "ca:\n" + " (<...outputs[0]: >) <= SetCA\n" + "add.inp0.setvl:\n" + " (<...outputs[0]: >) <= SetVLI(0x20)\n" + "add.inp0.copy:\n" + " (<...outputs[0]: >) <= VecCopyToReg(\n" + " >,\n" + " >)\n" + "add.inp1.setvl:\n" + " (<...outputs[0]: >) <= SetVLI(0x20)\n" + "add.inp1.copy:\n" + " (<...outputs[0]: >) <= VecCopyToReg(\n" + " >,\n" + " >)\n" + "add.inp3.setvl:\n" + " (<...outputs[0]: >) <= SetVLI(0x20)\n" + "add:\n" + " (<...outputs[0]: >, <...outputs[1]: >\n" + " ) <= SvAddE(>,\n" + " >,\n" + " >,\n" + " >)\n" + "add.out0.setvl:\n" + " (<...outputs[0]: >) <= SetVLI(0x20)\n" + "add.out0.copy:\n" + " (<...outputs[0]: >) <= VecCopyFromReg(\n" + " >,\n" + " >)\n" + "st.inp0.setvl:\n" + " (<...outputs[0]: >) <= SetVLI(0x20)\n" + "st.inp0.copy:\n" + " (<...outputs[0]: >) <= VecCopyToReg(\n" + " >,\n" + " >)\n" + "st.inp1.copy:\n" + " (<...outputs[0]: >) <= CopyToReg(\n" + " >)\n" + "st.inp2.setvl:\n" + " (<...outputs[0]: >) <= SetVLI(0x20)\n" + "st:\n" + " SvStd(>,\n" + " >,\n" + " >, 0x0)" + ) self.assertEqual([repr(op.properties) for op in fn.ops], [ "OpProperties(kind=OpKind.FuncArgR3, " "inputs=(), " @@ -1045,47 +945,23 @@ class TestCompilerIR(unittest.TestCase): "write_stage=OpStage.Late)," "), maxvl=4)", ]) - self.assertEqual([repr(op) for op in fn.ops], [ - "Op(kind=OpKind.SetVLI, input_vals=[" - "], input_uses=(" - "), immediates=[4], outputs=(" - ">," - "), name='vl')", - "Op(kind=OpKind.SvLI, input_vals=[" - ">" - "], input_uses=(" - ">," - "), immediates=[0], outputs=(" - ">," - "), name='li')", - "Op(kind=OpKind.Spread, input_vals=[" - ">, " - ">" - "], input_uses=(" - ">, " - ">" - "), immediates=[], outputs=(" - ">, " - ">, " - ">, " - ">" - "), name='spread')", - "Op(kind=OpKind.Concat, input_vals=[" - ">, " - ">, " - ">, " - ">, " - ">" - "], input_uses=(" - ">, " - ">, " - ">, " - ">, " - ">" - "), immediates=[], outputs=(" - ">," - "), name='concat')", - ]) + self.assertEqual( + fn.ops_to_str(), + "vl:\n" + " (<...outputs[0]: >) <= SetVLI(0x4)\n" + "li:\n" + " (<...outputs[0]: >) <= SvLI(\n" + " >, 0x0)\n" + "spread:\n" + " (<...outputs[0]: >, <...outputs[1]: >,\n" + " <...outputs[2]: >, <...outputs[3]: >) <= Spread(\n" + " >, >)\n" + "concat:\n" + " (<...outputs[0]: >) <= Concat(\n" + " >, >,\n" + " >, >,\n" + " >)" + ) if __name__ == "__main__": diff --git a/src/bigint_presentation_code/_tests/test_toom_cook.py b/src/bigint_presentation_code/_tests/test_toom_cook.py index 2f548a4..735d7e0 100644 --- a/src/bigint_presentation_code/_tests/test_toom_cook.py +++ b/src/bigint_presentation_code/_tests/test_toom_cook.py @@ -1,9 +1,10 @@ import unittest from typing import Callable -from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BYTES, - BaseSimState, Fn, - GenAsmState, OpKind, +from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BITS, + GPR_SIZE_IN_BYTES, + GPR_VALUE_MASK, BaseSimState, + Fn, GenAsmState, OpKind, PostRASimState, PreRASimState, SSAVal) from bigint_presentation_code.register_allocator import allocate_registers @@ -14,7 +15,7 @@ from bigint_presentation_code.toom_cook import (ToomCookInstance, simple_mul, def simple_umul(fn, lhs, rhs): # type: (Fn, SSAVal, SSAVal) -> SSAVal return simple_mul(fn=fn, lhs=lhs, lhs_signed=False, rhs=rhs, - rhs_signed=False, name="simple_umul") + rhs_signed=False, name="mul") class Mul: @@ -227,26 +228,33 @@ class TestToomCook(unittest.TestCase): ) def test_simple_mul_192x192_pre_ra_sim(self): - self.skipTest("WIP") # FIXME: finish fixing simple_mul - - def create_sim_state(code): - # type: (Mul) -> BaseSimState - return PreRASimState(ssa_vals={}, memory={}) - self.tst_simple_mul_192x192_sim(create_sim_state) + def get_state_factory(code): + # type: (Mul) -> Callable[[], BaseSimState] + return lambda: PreRASimState(ssa_vals={}, memory={}) + for lhs_signed in False, True: + for rhs_signed in False, True: + self.tst_simple_mul_192x192_sim( + lhs_signed=lhs_signed, rhs_signed=rhs_signed, + get_state_factory=get_state_factory) def test_simple_mul_192x192_post_ra_sim(self): - self.skipTest("WIP") # FIXME: finish fixing simple_mul - - def create_sim_state(code): - # type: (Mul) -> BaseSimState + def get_state_factory(code): + # type: (Mul) -> Callable[[], BaseSimState] ssa_val_to_loc_map = allocate_registers(code.fn) - return PostRASimState(ssa_val_to_loc_map=ssa_val_to_loc_map, - memory={}, loc_values={}) - self.tst_simple_mul_192x192_sim(create_sim_state) + return lambda: PostRASimState( + ssa_val_to_loc_map=ssa_val_to_loc_map, + memory={}, loc_values={}) + for lhs_signed in False, True: + for rhs_signed in False, True: + self.tst_simple_mul_192x192_sim( + lhs_signed=lhs_signed, rhs_signed=rhs_signed, + get_state_factory=get_state_factory) - def tst_simple_mul_192x192_sim(self, create_sim_state): - # type: (Callable[[Mul], BaseSimState]) -> None - self.skipTest("WIP") # FIXME: finish fixing simple_mul + def tst_simple_mul_192x192_sim( + self, lhs_signed, # type: bool + rhs_signed, # type: bool + get_state_factory, # type: Callable[[Mul], Callable[[], BaseSimState]] + ): # test multiplying: # 0x000191acb262e15b_4c6b5f2b19e1a53e_821a2342132c5b57 # * 0x4a37c0567bcbab53_cf1f597598194ae6_208a49071aeec507 @@ -256,260 +264,204 @@ class TestToomCook(unittest.TestCase): # "_3931783239312079_7261727469627261", base=0) # == int.from_bytes(b"arbitrary 192x192->384-bit multiplication test", # 'little') - code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3) - state = create_sim_state(code) + lhs_value = 0x000191acb262e15b_4c6b5f2b19e1a53e_821a2342132c5b57 + rhs_value = 0x4a37c0567bcbab53_cf1f597598194ae6_208a49071aeec507 + prod_value = int.from_bytes( + b"arbitrary 192x192->384-bit multiplication test", 'little') + self.assertEqual(lhs_value * rhs_value, prod_value) + code = Mul( + mul=lambda fn, lhs, rhs: simple_mul( + fn=fn, lhs=lhs, lhs_signed=lhs_signed, + rhs=rhs, rhs_signed=rhs_signed, name="mul"), + lhs_size_in_words=3, rhs_size_in_words=3) + state_factory = get_state_factory(code) ptr_in = 0x100 dest_ptr = ptr_in + code.dest_offset lhs_ptr = ptr_in + code.lhs_offset rhs_ptr = ptr_in + code.rhs_offset - state[code.ptr_in] = ptr_in, - state.store(lhs_ptr, 0x821a2342132c5b57) - state.store(lhs_ptr + 8, 0x4c6b5f2b19e1a53e) - state.store(lhs_ptr + 16, 0x000191acb262e15b) - state.store(rhs_ptr, 0x208a49071aeec507) - state.store(rhs_ptr + 8, 0xcf1f597598194ae6) - state.store(rhs_ptr + 16, 0x4a37c0567bcbab53) - code.fn.sim(state) - expected_bytes = b"arbitrary 192x192->384-bit multiplication test" - OUT_BYTE_COUNT = 6 * GPR_SIZE_IN_BYTES - expected_bytes = expected_bytes.ljust(OUT_BYTE_COUNT, b'\0') - out_bytes = bytes( - state.load_byte(dest_ptr + i) for i in range(OUT_BYTE_COUNT)) - self.assertEqual(out_bytes, expected_bytes) + for lhs_neg in False, True: + for rhs_neg in False, True: + if lhs_neg and not lhs_signed: + continue + if rhs_neg and not rhs_signed: + continue + with self.subTest(lhs_signed=lhs_signed, + rhs_signed=rhs_signed, + lhs_neg=lhs_neg, rhs_neg=rhs_neg): + state = state_factory() + state[code.ptr_in] = ptr_in, + lhs = lhs_value + if lhs_neg: + lhs = 2 ** 192 - lhs + rhs = rhs_value + if rhs_neg: + rhs = 2 ** 192 - rhs + for i in range(3): + v = (lhs >> GPR_SIZE_IN_BITS * i) & GPR_VALUE_MASK + state.store(lhs_ptr + i * GPR_SIZE_IN_BYTES, v) + for i in range(3): + v = (rhs >> GPR_SIZE_IN_BITS * i) & GPR_VALUE_MASK + state.store(rhs_ptr + i * GPR_SIZE_IN_BYTES, v) + code.fn.sim(state) + expected = prod_value + if lhs_neg != rhs_neg: + expected = 2 ** 384 - expected + prod = 0 + for i in range(6): + v = state.load(dest_ptr + GPR_SIZE_IN_BYTES * i) + prod += v << (GPR_SIZE_IN_BITS * i) + self.assertEqual(hex(prod), hex(expected)) def test_simple_mul_192x192_ops(self): - self.skipTest("WIP") # FIXME: finish fixing simple_mul code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3) fn = code.fn - self.assertEqual([repr(v) for v in fn.ops], [ - "Op(kind=OpKind.FuncArgR3, " - "input_vals=[], " - "input_uses=(), immediates=[], " - "outputs=(>,), " - "name='ptr_in')", - "Op(kind=OpKind.SetVLI, " - "input_vals=[], " - "input_uses=(), immediates=[3], " - "outputs=(>,), " - "name='lhs_setvl')", - "Op(kind=OpKind.SvLd, " - "input_vals=[>, " - ">], " - "input_uses=(>, " - ">), immediates=[48], " - "outputs=(>,), " - "name='load_lhs')", - "Op(kind=OpKind.SetVLI, " - "input_vals=[], " - "input_uses=(), immediates=[3], " - "outputs=(>,), " - "name='rhs_setvl')", - "Op(kind=OpKind.SvLd, " - "input_vals=[>, " - ">], " - "input_uses=(>, " - ">), immediates=[72], " - "outputs=(>,), " - "name='load_rhs')", - "Op(kind=OpKind.SetVLI, " - "input_vals=[], " - "input_uses=(), immediates=[3], " - "outputs=(>,), " - "name='rhs_setvl2')", - "Op(kind=OpKind.Spread, " - "input_vals=[>, " - ">], " - "input_uses=(>, " - ">), immediates=[], " - "outputs=(>, " - ">, " - ">), " - "name='rhs_spread')", - "Op(kind=OpKind.SetVLI, " - "input_vals=[], " - "input_uses=(), immediates=[3], " - "outputs=(>,), " - "name='lhs_setvl3')", - "Op(kind=OpKind.LI, " - "input_vals=[], " - "input_uses=(), immediates=[0], " - "outputs=(>,), " - "name='zero')", - "Op(kind=OpKind.SvMAddEDU, " - "input_vals=[>, " - ">, " - ">, " - ">], " - "input_uses=(>, " - ">, " - ">, " - ">), immediates=[], " - "outputs=(>, " - ">), " - "name='mul0')", - "Op(kind=OpKind.Spread, " - "input_vals=[>, " - ">], " - "input_uses=(>, " - ">), immediates=[], " - "outputs=(>, " - ">, " - ">), " - "name='mul0_rt_spread')", - "Op(kind=OpKind.SvMAddEDU, " - "input_vals=[>, " - ">, " - ">, " - ">], " - "input_uses=(>, " - ">, " - ">, " - ">), immediates=[], " - "outputs=(>, " - ">), " - "name='mul1')", - "Op(kind=OpKind.Concat, " - "input_vals=[>, " - ">, " - ">, " - ">], " - "input_uses=(>, " - ">, " - ">, " - ">), immediates=[], " - "outputs=(>,), " - "name='add1_rb_concat')", - "Op(kind=OpKind.ClearCA, " - "input_vals=[], " - "input_uses=(), immediates=[], " - "outputs=(>,), " - "name='clear_ca1')", - "Op(kind=OpKind.SvAddE, " - "input_vals=[>, " - ">, " - ">, " - ">], " - "input_uses=(>, " - ">, " - ">, " - ">), immediates=[], " - "outputs=(>, " - ">), " - "name='add1')", - "Op(kind=OpKind.Spread, " - "input_vals=[>, " - ">], " - "input_uses=(>, " - ">), immediates=[], " - "outputs=(>, " - ">, " - ">), " - "name='add1_rt_spread')", - "Op(kind=OpKind.AddZE, " - "input_vals=[>, " - ">], " - "input_uses=(>, " - ">), immediates=[], " - "outputs=(>, " - ">), " - "name='add_hi1')", - "Op(kind=OpKind.SvMAddEDU, " - "input_vals=[>, " - ">, " - ">, " - ">], " - "input_uses=(>, " - ">, " - ">, " - ">), immediates=[], " - "outputs=(>, " - ">), " - "name='mul2')", - "Op(kind=OpKind.Concat, " - "input_vals=[>, " - ">, " - ">, " - ">], " - "input_uses=(>, " - ">, " - ">, " - ">), immediates=[], " - "outputs=(>,), " - "name='add2_rb_concat')", - "Op(kind=OpKind.ClearCA, " - "input_vals=[], " - "input_uses=(), immediates=[], " - "outputs=(>,), " - "name='clear_ca2')", - "Op(kind=OpKind.SvAddE, " - "input_vals=[>, " - ">, " - ">, " - ">], " - "input_uses=(>, " - ">, " - ">, " - ">), immediates=[], " - "outputs=(>, " - ">), " - "name='add2')", - "Op(kind=OpKind.Spread, " - "input_vals=[>, " - ">], " - "input_uses=(>, " - ">), immediates=[], " - "outputs=(>, " - ">, " - ">), " - "name='add2_rt_spread')", - "Op(kind=OpKind.AddZE, " - "input_vals=[>, " - ">], " - "input_uses=(>, " - ">), immediates=[], " - "outputs=(>, " - ">), " - "name='add_hi2')", - "Op(kind=OpKind.SetVLI, " - "input_vals=[], " - "input_uses=(), immediates=[6], " - "outputs=(>,), " - "name='retval_setvl')", - "Op(kind=OpKind.Concat, " - "input_vals=[>, " - ">, " - ">, " - ">, " - ">, " - ">, " - ">], " - "input_uses=(>, " - ">, " - ">, " - ">, " - ">, " - ">, " - ">), immediates=[], " - "outputs=(>,), " - "name='concat_retval')", - "Op(kind=OpKind.SetVLI, " - "input_vals=[], " - "input_uses=(), immediates=[6], " - "outputs=(>,), " - "name='dest_setvl')", - "Op(kind=OpKind.SvStd, " - "input_vals=[>, " - ">, " - ">], " - "input_uses=(>, " - ">, " - ">), immediates=[0], " - "outputs=(), " - "name='store_dest')", - ]) + self.assertEqual( + fn.ops_to_str(), + "ptr_in:\n" + " (<...outputs[0]: >) <= FuncArgR3\n" + "lhs_setvl:\n" + " (<...outputs[0]: >) <= SetVLI(0x3)\n" + "load_lhs:\n" + " (<...outputs[0]: >) <= SvLd(\n" + " >,\n" + " >, 0x30)\n" + "rhs_setvl:\n" + " (<...outputs[0]: >) <= SetVLI(0x3)\n" + "load_rhs:\n" + " (<...outputs[0]: >) <= SvLd(\n" + " >,\n" + " >, 0x48)\n" + "mul_rhs_setvl:\n" + " (<...outputs[0]: >) <= SetVLI(0x3)\n" + "mul_rhs_spread:\n" + " (<...outputs[0]: >, <...outputs[1]: >,\n" + " <...outputs[2]: >) <= Spread(\n" + " >,\n" + " >)\n" + "mul_zero:\n" + " (<...outputs[0]: >) <= LI(0x0)\n" + "mul_lhs_setvl:\n" + " (<...outputs[0]: >) <= SetVLI(0x3)\n" + "mul_zero2:\n" + " (<...outputs[0]: >) <= LI(0x0)\n" + "mul_0_mul:\n" + " (<...outputs[0]: >, <...outputs[1]: >\n" + " ) <= SvMAddEDU(>,\n" + " >,\n" + " >,\n" + " >)\n" + "mul_0_mul_rt_spread:\n" + " (<...outputs[0]: >, <...outputs[1]: >,\n" + " <...outputs[2]: >) <= Spread(\n" + " >,\n" + " >)\n" + "mul_1_mul:\n" + " (<...outputs[0]: >, <...outputs[1]: >\n" + " ) <= SvMAddEDU(>,\n" + " >,\n" + " >,\n" + " >)\n" + "mul_1_mul_rt_spread:\n" + " (<...outputs[0]: >, <...outputs[1]: >,\n" + " <...outputs[2]: >) <= Spread(\n" + " >,\n" + " >)\n" + "mul_1_cast_retval_zero:\n" + " (<...outputs[0]: >) <= LI(0x0)\n" + "mul_1_cast_pp_zero:\n" + " (<...outputs[0]: >) <= LI(0x0)\n" + "mul_1_setvl:\n" + " (<...outputs[0]: >) <= SetVLI(0x5)\n" + "mul_1_retval_concat:\n" + " (<...outputs[0]: >) <= Concat(\n" + " >,\n" + " >,\n" + " >,\n" + " >,\n" + " >,\n" + " >)\n" + "mul_1_pp_concat:\n" + " (<...outputs[0]: >) <= Concat(\n" + " >,\n" + " >,\n" + " >,\n" + " >,\n" + " >,\n" + " >)\n" + "mul_1_clear_ca:\n" + " (<...outputs[0]: >) <= ClearCA\n" + "mul_1_add:\n" + " (<...outputs[0]: >, <...outputs[1]: >\n" + " ) <= SvAddE(>,\n" + " >,\n" + " >,\n" + " >)\n" + "mul_1_sum_spread:\n" + " (<...outputs[0]: >, <...outputs[1]: >,\n" + " <...outputs[2]: >, <...outputs[3]: >,\n" + " <...outputs[4]: >) <= Spread(\n" + " >,\n" + " >)\n" + "mul_2_mul:\n" + " (<...outputs[0]: >, <...outputs[1]: >\n" + " ) <= SvMAddEDU(>,\n" + " >,\n" + " >,\n" + " >)\n" + "mul_2_mul_rt_spread:\n" + " (<...outputs[0]: >, <...outputs[1]: >,\n" + " <...outputs[2]: >) <= Spread(\n" + " >,\n" + " >)\n" + "mul_2_setvl:\n" + " (<...outputs[0]: >) <= SetVLI(0x4)\n" + "mul_2_retval_concat:\n" + " (<...outputs[0]: >) <= Concat(\n" + " >,\n" + " >,\n" + " >,\n" + " >,\n" + " >)\n" + "mul_2_pp_concat:\n" + " (<...outputs[0]: >) <= Concat(\n" + " >,\n" + " >,\n" + " >,\n" + " >,\n" + " >)\n" + "mul_2_clear_ca:\n" + " (<...outputs[0]: >) <= ClearCA\n" + "mul_2_add:\n" + " (<...outputs[0]: >, <...outputs[1]: >\n" + " ) <= SvAddE(>,\n" + " >,\n" + " >,\n" + " >)\n" + "mul_2_sum_spread:\n" + " (<...outputs[0]: >, <...outputs[1]: >,\n" + " <...outputs[2]: >, <...outputs[3]: >) <= Spread(\n" + " >,\n" + " >)\n" + "mul_setvl:\n" + " (<...outputs[0]: >) <= SetVLI(0x6)\n" + "mul_concat:\n" + " (<...outputs[0]: >) <= Concat(\n" + " >,\n" + " >,\n" + " >,\n" + " >,\n" + " >,\n" + " >,\n" + " >)\n" + "dest_setvl:\n" + " (<...outputs[0]: >) <= SetVLI(0x6)\n" + "store_dest:\n" + " SvStd(>,\n" + " >,\n" + " >, 0x0)" + ) def test_simple_mul_192x192_reg_alloc(self): - self.skipTest("WIP") # FIXME: finish fixing simple_mul code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3) fn = code.fn assigned_registers = allocate_registers(fn) @@ -525,251 +477,339 @@ class TestToomCook(unittest.TestCase): "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=3, reg_len=6), " - ">: " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=3, reg_len=6), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=3, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=4, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=5, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=6, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=7, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=8, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=9, reg_len=1), " - ">: " - "Loc(kind=LocKind.CA, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.CA, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.CA, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=3, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=4, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=10, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=11, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=12, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=3, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=4, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=5, reg_len=1), " - ">: " + ">: " + "Loc(kind=LocKind.GPR, start=6, reg_len=1), " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=3, reg_len=3), " - ">: " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=4), " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=3, reg_len=3), " - ">: " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=4), " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=3, reg_len=3), " - ">: " + ">: " + "Loc(kind=LocKind.CA, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.CA, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=4), " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=6, reg_len=3), " - ">: " + ">: " + "Loc(kind=LocKind.GPR, start=7, reg_len=4), " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=9, reg_len=3), " - ">: " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=4), " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=3, reg_len=3), " - ">: " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=4), " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=3, reg_len=3), " - ">: " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=4), " + ">: " "Loc(kind=LocKind.GPR, start=3, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=4, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=5, reg_len=1), " - ">: " + ">: " + "Loc(kind=LocKind.GPR, start=6, reg_len=1), " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=14, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=6, reg_len=3), " - ">: " + ">: " + "Loc(kind=LocKind.GPR, start=7, reg_len=4), " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=4), " + ">: " "Loc(kind=LocKind.GPR, start=3, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=3, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=4, reg_len=3), " - ">: " + ">: " + "Loc(kind=LocKind.GPR, start=4, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=5, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=6, reg_len=1), " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=7, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=8, reg_len=3), " - ">: " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=11, reg_len=1), " - ">: " - "Loc(kind=LocKind.CA, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.CA, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.CA, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=3, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=4, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=12, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=15, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=16, reg_len=1), " - ">: " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=1), " + ">: " "Loc(kind=LocKind.GPR, start=3, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=4, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=5, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=3, reg_len=3), " - ">: " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " + ">: " + "Loc(kind=LocKind.GPR, start=15, reg_len=1), " + ">: " "Loc(kind=LocKind.GPR, start=3, reg_len=3), " - ">: " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " + ">: " + "Loc(kind=LocKind.GPR, start=6, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=6, reg_len=1), " + ">: " "Loc(kind=LocKind.GPR, start=3, reg_len=3), " - ">: " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=7, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=8, reg_len=3), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=16, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=17, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=18, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=19, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=20, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=4, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=5, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=6, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=7, reg_len=1), " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=6, reg_len=3), " - ">: " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=5), " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=9, reg_len=3), " - ">: " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=5), " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=3, reg_len=3), " - ">: " + ">: " + "Loc(kind=LocKind.CA, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.CA, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=5), " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=3, reg_len=3), " - ">: " + ">: " + "Loc(kind=LocKind.GPR, start=8, reg_len=5), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=5), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=5), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=5), " + ">: " "Loc(kind=LocKind.GPR, start=3, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=4, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=5, reg_len=1), " - ">: " + ">: " + "Loc(kind=LocKind.GPR, start=6, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=7, reg_len=1), " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=14, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=6, reg_len=3), " - ">: " + ">: " + "Loc(kind=LocKind.GPR, start=8, reg_len=5), " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=5), " + ">: " "Loc(kind=LocKind.GPR, start=3, reg_len=1), " - ">: " + ">: " + "Loc(kind=LocKind.GPR, start=4, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=5, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=6, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=7, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=8, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=15, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=16, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=17, reg_len=1), " + ">: " "Loc(kind=LocKind.GPR, start=3, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=4, reg_len=3), " - ">: " + ">: " + "Loc(kind=LocKind.GPR, start=4, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=5, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=3), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=18, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=3), " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " + ">: " + "Loc(kind=LocKind.GPR, start=6, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=6, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=3), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " "Loc(kind=LocKind.GPR, start=7, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=8, reg_len=3), " - ">: " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=11, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=12, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=17, reg_len=1), " - ">: " + ">: " + "Loc(kind=LocKind.GPR, start=21, reg_len=1), " + ">: " "Loc(kind=LocKind.GPR, start=3, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=4, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=5, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=3, reg_len=3), " - ">: " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=15, reg_len=1), " - ">: " + ">: " + "Loc(kind=LocKind.GPR, start=19, reg_len=1), " + ">: " "Loc(kind=LocKind.GPR, start=3, reg_len=3), " - ">: " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=6, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=6, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=3, reg_len=3), " - ">: " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=7, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=8, reg_len=3), " - ">: " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=18, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=3, reg_len=1), " - ">: " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=19, reg_len=1), " - ">: " + ">: " + "Loc(kind=LocKind.GPR, start=22, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=23, reg_len=1), " + ">: " "Loc(kind=LocKind.GPR, start=14, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=4, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=5, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=6, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=7, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.GPR, start=3, reg_len=3), " - ">: " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " ">: " "Loc(kind=LocKind.GPR, start=3, reg_len=3), " @@ -784,7 +824,7 @@ class TestToomCook(unittest.TestCase): ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " ">: " - "Loc(kind=LocKind.GPR, start=20, reg_len=3), " + "Loc(kind=LocKind.GPR, start=24, reg_len=3), " ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " ">: " @@ -796,28 +836,27 @@ class TestToomCook(unittest.TestCase): ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " ">: " - "Loc(kind=LocKind.GPR, start=23, reg_len=1), " + "Loc(kind=LocKind.GPR, start=27, reg_len=1), " ">: " "Loc(kind=LocKind.GPR, start=3, reg_len=1)" "}") def test_simple_mul_192x192_asm(self): - self.skipTest("WIP") # FIXME: finish fixing simple_mul code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3) fn = code.fn assigned_registers = allocate_registers(fn) gen_asm_state = GenAsmState(assigned_registers) fn.gen_asm(gen_asm_state) self.assertEqual(gen_asm_state.output, [ - 'or 23, 3, 3', + 'or 27, 3, 3', 'setvl 0, 0, 3, 0, 1, 1', - 'or 6, 23, 23', + 'or 6, 27, 27', 'setvl 0, 0, 3, 0, 1, 1', 'sv.ld *3, 48(6)', 'setvl 0, 0, 3, 0, 1, 1', - 'sv.or *20, *3, *3', + 'sv.or *24, *3, *3', 'setvl 0, 0, 3, 0, 1, 1', - 'or 6, 23, 23', + 'or 6, 27, 27', 'setvl 0, 0, 3, 0, 1, 1', 'sv.ld *3, 72(6)', 'setvl 0, 0, 3, 0, 1, 1', @@ -827,86 +866,116 @@ class TestToomCook(unittest.TestCase): 'sv.or/mrr *5, *3, *3', 'or 4, 5, 5', 'or 14, 6, 6', - 'or 19, 7, 7', + 'or 23, 7, 7', + 'addi 3, 0, 0', + 'or 22, 3, 3', 'setvl 0, 0, 3, 0, 1, 1', 'addi 3, 0, 0', - 'or 18, 3, 3', 'setvl 0, 0, 3, 0, 1, 1', - 'sv.or *8, *20, *20', + 'sv.or *8, *24, *24', 'or 7, 4, 4', - 'or 6, 18, 18', + 'or 6, 22, 22', 'setvl 0, 0, 3, 0, 1, 1', 'sv.maddedu *3, *8, 7, 6', 'setvl 0, 0, 3, 0, 1, 1', - 'or 15, 6, 6', + 'or 19, 6, 6', 'setvl 0, 0, 3, 0, 1, 1', 'setvl 0, 0, 3, 0, 1, 1', - 'or 17, 3, 3', + 'or 21, 3, 3', 'or 12, 4, 4', 'or 11, 5, 5', 'setvl 0, 0, 3, 0, 1, 1', - 'sv.or *8, *20, *20', + 'sv.or *8, *24, *24', 'or 7, 14, 14', - 'or 3, 18, 18', + 'or 6, 22, 22', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.maddedu *3, *8, 7, 6', + 'setvl 0, 0, 3, 0, 1, 1', + 'or 18, 6, 6', 'setvl 0, 0, 3, 0, 1, 1', - 'sv.maddedu *4, *8, 7, 3', 'setvl 0, 0, 3, 0, 1, 1', - 'sv.or/mrr *6, *4, *4', + 'or 17, 3, 3', + 'or 16, 4, 4', + 'or 15, 5, 5', + 'addi 3, 0, 0', + 'or 8, 3, 3', + 'addi 3, 0, 0', 'or 14, 3, 3', + 'setvl 0, 0, 5, 0, 1, 1', 'or 3, 12, 12', 'or 4, 11, 11', + 'or 5, 19, 19', + 'or 6, 8, 8', + 'or 7, 8, 8', + 'setvl 0, 0, 5, 0, 1, 1', + 'setvl 0, 0, 5, 0, 1, 1', + 'sv.or *8, *3, *3', + 'or 3, 17, 17', + 'or 4, 16, 16', 'or 5, 15, 15', - 'setvl 0, 0, 3, 0, 1, 1', - 'setvl 0, 0, 3, 0, 1, 1', + 'or 6, 18, 18', + 'or 7, 14, 14', + 'setvl 0, 0, 5, 0, 1, 1', + 'setvl 0, 0, 5, 0, 1, 1', 'addic 0, 0, 0', + 'setvl 0, 0, 5, 0, 1, 1', + 'sv.or *14, *8, *8', + 'setvl 0, 0, 5, 0, 1, 1', + 'sv.or *8, *3, *3', + 'setvl 0, 0, 5, 0, 1, 1', + 'sv.adde *3, *14, *8', + 'setvl 0, 0, 5, 0, 1, 1', + 'setvl 0, 0, 5, 0, 1, 1', + 'setvl 0, 0, 5, 0, 1, 1', + 'or 20, 3, 3', + 'or 19, 4, 4', + 'or 18, 5, 5', + 'or 17, 6, 6', + 'or 16, 7, 7', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *8, *24, *24', + 'or 7, 23, 23', + 'or 6, 22, 22', 'setvl 0, 0, 3, 0, 1, 1', - 'sv.or *9, *6, *6', - 'setvl 0, 0, 3, 0, 1, 1', - 'sv.or *6, *3, *3', - 'setvl 0, 0, 3, 0, 1, 1', - 'sv.adde *3, *9, *6', - 'setvl 0, 0, 3, 0, 1, 1', - 'setvl 0, 0, 3, 0, 1, 1', - 'setvl 0, 0, 3, 0, 1, 1', - 'or 16, 3, 3', - 'or 15, 4, 4', - 'or 12, 5, 5', - 'or 4, 14, 14', - 'addze *3, *4', - 'or 11, 3, 3', + 'sv.maddedu *3, *8, 7, 6', 'setvl 0, 0, 3, 0, 1, 1', - 'sv.or *8, *20, *20', - 'or 7, 19, 19', - 'or 3, 18, 18', + 'or 15, 6, 6', 'setvl 0, 0, 3, 0, 1, 1', - 'sv.maddedu *4, *8, 7, 3', 'setvl 0, 0, 3, 0, 1, 1', - 'sv.or/mrr *6, *4, *4', 'or 14, 3, 3', - 'or 3, 15, 15', + 'or 12, 4, 4', + 'or 11, 5, 5', + 'setvl 0, 0, 4, 0, 1, 1', + 'or 3, 19, 19', + 'or 4, 18, 18', + 'or 5, 17, 17', + 'or 6, 16, 16', + 'setvl 0, 0, 4, 0, 1, 1', + 'setvl 0, 0, 4, 0, 1, 1', + 'sv.or *7, *3, *3', + 'or 3, 14, 14', 'or 4, 12, 12', 'or 5, 11, 11', - 'setvl 0, 0, 3, 0, 1, 1', - 'setvl 0, 0, 3, 0, 1, 1', + 'or 6, 15, 15', + 'setvl 0, 0, 4, 0, 1, 1', + 'setvl 0, 0, 4, 0, 1, 1', 'addic 0, 0, 0', - 'setvl 0, 0, 3, 0, 1, 1', - 'sv.or *9, *6, *6', - 'setvl 0, 0, 3, 0, 1, 1', - 'sv.or *6, *3, *3', - 'setvl 0, 0, 3, 0, 1, 1', - 'sv.adde *3, *9, *6', - 'setvl 0, 0, 3, 0, 1, 1', - 'setvl 0, 0, 3, 0, 1, 1', - 'setvl 0, 0, 3, 0, 1, 1', + 'setvl 0, 0, 4, 0, 1, 1', + 'sv.or *14, *7, *7', + 'setvl 0, 0, 4, 0, 1, 1', + 'sv.or *7, *3, *3', + 'setvl 0, 0, 4, 0, 1, 1', + 'sv.adde *3, *14, *7', + 'setvl 0, 0, 4, 0, 1, 1', + 'setvl 0, 0, 4, 0, 1, 1', + 'setvl 0, 0, 4, 0, 1, 1', 'or 12, 3, 3', 'or 11, 4, 4', 'or 10, 5, 5', - 'or 4, 14, 14', - 'addze *3, *4', - 'or 9, 3, 3', + 'or 9, 6, 6', 'setvl 0, 0, 6, 0, 1, 1', - 'or 3, 17, 17', - 'or 4, 16, 16', + 'or 3, 21, 21', + 'or 4, 20, 20', 'or 5, 12, 12', 'or 6, 11, 11', 'or 7, 10, 10', @@ -916,7 +985,7 @@ class TestToomCook(unittest.TestCase): 'setvl 0, 0, 6, 0, 1, 1', 'setvl 0, 0, 6, 0, 1, 1', 'sv.or/mrr *4, *3, *3', - 'or 3, 23, 23', + 'or 3, 27, 27', 'setvl 0, 0, 6, 0, 1, 1', 'sv.std *4, 0(3)' ]) diff --git a/src/bigint_presentation_code/compiler_ir.py b/src/bigint_presentation_code/compiler_ir.py index ba5ada9..9c7d729 100644 --- a/src/bigint_presentation_code/compiler_ir.py +++ b/src/bigint_presentation_code/compiler_ir.py @@ -15,7 +15,6 @@ from bigint_presentation_code.type_util import (Literal, Self, assert_never, from bigint_presentation_code.util import (BitSet, FBitSet, FMap, InternedMeta, OFSet, OSet) - GPR_SIZE_IN_BYTES = 8 BITS_IN_BYTE = 8 GPR_SIZE_IN_BITS = GPR_SIZE_IN_BYTES * BITS_IN_BYTE @@ -47,6 +46,31 @@ class Fn: # type: () -> str return "" + def ops_to_str(self, as_python_literal=False, wrap_width=63, + python_indent=" ", indent=" "): + # type: (bool, int, str, str) -> str + l = [] # type: list[str] + for op in self.ops: + l.append(op.__repr__(wrap_width=wrap_width, indent=indent)) + retval = "\n".join(l) + if as_python_literal: + l = [python_indent + "\""] + for ch in retval: + if ch == "\n": + l.append(f"\\n\"\n{python_indent}\"") + elif ch in "\"\\": + l.append("\\" + ch) + elif ch.isascii() and ch.isprintable(): + l.append(ch) + else: + l.append(repr(ch).strip("\"'")) + l.append("\"") + retval = "".join(l) + empty_end = f"\"\n{python_indent}\"\"" + if retval.endswith(empty_end): + retval = retval[:-len(empty_end)] + return retval + def append_op(self, op): # type: (Op) -> None if op.fn is not self: @@ -1196,6 +1220,32 @@ class OpKind(Enum): _SIM_FNS[SvSubFE] = lambda: OpKind.__svsubfe_sim _GEN_ASMS[SvSubFE] = lambda: OpKind.__svsubfe_gen_asm + @staticmethod + def __svandvs_sim(op, state): + # type: (Op, BaseSimState) -> None + RA = state[op.input_vals[0]] + RB, = state[op.input_vals[1]] + VL, = state[op.input_vals[2]] + RT = [] # type: list[int] + for i in range(VL): + RT.append(RA[i] & RB & GPR_VALUE_MASK) + state[op.outputs[0]] = tuple(RT) + + @staticmethod + def __svandvs_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + RT = state.vgpr(op.outputs[0]) + RA = state.vgpr(op.input_vals[0]) + RB = state.sgpr(op.input_vals[1]) + state.writeln(f"sv.and {RT}, {RA}, {RB}") + SvAndVS = GenericOpProperties( + demo_asm="sv.and *RT, *RA, RB", + inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL], + outputs=[OD_EXTRA3_VGPR], + ) + _SIM_FNS[SvAndVS] = lambda: OpKind.__svandvs_sim + _GEN_ASMS[SvAndVS] = lambda: OpKind.__svandvs_gen_asm + @staticmethod def __svmaddedu_sim(op, state): # type: (Op, BaseSimState) -> None @@ -1946,24 +1996,54 @@ class Op: # type: () -> int return object.__hash__(self) - def __repr__(self): - # type: () -> str - field_vals = [] # type: list[str] - for name in fields(self): - if name == "properties": - name = "kind" - elif name == "fn": - continue - try: - value = getattr(self, name) - except AttributeError: - field_vals.append(f"{name}=") - continue - if isinstance(value, OpInputSeq): - value = list(value) # type: ignore - field_vals.append(f"{name}={value!r}") - field_vals_str = ", ".join(field_vals) - return f"Op({field_vals_str})" + def __repr__(self, wrap_width=63, indent=" "): + # type: (int, str) -> str + WRAP_POINT = "\u200B" # zero-width space + items = [f"{self.name}:\n"] + for i, out in enumerate(self.outputs): + item = f"<...outputs[{i}]: {out.ty}>" + if i == 0: + item = "(" + WRAP_POINT + item + if i != len(self.outputs) - 1: + item += ", " + WRAP_POINT + else: + item += WRAP_POINT + ") <= " + items.append(item) + items.append(self.kind._name_) + if len(self.input_vals) + len(self.immediates) != 0: + items[-1] += "(" + items[-1] += WRAP_POINT + for i, inp in enumerate(self.input_vals): + item = repr(inp) + if i != len(self.input_vals) - 1 or len(self.immediates) != 0: + item += ", " + WRAP_POINT + else: + item += ") " + WRAP_POINT + items.append(item) + for i, imm in enumerate(self.immediates): + item = hex(imm) + if i != len(self.immediates) - 1: + item += ", " + WRAP_POINT + else: + item += ") " + WRAP_POINT + items.append(item) + lines = [] # type: list[str] + for i, line_in in enumerate("".join(items).splitlines()): + if i != 0: + line_in = indent + line_in + line_out = "" + for part in line_in.split(WRAP_POINT): + if line_out == "": + line_out = part + continue + trial_line_out = line_out + part + if len(trial_line_out.rstrip()) > wrap_width: + lines.append(line_out.rstrip()) + line_out = indent + part + else: + line_out = trial_line_out + lines.append(line_out.rstrip()) + return "\n".join(lines) def sim(self, state): # type: (BaseSimState) -> None diff --git a/src/bigint_presentation_code/toom_cook.py b/src/bigint_presentation_code/toom_cook.py index de5b0f6..c261d82 100644 --- a/src/bigint_presentation_code/toom_cook.py +++ b/src/bigint_presentation_code/toom_cook.py @@ -457,7 +457,7 @@ class EvalOpInput(EvalOp): output = cast_to_size( fn=state.fn, ssa_val=inp.ssa_val, src_signed=inp.is_signed, dest_size=output_value_range.output_size, - name="input_{self.part_index}_cast") + name=f"input_{self.part_index}_cast") return EvalOpGenIrOutput(output=output, value_range=output_value_range) @@ -648,10 +648,10 @@ class ToomCookInstance: @plain_data(frozen=True, unsafe_hash=True) @final class PartialProduct: - __slots__ = "ssa_val_spread", "shift_in_words", "is_signed" + __slots__ = "ssa_val_spread", "shift_in_words", "is_signed", "subtract" - def __init__(self, ssa_val_spread, shift_in_words, is_signed): - # type: (Iterable[SSAVal], int, bool) -> None + def __init__(self, ssa_val_spread, shift_in_words, is_signed, subtract): + # type: (Iterable[SSAVal], int, bool, bool) -> None if shift_in_words < 0: raise ValueError("invalid shift_in_words") self.ssa_val_spread = tuple(ssa_val_spread) @@ -660,10 +660,11 @@ class PartialProduct: raise ValueError("invalid ssa_val.ty") self.shift_in_words = shift_in_words self.is_signed = is_signed + self.subtract = subtract -def sum_partial_products(fn, partial_products, name): - # type: (Fn, Iterable[PartialProduct], str) -> SSAVal +def sum_partial_products(fn, partial_products, retval_size, name): + # type: (Fn, Iterable[PartialProduct], int, str) -> SSAVal retval_spread = [] # type: list[SSAVal] retval_signed = False zero = fn.append_new_op(OpKind.LI, immediates=[0], @@ -672,7 +673,8 @@ def sum_partial_products(fn, partial_products, name): for idx, partial_product in enumerate(partial_products): shift_in_words = partial_product.shift_in_words spread = list(partial_product.ssa_val_spread) - if not retval_signed and shift_in_words >= len(retval_spread): + if (not retval_signed and shift_in_words >= len(retval_spread) + and not partial_product.subtract): retval_spread.extend( [zero] * (shift_in_words - len(retval_spread))) retval_spread.extend(spread) @@ -680,10 +682,21 @@ def sum_partial_products(fn, partial_products, name): has_carry_word = False continue assert len(retval_spread) != 0, "logic error" - maxvl = max(len(retval_spread) - shift_in_words, len(spread)) + retval_hi_len = len(retval_spread) - shift_in_words + if retval_hi_len <= len(spread): + maxvl = len(spread) + 1 + has_carry_word = True + elif has_carry_word: + maxvl = retval_hi_len + else: + maxvl = retval_hi_len + 1 + has_carry_word = True if not has_carry_word: maxvl += 1 has_carry_word = True + if maxvl > retval_size - shift_in_words: + maxvl = retval_size - shift_in_words + has_carry_word = False retval_spread = cast_to_size_spread( fn=fn, ssa_vals=retval_spread, src_signed=retval_signed, dest_size=maxvl + shift_in_words, name=f"{name}_{idx}_cast_retval") @@ -699,26 +712,38 @@ def sum_partial_products(fn, partial_products, name): name=f"{name}_{idx}_retval_concat", maxvl=maxvl) pp_concat = fn.append_new_op( kind=OpKind.Concat, - input_vals=[*retval_spread[shift_in_words:], setvl.outputs[0]], + input_vals=[*spread, setvl.outputs[0]], name=f"{name}_{idx}_pp_concat", maxvl=maxvl) - clear_ca = fn.append_new_op(kind=OpKind.ClearCA, - name=f"{name}_{idx}_clear_ca") - add = fn.append_new_op( - kind=OpKind.SvAddE, input_vals=[ - retval_concat.outputs[0], pp_concat.outputs[0], - clear_ca.outputs[0], setvl.outputs[0]], - maxvl=maxvl, name=f"{name}_{idx}_add") + if partial_product.subtract: + set_ca = fn.append_new_op(kind=OpKind.SetCA, + name=f"{name}_{idx}_set_ca") + add_sub = fn.append_new_op( + kind=OpKind.SvSubFE, input_vals=[ + pp_concat.outputs[0], retval_concat.outputs[0], + set_ca.outputs[0], setvl.outputs[0]], + maxvl=maxvl, name=f"{name}_{idx}_sub") + else: + clear_ca = fn.append_new_op(kind=OpKind.ClearCA, + name=f"{name}_{idx}_clear_ca") + add_sub = fn.append_new_op( + kind=OpKind.SvAddE, input_vals=[ + retval_concat.outputs[0], pp_concat.outputs[0], + clear_ca.outputs[0], setvl.outputs[0]], + maxvl=maxvl, name=f"{name}_{idx}_add") retval_spread[shift_in_words:] = fn.append_new_op( kind=OpKind.Spread, - input_vals=[add.outputs[0], setvl.outputs[0]], + input_vals=[add_sub.outputs[0], setvl.outputs[0]], name=f"{name}_{idx}_sum_spread", maxvl=maxvl).outputs + retval_spread = cast_to_size_spread( + fn=fn, ssa_vals=retval_spread, src_signed=retval_signed, + dest_size=retval_size, name=f"{name}_retval_cast") retval_setvl = fn.append_new_op( - OpKind.SetVLI, immediates=[len(retval_spread)], - maxvl=len(retval_spread), name=f"{name}_setvl") + OpKind.SetVLI, immediates=[retval_size], + maxvl=retval_size, name=f"{name}_setvl") retval_concat = fn.append_new_op( kind=OpKind.Concat, input_vals=[*retval_spread, retval_setvl.outputs[0]], - name=f"{name}_concat", maxvl=len(retval_spread)) + name=f"{name}_concat", maxvl=retval_size) return retval_concat.outputs[0] @@ -729,20 +754,20 @@ def simple_mul(fn, lhs, lhs_signed, rhs, rhs_signed, name): lhs, rhs = rhs, lhs lhs_signed, rhs_signed = rhs_signed, lhs_signed # split rhs into elements - rhs_setvl = fn.append_new_op(kind=OpKind.SetVLI, - immediates=[rhs.ty.reg_len], name="rhs_setvl") + rhs_setvl = fn.append_new_op( + kind=OpKind.SetVLI, immediates=[rhs.ty.reg_len], + name=f"{name}_rhs_setvl") rhs_spread = fn.append_new_op( kind=OpKind.Spread, input_vals=[rhs, rhs_setvl.outputs[0]], - maxvl=rhs.ty.reg_len, name="rhs_spread") + maxvl=rhs.ty.reg_len, name=f"{name}_rhs_spread") rhs_words = rhs_spread.outputs zero = fn.append_new_op( kind=OpKind.LI, immediates=[0], name=f"{name}_zero").outputs[0] maxvl = lhs.ty.reg_len lhs_setvl = fn.append_new_op( - kind=OpKind.SetVLI, immediates=[maxvl], name="lhs_setvl", maxvl=maxvl) + kind=OpKind.SetVLI, immediates=[maxvl], name=f"{name}_lhs_setvl", + maxvl=maxvl) vl = lhs_setvl.outputs[0] - if lhs_signed or rhs_signed: - raise NotImplementedError # FIXME: implement signed multiply def partial_products(): # type: () -> Iterable[PartialProduct] @@ -756,9 +781,46 @@ def simple_mul(fn, lhs, lhs_signed, rhs, rhs_signed, name): yield PartialProduct( ssa_val_spread=[*mul_rt_spread.outputs, mul.outputs[1]], shift_in_words=shift_in_words, - is_signed=False) - return sum_partial_products(fn=fn, partial_products=partial_products(), - name=name) + is_signed=False, subtract=False) + if lhs_signed: + lhs_spread = fn.append_new_op( + kind=OpKind.Spread, input_vals=[lhs, lhs_setvl.outputs[0]], + maxvl=lhs.ty.reg_len, name=f"{name}_lhs_spread") + rhs_mask = fn.append_new_op( + kind=OpKind.SRADI, input_vals=[lhs_spread.outputs[-1]], + immediates=[GPR_SIZE_IN_BITS - 1], name=f"{name}_rhs_mask") + lhs_and = fn.append_new_op( + kind=OpKind.SvAndVS, + input_vals=[rhs, rhs_mask.outputs[0], rhs_setvl.outputs[0]], + maxvl=rhs.ty.reg_len, name=f"{name}_rhs_and") + rhs_and_spread = fn.append_new_op( + kind=OpKind.Spread, + input_vals=[lhs_and.outputs[0], rhs_setvl.outputs[0]], + name=f"{name}_rhs_and_spread", maxvl=rhs.ty.reg_len) + yield PartialProduct( + ssa_val_spread=rhs_and_spread.outputs, + shift_in_words=lhs.ty.reg_len, is_signed=False, subtract=True) + if rhs_signed: + rhs_spread = fn.append_new_op( + kind=OpKind.Spread, input_vals=[rhs, rhs_setvl.outputs[0]], + maxvl=rhs.ty.reg_len, name=f"{name}_rhs_spread") + lhs_mask = fn.append_new_op( + kind=OpKind.SRADI, input_vals=[rhs_spread.outputs[-1]], + immediates=[GPR_SIZE_IN_BITS - 1], name=f"{name}_lhs_mask") + rhs_and = fn.append_new_op( + kind=OpKind.SvAndVS, + input_vals=[lhs, lhs_mask.outputs[0], lhs_setvl.outputs[0]], + maxvl=lhs.ty.reg_len, name=f"{name}_lhs_and") + lhs_and_spread = fn.append_new_op( + kind=OpKind.Spread, + input_vals=[rhs_and.outputs[0], lhs_setvl.outputs[0]], + name=f"{name}_lhs_and_spread", maxvl=lhs.ty.reg_len) + yield PartialProduct( + ssa_val_spread=lhs_and_spread.outputs, + shift_in_words=rhs.ty.reg_len, is_signed=False, subtract=True) + return sum_partial_products( + fn=fn, partial_products=partial_products(), + retval_size=lhs.ty.reg_len + rhs.ty.reg_len, name=name) def cast_to_size(fn, ssa_val, src_signed, dest_size, name): -- 2.30.2