1use num_bigint::BigInt;
5use num_traits::{One, ToPrimitive, Zero};
6
7use crate::{
8 context::{AnalyzerContext, Constant},
9 errors::ConstEvalError,
10 namespace::types::{self, Base, Type},
11};
12
13use fe_common::{numeric, Span};
14use fe_parser::{
15 ast::{self, BinOperator, BoolOperator, CompOperator, UnaryOperator},
16 node::Node,
17};
18
19pub(crate) fn eval_expr(
26 context: &mut dyn AnalyzerContext,
27 expr: &Node<ast::Expr>,
28) -> Result<Constant, ConstEvalError> {
29 let typ = context.expr_typ(expr);
30
31 match &expr.kind {
32 ast::Expr::Ternary {
33 if_expr,
34 test,
35 else_expr,
36 } => eval_ternary(context, if_expr, test, else_expr),
37 ast::Expr::BoolOperation { left, op, right } => eval_bool_op(context, left, op, right),
38 ast::Expr::BinOperation { left, op, right } => eval_bin_op(context, left, op, right, &typ),
39 ast::Expr::UnaryOperation { op, operand } => eval_unary_op(context, op, operand),
40 ast::Expr::CompOperation { left, op, right } => eval_comp_op(context, left, op, right),
41 ast::Expr::Bool(val) => Ok(Constant::Bool(*val)),
42 ast::Expr::Name(name) => match context.constant_value_by_name(name, expr.span)? {
43 Some(const_value) => Ok(const_value),
44 _ => Err(not_const_error(context, expr.span)),
45 },
46
47 ast::Expr::Num(num) => {
48 let span = expr.span;
51 Constant::from_num_str(context, num, &typ, span)
52 }
53
54 ast::Expr::Str(s) => Ok(Constant::Str(s.clone())),
55
56 ast::Expr::Subscript { .. }
58 | ast::Expr::Path(_)
59 | ast::Expr::Attribute { .. }
60 | ast::Expr::Call { .. }
61 | ast::Expr::List { .. }
62 | ast::Expr::Repeat { .. }
63 | ast::Expr::Tuple { .. }
64 | ast::Expr::Unit => Err(not_const_error(context, expr.span)),
65 }
66}
67
68fn eval_ternary(
70 context: &mut dyn AnalyzerContext,
71 then_expr: &Node<ast::Expr>,
72 cond: &Node<ast::Expr>,
73 else_expr: &Node<ast::Expr>,
74) -> Result<Constant, ConstEvalError> {
75 let then = eval_expr(context, then_expr)?;
77 let cond = eval_expr(context, cond)?;
78 let else_ = eval_expr(context, else_expr)?;
79
80 match cond {
81 Constant::Bool(cond) => {
82 if cond {
83 Ok(then)
84 } else {
85 Ok(else_)
86 }
87 }
88 _ => panic!("ternary condition is not a bool type"),
89 }
90}
91
92fn eval_bool_op(
94 context: &mut dyn AnalyzerContext,
95 lhs: &Node<ast::Expr>,
96 op: &Node<ast::BoolOperator>,
97 rhs: &Node<ast::Expr>,
98) -> Result<Constant, ConstEvalError> {
99 let (lhs, rhs) = (eval_expr(context, lhs)?, eval_expr(context, rhs)?);
101 let (lhs, rhs) = match (lhs, rhs) {
102 (Constant::Bool(lhs), Constant::Bool(rhs)) => (lhs, rhs),
103 _ => panic!("an argument of a logical expression is not bool type"),
104 };
105
106 match op.kind {
107 BoolOperator::And => Ok(Constant::Bool(lhs && rhs)),
108 BoolOperator::Or => Ok(Constant::Bool(lhs || rhs)),
109 }
110}
111
112fn eval_bin_op(
114 context: &mut dyn AnalyzerContext,
115 lhs: &Node<ast::Expr>,
116 op: &Node<ast::BinOperator>,
117 rhs: &Node<ast::Expr>,
118 typ: &Type,
119) -> Result<Constant, ConstEvalError> {
120 let span = lhs.span + rhs.span;
121 let lhs_ty = extract_int_typ(&context.expr_typ(lhs));
122
123 let (lhs, rhs) = (eval_expr(context, lhs)?, eval_expr(context, rhs)?);
124 let (lhs, rhs) = (lhs.extract_numeric(), rhs.extract_numeric());
125
126 let result = match op.kind {
127 BinOperator::Add => lhs + rhs,
128 BinOperator::Sub => lhs - rhs,
129 BinOperator::Mult => lhs * rhs,
130
131 BinOperator::Div => {
132 if rhs.is_zero() {
133 return Err(zero_division_error(context, span));
134 } else if lhs_ty.is_signed() && lhs == &(lhs_ty.min_value()) && rhs == &(-BigInt::one())
135 {
136 return Err(overflow_error(context, span));
137 } else {
138 lhs / rhs
139 }
140 }
141
142 BinOperator::Mod => {
143 if rhs.is_zero() {
144 return Err(zero_division_error(context, span));
145 }
146 lhs % rhs
147 }
148
149 BinOperator::Pow => {
150 if let Some(exponent) = rhs.to_u32() {
152 lhs.pow(exponent)
153 } else if lhs.is_zero() {
154 BigInt::zero()
155 } else if lhs.is_one() {
156 BigInt::one()
157 } else {
158 return Err(overflow_error(context, span));
161 }
162 }
163
164 BinOperator::LShift => {
165 if let Some(exponent) = rhs.to_usize() {
166 let type_bits = lhs_ty.bits();
167 if exponent >= type_bits {
170 return Err(overflow_error(context, span));
171 } else {
172 let mask = make_mask(typ);
173 (lhs * BigInt::from(2_u8).pow(exponent as u32)) & mask
174 }
175 } else {
176 return Err(overflow_error(context, span));
178 }
179 }
180
181 BinOperator::RShift => {
182 if let Some(exponent) = rhs.to_usize() {
183 let type_bits = lhs_ty.bits();
184 if exponent >= type_bits {
187 return Err(overflow_error(context, span));
188 } else {
189 let mask = make_mask(typ);
190 (lhs / BigInt::from(2_u8).pow(exponent as u32)) & mask
191 }
192 } else {
193 return Err(overflow_error(context, span));
195 }
196 }
197
198 BinOperator::BitOr => lhs | rhs,
199 BinOperator::BitXor => lhs ^ rhs,
200 BinOperator::BitAnd => lhs & rhs,
201 };
202
203 Constant::make_const_numeric_with_ty(context, result, typ, span)
204}
205
206fn eval_unary_op(
207 context: &mut dyn AnalyzerContext,
208 op: &Node<ast::UnaryOperator>,
209 arg: &Node<ast::Expr>,
210) -> Result<Constant, ConstEvalError> {
211 let arg = eval_expr(context, arg)?;
212
213 match op.kind {
214 UnaryOperator::Invert => Ok(Constant::Int(!arg.extract_numeric())),
215 UnaryOperator::Not => Ok(Constant::Bool(!arg.extract_bool())),
216 UnaryOperator::USub => Ok(Constant::Int(-arg.extract_numeric())),
217 }
218}
219
220fn eval_comp_op(
222 context: &mut dyn AnalyzerContext,
223 lhs: &Node<ast::Expr>,
224 op: &Node<ast::CompOperator>,
225 rhs: &Node<ast::Expr>,
226) -> Result<Constant, ConstEvalError> {
227 let (lhs, rhs) = (eval_expr(context, lhs)?, eval_expr(context, rhs)?);
228
229 let res = match (lhs, rhs) {
230 (Constant::Int(lhs), Constant::Int(rhs)) => match op.kind {
231 CompOperator::Eq => lhs == rhs,
232 CompOperator::NotEq => lhs != rhs,
233 CompOperator::Lt => lhs < rhs,
234 CompOperator::LtE => lhs <= rhs,
235 CompOperator::Gt => lhs > rhs,
236 CompOperator::GtE => lhs >= rhs,
237 },
238
239 (Constant::Bool(lhs), Constant::Bool(rhs)) => match op.kind {
240 CompOperator::Eq => lhs == rhs,
241 CompOperator::NotEq => lhs != rhs,
242 CompOperator::Lt => !lhs & rhs,
243 CompOperator::LtE => lhs <= rhs,
244 CompOperator::Gt => lhs & !rhs,
245 CompOperator::GtE => lhs >= rhs,
246 },
247
248 _ => panic!("arguments of comp op have invalid type"),
249 };
250
251 Ok(Constant::Bool(res))
252}
253
254impl Constant {
255 pub fn from_num_str(
260 context: &mut dyn AnalyzerContext,
261 s: &str,
262 typ: &Type,
263 span: Span,
264 ) -> Result<Self, ConstEvalError> {
265 let literal = numeric::Literal::new(s);
266 let num = literal.parse::<BigInt>().unwrap();
267 match typ {
268 Type::Base(Base::Numeric(_)) => {
269 Self::make_const_numeric_with_ty(context, num, typ, span)
270 }
271 Type::Base(Base::Address) => {
272 if num >= BigInt::zero() && num <= types::address_max() {
273 Ok(Constant::Address(num))
274 } else {
275 Err(overflow_error(context, span))
276 }
277 }
278 _ => unreachable!(),
279 }
280 }
281
282 fn make_const_numeric_with_ty(
288 context: &mut dyn AnalyzerContext,
289 val: BigInt,
290 typ: &Type,
291 span: Span,
292 ) -> Result<Self, ConstEvalError> {
293 if extract_int_typ(typ).fits(val.clone()) {
295 Ok(Constant::Int(val))
296 } else {
297 Err(overflow_error(context, span))
298 }
299 }
300
301 fn extract_numeric(&self) -> &BigInt {
306 match self {
307 Constant::Int(val) => val,
308 _ => panic!("can't extract numeric value from {self:?}"),
309 }
310 }
311
312 fn extract_bool(&self) -> bool {
317 match self {
318 Constant::Bool(val) => *val,
319 _ => panic!("can't extract bool value from {self:?}"),
320 }
321 }
322}
323
324fn not_const_error(context: &mut dyn AnalyzerContext, span: Span) -> ConstEvalError {
325 ConstEvalError::new(context.error(
326 "expression is not a constant",
327 span,
328 "expression is required to be constant here",
329 ))
330}
331
332fn overflow_error(context: &mut dyn AnalyzerContext, span: Span) -> ConstEvalError {
333 ConstEvalError::new(context.error(
334 "overflow error",
335 span,
336 "overflow occurred during constant evaluation",
337 ))
338}
339
340fn zero_division_error(context: &mut dyn AnalyzerContext, span: Span) -> ConstEvalError {
341 ConstEvalError::new(context.error(
342 "zero division error",
343 span,
344 "zero division occurred during constant evaluation",
345 ))
346}
347
348fn extract_int_typ(typ: &Type) -> types::Integer {
353 match typ {
354 Type::Base(Base::Numeric(int_ty)) => *int_ty,
355 _ => {
356 panic!("invalid binop expression type")
357 }
358 }
359}
360
361fn make_mask(typ: &Type) -> BigInt {
368 let bits = extract_int_typ(typ).bits();
369 (BigInt::one() << bits) - 1
370}