From 3224f08298e1d6371649783c33b3a77174d1de70 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 3 Nov 2022 00:38:58 -0700 Subject: [PATCH] working on code --- .../_tests/test_compiler_ir.py | 1 - .../_tests/test_compiler_ir2.py | 332 ++++++++++++++---- .../_tests/test_matrix.py | 10 +- .../_tests/test_util.py | 2 +- src/bigint_presentation_code/compiler_ir2.py | 329 ++++++++++++++--- .../register_allocator2.py | 200 ++++------- src/bigint_presentation_code/util.py | 4 + 7 files changed, 629 insertions(+), 249 deletions(-) diff --git a/src/bigint_presentation_code/_tests/test_compiler_ir.py b/src/bigint_presentation_code/_tests/test_compiler_ir.py index 68df120..820c305 100644 --- a/src/bigint_presentation_code/_tests/test_compiler_ir.py +++ b/src/bigint_presentation_code/_tests/test_compiler_ir.py @@ -9,7 +9,6 @@ from bigint_presentation_code.compiler_ir import (VL, FixedGPRRangeType, Fn, RegLoc, SSAVal, XERBit, generate_assembly, op_set_to_list) -import bigint_presentation_code.compiler_ir2 class TestCompilerIR(unittest.TestCase): diff --git a/src/bigint_presentation_code/_tests/test_compiler_ir2.py b/src/bigint_presentation_code/_tests/test_compiler_ir2.py index 40f02fa..116326a 100644 --- a/src/bigint_presentation_code/_tests/test_compiler_ir2.py +++ b/src/bigint_presentation_code/_tests/test_compiler_ir2.py @@ -1,13 +1,27 @@ import unittest from bigint_presentation_code.compiler_ir2 import (GPR_SIZE_IN_BYTES, Fn, - OpKind, PreRASimState, + FnAnalysis, OpKind, OpStage, + PreRASimState, ProgramPoint, SSAVal) class TestCompilerIR(unittest.TestCase): maxDiff = None + def test_program_point(self): + # type: () -> None + expected = [] # type: list[ProgramPoint] + for op_index in range(5): + for stage in OpStage: + expected.append(ProgramPoint(op_index=op_index, stage=stage)) + + for idx, pp in enumerate(expected): + if idx + 1 < len(expected): + self.assertEqual(pp.next(), expected[idx + 1]) + + self.assertEqual(sorted(expected), expected) + def make_add_fn(self): # type: () -> tuple[Fn, SSAVal] fn = Fn() @@ -28,10 +42,128 @@ class TestCompilerIR(unittest.TestCase): op5 = fn.append_new_op( OpKind.SvAddE, input_vals=[a, b, ca, vl], maxvl=MAXVL, name="add") s = op5.outputs[0] - fn.append_new_op(OpKind.SvStd, input_vals=[s, arg, vl], - immediates=[0], maxvl=MAXVL, name="st") + _ = fn.append_new_op(OpKind.SvStd, input_vals=[s, arg, vl], + immediates=[0], maxvl=MAXVL, name="st") return fn, arg + def test_fn_analysis(self): + fn, _arg = self.make_add_fn() + fn_analysis = FnAnalysis(fn) + print(repr(fn_analysis)) + self.assertEqual( + repr(fn_analysis), + "FnAnalysis(fn=, uses=FMap({" + ">: OFSet([" + ">, >]), " + ">: OFSet([" + ">, >, " + ">, " + ">]), " + ">: OFSet([" + ">]), " + ">: OFSet([" + ">]), " + ">: OFSet([>]), " + ">: OFSet([" + ">]), " + ">: OFSet()}), " + "op_indexes=FMap({" + "Op(kind=OpKind.FuncArgR3, input_vals=[], input_uses=(), " + "immediates=[], outputs=(>,), " + "name='arg'): 0, " + "Op(kind=OpKind.SetVLI, input_vals=[], input_uses=(), " + "immediates=[32], outputs=(>,), " + "name='vl'): 1, " + "Op(kind=OpKind.SvLd, input_vals=[" + ">, >], " + "input_uses=(>, " + ">), immediates=[0], " + "outputs=(>,), name='ld'): 2, " + "Op(kind=OpKind.SvLI, input_vals=[>], " + "input_uses=(>,), immediates=[0], " + "outputs=(>,), name='li'): 3, " + "Op(kind=OpKind.SetCA, input_vals=[], input_uses=(), " + "immediates=[], outputs=(>,), name='ca'): 4, " + "Op(kind=OpKind.SvAddE, input_vals=[" + ">, >, " + ">, >], " + "input_uses=(>, " + ">, >, " + ">), immediates=[], outputs=(" + ">, >), " + "name='add'): 5, " + "Op(kind=OpKind.SvStd, input_vals=[" + ">, >, " + ">], " + "input_uses=(>, " + ">, >), " + "immediates=[0], outputs=(), name='st'): 6}), " + "live_ranges=FMap({" + ">: , " + ">: , " + ">: , " + ">: , " + ">: , " + ">: , " + ">: }), " + "live_at=FMap({" + ": OFSet([>]), " + ": OFSet([>]), " + ": OFSet([>]), " + ": OFSet([" + ">, >]), " + ": OFSet([" + ">, >, " + ">]), " + ": OFSet([" + ">, >, " + ">]), " + ": OFSet([" + ">, >, " + ">, >]), " + ": OFSet([" + ">, >, " + ">, >]), " + ": OFSet([" + ">, >, " + ">, >]), " + ": OFSet([" + ">, >, " + ">, >, " + ">]), " + ": OFSet([" + ">, >, " + ">, >, " + ">, >, " + ">]), " + ": OFSet([" + ">, >, " + ">, >]), " + ": OFSet([" + ">, >, " + ">]), " + ": OFSet()}), " + "def_program_ranges=FMap({" + ">: , " + ">: , " + ">: , " + ">: , " + ">: , " + ">: , " + ">: }), " + "use_program_points=FMap({" + ">: , " + ">: , " + ">: , " + ">: , " + ">: , " + ">: , " + ">: , " + ">: , " + ">: , " + ">: }), " + "all_program_points=)") + def test_repr(self): fn, _arg = self.make_add_fn() self.assertEqual([repr(i) for i in fn.ops], [ @@ -86,74 +218,91 @@ class TestCompilerIR(unittest.TestCase): "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet([3])}), ty=), " - "tied_input_index=None, spread_index=None),), maxvl=1)", + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early),), maxvl=1)", "OpProperties(kind=OpKind.SetVLI, " "inputs=(), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None),), maxvl=1)", + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=1)", "OpProperties(kind=OpKind.SvLd, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), " "ty=), " - "tied_input_index=None, spread_index=None), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None)), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None),), maxvl=32)", + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early),), maxvl=32)", "OpProperties(kind=OpKind.SvLI, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None),), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early),), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None),), maxvl=32)", + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early),), maxvl=32)", "OpProperties(kind=OpKind.SetCA, " "inputs=(), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.CA: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None),), maxvl=1)", + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=1)", "OpProperties(kind=OpKind.SvAddE, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.CA: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None)), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.CA: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None)), maxvl=32)", + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)), maxvl=32)", "OpProperties(kind=OpKind.SvStd, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), " "ty=), " - "tied_input_index=None, spread_index=None), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None)), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)), " "outputs=(), maxvl=32)", ]) @@ -313,221 +462,268 @@ class TestCompilerIR(unittest.TestCase): "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet([3])}), ty=), " - "tied_input_index=None, spread_index=None),), maxvl=1)", + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early),), maxvl=1)", "OpProperties(kind=OpKind.CopyFromReg, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), " "ty=), " - "tied_input_index=None, spread_index=None),), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early),), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)]), " "LocKind.StackI64: FBitSet(range(0, 1024))}), ty=), " - "tied_input_index=None, spread_index=None),), maxvl=1)", + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=1)", "OpProperties(kind=OpKind.SetVLI, " "inputs=(), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None),), maxvl=1)", + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=1)", "OpProperties(kind=OpKind.CopyToReg, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)]), " "LocKind.StackI64: FBitSet(range(0, 1024))}), ty=), " - "tied_input_index=None, spread_index=None),), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early),), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), " "ty=), " - "tied_input_index=None, spread_index=None),), maxvl=1)", + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=1)", "OpProperties(kind=OpKind.SvLd, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), " "ty=), " - "tied_input_index=None, spread_index=None), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None)), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None),), maxvl=32)", + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early),), maxvl=32)", "OpProperties(kind=OpKind.SetVLI, " "inputs=(), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None),), maxvl=1)", + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=1)", "OpProperties(kind=OpKind.VecCopyFromReg, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None)), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97)), " "LocKind.StackI64: FBitSet(range(0, 993))}), ty=), " - "tied_input_index=None, spread_index=None),), maxvl=32)", + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=32)", "OpProperties(kind=OpKind.SvLI, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None),), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early),), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None),), maxvl=32)", + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early),), maxvl=32)", "OpProperties(kind=OpKind.SetVLI, " "inputs=(), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None),), maxvl=1)", + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=1)", "OpProperties(kind=OpKind.VecCopyFromReg, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None)), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97)), " "LocKind.StackI64: FBitSet(range(0, 993))}), ty=), " - "tied_input_index=None, spread_index=None),), maxvl=32)", + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=32)", "OpProperties(kind=OpKind.SetCA, " "inputs=(), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.CA: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None),), maxvl=1)", + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=1)", "OpProperties(kind=OpKind.SetVLI, " "inputs=(), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None),), maxvl=1)", + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=1)", "OpProperties(kind=OpKind.VecCopyToReg, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97)), " "LocKind.StackI64: FBitSet(range(0, 993))}), ty=), " - "tied_input_index=None, spread_index=None), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None)), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None),), maxvl=32)", + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=32)", "OpProperties(kind=OpKind.SetVLI, " "inputs=(), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None),), maxvl=1)", + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=1)", "OpProperties(kind=OpKind.VecCopyToReg, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97)), " "LocKind.StackI64: FBitSet(range(0, 993))}), ty=), " - "tied_input_index=None, spread_index=None), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None)), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None),), maxvl=32)", + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=32)", "OpProperties(kind=OpKind.SvAddE, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.CA: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None)), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.CA: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None)), maxvl=32)", + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)), maxvl=32)", "OpProperties(kind=OpKind.SetVLI, " "inputs=(), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None),), maxvl=1)", + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=1)", "OpProperties(kind=OpKind.VecCopyFromReg, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None)), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97)), " "LocKind.StackI64: FBitSet(range(0, 993))}), ty=), " - "tied_input_index=None, spread_index=None),), maxvl=32)", + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=32)", "OpProperties(kind=OpKind.SetVLI, " "inputs=(), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None),), maxvl=1)", + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=1)", "OpProperties(kind=OpKind.VecCopyToReg, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97)), " "LocKind.StackI64: FBitSet(range(0, 993))}), ty=), " - "tied_input_index=None, spread_index=None), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None)), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None),), maxvl=32)", + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=32)", "OpProperties(kind=OpKind.CopyToReg, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)]), " "LocKind.StackI64: FBitSet(range(0, 1024))}), ty=), " - "tied_input_index=None, spread_index=None),), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early),), " "outputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), " "ty=), " - "tied_input_index=None, spread_index=None),), maxvl=1)", + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=1)", "OpProperties(kind=OpKind.SvStd, " "inputs=(" "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), " "ty=), " - "tied_input_index=None, spread_index=None), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None)), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)), " "outputs=(), maxvl=32)", ]) @@ -616,4 +812,4 @@ class TestCompilerIR(unittest.TestCase): if __name__ == "__main__": - unittest.main() + _ = unittest.main() diff --git a/src/bigint_presentation_code/_tests/test_matrix.py b/src/bigint_presentation_code/_tests/test_matrix.py index 1a56df0..78bd990 100644 --- a/src/bigint_presentation_code/_tests/test_matrix.py +++ b/src/bigint_presentation_code/_tests/test_matrix.py @@ -100,14 +100,14 @@ class TestMatrix(unittest.TestCase): -_1_2, _1_6, _1_2, -_1_6, 2, 0, 0, 0, 0, 1])) with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"): - Matrix(1, 1, [0]).inverse() + _ = Matrix(1, 1, [0]).inverse() with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"): - Matrix(2, 2, [0, 0, 1, 1]).inverse() + _ = Matrix(2, 2, [0, 0, 1, 1]).inverse() with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"): - Matrix(2, 2, [1, 0, 1, 0]).inverse() + _ = Matrix(2, 2, [1, 0, 1, 0]).inverse() with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"): - Matrix(2, 2, [1, 1, 1, 1]).inverse() + _ = Matrix(2, 2, [1, 1, 1, 1]).inverse() if __name__ == "__main__": - unittest.main() + _ = unittest.main() diff --git a/src/bigint_presentation_code/_tests/test_util.py b/src/bigint_presentation_code/_tests/test_util.py index 0bfe365..d409df2 100644 --- a/src/bigint_presentation_code/_tests/test_util.py +++ b/src/bigint_presentation_code/_tests/test_util.py @@ -27,4 +27,4 @@ class TestBitSet(unittest.TestCase): if __name__ == "__main__": - unittest.main() + _ = unittest.main() diff --git a/src/bigint_presentation_code/compiler_ir2.py b/src/bigint_presentation_code/compiler_ir2.py index 03b623d..05dc6f4 100644 --- a/src/bigint_presentation_code/compiler_ir2.py +++ b/src/bigint_presentation_code/compiler_ir2.py @@ -1,7 +1,8 @@ +from collections import defaultdict import enum from abc import ABCMeta, abstractmethod from enum import Enum, unique -from functools import lru_cache +from functools import lru_cache, total_ordering from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator, Sequence, TypeVar, overload) from weakref import WeakValueDictionary as _WeakVDict @@ -9,7 +10,7 @@ from weakref import WeakValueDictionary as _WeakVDict from cached_property import cached_property from nmutil.plain_data import fields, plain_data -from bigint_presentation_code.type_util import Self, assert_never, final +from bigint_presentation_code.type_util import Self, assert_never, final, Literal from bigint_presentation_code.util import BitSet, FBitSet, FMap, OFSet, OSet @@ -111,25 +112,221 @@ class Fn: assert_never(out.ty.base_ty) +@final +@unique +@total_ordering +class OpStage(Enum): + value: Literal[0, 1] # type: ignore + + def __new__(cls, value): + # type: (int) -> OpStage + value = int(value) + if value not in (0, 1): + raise ValueError("invalid value") + retval = object.__new__(cls) + retval._value_ = value + return retval + + Early = 0 + """ early stage of Op execution, where all input reads occur. + all output writes with `write_stage == Early` occur here too, and therefore + conflict with input reads, telling the compiler that it that can't share + that output's register with any inputs that the output isn't tied to. + + All outputs, even unused outputs, can't share registers with any other + outputs, independent of `write_stage` settings. + """ + Late = 1 + """ late stage of Op execution, where all output writes with + `write_stage == Late` occur, and therefore don't conflict with input reads, + telling the compiler that any inputs can safely use the same register as + those outputs. + + All outputs, even unused outputs, can't share registers with any other + outputs, independent of `write_stage` settings. + """ + + def __repr__(self): + # type: () -> str + return f"OpStage.{self._name_}" + + def __lt__(self, other): + # type: (OpStage | object) -> bool + if isinstance(other, OpStage): + return self.value < other.value + return NotImplemented + + +assert OpStage.Early < OpStage.Late, "early must be less than late" + + +@plain_data(frozen=True, unsafe_hash=True, repr=False) +@final +@total_ordering +class ProgramPoint: + __slots__ = "op_index", "stage" + + def __init__(self, op_index, stage): + # type: (int, OpStage) -> None + self.op_index = op_index + self.stage = stage + + @property + def int_value(self): + # type: () -> int + """ an integer representation of `self` such that it keeps ordering and + successor/predecessor relations. + """ + return self.op_index * 2 + self.stage.value + + @staticmethod + def from_int_value(int_value): + # type: (int) -> ProgramPoint + op_index, stage = divmod(int_value, 2) + return ProgramPoint(op_index=op_index, stage=OpStage(stage)) + + def next(self, steps=1): + # type: (int) -> ProgramPoint + return ProgramPoint.from_int_value(self.int_value + steps) + + def prev(self, steps=1): + # type: (int) -> ProgramPoint + return self.next(steps=-steps) + + def __lt__(self, other): + # type: (ProgramPoint | Any) -> bool + if not isinstance(other, ProgramPoint): + return NotImplemented + if self.op_index != other.op_index: + return self.op_index < other.op_index + return self.stage < other.stage + + def __repr__(self): + # type: () -> str + return f"" + + +@plain_data(frozen=True, unsafe_hash=True, repr=False) +@final +class ProgramRange(Sequence[ProgramPoint]): + __slots__ = "start", "stop" + + def __init__(self, start, stop): + # type: (ProgramPoint, ProgramPoint) -> None + self.start = start + self.stop = stop + + @cached_property + def int_value_range(self): + # type: () -> range + return range(self.start.int_value, self.stop.int_value) + + @staticmethod + def from_int_value_range(int_value_range): + # type: (range) -> ProgramRange + if int_value_range.step != 1: + raise ValueError("int_value_range must have step == 1") + return ProgramRange( + start=ProgramPoint.from_int_value(int_value_range.start), + stop=ProgramPoint.from_int_value(int_value_range.stop)) + + @overload + def __getitem__(self, __idx): + # type: (int) -> ProgramPoint + ... + + @overload + def __getitem__(self, __idx): + # type: (slice) -> ProgramRange + ... + + def __getitem__(self, __idx): + # type: (int | slice) -> ProgramPoint | ProgramRange + v = range(self.start.int_value, self.stop.int_value)[__idx] + if isinstance(v, int): + return ProgramPoint.from_int_value(v) + return ProgramRange.from_int_value_range(v) + + def __len__(self): + # type: () -> int + return len(self.int_value_range) + + def __iter__(self): + # type: () -> Iterator[ProgramPoint] + return map(ProgramPoint.from_int_value, self.int_value_range) + + def __repr__(self): + # type: () -> str + start = repr(self.start).lstrip("<").rstrip(">") + stop = repr(self.stop).lstrip("<").rstrip(">") + return f"" + + @plain_data(frozen=True, eq=False) @final -class FnWithUses: - __slots__ = "fn", "uses" +class FnAnalysis: + __slots__ = ("fn", "uses", "op_indexes", "live_ranges", "live_at", + "def_program_ranges", "use_program_points", + "all_program_points") def __init__(self, fn): # type: (Fn) -> None self.fn = fn - retval = {} # type: dict[SSAVal, OSet[SSAUse]] + self.op_indexes = FMap((op, idx) for idx, op in enumerate(fn.ops)) + self.all_program_points = ProgramRange( + start=ProgramPoint(op_index=0, stage=OpStage.Early), + stop=ProgramPoint(op_index=len(fn.ops), stage=OpStage.Early)) + def_program_ranges = {} # type: dict[SSAVal, ProgramRange] + use_program_points = {} # type: dict[SSAUse, ProgramPoint] + uses = {} # type: dict[SSAVal, OSet[SSAUse]] + live_range_stops = {} # type: dict[SSAVal, ProgramPoint] for op in fn.ops: - for idx, inp in enumerate(op.input_vals): - retval[inp].add(SSAUse(op, idx)) + for use in op.input_uses: + uses[use.ssa_val].add(use) + use_program_point = self.__get_use_program_point(use) + use_program_points[use] = use_program_point + live_range_stops[use.ssa_val] = max( + live_range_stops[use.ssa_val], use_program_point.next()) for out in op.outputs: - retval[out] = OSet() - self.uses = FMap((k, OFSet(v)) for k, v in retval.items()) + uses[out] = OSet() + def_program_range = self.__get_def_program_range(out) + def_program_ranges[out] = def_program_range + live_range_stops[out] = def_program_range.stop + self.uses = FMap((k, OFSet(v)) for k, v in uses.items()) + self.def_program_ranges = FMap(def_program_ranges) + self.use_program_points = FMap(use_program_points) + live_ranges = {} # type: dict[SSAVal, ProgramRange] + live_at = {i: OSet[SSAVal]() for i in self.all_program_points} + for ssa_val in uses.keys(): + live_ranges[ssa_val] = live_range = ProgramRange( + start=self.def_program_ranges[ssa_val].start, + stop=live_range_stops[ssa_val]) + for program_point in live_range: + live_at[program_point].add(ssa_val) + self.live_ranges = FMap(live_ranges) + self.live_at = FMap((k, OFSet(v)) for k, v in live_at.items()) + + def __get_def_program_range(self, ssa_val): + # type: (SSAVal) -> ProgramRange + write_stage = ssa_val.defining_descriptor.write_stage + start = ProgramPoint( + op_index=self.op_indexes[ssa_val.op], stage=write_stage) + # always include late stage of ssa_val.op, to ensure outputs always + # overlap all other outputs. + # stop is exclusive, so we need the next program point. + stop = ProgramPoint(op_index=start.op_index, stage=OpStage.Late).next() + return ProgramRange(start=start, stop=stop) + + def __get_use_program_point(self, ssa_use): + # type: (SSAUse) -> ProgramPoint + assert ssa_use.defining_descriptor.write_stage is OpStage.Early, \ + "assumed here, ensured by GenericOpProperties.__init__" + return ProgramPoint( + op_index=self.op_indexes[ssa_use.op], stage=OpStage.Early) def __eq__(self, other): - # type: (FnWithUses | Any) -> bool - if isinstance(other, FnWithUses): + # type: (FnAnalysis | Any) -> bool + if isinstance(other, FnAnalysis): return self.fn == other.fn return NotImplemented @@ -255,10 +452,9 @@ class LocSubKind(Enum): # type: () -> LocKind # pyright fails typechecking when using `in` here: # reported: https://github.com/microsoft/pyright/issues/4102 - if self is LocSubKind.BASE_GPR or self is LocSubKind.SV_EXTRA2_VGPR \ - or self is LocSubKind.SV_EXTRA2_SGPR \ - or self is LocSubKind.SV_EXTRA3_VGPR \ - or self is LocSubKind.SV_EXTRA3_SGPR: + if self in (LocSubKind.BASE_GPR, LocSubKind.SV_EXTRA2_VGPR, + LocSubKind.SV_EXTRA2_SGPR, LocSubKind.SV_EXTRA3_VGPR, + LocSubKind.SV_EXTRA3_SGPR): return LocKind.GPR if self is LocSubKind.StackI64: return LocKind.StackI64 @@ -526,12 +722,24 @@ class LocSet(AbstractSet[Loc]): def __hash__(self): return self.__hash + @lru_cache(maxsize=None, typed=True) + def max_conflicts_with(self, other): + # type: (LocSet | Loc) -> int + """the largest number of Locs in `self` that a single Loc + from `other` can conflict with + """ + if isinstance(other, LocSet): + return max(self.max_conflicts_with(i) for i in other) + else: + return sum(other.conflicts(i) for i in self) + @plain_data(frozen=True, unsafe_hash=True) @final class GenericOperandDesc: """generic Op operand descriptor""" - __slots__ = "ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread" + __slots__ = ("ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread", + "write_stage") def __init__( self, ty, # type: GenericTy @@ -540,6 +748,7 @@ class GenericOperandDesc: fixed_loc=None, # type: Loc | None tied_input_index=None, # type: int | None spread=False, # type: bool + write_stage=OpStage.Early, # type: OpStage ): # type: (...) -> None self.ty = ty @@ -577,15 +786,26 @@ class GenericOperandDesc: raise ValueError("operand can't be both spread and fixed") if self.ty.is_vec: raise ValueError("operand can't be both spread and vector") + self.write_stage = write_stage def tied_to_input(self, tied_input_index): # type: (int) -> Self return GenericOperandDesc(self.ty, self.sub_kinds, - tied_input_index=tied_input_index) + tied_input_index=tied_input_index, + write_stage=self.write_stage) def with_fixed_loc(self, fixed_loc): # type: (Loc) -> Self - return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc) + return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc, + write_stage=self.write_stage) + + def with_write_stage(self, write_stage): + # type: (OpStage) -> Self + return GenericOperandDesc(self.ty, self.sub_kinds, + fixed_loc=self.fixed_loc, + tied_input_index=self.tied_input_index, + spread=self.spread, + write_stage=write_stage) def instantiate(self, maxvl): # type: (int) -> Iterable[OperandDesc] @@ -613,17 +833,19 @@ class GenericOperandDesc: idx = None yield OperandDesc(loc_set_before_spread=loc_set_before_spread, tied_input_index=self.tied_input_index, - spread_index=idx) + spread_index=idx, write_stage=self.write_stage) @plain_data(frozen=True, unsafe_hash=True) @final class OperandDesc: """Op operand descriptor""" - __slots__ = "loc_set_before_spread", "tied_input_index", "spread_index" + __slots__ = ("loc_set_before_spread", "tied_input_index", "spread_index", + "write_stage") - def __init__(self, loc_set_before_spread, tied_input_index, spread_index): - # type: (LocSet, int | None, int | None) -> None + def __init__(self, loc_set_before_spread, tied_input_index, spread_index, + write_stage): + # type: (LocSet, int | None, int | None, OpStage) -> None if len(loc_set_before_spread) == 0: raise ValueError("loc_set_before_spread must not be empty") self.loc_set_before_spread = loc_set_before_spread @@ -631,6 +853,7 @@ class OperandDesc: if self.tied_input_index is not None and self.spread_index is not None: raise ValueError("operand can't be both spread and tied") self.spread_index = spread_index + self.write_stage = write_stage @cached_property def ty_before_spread(self): @@ -702,13 +925,16 @@ class GenericOpProperties: has_side_effects=False, # type: bool ): # type: (...) -> None - self.demo_asm = demo_asm - self.inputs = tuple(inputs) + self.demo_asm = demo_asm # type: str + self.inputs = tuple(inputs) # type: tuple[GenericOperandDesc, ...] for inp in self.inputs: if inp.tied_input_index is not None: raise ValueError( f"tied_input_index is not allowed on inputs: {inp}") - self.outputs = tuple(outputs) + if inp.write_stage is not OpStage.Early: + raise ValueError( + f"write_stage is not allowed on inputs: {inp}") + self.outputs = tuple(outputs) # type: tuple[GenericOperandDesc, ...] fixed_locs = [] # type: list[tuple[Loc, int]] for idx, out in enumerate(self.outputs): if out.tied_input_index is not None: @@ -727,10 +953,10 @@ class GenericOpProperties: f"outputs[{other_idx}]: {out.fixed_loc} conflicts " f"with {other_fixed_loc}") fixed_locs.append((out.fixed_loc, idx)) - self.immediates = tuple(immediates) - self.is_copy = is_copy - self.is_load_immediate = is_load_immediate - self.has_side_effects = has_side_effects + self.immediates = tuple(immediates) # type: tuple[range, ...] + self.is_copy = is_copy # type: bool + self.is_load_immediate = is_load_immediate # type: bool + self.has_side_effects = has_side_effects # type: bool @plain_data(frozen=True, unsafe_hash=True) @@ -740,16 +966,16 @@ class OpProperties: def __init__(self, kind, maxvl): # type: (OpKind, int) -> None - self.kind = kind + self.kind = kind # type: OpKind inputs = [] # type: list[OperandDesc] for inp in self.generic.inputs: inputs.extend(inp.instantiate(maxvl=maxvl)) - self.inputs = tuple(inputs) + self.inputs = tuple(inputs) # type: tuple[OperandDesc, ...] outputs = [] # type: list[OperandDesc] for out in self.generic.outputs: outputs.extend(out.instantiate(maxvl=maxvl)) - self.outputs = tuple(outputs) - self.maxvl = maxvl + self.outputs = tuple(outputs) # type: tuple[OperandDesc, ...] + self.maxvl = maxvl # type: int @property def generic(self): @@ -807,6 +1033,7 @@ class OpKind(Enum): return OpProperties(self, maxvl=maxvl) def __repr__(self): + # type: () -> str return "OpKind." + self._name_ @cached_property @@ -821,7 +1048,7 @@ class OpKind(Enum): ClearCA = GenericOpProperties( demo_asm="addic 0, 0, 0", inputs=[], - outputs=[OD_CA], + outputs=[OD_CA.with_write_stage(OpStage.Late)], ) _PRE_RA_SIMS[ClearCA] = lambda: OpKind.__clearca_pre_ra_sim @@ -832,7 +1059,7 @@ class OpKind(Enum): SetCA = GenericOpProperties( demo_asm="subfc 0, 0, 0", inputs=[], - outputs=[OD_CA], + outputs=[OD_CA.with_write_stage(OpStage.Late)], ) _PRE_RA_SIMS[SetCA] = lambda: OpKind.__setca_pre_ra_sim @@ -906,7 +1133,7 @@ class OpKind(Enum): SetVLI = GenericOpProperties( demo_asm="setvl 0, 0, imm, 0, 1, 1", inputs=(), - outputs=[OD_VL], + outputs=[OD_VL.with_write_stage(OpStage.Late)], immediates=[range(1, 65)], is_load_immediate=True, ) @@ -935,7 +1162,7 @@ class OpKind(Enum): LI = GenericOpProperties( demo_asm="addi RT, 0, imm", inputs=(), - outputs=[OD_BASE_SGPR], + outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)], immediates=[IMM_S16], is_load_immediate=True, ) @@ -951,7 +1178,7 @@ class OpKind(Enum): ty=GenericTy(BaseTy.I64, is_vec=True), sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64], ), OD_VL], - outputs=[OD_EXTRA3_VGPR], + outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)], is_copy=True, ) _PRE_RA_SIMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_pre_ra_sim @@ -966,6 +1193,7 @@ class OpKind(Enum): outputs=[GenericOperandDesc( ty=GenericTy(BaseTy.I64, is_vec=True), sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64], + write_stage=OpStage.Late, )], is_copy=True, ) @@ -985,6 +1213,7 @@ class OpKind(Enum): outputs=[GenericOperandDesc( ty=GenericTy(BaseTy.I64, is_vec=False), sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR], + write_stage=OpStage.Late, )], is_copy=True, ) @@ -1004,6 +1233,7 @@ class OpKind(Enum): ty=GenericTy(BaseTy.I64, is_vec=False), sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR, LocSubKind.StackI64], + write_stage=OpStage.Late, )], is_copy=True, ) @@ -1021,7 +1251,7 @@ class OpKind(Enum): sub_kinds=[LocSubKind.SV_EXTRA3_VGPR], spread=True, ), OD_VL], - outputs=[OD_EXTRA3_VGPR], + outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)], is_copy=True, ) _PRE_RA_SIMS[Concat] = lambda: OpKind.__concat_pre_ra_sim @@ -1038,6 +1268,7 @@ class OpKind(Enum): ty=GenericTy(BaseTy.I64, is_vec=False), sub_kinds=[LocSubKind.SV_EXTRA3_VGPR], spread=True, + write_stage=OpStage.Late, )], is_copy=True, ) @@ -1072,7 +1303,7 @@ class OpKind(Enum): Ld = GenericOpProperties( demo_asm="ld RT, imm(RA)", inputs=[OD_BASE_SGPR], - outputs=[OD_BASE_SGPR], + outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)], immediates=[IMM_S16], ) _PRE_RA_SIMS[Ld] = lambda: OpKind.__ld_pre_ra_sim @@ -1130,6 +1361,7 @@ class SSAValOrUse(metaclass=ABCMeta): def __init__(self, op, operand_idx): # type: (Op, int) -> None + super().__init__() self.op = op if operand_idx < 0 or operand_idx >= len(self.descriptor_array): raise ValueError("invalid operand_idx") @@ -1146,7 +1378,7 @@ class SSAValOrUse(metaclass=ABCMeta): # type: () -> tuple[OperandDesc, ...] ... - @property + @cached_property def defining_descriptor(self): # type: () -> OperandDesc return self.descriptor_array[self.operand_idx] @@ -1215,6 +1447,11 @@ class SSAVal(SSAValOrUse): return SSAUse(op=self.op, operand_idx=self.defining_descriptor.tied_input_index) + @property + def write_stage(self): + # type: () -> OpStage + return self.defining_descriptor.write_stage + @plain_data(frozen=True, unsafe_hash=True, repr=False) @final @@ -1292,12 +1529,13 @@ class OpInputSeq(Sequence[_T], Generic[_T, _Desc]): def __init__(self, items, op): # type: (Iterable[_T], Op) -> None + super().__init__() self.__op = op self.__items = [] # type: list[_T] for idx, item in enumerate(items): if idx >= len(self.descriptors): raise ValueError("too many items") - self._verify_write(idx, item) + _ = self._verify_write(idx, item) self.__items.append(item) if len(self.__items) < len(self.descriptors): raise ValueError("not enough items") @@ -1334,6 +1572,7 @@ class OpInputSeq(Sequence[_T], Generic[_T, _Desc]): return len(self.__items) def __repr__(self): + # type: () -> str return f"{self.__class__.__name__}({self.__items}, op=...)" @@ -1402,6 +1641,7 @@ class Op: @property def kind(self): + # type: () -> OpKind return self.properties.kind def __eq__(self, other): @@ -1411,6 +1651,7 @@ class Op: return NotImplemented def __hash__(self): + # type: () -> int return object.__hash__(self) def __repr__(self): @@ -1473,8 +1714,8 @@ class PreRASimState: def __init__(self, ssa_vals, memory): # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None - self.ssa_vals = ssa_vals - self.memory = memory + self.ssa_vals = ssa_vals # type: dict[SSAVal, tuple[int, ...]] + self.memory = memory # type: dict[int, int] def load_byte(self, addr): # type: (int) -> int diff --git a/src/bigint_presentation_code/register_allocator2.py b/src/bigint_presentation_code/register_allocator2.py index 68443d9..c7d9a88 100644 --- a/src/bigint_presentation_code/register_allocator2.py +++ b/src/bigint_presentation_code/register_allocator2.py @@ -6,13 +6,14 @@ this uses an algorithm based on: """ from itertools import combinations -from typing import Any, Generic, Iterable, Iterator, Mapping, MutableSet +from typing import Any, Iterable, Iterator, Mapping, MutableSet from cached_property import cached_property from nmutil.plain_data import plain_data -from bigint_presentation_code.compiler_ir2 import (BaseTy, FnWithUses, Loc, - LocSet, Op, SSAVal, Ty) +from bigint_presentation_code.compiler_ir2 import (BaseTy, FnAnalysis, Loc, + LocSet, Op, ProgramRange, + SSAVal, Ty) from bigint_presentation_code.type_util import final from bigint_presentation_code.util import FMap, OFSet, OSet @@ -23,6 +24,7 @@ class LiveInterval: def __init__(self, first_write, last_use=None): # type: (int, int | None) -> None + super().__init__() if last_use is None: last_use = first_write if last_use < first_write: @@ -86,11 +88,11 @@ class MergedSSAVal: * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)` * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)` """ - __slots__ = "fn_with_uses", "ssa_val_offsets", "first_ssa_val", "loc_set" + __slots__ = "fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set" - def __init__(self, fn_with_uses, ssa_val_offsets): - # type: (FnWithUses, Mapping[SSAVal, int] | SSAVal) -> None - self.fn_with_uses = fn_with_uses + def __init__(self, fn_analysis, ssa_val_offsets): + # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal) -> None + self.fn_analysis = fn_analysis if isinstance(ssa_val_offsets, SSAVal): ssa_val_offsets = {ssa_val_offsets: 0} self.ssa_val_offsets = FMap(ssa_val_offsets) # type: FMap[SSAVal, int] @@ -111,7 +113,7 @@ class MergedSSAVal: # type: () -> Iterable[Loc] for loc in ssa_val.def_loc_set_before_spread: disallowed_by_use = False - for use in fn_with_uses.uses[ssa_val]: + for use in fn_analysis.uses[ssa_val]: use_spread_idx = \ use.defining_descriptor.spread_index or 0 # calculate the start for the use's Loc before spread @@ -145,7 +147,7 @@ class MergedSSAVal: @cached_property def __hash(self): # type: () -> int - return hash((self.fn_with_uses, self.ssa_val_offsets)) + return hash((self.fn_analysis, self.ssa_val_offsets)) def __hash__(self): # type: () -> int @@ -190,7 +192,7 @@ class MergedSSAVal: def offset_by(self, amount): # type: (int) -> MergedSSAVal v = {k: v + amount for k, v in self.ssa_val_offsets.items()} - return MergedSSAVal(fn_with_uses=self.fn_with_uses, ssa_val_offsets=v) + return MergedSSAVal(fn_analysis=self.fn_analysis, ssa_val_offsets=v) def normalized(self): # type: () -> MergedSSAVal @@ -212,16 +214,28 @@ class MergedSSAVal: # type: (*MergedSSAVal) -> MergedSSAVal retval = dict(self.ssa_val_offsets) for other in others: - if other.fn_with_uses != self.fn_with_uses: - raise ValueError("fn_with_uses mismatch") + if other.fn_analysis != self.fn_analysis: + raise ValueError("fn_analysis mismatch") for ssa_val, offset in other.ssa_val_offsets.items(): if ssa_val in retval and retval[ssa_val] != offset: raise BadMergedSSAVal(f"offset mismatch for {ssa_val}: " f"{retval[ssa_val]} != {offset}") retval[ssa_val] = offset - return MergedSSAVal(fn_with_uses=self.fn_with_uses, + return MergedSSAVal(fn_analysis=self.fn_analysis, ssa_val_offsets=retval) + @cached_property + def live_interval(self): + # type: () -> ProgramRange + live_range = self.fn_analysis.live_ranges[self.first_ssa_val] + start = live_range.start + stop = live_range.stop + for ssa_val in self.ssa_vals: + live_range = self.fn_analysis.live_ranges[ssa_val] + start = min(start, live_range.start) + stop = max(stop, live_range.stop) + return ProgramRange(start=start, stop=stop) + @final class MergedSSAValsMap(Mapping[SSAVal, MergedSSAVal]): @@ -322,11 +336,11 @@ class MergedSSAValsSet(MutableSet[MergedSSAVal]): @plain_data(frozen=True) @final class MergedSSAVals: - __slots__ = "fn_with_uses", "merge_map", "merged_ssa_vals" + __slots__ = "fn_analysis", "merge_map", "merged_ssa_vals" - def __init__(self, fn_with_uses, merged_ssa_vals): - # type: (FnWithUses, Iterable[MergedSSAVal]) -> None - self.fn_with_uses = fn_with_uses + def __init__(self, fn_analysis, merged_ssa_vals): + # type: (FnAnalysis, Iterable[MergedSSAVal]) -> None + self.fn_analysis = fn_analysis self.merge_map = MergedSSAValsMap() self.merged_ssa_vals = self.merge_map.values_set for i in merged_ssa_vals: @@ -345,10 +359,10 @@ class MergedSSAVals: return merged @staticmethod - def minimally_merged(fn_with_uses): - # type: (FnWithUses) -> MergedSSAVals - retval = MergedSSAVals(fn_with_uses=fn_with_uses, merged_ssa_vals=()) - for op in fn_with_uses.fn.ops: + def minimally_merged(fn_analysis): + # type: (FnAnalysis) -> MergedSSAVals + retval = MergedSSAVals(fn_analysis=fn_analysis, merged_ssa_vals=()) + for op in fn_analysis.fn.ops: for inp in op.input_uses: if inp.unspread_start != inp: retval.merge(inp.unspread_start.ssa_val, inp.ssa_val, @@ -362,67 +376,16 @@ class MergedSSAVals: return retval -# FIXME: work on code from here - - -@final -class LiveIntervals(Mapping[MergedSSAVal, LiveInterval]): - def __init__(self, merged_ssa_vals): - # type: (list[Op]) -> None - self.__merged_reg_sets = MergedRegSets(ops) - live_intervals = {} # type: dict[MergedRegSet[_RegType], LiveInterval] - for op_idx, op in enumerate(ops): - for val in op.inputs().values(): - live_intervals[self.__merged_reg_sets[val]] += op_idx - for val in op.outputs().values(): - reg_set = self.__merged_reg_sets[val] - if reg_set not in live_intervals: - live_intervals[reg_set] = LiveInterval(op_idx) - else: - live_intervals[reg_set] += op_idx - self.__live_intervals = live_intervals - live_after = [] # type: list[OSet[MergedRegSet[_RegType]]] - live_after += (OSet() for _ in ops) - for reg_set, live_interval in self.__live_intervals.items(): - for i in live_interval.live_after_op_range: - live_after[i].add(reg_set) - self.__live_after = [OFSet(i) for i in live_after] - - @property - def merged_reg_sets(self): - return self.__merged_reg_sets - - def __getitem__(self, key): - # type: (MergedRegSet[_RegType]) -> LiveInterval - return self.__live_intervals[key] - - def __iter__(self): - return iter(self.__live_intervals) - - def __len__(self): - return len(self.__live_intervals) - - def reg_sets_live_after(self, op_index): - # type: (int) -> OFSet[MergedRegSet[_RegType]] - return self.__live_after[op_index] - - def __repr__(self): - reg_sets_live_after = dict(enumerate(self.__live_after)) - return (f"LiveIntervals(live_intervals={self.__live_intervals}, " - f"merged_reg_sets={self.merged_reg_sets}, " - f"reg_sets_live_after={reg_sets_live_after})") - - @final -class IGNode(Generic[_RegType]): +class IGNode: """ interference graph node """ - __slots__ = "merged_reg_set", "edges", "reg" + __slots__ = "merged_ssa_val", "edges", "loc" - def __init__(self, merged_reg_set, edges=(), reg=None): - # type: (MergedRegSet[_RegType], Iterable[IGNode], RegLoc | None) -> None - self.merged_reg_set = merged_reg_set + def __init__(self, merged_ssa_val, edges=(), loc=None): + # type: (MergedSSAVal, Iterable[IGNode], Loc | None) -> None + self.merged_ssa_val = merged_ssa_val self.edges = OSet(edges) - self.reg = reg + self.loc = loc def add_edge(self, other): # type: (IGNode) -> None @@ -432,11 +395,11 @@ class IGNode(Generic[_RegType]): def __eq__(self, other): # type: (object) -> bool if isinstance(other, IGNode): - return self.merged_reg_set == other.merged_reg_set + return self.merged_ssa_val == other.merged_ssa_val return NotImplemented def __hash__(self): - return hash(self.merged_reg_set) + return hash(self.merged_ssa_val) def __repr__(self, nodes=None): # type: (None | dict[IGNode, int]) -> str @@ -447,54 +410,32 @@ class IGNode(Generic[_RegType]): nodes[self] = len(nodes) edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}" return (f"IGNode(#{nodes[self]}, " - f"merged_reg_set={self.merged_reg_set}, " + f"merged_ssa_val={self.merged_ssa_val}, " f"edges={edges}, " - f"reg={self.reg})") + f"loc={self.loc})") @property - def reg_class(self): - # type: () -> RegClass - return self.merged_reg_set.ty.reg_class + def loc_set(self): + # type: () -> LocSet + return self.merged_ssa_val.loc_set - def reg_conflicts_with_neighbors(self, reg): - # type: (RegLoc) -> bool + def loc_conflicts_with_neighbors(self, loc): + # type: (Loc) -> bool for neighbor in self.edges: - if neighbor.reg is not None and neighbor.reg.conflicts(reg): + if neighbor.loc is not None and neighbor.loc.conflicts(loc): return True return False -@final -class InterferenceGraph(Mapping[MergedRegSet[_RegType], IGNode[_RegType]]): - def __init__(self, merged_reg_sets): - # type: (Iterable[MergedRegSet[_RegType]]) -> None - self.__nodes = {i: IGNode(i) for i in merged_reg_sets} - - def __getitem__(self, key): - # type: (MergedRegSet[_RegType]) -> IGNode - return self.__nodes[key] - - def __iter__(self): - return iter(self.__nodes) - - def __len__(self): - return len(self.__nodes) - - def __repr__(self): - nodes = {} - nodes_text = [f"...: {node.__repr__(nodes)}" for node in self.values()] - nodes_text = ", ".join(nodes_text) - return f"InterferenceGraph(nodes={{{nodes_text}}})" - - @plain_data() class AllocationFailed: - __slots__ = "node", "live_intervals", "interference_graph" + __slots__ = "node", "merged_ssa_vals", "interference_graph" - def __init__(self, node, live_intervals, interference_graph): - # type: (IGNode, LiveIntervals, InterferenceGraph) -> None + def __init__(self, node, merged_ssa_vals, interference_graph): + # type: (IGNode, MergedSSAVals, dict[MergedSSAVal, IGNode]) -> None + super().__init__() self.node = node - self.live_intervals = live_intervals + self.merged_ssa_vals = merged_ssa_vals self.interference_graph = interference_graph @@ -505,25 +446,24 @@ class AllocationFailedError(Exception): self.allocation_failed = allocation_failed -def try_allocate_registers_without_spilling(ops): - # type: (list[Op]) -> dict[SSAVal, RegLoc] | AllocationFailed +def try_allocate_registers_without_spilling(merged_ssa_vals): + # type: (MergedSSAVals) -> dict[SSAVal, Loc] | AllocationFailed - live_intervals = LiveIntervals(ops) - merged_reg_sets = live_intervals.merged_reg_sets - interference_graph = InterferenceGraph(merged_reg_sets.values()) - for op_idx, op in enumerate(ops): - reg_sets = live_intervals.reg_sets_live_after(op_idx) - for i, j in combinations(reg_sets, 2): - if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0: - interference_graph[i].add_edge(interference_graph[j]) - for i, j in op.get_extra_interferences(): - i = merged_reg_sets[i] - j = merged_reg_sets[j] - if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0: + interference_graph = { + i: IGNode(i) for i in merged_ssa_vals.merged_ssa_vals} + fn_analysis = merged_ssa_vals.fn_analysis + for ssa_vals in fn_analysis.live_at.values(): + live_merged_ssa_vals = OSet() # type: OSet[MergedSSAVal] + for ssa_val in ssa_vals: + live_merged_ssa_vals.add(merged_ssa_vals.merge_map[ssa_val]) + for i, j in combinations(live_merged_ssa_vals, 2): + if i.loc_set.max_conflicts_with(j.loc_set) != 0: interference_graph[i].add_edge(interference_graph[j]) nodes_remaining = OSet(interference_graph.values()) +# FIXME: work on code from here + def local_colorability_score(node): # type: (IGNode) -> int """ returns a positive integer if node is locally colorable, returns @@ -532,7 +472,7 @@ def try_allocate_registers_without_spilling(ops): """ if node not in nodes_remaining: raise ValueError() - retval = len(node.reg_class) + retval = len(node.loc_set) for neighbor in node.edges: if neighbor in nodes_remaining: retval -= node.reg_class.max_conflicts_with(neighbor.reg_class) diff --git a/src/bigint_presentation_code/util.py b/src/bigint_presentation_code/util.py index b85b3ac..757f267 100644 --- a/src/bigint_presentation_code/util.py +++ b/src/bigint_presentation_code/util.py @@ -26,6 +26,7 @@ class OFSet(AbstractSet[_T_co]): def __init__(self, items=()): # type: (Iterable[_T_co]) -> None + super().__init__() self.__items = {v: None for v in items} def __contains__(self, x): @@ -57,6 +58,7 @@ class OSet(MutableSet[_T]): def __init__(self, items=()): # type: (Iterable[_T]) -> None + super().__init__() self.__items = {v: None for v in items} def __contains__(self, x): @@ -107,6 +109,7 @@ class FMap(Mapping[_T, _T_co]): def __init__(self, items=()): # type: (Mapping[_T, _T_co] | Iterable[tuple[_T, _T_co]]) -> None + super().__init__() self.__items = dict(items) # type: dict[_T, _T_co] self.__hash = None # type: None | int @@ -179,6 +182,7 @@ class BaseBitSet(AbstractSet[int]): def __init__(self, items=(), bits=0): # type: (Iterable[int], int) -> None + super().__init__() if isinstance(items, BaseBitSet): bits |= items.bits else: -- 2.30.2