From: Jacob Lifshay Date: Fri, 27 Jan 2023 06:11:49 +0000 (-0800) Subject: LocSet now allows multiple reg_lens simultaneously X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=99a453e743c00b1c896fadaa0be892694e6293c6;p=bigint-presentation-code.git LocSet now allows multiple reg_lens simultaneously --- diff --git a/Cargo.lock b/Cargo.lock index 5c9cfba..dd2eb38 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -39,6 +39,7 @@ dependencies = [ "libfuzzer-sys", "num-bigint", "num-traits", + "once_cell", "petgraph", "scoped-tls", "serde", diff --git a/register_allocator/Cargo.toml b/register_allocator/Cargo.toml index 05e792d..b19356f 100644 --- a/register_allocator/Cargo.toml +++ b/register_allocator/Cargo.toml @@ -20,6 +20,7 @@ num-traits = "0.2.15" petgraph = "0.6.2" libfuzzer-sys = { version = "0.4.5", optional = true } arbitrary = { version = "1.2.2", features = ["derive"] } +once_cell = "1.17.0" [features] fuzzing = ["libfuzzer-sys"] diff --git a/register_allocator/fuzz/Cargo.lock b/register_allocator/fuzz/Cargo.lock index 009ac03..8b6e14e 100644 --- a/register_allocator/fuzz/Cargo.lock +++ b/register_allocator/fuzz/Cargo.lock @@ -39,6 +39,7 @@ dependencies = [ "libfuzzer-sys", "num-bigint", "num-traits", + "once_cell", "petgraph", "scoped-tls", "serde", diff --git a/register_allocator/fuzz/Cargo.toml b/register_allocator/fuzz/Cargo.toml index 6b0672c..f51036c 100644 --- a/register_allocator/fuzz/Cargo.toml +++ b/register_allocator/fuzz/Cargo.toml @@ -25,3 +25,15 @@ name = "fn_new" path = "fuzz_targets/fn_new.rs" test = false doc = false + +[[bin]] +name = "loc_set_ops" +path = "fuzz_targets/loc_set_ops.rs" +test = false +doc = false + +[[bin]] +name = "loc_set_max_conflicts_with" +path = "fuzz_targets/loc_set_max_conflicts_with.rs" +test = false +doc = false diff --git a/register_allocator/fuzz/fuzz_targets/loc_set_max_conflicts_with.rs b/register_allocator/fuzz/fuzz_targets/loc_set_max_conflicts_with.rs new file mode 100644 index 0000000..c2ee30f --- /dev/null +++ b/register_allocator/fuzz/fuzz_targets/loc_set_max_conflicts_with.rs @@ -0,0 +1,36 @@ +#![no_main] +use bigint_presentation_code_register_allocator::{ + interned::{GlobalState, Intern}, + loc::Loc, + loc_set::LocSet, +}; +use libfuzzer_sys::fuzz_target; + +fuzz_target!(|data: (LocSet, LocSet)| { + GlobalState::scope(|| { + GlobalState::get(|fast_global_state| { + GlobalState::scope(|| { + GlobalState::get(|reference_global_state| { + let (a, b) = data; + let a_fast = a.to_interned(fast_global_state); + let a_reference = a.into_interned(reference_global_state); + let b_fast = b.to_interned(fast_global_state); + let b_reference = b.into_interned(reference_global_state); + if let Some(loc) = b_fast.iter().next() { + let fast = a_fast.clone().max_conflicts_with(loc, fast_global_state); + let reference = a_reference + .clone() + .max_conflicts_with(loc, reference_global_state); + assert_eq!(fast, reference, "a={a_fast:?} loc={loc:?}"); + } + let fast = a_fast + .clone() + .max_conflicts_with(b_fast.clone(), fast_global_state); + let reference = + a_reference.max_conflicts_with(b_reference, reference_global_state); + assert_eq!(fast, reference, "a={a_fast:?} b={b_fast:?}"); + }) + }) + }) + }) +}); diff --git a/register_allocator/fuzz/fuzz_targets/loc_set_ops.rs b/register_allocator/fuzz/fuzz_targets/loc_set_ops.rs new file mode 100644 index 0000000..c7be296 --- /dev/null +++ b/register_allocator/fuzz/fuzz_targets/loc_set_ops.rs @@ -0,0 +1,120 @@ +#![no_main] +use bigint_presentation_code_register_allocator::{loc::Loc, loc_set::LocSet}; +use libfuzzer_sys::fuzz_target; +use std::{ + collections::HashSet, + ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Sub, SubAssign}, +}; + +fn check_op( + lhs_hash_set: &HashSet, + rhs_hash_set: &HashSet, + op: impl Fn(LocSet, LocSet) -> LocSet, + expected: impl IntoIterator, +) { + let lhs = LocSet::from_iter(lhs_hash_set.iter().copied()); + let rhs = LocSet::from_iter(rhs_hash_set.iter().copied()); + let result = op(lhs, rhs); + let result: Vec = result.iter().collect(); + let mut expected = Vec::from_iter(expected); + expected.sort(); + assert_eq!(result, expected); +} + +macro_rules! check_all_op_combos { + ( + $lhs_hash_set:expr, $rhs_hash_set:expr, $expected:expr, + $bin_op:ident::$bin_op_fn:ident(), + $bin_assign_op:ident::$bin_assign_op_fn:ident(), + $( + #[rev] + $bin_assign_rev_op_fn:ident(), + )? + ) => { + check_op( + $lhs_hash_set, + $rhs_hash_set, + |lhs, rhs| $bin_op::$bin_op_fn(lhs, rhs), + $expected, + ); + check_op( + $lhs_hash_set, + $rhs_hash_set, + |lhs, rhs| $bin_op::$bin_op_fn(&lhs, rhs), + $expected, + ); + check_op( + $lhs_hash_set, + $rhs_hash_set, + |lhs, rhs| $bin_op::$bin_op_fn(lhs, &rhs), + $expected, + ); + check_op( + $lhs_hash_set, + $rhs_hash_set, + |lhs, rhs| $bin_op::$bin_op_fn(&lhs, &rhs), + $expected, + ); + check_op( + $lhs_hash_set, + $rhs_hash_set, + |mut lhs, rhs| { + $bin_assign_op::$bin_assign_op_fn(&mut lhs, &rhs); + lhs + }, + $expected, + ); + check_op( + $lhs_hash_set, + $rhs_hash_set, + |mut lhs, rhs| { + $bin_assign_op::$bin_assign_op_fn(&mut lhs, rhs); + lhs + }, + $expected, + ); + $(check_op( + $lhs_hash_set, + $rhs_hash_set, + |lhs, mut rhs| { + rhs.$bin_assign_rev_op_fn(&lhs); + lhs + }, + $expected, + );)? + }; +} + +fuzz_target!(|data: (HashSet, HashSet)| { + let (lhs_hash_set, rhs_hash_set) = data; + let lhs = LocSet::from_iter(lhs_hash_set.iter().copied()); + let rhs = LocSet::from_iter(rhs_hash_set.iter().copied()); + check_all_op_combos!( + &lhs_hash_set, + &rhs_hash_set, + lhs_hash_set.intersection(&rhs_hash_set).copied(), + BitAnd::bitand(), + BitAndAssign::bitand_assign(), + ); + check_all_op_combos!( + &lhs_hash_set, + &rhs_hash_set, + lhs_hash_set.union(&rhs_hash_set).copied(), + BitOr::bitor(), + BitOrAssign::bitor_assign(), + ); + check_all_op_combos!( + &lhs_hash_set, + &rhs_hash_set, + lhs_hash_set.symmetric_difference(&rhs_hash_set).copied(), + BitXor::bitxor(), + BitXorAssign::bitxor_assign(), + ); + check_all_op_combos!( + &lhs_hash_set, + &rhs_hash_set, + lhs_hash_set.difference(&rhs_hash_set).copied(), + Sub::sub(), + SubAssign::sub_assign(), + ); +}); diff --git a/register_allocator/src/fuzzing.rs b/register_allocator/src/fuzzing.rs index da43401..8338fe4 100644 --- a/register_allocator/src/fuzzing.rs +++ b/register_allocator/src/fuzzing.rs @@ -1,5 +1,3 @@ -use std::{collections::BTreeMap, num::NonZeroUsize}; - use crate::{ function::{Block, BlockTermInstKind, FnFields, Inst, InstKind, Operand, SSAVal, SSAValDef}, index::{BlockIdx, InstIdx, InstRange, SSAValIdx}, @@ -9,6 +7,7 @@ use crate::{ }; use arbitrary::{Arbitrary, Error, Unstructured}; use petgraph::algo::dominators; +use std::collections::BTreeMap; struct FnBuilder<'a, 'b, 'g> { global_state: &'g GlobalState, @@ -70,7 +69,8 @@ impl FnBuilder<'_, '_, '_> { immediate_dominator: Default::default(), }); for _ in 0..self.u.int_in_range(0..=10)? { - self.new_inst_in_last_block(InstKind::Normal, vec![], self.u.arbitrary()?); + let clobbers = self.u.arbitrary()?; + self.new_inst_in_last_block(InstKind::Normal, vec![], clobbers); } let mut succs_and_params = BTreeMap::default(); let succ_range = BlockIdx::ENTRY_BLOCK.get() as u16..=(block_count - 1); @@ -79,7 +79,7 @@ impl FnBuilder<'_, '_, '_> { if i > succ_range.len() { break; } - let succ = BlockIdx::new(self.u.int_in_range(succ_range)?.into()); + let succ = BlockIdx::new(self.u.int_in_range(succ_range.clone())?.into()); succs_and_params.insert(succ, vec![]); } } @@ -91,10 +91,11 @@ impl FnBuilder<'_, '_, '_> { stage: todo!(), }); } + let clobbers = self.u.arbitrary()?; self.new_inst_in_last_block( InstKind::BlockTerm(BlockTermInstKind { succs_and_params }), operands, - self.u.arbitrary()?, + clobbers, ); } let dominators = dominators::simple_fast(&self.func, BlockIdx::ENTRY_BLOCK); @@ -120,12 +121,14 @@ impl FnBuilder<'_, '_, '_> { let inst = &mut self.func.insts[inst_idx]; match &mut inst.kind { InstKind::Normal => { - let _; + let _: (); todo!() } InstKind::Copy(_) => unreachable!(), InstKind::BlockTerm(block_term_inst_kind) => { - for (&succ, params) in &mut block_term_inst_kind.succs_and_params {} + for (&succ, params) in &mut block_term_inst_kind.succs_and_params { + todo!(); + } } } } diff --git a/register_allocator/src/interned.rs b/register_allocator/src/interned.rs index 95d49a0..2eb61d7 100644 --- a/register_allocator/src/interned.rs +++ b/register_allocator/src/interned.rs @@ -301,6 +301,15 @@ pub trait InternTarget: Intern + Hash + Eq { } } +impl<'a, T: arbitrary::Arbitrary<'a> + InternTarget> arbitrary::Arbitrary<'a> for Interned { + fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { + let retval: T = u.arbitrary()?; + Ok(GlobalState::get(|global_state| { + retval.into_interned(global_state) + })) + } +} + impl InternTarget for str { fn get_interner(global_state: &GlobalState) -> &Interner { &global_state.interners.str diff --git a/register_allocator/src/loc.rs b/register_allocator/src/loc.rs index e6473ef..08e81fc 100644 --- a/register_allocator/src/loc.rs +++ b/register_allocator/src/loc.rs @@ -185,9 +185,9 @@ validated_fields! { #[fields_ty = LocFields] #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)] pub struct Loc { + pub reg_len: NonZeroU32, pub kind: LocKind, pub start: u32, - pub reg_len: NonZeroU32, } } @@ -258,6 +258,9 @@ impl Loc { pub const fn ty(self) -> Ty { const_unwrap_res!(self.0.ty(), "Loc can only be constructed with valid fields") } + /// does all `Loc` validation except checking `start`, returns the maximum + /// value `start` can have, so a `Loc` is valid if + /// `start < Loc::max_start(kind, reg_len)?` pub const fn max_start(kind: LocKind, reg_len: NonZeroU32) -> Result { // validate Ty const_try!(Ty::new(TyFields { diff --git a/register_allocator/src/loc_set.rs b/register_allocator/src/loc_set.rs index a9d8565..4427569 100644 --- a/register_allocator/src/loc_set.rs +++ b/register_allocator/src/loc_set.rs @@ -1,43 +1,60 @@ use crate::{ error::{Error, Result}, interned::{GlobalState, Intern, InternTarget, Interned}, - loc::{BaseTy, Loc, LocFields, LocKind, Ty, TyFields}, + loc::{Loc, LocFields, LocKind, Ty}, }; use enum_map::{enum_map, EnumMap}; use num_bigint::BigUint; use num_traits::Zero; +use once_cell::race::OnceBox; use serde::{Deserialize, Serialize}; use std::{ - borrow::{Borrow, Cow}, + borrow::Borrow, cell::Cell, - collections::BTreeMap, + collections::{ + btree_map::{self, Entry}, + BTreeMap, + }, fmt, hash::Hash, - iter::{FusedIterator, Peekable}, + iter::FusedIterator, + mem, num::NonZeroU32, - ops::{ - BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, ControlFlow, Range, Sub, - SubAssign, - }, + ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Range, Sub, SubAssign}, }; +#[inline] +fn zero_biguint<'a>() -> &'a BigUint { + static ZERO: OnceBox = OnceBox::new(); + ZERO.get_or_init( + #[cold] + || BigUint::zero().into(), + ) +} + #[derive(Deserialize)] struct LocSetSerialized { - reg_len_to_starts_map: BTreeMap>, + starts_map: BTreeMap>, } impl TryFrom for LocSet { type Error = Error; fn try_from(value: LocSetSerialized) -> Result { - Self::from_reg_len_to_starts_map(value.reg_len_to_starts_map) + Self::from_starts_map(value.starts_map) } } #[derive(Clone, Default, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(try_from = "LocSetSerialized")] pub struct LocSet { - reg_len_to_starts_map: BTreeMap>, + starts_map: BTreeMap>, +} + +impl<'a> arbitrary::Arbitrary<'a> for LocSet { + fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { + u.arbitrary_iter()?.collect() + } } /// computes same value as `a & !b`, but more efficiently @@ -58,12 +75,9 @@ impl From for LocSet { impl LocSet { pub fn arbitrary_with_ty( - ty: Option, + ty: Ty, u: &mut arbitrary::Unstructured<'_>, ) -> arbitrary::Result { - let Some(ty) = ty else { - return Ok(Self::new()); - }; let kinds = ty.base_ty.loc_kinds(); type Mask = u128; let kinds: Vec<_> = if kinds.len() > Mask::BITS as usize { @@ -103,103 +117,135 @@ impl LocSet { let byte_count = (bit_count + u8::BITS - 1) / u8::BITS; let bytes = u.bytes(byte_count as usize)?; starts[kind] = BigUint::from_bytes_le(bytes); + starts[kind] &= (BigUint::from(1u8) << bit_count) - 1u8; all_zero &= starts[kind].is_zero(); } if all_zero { Ok(Loc::arbitrary_with_ty(ty, u)?.into()) } else { - Ok(Self::from_parts(starts, Some(ty))?) + Ok(Self::from_starts_map_iter_unchecked([(ty.reg_len, starts)])) } } - pub fn starts(&self) -> &EnumMap { - &self.starts - } - pub fn stops(&self) -> EnumMap { - let Some(ty) = self.ty else { - return EnumMap::default(); - }; - enum_map! {kind => &self.starts[kind] << ty.reg_len.get()} + pub fn starts(&self, reg_len: NonZeroU32, kind: LocKind) -> &BigUint { + self.starts_map + .get(®_len) + .map(|v| &v[kind]) + .unwrap_or_else(zero_biguint) } - pub fn ty(&self) -> Option { - self.ty + pub fn stops(&self, reg_len: NonZeroU32, kind: LocKind) -> BigUint { + self.starts(reg_len, kind) << reg_len.get() } - pub fn kinds(&self) -> impl Iterator + '_ { - self.starts - .iter() - .filter_map(|(kind, starts)| if starts.is_zero() { None } else { Some(kind) }) + pub fn starts_map(&self) -> &BTreeMap> { + &self.starts_map } - pub fn reg_len(&self) -> Option { - self.ty.map(|v| v.reg_len) + pub const fn new() -> Self { + Self { + starts_map: BTreeMap::new(), + } } - pub fn base_ty(&self) -> Option { - self.ty.map(|v| v.base_ty) + /// filters out empty entries, but doesn't do any other checks + fn from_starts_map_iter_unchecked( + starts_map: impl IntoIterator)>, + ) -> Self { + Self { + starts_map: starts_map + .into_iter() + .filter(|(_, starts)| !starts.iter().all(|(_, starts)| starts.is_zero())) + .collect(), + } } - pub fn new() -> Self { - Self::default() + fn for_each_reg_len_filtering_out_empty_entries( + &mut self, + mut f: impl FnMut(NonZeroU32, &mut EnumMap), + ) { + self.starts_map.retain(|®_len, starts| { + f(reg_len, starts); + !Self::is_entry_empty(starts) + }); } - pub fn from_parts(starts: EnumMap, ty: Option) -> Result { - let mut empty = true; - for (kind, starts) in &starts { - if !starts.is_zero() { - empty = false; - let expected_ty = Ty::new_or_scalar(TyFields { - base_ty: kind.base_ty(), - reg_len: ty.map(|v| v.reg_len).unwrap_or(nzu32_lit!(1)), - }); - if ty != Some(expected_ty) { - return Err(Error::TyMismatch { - ty, - expected_ty: Some(expected_ty), - }); + /// helper for binary operations that keeps Locs not present in rhs + fn bin_op_keep_helper( + &mut self, + rhs: &Self, + mut f: impl FnMut(NonZeroU32, &mut EnumMap, &EnumMap), + ) { + rhs.starts_map.iter().for_each(|(®_len, rhs_starts)| { + match self.starts_map.entry(reg_len) { + Entry::Vacant(entry) => { + let mut lhs_starts = EnumMap::default(); + f(reg_len, &mut lhs_starts, rhs_starts); + if !Self::is_entry_empty(&lhs_starts) { + entry.insert(lhs_starts); + } } - // bits() is one past max bit set, so use >= rather than > - if starts.bits() >= Loc::max_start(kind, expected_ty.reg_len)? as u64 { - return Err(Error::StartNotInValidRange); + Entry::Occupied(mut entry) => { + f(reg_len, entry.get_mut(), rhs_starts); + if Self::is_entry_empty(entry.get()) { + entry.remove(); + } } } - } - if empty && ty.is_some() { - Err(Error::TyMismatch { - ty, - expected_ty: None, - }) - } else { - Ok(Self { starts, ty }) - } + }); + } + fn is_entry_empty(starts: &EnumMap) -> bool { + starts.iter().all(|(_, starts)| starts.is_zero()) + } + pub fn from_starts_map( + mut starts_map: BTreeMap>, + ) -> Result { + let mut error = Ok(()); + starts_map.retain(|®_len, starts| { + if error.is_err() { + return false; + } + let mut any_locs = false; + for (kind, starts) in starts { + if !starts.is_zero() { + any_locs = true; + error = (|| { + // bits() is one past max bit set, so use >= rather than > + if starts.bits() >= Loc::max_start(kind, reg_len)? as u64 { + return Err(Error::StartNotInValidRange); + } + Ok(()) + })(); + if error.is_err() { + return false; + } + } + } + any_locs + }); + Ok(Self { starts_map }) } pub fn clear(&mut self) { - for v in self.starts.values_mut() { - v.assign_from_slice(&[]); - } + self.starts_map.clear(); } pub fn contains_exact(&self, value: Loc) -> bool { - Some(value.ty()) == self.ty && self.starts[value.kind].bit(value.start as _) - } - pub fn try_insert(&mut self, value: Loc) -> Result { - if self.is_empty() { - self.ty = Some(value.ty()); - self.starts[value.kind].set_bit(value.start as u64, true); - return Ok(true); - }; - let ty = Some(value.ty()); - if ty != self.ty { - return Err(Error::TyMismatch { - ty, - expected_ty: self.ty, - }); - } - let retval = !self.starts[value.kind].bit(value.start as u64); - self.starts[value.kind].set_bit(value.start as u64, true); - Ok(retval) + self.starts(value.reg_len, value.kind).bit(value.start as _) } pub fn insert(&mut self, value: Loc) -> bool { - self.try_insert(value).unwrap() + let starts = match self.starts_map.entry(value.reg_len) { + Entry::Occupied(entry) => entry.into_mut(), + Entry::Vacant(entry) => { + entry.insert(Default::default())[value.kind].set_bit(value.start as u64, true); + return true; + } + }; + let starts = &mut starts[value.kind]; + let retval = !starts.bit(value.start as u64); + starts.set_bit(value.start as u64, true); + retval } pub fn remove(&mut self, value: Loc) -> bool { - if self.contains_exact(value) { - self.starts[value.kind].set_bit(value.start as u64, false); - if self.starts.values().all(BigUint::is_zero) { - self.ty = None; + let Entry::Occupied(mut entry) = self.starts_map.entry(value.reg_len) else { + return false; + }; + let starts = entry.get_mut(); + if starts[value.kind].bit(value.start as u64) { + starts[value.kind].set_bit(value.start as u64, false); + if starts.values().all(BigUint::is_zero) { + entry.remove(); } true } else { @@ -207,99 +253,173 @@ impl LocSet { } } pub fn is_empty(&self) -> bool { - self.ty.is_none() + self.starts_map.is_empty() } pub fn iter(&self) -> Iter<'_> { - if let Some(ty) = self.ty { - let mut starts = self.starts.iter().peekable(); - Iter { - internals: Some(IterInternals { - ty, - start_range: get_start_range(starts.peek()), - starts, - }), - } - } else { - Iter { internals: None } + Iter { + internals: IterInternals::new(self.starts_map.iter()), } } pub fn len(&self) -> usize { - let retval: u64 = self.starts.values().map(BigUint::count_ones).sum(); + let retval: u64 = self + .starts_map + .values() + .map(|starts| starts.values().map(BigUint::count_ones).sum::()) + .sum(); retval as usize } + /// computes `self = &other - &self` + pub fn sub_reverse_assign(&mut self, other: impl Borrow) { + // TODO: make more efficient + let other: &Self = other.borrow(); + *self = other - &*self; + } } #[derive(Clone, Debug)] -struct IterInternals +struct IterInternalsRest where - I: Iterator, - T: Clone + Borrow, + StartsMapValueIter: Iterator, + Starts: Borrow, { - ty: Ty, - starts: Peekable, + reg_len: NonZeroU32, + starts_map_value_iter: StartsMapValueIter, + kind: LocKind, + starts: Starts, start_range: Range, } -impl IterInternals +impl IterInternalsRest where - I: Iterator, - T: Clone + Borrow, + StartsMapValueIter: Iterator, + Starts: Borrow, { - fn next(&mut self) -> Option { - let IterInternals { - ty, - ref mut starts, - ref mut start_range, - } = *self; + fn new(reg_len: NonZeroU32, mut starts_map_value_iter: StartsMapValueIter) -> Option { loop { - let (kind, ref v) = *starts.peek()?; + let (kind, starts) = starts_map_value_iter.next()?; + let starts_ref: &BigUint = starts.borrow(); + let Some(start) = starts_ref.trailing_zeros() else { + continue; + }; + let start = start.try_into().expect("checked by LocSet constructors"); + let end = starts_ref + .bits() + .try_into() + .expect("checked by LocSet constructors"); + return Some(Self { + reg_len, + starts_map_value_iter, + kind, + starts, + start_range: start..end, + }); + } + } + fn next(this: &mut Option) -> Option { + while let Some(Self { + reg_len, + starts_map_value_iter: _, + kind, + ref starts, + ref mut start_range, + }) = *this + { let Some(start) = start_range.next() else { - starts.next(); - *start_range = get_start_range(starts.peek()); + *this = Self::new(reg_len, this.take().expect("known to be Some").starts_map_value_iter); continue; }; - if v.borrow().bit(start as u64) { + if starts.borrow().bit(start as u64) { return Some( Loc::new(LocFields { kind, start, - reg_len: ty.reg_len, + reg_len, }) .expect("known to be valid"), ); } } + None } } -fn get_start_range(v: Option<&(LocKind, impl Borrow)>) -> Range { - 0..v.map(|(_, v)| v.borrow().bits() as u32).unwrap_or(0) +#[derive(Clone, Debug)] +struct IterInternals +where + StartsMapIter: Iterator, + RegLen: Borrow, + StartsMapValue: IntoIterator, + StartsMapValueIter: Iterator, + Starts: Borrow, +{ + starts_map_iter: StartsMapIter, + rest: Option>, +} + +impl + IterInternals +where + StartsMapIter: Iterator, + RegLen: Borrow, + StartsMapValue: IntoIterator, + StartsMapValueIter: Iterator, + Starts: Borrow, +{ + fn new(starts_map_iter: StartsMapIter) -> Self { + Self { + starts_map_iter, + rest: None, + } + } + fn next(&mut self) -> Option { + loop { + while self.rest.is_none() { + let (reg_len, starts_map_value) = self.starts_map_iter.next()?; + self.rest = IterInternalsRest::new(*reg_len.borrow(), starts_map_value.into_iter()); + } + if let Some(loc) = IterInternalsRest::next(&mut self.rest) { + return Some(loc); + } + } + } } #[derive(Clone, Debug)] pub struct Iter<'a> { - internals: Option, &'a BigUint>>, + internals: IterInternals< + btree_map::Iter<'a, NonZeroU32, EnumMap>, + &'a NonZeroU32, + &'a EnumMap, + enum_map::Iter<'a, LocKind, BigUint>, + &'a BigUint, + >, } impl Iterator for Iter<'_> { type Item = Loc; fn next(&mut self) -> Option { - self.internals.as_mut()?.next() + self.internals.next() } } impl FusedIterator for Iter<'_> {} pub struct IntoIter { - internals: Option, BigUint>>, + internals: IterInternals< + btree_map::IntoIter>, + NonZeroU32, + EnumMap, + enum_map::IntoIter, + BigUint, + >, } impl Iterator for IntoIter { type Item = Loc; fn next(&mut self) -> Option { - self.internals.as_mut()?.next() + self.internals.next() } } @@ -310,17 +430,8 @@ impl IntoIterator for LocSet { type IntoIter = IntoIter; fn into_iter(self) -> Self::IntoIter { - if let Some(ty) = self.ty { - let mut starts = self.starts.into_iter().peekable(); - IntoIter { - internals: Some(IterInternals { - ty, - start_range: get_start_range(starts.peek()), - starts, - }), - } - } else { - IntoIter { internals: None } + IntoIter { + internals: IterInternals::new(self.starts_map.into_iter()), } } } @@ -342,23 +453,6 @@ impl Extend for LocSet { } } -impl> Extend for Result { - fn extend>(&mut self, iter: T) { - iter.into_iter().try_for_each(|item| { - let Ok(loc_set) = self else { - return ControlFlow::Break(()); - }; - match loc_set.try_insert(item) { - Ok(_) => ControlFlow::Continue(()), - Err(e) => { - *self = Err(e.into()); - ControlFlow::Break(()) - } - } - }); - } -} - impl FromIterator for LocSet { fn from_iter>(iter: T) -> Self { let mut retval = LocSet::new(); @@ -367,14 +461,6 @@ impl FromIterator for LocSet { } } -impl> FromIterator for Result { - fn from_iter>(iter: T) -> Self { - let mut retval = Ok(LocSet::new()); - retval.extend(iter); - retval - } -} - struct HexBigUint<'a>(&'a BigUint); impl fmt::Debug for HexBigUint<'_> { @@ -393,50 +479,30 @@ impl fmt::Debug for LocSetStarts<'_> { } } +struct LocSetStartsMap<'a>(&'a BTreeMap>); + +impl fmt::Debug for LocSetStartsMap<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_map() + .entries(self.0.iter().map(|(k, v)| (k, LocSetStarts(v)))) + .finish() + } +} + impl fmt::Debug for LocSet { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("LocSet") - .field("starts", &self.starts) + .field("starts_map", &LocSetStartsMap(&self.starts_map)) .finish() } } -macro_rules! impl_bin_op { +macro_rules! forward_bin_op { ( $bin_op:ident::$bin_op_fn:ident(), $bin_assign_op:ident::$bin_assign_op_fn:ident(), - $starts_op:expr, - $handle_unequal_types:expr, - $update_unequal_types:expr, + $bin_assign_rev_op_fn:ident(), ) => { - impl $bin_op<&'_ LocSet> for &'_ LocSet { - type Output = LocSet; - - fn $bin_op_fn(self, rhs: &'_ LocSet) -> Self::Output { - if self.ty != rhs.ty { - $handle_unequal_types(self, Cow::::Borrowed(rhs)) - } else { - LocSet { - starts: enum_map! {kind => $starts_op(&self.starts[kind], &rhs.starts[kind])}, - ty: self.ty, - } - } - } - } - - impl $bin_assign_op<&'_ LocSet> for LocSet { - fn $bin_assign_op_fn(&mut self, rhs: &'_ LocSet) { - if self.ty != rhs.ty { - $update_unequal_types(self, rhs); - } else { - for (kind, starts) in &mut self.starts { - let v: BigUint = std::mem::take(starts); - *starts = $starts_op(v, &rhs.starts[kind]); - } - } - } - } - impl $bin_assign_op for LocSet { fn $bin_assign_op_fn(&mut self, rhs: LocSet) { self.$bin_assign_op_fn(&rhs); @@ -465,49 +531,133 @@ macro_rules! impl_bin_op { type Output = LocSet; fn $bin_op_fn(self, mut rhs: LocSet) -> Self::Output { - if self.ty != rhs.ty { - $handle_unequal_types(self, Cow::::Owned(rhs)) - } else { - for (kind, starts) in &mut rhs.starts { - *starts = $starts_op(&self.starts[kind], std::mem::take(starts)); + rhs.$bin_assign_rev_op_fn(self); + rhs + } + } + + const _: fn() = { + fn _check() + where + for<'a> T: $bin_op + $bin_op<&'a T> + $bin_assign_op + $bin_assign_op<&'a T>, + for<'a, 'b> &'a T: $bin_op + $bin_op<&'b T>, + { + } + _check:: + }; + }; +} + +impl BitAnd<&'_ LocSet> for &'_ LocSet { + type Output = LocSet; + + fn bitand(self, rhs: &'_ LocSet) -> Self::Output { + LocSet::from_starts_map_iter_unchecked(self.starts_map.iter().map(|(®_len, starts)| { + ( + reg_len, + enum_map! {kind => (&starts[kind]).bitand(rhs.starts(reg_len, kind))}, + ) + })) + } +} + +impl BitAndAssign<&'_ LocSet> for LocSet { + fn bitand_assign(&mut self, rhs: &'_ LocSet) { + self.for_each_reg_len_filtering_out_empty_entries(|reg_len, starts| { + for (kind, starts) in starts { + starts.bitand_assign(rhs.starts(reg_len, kind)); + } + }); + } +} + +/// helper for binary operations that keeps Locs not present in rhs +macro_rules! impl_bin_op_keep { + ( + $bin_op:ident::$bin_op_fn:ident(), + $bin_assign_op:ident::$bin_assign_op_fn:ident(), + ) => { + impl $bin_op<&'_ LocSet> for &'_ LocSet { + type Output = LocSet; + + fn $bin_op_fn(self, rhs: &'_ LocSet) -> Self::Output { + let mut retval: LocSet = self.clone(); + retval.$bin_assign_op_fn(rhs); + retval + } + } + + impl $bin_assign_op<&'_ LocSet> for LocSet { + fn $bin_assign_op_fn(&mut self, rhs: &'_ LocSet) { + self.bin_op_keep_helper(rhs, |_reg_len, lhs_starts, rhs_starts| { + for (kind, rhs_starts) in rhs_starts { + lhs_starts[kind].$bin_assign_op_fn(rhs_starts); } - rhs - } + }); } } }; } -impl_bin_op! { +forward_bin_op! { BitAnd::bitand(), BitAndAssign::bitand_assign(), - BitAnd::bitand, - |_, _| LocSet::new(), - |lhs, _| LocSet::clear(lhs), + bitand_assign(), } -impl_bin_op! { +impl_bin_op_keep! { BitOr::bitor(), BitOrAssign::bitor_assign(), - BitOr::bitor, - |lhs: &LocSet, rhs: Cow| panic!("{}", Error::TyMismatch { ty: rhs.ty, expected_ty: lhs.ty }), - |lhs: &mut LocSet, rhs: &LocSet| panic!("{}", Error::TyMismatch { ty: rhs.ty, expected_ty: lhs.ty }), } -impl_bin_op! { +forward_bin_op! { + BitOr::bitor(), + BitOrAssign::bitor_assign(), + bitor_assign(), +} + +impl_bin_op_keep! { + BitXor::bitxor(), + BitXorAssign::bitxor_assign(), +} + +forward_bin_op! { BitXor::bitxor(), BitXorAssign::bitxor_assign(), - BitXor::bitxor, - |lhs: &LocSet, rhs: Cow| panic!("{}", Error::TyMismatch { ty: rhs.ty, expected_ty: lhs.ty }), - |lhs: &mut LocSet, rhs: &LocSet| panic!("{}", Error::TyMismatch { ty: rhs.ty, expected_ty: lhs.ty }), + bitxor_assign(), +} + +impl Sub<&'_ LocSet> for &'_ LocSet { + type Output = LocSet; + + fn sub(self, rhs: &'_ LocSet) -> Self::Output { + LocSet::from_starts_map_iter_unchecked(self.starts_map.iter().map(|(®_len, starts)| { + ( + reg_len, + enum_map! {kind => and_not(&starts[kind], rhs.starts(reg_len, kind))}, + ) + })) + } +} + +impl SubAssign<&'_ LocSet> for LocSet { + fn sub_assign(&mut self, rhs: &'_ LocSet) { + self.bin_op_keep_helper(rhs, |_reg_len, lhs_starts, rhs_starts| { + for (kind, lhs_starts) in lhs_starts { + let rhs_starts = &rhs_starts[kind]; + if rhs_starts.is_zero() { + continue; + } + *lhs_starts = and_not(mem::take(lhs_starts), rhs_starts); + } + }); + } } -impl_bin_op! { +forward_bin_op! { Sub::sub(), SubAssign::sub_assign(), - and_not, - |lhs: &LocSet, _| lhs.clone(), - |_, _| {}, + sub_reverse_assign(), } /// the largest number of Locs in `lhs` that a single Loc @@ -541,35 +691,52 @@ pub trait LocSetMaxConflictsWithTrait: Clone { global_state: &GlobalState, ) -> Interned>; fn compute_result(lhs: &Interned, rhs: &Self, global_state: &GlobalState) -> u32; + #[cfg(feature = "fuzzing")] + fn reference_compute_result( + lhs: &Interned, + rhs: &Self, + global_state: &GlobalState, + ) -> u32; } impl LocSetMaxConflictsWithTrait for Loc { fn compute_result(lhs: &Interned, rhs: &Self, _global_state: &GlobalState) -> u32 { // now we do the equivalent of: - // return lhs.iter().map(|loc| rhs.conflicts(loc) as u32).sum().unwrap_or(0) - let Some(reg_len) = lhs.reg_len() else { - return 0; - }; - let starts = &lhs.starts[rhs.kind]; - if starts.is_zero() { - return 0; + // return lhs.iter().map(|loc| rhs.conflicts(loc) as u32).sum() + let mut retval = 0; + for (&lhs_reg_len, lhs_starts) in lhs.starts_map() { + let lhs_starts = &lhs_starts[rhs.kind]; + if lhs_starts.is_zero() { + continue; + } + // now we do the equivalent of: + // retval += sum(rhs.start < lhs_start + lhs_reg_len + // and lhs_start < rhs.start + rhs.reg_len + // for lhs_start in lhs_starts) + let lhs_stops = lhs_starts << lhs_reg_len.get(); + + // find all the bit indexes `i` where `i < rhs.start + 1` + let lt_rhs_start_plus_1 = (BigUint::from(1u32) << (rhs.start + 1)) - 1u32; + + // find all the bit indexes `i` where + // `i < rhs.start + rhs.reg_len + lhs_reg_len` + let lt_rhs_start_plus_rhs_reg_len_plus_reg_len = + (BigUint::from(1u32) << (rhs.start + rhs.reg_len.get() + lhs_reg_len.get())) - 1u32; + let lhs_stops_and_lt_rhs_start_plus_1 = &lhs_stops & lt_rhs_start_plus_1; + let mut included = and_not(lhs_stops, lhs_stops_and_lt_rhs_start_plus_1); + included &= lt_rhs_start_plus_rhs_reg_len_plus_reg_len; + retval += included.count_ones() as u32; } - // now we do the equivalent of: - // return sum(rhs.start < start + reg_len - // and start < rhs.start + rhs.reg_len - // for start in starts) - let stops = starts << reg_len.get(); - - // find all the bit indexes `i` where `i < rhs.start + 1` - let lt_rhs_start_plus_1 = (BigUint::from(1u32) << (rhs.start + 1)) - 1u32; + retval + } - // find all the bit indexes `i` where - // `i < rhs.start + rhs.reg_len + reg_len` - let lt_rhs_start_plus_rhs_reg_len_plus_reg_len = - (BigUint::from(1u32) << (rhs.start + rhs.reg_len.get() + reg_len.get())) - 1u32; - let mut included = and_not(&stops, &stops & lt_rhs_start_plus_1); - included &= lt_rhs_start_plus_rhs_reg_len_plus_reg_len; - included.count_ones() as u32 + #[cfg(feature = "fuzzing")] + fn reference_compute_result( + lhs: &Interned, + rhs: &Self, + global_state: &GlobalState, + ) -> u32 { + lhs.iter().map(|loc| rhs.conflicts(loc) as u32).sum::() } fn intern( @@ -588,6 +755,18 @@ impl LocSetMaxConflictsWithTrait for Interned { .unwrap_or(0) } + #[cfg(feature = "fuzzing")] + fn reference_compute_result( + lhs: &Interned, + rhs: &Self, + global_state: &GlobalState, + ) -> u32 { + rhs.iter() + .map(|loc| lhs.clone().reference_max_conflicts_with(loc, global_state)) + .max() + .unwrap_or(0) + } + fn intern( v: LocSetMaxConflictsWith, global_state: &GlobalState, @@ -613,6 +792,17 @@ impl LocSetMaxConflictsWith { } } } + #[cfg(feature = "fuzzing")] + pub fn reference_result(&self, global_state: &GlobalState) -> u32 { + match self.result.get() { + Some(v) => v, + None => { + let retval = Rhs::reference_compute_result(&self.lhs, &self.rhs, global_state); + self.result.set(Some(retval)); + retval + } + } + } } impl Interned { @@ -631,6 +821,22 @@ impl Interned { ) .result(global_state) } + #[cfg(feature = "fuzzing")] + pub fn reference_max_conflicts_with(self, rhs: Rhs, global_state: &GlobalState) -> u32 + where + Rhs: LocSetMaxConflictsWithTrait, + LocSetMaxConflictsWith: InternTarget, + { + LocSetMaxConflictsWithTrait::intern( + LocSetMaxConflictsWith { + lhs: self, + rhs, + result: Cell::default(), + }, + global_state, + ) + .reference_result(global_state) + } pub fn conflicts_with(self, rhs: Rhs, global_state: &GlobalState) -> bool where Rhs: LocSetMaxConflictsWithTrait,