Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,14 @@ impl<'tcx> Analyzer<'tcx> {
self.def_ids.clone()
}

pub fn mark_uses_seq_concat(&self) {
self.system.borrow_mut().uses_seq_concat = true;
}

pub fn mark_uses_seq_subseq(&self) {
self.system.borrow_mut().uses_seq_subseq = true;
}

pub fn add_clause(&mut self, clause: chc::Clause) {
self.system.borrow_mut().push_clause(clause);
}
Expand Down Expand Up @@ -454,6 +462,22 @@ impl<'tcx> Analyzer<'tcx> {
Some(formula_fn)
}

/// Companion of [`Self::formula_fn_with_args`] for `#[thrust::formula_fn]`
/// bodies whose body yields a model term (rather than a bool formula) — used
/// by `snapshot!{}`. Not cached.
pub fn term_fn_with_args(
&self,
local_def_id: LocalDefId,
generic_args: mir_ty::GenericArgsRef<'tcx>,
) -> Option<annot_fn::TermFn<'tcx>> {
// Reuse the registration set used for `formula_fn_with_args`: any
// `#[thrust::formula_fn]` may serve as either.
self.formula_fns.get(&local_def_id)?;
let translator = annot_fn::AnnotFnTranslator::new(self, local_def_id, generic_args)
.with_def_id_cache(self.def_ids());
Some(translator.to_term_fn())
}

pub fn def_ty_with_args(
&mut self,
def_id: DefId,
Expand Down
72 changes: 72 additions & 0 deletions src/analyze/annot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,62 @@ pub fn array_model_store_path() -> [Symbol; 3] {
]
}

pub fn seq_model_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Symbol::intern("def"),
Symbol::intern("seq_model"),
]
}

pub fn seq_empty_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Symbol::intern("def"),
Symbol::intern("seq_empty"),
]
}

pub fn seq_singleton_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Symbol::intern("def"),
Symbol::intern("seq_singleton"),
]
}

pub fn seq_len_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Symbol::intern("def"),
Symbol::intern("seq_len"),
]
}

pub fn seq_push_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Symbol::intern("def"),
Symbol::intern("seq_push"),
]
}

pub fn seq_concat_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Symbol::intern("def"),
Symbol::intern("seq_concat"),
]
}

pub fn seq_subsequence_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Symbol::intern("def"),
Symbol::intern("seq_subsequence"),
]
}

pub fn exists_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Expand Down Expand Up @@ -169,6 +225,22 @@ pub fn invariant_marker_path() -> [Symbol; 3] {
]
}

pub fn snapshot_marker_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Symbol::intern("def"),
Symbol::intern("snapshot_marker"),
]
}

pub fn proof_assert_marker_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Symbol::intern("def"),
Symbol::intern("proof_assert_marker"),
]
}

pub fn fn_param_wrapper_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Expand Down
127 changes: 126 additions & 1 deletion src/analyze/annot_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,42 @@ where
}
}

/// The term analogue of [`FormulaFn`]: a `#[thrust::formula_fn]` whose body
/// evaluates to a non-`bool` model term (used by `snapshot!{}`).
#[derive(Debug, Clone)]
pub struct TermFn<'tcx> {
params: IndexVec<rty::FunctionParamIdx, mir_ty::Ty<'tcx>>,
term: chc::Term<rty::FunctionParamIdx>,
}

impl<'a, D> Pretty<'a, D, termcolor::ColorSpec> for &TermFn<'_>
where
D: pretty::DocAllocator<'a, termcolor::ColorSpec>,
D::Doc: Clone,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, termcolor::ColorSpec> {
allocator
.intersperse(
self.params.iter_enumerated().map(|(idx, ty)| {
idx.pretty(allocator)
.append(": ")
.append(allocator.as_string(ty))
}),
", ",
)
.enclose("|", "|")
.group()
.append(self.term.pretty(allocator))
.group()
}
}

impl<'tcx> TermFn<'tcx> {
pub fn term(&self) -> &chc::Term<rty::FunctionParamIdx> {
&self.term
}
}

impl<'tcx> FormulaFn<'tcx> {
pub fn formula(&self) -> &chc::Formula<rty::FunctionParamIdx> {
&self.formula
Expand Down Expand Up @@ -314,6 +350,25 @@ impl<'a, 'tcx> AnnotFnTranslator<'a, 'tcx> {
}
}

/// Same shape as [`Self::to_formula_fn`] but for `#[thrust::formula_fn]`
/// bodies whose value is a non-`bool` model term (i.e. `snapshot!{}`
/// closures). The body is interpreted via [`Self::to_term`] rather than
/// [`Self::to_formula`].
pub fn to_term_fn(&self) -> TermFn<'tcx> {
let term = self.to_term(self.body.value);
let params = self
.tcx
.fn_sig(self.local_def_id.to_def_id())
.instantiate(self.tcx, self.generic_args)
.skip_binder()
.inputs()
.to_vec();
TermFn {
params: IndexVec::from_raw(params),
term,
}
}

