diff --git a/src/analyze/annot_fn.rs b/src/analyze/annot_fn.rs index 6e2bb44..38dba93 100644 --- a/src/analyze/annot_fn.rs +++ b/src/analyze/annot_fn.rs @@ -482,7 +482,26 @@ impl<'tcx> AnnotFnTranslator<'tcx> { .next() .is_some() { - let pred = refine::user_defined_pred(self.tcx, def_id); + let param_env = self + .tcx + .param_env(self.local_def_id) + .with_reveal_all_normalized(self.tcx); + let generic_args = self.typeck.node_args(func_expr.hir_id); + let generic_args = mir_ty::EarlyBinder::bind(generic_args) + .instantiate(self.tcx, self.generic_args); + let instance = mir_ty::Instance::resolve( + self.tcx, + param_env, + def_id, + generic_args, + ) + .unwrap(); + let pred_def_id = if let Some(instance) = instance { + instance.def_id() + } else { + def_id + }; + let pred = refine::user_defined_pred(self.tcx, pred_def_id); let arg_terms = args.iter().map(|e| self.to_term(e)).collect(); let atom = chc::Atom::new(pred.into(), arg_terms); return FormulaOrTerm::Formula(chc::Formula::Atom(atom)); diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index 29b086d..579150c 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -420,12 +420,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { _ty, ) => { let func_ty = match operand.const_fn_def() { - Some((def_id, args)) => self - .ctx - .def_ty_with_args(def_id, args) - .expect("unknown def") - .ty - .clone(), + Some((def_id, args)) => self.fn_def_ty(def_id, args), _ => unimplemented!(), }; PlaceType::with_ty_and_term(func_ty.vacuous(), chc::Term::null()) @@ -573,44 +568,68 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { }); } + fn resolve_fn_def( + &self, + def_id: DefId, + args: mir_ty::GenericArgsRef<'tcx>, + ) -> (DefId, mir_ty::GenericArgsRef<'tcx>) { + if self.ctx.is_fn_trait_method(def_id) { + // When calling a closure via `Fn`/`FnMut`/`FnOnce` trait, + // we simply replace the def_id with the closure's function def_id. + // This skips shims, and makes self arguments mismatch. visitor::RustCallVisitor + // adjusts the arguments accordingly. + let mir_ty::TyKind::Closure(closure_def_id, _) = args.type_at(0).kind() else { + panic!("expected closure arg for fn trait"); + }; + tracing::debug!(?closure_def_id, "closure instance"); + (*closure_def_id, args) + } else { + let param_env = self + .tcx + .param_env(self.local_def_id) + .with_reveal_all_normalized(self.tcx); + let instance = mir_ty::Instance::resolve(self.tcx, param_env, def_id, args).unwrap(); + if let Some(instance) = instance { + (instance.def_id(), instance.args) + } else { + (def_id, args) + } + } + } + + fn fn_def_ty( + &mut self, + def_id: DefId, + args: mir_ty::GenericArgsRef<'tcx>, + ) -> rty::Type { + if let Some(def_ty) = self.ctx.def_ty_with_args(def_id, args) { + return def_ty.ty; + } + + let (resolved_def_id, resolved_args) = self.resolve_fn_def(def_id, args); + if resolved_def_id == def_id { + panic!( + "unknown def (and not resolved): {:?}, args: {:?}", + def_id, args + ); + } + tracing::info!(?def_id, ?resolved_def_id, ?resolved_args, "resolved"); + let Some(def_ty) = self.ctx.def_ty_with_args(resolved_def_id, resolved_args) else { + panic!( + "unknown def (resolved): {:?}, args: {:?}", + resolved_def_id, resolved_args + ); + }; + def_ty.ty + } + fn type_call(&mut self, func: Operand<'tcx>, args: I, expected_ret: &rty::RefinedType) where I: IntoIterator>, { // TODO: handle const_fn_def on Env side let func_ty = if let Some((def_id, args)) = func.const_fn_def() { - let (resolved_def_id, resolved_args) = if self.ctx.is_fn_trait_method(def_id) { - // When calling a closure via `Fn`/`FnMut`/`FnOnce` trait, - // we simply replace the def_id with the closure's function def_id. - // This skips shims, and makes self arguments mismatch. visitor::RustCallVisitor - // adjusts the arguments accordingly. - let mir_ty::TyKind::Closure(closure_def_id, _) = args.type_at(0).kind() else { - panic!("expected closure arg for fn trait"); - }; - tracing::debug!(?closure_def_id, "closure instance"); - (*closure_def_id, args) - } else { - let param_env = self - .tcx - .param_env(self.local_def_id) - .with_reveal_all_normalized(self.tcx); - let instance = - mir_ty::Instance::resolve(self.tcx, param_env, def_id, args).unwrap(); - if let Some(instance) = instance { - (instance.def_id(), instance.args) - } else { - (def_id, args) - } - }; - if def_id != resolved_def_id { - tracing::info!(?def_id, ?resolved_def_id, ?resolved_args, "resolved"); - } - - self.ctx - .def_ty_with_args(resolved_def_id, resolved_args) - .expect("unknown def") - .ty - .vacuous() + self.fn_def_ty(def_id, args).vacuous() } else { self.operand_type(func.clone()).ty }; diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index 69d92c7..0332f4b 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -117,14 +117,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { pub fn analyze_predicate_definition(&self, local_def_id: LocalDefId) { // predicate's name - // TODO: simply use refine::user_defined_pred for all functions - // after we dropped old annotation parser impl - let impl_type = self.impl_type(); - let pred_item_name = self.tcx.item_name(local_def_id.to_def_id()).to_string(); - let pred = match impl_type { - Some(t) => chc::UserDefinedPred::new(t.to_string() + "_" + &pred_item_name), - None => refine::user_defined_pred(self.tcx, local_def_id.to_def_id()), - }; + let pred = refine::user_defined_pred(self.tcx, local_def_id.to_def_id()); // function's body use rustc_hir::{Block, Expr, ExprKind}; @@ -276,7 +269,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { || (all_params_annotated && has_ret) } - pub fn trait_item_id(&self) -> Option { + pub fn local_trait_item_id(&self) -> Option { let impl_item_assoc = self .tcx .opt_associated_item(self.local_def_id.to_def_id())?; @@ -284,9 +277,33 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { .trait_item_def_id .and_then(|id| id.as_local())?; + if trait_item_id == self.local_def_id { + return None; + } + Some(trait_item_id) } + pub fn trait_item_ty(&mut self) -> Option { + let impl_did = self.tcx.parent(self.local_def_id.to_def_id()); + + if self.tcx.def_kind(impl_did) != (rustc_hir::def::DefKind::Impl { of_trait: true }) { + return None; + } + + let trait_ref = self.tcx.impl_trait_ref(impl_did)?.instantiate_identity(); + let trait_item_did = self + .tcx + .associated_item(self.local_def_id.to_def_id()) + .trait_item_def_id + .unwrap(); + self.ctx.def_ty_with_args(trait_item_did, trait_ref.args) + } + + // Note that we do not expect predicate variables to be generated here + // when type params are still present in the type. Callers should ensure either + // - type params are fully instantiated, or + // - the function is fully annotated pub fn expected_ty(&mut self) -> rty::RefinedType { let sig = self .ctx @@ -324,7 +341,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self.generic_args, ); - if let Some(trait_item_id) = self.trait_item_id() { + if let Some(trait_item_id) = self.local_trait_item_id() { tracing::info!("trait item found: {:?}", trait_item_id); let trait_require_annot = self.ctx.extract_require_annot( trait_item_id, @@ -364,6 +381,9 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { assert!(require_annot.is_none() || param_annots.is_empty()); assert!(ensure_annot.is_none() || ret_annot.is_none()); + let trait_item_ty = self.trait_item_ty(); + let is_fully_annotated = self.is_fully_annotated(); + let mut builder = self.type_builder.for_function_template(&mut self.ctx, sig); if let Some(AnnotFormula::Formula(require)) = require_annot { let formula = require.map_var(|idx| { @@ -387,11 +407,18 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { builder.ret_rty(ret_rty); } - // Note that we do not expect predicate variables to be generated here - // when type params are still present in the type. Callers should ensure either - // - type params are fully instantiated, or - // - the function is fully annotated - rty::RefinedType::unrefined(builder.build().into()) + if is_fully_annotated { + let expected_ty = builder.build().into(); + if let Some(trait_item_ty) = trait_item_ty { + let clauses = rty::relate_sub_closed_type(&expected_ty, &trait_item_ty.ty); + self.ctx.extend_clauses(clauses); + } + rty::RefinedType::unrefined(expected_ty) + } else if let Some(trait_item_ty) = trait_item_ty { + trait_item_ty + } else { + rty::RefinedType::unrefined(builder.build().into()) + } } /// Extract the target DefId from `#[thrust::extern_spec_fn]` function. diff --git a/std.rs b/std.rs index fb0abcd..7cba080 100644 --- a/std.rs +++ b/std.rs @@ -63,7 +63,7 @@ mod thrust_models { } #[thrust::def::mut_model] - pub struct Mut(PhantomData); + pub struct Mut(PhantomData); impl Mut { #[allow(dead_code)] @@ -100,7 +100,7 @@ mod thrust_models { } #[thrust::def::box_model] - pub struct Box(PhantomData); + pub struct Box(PhantomData); impl Box { #[allow(dead_code)] @@ -128,7 +128,7 @@ mod thrust_models { } #[thrust::def::array_model] - pub struct Array(PhantomData, PhantomData); + pub struct Array(PhantomData, PhantomData); impl PartialEq for Array where U: super::Model { #[thrust::ignored] @@ -156,9 +156,9 @@ mod thrust_models { } #[thrust::def::closure_model] - pub struct Closure(PhantomData); + pub struct Closure(PhantomData); - pub struct Vec(pub Array, pub Int); + pub struct Vec(pub Array, pub Int); impl PartialEq for Vec where U: super::Model { #[thrust::ignored] @@ -200,7 +200,7 @@ mod thrust_models { type Ty = bool; } - impl Model for model::Closure { + impl Model for model::Closure { type Ty = model::Closure; } @@ -224,27 +224,27 @@ mod thrust_models { impl_tuple_model!(T0, T1, T2, T3, T4, T5, T6, T7, T8); impl_tuple_model!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9); - impl<'a, T> Model for &'a mut T where T: Model { + impl<'a, T: ?Sized> Model for &'a mut T where T: Model { type Ty = model::Mut<::Ty>; } - impl Model for model::Mut { + impl Model for model::Mut { type Ty = model::Mut; } - impl<'a, T> Model for &'a T where T: Model { + impl<'a, T: ?Sized> Model for &'a T where T: Model { type Ty = &'a ::Ty; } - impl Model for Box where T: Model { + impl Model for Box where T: Model { type Ty = model::Box<::Ty>; } - impl Model for model::Box { + impl Model for model::Box { type Ty = model::Box; } - impl Model for model::Array { + impl Model for model::Array { type Ty = model::Array; } @@ -252,7 +252,7 @@ mod thrust_models { type Ty = model::Vec<::Ty>; } - impl Model for model::Vec { + impl Model for model::Vec { type Ty = model::Vec; } diff --git a/tests/ui/fail/annot_preds_trait.rs b/tests/ui/fail/annot_preds_trait.rs index f3b0a2c..d6e9a28 100644 --- a/tests/ui/fail/annot_preds_trait.rs +++ b/tests/ui/fail/annot_preds_trait.rs @@ -2,6 +2,7 @@ //@compile-flags: -Adead_code -C debug-assertions=off // A is represented as Tuple in SMT-LIB2 format. +#[derive(PartialEq)] struct A { x: i64, } @@ -10,25 +11,27 @@ impl thrust_models::Model for A { type Ty = Self; } +#[thrust_macros::context] trait Double { // Support annotations in trait definitions - #[thrust::predicate] + #[thrust_macros::predicate] fn is_double(self, doubled: Self) -> bool; // This annotations are applied to all implementors of the `Double` trait. - #[thrust::requires(true)] - #[thrust::ensures(Self::is_double(*self, ^self))] + #[thrust_macros::requires(true)] + #[thrust_macros::ensures(Self::is_double(*self, !self))] fn double(&mut self); } +#[thrust_macros::context] impl Double for A { // Write concrete definitions for predicates in `impl` blocks - #[thrust::predicate] + #[thrust_macros::predicate] fn is_double(self, doubled: Self) -> bool { // (tuple_proj.0 self) is equivalent to self.x // self.x * 3 == doubled.x (this isn't actually doubled!) is written as following: "(= - (* (tuple_proj.0 self) 3) + (* (tuple_proj.0 self_) 3) (tuple_proj.0 doubled) )"; true // This definition does not comply with annotations in trait! } diff --git a/tests/ui/fail/annot_preds_trait_multi.rs b/tests/ui/fail/annot_preds_trait_multi.rs index 713f3e2..b1156e5 100644 --- a/tests/ui/fail/annot_preds_trait_multi.rs +++ b/tests/ui/fail/annot_preds_trait_multi.rs @@ -1,14 +1,15 @@ //@error-in-other-file: Unsat //@compile-flags: -Adead_code -C debug-assertions=off +#[thrust_macros::context] trait Double { // Support annotations in trait definitions - #[thrust::predicate] + #[thrust_macros::predicate] fn is_double(self, doubled: Self) -> bool; // This annotations are applied to all implementors of the `Double` trait. - #[thrust::requires(true)] - #[thrust::ensures(Self::is_double(*self, ^self))] + #[thrust_macros::requires(true)] + #[thrust_macros::ensures(Self::is_double(*self, !self))] fn double(&mut self); } @@ -21,12 +22,13 @@ impl thrust_models::Model for A { type Ty = Self; } +#[thrust_macros::context] impl Double for A { - #[thrust::predicate] + #[thrust_macros::predicate] fn is_double(self, doubled: Self) -> bool { // self.x * 2 == doubled.x "(= - (* (tuple_proj.0 self) 2) + (* (tuple_proj.0 self_) 2) (tuple_proj.0 doubled) )"; true } @@ -46,17 +48,18 @@ impl thrust_models::Model for B { type Ty = Self; } +#[thrust_macros::context] impl Double for B { - #[thrust::predicate] + #[thrust_macros::predicate] fn is_double(self, doubled: Self) -> bool { // self.x * 3 == doubled.x && self.y * 2 == doubled.y (this isn't actually doubled!) "(and (= - (* (tuple_proj.0 self) 3) + (* (tuple_proj.0 self_) 3) (tuple_proj.0 doubled) ) (= - (* (tuple_proj.1 self) 2) + (* (tuple_proj.1 self_) 2) (tuple_proj.1 doubled) ) )"; true // This definition does not comply with annotations in trait! diff --git a/tests/ui/fail/iterators/annot_range_loop.rs b/tests/ui/fail/iterators/annot_range_loop.rs index 356b43c..fa8f544 100644 --- a/tests/ui/fail/iterators/annot_range_loop.rs +++ b/tests/ui/fail/iterators/annot_range_loop.rs @@ -2,19 +2,20 @@ //@compile-flags: -C debug-assertions=off //@rustc-env: THRUST_SOLVER=tests/thrust-pcsat-wrapper THRUST_SOLVER_TIMEOUT_SECS=60 +#[thrust_macros::context] trait Iterator { type Item; - #[thrust::requires(true)] - #[thrust::ensures( - (Self::completed(*self) || (exists i:int. (result == std::option::Option::::Some(i)) && Self::step(*self, i, ^self))) - && (!Self::completed(*self) || (result == std::option::Option::::None() && *self == ^self)) + #[thrust_macros::ensures( + Self::completed(*self) + || thrust_models::exists(|i| (result == Some(i)) && Self::step(*self, i, !self)) )] + #[thrust_macros::ensures(!Self::completed(*self) || (result == None && *self == !self))] fn next(&mut self) -> Option; - #[thrust::predicate] + #[thrust_macros::predicate] fn completed(self) -> bool; - #[thrust::predicate] + #[thrust_macros::predicate] fn step(self, item: Self::Item, dist: Self) -> bool; } @@ -27,6 +28,7 @@ impl thrust_models::Model for Range { type Ty = Range; } +#[thrust_macros::context] impl Iterator for Range { type Item = i64; @@ -40,25 +42,25 @@ impl Iterator for Range { } } - #[thrust::predicate] + #[thrust_macros::predicate] fn completed(self) -> bool { // (tuple_proj.0 self) is equivalent to self.start // !(self.start < self.end) is written as following: "(not (< - (tuple_proj.0 self) - (tuple_proj.1 self) + (tuple_proj.0 self_) + (tuple_proj.1 self_) ))"; true } - #[thrust::predicate] + #[thrust_macros::predicate] fn step(self, item: Self::Item, dist: Self) -> bool { // self.end == dist.end && self.start == item && self.start + 1 == dist.start // is written as following: "(and - (= (tuple_proj.1 self) (tuple_proj.1 dist)) - (= (tuple_proj.0 self) item) - (= (+ (tuple_proj.0 self) 1) (tuple_proj.0 dist)) + (= (tuple_proj.1 self_) (tuple_proj.1 dist)) + (= (tuple_proj.0 self_) item) + (= (+ (tuple_proj.0 self_) 1) (tuple_proj.0 dist)) )"; true } diff --git a/tests/ui/fail/iterators/annot_range_next.rs b/tests/ui/fail/iterators/annot_range_next.rs index 1143c1b..a7cdfc0 100644 --- a/tests/ui/fail/iterators/annot_range_next.rs +++ b/tests/ui/fail/iterators/annot_range_next.rs @@ -2,19 +2,20 @@ //@compile-flags: -C debug-assertions=off //@rustc-env: THRUST_SOLVER=tests/thrust-pcsat-wrapper +#[thrust_macros::context] trait Iterator { type Item; - #[thrust::requires(true)] - #[thrust::ensures( - (Self::completed(*self) || (exists i:int. (result == std::option::Option::::Some(i)) && Self::step(*self, i, ^self))) - && (!Self::completed(*self) || (result == std::option::Option::::None() && *self == ^self)) + #[thrust_macros::ensures( + Self::completed(*self) + || thrust_models::exists(|i| (result == Some(i)) && Self::step(*self, i, !self)) )] + #[thrust_macros::ensures(!Self::completed(*self) || (result == None && *self == !self))] fn next(&mut self) -> Option; - #[thrust::predicate] + #[thrust_macros::predicate] fn completed(self) -> bool; - #[thrust::predicate] + #[thrust_macros::predicate] fn step(self, item: Self::Item, dist: Self) -> bool; } @@ -27,6 +28,7 @@ impl thrust_models::Model for Range { type Ty = Range; } +#[thrust_macros::context] impl Iterator for Range { type Item = i64; @@ -40,25 +42,25 @@ impl Iterator for Range { } } - #[thrust::predicate] + #[thrust_macros::predicate] fn completed(self) -> bool { // (tuple_proj.0 self) is equivalent to self.start // !(self.start < self.end) is written as following: "(not (< - (tuple_proj.0 self) - (tuple_proj.1 self) + (tuple_proj.0 self_) + (tuple_proj.1 self_) ))"; true } - #[thrust::predicate] + #[thrust_macros::predicate] fn step(self, item: Self::Item, dist: Self) -> bool { // self.end == dist.end && self.start == item && self.start + 1 == dist.start // is written as following: "(and - (= (tuple_proj.1 self) (tuple_proj.1 dist)) - (= (tuple_proj.0 self) item) - (= (+ (tuple_proj.0 self) 1) (tuple_proj.0 dist)) + (= (tuple_proj.1 self_) (tuple_proj.1 dist)) + (= (tuple_proj.0 self_) item) + (= (+ (tuple_proj.0 self_) 1) (tuple_proj.0 dist)) )"; true } diff --git a/tests/ui/pass/annot_preds_trait.rs b/tests/ui/pass/annot_preds_trait.rs index b16b5e8..45f26ab 100644 --- a/tests/ui/pass/annot_preds_trait.rs +++ b/tests/ui/pass/annot_preds_trait.rs @@ -2,6 +2,7 @@ //@compile-flags: -Adead_code -C debug-assertions=off // A is represented as Tuple in SMT-LIB2 format. +#[derive(PartialEq)] struct A { x: i64, } @@ -10,25 +11,27 @@ impl thrust_models::Model for A { type Ty = Self; } +#[thrust_macros::context] trait Double { // Support annotations in trait definitions - #[thrust::predicate] + #[thrust_macros::predicate] fn is_double(self, doubled: Self) -> bool; // This annotations are applied to all implementors of the `Double` trait. - #[thrust::requires(true)] - #[thrust::ensures(Self::is_double(*self, ^self))] + #[thrust_macros::requires(true)] + #[thrust_macros::ensures(Self::is_double(*self, !self))] fn double(&mut self); } +#[thrust_macros::context] impl Double for A { // Write concrete definitions for predicates in `impl` blocks - #[thrust::predicate] + #[thrust_macros::predicate] fn is_double(self, doubled: Self) -> bool { // (tuple_proj.0 self) is equivalent to self.x // self.x * 2 == doubled.x is written as following: "(= - (* (tuple_proj.0 self) 2) + (* (tuple_proj.0 self_) 2) (tuple_proj.0 doubled) )"; true } diff --git a/tests/ui/pass/annot_preds_trait_multi.rs b/tests/ui/pass/annot_preds_trait_multi.rs index 9467b16..326e5f0 100644 --- a/tests/ui/pass/annot_preds_trait_multi.rs +++ b/tests/ui/pass/annot_preds_trait_multi.rs @@ -1,14 +1,15 @@ //@check-pass //@compile-flags: -Adead_code -C debug-assertions=off +#[thrust_macros::context] trait Double { // Support annotations in trait definitions - #[thrust::predicate] + #[thrust_macros::predicate] fn is_double(self, doubled: Self) -> bool; // This annotations are applied to all implementors of the `Double` trait. - #[thrust::requires(true)] - #[thrust::ensures(Self::is_double(*self, ^self))] + #[thrust_macros::requires(true)] + #[thrust_macros::ensures(Self::is_double(*self, !self))] fn double(&mut self); } @@ -21,12 +22,13 @@ impl thrust_models::Model for A { type Ty = Self; } +#[thrust_macros::context] impl Double for A { - #[thrust::predicate] + #[thrust_macros::predicate] fn is_double(self, doubled: Self) -> bool { // self.x * 2 == doubled.x "(= - (* (tuple_proj.0 self) 2) + (* (tuple_proj.0 self_) 2) (tuple_proj.0 doubled) )"; true } @@ -46,17 +48,18 @@ impl thrust_models::Model for B { type Ty = Self; } +#[thrust_macros::context] impl Double for B { - #[thrust::predicate] + #[thrust_macros::predicate] fn is_double(self, doubled: Self) -> bool { // self.x * 2 == doubled.x && self.y * 2 == doubled.y "(and (= - (* (tuple_proj.0 self) 2) + (* (tuple_proj.0 self_) 2) (tuple_proj.0 doubled) ) (= - (* (tuple_proj.1 self) 2) + (* (tuple_proj.1 self_) 2) (tuple_proj.1 doubled) ) )"; true diff --git a/tests/ui/pass/iterators/annot_range_loop.rs b/tests/ui/pass/iterators/annot_range_loop.rs index a0cfd28..2c09bff 100644 --- a/tests/ui/pass/iterators/annot_range_loop.rs +++ b/tests/ui/pass/iterators/annot_range_loop.rs @@ -2,19 +2,20 @@ //@compile-flags: -C debug-assertions=off //@rustc-env: THRUST_SOLVER=tests/thrust-pcsat-wrapper THRUST_SOLVER_TIMEOUT_SECS=60 +#[thrust_macros::context] trait Iterator { type Item; - #[thrust::requires(true)] - #[thrust::ensures( - (Self::completed(*self) || (exists i:int. (result == std::option::Option::::Some(i)) && Self::step(*self, i, ^self))) - && (!Self::completed(*self) || (result == std::option::Option::::None() && *self == ^self)) + #[thrust_macros::ensures( + Self::completed(*self) + || thrust_models::exists(|i| (result == Some(i)) && Self::step(*self, i, !self)) )] + #[thrust_macros::ensures(!Self::completed(*self) || (result == None && *self == !self))] fn next(&mut self) -> Option; - #[thrust::predicate] + #[thrust_macros::predicate] fn completed(self) -> bool; - #[thrust::predicate] + #[thrust_macros::predicate] fn step(self, item: Self::Item, dist: Self) -> bool; } @@ -27,6 +28,7 @@ impl thrust_models::Model for Range { type Ty = Range; } +#[thrust_macros::context] impl Iterator for Range { type Item = i64; @@ -40,25 +42,25 @@ impl Iterator for Range { } } - #[thrust::predicate] + #[thrust_macros::predicate] fn completed(self) -> bool { // (tuple_proj.0 self) is equivalent to self.start // !(self.start < self.end) is written as following: "(not (< - (tuple_proj.0 self) - (tuple_proj.1 self) + (tuple_proj.0 self_) + (tuple_proj.1 self_) ))"; true } - #[thrust::predicate] + #[thrust_macros::predicate] fn step(self, item: Self::Item, dist: Self) -> bool { // self.end == dist.end && self.start == item && self.start + 1 == dist.start // is written as following: "(and - (= (tuple_proj.1 self) (tuple_proj.1 dist)) - (= (tuple_proj.0 self) item) - (= (+ (tuple_proj.0 self) 1) (tuple_proj.0 dist)) + (= (tuple_proj.1 self_) (tuple_proj.1 dist)) + (= (tuple_proj.0 self_) item) + (= (+ (tuple_proj.0 self_) 1) (tuple_proj.0 dist)) )"; true } diff --git a/tests/ui/pass/iterators/annot_range_next.rs b/tests/ui/pass/iterators/annot_range_next.rs index 6cafdd3..d0ecc0a 100644 --- a/tests/ui/pass/iterators/annot_range_next.rs +++ b/tests/ui/pass/iterators/annot_range_next.rs @@ -2,19 +2,20 @@ //@compile-flags: -C debug-assertions=off //@rustc-env: THRUST_SOLVER=tests/thrust-pcsat-wrapper +#[thrust_macros::context] trait Iterator { type Item; - #[thrust::requires(true)] - #[thrust::ensures( - (Self::completed(*self) || (exists i:int. (result == std::option::Option::::Some(i)) && Self::step(*self, i, ^self))) - && (!Self::completed(*self) || (result == std::option::Option::::None() && *self == ^self)) + #[thrust_macros::ensures( + Self::completed(*self) + || thrust_models::exists(|i| (result == Some(i)) && Self::step(*self, i, !self)) )] + #[thrust_macros::ensures(!Self::completed(*self) || (result == None && *self == !self))] fn next(&mut self) -> Option; - #[thrust::predicate] + #[thrust_macros::predicate] fn completed(self) -> bool; - #[thrust::predicate] + #[thrust_macros::predicate] fn step(self, item: Self::Item, dist: Self) -> bool; } @@ -27,6 +28,7 @@ impl thrust_models::Model for Range { type Ty = Range; } +#[thrust_macros::context] impl Iterator for Range { type Item = i64; @@ -40,25 +42,25 @@ impl Iterator for Range { } } - #[thrust::predicate] + #[thrust_macros::predicate] fn completed(self) -> bool { // (tuple_proj.0 self) is equivalent to self.start // !(self.start < self.end) is written as following: "(not (< - (tuple_proj.0 self) - (tuple_proj.1 self) + (tuple_proj.0 self_) + (tuple_proj.1 self_) ))"; true } - #[thrust::predicate] + #[thrust_macros::predicate] fn step(self, item: Self::Item, dist: Self) -> bool { // self.end == dist.end && self.start == item && self.start + 1 == dist.start // is written as following: "(and - (= (tuple_proj.1 self) (tuple_proj.1 dist)) - (= (tuple_proj.0 self) item) - (= (+ (tuple_proj.0 self) 1) (tuple_proj.0 dist)) + (= (tuple_proj.1 self_) (tuple_proj.1 dist)) + (= (tuple_proj.0 self_) item) + (= (+ (tuple_proj.0 self_) 1) (tuple_proj.0 dist)) )"; true } diff --git a/thrust-macros/src/lib.rs b/thrust-macros/src/lib.rs index c5ccf9d..2505d36 100644 --- a/thrust-macros/src/lib.rs +++ b/thrust-macros/src/lib.rs @@ -2,39 +2,221 @@ use proc_macro::TokenStream; use proc_macro2::TokenStream as TokenStream2; use quote::{format_ident, quote, ToTokens}; use syn::{ - parse_macro_input, punctuated::Punctuated, FnArg, GenericParam, Generics, ItemFn, ReturnType, - Type, TypeParamBound, WherePredicate, + parse_macro_input, punctuated::Punctuated, FnArg, GenericParam, Generics, TypeParamBound, + WherePredicate, }; +#[derive(Debug, Clone)] +enum FnOuterItem { + ItemImpl(syn::ItemImpl), + ItemTrait(syn::ItemTrait), +} + +impl syn::parse::Parse for FnOuterItem { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + use syn::parse::discouraged::Speculative as _; + + let fork = input.fork(); + if let Ok(item_impl) = fork.parse::() { + input.advance_to(&fork); + return Ok(Self::ItemImpl(item_impl)); + } + + let fork = input.fork(); + if let Ok(item_trait) = fork.parse::() { + input.advance_to(&fork); + return Ok(Self::ItemTrait(item_trait)); + } + + Err(input.error("expected an impl block or a trait definition")) + } +} + +impl quote::ToTokens for FnOuterItem { + fn to_tokens(&self, tokens: &mut TokenStream2) { + match self { + FnOuterItem::ItemImpl(item_impl) => item_impl.to_tokens(tokens), + FnOuterItem::ItemTrait(item_trait) => item_trait.to_tokens(tokens), + } + } +} + +impl FnOuterItem { + fn into_header_only(mut self) -> Self { + match &mut self { + FnOuterItem::ItemImpl(item_impl) => { + item_impl.items.clear(); + } + FnOuterItem::ItemTrait(item_trait) => { + item_trait.items.clear(); + } + } + self + } + + fn generics(&self) -> &Generics { + match self { + FnOuterItem::ItemImpl(item_impl) => &item_impl.generics, + FnOuterItem::ItemTrait(item_trait) => &item_trait.generics, + } + } +} + #[proc_macro_attribute] pub fn context(_attr: TokenStream, item: TokenStream) -> TokenStream { - let mut impl_item = syn::parse_macro_input!(item as syn::ItemImpl); - let impl_header = { - let mut header = impl_item.clone(); - header.items.clear(); - header + let mut outer_item = syn::parse_macro_input!(item as FnOuterItem); + let outer_header = outer_item.clone().into_header_only(); + match &mut outer_item { + FnOuterItem::ItemImpl(item_impl) => { + for item in &mut item_impl.items { + let syn::ImplItem::Fn(item) = item else { + continue; + }; + item.attrs + .push(syn::parse_quote!(#[thrust::_outer_context(#outer_header)])); + } + } + FnOuterItem::ItemTrait(item_trait) => { + for item in &mut item_trait.items { + let syn::TraitItem::Fn(item) = item else { + continue; + }; + item.attrs + .push(syn::parse_quote!(#[thrust::_outer_context(#outer_header)])); + } + } + } + + outer_item.into_token_stream().into() +} + +#[allow(clippy::enum_variant_names)] +#[derive(Debug, Clone)] +enum FnItemWithSignature { + ItemFn(syn::ItemFn), + ImplItemFn(syn::ImplItemFn), + TraitItemFn(syn::TraitItemFn), +} + +impl syn::parse::Parse for FnItemWithSignature { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + use syn::parse::discouraged::Speculative as _; + + let fork = input.fork(); + if let Ok(item_fn) = fork.parse::() { + input.advance_to(&fork); + return Ok(Self::ItemFn(item_fn)); + } + + let fork = input.fork(); + if let Ok(impl_item_fn) = fork.parse::() { + input.advance_to(&fork); + return Ok(Self::ImplItemFn(impl_item_fn)); + } + + let fork = input.fork(); + if let Ok(trait_item_fn) = fork.parse::() { + input.advance_to(&fork); + return Ok(Self::TraitItemFn(trait_item_fn)); + } + + Err(input.error("expected a free function, an impl method, or a trait method")) + } +} + +impl quote::ToTokens for FnItemWithSignature { + fn to_tokens(&self, tokens: &mut TokenStream2) { + match self { + FnItemWithSignature::ItemFn(item_fn) => item_fn.to_tokens(tokens), + FnItemWithSignature::ImplItemFn(impl_item_fn) => impl_item_fn.to_tokens(tokens), + FnItemWithSignature::TraitItemFn(trait_item_fn) => trait_item_fn.to_tokens(tokens), + } + } +} + +impl FnItemWithSignature { + fn block(&self) -> Option<&syn::Block> { + match self { + FnItemWithSignature::ItemFn(item_fn) => Some(&item_fn.block), + FnItemWithSignature::ImplItemFn(impl_item_fn) => Some(&impl_item_fn.block), + FnItemWithSignature::TraitItemFn(_) => None, + } + } + + fn block_mut(&mut self) -> Option<&mut syn::Block> { + match self { + FnItemWithSignature::ItemFn(item_fn) => Some(&mut item_fn.block), + FnItemWithSignature::ImplItemFn(impl_item_fn) => Some(&mut impl_item_fn.block), + FnItemWithSignature::TraitItemFn(_) => None, + } + } + + fn attrs(&self) -> &[syn::Attribute] { + match self { + FnItemWithSignature::ItemFn(item_fn) => &item_fn.attrs, + FnItemWithSignature::ImplItemFn(impl_item_fn) => &impl_item_fn.attrs, + FnItemWithSignature::TraitItemFn(trait_item_fn) => &trait_item_fn.attrs, + } + } + + fn attrs_mut(&mut self) -> &mut Vec { + match self { + FnItemWithSignature::ItemFn(item_fn) => &mut item_fn.attrs, + FnItemWithSignature::ImplItemFn(impl_item_fn) => &mut impl_item_fn.attrs, + FnItemWithSignature::TraitItemFn(trait_item_fn) => &mut trait_item_fn.attrs, + } + } + + fn sig(&self) -> &syn::Signature { + match self { + FnItemWithSignature::ItemFn(item_fn) => &item_fn.sig, + FnItemWithSignature::ImplItemFn(impl_item_fn) => &impl_item_fn.sig, + FnItemWithSignature::TraitItemFn(trait_item_fn) => &trait_item_fn.sig, + } + } +} + +#[proc_macro_attribute] +pub fn predicate(_attr: TokenStream, item: TokenStream) -> TokenStream { + let func = parse_macro_input!(item as FnItemWithSignature); + let outer_context = match extract_outer_context(&func) { + Ok(ctx) => ctx, + Err(e) => { + let err = e.to_compile_error(); + return quote! { #err #func }.into(); + } }; - for item in &mut impl_item.items { - let syn::ImplItem::Fn(item) = item else { - continue; - }; - // TODO: why ::thrust_macros doesn't work here? - item.attrs - .push(syn::parse_quote!(#[thrust::_impl_context(#impl_header)])); + + let name = &func.sig().ident; + let def_generics = generic_params_tokens(&func.sig().generics); + let model_ty_params = fn_params_with_model_ty(&func.sig().inputs); + let model_ret = fn_return_ty_with_model_ty(&func.sig().output); + + let model_preds = model_where_predicates(&func, outer_context.as_ref()); + let extended_where = extended_where_clause(&func, &model_preds); + + let sig = quote! { + #[allow(dead_code)] + #[thrust::predicate] + fn #name #def_generics(#model_ty_params) -> #model_ret #extended_where + }; + if let Some(block) = func.block() { + quote! { #sig #block }.into() + } else { + quote! { #sig; }.into() } - impl_item.into_token_stream().into() } #[proc_macro_attribute] pub fn requires(attr: TokenStream, item: TokenStream) -> TokenStream { let expr = TokenStream2::from(attr); - let mut func = parse_macro_input!(item as ItemFn); + let mut func = parse_macro_input!(item as FnItemWithSignature); let (req_expr, ens_expr) = match extract_requires_ensures(&mut func) { Ok((req, ens)) => (req, ens), Err(e) => return e.to_compile_error().into(), }; - func.attrs.push(syn::parse_quote!( + func.attrs_mut().push(syn::parse_quote!( #[::thrust_macros::_requires_ensures((#req_expr) && (#expr), #ens_expr)] )); @@ -44,25 +226,25 @@ pub fn requires(attr: TokenStream, item: TokenStream) -> TokenStream { #[proc_macro_attribute] pub fn ensures(attr: TokenStream, item: TokenStream) -> TokenStream { let expr = TokenStream2::from(attr); - let mut func = parse_macro_input!(item as ItemFn); + let mut func = parse_macro_input!(item as FnItemWithSignature); let (req_expr, ens_expr) = match extract_requires_ensures(&mut func) { Ok((req, ens)) => (req, ens), Err(e) => return e.to_compile_error().into(), }; - func.attrs.push(syn::parse_quote!( + func.attrs_mut().push(syn::parse_quote!( #[::thrust_macros::_requires_ensures(#req_expr, (#ens_expr) && (#expr))] )); func.into_token_stream().into() } -fn extract_requires_ensures(func: &mut ItemFn) -> syn::Result<(syn::Expr, syn::Expr)> { +fn extract_requires_ensures(func: &mut FnItemWithSignature) -> syn::Result<(syn::Expr, syn::Expr)> { let mut result = None; let requires_ensures_path: syn::Path = syn::parse_quote!(::thrust_macros::_requires_ensures); - for attr in &func.attrs { + for attr in func.attrs() { if attr.path() == &requires_ensures_path { if result.is_some() { return Err(syn::Error::new_spanned( @@ -85,7 +267,7 @@ fn extract_requires_ensures(func: &mut ItemFn) -> syn::Result<(syn::Expr, syn::E } } - func.attrs + func.attrs_mut() .retain(|attr| attr.path() != &requires_ensures_path); if let Some((req_expr, ens_expr)) = result { @@ -115,48 +297,49 @@ pub fn _requires_ensures(attr: TokenStream, item: TokenStream) -> TokenStream { let ens_expr = exprs.pop().unwrap().into_value(); let req_expr = exprs.pop().unwrap().into_value(); - let func = parse_macro_input!(item as ItemFn); - let impl_context = match extract_impl_context(&func) { + let func = parse_macro_input!(item as FnItemWithSignature); + let outer_context = match extract_outer_context(&func) { Ok(ctx) => ctx, - Err(e) => return e.to_compile_error().into(), + Err(e) => { + let err = e.to_compile_error(); + return quote! { #err #func }.into(); + } }; - if mentions_self(&func.sig) && impl_context.is_none() { - let err = syn::Error::new_spanned( - func.sig.ident.clone(), - "Wrap impl block with #[thrust_macros::context] to use requires/ensures on methods", - ) - .to_compile_error(); - return quote! { #err #func }.into(); - } let mut tokens = ExpandedTokens::new(func, req_expr, ens_expr); - if let Some(ctx) = impl_context { - tokens = tokens.with_impl_context(ctx); + if let Some(ctx) = outer_context { + tokens = tokens.with_outer_context(ctx); } tokens.into_token_stream().into() } -fn extract_impl_context(func: &syn::ItemFn) -> syn::Result> { - let impl_context_path: syn::Path = syn::parse_quote!(thrust::_impl_context); - let mut impl_context = None; - for attr in &func.attrs { - if attr.path() != &impl_context_path { +fn extract_outer_context(func: &FnItemWithSignature) -> syn::Result> { + let outer_context_path: syn::Path = syn::parse_quote!(thrust::_outer_context); + let mut outer_context = None; + for attr in func.attrs() { + if attr.path() != &outer_context_path { continue; } let item = attr.parse_args()?; - if impl_context.is_some() { + if outer_context.is_some() { return Err(syn::Error::new_spanned( attr, - "multiple _impl_context attributes found; expected at most one", + "multiple _outer_context attributes found; expected at most one", )); } - impl_context = Some(item); + outer_context = Some(item); + } + if mentions_self(func.sig()) && outer_context.is_none() { + return Err(syn::Error::new_spanned( + func.sig().ident.clone(), + "Wrap impl block with #[thrust_macros::context] to annotate methods", + )); } - Ok(impl_context) + Ok(outer_context) } struct ExpandedTokens { - func: ItemFn, + func: FnItemWithSignature, requires_name: syn::Ident, ensures_name: syn::Ident, @@ -167,9 +350,9 @@ struct ExpandedTokens { turbofish: TokenStream2, model_ty_params: TokenStream2, - ret_model_ty: Type, + ret_model_ty: syn::Type, - impl_context: Option, + outer_context: Option, } impl quote::ToTokens for ExpandedTokens { @@ -183,22 +366,23 @@ impl quote::ToTokens for ExpandedTokens { } impl ExpandedTokens { - pub fn new(func: ItemFn, mut req_expr: syn::Expr, mut ens_expr: syn::Expr) -> Self { - let name = &func.sig.ident; + pub fn new( + func: FnItemWithSignature, + mut req_expr: syn::Expr, + mut ens_expr: syn::Expr, + ) -> Self { + let name = &func.sig().ident; let requires_name = format_ident!("_thrust_requires_{}", name); let ensures_name = format_ident!("_thrust_ensures_{}", name); - let generics = &func.sig.generics; + let generics = &func.sig().generics; let def_generics = generic_params_tokens(generics); let turbofish = generic_turbofish(generics); - let model_ty_params = fn_params_with_model_ty(&func.sig.inputs); - let ret_model_ty: Type = match &func.sig.output { - ReturnType::Default => syn::parse_quote!(<() as thrust_models::Model>::Ty), - ReturnType::Type(_, ty) => syn::parse_quote!(<#ty as thrust_models::Model>::Ty), - }; + let model_ty_params = fn_params_with_model_ty(&func.sig().inputs); + let ret_model_ty = fn_return_ty_with_model_ty(&func.sig().output); - if func.sig.receiver().is_some() { + if func.sig().receiver().is_some() { rewrite_self_in_expr(&mut req_expr); rewrite_self_in_expr(&mut ens_expr); } @@ -213,78 +397,24 @@ impl ExpandedTokens { turbofish, model_ty_params, ret_model_ty, - impl_context: None, + outer_context: None, } } - pub fn with_impl_context(mut self, impl_item: syn::ItemImpl) -> Self { - self.impl_context = Some(impl_item); + pub fn with_outer_context(mut self, outer_item: FnOuterItem) -> Self { + self.outer_context = Some(outer_item); self } - /// Returns `T: thrust_models::Model` predicates for every type param that does not - /// already carry an `Fn`, `FnOnce`, or `FnMut` bound. - fn model_where_predicates(&self) -> Vec { - let mut generic_type_params: Vec<&syn::TypeParam> = Vec::new(); - for param in &self.func.sig.generics.params { - let GenericParam::Type(tp) = param else { - continue; - }; - generic_type_params.push(tp); - } - if let Some(impl_item) = &self.impl_context { - for param in &impl_item.generics.params { - let GenericParam::Type(tp) = param else { - continue; - }; - generic_type_params.push(tp); - } - } - - let mut predicates: Vec = Vec::new(); - for param in generic_type_params { - let has_fn_bound = param.bounds.iter().any(|b| { - let TypeParamBound::Trait(tb) = b else { - return false; - }; - tb.path.segments.last().map_or(false, |s| { - matches!(s.ident.to_string().as_str(), "Fn" | "FnOnce" | "FnMut") - }) - }); - if has_fn_bound { - continue; - } - let ident = ¶m.ident; - predicates.push(syn::parse_quote!(#ident: thrust_models::Model)); - predicates.push(syn::parse_quote!(<#ident as thrust_models::Model>::Ty: PartialEq)); - } - predicates - } - - /// Builds `where , `. - /// Returns an empty token stream when both sets are empty. fn extended_where_clause(&self) -> TokenStream2 { - let model_preds = self.model_where_predicates(); - let existing: Vec<&WherePredicate> = self - .func - .sig - .generics - .where_clause - .as_ref() - .map(|wc| wc.predicates.iter().collect()) - .unwrap_or_default(); - - if existing.is_empty() && model_preds.is_empty() { - return quote!(); - } - - quote! { where #(#existing,)* #(#model_preds),* } + let model_preds = model_where_predicates(&self.func, self.outer_context.as_ref()); + extended_where_clause(&self.func, &model_preds) } fn is_extern_spec_fn(&self) -> bool { let extern_spec_fn_path: syn::Path = syn::parse_quote!(thrust::extern_spec_fn); self.func - .attrs + .attrs() .iter() .any(|a| a.path() == &extern_spec_fn_path) } @@ -325,14 +455,14 @@ impl ExpandedTokens { } fn path_prefix(&self) -> Option { - self.impl_context.as_ref()?; + self.outer_context.as_ref()?; Some(quote!(Self::)) } fn expand(&self) -> TokenStream2 { let mut func = self.func.clone(); let trusted_path: syn::Path = syn::parse_quote!(thrust::trusted); - for attr in &mut func.attrs { + for attr in func.attrs_mut() { if attr.path() == &trusted_path { *attr = syn::parse_quote!(#[thrust::ignored]); } @@ -341,9 +471,9 @@ impl ExpandedTokens { let requires_fn = self.requires_fn(); let ensures_fn = self.ensures_fn(); - let extern_spec_name = format_ident!("_thrust_extern_spec_{}", self.func.sig.ident); + let extern_spec_name = format_ident!("_thrust_extern_spec_{}", self.func.sig().ident); let def_generics = &self.def_generics; - let orig_output = &self.func.sig.output; + let orig_output = &self.func.sig().output; let extended_where = self.extended_where_clause(); let requires_name = &self.requires_name; @@ -351,8 +481,8 @@ impl ExpandedTokens { let turbofish = &self.turbofish; let path_prefix = self.path_prefix(); - let name = &self.func.sig.ident; - let (extern_spec_inputs, call_args) = rewrite_inputs_for_call(&self.func.sig.inputs); + let name = &self.func.sig().ident; + let (extern_spec_inputs, call_args) = rewrite_inputs_for_call(&self.func.sig().inputs); quote! { #func @@ -381,16 +511,32 @@ impl ExpandedTokens { let path_prefix = self.path_prefix(); let mut func = self.func.clone(); - let orig_stmts = func.block.stmts.clone(); - func.block = syn::parse_quote!({ - #[thrust::requires_path] - #path_prefix #requires_name #turbofish; + let func_tokens = if let Some(block) = func.block_mut() { + let orig_stmts = block.stmts.drain(..).collect::>(); + *block = syn::parse_quote!({ + #[thrust::requires_path] + #path_prefix #requires_name #turbofish; - #[thrust::ensures_path] - #path_prefix #ensures_name #turbofish; + #[thrust::ensures_path] + #path_prefix #ensures_name #turbofish; - #(#orig_stmts)* - }); + #(#orig_stmts)* + }); + quote! { + #[allow(path_statements)] + #func + } + } else { + let error = syn::Error::new_spanned( + func.sig().ident.clone(), + "extern_spec_fn must have a function body", + ) + .into_compile_error(); + quote! { + #error + #func + } + }; let requires_fn = self.requires_fn(); let ensures_fn = self.ensures_fn(); @@ -399,8 +545,7 @@ impl ExpandedTokens { #requires_fn #ensures_fn - #[allow(path_statements)] - #func + #func_tokens } } } @@ -515,3 +660,137 @@ fn rewrite_inputs_for_call( (quote!(#(#rewritten),*), quote!(#(#call_args),*)) } + +/// Returns `T: thrust_models::Model` predicates for every type param that does not +/// already carry an `Fn`, `FnOnce`, or `FnMut` bound. +fn model_where_predicates( + func: &FnItemWithSignature, + outer_context: Option<&FnOuterItem>, +) -> Vec { + struct GenericTypeParam { + ident: syn::Ident, + bounds: Vec, + } + + impl From for GenericTypeParam { + fn from(tp: syn::TypeParam) -> Self { + Self { + ident: tp.ident, + bounds: tp.bounds.into_iter().collect(), + } + } + } + + impl GenericTypeParam { + fn has_fn_bound(&self) -> bool { + self.bounds.iter().any(|b| { + let TypeParamBound::Trait(tb) = b else { + return false; + }; + tb.path.segments.last().map_or(false, |s| { + matches!(s.ident.to_string().as_str(), "Fn" | "FnOnce" | "FnMut") + }) + }) + } + } + + let mut generic_type_params: Vec = Vec::new(); + for param in &func.sig().generics.params { + let GenericParam::Type(tp) = param else { + continue; + }; + generic_type_params.push(tp.clone().into()); + } + if let Some(outer_item) = outer_context { + for param in &outer_item.generics().params { + let GenericParam::Type(tp) = param else { + continue; + }; + generic_type_params.push(tp.clone().into()); + } + if let FnOuterItem::ItemTrait(outer_item) = &outer_item { + generic_type_params.push(GenericTypeParam { + ident: format_ident!("Self"), + bounds: outer_item.supertraits.iter().cloned().collect(), + }); + } + } + generic_type_params.retain(|p| !p.has_fn_bound()); + + let mut predicates: Vec = Vec::new(); + for param in &generic_type_params { + let ident = ¶m.ident; + predicates.push(syn::parse_quote!(#ident: thrust_models::Model)); + predicates.push(syn::parse_quote!(<#ident as thrust_models::Model>::Ty: PartialEq)); + } + + struct Visitor { + generic_type_params: Vec, + generic_paths: Vec, + } + + impl syn::visit::Visit<'_> for Visitor { + fn visit_type_path(&mut self, tp: &syn::TypePath) { + for param in &self.generic_type_params { + if let Some(qself) = &tp.qself { + let param = ¶m.ident; + let param_ty: syn::Type = syn::parse_quote!(#param); + if *qself.ty == param_ty { + self.generic_paths.push(tp.clone()); + } + } + if tp.path.segments.len() > 1 + && tp.path.segments.first().unwrap().ident == param.ident + && tp.qself.is_none() + { + self.generic_paths.push(tp.clone()); + } + } + syn::visit::visit_type_path(self, tp); + } + } + + let mut visitor = Visitor { + generic_type_params, + generic_paths: Vec::new(), + }; + use syn::visit::Visit as _; + for arg in &func.sig().inputs { + visitor.visit_fn_arg(arg); + } + visitor.visit_return_type(&func.sig().output); + for tp in visitor.generic_paths { + predicates.push(syn::parse_quote!(#tp: thrust_models::Model)); + predicates.push(syn::parse_quote!(<#tp as thrust_models::Model>::Ty: PartialEq)); + } + + predicates +} + +/// Builds `where , `. +/// Returns an empty token stream when both sets are empty. +fn extended_where_clause( + func: &FnItemWithSignature, + model_preds: &Vec, +) -> TokenStream2 { + let existing: Vec<&WherePredicate> = func + .sig() + .generics + .where_clause + .as_ref() + .map(|wc| wc.predicates.iter().collect()) + .unwrap_or_default(); + + if existing.is_empty() && model_preds.is_empty() { + return quote!(); + } + + quote! { where #(#existing,)* #(#model_preds),* } +} + +fn fn_return_ty_with_model_ty(ret: &syn::ReturnType) -> syn::Type { + match ret { + syn::ReturnType::Default => syn::parse_quote!(<() as thrust_models::Model>::Ty), + syn::ReturnType::Type(_, ty) => syn::parse_quote!(<#ty as thrust_models::Model>::Ty), + } +}