fe_analyzer/traversal/
const_expr.rs

1//! This module provides evaluator for constant expression to resolve const
2//! generics.
3
4use num_bigint::BigInt;
5use num_traits::{One, ToPrimitive, Zero};
6
7use crate::{
8    context::{AnalyzerContext, Constant},
9    errors::ConstEvalError,
10    namespace::types::{self, Base, Type},
11};
12
13use fe_common::{numeric, Span};
14use fe_parser::{
15    ast::{self, BinOperator, BoolOperator, CompOperator, UnaryOperator},
16    node::Node,
17};
18
19/// Evaluate expression.
20///
21/// # Panics
22///
23/// 1. Panics if type analysis on an `expr` is not performed beforehand.
24/// 2. Panics if an `expr` has an invalid type.
25pub(crate) fn eval_expr(
26    context: &mut dyn AnalyzerContext,
27    expr: &Node<ast::Expr>,
28) -> Result<Constant, ConstEvalError> {
29    let typ = context.expr_typ(expr);
30
31    match &expr.kind {
32        ast::Expr::Ternary {
33            if_expr,
34            test,
35            else_expr,
36        } => eval_ternary(context, if_expr, test, else_expr),
37        ast::Expr::BoolOperation { left, op, right } => eval_bool_op(context, left, op, right),
38        ast::Expr::BinOperation { left, op, right } => eval_bin_op(context, left, op, right, &typ),
39        ast::Expr::UnaryOperation { op, operand } => eval_unary_op(context, op, operand),
40        ast::Expr::CompOperation { left, op, right } => eval_comp_op(context, left, op, right),
41        ast::Expr::Bool(val) => Ok(Constant::Bool(*val)),
42        ast::Expr::Name(name) => match context.constant_value_by_name(name, expr.span)? {
43            Some(const_value) => Ok(const_value),
44            _ => Err(not_const_error(context, expr.span)),
45        },
46
47        ast::Expr::Num(num) => {
48            // We don't validate the string representing number here,
49            // because we assume the string has been already validate in type analysis.
50            let span = expr.span;
51            Constant::from_num_str(context, num, &typ, span)
52        }
53
54        ast::Expr::Str(s) => Ok(Constant::Str(s.clone())),
55
56        // TODO: Need to evaluate attribute getter, constant constructor and const fn call.
57        ast::Expr::Subscript { .. }
58        | ast::Expr::Path(_)
59        | ast::Expr::Attribute { .. }
60        | ast::Expr::Call { .. }
61        | ast::Expr::List { .. }
62        | ast::Expr::Repeat { .. }
63        | ast::Expr::Tuple { .. }
64        | ast::Expr::Unit => Err(not_const_error(context, expr.span)),
65    }
66}
67
68/// Evaluates ternary expression.
69fn eval_ternary(
70    context: &mut dyn AnalyzerContext,
71    then_expr: &Node<ast::Expr>,
72    cond: &Node<ast::Expr>,
73    else_expr: &Node<ast::Expr>,
74) -> Result<Constant, ConstEvalError> {
75    // In constant evaluation, we don't apply short circuit property for safety.
76    let then = eval_expr(context, then_expr)?;
77    let cond = eval_expr(context, cond)?;
78    let else_ = eval_expr(context, else_expr)?;
79
80    match cond {
81        Constant::Bool(cond) => {
82            if cond {
83                Ok(then)
84            } else {
85                Ok(else_)
86            }
87        }
88        _ => panic!("ternary condition is not a bool type"),
89    }
90}
91
92/// Evaluates logical expressions.
93fn eval_bool_op(
94    context: &mut dyn AnalyzerContext,
95    lhs: &Node<ast::Expr>,
96    op: &Node<ast::BoolOperator>,
97    rhs: &Node<ast::Expr>,
98) -> Result<Constant, ConstEvalError> {
99    // In constant evaluation, we don't apply short circuit property for safety.
100    let (lhs, rhs) = (eval_expr(context, lhs)?, eval_expr(context, rhs)?);
101    let (lhs, rhs) = match (lhs, rhs) {
102        (Constant::Bool(lhs), Constant::Bool(rhs)) => (lhs, rhs),
103        _ => panic!("an argument of a logical expression is not bool type"),
104    };
105
106    match op.kind {
107        BoolOperator::And => Ok(Constant::Bool(lhs && rhs)),
108        BoolOperator::Or => Ok(Constant::Bool(lhs || rhs)),
109    }
110}
111
112/// Evaluates binary expressions.
113fn eval_bin_op(
114    context: &mut dyn AnalyzerContext,
115    lhs: &Node<ast::Expr>,
116    op: &Node<ast::BinOperator>,
117    rhs: &Node<ast::Expr>,
118    typ: &Type,
119) -> Result<Constant, ConstEvalError> {
120    let span = lhs.span + rhs.span;
121    let lhs_ty = extract_int_typ(&context.expr_typ(lhs));
122
123    let (lhs, rhs) = (eval_expr(context, lhs)?, eval_expr(context, rhs)?);
124    let (lhs, rhs) = (lhs.extract_numeric(), rhs.extract_numeric());
125
126    let result = match op.kind {
127        BinOperator::Add => lhs + rhs,
128        BinOperator::Sub => lhs - rhs,
129        BinOperator::Mult => lhs * rhs,
130
131        BinOperator::Div => {
132            if rhs.is_zero() {
133                return Err(zero_division_error(context, span));
134            } else if lhs_ty.is_signed() && lhs == &(lhs_ty.min_value()) && rhs == &(-BigInt::one())
135            {
136                return Err(overflow_error(context, span));
137            } else {
138                lhs / rhs
139            }
140        }
141
142        BinOperator::Mod => {
143            if rhs.is_zero() {
144                return Err(zero_division_error(context, span));
145            }
146            lhs % rhs
147        }
148
149        BinOperator::Pow => {
150            // We assume `rhs` type is unsigned numeric.
151            if let Some(exponent) = rhs.to_u32() {
152                lhs.pow(exponent)
153            } else if lhs.is_zero() {
154                BigInt::zero()
155            } else if lhs.is_one() {
156                BigInt::one()
157            } else {
158                // Exponent is larger than u32::MAX and lhs is not zero nor one,
159                // then this trivially causes overflow.
160                return Err(overflow_error(context, span));
161            }
162        }
163
164        BinOperator::LShift => {
165            if let Some(exponent) = rhs.to_usize() {
166                let type_bits = lhs_ty.bits();
167                // If rhs is larger than or equal to lhs type bits, then we emits overflow
168                // error.
169                if exponent >= type_bits {
170                    return Err(overflow_error(context, span));
171                } else {
172                    let mask = make_mask(typ);
173                    (lhs * BigInt::from(2_u8).pow(exponent as u32)) & mask
174                }
175            } else {
176                // If exponent is larger than usize::MAX, it causes trivially overflow.
177                return Err(overflow_error(context, span));
178            }
179        }
180
181        BinOperator::RShift => {
182            if let Some(exponent) = rhs.to_usize() {
183                let type_bits = lhs_ty.bits();
184                // If rhs is larger than or equal to lhs type bits, then we emits overflow
185                // error.
186                if exponent >= type_bits {
187                    return Err(overflow_error(context, span));
188                } else {
189                    let mask = make_mask(typ);
190                    (lhs / BigInt::from(2_u8).pow(exponent as u32)) & mask
191                }
192            } else {
193                // If exponent is larger than usize::MAX, it causes trivially overflow.
194                return Err(overflow_error(context, span));
195            }
196        }
197
198        BinOperator::BitOr => lhs | rhs,
199        BinOperator::BitXor => lhs ^ rhs,
200        BinOperator::BitAnd => lhs & rhs,
201    };
202
203    Constant::make_const_numeric_with_ty(context, result, typ, span)
204}
205
206fn eval_unary_op(
207    context: &mut dyn AnalyzerContext,
208    op: &Node<ast::UnaryOperator>,
209    arg: &Node<ast::Expr>,
210) -> Result<Constant, ConstEvalError> {
211    let arg = eval_expr(context, arg)?;
212
213    match op.kind {
214        UnaryOperator::Invert => Ok(Constant::Int(!arg.extract_numeric())),
215        UnaryOperator::Not => Ok(Constant::Bool(!arg.extract_bool())),
216        UnaryOperator::USub => Ok(Constant::Int(-arg.extract_numeric())),
217    }
218}
219
220/// Evaluates comp operation.
221fn eval_comp_op(
222    context: &mut dyn AnalyzerContext,
223    lhs: &Node<ast::Expr>,
224    op: &Node<ast::CompOperator>,
225    rhs: &Node<ast::Expr>,
226) -> Result<Constant, ConstEvalError> {
227    let (lhs, rhs) = (eval_expr(context, lhs)?, eval_expr(context, rhs)?);
228
229    let res = match (lhs, rhs) {
230        (Constant::Int(lhs), Constant::Int(rhs)) => match op.kind {
231            CompOperator::Eq => lhs == rhs,
232            CompOperator::NotEq => lhs != rhs,
233            CompOperator::Lt => lhs < rhs,
234            CompOperator::LtE => lhs <= rhs,
235            CompOperator::Gt => lhs > rhs,
236            CompOperator::GtE => lhs >= rhs,
237        },
238
239        (Constant::Bool(lhs), Constant::Bool(rhs)) => match op.kind {
240            CompOperator::Eq => lhs == rhs,
241            CompOperator::NotEq => lhs != rhs,
242            CompOperator::Lt => !lhs & rhs,
243            CompOperator::LtE => lhs <= rhs,
244            CompOperator::Gt => lhs & !rhs,
245            CompOperator::GtE => lhs >= rhs,
246        },
247
248        _ => panic!("arguments of comp op have invalid type"),
249    };
250
251    Ok(Constant::Bool(res))
252}
253
254impl Constant {
255    /// Returns constant from numeric literal represented by string.
256    ///
257    /// # Panics
258    /// Panics if `s` is invalid string for numeric literal.
259    pub fn from_num_str(
260        context: &mut dyn AnalyzerContext,
261        s: &str,
262        typ: &Type,
263        span: Span,
264    ) -> Result<Self, ConstEvalError> {
265        let literal = numeric::Literal::new(s);
266        let num = literal.parse::<BigInt>().unwrap();
267        match typ {
268            Type::Base(Base::Numeric(_)) => {
269                Self::make_const_numeric_with_ty(context, num, typ, span)
270            }
271            Type::Base(Base::Address) => {
272                if num >= BigInt::zero() && num <= types::address_max() {
273                    Ok(Constant::Address(num))
274                } else {
275                    Err(overflow_error(context, span))
276                }
277            }
278            _ => unreachable!(),
279        }
280    }
281
282    /// Returns constant from numeric literal that fits type bits.
283    /// If `val` doesn't fit type bits, then return `Err`.
284    ///
285    /// # Panics
286    /// Panics if `typ` is invalid string for numeric literal.
287    fn make_const_numeric_with_ty(
288        context: &mut dyn AnalyzerContext,
289        val: BigInt,
290        typ: &Type,
291        span: Span,
292    ) -> Result<Self, ConstEvalError> {
293        // Overflowing check.
294        if extract_int_typ(typ).fits(val.clone()) {
295            Ok(Constant::Int(val))
296        } else {
297            Err(overflow_error(context, span))
298        }
299    }
300
301    /// Extracts numeric value from a `Constant`.
302    ///
303    /// # Panics
304    /// Panics if a `self` variant is not a numeric.
305    fn extract_numeric(&self) -> &BigInt {
306        match self {
307            Constant::Int(val) => val,
308            _ => panic!("can't extract numeric value from {self:?}"),
309        }
310    }
311
312    /// Extracts bool value from a `Constant`.
313    ///
314    /// # Panics
315    /// Panics if a `self` variant is not a bool.
316    fn extract_bool(&self) -> bool {
317        match self {
318            Constant::Bool(val) => *val,
319            _ => panic!("can't extract bool value from {self:?}"),
320        }
321    }
322}
323
324fn not_const_error(context: &mut dyn AnalyzerContext, span: Span) -> ConstEvalError {
325    ConstEvalError::new(context.error(
326        "expression is not a constant",
327        span,
328        "expression is required to be constant here",
329    ))
330}
331
332fn overflow_error(context: &mut dyn AnalyzerContext, span: Span) -> ConstEvalError {
333    ConstEvalError::new(context.error(
334        "overflow error",
335        span,
336        "overflow occurred during constant evaluation",
337    ))
338}
339
340fn zero_division_error(context: &mut dyn AnalyzerContext, span: Span) -> ConstEvalError {
341    ConstEvalError::new(context.error(
342        "zero division error",
343        span,
344        "zero division occurred during constant evaluation",
345    ))
346}
347
348/// Returns integer types embedded in `typ`.
349///
350/// # Panic
351/// Panics if `typ` is not a numeric type.
352fn extract_int_typ(typ: &Type) -> types::Integer {
353    match typ {
354        Type::Base(Base::Numeric(int_ty)) => *int_ty,
355        _ => {
356            panic!("invalid binop expression type")
357        }
358    }
359}
360
361/// Returns bit mask corresponding to typ.
362/// e.g. If type is `Type::Base(Base::Numeric(Integer::I32))`, then returns
363/// `0xffff_ffff`.
364///
365/// # Panic
366/// Panics if `typ` is not a numeric type.
367fn make_mask(typ: &Type) -> BigInt {
368    let bits = extract_int_typ(typ).bits();
369    (BigInt::one() << bits) - 1
370}