From: Jacob Lifshay Date: Fri, 28 Oct 2022 09:24:23 +0000 (-0700) Subject: working on rewriting compiler ir to fix reg alloc issues X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=861dc2996a6a9e6feb25ca384a8fa0d44982d80e;p=bigint-presentation-code.git working on rewriting compiler ir to fix reg alloc issues --- diff --git a/src/bigint_presentation_code/_tests/__init__.py b/src/bigint_presentation_code/_tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/bigint_presentation_code/_tests/test_compiler_ir.py b/src/bigint_presentation_code/_tests/test_compiler_ir.py new file mode 100644 index 0000000..820c305 --- /dev/null +++ b/src/bigint_presentation_code/_tests/test_compiler_ir.py @@ -0,0 +1,120 @@ +import unittest + +from bigint_presentation_code.compiler_ir import (VL, FixedGPRRangeType, Fn, + GlobalMem, GPRRange, GPRType, + OpBigIntAddSub, OpConcat, + OpCopy, OpFuncArg, + OpInputMem, OpLI, OpLoad, + OpSetCA, OpSetVLImm, OpStore, + RegLoc, SSAVal, XERBit, + generate_assembly, + op_set_to_list) + + +class TestCompilerIR(unittest.TestCase): + maxDiff = None + + def test_op_set_to_list(self): + fn = Fn() + op0 = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3))) + op1 = OpCopy(fn, op0.out, GPRType()) + arg = op1.dest + op2 = OpInputMem(fn) + mem = op2.out + op3 = OpSetVLImm(fn, 32) + vl = op3.out + op4 = OpLoad(fn, arg, offset=0, mem=mem, vl=vl) + a = op4.RT + op5 = OpLI(fn, 1) + b_0 = op5.out + op6 = OpSetVLImm(fn, 31) + vl = op6.out + op7 = OpLI(fn, 0, vl=vl) + b_rest = op7.out + op8 = OpConcat(fn, [b_0, b_rest]) + b = op8.dest + op9 = OpSetVLImm(fn, 32) + vl = op9.out + op10 = OpSetCA(fn, False) + ca = op10.out + op11 = OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl) + s = op11.out + op12 = OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl) + mem = op12.mem_out + + expected_ops = [ + op10, # OpSetCA(fn, False) + op9, # OpSetVLImm(fn, 32) + op6, # OpSetVLImm(fn, 31) + op5, # OpLI(fn, 1) + op3, # OpSetVLImm(fn, 32) + op2, # OpInputMem(fn) + op0, # OpFuncArg(fn, FixedGPRRangeType(GPRRange(3))) + op7, # OpLI(fn, 0, vl=vl) + op1, # OpCopy(fn, op0.out, GPRType()) + op8, # OpConcat(fn, [b_0, b_rest]) + op4, # OpLoad(fn, arg, offset=0, mem=mem, vl=vl) + op11, # OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl) + op12, # OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl) + ] + + ops = op_set_to_list(fn.ops[::-1]) + if ops != expected_ops: + self.assertEqual(repr(ops), repr(expected_ops)) + + def tst_generate_assembly(self, use_reg_alloc=False): + fn = Fn() + op0 = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3))) + op1 = OpCopy(fn, op0.out, GPRType()) + arg = op1.dest + op2 = OpInputMem(fn) + mem = op2.out + op3 = OpSetVLImm(fn, 32) + vl = op3.out + op4 = OpLoad(fn, arg, offset=0, mem=mem, vl=vl) + a = op4.RT + op5 = OpLI(fn, 0, vl=vl) + b = op5.out + op6 = OpSetCA(fn, True) + ca = op6.out + op7 = OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl) + s = op7.out + op8 = OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl) + mem = op8.mem_out + + assigned_registers = { + op0.out: GPRRange(start=3, length=1), + op1.dest: GPRRange(start=3, length=1), + op2.out: GlobalMem.GlobalMem, + op3.out: VL.VL_MAXVL, + op4.RT: GPRRange(start=78, length=32), + op5.out: GPRRange(start=46, length=32), + op6.out: XERBit.CA, + op7.out: GPRRange(start=14, length=32), + op7.CA_out: XERBit.CA, + op8.mem_out: GlobalMem.GlobalMem, + } # type: dict[SSAVal, RegLoc] | None + + if use_reg_alloc: + assigned_registers = None + + asm = generate_assembly(fn.ops, assigned_registers) + self.assertEqual(asm, [ + "setvl 0, 0, 32, 0, 1, 1", + "sv.ld *78, 0(3)", + "sv.addi *46, 0, 0", + "subfic 0, 0, -1", + "sv.adde *14, *78, *46", + "sv.std *14, 0(3)", + "bclr 20, 0, 0", + ]) + + def test_generate_assembly(self): + self.tst_generate_assembly() + + def test_generate_assembly_with_register_allocator(self): + self.tst_generate_assembly(use_reg_alloc=True) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/bigint_presentation_code/_tests/test_matrix.py b/src/bigint_presentation_code/_tests/test_matrix.py new file mode 100644 index 0000000..1a56df0 --- /dev/null +++ b/src/bigint_presentation_code/_tests/test_matrix.py @@ -0,0 +1,113 @@ +import unittest +from fractions import Fraction + +from bigint_presentation_code.matrix import Matrix, SpecialMatrix + + +class TestMatrix(unittest.TestCase): + def test_repr(self): + self.assertEqual(repr(Matrix(2, 3, [0, 1, 2, + 3, 4, 5])), + 'Matrix(height=2, width=3, data=[\n' + ' 0, 1, 2,\n' + ' 3, 4, 5,\n' + '])') + self.assertEqual(repr(Matrix(2, 3, [0, 1, Fraction(2) / 3, + 3, 4, 5])), + 'Matrix(height=2, width=3, data=[\n' + ' 0, 1, Fraction(2, 3),\n' + ' 3, 4, 5,\n' + '])') + self.assertEqual(repr(Matrix(0, 3)), 'Matrix(height=0, width=3)') + self.assertEqual(repr(Matrix(2, 0)), 'Matrix(height=2, width=0)') + + def test_eq(self): + self.assertFalse(Matrix(1, 1) == 5) + self.assertFalse(5 == Matrix(1, 1)) + self.assertFalse(Matrix(2, 1) == Matrix(1, 1)) + self.assertFalse(Matrix(1, 2) == Matrix(1, 1)) + self.assertTrue(Matrix(1, 1) == Matrix(1, 1)) + self.assertTrue(Matrix(1, 1, [1]) == Matrix(1, 1, [1])) + self.assertFalse(Matrix(1, 1, [2]) == Matrix(1, 1, [1])) + + def test_add(self): + self.assertEqual(Matrix(2, 2, [1, 2, 3, 4]) + + Matrix(2, 2, [40, 30, 20, 10]), + Matrix(2, 2, [41, 32, 23, 14])) + + def test_identity(self): + self.assertEqual(Matrix(2, 2, data=SpecialMatrix.Identity), + Matrix(2, 2, [1, 0, + 0, 1])) + self.assertEqual(Matrix(1, 3, data=SpecialMatrix.Identity), + Matrix(1, 3, [1, 0, 0])) + self.assertEqual(Matrix(2, 3, data=SpecialMatrix.Identity), + Matrix(2, 3, [1, 0, 0, + 0, 1, 0])) + self.assertEqual(Matrix(3, 3, data=SpecialMatrix.Identity), + Matrix(3, 3, [1, 0, 0, + 0, 1, 0, + 0, 0, 1])) + + def test_sub(self): + self.assertEqual(Matrix(2, 2, [40, 30, 20, 10]) + - Matrix(2, 2, [-1, -2, -3, -4]), + Matrix(2, 2, [41, 32, 23, 14])) + + def test_neg(self): + self.assertEqual(-Matrix(2, 2, [40, 30, 20, 10]), + Matrix(2, 2, [-40, -30, -20, -10])) + + def test_mul(self): + self.assertEqual(Matrix(2, 2, [1, 2, 3, 4]) * Fraction(3, 2), + Matrix(2, 2, [Fraction(3, 2), 3, Fraction(9, 2), 6])) + self.assertEqual(Fraction(3, 2) * Matrix(2, 2, [1, 2, 3, 4]), + Matrix(2, 2, [Fraction(3, 2), 3, Fraction(9, 2), 6])) + + def test_matmul(self): + self.assertEqual(Matrix(2, 2, [1, 2, 3, 4]) + @ Matrix(2, 2, [4, 3, 2, 1]), + Matrix(2, 2, [8, 5, 20, 13])) + self.assertEqual(Matrix(3, 2, [6, 5, 4, 3, 2, 1]) + @ Matrix(2, 1, [1, 2]), + Matrix(3, 1, [16, 10, 4])) + + def test_inverse(self): + self.assertEqual(Matrix(0, 0).inverse(), Matrix(0, 0)) + self.assertEqual(Matrix(1, 1, [2]).inverse(), + Matrix(1, 1, [Fraction(1, 2)])) + self.assertEqual(Matrix(1, 1, [1]).inverse(), + Matrix(1, 1, [1])) + self.assertEqual(Matrix(2, 2, [1, 0, 1, 1]).inverse(), + Matrix(2, 2, [1, 0, -1, 1])) + self.assertEqual(Matrix(3, 3, [0, 1, 0, + 1, 0, 0, + 0, 0, 1]).inverse(), + Matrix(3, 3, [0, 1, 0, + 1, 0, 0, + 0, 0, 1])) + _1_2 = Fraction(1, 2) + _1_3 = Fraction(1, 3) + _1_6 = Fraction(1, 6) + self.assertEqual(Matrix(5, 5, [1, 0, 0, 0, 0, + 1, 1, 1, 1, 1, + 1, -1, 1, -1, 1, + 1, -2, 4, -8, 16, + 0, 0, 0, 0, 1]).inverse(), + Matrix(5, 5, [1, 0, 0, 0, 0, + _1_2, _1_3, -1, _1_6, -2, + -1, _1_2, _1_2, 0, -1, + -_1_2, _1_6, _1_2, -_1_6, 2, + 0, 0, 0, 0, 1])) + with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"): + Matrix(1, 1, [0]).inverse() + with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"): + Matrix(2, 2, [0, 0, 1, 1]).inverse() + with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"): + Matrix(2, 2, [1, 0, 1, 0]).inverse() + with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"): + Matrix(2, 2, [1, 1, 1, 1]).inverse() + + +if __name__ == "__main__": + unittest.main() diff --git a/src/bigint_presentation_code/_tests/test_register_allocator.py b/src/bigint_presentation_code/_tests/test_register_allocator.py new file mode 100644 index 0000000..1eff254 --- /dev/null +++ b/src/bigint_presentation_code/_tests/test_register_allocator.py @@ -0,0 +1,201 @@ +import unittest + +from bigint_presentation_code.compiler_ir import (VL, FixedGPRRangeType, Fn, + GlobalMem, GPRRange, GPRType, + OpBigIntAddSub, OpConcat, + OpCopy, OpFuncArg, + OpInputMem, OpLI, OpLoad, + OpSetCA, OpSetVLImm, OpStore, + XERBit) +from bigint_presentation_code.register_allocator import ( + AllocationFailed, MergedRegSet, allocate_registers, + try_allocate_registers_without_spilling) + + +class TestMergedRegSet(unittest.TestCase): + maxDiff = None + + def test_from_equality_constraint(self): + fn = Fn() + li0x1 = OpLI(fn, 0, vl=OpSetVLImm(fn, 1).out) + li0x2 = OpLI(fn, 0, vl=OpSetVLImm(fn, 2).out) + li0x3 = OpLI(fn, 0, vl=OpSetVLImm(fn, 3).out) + self.assertEqual(MergedRegSet.from_equality_constraint([ + li0x1.out, + li0x2.out, + li0x3.out, + ]), MergedRegSet({ + li0x1.out: 0, + li0x2.out: 1, + li0x3.out: 3, + }.items())) + self.assertEqual(MergedRegSet.from_equality_constraint([ + li0x2.out, + li0x1.out, + li0x3.out, + ]), MergedRegSet({ + li0x2.out: 0, + li0x1.out: 2, + li0x3.out: 3, + }.items())) + + +class TestRegisterAllocator(unittest.TestCase): + maxDiff = None + + def test_try_alloc_fail(self): + fn = Fn() + op0 = OpSetVLImm(fn, 52) + op1 = OpLI(fn, 0, vl=op0.out) + op2 = OpSetVLImm(fn, 64) + op3 = OpLI(fn, 0, vl=op2.out) + op4 = OpConcat(fn, [op1.out, op3.out]) + + reg_assignments = try_allocate_registers_without_spilling(fn.ops) + self.assertEqual( + repr(reg_assignments), + "AllocationFailed(" + "node=IGNode(#0, merged_reg_set=MergedRegSet([" + "(<#4.dest: >, 0), " + "(<#1.out: >, 0), " + "(<#3.out: >, 52)]), " + "edges={}, reg=None), " + "live_intervals=LiveIntervals(live_intervals={" + "MergedRegSet([(<#0.out: KnownVLType(length=52)>, 0)]): " + "LiveInterval(first_write=0, last_use=1), " + "MergedRegSet([(<#4.dest: >, 0), " + "(<#1.out: >, 0), " + "(<#3.out: >, 52)]): " + "LiveInterval(first_write=1, last_use=4), " + "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)]): " + "LiveInterval(first_write=2, last_use=3)}, " + "merged_reg_sets=MergedRegSets(data={" + "<#0.out: KnownVLType(length=52)>: " + "MergedRegSet([(<#0.out: KnownVLType(length=52)>, 0)]), " + "<#1.out: >: MergedRegSet([" + "(<#4.dest: >, 0), " + "(<#1.out: >, 0), " + "(<#3.out: >, 52)]), " + "<#2.out: KnownVLType(length=64)>: " + "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)]), " + "<#3.out: >: MergedRegSet([" + "(<#4.dest: >, 0), " + "(<#1.out: >, 0), " + "(<#3.out: >, 52)]), " + "<#4.dest: >: MergedRegSet([" + "(<#4.dest: >, 0), " + "(<#1.out: >, 0), " + "(<#3.out: >, 52)])}), " + "reg_sets_live_after={" + "0: OFSet([MergedRegSet([" + "(<#0.out: KnownVLType(length=52)>, 0)])]), " + "1: OFSet([MergedRegSet([" + "(<#4.dest: >, 0), " + "(<#1.out: >, 0), " + "(<#3.out: >, 52)])]), " + "2: OFSet([MergedRegSet([" + "(<#4.dest: >, 0), " + "(<#1.out: >, 0), " + "(<#3.out: >, 52)]), " + "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)])]), " + "3: OFSet([MergedRegSet([" + "(<#4.dest: >, 0), " + "(<#1.out: >, 0), " + "(<#3.out: >, 52)])]), " + "4: OFSet()}), " + "interference_graph=InterferenceGraph(nodes={" + "...: IGNode(#0, merged_reg_set=MergedRegSet([" + "(<#0.out: KnownVLType(length=52)>, 0)]), edges={}, reg=None), " + "...: IGNode(#1, merged_reg_set=MergedRegSet([" + "(<#4.dest: >, 0), " + "(<#1.out: >, 0), " + "(<#3.out: >, 52)]), edges={}, reg=None), " + "...: IGNode(#2, merged_reg_set=MergedRegSet([" + "(<#2.out: KnownVLType(length=64)>, 0)]), edges={}, reg=None)}))" + ) + + def test_try_alloc_bigint_inc(self): + fn = Fn() + op0 = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3))) + op1 = OpCopy(fn, op0.out, GPRType()) + arg = op1.dest + op2 = OpInputMem(fn) + mem = op2.out + op3 = OpSetVLImm(fn, 32) + vl = op3.out + op4 = OpLoad(fn, arg, offset=0, mem=mem, vl=vl) + a = op4.RT + op5 = OpLI(fn, 0, vl=vl) + b = op5.out + op6 = OpSetCA(fn, True) + ca = op6.out + op7 = OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl) + s = op7.out + op8 = OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl) + mem = op8.mem_out + + reg_assignments = try_allocate_registers_without_spilling(fn.ops) + + expected_reg_assignments = { + op0.out: GPRRange(start=3, length=1), + op1.dest: GPRRange(start=3, length=1), + op2.out: GlobalMem.GlobalMem, + op3.out: VL.VL_MAXVL, + op4.RT: GPRRange(start=78, length=32), + op5.out: GPRRange(start=46, length=32), + op6.out: XERBit.CA, + op7.out: GPRRange(start=14, length=32), + op7.CA_out: XERBit.CA, + op8.mem_out: GlobalMem.GlobalMem, + } + + self.assertEqual(reg_assignments, expected_reg_assignments) + + def tst_try_alloc_concat(self, expected_regs, expected_dest_reg): + # type: (list[GPRRange], GPRRange) -> None + fn = Fn() + inputs = [] + expected_reg_assignments = {} + for i, r in enumerate(expected_regs): + vl = OpSetVLImm(fn, r.length).out + expected_reg_assignments[vl] = VL.VL_MAXVL + inp = OpLI(fn, i, vl=vl).out + inputs.append(inp) + expected_reg_assignments[inp] = r + concat = OpConcat(fn, inputs) + expected_reg_assignments[concat.dest] = expected_dest_reg + + reg_assignments = try_allocate_registers_without_spilling(fn.ops) + + for inp, reg in zip(inputs, expected_regs): + expected_reg_assignments[inp] = reg + + self.assertEqual(reg_assignments, expected_reg_assignments) + + def test_try_alloc_concat_1(self): + self.tst_try_alloc_concat([GPRRange(3)], GPRRange(3)) + + def test_try_alloc_concat_3(self): + self.tst_try_alloc_concat([GPRRange(3, 3)], GPRRange(3, 3)) + + def test_try_alloc_concat_3_5(self): + self.tst_try_alloc_concat([GPRRange(3, 3), GPRRange(6, 5)], + GPRRange(3, 8)) + + def test_try_alloc_concat_5_3(self): + self.tst_try_alloc_concat([GPRRange(3, 5), GPRRange(8, 3)], + GPRRange(3, 8)) + + def test_try_alloc_concat_1_2_3_4_5_6(self): + self.tst_try_alloc_concat([ + GPRRange(14, 1), + GPRRange(15, 2), + GPRRange(17, 3), + GPRRange(20, 4), + GPRRange(24, 5), + GPRRange(29, 6), + ], GPRRange(14, 21)) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/bigint_presentation_code/_tests/test_toom_cook.py b/src/bigint_presentation_code/_tests/test_toom_cook.py new file mode 100644 index 0000000..6fff570 --- /dev/null +++ b/src/bigint_presentation_code/_tests/test_toom_cook.py @@ -0,0 +1,378 @@ +import unittest + +from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BYTES, SSAGPR, VL, FixedGPRRangeType, Fn, + GlobalMem, GPRRange, + GPRRangeType, OpCopy, + OpFuncArg, OpInputMem, + OpSetVLImm, OpStore, PreRASimState, SSAGPRRange, XERBit, + generate_assembly) +from bigint_presentation_code.register_allocator import allocate_registers +from bigint_presentation_code.toom_cook import ToomCookInstance, simple_mul +from bigint_presentation_code.util import FMap + + +class SimpleMul192x192: + def __init__(self): + self.fn = fn = Fn() + self.mem_in = mem = OpInputMem(fn).out + self.dest_ptr_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3))).out + self.lhs_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(4, 3))).out + self.rhs_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(7, 3))).out + dest_ptr = OpCopy(fn, self.dest_ptr_in, GPRRangeType()).dest + vl = OpSetVLImm(fn, 3).out + lhs = OpCopy(fn, self.lhs_in, GPRRangeType(3), vl=vl).dest + rhs = OpCopy(fn, self.rhs_in, GPRRangeType(3), vl=vl).dest + retval = simple_mul(fn, lhs, rhs) + vl = OpSetVLImm(fn, 6).out + self.mem_out = OpStore(fn, RS=retval, RA=dest_ptr, offset=0, + mem_in=mem, vl=vl).mem_out + + +class TestToomCook(unittest.TestCase): + maxDiff = None + + def test_toom_2_repr(self): + TOOM_2 = ToomCookInstance.make_toom_2() + # print(repr(repr(TOOM_2))) + self.assertEqual( + repr(TOOM_2), + "ToomCookInstance(lhs_part_count=2, rhs_part_count=2, " + "eval_points=(0, 1, POINT_AT_INFINITY), " + "lhs_eval_ops=(" + "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " + "EvalOpAdd(lhs=" + "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " + "rhs=" + "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), " + "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), " + "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})))," + " rhs_eval_ops=(" + "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " + "EvalOpAdd(lhs=" + "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " + "rhs=" + "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), " + "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), " + "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})))," + " prod_eval_ops=(" + "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " + "EvalOpSub(lhs=" + "EvalOpSub(lhs=" + "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), " + "rhs=" + "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " + "poly=EvalOpPoly({0: Fraction(-1, 1), 1: Fraction(1, 1)})), " + "rhs=" + "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), " + "poly=EvalOpPoly({" + "0: Fraction(-1, 1), 1: Fraction(1, 1), 2: Fraction(-1, 1)})), " + "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)}))))" + ) + + def test_toom_2_5_repr(self): + TOOM_2_5 = ToomCookInstance.make_toom_2_5() + # print(repr(repr(TOOM_2_5))) + self.assertEqual( + repr(TOOM_2_5), + "ToomCookInstance(lhs_part_count=3, rhs_part_count=2, " + "eval_points=(0, 1, -1, POINT_AT_INFINITY), lhs_eval_ops=(" + "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " + "EvalOpAdd(lhs=" + "EvalOpAdd(lhs=" + "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " + "rhs=EvalOpInput(lhs=2, rhs=0, " + "poly=EvalOpPoly({2: Fraction(1, 1)})), " + "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), " + "rhs=EvalOpInput(lhs=1, rhs=0, " + "poly=EvalOpPoly({1: Fraction(1, 1)})), " + "poly=EvalOpPoly({" + "0: Fraction(1, 1), 1: Fraction(1, 1), 2: Fraction(1, 1)})), " + "EvalOpSub(lhs=" + "EvalOpAdd(lhs=" + "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " + "rhs=EvalOpInput(lhs=2, rhs=0, " + "poly=EvalOpPoly({2: Fraction(1, 1)})), " + "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), " + "rhs=EvalOpInput(lhs=1, rhs=0, " + "poly=EvalOpPoly({1: Fraction(1, 1)})), poly=EvalOpPoly(" + "{0: Fraction(1, 1), 1: Fraction(-1, 1), 2: Fraction(1, 1)})), " + "EvalOpInput(lhs=2, rhs=0, " + "poly=EvalOpPoly({2: Fraction(1, 1)}))), rhs_eval_ops=(" + "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " + "EvalOpAdd(lhs=EvalOpInput(lhs=0, rhs=0, " + "poly=EvalOpPoly({0: Fraction(1, 1)})), rhs=" + "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), " + "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), " + "EvalOpSub(lhs=" + "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " + "rhs=EvalOpInput(lhs=1, rhs=0, " + "poly=EvalOpPoly({1: Fraction(1, 1)})), " + "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(-1, 1)})), " + "EvalOpInput(lhs=1, rhs=0, " + "poly=EvalOpPoly({1: Fraction(1, 1)}))), " + "prod_eval_ops=(" + "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " + "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpSub(lhs=" + "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), " + "rhs=EvalOpInput(lhs=2, rhs=0, " + "poly=EvalOpPoly({2: Fraction(1, 1)})), " + "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(-1, 1)})), " + "rhs=2, " + "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(-1, 2)})), rhs=" + "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)})), " + "poly=EvalOpPoly(" + "{1: Fraction(1, 2), 2: Fraction(-1, 2), 3: Fraction(-1, 1)})), " + "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpAdd(lhs=" + "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), " + "rhs=" + "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), " + "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(1, 1)})), rhs=2, " + "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(1, 2)})), rhs=" + "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " + "poly=EvalOpPoly(" + "{0: Fraction(-1, 1), 1: Fraction(1, 2), 2: Fraction(1, 2)})), " + "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)}))))" + ) + + def test_reversed_toom_2_5_repr(self): + TOOM_2_5 = ToomCookInstance.make_toom_2_5().reversed() + # print(repr(repr(TOOM_2_5))) + self.assertEqual( + repr(TOOM_2_5), + "ToomCookInstance(lhs_part_count=2, rhs_part_count=3, " + "eval_points=(0, 1, -1, POINT_AT_INFINITY), lhs_eval_ops=(" + "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " + "EvalOpAdd(lhs=" + "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " + "rhs=" + "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), " + "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), " + "EvalOpSub(lhs=" + "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " + "rhs=" + "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), " + "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(-1, 1)})), " + "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})))," + " rhs_eval_ops=(" + "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " + "EvalOpAdd(lhs=EvalOpAdd(lhs=" + "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " + "rhs=" + "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), " + "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), rhs=" + "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), " + "poly=EvalOpPoly(" + "{0: Fraction(1, 1), 1: Fraction(1, 1), 2: Fraction(1, 1)})), " + "EvalOpSub(lhs=EvalOpAdd(lhs=" + "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " + "rhs=" + "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), " + "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), rhs=" + "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), " + "poly=EvalOpPoly(" + "{0: Fraction(1, 1), 1: Fraction(-1, 1), 2: Fraction(1, 1)})), " + "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})))," + " prod_eval_ops=(" + "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " + "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpSub(lhs=" + "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), " + "rhs=" + "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), " + "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(-1, 1)})), " + "rhs=2, " + "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(-1, 2)})), rhs=" + "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)})), " + "poly=EvalOpPoly(" + "{1: Fraction(1, 2), 2: Fraction(-1, 2), 3: Fraction(-1, 1)})), " + "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpAdd(lhs=" + "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), " + "rhs=" + "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), " + "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(1, 1)})), rhs=2, " + "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(1, 2)})), rhs=" + "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " + "poly=EvalOpPoly(" + "{0: Fraction(-1, 1), 1: Fraction(1, 2), 2: Fraction(1, 2)})), " + "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)}))))" + ) + + def test_simple_mul_192x192_pre_ra_sim(self): + # test multiplying: + # 0x000191acb262e15b_4c6b5f2b19e1a53e_821a2342132c5b57 + # * 0x4a37c0567bcbab53_cf1f597598194ae6_208a49071aeec507 + # == + # int("0x00074736574206e_6f69746163696c70" + # "_69746c756d207469_622d3438333e2d32" + # "_3931783239312079_7261727469627261", base=0) + # == int.from_bytes(b"arbitrary 192x192->384-bit multiplication test", + # 'little') + code = SimpleMul192x192() + dest_ptr = 0x100 + state = PreRASimState( + gprs={}, VLs={}, CAs={}, global_mems={code.mem_in: FMap()}, + stack_slots={}, fixed_gprs={ + code.dest_ptr_in: (dest_ptr,), + code.lhs_in: (0x821a2342132c5b57, 0x4c6b5f2b19e1a53e, + 0x000191acb262e15b), + code.rhs_in: (0x208a49071aeec507, 0xcf1f597598194ae6, + 0x4a37c0567bcbab53) + }) + code.fn.pre_ra_sim(state) + expected_bytes = b"arbitrary 192x192->384-bit multiplication test" + OUT_BYTE_COUNT = 6 * GPR_SIZE_IN_BYTES + expected_bytes = expected_bytes.ljust(OUT_BYTE_COUNT, b'\0') + mem_out = state.global_mems[code.mem_out] + out_bytes = bytes( + mem_out.get(dest_ptr + i, 0) for i in range(OUT_BYTE_COUNT)) + self.assertEqual(out_bytes, expected_bytes) + + def test_simple_mul_192x192_ops(self): + code = SimpleMul192x192() + fn = code.fn + self.assertEqual([repr(v) for v in fn.ops], [ + 'OpInputMem(#0, <#0.out: GlobalMemType()>)', + 'OpFuncArg(#1, <#1.out: )>>)', + 'OpFuncArg(#2, <#2.out: )>>)', + 'OpFuncArg(#3, <#3.out: )>>)', + 'OpCopy(#4, <#4.dest: >, src=<#1.out: )>>, ' + 'vl=None)', + 'OpSetVLImm(#5, <#5.out: KnownVLType(length=3)>)', + 'OpCopy(#6, <#6.dest: >, ' + 'src=<#2.out: )>>, ' + 'vl=<#5.out: KnownVLType(length=3)>)', + 'OpCopy(#7, <#7.dest: >, ' + 'src=<#3.out: )>>, ' + 'vl=<#5.out: KnownVLType(length=3)>)', + 'OpSplit(#8, results=(<#8.results[0]: >, ' + '<#8.results[1]: >, <#8.results[2]: >), ' + 'src=<#7.dest: >)', + 'OpSetVLImm(#9, <#9.out: KnownVLType(length=3)>)', + 'OpLI(#10, <#10.out: >, value=0, vl=None)', + 'OpBigIntMulDiv(#11, <#11.RT: >, ' + 'RA=<#6.dest: >, RB=<#8.results[0]: >, ' + 'RC=<#10.out: >, <#11.RS: >, is_div=False, ' + 'vl=<#9.out: KnownVLType(length=3)>)', + 'OpConcat(#12, <#12.dest: >, sources=(' + '<#11.RT: >, <#11.RS: >))', + 'OpBigIntMulDiv(#13, <#13.RT: >, ' + 'RA=<#6.dest: >, RB=<#8.results[1]: >, ' + 'RC=<#10.out: >, <#13.RS: >, is_div=False, ' + 'vl=<#9.out: KnownVLType(length=3)>)', + 'OpSplit(#14, results=(<#14.results[0]: >, ' + '<#14.results[1]: >), src=<#12.dest: >)', + 'OpSetCA(#15, <#15.out: CAType()>, value=False)', + 'OpBigIntAddSub(#16, <#16.out: >, ' + 'lhs=<#13.RT: >, rhs=<#14.results[1]: >, ' + 'CA_in=<#15.out: CAType()>, <#16.CA_out: CAType()>, is_sub=False, ' + 'vl=<#9.out: KnownVLType(length=3)>)', + 'OpBigIntAddSub(#17, <#17.out: >, ' + 'lhs=<#13.RS: >, rhs=<#10.out: >, ' + 'CA_in=<#16.CA_out: CAType()>, <#17.CA_out: CAType()>, ' + 'is_sub=False, vl=None)', + 'OpConcat(#18, <#18.dest: >, sources=(' + '<#14.results[0]: >, <#16.out: >, ' + '<#17.out: >))', + 'OpBigIntMulDiv(#19, <#19.RT: >, ' + 'RA=<#6.dest: >, RB=<#8.results[2]: >, ' + 'RC=<#10.out: >, <#19.RS: >, is_div=False, ' + 'vl=<#9.out: KnownVLType(length=3)>)', + 'OpSplit(#20, results=(<#20.results[0]: >, ' + '<#20.results[1]: >), src=<#18.dest: >)', + 'OpSetCA(#21, <#21.out: CAType()>, value=False)', + 'OpBigIntAddSub(#22, <#22.out: >, ' + 'lhs=<#19.RT: >, rhs=<#20.results[1]: >, ' + 'CA_in=<#21.out: CAType()>, <#22.CA_out: CAType()>, is_sub=False, ' + 'vl=<#9.out: KnownVLType(length=3)>)', + 'OpBigIntAddSub(#23, <#23.out: >, ' + 'lhs=<#19.RS: >, rhs=<#10.out: >, ' + 'CA_in=<#22.CA_out: CAType()>, <#23.CA_out: CAType()>, ' + 'is_sub=False, vl=None)', + 'OpConcat(#24, <#24.dest: >, sources=(' + '<#20.results[0]: >, <#22.out: >, ' + '<#23.out: >))', + 'OpSetVLImm(#25, <#25.out: KnownVLType(length=6)>)', + 'OpStore(#26, RS=<#24.dest: >, ' + 'RA=<#4.dest: >, offset=0, ' + 'mem_in=<#0.out: GlobalMemType()>, ' + '<#26.mem_out: GlobalMemType()>, ' + 'vl=<#25.out: KnownVLType(length=6)>)' + ]) + + # FIXME: register allocator currently allocates wrong registers + @unittest.expectedFailure + def test_simple_mul_192x192_reg_alloc(self): + code = SimpleMul192x192() + fn = code.fn + assigned_registers = allocate_registers(fn.ops) + self.assertEqual(assigned_registers, { + fn.ops[13].RS: GPRRange(9), # type: ignore + fn.ops[14].results[0]: GPRRange(6), # type: ignore + fn.ops[14].results[1]: GPRRange(7, length=3), # type: ignore + fn.ops[15].out: XERBit.CA, # type: ignore + fn.ops[16].out: GPRRange(7, length=3), # type: ignore + fn.ops[16].CA_out: XERBit.CA, # type: ignore + fn.ops[17].out: GPRRange(10), # type: ignore + fn.ops[17].CA_out: XERBit.CA, # type: ignore + fn.ops[18].dest: GPRRange(6, length=5), # type: ignore + fn.ops[19].RT: GPRRange(3, length=3), # type: ignore + fn.ops[19].RS: GPRRange(9), # type: ignore + fn.ops[20].results[0]: GPRRange(6, length=2), # type: ignore + fn.ops[20].results[1]: GPRRange(8, length=3), # type: ignore + fn.ops[21].out: XERBit.CA, # type: ignore + fn.ops[22].out: GPRRange(8, length=3), # type: ignore + fn.ops[22].CA_out: XERBit.CA, # type: ignore + fn.ops[23].out: GPRRange(11), # type: ignore + fn.ops[23].CA_out: XERBit.CA, # type: ignore + fn.ops[24].dest: GPRRange(6, length=6), # type: ignore + fn.ops[25].out: VL.VL_MAXVL, # type: ignore + fn.ops[26].mem_out: GlobalMem.GlobalMem, # type: ignore + fn.ops[0].out: GlobalMem.GlobalMem, # type: ignore + fn.ops[1].out: GPRRange(3), # type: ignore + fn.ops[2].out: GPRRange(4, length=3), # type: ignore + fn.ops[3].out: GPRRange(7, length=3), # type: ignore + fn.ops[4].dest: GPRRange(12), # type: ignore + fn.ops[5].out: VL.VL_MAXVL, # type: ignore + fn.ops[6].dest: GPRRange(17, length=3), # type: ignore + fn.ops[7].dest: GPRRange(14, length=3), # type: ignore + fn.ops[8].results[0]: GPRRange(14), # type: ignore + fn.ops[8].results[1]: GPRRange(15), # type: ignore + fn.ops[8].results[2]: GPRRange(16), # type: ignore + fn.ops[9].out: VL.VL_MAXVL, # type: ignore + fn.ops[10].out: GPRRange(9), # type: ignore + fn.ops[11].RT: GPRRange(6, length=3), # type: ignore + fn.ops[11].RS: GPRRange(9), # type: ignore + fn.ops[12].dest: GPRRange(6, length=4), # type: ignore + fn.ops[13].RT: GPRRange(3, length=3) # type: ignore + }) + self.fail("register allocator currently allocates wrong registers") + + # FIXME: register allocator currently allocates wrong registers + @unittest.expectedFailure + def test_simple_mul_192x192_asm(self): + code = SimpleMul192x192() + asm = generate_assembly(code.fn.ops) + self.assertEqual(asm, [ + 'or 12, 3, 3', + 'setvl 0, 0, 3, 0, 1, 1', + 'sv.or *17, *4, *4', + 'sv.or *14, *7, *7', + 'setvl 0, 0, 3, 0, 1, 1', + 'addi 9, 0, 0', + 'sv.maddedu *6, *17, 14, 9', + 'sv.maddedu *3, *17, 15, 9', + 'addic 0, 0, 0', + 'sv.adde *7, *3, *7', + 'adde 10, 9, 9', + 'sv.maddedu *3, *17, 16, 9', + 'addic 0, 0, 0', + 'sv.adde *8, *3, *8', + 'adde 11, 9, 9', + 'setvl 0, 0, 6, 0, 1, 1', + 'sv.std *6, 0(12)', + 'bclr 20, 0, 0' + ]) + self.fail("register allocator currently allocates wrong registers") + + +if __name__ == "__main__": + unittest.main() diff --git a/src/bigint_presentation_code/compiler_ir.py b/src/bigint_presentation_code/compiler_ir.py index 77e44a2..c574174 100644 --- a/src/bigint_presentation_code/compiler_ir.py +++ b/src/bigint_presentation_code/compiler_ir.py @@ -1,3 +1,4 @@ +# type: ignore """ Compiler IR for Toom-Cook algorithm generator for SVP64 @@ -12,7 +13,8 @@ from typing import Any, Generic, Iterable, Sequence, Type, TypeVar, cast from nmutil.plain_data import fields, plain_data -from bigint_presentation_code.util import FMap, OFSet, OSet, final +from bigint_presentation_code.type_util import final +from bigint_presentation_code.util import FMap, OFSet, OSet class ABCEnumMeta(EnumMeta, ABCMeta): diff --git a/src/bigint_presentation_code/compiler_ir2.py b/src/bigint_presentation_code/compiler_ir2.py index eacceb4..666df14 100644 --- a/src/bigint_presentation_code/compiler_ir2.py +++ b/src/bigint_presentation_code/compiler_ir2.py @@ -1,20 +1,23 @@ +from collections import defaultdict import enum from enum import Enum, unique -from typing import AbstractSet, Iterable, Iterator, NoReturn, Tuple, Union, overload +from typing import AbstractSet, Any, Iterable, Iterator, NoReturn, Tuple, Union, Mapping, overload +from weakref import WeakValueDictionary as _WeakVDict from cached_property import cached_property from nmutil.plain_data import plain_data -from bigint_presentation_code.util import OFSet, OSet, Self, assert_never, final -from weakref import WeakValueDictionary +from bigint_presentation_code.type_util import Self, assert_never, final +from bigint_presentation_code.util import (BaseBitSet, BitSet, FBitSet, OFSet, + OSet, FMap) +from functools import lru_cache @final class Fn: def __init__(self): self.ops = [] # type: list[Op] - op_names = WeakValueDictionary() - self.__op_names = op_names # type: WeakValueDictionary[str, Op] + self.__op_names = _WeakVDict() # type: _WeakVDict[str, Op] self.__next_name_suffix = 2 def _add_op_with_unused_name(self, op, name=""): @@ -32,278 +35,464 @@ class Fn: self.__next_name_suffix += 1 def __repr__(self): + # type: () -> str return "" @unique @final -class RegKind(Enum): - GPR = enum.auto() +class BaseTy(Enum): + I64 = enum.auto() CA = enum.auto() VL_MAXVL = enum.auto() @cached_property def only_scalar(self): - if self is RegKind.GPR: + # type: () -> bool + if self is BaseTy.I64: return False - elif self is RegKind.CA or self is RegKind.VL_MAXVL: + elif self is BaseTy.CA or self is BaseTy.VL_MAXVL: return True else: assert_never(self) @cached_property - def reg_count(self): - if self is RegKind.GPR: + def max_reg_len(self): + # type: () -> int + if self is BaseTy.I64: return 128 - elif self is RegKind.CA or self is RegKind.VL_MAXVL: + elif self is BaseTy.CA or self is BaseTy.VL_MAXVL: return 1 else: assert_never(self) def __repr__(self): - return "RegKind." + self._name_ + return "BaseTy." + self._name_ @plain_data(frozen=True, unsafe_hash=True) @final -class OperandType: - __slots__ = "kind", "vec" +class Ty: + __slots__ = "base_ty", "reg_len" - def __init__(self, kind, vec): - # type: (RegKind, bool) -> None - self.kind = kind - if kind.only_scalar and vec: - raise ValueError(f"kind={kind} must have vec=False") - self.vec = vec - - def get_length(self, maxvl): - # type: (int) -> int - # here's where subvl and elwid would be accounted for - if self.vec: - return maxvl - return 1 + @staticmethod + def validate(base_ty, reg_len): + # type: (BaseTy, int) -> str | None + """ return a string with the error if the combination is invalid, + otherwise return None + """ + if base_ty.only_scalar and reg_len != 1: + return f"can't create a vector of an only-scalar type: {base_ty}" + if reg_len < 1 or reg_len > base_ty.max_reg_len: + return "reg_len out of range" + return None + + def __init__(self, base_ty, reg_len): + # type: (BaseTy, int) -> None + msg = self.validate(base_ty=base_ty, reg_len=reg_len) + if msg is not None: + raise ValueError(msg) + self.base_ty = base_ty + self.reg_len = reg_len -@plain_data(frozen=True, unsafe_hash=True) +@unique @final -class RegShape: - __slots__ = "kind", "length" +class LocKind(Enum): + GPR = enum.auto() + StackI64 = enum.auto() + CA = enum.auto() + VL_MAXVL = enum.auto() - def __init__(self, kind, length=1): - # type: (RegKind, int) -> None - self.kind = kind - if length < 1 or length > kind.reg_count: - raise ValueError("invalid length") - self.length = length + @cached_property + def base_ty(self): + # type: () -> BaseTy + if self is LocKind.GPR or self is LocKind.StackI64: + return BaseTy.I64 + if self is LocKind.CA: + return BaseTy.CA + if self is LocKind.VL_MAXVL: + return BaseTy.VL_MAXVL + else: + assert_never(self) - def try_concat(self, *others): - # type: (*RegShape | Reg | RegClass | None) -> RegShape | None - kind = self.kind - length = self.length - for other in others: - if isinstance(other, (Reg, RegClass)): - other = other.shape - if other is None: - return None - if other.kind != self.kind: - return None - length += other.length - if length > kind.reg_count: - return None - return RegShape(kind=kind, length=length) + @cached_property + def loc_count(self): + # type: () -> int + if self is LocKind.StackI64: + return 1024 + if self is LocKind.GPR or self is LocKind.CA \ + or self is LocKind.VL_MAXVL: + return self.base_ty.max_reg_len + else: + assert_never(self) + + def __repr__(self): + return "LocKind." + self._name_ -@plain_data(frozen=True, unsafe_hash=True) @final -class Reg: - __slots__ = "shape", "start" - - def __init__(self, shape, start): - # type: (RegShape, int) -> None - self.shape = shape - if start < 0 or start + shape.length > shape.kind.reg_count: - raise ValueError("start not in valid range") - self.start = start +@unique +class LocSubKind(Enum): + BASE_GPR = enum.auto() + SV_EXTRA2_VGPR = enum.auto() + SV_EXTRA2_SGPR = enum.auto() + SV_EXTRA3_VGPR = enum.auto() + SV_EXTRA3_SGPR = enum.auto() + StackI64 = enum.auto() + CA = enum.auto() + VL_MAXVL = enum.auto() - @property + @cached_property def kind(self): - return self.shape.kind + # type: () -> LocKind + # pyright fails typechecking when using `in` here: + # reported: https://github.com/microsoft/pyright/issues/4102 + if self is LocSubKind.BASE_GPR or self is LocSubKind.SV_EXTRA2_VGPR \ + or self is LocSubKind.SV_EXTRA2_SGPR \ + or self is LocSubKind.SV_EXTRA3_VGPR \ + or self is LocSubKind.SV_EXTRA3_SGPR: + return LocKind.GPR + if self is LocSubKind.StackI64: + return LocKind.StackI64 + if self is LocSubKind.CA: + return LocKind.CA + if self is LocSubKind.VL_MAXVL: + return LocKind.VL_MAXVL + assert_never(self) @property - def length(self): - return self.shape.length + def base_ty(self): + return self.kind.base_ty + + @lru_cache() + def allocatable_locs(self, ty): + # type: (Ty) -> LocSet + if ty.base_ty != self.base_ty: + raise ValueError("type mismatch") + raise NotImplementedError # FIXME: finish + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class GenericTy: + __slots__ = "base_ty", "is_vec" + + def __init__(self, base_ty, is_vec): + # type: (BaseTy, bool) -> None + self.base_ty = base_ty + if base_ty.only_scalar and is_vec: + raise ValueError(f"base_ty={base_ty} requires is_vec=False") + self.is_vec = is_vec + + def instantiate(self, maxvl): + # type: (int) -> Ty + # here's where subvl and elwid would be accounted for + if self.is_vec: + return Ty(self.base_ty, maxvl) + return Ty(self.base_ty, 1) + + def can_instantiate_to(self, ty): + # type: (Ty) -> bool + if self.base_ty != ty.base_ty: + return False + if self.is_vec: + return True + return ty.reg_len == 1 + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class Loc: + __slots__ = "kind", "start", "reg_len" + + @staticmethod + def validate(kind, start, reg_len): + # type: (LocKind, int, int) -> str | None + msg = Ty.validate(base_ty=kind.base_ty, reg_len=reg_len) + if msg is not None: + return msg + if reg_len > kind.loc_count: + return "invalid reg_len" + if start < 0 or start + reg_len > kind.loc_count: + return "start not in valid range" + return None + + @staticmethod + def try_make(kind, start, reg_len): + # type: (LocKind, int, int) -> Loc | None + msg = Loc.validate(kind=kind, start=start, reg_len=reg_len) + if msg is None: + return None + return Loc(kind=kind, start=start, reg_len=reg_len) + + def __init__(self, kind, start, reg_len): + # type: (LocKind, int, int) -> None + msg = self.validate(kind=kind, start=start, reg_len=reg_len) + if msg is not None: + raise ValueError(msg) + self.kind = kind + self.reg_len = reg_len + self.start = start def conflicts(self, other): - # type: (Reg) -> bool - return (self.kind == other.kind + # type: (Loc) -> bool + return (self.kind != other.kind and self.start < other.stop and other.start < self.stop) + @staticmethod + def make_ty(kind, reg_len): + # type: (LocKind, int) -> Ty + return Ty(base_ty=kind.base_ty, reg_len=reg_len) + + @cached_property + def ty(self): + # type: () -> Ty + return self.make_ty(kind=self.kind, reg_len=self.reg_len) + @property def stop(self): - return self.start + self.length + # type: () -> int + return self.start + self.reg_len def try_concat(self, *others): - # type: (*Reg | None) -> Reg | None - shape = self.shape.try_concat(*others) - if shape is None: - return None + # type: (*Loc | None) -> Loc | None + reg_len = self.reg_len stop = self.stop for other in others: - assert other is not None, "already caught by RegShape.try_concat" + if other is None or other.kind != self.kind: + return None if stop != other.start: return None stop = other.stop - return Reg(shape, self.start) + reg_len += other.reg_len + return Loc(kind=self.kind, start=self.start, reg_len=reg_len) +@plain_data(frozen=True, eq=False, repr=False) @final -class RegClass(AbstractSet[Reg]): - def __init__(self, regs_or_starts=(), shape=None, starts_bitset=0): - # type: (Iterable[Reg | int], RegShape | None, int) -> None - for reg_or_start in regs_or_starts: - if isinstance(reg_or_start, Reg): - if shape is None: - shape = reg_or_start.shape - elif shape != reg_or_start.shape: - raise ValueError(f"conflicting RegShapes: {shape} and " - f"{reg_or_start.shape}") - start = reg_or_start.start - else: - start = reg_or_start - if start < 0: - raise ValueError("a Reg's start is out of range") - starts_bitset |= 1 << start - if starts_bitset == 0: - shape = None - self.__shape = shape - self.__starts_bitset = starts_bitset - if shape is None: - if starts_bitset != 0: - raise ValueError("non-empty RegClass must have non-None shape") +class LocSet(AbstractSet[Loc]): + __slots__ = "starts", "ty" + + def __init__(self, __locs=()): + # type: (Iterable[Loc]) -> None + if isinstance(__locs, LocSet): + self.starts = __locs.starts # type: FMap[LocKind, FBitSet] + self.ty = __locs.ty # type: Ty | None return - if self.stops_bitset >= 1 << shape.kind.reg_count: - raise ValueError("a Reg's start is out of range") - - @property - def shape(self): - # type: () -> RegShape | None - return self.__shape - - @property - def starts_bitset(self): - # type: () -> int - return self.__starts_bitset - - @property - def stops_bitset(self): - # type: () -> int - if self.__shape is None: - return 0 - return self.__starts_bitset << self.__shape.length - - @cached_property - def starts(self): - # type: () -> OFSet[int] - if self.length is None: - return OFSet() - # TODO: fixme - # return OFSet(for i in range(self.length)) + starts = {i: BitSet() for i in LocKind} + ty = None + for loc in __locs: + if ty is None: + ty = loc.ty + if ty != loc.ty: + raise ValueError(f"conflicting types: {ty} != {loc.ty}") + starts[loc.kind].add(loc.start) + self.starts = FMap( + (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0) + self.ty = ty @cached_property def stops(self): - # type: () -> OFSet[int] - if self.__shape is None: - return OFSet() - return OFSet(i + self.__shape.length for i in self.__starts) + # type: () -> FMap[LocKind, FBitSet] + if self.ty is None: + return FMap() + sh = self.ty.reg_len + return FMap( + (k, FBitSet(bits=v.bits << sh)) for k, v in self.starts.items()) @property - def kind(self): - if self.__shape is None: + def kinds(self): + # type: () -> AbstractSet[LocKind] + return self.starts.keys() + + @property + def reg_len(self): + # type: () -> int | None + if self.ty is None: return None - return self.__shape.kind + return self.ty.reg_len @property - def length(self): - """length of registers in this RegClass, not to be confused with the number of `Reg`s in self""" - if self.__shape is None: + def base_ty(self): + # type: () -> BaseTy | None + if self.ty is None: return None - return self.__shape.length + return self.ty.base_ty def concat(self, *others): - # type: (*RegClass) -> RegClass - shape = self.__shape - if shape is None: - return RegClass() - shape = shape.try_concat(*others) - if shape is None: - return RegClass() - starts = OSet(self.starts) - offset = shape.length + # type: (*LocSet) -> LocSet + if self.ty is None: + return LocSet() + base_ty = self.ty.base_ty + reg_len = self.ty.reg_len + starts = {k: BitSet(v) for k, v in self.starts.items()} for other in others: - assert other.__shape is not None, \ - "already caught by RegShape.try_concat" - starts &= OSet(i - offset for i in other.starts) - offset += other.__shape.length - return RegClass(starts, shape=shape) - - def __contains__(self, reg): - # type: (Reg) -> bool - return reg.shape == self.shape and reg.start in self.starts + if other.ty is None: + return LocSet() + if other.ty.base_ty != base_ty: + return LocSet() + for kind, other_starts in other.starts.items(): + if kind not in starts: + continue + starts[kind].bits &= other_starts.bits >> reg_len + if starts[kind] == 0: + del starts[kind] + if len(starts) == 0: + return LocSet() + reg_len += other.ty.reg_len + + def locs(): + # type: () -> Iterable[Loc] + for kind, v in starts.items(): + for start in v: + loc = Loc.try_make(kind=kind, start=start, reg_len=reg_len) + if loc is not None: + yield loc + return LocSet(locs()) + + def __contains__(self, loc): + # type: (Loc | Any) -> bool + if not isinstance(loc, Loc) or loc.ty == self.ty: + return False + if loc.kind not in self.starts: + return False + return loc.start in self.starts[loc.kind] def __iter__(self): - # type: () -> Iterator[Reg] - if self.shape is None: + # type: () -> Iterator[Loc] + if self.ty is None: return - for start in self.starts: - yield Reg(shape=self.shape, start=start) + for kind, starts in self.starts.items(): + for start in starts: + yield Loc(kind=kind, start=start, reg_len=self.ty.reg_len) + + @cached_property + def __len(self): + return sum((len(v) for v in self.starts.values()), 0) def __len__(self): - return len(self.starts) + return self.__len - def __hash__(self): + @cached_property + def __hash(self): return super()._hash() + def __hash__(self): + return self.__hash + @plain_data(frozen=True, unsafe_hash=True) @final -class Operand: - __slots__ = "ty", "regs" - - def __init__(self, ty, regs=None): - # type: (OperandType, OFSet[int] | None) -> None - pass - - -OT_VGPR = OperandType(RegKind.GPR, vec=True) -OT_SGPR = OperandType(RegKind.GPR, vec=False) -OT_CA = OperandType(RegKind.CA, vec=False) -OT_VL = OperandType(RegKind.VL_MAXVL, vec=False) +class GenericOperandDesc: + """generic Op operand descriptor""" + __slots__ = "ty", "fixed_loc", "sub_kinds", "tied_input_index" + + def __init__(self, ty, sub_kinds, fixed_loc=None, tied_input_index=None): + # type: (GenericTy, Iterable[LocSubKind], Loc | None, int | None) -> None + self.ty = ty + self.sub_kinds = OFSet(sub_kinds) + if len(self.sub_kinds) == 0: + raise ValueError("sub_kinds can't be empty") + self.fixed_loc = fixed_loc + if fixed_loc is not None: + if tied_input_index is not None: + raise ValueError("operand can't be both tied and fixed") + if not ty.can_instantiate_to(fixed_loc.ty): + raise ValueError( + f"fixed_loc has incompatible type for given generic " + f"type: fixed_loc={fixed_loc} generic ty={ty}") + if len(self.sub_kinds) != 1: + raise ValueError( + "multiple sub_kinds not allowed for fixed operand") + for sub_kind in self.sub_kinds: + if fixed_loc not in sub_kind.allocatable_locs(fixed_loc.ty): + raise ValueError( + f"fixed_loc not in given sub_kind: " + f"fixed_loc={fixed_loc} sub_kind={sub_kind}") + for sub_kind in self.sub_kinds: + if sub_kind.base_ty != ty.base_ty: + raise ValueError(f"sub_kind is incompatible with type: " + f"sub_kind={sub_kind} ty={ty}") + if tied_input_index is not None and tied_input_index < 0: + raise ValueError("invalid tied_input_index") + self.tied_input_index = tied_input_index + + def tied_to_input(self, tied_input_index): + # type: (int) -> Self + return GenericOperandDesc(self.ty, self.sub_kinds, + tied_input_index=tied_input_index) + + def with_fixed_loc(self, fixed_loc): + # type: (Loc) -> Self + return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc) + + def instantiate(self, maxvl): + # type: (int) -> OperandDesc + ty = self.ty.instantiate(maxvl=maxvl) + + def locs(): + # type: () -> Iterable[Loc] + if self.fixed_loc is not None: + if ty != self.fixed_loc.ty: + raise ValueError( + f"instantiation failed: type mismatch with fixed_loc: " + f"instantiated type: {ty} fixed_loc: {self.fixed_loc}") + yield self.fixed_loc + return + for sub_kind in self.sub_kinds: + yield from sub_kind.allocatable_locs(ty) + return OperandDesc(loc_set=LocSet(locs()), + tied_input_index=self.tied_input_index) @plain_data(frozen=True, unsafe_hash=True) -class TiedOutput: - __slots__ = "input_index", "output_index" - - def __init__(self, input_index, output_index): - # type: (int, int) -> None - self.input_index = input_index - self.output_index = output_index - - -Constraint = Union[TiedOutput, NoReturn] +@final +class OperandDesc: + """Op operand descriptor""" + __slots__ = "loc_set", "tied_input_index" + + def __init__(self, loc_set, tied_input_index): + # type: (LocSet, int | None) -> None + if len(loc_set) == 0: + raise ValueError("loc_set must not be empty") + self.loc_set = loc_set + self.tied_input_index = tied_input_index + + +OD_BASE_SGPR = GenericOperandDesc( + ty=GenericTy(base_ty=BaseTy.I64, is_vec=False), + sub_kinds=[LocSubKind.BASE_GPR]) +OD_EXTRA3_SGPR = GenericOperandDesc( + ty=GenericTy(base_ty=BaseTy.I64, is_vec=False), + sub_kinds=[LocSubKind.SV_EXTRA3_SGPR]) +OD_EXTRA3_VGPR = GenericOperandDesc( + ty=GenericTy(base_ty=BaseTy.I64, is_vec=True), + sub_kinds=[LocSubKind.SV_EXTRA3_VGPR]) +OD_EXTRA2_SGPR = GenericOperandDesc( + ty=GenericTy(base_ty=BaseTy.I64, is_vec=False), + sub_kinds=[LocSubKind.SV_EXTRA2_SGPR]) +OD_EXTRA2_VGPR = GenericOperandDesc( + ty=GenericTy(base_ty=BaseTy.I64, is_vec=True), + sub_kinds=[LocSubKind.SV_EXTRA2_VGPR]) +OD_CA = GenericOperandDesc( + ty=GenericTy(base_ty=BaseTy.CA, is_vec=False), + sub_kinds=[LocSubKind.CA]) +OD_VL = GenericOperandDesc( + ty=GenericTy(base_ty=BaseTy.VL_MAXVL, is_vec=False), + sub_kinds=[LocSubKind.VL_MAXVL]) @plain_data(frozen=True, unsafe_hash=True) @final -class OpProperties: - __slots__ = ("demo_asm", "inputs", "outputs", "immediates", "constraints", +class GenericOpProperties: + __slots__ = ("demo_asm", "inputs", "outputs", "immediates", "is_copy", "is_load_immediate", "has_side_effects") def __init__(self, demo_asm, # type: str - inputs, # type: Iterable[OperandType] - outputs, # type: Iterable[OperandType] + inputs, # type: Iterable[GenericOperandDesc] + outputs, # type: Iterable[GenericOperandDesc] immediates, # type: Iterable[range] - constraints, # type: Iterable[Constraint] is_copy=False, # type: bool is_load_immediate=False, # type: bool has_side_effects=False, # type: bool @@ -313,17 +502,22 @@ class OpProperties: self.inputs = tuple(inputs) self.outputs = tuple(outputs) self.immediates = tuple(immediates) - self.constraints = tuple(constraints) self.is_copy = is_copy self.is_load_immediate = is_load_immediate self.has_side_effects = has_side_effects + def instantiate(self, maxvl): + # type: (int) -> OpProperties + raise NotImplementedError # FIXME: finish + + +# FIXME: add OpProperties @unique @final class OpKind(Enum): def __init__(self, properties): - # type: (OpProperties) -> None + # type: (GenericOpProperties) -> None super().__init__() self.properties = properties @@ -451,13 +645,13 @@ class SSAVal: self.sliced_op_outputs = tuple(processed) def __add__(self, other): - # type: (SSAVal) -> SSAVal + # type: (SSAVal | Any) -> SSAVal if not isinstance(other, SSAVal): return NotImplemented return SSAVal(self.sliced_op_outputs + other.sliced_op_outputs) def __radd__(self, other): - # type: (SSAVal) -> SSAVal + # type: (SSAVal | Any) -> SSAVal if isinstance(other, SSAVal): return other.__add__(self) return NotImplemented @@ -465,7 +659,7 @@ class SSAVal: @cached_property def expanded_sliced_op_outputs(self): # type: () -> tuple[tuple[Op, int, int], ...] - retval = [] + retval = [] # type: list[tuple[Op, int, int]] for op, output_index, range_ in self.sliced_op_outputs: for i in range_: retval.append((op, output_index, i)) @@ -490,7 +684,7 @@ class SSAVal: # type: () -> str if len(self.sliced_op_outputs) == 0: return "SSAVal([])" - parts = [] + parts = [] # type: list[str] for op, output_index, range_ in self.sliced_op_outputs: out_len = op.properties.outputs[output_index].get_length(op.maxvl) parts.append(f"<{op.name}#{output_index}>") @@ -513,13 +707,14 @@ class Op: self.maxvl = maxvl outputs_len = len(self.properties.outputs) self.outputs = tuple(SSAVal([(self, i)]) for i in range(outputs_len)) - self.name = fn._add_op_with_unused_name(self, name) + self.name = fn._add_op_with_unused_name(self, name) # type: ignore @property def properties(self): return self.kind.properties def __eq__(self, other): + # type: (Op | Any) -> bool if isinstance(other, Op): return self is other return NotImplemented diff --git a/src/bigint_presentation_code/matrix.py b/src/bigint_presentation_code/matrix.py index 89c3ea2..0674c02 100644 --- a/src/bigint_presentation_code/matrix.py +++ b/src/bigint_presentation_code/matrix.py @@ -1,10 +1,9 @@ -import operator from enum import Enum, unique from fractions import Fraction -from numbers import Rational +import operator from typing import Any, Callable, Generic, Iterable, Iterator, Type, TypeVar -from bigint_presentation_code.util import final +from bigint_presentation_code.type_util import final _T = TypeVar("_T") _T2 = TypeVar("_T2") @@ -103,7 +102,7 @@ class Matrix(Generic[_T]): return retval def __truediv__(self, rhs): - # type: (Rational | int) -> Matrix + # type: (_T | int) -> Matrix[_T] retval = self.copy() for i in self.indexes(): retval[i] /= rhs # type: ignore @@ -128,7 +127,7 @@ class Matrix(Generic[_T]): return lhs.__matmul__(self) def __elementwise_bin_op(self, rhs, op): - # type: (Matrix, Callable[[_T | int, _T | int], _T | int]) -> Matrix[_T] + # type: (Matrix[_T], Callable[[_T | int, _T | int], _T | int]) -> Matrix[_T] if self.height != rhs.height or self.width != rhs.width: raise ValueError( "matrix dimensions must match for element-wise operations") @@ -172,8 +171,8 @@ class Matrix(Generic[_T]): # type: () -> str if self.height == 0 or self.width == 0: return f"Matrix(height={self.height}, width={self.width})" - lines = [] - line = [] + lines = [] # type: list[str] + line = [] # type: list[str] for row in range(self.height): line.clear() for col in range(self.width): @@ -183,16 +182,16 @@ class Matrix(Generic[_T]): else: line.append(repr(el)) lines.append(", ".join(line)) - lines = ",\n ".join(lines) + lines_str = ",\n ".join(lines) element_type = "" if self.element_type is not Fraction: element_type = f"element_type={self.element_type}, " return (f"Matrix(height={self.height}, width={self.width}, " f"{element_type}data=[\n" - f" {lines},\n])") + f" {lines_str},\n])") def __eq__(self, rhs): - # type: (object) -> bool + # type: (Matrix[Any] | Any) -> bool if not isinstance(rhs, Matrix): return NotImplemented return (self.height == rhs.height diff --git a/src/bigint_presentation_code/py.typed b/src/bigint_presentation_code/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/bigint_presentation_code/register_allocator.py b/src/bigint_presentation_code/register_allocator.py index b8269e4..cc794e9 100644 --- a/src/bigint_presentation_code/register_allocator.py +++ b/src/bigint_presentation_code/register_allocator.py @@ -12,7 +12,8 @@ from nmutil.plain_data import plain_data from bigint_presentation_code.compiler_ir import (GPRRangeType, Op, RegClass, RegLoc, RegType, SSAVal) -from bigint_presentation_code.util import OFSet, OSet, final +from bigint_presentation_code.type_util import final +from bigint_presentation_code.util import OFSet, OSet _RegType = TypeVar("_RegType", bound=RegType) diff --git a/src/bigint_presentation_code/test_compiler_ir.py b/src/bigint_presentation_code/test_compiler_ir.py deleted file mode 100644 index 820c305..0000000 --- a/src/bigint_presentation_code/test_compiler_ir.py +++ /dev/null @@ -1,120 +0,0 @@ -import unittest - -from bigint_presentation_code.compiler_ir import (VL, FixedGPRRangeType, Fn, - GlobalMem, GPRRange, GPRType, - OpBigIntAddSub, OpConcat, - OpCopy, OpFuncArg, - OpInputMem, OpLI, OpLoad, - OpSetCA, OpSetVLImm, OpStore, - RegLoc, SSAVal, XERBit, - generate_assembly, - op_set_to_list) - - -class TestCompilerIR(unittest.TestCase): - maxDiff = None - - def test_op_set_to_list(self): - fn = Fn() - op0 = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3))) - op1 = OpCopy(fn, op0.out, GPRType()) - arg = op1.dest - op2 = OpInputMem(fn) - mem = op2.out - op3 = OpSetVLImm(fn, 32) - vl = op3.out - op4 = OpLoad(fn, arg, offset=0, mem=mem, vl=vl) - a = op4.RT - op5 = OpLI(fn, 1) - b_0 = op5.out - op6 = OpSetVLImm(fn, 31) - vl = op6.out - op7 = OpLI(fn, 0, vl=vl) - b_rest = op7.out - op8 = OpConcat(fn, [b_0, b_rest]) - b = op8.dest - op9 = OpSetVLImm(fn, 32) - vl = op9.out - op10 = OpSetCA(fn, False) - ca = op10.out - op11 = OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl) - s = op11.out - op12 = OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl) - mem = op12.mem_out - - expected_ops = [ - op10, # OpSetCA(fn, False) - op9, # OpSetVLImm(fn, 32) - op6, # OpSetVLImm(fn, 31) - op5, # OpLI(fn, 1) - op3, # OpSetVLImm(fn, 32) - op2, # OpInputMem(fn) - op0, # OpFuncArg(fn, FixedGPRRangeType(GPRRange(3))) - op7, # OpLI(fn, 0, vl=vl) - op1, # OpCopy(fn, op0.out, GPRType()) - op8, # OpConcat(fn, [b_0, b_rest]) - op4, # OpLoad(fn, arg, offset=0, mem=mem, vl=vl) - op11, # OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl) - op12, # OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl) - ] - - ops = op_set_to_list(fn.ops[::-1]) - if ops != expected_ops: - self.assertEqual(repr(ops), repr(expected_ops)) - - def tst_generate_assembly(self, use_reg_alloc=False): - fn = Fn() - op0 = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3))) - op1 = OpCopy(fn, op0.out, GPRType()) - arg = op1.dest - op2 = OpInputMem(fn) - mem = op2.out - op3 = OpSetVLImm(fn, 32) - vl = op3.out - op4 = OpLoad(fn, arg, offset=0, mem=mem, vl=vl) - a = op4.RT - op5 = OpLI(fn, 0, vl=vl) - b = op5.out - op6 = OpSetCA(fn, True) - ca = op6.out - op7 = OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl) - s = op7.out - op8 = OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl) - mem = op8.mem_out - - assigned_registers = { - op0.out: GPRRange(start=3, length=1), - op1.dest: GPRRange(start=3, length=1), - op2.out: GlobalMem.GlobalMem, - op3.out: VL.VL_MAXVL, - op4.RT: GPRRange(start=78, length=32), - op5.out: GPRRange(start=46, length=32), - op6.out: XERBit.CA, - op7.out: GPRRange(start=14, length=32), - op7.CA_out: XERBit.CA, - op8.mem_out: GlobalMem.GlobalMem, - } # type: dict[SSAVal, RegLoc] | None - - if use_reg_alloc: - assigned_registers = None - - asm = generate_assembly(fn.ops, assigned_registers) - self.assertEqual(asm, [ - "setvl 0, 0, 32, 0, 1, 1", - "sv.ld *78, 0(3)", - "sv.addi *46, 0, 0", - "subfic 0, 0, -1", - "sv.adde *14, *78, *46", - "sv.std *14, 0(3)", - "bclr 20, 0, 0", - ]) - - def test_generate_assembly(self): - self.tst_generate_assembly() - - def test_generate_assembly_with_register_allocator(self): - self.tst_generate_assembly(use_reg_alloc=True) - - -if __name__ == "__main__": - unittest.main() diff --git a/src/bigint_presentation_code/test_matrix.py b/src/bigint_presentation_code/test_matrix.py deleted file mode 100644 index 1a56df0..0000000 --- a/src/bigint_presentation_code/test_matrix.py +++ /dev/null @@ -1,113 +0,0 @@ -import unittest -from fractions import Fraction - -from bigint_presentation_code.matrix import Matrix, SpecialMatrix - - -class TestMatrix(unittest.TestCase): - def test_repr(self): - self.assertEqual(repr(Matrix(2, 3, [0, 1, 2, - 3, 4, 5])), - 'Matrix(height=2, width=3, data=[\n' - ' 0, 1, 2,\n' - ' 3, 4, 5,\n' - '])') - self.assertEqual(repr(Matrix(2, 3, [0, 1, Fraction(2) / 3, - 3, 4, 5])), - 'Matrix(height=2, width=3, data=[\n' - ' 0, 1, Fraction(2, 3),\n' - ' 3, 4, 5,\n' - '])') - self.assertEqual(repr(Matrix(0, 3)), 'Matrix(height=0, width=3)') - self.assertEqual(repr(Matrix(2, 0)), 'Matrix(height=2, width=0)') - - def test_eq(self): - self.assertFalse(Matrix(1, 1) == 5) - self.assertFalse(5 == Matrix(1, 1)) - self.assertFalse(Matrix(2, 1) == Matrix(1, 1)) - self.assertFalse(Matrix(1, 2) == Matrix(1, 1)) - self.assertTrue(Matrix(1, 1) == Matrix(1, 1)) - self.assertTrue(Matrix(1, 1, [1]) == Matrix(1, 1, [1])) - self.assertFalse(Matrix(1, 1, [2]) == Matrix(1, 1, [1])) - - def test_add(self): - self.assertEqual(Matrix(2, 2, [1, 2, 3, 4]) - + Matrix(2, 2, [40, 30, 20, 10]), - Matrix(2, 2, [41, 32, 23, 14])) - - def test_identity(self): - self.assertEqual(Matrix(2, 2, data=SpecialMatrix.Identity), - Matrix(2, 2, [1, 0, - 0, 1])) - self.assertEqual(Matrix(1, 3, data=SpecialMatrix.Identity), - Matrix(1, 3, [1, 0, 0])) - self.assertEqual(Matrix(2, 3, data=SpecialMatrix.Identity), - Matrix(2, 3, [1, 0, 0, - 0, 1, 0])) - self.assertEqual(Matrix(3, 3, data=SpecialMatrix.Identity), - Matrix(3, 3, [1, 0, 0, - 0, 1, 0, - 0, 0, 1])) - - def test_sub(self): - self.assertEqual(Matrix(2, 2, [40, 30, 20, 10]) - - Matrix(2, 2, [-1, -2, -3, -4]), - Matrix(2, 2, [41, 32, 23, 14])) - - def test_neg(self): - self.assertEqual(-Matrix(2, 2, [40, 30, 20, 10]), - Matrix(2, 2, [-40, -30, -20, -10])) - - def test_mul(self): - self.assertEqual(Matrix(2, 2, [1, 2, 3, 4]) * Fraction(3, 2), - Matrix(2, 2, [Fraction(3, 2), 3, Fraction(9, 2), 6])) - self.assertEqual(Fraction(3, 2) * Matrix(2, 2, [1, 2, 3, 4]), - Matrix(2, 2, [Fraction(3, 2), 3, Fraction(9, 2), 6])) - - def test_matmul(self): - self.assertEqual(Matrix(2, 2, [1, 2, 3, 4]) - @ Matrix(2, 2, [4, 3, 2, 1]), - Matrix(2, 2, [8, 5, 20, 13])) - self.assertEqual(Matrix(3, 2, [6, 5, 4, 3, 2, 1]) - @ Matrix(2, 1, [1, 2]), - Matrix(3, 1, [16, 10, 4])) - - def test_inverse(self): - self.assertEqual(Matrix(0, 0).inverse(), Matrix(0, 0)) - self.assertEqual(Matrix(1, 1, [2]).inverse(), - Matrix(1, 1, [Fraction(1, 2)])) - self.assertEqual(Matrix(1, 1, [1]).inverse(), - Matrix(1, 1, [1])) - self.assertEqual(Matrix(2, 2, [1, 0, 1, 1]).inverse(), - Matrix(2, 2, [1, 0, -1, 1])) - self.assertEqual(Matrix(3, 3, [0, 1, 0, - 1, 0, 0, - 0, 0, 1]).inverse(), - Matrix(3, 3, [0, 1, 0, - 1, 0, 0, - 0, 0, 1])) - _1_2 = Fraction(1, 2) - _1_3 = Fraction(1, 3) - _1_6 = Fraction(1, 6) - self.assertEqual(Matrix(5, 5, [1, 0, 0, 0, 0, - 1, 1, 1, 1, 1, - 1, -1, 1, -1, 1, - 1, -2, 4, -8, 16, - 0, 0, 0, 0, 1]).inverse(), - Matrix(5, 5, [1, 0, 0, 0, 0, - _1_2, _1_3, -1, _1_6, -2, - -1, _1_2, _1_2, 0, -1, - -_1_2, _1_6, _1_2, -_1_6, 2, - 0, 0, 0, 0, 1])) - with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"): - Matrix(1, 1, [0]).inverse() - with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"): - Matrix(2, 2, [0, 0, 1, 1]).inverse() - with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"): - Matrix(2, 2, [1, 0, 1, 0]).inverse() - with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"): - Matrix(2, 2, [1, 1, 1, 1]).inverse() - - -if __name__ == "__main__": - unittest.main() diff --git a/src/bigint_presentation_code/test_register_allocator.py b/src/bigint_presentation_code/test_register_allocator.py deleted file mode 100644 index 1eff254..0000000 --- a/src/bigint_presentation_code/test_register_allocator.py +++ /dev/null @@ -1,201 +0,0 @@ -import unittest - -from bigint_presentation_code.compiler_ir import (VL, FixedGPRRangeType, Fn, - GlobalMem, GPRRange, GPRType, - OpBigIntAddSub, OpConcat, - OpCopy, OpFuncArg, - OpInputMem, OpLI, OpLoad, - OpSetCA, OpSetVLImm, OpStore, - XERBit) -from bigint_presentation_code.register_allocator import ( - AllocationFailed, MergedRegSet, allocate_registers, - try_allocate_registers_without_spilling) - - -class TestMergedRegSet(unittest.TestCase): - maxDiff = None - - def test_from_equality_constraint(self): - fn = Fn() - li0x1 = OpLI(fn, 0, vl=OpSetVLImm(fn, 1).out) - li0x2 = OpLI(fn, 0, vl=OpSetVLImm(fn, 2).out) - li0x3 = OpLI(fn, 0, vl=OpSetVLImm(fn, 3).out) - self.assertEqual(MergedRegSet.from_equality_constraint([ - li0x1.out, - li0x2.out, - li0x3.out, - ]), MergedRegSet({ - li0x1.out: 0, - li0x2.out: 1, - li0x3.out: 3, - }.items())) - self.assertEqual(MergedRegSet.from_equality_constraint([ - li0x2.out, - li0x1.out, - li0x3.out, - ]), MergedRegSet({ - li0x2.out: 0, - li0x1.out: 2, - li0x3.out: 3, - }.items())) - - -class TestRegisterAllocator(unittest.TestCase): - maxDiff = None - - def test_try_alloc_fail(self): - fn = Fn() - op0 = OpSetVLImm(fn, 52) - op1 = OpLI(fn, 0, vl=op0.out) - op2 = OpSetVLImm(fn, 64) - op3 = OpLI(fn, 0, vl=op2.out) - op4 = OpConcat(fn, [op1.out, op3.out]) - - reg_assignments = try_allocate_registers_without_spilling(fn.ops) - self.assertEqual( - repr(reg_assignments), - "AllocationFailed(" - "node=IGNode(#0, merged_reg_set=MergedRegSet([" - "(<#4.dest: >, 0), " - "(<#1.out: >, 0), " - "(<#3.out: >, 52)]), " - "edges={}, reg=None), " - "live_intervals=LiveIntervals(live_intervals={" - "MergedRegSet([(<#0.out: KnownVLType(length=52)>, 0)]): " - "LiveInterval(first_write=0, last_use=1), " - "MergedRegSet([(<#4.dest: >, 0), " - "(<#1.out: >, 0), " - "(<#3.out: >, 52)]): " - "LiveInterval(first_write=1, last_use=4), " - "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)]): " - "LiveInterval(first_write=2, last_use=3)}, " - "merged_reg_sets=MergedRegSets(data={" - "<#0.out: KnownVLType(length=52)>: " - "MergedRegSet([(<#0.out: KnownVLType(length=52)>, 0)]), " - "<#1.out: >: MergedRegSet([" - "(<#4.dest: >, 0), " - "(<#1.out: >, 0), " - "(<#3.out: >, 52)]), " - "<#2.out: KnownVLType(length=64)>: " - "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)]), " - "<#3.out: >: MergedRegSet([" - "(<#4.dest: >, 0), " - "(<#1.out: >, 0), " - "(<#3.out: >, 52)]), " - "<#4.dest: >: MergedRegSet([" - "(<#4.dest: >, 0), " - "(<#1.out: >, 0), " - "(<#3.out: >, 52)])}), " - "reg_sets_live_after={" - "0: OFSet([MergedRegSet([" - "(<#0.out: KnownVLType(length=52)>, 0)])]), " - "1: OFSet([MergedRegSet([" - "(<#4.dest: >, 0), " - "(<#1.out: >, 0), " - "(<#3.out: >, 52)])]), " - "2: OFSet([MergedRegSet([" - "(<#4.dest: >, 0), " - "(<#1.out: >, 0), " - "(<#3.out: >, 52)]), " - "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)])]), " - "3: OFSet([MergedRegSet([" - "(<#4.dest: >, 0), " - "(<#1.out: >, 0), " - "(<#3.out: >, 52)])]), " - "4: OFSet()}), " - "interference_graph=InterferenceGraph(nodes={" - "...: IGNode(#0, merged_reg_set=MergedRegSet([" - "(<#0.out: KnownVLType(length=52)>, 0)]), edges={}, reg=None), " - "...: IGNode(#1, merged_reg_set=MergedRegSet([" - "(<#4.dest: >, 0), " - "(<#1.out: >, 0), " - "(<#3.out: >, 52)]), edges={}, reg=None), " - "...: IGNode(#2, merged_reg_set=MergedRegSet([" - "(<#2.out: KnownVLType(length=64)>, 0)]), edges={}, reg=None)}))" - ) - - def test_try_alloc_bigint_inc(self): - fn = Fn() - op0 = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3))) - op1 = OpCopy(fn, op0.out, GPRType()) - arg = op1.dest - op2 = OpInputMem(fn) - mem = op2.out - op3 = OpSetVLImm(fn, 32) - vl = op3.out - op4 = OpLoad(fn, arg, offset=0, mem=mem, vl=vl) - a = op4.RT - op5 = OpLI(fn, 0, vl=vl) - b = op5.out - op6 = OpSetCA(fn, True) - ca = op6.out - op7 = OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl) - s = op7.out - op8 = OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl) - mem = op8.mem_out - - reg_assignments = try_allocate_registers_without_spilling(fn.ops) - - expected_reg_assignments = { - op0.out: GPRRange(start=3, length=1), - op1.dest: GPRRange(start=3, length=1), - op2.out: GlobalMem.GlobalMem, - op3.out: VL.VL_MAXVL, - op4.RT: GPRRange(start=78, length=32), - op5.out: GPRRange(start=46, length=32), - op6.out: XERBit.CA, - op7.out: GPRRange(start=14, length=32), - op7.CA_out: XERBit.CA, - op8.mem_out: GlobalMem.GlobalMem, - } - - self.assertEqual(reg_assignments, expected_reg_assignments) - - def tst_try_alloc_concat(self, expected_regs, expected_dest_reg): - # type: (list[GPRRange], GPRRange) -> None - fn = Fn() - inputs = [] - expected_reg_assignments = {} - for i, r in enumerate(expected_regs): - vl = OpSetVLImm(fn, r.length).out - expected_reg_assignments[vl] = VL.VL_MAXVL - inp = OpLI(fn, i, vl=vl).out - inputs.append(inp) - expected_reg_assignments[inp] = r - concat = OpConcat(fn, inputs) - expected_reg_assignments[concat.dest] = expected_dest_reg - - reg_assignments = try_allocate_registers_without_spilling(fn.ops) - - for inp, reg in zip(inputs, expected_regs): - expected_reg_assignments[inp] = reg - - self.assertEqual(reg_assignments, expected_reg_assignments) - - def test_try_alloc_concat_1(self): - self.tst_try_alloc_concat([GPRRange(3)], GPRRange(3)) - - def test_try_alloc_concat_3(self): - self.tst_try_alloc_concat([GPRRange(3, 3)], GPRRange(3, 3)) - - def test_try_alloc_concat_3_5(self): - self.tst_try_alloc_concat([GPRRange(3, 3), GPRRange(6, 5)], - GPRRange(3, 8)) - - def test_try_alloc_concat_5_3(self): - self.tst_try_alloc_concat([GPRRange(3, 5), GPRRange(8, 3)], - GPRRange(3, 8)) - - def test_try_alloc_concat_1_2_3_4_5_6(self): - self.tst_try_alloc_concat([ - GPRRange(14, 1), - GPRRange(15, 2), - GPRRange(17, 3), - GPRRange(20, 4), - GPRRange(24, 5), - GPRRange(29, 6), - ], GPRRange(14, 21)) - - -if __name__ == "__main__": - unittest.main() diff --git a/src/bigint_presentation_code/test_toom_cook.py b/src/bigint_presentation_code/test_toom_cook.py deleted file mode 100644 index 6fff570..0000000 --- a/src/bigint_presentation_code/test_toom_cook.py +++ /dev/null @@ -1,378 +0,0 @@ -import unittest - -from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BYTES, SSAGPR, VL, FixedGPRRangeType, Fn, - GlobalMem, GPRRange, - GPRRangeType, OpCopy, - OpFuncArg, OpInputMem, - OpSetVLImm, OpStore, PreRASimState, SSAGPRRange, XERBit, - generate_assembly) -from bigint_presentation_code.register_allocator import allocate_registers -from bigint_presentation_code.toom_cook import ToomCookInstance, simple_mul -from bigint_presentation_code.util import FMap - - -class SimpleMul192x192: - def __init__(self): - self.fn = fn = Fn() - self.mem_in = mem = OpInputMem(fn).out - self.dest_ptr_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3))).out - self.lhs_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(4, 3))).out - self.rhs_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(7, 3))).out - dest_ptr = OpCopy(fn, self.dest_ptr_in, GPRRangeType()).dest - vl = OpSetVLImm(fn, 3).out - lhs = OpCopy(fn, self.lhs_in, GPRRangeType(3), vl=vl).dest - rhs = OpCopy(fn, self.rhs_in, GPRRangeType(3), vl=vl).dest - retval = simple_mul(fn, lhs, rhs) - vl = OpSetVLImm(fn, 6).out - self.mem_out = OpStore(fn, RS=retval, RA=dest_ptr, offset=0, - mem_in=mem, vl=vl).mem_out - - -class TestToomCook(unittest.TestCase): - maxDiff = None - - def test_toom_2_repr(self): - TOOM_2 = ToomCookInstance.make_toom_2() - # print(repr(repr(TOOM_2))) - self.assertEqual( - repr(TOOM_2), - "ToomCookInstance(lhs_part_count=2, rhs_part_count=2, " - "eval_points=(0, 1, POINT_AT_INFINITY), " - "lhs_eval_ops=(" - "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " - "EvalOpAdd(lhs=" - "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " - "rhs=" - "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), " - "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), " - "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})))," - " rhs_eval_ops=(" - "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " - "EvalOpAdd(lhs=" - "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " - "rhs=" - "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), " - "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), " - "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})))," - " prod_eval_ops=(" - "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " - "EvalOpSub(lhs=" - "EvalOpSub(lhs=" - "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), " - "rhs=" - "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " - "poly=EvalOpPoly({0: Fraction(-1, 1), 1: Fraction(1, 1)})), " - "rhs=" - "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), " - "poly=EvalOpPoly({" - "0: Fraction(-1, 1), 1: Fraction(1, 1), 2: Fraction(-1, 1)})), " - "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)}))))" - ) - - def test_toom_2_5_repr(self): - TOOM_2_5 = ToomCookInstance.make_toom_2_5() - # print(repr(repr(TOOM_2_5))) - self.assertEqual( - repr(TOOM_2_5), - "ToomCookInstance(lhs_part_count=3, rhs_part_count=2, " - "eval_points=(0, 1, -1, POINT_AT_INFINITY), lhs_eval_ops=(" - "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " - "EvalOpAdd(lhs=" - "EvalOpAdd(lhs=" - "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " - "rhs=EvalOpInput(lhs=2, rhs=0, " - "poly=EvalOpPoly({2: Fraction(1, 1)})), " - "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), " - "rhs=EvalOpInput(lhs=1, rhs=0, " - "poly=EvalOpPoly({1: Fraction(1, 1)})), " - "poly=EvalOpPoly({" - "0: Fraction(1, 1), 1: Fraction(1, 1), 2: Fraction(1, 1)})), " - "EvalOpSub(lhs=" - "EvalOpAdd(lhs=" - "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " - "rhs=EvalOpInput(lhs=2, rhs=0, " - "poly=EvalOpPoly({2: Fraction(1, 1)})), " - "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), " - "rhs=EvalOpInput(lhs=1, rhs=0, " - "poly=EvalOpPoly({1: Fraction(1, 1)})), poly=EvalOpPoly(" - "{0: Fraction(1, 1), 1: Fraction(-1, 1), 2: Fraction(1, 1)})), " - "EvalOpInput(lhs=2, rhs=0, " - "poly=EvalOpPoly({2: Fraction(1, 1)}))), rhs_eval_ops=(" - "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " - "EvalOpAdd(lhs=EvalOpInput(lhs=0, rhs=0, " - "poly=EvalOpPoly({0: Fraction(1, 1)})), rhs=" - "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), " - "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), " - "EvalOpSub(lhs=" - "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " - "rhs=EvalOpInput(lhs=1, rhs=0, " - "poly=EvalOpPoly({1: Fraction(1, 1)})), " - "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(-1, 1)})), " - "EvalOpInput(lhs=1, rhs=0, " - "poly=EvalOpPoly({1: Fraction(1, 1)}))), " - "prod_eval_ops=(" - "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " - "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpSub(lhs=" - "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), " - "rhs=EvalOpInput(lhs=2, rhs=0, " - "poly=EvalOpPoly({2: Fraction(1, 1)})), " - "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(-1, 1)})), " - "rhs=2, " - "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(-1, 2)})), rhs=" - "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)})), " - "poly=EvalOpPoly(" - "{1: Fraction(1, 2), 2: Fraction(-1, 2), 3: Fraction(-1, 1)})), " - "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpAdd(lhs=" - "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), " - "rhs=" - "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), " - "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(1, 1)})), rhs=2, " - "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(1, 2)})), rhs=" - "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " - "poly=EvalOpPoly(" - "{0: Fraction(-1, 1), 1: Fraction(1, 2), 2: Fraction(1, 2)})), " - "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)}))))" - ) - - def test_reversed_toom_2_5_repr(self): - TOOM_2_5 = ToomCookInstance.make_toom_2_5().reversed() - # print(repr(repr(TOOM_2_5))) - self.assertEqual( - repr(TOOM_2_5), - "ToomCookInstance(lhs_part_count=2, rhs_part_count=3, " - "eval_points=(0, 1, -1, POINT_AT_INFINITY), lhs_eval_ops=(" - "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " - "EvalOpAdd(lhs=" - "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " - "rhs=" - "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), " - "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), " - "EvalOpSub(lhs=" - "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " - "rhs=" - "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), " - "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(-1, 1)})), " - "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})))," - " rhs_eval_ops=(" - "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " - "EvalOpAdd(lhs=EvalOpAdd(lhs=" - "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " - "rhs=" - "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), " - "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), rhs=" - "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), " - "poly=EvalOpPoly(" - "{0: Fraction(1, 1), 1: Fraction(1, 1), 2: Fraction(1, 1)})), " - "EvalOpSub(lhs=EvalOpAdd(lhs=" - "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " - "rhs=" - "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), " - "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), rhs=" - "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), " - "poly=EvalOpPoly(" - "{0: Fraction(1, 1), 1: Fraction(-1, 1), 2: Fraction(1, 1)})), " - "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})))," - " prod_eval_ops=(" - "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " - "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpSub(lhs=" - "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), " - "rhs=" - "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), " - "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(-1, 1)})), " - "rhs=2, " - "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(-1, 2)})), rhs=" - "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)})), " - "poly=EvalOpPoly(" - "{1: Fraction(1, 2), 2: Fraction(-1, 2), 3: Fraction(-1, 1)})), " - "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpAdd(lhs=" - "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), " - "rhs=" - "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), " - "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(1, 1)})), rhs=2, " - "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(1, 2)})), rhs=" - "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " - "poly=EvalOpPoly(" - "{0: Fraction(-1, 1), 1: Fraction(1, 2), 2: Fraction(1, 2)})), " - "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)}))))" - ) - - def test_simple_mul_192x192_pre_ra_sim(self): - # test multiplying: - # 0x000191acb262e15b_4c6b5f2b19e1a53e_821a2342132c5b57 - # * 0x4a37c0567bcbab53_cf1f597598194ae6_208a49071aeec507 - # == - # int("0x00074736574206e_6f69746163696c70" - # "_69746c756d207469_622d3438333e2d32" - # "_3931783239312079_7261727469627261", base=0) - # == int.from_bytes(b"arbitrary 192x192->384-bit multiplication test", - # 'little') - code = SimpleMul192x192() - dest_ptr = 0x100 - state = PreRASimState( - gprs={}, VLs={}, CAs={}, global_mems={code.mem_in: FMap()}, - stack_slots={}, fixed_gprs={ - code.dest_ptr_in: (dest_ptr,), - code.lhs_in: (0x821a2342132c5b57, 0x4c6b5f2b19e1a53e, - 0x000191acb262e15b), - code.rhs_in: (0x208a49071aeec507, 0xcf1f597598194ae6, - 0x4a37c0567bcbab53) - }) - code.fn.pre_ra_sim(state) - expected_bytes = b"arbitrary 192x192->384-bit multiplication test" - OUT_BYTE_COUNT = 6 * GPR_SIZE_IN_BYTES - expected_bytes = expected_bytes.ljust(OUT_BYTE_COUNT, b'\0') - mem_out = state.global_mems[code.mem_out] - out_bytes = bytes( - mem_out.get(dest_ptr + i, 0) for i in range(OUT_BYTE_COUNT)) - self.assertEqual(out_bytes, expected_bytes) - - def test_simple_mul_192x192_ops(self): - code = SimpleMul192x192() - fn = code.fn - self.assertEqual([repr(v) for v in fn.ops], [ - 'OpInputMem(#0, <#0.out: GlobalMemType()>)', - 'OpFuncArg(#1, <#1.out: )>>)', - 'OpFuncArg(#2, <#2.out: )>>)', - 'OpFuncArg(#3, <#3.out: )>>)', - 'OpCopy(#4, <#4.dest: >, src=<#1.out: )>>, ' - 'vl=None)', - 'OpSetVLImm(#5, <#5.out: KnownVLType(length=3)>)', - 'OpCopy(#6, <#6.dest: >, ' - 'src=<#2.out: )>>, ' - 'vl=<#5.out: KnownVLType(length=3)>)', - 'OpCopy(#7, <#7.dest: >, ' - 'src=<#3.out: )>>, ' - 'vl=<#5.out: KnownVLType(length=3)>)', - 'OpSplit(#8, results=(<#8.results[0]: >, ' - '<#8.results[1]: >, <#8.results[2]: >), ' - 'src=<#7.dest: >)', - 'OpSetVLImm(#9, <#9.out: KnownVLType(length=3)>)', - 'OpLI(#10, <#10.out: >, value=0, vl=None)', - 'OpBigIntMulDiv(#11, <#11.RT: >, ' - 'RA=<#6.dest: >, RB=<#8.results[0]: >, ' - 'RC=<#10.out: >, <#11.RS: >, is_div=False, ' - 'vl=<#9.out: KnownVLType(length=3)>)', - 'OpConcat(#12, <#12.dest: >, sources=(' - '<#11.RT: >, <#11.RS: >))', - 'OpBigIntMulDiv(#13, <#13.RT: >, ' - 'RA=<#6.dest: >, RB=<#8.results[1]: >, ' - 'RC=<#10.out: >, <#13.RS: >, is_div=False, ' - 'vl=<#9.out: KnownVLType(length=3)>)', - 'OpSplit(#14, results=(<#14.results[0]: >, ' - '<#14.results[1]: >), src=<#12.dest: >)', - 'OpSetCA(#15, <#15.out: CAType()>, value=False)', - 'OpBigIntAddSub(#16, <#16.out: >, ' - 'lhs=<#13.RT: >, rhs=<#14.results[1]: >, ' - 'CA_in=<#15.out: CAType()>, <#16.CA_out: CAType()>, is_sub=False, ' - 'vl=<#9.out: KnownVLType(length=3)>)', - 'OpBigIntAddSub(#17, <#17.out: >, ' - 'lhs=<#13.RS: >, rhs=<#10.out: >, ' - 'CA_in=<#16.CA_out: CAType()>, <#17.CA_out: CAType()>, ' - 'is_sub=False, vl=None)', - 'OpConcat(#18, <#18.dest: >, sources=(' - '<#14.results[0]: >, <#16.out: >, ' - '<#17.out: >))', - 'OpBigIntMulDiv(#19, <#19.RT: >, ' - 'RA=<#6.dest: >, RB=<#8.results[2]: >, ' - 'RC=<#10.out: >, <#19.RS: >, is_div=False, ' - 'vl=<#9.out: KnownVLType(length=3)>)', - 'OpSplit(#20, results=(<#20.results[0]: >, ' - '<#20.results[1]: >), src=<#18.dest: >)', - 'OpSetCA(#21, <#21.out: CAType()>, value=False)', - 'OpBigIntAddSub(#22, <#22.out: >, ' - 'lhs=<#19.RT: >, rhs=<#20.results[1]: >, ' - 'CA_in=<#21.out: CAType()>, <#22.CA_out: CAType()>, is_sub=False, ' - 'vl=<#9.out: KnownVLType(length=3)>)', - 'OpBigIntAddSub(#23, <#23.out: >, ' - 'lhs=<#19.RS: >, rhs=<#10.out: >, ' - 'CA_in=<#22.CA_out: CAType()>, <#23.CA_out: CAType()>, ' - 'is_sub=False, vl=None)', - 'OpConcat(#24, <#24.dest: >, sources=(' - '<#20.results[0]: >, <#22.out: >, ' - '<#23.out: >))', - 'OpSetVLImm(#25, <#25.out: KnownVLType(length=6)>)', - 'OpStore(#26, RS=<#24.dest: >, ' - 'RA=<#4.dest: >, offset=0, ' - 'mem_in=<#0.out: GlobalMemType()>, ' - '<#26.mem_out: GlobalMemType()>, ' - 'vl=<#25.out: KnownVLType(length=6)>)' - ]) - - # FIXME: register allocator currently allocates wrong registers - @unittest.expectedFailure - def test_simple_mul_192x192_reg_alloc(self): - code = SimpleMul192x192() - fn = code.fn - assigned_registers = allocate_registers(fn.ops) - self.assertEqual(assigned_registers, { - fn.ops[13].RS: GPRRange(9), # type: ignore - fn.ops[14].results[0]: GPRRange(6), # type: ignore - fn.ops[14].results[1]: GPRRange(7, length=3), # type: ignore - fn.ops[15].out: XERBit.CA, # type: ignore - fn.ops[16].out: GPRRange(7, length=3), # type: ignore - fn.ops[16].CA_out: XERBit.CA, # type: ignore - fn.ops[17].out: GPRRange(10), # type: ignore - fn.ops[17].CA_out: XERBit.CA, # type: ignore - fn.ops[18].dest: GPRRange(6, length=5), # type: ignore - fn.ops[19].RT: GPRRange(3, length=3), # type: ignore - fn.ops[19].RS: GPRRange(9), # type: ignore - fn.ops[20].results[0]: GPRRange(6, length=2), # type: ignore - fn.ops[20].results[1]: GPRRange(8, length=3), # type: ignore - fn.ops[21].out: XERBit.CA, # type: ignore - fn.ops[22].out: GPRRange(8, length=3), # type: ignore - fn.ops[22].CA_out: XERBit.CA, # type: ignore - fn.ops[23].out: GPRRange(11), # type: ignore - fn.ops[23].CA_out: XERBit.CA, # type: ignore - fn.ops[24].dest: GPRRange(6, length=6), # type: ignore - fn.ops[25].out: VL.VL_MAXVL, # type: ignore - fn.ops[26].mem_out: GlobalMem.GlobalMem, # type: ignore - fn.ops[0].out: GlobalMem.GlobalMem, # type: ignore - fn.ops[1].out: GPRRange(3), # type: ignore - fn.ops[2].out: GPRRange(4, length=3), # type: ignore - fn.ops[3].out: GPRRange(7, length=3), # type: ignore - fn.ops[4].dest: GPRRange(12), # type: ignore - fn.ops[5].out: VL.VL_MAXVL, # type: ignore - fn.ops[6].dest: GPRRange(17, length=3), # type: ignore - fn.ops[7].dest: GPRRange(14, length=3), # type: ignore - fn.ops[8].results[0]: GPRRange(14), # type: ignore - fn.ops[8].results[1]: GPRRange(15), # type: ignore - fn.ops[8].results[2]: GPRRange(16), # type: ignore - fn.ops[9].out: VL.VL_MAXVL, # type: ignore - fn.ops[10].out: GPRRange(9), # type: ignore - fn.ops[11].RT: GPRRange(6, length=3), # type: ignore - fn.ops[11].RS: GPRRange(9), # type: ignore - fn.ops[12].dest: GPRRange(6, length=4), # type: ignore - fn.ops[13].RT: GPRRange(3, length=3) # type: ignore - }) - self.fail("register allocator currently allocates wrong registers") - - # FIXME: register allocator currently allocates wrong registers - @unittest.expectedFailure - def test_simple_mul_192x192_asm(self): - code = SimpleMul192x192() - asm = generate_assembly(code.fn.ops) - self.assertEqual(asm, [ - 'or 12, 3, 3', - 'setvl 0, 0, 3, 0, 1, 1', - 'sv.or *17, *4, *4', - 'sv.or *14, *7, *7', - 'setvl 0, 0, 3, 0, 1, 1', - 'addi 9, 0, 0', - 'sv.maddedu *6, *17, 14, 9', - 'sv.maddedu *3, *17, 15, 9', - 'addic 0, 0, 0', - 'sv.adde *7, *3, *7', - 'adde 10, 9, 9', - 'sv.maddedu *3, *17, 16, 9', - 'addic 0, 0, 0', - 'sv.adde *8, *3, *8', - 'adde 11, 9, 9', - 'setvl 0, 0, 6, 0, 1, 1', - 'sv.std *6, 0(12)', - 'bclr 20, 0, 0' - ]) - self.fail("register allocator currently allocates wrong registers") - - -if __name__ == "__main__": - unittest.main() diff --git a/src/bigint_presentation_code/toom_cook.py b/src/bigint_presentation_code/toom_cook.py index 9e3ec74..246f654 100644 --- a/src/bigint_presentation_code/toom_cook.py +++ b/src/bigint_presentation_code/toom_cook.py @@ -4,13 +4,16 @@ Toom-Cook multiplication algorithm generator for SVP64 from abc import abstractmethod from enum import Enum from fractions import Fraction -from typing import Any, Generic, Iterable, Mapping, Sequence, TypeVar, Union +from typing import Any, Generic, Iterable, Mapping, TypeVar, Union from nmutil.plain_data import plain_data -from bigint_presentation_code.compiler_ir import Fn, Op, OpBigIntAddSub, OpBigIntMulDiv, OpConcat, OpLI, OpSetCA, OpSetVLImm, OpSplit, SSAGPRRange +from bigint_presentation_code.compiler_ir import (Fn, OpBigIntAddSub, + OpBigIntMulDiv, OpConcat, + OpLI, OpSetCA, OpSetVLImm, + OpSplit, SSAGPRRange) from bigint_presentation_code.matrix import Matrix -from bigint_presentation_code.util import Literal, OSet, final +from bigint_presentation_code.type_util import Literal, final @final @@ -158,8 +161,8 @@ class EvalOpPoly: return f"EvalOpPoly({self.coefficients})" -_EvalOpLHS = TypeVar("_EvalOpLHS", int, "EvalOp") -_EvalOpRHS = TypeVar("_EvalOpRHS", int, "EvalOp") +_EvalOpLHS = TypeVar("_EvalOpLHS", int, "EvalOp[Any, Any]") +_EvalOpRHS = TypeVar("_EvalOpRHS", int, "EvalOp[Any, Any]") @plain_data(frozen=True, unsafe_hash=True) @@ -238,7 +241,7 @@ class EvalOpInput(EvalOp[int, Literal[0]]): __slots__ = () def __init__(self, lhs, rhs=0): - # type: (...) -> None + # type: (int, int) -> None if lhs < 0: raise ValueError("Input part_index (lhs) must be >= 0") if rhs != 0: diff --git a/src/bigint_presentation_code/type_util.py b/src/bigint_presentation_code/type_util.py new file mode 100644 index 0000000..ed7296f --- /dev/null +++ b/src/bigint_presentation_code/type_util.py @@ -0,0 +1,32 @@ +from typing import TYPE_CHECKING, Any, NoReturn, Union + +if TYPE_CHECKING: + from typing_extensions import Literal, Self, final +else: + def final(v): + return v + + class _Literal: + def __getitem__(self, v): + if isinstance(v, tuple): + return Union[tuple(type(i) for i in v)] + return type(v) + + Literal = _Literal() + + Self = Any + + +# pyright currently doesn't like typing_extensions' definition +# -- added to typing in python 3.11 +def assert_never(arg): + # type: (NoReturn) -> NoReturn + raise AssertionError("got to code that's supposed to be unreachable") + + +__all__ = [ + "assert_never", + "final", + "Literal", + "Self", +] diff --git a/src/bigint_presentation_code/type_util.pyi b/src/bigint_presentation_code/type_util.pyi new file mode 100644 index 0000000..630ca20 --- /dev/null +++ b/src/bigint_presentation_code/type_util.pyi @@ -0,0 +1,19 @@ +from typing import NoReturn, TypeVar + +from typing_extensions import Literal, Self, final + +_T_co = TypeVar("_T_co", covariant=True) +_T = TypeVar("_T") + + +# pyright currently doesn't like typing_extensions' definition +# -- added to typing in python 3.11 +def assert_never(arg: NoReturn) -> NoReturn: ... + + +__all__ = [ + "assert_never", + "final", + "Literal", + "Self", +] diff --git a/src/bigint_presentation_code/util.py b/src/bigint_presentation_code/util.py index aeea240..4b39787 100644 --- a/src/bigint_presentation_code/util.py +++ b/src/bigint_presentation_code/util.py @@ -1,50 +1,25 @@ from abc import abstractmethod -from typing import (TYPE_CHECKING, AbstractSet, Any, Iterable, Iterator, - Mapping, MutableSet, NoReturn, TypeVar, Union) +from typing import (AbstractSet, Any, Iterable, Iterator, Mapping, MutableSet, + TypeVar, overload) -if TYPE_CHECKING: - from typing_extensions import Literal, Self, final -else: - def final(v): - return v - - class _Literal: - def __getitem__(self, v): - if isinstance(v, tuple): - return Union[tuple(type(i) for i in v)] - return type(v) - - Literal = _Literal() - - Self = Any +from bigint_presentation_code.type_util import Self, final _T_co = TypeVar("_T_co", covariant=True) _T = TypeVar("_T") __all__ = [ - "assert_never", "BaseBitSet", "bit_count", "BitSet", "FBitSet", - "final", "FMap", - "Literal", "OFSet", "OSet", - "Self", "top_set_bit_index", "trailing_zero_count", ] -# pyright currently doesn't like typing_extensions' definition -# -- added to typing in python 3.11 -def assert_never(arg): - # type: (NoReturn) -> NoReturn - raise AssertionError("got to code that's supposed to be unreachable") - - class OFSet(AbstractSet[_T_co]): """ ordered frozen set """ __slots__ = "__items", @@ -54,18 +29,23 @@ class OFSet(AbstractSet[_T_co]): self.__items = {v: None for v in items} def __contains__(self, x): + # type: (Any) -> bool return x in self.__items def __iter__(self): + # type: () -> Iterator[_T_co] return iter(self.__items) def __len__(self): + # type: () -> int return len(self.__items) def __hash__(self): + # type: () -> int return self._hash() def __repr__(self): + # type: () -> str if len(self) == 0: return "OFSet()" return f"OFSet({list(self)})" @@ -80,12 +60,15 @@ class OSet(MutableSet[_T]): self.__items = {v: None for v in items} def __contains__(self, x): + # type: (Any) -> bool return x in self.__items def __iter__(self): + # type: () -> Iterator[_T] return iter(self.__items) def __len__(self): + # type: () -> int return len(self.__items) def add(self, value): @@ -97,6 +80,7 @@ class OSet(MutableSet[_T]): self.__items.pop(value, None) def __repr__(self): + # type: () -> str if len(self) == 0: return "OSet()" return f"OSet({list(self)})" @@ -106,6 +90,21 @@ class FMap(Mapping[_T, _T_co]): """ordered frozen hashable mapping""" __slots__ = "__items", "__hash" + @overload + def __init__(self, items): + # type: (Mapping[_T, _T_co]) -> None + ... + + @overload + def __init__(self, items): + # type: (Iterable[tuple[_T, _T_co]]) -> None + ... + + @overload + def __init__(self): + # type: () -> None + ... + def __init__(self, items=()): # type: (Mapping[_T, _T_co] | Iterable[tuple[_T, _T_co]]) -> None self.__items = dict(items) # type: dict[_T, _T_co] @@ -120,20 +119,23 @@ class FMap(Mapping[_T, _T_co]): return iter(self.__items) def __len__(self): + # type: () -> int return len(self.__items) def __eq__(self, other): - # type: (object) -> bool + # type: (FMap[Any, Any] | Any) -> bool if isinstance(other, FMap): return self.__items == other.__items return super().__eq__(other) def __hash__(self): + # type: () -> int if self.__hash is None: self.__hash = hash(frozenset(self.items())) return self.__hash def __repr__(self): + # type: () -> str return f"FMap({self.__items})" @@ -153,7 +155,7 @@ def top_set_bit_index(v, default=-1): try: # added in cpython 3.10 - bit_count = int.bit_count # type: ignore[attr] + bit_count = int.bit_count # type: ignore except AttributeError: def bit_count(v): # type: (int) -> int @@ -177,16 +179,20 @@ class BaseBitSet(AbstractSet[int]): def __init__(self, items=(), bits=0): # type: (Iterable[int], int) -> None - for item in items: - if item < 0: - raise ValueError("can't store negative integers") - bits |= 1 << item + if isinstance(items, BaseBitSet): + bits |= items.bits + else: + for item in items: + if item < 0: + raise ValueError("can't store negative integers") + bits |= 1 << item if bits < 0: raise ValueError("can't store an infinite set") self.__bits = bits @property def bits(self): + # type: () -> int return self.__bits @bits.setter @@ -199,6 +205,7 @@ class BaseBitSet(AbstractSet[int]): self.__bits = bits def __contains__(self, x): + # type: (Any) -> bool if isinstance(x, int) and x >= 0: return (1 << x) & self.bits != 0 return False @@ -220,9 +227,11 @@ class BaseBitSet(AbstractSet[int]): bits -= 1 << index def __len__(self): + # type: () -> int return bit_count(self.bits) def __repr__(self): + # type: () -> str if self.bits == 0: return f"{self.__class__.__name__}()" if self.bits > 0xFFFFFFFF and len(self) < 10: @@ -231,7 +240,7 @@ class BaseBitSet(AbstractSet[int]): return f"{self.__class__.__name__}(bits={hex(self.bits)})" def __eq__(self, other): - # type: (object) -> bool + # type: (Any) -> bool if not isinstance(other, BaseBitSet): return super().__eq__(other) return self.bits == other.bits @@ -320,6 +329,7 @@ class BitSet(BaseBitSet, MutableSet[int]): self.bits &= ~(1 << value) def clear(self): + # type: () -> None self.bits = 0 def __ior__(self, it): @@ -361,4 +371,5 @@ class FBitSet(BaseBitSet): return True def __hash__(self): + # type: () -> int return super()._hash() diff --git a/src/bigint_presentation_code/util.pyi b/src/bigint_presentation_code/util.pyi deleted file mode 100644 index 6315823..0000000 --- a/src/bigint_presentation_code/util.pyi +++ /dev/null @@ -1,190 +0,0 @@ -from abc import abstractmethod -from typing import (AbstractSet, Any, Iterable, Iterator, Mapping, MutableSet, - NoReturn, TypeVar, overload) - -from typing_extensions import Literal, Self, final - -_T_co = TypeVar("_T_co", covariant=True) -_T = TypeVar("_T") - -__all__ = [ - "assert_never", - "BaseBitSet", - "bit_count", - "BitSet", - "FBitSet", - "final", - "FMap", - "Literal", - "OFSet", - "OSet", - "Self", - "top_set_bit_index", - "trailing_zero_count", -] - - -# pyright currently doesn't like typing_extensions' definition -# -- added to typing in python 3.11 -def assert_never(arg): - # type: (NoReturn) -> NoReturn - raise AssertionError("got to code that's supposed to be unreachable") - - -class OFSet(AbstractSet[_T_co]): - """ ordered frozen set """ - - def __init__(self, items: Iterable[_T_co] = ()): - ... - - def __contains__(self, x: object) -> bool: - ... - - def __iter__(self) -> Iterator[_T_co]: - ... - - def __len__(self) -> int: - ... - - def __hash__(self) -> int: - ... - - def __repr__(self) -> str: - ... - - -class OSet(MutableSet[_T]): - """ ordered mutable set """ - - def __init__(self, items: Iterable[_T] = ()): - ... - - def __contains__(self, x: object) -> bool: - ... - - def __iter__(self) -> Iterator[_T]: - ... - - def __len__(self) -> int: - ... - - def add(self, value: _T) -> None: - ... - - def discard(self, value: _T) -> None: - ... - - def __repr__(self) -> str: - ... - - -class FMap(Mapping[_T, _T_co]): - """ordered frozen hashable mapping""" - @overload - def __init__(self, items: Mapping[_T, _T_co]): ... - @overload - def __init__(self, items: Iterable[tuple[_T, _T_co]]): ... - @overload - def __init__(self): ... - - def __getitem__(self, item: _T) -> _T_co: - ... - - def __iter__(self) -> Iterator[_T]: - ... - - def __len__(self) -> int: - ... - - def __eq__(self, other: object) -> bool: - ... - - def __hash__(self) -> int: - ... - - def __repr__(self) -> str: - ... - - -def trailing_zero_count(v: int, default: int = -1) -> int: ... -def top_set_bit_index(v: int, default: int = -1) -> int: ... -def bit_count(v: int) -> int: ... - - -class BaseBitSet(AbstractSet[int]): - @classmethod - @abstractmethod - def _frozen(cls) -> bool: ... - - @classmethod - def _from_bits(cls, bits: int) -> Self: ... - - def __init__(self, items: Iterable[int] = (), bits: int = 0): ... - - @property - def bits(self) -> int: - ... - - @bits.setter - def bits(self, bits: int) -> None: ... - - def __contains__(self, x: object) -> bool: ... - - def __iter__(self) -> Iterator[int]: ... - - def __reversed__(self) -> Iterator[int]: ... - - def __len__(self) -> int: ... - - def __repr__(self) -> str: ... - - def __eq__(self, other: object) -> bool: ... - - def __and__(self, other: Iterable[Any]) -> Self: ... - - __rand__ = __and__ - - def __or__(self, other: Iterable[Any]) -> Self: ... - - __ror__ = __or__ - - def __xor__(self, other: Iterable[Any]) -> Self: ... - - __rxor__ = __xor__ - - def __sub__(self, other: Iterable[Any]) -> Self: ... - - def __rsub__(self, other: Iterable[Any]) -> Self: ... - - def isdisjoint(self, other: Iterable[Any]) -> bool: ... - - -class BitSet(BaseBitSet, MutableSet[int]): - @final - @classmethod - def _frozen(cls) -> Literal[False]: ... - - def add(self, value: int) -> None: ... - - def discard(self, value: int) -> None: ... - - def clear(self) -> None: ... - - def __ior__(self, it: AbstractSet[Any]) -> Self: ... - - def __iand__(self, it: AbstractSet[Any]) -> Self: ... - - def __ixor__(self, it: AbstractSet[Any]) -> Self: ... - - def __isub__(self, it: AbstractSet[Any]) -> Self: ... - - -class FBitSet(BaseBitSet): - @property - def bits(self) -> int: ... - - @final - @classmethod - def _frozen(cls) -> Literal[True]: ... - - def __hash__(self) -> int: ... diff --git a/typings/cached_property.pyi b/typings/cached_property.pyi index b8b1f30..5ec7085 100644 --- a/typings/cached_property.pyi +++ b/typings/cached_property.pyi @@ -1,15 +1 @@ -from typing import Any, Callable, Generic, TypeVar, overload - -_T = TypeVar("_T") - - -class cached_property(Generic[_T]): - def __init__(self, func: Callable[[Any], _T]) -> None: ... - - @overload - def __get__(self, instance: None, - owner: type[Any] | None = ...) -> cached_property[_T]: ... - - @overload - def __get__(self, instance: object, - owner: type[Any] | None = ...) -> _T: ... +cached_property = property