fe_mir/lower/pattern_match/
decision_tree.rs1use std::io;
5
6use fe_analyzer::{
7 pattern_analysis::{
8 ConstructorKind, PatternMatrix, PatternRowVec, SigmaSet, SimplifiedPattern,
9 SimplifiedPatternKind,
10 },
11 AnalyzerDb,
12};
13use indexmap::IndexMap;
14use smol_str::SmolStr;
15
16use super::tree_vis::TreeRenderer;
17
18pub fn build_decision_tree(
19 db: &dyn AnalyzerDb,
20 pattern_matrix: &PatternMatrix,
21 policy: ColumnSelectionPolicy,
22) -> DecisionTree {
23 let builder = DecisionTreeBuilder::new(policy);
24 let simplified_arms = SimplifiedArmMatrix::new(pattern_matrix);
25
26 builder.build(db, simplified_arms)
27}
28
29#[derive(Debug)]
30pub enum DecisionTree {
31 Leaf(LeafNode),
32 Switch(SwitchNode),
33}
34
35impl DecisionTree {
36 #[allow(unused)]
37 pub fn dump_dot<W>(&self, db: &dyn AnalyzerDb, w: &mut W) -> io::Result<()>
38 where
39 W: io::Write,
40 {
41 let renderer = TreeRenderer::new(db, self);
42 dot2::render(&renderer, w).map_err(|err| match err {
43 dot2::Error::Io(err) => err,
44 _ => panic!("invalid graphviz id"),
45 })
46 }
47}
48
49#[derive(Debug)]
50pub struct LeafNode {
51 pub arm_idx: usize,
52 pub binds: IndexMap<(SmolStr, usize), Occurrence>,
53}
54
55impl LeafNode {
56 fn new(arm: SimplifiedArm, occurrences: &[Occurrence]) -> Self {
57 let arm_idx = arm.body;
58 let binds = arm.finalize_binds(occurrences);
59 Self { arm_idx, binds }
60 }
61}
62
63#[derive(Debug)]
64pub struct SwitchNode {
65 pub occurrence: Occurrence,
66 pub arms: Vec<(Case, DecisionTree)>,
67}
68
69#[derive(Debug, Clone, Copy)]
70pub enum Case {
71 Ctor(ConstructorKind),
72 Default,
73}
74
75#[derive(Debug, Clone, Default)]
76pub struct ColumnSelectionPolicy(Vec<ColumnScoringFunction>);
77
78impl ColumnSelectionPolicy {
79 pub fn arity(&mut self) -> &mut Self {
82 self.add_heuristic(ColumnScoringFunction::Arity)
83 }
84
85 pub fn small_branching(&mut self) -> &mut Self {
88 self.add_heuristic(ColumnScoringFunction::SmallBranching)
89 }
90
91 #[allow(unused)]
94 pub fn needed_column(&mut self) -> &mut Self {
95 self.add_heuristic(ColumnScoringFunction::NeededColumn)
96 }
97
98 pub fn needed_prefix(&mut self) -> &mut Self {
101 self.add_heuristic(ColumnScoringFunction::NeededPrefix)
102 }
103
104 fn select_column(&self, db: &dyn AnalyzerDb, mat: &SimplifiedArmMatrix) -> usize {
105 let mut candidates: Vec<_> = (0..mat.ncols()).collect();
106
107 for scoring_fn in &self.0 {
108 let mut max_score = i32::MIN;
109 for col in std::mem::take(&mut candidates) {
110 let score = scoring_fn.score(db, mat, col);
111 match score.cmp(&max_score) {
112 std::cmp::Ordering::Less => {}
113 std::cmp::Ordering::Equal => {
114 candidates.push(col);
115 }
116 std::cmp::Ordering::Greater => {
117 candidates = vec![col];
118 max_score = score;
119 }
120 }
121 }
122
123 if candidates.len() == 1 {
124 return candidates.pop().unwrap();
125 }
126 }
127
128 let mut shortest_occurrences = usize::MAX;
132 for col in std::mem::take(&mut candidates) {
133 let occurrences = mat.occurrences[col].len();
134 match occurrences.cmp(&shortest_occurrences) {
135 std::cmp::Ordering::Less => {
136 candidates = vec![col];
137 shortest_occurrences = occurrences;
138 }
139 std::cmp::Ordering::Equal => {
140 candidates.push(col);
141 }
142 std::cmp::Ordering::Greater => {}
143 }
144 }
145
146 candidates.pop().unwrap()
147 }
148
149 fn add_heuristic(&mut self, heuristic: ColumnScoringFunction) -> &mut Self {
150 debug_assert!(!self.0.contains(&heuristic));
151 self.0.push(heuristic);
152 self
153 }
154}
155
156#[derive(Clone, Debug, PartialEq, Eq, Hash)]
157pub struct Occurrence(Vec<usize>);
158
159impl Occurrence {
160 pub fn new() -> Self {
161 Self(vec![])
162 }
163
164 pub fn iter(&self) -> impl Iterator<Item = &usize> {
165 self.0.iter()
166 }
167
168 pub fn parent(&self) -> Option<Occurrence> {
169 let mut inner = self.0.clone();
170 inner.pop().map(|_| Occurrence(inner))
171 }
172
173 pub fn last_index(&self) -> Option<usize> {
174 self.0.last().cloned()
175 }
176
177 fn phi_specialize(&self, db: &dyn AnalyzerDb, ctor: ConstructorKind) -> Vec<Self> {
178 let arity = ctor.arity(db);
179 (0..arity)
180 .map(|i| {
181 let mut inner = self.0.clone();
182 inner.push(i);
183 Self(inner)
184 })
185 .collect()
186 }
187
188 fn len(&self) -> usize {
189 self.0.len()
190 }
191}
192
193struct DecisionTreeBuilder {
194 policy: ColumnSelectionPolicy,
195}
196
197impl DecisionTreeBuilder {
198 fn new(policy: ColumnSelectionPolicy) -> Self {
199 DecisionTreeBuilder { policy }
200 }
201
202 fn build(&self, db: &dyn AnalyzerDb, mut mat: SimplifiedArmMatrix) -> DecisionTree {
203 debug_assert!(mat.nrows() > 0, "unexhausted pattern matrix");
204
205 if mat.is_first_arm_satisfied() {
206 mat.arms.truncate(1);
207 return DecisionTree::Leaf(LeafNode::new(mat.arms.pop().unwrap(), &mat.occurrences));
208 }
209
210 let col = self.policy.select_column(db, &mat);
211 mat.swap(col);
212
213 let mut switch_arms = vec![];
214 let occurrence = &mat.occurrences[0];
215 let sigma_set = mat.sigma_set(0);
216 for &ctor in sigma_set.iter() {
217 let destructured_mat = mat.phi_specialize(db, ctor, occurrence);
218 let subtree = self.build(db, destructured_mat);
219 switch_arms.push((Case::Ctor(ctor), subtree));
220 }
221
222 if !sigma_set.is_complete(db) {
223 let destructured_mat = mat.d_specialize(db, occurrence);
224 let subtree = self.build(db, destructured_mat);
225 switch_arms.push((Case::Default, subtree));
226 }
227
228 DecisionTree::Switch(SwitchNode {
229 occurrence: occurrence.clone(),
230 arms: switch_arms,
231 })
232 }
233}
234
235#[derive(Clone, Debug)]
236struct SimplifiedArmMatrix {
237 arms: Vec<SimplifiedArm>,
238 occurrences: Vec<Occurrence>,
239}
240
241impl SimplifiedArmMatrix {
242 fn new(mat: &PatternMatrix) -> Self {
243 let cols = mat.ncols();
244 let arms: Vec<_> = mat
245 .rows()
246 .iter()
247 .enumerate()
248 .map(|(body, pat)| SimplifiedArm::new(pat, body))
249 .collect();
250 let occurrences = vec![Occurrence::new(); cols];
251
252 SimplifiedArmMatrix { arms, occurrences }
253 }
254
255 fn nrows(&self) -> usize {
256 self.arms.len()
257 }
258
259 fn ncols(&self) -> usize {
260 self.arms[0].pat_vec.len()
261 }
262
263 fn pat(&self, row: usize, col: usize) -> &SimplifiedPattern {
264 self.arms[row].pat(col)
265 }
266
267 fn necessity_matrix(&self, db: &dyn AnalyzerDb) -> NecessityMatrix {
268 NecessityMatrix::from_mat(db, self)
269 }
270
271 fn reduced_pat_mat(&self, col: usize) -> PatternMatrix {
272 let mut rows = Vec::with_capacity(self.nrows());
273 for arm in self.arms.iter() {
274 let reduced_pat_vec = arm
275 .pat_vec
276 .pats()
277 .iter()
278 .enumerate()
279 .filter(|(i, _)| (*i != col))
280 .map(|(_, pat)| pat.clone())
281 .collect();
282 rows.push(PatternRowVec::new(reduced_pat_vec));
283 }
284
285 PatternMatrix::new(rows)
286 }
287
288 fn sigma_set(&self, col: usize) -> SigmaSet {
290 SigmaSet::from_rows(self.arms.iter().map(|arm| &arm.pat_vec), col)
291 }
292
293 fn is_first_arm_satisfied(&self) -> bool {
294 self.arms[0]
295 .pat_vec
296 .pats()
297 .iter()
298 .all(SimplifiedPattern::is_wildcard)
299 }
300
301 fn phi_specialize(
302 &self,
303 db: &dyn AnalyzerDb,
304 ctor: ConstructorKind,
305 occurrence: &Occurrence,
306 ) -> Self {
307 let mut new_arms = Vec::new();
308 for arm in &self.arms {
309 new_arms.extend_from_slice(&arm.phi_specialize(db, ctor, occurrence));
310 }
311
312 let mut new_occurrences = self.occurrences[0].phi_specialize(db, ctor);
313 new_occurrences.extend_from_slice(&self.occurrences.as_slice()[1..]);
314
315 Self {
316 arms: new_arms,
317 occurrences: new_occurrences,
318 }
319 }
320
321 fn d_specialize(&self, db: &dyn AnalyzerDb, occurrence: &Occurrence) -> Self {
322 let mut new_arms = Vec::new();
323 for arm in &self.arms {
324 new_arms.extend_from_slice(&arm.d_specialize(db, occurrence));
325 }
326
327 Self {
328 arms: new_arms,
329 occurrences: self.occurrences.as_slice()[1..].to_vec(),
330 }
331 }
332
333 fn swap(&mut self, i: usize) {
334 for arm in &mut self.arms {
335 arm.swap(0, i)
336 }
337 self.occurrences.swap(0, i);
338 }
339}
340
341#[derive(Clone, Debug)]
342struct SimplifiedArm {
343 pat_vec: PatternRowVec,
344 body: usize,
345 binds: IndexMap<(SmolStr, usize), Occurrence>,
346}
347
348impl SimplifiedArm {
349 fn new(pat: &PatternRowVec, body: usize) -> Self {
350 let pat = PatternRowVec::new(pat.inner.iter().map(generalize_pattern).collect());
351 Self {
352 pat_vec: pat,
353 body,
354 binds: IndexMap::new(),
355 }
356 }
357
358 fn len(&self) -> usize {
359 self.pat_vec.len()
360 }
361
362 fn pat(&self, col: usize) -> &SimplifiedPattern {
363 &self.pat_vec.inner[col]
364 }
365
366 fn phi_specialize(
367 &self,
368 db: &dyn AnalyzerDb,
369 ctor: ConstructorKind,
370 occurrence: &Occurrence,
371 ) -> Vec<Self> {
372 let body = self.body;
373 let binds = self.new_binds(occurrence);
374
375 self.pat_vec
376 .phi_specialize(db, ctor)
377 .into_iter()
378 .map(|pat| SimplifiedArm {
379 pat_vec: pat,
380 body,
381 binds: binds.clone(),
382 })
383 .collect()
384 }
385
386 fn d_specialize(&self, db: &dyn AnalyzerDb, occurrence: &Occurrence) -> Vec<Self> {
387 let body = self.body;
388 let binds = self.new_binds(occurrence);
389
390 self.pat_vec
391 .d_specialize(db)
392 .into_iter()
393 .map(|pat| SimplifiedArm {
394 pat_vec: pat,
395 body,
396 binds: binds.clone(),
397 })
398 .collect()
399 }
400
401 fn new_binds(&self, occurrence: &Occurrence) -> IndexMap<(SmolStr, usize), Occurrence> {
402 let mut binds = self.binds.clone();
403 if let Some(SimplifiedPatternKind::WildCard(Some(bind))) =
404 self.pat_vec.head().map(|pat| &pat.kind)
405 {
406 binds.entry(bind.clone()).or_insert(occurrence.clone());
407 }
408 binds
409 }
410
411 fn finalize_binds(self, occurrences: &[Occurrence]) -> IndexMap<(SmolStr, usize), Occurrence> {
412 debug_assert!(self.len() == occurrences.len());
413
414 let mut binds = self.binds;
415 for (pat, occurrence) in self.pat_vec.pats().iter().zip(occurrences.iter()) {
416 debug_assert!(pat.is_wildcard());
417
418 if let SimplifiedPatternKind::WildCard(Some(bind)) = &pat.kind {
419 binds.entry(bind.clone()).or_insert(occurrence.clone());
420 }
421 }
422
423 binds
424 }
425
426 fn swap(&mut self, i: usize, j: usize) {
427 self.pat_vec.swap(i, j);
428 }
429}
430
431struct NecessityMatrix {
432 data: Vec<bool>,
433 ncol: usize,
434 nrow: usize,
435}
436
437impl NecessityMatrix {
438 fn new(ncol: usize, nrow: usize) -> Self {
439 let data = vec![false; ncol * nrow];
440 Self { data, ncol, nrow }
441 }
442
443 fn from_mat(db: &dyn AnalyzerDb, mat: &SimplifiedArmMatrix) -> Self {
444 let nrow = mat.nrows();
445 let ncol = mat.ncols();
446 let mut necessity_mat = Self::new(ncol, nrow);
447
448 necessity_mat.compute(db, mat);
449 necessity_mat
450 }
451
452 fn compute(&mut self, db: &dyn AnalyzerDb, mat: &SimplifiedArmMatrix) {
453 for row in 0..self.nrow {
454 for col in 0..self.ncol {
455 let pat = mat.pat(row, col);
456 let pos = self.pos(row, col);
457
458 if !pat.is_wildcard() {
459 self.data[pos] = true;
460 } else {
461 let reduced_pat_mat = mat.reduced_pat_mat(col);
462 self.data[pos] = !reduced_pat_mat.is_row_useful(db, row);
463 }
464 }
465 }
466 }
467
468 fn compute_needed_column_score(&self, col: usize) -> i32 {
469 let mut num = 0;
470 for i in 0..self.nrow {
471 if self.data[self.pos(i, col)] {
472 num += 1;
473 }
474 }
475
476 num
477 }
478
479 fn compute_needed_prefix_score(&self, col: usize) -> i32 {
480 let mut current_row = 0;
481 for i in 0..self.nrow {
482 if self.data[self.pos(i, col)] {
483 current_row += 1;
484 } else {
485 return current_row;
486 }
487 }
488
489 current_row
490 }
491
492 fn pos(&self, row: usize, col: usize) -> usize {
493 self.ncol * row + col
494 }
495}
496
497#[derive(Debug, Clone, Copy, PartialEq, Eq)]
498enum ColumnScoringFunction {
499 Arity,
502
503 SmallBranching,
506
507 NeededColumn,
510
511 NeededPrefix,
512}
513
514impl ColumnScoringFunction {
515 fn score(&self, db: &dyn AnalyzerDb, mat: &SimplifiedArmMatrix, col: usize) -> i32 {
516 match self {
517 ColumnScoringFunction::Arity => mat
518 .sigma_set(col)
519 .iter()
520 .map(|c| -(c.arity(db) as i32))
521 .sum(),
522
523 ColumnScoringFunction::SmallBranching => {
524 let sigma_set = mat.sigma_set(col);
525 let score = -(mat.sigma_set(col).len() as i32);
526 if sigma_set.is_complete(db) {
527 score
528 } else {
529 score - 1
530 }
531 }
532
533 ColumnScoringFunction::NeededColumn => {
534 mat.necessity_matrix(db).compute_needed_column_score(col)
535 }
536
537 ColumnScoringFunction::NeededPrefix => {
538 mat.necessity_matrix(db).compute_needed_prefix_score(col)
539 }
540 }
541 }
542}
543
544fn generalize_pattern(pat: &SimplifiedPattern) -> SimplifiedPattern {
545 match &pat.kind {
546 SimplifiedPatternKind::WildCard(_) => pat.clone(),
547
548 SimplifiedPatternKind::Constructor { kind, fields } => {
549 let fields = fields.iter().map(generalize_pattern).collect();
550 let kind = SimplifiedPatternKind::Constructor {
551 kind: *kind,
552 fields,
553 };
554 SimplifiedPattern::new(kind, pat.ty)
555 }
556
557 SimplifiedPatternKind::Or(pats) => {
558 let mut gen_pats = vec![];
559 for pat in pats {
560 let gen_pad = generalize_pattern(pat);
561 if gen_pad.is_wildcard() {
562 gen_pats.push(gen_pad);
563 break;
564 } else {
565 gen_pats.push(gen_pad);
566 }
567 }
568
569 if gen_pats.len() == 1 {
570 gen_pats.pop().unwrap()
571 } else {
572 SimplifiedPattern::new(SimplifiedPatternKind::Or(gen_pats), pat.ty)
573 }
574 }
575 }
576}