From 78631c839edc2907cfda958b4f701b62ac20562e Mon Sep 17 00:00:00 2001 From: coord_e Date: Thu, 18 Jun 2026 14:47:02 +0900 Subject: [PATCH] Support forall in annotation --- src/analyze/annot.rs | 8 +++++ src/analyze/annot_fn.rs | 63 +++++++++++++++++++++++------------ src/analyze/did_cache.rs | 8 +++++ src/annot.rs | 2 +- src/chc.rs | 53 ++++++++++++++++++++++++----- src/chc/format_context.rs | 2 +- src/chc/smtlib2.rs | 10 +++++- src/chc/unbox.rs | 8 +++-- src/rty.rs | 8 ++++- std.rs | 7 ++++ tests/ui/fail/annot_forall.rs | 13 ++++++++ tests/ui/pass/annot_forall.rs | 13 ++++++++ 12 files changed, 159 insertions(+), 36 deletions(-) create mode 100644 tests/ui/fail/annot_forall.rs create mode 100644 tests/ui/pass/annot_forall.rs diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index 409d5218..76e76e22 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -145,6 +145,14 @@ pub fn exists_path() -> [Symbol; 3] { ] } +pub fn forall_path() -> [Symbol; 3] { + [ + Symbol::intern("thrust"), + Symbol::intern("def"), + Symbol::intern("forall"), + ] +} + pub fn implies_path() -> [Symbol; 3] { [ Symbol::intern("thrust"), diff --git a/src/analyze/annot_fn.rs b/src/analyze/annot_fn.rs index c91a87a4..0249e4aa 100644 --- a/src/analyze/annot_fn.rs +++ b/src/analyze/annot_fn.rs @@ -452,6 +452,32 @@ impl<'a, 'tcx> AnnotFnTranslator<'a, 'tcx> { chc::Term::datatype_ctor(d_sym, sort_args, v_sym, field_terms) } + fn to_formula_with_quantified_vars( + &self, + closure: &rustc_hir::Body<'tcx>, + ) -> ( + Vec<(String, chc::Sort)>, + chc::Formula, + ) { + let mut inner_translator = self.clone(); + let mut vars = Vec::new(); + for param in closure.params { + let rustc_hir::PatKind::Binding(_, hir_id, ident, None) = param.pat.kind else { + panic!( + "exists/forall closure parameter must be a simple binding: {:?}", + param.pat + ); + }; + let param_ty = self.pat_ty(param.pat); + let sort = self.type_builder.build(param_ty).to_sort(); + let var_term = chc::Term::FormulaQuantifiedVar(sort.clone(), ident.name.to_string()); + inner_translator.env.insert(hir_id, var_term); + vars.push((ident.name.to_string(), sort)); + } + let body_formula = inner_translator.to_formula(closure.value); + (vars, body_formula) + } + fn to_formula_or_term( &self, hir: &'tcx rustc_hir::Expr<'tcx>, @@ -648,32 +674,27 @@ impl<'a, 'tcx> AnnotFnTranslator<'a, 'tcx> { }; let closure_body = self.tcx.hir_body(closure.body); - let mut inner_translator = self.clone(); - let mut vars = Vec::new(); - for param in closure_body.params { - let rustc_hir::PatKind::Binding(_, hir_id, ident, None) = - param.pat.kind - else { - panic!( - "exists closure parameter must be a simple binding: {:?}", - param.pat - ); - }; - let param_ty = self.pat_ty(param.pat); - let sort = self.type_builder.build(param_ty).to_sort(); - let var_term = chc::Term::FormulaExistentialVar( - sort.clone(), - ident.name.to_string(), - ); - inner_translator.env.insert(hir_id, var_term); - vars.push((ident.name.to_string(), sort)); - } - let body_formula = inner_translator.to_formula(closure_body.value); + let (vars, body_formula) = + self.to_formula_with_quantified_vars(closure_body); return FormulaOrTerm::Formula(chc::Formula::exists( vars, body_formula, )); } + if Some(def_id) == self.def_ids.forall() { + assert_eq!(args.len(), 1, "forall takes exactly 1 argument"); + let ExprKind::Closure(closure) = args[0].kind else { + panic!("forall argument must be a closure"); + }; + let closure_body = self.tcx.hir_body(closure.body); + + let (vars, body_formula) = + self.to_formula_with_quantified_vars(closure_body); + return FormulaOrTerm::Formula(chc::Formula::forall( + vars, + body_formula, + )); + } if Some(def_id) == self.def_ids.fn_param_at_entry() { assert_eq!(args.len(), 1, "FnParam::at_entry takes exactly 1 argument"); let t = self.to_term(&args[0]); diff --git a/src/analyze/did_cache.rs b/src/analyze/did_cache.rs index 6c56e95a..15149120 100644 --- a/src/analyze/did_cache.rs +++ b/src/analyze/did_cache.rs @@ -25,6 +25,7 @@ struct DefIds { array_model_store: OnceCell>, exists: OnceCell>, + forall: OnceCell>, implies: OnceCell>, invariant_marker: OnceCell>, @@ -185,6 +186,13 @@ impl<'tcx> DefIdCache<'tcx> { .get_or_init(|| self.annotated_def(&crate::analyze::annot::exists_path())) } + pub fn forall(&self) -> Option { + *self + .def_ids + .forall + .get_or_init(|| self.annotated_def(&crate::analyze::annot::forall_path())) + } + pub fn implies(&self) -> Option { *self .def_ids diff --git a/src/annot.rs b/src/annot.rs index b1179d2c..e51836e9 100644 --- a/src/annot.rs +++ b/src/annot.rs @@ -510,7 +510,7 @@ where ("false", _) => FormulaOrTerm::Literal(false), (_, Some(sort)) => { let var = - chc::Term::FormulaExistentialVar(sort.clone(), ident.name.to_string()); + chc::Term::FormulaQuantifiedVar(sort.clone(), ident.name.to_string()); FormulaOrTerm::Term(var, sort.clone()) } _ => { diff --git a/src/chc.rs b/src/chc.rs index 5435a136..2a61f375 100644 --- a/src/chc.rs +++ b/src/chc.rs @@ -452,8 +452,8 @@ pub enum Term { TupleProj(Box>, usize), DatatypeCtor(DatatypeSort, DatatypeSymbol, Vec>), DatatypeDiscr(DatatypeSymbol, Box>), - /// Used in [`Formula`] to represent existentially quantified variables appearing in annotations. - FormulaExistentialVar(Sort, String), + /// Used in [`Formula`] to represent quantified variables appearing in annotations. + FormulaQuantifiedVar(Sort, String), } impl<'a, D, V> Pretty<'a, D, termcolor::ColorSpec> for &Term @@ -523,7 +523,7 @@ where Term::DatatypeDiscr(_, t) => allocator .text("discriminant") .append(t.pretty(allocator).parens()), - Term::FormulaExistentialVar(_, name) => allocator.text(name.clone()), + Term::FormulaQuantifiedVar(_, name) => allocator.text(name.clone()), } } } @@ -570,7 +570,7 @@ impl Term { args.into_iter().map(|t| t.subst_var(&mut f)).collect(), ), Term::DatatypeDiscr(d_sym, t) => Term::DatatypeDiscr(d_sym, Box::new(t.subst_var(f))), - Term::FormulaExistentialVar(sort, name) => Term::FormulaExistentialVar(sort, name), + Term::FormulaQuantifiedVar(sort, name) => Term::FormulaQuantifiedVar(sort, name), } } @@ -616,7 +616,7 @@ impl Term { Term::TupleProj(t, i) => t.sort(var_sort).tuple_elem(*i), Term::DatatypeCtor(sort, _, _) => sort.clone().into(), Term::DatatypeDiscr(_, _) => Sort::int(), - Term::FormulaExistentialVar(sort, _) => sort.clone(), + Term::FormulaQuantifiedVar(sort, _) => sort.clone(), } } @@ -627,7 +627,7 @@ impl Term { | Term::Bool(_) | Term::Int(_) | Term::String(_) - | Term::FormulaExistentialVar { .. } => Box::new(std::iter::empty()), + | Term::FormulaQuantifiedVar { .. } => Box::new(std::iter::empty()), Term::Box(t) => t.fv_impl(), Term::Mut(t1, t2) => Box::new(t1.fv_impl().chain(t2.fv_impl())), Term::BoxCurrent(t) => t.fv_impl(), @@ -1253,6 +1253,7 @@ pub enum Formula { Or(Vec>), Implies(Box>, Box>), Exists(Vec<(String, Sort)>, Box>), + Forall(Vec<(String, Sort)>, Box>), } impl Default for Formula { @@ -1320,6 +1321,25 @@ where .append(fo.pretty(allocator).nest(2)) .group() } + Formula::Forall(vars, fo) => { + let vars = allocator.intersperse( + vars.iter().map(|(name, sort)| { + allocator + .text(name.clone()) + .append(allocator.text(":")) + .append(allocator.text(" ")) + .append(sort.pretty(allocator)) + }), + allocator.text(", ").append(allocator.line()), + ); + allocator + .text("∀") + .append(vars) + .append(allocator.text(".")) + .append(allocator.line()) + .append(fo.pretty(allocator).nest(2)) + .group() + } } } } @@ -1335,9 +1355,11 @@ impl Formula { D::Doc: Clone, { match self { - Formula::And(_) | Formula::Or(_) | Formula::Implies(_, _) | Formula::Exists { .. } => { - self.pretty(allocator).parens() - } + Formula::And(_) + | Formula::Or(_) + | Formula::Implies { .. } + | Formula::Exists { .. } + | Formula::Forall { .. } => self.pretty(allocator).parens(), _ => self.pretty(allocator), } } @@ -1358,6 +1380,7 @@ impl Formula { Formula::Or(fs) => fs.iter().any(Formula::is_top), Formula::Implies(lhs, rhs) => lhs.is_bottom() || rhs.is_top(), Formula::Exists(_, fo) => fo.is_top(), + Formula::Forall(_, fo) => fo.is_top(), } } @@ -1369,6 +1392,7 @@ impl Formula { Formula::Or(fs) => fs.iter().all(Formula::is_bottom), Formula::Implies(lhs, rhs) => lhs.is_top() && rhs.is_bottom(), Formula::Exists(_, fo) => fo.is_bottom(), + Formula::Forall(_, fo) => fo.is_bottom(), } } @@ -1407,6 +1431,10 @@ impl Formula { Formula::Exists(vars, Box::new(body)) } + pub fn forall(vars: Vec<(String, Sort)>, body: Self) -> Self { + Formula::Forall(vars, Box::new(body)) + } + pub fn subst_var(self, f: F) -> Formula where F: FnMut(V) -> Term, @@ -1424,6 +1452,7 @@ impl Formula { Formula::Implies(Box::new(lhs.subst_var(&mut f)), Box::new(rhs.subst_var(f))) } Formula::Exists(vars, fo) => Formula::Exists(vars, Box::new(fo.subst_var(f))), + Formula::Forall(vars, fo) => Formula::Forall(vars, Box::new(fo.subst_var(f))), } } @@ -1442,6 +1471,7 @@ impl Formula { Formula::Implies(Box::new(lhs.map_var(&mut f)), Box::new(rhs.map_var(f))) } Formula::Exists(vars, fo) => Formula::Exists(vars, Box::new(fo.map_var(f))), + Formula::Forall(vars, fo) => Formula::Forall(vars, Box::new(fo.map_var(f))), } } @@ -1457,6 +1487,7 @@ impl Formula { Formula::Or(fs) => Box::new(fs.iter().flat_map(Formula::fv)), Formula::Implies(lhs, rhs) => Box::new(lhs.fv().chain(rhs.fv())), Formula::Exists(_, fo) => Box::new(fo.fv()), + Formula::Forall(_, fo) => Box::new(fo.fv()), } } @@ -1472,6 +1503,7 @@ impl Formula { Formula::Or(fs) => Box::new(fs.iter().flat_map(Formula::iter_atoms)), Formula::Implies(lhs, rhs) => Box::new(lhs.iter_atoms().chain(rhs.iter_atoms())), Formula::Exists(_, fo) => Box::new(fo.iter_atoms()), + Formula::Forall(_, fo) => Box::new(fo.iter_atoms()), } } @@ -1527,6 +1559,9 @@ impl Formula { Formula::Exists(_, fo) => { fo.simplify(); } + Formula::Forall(_, fo) => { + fo.simplify(); + } } } } diff --git a/src/chc/format_context.rs b/src/chc/format_context.rs index 94548274..f4a79fd4 100644 --- a/src/chc/format_context.rs +++ b/src/chc/format_context.rs @@ -57,7 +57,7 @@ fn term_sorts(clause: &chc::Clause, t: &chc::Term, sorts: &mut BTreeSet term_sorts(clause, t, sorts), - chc::Term::FormulaExistentialVar(_, _) => {} + chc::Term::FormulaQuantifiedVar(_, _) => {} } } diff --git a/src/chc/smtlib2.rs b/src/chc/smtlib2.rs index c00782c0..7904d58d 100644 --- a/src/chc/smtlib2.rs +++ b/src/chc/smtlib2.rs @@ -202,7 +202,7 @@ impl<'ctx, 'a> std::fmt::Display for Term<'ctx, 'a> { Term::new(self.ctx, self.clause, t) ) } - chc::Term::FormulaExistentialVar(_, name) => write!(f, "{}", name), + chc::Term::FormulaQuantifiedVar(_, name) => write!(f, "{}", name), } } } @@ -301,6 +301,14 @@ impl<'ctx, 'a> std::fmt::Display for Formula<'ctx, 'a> { let fo = Formula::new(self.ctx, self.clause, fo); write!(f, "(exists {vars} {fo})") } + chc::Formula::Forall(vars, fo) => { + let vars = + List::closed(vars.iter().map(|(v, s)| { + List::closed([v.to_string(), self.ctx.fmt_sort(s).to_string()]) + })); + let fo = Formula::new(self.ctx, self.clause, fo); + write!(f, "(forall {vars} {fo})") + } } } } diff --git a/src/chc/unbox.rs b/src/chc/unbox.rs index 64f79aa4..bfb782eb 100644 --- a/src/chc/unbox.rs +++ b/src/chc/unbox.rs @@ -19,8 +19,8 @@ fn unbox_term(term: Term) -> Term { args.into_iter().map(unbox_term).collect(), ), Term::DatatypeDiscr(sym, arg) => Term::DatatypeDiscr(sym, Box::new(unbox_term(*arg))), - Term::FormulaExistentialVar(sort, name) => { - Term::FormulaExistentialVar(unbox_sort(sort), name) + Term::FormulaQuantifiedVar(sort, name) => { + Term::FormulaQuantifiedVar(unbox_sort(sort), name) } } } @@ -88,6 +88,10 @@ fn unbox_formula(formula: Formula) -> Formula { let vars = vars.into_iter().map(|(v, s)| (v, unbox_sort(s))).collect(); Formula::Exists(vars, Box::new(unbox_formula(*fo))) } + Formula::Forall(vars, fo) => { + let vars = vars.into_iter().map(|(v, s)| (v, unbox_sort(s))).collect(); + Formula::Forall(vars, Box::new(unbox_formula(*fo))) + } } } diff --git a/src/rty.rs b/src/rty.rs index 6a98e7e5..28209bb4 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -1940,6 +1940,12 @@ fn subst_ty_params_in_formula(formula: &mut chc::Formula, subst: &TypeP } subst_ty_params_in_formula(f, subst); } + chc::Formula::Forall(vars, f) => { + for (_, sort) in vars { + subst_ty_params_in_sort(sort, subst); + } + subst_ty_params_in_formula(f, subst); + } } } @@ -1976,7 +1982,7 @@ fn subst_ty_params_in_term(term: &mut chc::Term, subst: &TypeParamSubst subst_ty_params_in_term(arg, subst); } } - chc::Term::FormulaExistentialVar(sort, _) => { + chc::Term::FormulaQuantifiedVar(sort, _) => { subst_ty_params_in_sort(sort, subst); } } diff --git a/std.rs b/std.rs index abbdbde9..87e5bde2 100644 --- a/std.rs +++ b/std.rs @@ -332,6 +332,13 @@ mod thrust_models { unimplemented!() } + #[allow(dead_code)] + #[thrust::def::forall] + #[thrust::ignored] + pub fn forall(_x: T) -> bool { + unimplemented!() + } + #[allow(dead_code)] #[thrust::def::implies] #[thrust::ignored] diff --git a/tests/ui/fail/annot_forall.rs b/tests/ui/fail/annot_forall.rs new file mode 100644 index 00000000..e60993d2 --- /dev/null +++ b/tests/ui/fail/annot_forall.rs @@ -0,0 +1,13 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off +//@rustc-env: THRUST_SOLVER=tests/thrust-pcsat-wrapper + +use thrust_models::forall; + +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(result > x && forall(|y: i32| y <= x || result < y))] +fn succ(x: i32) -> i32 { + x + 1 +} + +fn main() {} diff --git a/tests/ui/pass/annot_forall.rs b/tests/ui/pass/annot_forall.rs new file mode 100644 index 00000000..f13b69e9 --- /dev/null +++ b/tests/ui/pass/annot_forall.rs @@ -0,0 +1,13 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off +//@rustc-env: THRUST_SOLVER=tests/thrust-pcsat-wrapper + +use thrust_models::forall; + +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(result > x && forall(|y: i32| y <= x || result <= y))] +fn succ(x: i32) -> i32 { + x + 1 +} + +fn main() {}