fn to_formula(&self, hir: &'tcx rustc_hir::Expr<'tcx>) -> chc::Formula<rty::FunctionParamIdx> {
self.to_formula_or_term(hir)
.into_formula()
Expand Down Expand Up @@ -623,9 +678,18 @@ impl<'a, 'tcx> AnnotFnTranslator<'a, 'tcx> {
FormulaOrTerm::Term(term.tuple_proj(index))
}
ExprKind::Index(array, index, _) => {
let array_ty = self.expr_ty(array);
let array_term = self.to_term(array);
let index_term = self.to_term(index);
FormulaOrTerm::Term(array_term.select(index_term))
let is_seq = array_ty
.ty_adt_def()
.is_some_and(|adt| Some(adt.did()) == self.def_ids.seq_model());
let array_inner = if is_seq {
array_term.tuple_proj(0)
} else {
array_term
};
FormulaOrTerm::Term(array_inner.select(index_term))
}
ExprKind::MethodCall(method, receiver, args, _) => {
if let Some(def_id) = self.typeck.type_dependent_def_id(hir.hir_id) {
Expand All @@ -644,6 +708,49 @@ impl<'a, 'tcx> AnnotFnTranslator<'a, 'tcx> {
let t = self.to_term(receiver);
return FormulaOrTerm::Term(t);
}
if Some(def_id) == self.def_ids.seq_len() {
assert!(args.is_empty(), "Seq::len does not take any arguments");
let t = self.to_term(receiver);
return FormulaOrTerm::Term(t.tuple_proj(1));
}
if Some(def_id) == self.def_ids.seq_push() {
assert_eq!(args.len(), 1, "Seq::push takes exactly 1 argument");
let t = self.to_term(receiver);
let v = self.to_term(&args[0]);
let arr = t.clone().tuple_proj(0);
let len = t.tuple_proj(1);
let new_arr = arr.store(len.clone(), v);
let new_len = len.add(chc::Term::int(1));
return FormulaOrTerm::Term(chc::Term::tuple(vec![new_arr, new_len]));
}
if Some(def_id) == self.def_ids.seq_concat() {
assert_eq!(args.len(), 1, "Seq::concat takes exactly 1 argument");
self.analyzer.mark_uses_seq_concat();
let t = self.to_term(receiver);
let other = self.to_term(&args[0]);
let a_arr = t.clone().tuple_proj(0);
let a_len = t.tuple_proj(1);
let b_arr = other.clone().tuple_proj(0);
let b_len = other.tuple_proj(1);
let new_arr = chc::Term::App(
chc::Function::SEQ_CONCAT_ARR_INT,
vec![a_arr, a_len.clone(), b_arr],
);
let new_len = a_len.add(b_len);
return FormulaOrTerm::Term(chc::Term::tuple(vec![new_arr, new_len]));
}
if Some(def_id) == self.def_ids.seq_subsequence() {
assert_eq!(args.len(), 2, "Seq::subsequence takes exactly 2 arguments");
self.analyzer.mark_uses_seq_subseq();
let t = self.to_term(receiver);
let l = self.to_term(&args[0]);
let r = self.to_term(&args[1]);
let arr = t.tuple_proj(0);
let new_arr =
chc::Term::App(chc::Function::SEQ_SUBSEQ_ARR_INT, vec![arr, l.clone()]);
let new_len = r.sub(l);
return FormulaOrTerm::Term(chc::Term::tuple(vec![new_arr, new_len]));
}
}
unimplemented!("unsupported method call in formula: {:?}", method)
}
Expand Down Expand Up @@ -719,6 +826,24 @@ impl<'a, 'tcx> AnnotFnTranslator<'a, 'tcx> {
let t = self.to_term(&args[0]);
return FormulaOrTerm::Term(chc::Term::box_(t));
}
if Some(def_id) == self.def_ids.seq_empty() {
assert!(args.is_empty(), "Seq::empty does not take any arguments");
let arr = chc::Term::App(chc::Function::SEQ_DEFAULT_ARR_INT, vec![]);
return FormulaOrTerm::Term(chc::Term::tuple(vec![
arr,
chc::Term::int(0),
]));
}
if Some(def_id) == self.def_ids.seq_singleton() {
assert_eq!(args.len(), 1, "Seq::singleton takes exactly 1 argument");
let v = self.to_term(&args[0]);
let arr = chc::Term::App(chc::Function::SEQ_DEFAULT_ARR_INT, vec![]);
let new_arr = arr.store(chc::Term::int(0), v);
return FormulaOrTerm::Term(chc::Term::tuple(vec![
new_arr,
chc::Term::int(1),
]));
}
if let rustc_hir::def::DefKind::Ctor(ctor_of, _) = def_kind {
let terms = args.iter().map(|e| self.to_term(e)).collect();
match ctor_of {
Expand Down
Loading