clean up warnings
[bigint-presentation-code.git] / register_allocator / src / function.rs
1 use crate::{
2 error::{Error, Result},
3 index::{BlockIdx, InstIdx, InstRange, SSAValIdx},
4 interned::{GlobalState, Intern, Interned},
5 loc::{BaseTy, Loc, Ty},
6 loc_set::LocSet,
7 };
8 use arbitrary::Arbitrary;
9 use core::fmt;
10 use enum_map::Enum;
11 use hashbrown::HashSet;
12 use petgraph::{
13 algo::dominators,
14 visit::{GraphBase, GraphProp, IntoNeighbors, VisitMap, Visitable},
15 Directed,
16 };
17 use serde::{Deserialize, Serialize};
18 use smallvec::SmallVec;
19 use std::{
20 collections::{btree_map, BTreeMap, BTreeSet},
21 mem,
22 ops::{Index, IndexMut},
23 };
24
25 #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
26 pub enum SSAValDef {
27 BlockParam { block: BlockIdx, param_idx: usize },
28 Operand { inst: InstIdx, operand_idx: usize },
29 }
30
31 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)]
32 pub struct BranchSuccParamUse {
33 pub branch_inst: InstIdx,
34 pub succ: BlockIdx,
35 pub param_idx: usize,
36 }
37
38 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)]
39 pub struct OperandUse {
40 pub inst: InstIdx,
41 pub operand_idx: usize,
42 }
43
44 #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
45 pub struct SSAVal {
46 pub ty: Ty,
47 pub def: SSAValDef,
48 pub operand_uses: BTreeSet<OperandUse>,
49 pub branch_succ_param_uses: BTreeSet<BranchSuccParamUse>,
50 }
51
52 impl SSAVal {
53 fn validate(&self, ssa_val_idx: SSAValIdx, func: &FnFields) -> Result<()> {
54 let Self {
55 ty: _,
56 def,
57 operand_uses,
58 branch_succ_param_uses,
59 } = self;
60 match *def {
61 SSAValDef::BlockParam { block, param_idx } => {
62 let block_param = func.try_get_block_param(block, param_idx)?;
63 if ssa_val_idx != block_param {
64 return Err(Error::MismatchedBlockParamDef {
65 ssa_val_idx,
66 block,
67 param_idx,
68 });
69 }
70 }
71 SSAValDef::Operand { inst, operand_idx } => {
72 let operand = func.try_get_operand(inst, operand_idx)?;
73 if ssa_val_idx != operand.ssa_val {
74 return Err(Error::SSAValDefIsNotOperandsSSAVal {
75 ssa_val_idx,
76 inst,
77 operand_idx,
78 });
79 }
80 }
81 }
82 for &OperandUse { inst, operand_idx } in operand_uses {
83 let operand = func.try_get_operand(inst, operand_idx)?;
84 if ssa_val_idx != operand.ssa_val {
85 return Err(Error::SSAValUseIsNotOperandsSSAVal {
86 ssa_val_idx,
87 inst,
88 operand_idx,
89 });
90 }
91 }
92 for &BranchSuccParamUse {
93 branch_inst,
94 succ,
95 param_idx,
96 } in branch_succ_param_uses
97 {
98 if ssa_val_idx != func.try_get_branch_target_param(branch_inst, succ, param_idx)? {
99 return Err(Error::MismatchedBranchTargetBlockParamUse {
100 ssa_val_idx,
101 branch_inst,
102 tgt_block: succ,
103 param_idx,
104 });
105 }
106 }
107 Ok(())
108 }
109 }
110
111 #[derive(
112 Copy,
113 Clone,
114 PartialEq,
115 Eq,
116 PartialOrd,
117 Ord,
118 Debug,
119 Hash,
120 Serialize,
121 Deserialize,
122 Arbitrary,
123 Enum,
124 )]
125 #[repr(u8)]
126 pub enum InstStage {
127 Early = 0,
128 Late = 1,
129 }
130
131 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
132 #[serde(try_from = "SerializedProgPoint", into = "SerializedProgPoint")]
133 pub struct ProgPoint(usize);
134
135 impl ProgPoint {
136 pub const fn new(inst: InstIdx, stage: InstStage) -> Self {
137 const_unwrap_res!(Self::try_new(inst, stage))
138 }
139 pub const fn try_new(inst: InstIdx, stage: InstStage) -> Result<Self> {
140 let Some(inst) = inst.get().checked_shl(1) else {
141 return Err(Error::InstIdxTooBig);
142 };
143 Ok(Self(inst | stage as usize))
144 }
145 pub const fn inst(self) -> InstIdx {
146 InstIdx::new(self.0 >> 1)
147 }
148 pub const fn stage(self) -> InstStage {
149 if self.0 & 1 != 0 {
150 InstStage::Late
151 } else {
152 InstStage::Early
153 }
154 }
155 pub const fn next(self) -> Self {
156 Self(self.0 + 1)
157 }
158 pub const fn prev(self) -> Self {
159 Self(self.0 - 1)
160 }
161 }
162
163 impl fmt::Debug for ProgPoint {
164 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165 f.debug_struct("ProgPoint")
166 .field("inst", &self.inst())
167 .field("stage", &self.stage())
168 .finish()
169 }
170 }
171
172 #[derive(Serialize, Deserialize)]
173 struct SerializedProgPoint {
174 inst: InstIdx,
175 stage: InstStage,
176 }
177
178 impl From<ProgPoint> for SerializedProgPoint {
179 fn from(value: ProgPoint) -> Self {
180 Self {
181 inst: value.inst(),
182 stage: value.stage(),
183 }
184 }
185 }
186
187 impl TryFrom<SerializedProgPoint> for ProgPoint {
188 type Error = Error;
189
190 fn try_from(value: SerializedProgPoint) -> Result<Self, Self::Error> {
191 ProgPoint::try_new(value.inst, value.stage)
192 }
193 }
194
195 #[derive(
196 Copy,
197 Clone,
198 PartialEq,
199 Eq,
200 PartialOrd,
201 Ord,
202 Debug,
203 Hash,
204 Serialize,
205 Deserialize,
206 Arbitrary,
207 Enum,
208 )]
209 #[repr(u8)]
210 pub enum OperandKind {
211 Use = 0,
212 Def = 1,
213 }
214
215 #[derive(
216 Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize, Arbitrary,
217 )]
218 pub enum Constraint {
219 /// any register or stack location
220 Any,
221 /// r1-r32
222 BaseGpr,
223 /// r2,r4,r6,r8,...r126
224 SVExtra2VGpr,
225 /// r1-63
226 SVExtra2SGpr,
227 /// r1-127
228 SVExtra3Gpr,
229 /// any stack location
230 Stack,
231 FixedLoc(Loc),
232 }
233
234 impl Constraint {
235 pub fn is_any(&self) -> bool {
236 matches!(self, Self::Any)
237 }
238 pub fn fixed_loc(&self) -> Option<Loc> {
239 match *self {
240 Constraint::Any
241 | Constraint::BaseGpr
242 | Constraint::SVExtra2VGpr
243 | Constraint::SVExtra2SGpr
244 | Constraint::SVExtra3Gpr
245 | Constraint::Stack => None,
246 Constraint::FixedLoc(v) => Some(v),
247 }
248 }
249 pub fn non_fixed_choices_for_ty(ty: Ty) -> &'static [Constraint] {
250 match (ty.base_ty, ty.reg_len.get()) {
251 (BaseTy::Bits64, 1) => &[
252 Constraint::Any,
253 Constraint::BaseGpr,
254 Constraint::SVExtra2SGpr,
255 Constraint::SVExtra2VGpr,
256 Constraint::SVExtra3Gpr,
257 Constraint::Stack,
258 ],
259 (BaseTy::Bits64, _) => &[
260 Constraint::Any,
261 Constraint::SVExtra2VGpr,
262 Constraint::SVExtra3Gpr,
263 Constraint::Stack,
264 ],
265 (BaseTy::Ca, _) | (BaseTy::VlMaxvl, _) => &[Constraint::Any, Constraint::Stack],
266 }
267 }
268 pub fn arbitrary_with_ty(
269 ty: Ty,
270 u: &mut arbitrary::Unstructured<'_>,
271 ) -> arbitrary::Result<Self> {
272 let non_fixed_choices = Self::non_fixed_choices_for_ty(ty);
273 if let Some(&retval) = non_fixed_choices.get(u.choose_index(non_fixed_choices.len() + 1)?) {
274 Ok(retval)
275 } else {
276 Ok(Constraint::FixedLoc(Loc::arbitrary_with_ty(ty, u)?))
277 }
278 }
279 pub fn check_for_ty_mismatch(&self, ty: Ty) -> Result<(), ()> {
280 match self {
281 Constraint::Any | Constraint::Stack => {}
282 Constraint::BaseGpr | Constraint::SVExtra2SGpr => {
283 if ty != Ty::scalar(BaseTy::Bits64) {
284 return Err(());
285 }
286 }
287 Constraint::SVExtra2VGpr | Constraint::SVExtra3Gpr => {
288 if ty.base_ty != BaseTy::Bits64 {
289 return Err(());
290 }
291 }
292 Constraint::FixedLoc(loc) => {
293 if ty != loc.ty() {
294 return Err(());
295 }
296 }
297 }
298 Ok(())
299 }
300 }
301
302 impl Default for Constraint {
303 fn default() -> Self {
304 Self::Any
305 }
306 }
307
308 #[derive(
309 Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize, Default,
310 )]
311 #[serde(try_from = "OperandKind", into = "OperandKind")]
312 pub struct OperandKindDefOnly;
313
314 impl TryFrom<OperandKind> for OperandKindDefOnly {
315 type Error = Error;
316
317 fn try_from(value: OperandKind) -> Result<Self, Self::Error> {
318 match value {
319 OperandKind::Use => Err(Error::OperandKindMustBeDef),
320 OperandKind::Def => Ok(Self),
321 }
322 }
323 }
324
325 impl From<OperandKindDefOnly> for OperandKind {
326 fn from(_value: OperandKindDefOnly) -> Self {
327 Self::Def
328 }
329 }
330
331 #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Serialize, Deserialize)]
332 #[serde(untagged)]
333 pub enum KindAndConstraint {
334 Reuse {
335 kind: OperandKindDefOnly,
336 reuse_operand_idx: usize,
337 },
338 Constraint {
339 kind: OperandKind,
340 #[serde(default, skip_serializing_if = "Constraint::is_any")]
341 constraint: Constraint,
342 },
343 }
344
345 impl KindAndConstraint {
346 pub fn kind(self) -> OperandKind {
347 match self {
348 Self::Reuse { .. } => OperandKind::Def,
349 Self::Constraint { kind, .. } => kind,
350 }
351 }
352 pub fn is_reuse(self) -> bool {
353 matches!(self, Self::Reuse { .. })
354 }
355 }
356
357 #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
358 pub struct Operand {
359 pub ssa_val: SSAValIdx,
360 #[serde(flatten)]
361 pub kind_and_constraint: KindAndConstraint,
362 pub stage: InstStage,
363 }
364
365 impl Operand {
366 pub fn try_get_reuse_src<'f>(
367 &self,
368 inst: InstIdx,
369 func: &'f FnFields,
370 ) -> Result<Option<&'f Operand>> {
371 if let KindAndConstraint::Reuse {
372 reuse_operand_idx, ..
373 } = self.kind_and_constraint
374 {
375 Ok(Some(func.try_get_operand(inst, reuse_operand_idx)?))
376 } else {
377 Ok(None)
378 }
379 }
380 pub fn try_constraint(&self, inst: InstIdx, func: &FnFields) -> Result<Constraint> {
381 Ok(match self.kind_and_constraint {
382 KindAndConstraint::Reuse {
383 kind: _,
384 reuse_operand_idx,
385 } => {
386 let operand = func.try_get_operand(inst, reuse_operand_idx)?;
387 match operand.kind_and_constraint {
388 KindAndConstraint::Reuse { .. }
389 | KindAndConstraint::Constraint {
390 kind: OperandKind::Def,
391 ..
392 } => {
393 return Err(Error::ReuseTargetOperandMustBeUse {
394 inst,
395 reuse_target_operand_idx: reuse_operand_idx,
396 })
397 }
398 KindAndConstraint::Constraint {
399 kind: OperandKind::Use,
400 constraint,
401 } => constraint,
402 }
403 }
404 KindAndConstraint::Constraint { constraint, .. } => constraint,
405 })
406 }
407 pub fn constraint(&self, inst: InstIdx, func: &Function) -> Constraint {
408 self.try_constraint(inst, func).unwrap()
409 }
410 fn validate(
411 self,
412 _block: BlockIdx,
413 inst: InstIdx,
414 operand_idx: usize,
415 func: &FnFields,
416 global_state: &GlobalState,
417 ) -> Result<()> {
418 let Self {
419 ssa_val: ssa_val_idx,
420 kind_and_constraint,
421 stage: _,
422 } = self;
423 let ssa_val = func.try_get_ssa_val(ssa_val_idx)?;
424 match kind_and_constraint.kind() {
425 OperandKind::Use => {
426 if !ssa_val
427 .operand_uses
428 .contains(&OperandUse { inst, operand_idx })
429 {
430 return Err(Error::MissingOperandUse {
431 ssa_val_idx,
432 inst,
433 operand_idx,
434 });
435 }
436 }
437 OperandKind::Def => {
438 let def = SSAValDef::Operand { inst, operand_idx };
439 if ssa_val.def != def {
440 return Err(Error::OperandDefIsNotSSAValDef {
441 ssa_val_idx,
442 inst,
443 operand_idx,
444 });
445 }
446 }
447 }
448 if let KindAndConstraint::Reuse {
449 kind: _,
450 reuse_operand_idx,
451 } = self.kind_and_constraint
452 {
453 let reuse_src = func.try_get_operand(inst, reuse_operand_idx)?;
454 let reuse_src_ssa_val = func.try_get_ssa_val(reuse_src.ssa_val)?;
455 if ssa_val.ty != reuse_src_ssa_val.ty {
456 return Err(Error::ReuseOperandTyMismatch {
457 inst,
458 tgt_operand_idx: operand_idx,
459 src_operand_idx: reuse_operand_idx,
460 src_ty: reuse_src_ssa_val.ty,
461 tgt_ty: ssa_val.ty,
462 });
463 }
464 }
465 let constraint = self.try_constraint(inst, func)?;
466 constraint
467 .check_for_ty_mismatch(ssa_val.ty)
468 .map_err(|()| Error::ConstraintTyMismatch {
469 ssa_val_idx,
470 inst,
471 operand_idx,
472 })?;
473 if let Some(fixed_loc) = constraint.fixed_loc() {
474 if func
475 .try_get_inst(inst)?
476 .clobbers
477 .clone()
478 .conflicts_with(fixed_loc, global_state)
479 {
480 return Err(Error::FixedLocConflictsWithClobbers { inst, operand_idx });
481 }
482 }
483 Ok(())
484 }
485 }
486
487 /// copy concatenates all `srcs` together and de-concatenates the result into all `dests`.
488 #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
489 pub struct CopyInstKind {
490 pub src_operand_idxs: Vec<usize>,
491 pub dest_operand_idxs: Vec<usize>,
492 pub copy_ty: Ty,
493 }
494
495 impl CopyInstKind {
496 fn calc_copy_ty(operand_idxs: &[usize], inst: InstIdx, func: &FnFields) -> Result<Option<Ty>> {
497 let mut retval: Option<Ty> = None;
498 for &operand_idx in operand_idxs {
499 let operand = func.try_get_operand(inst, operand_idx)?;
500 let ssa_val = func.try_get_ssa_val(operand.ssa_val)?;
501 retval = Some(match retval {
502 Some(retval) => retval.try_concat(ssa_val.ty)?,
503 None => ssa_val.ty,
504 });
505 }
506 Ok(retval)
507 }
508 }
509
510 #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
511 pub struct BlockTermInstKind {
512 pub succs_and_params: BTreeMap<BlockIdx, Vec<SSAValIdx>>,
513 }
514
515 #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
516 pub enum InstKind {
517 Normal,
518 Copy(CopyInstKind),
519 BlockTerm(BlockTermInstKind),
520 }
521
522 impl InstKind {
523 pub fn is_normal(&self) -> bool {
524 matches!(self, Self::Normal)
525 }
526 pub fn is_block_term(&self) -> bool {
527 matches!(self, Self::BlockTerm { .. })
528 }
529 pub fn is_copy(&self) -> bool {
530 matches!(self, Self::Copy { .. })
531 }
532 pub fn block_term(&self) -> Option<&BlockTermInstKind> {
533 match self {
534 InstKind::BlockTerm(v) => Some(v),
535 _ => None,
536 }
537 }
538 pub fn block_term_mut(&mut self) -> Option<&mut BlockTermInstKind> {
539 match self {
540 InstKind::BlockTerm(v) => Some(v),
541 _ => None,
542 }
543 }
544 pub fn copy(&self) -> Option<&CopyInstKind> {
545 match self {
546 InstKind::Copy(v) => Some(v),
547 _ => None,
548 }
549 }
550 }
551
552 impl Default for InstKind {
553 fn default() -> Self {
554 InstKind::Normal
555 }
556 }
557
558 fn loc_set_is_empty(clobbers: &Interned<LocSet>) -> bool {
559 clobbers.is_empty()
560 }
561
562 fn empty_loc_set() -> Interned<LocSet> {
563 GlobalState::get(|global_state| LocSet::default().into_interned(global_state))
564 }
565
566 #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
567 pub struct Inst {
568 #[serde(default, skip_serializing_if = "InstKind::is_normal")]
569 pub kind: InstKind,
570 pub operands: Vec<Operand>,
571 #[serde(default = "empty_loc_set", skip_serializing_if = "loc_set_is_empty")]
572 pub clobbers: Interned<LocSet>,
573 }
574
575 impl Inst {
576 fn validate(
577 &self,
578 block: BlockIdx,
579 inst: InstIdx,
580 func: &FnFields,
581 global_state: &GlobalState,
582 ) -> Result<()> {
583 let Self {
584 kind,
585 operands,
586 clobbers: _,
587 } = self;
588 let is_at_end_of_block = func.blocks[block].insts.last() == Some(inst);
589 if kind.is_block_term() != is_at_end_of_block {
590 return Err(if is_at_end_of_block {
591 Error::BlocksLastInstMustBeTerm { term_idx: inst }
592 } else {
593 Error::TermInstOnlyAllowedAtBlockEnd { inst_idx: inst }
594 });
595 }
596 for (idx, operand) in operands.iter().enumerate() {
597 operand.validate(block, inst, idx, func, global_state)?;
598 }
599 match kind {
600 InstKind::Normal => {}
601 InstKind::Copy(CopyInstKind {
602 src_operand_idxs,
603 dest_operand_idxs,
604 copy_ty,
605 }) => {
606 let mut seen_dest_operands = SmallVec::<[bool; 16]>::new();
607 seen_dest_operands.resize(operands.len(), false);
608 for &dest_operand_idx in dest_operand_idxs {
609 let seen_dest_operand = seen_dest_operands.get_mut(dest_operand_idx).ok_or(
610 Error::OperandIndexOutOfRange {
611 inst,
612 operand_idx: dest_operand_idx,
613 },
614 )?;
615 if mem::replace(seen_dest_operand, true) {
616 return Err(Error::DupCopyDestOperand {
617 inst,
618 operand_idx: dest_operand_idx,
619 });
620 }
621 }
622 if Some(*copy_ty) != CopyInstKind::calc_copy_ty(&src_operand_idxs, inst, func)? {
623 return Err(Error::CopySrcTyMismatch { inst });
624 }
625 if Some(*copy_ty) != CopyInstKind::calc_copy_ty(&dest_operand_idxs, inst, func)? {
626 return Err(Error::CopyDestTyMismatch { inst });
627 }
628 }
629 InstKind::BlockTerm(BlockTermInstKind { succs_and_params }) => {
630 for (&succ_idx, params) in succs_and_params {
631 let succ = func.try_get_block(succ_idx)?;
632 if !succ.preds.contains(&block) {
633 return Err(Error::SrcBlockMissingFromBranchTgtBlocksPreds {
634 src_block: block,
635 branch_inst: inst,
636 tgt_block: succ_idx,
637 });
638 }
639 if succ.params.len() != params.len() {
640 return Err(Error::BranchSuccParamCountMismatch {
641 inst,
642 succ: succ_idx,
643 block_param_count: succ.params.len(),
644 branch_param_count: params.len(),
645 });
646 }
647 for (param_idx, (&branch_ssa_val_idx, &block_ssa_val_idx)) in
648 params.iter().zip(&succ.params).enumerate()
649 {
650 let branch_ssa_val = func.try_get_ssa_val(branch_ssa_val_idx)?;
651 let block_ssa_val = func.try_get_ssa_val(block_ssa_val_idx)?;
652 if !branch_ssa_val
653 .branch_succ_param_uses
654 .contains(&BranchSuccParamUse {
655 branch_inst: inst,
656 succ: succ_idx,
657 param_idx,
658 })
659 {
660 return Err(Error::MissingBranchSuccParamUse {
661 ssa_val_idx: branch_ssa_val_idx,
662 inst,
663 succ: succ_idx,
664 param_idx,
665 });
666 }
667 if block_ssa_val.ty != branch_ssa_val.ty {
668 return Err(Error::BranchSuccParamTyMismatch {
669 inst,
670 succ: succ_idx,
671 param_idx,
672 block_param_ty: block_ssa_val.ty,
673 branch_param_ty: branch_ssa_val.ty,
674 });
675 }
676 }
677 }
678 }
679 }
680 Ok(())
681 }
682 pub fn try_get_operand(&self, inst: InstIdx, operand_idx: usize) -> Result<&Operand> {
683 self.operands
684 .get(operand_idx)
685 .ok_or(Error::OperandIndexOutOfRange { inst, operand_idx })
686 }
687 }
688
689 #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
690 pub struct Block {
691 pub params: Vec<SSAValIdx>,
692 pub insts: InstRange,
693 pub preds: BTreeSet<BlockIdx>,
694 pub immediate_dominator: Option<BlockIdx>,
695 }
696
697 impl Block {
698 fn validate(&self, block: BlockIdx, func: &FnFields, global_state: &GlobalState) -> Result<()> {
699 let Self {
700 params,
701 insts,
702 preds,
703 immediate_dominator: _, // validated by Function::new_with_global_state
704 } = self;
705 const _: () = assert!(BlockIdx::ENTRY_BLOCK.get() == 0);
706 let expected_start = if block == BlockIdx::ENTRY_BLOCK {
707 InstIdx::new(0)
708 } else {
709 func.blocks[block.prev()].insts.end
710 };
711 if insts.start != expected_start {
712 return Err(Error::BlockHasInvalidStart {
713 start: insts.start,
714 expected_start,
715 });
716 }
717 let term_inst_idx = insts.last().ok_or(Error::BlockIsEmpty { block })?;
718 func.insts
719 .get(term_inst_idx.get())
720 .ok_or(Error::BlockEndOutOfRange { end: insts.end })?;
721 if block.get() == func.blocks.len() - 1 && insts.end.get() != func.insts.len() {
722 return Err(Error::InstHasNoBlock { inst: insts.end });
723 }
724 if block == BlockIdx::ENTRY_BLOCK {
725 if !params.is_empty() {
726 return Err(Error::EntryBlockCantHaveParams);
727 }
728 if !preds.is_empty() {
729 return Err(Error::EntryBlockCantHavePreds);
730 }
731 }
732 for inst in *insts {
733 func.insts[inst].validate(block, inst, func, global_state)?;
734 }
735 for (param_idx, &ssa_val_idx) in params.iter().enumerate() {
736 let ssa_val = func.try_get_ssa_val(ssa_val_idx)?;
737 let def = SSAValDef::BlockParam { block, param_idx };
738 if ssa_val.def != def {
739 return Err(Error::MismatchedBlockParamDef {
740 ssa_val_idx,
741 block,
742 param_idx,
743 });
744 }
745 }
746 for &pred in preds {
747 let (term_inst, BlockTermInstKind { succs_and_params }) =
748 func.try_get_block_term_inst_and_kind(pred)?;
749 if !succs_and_params.contains_key(&block) {
750 return Err(Error::PredMissingFromPredsTermBranchsTargets {
751 src_block: pred,
752 branch_inst: term_inst,
753 tgt_block: block,
754 });
755 }
756 if preds.len() > 1 && succs_and_params.len() > 1 {
757 return Err(Error::CriticalEdgeNotAllowed {
758 src_block: pred,
759 branch_inst: term_inst,
760 tgt_block: block,
761 });
762 }
763 }
764 Ok(())
765 }
766 }
767
768 validated_fields! {
769 #[fields_ty = FnFields]
770 #[derive(Clone, PartialEq, Eq, Debug, Hash)]
771 pub struct Function {
772 pub ssa_vals: Vec<SSAVal>,
773 pub insts: Vec<Inst>,
774 pub blocks: Vec<Block>,
775 #[serde(skip)]
776 /// map from blocks' start instruction's index to their block index, doesn't contain the entry block
777 pub start_inst_to_block_map: BTreeMap<InstIdx, BlockIdx>,
778 }
779 }
780
781 impl Function {
782 pub fn new(fields: FnFields) -> Result<Self> {
783 GlobalState::get(|global_state| Self::new_with_global_state(fields, global_state))
784 }
785 pub fn new_with_global_state(mut fields: FnFields, global_state: &GlobalState) -> Result<Self> {
786 fields.fill_start_inst_to_block_map();
787 let FnFields {
788 ssa_vals,
789 insts: _,
790 blocks,
791 start_inst_to_block_map: _,
792 } = &fields;
793 blocks
794 .get(BlockIdx::ENTRY_BLOCK.get())
795 .ok_or(Error::MissingEntryBlock)?;
796 for (idx, block) in blocks.iter().enumerate() {
797 block.validate(BlockIdx::new(idx), &fields, global_state)?;
798 }
799 let dominators = dominators::simple_fast(&fields, BlockIdx::ENTRY_BLOCK);
800 for (idx, block) in blocks.iter().enumerate() {
801 let block_idx = BlockIdx::new(idx);
802 let expected = dominators.immediate_dominator(block_idx);
803 if block.immediate_dominator != expected {
804 return Err(Error::IncorrectImmediateDominator {
805 block_idx,
806 found: block.immediate_dominator,
807 expected,
808 });
809 }
810 }
811 for (idx, ssa_val) in ssa_vals.iter().enumerate() {
812 ssa_val.validate(SSAValIdx::new(idx), &fields)?;
813 }
814 Ok(Self(fields))
815 }
816 pub fn entry_block(&self) -> &Block {
817 &self.blocks[0]
818 }
819 pub fn block_term_kind(&self, block: BlockIdx) -> &BlockTermInstKind {
820 self.insts[self.blocks[block].insts.last().unwrap()]
821 .kind
822 .block_term()
823 .unwrap()
824 }
825 }
826
827 impl FnFields {
828 pub fn fill_start_inst_to_block_map(&mut self) {
829 self.start_inst_to_block_map.clear();
830 for (idx, block) in self.blocks.iter().enumerate() {
831 let block_idx = BlockIdx::new(idx);
832 if block_idx != BlockIdx::ENTRY_BLOCK {
833 self.start_inst_to_block_map
834 .insert(block.insts.start, block_idx);
835 }
836 }
837 }
838 pub fn try_get_ssa_val(&self, idx: SSAValIdx) -> Result<&SSAVal> {
839 self.ssa_vals
840 .get(idx.get())
841 .ok_or(Error::SSAValIdxOutOfRange { idx })
842 }
843 pub fn try_get_inst(&self, idx: InstIdx) -> Result<&Inst> {
844 self.insts
845 .get(idx.get())
846 .ok_or(Error::InstIdxOutOfRange { idx })
847 }
848 pub fn try_get_inst_mut(&mut self, idx: InstIdx) -> Result<&mut Inst> {
849 self.insts
850 .get_mut(idx.get())
851 .ok_or(Error::InstIdxOutOfRange { idx })
852 }
853 pub fn try_get_operand(&self, inst: InstIdx, operand_idx: usize) -> Result<&Operand> {
854 self.try_get_inst(inst)?.try_get_operand(inst, operand_idx)
855 }
856 pub fn try_get_block(&self, idx: BlockIdx) -> Result<&Block> {
857 self.blocks
858 .get(idx.get())
859 .ok_or(Error::BlockIdxOutOfRange { idx })
860 }
861 pub fn try_get_block_param(&self, block: BlockIdx, param_idx: usize) -> Result<SSAValIdx> {
862 self.try_get_block(block)?
863 .params
864 .get(param_idx)
865 .copied()
866 .ok_or(Error::BlockParamIdxOutOfRange { block, param_idx })
867 }
868 pub fn try_get_block_term_inst_idx(&self, block: BlockIdx) -> Result<InstIdx> {
869 self.try_get_block(block)?
870 .insts
871 .last()
872 .ok_or(Error::BlockIsEmpty { block })
873 }
874 pub fn try_get_block_term_inst_and_kind(
875 &self,
876 block: BlockIdx,
877 ) -> Result<(InstIdx, &BlockTermInstKind)> {
878 let term_idx = self.try_get_block_term_inst_idx(block)?;
879 let term_kind = self
880 .try_get_inst(term_idx)?
881 .kind
882 .block_term()
883 .ok_or(Error::BlocksLastInstMustBeTerm { term_idx })?;
884 Ok((term_idx, term_kind))
885 }
886 pub fn try_get_block_term_inst_and_kind_mut(
887 &mut self,
888 block: BlockIdx,
889 ) -> Result<(InstIdx, &mut BlockTermInstKind)> {
890 let term_idx = self.try_get_block_term_inst_idx(block)?;
891 let term_kind = self
892 .try_get_inst_mut(term_idx)?
893 .kind
894 .block_term_mut()
895 .ok_or(Error::BlocksLastInstMustBeTerm { term_idx })?;
896 Ok((term_idx, term_kind))
897 }
898 pub fn try_get_branch_target_params(
899 &self,
900 branch_inst: InstIdx,
901 succ: BlockIdx,
902 ) -> Result<&[SSAValIdx]> {
903 let inst = self.try_get_inst(branch_inst)?;
904 let BlockTermInstKind { succs_and_params } = inst
905 .kind
906 .block_term()
907 .ok_or(Error::InstIsNotBlockTerm { inst: branch_inst })?;
908 Ok(succs_and_params
909 .get(&succ)
910 .ok_or(Error::BranchTargetNotFound {
911 branch_inst,
912 tgt_block: succ,
913 })?)
914 }
915 pub fn try_get_branch_target_param(
916 &self,
917 branch_inst: InstIdx,
918 succ: BlockIdx,
919 param_idx: usize,
920 ) -> Result<SSAValIdx> {
921 Ok(*self
922 .try_get_branch_target_params(branch_inst, succ)?
923 .get(param_idx)
924 .ok_or(Error::BranchTargetParamIdxOutOfRange {
925 branch_inst,
926 tgt_block: succ,
927 param_idx,
928 })?)
929 }
930 pub fn inst_to_block(&self, inst: InstIdx) -> BlockIdx {
931 self.start_inst_to_block_map
932 .range(..=inst)
933 .next_back()
934 .map(|v| *v.1)
935 .unwrap_or(BlockIdx::ENTRY_BLOCK)
936 }
937 }
938
939 impl Index<SSAValIdx> for Vec<SSAVal> {
940 type Output = SSAVal;
941
942 fn index(&self, index: SSAValIdx) -> &Self::Output {
943 &self[index.get()]
944 }
945 }
946
947 impl IndexMut<SSAValIdx> for Vec<SSAVal> {
948 fn index_mut(&mut self, index: SSAValIdx) -> &mut Self::Output {
949 &mut self[index.get()]
950 }
951 }
952
953 impl Index<InstIdx> for Vec<Inst> {
954 type Output = Inst;
955
956 fn index(&self, index: InstIdx) -> &Self::Output {
957 &self[index.get()]
958 }
959 }
960
961 impl IndexMut<InstIdx> for Vec<Inst> {
962 fn index_mut(&mut self, index: InstIdx) -> &mut Self::Output {
963 &mut self[index.get()]
964 }
965 }
966
967 impl Index<BlockIdx> for Vec<Block> {
968 type Output = Block;
969
970 fn index(&self, index: BlockIdx) -> &Self::Output {
971 &self[index.get()]
972 }
973 }
974
975 impl IndexMut<BlockIdx> for Vec<Block> {
976 fn index_mut(&mut self, index: BlockIdx) -> &mut Self::Output {
977 &mut self[index.get()]
978 }
979 }
980
981 impl GraphBase for FnFields {
982 type EdgeId = (BlockIdx, BlockIdx);
983 type NodeId = BlockIdx;
984 }
985
986 pub struct Neighbors<'a> {
987 iter: Option<btree_map::Keys<'a, BlockIdx, Vec<SSAValIdx>>>,
988 }
989
990 impl Iterator for Neighbors<'_> {
991 type Item = BlockIdx;
992
993 fn next(&mut self) -> Option<Self::Item> {
994 Some(*self.iter.as_mut()?.next()?)
995 }
996 }
997
998 impl<'a> IntoNeighbors for &'a FnFields {
999 type Neighbors = Neighbors<'a>;
1000
1001 fn neighbors(self, block_idx: Self::NodeId) -> Self::Neighbors {
1002 Neighbors {
1003 iter: self
1004 .try_get_block_term_inst_and_kind(block_idx)
1005 .ok()
1006 .map(|(_, BlockTermInstKind { succs_and_params })| succs_and_params.keys()),
1007 }
1008 }
1009 }
1010
1011 pub struct VisitedMap(HashSet<BlockIdx>);
1012
1013 impl VisitMap<BlockIdx> for VisitedMap {
1014 fn visit(&mut self, block: BlockIdx) -> bool {
1015 self.0.insert(block)
1016 }
1017
1018 fn is_visited(&self, block: &BlockIdx) -> bool {
1019 self.0.contains(block)
1020 }
1021 }
1022
1023 impl Visitable for FnFields {
1024 type Map = VisitedMap;
1025
1026 fn visit_map(&self) -> Self::Map {
1027 VisitedMap(HashSet::new())
1028 }
1029
1030 fn reset_map(&self, map: &mut Self::Map) {
1031 map.0.clear();
1032 }
1033 }
1034
1035 impl GraphProp for FnFields {
1036 type EdgeType = Directed;
1037 }
1038
1039 #[cfg(test)]
1040 mod tests {
1041 use super::*;
1042 use crate::loc::TyFields;
1043 use std::num::NonZeroU32;
1044
1045 #[test]
1046 fn test_constraint_non_fixed_choices_for_ty() {
1047 macro_rules! seen {
1048 (
1049 enum ConstraintWithoutFixedLoc {
1050 $($field:ident,)*
1051 }
1052 ) => {
1053 #[derive(Default)]
1054 #[allow(non_snake_case)]
1055 struct Seen {
1056 $($field: bool,)*
1057 }
1058
1059 impl Seen {
1060 fn add(&mut self, constraint: &Constraint) {
1061 match constraint {
1062 Constraint::FixedLoc(_) => {}
1063 $(Constraint::$field => self.$field = true,)*
1064 }
1065 }
1066 fn check(self) {
1067 $(assert!(self.$field, "never seen field: {}", stringify!($field));)*
1068 }
1069 }
1070 };
1071 }
1072 seen! {
1073 enum ConstraintWithoutFixedLoc {
1074 Any,
1075 BaseGpr,
1076 SVExtra2VGpr,
1077 SVExtra2SGpr,
1078 SVExtra3Gpr,
1079 Stack,
1080 }
1081 }
1082 let mut seen = Seen::default();
1083 for base_ty in 0..BaseTy::LENGTH {
1084 let base_ty = BaseTy::from_usize(base_ty);
1085 for reg_len in [1, 2, 100] {
1086 let reg_len = NonZeroU32::new(reg_len).unwrap();
1087 let ty = Ty::new_or_scalar(TyFields { base_ty, reg_len });
1088 let non_fixed_choices = Constraint::non_fixed_choices_for_ty(ty);
1089 assert_eq!(non_fixed_choices.first(), Some(&Constraint::Any));
1090 assert_eq!(non_fixed_choices.last(), Some(&Constraint::Stack));
1091 for constraint in non_fixed_choices {
1092 assert_eq!(constraint.fixed_loc(), None);
1093 seen.add(constraint);
1094 if constraint.check_for_ty_mismatch(ty).is_err() {
1095 panic!("constraint ty mismatch: constraint={constraint:?} ty={ty:?}");
1096 }
1097 }
1098 }
1099 }
1100 seen.check();
1101 }
1102 }