-use crate::loc::{BaseTy, Ty};
+use crate::{
+ index::{BlockIdx, InstIdx},
+ loc::{BaseTy, Ty},
+};
use thiserror::Error;
#[derive(Debug, Error)]
ty: Option<Ty>,
expected_ty: Option<Ty>,
},
+ #[error("function doesn't have entry block")]
+ MissingEntryBlock,
+ #[error("instruction index is too big")]
+ InstIdxTooBig,
+ #[error("block has invalid start {start}, expected {expected_start}")]
+ BlockHasInvalidStart {
+ start: InstIdx,
+ expected_start: InstIdx,
+ },
+ #[error("block {block} doesn't contain any instructions")]
+ BlockIsEmpty { block: BlockIdx },
+ #[error("entry block must not have any block parameters")]
+ EntryBlockCantHaveParams,
+ #[error("entry block must not have any predecessors")]
+ EntryBlockCantHavePreds,
+ #[error("block end is out of range: {end}")]
+ BlockEndOutOfRange { end: InstIdx },
+ #[error("block's last instruction must be a block terminator: {term_idx}")]
+ BlocksLastInstMustBeTerm { term_idx: InstIdx },
+ #[error(
+ "block terminator instructions are only allowed as a block's last instruction: {inst_idx}"
+ )]
+ TermInstOnlyAllowedAtBlockEnd { inst_idx: InstIdx },
}
pub type Result<T, E = Error> = std::result::Result<T, E>;
--- /dev/null
+use crate::{
+ error::{Error, Result},
+ index::{BlockIdx, InstIdx, InstRange, SSAValIdx},
+ interned::Interned,
+ loc::{Loc, Ty},
+ loc_set::LocSet,
+};
+use core::fmt;
+use serde::{Deserialize, Serialize};
+use std::ops::Index;
+
+#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
+pub struct SSAVal {
+ pub ty: Ty,
+}
+
+#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)]
+#[repr(u8)]
+pub enum InstStage {
+ Early = 0,
+ Late = 1,
+}
+
+#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
+#[serde(try_from = "SerializedProgPoint", into = "SerializedProgPoint")]
+pub struct ProgPoint(usize);
+
+impl ProgPoint {
+ pub const fn new(inst: InstIdx, stage: InstStage) -> Self {
+ const_unwrap_res!(Self::try_new(inst, stage))
+ }
+ pub const fn try_new(inst: InstIdx, stage: InstStage) -> Result<Self> {
+ let Some(inst) = inst.get().checked_shl(1) else {
+ return Err(Error::InstIdxTooBig);
+ };
+ Ok(Self(inst | stage as usize))
+ }
+ pub const fn inst(self) -> InstIdx {
+ InstIdx::new(self.0 >> 1)
+ }
+ pub const fn stage(self) -> InstStage {
+ if self.0 & 1 != 0 {
+ InstStage::Late
+ } else {
+ InstStage::Early
+ }
+ }
+ pub const fn next(self) -> Self {
+ Self(self.0 + 1)
+ }
+ pub const fn prev(self) -> Self {
+ Self(self.0 - 1)
+ }
+}
+
+impl fmt::Debug for ProgPoint {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("ProgPoint")
+ .field("inst", &self.inst())
+ .field("stage", &self.stage())
+ .finish()
+ }
+}
+
+#[derive(Serialize, Deserialize)]
+struct SerializedProgPoint {
+ inst: InstIdx,
+ stage: InstStage,
+}
+
+impl From<ProgPoint> for SerializedProgPoint {
+ fn from(value: ProgPoint) -> Self {
+ Self {
+ inst: value.inst(),
+ stage: value.stage(),
+ }
+ }
+}
+
+impl TryFrom<SerializedProgPoint> for ProgPoint {
+ type Error = Error;
+
+ fn try_from(value: SerializedProgPoint) -> Result<Self, Self::Error> {
+ ProgPoint::try_new(value.inst, value.stage)
+ }
+}
+
+#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)]
+#[repr(u8)]
+pub enum OperandKind {
+ Use = 0,
+ Def = 1,
+}
+
+#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)]
+pub enum Constraint {
+ /// any register or stack location
+ Any,
+ /// r1-r32
+ BaseGpr,
+ /// r2,r4,r6,r8,...r126
+ SVExtra2VGpr,
+ /// r1-63
+ SVExtra2SGpr,
+ /// r1-127
+ SVExtra3Gpr,
+ /// any stack location
+ Stack,
+ FixedLoc(Loc),
+ Reuse(usize),
+}
+
+#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
+pub struct Operand {
+ pub ssa_val: SSAValIdx,
+ pub constraint: Constraint,
+ pub kind: OperandKind,
+ pub stage: InstStage,
+}
+
+#[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
+pub struct BranchSucc {
+ pub block: BlockIdx,
+ pub params: Vec<SSAValIdx>,
+}
+
+#[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
+pub enum InstKind {
+ Normal,
+ /// copy concatenates all `srcs` together and de-concatenates the result into all `dests`.
+ Copy {
+ srcs: Vec<Operand>,
+ dests: Vec<Operand>,
+ },
+ Return,
+ Branch {
+ succs: Vec<BranchSucc>,
+ },
+}
+
+impl InstKind {
+ pub fn is_normal(&self) -> bool {
+ matches!(self, Self::Normal)
+ }
+ pub fn is_block_term(&self) -> bool {
+ matches!(self, Self::Return | Self::Branch { .. })
+ }
+ pub fn succs(&self) -> Option<&[BranchSucc]> {
+ match self {
+ InstKind::Normal | InstKind::Copy { .. } => None,
+ InstKind::Return => Some(&[]),
+ InstKind::Branch { succs } => Some(succs),
+ }
+ }
+}
+
+impl Default for InstKind {
+ fn default() -> Self {
+ InstKind::Normal
+ }
+}
+
+#[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
+pub struct Inst {
+ #[serde(default, skip_serializing_if = "InstKind::is_normal")]
+ pub kind: InstKind,
+ pub operands: Vec<Operand>,
+ pub clobbers: Interned<LocSet>,
+}
+
+#[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
+pub struct Block {
+ pub params: Vec<SSAValIdx>,
+ pub insts: InstRange,
+ pub preds: Vec<BlockIdx>,
+}
+
+validated_fields! {
+ #[fields_ty = FnFields]
+ #[derive(Clone, PartialEq, Eq, Debug, Hash)]
+ pub struct Function {
+ pub ssa_vals: Vec<SSAVal>,
+ pub insts: Vec<Inst>,
+ pub blocks: Vec<Block>,
+ }
+}
+
+impl Function {
+ pub fn new(fields: FnFields) -> Result<Self> {
+ let FnFields {
+ ssa_vals,
+ insts: insts_vec,
+ blocks,
+ } = &fields;
+ let entry_block = blocks
+ .get(BlockIdx::ENTRY_BLOCK.get())
+ .ok_or(Error::MissingEntryBlock)?;
+ if !entry_block.params.is_empty() {
+ return Err(Error::EntryBlockCantHaveParams);
+ }
+ if !entry_block.preds.is_empty() {
+ return Err(Error::EntryBlockCantHavePreds);
+ }
+ let mut expected_start = InstIdx::new(0);
+ for (block_idx, block) in fields.blocks.iter().enumerate() {
+ let block_idx = BlockIdx::new(block_idx);
+ let Block {
+ params,
+ insts: inst_range,
+ preds,
+ } = block;
+ if inst_range.start != expected_start {
+ return Err(Error::BlockHasInvalidStart {
+ start: inst_range.start,
+ expected_start,
+ });
+ }
+ let Some((term_idx, non_term_inst_range)) = inst_range.split_last() else {
+ return Err(Error::BlockIsEmpty { block: block_idx });
+ };
+ expected_start = inst_range.end;
+ let Some(Inst { kind: term_kind, .. }) = insts_vec.get(term_idx.get()) else {
+ return Err(Error::BlockEndOutOfRange { end: inst_range.end });
+ };
+ if !term_kind.is_block_term() {
+ return Err(Error::BlocksLastInstMustBeTerm { term_idx });
+ }
+ for inst_idx in non_term_inst_range {
+ if insts_vec[inst_idx].kind.is_block_term() {
+ return Err(Error::TermInstOnlyAllowedAtBlockEnd { inst_idx });
+ }
+ }
+ }
+ todo!()
+ }
+ pub fn entry_block(&self) -> &Block {
+ &self.blocks[0]
+ }
+ pub fn block_succs(&self, block: BlockIdx) -> &[BranchSucc] {
+ self.insts[self.blocks[block].insts.last().unwrap()]
+ .kind
+ .succs()
+ .unwrap()
+ }
+}
+
+impl Index<SSAValIdx> for Vec<SSAVal> {
+ type Output = SSAVal;
+
+ fn index(&self, index: SSAValIdx) -> &Self::Output {
+ &self[index.get()]
+ }
+}
+
+impl Index<InstIdx> for Vec<Inst> {
+ type Output = Inst;
+
+ fn index(&self, index: InstIdx) -> &Self::Output {
+ &self[index.get()]
+ }
+}
+
+impl Index<BlockIdx> for Vec<Block> {
+ type Output = Block;
+
+ fn index(&self, index: BlockIdx) -> &Self::Output {
+ &self[index.get()]
+ }
+}
--- /dev/null
+use serde::{Deserialize, Serialize};
+use std::{fmt, iter::FusedIterator, ops::Range};
+
+macro_rules! define_index {
+ ($name:ident) => {
+ #[derive(
+ Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize,
+ )]
+ #[serde(transparent)]
+ pub struct $name {
+ value: usize,
+ }
+
+ impl $name {
+ pub const fn new(value: usize) -> Self {
+ Self { value }
+ }
+ pub const fn get(self) -> usize {
+ self.value
+ }
+ pub const fn next(self) -> Self {
+ Self {
+ value: self.value + 1,
+ }
+ }
+ pub const fn prev(self) -> Self {
+ Self {
+ value: self.value - 1,
+ }
+ }
+ }
+ };
+}
+
+define_index!(SSAValIdx);
+define_index!(InstIdx);
+define_index!(BlockIdx);
+
+impl fmt::Display for SSAValIdx {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "v{}", self.get())
+ }
+}
+
+impl fmt::Display for InstIdx {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "inst{}", self.get())
+ }
+}
+
+impl fmt::Display for BlockIdx {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "blk{}", self.get())
+ }
+}
+
+impl BlockIdx {
+ pub const ENTRY_BLOCK: BlockIdx = BlockIdx::new(0);
+}
+
+/// range of instruction indexes from `start` inclusive to `end` exclusive.
+#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
+pub struct InstRange {
+ pub start: InstIdx,
+ pub end: InstIdx,
+}
+
+impl fmt::Display for InstRange {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "inst{}..{}", self.start.get(), self.end.get())
+ }
+}
+
+impl IntoIterator for InstRange {
+ type Item = InstIdx;
+ type IntoIter = InstRangeIter;
+ fn into_iter(self) -> Self::IntoIter {
+ self.iter()
+ }
+}
+
+impl InstRange {
+ pub const fn iter(self) -> InstRangeIter {
+ InstRangeIter(self.start.get()..self.end.get())
+ }
+ pub const fn is_empty(self) -> bool {
+ self.start.get() >= self.end.get()
+ }
+ pub const fn first(self) -> Option<InstIdx> {
+ Some(const_try_opt!(self.split_first()).0)
+ }
+ pub const fn last(self) -> Option<InstIdx> {
+ Some(const_try_opt!(self.split_last()).0)
+ }
+ pub const fn split_first(self) -> Option<(InstIdx, InstRange)> {
+ if self.is_empty() {
+ None
+ } else {
+ Some((
+ self.start,
+ Self {
+ start: self.start.next(),
+ end: self.end,
+ },
+ ))
+ }
+ }
+ pub const fn split_last(self) -> Option<(InstIdx, InstRange)> {
+ if self.is_empty() {
+ None
+ } else {
+ Some((
+ self.end.prev(),
+ Self {
+ start: self.start,
+ end: self.end.prev(),
+ },
+ ))
+ }
+ }
+ pub const fn len(self) -> usize {
+ if self.is_empty() {
+ 0
+ } else {
+ self.end.get() - self.start.get()
+ }
+ }
+}
+
+#[derive(Clone, Debug)]
+pub struct InstRangeIter(Range<usize>);
+
+impl InstRangeIter {
+ pub const fn range(self) -> InstRange {
+ InstRange {
+ start: InstIdx::new(self.0.start),
+ end: InstIdx::new(self.0.end),
+ }
+ }
+}
+
+impl Iterator for InstRangeIter {
+ type Item = InstIdx;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ self.0.next().map(InstIdx::new)
+ }
+
+ fn size_hint(&self) -> (usize, Option<usize>) {
+ let v = self.0.len();
+ (v, Some(v))
+ }
+
+ fn count(self) -> usize
+ where
+ Self: Sized,
+ {
+ self.len()
+ }
+
+ fn last(self) -> Option<Self::Item>
+ where
+ Self: Sized,
+ {
+ self.0.last().map(InstIdx::new)
+ }
+
+ fn nth(&mut self, n: usize) -> Option<Self::Item> {
+ self.0.nth(n).map(InstIdx::new)
+ }
+}
+
+impl FusedIterator for InstRangeIter {}
+
+impl DoubleEndedIterator for InstRangeIter {
+ fn next_back(&mut self) -> Option<Self::Item> {
+ self.0.next_back().map(InstIdx::new)
+ }
+
+ fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
+ self.0.nth_back(n).map(InstIdx::new)
+ }
+}
+
+impl ExactSizeIterator for InstRangeIter {
+ fn len(&self) -> usize {
+ self.0.len()
+ }
+}
-use hashbrown::{hash_map::RawEntryMut, HashMap};
-use serde::Serialize;
+use crate::{
+ loc::Loc,
+ loc_set::{LocSet, LocSetMaxConflictsWith},
+};
+use hashbrown::{
+ hash_map::{Entry, RawEntryMut},
+ HashMap,
+};
+use serde::{de, Deserialize, Serialize};
use std::{
cell::RefCell,
cmp::Ordering,
rc::Rc,
};
-use crate::{
- loc::Loc,
- loc_set::{LocSet, LocSetMaxConflictsWith},
-};
-
-#[derive(Clone)]
pub struct Interned<T: ?Sized> {
ptr: Rc<T>,
}
}
}
+impl<T: ?Sized> Clone for Interned<T> {
+ fn clone(&self) -> Self {
+ Self {
+ ptr: self.ptr.clone(),
+ }
+ }
+}
+
impl<T: ?Sized> Hash for Interned<T> {
fn hash<H: Hasher>(&self, state: &mut H) {
Rc::as_ptr(&self.ptr).hash(state);
}
}
-impl<T: ?Sized + Serialize> Serialize for Interned<T> {
+pub struct SerdeState {
+ global_state: Rc<GlobalState>,
+ inner: SerdeStateInner,
+}
+
+scoped_tls::scoped_thread_local!(static SERDE_STATE: SerdeState);
+
+impl SerdeState {
+ pub fn global_state(&self) -> &Rc<GlobalState> {
+ &self.global_state
+ }
+ #[cold]
+ pub fn scope<R>(global_state: &Rc<GlobalState>, f: impl FnOnce() -> R) -> R {
+ SERDE_STATE.set(
+ &SerdeState {
+ global_state: global_state.clone(),
+ inner: SerdeStateInner::default(),
+ },
+ f,
+ )
+ }
+ pub fn get_or_scope<R>(f: impl for<'a> FnOnce(&'a SerdeState) -> R) -> R {
+ if SERDE_STATE.is_set() {
+ SERDE_STATE.with(f)
+ } else {
+ GlobalState::get(|global_state| Self::scope(global_state, || SERDE_STATE.with(f)))
+ }
+ }
+}
+
+pub struct SerdeStateFor<T: ?Sized> {
+ de: RefCell<Vec<Interned<T>>>,
+ ser: RefCell<HashMap<Interned<T>, usize>>,
+}
+
+impl<T: ?Sized> Default for SerdeStateFor<T> {
+ fn default() -> Self {
+ Self {
+ de: Default::default(),
+ ser: Default::default(),
+ }
+ }
+}
+
+#[derive(Serialize, Deserialize)]
+enum SerializedInterned<T> {
+ Old(usize),
+ New(T),
+}
+
+impl<T: ?Sized + Serialize + InternTarget> Serialize for Interned<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
- self.ptr.serialize(serializer)
+ SerdeState::get_or_scope(|serde_state| {
+ let mut state = T::get_serde_state_for(serde_state).ser.borrow_mut();
+ let next_index = state.len();
+ match state.entry(self.clone()) {
+ Entry::Occupied(entry) => SerializedInterned::Old(*entry.get()),
+ Entry::Vacant(entry) => {
+ entry.insert(next_index);
+ SerializedInterned::<&T>::New(self)
+ }
+ }
+ .serialize(serializer)
+ })
+ }
+}
+
+impl<'de, T, Owned> Deserialize<'de> for Interned<T>
+where
+ T: ?Sized + InternTarget + ToOwned<Owned = Owned>,
+ Owned: Deserialize<'de> + Intern<Target = T>,
+{
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ SerdeState::get_or_scope(|serde_state| {
+ let mut state = T::get_serde_state_for(serde_state).de.borrow_mut();
+ match SerializedInterned::<Owned>::deserialize(deserializer)? {
+ SerializedInterned::Old(index) => state
+ .get(index)
+ .cloned()
+ .ok_or_else(|| <D::Error as de::Error>::custom("index out of range")),
+ SerializedInterned::New(value) => {
+ let retval = value.into_interned(&serde_state.global_state);
+ state.push(retval.clone());
+ Ok(retval)
+ }
+ }
+ })
}
}
}
}
-#[derive(Default)]
-struct Interners {
- str: Interner<str>,
- loc_set: Interner<LocSet>,
- loc_set_max_conflicts_with_loc_set: Interner<LocSetMaxConflictsWith<Interned<LocSet>>>,
- loc_set_max_conflicts_with_loc: Interner<LocSetMaxConflictsWith<Loc>>,
+macro_rules! make_interners {
+ {
+ $(
+ $(#[$impl_intern_target:ident intern_target])?
+ $name:ident: $ty:ty,
+ )*
+ } => {
+ #[derive(Default)]
+ struct Interners {
+ $($name: Interner<$ty>,)*
+ }
+
+ #[derive(Default)]
+ struct SerdeStateInner {
+ $($name: SerdeStateFor<$ty>,)*
+ }
+
+ $($(
+ $impl_intern_target InternTarget for $ty {
+ fn get_interner(global_state: &GlobalState) -> &Interner<Self> {
+ &global_state.interners.$name
+ }
+
+ fn get_serde_state_for(serde_state: &SerdeState) -> &SerdeStateFor<Self> {
+ &serde_state.inner.$name
+ }
+ }
+ )?)*
+ };
+}
+
+make_interners! {
+ str: str,
+ #[impl intern_target]
+ loc_set: LocSet,
+ #[impl intern_target]
+ loc_set_max_conflicts_with_loc_set: LocSetMaxConflictsWith<Interned<LocSet>>,
+ #[impl intern_target]
+ loc_set_max_conflicts_with_loc: LocSetMaxConflictsWith<Loc>,
}
pub struct GlobalState {
interners: Interners,
}
-scoped_tls::scoped_thread_local!(static GLOBAL_STATE: GlobalState);
+scoped_tls::scoped_thread_local!(static GLOBAL_STATE: Rc<GlobalState>);
impl GlobalState {
+ #[cold]
pub fn scope<R>(f: impl FnOnce() -> R) -> R {
GLOBAL_STATE.set(
- &GlobalState {
+ &Rc::new(GlobalState {
interners: Interners::default(),
- },
+ }),
f,
)
}
- pub fn get<R>(f: impl for<'a> FnOnce(&'a GlobalState) -> R) -> R {
+ pub fn get<R>(f: impl for<'a> FnOnce(&'a Rc<GlobalState>) -> R) -> R {
GLOBAL_STATE.with(f)
}
}
pub trait InternTarget: Intern<Target = Self> + Hash + Eq {
fn get_interner(global_state: &GlobalState) -> &Interner<Self>;
+ fn get_serde_state_for(serde_state: &SerdeState) -> &SerdeStateFor<Self>;
fn into_interned(input: Self, global_state: &GlobalState) -> Interned<Self>
where
Self: Sized,
fn get_interner(global_state: &GlobalState) -> &Interner<Self> {
&global_state.interners.str
}
+ fn get_serde_state_for(serde_state: &SerdeState) -> &SerdeStateFor<Self> {
+ &serde_state.inner.str
+ }
}
impl Intern for str {
})
}
}
-
-impl InternTarget for LocSet {
- fn get_interner(global_state: &GlobalState) -> &Interner<Self> {
- &global_state.interners.loc_set
- }
-}
-
-impl InternTarget for LocSetMaxConflictsWith<Interned<LocSet>> {
- fn get_interner(global_state: &GlobalState) -> &Interner<Self> {
- &global_state.interners.loc_set_max_conflicts_with_loc_set
- }
-}
-
-impl InternTarget for LocSetMaxConflictsWith<Loc> {
- fn get_interner(global_state: &GlobalState) -> &Interner<Self> {
- &global_state.interners.loc_set_max_conflicts_with_loc
- }
-}
#[macro_use]
mod macros;
pub mod error;
+pub mod function;
+pub mod index;
pub mod interned;
pub mod loc;
pub mod loc_set;
};
}
+macro_rules! const_try_opt {
+ ($v:expr $(,)?) => {
+ match $v {
+ Some(v) => v,
+ None => return None,
+ }
+ };
+}
+
macro_rules! nzu32_lit {
($v:literal) => {{
const V: ::std::num::NonZeroU32 = match ::std::num::NonZeroU32::new($v) {