fe_mir/lower/pattern_match/
decision_tree.rs

1//! This module contains the decision tree definition and its construction
2//! function.
3//! The algorithm for efficient decision tree construction is mainly based on [Compiling pattern matching to good decision trees](https://dl.acm.org/doi/10.1145/1411304.1411311).
4use std::io;
5
6use fe_analyzer::{
7    pattern_analysis::{
8        ConstructorKind, PatternMatrix, PatternRowVec, SigmaSet, SimplifiedPattern,
9        SimplifiedPatternKind,
10    },
11    AnalyzerDb,
12};
13use indexmap::IndexMap;
14use smol_str::SmolStr;
15
16use super::tree_vis::TreeRenderer;
17
18pub fn build_decision_tree(
19    db: &dyn AnalyzerDb,
20    pattern_matrix: &PatternMatrix,
21    policy: ColumnSelectionPolicy,
22) -> DecisionTree {
23    let builder = DecisionTreeBuilder::new(policy);
24    let simplified_arms = SimplifiedArmMatrix::new(pattern_matrix);
25
26    builder.build(db, simplified_arms)
27}
28
29#[derive(Debug)]
30pub enum DecisionTree {
31    Leaf(LeafNode),
32    Switch(SwitchNode),
33}
34
35impl DecisionTree {
36    #[allow(unused)]
37    pub fn dump_dot<W>(&self, db: &dyn AnalyzerDb, w: &mut W) -> io::Result<()>
38    where
39        W: io::Write,
40    {
41        let renderer = TreeRenderer::new(db, self);
42        dot2::render(&renderer, w).map_err(|err| match err {
43            dot2::Error::Io(err) => err,
44            _ => panic!("invalid graphviz id"),
45        })
46    }
47}
48
49#[derive(Debug)]
50pub struct LeafNode {
51    pub arm_idx: usize,
52    pub binds: IndexMap<(SmolStr, usize), Occurrence>,
53}
54
55impl LeafNode {
56    fn new(arm: SimplifiedArm, occurrences: &[Occurrence]) -> Self {
57        let arm_idx = arm.body;
58        let binds = arm.finalize_binds(occurrences);
59        Self { arm_idx, binds }
60    }
61}
62
63#[derive(Debug)]
64pub struct SwitchNode {
65    pub occurrence: Occurrence,
66    pub arms: Vec<(Case, DecisionTree)>,
67}
68
69#[derive(Debug, Clone, Copy)]
70pub enum Case {
71    Ctor(ConstructorKind),
72    Default,
73}
74
75#[derive(Debug, Clone, Default)]
76pub struct ColumnSelectionPolicy(Vec<ColumnScoringFunction>);
77
78impl ColumnSelectionPolicy {
79    /// The score of column i is the sum of the negation of the arities of
80    /// constructors in sigma(i).
81    pub fn arity(&mut self) -> &mut Self {
82        self.add_heuristic(ColumnScoringFunction::Arity)
83    }
84
85    /// The score is the negation of the cardinal of sigma(i), C(Sigma(i)).
86    /// If sigma(i) is NOT complete, the resulting score is C(Sigma(i)) - 1.
87    pub fn small_branching(&mut self) -> &mut Self {
88        self.add_heuristic(ColumnScoringFunction::SmallBranching)
89    }
90
91    /// The score is the number of needed rows of column i in the necessity
92    /// matrix.
93    #[allow(unused)]
94    pub fn needed_column(&mut self) -> &mut Self {
95        self.add_heuristic(ColumnScoringFunction::NeededColumn)
96    }
97
98    /// The score is the larger row index j such that column i is needed for all
99    /// rows j′; 1 ≤ j′ ≤ j.
100    pub fn needed_prefix(&mut self) -> &mut Self {
101        self.add_heuristic(ColumnScoringFunction::NeededPrefix)
102    }
103
104    fn select_column(&self, db: &dyn AnalyzerDb, mat: &SimplifiedArmMatrix) -> usize {
105        let mut candidates: Vec<_> = (0..mat.ncols()).collect();
106
107        for scoring_fn in &self.0 {
108            let mut max_score = i32::MIN;
109            for col in std::mem::take(&mut candidates) {
110                let score = scoring_fn.score(db, mat, col);
111                match score.cmp(&max_score) {
112                    std::cmp::Ordering::Less => {}
113                    std::cmp::Ordering::Equal => {
114                        candidates.push(col);
115                    }
116                    std::cmp::Ordering::Greater => {
117                        candidates = vec![col];
118                        max_score = score;
119                    }
120                }
121            }
122
123            if candidates.len() == 1 {
124                return candidates.pop().unwrap();
125            }
126        }
127
128        // If there are more than one candidates remained, filter the columns with the
129        // shortest occurrences among the candidates, then select the rightmost one.
130        // This heuristics corresponds to the R pseudo heuristic in the paper.
131        let mut shortest_occurrences = usize::MAX;
132        for col in std::mem::take(&mut candidates) {
133            let occurrences = mat.occurrences[col].len();
134            match occurrences.cmp(&shortest_occurrences) {
135                std::cmp::Ordering::Less => {
136                    candidates = vec![col];
137                    shortest_occurrences = occurrences;
138                }
139                std::cmp::Ordering::Equal => {
140                    candidates.push(col);
141                }
142                std::cmp::Ordering::Greater => {}
143            }
144        }
145
146        candidates.pop().unwrap()
147    }
148
149    fn add_heuristic(&mut self, heuristic: ColumnScoringFunction) -> &mut Self {
150        debug_assert!(!self.0.contains(&heuristic));
151        self.0.push(heuristic);
152        self
153    }
154}
155
156#[derive(Clone, Debug, PartialEq, Eq, Hash)]
157pub struct Occurrence(Vec<usize>);
158
159impl Occurrence {
160    pub fn new() -> Self {
161        Self(vec![])
162    }
163
164    pub fn iter(&self) -> impl Iterator<Item = &usize> {
165        self.0.iter()
166    }
167
168    pub fn parent(&self) -> Option<Occurrence> {
169        let mut inner = self.0.clone();
170        inner.pop().map(|_| Occurrence(inner))
171    }
172
173    pub fn last_index(&self) -> Option<usize> {
174        self.0.last().cloned()
175    }
176
177    fn phi_specialize(&self, db: &dyn AnalyzerDb, ctor: ConstructorKind) -> Vec<Self> {
178        let arity = ctor.arity(db);
179        (0..arity)
180            .map(|i| {
181                let mut inner = self.0.clone();
182                inner.push(i);
183                Self(inner)
184            })
185            .collect()
186    }
187
188    fn len(&self) -> usize {
189        self.0.len()
190    }
191}
192
193struct DecisionTreeBuilder {
194    policy: ColumnSelectionPolicy,
195}
196
197impl DecisionTreeBuilder {
198    fn new(policy: ColumnSelectionPolicy) -> Self {
199        DecisionTreeBuilder { policy }
200    }
201
202    fn build(&self, db: &dyn AnalyzerDb, mut mat: SimplifiedArmMatrix) -> DecisionTree {
203        debug_assert!(mat.nrows() > 0, "unexhausted pattern matrix");
204
205        if mat.is_first_arm_satisfied() {
206            mat.arms.truncate(1);
207            return DecisionTree::Leaf(LeafNode::new(mat.arms.pop().unwrap(), &mat.occurrences));
208        }
209
210        let col = self.policy.select_column(db, &mat);
211        mat.swap(col);
212
213        let mut switch_arms = vec![];
214        let occurrence = &mat.occurrences[0];
215        let sigma_set = mat.sigma_set(0);
216        for &ctor in sigma_set.iter() {
217            let destructured_mat = mat.phi_specialize(db, ctor, occurrence);
218            let subtree = self.build(db, destructured_mat);
219            switch_arms.push((Case::Ctor(ctor), subtree));
220        }
221
222        if !sigma_set.is_complete(db) {
223            let destructured_mat = mat.d_specialize(db, occurrence);
224            let subtree = self.build(db, destructured_mat);
225            switch_arms.push((Case::Default, subtree));
226        }
227
228        DecisionTree::Switch(SwitchNode {
229            occurrence: occurrence.clone(),
230            arms: switch_arms,
231        })
232    }
233}
234
235#[derive(Clone, Debug)]
236struct SimplifiedArmMatrix {
237    arms: Vec<SimplifiedArm>,
238    occurrences: Vec<Occurrence>,
239}
240
241impl SimplifiedArmMatrix {
242    fn new(mat: &PatternMatrix) -> Self {
243        let cols = mat.ncols();
244        let arms: Vec<_> = mat
245            .rows()
246            .iter()
247            .enumerate()
248            .map(|(body, pat)| SimplifiedArm::new(pat, body))
249            .collect();
250        let occurrences = vec![Occurrence::new(); cols];
251
252        SimplifiedArmMatrix { arms, occurrences }
253    }
254
255    fn nrows(&self) -> usize {
256        self.arms.len()
257    }
258
259    fn ncols(&self) -> usize {
260        self.arms[0].pat_vec.len()
261    }
262
263    fn pat(&self, row: usize, col: usize) -> &SimplifiedPattern {
264        self.arms[row].pat(col)
265    }
266
267    fn necessity_matrix(&self, db: &dyn AnalyzerDb) -> NecessityMatrix {
268        NecessityMatrix::from_mat(db, self)
269    }
270
271    fn reduced_pat_mat(&self, col: usize) -> PatternMatrix {
272        let mut rows = Vec::with_capacity(self.nrows());
273        for arm in self.arms.iter() {
274            let reduced_pat_vec = arm
275                .pat_vec
276                .pats()
277                .iter()
278                .enumerate()
279                .filter(|(i, _)| (*i != col))
280                .map(|(_, pat)| pat.clone())
281                .collect();
282            rows.push(PatternRowVec::new(reduced_pat_vec));
283        }
284
285        PatternMatrix::new(rows)
286    }
287
288    /// Returns the constructor set in the column i.
289    fn sigma_set(&self, col: usize) -> SigmaSet {
290        SigmaSet::from_rows(self.arms.iter().map(|arm| &arm.pat_vec), col)
291    }
292
293    fn is_first_arm_satisfied(&self) -> bool {
294        self.arms[0]
295            .pat_vec
296            .pats()
297            .iter()
298            .all(SimplifiedPattern::is_wildcard)
299    }
300
301    fn phi_specialize(
302        &self,
303        db: &dyn AnalyzerDb,
304        ctor: ConstructorKind,
305        occurrence: &Occurrence,
306    ) -> Self {
307        let mut new_arms = Vec::new();
308        for arm in &self.arms {
309            new_arms.extend_from_slice(&arm.phi_specialize(db, ctor, occurrence));
310        }
311
312        let mut new_occurrences = self.occurrences[0].phi_specialize(db, ctor);
313        new_occurrences.extend_from_slice(&self.occurrences.as_slice()[1..]);
314
315        Self {
316            arms: new_arms,
317            occurrences: new_occurrences,
318        }
319    }
320
321    fn d_specialize(&self, db: &dyn AnalyzerDb, occurrence: &Occurrence) -> Self {
322        let mut new_arms = Vec::new();
323        for arm in &self.arms {
324            new_arms.extend_from_slice(&arm.d_specialize(db, occurrence));
325        }
326
327        Self {
328            arms: new_arms,
329            occurrences: self.occurrences.as_slice()[1..].to_vec(),
330        }
331    }
332
333    fn swap(&mut self, i: usize) {
334        for arm in &mut self.arms {
335            arm.swap(0, i)
336        }
337        self.occurrences.swap(0, i);
338    }
339}
340
341#[derive(Clone, Debug)]
342struct SimplifiedArm {
343    pat_vec: PatternRowVec,
344    body: usize,
345    binds: IndexMap<(SmolStr, usize), Occurrence>,
346}
347
348impl SimplifiedArm {
349    fn new(pat: &PatternRowVec, body: usize) -> Self {
350        let pat = PatternRowVec::new(pat.inner.iter().map(generalize_pattern).collect());
351        Self {
352            pat_vec: pat,
353            body,
354            binds: IndexMap::new(),
355        }
356    }
357
358    fn len(&self) -> usize {
359        self.pat_vec.len()
360    }
361
362    fn pat(&self, col: usize) -> &SimplifiedPattern {
363        &self.pat_vec.inner[col]
364    }
365
366    fn phi_specialize(
367        &self,
368        db: &dyn AnalyzerDb,
369        ctor: ConstructorKind,
370        occurrence: &Occurrence,
371    ) -> Vec<Self> {
372        let body = self.body;
373        let binds = self.new_binds(occurrence);
374
375        self.pat_vec
376            .phi_specialize(db, ctor)
377            .into_iter()
378            .map(|pat| SimplifiedArm {
379                pat_vec: pat,
380                body,
381                binds: binds.clone(),
382            })
383            .collect()
384    }
385
386    fn d_specialize(&self, db: &dyn AnalyzerDb, occurrence: &Occurrence) -> Vec<Self> {
387        let body = self.body;
388        let binds = self.new_binds(occurrence);
389
390        self.pat_vec
391            .d_specialize(db)
392            .into_iter()
393            .map(|pat| SimplifiedArm {
394                pat_vec: pat,
395                body,
396                binds: binds.clone(),
397            })
398            .collect()
399    }
400
401    fn new_binds(&self, occurrence: &Occurrence) -> IndexMap<(SmolStr, usize), Occurrence> {
402        let mut binds = self.binds.clone();
403        if let Some(SimplifiedPatternKind::WildCard(Some(bind))) =
404            self.pat_vec.head().map(|pat| &pat.kind)
405        {
406            binds.entry(bind.clone()).or_insert(occurrence.clone());
407        }
408        binds
409    }
410
411    fn finalize_binds(self, occurrences: &[Occurrence]) -> IndexMap<(SmolStr, usize), Occurrence> {
412        debug_assert!(self.len() == occurrences.len());
413
414        let mut binds = self.binds;
415        for (pat, occurrence) in self.pat_vec.pats().iter().zip(occurrences.iter()) {
416            debug_assert!(pat.is_wildcard());
417
418            if let SimplifiedPatternKind::WildCard(Some(bind)) = &pat.kind {
419                binds.entry(bind.clone()).or_insert(occurrence.clone());
420            }
421        }
422
423        binds
424    }
425
426    fn swap(&mut self, i: usize, j: usize) {
427        self.pat_vec.swap(i, j);
428    }
429}
430
431struct NecessityMatrix {
432    data: Vec<bool>,
433    ncol: usize,
434    nrow: usize,
435}
436
437impl NecessityMatrix {
438    fn new(ncol: usize, nrow: usize) -> Self {
439        let data = vec![false; ncol * nrow];
440        Self { data, ncol, nrow }
441    }
442
443    fn from_mat(db: &dyn AnalyzerDb, mat: &SimplifiedArmMatrix) -> Self {
444        let nrow = mat.nrows();
445        let ncol = mat.ncols();
446        let mut necessity_mat = Self::new(ncol, nrow);
447
448        necessity_mat.compute(db, mat);
449        necessity_mat
450    }
451
452    fn compute(&mut self, db: &dyn AnalyzerDb, mat: &SimplifiedArmMatrix) {
453        for row in 0..self.nrow {
454            for col in 0..self.ncol {
455                let pat = mat.pat(row, col);
456                let pos = self.pos(row, col);
457
458                if !pat.is_wildcard() {
459                    self.data[pos] = true;
460                } else {
461                    let reduced_pat_mat = mat.reduced_pat_mat(col);
462                    self.data[pos] = !reduced_pat_mat.is_row_useful(db, row);
463                }
464            }
465        }
466    }
467
468    fn compute_needed_column_score(&self, col: usize) -> i32 {
469        let mut num = 0;
470        for i in 0..self.nrow {
471            if self.data[self.pos(i, col)] {
472                num += 1;
473            }
474        }
475
476        num
477    }
478
479    fn compute_needed_prefix_score(&self, col: usize) -> i32 {
480        let mut current_row = 0;
481        for i in 0..self.nrow {
482            if self.data[self.pos(i, col)] {
483                current_row += 1;
484            } else {
485                return current_row;
486            }
487        }
488
489        current_row
490    }
491
492    fn pos(&self, row: usize, col: usize) -> usize {
493        self.ncol * row + col
494    }
495}
496
497#[derive(Debug, Clone, Copy, PartialEq, Eq)]
498enum ColumnScoringFunction {
499    /// The score of column i is the sum of the negation of the arities of
500    /// constructors in sigma(i).
501    Arity,
502
503    /// The score is the negation of the cardinal of sigma(i), C(Sigma(i)).
504    /// If sigma(i) is NOT complete, the resulting score is C(Sigma(i)) - 1.
505    SmallBranching,
506
507    /// The score is the number of needed rows of column i in the necessity
508    /// matrix.
509    NeededColumn,
510
511    NeededPrefix,
512}
513
514impl ColumnScoringFunction {
515    fn score(&self, db: &dyn AnalyzerDb, mat: &SimplifiedArmMatrix, col: usize) -> i32 {
516        match self {
517            ColumnScoringFunction::Arity => mat
518                .sigma_set(col)
519                .iter()
520                .map(|c| -(c.arity(db) as i32))
521                .sum(),
522
523            ColumnScoringFunction::SmallBranching => {
524                let sigma_set = mat.sigma_set(col);
525                let score = -(mat.sigma_set(col).len() as i32);
526                if sigma_set.is_complete(db) {
527                    score
528                } else {
529                    score - 1
530                }
531            }
532
533            ColumnScoringFunction::NeededColumn => {
534                mat.necessity_matrix(db).compute_needed_column_score(col)
535            }
536
537            ColumnScoringFunction::NeededPrefix => {
538                mat.necessity_matrix(db).compute_needed_prefix_score(col)
539            }
540        }
541    }
542}
543
544fn generalize_pattern(pat: &SimplifiedPattern) -> SimplifiedPattern {
545    match &pat.kind {
546        SimplifiedPatternKind::WildCard(_) => pat.clone(),
547
548        SimplifiedPatternKind::Constructor { kind, fields } => {
549            let fields = fields.iter().map(generalize_pattern).collect();
550            let kind = SimplifiedPatternKind::Constructor {
551                kind: *kind,
552                fields,
553            };
554            SimplifiedPattern::new(kind, pat.ty)
555        }
556
557        SimplifiedPatternKind::Or(pats) => {
558            let mut gen_pats = vec![];
559            for pat in pats {
560                let gen_pad = generalize_pattern(pat);
561                if gen_pad.is_wildcard() {
562                    gen_pats.push(gen_pad);
563                    break;
564                } else {
565                    gen_pats.push(gen_pad);
566                }
567            }
568
569            if gen_pats.len() == 1 {
570                gen_pats.pop().unwrap()
571            } else {
572                SimplifiedPattern::new(SimplifiedPatternKind::Or(gen_pats), pat.ty)
573            }
574        }
575    }
576}