fe_mir/analysis/
loop_tree.rs

1use id_arena::{Arena, Id};
2
3use fxhash::FxHashMap;
4
5use super::{cfg::ControlFlowGraph, domtree::DomTree};
6
7use crate::ir::BasicBlockId;
8
9#[derive(Debug, Default, Clone)]
10pub struct LoopTree {
11    /// Stores loops.
12    /// The index of an outer loops is guaranteed to be lower than its inner
13    /// loops because loops are found in RPO.
14    loops: Arena<Loop>,
15
16    /// Maps blocks to its contained loop.
17    /// If the block is contained by multiple nested loops, then the block is
18    /// mapped to the innermost loop.
19    block_to_loop: FxHashMap<BasicBlockId, LoopId>,
20}
21
22pub type LoopId = Id<Loop>;
23
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub struct Loop {
26    /// A header of the loop.
27    pub header: BasicBlockId,
28
29    /// A parent loop that includes the loop.
30    pub parent: Option<LoopId>,
31
32    /// Child loops that the loop includes.
33    pub children: Vec<LoopId>,
34}
35
36impl LoopTree {
37    pub fn compute(cfg: &ControlFlowGraph, domtree: &DomTree) -> Self {
38        let mut tree = LoopTree::default();
39
40        // Find loop headers in RPO, this means outer loops are guaranteed to be
41        // inserted first, then its inner loops are inserted.
42        for &block in domtree.rpo() {
43            for &pred in cfg.preds(block) {
44                if domtree.dominates(block, pred) {
45                    let loop_data = Loop {
46                        header: block,
47                        parent: None,
48                        children: Vec::new(),
49                    };
50
51                    tree.loops.alloc(loop_data);
52                    break;
53                }
54            }
55        }
56
57        tree.analyze_loops(cfg, domtree);
58
59        tree
60    }
61
62    /// Returns all blocks in the loop.
63    pub fn iter_blocks_post_order<'a, 'b>(
64        &'a self,
65        cfg: &'b ControlFlowGraph,
66        lp: LoopId,
67    ) -> BlocksInLoopPostOrder<'a, 'b> {
68        BlocksInLoopPostOrder::new(self, cfg, lp)
69    }
70
71    /// Returns all loops in a function body.
72    /// An outer loop is guaranteed to be iterated before its inner loops.
73    pub fn loops(&self) -> impl Iterator<Item = LoopId> + '_ {
74        self.loops.iter().map(|(id, _)| id)
75    }
76
77    /// Returns number of loops found.
78    pub fn loop_num(&self) -> usize {
79        self.loops.len()
80    }
81
82    /// Returns `true` if the `block` is in the `lp`.
83    pub fn is_block_in_loop(&self, block: BasicBlockId, lp: LoopId) -> bool {
84        let mut loop_of_block = self.loop_of_block(block);
85        while let Some(cur_lp) = loop_of_block {
86            if lp == cur_lp {
87                return true;
88            }
89            loop_of_block = self.parent_loop(cur_lp);
90        }
91        false
92    }
93
94    /// Returns header block of the `lp`.
95    pub fn loop_header(&self, lp: LoopId) -> BasicBlockId {
96        self.loops[lp].header
97    }
98
99    /// Get parent loop of the `lp` if exists.
100    pub fn parent_loop(&self, lp: LoopId) -> Option<LoopId> {
101        self.loops[lp].parent
102    }
103
104    /// Returns the loop that the `block` belongs to.
105    /// If the `block` belongs to multiple loops, then returns the innermost
106    /// loop.
107    pub fn loop_of_block(&self, block: BasicBlockId) -> Option<LoopId> {
108        self.block_to_loop.get(&block).copied()
109    }
110
111    /// Analyze loops. This method does
112    /// 1. Mapping each blocks to its contained loop.
113    /// 2. Setting parent and child of the loops.
114    fn analyze_loops(&mut self, cfg: &ControlFlowGraph, domtree: &DomTree) {
115        let mut worklist = vec![];
116
117        // Iterate loops reversely to ensure analyze inner loops first.
118        let loops_rev: Vec<_> = self.loops.iter().rev().map(|(id, _)| id).collect();
119        for cur_lp in loops_rev {
120            let cur_lp_header = self.loop_header(cur_lp);
121
122            // Add predecessors of the loop header to worklist.
123            for &block in cfg.preds(cur_lp_header) {
124                if domtree.dominates(cur_lp_header, block) {
125                    worklist.push(block);
126                }
127            }
128
129            while let Some(block) = worklist.pop() {
130                match self.block_to_loop.get(&block).copied() {
131                    Some(lp_of_block) => {
132                        let outermost_parent = self.outermost_parent(lp_of_block);
133
134                        // If outermost parent is current loop, then the block is already visited.
135                        if outermost_parent == cur_lp {
136                            continue;
137                        } else {
138                            self.loops[cur_lp].children.push(outermost_parent);
139                            self.loops[outermost_parent].parent = cur_lp.into();
140
141                            let lp_header_of_block = self.loop_header(lp_of_block);
142                            worklist.extend(cfg.preds(lp_header_of_block));
143                        }
144                    }
145
146                    // If the block is not mapped to any loops, then map it to the loop.
147                    None => {
148                        self.map_block(block, cur_lp);
149                        // If block is not loop header, then add its predecessors to the worklist.
150                        if block != cur_lp_header {
151                            worklist.extend(cfg.preds(block));
152                        }
153                    }
154                }
155            }
156        }
157    }
158
159    /// Returns the outermost parent loop of `lp`. If `lp` doesn't have any
160    /// parent, then returns `lp` itself.
161    fn outermost_parent(&self, mut lp: LoopId) -> LoopId {
162        while let Some(parent) = self.parent_loop(lp) {
163            lp = parent;
164        }
165        lp
166    }
167
168    /// Map `block` to `lp`.
169    fn map_block(&mut self, block: BasicBlockId, lp: LoopId) {
170        self.block_to_loop.insert(block, lp);
171    }
172}
173
174pub struct BlocksInLoopPostOrder<'a, 'b> {
175    lpt: &'a LoopTree,
176    cfg: &'b ControlFlowGraph,
177    lp: LoopId,
178    stack: Vec<BasicBlockId>,
179    block_state: FxHashMap<BasicBlockId, BlockState>,
180}
181
182impl<'a, 'b> BlocksInLoopPostOrder<'a, 'b> {
183    fn new(lpt: &'a LoopTree, cfg: &'b ControlFlowGraph, lp: LoopId) -> Self {
184        let loop_header = lpt.loop_header(lp);
185
186        Self {
187            lpt,
188            cfg,
189            lp,
190            stack: vec![loop_header],
191            block_state: FxHashMap::default(),
192        }
193    }
194}
195
196impl<'a, 'b> Iterator for BlocksInLoopPostOrder<'a, 'b> {
197    type Item = BasicBlockId;
198
199    fn next(&mut self) -> Option<Self::Item> {
200        while let Some(&block) = self.stack.last() {
201            match self.block_state.get(&block) {
202                // The block is already visited, but not returned from the iterator,
203                // so mark the block as `Finished` and return the block.
204                Some(BlockState::Visited) => {
205                    let block = self.stack.pop().unwrap();
206                    self.block_state.insert(block, BlockState::Finished);
207                    return Some(block);
208                }
209
210                // The block is already returned, so just remove the block from the stack.
211                Some(BlockState::Finished) => {
212                    self.stack.pop().unwrap();
213                }
214
215                // The block is not visited yet, so push its unvisited in-loop successors to the
216                // stack and mark the block as `Visited`.
217                None => {
218                    self.block_state.insert(block, BlockState::Visited);
219                    for &succ in self.cfg.succs(block) {
220                        if !self.block_state.contains_key(&succ)
221                            && self.lpt.is_block_in_loop(succ, self.lp)
222                        {
223                            self.stack.push(succ);
224                        }
225                    }
226                }
227            }
228        }
229
230        None
231    }
232}
233
234enum BlockState {
235    Visited,
236    Finished,
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242
243    use crate::ir::{body_builder::BodyBuilder, FunctionBody, FunctionId, SourceInfo, TypeId};
244
245    fn compute_loop(func: &FunctionBody) -> LoopTree {
246        let cfg = ControlFlowGraph::compute(func);
247        let domtree = DomTree::compute(&cfg);
248        LoopTree::compute(&cfg, &domtree)
249    }
250
251    fn body_builder() -> BodyBuilder {
252        BodyBuilder::new(FunctionId(0), SourceInfo::dummy())
253    }
254
255    #[test]
256    fn simple_loop() {
257        let mut builder = body_builder();
258
259        let entry = builder.current_block();
260        let block1 = builder.make_block();
261        let block2 = builder.make_block();
262
263        let dummy_ty = TypeId(0);
264        let v0 = builder.make_imm_from_bool(false, dummy_ty);
265        builder.branch(v0, block1, block2, SourceInfo::dummy());
266
267        builder.move_to_block(block1);
268        builder.jump(entry, SourceInfo::dummy());
269
270        builder.move_to_block(block2);
271        let dummy_value = builder.make_unit(dummy_ty);
272        builder.ret(dummy_value, SourceInfo::dummy());
273
274        let func = builder.build();
275
276        let lpt = compute_loop(&func);
277
278        assert_eq!(lpt.loop_num(), 1);
279        let lp = lpt.loops().next().unwrap();
280
281        assert!(lpt.is_block_in_loop(entry, lp));
282        assert_eq!(lpt.loop_of_block(entry), Some(lp));
283
284        assert!(lpt.is_block_in_loop(block1, lp));
285        assert_eq!(lpt.loop_of_block(block1), Some(lp));
286
287        assert!(!lpt.is_block_in_loop(block2, lp));
288        assert!(lpt.loop_of_block(block2).is_none());
289
290        assert_eq!(lpt.loop_header(lp), entry);
291    }
292
293    #[test]
294    fn nested_loop() {
295        let mut builder = body_builder();
296
297        let entry = builder.current_block();
298        let block1 = builder.make_block();
299        let block2 = builder.make_block();
300        let block3 = builder.make_block();
301
302        let dummy_ty = TypeId(0);
303        let v0 = builder.make_imm_from_bool(false, dummy_ty);
304        builder.branch(v0, block1, block3, SourceInfo::dummy());
305
306        builder.move_to_block(block1);
307        builder.branch(v0, entry, block2, SourceInfo::dummy());
308
309        builder.move_to_block(block2);
310        builder.jump(block1, SourceInfo::dummy());
311
312        builder.move_to_block(block3);
313        let dummy_value = builder.make_unit(dummy_ty);
314        builder.ret(dummy_value, SourceInfo::dummy());
315
316        let func = builder.build();
317
318        let lpt = compute_loop(&func);
319
320        assert_eq!(lpt.loop_num(), 2);
321        let mut loops = lpt.loops();
322        let outer_lp = loops.next().unwrap();
323        let inner_lp = loops.next().unwrap();
324
325        assert!(lpt.is_block_in_loop(entry, outer_lp));
326        assert!(!lpt.is_block_in_loop(entry, inner_lp));
327        assert_eq!(lpt.loop_of_block(entry), Some(outer_lp));
328
329        assert!(lpt.is_block_in_loop(block1, outer_lp));
330        assert!(lpt.is_block_in_loop(block1, inner_lp));
331        assert_eq!(lpt.loop_of_block(block1), Some(inner_lp));
332
333        assert!(lpt.is_block_in_loop(block2, outer_lp));
334        assert!(lpt.is_block_in_loop(block2, inner_lp));
335        assert_eq!(lpt.loop_of_block(block2), Some(inner_lp));
336
337        assert!(!lpt.is_block_in_loop(block3, outer_lp));
338        assert!(!lpt.is_block_in_loop(block3, inner_lp));
339        assert!(lpt.loop_of_block(block3).is_none());
340
341        assert!(lpt.parent_loop(outer_lp).is_none());
342        assert_eq!(lpt.parent_loop(inner_lp), Some(outer_lp));
343
344        assert_eq!(lpt.loop_header(outer_lp), entry);
345        assert_eq!(lpt.loop_header(inner_lp), block1);
346    }
347}