1use std::collections::BTreeSet;
7
8use fxhash::FxHashMap;
9
10use crate::ir::BasicBlockId;
11
12use super::cfg::ControlFlowGraph;
13
14#[derive(Debug, Clone)]
15pub struct DomTree {
16 doms: FxHashMap<BasicBlockId, BasicBlockId>,
17 rpo: Vec<BasicBlockId>,
19}
20
21impl DomTree {
22 pub fn compute(cfg: &ControlFlowGraph) -> Self {
23 let mut doms = FxHashMap::default();
24 doms.insert(cfg.entry(), cfg.entry());
25 let mut rpo: Vec<_> = cfg.post_order().collect();
26 rpo.reverse();
27
28 let mut domtree = Self { doms, rpo };
29
30 let block_num = domtree.rpo.len();
31
32 let mut rpo_nums = FxHashMap::default();
33 for (i, &block) in domtree.rpo.iter().enumerate() {
34 rpo_nums.insert(block, (block_num - i) as u32);
35 }
36
37 let mut changed = true;
38 while changed {
39 changed = false;
40 for &block in domtree.rpo.iter().skip(1) {
41 let processed_pred = match cfg
42 .preds(block)
43 .iter()
44 .find(|pred| domtree.doms.contains_key(pred))
45 {
46 Some(pred) => *pred,
47 _ => continue,
48 };
49 let mut new_dom = processed_pred;
50
51 for &pred in cfg.preds(block) {
52 if pred != processed_pred && domtree.doms.contains_key(&pred) {
53 new_dom = domtree.intersect(new_dom, pred, &rpo_nums);
54 }
55 }
56 if Some(new_dom) != domtree.doms.get(&block).copied() {
57 changed = true;
58 domtree.doms.insert(block, new_dom);
59 }
60 }
61 }
62
63 domtree
64 }
65
66 pub fn idom(&self, block: BasicBlockId) -> Option<BasicBlockId> {
70 if self.rpo[0] == block {
71 return None;
72 }
73 self.doms.get(&block).copied()
74 }
75
76 pub fn strictly_dominates(&self, block1: BasicBlockId, block2: BasicBlockId) -> bool {
78 let mut current_block = block2;
79 while let Some(block) = self.idom(current_block) {
80 if block == block1 {
81 return true;
82 }
83 current_block = block;
84 }
85
86 false
87 }
88
89 pub fn dominates(&self, block1: BasicBlockId, block2: BasicBlockId) -> bool {
91 if block1 == block2 {
92 return true;
93 }
94
95 self.strictly_dominates(block1, block2)
96 }
97
98 pub fn is_reachable(&self, block: BasicBlockId) -> bool {
100 self.idom(block).is_some()
101 }
102
103 pub fn rpo(&self) -> &[BasicBlockId] {
105 &self.rpo
106 }
107
108 fn intersect(
109 &self,
110 mut b1: BasicBlockId,
111 mut b2: BasicBlockId,
112 rpo_nums: &FxHashMap<BasicBlockId, u32>,
113 ) -> BasicBlockId {
114 while b1 != b2 {
115 while rpo_nums[&b1] < rpo_nums[&b2] {
116 b1 = self.doms[&b1];
117 }
118 while rpo_nums[&b2] < rpo_nums[&b1] {
119 b2 = self.doms[&b2]
120 }
121 }
122
123 b1
124 }
125
126 pub fn compute_df(&self, cfg: &ControlFlowGraph) -> DFSet {
128 let mut df = DFSet::default();
129
130 for &block in &self.rpo {
131 let preds = cfg.preds(block);
132 if preds.len() < 2 {
133 continue;
134 }
135
136 for pred in preds {
137 let mut runner = *pred;
138 while self.doms.get(&block) != Some(&runner) && self.is_reachable(runner) {
139 df.0.entry(runner).or_default().insert(block);
140 runner = self.doms[&runner];
141 }
142 }
143 }
144
145 df
146 }
147}
148
149#[derive(Default, Debug)]
151pub struct DFSet(FxHashMap<BasicBlockId, BTreeSet<BasicBlockId>>);
152
153impl DFSet {
154 pub fn frontiers(
156 &self,
157 block: BasicBlockId,
158 ) -> Option<impl Iterator<Item = BasicBlockId> + '_> {
159 self.0.get(&block).map(|set| set.iter().copied())
160 }
161
162 pub fn frontier_num(&self, block: BasicBlockId) -> usize {
164 self.0.get(&block).map(BTreeSet::len).unwrap_or(0)
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171
172 use crate::ir::{body_builder::BodyBuilder, FunctionBody, FunctionId, SourceInfo, TypeId};
173
174 fn calc_dom(func: &FunctionBody) -> (DomTree, DFSet) {
175 let cfg = ControlFlowGraph::compute(func);
176 let domtree = DomTree::compute(&cfg);
177 let df = domtree.compute_df(&cfg);
178 (domtree, df)
179 }
180
181 fn body_builder() -> BodyBuilder {
182 BodyBuilder::new(FunctionId(0), SourceInfo::dummy())
183 }
184
185 #[test]
186 fn dom_tree_if_else() {
187 let mut builder = body_builder();
188
189 let then_block = builder.make_block();
190 let else_block = builder.make_block();
191 let merge_block = builder.make_block();
192
193 let dummy_ty = TypeId(0);
194 let v0 = builder.make_imm_from_bool(true, dummy_ty);
195 builder.branch(v0, then_block, else_block, SourceInfo::dummy());
196
197 builder.move_to_block(then_block);
198 builder.jump(merge_block, SourceInfo::dummy());
199
200 builder.move_to_block(else_block);
201 builder.jump(merge_block, SourceInfo::dummy());
202
203 builder.move_to_block(merge_block);
204 let dummy_value = builder.make_unit(dummy_ty);
205 builder.ret(dummy_value, SourceInfo::dummy());
206
207 let func = builder.build();
208
209 let (dom_tree, df) = calc_dom(&func);
210 let entry_block = func.order.entry();
211 assert_eq!(dom_tree.idom(entry_block), None);
212 assert_eq!(dom_tree.idom(then_block), Some(entry_block));
213 assert_eq!(dom_tree.idom(else_block), Some(entry_block));
214 assert_eq!(dom_tree.idom(merge_block), Some(entry_block));
215
216 assert_eq!(df.frontier_num(entry_block), 0);
217 assert_eq!(df.frontier_num(then_block), 1);
218 assert_eq!(
219 df.frontiers(then_block).unwrap().next().unwrap(),
220 merge_block
221 );
222 assert_eq!(
223 df.frontiers(else_block).unwrap().next().unwrap(),
224 merge_block
225 );
226 assert_eq!(df.frontier_num(merge_block), 0);
227 }
228
229 #[test]
230 fn unreachable_edge() {
231 let mut builder = body_builder();
232
233 let block1 = builder.make_block();
234 let block2 = builder.make_block();
235 let block3 = builder.make_block();
236 let block4 = builder.make_block();
237
238 let dummy_ty = TypeId(0);
239 let v0 = builder.make_imm_from_bool(true, dummy_ty);
240 builder.branch(v0, block1, block2, SourceInfo::dummy());
241
242 builder.move_to_block(block1);
243 builder.jump(block4, SourceInfo::dummy());
244
245 builder.move_to_block(block2);
246 builder.jump(block4, SourceInfo::dummy());
247
248 builder.move_to_block(block3);
249 builder.jump(block4, SourceInfo::dummy());
250
251 builder.move_to_block(block4);
252 let dummy_value = builder.make_unit(dummy_ty);
253 builder.ret(dummy_value, SourceInfo::dummy());
254
255 let func = builder.build();
256
257 let (dom_tree, _) = calc_dom(&func);
258 let entry_block = func.order.entry();
259 assert_eq!(dom_tree.idom(entry_block), None);
260 assert_eq!(dom_tree.idom(block1), Some(entry_block));
261 assert_eq!(dom_tree.idom(block2), Some(entry_block));
262 assert_eq!(dom_tree.idom(block3), None);
263 assert!(!dom_tree.is_reachable(block3));
264 assert_eq!(dom_tree.idom(block4), Some(entry_block));
265 }
266
267 #[test]
268 fn dom_tree_complex() {
269 let mut builder = body_builder();
270
271 let block1 = builder.make_block();
272 let block2 = builder.make_block();
273 let block3 = builder.make_block();
274 let block4 = builder.make_block();
275 let block5 = builder.make_block();
276 let block6 = builder.make_block();
277 let block7 = builder.make_block();
278 let block8 = builder.make_block();
279 let block9 = builder.make_block();
280 let block10 = builder.make_block();
281 let block11 = builder.make_block();
282 let block12 = builder.make_block();
283
284 let dummy_ty = TypeId(0);
285 let v0 = builder.make_imm_from_bool(true, dummy_ty);
286 builder.branch(v0, block2, block1, SourceInfo::dummy());
287
288 builder.move_to_block(block1);
289 builder.branch(v0, block6, block3, SourceInfo::dummy());
290
291 builder.move_to_block(block2);
292 builder.branch(v0, block7, block4, SourceInfo::dummy());
293
294 builder.move_to_block(block3);
295 builder.branch(v0, block6, block5, SourceInfo::dummy());
296
297 builder.move_to_block(block4);
298 builder.branch(v0, block7, block2, SourceInfo::dummy());
299
300 builder.move_to_block(block5);
301 builder.branch(v0, block10, block8, SourceInfo::dummy());
302
303 builder.move_to_block(block6);
304 builder.jump(block9, SourceInfo::dummy());
305
306 builder.move_to_block(block7);
307 builder.jump(block12, SourceInfo::dummy());
308
309 builder.move_to_block(block8);
310 builder.jump(block11, SourceInfo::dummy());
311
312 builder.move_to_block(block9);
313 builder.jump(block8, SourceInfo::dummy());
314
315 builder.move_to_block(block10);
316 builder.jump(block11, SourceInfo::dummy());
317
318 builder.move_to_block(block11);
319 builder.branch(v0, block12, block2, SourceInfo::dummy());
320
321 builder.move_to_block(block12);
322 let dummy_value = builder.make_unit(dummy_ty);
323 builder.ret(dummy_value, SourceInfo::dummy());
324
325 let func = builder.build();
326
327 let (dom_tree, _) = calc_dom(&func);
328 let entry_block = func.order.entry();
329 assert_eq!(dom_tree.idom(entry_block), None);
330 assert_eq!(dom_tree.idom(block1), Some(entry_block));
331 assert_eq!(dom_tree.idom(block2), Some(entry_block));
332 assert_eq!(dom_tree.idom(block3), Some(block1));
333 assert_eq!(dom_tree.idom(block4), Some(block2));
334 assert_eq!(dom_tree.idom(block5), Some(block3));
335 assert_eq!(dom_tree.idom(block6), Some(block1));
336 assert_eq!(dom_tree.idom(block7), Some(block2));
337 assert_eq!(dom_tree.idom(block8), Some(block1));
338 assert_eq!(dom_tree.idom(block9), Some(block6));
339 assert_eq!(dom_tree.idom(block10), Some(block5));
340 assert_eq!(dom_tree.idom(block11), Some(block1));
341 assert_eq!(dom_tree.idom(block12), Some(entry_block));
342 }
343}