diff --git a/compiler/rustc_const_eval/src/interpret/cast.rs b/compiler/rustc_const_eval/src/interpret/cast.rs index ef3e96784ce86..c7a5135a24147 100644 --- a/compiler/rustc_const_eval/src/interpret/cast.rs +++ b/compiler/rustc_const_eval/src/interpret/cast.rs @@ -431,11 +431,9 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { }; let erased_trait_ref = upcast_trait_ref .map_bound(|r| ty::ExistentialTraitRef::erase_self_ty(*self.tcx, r)); - assert!( - data_b - .principal() - .is_some_and(|b| self.eq_in_param_env(erased_trait_ref, b)) - ); + assert!(data_b.principal().is_some_and(|b| { + self.relate_dyn_predicates_bivariantly(erased_trait_ref, b) + })); } else { // In this case codegen would keep using the old vtable. We don't want to do // that as it has the wrong trait. The reason codegen can do this is that diff --git a/compiler/rustc_const_eval/src/interpret/eval_context.rs b/compiler/rustc_const_eval/src/interpret/eval_context.rs index 95a72d3cbc1d7..395db3343b55e 100644 --- a/compiler/rustc_const_eval/src/interpret/eval_context.rs +++ b/compiler/rustc_const_eval/src/interpret/eval_context.rs @@ -12,6 +12,7 @@ use rustc_middle::query::TyCtxtAt; use rustc_middle::ty::layout::{ self, FnAbiError, FnAbiOfHelpers, FnAbiRequest, LayoutError, LayoutOfHelpers, TyAndLayout, }; +use rustc_middle::ty::relate::Relate; use rustc_middle::ty::{self, GenericArgsRef, Ty, TyCtxt, TypeFoldable, TypingEnv, Variance}; use rustc_middle::{mir, span_bug}; use rustc_session::Limit; @@ -323,12 +324,25 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { } } - /// Check if the two things are equal in the current param_env, using an infcx to get proper - /// equality checks. + /// Checks whether the predicates of two trait objects are sufficiently equal to allow + /// transmutes between them. + /// + /// This is called on [`ty::PolyExistentialTraitRef`] or [`ty::PolyExistentialProjection`] + /// and checks whether there is a subtyping relation between them in either direction. + /// + /// # Examples + /// + /// - returns `true` for `dyn for<'a> Trait` and `dyn Trait` + /// - returns `false` for `dyn Trait fn(&'a u8)>` and either of the above #[instrument(level = "trace", skip(self), ret)] - pub(super) fn eq_in_param_env(&self, a: T, b: T) -> bool + pub(super) fn relate_dyn_predicates_bivariantly( + &self, + a: ty::Binder<'tcx, T>, + b: ty::Binder<'tcx, T>, + ) -> bool where - T: PartialEq + TypeFoldable> + ToTrace<'tcx>, + T: PartialEq + Copy + TypeFoldable> + Relate>, + ty::Binder<'tcx, T>: ToTrace<'tcx>, { // Fast path: compare directly. if a == b { @@ -338,11 +352,18 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { let (infcx, param_env) = self.tcx.infer_ctxt().build_with_typing_env(self.typing_env); let ocx = ObligationCtxt::new(&infcx); let cause = ObligationCause::dummy_with_span(self.cur_span()); + let trace = ToTrace::to_trace(&cause, a, b); + + // Instantiate the binder with erased instead of fresh vars, because in runtime MIR + // all free regions are erased anyway, so it doesn't make a difference. + let a = self.tcx.instantiate_bound_regions_with_erased(a); + let b = self.tcx.instantiate_bound_regions_with_erased(b); + // equate the two trait refs after normalization let a = ocx.normalize(&cause, param_env, a); let b = ocx.normalize(&cause, param_env, b); - if let Err(terr) = ocx.eq(&cause, param_env, a, b) { + if let Err(terr) = ocx.eq_trace(&cause, param_env, trace, a, b) { trace!(?terr); return false; } @@ -353,6 +374,12 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { return false; } + // Do a leak check to ensure that e.g. `for<'a> fn(&'a u8)` and `fn(&'static u8)` are not equal. + if let Err(terr) = infcx.leak_check(ty::UniverseIndex::ROOT, None) { + trace!(?terr, "failed leak check"); + return false; + } + // All good. true } diff --git a/compiler/rustc_const_eval/src/interpret/traits.rs b/compiler/rustc_const_eval/src/interpret/traits.rs index af8d618b6b5e5..be0a7e453b3a8 100644 --- a/compiler/rustc_const_eval/src/interpret/traits.rs +++ b/compiler/rustc_const_eval/src/interpret/traits.rs @@ -90,12 +90,18 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { ( ty::ExistentialPredicate::Trait(a_data), ty::ExistentialPredicate::Trait(b_data), - ) => self.eq_in_param_env(a_pred.rebind(a_data), b_pred.rebind(b_data)), + ) => self.relate_dyn_predicates_bivariantly( + a_pred.rebind(a_data), + b_pred.rebind(b_data), + ), ( ty::ExistentialPredicate::Projection(a_data), ty::ExistentialPredicate::Projection(b_data), - ) => self.eq_in_param_env(a_pred.rebind(a_data), b_pred.rebind(b_data)), + ) => self.relate_dyn_predicates_bivariantly( + a_pred.rebind(a_data), + b_pred.rebind(b_data), + ), _ => false, }; diff --git a/src/tools/miri/tests/fail/validity/dyn-trait-leak-check.rs b/src/tools/miri/tests/fail/validity/dyn-trait-leak-check.rs new file mode 100644 index 0000000000000..f4d6a89122781 --- /dev/null +++ b/src/tools/miri/tests/fail/validity/dyn-trait-leak-check.rs @@ -0,0 +1,41 @@ +// Test that transmuting from `&dyn Dyn` to `&dyn Dyn fn(&'a ())>` is UB. +// +// The vtable of `() as Dyn` and `() as Dyn fn(&'a ())>` can have +// different entries and, because in the former the entry for `foo` is vacant, this test will +// segfault at runtime. + +trait Dyn { + fn foo(&self) + where + U: HigherRanked, + { + U::call() + } +} +impl Dyn for T {} + +trait HigherRanked { + fn call(); +} +impl HigherRanked for for<'a> fn(&'a ()) { + fn call() { + println!("higher ranked"); + } +} + +// 2nd candidate is required so that selecting `(): Dyn` will +// evaluate the candidates and fail the leak check instead of returning the +// only applicable candidate. +trait Unsatisfied {} +impl HigherRanked for T { + fn call() { + unreachable!(); + } +} + +fn main() { + let x: &dyn Dyn = &(); + let y: &dyn Dyn fn(&'a ())> = unsafe { std::mem::transmute(x) }; + //~^ ERROR: wrong trait in wide pointer vtable + y.foo(); +} diff --git a/src/tools/miri/tests/fail/validity/dyn-trait-leak-check.stderr b/src/tools/miri/tests/fail/validity/dyn-trait-leak-check.stderr new file mode 100644 index 0000000000000..604c36cab915b --- /dev/null +++ b/src/tools/miri/tests/fail/validity/dyn-trait-leak-check.stderr @@ -0,0 +1,15 @@ +error: Undefined Behavior: constructing invalid value: wrong trait in wide pointer vtable: expected `Dyn fn(&'a ())>`, but encountered `Dyn` + --> tests/fail/validity/dyn-trait-leak-check.rs:LL:CC + | +LL | let y: &dyn Dyn fn(&'a ())> = unsafe { std::mem::transmute(x) }; + | ^^^^^^^^^^^^^^^^^^^^^^ constructing invalid value: wrong trait in wide pointer vtable: expected `Dyn fn(&'a ())>`, but encountered `Dyn` + | + = help: this indicates a bug in the program: it performed an invalid operation, and caused Undefined Behavior + = help: see https://doc.rust-lang.org/nightly/reference/behavior-considered-undefined.html for further information + = note: BACKTRACE: + = note: inside `main` at tests/fail/validity/dyn-trait-leak-check.rs:LL:CC + +note: some details are omitted, run with `MIRIFLAGS=-Zmiri-backtrace=full` for a verbose backtrace + +error: aborting due to 1 previous error + diff --git a/src/tools/miri/tests/pass/validity/dyn-trait-bivariant-transmutes.rs b/src/tools/miri/tests/pass/validity/dyn-trait-bivariant-transmutes.rs new file mode 100644 index 0000000000000..cbd2dd25f2222 --- /dev/null +++ b/src/tools/miri/tests/pass/validity/dyn-trait-bivariant-transmutes.rs @@ -0,0 +1,26 @@ +// Test that transmuting between subtypes of dyn traits is fine, even in the +// "wrong direction", i.e. going from a lower-ranked to a higher-ranked dyn trait. + +trait Dyn {} +impl Dyn for T {} + +struct Wrapper(T); + +fn main() { + let x: &dyn Dyn = &(); + let _y: &dyn for<'a> Dyn = unsafe { std::mem::transmute(x) }; + + let x: &dyn for<'a> Dyn = &(); + let _y: &dyn Dyn = unsafe { std::mem::transmute(x) }; + + let x: &dyn Dyn> = &(); + let _y: &dyn for<'a> Dyn> = unsafe { std::mem::transmute(x) }; + + let x: &dyn for<'a> Dyn> = &(); + let _y: &dyn Dyn> = unsafe { std::mem::transmute(x) }; + + // This lowers to a ptr-to-ptr cast (which behaves like a transmute) + // and not an unsizing coercion: + let x: *const dyn for<'a> Dyn<&'a ()> = &(); + let _y: *const Wrapper> = x as _; +}