1use id_arena::{Arena, Id};
2
3use fxhash::FxHashMap;
4
5use super::{cfg::ControlFlowGraph, domtree::DomTree};
6
7use crate::ir::BasicBlockId;
8
9#[derive(Debug, Default, Clone)]
10pub struct LoopTree {
11 loops: Arena<Loop>,
15
16 block_to_loop: FxHashMap<BasicBlockId, LoopId>,
20}
21
22pub type LoopId = Id<Loop>;
23
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub struct Loop {
26 pub header: BasicBlockId,
28
29 pub parent: Option<LoopId>,
31
32 pub children: Vec<LoopId>,
34}
35
36impl LoopTree {
37 pub fn compute(cfg: &ControlFlowGraph, domtree: &DomTree) -> Self {
38 let mut tree = LoopTree::default();
39
40 for &block in domtree.rpo() {
43 for &pred in cfg.preds(block) {
44 if domtree.dominates(block, pred) {
45 let loop_data = Loop {
46 header: block,
47 parent: None,
48 children: Vec::new(),
49 };
50
51 tree.loops.alloc(loop_data);
52 break;
53 }
54 }
55 }
56
57 tree.analyze_loops(cfg, domtree);
58
59 tree
60 }
61
62 pub fn iter_blocks_post_order<'a, 'b>(
64 &'a self,
65 cfg: &'b ControlFlowGraph,
66 lp: LoopId,
67 ) -> BlocksInLoopPostOrder<'a, 'b> {
68 BlocksInLoopPostOrder::new(self, cfg, lp)
69 }
70
71 pub fn loops(&self) -> impl Iterator<Item = LoopId> + '_ {
74 self.loops.iter().map(|(id, _)| id)
75 }
76
77 pub fn loop_num(&self) -> usize {
79 self.loops.len()
80 }
81
82 pub fn is_block_in_loop(&self, block: BasicBlockId, lp: LoopId) -> bool {
84 let mut loop_of_block = self.loop_of_block(block);
85 while let Some(cur_lp) = loop_of_block {
86 if lp == cur_lp {
87 return true;
88 }
89 loop_of_block = self.parent_loop(cur_lp);
90 }
91 false
92 }
93
94 pub fn loop_header(&self, lp: LoopId) -> BasicBlockId {
96 self.loops[lp].header
97 }
98
99 pub fn parent_loop(&self, lp: LoopId) -> Option<LoopId> {
101 self.loops[lp].parent
102 }
103
104 pub fn loop_of_block(&self, block: BasicBlockId) -> Option<LoopId> {
108 self.block_to_loop.get(&block).copied()
109 }
110
111 fn analyze_loops(&mut self, cfg: &ControlFlowGraph, domtree: &DomTree) {
115 let mut worklist = vec![];
116
117 let loops_rev: Vec<_> = self.loops.iter().rev().map(|(id, _)| id).collect();
119 for cur_lp in loops_rev {
120 let cur_lp_header = self.loop_header(cur_lp);
121
122 for &block in cfg.preds(cur_lp_header) {
124 if domtree.dominates(cur_lp_header, block) {
125 worklist.push(block);
126 }
127 }
128
129 while let Some(block) = worklist.pop() {
130 match self.block_to_loop.get(&block).copied() {
131 Some(lp_of_block) => {
132 let outermost_parent = self.outermost_parent(lp_of_block);
133
134 if outermost_parent == cur_lp {
136 continue;
137 } else {
138 self.loops[cur_lp].children.push(outermost_parent);
139 self.loops[outermost_parent].parent = cur_lp.into();
140
141 let lp_header_of_block = self.loop_header(lp_of_block);
142 worklist.extend(cfg.preds(lp_header_of_block));
143 }
144 }
145
146 None => {
148 self.map_block(block, cur_lp);
149 if block != cur_lp_header {
151 worklist.extend(cfg.preds(block));
152 }
153 }
154 }
155 }
156 }
157 }
158
159 fn outermost_parent(&self, mut lp: LoopId) -> LoopId {
162 while let Some(parent) = self.parent_loop(lp) {
163 lp = parent;
164 }
165 lp
166 }
167
168 fn map_block(&mut self, block: BasicBlockId, lp: LoopId) {
170 self.block_to_loop.insert(block, lp);
171 }
172}
173
174pub struct BlocksInLoopPostOrder<'a, 'b> {
175 lpt: &'a LoopTree,
176 cfg: &'b ControlFlowGraph,
177 lp: LoopId,
178 stack: Vec<BasicBlockId>,
179 block_state: FxHashMap<BasicBlockId, BlockState>,
180}
181
182impl<'a, 'b> BlocksInLoopPostOrder<'a, 'b> {
183 fn new(lpt: &'a LoopTree, cfg: &'b ControlFlowGraph, lp: LoopId) -> Self {
184 let loop_header = lpt.loop_header(lp);
185
186 Self {
187 lpt,
188 cfg,
189 lp,
190 stack: vec![loop_header],
191 block_state: FxHashMap::default(),
192 }
193 }
194}
195
196impl<'a, 'b> Iterator for BlocksInLoopPostOrder<'a, 'b> {
197 type Item = BasicBlockId;
198
199 fn next(&mut self) -> Option<Self::Item> {
200 while let Some(&block) = self.stack.last() {
201 match self.block_state.get(&block) {
202 Some(BlockState::Visited) => {
205 let block = self.stack.pop().unwrap();
206 self.block_state.insert(block, BlockState::Finished);
207 return Some(block);
208 }
209
210 Some(BlockState::Finished) => {
212 self.stack.pop().unwrap();
213 }
214
215 None => {
218 self.block_state.insert(block, BlockState::Visited);
219 for &succ in self.cfg.succs(block) {
220 if !self.block_state.contains_key(&succ)
221 && self.lpt.is_block_in_loop(succ, self.lp)
222 {
223 self.stack.push(succ);
224 }
225 }
226 }
227 }
228 }
229
230 None
231 }
232}
233
234enum BlockState {
235 Visited,
236 Finished,
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 use crate::ir::{body_builder::BodyBuilder, FunctionBody, FunctionId, SourceInfo, TypeId};
244
245 fn compute_loop(func: &FunctionBody) -> LoopTree {
246 let cfg = ControlFlowGraph::compute(func);
247 let domtree = DomTree::compute(&cfg);
248 LoopTree::compute(&cfg, &domtree)
249 }
250
251 fn body_builder() -> BodyBuilder {
252 BodyBuilder::new(FunctionId(0), SourceInfo::dummy())
253 }
254
255 #[test]
256 fn simple_loop() {
257 let mut builder = body_builder();
258
259 let entry = builder.current_block();
260 let block1 = builder.make_block();
261 let block2 = builder.make_block();
262
263 let dummy_ty = TypeId(0);
264 let v0 = builder.make_imm_from_bool(false, dummy_ty);
265 builder.branch(v0, block1, block2, SourceInfo::dummy());
266
267 builder.move_to_block(block1);
268 builder.jump(entry, SourceInfo::dummy());
269
270 builder.move_to_block(block2);
271 let dummy_value = builder.make_unit(dummy_ty);
272 builder.ret(dummy_value, SourceInfo::dummy());
273
274 let func = builder.build();
275
276 let lpt = compute_loop(&func);
277
278 assert_eq!(lpt.loop_num(), 1);
279 let lp = lpt.loops().next().unwrap();
280
281 assert!(lpt.is_block_in_loop(entry, lp));
282 assert_eq!(lpt.loop_of_block(entry), Some(lp));
283
284 assert!(lpt.is_block_in_loop(block1, lp));
285 assert_eq!(lpt.loop_of_block(block1), Some(lp));
286
287 assert!(!lpt.is_block_in_loop(block2, lp));
288 assert!(lpt.loop_of_block(block2).is_none());
289
290 assert_eq!(lpt.loop_header(lp), entry);
291 }
292
293 #[test]
294 fn nested_loop() {
295 let mut builder = body_builder();
296
297 let entry = builder.current_block();
298 let block1 = builder.make_block();
299 let block2 = builder.make_block();
300 let block3 = builder.make_block();
301
302 let dummy_ty = TypeId(0);
303 let v0 = builder.make_imm_from_bool(false, dummy_ty);
304 builder.branch(v0, block1, block3, SourceInfo::dummy());
305
306 builder.move_to_block(block1);
307 builder.branch(v0, entry, block2, SourceInfo::dummy());
308
309 builder.move_to_block(block2);
310 builder.jump(block1, SourceInfo::dummy());
311
312 builder.move_to_block(block3);
313 let dummy_value = builder.make_unit(dummy_ty);
314 builder.ret(dummy_value, SourceInfo::dummy());
315
316 let func = builder.build();
317
318 let lpt = compute_loop(&func);
319
320 assert_eq!(lpt.loop_num(), 2);
321 let mut loops = lpt.loops();
322 let outer_lp = loops.next().unwrap();
323 let inner_lp = loops.next().unwrap();
324
325 assert!(lpt.is_block_in_loop(entry, outer_lp));
326 assert!(!lpt.is_block_in_loop(entry, inner_lp));
327 assert_eq!(lpt.loop_of_block(entry), Some(outer_lp));
328
329 assert!(lpt.is_block_in_loop(block1, outer_lp));
330 assert!(lpt.is_block_in_loop(block1, inner_lp));
331 assert_eq!(lpt.loop_of_block(block1), Some(inner_lp));
332
333 assert!(lpt.is_block_in_loop(block2, outer_lp));
334 assert!(lpt.is_block_in_loop(block2, inner_lp));
335 assert_eq!(lpt.loop_of_block(block2), Some(inner_lp));
336
337 assert!(!lpt.is_block_in_loop(block3, outer_lp));
338 assert!(!lpt.is_block_in_loop(block3, inner_lp));
339 assert!(lpt.loop_of_block(block3).is_none());
340
341 assert!(lpt.parent_loop(outer_lp).is_none());
342 assert_eq!(lpt.parent_loop(inner_lp), Some(outer_lp));
343
344 assert_eq!(lpt.loop_header(outer_lp), entry);
345 assert_eq!(lpt.loop_header(inner_lp), block1);
346 }
347}