From: Jacob Lifshay Date: Thu, 19 Jan 2023 02:51:28 +0000 (-0800) Subject: working on code X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=eb07536d02bc039390e0fc6e6028a6add1caa310;p=bigint-presentation-code.git working on code --- diff --git a/register_allocator/src/error.rs b/register_allocator/src/error.rs index a07fb75..104795e 100644 --- a/register_allocator/src/error.rs +++ b/register_allocator/src/error.rs @@ -1,4 +1,7 @@ -use crate::loc::{BaseTy, Ty}; +use crate::{ + index::{BlockIdx, InstIdx}, + loc::{BaseTy, Ty}, +}; use thiserror::Error; #[derive(Debug, Error)] @@ -20,6 +23,29 @@ pub enum Error { ty: Option, expected_ty: Option, }, + #[error("function doesn't have entry block")] + MissingEntryBlock, + #[error("instruction index is too big")] + InstIdxTooBig, + #[error("block has invalid start {start}, expected {expected_start}")] + BlockHasInvalidStart { + start: InstIdx, + expected_start: InstIdx, + }, + #[error("block {block} doesn't contain any instructions")] + BlockIsEmpty { block: BlockIdx }, + #[error("entry block must not have any block parameters")] + EntryBlockCantHaveParams, + #[error("entry block must not have any predecessors")] + EntryBlockCantHavePreds, + #[error("block end is out of range: {end}")] + BlockEndOutOfRange { end: InstIdx }, + #[error("block's last instruction must be a block terminator: {term_idx}")] + BlocksLastInstMustBeTerm { term_idx: InstIdx }, + #[error( + "block terminator instructions are only allowed as a block's last instruction: {inst_idx}" + )] + TermInstOnlyAllowedAtBlockEnd { inst_idx: InstIdx }, } pub type Result = std::result::Result; diff --git a/register_allocator/src/function.rs b/register_allocator/src/function.rs new file mode 100644 index 0000000..f9eacec --- /dev/null +++ b/register_allocator/src/function.rs @@ -0,0 +1,269 @@ +use crate::{ + error::{Error, Result}, + index::{BlockIdx, InstIdx, InstRange, SSAValIdx}, + interned::Interned, + loc::{Loc, Ty}, + loc_set::LocSet, +}; +use core::fmt; +use serde::{Deserialize, Serialize}; +use std::ops::Index; + +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] +pub struct SSAVal { + pub ty: Ty, +} + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)] +#[repr(u8)] +pub enum InstStage { + Early = 0, + Late = 1, +} + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] +#[serde(try_from = "SerializedProgPoint", into = "SerializedProgPoint")] +pub struct ProgPoint(usize); + +impl ProgPoint { + pub const fn new(inst: InstIdx, stage: InstStage) -> Self { + const_unwrap_res!(Self::try_new(inst, stage)) + } + pub const fn try_new(inst: InstIdx, stage: InstStage) -> Result { + let Some(inst) = inst.get().checked_shl(1) else { + return Err(Error::InstIdxTooBig); + }; + Ok(Self(inst | stage as usize)) + } + pub const fn inst(self) -> InstIdx { + InstIdx::new(self.0 >> 1) + } + pub const fn stage(self) -> InstStage { + if self.0 & 1 != 0 { + InstStage::Late + } else { + InstStage::Early + } + } + pub const fn next(self) -> Self { + Self(self.0 + 1) + } + pub const fn prev(self) -> Self { + Self(self.0 - 1) + } +} + +impl fmt::Debug for ProgPoint { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ProgPoint") + .field("inst", &self.inst()) + .field("stage", &self.stage()) + .finish() + } +} + +#[derive(Serialize, Deserialize)] +struct SerializedProgPoint { + inst: InstIdx, + stage: InstStage, +} + +impl From for SerializedProgPoint { + fn from(value: ProgPoint) -> Self { + Self { + inst: value.inst(), + stage: value.stage(), + } + } +} + +impl TryFrom for ProgPoint { + type Error = Error; + + fn try_from(value: SerializedProgPoint) -> Result { + ProgPoint::try_new(value.inst, value.stage) + } +} + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)] +#[repr(u8)] +pub enum OperandKind { + Use = 0, + Def = 1, +} + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)] +pub enum Constraint { + /// any register or stack location + Any, + /// r1-r32 + BaseGpr, + /// r2,r4,r6,r8,...r126 + SVExtra2VGpr, + /// r1-63 + SVExtra2SGpr, + /// r1-127 + SVExtra3Gpr, + /// any stack location + Stack, + FixedLoc(Loc), + Reuse(usize), +} + +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] +pub struct Operand { + pub ssa_val: SSAValIdx, + pub constraint: Constraint, + pub kind: OperandKind, + pub stage: InstStage, +} + +#[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] +pub struct BranchSucc { + pub block: BlockIdx, + pub params: Vec, +} + +#[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] +pub enum InstKind { + Normal, + /// copy concatenates all `srcs` together and de-concatenates the result into all `dests`. + Copy { + srcs: Vec, + dests: Vec, + }, + Return, + Branch { + succs: Vec, + }, +} + +impl InstKind { + pub fn is_normal(&self) -> bool { + matches!(self, Self::Normal) + } + pub fn is_block_term(&self) -> bool { + matches!(self, Self::Return | Self::Branch { .. }) + } + pub fn succs(&self) -> Option<&[BranchSucc]> { + match self { + InstKind::Normal | InstKind::Copy { .. } => None, + InstKind::Return => Some(&[]), + InstKind::Branch { succs } => Some(succs), + } + } +} + +impl Default for InstKind { + fn default() -> Self { + InstKind::Normal + } +} + +#[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] +pub struct Inst { + #[serde(default, skip_serializing_if = "InstKind::is_normal")] + pub kind: InstKind, + pub operands: Vec, + pub clobbers: Interned, +} + +#[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] +pub struct Block { + pub params: Vec, + pub insts: InstRange, + pub preds: Vec, +} + +validated_fields! { + #[fields_ty = FnFields] + #[derive(Clone, PartialEq, Eq, Debug, Hash)] + pub struct Function { + pub ssa_vals: Vec, + pub insts: Vec, + pub blocks: Vec, + } +} + +impl Function { + pub fn new(fields: FnFields) -> Result { + let FnFields { + ssa_vals, + insts: insts_vec, + blocks, + } = &fields; + let entry_block = blocks + .get(BlockIdx::ENTRY_BLOCK.get()) + .ok_or(Error::MissingEntryBlock)?; + if !entry_block.params.is_empty() { + return Err(Error::EntryBlockCantHaveParams); + } + if !entry_block.preds.is_empty() { + return Err(Error::EntryBlockCantHavePreds); + } + let mut expected_start = InstIdx::new(0); + for (block_idx, block) in fields.blocks.iter().enumerate() { + let block_idx = BlockIdx::new(block_idx); + let Block { + params, + insts: inst_range, + preds, + } = block; + if inst_range.start != expected_start { + return Err(Error::BlockHasInvalidStart { + start: inst_range.start, + expected_start, + }); + } + let Some((term_idx, non_term_inst_range)) = inst_range.split_last() else { + return Err(Error::BlockIsEmpty { block: block_idx }); + }; + expected_start = inst_range.end; + let Some(Inst { kind: term_kind, .. }) = insts_vec.get(term_idx.get()) else { + return Err(Error::BlockEndOutOfRange { end: inst_range.end }); + }; + if !term_kind.is_block_term() { + return Err(Error::BlocksLastInstMustBeTerm { term_idx }); + } + for inst_idx in non_term_inst_range { + if insts_vec[inst_idx].kind.is_block_term() { + return Err(Error::TermInstOnlyAllowedAtBlockEnd { inst_idx }); + } + } + } + todo!() + } + pub fn entry_block(&self) -> &Block { + &self.blocks[0] + } + pub fn block_succs(&self, block: BlockIdx) -> &[BranchSucc] { + self.insts[self.blocks[block].insts.last().unwrap()] + .kind + .succs() + .unwrap() + } +} + +impl Index for Vec { + type Output = SSAVal; + + fn index(&self, index: SSAValIdx) -> &Self::Output { + &self[index.get()] + } +} + +impl Index for Vec { + type Output = Inst; + + fn index(&self, index: InstIdx) -> &Self::Output { + &self[index.get()] + } +} + +impl Index for Vec { + type Output = Block; + + fn index(&self, index: BlockIdx) -> &Self::Output { + &self[index.get()] + } +} diff --git a/register_allocator/src/index.rs b/register_allocator/src/index.rs new file mode 100644 index 0000000..d1423f3 --- /dev/null +++ b/register_allocator/src/index.rs @@ -0,0 +1,189 @@ +use serde::{Deserialize, Serialize}; +use std::{fmt, iter::FusedIterator, ops::Range}; + +macro_rules! define_index { + ($name:ident) => { + #[derive( + Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, + )] + #[serde(transparent)] + pub struct $name { + value: usize, + } + + impl $name { + pub const fn new(value: usize) -> Self { + Self { value } + } + pub const fn get(self) -> usize { + self.value + } + pub const fn next(self) -> Self { + Self { + value: self.value + 1, + } + } + pub const fn prev(self) -> Self { + Self { + value: self.value - 1, + } + } + } + }; +} + +define_index!(SSAValIdx); +define_index!(InstIdx); +define_index!(BlockIdx); + +impl fmt::Display for SSAValIdx { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "v{}", self.get()) + } +} + +impl fmt::Display for InstIdx { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "inst{}", self.get()) + } +} + +impl fmt::Display for BlockIdx { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "blk{}", self.get()) + } +} + +impl BlockIdx { + pub const ENTRY_BLOCK: BlockIdx = BlockIdx::new(0); +} + +/// range of instruction indexes from `start` inclusive to `end` exclusive. +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] +pub struct InstRange { + pub start: InstIdx, + pub end: InstIdx, +} + +impl fmt::Display for InstRange { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "inst{}..{}", self.start.get(), self.end.get()) + } +} + +impl IntoIterator for InstRange { + type Item = InstIdx; + type IntoIter = InstRangeIter; + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl InstRange { + pub const fn iter(self) -> InstRangeIter { + InstRangeIter(self.start.get()..self.end.get()) + } + pub const fn is_empty(self) -> bool { + self.start.get() >= self.end.get() + } + pub const fn first(self) -> Option { + Some(const_try_opt!(self.split_first()).0) + } + pub const fn last(self) -> Option { + Some(const_try_opt!(self.split_last()).0) + } + pub const fn split_first(self) -> Option<(InstIdx, InstRange)> { + if self.is_empty() { + None + } else { + Some(( + self.start, + Self { + start: self.start.next(), + end: self.end, + }, + )) + } + } + pub const fn split_last(self) -> Option<(InstIdx, InstRange)> { + if self.is_empty() { + None + } else { + Some(( + self.end.prev(), + Self { + start: self.start, + end: self.end.prev(), + }, + )) + } + } + pub const fn len(self) -> usize { + if self.is_empty() { + 0 + } else { + self.end.get() - self.start.get() + } + } +} + +#[derive(Clone, Debug)] +pub struct InstRangeIter(Range); + +impl InstRangeIter { + pub const fn range(self) -> InstRange { + InstRange { + start: InstIdx::new(self.0.start), + end: InstIdx::new(self.0.end), + } + } +} + +impl Iterator for InstRangeIter { + type Item = InstIdx; + + fn next(&mut self) -> Option { + self.0.next().map(InstIdx::new) + } + + fn size_hint(&self) -> (usize, Option) { + let v = self.0.len(); + (v, Some(v)) + } + + fn count(self) -> usize + where + Self: Sized, + { + self.len() + } + + fn last(self) -> Option + where + Self: Sized, + { + self.0.last().map(InstIdx::new) + } + + fn nth(&mut self, n: usize) -> Option { + self.0.nth(n).map(InstIdx::new) + } +} + +impl FusedIterator for InstRangeIter {} + +impl DoubleEndedIterator for InstRangeIter { + fn next_back(&mut self) -> Option { + self.0.next_back().map(InstIdx::new) + } + + fn nth_back(&mut self, n: usize) -> Option { + self.0.nth_back(n).map(InstIdx::new) + } +} + +impl ExactSizeIterator for InstRangeIter { + fn len(&self) -> usize { + self.0.len() + } +} diff --git a/register_allocator/src/interned.rs b/register_allocator/src/interned.rs index f99839f..95d49a0 100644 --- a/register_allocator/src/interned.rs +++ b/register_allocator/src/interned.rs @@ -1,5 +1,12 @@ -use hashbrown::{hash_map::RawEntryMut, HashMap}; -use serde::Serialize; +use crate::{ + loc::Loc, + loc_set::{LocSet, LocSetMaxConflictsWith}, +}; +use hashbrown::{ + hash_map::{Entry, RawEntryMut}, + HashMap, +}; +use serde::{de, Deserialize, Serialize}; use std::{ cell::RefCell, cmp::Ordering, @@ -9,12 +16,6 @@ use std::{ rc::Rc, }; -use crate::{ - loc::Loc, - loc_set::{LocSet, LocSetMaxConflictsWith}, -}; - -#[derive(Clone)] pub struct Interned { ptr: Rc, } @@ -27,6 +28,14 @@ impl Deref for Interned { } } +impl Clone for Interned { + fn clone(&self) -> Self { + Self { + ptr: self.ptr.clone(), + } + } +} + impl Hash for Interned { fn hash(&self, state: &mut H) { Rc::as_ptr(&self.ptr).hash(state); @@ -47,12 +56,99 @@ impl fmt::Display for Interned { } } -impl Serialize for Interned { +pub struct SerdeState { + global_state: Rc, + inner: SerdeStateInner, +} + +scoped_tls::scoped_thread_local!(static SERDE_STATE: SerdeState); + +impl SerdeState { + pub fn global_state(&self) -> &Rc { + &self.global_state + } + #[cold] + pub fn scope(global_state: &Rc, f: impl FnOnce() -> R) -> R { + SERDE_STATE.set( + &SerdeState { + global_state: global_state.clone(), + inner: SerdeStateInner::default(), + }, + f, + ) + } + pub fn get_or_scope(f: impl for<'a> FnOnce(&'a SerdeState) -> R) -> R { + if SERDE_STATE.is_set() { + SERDE_STATE.with(f) + } else { + GlobalState::get(|global_state| Self::scope(global_state, || SERDE_STATE.with(f))) + } + } +} + +pub struct SerdeStateFor { + de: RefCell>>, + ser: RefCell, usize>>, +} + +impl Default for SerdeStateFor { + fn default() -> Self { + Self { + de: Default::default(), + ser: Default::default(), + } + } +} + +#[derive(Serialize, Deserialize)] +enum SerializedInterned { + Old(usize), + New(T), +} + +impl Serialize for Interned { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { - self.ptr.serialize(serializer) + SerdeState::get_or_scope(|serde_state| { + let mut state = T::get_serde_state_for(serde_state).ser.borrow_mut(); + let next_index = state.len(); + match state.entry(self.clone()) { + Entry::Occupied(entry) => SerializedInterned::Old(*entry.get()), + Entry::Vacant(entry) => { + entry.insert(next_index); + SerializedInterned::<&T>::New(self) + } + } + .serialize(serializer) + }) + } +} + +impl<'de, T, Owned> Deserialize<'de> for Interned +where + T: ?Sized + InternTarget + ToOwned, + Owned: Deserialize<'de> + Intern, +{ + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + SerdeState::get_or_scope(|serde_state| { + let mut state = T::get_serde_state_for(serde_state).de.borrow_mut(); + match SerializedInterned::::deserialize(deserializer)? { + SerializedInterned::Old(index) => state + .get(index) + .cloned() + .ok_or_else(|| ::custom("index out of range")), + SerializedInterned::New(value) => { + let retval = value.into_interned(&serde_state.global_state); + state.push(retval.clone()); + Ok(retval) + } + } + }) } } @@ -74,30 +170,64 @@ impl Ord for Interned { } } -#[derive(Default)] -struct Interners { - str: Interner, - loc_set: Interner, - loc_set_max_conflicts_with_loc_set: Interner>>, - loc_set_max_conflicts_with_loc: Interner>, +macro_rules! make_interners { + { + $( + $(#[$impl_intern_target:ident intern_target])? + $name:ident: $ty:ty, + )* + } => { + #[derive(Default)] + struct Interners { + $($name: Interner<$ty>,)* + } + + #[derive(Default)] + struct SerdeStateInner { + $($name: SerdeStateFor<$ty>,)* + } + + $($( + $impl_intern_target InternTarget for $ty { + fn get_interner(global_state: &GlobalState) -> &Interner { + &global_state.interners.$name + } + + fn get_serde_state_for(serde_state: &SerdeState) -> &SerdeStateFor { + &serde_state.inner.$name + } + } + )?)* + }; +} + +make_interners! { + str: str, + #[impl intern_target] + loc_set: LocSet, + #[impl intern_target] + loc_set_max_conflicts_with_loc_set: LocSetMaxConflictsWith>, + #[impl intern_target] + loc_set_max_conflicts_with_loc: LocSetMaxConflictsWith, } pub struct GlobalState { interners: Interners, } -scoped_tls::scoped_thread_local!(static GLOBAL_STATE: GlobalState); +scoped_tls::scoped_thread_local!(static GLOBAL_STATE: Rc); impl GlobalState { + #[cold] pub fn scope(f: impl FnOnce() -> R) -> R { GLOBAL_STATE.set( - &GlobalState { + &Rc::new(GlobalState { interners: Interners::default(), - }, + }), f, ) } - pub fn get(f: impl for<'a> FnOnce(&'a GlobalState) -> R) -> R { + pub fn get(f: impl for<'a> FnOnce(&'a Rc) -> R) -> R { GLOBAL_STATE.with(f) } } @@ -144,6 +274,7 @@ impl Interner { pub trait InternTarget: Intern + Hash + Eq { fn get_interner(global_state: &GlobalState) -> &Interner; + fn get_serde_state_for(serde_state: &SerdeState) -> &SerdeStateFor; fn into_interned(input: Self, global_state: &GlobalState) -> Interned where Self: Sized, @@ -174,6 +305,9 @@ impl InternTarget for str { fn get_interner(global_state: &GlobalState) -> &Interner { &global_state.interners.str } + fn get_serde_state_for(serde_state: &SerdeState) -> &SerdeStateFor { + &serde_state.inner.str + } } impl Intern for str { @@ -323,21 +457,3 @@ impl Intern for T { }) } } - -impl InternTarget for LocSet { - fn get_interner(global_state: &GlobalState) -> &Interner { - &global_state.interners.loc_set - } -} - -impl InternTarget for LocSetMaxConflictsWith> { - fn get_interner(global_state: &GlobalState) -> &Interner { - &global_state.interners.loc_set_max_conflicts_with_loc_set - } -} - -impl InternTarget for LocSetMaxConflictsWith { - fn get_interner(global_state: &GlobalState) -> &Interner { - &global_state.interners.loc_set_max_conflicts_with_loc - } -} diff --git a/register_allocator/src/lib.rs b/register_allocator/src/lib.rs index d1b478f..4a96fe7 100644 --- a/register_allocator/src/lib.rs +++ b/register_allocator/src/lib.rs @@ -1,6 +1,8 @@ #[macro_use] mod macros; pub mod error; +pub mod function; +pub mod index; pub mod interned; pub mod loc; pub mod loc_set; diff --git a/register_allocator/src/macros.rs b/register_allocator/src/macros.rs index b7a2510..3f7a089 100644 --- a/register_allocator/src/macros.rs +++ b/register_allocator/src/macros.rs @@ -94,6 +94,15 @@ macro_rules! const_try { }; } +macro_rules! const_try_opt { + ($v:expr $(,)?) => { + match $v { + Some(v) => v, + None => return None, + } + }; +} + macro_rules! nzu32_lit { ($v:literal) => {{ const V: ::std::num::NonZeroU32 = match ::std::num::NonZeroU32::new($v) {