From 209664fb5bf067f98bf38b151b3c917bc0636ae1 Mon Sep 17 00:00:00 2001 From: Tobias Reiher Date: Mon, 16 Dec 2024 17:41:47 +0100 Subject: [PATCH] Rewrite binary expressions in Rust Ref. eng/recordflux/RecordFlux#1774 --- librapidflux/src/expr.rs | 880 +++++++++++++++++++++++++++++- rapidflux/src/expr.rs | 512 ++++++++++++++++- rflx/rapidflux/expr.pyi | 38 +- tests/unit/expr_test.py | 191 ++++++- tests/unit/rapidflux/expr_test.py | 414 +++++++++++++- 5 files changed, 2001 insertions(+), 34 deletions(-) diff --git a/librapidflux/src/expr.rs b/librapidflux/src/expr.rs index 9b28e03c5..0dea345f7 100644 --- a/librapidflux/src/expr.rs +++ b/librapidflux/src/expr.rs @@ -33,6 +33,7 @@ pub enum Expr { Lit(Sym), Num(Num), Neg(Neg), + BinExpr(BinExpr), } impl Expr { @@ -67,34 +68,49 @@ impl Expr { }) } + pub fn bin_expr(op: BinOp, left: Expr, right: Expr, location: Location) -> Self { + Self::BinExpr(BinExpr { + op, + left: Box::new(left), + right: Box::new(right), + location, + }) + } + pub fn location(&self) -> &Location { match self { Self::Var(sym) | Self::Lit(sym) => sym.id.location(), Self::Num(num) => &num.location, Self::Neg(neg) => &neg.location, + Self::BinExpr(bin_expr) => &bin_expr.location, } } pub fn check_type(&self, expected: &[ty::Ty]) -> Error { let mut error = self.check_sub_expr_type(); - error.extend( - ty::check_type(&self.ty(), expected, self.location(), &self.expr_name()).into_entries(), - ); + if error.entries().is_empty() { + error.extend( + ty::check_type(&self.ty(), expected, self.location(), &self.expr_name()) + .into_entries(), + ); + } error } pub fn check_type_instance(&self, expected: &[ty::TyDiscriminants]) -> Error { let mut error = self.check_sub_expr_type(); - error.extend( - ty::check_type_instance( - &self.ty(), - expected, - self.location(), - &self.expr_name(), - &[], - ) - .into_entries(), - ); + if error.entries().is_empty() { + error.extend( + ty::check_type_instance( + &self.ty(), + expected, + self.location(), + &self.expr_name(), + &[], + ) + .into_entries(), + ); + } error } @@ -102,6 +118,7 @@ impl Expr { match self { Self::Num(..) | Self::Var(..) | Self::Lit(..) => Error::default(), Self::Neg(neg) => neg.check_sub_expr_type(), + Self::BinExpr(bin_expr) => bin_expr.check_sub_expr_type(), } } @@ -119,6 +136,7 @@ impl Expr { } } Self::Neg(neg) => neg.find_all(f), + Self::BinExpr(bin_expr) => bin_expr.find_all(f), } } @@ -131,6 +149,7 @@ impl Expr { Self::Lit(..) => panic!("literal cannot be negated"), Self::Num(num) => num.into_negated(), Self::Neg(neg) => neg.into_negated(), + Self::BinExpr(bin_expr) => bin_expr.into_negated(), } } @@ -138,6 +157,7 @@ impl Expr { match self { Self::Var(..) | Self::Lit(..) | Self::Num(..) => self, Self::Neg(neg) => neg.into_simplified(), + Self::BinExpr(bin_expr) => bin_expr.into_simplified(), } } @@ -145,6 +165,7 @@ impl Expr { match self { Self::Var(..) | Self::Lit(..) | Self::Num(..) => f(&self), Self::Neg(neg) => f(&neg.into_substituted(f)), + Self::BinExpr(bin_expr) => f(&bin_expr.into_substituted(f)), } } @@ -153,17 +174,18 @@ impl Expr { Self::Var(sym) | Self::Lit(sym) => sym.ty.clone(), Self::Num(num) => num.ty(), Self::Neg(neg) => neg.ty(), + Self::BinExpr(bin_expr) => bin_expr.ty(), } } fn expr_name(&self) -> String { let expr_type = match self { - Self::Num(..) | Self::Neg(..) => "expression", + Self::Num(..) | Self::Neg(..) | Self::BinExpr(..) => "expression", Self::Var(..) => "variable", Self::Lit(..) => "literal", }; let expr_name = match self { - Self::Num(..) | Self::Neg(..) => self.to_string(), + Self::Num(..) | Self::Neg(..) | Self::BinExpr(..) => self.to_string(), Self::Var(sym) | Self::Lit(sym) => sym.id.to_string(), }; format!("{expr_type} \"{expr_name}\"") @@ -182,6 +204,7 @@ impl Expr { Self::Var(sym) | Self::Lit(sym) => sym.precedence(), Self::Num(num) => num.precedence(), Self::Neg(neg) => neg.precedence(), + Self::BinExpr(bin_expr) => bin_expr.precedence(), } } } @@ -195,6 +218,7 @@ impl Display for Expr { Self::Num(num) => num.to_string(), Self::Var(sym) | Self::Lit(sym) => sym.id.to_string(), Self::Neg(neg) => neg.to_string(), + Self::BinExpr(bin_expr) => bin_expr.to_string(), } ) } @@ -346,6 +370,222 @@ impl Display for Neg { } } +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub enum BinOp { + Sub, + Div, + Pow, + Mod, +} + +impl Display for BinOp { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Self::Sub => "-", + Self::Div => "/", + Self::Pow => "**", + Self::Mod => "mod", + } + ) + } +} + +#[allow(clippy::module_name_repetitions)] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct BinExpr { + pub op: BinOp, + pub left: Box, + pub right: Box, + pub location: Location, +} + +impl BinExpr { + pub fn find_all(&self, f: &impl Fn(&Expr) -> bool) -> Vec { + let mut result = vec![]; + let e = Expr::BinExpr(self.clone()); + if f(&e) { + result.push(e); + } + result.extend(self.left.find_all(f)); + result.extend(self.right.find_all(f)); + result + } + + pub fn into_negated(self) -> Expr { + match self.op { + BinOp::Sub => Expr::BinExpr(BinExpr { + op: self.op, + left: self.right, + right: self.left, + location: self.location.clone(), + }), + BinOp::Div => Expr::BinExpr(BinExpr { + op: self.op, + left: Box::new(self.left.into_negated()), + right: self.right, + location: self.location.clone(), + }), + BinOp::Pow | BinOp::Mod => Expr::neg( + Expr::BinExpr(BinExpr { + op: self.op, + left: self.left, + right: self.right, + location: self.location.clone(), + }), + self.location.clone(), + ), + } + } + + /// # Panics + /// + /// Will panic if the expression contains an invalid operation. + pub fn into_simplified(self) -> Expr { + let left = self.left.clone().into_simplified(); + let right = self.right.clone().into_simplified(); + match (left, right) { + ( + Expr::Num(Num { + value: l, + base: _, + location: _, + }), + Expr::Num(Num { + value: r, + base: _, + location: _, + }), + ) => match self.op { + BinOp::Sub => Expr::Num(Num { + value: l - r, + base: NumBase::Default, + location: self.location, + }), + BinOp::Div => { + if l % r == 0 { + Expr::Num(Num { + value: l / r, + base: NumBase::Default, + location: self.location, + }) + } else { + Expr::BinExpr(self) + } + } + BinOp::Pow => { + if r >= 0 { + Expr::Num(Num { + value: l.pow(u32::try_from(r).expect("too big exponent")), + base: NumBase::Default, + location: self.location, + }) + } else { + panic!("negative exponent") + } + } + BinOp::Mod => { + if r != 0 { + Expr::Num(Num { + value: l.rem_euclid(r) + if r < 0 { r } else { 0 }, + base: NumBase::Default, + location: self.location, + }) + } else { + panic!("modulo by zero") + } + } + }, + ( + Expr::Num(Num { + value: 0, + base: _, + location, + }), + other, + ) => match self.op { + BinOp::Sub => other.into_negated().into_simplified(), + BinOp::Div | BinOp::Pow | BinOp::Mod => Expr::num(0, location), + }, + ( + other, + Expr::Num(Num { + value: 0, + base: _, + location, + }), + ) => match self.op { + BinOp::Sub => other.into_simplified(), + BinOp::Div => panic!("division by zero"), + BinOp::Pow => Expr::num(1, location), + BinOp::Mod => panic!("modulo by zero"), + }, + (Expr::Lit(_), Expr::Lit(_)) => panic!("invalid operation"), + _ => Expr::BinExpr(BinExpr { + op: self.op, + left: Box::new(self.left.into_simplified()), + right: Box::new(self.right.into_simplified()), + location: self.location, + }), + } + } + + pub fn into_substituted(self, f: &impl Fn(&Expr) -> Expr) -> Expr { + f(&Expr::bin_expr( + self.op, + self.left.into_substituted(f), + self.right.into_substituted(f), + self.location, + )) + } + + pub fn precedence(&self) -> Precedence { + match self.op { + BinOp::Sub => Precedence::BinaryAddingOperator, + BinOp::Div | BinOp::Mod => Precedence::MultiplyingOperator, + BinOp::Pow => Precedence::HighestPrecedenceOperator, + } + } + + pub fn ty(&self) -> ty::Ty { + ty::common_type(&[self.left.ty().clone(), self.right.ty().clone()]) + } + + fn check_sub_expr_type(&self) -> Error { + let mut error = Error::default(); + + for e in &[&self.left, &self.right] { + error.extend( + e.check_type_instance(&[ty::TyDiscriminants::AnyInteger]) + .into_entries(), + ); + } + + error + } +} + +impl PartialEq for BinExpr { + fn eq(&self, other: &Self) -> bool { + self.op == other.op && self.left == other.left && self.right == other.right + } +} + +impl Display for BinExpr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let self_expr = Expr::BinExpr(self.clone()); + write!( + f, + "{} {} {}", + self_expr.parenthesized(&self.left), + self.op, + self_expr.parenthesized(&self.right), + ) + } +} + lazy_static! { static ref TRUE_SYM: Sym = Sym { id: create_id!(["True"], Location::None), @@ -408,6 +648,12 @@ mod tests { #[case::lit(lit("X", location(1)))] #[case::num(Expr::num(42, location(1)))] #[case::neg(Expr::neg(Expr::num(42, location(2)), location(1)))] + #[case::bin_expr(Expr::bin_expr( + BinOp::Sub, + Expr::num(2, location(3)), + Expr::num(1, location(2)), + location(1) + ))] fn test_expr_location(#[case] expression: Expr) { assert_eq!(*expression.location(), location(1)); } @@ -429,6 +675,14 @@ mod tests { Expr::neg(Expr::num(42, location(2)), location(1)), vec![] )] + #[case::bin_expr( + Expr::bin_expr(BinOp::Sub, Expr::num(2, location(3)), Expr::num(1, location(2)), location(1)), + vec![] + )] + #[case::bin_expr( + Expr::bin_expr(BinOp::Sub, var("X", location(2)), var("Y", location(3)), location(1)), + vec![var("X", location(2)), var("Y", location(3))] + )] fn test_expr_variables(#[case] expression: Expr, #[case] expected: Vec) { assert_eq!(*expression.variables(), expected); } @@ -479,6 +733,21 @@ mod tests { |e: &Expr| matches!(e, Expr::Neg(_)), vec![Expr::neg(Expr::num(42, location(2)), location(1))] )] + #[case::bin_expr( + Expr::bin_expr(BinOp::Sub, Expr::num(2, location(3)), Expr::num(1, location(2)), location(1)), + |_: &Expr| false, + vec![] + )] + #[case::bin_expr( + Expr::bin_expr(BinOp::Sub, Expr::num(2, location(3)), Expr::num(1, location(2)), location(1)), + |e: &Expr| matches!(e, Expr::Num(_)), + vec![Expr::num(2, location(3)), Expr::num(1, location(2))] + )] + #[case::bin_expr( + Expr::bin_expr(BinOp::Sub, var("X", location(2)), var("Y", location(3)), location(1)), + |e: &Expr| matches!(e, Expr::BinExpr(_)), + vec![Expr::bin_expr(BinOp::Sub, var("X", location(2)), var("Y", location(3)), location(1))] + )] fn test_expr_find_all( #[case] expression: Expr, #[case] f: impl Fn(&Expr) -> bool, @@ -499,6 +768,51 @@ mod tests { Expr::neg(Expr::num(42, location(2)), location(1)), Expr::num(42, location(2)) )] + #[case::sub( + Expr::bin_expr( + BinOp::Sub, + var("X", location(2)), + Expr::num(1, location(3)), + location(1) + ), + Expr::bin_expr( + BinOp::Sub, + Expr::num(1, location(3)), + var("X", location(2)), + location(1) + ) + )] + #[case::div( + Expr::bin_expr( + BinOp::Div, + Expr::num(2, location(3)), + var("X", location(2)), + location(1) + ), + Expr::bin_expr( + BinOp::Div, + Expr::num(-2, location(3)), + var("X", location(2)), + location(1) + ), + )] + #[case::pow( + Expr::bin_expr( + BinOp::Pow, + Expr::num(2, location(3)), + var("X", location(2)), + location(1) + ), + Expr::neg( + Expr::bin_expr( + BinOp::Pow, + Expr::num(2, location(3)), + var("X", location(2)), + location(1) + ), + location(1) + ) + )] fn test_expr_into_negated(#[case] expression: Expr, #[case] expected: Expr) { assert_eq!(expression.into_negated(), expected); } @@ -543,6 +857,317 @@ mod tests { ), Expr::num(-42, Location::None) )] + #[case::sub_ident( + Expr::bin_expr( + BinOp::Sub, + Expr::num(1, Location::None), + var("X", Location::None), + location(1) + ), + Expr::bin_expr( + BinOp::Sub, + Expr::num(1, Location::None), + var("X", Location::None), + location(1) + ) + )] + #[case::sub_num( + Expr::bin_expr( + BinOp::Sub, + Expr::num(1, Location::None), + Expr::num(2, Location::None), + location(1) + ), + Expr::num(-1, location(1)) + )] + #[case::sub_nested( + Expr::bin_expr( + BinOp::Sub, + Expr::bin_expr( + BinOp::Sub, + Expr::num(1, Location::None), + Expr::num(2, Location::None), + Location::None, + ), + Expr::bin_expr( + BinOp::Sub, + Expr::num(1, Location::None), + Expr::bin_expr( + BinOp::Sub, + Expr::num(2, Location::None), + Expr::num(3, Location::None), + Location::None, + ), + Location::None, + ), + location(1) + ), + Expr::num(-3, location(1)) + )] + #[case::sub_zero( + Expr::bin_expr( + BinOp::Sub, + Expr::bin_expr( + BinOp::Sub, + var("X", Location::None), + Expr::num(0, Location::None), + Location::None, + ), + Expr::bin_expr( + BinOp::Sub, + Expr::num(0, Location::None), + var("Y", Location::None), + Location::None, + ), + location(1) + ), + Expr::bin_expr( + BinOp::Sub, + var("X", Location::None), + Expr::neg(var("Y", Location::None), Location::None), + location(1) + ) + )] + #[case::div_ident( + Expr::bin_expr( + BinOp::Div, + Expr::num(1, Location::None), + Expr::num(2, Location::None), + location(1) + ), + Expr::bin_expr( + BinOp::Div, + Expr::num(1, Location::None), + Expr::num(2, Location::None), + location(1) + ) + )] + #[case::div_num( + Expr::bin_expr( + BinOp::Div, + Expr::num(4, Location::None), + Expr::num(2, Location::None), + location(1) + ), + Expr::num(2, location(1)) + )] + #[case::div_nested( + Expr::bin_expr( + BinOp::Div, + Expr::bin_expr( + BinOp::Div, + Expr::num(8, Location::None), + Expr::num(2, Location::None), + Location::None, + ), + Expr::bin_expr( + BinOp::Div, + Expr::num(6, Location::None), + Expr::num(3, Location::None), + Location::None, + ), + location(1) + ), + Expr::num(2, location(1)) + )] + #[case::div_zero_dividend( + Expr::bin_expr( + BinOp::Div, + Expr::num(0, location(1)), + var("X", Location::None), + Location::None + ), + Expr::num(0, location(1)) + )] + #[should_panic(expected = "division by zero")] + #[case::div_zero_divisor( + Expr::bin_expr( + BinOp::Div, + var("X", Location::None), + Expr::num(0, location(1)), + Location::None + ), + Expr::num(0, location(1)) + )] + #[case::pow_ident( + Expr::bin_expr( + BinOp::Pow, + var("X", Location::None), + Expr::num(2, Location::None), + location(1) + ), + Expr::bin_expr( + BinOp::Pow, + var("X", Location::None), + Expr::num(2, Location::None), + location(1) + ) + )] + #[case::pow_num( + Expr::bin_expr( + BinOp::Pow, + Expr::num(4, Location::None), + Expr::num(2, Location::None), + location(1) + ), + Expr::num(16, location(1)) + )] + #[case::pow_nested( + Expr::bin_expr( + BinOp::Pow, + Expr::bin_expr( + BinOp::Pow, + Expr::num(2, Location::None), + Expr::num(3, Location::None), + Location::None, + ), + Expr::bin_expr( + BinOp::Pow, + Expr::num(2, Location::None), + Expr::num(2, Location::None), + Location::None, + ), + location(1) + ), + Expr::num(4096, location(1)) + )] + #[case::pow_zero_base( + Expr::bin_expr( + BinOp::Pow, + Expr::num(0, Location::None), + var("X", Location::None), + location(1) + ), + Expr::num(0, location(1)) + )] + #[case::pow_zero_exponent( + Expr::bin_expr( + BinOp::Pow, + var("X", Location::None), + Expr::num(0, Location::None), + location(1) + ), + Expr::num(1, location(1)) + )] + #[should_panic(expected = "negative exponent")] + #[case::pow_negative_exponent( + Expr::bin_expr( + BinOp::Pow, + Expr::num(2, Location::None), + Expr::num(-1, Location::None), + Location::None + ), + Expr::bin_expr( + BinOp::Pow, + Expr::num(2, Location::None), + Expr::num(-1, Location::None), + Location::None + ) + )] + #[should_panic(expected = "too big exponent")] + #[case::pow_too_big_exponent( + Expr::bin_expr( + BinOp::Pow, + Expr::num(2, Location::None), + Expr::num(i64::MAX, Location::None), + Location::None + ), + Expr::bin_expr( + BinOp::Pow, + Expr::num(2, Location::None), + Expr::num(i64::MAX, Location::None), + Location::None + ) + )] + #[case::mod_ident( + Expr::bin_expr( + BinOp::Mod, + var("X", Location::None), + Expr::num(2, Location::None), + location(1) + ), + Expr::bin_expr( + BinOp::Mod, + var("X", Location::None), + Expr::num(2, Location::None), + location(1) + ) + )] + #[case::mod_num( + Expr::bin_expr( + BinOp::Mod, + Expr::num(4, Location::None), + Expr::num(3, Location::None), + location(1) + ), + Expr::num(1, location(1)) + )] + #[case::mod_num_neg( + Expr::bin_expr( + BinOp::Mod, + Expr::num(-7, Location::None), + Expr::num(-3, Location::None), + location(1) + ), + Expr::num(-1, location(1)) + )] + #[case::mod_nested( + Expr::bin_expr( + BinOp::Mod, + Expr::bin_expr( + BinOp::Mod, + Expr::num(8, Location::None), + Expr::num(5, Location::None), + Location::None, + ), + Expr::bin_expr( + BinOp::Mod, + Expr::num(5, Location::None), + Expr::num(3, Location::None), + Location::None, + ), + location(1) + ), + Expr::num(1, location(1)) + )] + #[case::mod_zero_dividend( + Expr::bin_expr( + BinOp::Mod, + Expr::num(0, location(1)), + var("X", Location::None), + Location::None + ), + Expr::num(0, location(1)) + )] + #[should_panic(expected = "modulo by zero")] + #[case::mod_zero_divisor_num( + Expr::bin_expr( + BinOp::Mod, + Expr::num(2, location(1)), + Expr::num(0, location(1)), + Location::None + ), + Expr::num(0, location(1)) + )] + #[should_panic(expected = "modulo by zero")] + #[case::mod_zero_divisor_var( + Expr::bin_expr( + BinOp::Mod, + var("X", Location::None), + Expr::num(0, location(1)), + Location::None + ), + Expr::num(0, location(1)) + )] + #[should_panic(expected = "invalid operation")] + #[case::sub_invalid_operation( + Expr::bin_expr( + BinOp::Sub, + TRUE.clone(), + FALSE.clone(), + location(1)), + TRUE.clone() + )] fn test_expr_into_simplified(#[case] expression: Expr, #[case] expected: Expr) { assert_eq!(expression.into_simplified(), expected); } @@ -583,6 +1208,26 @@ mod tests { }, Expr::neg(var("X", location(1)), location(1)), )] + #[case::bin_expr( + Expr::bin_expr(BinOp::Sub, Expr::num(2, location(3)), Expr::num(1, location(2)), location(1)), + &|e: &Expr| + if matches!(e, Expr::BinExpr(_)) { + var("X", e.location().clone()) + } else { + e.clone() + }, + var("X", location(1)) + )] + #[case::bin_expr_nested( + Expr::bin_expr(BinOp::Sub, Expr::num(2, location(3)), Expr::num(1, location(2)), location(1)), + &|e: &Expr| + if matches!(e, Expr::Num(_)) { + var("X", e.location().clone()) + } else { + e.clone() + }, + Expr::bin_expr(BinOp::Sub, var("X", location(3)), var("X", location(2)), location(1)), + )] fn test_expr_into_substituted( #[case] expression: Expr, #[case] f: &impl Fn(&Expr) -> Expr, @@ -602,6 +1247,34 @@ mod tests { Expr::neg(Expr::num(1, Location::None), Location::None), ty::Ty::UniversalInteger(ty::UniversalInteger { bounds: ty::Bounds::new(1, 1) }) )] + #[case::sub( + Expr::bin_expr( + BinOp::Sub, + Expr::bin_expr( + BinOp::Sub, + var("X", Location::None), + Expr::num(1, Location::None), + Location::None + ), + Expr::num(2, Location::None), + Location::None + ), + ty::BASE_INTEGER.clone() + )] + #[case::div( + Expr::bin_expr( + BinOp::Div, + Expr::bin_expr( + BinOp::Div, + var("X", Location::None), + Expr::num(1, Location::None), + Location::None + ), + TRUE.clone(), + Location::None + ), + ty::Ty::Undefined + )] fn test_expr_ty(#[case] expression: Expr, #[case] expected: ty::Ty) { assert_eq!(expression.ty(), expected); } @@ -627,6 +1300,48 @@ mod tests { bounds: ty::Bounds::new(1, 1), }) )] + #[case::sub( + Expr::bin_expr( + BinOp::Sub, + Expr::bin_expr( + BinOp::Sub, + Expr::var(create_id!(["X"], Location::None), int_ty()), + Expr::num(1, Location::None), + Location::None + ), + Expr::num(2, Location::None), + Location::None + ), + int_ty() + )] + #[case::div( + Expr::bin_expr( + BinOp::Div, + Expr::bin_expr( + BinOp::Div, + Expr::var(create_id!(["X"], Location::None), int_ty()), + Expr::num(1, Location::None), + Location::None + ), + Expr::num(2, Location::None), + Location::None + ), + int_ty() + )] + #[case::pow( + Expr::bin_expr( + BinOp::Pow, + Expr::bin_expr( + BinOp::Pow, + Expr::var(create_id!(["X"], Location::None), int_ty()), + Expr::num(1, Location::None), + Location::None + ), + Expr::num(2, Location::None), + Location::None + ), + int_ty() + )] fn test_expr_check_type(#[case] expression: Expr, #[case] expected: ty::Ty) { assert_eq!(expression.check_type(&[expected.clone()]).entries(), vec![]); assert_eq!(expression.ty(), expected); @@ -651,6 +1366,50 @@ mod tests { ) ] )] + #[case::bin_expr( + Expr::bin_expr( + BinOp::Div, + Expr::bin_expr( + BinOp::Sub, + var("X", Location::None), + Expr::neg(Expr::num(1, Location::None), Location::None), + Location::None + ), + Expr::num(2, Location::None), + location(1) + ), + &[ + Entry::new( + "expected enumeration type \"__BUILTINS__::Boolean\"".to_string(), + Severity::Error, + location(1), + vec![Annotation::new(Some("found integer type \"__BUILTINS__::Base_Integer\" (0 .. 2**63 - 1)".to_string()), Severity::Error, location(1) )], + false + ) + ] + )] + #[case::bin_expr_sub_expr( + Expr::bin_expr( + BinOp::Div, + Expr::bin_expr( + BinOp::Sub, + Expr::var(create_id!(["X"], location(1)), ty::BOOLEAN.clone()), + Expr::neg(Expr::num(1, Location::None), Location::None), + Location::None + ), + Expr::num(2, Location::None), + Location::None + ), + &[ + Entry::new( + "expected integer type".to_string(), + Severity::Error, + location(1), + vec![Annotation::new(Some("found enumeration type \"__BUILTINS__::Boolean\"".to_string()), Severity::Error, location(1) )], + false + ) + ] + )] fn test_expr_check_type_error(#[case] expression: Expr, #[case] expected: &[Entry]) { assert_eq!( expression.check_type(&[ty::BOOLEAN.clone()]), @@ -678,6 +1437,20 @@ mod tests { Expr::neg(Expr::num(42, location(2)), location(1)), Expr::neg(Expr::num(42, location(4)), location(3)) )] + #[case::bin_expr( + Expr::bin_expr( + BinOp::Sub, + Expr::num(2, location(3)), + Expr::num(1, location(2)), + location(1) + ), + Expr::bin_expr( + BinOp::Sub, + Expr::num(2, location(6)), + Expr::num(1, location(5)), + location(4) + ) + )] fn test_expr_eq(#[case] left: Expr, #[case] right: Expr) { assert_eq!(left, right); } @@ -690,6 +1463,48 @@ mod tests { Expr::neg(Expr::num(42, Location::None), Location::None), Expr::neg(var("X", Location::None), Location::None) )] + #[case::bin_expr_op( + Expr::bin_expr( + BinOp::Sub, + Expr::num(2, location(3)), + Expr::num(1, location(2)), + location(1) + ), + Expr::bin_expr( + BinOp::Div, + Expr::num(2, location(6)), + Expr::num(1, location(5)), + location(4) + ) + )] + #[case::bin_expr_left( + Expr::bin_expr( + BinOp::Sub, + Expr::num(2, location(3)), + Expr::num(1, location(2)), + location(1) + ), + Expr::bin_expr( + BinOp::Sub, + Expr::num(1, location(6)), + Expr::num(1, location(5)), + location(4) + ) + )] + #[case::bin_expr_right( + Expr::bin_expr( + BinOp::Sub, + Expr::num(2, location(3)), + Expr::num(1, location(2)), + location(1) + ), + Expr::bin_expr( + BinOp::Sub, + Expr::num(2, location(6)), + Expr::num(2, location(5)), + location(4) + ) + )] fn test_expr_ne(#[case] left: Expr, #[case] right: Expr) { assert_ne!(left, right); } @@ -708,6 +1523,30 @@ mod tests { Expr::neg(Expr::neg(var("X", Location::None), Location::None), Location::None), "-(-X)" )] + #[case::bin_expr( + Expr::bin_expr( + BinOp::Mod, + Expr::bin_expr( + BinOp::Div, + Expr::bin_expr( + BinOp::Sub, + Expr::bin_expr( + BinOp::Pow, + var("X", Location::None), + Expr::num(2, Location::None), + Location::None + ), + Expr::num(1, Location::None), + Location::None + ), + Expr::neg(Expr::num(2, Location::None), Location::None), + Location::None + ), + Expr::num(3, Location::None), + Location::None + ), + "((X ** 2 - 1) / (-2)) mod 3" + )] fn test_expr_display(#[case] expression: Expr, #[case] expected: &str) { assert_eq!(expression.to_string(), expected); } @@ -717,6 +1556,17 @@ mod tests { #[case::lit(lit("X", location(1)))] #[case::num(Expr::num(42, location(1)))] #[case::neg(Expr::neg(Expr::num(42, location(2)), location(1)))] + #[case::bin_expr(Expr::bin_expr( + BinOp::Div, + Expr::bin_expr( + BinOp::Sub, + var("X", Location::None), + Expr::num(1, Location::None), + Location::None + ), + Expr::neg(Expr::num(2, Location::None), Location::None), + Location::None + ))] fn test_expr_serde(#[case] expr: Expr) { let bytes = bincode::serialize(&expr).expect("failed to serialize"); let deserialized_expr: Expr = bincode::deserialize(&bytes).expect("failed to deserialize"); diff --git a/rapidflux/src/expr.rs b/rapidflux/src/expr.rs index 02daabbf3..395056292 100644 --- a/rapidflux/src/expr.rs +++ b/rapidflux/src/expr.rs @@ -658,6 +658,486 @@ impl Neg { } } +#[pyclass(extends = Expr, module = "rflx.rapidflux.expr")] +#[derive(Clone, PartialEq, Serialize, Deserialize)] +pub struct Sub(lib::BinExpr); + +#[pymethods] +impl Sub { + #[new] + fn new( + left: &Bound<'_, Expr>, + right: &Bound<'_, Expr>, + location: Option, + ) -> (Self, Expr) { + ( + Self(lib::BinExpr { + op: lib::BinOp::Sub, + left: Box::new(to_expr(left)), + right: Box::new(to_expr(right)), + location: location.unwrap_or(NO_LOCATION).0, + }), + Expr, + ) + } + + fn __getnewargs__(&self, py: Python<'_>) -> (PyObject, PyObject, Location) { + (self.left(py), self.right(py), self.location()) + } + + fn __str__(&self) -> String { + self.0.to_string() + } + + fn __repr__(&self, py: Python<'_>) -> PyResult { + Ok(format!( + "Sub({}, {}, {})", + to_py(&self.0.left, py).call_method0(py, "__repr__")?, + to_py(&self.0.right, py).call_method0(py, "__repr__")?, + self.location().__repr__() + )) + } + + fn __richcmp__(&self, other: &Bound<'_, PyAny>, op: CompareOp) -> bool { + if let Ok(other) = other.extract::() { + match op { + CompareOp::Eq => self.0.left == other.0.left && self.0.right == other.0.right, + CompareOp::Ne => self.0.left != other.0.left || self.0.right != other.0.right, + _ => false, + } + } else { + matches!(op, CompareOp::Ne) + } + } + + #[allow(clippy::unused_self)] + fn __hash__(&self) -> usize { + 0 + } + + fn __neg__(&self, py: Python<'_>) -> PyObject { + to_py(&self.0.clone().into_negated(), py) + } + + #[getter] + fn location(&self) -> Location { + Location(self.0.location.clone()) + } + + #[getter] + fn precedence(&self) -> Precedence { + Precedence(self.0.precedence().clone()) + } + + #[getter] + fn type_(&self, py: Python<'_>) -> PyObject { + ty::to_py(&self.0.ty(), py) + } + + #[getter] + fn left(&self, py: Python<'_>) -> PyObject { + to_py(&self.0.left, py) + } + + #[getter] + fn right(&self, py: Python<'_>) -> PyObject { + to_py(&self.0.right, py) + } + + fn check_type(&self, expected: &Bound<'_, PyAny>) -> error::Error { + error::Error(lib::Expr::BinExpr(self.0.clone()).check_type(&to_ty_list(expected))) + } + + fn check_type_instance(&self, expected: &Bound<'_, PyAny>, py: Python<'_>) -> error::Error { + error::Error( + lib::Expr::BinExpr(self.0.clone()) + .check_type_instance(&to_ty_discriminants_list(expected, py)), + ) + } + + fn variables(&self, py: Python<'_>) -> Vec { + to_py_list(&lib::Expr::BinExpr(self.0.clone()).variables(), py) + } + + fn findall(&self, r#match: &Bound<'_, PyAny>, py: Python<'_>) -> Vec { + to_py_list( + &lib::Expr::BinExpr(self.0.clone()).find_all(&to_fn_expr_bool(r#match, py)), + py, + ) + } + + fn substituted(&self, func: &Bound<'_, PyAny>, py: Python<'_>) -> PyObject { + to_py( + &lib::Expr::BinExpr(self.0.clone()).into_substituted(&to_fn_expr_expr(func, py)), + py, + ) + } + + fn simplified(&self, py: Python<'_>) -> PyObject { + to_py(&lib::Expr::BinExpr(self.0.clone()).into_simplified(), py) + } +} + +#[pyclass(extends = Expr, module = "rflx.rapidflux.expr")] +#[derive(Clone, PartialEq, Serialize, Deserialize)] +pub struct Div(lib::BinExpr); + +#[pymethods] +impl Div { + #[new] + fn new( + left: &Bound<'_, Expr>, + right: &Bound<'_, Expr>, + location: Option, + ) -> (Self, Expr) { + ( + Self(lib::BinExpr { + op: lib::BinOp::Div, + left: Box::new(to_expr(left)), + right: Box::new(to_expr(right)), + location: location.unwrap_or(NO_LOCATION).0, + }), + Expr, + ) + } + + fn __getnewargs__(&self, py: Python<'_>) -> (PyObject, PyObject, Location) { + (self.left(py), self.right(py), self.location()) + } + + fn __str__(&self) -> String { + self.0.to_string() + } + + fn __repr__(&self, py: Python<'_>) -> PyResult { + Ok(format!( + "Div({}, {}, {})", + to_py(&self.0.left, py).call_method0(py, "__repr__")?, + to_py(&self.0.right, py).call_method0(py, "__repr__")?, + self.location().__repr__() + )) + } + + fn __richcmp__(&self, other: &Bound<'_, PyAny>, op: CompareOp) -> bool { + if let Ok(other) = other.extract::() { + match op { + CompareOp::Eq => self.0.left == other.0.left && self.0.right == other.0.right, + CompareOp::Ne => self.0.left != other.0.left || self.0.right != other.0.right, + _ => false, + } + } else { + matches!(op, CompareOp::Ne) + } + } + + #[allow(clippy::unused_self)] + fn __hash__(&self) -> usize { + 0 + } + + fn __neg__(&self, py: Python<'_>) -> PyObject { + to_py(&self.0.clone().into_negated(), py) + } + + #[getter] + fn location(&self) -> Location { + Location(self.0.location.clone()) + } + + #[getter] + fn precedence(&self) -> Precedence { + Precedence(self.0.precedence().clone()) + } + + #[getter] + fn type_(&self, py: Python<'_>) -> PyObject { + ty::to_py(&self.0.ty(), py) + } + + #[getter] + fn left(&self, py: Python<'_>) -> PyObject { + to_py(&self.0.left, py) + } + + #[getter] + fn right(&self, py: Python<'_>) -> PyObject { + to_py(&self.0.right, py) + } + + fn check_type(&self, expected: &Bound<'_, PyAny>) -> error::Error { + error::Error(lib::Expr::BinExpr(self.0.clone()).check_type(&to_ty_list(expected))) + } + + fn check_type_instance(&self, expected: &Bound<'_, PyAny>, py: Python<'_>) -> error::Error { + error::Error( + lib::Expr::BinExpr(self.0.clone()) + .check_type_instance(&to_ty_discriminants_list(expected, py)), + ) + } + + fn variables(&self, py: Python<'_>) -> Vec { + to_py_list(&lib::Expr::BinExpr(self.0.clone()).variables(), py) + } + + fn findall(&self, r#match: &Bound<'_, PyAny>, py: Python<'_>) -> Vec { + to_py_list( + &lib::Expr::BinExpr(self.0.clone()).find_all(&to_fn_expr_bool(r#match, py)), + py, + ) + } + + fn substituted(&self, func: &Bound<'_, PyAny>, py: Python<'_>) -> PyObject { + to_py( + &lib::Expr::BinExpr(self.0.clone()).into_substituted(&to_fn_expr_expr(func, py)), + py, + ) + } + + fn simplified(&self, py: Python<'_>) -> PyObject { + to_py(&lib::Expr::BinExpr(self.0.clone()).into_simplified(), py) + } +} + +#[pyclass(extends = Expr, module = "rflx.rapidflux.expr")] +#[derive(Clone, PartialEq, Serialize, Deserialize)] +pub struct Pow(lib::BinExpr); + +#[pymethods] +impl Pow { + #[new] + fn new( + left: &Bound<'_, Expr>, + right: &Bound<'_, Expr>, + location: Option, + ) -> (Self, Expr) { + ( + Self(lib::BinExpr { + op: lib::BinOp::Pow, + left: Box::new(to_expr(left)), + right: Box::new(to_expr(right)), + location: location.unwrap_or(NO_LOCATION).0, + }), + Expr, + ) + } + + fn __getnewargs__(&self, py: Python<'_>) -> (PyObject, PyObject, Location) { + (self.left(py), self.right(py), self.location()) + } + + fn __str__(&self) -> String { + self.0.to_string() + } + + fn __repr__(&self, py: Python<'_>) -> PyResult { + Ok(format!( + "Pow({}, {}, {})", + to_py(&self.0.left, py).call_method0(py, "__repr__")?, + to_py(&self.0.right, py).call_method0(py, "__repr__")?, + self.location().__repr__() + )) + } + + fn __richcmp__(&self, other: &Bound<'_, PyAny>, op: CompareOp) -> bool { + if let Ok(other) = other.extract::() { + match op { + CompareOp::Eq => self.0.left == other.0.left && self.0.right == other.0.right, + CompareOp::Ne => self.0.left != other.0.left || self.0.right != other.0.right, + _ => false, + } + } else { + matches!(op, CompareOp::Ne) + } + } + + #[allow(clippy::unused_self)] + fn __hash__(&self) -> usize { + 0 + } + + fn __neg__(&self, py: Python<'_>) -> PyObject { + to_py(&self.0.clone().into_negated(), py) + } + + #[getter] + fn location(&self) -> Location { + Location(self.0.location.clone()) + } + + #[getter] + fn precedence(&self) -> Precedence { + Precedence(self.0.precedence().clone()) + } + + #[getter] + fn type_(&self, py: Python<'_>) -> PyObject { + ty::to_py(&self.0.ty(), py) + } + + #[getter] + fn left(&self, py: Python<'_>) -> PyObject { + to_py(&self.0.left, py) + } + + #[getter] + fn right(&self, py: Python<'_>) -> PyObject { + to_py(&self.0.right, py) + } + + fn check_type(&self, expected: &Bound<'_, PyAny>) -> error::Error { + error::Error(lib::Expr::BinExpr(self.0.clone()).check_type(&to_ty_list(expected))) + } + + fn check_type_instance(&self, expected: &Bound<'_, PyAny>, py: Python<'_>) -> error::Error { + error::Error( + lib::Expr::BinExpr(self.0.clone()) + .check_type_instance(&to_ty_discriminants_list(expected, py)), + ) + } + + fn variables(&self, py: Python<'_>) -> Vec { + to_py_list(&lib::Expr::BinExpr(self.0.clone()).variables(), py) + } + + fn findall(&self, r#match: &Bound<'_, PyAny>, py: Python<'_>) -> Vec { + to_py_list( + &lib::Expr::BinExpr(self.0.clone()).find_all(&to_fn_expr_bool(r#match, py)), + py, + ) + } + + fn substituted(&self, func: &Bound<'_, PyAny>, py: Python<'_>) -> PyObject { + to_py( + &lib::Expr::BinExpr(self.0.clone()).into_substituted(&to_fn_expr_expr(func, py)), + py, + ) + } + + fn simplified(&self, py: Python<'_>) -> PyObject { + to_py(&lib::Expr::BinExpr(self.0.clone()).into_simplified(), py) + } +} + +#[pyclass(extends = Expr, module = "rflx.rapidflux.expr")] +#[derive(Clone, PartialEq, Serialize, Deserialize)] +pub struct Mod(lib::BinExpr); + +#[pymethods] +impl Mod { + #[new] + fn new( + left: &Bound<'_, Expr>, + right: &Bound<'_, Expr>, + location: Option, + ) -> (Self, Expr) { + ( + Self(lib::BinExpr { + op: lib::BinOp::Mod, + left: Box::new(to_expr(left)), + right: Box::new(to_expr(right)), + location: location.unwrap_or(NO_LOCATION).0, + }), + Expr, + ) + } + + fn __getnewargs__(&self, py: Python<'_>) -> (PyObject, PyObject, Location) { + (self.left(py), self.right(py), self.location()) + } + + fn __str__(&self) -> String { + self.0.to_string() + } + + fn __repr__(&self, py: Python<'_>) -> PyResult { + Ok(format!( + "Mod({}, {}, {})", + to_py(&self.0.left, py).call_method0(py, "__repr__")?, + to_py(&self.0.right, py).call_method0(py, "__repr__")?, + self.location().__repr__() + )) + } + + fn __richcmp__(&self, other: &Bound<'_, PyAny>, op: CompareOp) -> bool { + if let Ok(other) = other.extract::() { + match op { + CompareOp::Eq => self.0.left == other.0.left && self.0.right == other.0.right, + CompareOp::Ne => self.0.left != other.0.left || self.0.right != other.0.right, + _ => false, + } + } else { + matches!(op, CompareOp::Ne) + } + } + + #[allow(clippy::unused_self)] + fn __hash__(&self) -> usize { + 0 + } + + fn __neg__(&self, py: Python<'_>) -> PyObject { + to_py(&self.0.clone().into_negated(), py) + } + + #[getter] + fn location(&self) -> Location { + Location(self.0.location.clone()) + } + + #[getter] + fn precedence(&self) -> Precedence { + Precedence(self.0.precedence().clone()) + } + + #[getter] + fn type_(&self, py: Python<'_>) -> PyObject { + ty::to_py(&self.0.ty(), py) + } + + #[getter] + fn left(&self, py: Python<'_>) -> PyObject { + to_py(&self.0.left, py) + } + + #[getter] + fn right(&self, py: Python<'_>) -> PyObject { + to_py(&self.0.right, py) + } + + fn check_type(&self, expected: &Bound<'_, PyAny>) -> error::Error { + error::Error(lib::Expr::BinExpr(self.0.clone()).check_type(&to_ty_list(expected))) + } + + fn check_type_instance(&self, expected: &Bound<'_, PyAny>, py: Python<'_>) -> error::Error { + error::Error( + lib::Expr::BinExpr(self.0.clone()) + .check_type_instance(&to_ty_discriminants_list(expected, py)), + ) + } + + fn variables(&self, py: Python<'_>) -> Vec { + to_py_list(&lib::Expr::BinExpr(self.0.clone()).variables(), py) + } + + fn findall(&self, r#match: &Bound<'_, PyAny>, py: Python<'_>) -> Vec { + to_py_list( + &lib::Expr::BinExpr(self.0.clone()).find_all(&to_fn_expr_bool(r#match, py)), + py, + ) + } + + fn substituted(&self, func: &Bound<'_, PyAny>, py: Python<'_>) -> PyObject { + to_py( + &lib::Expr::BinExpr(self.0.clone()).into_substituted(&to_fn_expr_expr(func, py)), + py, + ) + } + + fn simplified(&self, py: Python<'_>) -> PyObject { + to_py(&lib::Expr::BinExpr(self.0.clone()).into_simplified(), py) + } +} + fn to_expr(obj: &Bound<'_, PyAny>) -> lib::Expr { if let Ok(var) = obj.extract::>() { lib::Expr::Var(var.0.clone()) @@ -667,6 +1147,14 @@ fn to_expr(obj: &Bound<'_, PyAny>) -> lib::Expr { lib::Expr::Num(num.0.clone()) } else if let Ok(neg) = obj.extract::>() { lib::Expr::Neg(neg.0.clone()) + } else if let Ok(sub) = obj.extract::>() { + lib::Expr::BinExpr(sub.0.clone()) + } else if let Ok(div) = obj.extract::>() { + lib::Expr::BinExpr(div.0.clone()) + } else if let Ok(pow) = obj.extract::>() { + lib::Expr::BinExpr(pow.0.clone()) + } else if let Ok(r#mod) = obj.extract::>() { + lib::Expr::BinExpr(r#mod.0.clone()) } else { panic!("unknown expression \"{obj:?}\"") } @@ -725,16 +1213,30 @@ fn to_ty_discriminants_list( fn to_py(obj: &lib::Expr, py: Python<'_>) -> PyObject { match obj { - lib::Expr::Var(num) => Py::new(py, (Variable(num.clone()), Expr)) + lib::Expr::Var(var) => Py::new(py, (Variable(var.clone()), Expr)) .unwrap() .into_py(py), - lib::Expr::Lit(num) => Py::new(py, (Literal(num.clone()), Expr)) + lib::Expr::Lit(lit) => Py::new(py, (Literal(lit.clone()), Expr)) .unwrap() .into_py(py), lib::Expr::Num(num) => Py::new(py, (Number(num.clone()), Expr)) .unwrap() .into_py(py), - lib::Expr::Neg(num) => Py::new(py, (Neg(num.clone()), Expr)).unwrap().into_py(py), + lib::Expr::Neg(neg) => Py::new(py, (Neg(neg.clone()), Expr)).unwrap().into_py(py), + lib::Expr::BinExpr(bin_expr) => match bin_expr.op { + lib::BinOp::Sub => Py::new(py, (Sub(bin_expr.clone()), Expr)) + .unwrap() + .into_py(py), + lib::BinOp::Div => Py::new(py, (Div(bin_expr.clone()), Expr)) + .unwrap() + .into_py(py), + lib::BinOp::Pow => Py::new(py, (Pow(bin_expr.clone()), Expr)) + .unwrap() + .into_py(py), + lib::BinOp::Mod => Py::new(py, (Mod(bin_expr.clone()), Expr)) + .unwrap() + .into_py(py), + }, } } @@ -742,10 +1244,10 @@ fn to_py_list(obj: &[lib::Expr], py: Python<'_>) -> Vec { obj.iter().map(|e| to_py(e, py)).collect::>() } -impl_states!(Expr, Precedence, Variable, Literal, Number, Neg); +impl_states!(Expr, Precedence, Variable, Literal, Number, Neg, Sub, Div, Pow, Mod); register_submodule_declarations!( expr, [], - [Expr, Precedence, Variable, Literal, Number, Neg], + [Expr, Precedence, Variable, Literal, Number, Neg, Sub, Div, Pow, Mod], [] ); diff --git a/rflx/rapidflux/expr.pyi b/rflx/rapidflux/expr.pyi index 01a274150..3259032ab 100644 --- a/rflx/rapidflux/expr.pyi +++ b/rflx/rapidflux/expr.pyi @@ -63,11 +63,6 @@ class Literal(Expr): @property def name(self) -> str: ... -class Neg(Expr): - def __init__(self, expr: Expr, location: Location | None = None) -> None: ... - @property - def expr(self) -> Expr: ... - class Number(Expr): def __init__(self, value: int, base: int = 0, location: Location | None = None) -> None: ... def __int__(self) -> int: ... @@ -75,3 +70,36 @@ class Number(Expr): def value(self) -> int: ... @property def base(self) -> int: ... + +class Neg(Expr): + def __init__(self, expr: Expr, location: Location | None = None) -> None: ... + @property + def expr(self) -> Expr: ... + +class Sub(Expr): + def __init__(self, left: Expr, right: Expr, location: Location | None = None) -> None: ... + @property + def left(self) -> Expr: ... + @property + def right(self) -> Expr: ... + +class Div(Expr): + def __init__(self, left: Expr, right: Expr, location: Location | None = None) -> None: ... + @property + def left(self) -> Expr: ... + @property + def right(self) -> Expr: ... + +class Pow(Expr): + def __init__(self, left: Expr, right: Expr, location: Location | None = None) -> None: ... + @property + def left(self) -> Expr: ... + @property + def right(self) -> Expr: ... + +class Mod(Expr): + def __init__(self, left: Expr, right: Expr, location: Location | None = None) -> None: ... + @property + def left(self) -> Expr: ... + @property + def right(self) -> Expr: ... diff --git a/tests/unit/expr_test.py b/tests/unit/expr_test.py index 23ea9b15c..5b4f9c87e 100644 --- a/tests/unit/expr_test.py +++ b/tests/unit/expr_test.py @@ -94,6 +94,10 @@ def assert_type_error(expr: Expr, regex: str) -> None: expr.check_type(ty.Any()).propagate() +def assert_type_instance(expr: Expr, type_: type[ty.Type] | tuple[type[ty.Type], ...]) -> None: + expr.check_type_instance(type_).propagate() + + def test_true_type() -> None: assert_type( TRUE, @@ -858,6 +862,20 @@ def test_mul_simplified() -> None: assert Mul(Number(2), Number(3), Number(5)).simplified() == Number(30) +def test_sub_str() -> None: + assert str(Sub(Variable("X"), Number(1))) == "X - 1" + assert str(Sub(Neg(Variable("X")), Number(-1))) == "-X - (-1)" + + +def test_sub_eq() -> None: + assert Sub(Variable("X"), Number(1), location=Location((1, 2))) == Sub(Variable("X"), Number(1)) + + +def test_sub_ne() -> None: + assert Sub(Variable("X"), Number(1)) != Sub(Variable("Y"), Number(1)) + assert Sub(Variable("X"), Number(1)) != Sub(Variable("X"), Number(2)) + + def test_sub_neg() -> None: assert -Sub(Number(1), Variable("X")) == Sub(Variable("X"), Number(1)) @@ -880,10 +898,44 @@ def test_sub_neg_eval(left: int, right: int) -> None: assert (-Sub(Number(left), Number(right))).simplified() == Number(-(left - right)) +def test_sub_type() -> None: + assert_type( + Sub(Variable("X", type_=INT_TY), Number(1)), + ty.BASE_INTEGER, + ) + assert_type_instance( + Sub(Variable("X", type_=INT_TY), Number(1)), + ty.Integer, + ) + + +def test_sub_type_error() -> None: + assert_type_error( + Sub(Variable(ID("X", location=Location((1, 2))), type_=ty.BOOLEAN), Number(1)), + r"^" + r":1:2: error: expected integer type\n" + r':1:2: error: found enumeration type "__BUILTINS__::Boolean"' + r"$", + ) + + def test_sub_variables() -> None: assert Sub(Variable("X"), Variable("Y")).variables() == [Variable("X"), Variable("Y")] +def test_sub_findall() -> None: + assert Sub(Variable("X"), Variable("Y")).findall(lambda x: isinstance(x, Variable)) == [ + Variable("X"), + Variable("Y"), + ] + + +def test_sub_substituted() -> None: + assert Sub(Variable("X"), Variable("Y")).substituted( + lambda x: Number(42) if x == Variable("X") else x, + ) == Sub(Number(42), Variable("Y")) + + def test_sub_simplified() -> None: assert Sub(Number(1), Variable("X")).simplified() == Add(Number(1), -Variable("X")) assert Sub(Variable("X"), Number(1)).simplified() == Add(Variable("X"), Number(-1)) @@ -915,6 +967,20 @@ def test_sub_simplified_to_add() -> None: assert simplified.location == Location((1, 1), end=(1, 2)) +def test_div_str() -> None: + assert str(Div(Variable("X"), Number(1))) == "X / 1" + assert str(Div(Neg(Variable("X")), Number(-1))) == "(-X) / (-1)" + + +def test_div_eq() -> None: + assert Div(Variable("X"), Number(1), location=Location((1, 2))) == Div(Variable("X"), Number(1)) + + +def test_div_ne() -> None: + assert Div(Variable("X"), Number(1)) != Div(Variable("Y"), Number(1)) + assert Div(Variable("X"), Number(1)) != Div(Variable("X"), Number(2)) + + def test_div_neg() -> None: assert -Div(Variable("X"), Number(5)) == Div(-(Variable("X")), Number(5)) @@ -935,16 +1001,64 @@ def test_div_neg_eval(left: int, right: int) -> None: assert (-Div(Number(left), Number(right))).simplified() == Number(-(left // right)) +def test_div_type() -> None: + assert_type( + Div(Variable("X", type_=INT_TY), Number(1)), + ty.BASE_INTEGER, + ) + assert_type_instance( + Div(Variable("X", type_=INT_TY), Number(1)), + ty.Integer, + ) + + +def test_div_type_error() -> None: + assert_type_error( + Div(Variable(ID("X", location=Location((1, 2))), type_=ty.BOOLEAN), Number(1)), + r"^" + r":1:2: error: expected integer type\n" + r':1:2: error: found enumeration type "__BUILTINS__::Boolean"' + r"$", + ) + + def test_div_variables() -> None: assert Div(Variable("X"), Variable("Y")).variables() == [Variable("X"), Variable("Y")] +def test_div_findall() -> None: + assert Div(Variable("X"), Variable("Y")).findall(lambda x: isinstance(x, Variable)) == [ + Variable("X"), + Variable("Y"), + ] + + +def test_div_substituted() -> None: + assert Div(Variable("X"), Variable("Y")).substituted( + lambda x: Number(42) if x == Variable("X") else x, + ) == Div(Number(42), Variable("Y")) + + def test_div_simplified() -> None: assert Div(Variable("X"), Number(1)).simplified() == Div(Variable("X"), Number(1)) assert Div(Number(6), Number(2)).simplified() == Number(3) assert Div(Number(9), Number(2)).simplified() == Div(Number(9), Number(2)) +def test_pow_str() -> None: + assert str(Pow(Variable("X"), Number(1))) == "X ** 1" + assert str(Pow(Neg(Variable("X")), Number(-1))) == "(-X) ** (-1)" + + +def test_pow_eq() -> None: + assert Pow(Variable("X"), Number(1), location=Location((1, 2))) == Pow(Variable("X"), Number(1)) + + +def test_pow_ne() -> None: + assert Pow(Variable("X"), Number(1)) != Pow(Variable("Y"), Number(1)) + assert Pow(Variable("X"), Number(1)) != Pow(Variable("X"), Number(2)) + + def test_pow_neg() -> None: assert -Pow(Variable("X"), Number(5)) == -Pow(Variable("X"), Number(5)) @@ -970,6 +1084,44 @@ def test_pow_neg_eval(left: int, right: int) -> None: assert (-Pow(Number(left), Number(right))).simplified() == Number(-(left**right)) +def test_pow_type() -> None: + assert_type( + Pow(Variable("X", type_=INT_TY), Number(1)), + ty.BASE_INTEGER, + ) + assert_type_instance( + Pow(Variable("X", type_=INT_TY), Number(1)), + ty.Integer, + ) + + +def test_pow_type_error() -> None: + assert_type_error( + Pow(Variable(ID("X", location=Location((1, 2))), type_=ty.BOOLEAN), Number(1)), + r"^" + r":1:2: error: expected integer type\n" + r':1:2: error: found enumeration type "__BUILTINS__::Boolean"' + r"$", + ) + + +def test_pow_variables() -> None: + assert Pow(Variable("X"), Variable("Y")).variables() == [Variable("X"), Variable("Y")] + + +def test_pow_findall() -> None: + assert Pow(Variable("X"), Variable("Y")).findall(lambda x: isinstance(x, Variable)) == [ + Variable("X"), + Variable("Y"), + ] + + +def test_pow_substituted() -> None: + assert Pow(Variable("X"), Variable("Y")).substituted( + lambda x: Number(42) if x == Variable("X") else x, + ) == Pow(Number(42), Variable("Y")) + + def test_pow_simplified() -> None: assert Pow(Variable("X"), Number(1)).simplified() == Pow(Variable("X"), Number(1)) assert Pow(Variable("X"), Add(Number(1), Number(1))).simplified() == Pow( @@ -979,10 +1131,6 @@ def test_pow_simplified() -> None: assert Pow(Number(6), Number(2)).simplified() == Number(36) -def test_pow_variables() -> None: - assert Pow(Variable("X"), Variable("Y")).variables() == [Variable("X"), Variable("Y")] - - def test_rem_neg() -> None: assert -Rem(Variable("X"), Number(5)) == -Rem(Variable("X"), Number(5)) @@ -1003,6 +1151,20 @@ def test_rem_neg_eval(left: int, right: int, expected: Expr) -> None: assert (-Rem(Number(left), Number(right))).simplified() == expected +def test_mod_str() -> None: + assert str(Mod(Variable("X"), Number(1))) == "X mod 1" + assert str(Mod(Neg(Variable("X")), Number(-1))) == "(-X) mod (-1)" + + +def test_mod_eq() -> None: + assert Mod(Variable("X"), Number(1), location=Location((1, 2))) == Mod(Variable("X"), Number(1)) + + +def test_mod_ne() -> None: + assert Mod(Variable("X"), Number(1)) != Mod(Variable("Y"), Number(1)) + assert Mod(Variable("X"), Number(1)) != Mod(Variable("X"), Number(2)) + + def test_mod_neg() -> None: assert -Mod(Variable("X"), Number(5)) == -Mod(Variable("X"), Number(5)) @@ -1023,6 +1185,23 @@ def test_mod_neg_eval(left: int, right: int) -> None: assert (-Mod(Number(left), Number(right))).simplified() == Number(-(left % right)) +def test_mod_variables() -> None: + assert Mod(Variable("X"), Variable("Y")).variables() == [Variable("X"), Variable("Y")] + + +def test_mod_findall() -> None: + assert Mod(Variable("X"), Variable("Y")).findall(lambda x: isinstance(x, Variable)) == [ + Variable("X"), + Variable("Y"), + ] + + +def test_mod_substituted() -> None: + assert Mod(Variable("X"), Variable("Y")).substituted( + lambda x: Number(42) if x == Variable("X") else x, + ) == Mod(Number(42), Variable("Y")) + + def test_mod_simplified() -> None: assert Mod(Variable("X"), Number(1)).simplified() == Mod(Variable("X"), Number(1)) assert Mod(Variable("X"), Add(Number(1), Number(1))).simplified() == Mod( @@ -1032,10 +1211,6 @@ def test_mod_simplified() -> None: assert Mod(Number(6), Number(2)).simplified() == Number(0) -def test_mod_variables() -> None: - assert Mod(Variable("X"), Variable("Y")).variables() == [Variable("X"), Variable("Y")] - - def test_term_simplified() -> None: assert_equal( Add( diff --git a/tests/unit/rapidflux/expr_test.py b/tests/unit/rapidflux/expr_test.py index 9404dcacc..dc148914f 100644 --- a/tests/unit/rapidflux/expr_test.py +++ b/tests/unit/rapidflux/expr_test.py @@ -9,7 +9,7 @@ from rflx import ty from rflx.rapidflux import ID, Location, RecordFluxError -from rflx.rapidflux.expr import Expr, Literal, Neg, Number, Variable +from rflx.rapidflux.expr import Div, Expr, Literal, Mod, Neg, Number, Pow, Sub, Variable from tests.utils import check_regex INT_TY = ty.Integer("I", ty.Bounds(10, 100)) @@ -432,6 +432,418 @@ def test_neg_simplified(expr: Expr, expected: Expr) -> None: assert expr.simplified() == expected +def test_sub_str() -> None: + assert str(Sub(Variable("X"), Number(1))) == "X - 1" + assert str(Sub(Neg(Variable("X")), Number(-1))) == "-X - (-1)" + + +def test_sub_repr() -> None: + assert repr(Sub(Variable("X"), Number(1))) == ( + 'Sub(Variable(ID("X", Location((1, 1), "", (1, 1))), Undefined()),' + ' Number(1, Location((1, 1), "", (1, 1))), Location((1, 1), "", (1, 1)))' + ) + + +def test_sub_eq() -> None: + assert Sub(Variable("X"), Number(1), location=Location((1, 2))) == Sub(Variable("X"), Number(1)) + + +def test_sub_ne() -> None: + assert Sub(Variable("X"), Number(1)) != Sub(Variable("Y"), Number(1)) + assert Sub(Variable("X"), Number(1)) != Sub(Variable("X"), Number(2)) + + +def test_sub_neg() -> None: + assert -Sub(Number(1), Variable("X")) == Sub(Variable("X"), Number(1)) + + +@pytest.mark.parametrize( + ("left", "right"), + [ + # 0 in argument + (0, 0), + (0, 5), + (1, 0), + # Argument sign variations + (1, 5), + (-1, 5), + (1, -5), + (-1, -5), + ], +) +def test_sub_neg_eval(left: int, right: int) -> None: + assert (-Sub(Number(left), Number(right))).simplified() == Number(-(left - right)) + + +def test_sub_location() -> None: + assert Sub(Variable("X"), Number(1), location=Location((1, 2))).location == Location((1, 2)) + + +def test_sub_type() -> None: + assert_type( + Sub(Variable("X", type_=INT_TY), Number(1)), + ty.BASE_INTEGER, + ) + assert_type_instance( + Sub(Variable("X", type_=INT_TY), Number(1)), + ty.Integer, + ) + + +def test_sub_type_error() -> None: + assert_type_error( + Sub(Variable(ID("X", location=Location((1, 2))), type_=ty.BOOLEAN), Number(1)), + r"^" + r":1:2: error: expected integer type\n" + r':1:2: error: found enumeration type "__BUILTINS__::Boolean"' + r"$", + ) + + +def test_sub_left() -> None: + assert Sub(Variable("X"), Variable("Y")).left == Variable("X") + + +def test_sub_right() -> None: + assert Sub(Variable("X"), Variable("Y")).right == Variable("Y") + + +def test_sub_variables() -> None: + assert Sub(Variable("X"), Variable("Y")).variables() == [Variable("X"), Variable("Y")] + + +def test_sub_findall() -> None: + assert Sub(Variable("X"), Variable("Y")).findall(lambda x: isinstance(x, Variable)) == [ + Variable("X"), + Variable("Y"), + ] + + +def test_sub_substituted() -> None: + assert Sub(Variable("X"), Variable("Y")).substituted( + lambda x: Number(42) if x == Variable("X") else x, + ) == Sub(Number(42), Variable("Y")) + + +def test_sub_simplified() -> None: + assert Sub(Number(6), Number(2)).simplified() == Number(4) + assert Sub(Number(6), Sub(Number(4), Number(2))).simplified() == Number(4) + + +def test_sub_simplified_location() -> None: + simplified = Sub(Number(5), Number(2), location=Location((1, 1))).simplified() + assert simplified == Number(3) + assert simplified.location == Location((1, 1)) + + +def test_div_str() -> None: + assert str(Div(Variable("X"), Number(1))) == "X / 1" + assert str(Div(Neg(Variable("X")), Number(-1))) == "(-X) / (-1)" + + +def test_div_repr() -> None: + assert repr(Div(Variable("X"), Number(1))) == ( + 'Div(Variable(ID("X", Location((1, 1), "", (1, 1))), Undefined()),' + ' Number(1, Location((1, 1), "", (1, 1))), Location((1, 1), "", (1, 1)))' + ) + + +def test_div_eq() -> None: + assert Div(Variable("X"), Number(1), location=Location((1, 2))) == Div(Variable("X"), Number(1)) + + +def test_div_ne() -> None: + assert Div(Variable("X"), Number(1)) != Div(Variable("Y"), Number(1)) + assert Div(Variable("X"), Number(1)) != Div(Variable("X"), Number(2)) + + +def test_div_neg() -> None: + assert -Div(Variable("X"), Number(5)) == Div(-(Variable("X")), Number(5)) + + +@pytest.mark.parametrize( + ("left", "right"), + [ + # 0 in argument + (0, 5), + # Argument sign variations + (10, 5), + (-10, 5), + (10, -5), + (-10, -5), + ], +) +def test_div_neg_eval(left: int, right: int) -> None: + assert (-Div(Number(left), Number(right))).simplified() == Number(-(left // right)) + + +def test_div_location() -> None: + assert Div(Variable("X"), Number(1), location=Location((1, 2))).location == Location((1, 2)) + + +def test_div_type() -> None: + assert_type( + Div(Variable("X", type_=INT_TY), Number(1)), + ty.BASE_INTEGER, + ) + assert_type_instance( + Div(Variable("X", type_=INT_TY), Number(1)), + ty.Integer, + ) + + +def test_div_type_error() -> None: + assert_type_error( + Div(Variable(ID("X", location=Location((1, 2))), type_=ty.BOOLEAN), Number(1)), + r"^" + r":1:2: error: expected integer type\n" + r':1:2: error: found enumeration type "__BUILTINS__::Boolean"' + r"$", + ) + + +def test_div_left() -> None: + assert Div(Variable("X"), Variable("Y")).left == Variable("X") + + +def test_div_right() -> None: + assert Div(Variable("X"), Variable("Y")).right == Variable("Y") + + +def test_div_variables() -> None: + assert Div(Variable("X"), Variable("Y")).variables() == [Variable("X"), Variable("Y")] + + +def test_div_findall() -> None: + assert Div(Variable("X"), Variable("Y")).findall(lambda x: isinstance(x, Variable)) == [ + Variable("X"), + Variable("Y"), + ] + + +def test_div_substituted() -> None: + assert Div(Variable("X"), Variable("Y")).substituted( + lambda x: Number(42) if x == Variable("X") else x, + ) == Div(Number(42), Variable("Y")) + + +def test_div_simplified() -> None: + assert Div(Variable("X"), Number(1)).simplified() == Div(Variable("X"), Number(1)) + assert Div(Number(6), Number(2)).simplified() == Number(3) + assert Div(Number(9), Number(2)).simplified() == Div(Number(9), Number(2)) + + +def test_div_simplified_location() -> None: + simplified = Div(Number(6), Number(2), location=Location((1, 1))).simplified() + assert simplified == Number(3) + assert simplified.location == Location((1, 1)) + + +def test_pow_str() -> None: + assert str(Pow(Variable("X"), Number(1))) == "X ** 1" + assert str(Pow(Neg(Variable("X")), Number(-1))) == "(-X) ** (-1)" + + +def test_pow_repr() -> None: + assert repr(Pow(Variable("X"), Number(1))) == ( + 'Pow(Variable(ID("X", Location((1, 1), "", (1, 1))), Undefined()),' + ' Number(1, Location((1, 1), "", (1, 1))), Location((1, 1), "", (1, 1)))' + ) + + +def test_pow_eq() -> None: + assert Pow(Variable("X"), Number(1), location=Location((1, 2))) == Pow(Variable("X"), Number(1)) + + +def test_pow_ne() -> None: + assert Pow(Variable("X"), Number(1)) != Pow(Variable("Y"), Number(1)) + assert Pow(Variable("X"), Number(1)) != Pow(Variable("X"), Number(2)) + + +def test_pow_neg() -> None: + assert -Pow(Variable("X"), Number(5)) == -Pow(Variable("X"), Number(5)) + + +@pytest.mark.parametrize( + ("left", "right"), + [ + # 0 in argument + (0, 0), + (-10, 0), + (0, 4), + # Argument sign variations + # Constraints: + # * The second argument cannot be negative. + # * The second argument must be tested with an even and odd value. + (10, 4), + (-10, 4), + (10, 5), + (-10, 5), + ], +) +def test_pow_neg_eval(left: int, right: int) -> None: + assert (-Pow(Number(left), Number(right))).simplified() == Number(-(left**right)) + + +def test_pow_location() -> None: + assert Pow(Variable("X"), Number(1), location=Location((1, 2))).location == Location((1, 2)) + + +def test_pow_type() -> None: + assert_type( + Pow(Variable("X", type_=INT_TY), Number(1)), + ty.BASE_INTEGER, + ) + assert_type_instance( + Pow(Variable("X", type_=INT_TY), Number(1)), + ty.Integer, + ) + + +def test_pow_type_error() -> None: + assert_type_error( + Pow(Variable(ID("X", location=Location((1, 2))), type_=ty.BOOLEAN), Number(1)), + r"^" + r":1:2: error: expected integer type\n" + r':1:2: error: found enumeration type "__BUILTINS__::Boolean"' + r"$", + ) + + +def test_pow_left() -> None: + assert Pow(Variable("X"), Variable("Y")).left == Variable("X") + + +def test_pow_right() -> None: + assert Pow(Variable("X"), Variable("Y")).right == Variable("Y") + + +def test_pow_variables() -> None: + assert Pow(Variable("X"), Variable("Y")).variables() == [Variable("X"), Variable("Y")] + + +def test_pow_findall() -> None: + assert Pow(Variable("X"), Variable("Y")).findall(lambda x: isinstance(x, Variable)) == [ + Variable("X"), + Variable("Y"), + ] + + +def test_pow_substituted() -> None: + assert Pow(Variable("X"), Variable("Y")).substituted( + lambda x: Number(42) if x == Variable("X") else x, + ) == Pow(Number(42), Variable("Y")) + + +def test_pow_simplified() -> None: + assert Pow(Variable("X"), Number(1)).simplified() == Pow(Variable("X"), Number(1)) + assert Pow(Number(6), Number(2)).simplified() == Number(36) + + +def test_pow_simplified_location() -> None: + simplified = Pow(Number(6), Number(2), location=Location((1, 1))).simplified() + assert simplified == Number(36) + assert simplified.location == Location((1, 1)) + + +def test_mod_str() -> None: + assert str(Mod(Variable("X"), Number(1))) == "X mod 1" + assert str(Mod(Neg(Variable("X")), Number(-1))) == "(-X) mod (-1)" + + +def test_mod_repr() -> None: + assert repr(Mod(Variable("X"), Number(1))) == ( + 'Mod(Variable(ID("X", Location((1, 1), "", (1, 1))), Undefined()),' + ' Number(1, Location((1, 1), "", (1, 1))), Location((1, 1), "", (1, 1)))' + ) + + +def test_mod_eq() -> None: + assert Mod(Variable("X"), Number(1), location=Location((1, 2))) == Mod(Variable("X"), Number(1)) + + +def test_mod_ne() -> None: + assert Mod(Variable("X"), Number(1)) != Mod(Variable("Y"), Number(1)) + assert Mod(Variable("X"), Number(1)) != Mod(Variable("X"), Number(2)) + + +def test_mod_neg() -> None: + assert -Mod(Variable("X"), Number(5)) == -Mod(Variable("X"), Number(5)) + + +@pytest.mark.parametrize( + ("left", "right"), + [ + # 0 in argument + (0, 3), + # Argument sign variations + (7, 3), + (-7, 3), + (7, -3), + (-7, -3), + ], +) +def test_mod_neg_eval(left: int, right: int) -> None: + assert (-Mod(Number(left), Number(right))).simplified() == Number(-(left % right)) + + +def test_mod_type() -> None: + assert_type( + Mod(Variable("X", type_=INT_TY), Number(1)), + ty.BASE_INTEGER, + ) + assert_type_instance( + Mod(Variable("X", type_=INT_TY), Number(1)), + ty.Integer, + ) + + +def test_mod_type_error() -> None: + assert_type_error( + Mod(Variable(ID("X", location=Location((1, 2))), type_=ty.BOOLEAN), Number(1)), + r"^" + r":1:2: error: expected integer type\n" + r':1:2: error: found enumeration type "__BUILTINS__::Boolean"' + r"$", + ) + + +def test_mod_left() -> None: + assert Mod(Variable("X"), Variable("Y")).left == Variable("X") + + +def test_mod_right() -> None: + assert Mod(Variable("X"), Variable("Y")).right == Variable("Y") + + +def test_mod_variables() -> None: + assert Mod(Variable("X"), Variable("Y")).variables() == [Variable("X"), Variable("Y")] + + +def test_mod_findall() -> None: + assert Mod(Variable("X"), Variable("Y")).findall(lambda x: isinstance(x, Variable)) == [ + Variable("X"), + Variable("Y"), + ] + + +def test_mod_substituted() -> None: + assert Mod(Variable("X"), Variable("Y")).substituted( + lambda x: Number(42) if x == Variable("X") else x, + ) == Mod(Number(42), Variable("Y")) + + +def test_mod_simplified() -> None: + assert Mod(Variable("X"), Number(1)).simplified() == Mod(Variable("X"), Number(1)) + assert Mod(Number(6), Number(2)).simplified() == Number(0) + + +def test_mod_simplified_location() -> None: + simplified = Mod(Number(5), Number(2), location=Location((1, 1))).simplified() + assert simplified == Number(1) + assert simplified.location == Location((1, 1)) + + @pytest.mark.parametrize( ("left", "right"), [