From: Jacob Lifshay Date: Fri, 16 Dec 2022 10:03:51 +0000 (-0800) Subject: optimize LocSet.max_conflicts_with X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=a2ae6c06df330f726c90fa06411feef8e79056f8;p=bigint-presentation-code.git optimize LocSet.max_conflicts_with --- diff --git a/src/bigint_presentation_code/compiler_ir.py b/src/bigint_presentation_code/compiler_ir.py index 9e4d13d..724841c 100644 --- a/src/bigint_presentation_code/compiler_ir.py +++ b/src/bigint_presentation_code/compiler_ir.py @@ -15,7 +15,7 @@ from nmutil import plain_data # type: ignore from bigint_presentation_code.type_util import (Literal, Self, assert_never, final) from bigint_presentation_code.util import (BitSet, FBitSet, FMap, Interned, - OFSet, OSet) + OFSet, OSet, bit_count) GPR_SIZE_IN_BYTES = 8 BITS_IN_BYTE = 8 @@ -762,7 +762,7 @@ class Loc(Interned): # type: () -> Ty return self.make_ty(kind=self.kind, reg_len=self.reg_len) - @property + @cached_property def stop(self): # type: () -> int return self.start + self.reg_len @@ -912,7 +912,31 @@ class LocSet(OFSet[Loc], Interned): if isinstance(other, LocSet): return max(self.max_conflicts_with(i) for i in other) else: - return sum(other.conflicts(i) for i in self) + reg_len = self.reg_len + if reg_len is None: + return 0 + starts = self.starts.get(other.kind) + if starts is None: + return 0 + # now we do the equivalent of: + # return sum(other.conflicts(i) for i in self) + # which is the equivalent of: + # return sum(other.start < start + reg_len + # and start < other.start + other.reg_len + # for start in starts) + stops = starts.bits << reg_len + + # find all the bit indexes `i` where `i < other.start + 1` + lt_other_start_plus_1 = ~(~0 << (other.start + 1)) + + # find all the bit indexes `i` where + # `i < other.start + other.reg_len + reg_len` + lt_other_start_plus_other_reg_len_plus_reg_len = ( + ~(~0 << (other.start + other.reg_len + reg_len))) + included = ~(stops & lt_other_start_plus_1) + included &= stops + included &= lt_other_start_plus_other_reg_len_plus_reg_len + return bit_count(included) def __repr__(self): return f"LocSet(starts={self.starts!r}, ty={self.ty!r})"