--- /dev/null
+import unittest
+
+from bigint_presentation_code.compiler_ir import (VL, FixedGPRRangeType, Fn,
+ GlobalMem, GPRRange, GPRType,
+ OpBigIntAddSub, OpConcat,
+ OpCopy, OpFuncArg,
+ OpInputMem, OpLI, OpLoad,
+ OpSetCA, OpSetVLImm, OpStore,
+ RegLoc, SSAVal, XERBit,
+ generate_assembly,
+ op_set_to_list)
+
+
+class TestCompilerIR(unittest.TestCase):
+ maxDiff = None
+
+ def test_op_set_to_list(self):
+ fn = Fn()
+ op0 = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3)))
+ op1 = OpCopy(fn, op0.out, GPRType())
+ arg = op1.dest
+ op2 = OpInputMem(fn)
+ mem = op2.out
+ op3 = OpSetVLImm(fn, 32)
+ vl = op3.out
+ op4 = OpLoad(fn, arg, offset=0, mem=mem, vl=vl)
+ a = op4.RT
+ op5 = OpLI(fn, 1)
+ b_0 = op5.out
+ op6 = OpSetVLImm(fn, 31)
+ vl = op6.out
+ op7 = OpLI(fn, 0, vl=vl)
+ b_rest = op7.out
+ op8 = OpConcat(fn, [b_0, b_rest])
+ b = op8.dest
+ op9 = OpSetVLImm(fn, 32)
+ vl = op9.out
+ op10 = OpSetCA(fn, False)
+ ca = op10.out
+ op11 = OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl)
+ s = op11.out
+ op12 = OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl)
+ mem = op12.mem_out
+
+ expected_ops = [
+ op10, # OpSetCA(fn, False)
+ op9, # OpSetVLImm(fn, 32)
+ op6, # OpSetVLImm(fn, 31)
+ op5, # OpLI(fn, 1)
+ op3, # OpSetVLImm(fn, 32)
+ op2, # OpInputMem(fn)
+ op0, # OpFuncArg(fn, FixedGPRRangeType(GPRRange(3)))
+ op7, # OpLI(fn, 0, vl=vl)
+ op1, # OpCopy(fn, op0.out, GPRType())
+ op8, # OpConcat(fn, [b_0, b_rest])
+ op4, # OpLoad(fn, arg, offset=0, mem=mem, vl=vl)
+ op11, # OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl)
+ op12, # OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl)
+ ]
+
+ ops = op_set_to_list(fn.ops[::-1])
+ if ops != expected_ops:
+ self.assertEqual(repr(ops), repr(expected_ops))
+
+ def tst_generate_assembly(self, use_reg_alloc=False):
+ fn = Fn()
+ op0 = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3)))
+ op1 = OpCopy(fn, op0.out, GPRType())
+ arg = op1.dest
+ op2 = OpInputMem(fn)
+ mem = op2.out
+ op3 = OpSetVLImm(fn, 32)
+ vl = op3.out
+ op4 = OpLoad(fn, arg, offset=0, mem=mem, vl=vl)
+ a = op4.RT
+ op5 = OpLI(fn, 0, vl=vl)
+ b = op5.out
+ op6 = OpSetCA(fn, True)
+ ca = op6.out
+ op7 = OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl)
+ s = op7.out
+ op8 = OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl)
+ mem = op8.mem_out
+
+ assigned_registers = {
+ op0.out: GPRRange(start=3, length=1),
+ op1.dest: GPRRange(start=3, length=1),
+ op2.out: GlobalMem.GlobalMem,
+ op3.out: VL.VL_MAXVL,
+ op4.RT: GPRRange(start=78, length=32),
+ op5.out: GPRRange(start=46, length=32),
+ op6.out: XERBit.CA,
+ op7.out: GPRRange(start=14, length=32),
+ op7.CA_out: XERBit.CA,
+ op8.mem_out: GlobalMem.GlobalMem,
+ } # type: dict[SSAVal, RegLoc] | None
+
+ if use_reg_alloc:
+ assigned_registers = None
+
+ asm = generate_assembly(fn.ops, assigned_registers)
+ self.assertEqual(asm, [
+ "setvl 0, 0, 32, 0, 1, 1",
+ "sv.ld *78, 0(3)",
+ "sv.addi *46, 0, 0",
+ "subfic 0, 0, -1",
+ "sv.adde *14, *78, *46",
+ "sv.std *14, 0(3)",
+ "bclr 20, 0, 0",
+ ])
+
+ def test_generate_assembly(self):
+ self.tst_generate_assembly()
+
+ def test_generate_assembly_with_register_allocator(self):
+ self.tst_generate_assembly(use_reg_alloc=True)
+
+
+if __name__ == "__main__":
+ unittest.main()
--- /dev/null
+import unittest
+from fractions import Fraction
+
+from bigint_presentation_code.matrix import Matrix, SpecialMatrix
+
+
+class TestMatrix(unittest.TestCase):
+ def test_repr(self):
+ self.assertEqual(repr(Matrix(2, 3, [0, 1, 2,
+ 3, 4, 5])),
+ 'Matrix(height=2, width=3, data=[\n'
+ ' 0, 1, 2,\n'
+ ' 3, 4, 5,\n'
+ '])')
+ self.assertEqual(repr(Matrix(2, 3, [0, 1, Fraction(2) / 3,
+ 3, 4, 5])),
+ 'Matrix(height=2, width=3, data=[\n'
+ ' 0, 1, Fraction(2, 3),\n'
+ ' 3, 4, 5,\n'
+ '])')
+ self.assertEqual(repr(Matrix(0, 3)), 'Matrix(height=0, width=3)')
+ self.assertEqual(repr(Matrix(2, 0)), 'Matrix(height=2, width=0)')
+
+ def test_eq(self):
+ self.assertFalse(Matrix(1, 1) == 5)
+ self.assertFalse(5 == Matrix(1, 1))
+ self.assertFalse(Matrix(2, 1) == Matrix(1, 1))
+ self.assertFalse(Matrix(1, 2) == Matrix(1, 1))
+ self.assertTrue(Matrix(1, 1) == Matrix(1, 1))
+ self.assertTrue(Matrix(1, 1, [1]) == Matrix(1, 1, [1]))
+ self.assertFalse(Matrix(1, 1, [2]) == Matrix(1, 1, [1]))
+
+ def test_add(self):
+ self.assertEqual(Matrix(2, 2, [1, 2, 3, 4])
+ + Matrix(2, 2, [40, 30, 20, 10]),
+ Matrix(2, 2, [41, 32, 23, 14]))
+
+ def test_identity(self):
+ self.assertEqual(Matrix(2, 2, data=SpecialMatrix.Identity),
+ Matrix(2, 2, [1, 0,
+ 0, 1]))
+ self.assertEqual(Matrix(1, 3, data=SpecialMatrix.Identity),
+ Matrix(1, 3, [1, 0, 0]))
+ self.assertEqual(Matrix(2, 3, data=SpecialMatrix.Identity),
+ Matrix(2, 3, [1, 0, 0,
+ 0, 1, 0]))
+ self.assertEqual(Matrix(3, 3, data=SpecialMatrix.Identity),
+ Matrix(3, 3, [1, 0, 0,
+ 0, 1, 0,
+ 0, 0, 1]))
+
+ def test_sub(self):
+ self.assertEqual(Matrix(2, 2, [40, 30, 20, 10])
+ - Matrix(2, 2, [-1, -2, -3, -4]),
+ Matrix(2, 2, [41, 32, 23, 14]))
+
+ def test_neg(self):
+ self.assertEqual(-Matrix(2, 2, [40, 30, 20, 10]),
+ Matrix(2, 2, [-40, -30, -20, -10]))
+
+ def test_mul(self):
+ self.assertEqual(Matrix(2, 2, [1, 2, 3, 4]) * Fraction(3, 2),
+ Matrix(2, 2, [Fraction(3, 2), 3, Fraction(9, 2), 6]))
+ self.assertEqual(Fraction(3, 2) * Matrix(2, 2, [1, 2, 3, 4]),
+ Matrix(2, 2, [Fraction(3, 2), 3, Fraction(9, 2), 6]))
+
+ def test_matmul(self):
+ self.assertEqual(Matrix(2, 2, [1, 2, 3, 4])
+ @ Matrix(2, 2, [4, 3, 2, 1]),
+ Matrix(2, 2, [8, 5, 20, 13]))
+ self.assertEqual(Matrix(3, 2, [6, 5, 4, 3, 2, 1])
+ @ Matrix(2, 1, [1, 2]),
+ Matrix(3, 1, [16, 10, 4]))
+
+ def test_inverse(self):
+ self.assertEqual(Matrix(0, 0).inverse(), Matrix(0, 0))
+ self.assertEqual(Matrix(1, 1, [2]).inverse(),
+ Matrix(1, 1, [Fraction(1, 2)]))
+ self.assertEqual(Matrix(1, 1, [1]).inverse(),
+ Matrix(1, 1, [1]))
+ self.assertEqual(Matrix(2, 2, [1, 0, 1, 1]).inverse(),
+ Matrix(2, 2, [1, 0, -1, 1]))
+ self.assertEqual(Matrix(3, 3, [0, 1, 0,
+ 1, 0, 0,
+ 0, 0, 1]).inverse(),
+ Matrix(3, 3, [0, 1, 0,
+ 1, 0, 0,
+ 0, 0, 1]))
+ _1_2 = Fraction(1, 2)
+ _1_3 = Fraction(1, 3)
+ _1_6 = Fraction(1, 6)
+ self.assertEqual(Matrix(5, 5, [1, 0, 0, 0, 0,
+ 1, 1, 1, 1, 1,
+ 1, -1, 1, -1, 1,
+ 1, -2, 4, -8, 16,
+ 0, 0, 0, 0, 1]).inverse(),
+ Matrix(5, 5, [1, 0, 0, 0, 0,
+ _1_2, _1_3, -1, _1_6, -2,
+ -1, _1_2, _1_2, 0, -1,
+ -_1_2, _1_6, _1_2, -_1_6, 2,
+ 0, 0, 0, 0, 1]))
+ with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"):
+ Matrix(1, 1, [0]).inverse()
+ with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"):
+ Matrix(2, 2, [0, 0, 1, 1]).inverse()
+ with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"):
+ Matrix(2, 2, [1, 0, 1, 0]).inverse()
+ with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"):
+ Matrix(2, 2, [1, 1, 1, 1]).inverse()
+
+
+if __name__ == "__main__":
+ unittest.main()
--- /dev/null
+import unittest
+
+from bigint_presentation_code.compiler_ir import (VL, FixedGPRRangeType, Fn,
+ GlobalMem, GPRRange, GPRType,
+ OpBigIntAddSub, OpConcat,
+ OpCopy, OpFuncArg,
+ OpInputMem, OpLI, OpLoad,
+ OpSetCA, OpSetVLImm, OpStore,
+ XERBit)
+from bigint_presentation_code.register_allocator import (
+ AllocationFailed, MergedRegSet, allocate_registers,
+ try_allocate_registers_without_spilling)
+
+
+class TestMergedRegSet(unittest.TestCase):
+ maxDiff = None
+
+ def test_from_equality_constraint(self):
+ fn = Fn()
+ li0x1 = OpLI(fn, 0, vl=OpSetVLImm(fn, 1).out)
+ li0x2 = OpLI(fn, 0, vl=OpSetVLImm(fn, 2).out)
+ li0x3 = OpLI(fn, 0, vl=OpSetVLImm(fn, 3).out)
+ self.assertEqual(MergedRegSet.from_equality_constraint([
+ li0x1.out,
+ li0x2.out,
+ li0x3.out,
+ ]), MergedRegSet({
+ li0x1.out: 0,
+ li0x2.out: 1,
+ li0x3.out: 3,
+ }.items()))
+ self.assertEqual(MergedRegSet.from_equality_constraint([
+ li0x2.out,
+ li0x1.out,
+ li0x3.out,
+ ]), MergedRegSet({
+ li0x2.out: 0,
+ li0x1.out: 2,
+ li0x3.out: 3,
+ }.items()))
+
+
+class TestRegisterAllocator(unittest.TestCase):
+ maxDiff = None
+
+ def test_try_alloc_fail(self):
+ fn = Fn()
+ op0 = OpSetVLImm(fn, 52)
+ op1 = OpLI(fn, 0, vl=op0.out)
+ op2 = OpSetVLImm(fn, 64)
+ op3 = OpLI(fn, 0, vl=op2.out)
+ op4 = OpConcat(fn, [op1.out, op3.out])
+
+ reg_assignments = try_allocate_registers_without_spilling(fn.ops)
+ self.assertEqual(
+ repr(reg_assignments),
+ "AllocationFailed("
+ "node=IGNode(#0, merged_reg_set=MergedRegSet(["
+ "(<#4.dest: <gpr_ty[116]>>, 0), "
+ "(<#1.out: <gpr_ty[52]>>, 0), "
+ "(<#3.out: <gpr_ty[64]>>, 52)]), "
+ "edges={}, reg=None), "
+ "live_intervals=LiveIntervals(live_intervals={"
+ "MergedRegSet([(<#0.out: KnownVLType(length=52)>, 0)]): "
+ "LiveInterval(first_write=0, last_use=1), "
+ "MergedRegSet([(<#4.dest: <gpr_ty[116]>>, 0), "
+ "(<#1.out: <gpr_ty[52]>>, 0), "
+ "(<#3.out: <gpr_ty[64]>>, 52)]): "
+ "LiveInterval(first_write=1, last_use=4), "
+ "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)]): "
+ "LiveInterval(first_write=2, last_use=3)}, "
+ "merged_reg_sets=MergedRegSets(data={"
+ "<#0.out: KnownVLType(length=52)>: "
+ "MergedRegSet([(<#0.out: KnownVLType(length=52)>, 0)]), "
+ "<#1.out: <gpr_ty[52]>>: MergedRegSet(["
+ "(<#4.dest: <gpr_ty[116]>>, 0), "
+ "(<#1.out: <gpr_ty[52]>>, 0), "
+ "(<#3.out: <gpr_ty[64]>>, 52)]), "
+ "<#2.out: KnownVLType(length=64)>: "
+ "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)]), "
+ "<#3.out: <gpr_ty[64]>>: MergedRegSet(["
+ "(<#4.dest: <gpr_ty[116]>>, 0), "
+ "(<#1.out: <gpr_ty[52]>>, 0), "
+ "(<#3.out: <gpr_ty[64]>>, 52)]), "
+ "<#4.dest: <gpr_ty[116]>>: MergedRegSet(["
+ "(<#4.dest: <gpr_ty[116]>>, 0), "
+ "(<#1.out: <gpr_ty[52]>>, 0), "
+ "(<#3.out: <gpr_ty[64]>>, 52)])}), "
+ "reg_sets_live_after={"
+ "0: OFSet([MergedRegSet(["
+ "(<#0.out: KnownVLType(length=52)>, 0)])]), "
+ "1: OFSet([MergedRegSet(["
+ "(<#4.dest: <gpr_ty[116]>>, 0), "
+ "(<#1.out: <gpr_ty[52]>>, 0), "
+ "(<#3.out: <gpr_ty[64]>>, 52)])]), "
+ "2: OFSet([MergedRegSet(["
+ "(<#4.dest: <gpr_ty[116]>>, 0), "
+ "(<#1.out: <gpr_ty[52]>>, 0), "
+ "(<#3.out: <gpr_ty[64]>>, 52)]), "
+ "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)])]), "
+ "3: OFSet([MergedRegSet(["
+ "(<#4.dest: <gpr_ty[116]>>, 0), "
+ "(<#1.out: <gpr_ty[52]>>, 0), "
+ "(<#3.out: <gpr_ty[64]>>, 52)])]), "
+ "4: OFSet()}), "
+ "interference_graph=InterferenceGraph(nodes={"
+ "...: IGNode(#0, merged_reg_set=MergedRegSet(["
+ "(<#0.out: KnownVLType(length=52)>, 0)]), edges={}, reg=None), "
+ "...: IGNode(#1, merged_reg_set=MergedRegSet(["
+ "(<#4.dest: <gpr_ty[116]>>, 0), "
+ "(<#1.out: <gpr_ty[52]>>, 0), "
+ "(<#3.out: <gpr_ty[64]>>, 52)]), edges={}, reg=None), "
+ "...: IGNode(#2, merged_reg_set=MergedRegSet(["
+ "(<#2.out: KnownVLType(length=64)>, 0)]), edges={}, reg=None)}))"
+ )
+
+ def test_try_alloc_bigint_inc(self):
+ fn = Fn()
+ op0 = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3)))
+ op1 = OpCopy(fn, op0.out, GPRType())
+ arg = op1.dest
+ op2 = OpInputMem(fn)
+ mem = op2.out
+ op3 = OpSetVLImm(fn, 32)
+ vl = op3.out
+ op4 = OpLoad(fn, arg, offset=0, mem=mem, vl=vl)
+ a = op4.RT
+ op5 = OpLI(fn, 0, vl=vl)
+ b = op5.out
+ op6 = OpSetCA(fn, True)
+ ca = op6.out
+ op7 = OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl)
+ s = op7.out
+ op8 = OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl)
+ mem = op8.mem_out
+
+ reg_assignments = try_allocate_registers_without_spilling(fn.ops)
+
+ expected_reg_assignments = {
+ op0.out: GPRRange(start=3, length=1),
+ op1.dest: GPRRange(start=3, length=1),
+ op2.out: GlobalMem.GlobalMem,
+ op3.out: VL.VL_MAXVL,
+ op4.RT: GPRRange(start=78, length=32),
+ op5.out: GPRRange(start=46, length=32),
+ op6.out: XERBit.CA,
+ op7.out: GPRRange(start=14, length=32),
+ op7.CA_out: XERBit.CA,
+ op8.mem_out: GlobalMem.GlobalMem,
+ }
+
+ self.assertEqual(reg_assignments, expected_reg_assignments)
+
+ def tst_try_alloc_concat(self, expected_regs, expected_dest_reg):
+ # type: (list[GPRRange], GPRRange) -> None
+ fn = Fn()
+ inputs = []
+ expected_reg_assignments = {}
+ for i, r in enumerate(expected_regs):
+ vl = OpSetVLImm(fn, r.length).out
+ expected_reg_assignments[vl] = VL.VL_MAXVL
+ inp = OpLI(fn, i, vl=vl).out
+ inputs.append(inp)
+ expected_reg_assignments[inp] = r
+ concat = OpConcat(fn, inputs)
+ expected_reg_assignments[concat.dest] = expected_dest_reg
+
+ reg_assignments = try_allocate_registers_without_spilling(fn.ops)
+
+ for inp, reg in zip(inputs, expected_regs):
+ expected_reg_assignments[inp] = reg
+
+ self.assertEqual(reg_assignments, expected_reg_assignments)
+
+ def test_try_alloc_concat_1(self):
+ self.tst_try_alloc_concat([GPRRange(3)], GPRRange(3))
+
+ def test_try_alloc_concat_3(self):
+ self.tst_try_alloc_concat([GPRRange(3, 3)], GPRRange(3, 3))
+
+ def test_try_alloc_concat_3_5(self):
+ self.tst_try_alloc_concat([GPRRange(3, 3), GPRRange(6, 5)],
+ GPRRange(3, 8))
+
+ def test_try_alloc_concat_5_3(self):
+ self.tst_try_alloc_concat([GPRRange(3, 5), GPRRange(8, 3)],
+ GPRRange(3, 8))
+
+ def test_try_alloc_concat_1_2_3_4_5_6(self):
+ self.tst_try_alloc_concat([
+ GPRRange(14, 1),
+ GPRRange(15, 2),
+ GPRRange(17, 3),
+ GPRRange(20, 4),
+ GPRRange(24, 5),
+ GPRRange(29, 6),
+ ], GPRRange(14, 21))
+
+
+if __name__ == "__main__":
+ unittest.main()
--- /dev/null
+import unittest
+
+from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BYTES, SSAGPR, VL, FixedGPRRangeType, Fn,
+ GlobalMem, GPRRange,
+ GPRRangeType, OpCopy,
+ OpFuncArg, OpInputMem,
+ OpSetVLImm, OpStore, PreRASimState, SSAGPRRange, XERBit,
+ generate_assembly)
+from bigint_presentation_code.register_allocator import allocate_registers
+from bigint_presentation_code.toom_cook import ToomCookInstance, simple_mul
+from bigint_presentation_code.util import FMap
+
+
+class SimpleMul192x192:
+ def __init__(self):
+ self.fn = fn = Fn()
+ self.mem_in = mem = OpInputMem(fn).out
+ self.dest_ptr_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3))).out
+ self.lhs_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(4, 3))).out
+ self.rhs_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(7, 3))).out
+ dest_ptr = OpCopy(fn, self.dest_ptr_in, GPRRangeType()).dest
+ vl = OpSetVLImm(fn, 3).out
+ lhs = OpCopy(fn, self.lhs_in, GPRRangeType(3), vl=vl).dest
+ rhs = OpCopy(fn, self.rhs_in, GPRRangeType(3), vl=vl).dest
+ retval = simple_mul(fn, lhs, rhs)
+ vl = OpSetVLImm(fn, 6).out
+ self.mem_out = OpStore(fn, RS=retval, RA=dest_ptr, offset=0,
+ mem_in=mem, vl=vl).mem_out
+
+
+class TestToomCook(unittest.TestCase):
+ maxDiff = None
+
+ def test_toom_2_repr(self):
+ TOOM_2 = ToomCookInstance.make_toom_2()
+ # print(repr(repr(TOOM_2)))
+ self.assertEqual(
+ repr(TOOM_2),
+ "ToomCookInstance(lhs_part_count=2, rhs_part_count=2, "
+ "eval_points=(0, 1, POINT_AT_INFINITY), "
+ "lhs_eval_ops=("
+ "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+ "EvalOpAdd(lhs="
+ "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+ "rhs="
+ "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
+ "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
+ "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
+ " rhs_eval_ops=("
+ "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+ "EvalOpAdd(lhs="
+ "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+ "rhs="
+ "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
+ "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
+ "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
+ " prod_eval_ops=("
+ "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+ "EvalOpSub(lhs="
+ "EvalOpSub(lhs="
+ "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
+ "rhs="
+ "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+ "poly=EvalOpPoly({0: Fraction(-1, 1), 1: Fraction(1, 1)})), "
+ "rhs="
+ "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
+ "poly=EvalOpPoly({"
+ "0: Fraction(-1, 1), 1: Fraction(1, 1), 2: Fraction(-1, 1)})), "
+ "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)}))))"
+ )
+
+ def test_toom_2_5_repr(self):
+ TOOM_2_5 = ToomCookInstance.make_toom_2_5()
+ # print(repr(repr(TOOM_2_5)))
+ self.assertEqual(
+ repr(TOOM_2_5),
+ "ToomCookInstance(lhs_part_count=3, rhs_part_count=2, "
+ "eval_points=(0, 1, -1, POINT_AT_INFINITY), lhs_eval_ops=("
+ "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+ "EvalOpAdd(lhs="
+ "EvalOpAdd(lhs="
+ "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+ "rhs=EvalOpInput(lhs=2, rhs=0, "
+ "poly=EvalOpPoly({2: Fraction(1, 1)})), "
+ "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), "
+ "rhs=EvalOpInput(lhs=1, rhs=0, "
+ "poly=EvalOpPoly({1: Fraction(1, 1)})), "
+ "poly=EvalOpPoly({"
+ "0: Fraction(1, 1), 1: Fraction(1, 1), 2: Fraction(1, 1)})), "
+ "EvalOpSub(lhs="
+ "EvalOpAdd(lhs="
+ "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+ "rhs=EvalOpInput(lhs=2, rhs=0, "
+ "poly=EvalOpPoly({2: Fraction(1, 1)})), "
+ "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), "
+ "rhs=EvalOpInput(lhs=1, rhs=0, "
+ "poly=EvalOpPoly({1: Fraction(1, 1)})), poly=EvalOpPoly("
+ "{0: Fraction(1, 1), 1: Fraction(-1, 1), 2: Fraction(1, 1)})), "
+ "EvalOpInput(lhs=2, rhs=0, "
+ "poly=EvalOpPoly({2: Fraction(1, 1)}))), rhs_eval_ops=("
+ "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+ "EvalOpAdd(lhs=EvalOpInput(lhs=0, rhs=0, "
+ "poly=EvalOpPoly({0: Fraction(1, 1)})), rhs="
+ "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
+ "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
+ "EvalOpSub(lhs="
+ "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+ "rhs=EvalOpInput(lhs=1, rhs=0, "
+ "poly=EvalOpPoly({1: Fraction(1, 1)})), "
+ "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(-1, 1)})), "
+ "EvalOpInput(lhs=1, rhs=0, "
+ "poly=EvalOpPoly({1: Fraction(1, 1)}))), "
+ "prod_eval_ops=("
+ "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+ "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpSub(lhs="
+ "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
+ "rhs=EvalOpInput(lhs=2, rhs=0, "
+ "poly=EvalOpPoly({2: Fraction(1, 1)})), "
+ "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(-1, 1)})), "
+ "rhs=2, "
+ "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(-1, 2)})), rhs="
+ "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)})), "
+ "poly=EvalOpPoly("
+ "{1: Fraction(1, 2), 2: Fraction(-1, 2), 3: Fraction(-1, 1)})), "
+ "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpAdd(lhs="
+ "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
+ "rhs="
+ "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
+ "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(1, 1)})), rhs=2, "
+ "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(1, 2)})), rhs="
+ "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+ "poly=EvalOpPoly("
+ "{0: Fraction(-1, 1), 1: Fraction(1, 2), 2: Fraction(1, 2)})), "
+ "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)}))))"
+ )
+
+ def test_reversed_toom_2_5_repr(self):
+ TOOM_2_5 = ToomCookInstance.make_toom_2_5().reversed()
+ # print(repr(repr(TOOM_2_5)))
+ self.assertEqual(
+ repr(TOOM_2_5),
+ "ToomCookInstance(lhs_part_count=2, rhs_part_count=3, "
+ "eval_points=(0, 1, -1, POINT_AT_INFINITY), lhs_eval_ops=("
+ "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+ "EvalOpAdd(lhs="
+ "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+ "rhs="
+ "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
+ "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
+ "EvalOpSub(lhs="
+ "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+ "rhs="
+ "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
+ "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(-1, 1)})), "
+ "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
+ " rhs_eval_ops=("
+ "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+ "EvalOpAdd(lhs=EvalOpAdd(lhs="
+ "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+ "rhs="
+ "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
+ "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), rhs="
+ "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
+ "poly=EvalOpPoly("
+ "{0: Fraction(1, 1), 1: Fraction(1, 1), 2: Fraction(1, 1)})), "
+ "EvalOpSub(lhs=EvalOpAdd(lhs="
+ "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+ "rhs="
+ "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
+ "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), rhs="
+ "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
+ "poly=EvalOpPoly("
+ "{0: Fraction(1, 1), 1: Fraction(-1, 1), 2: Fraction(1, 1)})), "
+ "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)}))),"
+ " prod_eval_ops=("
+ "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+ "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpSub(lhs="
+ "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
+ "rhs="
+ "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
+ "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(-1, 1)})), "
+ "rhs=2, "
+ "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(-1, 2)})), rhs="
+ "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)})), "
+ "poly=EvalOpPoly("
+ "{1: Fraction(1, 2), 2: Fraction(-1, 2), 3: Fraction(-1, 1)})), "
+ "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpAdd(lhs="
+ "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
+ "rhs="
+ "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
+ "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(1, 1)})), rhs=2, "
+ "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(1, 2)})), rhs="
+ "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+ "poly=EvalOpPoly("
+ "{0: Fraction(-1, 1), 1: Fraction(1, 2), 2: Fraction(1, 2)})), "
+ "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)}))))"
+ )
+
+ def test_simple_mul_192x192_pre_ra_sim(self):
+ # test multiplying:
+ # 0x000191acb262e15b_4c6b5f2b19e1a53e_821a2342132c5b57
+ # * 0x4a37c0567bcbab53_cf1f597598194ae6_208a49071aeec507
+ # ==
+ # int("0x00074736574206e_6f69746163696c70"
+ # "_69746c756d207469_622d3438333e2d32"
+ # "_3931783239312079_7261727469627261", base=0)
+ # == int.from_bytes(b"arbitrary 192x192->384-bit multiplication test",
+ # 'little')
+ code = SimpleMul192x192()
+ dest_ptr = 0x100
+ state = PreRASimState(
+ gprs={}, VLs={}, CAs={}, global_mems={code.mem_in: FMap()},
+ stack_slots={}, fixed_gprs={
+ code.dest_ptr_in: (dest_ptr,),
+ code.lhs_in: (0x821a2342132c5b57, 0x4c6b5f2b19e1a53e,
+ 0x000191acb262e15b),
+ code.rhs_in: (0x208a49071aeec507, 0xcf1f597598194ae6,
+ 0x4a37c0567bcbab53)
+ })
+ code.fn.pre_ra_sim(state)
+ expected_bytes = b"arbitrary 192x192->384-bit multiplication test"
+ OUT_BYTE_COUNT = 6 * GPR_SIZE_IN_BYTES
+ expected_bytes = expected_bytes.ljust(OUT_BYTE_COUNT, b'\0')
+ mem_out = state.global_mems[code.mem_out]
+ out_bytes = bytes(
+ mem_out.get(dest_ptr + i, 0) for i in range(OUT_BYTE_COUNT))
+ self.assertEqual(out_bytes, expected_bytes)
+
+ def test_simple_mul_192x192_ops(self):
+ code = SimpleMul192x192()
+ fn = code.fn
+ self.assertEqual([repr(v) for v in fn.ops], [
+ 'OpInputMem(#0, <#0.out: GlobalMemType()>)',
+ 'OpFuncArg(#1, <#1.out: <fixed(<r3>)>>)',
+ 'OpFuncArg(#2, <#2.out: <fixed(<r4..len=3>)>>)',
+ 'OpFuncArg(#3, <#3.out: <fixed(<r7..len=3>)>>)',
+ 'OpCopy(#4, <#4.dest: <gpr_ty[1]>>, src=<#1.out: <fixed(<r3>)>>, '
+ 'vl=None)',
+ 'OpSetVLImm(#5, <#5.out: KnownVLType(length=3)>)',
+ 'OpCopy(#6, <#6.dest: <gpr_ty[3]>>, '
+ 'src=<#2.out: <fixed(<r4..len=3>)>>, '
+ 'vl=<#5.out: KnownVLType(length=3)>)',
+ 'OpCopy(#7, <#7.dest: <gpr_ty[3]>>, '
+ 'src=<#3.out: <fixed(<r7..len=3>)>>, '
+ 'vl=<#5.out: KnownVLType(length=3)>)',
+ 'OpSplit(#8, results=(<#8.results[0]: <gpr_ty[1]>>, '
+ '<#8.results[1]: <gpr_ty[1]>>, <#8.results[2]: <gpr_ty[1]>>), '
+ 'src=<#7.dest: <gpr_ty[3]>>)',
+ 'OpSetVLImm(#9, <#9.out: KnownVLType(length=3)>)',
+ 'OpLI(#10, <#10.out: <gpr_ty[1]>>, value=0, vl=None)',
+ 'OpBigIntMulDiv(#11, <#11.RT: <gpr_ty[3]>>, '
+ 'RA=<#6.dest: <gpr_ty[3]>>, RB=<#8.results[0]: <gpr_ty[1]>>, '
+ 'RC=<#10.out: <gpr_ty[1]>>, <#11.RS: <gpr_ty[1]>>, is_div=False, '
+ 'vl=<#9.out: KnownVLType(length=3)>)',
+ 'OpConcat(#12, <#12.dest: <gpr_ty[4]>>, sources=('
+ '<#11.RT: <gpr_ty[3]>>, <#11.RS: <gpr_ty[1]>>))',
+ 'OpBigIntMulDiv(#13, <#13.RT: <gpr_ty[3]>>, '
+ 'RA=<#6.dest: <gpr_ty[3]>>, RB=<#8.results[1]: <gpr_ty[1]>>, '
+ 'RC=<#10.out: <gpr_ty[1]>>, <#13.RS: <gpr_ty[1]>>, is_div=False, '
+ 'vl=<#9.out: KnownVLType(length=3)>)',
+ 'OpSplit(#14, results=(<#14.results[0]: <gpr_ty[1]>>, '
+ '<#14.results[1]: <gpr_ty[3]>>), src=<#12.dest: <gpr_ty[4]>>)',
+ 'OpSetCA(#15, <#15.out: CAType()>, value=False)',
+ 'OpBigIntAddSub(#16, <#16.out: <gpr_ty[3]>>, '
+ 'lhs=<#13.RT: <gpr_ty[3]>>, rhs=<#14.results[1]: <gpr_ty[3]>>, '
+ 'CA_in=<#15.out: CAType()>, <#16.CA_out: CAType()>, is_sub=False, '
+ 'vl=<#9.out: KnownVLType(length=3)>)',
+ 'OpBigIntAddSub(#17, <#17.out: <gpr_ty[1]>>, '
+ 'lhs=<#13.RS: <gpr_ty[1]>>, rhs=<#10.out: <gpr_ty[1]>>, '
+ 'CA_in=<#16.CA_out: CAType()>, <#17.CA_out: CAType()>, '
+ 'is_sub=False, vl=None)',
+ 'OpConcat(#18, <#18.dest: <gpr_ty[5]>>, sources=('
+ '<#14.results[0]: <gpr_ty[1]>>, <#16.out: <gpr_ty[3]>>, '
+ '<#17.out: <gpr_ty[1]>>))',
+ 'OpBigIntMulDiv(#19, <#19.RT: <gpr_ty[3]>>, '
+ 'RA=<#6.dest: <gpr_ty[3]>>, RB=<#8.results[2]: <gpr_ty[1]>>, '
+ 'RC=<#10.out: <gpr_ty[1]>>, <#19.RS: <gpr_ty[1]>>, is_div=False, '
+ 'vl=<#9.out: KnownVLType(length=3)>)',
+ 'OpSplit(#20, results=(<#20.results[0]: <gpr_ty[2]>>, '
+ '<#20.results[1]: <gpr_ty[3]>>), src=<#18.dest: <gpr_ty[5]>>)',
+ 'OpSetCA(#21, <#21.out: CAType()>, value=False)',
+ 'OpBigIntAddSub(#22, <#22.out: <gpr_ty[3]>>, '
+ 'lhs=<#19.RT: <gpr_ty[3]>>, rhs=<#20.results[1]: <gpr_ty[3]>>, '
+ 'CA_in=<#21.out: CAType()>, <#22.CA_out: CAType()>, is_sub=False, '
+ 'vl=<#9.out: KnownVLType(length=3)>)',
+ 'OpBigIntAddSub(#23, <#23.out: <gpr_ty[1]>>, '
+ 'lhs=<#19.RS: <gpr_ty[1]>>, rhs=<#10.out: <gpr_ty[1]>>, '
+ 'CA_in=<#22.CA_out: CAType()>, <#23.CA_out: CAType()>, '
+ 'is_sub=False, vl=None)',
+ 'OpConcat(#24, <#24.dest: <gpr_ty[6]>>, sources=('
+ '<#20.results[0]: <gpr_ty[2]>>, <#22.out: <gpr_ty[3]>>, '
+ '<#23.out: <gpr_ty[1]>>))',
+ 'OpSetVLImm(#25, <#25.out: KnownVLType(length=6)>)',
+ 'OpStore(#26, RS=<#24.dest: <gpr_ty[6]>>, '
+ 'RA=<#4.dest: <gpr_ty[1]>>, offset=0, '
+ 'mem_in=<#0.out: GlobalMemType()>, '
+ '<#26.mem_out: GlobalMemType()>, '
+ 'vl=<#25.out: KnownVLType(length=6)>)'
+ ])
+
+ # FIXME: register allocator currently allocates wrong registers
+ @unittest.expectedFailure
+ def test_simple_mul_192x192_reg_alloc(self):
+ code = SimpleMul192x192()
+ fn = code.fn
+ assigned_registers = allocate_registers(fn.ops)
+ self.assertEqual(assigned_registers, {
+ fn.ops[13].RS: GPRRange(9), # type: ignore
+ fn.ops[14].results[0]: GPRRange(6), # type: ignore
+ fn.ops[14].results[1]: GPRRange(7, length=3), # type: ignore
+ fn.ops[15].out: XERBit.CA, # type: ignore
+ fn.ops[16].out: GPRRange(7, length=3), # type: ignore
+ fn.ops[16].CA_out: XERBit.CA, # type: ignore
+ fn.ops[17].out: GPRRange(10), # type: ignore
+ fn.ops[17].CA_out: XERBit.CA, # type: ignore
+ fn.ops[18].dest: GPRRange(6, length=5), # type: ignore
+ fn.ops[19].RT: GPRRange(3, length=3), # type: ignore
+ fn.ops[19].RS: GPRRange(9), # type: ignore
+ fn.ops[20].results[0]: GPRRange(6, length=2), # type: ignore
+ fn.ops[20].results[1]: GPRRange(8, length=3), # type: ignore
+ fn.ops[21].out: XERBit.CA, # type: ignore
+ fn.ops[22].out: GPRRange(8, length=3), # type: ignore
+ fn.ops[22].CA_out: XERBit.CA, # type: ignore
+ fn.ops[23].out: GPRRange(11), # type: ignore
+ fn.ops[23].CA_out: XERBit.CA, # type: ignore
+ fn.ops[24].dest: GPRRange(6, length=6), # type: ignore
+ fn.ops[25].out: VL.VL_MAXVL, # type: ignore
+ fn.ops[26].mem_out: GlobalMem.GlobalMem, # type: ignore
+ fn.ops[0].out: GlobalMem.GlobalMem, # type: ignore
+ fn.ops[1].out: GPRRange(3), # type: ignore
+ fn.ops[2].out: GPRRange(4, length=3), # type: ignore
+ fn.ops[3].out: GPRRange(7, length=3), # type: ignore
+ fn.ops[4].dest: GPRRange(12), # type: ignore
+ fn.ops[5].out: VL.VL_MAXVL, # type: ignore
+ fn.ops[6].dest: GPRRange(17, length=3), # type: ignore
+ fn.ops[7].dest: GPRRange(14, length=3), # type: ignore
+ fn.ops[8].results[0]: GPRRange(14), # type: ignore
+ fn.ops[8].results[1]: GPRRange(15), # type: ignore
+ fn.ops[8].results[2]: GPRRange(16), # type: ignore
+ fn.ops[9].out: VL.VL_MAXVL, # type: ignore
+ fn.ops[10].out: GPRRange(9), # type: ignore
+ fn.ops[11].RT: GPRRange(6, length=3), # type: ignore
+ fn.ops[11].RS: GPRRange(9), # type: ignore
+ fn.ops[12].dest: GPRRange(6, length=4), # type: ignore
+ fn.ops[13].RT: GPRRange(3, length=3) # type: ignore
+ })
+ self.fail("register allocator currently allocates wrong registers")
+
+ # FIXME: register allocator currently allocates wrong registers
+ @unittest.expectedFailure
+ def test_simple_mul_192x192_asm(self):
+ code = SimpleMul192x192()
+ asm = generate_assembly(code.fn.ops)
+ self.assertEqual(asm, [
+ 'or 12, 3, 3',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'sv.or *17, *4, *4',
+ 'sv.or *14, *7, *7',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'addi 9, 0, 0',
+ 'sv.maddedu *6, *17, 14, 9',
+ 'sv.maddedu *3, *17, 15, 9',
+ 'addic 0, 0, 0',
+ 'sv.adde *7, *3, *7',
+ 'adde 10, 9, 9',
+ 'sv.maddedu *3, *17, 16, 9',
+ 'addic 0, 0, 0',
+ 'sv.adde *8, *3, *8',
+ 'adde 11, 9, 9',
+ 'setvl 0, 0, 6, 0, 1, 1',
+ 'sv.std *6, 0(12)',
+ 'bclr 20, 0, 0'
+ ])
+ self.fail("register allocator currently allocates wrong registers")
+
+
+if __name__ == "__main__":
+ unittest.main()
+# type: ignore
"""
Compiler IR for Toom-Cook algorithm generator for SVP64
from nmutil.plain_data import fields, plain_data
-from bigint_presentation_code.util import FMap, OFSet, OSet, final
+from bigint_presentation_code.type_util import final
+from bigint_presentation_code.util import FMap, OFSet, OSet
class ABCEnumMeta(EnumMeta, ABCMeta):
+from collections import defaultdict
import enum
from enum import Enum, unique
-from typing import AbstractSet, Iterable, Iterator, NoReturn, Tuple, Union, overload
+from typing import AbstractSet, Any, Iterable, Iterator, NoReturn, Tuple, Union, Mapping, overload
+from weakref import WeakValueDictionary as _WeakVDict
from cached_property import cached_property
from nmutil.plain_data import plain_data
-from bigint_presentation_code.util import OFSet, OSet, Self, assert_never, final
-from weakref import WeakValueDictionary
+from bigint_presentation_code.type_util import Self, assert_never, final
+from bigint_presentation_code.util import (BaseBitSet, BitSet, FBitSet, OFSet,
+ OSet, FMap)
+from functools import lru_cache
@final
class Fn:
def __init__(self):
self.ops = [] # type: list[Op]
- op_names = WeakValueDictionary()
- self.__op_names = op_names # type: WeakValueDictionary[str, Op]
+ self.__op_names = _WeakVDict() # type: _WeakVDict[str, Op]
self.__next_name_suffix = 2
def _add_op_with_unused_name(self, op, name=""):
self.__next_name_suffix += 1
def __repr__(self):
+ # type: () -> str
return "<Fn>"
@unique
@final
-class RegKind(Enum):
- GPR = enum.auto()
+class BaseTy(Enum):
+ I64 = enum.auto()
CA = enum.auto()
VL_MAXVL = enum.auto()
@cached_property
def only_scalar(self):
- if self is RegKind.GPR:
+ # type: () -> bool
+ if self is BaseTy.I64:
return False
- elif self is RegKind.CA or self is RegKind.VL_MAXVL:
+ elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
return True
else:
assert_never(self)
@cached_property
- def reg_count(self):
- if self is RegKind.GPR:
+ def max_reg_len(self):
+ # type: () -> int
+ if self is BaseTy.I64:
return 128
- elif self is RegKind.CA or self is RegKind.VL_MAXVL:
+ elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
return 1
else:
assert_never(self)
def __repr__(self):
- return "RegKind." + self._name_
+ return "BaseTy." + self._name_
@plain_data(frozen=True, unsafe_hash=True)
@final
-class OperandType:
- __slots__ = "kind", "vec"
+class Ty:
+ __slots__ = "base_ty", "reg_len"
- def __init__(self, kind, vec):
- # type: (RegKind, bool) -> None
- self.kind = kind
- if kind.only_scalar and vec:
- raise ValueError(f"kind={kind} must have vec=False")
- self.vec = vec
-
- def get_length(self, maxvl):
- # type: (int) -> int
- # here's where subvl and elwid would be accounted for
- if self.vec:
- return maxvl
- return 1
+ @staticmethod
+ def validate(base_ty, reg_len):
+ # type: (BaseTy, int) -> str | None
+ """ return a string with the error if the combination is invalid,
+ otherwise return None
+ """
+ if base_ty.only_scalar and reg_len != 1:
+ return f"can't create a vector of an only-scalar type: {base_ty}"
+ if reg_len < 1 or reg_len > base_ty.max_reg_len:
+ return "reg_len out of range"
+ return None
+
+ def __init__(self, base_ty, reg_len):
+ # type: (BaseTy, int) -> None
+ msg = self.validate(base_ty=base_ty, reg_len=reg_len)
+ if msg is not None:
+ raise ValueError(msg)
+ self.base_ty = base_ty
+ self.reg_len = reg_len
-@plain_data(frozen=True, unsafe_hash=True)
+@unique
@final
-class RegShape:
- __slots__ = "kind", "length"
+class LocKind(Enum):
+ GPR = enum.auto()
+ StackI64 = enum.auto()
+ CA = enum.auto()
+ VL_MAXVL = enum.auto()
- def __init__(self, kind, length=1):
- # type: (RegKind, int) -> None
- self.kind = kind
- if length < 1 or length > kind.reg_count:
- raise ValueError("invalid length")
- self.length = length
+ @cached_property
+ def base_ty(self):
+ # type: () -> BaseTy
+ if self is LocKind.GPR or self is LocKind.StackI64:
+ return BaseTy.I64
+ if self is LocKind.CA:
+ return BaseTy.CA
+ if self is LocKind.VL_MAXVL:
+ return BaseTy.VL_MAXVL
+ else:
+ assert_never(self)
- def try_concat(self, *others):
- # type: (*RegShape | Reg | RegClass | None) -> RegShape | None
- kind = self.kind
- length = self.length
- for other in others:
- if isinstance(other, (Reg, RegClass)):
- other = other.shape
- if other is None:
- return None
- if other.kind != self.kind:
- return None
- length += other.length
- if length > kind.reg_count:
- return None
- return RegShape(kind=kind, length=length)
+ @cached_property
+ def loc_count(self):
+ # type: () -> int
+ if self is LocKind.StackI64:
+ return 1024
+ if self is LocKind.GPR or self is LocKind.CA \
+ or self is LocKind.VL_MAXVL:
+ return self.base_ty.max_reg_len
+ else:
+ assert_never(self)
+
+ def __repr__(self):
+ return "LocKind." + self._name_
-@plain_data(frozen=True, unsafe_hash=True)
@final
-class Reg:
- __slots__ = "shape", "start"
-
- def __init__(self, shape, start):
- # type: (RegShape, int) -> None
- self.shape = shape
- if start < 0 or start + shape.length > shape.kind.reg_count:
- raise ValueError("start not in valid range")
- self.start = start
+@unique
+class LocSubKind(Enum):
+ BASE_GPR = enum.auto()
+ SV_EXTRA2_VGPR = enum.auto()
+ SV_EXTRA2_SGPR = enum.auto()
+ SV_EXTRA3_VGPR = enum.auto()
+ SV_EXTRA3_SGPR = enum.auto()
+ StackI64 = enum.auto()
+ CA = enum.auto()
+ VL_MAXVL = enum.auto()
- @property
+ @cached_property
def kind(self):
- return self.shape.kind
+ # type: () -> LocKind
+ # pyright fails typechecking when using `in` here:
+ # reported: https://github.com/microsoft/pyright/issues/4102
+ if self is LocSubKind.BASE_GPR or self is LocSubKind.SV_EXTRA2_VGPR \
+ or self is LocSubKind.SV_EXTRA2_SGPR \
+ or self is LocSubKind.SV_EXTRA3_VGPR \
+ or self is LocSubKind.SV_EXTRA3_SGPR:
+ return LocKind.GPR
+ if self is LocSubKind.StackI64:
+ return LocKind.StackI64
+ if self is LocSubKind.CA:
+ return LocKind.CA
+ if self is LocSubKind.VL_MAXVL:
+ return LocKind.VL_MAXVL
+ assert_never(self)
@property
- def length(self):
- return self.shape.length
+ def base_ty(self):
+ return self.kind.base_ty
+
+ @lru_cache()
+ def allocatable_locs(self, ty):
+ # type: (Ty) -> LocSet
+ if ty.base_ty != self.base_ty:
+ raise ValueError("type mismatch")
+ raise NotImplementedError # FIXME: finish
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class GenericTy:
+ __slots__ = "base_ty", "is_vec"
+
+ def __init__(self, base_ty, is_vec):
+ # type: (BaseTy, bool) -> None
+ self.base_ty = base_ty
+ if base_ty.only_scalar and is_vec:
+ raise ValueError(f"base_ty={base_ty} requires is_vec=False")
+ self.is_vec = is_vec
+
+ def instantiate(self, maxvl):
+ # type: (int) -> Ty
+ # here's where subvl and elwid would be accounted for
+ if self.is_vec:
+ return Ty(self.base_ty, maxvl)
+ return Ty(self.base_ty, 1)
+
+ def can_instantiate_to(self, ty):
+ # type: (Ty) -> bool
+ if self.base_ty != ty.base_ty:
+ return False
+ if self.is_vec:
+ return True
+ return ty.reg_len == 1
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class Loc:
+ __slots__ = "kind", "start", "reg_len"
+
+ @staticmethod
+ def validate(kind, start, reg_len):
+ # type: (LocKind, int, int) -> str | None
+ msg = Ty.validate(base_ty=kind.base_ty, reg_len=reg_len)
+ if msg is not None:
+ return msg
+ if reg_len > kind.loc_count:
+ return "invalid reg_len"
+ if start < 0 or start + reg_len > kind.loc_count:
+ return "start not in valid range"
+ return None
+
+ @staticmethod
+ def try_make(kind, start, reg_len):
+ # type: (LocKind, int, int) -> Loc | None
+ msg = Loc.validate(kind=kind, start=start, reg_len=reg_len)
+ if msg is None:
+ return None
+ return Loc(kind=kind, start=start, reg_len=reg_len)
+
+ def __init__(self, kind, start, reg_len):
+ # type: (LocKind, int, int) -> None
+ msg = self.validate(kind=kind, start=start, reg_len=reg_len)
+ if msg is not None:
+ raise ValueError(msg)
+ self.kind = kind
+ self.reg_len = reg_len
+ self.start = start
def conflicts(self, other):
- # type: (Reg) -> bool
- return (self.kind == other.kind
+ # type: (Loc) -> bool
+ return (self.kind != other.kind
and self.start < other.stop and other.start < self.stop)
+ @staticmethod
+ def make_ty(kind, reg_len):
+ # type: (LocKind, int) -> Ty
+ return Ty(base_ty=kind.base_ty, reg_len=reg_len)
+
+ @cached_property
+ def ty(self):
+ # type: () -> Ty
+ return self.make_ty(kind=self.kind, reg_len=self.reg_len)
+
@property
def stop(self):
- return self.start + self.length
+ # type: () -> int
+ return self.start + self.reg_len
def try_concat(self, *others):
- # type: (*Reg | None) -> Reg | None
- shape = self.shape.try_concat(*others)
- if shape is None:
- return None
+ # type: (*Loc | None) -> Loc | None
+ reg_len = self.reg_len
stop = self.stop
for other in others:
- assert other is not None, "already caught by RegShape.try_concat"
+ if other is None or other.kind != self.kind:
+ return None
if stop != other.start:
return None
stop = other.stop
- return Reg(shape, self.start)
+ reg_len += other.reg_len
+ return Loc(kind=self.kind, start=self.start, reg_len=reg_len)
+@plain_data(frozen=True, eq=False, repr=False)
@final
-class RegClass(AbstractSet[Reg]):
- def __init__(self, regs_or_starts=(), shape=None, starts_bitset=0):
- # type: (Iterable[Reg | int], RegShape | None, int) -> None
- for reg_or_start in regs_or_starts:
- if isinstance(reg_or_start, Reg):
- if shape is None:
- shape = reg_or_start.shape
- elif shape != reg_or_start.shape:
- raise ValueError(f"conflicting RegShapes: {shape} and "
- f"{reg_or_start.shape}")
- start = reg_or_start.start
- else:
- start = reg_or_start
- if start < 0:
- raise ValueError("a Reg's start is out of range")
- starts_bitset |= 1 << start
- if starts_bitset == 0:
- shape = None
- self.__shape = shape
- self.__starts_bitset = starts_bitset
- if shape is None:
- if starts_bitset != 0:
- raise ValueError("non-empty RegClass must have non-None shape")
+class LocSet(AbstractSet[Loc]):
+ __slots__ = "starts", "ty"
+
+ def __init__(self, __locs=()):
+ # type: (Iterable[Loc]) -> None
+ if isinstance(__locs, LocSet):
+ self.starts = __locs.starts # type: FMap[LocKind, FBitSet]
+ self.ty = __locs.ty # type: Ty | None
return
- if self.stops_bitset >= 1 << shape.kind.reg_count:
- raise ValueError("a Reg's start is out of range")
-
- @property
- def shape(self):
- # type: () -> RegShape | None
- return self.__shape
-
- @property
- def starts_bitset(self):
- # type: () -> int
- return self.__starts_bitset
-
- @property
- def stops_bitset(self):
- # type: () -> int
- if self.__shape is None:
- return 0
- return self.__starts_bitset << self.__shape.length
-
- @cached_property
- def starts(self):
- # type: () -> OFSet[int]
- if self.length is None:
- return OFSet()
- # TODO: fixme
- # return OFSet(for i in range(self.length))
+ starts = {i: BitSet() for i in LocKind}
+ ty = None
+ for loc in __locs:
+ if ty is None:
+ ty = loc.ty
+ if ty != loc.ty:
+ raise ValueError(f"conflicting types: {ty} != {loc.ty}")
+ starts[loc.kind].add(loc.start)
+ self.starts = FMap(
+ (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0)
+ self.ty = ty
@cached_property
def stops(self):
- # type: () -> OFSet[int]
- if self.__shape is None:
- return OFSet()
- return OFSet(i + self.__shape.length for i in self.__starts)
+ # type: () -> FMap[LocKind, FBitSet]
+ if self.ty is None:
+ return FMap()
+ sh = self.ty.reg_len
+ return FMap(
+ (k, FBitSet(bits=v.bits << sh)) for k, v in self.starts.items())
@property
- def kind(self):
- if self.__shape is None:
+ def kinds(self):
+ # type: () -> AbstractSet[LocKind]
+ return self.starts.keys()
+
+ @property
+ def reg_len(self):
+ # type: () -> int | None
+ if self.ty is None:
return None
- return self.__shape.kind
+ return self.ty.reg_len
@property
- def length(self):
- """length of registers in this RegClass, not to be confused with the number of `Reg`s in self"""
- if self.__shape is None:
+ def base_ty(self):
+ # type: () -> BaseTy | None
+ if self.ty is None:
return None
- return self.__shape.length
+ return self.ty.base_ty
def concat(self, *others):
- # type: (*RegClass) -> RegClass
- shape = self.__shape
- if shape is None:
- return RegClass()
- shape = shape.try_concat(*others)
- if shape is None:
- return RegClass()
- starts = OSet(self.starts)
- offset = shape.length
+ # type: (*LocSet) -> LocSet
+ if self.ty is None:
+ return LocSet()
+ base_ty = self.ty.base_ty
+ reg_len = self.ty.reg_len
+ starts = {k: BitSet(v) for k, v in self.starts.items()}
for other in others:
- assert other.__shape is not None, \
- "already caught by RegShape.try_concat"
- starts &= OSet(i - offset for i in other.starts)
- offset += other.__shape.length
- return RegClass(starts, shape=shape)
-
- def __contains__(self, reg):
- # type: (Reg) -> bool
- return reg.shape == self.shape and reg.start in self.starts
+ if other.ty is None:
+ return LocSet()
+ if other.ty.base_ty != base_ty:
+ return LocSet()
+ for kind, other_starts in other.starts.items():
+ if kind not in starts:
+ continue
+ starts[kind].bits &= other_starts.bits >> reg_len
+ if starts[kind] == 0:
+ del starts[kind]
+ if len(starts) == 0:
+ return LocSet()
+ reg_len += other.ty.reg_len
+
+ def locs():
+ # type: () -> Iterable[Loc]
+ for kind, v in starts.items():
+ for start in v:
+ loc = Loc.try_make(kind=kind, start=start, reg_len=reg_len)
+ if loc is not None:
+ yield loc
+ return LocSet(locs())
+
+ def __contains__(self, loc):
+ # type: (Loc | Any) -> bool
+ if not isinstance(loc, Loc) or loc.ty == self.ty:
+ return False
+ if loc.kind not in self.starts:
+ return False
+ return loc.start in self.starts[loc.kind]
def __iter__(self):
- # type: () -> Iterator[Reg]
- if self.shape is None:
+ # type: () -> Iterator[Loc]
+ if self.ty is None:
return
- for start in self.starts:
- yield Reg(shape=self.shape, start=start)
+ for kind, starts in self.starts.items():
+ for start in starts:
+ yield Loc(kind=kind, start=start, reg_len=self.ty.reg_len)
+
+ @cached_property
+ def __len(self):
+ return sum((len(v) for v in self.starts.values()), 0)
def __len__(self):
- return len(self.starts)
+ return self.__len
- def __hash__(self):
+ @cached_property
+ def __hash(self):
return super()._hash()
+ def __hash__(self):
+ return self.__hash
+
@plain_data(frozen=True, unsafe_hash=True)
@final
-class Operand:
- __slots__ = "ty", "regs"
-
- def __init__(self, ty, regs=None):
- # type: (OperandType, OFSet[int] | None) -> None
- pass
-
-
-OT_VGPR = OperandType(RegKind.GPR, vec=True)
-OT_SGPR = OperandType(RegKind.GPR, vec=False)
-OT_CA = OperandType(RegKind.CA, vec=False)
-OT_VL = OperandType(RegKind.VL_MAXVL, vec=False)
+class GenericOperandDesc:
+ """generic Op operand descriptor"""
+ __slots__ = "ty", "fixed_loc", "sub_kinds", "tied_input_index"
+
+ def __init__(self, ty, sub_kinds, fixed_loc=None, tied_input_index=None):
+ # type: (GenericTy, Iterable[LocSubKind], Loc | None, int | None) -> None
+ self.ty = ty
+ self.sub_kinds = OFSet(sub_kinds)
+ if len(self.sub_kinds) == 0:
+ raise ValueError("sub_kinds can't be empty")
+ self.fixed_loc = fixed_loc
+ if fixed_loc is not None:
+ if tied_input_index is not None:
+ raise ValueError("operand can't be both tied and fixed")
+ if not ty.can_instantiate_to(fixed_loc.ty):
+ raise ValueError(
+ f"fixed_loc has incompatible type for given generic "
+ f"type: fixed_loc={fixed_loc} generic ty={ty}")
+ if len(self.sub_kinds) != 1:
+ raise ValueError(
+ "multiple sub_kinds not allowed for fixed operand")
+ for sub_kind in self.sub_kinds:
+ if fixed_loc not in sub_kind.allocatable_locs(fixed_loc.ty):
+ raise ValueError(
+ f"fixed_loc not in given sub_kind: "
+ f"fixed_loc={fixed_loc} sub_kind={sub_kind}")
+ for sub_kind in self.sub_kinds:
+ if sub_kind.base_ty != ty.base_ty:
+ raise ValueError(f"sub_kind is incompatible with type: "
+ f"sub_kind={sub_kind} ty={ty}")
+ if tied_input_index is not None and tied_input_index < 0:
+ raise ValueError("invalid tied_input_index")
+ self.tied_input_index = tied_input_index
+
+ def tied_to_input(self, tied_input_index):
+ # type: (int) -> Self
+ return GenericOperandDesc(self.ty, self.sub_kinds,
+ tied_input_index=tied_input_index)
+
+ def with_fixed_loc(self, fixed_loc):
+ # type: (Loc) -> Self
+ return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc)
+
+ def instantiate(self, maxvl):
+ # type: (int) -> OperandDesc
+ ty = self.ty.instantiate(maxvl=maxvl)
+
+ def locs():
+ # type: () -> Iterable[Loc]
+ if self.fixed_loc is not None:
+ if ty != self.fixed_loc.ty:
+ raise ValueError(
+ f"instantiation failed: type mismatch with fixed_loc: "
+ f"instantiated type: {ty} fixed_loc: {self.fixed_loc}")
+ yield self.fixed_loc
+ return
+ for sub_kind in self.sub_kinds:
+ yield from sub_kind.allocatable_locs(ty)
+ return OperandDesc(loc_set=LocSet(locs()),
+ tied_input_index=self.tied_input_index)
@plain_data(frozen=True, unsafe_hash=True)
-class TiedOutput:
- __slots__ = "input_index", "output_index"
-
- def __init__(self, input_index, output_index):
- # type: (int, int) -> None
- self.input_index = input_index
- self.output_index = output_index
-
-
-Constraint = Union[TiedOutput, NoReturn]
+@final
+class OperandDesc:
+ """Op operand descriptor"""
+ __slots__ = "loc_set", "tied_input_index"
+
+ def __init__(self, loc_set, tied_input_index):
+ # type: (LocSet, int | None) -> None
+ if len(loc_set) == 0:
+ raise ValueError("loc_set must not be empty")
+ self.loc_set = loc_set
+ self.tied_input_index = tied_input_index
+
+
+OD_BASE_SGPR = GenericOperandDesc(
+ ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
+ sub_kinds=[LocSubKind.BASE_GPR])
+OD_EXTRA3_SGPR = GenericOperandDesc(
+ ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
+ sub_kinds=[LocSubKind.SV_EXTRA3_SGPR])
+OD_EXTRA3_VGPR = GenericOperandDesc(
+ ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
+ sub_kinds=[LocSubKind.SV_EXTRA3_VGPR])
+OD_EXTRA2_SGPR = GenericOperandDesc(
+ ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
+ sub_kinds=[LocSubKind.SV_EXTRA2_SGPR])
+OD_EXTRA2_VGPR = GenericOperandDesc(
+ ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
+ sub_kinds=[LocSubKind.SV_EXTRA2_VGPR])
+OD_CA = GenericOperandDesc(
+ ty=GenericTy(base_ty=BaseTy.CA, is_vec=False),
+ sub_kinds=[LocSubKind.CA])
+OD_VL = GenericOperandDesc(
+ ty=GenericTy(base_ty=BaseTy.VL_MAXVL, is_vec=False),
+ sub_kinds=[LocSubKind.VL_MAXVL])
@plain_data(frozen=True, unsafe_hash=True)
@final
-class OpProperties:
- __slots__ = ("demo_asm", "inputs", "outputs", "immediates", "constraints",
+class GenericOpProperties:
+ __slots__ = ("demo_asm", "inputs", "outputs", "immediates",
"is_copy", "is_load_immediate", "has_side_effects")
def __init__(self, demo_asm, # type: str
- inputs, # type: Iterable[OperandType]
- outputs, # type: Iterable[OperandType]
+ inputs, # type: Iterable[GenericOperandDesc]
+ outputs, # type: Iterable[GenericOperandDesc]
immediates, # type: Iterable[range]
- constraints, # type: Iterable[Constraint]
is_copy=False, # type: bool
is_load_immediate=False, # type: bool
has_side_effects=False, # type: bool
self.inputs = tuple(inputs)
self.outputs = tuple(outputs)
self.immediates = tuple(immediates)
- self.constraints = tuple(constraints)
self.is_copy = is_copy
self.is_load_immediate = is_load_immediate
self.has_side_effects = has_side_effects
+ def instantiate(self, maxvl):
+ # type: (int) -> OpProperties
+ raise NotImplementedError # FIXME: finish
+
+
+# FIXME: add OpProperties
@unique
@final
class OpKind(Enum):
def __init__(self, properties):
- # type: (OpProperties) -> None
+ # type: (GenericOpProperties) -> None
super().__init__()
self.properties = properties
self.sliced_op_outputs = tuple(processed)
def __add__(self, other):
- # type: (SSAVal) -> SSAVal
+ # type: (SSAVal | Any) -> SSAVal
if not isinstance(other, SSAVal):
return NotImplemented
return SSAVal(self.sliced_op_outputs + other.sliced_op_outputs)
def __radd__(self, other):
- # type: (SSAVal) -> SSAVal
+ # type: (SSAVal | Any) -> SSAVal
if isinstance(other, SSAVal):
return other.__add__(self)
return NotImplemented
@cached_property
def expanded_sliced_op_outputs(self):
# type: () -> tuple[tuple[Op, int, int], ...]
- retval = []
+ retval = [] # type: list[tuple[Op, int, int]]
for op, output_index, range_ in self.sliced_op_outputs:
for i in range_:
retval.append((op, output_index, i))
# type: () -> str
if len(self.sliced_op_outputs) == 0:
return "SSAVal([])"
- parts = []
+ parts = [] # type: list[str]
for op, output_index, range_ in self.sliced_op_outputs:
out_len = op.properties.outputs[output_index].get_length(op.maxvl)
parts.append(f"<{op.name}#{output_index}>")
self.maxvl = maxvl
outputs_len = len(self.properties.outputs)
self.outputs = tuple(SSAVal([(self, i)]) for i in range(outputs_len))
- self.name = fn._add_op_with_unused_name(self, name)
+ self.name = fn._add_op_with_unused_name(self, name) # type: ignore
@property
def properties(self):
return self.kind.properties
def __eq__(self, other):
+ # type: (Op | Any) -> bool
if isinstance(other, Op):
return self is other
return NotImplemented
-import operator
from enum import Enum, unique
from fractions import Fraction
-from numbers import Rational
+import operator
from typing import Any, Callable, Generic, Iterable, Iterator, Type, TypeVar
-from bigint_presentation_code.util import final
+from bigint_presentation_code.type_util import final
_T = TypeVar("_T")
_T2 = TypeVar("_T2")
return retval
def __truediv__(self, rhs):
- # type: (Rational | int) -> Matrix
+ # type: (_T | int) -> Matrix[_T]
retval = self.copy()
for i in self.indexes():
retval[i] /= rhs # type: ignore
return lhs.__matmul__(self)
def __elementwise_bin_op(self, rhs, op):
- # type: (Matrix, Callable[[_T | int, _T | int], _T | int]) -> Matrix[_T]
+ # type: (Matrix[_T], Callable[[_T | int, _T | int], _T | int]) -> Matrix[_T]
if self.height != rhs.height or self.width != rhs.width:
raise ValueError(
"matrix dimensions must match for element-wise operations")
# type: () -> str
if self.height == 0 or self.width == 0:
return f"Matrix(height={self.height}, width={self.width})"
- lines = []
- line = []
+ lines = [] # type: list[str]
+ line = [] # type: list[str]
for row in range(self.height):
line.clear()
for col in range(self.width):
else:
line.append(repr(el))
lines.append(", ".join(line))
- lines = ",\n ".join(lines)
+ lines_str = ",\n ".join(lines)
element_type = ""
if self.element_type is not Fraction:
element_type = f"element_type={self.element_type}, "
return (f"Matrix(height={self.height}, width={self.width}, "
f"{element_type}data=[\n"
- f" {lines},\n])")
+ f" {lines_str},\n])")
def __eq__(self, rhs):
- # type: (object) -> bool
+ # type: (Matrix[Any] | Any) -> bool
if not isinstance(rhs, Matrix):
return NotImplemented
return (self.height == rhs.height
from bigint_presentation_code.compiler_ir import (GPRRangeType, Op, RegClass,
RegLoc, RegType, SSAVal)
-from bigint_presentation_code.util import OFSet, OSet, final
+from bigint_presentation_code.type_util import final
+from bigint_presentation_code.util import OFSet, OSet
_RegType = TypeVar("_RegType", bound=RegType)
+++ /dev/null
-import unittest
-
-from bigint_presentation_code.compiler_ir import (VL, FixedGPRRangeType, Fn,
- GlobalMem, GPRRange, GPRType,
- OpBigIntAddSub, OpConcat,
- OpCopy, OpFuncArg,
- OpInputMem, OpLI, OpLoad,
- OpSetCA, OpSetVLImm, OpStore,
- RegLoc, SSAVal, XERBit,
- generate_assembly,
- op_set_to_list)
-
-
-class TestCompilerIR(unittest.TestCase):
- maxDiff = None
-
- def test_op_set_to_list(self):
- fn = Fn()
- op0 = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3)))
- op1 = OpCopy(fn, op0.out, GPRType())
- arg = op1.dest
- op2 = OpInputMem(fn)
- mem = op2.out
- op3 = OpSetVLImm(fn, 32)
- vl = op3.out
- op4 = OpLoad(fn, arg, offset=0, mem=mem, vl=vl)
- a = op4.RT
- op5 = OpLI(fn, 1)
- b_0 = op5.out
- op6 = OpSetVLImm(fn, 31)
- vl = op6.out
- op7 = OpLI(fn, 0, vl=vl)
- b_rest = op7.out
- op8 = OpConcat(fn, [b_0, b_rest])
- b = op8.dest
- op9 = OpSetVLImm(fn, 32)
- vl = op9.out
- op10 = OpSetCA(fn, False)
- ca = op10.out
- op11 = OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl)
- s = op11.out
- op12 = OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl)
- mem = op12.mem_out
-
- expected_ops = [
- op10, # OpSetCA(fn, False)
- op9, # OpSetVLImm(fn, 32)
- op6, # OpSetVLImm(fn, 31)
- op5, # OpLI(fn, 1)
- op3, # OpSetVLImm(fn, 32)
- op2, # OpInputMem(fn)
- op0, # OpFuncArg(fn, FixedGPRRangeType(GPRRange(3)))
- op7, # OpLI(fn, 0, vl=vl)
- op1, # OpCopy(fn, op0.out, GPRType())
- op8, # OpConcat(fn, [b_0, b_rest])
- op4, # OpLoad(fn, arg, offset=0, mem=mem, vl=vl)
- op11, # OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl)
- op12, # OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl)
- ]
-
- ops = op_set_to_list(fn.ops[::-1])
- if ops != expected_ops:
- self.assertEqual(repr(ops), repr(expected_ops))
-
- def tst_generate_assembly(self, use_reg_alloc=False):
- fn = Fn()
- op0 = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3)))
- op1 = OpCopy(fn, op0.out, GPRType())
- arg = op1.dest
- op2 = OpInputMem(fn)
- mem = op2.out
- op3 = OpSetVLImm(fn, 32)
- vl = op3.out
- op4 = OpLoad(fn, arg, offset=0, mem=mem, vl=vl)
- a = op4.RT
- op5 = OpLI(fn, 0, vl=vl)
- b = op5.out
- op6 = OpSetCA(fn, True)
- ca = op6.out
- op7 = OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl)
- s = op7.out
- op8 = OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl)
- mem = op8.mem_out
-
- assigned_registers = {
- op0.out: GPRRange(start=3, length=1),
- op1.dest: GPRRange(start=3, length=1),
- op2.out: GlobalMem.GlobalMem,
- op3.out: VL.VL_MAXVL,
- op4.RT: GPRRange(start=78, length=32),
- op5.out: GPRRange(start=46, length=32),
- op6.out: XERBit.CA,
- op7.out: GPRRange(start=14, length=32),
- op7.CA_out: XERBit.CA,
- op8.mem_out: GlobalMem.GlobalMem,
- } # type: dict[SSAVal, RegLoc] | None
-
- if use_reg_alloc:
- assigned_registers = None
-
- asm = generate_assembly(fn.ops, assigned_registers)
- self.assertEqual(asm, [
- "setvl 0, 0, 32, 0, 1, 1",
- "sv.ld *78, 0(3)",
- "sv.addi *46, 0, 0",
- "subfic 0, 0, -1",
- "sv.adde *14, *78, *46",
- "sv.std *14, 0(3)",
- "bclr 20, 0, 0",
- ])
-
- def test_generate_assembly(self):
- self.tst_generate_assembly()
-
- def test_generate_assembly_with_register_allocator(self):
- self.tst_generate_assembly(use_reg_alloc=True)
-
-
-if __name__ == "__main__":
- unittest.main()
+++ /dev/null
-import unittest
-from fractions import Fraction
-
-from bigint_presentation_code.matrix import Matrix, SpecialMatrix
-
-
-class TestMatrix(unittest.TestCase):
- def test_repr(self):
- self.assertEqual(repr(Matrix(2, 3, [0, 1, 2,
- 3, 4, 5])),
- 'Matrix(height=2, width=3, data=[\n'
- ' 0, 1, 2,\n'
- ' 3, 4, 5,\n'
- '])')
- self.assertEqual(repr(Matrix(2, 3, [0, 1, Fraction(2) / 3,
- 3, 4, 5])),
- 'Matrix(height=2, width=3, data=[\n'
- ' 0, 1, Fraction(2, 3),\n'
- ' 3, 4, 5,\n'
- '])')
- self.assertEqual(repr(Matrix(0, 3)), 'Matrix(height=0, width=3)')
- self.assertEqual(repr(Matrix(2, 0)), 'Matrix(height=2, width=0)')
-
- def test_eq(self):
- self.assertFalse(Matrix(1, 1) == 5)
- self.assertFalse(5 == Matrix(1, 1))
- self.assertFalse(Matrix(2, 1) == Matrix(1, 1))
- self.assertFalse(Matrix(1, 2) == Matrix(1, 1))
- self.assertTrue(Matrix(1, 1) == Matrix(1, 1))
- self.assertTrue(Matrix(1, 1, [1]) == Matrix(1, 1, [1]))
- self.assertFalse(Matrix(1, 1, [2]) == Matrix(1, 1, [1]))
-
- def test_add(self):
- self.assertEqual(Matrix(2, 2, [1, 2, 3, 4])
- + Matrix(2, 2, [40, 30, 20, 10]),
- Matrix(2, 2, [41, 32, 23, 14]))
-
- def test_identity(self):
- self.assertEqual(Matrix(2, 2, data=SpecialMatrix.Identity),
- Matrix(2, 2, [1, 0,
- 0, 1]))
- self.assertEqual(Matrix(1, 3, data=SpecialMatrix.Identity),
- Matrix(1, 3, [1, 0, 0]))
- self.assertEqual(Matrix(2, 3, data=SpecialMatrix.Identity),
- Matrix(2, 3, [1, 0, 0,
- 0, 1, 0]))
- self.assertEqual(Matrix(3, 3, data=SpecialMatrix.Identity),
- Matrix(3, 3, [1, 0, 0,
- 0, 1, 0,
- 0, 0, 1]))
-
- def test_sub(self):
- self.assertEqual(Matrix(2, 2, [40, 30, 20, 10])
- - Matrix(2, 2, [-1, -2, -3, -4]),
- Matrix(2, 2, [41, 32, 23, 14]))
-
- def test_neg(self):
- self.assertEqual(-Matrix(2, 2, [40, 30, 20, 10]),
- Matrix(2, 2, [-40, -30, -20, -10]))
-
- def test_mul(self):
- self.assertEqual(Matrix(2, 2, [1, 2, 3, 4]) * Fraction(3, 2),
- Matrix(2, 2, [Fraction(3, 2), 3, Fraction(9, 2), 6]))
- self.assertEqual(Fraction(3, 2) * Matrix(2, 2, [1, 2, 3, 4]),
- Matrix(2, 2, [Fraction(3, 2), 3, Fraction(9, 2), 6]))
-
- def test_matmul(self):
- self.assertEqual(Matrix(2, 2, [1, 2, 3, 4])
- @ Matrix(2, 2, [4, 3, 2, 1]),
- Matrix(2, 2, [8, 5, 20, 13]))
- self.assertEqual(Matrix(3, 2, [6, 5, 4, 3, 2, 1])
- @ Matrix(2, 1, [1, 2]),
- Matrix(3, 1, [16, 10, 4]))
-
- def test_inverse(self):
- self.assertEqual(Matrix(0, 0).inverse(), Matrix(0, 0))
- self.assertEqual(Matrix(1, 1, [2]).inverse(),
- Matrix(1, 1, [Fraction(1, 2)]))
- self.assertEqual(Matrix(1, 1, [1]).inverse(),
- Matrix(1, 1, [1]))
- self.assertEqual(Matrix(2, 2, [1, 0, 1, 1]).inverse(),
- Matrix(2, 2, [1, 0, -1, 1]))
- self.assertEqual(Matrix(3, 3, [0, 1, 0,
- 1, 0, 0,
- 0, 0, 1]).inverse(),
- Matrix(3, 3, [0, 1, 0,
- 1, 0, 0,
- 0, 0, 1]))
- _1_2 = Fraction(1, 2)
- _1_3 = Fraction(1, 3)
- _1_6 = Fraction(1, 6)
- self.assertEqual(Matrix(5, 5, [1, 0, 0, 0, 0,
- 1, 1, 1, 1, 1,
- 1, -1, 1, -1, 1,
- 1, -2, 4, -8, 16,
- 0, 0, 0, 0, 1]).inverse(),
- Matrix(5, 5, [1, 0, 0, 0, 0,
- _1_2, _1_3, -1, _1_6, -2,
- -1, _1_2, _1_2, 0, -1,
- -_1_2, _1_6, _1_2, -_1_6, 2,
- 0, 0, 0, 0, 1]))
- with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"):
- Matrix(1, 1, [0]).inverse()
- with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"):
- Matrix(2, 2, [0, 0, 1, 1]).inverse()
- with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"):
- Matrix(2, 2, [1, 0, 1, 0]).inverse()
- with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"):
- Matrix(2, 2, [1, 1, 1, 1]).inverse()
-
-
-if __name__ == "__main__":
- unittest.main()
+++ /dev/null
-import unittest
-
-from bigint_presentation_code.compiler_ir import (VL, FixedGPRRangeType, Fn,
- GlobalMem, GPRRange, GPRType,
- OpBigIntAddSub, OpConcat,
- OpCopy, OpFuncArg,
- OpInputMem, OpLI, OpLoad,
- OpSetCA, OpSetVLImm, OpStore,
- XERBit)
-from bigint_presentation_code.register_allocator import (
- AllocationFailed, MergedRegSet, allocate_registers,
- try_allocate_registers_without_spilling)
-
-
-class TestMergedRegSet(unittest.TestCase):
- maxDiff = None
-
- def test_from_equality_constraint(self):
- fn = Fn()
- li0x1 = OpLI(fn, 0, vl=OpSetVLImm(fn, 1).out)
- li0x2 = OpLI(fn, 0, vl=OpSetVLImm(fn, 2).out)
- li0x3 = OpLI(fn, 0, vl=OpSetVLImm(fn, 3).out)
- self.assertEqual(MergedRegSet.from_equality_constraint([
- li0x1.out,
- li0x2.out,
- li0x3.out,
- ]), MergedRegSet({
- li0x1.out: 0,
- li0x2.out: 1,
- li0x3.out: 3,
- }.items()))
- self.assertEqual(MergedRegSet.from_equality_constraint([
- li0x2.out,
- li0x1.out,
- li0x3.out,
- ]), MergedRegSet({
- li0x2.out: 0,
- li0x1.out: 2,
- li0x3.out: 3,
- }.items()))
-
-
-class TestRegisterAllocator(unittest.TestCase):
- maxDiff = None
-
- def test_try_alloc_fail(self):
- fn = Fn()
- op0 = OpSetVLImm(fn, 52)
- op1 = OpLI(fn, 0, vl=op0.out)
- op2 = OpSetVLImm(fn, 64)
- op3 = OpLI(fn, 0, vl=op2.out)
- op4 = OpConcat(fn, [op1.out, op3.out])
-
- reg_assignments = try_allocate_registers_without_spilling(fn.ops)
- self.assertEqual(
- repr(reg_assignments),
- "AllocationFailed("
- "node=IGNode(#0, merged_reg_set=MergedRegSet(["
- "(<#4.dest: <gpr_ty[116]>>, 0), "
- "(<#1.out: <gpr_ty[52]>>, 0), "
- "(<#3.out: <gpr_ty[64]>>, 52)]), "
- "edges={}, reg=None), "
- "live_intervals=LiveIntervals(live_intervals={"
- "MergedRegSet([(<#0.out: KnownVLType(length=52)>, 0)]): "
- "LiveInterval(first_write=0, last_use=1), "
- "MergedRegSet([(<#4.dest: <gpr_ty[116]>>, 0), "
- "(<#1.out: <gpr_ty[52]>>, 0), "
- "(<#3.out: <gpr_ty[64]>>, 52)]): "
- "LiveInterval(first_write=1, last_use=4), "
- "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)]): "
- "LiveInterval(first_write=2, last_use=3)}, "
- "merged_reg_sets=MergedRegSets(data={"
- "<#0.out: KnownVLType(length=52)>: "
- "MergedRegSet([(<#0.out: KnownVLType(length=52)>, 0)]), "
- "<#1.out: <gpr_ty[52]>>: MergedRegSet(["
- "(<#4.dest: <gpr_ty[116]>>, 0), "
- "(<#1.out: <gpr_ty[52]>>, 0), "
- "(<#3.out: <gpr_ty[64]>>, 52)]), "
- "<#2.out: KnownVLType(length=64)>: "
- "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)]), "
- "<#3.out: <gpr_ty[64]>>: MergedRegSet(["
- "(<#4.dest: <gpr_ty[116]>>, 0), "
- "(<#1.out: <gpr_ty[52]>>, 0), "
- "(<#3.out: <gpr_ty[64]>>, 52)]), "
- "<#4.dest: <gpr_ty[116]>>: MergedRegSet(["
- "(<#4.dest: <gpr_ty[116]>>, 0), "
- "(<#1.out: <gpr_ty[52]>>, 0), "
- "(<#3.out: <gpr_ty[64]>>, 52)])}), "
- "reg_sets_live_after={"
- "0: OFSet([MergedRegSet(["
- "(<#0.out: KnownVLType(length=52)>, 0)])]), "
- "1: OFSet([MergedRegSet(["
- "(<#4.dest: <gpr_ty[116]>>, 0), "
- "(<#1.out: <gpr_ty[52]>>, 0), "
- "(<#3.out: <gpr_ty[64]>>, 52)])]), "
- "2: OFSet([MergedRegSet(["
- "(<#4.dest: <gpr_ty[116]>>, 0), "
- "(<#1.out: <gpr_ty[52]>>, 0), "
- "(<#3.out: <gpr_ty[64]>>, 52)]), "
- "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)])]), "
- "3: OFSet([MergedRegSet(["
- "(<#4.dest: <gpr_ty[116]>>, 0), "
- "(<#1.out: <gpr_ty[52]>>, 0), "
- "(<#3.out: <gpr_ty[64]>>, 52)])]), "
- "4: OFSet()}), "
- "interference_graph=InterferenceGraph(nodes={"
- "...: IGNode(#0, merged_reg_set=MergedRegSet(["
- "(<#0.out: KnownVLType(length=52)>, 0)]), edges={}, reg=None), "
- "...: IGNode(#1, merged_reg_set=MergedRegSet(["
- "(<#4.dest: <gpr_ty[116]>>, 0), "
- "(<#1.out: <gpr_ty[52]>>, 0), "
- "(<#3.out: <gpr_ty[64]>>, 52)]), edges={}, reg=None), "
- "...: IGNode(#2, merged_reg_set=MergedRegSet(["
- "(<#2.out: KnownVLType(length=64)>, 0)]), edges={}, reg=None)}))"
- )
-
- def test_try_alloc_bigint_inc(self):
- fn = Fn()
- op0 = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3)))
- op1 = OpCopy(fn, op0.out, GPRType())
- arg = op1.dest
- op2 = OpInputMem(fn)
- mem = op2.out
- op3 = OpSetVLImm(fn, 32)
- vl = op3.out
- op4 = OpLoad(fn, arg, offset=0, mem=mem, vl=vl)
- a = op4.RT
- op5 = OpLI(fn, 0, vl=vl)
- b = op5.out
- op6 = OpSetCA(fn, True)
- ca = op6.out
- op7 = OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl)
- s = op7.out
- op8 = OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl)
- mem = op8.mem_out
-
- reg_assignments = try_allocate_registers_without_spilling(fn.ops)
-
- expected_reg_assignments = {
- op0.out: GPRRange(start=3, length=1),
- op1.dest: GPRRange(start=3, length=1),
- op2.out: GlobalMem.GlobalMem,
- op3.out: VL.VL_MAXVL,
- op4.RT: GPRRange(start=78, length=32),
- op5.out: GPRRange(start=46, length=32),
- op6.out: XERBit.CA,
- op7.out: GPRRange(start=14, length=32),
- op7.CA_out: XERBit.CA,
- op8.mem_out: GlobalMem.GlobalMem,
- }
-
- self.assertEqual(reg_assignments, expected_reg_assignments)
-
- def tst_try_alloc_concat(self, expected_regs, expected_dest_reg):
- # type: (list[GPRRange], GPRRange) -> None
- fn = Fn()
- inputs = []
- expected_reg_assignments = {}
- for i, r in enumerate(expected_regs):
- vl = OpSetVLImm(fn, r.length).out
- expected_reg_assignments[vl] = VL.VL_MAXVL
- inp = OpLI(fn, i, vl=vl).out
- inputs.append(inp)
- expected_reg_assignments[inp] = r
- concat = OpConcat(fn, inputs)
- expected_reg_assignments[concat.dest] = expected_dest_reg
-
- reg_assignments = try_allocate_registers_without_spilling(fn.ops)
-
- for inp, reg in zip(inputs, expected_regs):
- expected_reg_assignments[inp] = reg
-
- self.assertEqual(reg_assignments, expected_reg_assignments)
-
- def test_try_alloc_concat_1(self):
- self.tst_try_alloc_concat([GPRRange(3)], GPRRange(3))
-
- def test_try_alloc_concat_3(self):
- self.tst_try_alloc_concat([GPRRange(3, 3)], GPRRange(3, 3))
-
- def test_try_alloc_concat_3_5(self):
- self.tst_try_alloc_concat([GPRRange(3, 3), GPRRange(6, 5)],
- GPRRange(3, 8))
-
- def test_try_alloc_concat_5_3(self):
- self.tst_try_alloc_concat([GPRRange(3, 5), GPRRange(8, 3)],
- GPRRange(3, 8))
-
- def test_try_alloc_concat_1_2_3_4_5_6(self):
- self.tst_try_alloc_concat([
- GPRRange(14, 1),
- GPRRange(15, 2),
- GPRRange(17, 3),
- GPRRange(20, 4),
- GPRRange(24, 5),
- GPRRange(29, 6),
- ], GPRRange(14, 21))
-
-
-if __name__ == "__main__":
- unittest.main()
+++ /dev/null
-import unittest
-
-from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BYTES, SSAGPR, VL, FixedGPRRangeType, Fn,
- GlobalMem, GPRRange,
- GPRRangeType, OpCopy,
- OpFuncArg, OpInputMem,
- OpSetVLImm, OpStore, PreRASimState, SSAGPRRange, XERBit,
- generate_assembly)
-from bigint_presentation_code.register_allocator import allocate_registers
-from bigint_presentation_code.toom_cook import ToomCookInstance, simple_mul
-from bigint_presentation_code.util import FMap
-
-
-class SimpleMul192x192:
- def __init__(self):
- self.fn = fn = Fn()
- self.mem_in = mem = OpInputMem(fn).out
- self.dest_ptr_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3))).out
- self.lhs_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(4, 3))).out
- self.rhs_in = OpFuncArg(fn, FixedGPRRangeType(GPRRange(7, 3))).out
- dest_ptr = OpCopy(fn, self.dest_ptr_in, GPRRangeType()).dest
- vl = OpSetVLImm(fn, 3).out
- lhs = OpCopy(fn, self.lhs_in, GPRRangeType(3), vl=vl).dest
- rhs = OpCopy(fn, self.rhs_in, GPRRangeType(3), vl=vl).dest
- retval = simple_mul(fn, lhs, rhs)
- vl = OpSetVLImm(fn, 6).out
- self.mem_out = OpStore(fn, RS=retval, RA=dest_ptr, offset=0,
- mem_in=mem, vl=vl).mem_out
-
-
-class TestToomCook(unittest.TestCase):
- maxDiff = None
-
- def test_toom_2_repr(self):
- TOOM_2 = ToomCookInstance.make_toom_2()
- # print(repr(repr(TOOM_2)))
- self.assertEqual(
- repr(TOOM_2),
- "ToomCookInstance(lhs_part_count=2, rhs_part_count=2, "
- "eval_points=(0, 1, POINT_AT_INFINITY), "
- "lhs_eval_ops=("
- "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
- "EvalOpAdd(lhs="
- "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
- "rhs="
- "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
- "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
- "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
- " rhs_eval_ops=("
- "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
- "EvalOpAdd(lhs="
- "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
- "rhs="
- "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
- "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
- "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
- " prod_eval_ops=("
- "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
- "EvalOpSub(lhs="
- "EvalOpSub(lhs="
- "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
- "rhs="
- "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
- "poly=EvalOpPoly({0: Fraction(-1, 1), 1: Fraction(1, 1)})), "
- "rhs="
- "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
- "poly=EvalOpPoly({"
- "0: Fraction(-1, 1), 1: Fraction(1, 1), 2: Fraction(-1, 1)})), "
- "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)}))))"
- )
-
- def test_toom_2_5_repr(self):
- TOOM_2_5 = ToomCookInstance.make_toom_2_5()
- # print(repr(repr(TOOM_2_5)))
- self.assertEqual(
- repr(TOOM_2_5),
- "ToomCookInstance(lhs_part_count=3, rhs_part_count=2, "
- "eval_points=(0, 1, -1, POINT_AT_INFINITY), lhs_eval_ops=("
- "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
- "EvalOpAdd(lhs="
- "EvalOpAdd(lhs="
- "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
- "rhs=EvalOpInput(lhs=2, rhs=0, "
- "poly=EvalOpPoly({2: Fraction(1, 1)})), "
- "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), "
- "rhs=EvalOpInput(lhs=1, rhs=0, "
- "poly=EvalOpPoly({1: Fraction(1, 1)})), "
- "poly=EvalOpPoly({"
- "0: Fraction(1, 1), 1: Fraction(1, 1), 2: Fraction(1, 1)})), "
- "EvalOpSub(lhs="
- "EvalOpAdd(lhs="
- "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
- "rhs=EvalOpInput(lhs=2, rhs=0, "
- "poly=EvalOpPoly({2: Fraction(1, 1)})), "
- "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), "
- "rhs=EvalOpInput(lhs=1, rhs=0, "
- "poly=EvalOpPoly({1: Fraction(1, 1)})), poly=EvalOpPoly("
- "{0: Fraction(1, 1), 1: Fraction(-1, 1), 2: Fraction(1, 1)})), "
- "EvalOpInput(lhs=2, rhs=0, "
- "poly=EvalOpPoly({2: Fraction(1, 1)}))), rhs_eval_ops=("
- "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
- "EvalOpAdd(lhs=EvalOpInput(lhs=0, rhs=0, "
- "poly=EvalOpPoly({0: Fraction(1, 1)})), rhs="
- "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
- "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
- "EvalOpSub(lhs="
- "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
- "rhs=EvalOpInput(lhs=1, rhs=0, "
- "poly=EvalOpPoly({1: Fraction(1, 1)})), "
- "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(-1, 1)})), "
- "EvalOpInput(lhs=1, rhs=0, "
- "poly=EvalOpPoly({1: Fraction(1, 1)}))), "
- "prod_eval_ops=("
- "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
- "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpSub(lhs="
- "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
- "rhs=EvalOpInput(lhs=2, rhs=0, "
- "poly=EvalOpPoly({2: Fraction(1, 1)})), "
- "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(-1, 1)})), "
- "rhs=2, "
- "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(-1, 2)})), rhs="
- "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)})), "
- "poly=EvalOpPoly("
- "{1: Fraction(1, 2), 2: Fraction(-1, 2), 3: Fraction(-1, 1)})), "
- "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpAdd(lhs="
- "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
- "rhs="
- "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
- "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(1, 1)})), rhs=2, "
- "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(1, 2)})), rhs="
- "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
- "poly=EvalOpPoly("
- "{0: Fraction(-1, 1), 1: Fraction(1, 2), 2: Fraction(1, 2)})), "
- "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)}))))"
- )
-
- def test_reversed_toom_2_5_repr(self):
- TOOM_2_5 = ToomCookInstance.make_toom_2_5().reversed()
- # print(repr(repr(TOOM_2_5)))
- self.assertEqual(
- repr(TOOM_2_5),
- "ToomCookInstance(lhs_part_count=2, rhs_part_count=3, "
- "eval_points=(0, 1, -1, POINT_AT_INFINITY), lhs_eval_ops=("
- "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
- "EvalOpAdd(lhs="
- "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
- "rhs="
- "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
- "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
- "EvalOpSub(lhs="
- "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
- "rhs="
- "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
- "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(-1, 1)})), "
- "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
- " rhs_eval_ops=("
- "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
- "EvalOpAdd(lhs=EvalOpAdd(lhs="
- "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
- "rhs="
- "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
- "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), rhs="
- "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
- "poly=EvalOpPoly("
- "{0: Fraction(1, 1), 1: Fraction(1, 1), 2: Fraction(1, 1)})), "
- "EvalOpSub(lhs=EvalOpAdd(lhs="
- "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
- "rhs="
- "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
- "poly=EvalOpPoly({0: Fraction(1, 1), 2: Fraction(1, 1)})), rhs="
- "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
- "poly=EvalOpPoly("
- "{0: Fraction(1, 1), 1: Fraction(-1, 1), 2: Fraction(1, 1)})), "
- "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)}))),"
- " prod_eval_ops=("
- "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
- "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpSub(lhs="
- "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
- "rhs="
- "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
- "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(-1, 1)})), "
- "rhs=2, "
- "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(-1, 2)})), rhs="
- "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)})), "
- "poly=EvalOpPoly("
- "{1: Fraction(1, 2), 2: Fraction(-1, 2), 3: Fraction(-1, 1)})), "
- "EvalOpSub(lhs=EvalOpExactDiv(lhs=EvalOpAdd(lhs="
- "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
- "rhs="
- "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
- "poly=EvalOpPoly({1: Fraction(1, 1), 2: Fraction(1, 1)})), rhs=2, "
- "poly=EvalOpPoly({1: Fraction(1, 2), 2: Fraction(1, 2)})), rhs="
- "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
- "poly=EvalOpPoly("
- "{0: Fraction(-1, 1), 1: Fraction(1, 2), 2: Fraction(1, 2)})), "
- "EvalOpInput(lhs=3, rhs=0, poly=EvalOpPoly({3: Fraction(1, 1)}))))"
- )
-
- def test_simple_mul_192x192_pre_ra_sim(self):
- # test multiplying:
- # 0x000191acb262e15b_4c6b5f2b19e1a53e_821a2342132c5b57
- # * 0x4a37c0567bcbab53_cf1f597598194ae6_208a49071aeec507
- # ==
- # int("0x00074736574206e_6f69746163696c70"
- # "_69746c756d207469_622d3438333e2d32"
- # "_3931783239312079_7261727469627261", base=0)
- # == int.from_bytes(b"arbitrary 192x192->384-bit multiplication test",
- # 'little')
- code = SimpleMul192x192()
- dest_ptr = 0x100
- state = PreRASimState(
- gprs={}, VLs={}, CAs={}, global_mems={code.mem_in: FMap()},
- stack_slots={}, fixed_gprs={
- code.dest_ptr_in: (dest_ptr,),
- code.lhs_in: (0x821a2342132c5b57, 0x4c6b5f2b19e1a53e,
- 0x000191acb262e15b),
- code.rhs_in: (0x208a49071aeec507, 0xcf1f597598194ae6,
- 0x4a37c0567bcbab53)
- })
- code.fn.pre_ra_sim(state)
- expected_bytes = b"arbitrary 192x192->384-bit multiplication test"
- OUT_BYTE_COUNT = 6 * GPR_SIZE_IN_BYTES
- expected_bytes = expected_bytes.ljust(OUT_BYTE_COUNT, b'\0')
- mem_out = state.global_mems[code.mem_out]
- out_bytes = bytes(
- mem_out.get(dest_ptr + i, 0) for i in range(OUT_BYTE_COUNT))
- self.assertEqual(out_bytes, expected_bytes)
-
- def test_simple_mul_192x192_ops(self):
- code = SimpleMul192x192()
- fn = code.fn
- self.assertEqual([repr(v) for v in fn.ops], [
- 'OpInputMem(#0, <#0.out: GlobalMemType()>)',
- 'OpFuncArg(#1, <#1.out: <fixed(<r3>)>>)',
- 'OpFuncArg(#2, <#2.out: <fixed(<r4..len=3>)>>)',
- 'OpFuncArg(#3, <#3.out: <fixed(<r7..len=3>)>>)',
- 'OpCopy(#4, <#4.dest: <gpr_ty[1]>>, src=<#1.out: <fixed(<r3>)>>, '
- 'vl=None)',
- 'OpSetVLImm(#5, <#5.out: KnownVLType(length=3)>)',
- 'OpCopy(#6, <#6.dest: <gpr_ty[3]>>, '
- 'src=<#2.out: <fixed(<r4..len=3>)>>, '
- 'vl=<#5.out: KnownVLType(length=3)>)',
- 'OpCopy(#7, <#7.dest: <gpr_ty[3]>>, '
- 'src=<#3.out: <fixed(<r7..len=3>)>>, '
- 'vl=<#5.out: KnownVLType(length=3)>)',
- 'OpSplit(#8, results=(<#8.results[0]: <gpr_ty[1]>>, '
- '<#8.results[1]: <gpr_ty[1]>>, <#8.results[2]: <gpr_ty[1]>>), '
- 'src=<#7.dest: <gpr_ty[3]>>)',
- 'OpSetVLImm(#9, <#9.out: KnownVLType(length=3)>)',
- 'OpLI(#10, <#10.out: <gpr_ty[1]>>, value=0, vl=None)',
- 'OpBigIntMulDiv(#11, <#11.RT: <gpr_ty[3]>>, '
- 'RA=<#6.dest: <gpr_ty[3]>>, RB=<#8.results[0]: <gpr_ty[1]>>, '
- 'RC=<#10.out: <gpr_ty[1]>>, <#11.RS: <gpr_ty[1]>>, is_div=False, '
- 'vl=<#9.out: KnownVLType(length=3)>)',
- 'OpConcat(#12, <#12.dest: <gpr_ty[4]>>, sources=('
- '<#11.RT: <gpr_ty[3]>>, <#11.RS: <gpr_ty[1]>>))',
- 'OpBigIntMulDiv(#13, <#13.RT: <gpr_ty[3]>>, '
- 'RA=<#6.dest: <gpr_ty[3]>>, RB=<#8.results[1]: <gpr_ty[1]>>, '
- 'RC=<#10.out: <gpr_ty[1]>>, <#13.RS: <gpr_ty[1]>>, is_div=False, '
- 'vl=<#9.out: KnownVLType(length=3)>)',
- 'OpSplit(#14, results=(<#14.results[0]: <gpr_ty[1]>>, '
- '<#14.results[1]: <gpr_ty[3]>>), src=<#12.dest: <gpr_ty[4]>>)',
- 'OpSetCA(#15, <#15.out: CAType()>, value=False)',
- 'OpBigIntAddSub(#16, <#16.out: <gpr_ty[3]>>, '
- 'lhs=<#13.RT: <gpr_ty[3]>>, rhs=<#14.results[1]: <gpr_ty[3]>>, '
- 'CA_in=<#15.out: CAType()>, <#16.CA_out: CAType()>, is_sub=False, '
- 'vl=<#9.out: KnownVLType(length=3)>)',
- 'OpBigIntAddSub(#17, <#17.out: <gpr_ty[1]>>, '
- 'lhs=<#13.RS: <gpr_ty[1]>>, rhs=<#10.out: <gpr_ty[1]>>, '
- 'CA_in=<#16.CA_out: CAType()>, <#17.CA_out: CAType()>, '
- 'is_sub=False, vl=None)',
- 'OpConcat(#18, <#18.dest: <gpr_ty[5]>>, sources=('
- '<#14.results[0]: <gpr_ty[1]>>, <#16.out: <gpr_ty[3]>>, '
- '<#17.out: <gpr_ty[1]>>))',
- 'OpBigIntMulDiv(#19, <#19.RT: <gpr_ty[3]>>, '
- 'RA=<#6.dest: <gpr_ty[3]>>, RB=<#8.results[2]: <gpr_ty[1]>>, '
- 'RC=<#10.out: <gpr_ty[1]>>, <#19.RS: <gpr_ty[1]>>, is_div=False, '
- 'vl=<#9.out: KnownVLType(length=3)>)',
- 'OpSplit(#20, results=(<#20.results[0]: <gpr_ty[2]>>, '
- '<#20.results[1]: <gpr_ty[3]>>), src=<#18.dest: <gpr_ty[5]>>)',
- 'OpSetCA(#21, <#21.out: CAType()>, value=False)',
- 'OpBigIntAddSub(#22, <#22.out: <gpr_ty[3]>>, '
- 'lhs=<#19.RT: <gpr_ty[3]>>, rhs=<#20.results[1]: <gpr_ty[3]>>, '
- 'CA_in=<#21.out: CAType()>, <#22.CA_out: CAType()>, is_sub=False, '
- 'vl=<#9.out: KnownVLType(length=3)>)',
- 'OpBigIntAddSub(#23, <#23.out: <gpr_ty[1]>>, '
- 'lhs=<#19.RS: <gpr_ty[1]>>, rhs=<#10.out: <gpr_ty[1]>>, '
- 'CA_in=<#22.CA_out: CAType()>, <#23.CA_out: CAType()>, '
- 'is_sub=False, vl=None)',
- 'OpConcat(#24, <#24.dest: <gpr_ty[6]>>, sources=('
- '<#20.results[0]: <gpr_ty[2]>>, <#22.out: <gpr_ty[3]>>, '
- '<#23.out: <gpr_ty[1]>>))',
- 'OpSetVLImm(#25, <#25.out: KnownVLType(length=6)>)',
- 'OpStore(#26, RS=<#24.dest: <gpr_ty[6]>>, '
- 'RA=<#4.dest: <gpr_ty[1]>>, offset=0, '
- 'mem_in=<#0.out: GlobalMemType()>, '
- '<#26.mem_out: GlobalMemType()>, '
- 'vl=<#25.out: KnownVLType(length=6)>)'
- ])
-
- # FIXME: register allocator currently allocates wrong registers
- @unittest.expectedFailure
- def test_simple_mul_192x192_reg_alloc(self):
- code = SimpleMul192x192()
- fn = code.fn
- assigned_registers = allocate_registers(fn.ops)
- self.assertEqual(assigned_registers, {
- fn.ops[13].RS: GPRRange(9), # type: ignore
- fn.ops[14].results[0]: GPRRange(6), # type: ignore
- fn.ops[14].results[1]: GPRRange(7, length=3), # type: ignore
- fn.ops[15].out: XERBit.CA, # type: ignore
- fn.ops[16].out: GPRRange(7, length=3), # type: ignore
- fn.ops[16].CA_out: XERBit.CA, # type: ignore
- fn.ops[17].out: GPRRange(10), # type: ignore
- fn.ops[17].CA_out: XERBit.CA, # type: ignore
- fn.ops[18].dest: GPRRange(6, length=5), # type: ignore
- fn.ops[19].RT: GPRRange(3, length=3), # type: ignore
- fn.ops[19].RS: GPRRange(9), # type: ignore
- fn.ops[20].results[0]: GPRRange(6, length=2), # type: ignore
- fn.ops[20].results[1]: GPRRange(8, length=3), # type: ignore
- fn.ops[21].out: XERBit.CA, # type: ignore
- fn.ops[22].out: GPRRange(8, length=3), # type: ignore
- fn.ops[22].CA_out: XERBit.CA, # type: ignore
- fn.ops[23].out: GPRRange(11), # type: ignore
- fn.ops[23].CA_out: XERBit.CA, # type: ignore
- fn.ops[24].dest: GPRRange(6, length=6), # type: ignore
- fn.ops[25].out: VL.VL_MAXVL, # type: ignore
- fn.ops[26].mem_out: GlobalMem.GlobalMem, # type: ignore
- fn.ops[0].out: GlobalMem.GlobalMem, # type: ignore
- fn.ops[1].out: GPRRange(3), # type: ignore
- fn.ops[2].out: GPRRange(4, length=3), # type: ignore
- fn.ops[3].out: GPRRange(7, length=3), # type: ignore
- fn.ops[4].dest: GPRRange(12), # type: ignore
- fn.ops[5].out: VL.VL_MAXVL, # type: ignore
- fn.ops[6].dest: GPRRange(17, length=3), # type: ignore
- fn.ops[7].dest: GPRRange(14, length=3), # type: ignore
- fn.ops[8].results[0]: GPRRange(14), # type: ignore
- fn.ops[8].results[1]: GPRRange(15), # type: ignore
- fn.ops[8].results[2]: GPRRange(16), # type: ignore
- fn.ops[9].out: VL.VL_MAXVL, # type: ignore
- fn.ops[10].out: GPRRange(9), # type: ignore
- fn.ops[11].RT: GPRRange(6, length=3), # type: ignore
- fn.ops[11].RS: GPRRange(9), # type: ignore
- fn.ops[12].dest: GPRRange(6, length=4), # type: ignore
- fn.ops[13].RT: GPRRange(3, length=3) # type: ignore
- })
- self.fail("register allocator currently allocates wrong registers")
-
- # FIXME: register allocator currently allocates wrong registers
- @unittest.expectedFailure
- def test_simple_mul_192x192_asm(self):
- code = SimpleMul192x192()
- asm = generate_assembly(code.fn.ops)
- self.assertEqual(asm, [
- 'or 12, 3, 3',
- 'setvl 0, 0, 3, 0, 1, 1',
- 'sv.or *17, *4, *4',
- 'sv.or *14, *7, *7',
- 'setvl 0, 0, 3, 0, 1, 1',
- 'addi 9, 0, 0',
- 'sv.maddedu *6, *17, 14, 9',
- 'sv.maddedu *3, *17, 15, 9',
- 'addic 0, 0, 0',
- 'sv.adde *7, *3, *7',
- 'adde 10, 9, 9',
- 'sv.maddedu *3, *17, 16, 9',
- 'addic 0, 0, 0',
- 'sv.adde *8, *3, *8',
- 'adde 11, 9, 9',
- 'setvl 0, 0, 6, 0, 1, 1',
- 'sv.std *6, 0(12)',
- 'bclr 20, 0, 0'
- ])
- self.fail("register allocator currently allocates wrong registers")
-
-
-if __name__ == "__main__":
- unittest.main()
from abc import abstractmethod
from enum import Enum
from fractions import Fraction
-from typing import Any, Generic, Iterable, Mapping, Sequence, TypeVar, Union
+from typing import Any, Generic, Iterable, Mapping, TypeVar, Union
from nmutil.plain_data import plain_data
-from bigint_presentation_code.compiler_ir import Fn, Op, OpBigIntAddSub, OpBigIntMulDiv, OpConcat, OpLI, OpSetCA, OpSetVLImm, OpSplit, SSAGPRRange
+from bigint_presentation_code.compiler_ir import (Fn, OpBigIntAddSub,
+ OpBigIntMulDiv, OpConcat,
+ OpLI, OpSetCA, OpSetVLImm,
+ OpSplit, SSAGPRRange)
from bigint_presentation_code.matrix import Matrix
-from bigint_presentation_code.util import Literal, OSet, final
+from bigint_presentation_code.type_util import Literal, final
@final
return f"EvalOpPoly({self.coefficients})"
-_EvalOpLHS = TypeVar("_EvalOpLHS", int, "EvalOp")
-_EvalOpRHS = TypeVar("_EvalOpRHS", int, "EvalOp")
+_EvalOpLHS = TypeVar("_EvalOpLHS", int, "EvalOp[Any, Any]")
+_EvalOpRHS = TypeVar("_EvalOpRHS", int, "EvalOp[Any, Any]")
@plain_data(frozen=True, unsafe_hash=True)
__slots__ = ()
def __init__(self, lhs, rhs=0):
- # type: (...) -> None
+ # type: (int, int) -> None
if lhs < 0:
raise ValueError("Input part_index (lhs) must be >= 0")
if rhs != 0:
--- /dev/null
+from typing import TYPE_CHECKING, Any, NoReturn, Union
+
+if TYPE_CHECKING:
+ from typing_extensions import Literal, Self, final
+else:
+ def final(v):
+ return v
+
+ class _Literal:
+ def __getitem__(self, v):
+ if isinstance(v, tuple):
+ return Union[tuple(type(i) for i in v)]
+ return type(v)
+
+ Literal = _Literal()
+
+ Self = Any
+
+
+# pyright currently doesn't like typing_extensions' definition
+# -- added to typing in python 3.11
+def assert_never(arg):
+ # type: (NoReturn) -> NoReturn
+ raise AssertionError("got to code that's supposed to be unreachable")
+
+
+__all__ = [
+ "assert_never",
+ "final",
+ "Literal",
+ "Self",
+]
--- /dev/null
+from typing import NoReturn, TypeVar
+
+from typing_extensions import Literal, Self, final
+
+_T_co = TypeVar("_T_co", covariant=True)
+_T = TypeVar("_T")
+
+
+# pyright currently doesn't like typing_extensions' definition
+# -- added to typing in python 3.11
+def assert_never(arg: NoReturn) -> NoReturn: ...
+
+
+__all__ = [
+ "assert_never",
+ "final",
+ "Literal",
+ "Self",
+]
from abc import abstractmethod
-from typing import (TYPE_CHECKING, AbstractSet, Any, Iterable, Iterator,
- Mapping, MutableSet, NoReturn, TypeVar, Union)
+from typing import (AbstractSet, Any, Iterable, Iterator, Mapping, MutableSet,
+ TypeVar, overload)
-if TYPE_CHECKING:
- from typing_extensions import Literal, Self, final
-else:
- def final(v):
- return v
-
- class _Literal:
- def __getitem__(self, v):
- if isinstance(v, tuple):
- return Union[tuple(type(i) for i in v)]
- return type(v)
-
- Literal = _Literal()
-
- Self = Any
+from bigint_presentation_code.type_util import Self, final
_T_co = TypeVar("_T_co", covariant=True)
_T = TypeVar("_T")
__all__ = [
- "assert_never",
"BaseBitSet",
"bit_count",
"BitSet",
"FBitSet",
- "final",
"FMap",
- "Literal",
"OFSet",
"OSet",
- "Self",
"top_set_bit_index",
"trailing_zero_count",
]
-# pyright currently doesn't like typing_extensions' definition
-# -- added to typing in python 3.11
-def assert_never(arg):
- # type: (NoReturn) -> NoReturn
- raise AssertionError("got to code that's supposed to be unreachable")
-
-
class OFSet(AbstractSet[_T_co]):
""" ordered frozen set """
__slots__ = "__items",
self.__items = {v: None for v in items}
def __contains__(self, x):
+ # type: (Any) -> bool
return x in self.__items
def __iter__(self):
+ # type: () -> Iterator[_T_co]
return iter(self.__items)
def __len__(self):
+ # type: () -> int
return len(self.__items)
def __hash__(self):
+ # type: () -> int
return self._hash()
def __repr__(self):
+ # type: () -> str
if len(self) == 0:
return "OFSet()"
return f"OFSet({list(self)})"
self.__items = {v: None for v in items}
def __contains__(self, x):
+ # type: (Any) -> bool
return x in self.__items
def __iter__(self):
+ # type: () -> Iterator[_T]
return iter(self.__items)
def __len__(self):
+ # type: () -> int
return len(self.__items)
def add(self, value):
self.__items.pop(value, None)
def __repr__(self):
+ # type: () -> str
if len(self) == 0:
return "OSet()"
return f"OSet({list(self)})"
"""ordered frozen hashable mapping"""
__slots__ = "__items", "__hash"
+ @overload
+ def __init__(self, items):
+ # type: (Mapping[_T, _T_co]) -> None
+ ...
+
+ @overload
+ def __init__(self, items):
+ # type: (Iterable[tuple[_T, _T_co]]) -> None
+ ...
+
+ @overload
+ def __init__(self):
+ # type: () -> None
+ ...
+
def __init__(self, items=()):
# type: (Mapping[_T, _T_co] | Iterable[tuple[_T, _T_co]]) -> None
self.__items = dict(items) # type: dict[_T, _T_co]
return iter(self.__items)
def __len__(self):
+ # type: () -> int
return len(self.__items)
def __eq__(self, other):
- # type: (object) -> bool
+ # type: (FMap[Any, Any] | Any) -> bool
if isinstance(other, FMap):
return self.__items == other.__items
return super().__eq__(other)
def __hash__(self):
+ # type: () -> int
if self.__hash is None:
self.__hash = hash(frozenset(self.items()))
return self.__hash
def __repr__(self):
+ # type: () -> str
return f"FMap({self.__items})"
try:
# added in cpython 3.10
- bit_count = int.bit_count # type: ignore[attr]
+ bit_count = int.bit_count # type: ignore
except AttributeError:
def bit_count(v):
# type: (int) -> int
def __init__(self, items=(), bits=0):
# type: (Iterable[int], int) -> None
- for item in items:
- if item < 0:
- raise ValueError("can't store negative integers")
- bits |= 1 << item
+ if isinstance(items, BaseBitSet):
+ bits |= items.bits
+ else:
+ for item in items:
+ if item < 0:
+ raise ValueError("can't store negative integers")
+ bits |= 1 << item
if bits < 0:
raise ValueError("can't store an infinite set")
self.__bits = bits
@property
def bits(self):
+ # type: () -> int
return self.__bits
@bits.setter
self.__bits = bits
def __contains__(self, x):
+ # type: (Any) -> bool
if isinstance(x, int) and x >= 0:
return (1 << x) & self.bits != 0
return False
bits -= 1 << index
def __len__(self):
+ # type: () -> int
return bit_count(self.bits)
def __repr__(self):
+ # type: () -> str
if self.bits == 0:
return f"{self.__class__.__name__}()"
if self.bits > 0xFFFFFFFF and len(self) < 10:
return f"{self.__class__.__name__}(bits={hex(self.bits)})"
def __eq__(self, other):
- # type: (object) -> bool
+ # type: (Any) -> bool
if not isinstance(other, BaseBitSet):
return super().__eq__(other)
return self.bits == other.bits
self.bits &= ~(1 << value)
def clear(self):
+ # type: () -> None
self.bits = 0
def __ior__(self, it):
return True
def __hash__(self):
+ # type: () -> int
return super()._hash()
+++ /dev/null
-from abc import abstractmethod
-from typing import (AbstractSet, Any, Iterable, Iterator, Mapping, MutableSet,
- NoReturn, TypeVar, overload)
-
-from typing_extensions import Literal, Self, final
-
-_T_co = TypeVar("_T_co", covariant=True)
-_T = TypeVar("_T")
-
-__all__ = [
- "assert_never",
- "BaseBitSet",
- "bit_count",
- "BitSet",
- "FBitSet",
- "final",
- "FMap",
- "Literal",
- "OFSet",
- "OSet",
- "Self",
- "top_set_bit_index",
- "trailing_zero_count",
-]
-
-
-# pyright currently doesn't like typing_extensions' definition
-# -- added to typing in python 3.11
-def assert_never(arg):
- # type: (NoReturn) -> NoReturn
- raise AssertionError("got to code that's supposed to be unreachable")
-
-
-class OFSet(AbstractSet[_T_co]):
- """ ordered frozen set """
-
- def __init__(self, items: Iterable[_T_co] = ()):
- ...
-
- def __contains__(self, x: object) -> bool:
- ...
-
- def __iter__(self) -> Iterator[_T_co]:
- ...
-
- def __len__(self) -> int:
- ...
-
- def __hash__(self) -> int:
- ...
-
- def __repr__(self) -> str:
- ...
-
-
-class OSet(MutableSet[_T]):
- """ ordered mutable set """
-
- def __init__(self, items: Iterable[_T] = ()):
- ...
-
- def __contains__(self, x: object) -> bool:
- ...
-
- def __iter__(self) -> Iterator[_T]:
- ...
-
- def __len__(self) -> int:
- ...
-
- def add(self, value: _T) -> None:
- ...
-
- def discard(self, value: _T) -> None:
- ...
-
- def __repr__(self) -> str:
- ...
-
-
-class FMap(Mapping[_T, _T_co]):
- """ordered frozen hashable mapping"""
- @overload
- def __init__(self, items: Mapping[_T, _T_co]): ...
- @overload
- def __init__(self, items: Iterable[tuple[_T, _T_co]]): ...
- @overload
- def __init__(self): ...
-
- def __getitem__(self, item: _T) -> _T_co:
- ...
-
- def __iter__(self) -> Iterator[_T]:
- ...
-
- def __len__(self) -> int:
- ...
-
- def __eq__(self, other: object) -> bool:
- ...
-
- def __hash__(self) -> int:
- ...
-
- def __repr__(self) -> str:
- ...
-
-
-def trailing_zero_count(v: int, default: int = -1) -> int: ...
-def top_set_bit_index(v: int, default: int = -1) -> int: ...
-def bit_count(v: int) -> int: ...
-
-
-class BaseBitSet(AbstractSet[int]):
- @classmethod
- @abstractmethod
- def _frozen(cls) -> bool: ...
-
- @classmethod
- def _from_bits(cls, bits: int) -> Self: ...
-
- def __init__(self, items: Iterable[int] = (), bits: int = 0): ...
-
- @property
- def bits(self) -> int:
- ...
-
- @bits.setter
- def bits(self, bits: int) -> None: ...
-
- def __contains__(self, x: object) -> bool: ...
-
- def __iter__(self) -> Iterator[int]: ...
-
- def __reversed__(self) -> Iterator[int]: ...
-
- def __len__(self) -> int: ...
-
- def __repr__(self) -> str: ...
-
- def __eq__(self, other: object) -> bool: ...
-
- def __and__(self, other: Iterable[Any]) -> Self: ...
-
- __rand__ = __and__
-
- def __or__(self, other: Iterable[Any]) -> Self: ...
-
- __ror__ = __or__
-
- def __xor__(self, other: Iterable[Any]) -> Self: ...
-
- __rxor__ = __xor__
-
- def __sub__(self, other: Iterable[Any]) -> Self: ...
-
- def __rsub__(self, other: Iterable[Any]) -> Self: ...
-
- def isdisjoint(self, other: Iterable[Any]) -> bool: ...
-
-
-class BitSet(BaseBitSet, MutableSet[int]):
- @final
- @classmethod
- def _frozen(cls) -> Literal[False]: ...
-
- def add(self, value: int) -> None: ...
-
- def discard(self, value: int) -> None: ...
-
- def clear(self) -> None: ...
-
- def __ior__(self, it: AbstractSet[Any]) -> Self: ...
-
- def __iand__(self, it: AbstractSet[Any]) -> Self: ...
-
- def __ixor__(self, it: AbstractSet[Any]) -> Self: ...
-
- def __isub__(self, it: AbstractSet[Any]) -> Self: ...
-
-
-class FBitSet(BaseBitSet):
- @property
- def bits(self) -> int: ...
-
- @final
- @classmethod
- def _frozen(cls) -> Literal[True]: ...
-
- def __hash__(self) -> int: ...
-from typing import Any, Callable, Generic, TypeVar, overload
-
-_T = TypeVar("_T")
-
-
-class cached_property(Generic[_T]):
- def __init__(self, func: Callable[[Any], _T]) -> None: ...
-
- @overload
- def __get__(self, instance: None,
- owner: type[Any] | None = ...) -> cached_property[_T]: ...
-
- @overload
- def __get__(self, instance: object,
- owner: type[Any] | None = ...) -> _T: ...
+cached_property = property