diff options
| author | Linus Torvalds <torvalds@linux-foundation.org> | 2026-06-15 12:41:17 +0530 |
|---|---|---|
| committer | Linus Torvalds <torvalds@linux-foundation.org> | 2026-06-15 12:41:17 +0530 |
| commit | 36808d5e983985bbda87e01059cccc071fe3ec8d (patch) | |
| tree | fe89e03319a2eeacabced24f1e6b96123ad527f9 /rust | |
| parent | 5504ce0317f777dcf102751c3e518284226fc2e1 (diff) | |
| parent | fe221742e388bea3f5856b5d9b2cb0a037020ea4 (diff) | |
Merge tag 'driver-core-7.2-rc1' of gitolite.kernel.org:pub/scm/linux/kernel/git/driver-core/driver-core
Pull driver core updates from Danilo Krummrich:
"Deferred probe:
- Fix race where deferred probe timeout work could be permanently
canceled by using mod_delayed_work()
- Fix missing jiffies conversion in deferred_probe_extend_timeout()
- Guard timeout extension with delayed_work_pending() to prevent
premature firing
- Use system_percpu_wq instead of the deprecated system_wq
- Update deferred_probe_timeout documentation
device:
- Replace direct struct device bitfield access (can_match, dma_iommu,
dma_skip_sync, dma_ops_bypass, state_synced, dma_coherent,
of_node_reused, offline, offline_disabled) with flag-based
accessors using bit operations
- Reject devices with unregistered buses
- Delete unused DEVICE_ATTR_PREALLOC()
- Add low-level device attribute macros with const show/store
callbacks, allowing device attributes to reside in read-only memory
- Move core device attributes to read-only memory
- Constify group array pointers in driver_add_groups() /
driver_remove_groups(), struct bus_type, and struct device_driver
device property:
- Fix fwnode reference leak in fwnode_graph_get_endpoint_by_id()
- Initialize all fields of fwnode_handle in fwnode_init()
- Provide swnode_get()/swnode_put() wrappers around kobject_get/put()
- Allow passing struct software_node_ref_args pointers directly to
PROPERTY_ENTRY_REF()
driver_override:
- Migrate amba, cdx, vmbus, and rpmsg to the generic driver_override
infrastructure, fixing a UAF from unsynchronized access to
driver_override in bus match() callbacks
- Remove the now-unused driver_set_override()
firmware loader:
- Fix recursive lock deadlock in device_cache_fw_images() when async
work falls back to synchronous execution
- Fix device reference leak in firmware_upload_register()
platform:
- Pass KBUILD_MODNAME through the platform driver registration macro
to create module symlinks in sysfs for built-in drivers; move
module_kset initialization to a pure_initcall and tegra cbb
registration to core_initcall to ensure correct ordering
- Pass THIS_MODULE implicitly through a coresight_init_driver() macro
sysfs:
- Upgrade OOB write detection in sysfs_kf_seq_show() from printk to
WARN
- Add return value clamping to sysfs_kf_read()
Rust:
- ACPI:
Fix missing match data for PRP0001 by exporting
acpi_of_match_device()
- Auxiliary:
Replace drvdata() with dedicated registration data on
auxiliary_device. drvdata() exposed the driver's bus device private
data beyond the driver's own scope, creating ordering constraints
and forcing the data to outlive all registrations that access it.
Registration data is instead scoped structurally to the
Registration object, making lifecycle ordering enforced by
construction rather than convention.
- Rust-native device driver lifetimes (HRT):
Allow Rust device drivers to carry a lifetime parameter on their
bus device private data, tied to the device binding scope -- the
interval during which a bus device is bound to a driver. Device
resources like pci::Bar<'a> and IoMem<'a> can be stored directly in
the driver's bus device private data with a lifetime bounded by the
binding scope, so the compiler enforces at build time that they do
not outlive the binding. This removes Devres indirection from every
access site and eliminates try_access() failure paths in
destructors.
Bus driver traits use a Generic Associated Type (GAT) Data<'bound>
to introduce the lifetime on the private data, rather than
parameterizing the Driver trait itself. Auxiliary registration
data, where the lifetime is not introduced by a trait callback but
must be threaded through Registration, uses the ForLt trait (a
type-level abstraction for types generic over a lifetime).
Misc:
- Fix DT overlayed devices not probing by reverting the broken
treewide overlay fix and re-running fw_devlink consumer pickup when
an overlay is applied to a bound device
- Use root_device_register() for faux bus root device; add sanity
check for failed bus init
- Fix dev_has_sync_state() data race with READ_ONCE() and move it to
base.h
- Avoid spurious device_links warning when removing a device while
its supplier is unbinding
- Switch ISA bus to dynamic root device
- Fix suspicious RCU usage in kernfs_put()
- Remove devcoredump exit callback
- Constify devfreq_event_class"
* tag 'driver-core-7.2-rc1' of gitolite.kernel.org:pub/scm/linux/kernel/git/driver-core/driver-core: (81 commits)
software node: allow passing reference args to PROPERTY_ENTRY_REF()
driver core: platform: set mod_name in driver registration
coresight: pass THIS_MODULE implicitly through a macro
kernel: param: initialize module_kset in a pure_initcall
soc/tegra: cbb: Move driver registration from pure_initcall to core_initcall
firmware_loader: Fix recursive lock in device_cache_fw_images()
driver core: Use system_percpu_wq instead of system_wq
driver core: remove driver_set_override()
rpmsg: use generic driver_override infrastructure
Drivers: hv: vmbus: use generic driver_override infrastructure
cdx: use generic driver_override infrastructure
amba: use generic driver_override infrastructure
rust: devres: add 'static bound to Devres<T>
samples: rust: rust_driver_auxiliary: showcase lifetime-bound registration data
rust: auxiliary: generalize Registration over ForLt
rust: types: add `ForLt` trait for higher-ranked lifetime support
gpu: nova-core: separate driver type from driver data
samples: rust: rust_driver_pci: use HRT lifetime for Bar
rust: io: make IoMem and ExclusiveIoMem lifetime-parameterized
rust: pci: make Bar lifetime-parameterized
...
Diffstat (limited to 'rust')
| -rw-r--r-- | rust/Makefile | 1 | ||||
| -rw-r--r-- | rust/helpers/acpi.c | 16 | ||||
| -rw-r--r-- | rust/helpers/helpers.c | 1 | ||||
| -rw-r--r-- | rust/kernel/alloc/kbox.rs | 45 | ||||
| -rw-r--r-- | rust/kernel/auxiliary.rs | 285 | ||||
| -rw-r--r-- | rust/kernel/cpufreq.rs | 9 | ||||
| -rw-r--r-- | rust/kernel/device.rs | 121 | ||||
| -rw-r--r-- | rust/kernel/devres.rs | 8 | ||||
| -rw-r--r-- | rust/kernel/dma.rs | 2 | ||||
| -rw-r--r-- | rust/kernel/driver.rs | 115 | ||||
| -rw-r--r-- | rust/kernel/i2c.rs | 61 | ||||
| -rw-r--r-- | rust/kernel/io/mem.rs | 121 | ||||
| -rw-r--r-- | rust/kernel/pci.rs | 51 | ||||
| -rw-r--r-- | rust/kernel/pci/id.rs | 2 | ||||
| -rw-r--r-- | rust/kernel/pci/io.rs | 54 | ||||
| -rw-r--r-- | rust/kernel/platform.rs | 56 | ||||
| -rw-r--r-- | rust/kernel/types.rs | 12 | ||||
| -rw-r--r-- | rust/kernel/types/for_lt.rs | 122 | ||||
| -rw-r--r-- | rust/kernel/usb.rs | 57 | ||||
| -rw-r--r-- | rust/macros/for_lt.rs | 248 | ||||
| -rw-r--r-- | rust/macros/lib.rs | 13 |
21 files changed, 1035 insertions, 365 deletions
diff --git a/rust/Makefile b/rust/Makefile index 2fbdebb93bf2..63b1e355321d 100644 --- a/rust/Makefile +++ b/rust/Makefile @@ -119,6 +119,7 @@ syn-cfgs := \ feature="parsing" \ feature="printing" \ feature="proc-macro" \ + feature="visit" \ feature="visit-mut" syn-flags := \ diff --git a/rust/helpers/acpi.c b/rust/helpers/acpi.c new file mode 100644 index 000000000000..e75c9807bbad --- /dev/null +++ b/rust/helpers/acpi.c @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/acpi.h> +#include <acpi/acpi_bus.h> + +__rust_helper bool rust_helper_acpi_of_match_device(const struct acpi_device *adev, + const struct of_device_id *of_match_table, + const struct of_device_id **of_id) +{ + return acpi_of_match_device(adev, of_match_table, of_id); +} + +__rust_helper struct acpi_device *rust_helper_to_acpi_device_node(struct fwnode_handle *fwnode) +{ + return to_acpi_device_node(fwnode); +} diff --git a/rust/helpers/helpers.c b/rust/helpers/helpers.c index 625921e27dfb..38b34518eff1 100644 --- a/rust/helpers/helpers.c +++ b/rust/helpers/helpers.c @@ -38,6 +38,7 @@ #define __rust_helper __always_inline #endif +#include "acpi.c" #include "atomic.c" #include "atomic_ext.c" #include "auxiliary.c" diff --git a/rust/kernel/alloc/kbox.rs b/rust/kernel/alloc/kbox.rs index 80eb39364e86..35d1e015848d 100644 --- a/rust/kernel/alloc/kbox.rs +++ b/rust/kernel/alloc/kbox.rs @@ -279,6 +279,27 @@ where Ok(Box(ptr.cast(), PhantomData)) } + /// Creates a new zero-initialized `Box<T, A>`. + /// + /// New memory is allocated with `A` and the [`__GFP_ZERO`] flag. The allocation may fail, in + /// which case an error is returned. For ZSTs no memory is allocated. + /// + /// # Examples + /// + /// ``` + /// let b = KBox::<[u8; 128]>::zeroed(GFP_KERNEL)?; + /// assert_eq!(*b, [0; 128]); + /// # Ok::<(), Error>(()) + /// ``` + pub fn zeroed(flags: Flags) -> Result<Self, AllocError> + where + T: Zeroable, + { + // SAFETY: `__GFP_ZERO` guarantees the memory is zeroed; `T: Zeroable` guarantees that + // all-zeroes is a valid bit pattern for `T`. + Ok(unsafe { Self::new_uninit(flags | __GFP_ZERO)?.assume_init() }) + } + /// Constructs a new `Pin<Box<T, A>>`. If `T` does not implement [`Unpin`], then `x` will be /// pinned in memory and can't be moved. #[inline] @@ -483,7 +504,7 @@ where // SAFETY: The pointer returned by `into_foreign` comes from a well aligned // pointer to `T` allocated by `A`. -unsafe impl<T: 'static, A> ForeignOwnable for Box<T, A> +unsafe impl<T, A> ForeignOwnable for Box<T, A> where A: Allocator, { @@ -493,8 +514,14 @@ where core::mem::align_of::<T>() }; - type Borrowed<'a> = &'a T; - type BorrowedMut<'a> = &'a mut T; + type Borrowed<'a> + = &'a T + where + Self: 'a; + type BorrowedMut<'a> + = &'a mut T + where + Self: 'a; fn into_foreign(self) -> *mut c_void { Box::into_raw(self).cast() @@ -522,13 +549,19 @@ where // SAFETY: The pointer returned by `into_foreign` comes from a well aligned // pointer to `T` allocated by `A`. -unsafe impl<T: 'static, A> ForeignOwnable for Pin<Box<T, A>> +unsafe impl<T, A> ForeignOwnable for Pin<Box<T, A>> where A: Allocator, { const FOREIGN_ALIGN: usize = <Box<T, A> as ForeignOwnable>::FOREIGN_ALIGN; - type Borrowed<'a> = Pin<&'a T>; - type BorrowedMut<'a> = Pin<&'a mut T>; + type Borrowed<'a> + = Pin<&'a T> + where + Self: 'a; + type BorrowedMut<'a> + = Pin<&'a mut T> + where + Self: 'a; fn into_foreign(self) -> *mut c_void { // SAFETY: We are still treating the box as pinned. diff --git a/rust/kernel/auxiliary.rs b/rust/kernel/auxiliary.rs index 93c0db1f6655..c42928d5a239 100644 --- a/rust/kernel/auxiliary.rs +++ b/rust/kernel/auxiliary.rs @@ -12,19 +12,25 @@ use crate::{ RawDeviceId, RawDeviceIdIndex, // }, - devres::Devres, + driver, error::{ from_result, to_result, // }, prelude::*, - types::Opaque, + types::{ + ForLt, + ForeignOwnable, + Opaque, // + }, ThisModule, // }; use core::{ + any::TypeId, marker::PhantomData, mem::offset_of, + pin::Pin, ptr::{ addr_of_mut, NonNull, // @@ -36,18 +42,18 @@ pub struct Adapter<T: Driver>(T); // SAFETY: // - `bindings::auxiliary_driver` is a C type declared as `repr(C)`. -// - `T` is the type of the driver's device private data. +// - `T::Data` is the type of the driver's device private data. // - `struct auxiliary_driver` embeds a `struct device_driver`. // - `DEVICE_DRIVER_OFFSET` is the correct byte offset to the embedded `struct device_driver`. -unsafe impl<T: Driver + 'static> driver::DriverLayout for Adapter<T> { +unsafe impl<T: Driver> driver::DriverLayout for Adapter<T> { type DriverType = bindings::auxiliary_driver; - type DriverData = T; + type DriverData<'bound> = T::Data<'bound>; const DEVICE_DRIVER_OFFSET: usize = core::mem::offset_of!(Self::DriverType, driver); } // SAFETY: A call to `unregister` for a given instance of `DriverType` is guaranteed to be valid if // a preceding call to `register` has been successful. -unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { +unsafe impl<T: Driver> driver::RegistrationOps for Adapter<T> { unsafe fn register( adrv: &Opaque<Self::DriverType>, name: &'static CStr, @@ -73,7 +79,7 @@ unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { } } -impl<T: Driver + 'static> Adapter<T> { +impl<T: Driver> Adapter<T> { extern "C" fn probe_callback( adev: *mut bindings::auxiliary_device, id: *const bindings::auxiliary_device_id, @@ -82,7 +88,7 @@ impl<T: Driver + 'static> Adapter<T> { // `struct auxiliary_device`. // // INVARIANT: `adev` is valid for the duration of `probe_callback()`. - let adev = unsafe { &*adev.cast::<Device<device::CoreInternal>>() }; + let adev = unsafe { &*adev.cast::<Device<device::CoreInternal<'_>>>() }; // SAFETY: `DeviceId` is a `#[repr(transparent)`] wrapper of `struct auxiliary_device_id` // and does not add additional invariants, so it's safe to transmute. @@ -102,12 +108,12 @@ impl<T: Driver + 'static> Adapter<T> { // `struct auxiliary_device`. // // INVARIANT: `adev` is valid for the duration of `remove_callback()`. - let adev = unsafe { &*adev.cast::<Device<device::CoreInternal>>() }; + let adev = unsafe { &*adev.cast::<Device<device::CoreInternal<'_>>>() }; // SAFETY: `remove_callback` is only ever called after a successful call to // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called - // and stored a `Pin<KBox<T>>`. - let data = unsafe { adev.as_ref().drvdata_borrow::<T>() }; + // and stored a `Pin<KBox<T::Data<'_>>>`. + let data = unsafe { adev.as_ref().drvdata_borrow::<T::Data<'_>>() }; T::unbind(adev, data); } @@ -197,13 +203,19 @@ pub trait Driver { /// type IdInfo: 'static = (); type IdInfo: 'static; + /// The type of the driver's bus device private data. + type Data<'bound>: Send + 'bound; + /// The table of device ids supported by the driver. const ID_TABLE: IdTable<Self::IdInfo>; /// Auxiliary driver probe. /// /// Called when an auxiliary device is matches a corresponding driver. - fn probe(dev: &Device<device::Core>, id_info: &Self::IdInfo) -> impl PinInit<Self, Error>; + fn probe<'bound>( + dev: &'bound Device<device::Core<'_>>, + id_info: &'bound Self::IdInfo, + ) -> impl PinInit<Self::Data<'bound>, Error> + 'bound; /// Auxiliary driver unbind. /// @@ -214,8 +226,8 @@ pub trait Driver { /// `&Device<Core>` or `&Device<Bound>` reference. For instance, drivers may try to perform I/O /// operations to gracefully tear down the device. /// - /// Otherwise, release operations for driver resources should be performed in `Self::drop`. - fn unbind(dev: &Device<device::Core>, this: Pin<&Self>) { + /// Otherwise, release operations for driver resources should be performed in `Drop`. + fn unbind<'bound>(dev: &'bound Device<device::Core<'_>>, this: Pin<&Self::Data<'bound>>) { let _ = (dev, this); } } @@ -257,6 +269,49 @@ impl Device<device::Bound> { // SAFETY: A bound auxiliary device always has a bound parent device. unsafe { parent.as_bound() } } + + /// Returns a pinned reference to the registration data set by the registering (parent) driver. + /// + /// `F` is the [`ForLt`](trait@ForLt) encoding of the data type. The returned + /// reference has its lifetime shortened from `'static` to `&self`'s borrow lifetime via + /// [`ForLt::cast_ref`]. + /// + /// Returns [`EINVAL`] if `F` does not match the type used by the parent driver when calling + /// [`Registration::new()`]. + /// + /// Returns [`ENOENT`] if no registration data has been set, e.g. when the device was + /// registered by a C driver. + pub fn registration_data<F: ForLt + 'static>(&self) -> Result<Pin<&F::Of<'_>>> { + // SAFETY: By the type invariant, `self.as_raw()` is a valid `struct auxiliary_device`. + let ptr = unsafe { (*self.as_raw()).registration_data_rust }; + if ptr.is_null() { + dev_warn!( + self.as_ref(), + "No registration data set; parent is not a Rust driver.\n" + ); + return Err(ENOENT); + } + + // SAFETY: `ptr` is non-null and was set via `into_foreign()` in `Registration::new()`; + // `RegistrationData` is `#[repr(C)]` with `type_id` at offset 0, so reading a `TypeId` + // at the start of the allocation is valid regardless of `F`. + let type_id = unsafe { ptr.cast::<TypeId>().read() }; + if type_id != TypeId::of::<F>() { + return Err(EINVAL); + } + + // SAFETY: The `TypeId` check above confirms that the stored type matches + // `F::Of<'static>`; `ptr` remains valid until `Registration::drop()` calls + // `from_foreign()`. + let wrapper = unsafe { Pin::<KBox<RegistrationData<F::Of<'static>>>>::borrow(ptr) }; + + // SAFETY: `data` is a structurally pinned field of `RegistrationData`. + let pinned: Pin<&F::Of<'_>> = unsafe { wrapper.map_unchecked(|w| &w.data) }; + + // SAFETY: The data was pinned when stored; `cast_ref` only shortens + // the lifetime, so the pinning guarantee is preserved. + Ok(unsafe { Pin::new_unchecked(F::cast_ref(pinned.get_ref())) }) + } } impl Device { @@ -326,87 +381,173 @@ unsafe impl Send for Device {} // (i.e. `Device<Normal>) are thread safe. unsafe impl Sync for Device {} +// SAFETY: Same as `Device<Normal>` -- the underlying `struct auxiliary_device` is the same; +// `Bound` is a zero-sized type-state marker that does not affect thread safety. +unsafe impl Sync for Device<device::Bound> {} + +/// Wrapper that stores a [`TypeId`] alongside the registration data for runtime type checking. +#[repr(C)] +#[pin_data] +struct RegistrationData<T> { + type_id: TypeId, + #[pin] + data: T, +} + /// The registration of an auxiliary device. /// /// This type represents the registration of a [`struct auxiliary_device`]. When its parent device /// is unbound, the corresponding auxiliary device will be unregistered from the system. /// +/// The type parameter `F` is a [`ForLt`](trait@ForLt) encoding of the registration +/// data type. For non-lifetime-parameterized types, use [`ForLt!(T)`](macro@ForLt). +/// The data can be accessed by the auxiliary driver through [`Device::registration_data()`]. +/// /// # Invariants /// -/// `self.0` always holds a valid pointer to an initialized and registered -/// [`struct auxiliary_device`]. -pub struct Registration(NonNull<bindings::auxiliary_device>); +/// `self.adev` always holds a valid pointer to an initialized and registered +/// [`struct auxiliary_device`] whose `registration_data_rust` field points to a +/// valid `Pin<KBox<RegistrationData<F::Of<'static>>>>`. +pub struct Registration<'a, F: ForLt + 'static> { + adev: NonNull<bindings::auxiliary_device>, + _phantom: PhantomData<F::Of<'a>>, +} -impl Registration { - /// Create and register a new auxiliary device. - pub fn new<'a>( +impl<'a, F: ForLt> Registration<'a, F> +where + for<'b> F::Of<'b>: Send + Sync, +{ + /// Create and register a new auxiliary device with the given registration data. + /// + /// The `data` is owned by the registration and can be accessed through the auxiliary device + /// via [`Device::registration_data()`]. + /// + /// # Safety + /// + /// The caller must not `mem::forget()` the returned [`Registration`] or otherwise prevent its + /// [`Drop`] implementation from running, since the registration data may contain borrowed + /// references that become invalid after `'a` ends. + /// + /// If the registration data is `'static`, use the safe [`Registration::new()`] instead. + pub unsafe fn new_with_lt<E>( parent: &'a device::Device<device::Bound>, - name: &'a CStr, + name: &CStr, id: u32, - modname: &'a CStr, - ) -> impl PinInit<Devres<Self>, Error> + 'a { - pin_init::pin_init_scope(move || { - let boxed = KBox::new(Opaque::<bindings::auxiliary_device>::zeroed(), GFP_KERNEL)?; - let adev = boxed.get(); - - // SAFETY: It's safe to set the fields of `struct auxiliary_device` on initialization. - unsafe { - (*adev).dev.parent = parent.as_raw(); - (*adev).dev.release = Some(Device::release); - (*adev).name = name.as_char_ptr(); - (*adev).id = id; - } - - // SAFETY: `adev` is guaranteed to be a valid pointer to a `struct auxiliary_device`, - // which has not been initialized yet. - unsafe { bindings::auxiliary_device_init(adev) }; - - // Now that `adev` is initialized, leak the `Box`; the corresponding memory will be - // freed by `Device::release` when the last reference to the `struct auxiliary_device` - // is dropped. - let _ = KBox::into_raw(boxed); - - // SAFETY: - // - `adev` is guaranteed to be a valid pointer to a `struct auxiliary_device`, which - // has been initialized, - // - `modname.as_char_ptr()` is a NULL terminated string. - let ret = unsafe { bindings::__auxiliary_device_add(adev, modname.as_char_ptr()) }; - if ret != 0 { - // SAFETY: `adev` is guaranteed to be a valid pointer to a - // `struct auxiliary_device`, which has been initialized. - unsafe { bindings::auxiliary_device_uninit(adev) }; - - return Err(Error::from_errno(ret)); - } - - // INVARIANT: The device will remain registered until `auxiliary_device_delete()` is - // called, which happens in `Self::drop()`. - Ok(Devres::new( - parent, - // SAFETY: `adev` is guaranteed to be non-null, since the `KBox` was allocated - // successfully. - Self(unsafe { NonNull::new_unchecked(adev) }), - )) + modname: &CStr, + data: impl PinInit<F::Of<'a>, E>, + ) -> Result<Self> + where + Error: From<E>, + { + let data = KBox::pin_init::<Error>( + try_pin_init!(RegistrationData { + type_id: TypeId::of::<F>(), + data <- data, + }), + GFP_KERNEL, + )?; + + // SAFETY: `'a` is invariant (via `Registration`'s `PhantomData`). Lifetimes do not + // affect layout, so RegistrationData<F::Of<'a>> and RegistrationData<F::Of<'static>> + // have identical representation. + let data: Pin<KBox<RegistrationData<F::Of<'static>>>> = + unsafe { core::mem::transmute(data) }; + + let boxed: KBox<Opaque<bindings::auxiliary_device>> = KBox::zeroed(GFP_KERNEL)?; + let adev = boxed.get(); + + // SAFETY: It's safe to set the fields of `struct auxiliary_device` on initialization. + unsafe { + (*adev).dev.parent = parent.as_raw(); + (*adev).dev.release = Some(Device::release); + (*adev).name = name.as_char_ptr(); + (*adev).id = id; + (*adev).registration_data_rust = data.into_foreign(); + } + + // SAFETY: `adev` is guaranteed to be a valid pointer to a `struct auxiliary_device`, + // which has not been initialized yet. + unsafe { bindings::auxiliary_device_init(adev) }; + + // Now that `adev` is initialized, leak the `Box`; the corresponding memory will be + // freed by `Device::release` when the last reference to the `struct auxiliary_device` + // is dropped. + let _ = KBox::into_raw(boxed); + + // SAFETY: + // - `adev` is guaranteed to be a valid pointer to a `struct auxiliary_device`, which + // has been initialized, + // - `modname.as_char_ptr()` is a NULL terminated string. + let ret = unsafe { bindings::__auxiliary_device_add(adev, modname.as_char_ptr()) }; + if ret != 0 { + // SAFETY: `registration_data` was set above via `into_foreign()`. + drop(unsafe { + Pin::<KBox<RegistrationData<F::Of<'static>>>>::from_foreign( + (*adev).registration_data_rust, + ) + }); + + // SAFETY: `adev` is guaranteed to be a valid pointer to a + // `struct auxiliary_device`, which has been initialized. + unsafe { bindings::auxiliary_device_uninit(adev) }; + + return Err(Error::from_errno(ret)); + } + + // INVARIANT: The device will remain registered until `auxiliary_device_delete()` is + // called, which happens in `Self::drop()`. + Ok(Self { + // SAFETY: `adev` is guaranteed to be non-null, since the `KBox` was allocated + // successfully. + adev: unsafe { NonNull::new_unchecked(adev) }, + _phantom: PhantomData, }) } + + /// Create and register a new auxiliary device with `'static` registration data. + /// + /// Safe variant of [`Registration::new_with_lt()`] for registration data that does not contain + /// borrowed references. + pub fn new<E>( + parent: &'a device::Device<device::Bound>, + name: &CStr, + id: u32, + modname: &CStr, + data: impl PinInit<F::Of<'a>, E>, + ) -> Result<Self> + where + F::Of<'a>: 'static, + Error: From<E>, + { + // SAFETY: `F::Of<'a>: 'static` guarantees the data contains no borrowed references, + // so forgetting the `Registration` cannot cause use-after-free. + unsafe { Self::new_with_lt(parent, name, id, modname, data) } + } } -impl Drop for Registration { +impl<F: ForLt> Drop for Registration<'_, F> { fn drop(&mut self) { - // SAFETY: By the type invariant of `Self`, `self.0.as_ptr()` is a valid registered + // SAFETY: By the type invariant of `Self`, `self.adev.as_ptr()` is a valid registered // `struct auxiliary_device`. - unsafe { bindings::auxiliary_device_delete(self.0.as_ptr()) }; + unsafe { bindings::auxiliary_device_delete(self.adev.as_ptr()) }; + + // SAFETY: `registration_data` was set in `new()` via `into_foreign()`. + drop(unsafe { + Pin::<KBox<RegistrationData<F::Of<'static>>>>::from_foreign( + (*self.adev.as_ptr()).registration_data_rust, + ) + }); // This drops the reference we acquired through `auxiliary_device_init()`. // - // SAFETY: By the type invariant of `Self`, `self.0.as_ptr()` is a valid registered + // SAFETY: By the type invariant of `Self`, `self.adev.as_ptr()` is a valid registered // `struct auxiliary_device`. - unsafe { bindings::auxiliary_device_uninit(self.0.as_ptr()) }; + unsafe { bindings::auxiliary_device_uninit(self.adev.as_ptr()) }; } } // SAFETY: A `Registration` of a `struct auxiliary_device` can be released from any thread. -unsafe impl Send for Registration {} +unsafe impl<F: ForLt> Send for Registration<'_, F> where for<'a> F::Of<'a>: Send {} // SAFETY: `Registration` does not expose any methods or fields that need synchronization. -unsafe impl Sync for Registration {} +unsafe impl<F: ForLt> Sync for Registration<'_, F> where for<'a> F::Of<'a>: Send {} diff --git a/rust/kernel/cpufreq.rs b/rust/kernel/cpufreq.rs index a20bd5006f38..58ac04c650a1 100644 --- a/rust/kernel/cpufreq.rs +++ b/rust/kernel/cpufreq.rs @@ -888,12 +888,13 @@ pub trait Driver { /// /// impl platform::Driver for SampleDriver { /// type IdInfo = (); +/// type Data<'bound> = Self; /// const OF_ID_TABLE: Option<of::IdTable<Self::IdInfo>> = None; /// -/// fn probe( -/// pdev: &platform::Device<Core>, -/// _id_info: Option<&Self::IdInfo>, -/// ) -> impl PinInit<Self, Error> { +/// fn probe<'bound>( +/// pdev: &'bound platform::Device<Core<'_>>, +/// _id_info: Option<&'bound Self::IdInfo>, +/// ) -> impl PinInit<Self, Error> + 'bound { /// cpufreq::Registration::<SampleDriver>::new_foreign_owned(pdev.as_ref())?; /// Ok(Self {}) /// } diff --git a/rust/kernel/device.rs b/rust/kernel/device.rs index 6d5396a43ebe..645afc49a27d 100644 --- a/rust/kernel/device.rs +++ b/rust/kernel/device.rs @@ -15,16 +15,12 @@ use crate::{ }, // }; use core::{ - any::TypeId, marker::PhantomData, ptr, // }; pub mod property; -// Assert that we can `read()` / `write()` a `TypeId` instance from / into `struct driver_type`. -static_assert!(core::mem::size_of::<bindings::driver_type>() >= core::mem::size_of::<TypeId>()); - /// The core representation of a device in the kernel's driver model. /// /// This structure represents the Rust abstraction for a C `struct device`. A [`Device`] can either @@ -205,30 +201,13 @@ impl Device { } } -impl Device<CoreInternal> { - fn set_type_id<T: 'static>(&self) { - // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. - let private = unsafe { (*self.as_raw()).p }; - - // SAFETY: For a bound device (implied by the `CoreInternal` device context), `private` is - // guaranteed to be a valid pointer to a `struct device_private`. - let driver_type = unsafe { &raw mut (*private).driver_type }; - - // SAFETY: `driver_type` is valid for (unaligned) writes of a `TypeId`. - unsafe { - driver_type - .cast::<TypeId>() - .write_unaligned(TypeId::of::<T>()) - }; - } - +impl<'a> Device<CoreInternal<'a>> { /// Store a pointer to the bound driver's private data. - pub fn set_drvdata<T: 'static>(&self, data: impl PinInit<T, Error>) -> Result { + pub fn set_drvdata<T>(&self, data: impl PinInit<T, Error>) -> Result { let data = KBox::pin_init(data, GFP_KERNEL)?; // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. unsafe { bindings::dev_set_drvdata(self.as_raw(), data.into_foreign().cast()) }; - self.set_type_id::<T>(); Ok(()) } @@ -239,7 +218,7 @@ impl Device<CoreInternal> { /// /// - The type `T` must match the type of the `ForeignOwnable` previously stored by /// [`Device::set_drvdata`]. - pub(crate) unsafe fn drvdata_obtain<T: 'static>(&self) -> Option<Pin<KBox<T>>> { + pub(crate) unsafe fn drvdata_obtain<T>(&self) -> Option<Pin<KBox<T>>> { // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. let ptr = unsafe { bindings::dev_get_drvdata(self.as_raw()) }; @@ -265,7 +244,7 @@ impl Device<CoreInternal> { /// device is fully unbound. /// - The type `T` must match the type of the `ForeignOwnable` previously stored by /// [`Device::set_drvdata`]. - pub unsafe fn drvdata_borrow<T: 'static>(&self) -> Pin<&T> { + pub unsafe fn drvdata_borrow<T>(&self) -> Pin<&T> { // SAFETY: `drvdata_unchecked()` has the exact same safety requirements as the ones // required by this method. unsafe { self.drvdata_unchecked() } @@ -281,7 +260,7 @@ impl Device<Bound> { /// the device is fully unbound. /// - The type `T` must match the type of the `ForeignOwnable` previously stored by /// [`Device::set_drvdata`]. - unsafe fn drvdata_unchecked<T: 'static>(&self) -> Pin<&T> { + unsafe fn drvdata_unchecked<T>(&self) -> Pin<&T> { // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. let ptr = unsafe { bindings::dev_get_drvdata(self.as_raw()) }; @@ -292,45 +271,6 @@ impl Device<Bound> { // in `into_foreign()`. unsafe { Pin::<KBox<T>>::borrow(ptr.cast()) } } - - fn match_type_id<T: 'static>(&self) -> Result { - // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. - let private = unsafe { (*self.as_raw()).p }; - - // SAFETY: For a bound device, `private` is guaranteed to be a valid pointer to a - // `struct device_private`. - let driver_type = unsafe { &raw mut (*private).driver_type }; - - // SAFETY: - // - `driver_type` is valid for (unaligned) reads of a `TypeId`. - // - A bound device guarantees that `driver_type` contains a valid `TypeId` value. - let type_id = unsafe { driver_type.cast::<TypeId>().read_unaligned() }; - - if type_id != TypeId::of::<T>() { - return Err(EINVAL); - } - - Ok(()) - } - - /// Access a driver's private data. - /// - /// Returns a pinned reference to the driver's private data or [`EINVAL`] if it doesn't match - /// the asserted type `T`. - pub fn drvdata<T: 'static>(&self) -> Result<Pin<&T>> { - // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. - if unsafe { bindings::dev_get_drvdata(self.as_raw()) }.is_null() { - return Err(ENOENT); - } - - self.match_type_id::<T>()?; - - // SAFETY: - // - The above check of `dev_get_drvdata()` guarantees that we are called after - // `set_drvdata()`. - // - We've just checked that the type of the driver's private data is in fact `T`. - Ok(unsafe { self.drvdata_unchecked() }) - } } impl<Ctx: DeviceContext> Device<Ctx> { @@ -527,6 +467,10 @@ unsafe impl Send for Device {} // synchronization in `struct device`. unsafe impl Sync for Device {} +// SAFETY: Same as `Device<Normal>` -- the underlying `struct device` is the same; `Bound` is a +// zero-sized type-state marker that does not affect thread safety. +unsafe impl Sync for Device<Bound> {} + /// Marker trait for the context or scope of a bus specific device. /// /// [`DeviceContext`] is a marker trait for types representing the context of a bus specific @@ -567,7 +511,7 @@ pub struct Normal; /// callback it appears in. It is intended to be used for synchronization purposes. Bus device /// implementations can implement methods for [`Device<Core>`], such that they can only be called /// from bus callbacks. -pub struct Core; +pub struct Core<'a>(PhantomData<&'a ()>); /// Semantically the same as [`Core`], but reserved for internal usage of the corresponding bus /// abstraction. @@ -578,7 +522,7 @@ pub struct Core; /// /// This context mainly exists to share generic [`Device`] infrastructure that should only be called /// from bus callbacks with bus abstractions, but without making them accessible for drivers. -pub struct CoreInternal; +pub struct CoreInternal<'a>(PhantomData<&'a ()>); /// The [`Bound`] context is the [`DeviceContext`] of a bus specific device when it is guaranteed to /// be bound to a driver. @@ -602,14 +546,14 @@ mod private { pub trait Sealed {} impl Sealed for super::Bound {} - impl Sealed for super::Core {} - impl Sealed for super::CoreInternal {} + impl<'a> Sealed for super::Core<'a> {} + impl<'a> Sealed for super::CoreInternal<'a> {} impl Sealed for super::Normal {} } impl DeviceContext for Bound {} -impl DeviceContext for Core {} -impl DeviceContext for CoreInternal {} +impl<'a> DeviceContext for Core<'a> {} +impl<'a> DeviceContext for CoreInternal<'a> {} impl DeviceContext for Normal {} impl<Ctx: DeviceContext> AsRef<Device<Ctx>> for Device<Ctx> { @@ -659,6 +603,22 @@ pub unsafe trait AsBusDevice<Ctx: DeviceContext>: AsRef<Device<Ctx>> { #[doc(hidden)] #[macro_export] macro_rules! __impl_device_context_deref { + (unsafe { $device:ident, <$lt:lifetime> $src:ty => $dst:ty }) => { + impl<$lt> ::core::ops::Deref for $device<$src> { + type Target = $device<$dst>; + + fn deref(&self) -> &Self::Target { + let ptr: *const Self = self; + + // CAST: `$device<$src>` and `$device<$dst>` transparently wrap the same type by the + // safety requirement of the macro. + let ptr = ptr.cast::<Self::Target>(); + + // SAFETY: `ptr` was derived from `&self`. + unsafe { &*ptr } + } + } + }; (unsafe { $device:ident, $src:ty => $dst:ty }) => { impl ::core::ops::Deref for $device<$src> { type Target = $device<$dst>; @@ -691,14 +651,14 @@ macro_rules! impl_device_context_deref { // `__impl_device_context_deref!`. ::kernel::__impl_device_context_deref!(unsafe { $device, - $crate::device::CoreInternal => $crate::device::Core + <'a> $crate::device::CoreInternal<'a> => $crate::device::Core<'a> }); // SAFETY: This macro has the exact same safety requirement as // `__impl_device_context_deref!`. ::kernel::__impl_device_context_deref!(unsafe { $device, - $crate::device::Core => $crate::device::Bound + <'a> $crate::device::Core<'a> => $crate::device::Bound }); // SAFETY: This macro has the exact same safety requirement as @@ -713,6 +673,13 @@ macro_rules! impl_device_context_deref { #[doc(hidden)] #[macro_export] macro_rules! __impl_device_context_into_aref { + (<$lt:lifetime> $src:ty, $device:tt) => { + impl<$lt> ::core::convert::From<&$device<$src>> for $crate::sync::aref::ARef<$device> { + fn from(dev: &$device<$src>) -> Self { + (&**dev).into() + } + } + }; ($src:ty, $device:tt) => { impl ::core::convert::From<&$device<$src>> for $crate::sync::aref::ARef<$device> { fn from(dev: &$device<$src>) -> Self { @@ -727,8 +694,12 @@ macro_rules! __impl_device_context_into_aref { #[macro_export] macro_rules! impl_device_context_into_aref { ($device:tt) => { - ::kernel::__impl_device_context_into_aref!($crate::device::CoreInternal, $device); - ::kernel::__impl_device_context_into_aref!($crate::device::Core, $device); + ::kernel::__impl_device_context_into_aref!( + <'a> $crate::device::CoreInternal<'a>, $device + ); + ::kernel::__impl_device_context_into_aref!( + <'a> $crate::device::Core<'a>, $device + ); ::kernel::__impl_device_context_into_aref!($crate::device::Bound, $device); }; } diff --git a/rust/kernel/devres.rs b/rust/kernel/devres.rs index 9e5f93aed20c..11ce500e9b76 100644 --- a/rust/kernel/devres.rs +++ b/rust/kernel/devres.rs @@ -122,7 +122,7 @@ struct Inner<T> { /// # Ok(()) /// # } /// ``` -pub struct Devres<T: Send> { +pub struct Devres<T: Send + 'static> { dev: ARef<Device>, inner: Arc<Inner<T>>, } @@ -184,7 +184,7 @@ mod base { } } -impl<T: Send> Devres<T> { +impl<T: Send + 'static> Devres<T> { /// Creates a new [`Devres`] instance of the given `data`. /// /// The `data` encapsulated within the returned `Devres` instance' `data` will be @@ -304,7 +304,7 @@ impl<T: Send> Devres<T> { /// pci, // /// }; /// - /// fn from_core(dev: &pci::Device<Core>, devres: Devres<pci::Bar<0x4>>) -> Result { + /// fn from_core(dev: &pci::Device<Core<'_>>, devres: Devres<pci::Bar<'_, 0x4>>) -> Result { /// let bar = devres.access(dev.as_ref())?; /// /// let _ = bar.read32(0x0); @@ -349,7 +349,7 @@ unsafe impl<T: Send> Send for Devres<T> {} // SAFETY: `Devres` can be shared with any task, if `T: Sync`. unsafe impl<T: Send + Sync> Sync for Devres<T> {} -impl<T: Send> Drop for Devres<T> { +impl<T: Send + 'static> Drop for Devres<T> { fn drop(&mut self) { // SAFETY: When `drop` runs, it is guaranteed that nobody is accessing the revocable data // anymore, hence it is safe not to wait for the grace period to finish. diff --git a/rust/kernel/dma.rs b/rust/kernel/dma.rs index 642ccff465c8..200def84fb69 100644 --- a/rust/kernel/dma.rs +++ b/rust/kernel/dma.rs @@ -47,7 +47,7 @@ pub type DmaAddress = bindings::dma_addr_t; /// where the underlying bus is DMA capable, such as: #[cfg_attr(CONFIG_PCI, doc = "* [`pci::Device`](kernel::pci::Device)")] /// * [`platform::Device`](::kernel::platform::Device) -pub trait Device: AsRef<device::Device<Core>> { +pub trait Device<'a>: AsRef<device::Device<Core<'a>>> { /// Set up the device's DMA streaming addressing capabilities. /// /// This method is usually called once from `probe()` as soon as the device capabilities are diff --git a/rust/kernel/driver.rs b/rust/kernel/driver.rs index 36de8098754d..bf5ba0d27553 100644 --- a/rust/kernel/driver.rs +++ b/rust/kernel/driver.rs @@ -13,10 +13,13 @@ //! The main driver interface is defined by a bus specific driver trait. For instance: //! //! ```ignore -//! pub trait Driver: Send { +//! pub trait Driver { //! /// The type holding information about each device ID supported by the driver. //! type IdInfo: 'static; //! +//! /// The type of the driver's bus device private data. +//! type Data<'bound>: Send + 'bound; +//! //! /// The table of OF device ids supported by the driver. //! const OF_ID_TABLE: Option<of::IdTable<Self::IdInfo>> = None; //! @@ -24,10 +27,16 @@ //! const ACPI_ID_TABLE: Option<acpi::IdTable<Self::IdInfo>> = None; //! //! /// Driver probe. -//! fn probe(dev: &Device<device::Core>, id_info: &Self::IdInfo) -> impl PinInit<Self, Error>; +//! fn probe<'bound>( +//! dev: &'bound Device<device::Core<'_>>, +//! id_info: &'bound Self::IdInfo, +//! ) -> impl PinInit<Self::Data<'bound>, Error> + 'bound; //! //! /// Driver unbind (optional). -//! fn unbind(dev: &Device<device::Core>, this: Pin<&Self>) { +//! fn unbind<'bound>( +//! dev: &'bound Device<device::Core<'_>>, +//! this: Pin<&Self::Data<'bound>>, +//! ) { //! let _ = (dev, this); //! } //! } @@ -42,8 +51,9 @@ )] #")] //! -//! The `probe()` callback should return a `impl PinInit<Self, Error>`, i.e. the driver's private -//! data. The bus abstraction should store the pointer in the corresponding bus device. The generic +//! The `probe()` callback should return a +//! `impl PinInit<Self::Data<'bound>, Error>`, i.e. the driver's private data. The bus +//! abstraction should store the pointer in the corresponding bus device. The generic //! [`Device`] infrastructure provides common helpers for this purpose on its //! [`Device<CoreInternal>`] implementation. //! @@ -118,8 +128,8 @@ pub unsafe trait DriverLayout { /// The specific driver type embedding a `struct device_driver`. type DriverType: Default; - /// The type of the driver's device private data. - type DriverData; + /// The type of the driver's bus device private data. + type DriverData<'bound>; /// Byte offset of the embedded `struct device_driver` within `DriverType`. /// @@ -181,20 +191,20 @@ unsafe impl<T: RegistrationOps> Sync for Registration<T> {} // any thread, so `Registration` is `Send`. unsafe impl<T: RegistrationOps> Send for Registration<T> {} -impl<T: RegistrationOps + 'static> Registration<T> { +impl<T: RegistrationOps> Registration<T> { extern "C" fn post_unbind_callback(dev: *mut bindings::device) { // SAFETY: The driver core only ever calls the post unbind callback with a valid pointer to // a `struct device`. // // INVARIANT: `dev` is valid for the duration of the `post_unbind_callback()`. - let dev = unsafe { &*dev.cast::<device::Device<device::CoreInternal>>() }; + let dev = unsafe { &*dev.cast::<device::Device<device::CoreInternal<'_>>>() }; - // `remove()` and all devres callbacks have been completed at this point, hence drop the - // driver's device private data. + // `remove()` has been completed at this point; devres resources are still valid and will + // be released after the driver's bus device private data is dropped. // // SAFETY: By the safety requirements of the `Driver` trait, `T::DriverData` is the - // driver's device private data type. - drop(unsafe { dev.drvdata_obtain::<T::DriverData>() }); + // driver's bus device private data type. + drop(unsafe { dev.drvdata_obtain::<T::DriverData<'_>>() }); } /// Attach generic `struct device_driver` callbacks. @@ -215,7 +225,10 @@ impl<T: RegistrationOps + 'static> Registration<T> { } /// Creates a new instance of the registration object. - pub fn new(name: &'static CStr, module: &'static ThisModule) -> impl PinInit<Self, Error> { + pub fn new(name: &'static CStr, module: &'static ThisModule) -> impl PinInit<Self, Error> + where + T: 'static, + { try_pin_init!(Self { reg <- Opaque::try_ffi_init(|ptr: *mut T::DriverType| { // SAFETY: `try_ffi_init` guarantees that `ptr` is valid for write. @@ -278,6 +291,26 @@ macro_rules! module_driver { } } +// Calling the FFI function directly from the `Adapter` impl may result in it being called +// directly from driver modules. This happens since the Rust compiler will use monomorphisation, so +// it might happen that functions are instantiated within the calling driver module. For now, work +// around this with `#[inline(never)]` helpers. +// +// TODO: Remove once a more generic solution has been implemented. For instance, we may be able to +// leverage `bindgen` to take care of this depending on whether a symbol is (already) exported. +#[inline(never)] +#[allow(clippy::missing_safety_doc)] +#[allow(dead_code)] +#[must_use] +unsafe fn acpi_of_match_device( + adev: *const bindings::acpi_device, + of_match_table: *const bindings::of_device_id, + of_id: *mut *const bindings::of_device_id, +) -> bool { + // SAFETY: Safety requirements are the same as `bindings::acpi_of_match_device`. + unsafe { bindings::acpi_of_match_device(adev, of_match_table, of_id) } +} + /// The bus independent adapter to match a drivers and a devices. /// /// This trait should be implemented by the bus specific adapter, which represents the connection @@ -329,35 +362,63 @@ pub trait Adapter { /// /// If this returns `None`, it means there is no match with an entry in the [`of::IdTable`]. fn of_id_info(dev: &device::Device) -> Option<&'static Self::IdInfo> { - #[cfg(not(CONFIG_OF))] + let table = Self::of_id_table()?; + + #[cfg(not(any(CONFIG_OF, CONFIG_ACPI)))] { - let _ = dev; - None + let _ = (dev, table); } #[cfg(CONFIG_OF)] { - let table = Self::of_id_table()?; - // SAFETY: // - `table` has static lifetime, hence it's valid for read, // - `dev` is guaranteed to be valid while it's alive, and so is `dev.as_raw()`. let raw_id = unsafe { bindings::of_match_device(table.as_ptr(), dev.as_raw()) }; - if raw_id.is_null() { - None - } else { + if !raw_id.is_null() { // SAFETY: `DeviceId` is a `#[repr(transparent)]` wrapper of `struct of_device_id` // and does not add additional invariants, so it's safe to transmute. let id = unsafe { &*raw_id.cast::<of::DeviceId>() }; - Some( - table.info(<of::DeviceId as crate::device_id::RawDeviceIdIndex>::index( - id, - )), - ) + return Some(table.info( + <of::DeviceId as crate::device_id::RawDeviceIdIndex>::index(id), + )); + } + } + + #[cfg(CONFIG_ACPI)] + { + use core::ptr; + use device::property::FwNode; + + let mut raw_id = ptr::null(); + + let fwnode = dev.fwnode().map_or(ptr::null_mut(), FwNode::as_raw); + + // SAFETY: `fwnode` is a pointer to a valid `fwnode_handle`. A null pointer will be + // passed through the function. + let adev = unsafe { bindings::to_acpi_device_node(fwnode) }; + + // SAFETY: + // - `adev` is a valid pointer to `acpi_device` or is null. It is guaranteed to be + // valid as long as `dev` is alive. + // - `table` has static lifetime, hence it's valid for read. + if unsafe { acpi_of_match_device(adev, table.as_ptr(), &raw mut raw_id) } { + // SAFETY: + // - the function returns true, therefore `raw_id` has been set to a pointer to a + // valid `of_device_id`. + // - `DeviceId` is a `#[repr(transparent)]` wrapper of `struct of_device_id` + // and does not add additional invariants, so it's safe to transmute. + let id = unsafe { &*raw_id.cast::<of::DeviceId>() }; + + return Some(table.info( + <of::DeviceId as crate::device_id::RawDeviceIdIndex>::index(id), + )); } } + + None } /// Returns the driver's private data from the matching entry of any of the ID tables, if any. diff --git a/rust/kernel/i2c.rs b/rust/kernel/i2c.rs index c084a45b1916..624b971ca8b0 100644 --- a/rust/kernel/i2c.rs +++ b/rust/kernel/i2c.rs @@ -93,18 +93,18 @@ pub struct Adapter<T: Driver>(T); // SAFETY: // - `bindings::i2c_driver` is a C type declared as `repr(C)`. -// - `T` is the type of the driver's device private data. +// - `T::Data` is the type of the driver's device private data. // - `struct i2c_driver` embeds a `struct device_driver`. // - `DEVICE_DRIVER_OFFSET` is the correct byte offset to the embedded `struct device_driver`. -unsafe impl<T: Driver + 'static> driver::DriverLayout for Adapter<T> { +unsafe impl<T: Driver> driver::DriverLayout for Adapter<T> { type DriverType = bindings::i2c_driver; - type DriverData = T; + type DriverData<'bound> = T::Data<'bound>; const DEVICE_DRIVER_OFFSET: usize = core::mem::offset_of!(Self::DriverType, driver); } // SAFETY: A call to `unregister` for a given instance of `DriverType` is guaranteed to be valid if // a preceding call to `register` has been successful. -unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { +unsafe impl<T: Driver> driver::RegistrationOps for Adapter<T> { unsafe fn register( idrv: &Opaque<Self::DriverType>, name: &'static CStr, @@ -151,13 +151,13 @@ unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { } } -impl<T: Driver + 'static> Adapter<T> { +impl<T: Driver> Adapter<T> { extern "C" fn probe_callback(idev: *mut bindings::i2c_client) -> kernel::ffi::c_int { // SAFETY: The I2C bus only ever calls the probe callback with a valid pointer to a // `struct i2c_client`. // // INVARIANT: `idev` is valid for the duration of `probe_callback()`. - let idev = unsafe { &*idev.cast::<I2cClient<device::CoreInternal>>() }; + let idev = unsafe { &*idev.cast::<I2cClient<device::CoreInternal<'_>>>() }; let info = Self::i2c_id_info(idev).or_else(|| <Self as driver::Adapter>::id_info(idev.as_ref())); @@ -172,24 +172,24 @@ impl<T: Driver + 'static> Adapter<T> { extern "C" fn remove_callback(idev: *mut bindings::i2c_client) { // SAFETY: `idev` is a valid pointer to a `struct i2c_client`. - let idev = unsafe { &*idev.cast::<I2cClient<device::CoreInternal>>() }; + let idev = unsafe { &*idev.cast::<I2cClient<device::CoreInternal<'_>>>() }; // SAFETY: `remove_callback` is only ever called after a successful call to // `probe_callback`, hence it's guaranteed that `I2cClient::set_drvdata()` has been called - // and stored a `Pin<KBox<T>>`. - let data = unsafe { idev.as_ref().drvdata_borrow::<T>() }; + // and stored a `Pin<KBox<T::Data<'_>>>`. + let data = unsafe { idev.as_ref().drvdata_borrow::<T::Data<'_>>() }; T::unbind(idev, data); } extern "C" fn shutdown_callback(idev: *mut bindings::i2c_client) { // SAFETY: `shutdown_callback` is only ever called for a valid `idev` - let idev = unsafe { &*idev.cast::<I2cClient<device::CoreInternal>>() }; + let idev = unsafe { &*idev.cast::<I2cClient<device::CoreInternal<'_>>>() }; // SAFETY: `shutdown_callback` is only ever called after a successful call to // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called - // and stored a `Pin<KBox<T>>`. - let data = unsafe { idev.as_ref().drvdata_borrow::<T>() }; + // and stored a `Pin<KBox<T::Data<'_>>>`. + let data = unsafe { idev.as_ref().drvdata_borrow::<T::Data<'_>>() }; T::shutdown(idev, data); } @@ -222,7 +222,7 @@ impl<T: Driver + 'static> Adapter<T> { } } -impl<T: Driver + 'static> driver::Adapter for Adapter<T> { +impl<T: Driver> driver::Adapter for Adapter<T> { type IdInfo = T::IdInfo; fn of_id_table() -> Option<of::IdTable<Self::IdInfo>> { @@ -294,22 +294,26 @@ macro_rules! module_i2c_driver { /// /// impl i2c::Driver for MyDriver { /// type IdInfo = (); +/// type Data<'bound> = Self; /// const I2C_ID_TABLE: Option<i2c::IdTable<Self::IdInfo>> = Some(&I2C_TABLE); /// const OF_ID_TABLE: Option<of::IdTable<Self::IdInfo>> = Some(&OF_TABLE); /// const ACPI_ID_TABLE: Option<acpi::IdTable<Self::IdInfo>> = Some(&ACPI_TABLE); /// -/// fn probe( -/// _idev: &i2c::I2cClient<Core>, -/// _id_info: Option<&Self::IdInfo>, -/// ) -> impl PinInit<Self, Error> { +/// fn probe<'bound>( +/// _idev: &'bound i2c::I2cClient<Core<'_>>, +/// _id_info: Option<&'bound Self::IdInfo>, +/// ) -> impl PinInit<Self::Data<'bound>, Error> + 'bound { /// Err(ENODEV) /// } /// -/// fn shutdown(_idev: &i2c::I2cClient<Core>, this: Pin<&Self>) { +/// fn shutdown<'bound>( +/// _idev: &'bound i2c::I2cClient<Core<'_>>, +/// this: Pin<&Self::Data<'bound>>, +/// ) { /// } /// } ///``` -pub trait Driver: Send { +pub trait Driver { /// The type holding information about each device id supported by the driver. // TODO: Use `associated_type_defaults` once stabilized: // @@ -318,6 +322,9 @@ pub trait Driver: Send { // ``` type IdInfo: 'static; + /// The type of the driver's bus device private data. + type Data<'bound>: Send + 'bound; + /// The table of device ids supported by the driver. const I2C_ID_TABLE: Option<IdTable<Self::IdInfo>> = None; @@ -331,10 +338,10 @@ pub trait Driver: Send { /// /// Called when a new i2c client is added or discovered. /// Implementers should attempt to initialize the client here. - fn probe( - dev: &I2cClient<device::Core>, - id_info: Option<&Self::IdInfo>, - ) -> impl PinInit<Self, Error>; + fn probe<'bound>( + dev: &'bound I2cClient<device::Core<'_>>, + id_info: Option<&'bound Self::IdInfo>, + ) -> impl PinInit<Self::Data<'bound>, Error> + 'bound; /// I2C driver shutdown. /// @@ -346,8 +353,8 @@ pub trait Driver: Send { /// /// This callback is distinct from final resource cleanup, as the driver instance remains valid /// after it returns. Any deallocation or teardown of driver-owned resources should instead be - /// handled in `Self::drop`. - fn shutdown(dev: &I2cClient<device::Core>, this: Pin<&Self>) { + /// handled in `Drop`. + fn shutdown<'bound>(dev: &'bound I2cClient<device::Core<'_>>, this: Pin<&Self::Data<'bound>>) { let _ = (dev, this); } @@ -360,8 +367,8 @@ pub trait Driver: Send { /// `&Device<Core>` or `&Device<Bound>` reference. For instance, drivers may try to perform I/O /// operations to gracefully tear down the device. /// - /// Otherwise, release operations for driver resources should be performed in `Self::drop`. - fn unbind(dev: &I2cClient<device::Core>, this: Pin<&Self>) { + /// Otherwise, release operations for driver resources should be performed in `Drop`. + fn unbind<'bound>(dev: &'bound I2cClient<device::Core<'_>>, this: Pin<&Self::Data<'bound>>) { let _ = (dev, this); } } diff --git a/rust/kernel/io/mem.rs b/rust/kernel/io/mem.rs index 7dc78d547f7a..fc2a3e24f8d5 100644 --- a/rust/kernel/io/mem.rs +++ b/rust/kernel/io/mem.rs @@ -62,33 +62,31 @@ impl<'a> IoRequest<'a> { /// /// impl platform::Driver for SampleDriver { /// # type IdInfo = (); + /// # type Data<'bound> = Self; /// - /// fn probe( - /// pdev: &platform::Device<Core>, - /// info: Option<&Self::IdInfo>, - /// ) -> impl PinInit<Self, Error> { + /// fn probe<'bound>( + /// pdev: &'bound platform::Device<Core<'_>>, + /// info: Option<&'bound Self::IdInfo>, + /// ) -> impl PinInit<Self, Error> + 'bound { /// let offset = 0; // Some offset. /// /// // If the size is known at compile time, use [`Self::iomap_sized`]. /// // /// // No runtime checks will apply when reading and writing. /// let request = pdev.io_request_by_index(0).ok_or(ENODEV)?; - /// let iomem = request.iomap_sized::<42>(); - /// let iomem = KBox::pin_init(iomem, GFP_KERNEL)?; - /// - /// let io = iomem.access(pdev.as_ref())?; + /// let iomem = request.iomap_sized::<42>()?; /// /// // Read and write a 32-bit value at `offset`. - /// let data = io.read32(offset); + /// let data = iomem.read32(offset); /// - /// io.write32(data, offset); + /// iomem.write32(data, offset); /// /// # Ok(SampleDriver) /// } /// } /// ``` - pub fn iomap_sized<const SIZE: usize>(self) -> impl PinInit<Devres<IoMem<SIZE>>, Error> + 'a { - IoMem::new(self) + pub fn iomap_sized<const SIZE: usize>(self) -> Result<IoMem<'a, SIZE>> { + IoMem::ioremap(self.device, self.resource) } /// Same as [`Self::iomap_sized`] but with exclusive access to the @@ -97,10 +95,8 @@ impl<'a> IoRequest<'a> { /// This uses the [`ioremap()`] C API. /// /// [`ioremap()`]: https://docs.kernel.org/driver-api/device-io.html#getting-access-to-the-device - pub fn iomap_exclusive_sized<const SIZE: usize>( - self, - ) -> impl PinInit<Devres<ExclusiveIoMem<SIZE>>, Error> + 'a { - ExclusiveIoMem::new(self) + pub fn iomap_exclusive_sized<const SIZE: usize>(self) -> Result<ExclusiveIoMem<'a, SIZE>> { + ExclusiveIoMem::ioremap(self.device, self.resource) } /// Maps an [`IoRequest`] where the size is not known at compile time, @@ -126,11 +122,12 @@ impl<'a> IoRequest<'a> { /// /// impl platform::Driver for SampleDriver { /// # type IdInfo = (); + /// # type Data<'bound> = Self; /// - /// fn probe( - /// pdev: &platform::Device<Core>, - /// info: Option<&Self::IdInfo>, - /// ) -> impl PinInit<Self, Error> { + /// fn probe<'bound>( + /// pdev: &'bound platform::Device<Core<'_>>, + /// info: Option<&'bound Self::IdInfo>, + /// ) -> impl PinInit<Self, Error> + 'bound { /// let offset = 0; // Some offset. /// /// // Unlike [`Self::iomap_sized`], here the size of the memory region @@ -138,27 +135,24 @@ impl<'a> IoRequest<'a> { /// // family of functions should be used, leading to runtime checks on every /// // access. /// let request = pdev.io_request_by_index(0).ok_or(ENODEV)?; - /// let iomem = request.iomap(); - /// let iomem = KBox::pin_init(iomem, GFP_KERNEL)?; - /// - /// let io = iomem.access(pdev.as_ref())?; + /// let iomem = request.iomap()?; /// - /// let data = io.try_read32(offset)?; + /// let data = iomem.try_read32(offset)?; /// - /// io.try_write32(data, offset)?; + /// iomem.try_write32(data, offset)?; /// /// # Ok(SampleDriver) /// } /// } /// ``` - pub fn iomap(self) -> impl PinInit<Devres<IoMem<0>>, Error> + 'a { - Self::iomap_sized::<0>(self) + pub fn iomap(self) -> Result<IoMem<'a>> { + self.iomap_sized::<0>() } /// Same as [`Self::iomap`] but with exclusive access to the underlying /// region. - pub fn iomap_exclusive(self) -> impl PinInit<Devres<ExclusiveIoMem<0>>, Error> + 'a { - Self::iomap_exclusive_sized::<0>(self) + pub fn iomap_exclusive(self) -> Result<ExclusiveIoMem<'a, 0>> { + self.iomap_exclusive_sized::<0>() } } @@ -167,9 +161,9 @@ impl<'a> IoRequest<'a> { /// # Invariants /// /// - [`ExclusiveIoMem`] has exclusive access to the underlying [`IoMem`]. -pub struct ExclusiveIoMem<const SIZE: usize> { +pub struct ExclusiveIoMem<'a, const SIZE: usize> { /// The underlying `IoMem` instance. - iomem: IoMem<SIZE>, + iomem: IoMem<'a, SIZE>, /// The region abstraction. This represents exclusive access to the /// range represented by the underlying `iomem`. @@ -178,9 +172,9 @@ pub struct ExclusiveIoMem<const SIZE: usize> { _region: Region, } -impl<const SIZE: usize> ExclusiveIoMem<SIZE> { +impl<'a, const SIZE: usize> ExclusiveIoMem<'a, SIZE> { /// Creates a new `ExclusiveIoMem` instance. - fn ioremap(resource: &Resource) -> Result<Self> { + fn ioremap(dev: &'a Device<Bound>, resource: &Resource) -> Result<Self> { let start = resource.start(); let size = resource.size(); let name = resource.name().unwrap_or_default(); @@ -194,26 +188,29 @@ impl<const SIZE: usize> ExclusiveIoMem<SIZE> { ) .ok_or(EBUSY)?; - let iomem = IoMem::ioremap(resource)?; + let iomem = IoMem::ioremap(dev, resource)?; - let iomem = ExclusiveIoMem { + Ok(ExclusiveIoMem { iomem, _region: region, - }; - - Ok(iomem) + }) } - /// Creates a new `ExclusiveIoMem` instance from a previously acquired [`IoRequest`]. - pub fn new<'a>(io_request: IoRequest<'a>) -> impl PinInit<Devres<Self>, Error> + 'a { - let dev = io_request.device; - let res = io_request.resource; - - Devres::new(dev, Self::ioremap(res)) + /// Consume the `ExclusiveIoMem` and register it as a device-managed resource. + /// + /// The returned `Devres<ExclusiveIoMem<'static, SIZE>>` can outlive the original lifetime + /// `'a`. Access to the I/O memory is revoked when the device is unbound. + pub fn into_devres(self) -> Result<Devres<ExclusiveIoMem<'static, SIZE>>> { + // SAFETY: Casting to `'static` is sound because `Devres` guarantees the + // `ExclusiveIoMem` does not actually outlive the device -- access is revoked and the + // resource is released when the device is unbound. + let iomem: ExclusiveIoMem<'static, SIZE> = unsafe { core::mem::transmute(self) }; + let dev = iomem.iomem.dev; + Devres::new(dev, iomem) } } -impl<const SIZE: usize> Deref for ExclusiveIoMem<SIZE> { +impl<const SIZE: usize> Deref for ExclusiveIoMem<'_, SIZE> { type Target = Mmio<SIZE>; fn deref(&self) -> &Self::Target { @@ -230,12 +227,13 @@ impl<const SIZE: usize> Deref for ExclusiveIoMem<SIZE> { /// /// [`IoMem`] always holds an [`MmioRaw`] instance that holds a valid pointer to the /// start of the I/O memory mapped region. -pub struct IoMem<const SIZE: usize = 0> { +pub struct IoMem<'a, const SIZE: usize = 0> { + dev: &'a Device<Bound>, io: MmioRaw<SIZE>, } -impl<const SIZE: usize> IoMem<SIZE> { - fn ioremap(resource: &Resource) -> Result<Self> { +impl<'a, const SIZE: usize> IoMem<'a, SIZE> { + fn ioremap(dev: &'a Device<Bound>, resource: &Resource) -> Result<Self> { // Note: Some ioremap() implementations use types that depend on the CPU // word width rather than the bus address width. // @@ -267,28 +265,33 @@ impl<const SIZE: usize> IoMem<SIZE> { } let io = MmioRaw::new(addr as usize, size)?; - let io = IoMem { io }; - Ok(io) + Ok(IoMem { dev, io }) } - /// Creates a new `IoMem` instance from a previously acquired [`IoRequest`]. - pub fn new<'a>(io_request: IoRequest<'a>) -> impl PinInit<Devres<Self>, Error> + 'a { - let dev = io_request.device; - let res = io_request.resource; - - Devres::new(dev, Self::ioremap(res)) + /// Consume the `IoMem` and register it as a device-managed resource. + /// + /// The returned `Devres<IoMem<'static, SIZE>>` can outlive the original + /// lifetime `'a`. Access to the I/O memory is revoked when the device + /// is unbound. + pub fn into_devres(self) -> Result<Devres<IoMem<'static, SIZE>>> { + // SAFETY: Casting to `'static` is sound because `Devres` guarantees the `IoMem` does not + // actually outlive the device -- access is revoked and the resource is released when the + // device is unbound. + let iomem: IoMem<'static, SIZE> = unsafe { core::mem::transmute(self) }; + let dev = iomem.dev; + Devres::new(dev, iomem) } } -impl<const SIZE: usize> Drop for IoMem<SIZE> { +impl<const SIZE: usize> Drop for IoMem<'_, SIZE> { fn drop(&mut self) { // SAFETY: Safe as by the invariant of `Io`. unsafe { bindings::iounmap(self.io.addr() as *mut c_void) } } } -impl<const SIZE: usize> Deref for IoMem<SIZE> { +impl<const SIZE: usize> Deref for IoMem<'_, SIZE> { type Target = Mmio<SIZE>; fn deref(&self) -> &Self::Target { diff --git a/rust/kernel/pci.rs b/rust/kernel/pci.rs index af74ddff6114..5071cae6543f 100644 --- a/rust/kernel/pci.rs +++ b/rust/kernel/pci.rs @@ -59,18 +59,18 @@ pub struct Adapter<T: Driver>(T); // SAFETY: // - `bindings::pci_driver` is a C type declared as `repr(C)`. -// - `T` is the type of the driver's device private data. +// - `T::Data` is the type of the driver's device private data. // - `struct pci_driver` embeds a `struct device_driver`. // - `DEVICE_DRIVER_OFFSET` is the correct byte offset to the embedded `struct device_driver`. -unsafe impl<T: Driver + 'static> driver::DriverLayout for Adapter<T> { +unsafe impl<T: Driver> driver::DriverLayout for Adapter<T> { type DriverType = bindings::pci_driver; - type DriverData = T; + type DriverData<'bound> = T::Data<'bound>; const DEVICE_DRIVER_OFFSET: usize = core::mem::offset_of!(Self::DriverType, driver); } // SAFETY: A call to `unregister` for a given instance of `DriverType` is guaranteed to be valid if // a preceding call to `register` has been successful. -unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { +unsafe impl<T: Driver> driver::RegistrationOps for Adapter<T> { unsafe fn register( pdrv: &Opaque<Self::DriverType>, name: &'static CStr, @@ -96,7 +96,7 @@ unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { } } -impl<T: Driver + 'static> Adapter<T> { +impl<T: Driver> Adapter<T> { extern "C" fn probe_callback( pdev: *mut bindings::pci_dev, id: *const bindings::pci_device_id, @@ -105,7 +105,7 @@ impl<T: Driver + 'static> Adapter<T> { // `struct pci_dev`. // // INVARIANT: `pdev` is valid for the duration of `probe_callback()`. - let pdev = unsafe { &*pdev.cast::<Device<device::CoreInternal>>() }; + let pdev = unsafe { &*pdev.cast::<Device<device::CoreInternal<'_>>>() }; // SAFETY: `DeviceId` is a `#[repr(transparent)]` wrapper of `struct pci_device_id` and // does not add additional invariants, so it's safe to transmute. @@ -125,12 +125,12 @@ impl<T: Driver + 'static> Adapter<T> { // `struct pci_dev`. // // INVARIANT: `pdev` is valid for the duration of `remove_callback()`. - let pdev = unsafe { &*pdev.cast::<Device<device::CoreInternal>>() }; + let pdev = unsafe { &*pdev.cast::<Device<device::CoreInternal<'_>>>() }; // SAFETY: `remove_callback` is only ever called after a successful call to // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called - // and stored a `Pin<KBox<T>>`. - let data = unsafe { pdev.as_ref().drvdata_borrow::<T>() }; + // and stored a `Pin<KBox<T::Data<'_>>>`. + let data = unsafe { pdev.as_ref().drvdata_borrow::<T::Data<'_>>() }; T::unbind(pdev, data); } @@ -279,19 +279,20 @@ macro_rules! pci_device_table { /// /// impl pci::Driver for MyDriver { /// type IdInfo = (); +/// type Data<'bound> = Self; /// const ID_TABLE: pci::IdTable<Self::IdInfo> = &PCI_TABLE; /// -/// fn probe( -/// _pdev: &pci::Device<Core>, -/// _id_info: &Self::IdInfo, -/// ) -> impl PinInit<Self, Error> { +/// fn probe<'bound>( +/// _pdev: &'bound pci::Device<Core<'_>>, +/// _id_info: &'bound Self::IdInfo, +/// ) -> impl PinInit<Self::Data<'bound>, Error> + 'bound { /// Err(ENODEV) /// } /// } ///``` /// Drivers must implement this trait in order to get a PCI driver registered. Please refer to the /// `Adapter` documentation for an example. -pub trait Driver: Send { +pub trait Driver { /// The type holding information about each device id supported by the driver. // TODO: Use `associated_type_defaults` once stabilized: // @@ -300,6 +301,9 @@ pub trait Driver: Send { // ``` type IdInfo: 'static; + /// The type of the driver's bus device private data. + type Data<'bound>: Send + 'bound; + /// The table of device ids supported by the driver. const ID_TABLE: IdTable<Self::IdInfo>; @@ -307,7 +311,10 @@ pub trait Driver: Send { /// /// Called when a new pci device is added or discovered. Implementers should /// attempt to initialize the device here. - fn probe(dev: &Device<device::Core>, id_info: &Self::IdInfo) -> impl PinInit<Self, Error>; + fn probe<'bound>( + dev: &'bound Device<device::Core<'_>>, + id_info: &'bound Self::IdInfo, + ) -> impl PinInit<Self::Data<'bound>, Error> + 'bound; /// PCI driver unbind. /// @@ -318,8 +325,8 @@ pub trait Driver: Send { /// `&Device<Core>` or `&Device<Bound>` reference. For instance, drivers may try to perform I/O /// operations to gracefully tear down the device. /// - /// Otherwise, release operations for driver resources should be performed in `Self::drop`. - fn unbind(dev: &Device<device::Core>, this: Pin<&Self>) { + /// Otherwise, release operations for driver resources should be performed in `Drop`. + fn unbind<'bound>(dev: &'bound Device<device::Core<'_>>, this: Pin<&Self::Data<'bound>>) { let _ = (dev, this); } } @@ -354,7 +361,7 @@ impl Device { /// /// ``` /// # use kernel::{device::Core, pci::{self, Vendor}, prelude::*}; - /// fn log_device_info(pdev: &pci::Device<Core>) -> Result { + /// fn log_device_info(pdev: &pci::Device<Core<'_>>) -> Result { /// // Get an instance of `Vendor`. /// let vendor = pdev.vendor_id(); /// dev_info!( @@ -445,7 +452,7 @@ impl Device { } } -impl Device<device::Core> { +impl<'a> Device<device::Core<'a>> { /// Enable memory resources for this device. pub fn enable_device_mem(&self) -> Result { // SAFETY: `self.as_raw` is guaranteed to be a pointer to a valid `struct pci_dev`. @@ -471,7 +478,7 @@ unsafe impl<Ctx: device::DeviceContext> device::AsBusDevice<Ctx> for Device<Ctx> kernel::impl_device_context_deref!(unsafe { Device }); kernel::impl_device_context_into_aref!(Device); -impl crate::dma::Device for Device<device::Core> {} +impl<'a> crate::dma::Device<'a> for Device<device::Core<'a>> {} // SAFETY: Instances of `Device` are always reference-counted. unsafe impl crate::sync::aref::AlwaysRefCounted for Device { @@ -523,3 +530,7 @@ unsafe impl Send for Device {} // SAFETY: `Device` can be shared among threads because all methods of `Device` // (i.e. `Device<Normal>) are thread safe. unsafe impl Sync for Device {} + +// SAFETY: Same as `Device<Normal>` -- the underlying `struct pci_dev` is the same; +// `Bound` is a zero-sized type-state marker that does not affect thread safety. +unsafe impl Sync for Device<device::Bound> {} diff --git a/rust/kernel/pci/id.rs b/rust/kernel/pci/id.rs index 50005d176561..dbaf301666e7 100644 --- a/rust/kernel/pci/id.rs +++ b/rust/kernel/pci/id.rs @@ -19,7 +19,7 @@ use crate::{ /// /// ``` /// # use kernel::{device::Core, pci::{self, Class}, prelude::*}; -/// fn probe_device(pdev: &pci::Device<Core>) -> Result { +/// fn probe_device(pdev: &pci::Device<Core<'_>>) -> Result { /// let pci_class = pdev.pci_class(); /// dev_info!( /// pdev, diff --git a/rust/kernel/pci/io.rs b/rust/kernel/pci/io.rs index ae78676c927f..0461e01aaa20 100644 --- a/rust/kernel/pci/io.rs +++ b/rust/kernel/pci/io.rs @@ -14,8 +14,7 @@ use crate::{ Mmio, MmioRaw, // }, - prelude::*, - sync::aref::ARef, // + prelude::*, // }; use core::{ marker::PhantomData, @@ -146,14 +145,18 @@ impl<'a, S: ConfigSpaceKind> IoKnownSize for ConfigSpace<'a, S> { /// /// `Bar` always holds an `IoRaw` instance that holds a valid pointer to the start of the I/O /// memory mapped PCI BAR and its size. -pub struct Bar<const SIZE: usize = 0> { - pdev: ARef<Device>, +pub struct Bar<'a, const SIZE: usize = 0> { + pdev: &'a Device<device::Bound>, io: MmioRaw<SIZE>, num: i32, } -impl<const SIZE: usize> Bar<SIZE> { - pub(super) fn new(pdev: &Device, num: u32, name: &CStr) -> Result<Self> { +impl<'a, const SIZE: usize> Bar<'a, SIZE> { + pub(super) fn new( + pdev: &'a Device<device::Bound>, + num: u32, + name: &'static CStr, + ) -> Result<Self> { let len = pdev.resource_len(num)?; if len == 0 { return Err(ENOMEM); @@ -196,11 +199,7 @@ impl<const SIZE: usize> Bar<SIZE> { } }; - Ok(Bar { - pdev: pdev.into(), - io, - num, - }) + Ok(Bar { pdev, io, num }) } /// # Safety @@ -219,11 +218,24 @@ impl<const SIZE: usize> Bar<SIZE> { fn release(&self) { // SAFETY: The safety requirements are guaranteed by the type invariant of `self.pdev`. - unsafe { Self::do_release(&self.pdev, self.io.addr(), self.num) }; + unsafe { Self::do_release(self.pdev, self.io.addr(), self.num) }; + } + + /// Consume the `Bar` and register it as a device-managed resource. + /// + /// The returned `Devres<Bar<'static, SIZE>>` can outlive the original lifetime `'a`. Access + /// to the BAR is revoked when the device is unbound. + pub fn into_devres(self) -> Result<Devres<Bar<'static, SIZE>>> { + // SAFETY: Casting to `'static` is sound because `Devres` guarantees the `Bar` does not + // actually outlive the device -- access is revoked and the resource is released when the + // device is unbound. + let bar: Bar<'static, SIZE> = unsafe { core::mem::transmute(self) }; + let pdev = bar.pdev; + Devres::new(pdev.as_ref(), bar) } } -impl Bar { +impl Bar<'_> { #[inline] pub(super) fn index_is_valid(index: u32) -> bool { // A `struct pci_dev` owns an array of resources with at most `PCI_NUM_RESOURCES` entries. @@ -231,13 +243,13 @@ impl Bar { } } -impl<const SIZE: usize> Drop for Bar<SIZE> { +impl<const SIZE: usize> Drop for Bar<'_, SIZE> { fn drop(&mut self) { self.release(); } } -impl<const SIZE: usize> Deref for Bar<SIZE> { +impl<const SIZE: usize> Deref for Bar<'_, SIZE> { type Target = Mmio<SIZE>; fn deref(&self) -> &Self::Target { @@ -252,17 +264,13 @@ impl Device<device::Bound> { pub fn iomap_region_sized<'a, const SIZE: usize>( &'a self, bar: u32, - name: &'a CStr, - ) -> impl PinInit<Devres<Bar<SIZE>>, Error> + 'a { - Devres::new(self.as_ref(), Bar::<SIZE>::new(self, bar, name)) + name: &'static CStr, + ) -> Result<Bar<'a, SIZE>> { + Bar::new(self, bar, name) } /// Maps an entire PCI BAR after performing a region-request on it. - pub fn iomap_region<'a>( - &'a self, - bar: u32, - name: &'a CStr, - ) -> impl PinInit<Devres<Bar>, Error> + 'a { + pub fn iomap_region<'a>(&'a self, bar: u32, name: &'static CStr) -> Result<Bar<'a>> { self.iomap_region_sized::<0>(bar, name) } diff --git a/rust/kernel/platform.rs b/rust/kernel/platform.rs index 8917d4ee499f..9b362e0495d3 100644 --- a/rust/kernel/platform.rs +++ b/rust/kernel/platform.rs @@ -45,18 +45,18 @@ pub struct Adapter<T: Driver>(T); // SAFETY: // - `bindings::platform_driver` is a C type declared as `repr(C)`. -// - `T` is the type of the driver's device private data. +// - `T::Data` is the type of the driver's device private data. // - `struct platform_driver` embeds a `struct device_driver`. // - `DEVICE_DRIVER_OFFSET` is the correct byte offset to the embedded `struct device_driver`. -unsafe impl<T: Driver + 'static> driver::DriverLayout for Adapter<T> { +unsafe impl<T: Driver> driver::DriverLayout for Adapter<T> { type DriverType = bindings::platform_driver; - type DriverData = T; + type DriverData<'bound> = T::Data<'bound>; const DEVICE_DRIVER_OFFSET: usize = core::mem::offset_of!(Self::DriverType, driver); } // SAFETY: A call to `unregister` for a given instance of `DriverType` is guaranteed to be valid if // a preceding call to `register` has been successful. -unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { +unsafe impl<T: Driver> driver::RegistrationOps for Adapter<T> { unsafe fn register( pdrv: &Opaque<Self::DriverType>, name: &'static CStr, @@ -82,7 +82,9 @@ unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { } // SAFETY: `pdrv` is guaranteed to be a valid `DriverType`. - to_result(unsafe { bindings::__platform_driver_register(pdrv.get(), module.0) }) + to_result(unsafe { + bindings::__platform_driver_register(pdrv.get(), module.0, name.as_char_ptr()) + }) } unsafe fn unregister(pdrv: &Opaque<Self::DriverType>) { @@ -91,13 +93,13 @@ unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { } } -impl<T: Driver + 'static> Adapter<T> { +impl<T: Driver> Adapter<T> { extern "C" fn probe_callback(pdev: *mut bindings::platform_device) -> kernel::ffi::c_int { // SAFETY: The platform bus only ever calls the probe callback with a valid pointer to a // `struct platform_device`. // // INVARIANT: `pdev` is valid for the duration of `probe_callback()`. - let pdev = unsafe { &*pdev.cast::<Device<device::CoreInternal>>() }; + let pdev = unsafe { &*pdev.cast::<Device<device::CoreInternal<'_>>>() }; let info = <Self as driver::Adapter>::id_info(pdev.as_ref()); from_result(|| { @@ -113,18 +115,18 @@ impl<T: Driver + 'static> Adapter<T> { // `struct platform_device`. // // INVARIANT: `pdev` is valid for the duration of `remove_callback()`. - let pdev = unsafe { &*pdev.cast::<Device<device::CoreInternal>>() }; + let pdev = unsafe { &*pdev.cast::<Device<device::CoreInternal<'_>>>() }; // SAFETY: `remove_callback` is only ever called after a successful call to // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called - // and stored a `Pin<KBox<T>>`. - let data = unsafe { pdev.as_ref().drvdata_borrow::<T>() }; + // and stored a `Pin<KBox<T::Data<'_>>>`. + let data = unsafe { pdev.as_ref().drvdata_borrow::<T::Data<'_>>() }; T::unbind(pdev, data); } } -impl<T: Driver + 'static> driver::Adapter for Adapter<T> { +impl<T: Driver> driver::Adapter for Adapter<T> { type IdInfo = T::IdInfo; fn of_id_table() -> Option<of::IdTable<Self::IdInfo>> { @@ -192,18 +194,19 @@ macro_rules! module_platform_driver { /// /// impl platform::Driver for MyDriver { /// type IdInfo = (); +/// type Data<'bound> = Self; /// const OF_ID_TABLE: Option<of::IdTable<Self::IdInfo>> = Some(&OF_TABLE); /// const ACPI_ID_TABLE: Option<acpi::IdTable<Self::IdInfo>> = Some(&ACPI_TABLE); /// -/// fn probe( -/// _pdev: &platform::Device<Core>, -/// _id_info: Option<&Self::IdInfo>, -/// ) -> impl PinInit<Self, Error> { +/// fn probe<'bound>( +/// _pdev: &'bound platform::Device<Core<'_>>, +/// _id_info: Option<&'bound Self::IdInfo>, +/// ) -> impl PinInit<Self::Data<'bound>, Error> + 'bound { /// Err(ENODEV) /// } /// } ///``` -pub trait Driver: Send { +pub trait Driver { /// The type holding driver private data about each device id supported by the driver. // TODO: Use associated_type_defaults once stabilized: // @@ -212,6 +215,9 @@ pub trait Driver: Send { // ``` type IdInfo: 'static; + /// The type of the driver's bus device private data. + type Data<'bound>: Send + 'bound; + /// The table of OF device ids supported by the driver. const OF_ID_TABLE: Option<of::IdTable<Self::IdInfo>> = None; @@ -222,10 +228,10 @@ pub trait Driver: Send { /// /// Called when a new platform device is added or discovered. /// Implementers should attempt to initialize the device here. - fn probe( - dev: &Device<device::Core>, - id_info: Option<&Self::IdInfo>, - ) -> impl PinInit<Self, Error>; + fn probe<'bound>( + dev: &'bound Device<device::Core<'_>>, + id_info: Option<&'bound Self::IdInfo>, + ) -> impl PinInit<Self::Data<'bound>, Error> + 'bound; /// Platform driver unbind. /// @@ -236,8 +242,8 @@ pub trait Driver: Send { /// `&Device<Core>` or `&Device<Bound>` reference. For instance, drivers may try to perform I/O /// operations to gracefully tear down the device. /// - /// Otherwise, release operations for driver resources should be performed in `Self::drop`. - fn unbind(dev: &Device<device::Core>, this: Pin<&Self>) { + /// Otherwise, release operations for driver resources should be performed in `Drop`. + fn unbind<'bound>(dev: &'bound Device<device::Core<'_>>, this: Pin<&Self::Data<'bound>>) { let _ = (dev, this); } } @@ -509,7 +515,7 @@ impl Device<Bound> { kernel::impl_device_context_deref!(unsafe { Device }); kernel::impl_device_context_into_aref!(Device); -impl crate::dma::Device for Device<device::Core> {} +impl<'a> crate::dma::Device<'a> for Device<device::Core<'a>> {} // SAFETY: Instances of `Device` are always reference-counted. unsafe impl crate::sync::aref::AlwaysRefCounted for Device { @@ -561,3 +567,7 @@ unsafe impl Send for Device {} // SAFETY: `Device` can be shared among threads because all methods of `Device` // (i.e. `Device<Normal>) are thread safe. unsafe impl Sync for Device {} + +// SAFETY: Same as `Device<Normal>` -- the underlying `struct platform_device` is the same; +// `Bound` is a zero-sized type-state marker that does not affect thread safety. +unsafe impl Sync for Device<device::Bound> {} diff --git a/rust/kernel/types.rs b/rust/kernel/types.rs index 4329d3c2c2e5..ac316fd7b538 100644 --- a/rust/kernel/types.rs +++ b/rust/kernel/types.rs @@ -11,6 +11,10 @@ use core::{ }; use pin_init::{PinInit, Wrapper, Zeroable}; +#[doc(hidden)] +pub mod for_lt; +pub use for_lt::ForLt; + /// Used to transfer ownership to and from foreign (non-Rust) languages. /// /// Ownership is transferred from Rust to a foreign language by calling [`Self::into_foreign`] and @@ -27,10 +31,14 @@ pub unsafe trait ForeignOwnable: Sized { const FOREIGN_ALIGN: usize; /// Type used to immutably borrow a value that is currently foreign-owned. - type Borrowed<'a>; + type Borrowed<'a> + where + Self: 'a; /// Type used to mutably borrow a value that is currently foreign-owned. - type BorrowedMut<'a>; + type BorrowedMut<'a> + where + Self: 'a; /// Converts a Rust-owned object to a foreign-owned one. /// diff --git a/rust/kernel/types/for_lt.rs b/rust/kernel/types/for_lt.rs new file mode 100644 index 000000000000..d44323c28e8d --- /dev/null +++ b/rust/kernel/types/for_lt.rs @@ -0,0 +1,122 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +//! Provide implementation and test of the `ForLt` trait and macro. +//! +//! This module is hidden and user should just use `ForLt!` directly. + +use core::marker::PhantomData; + +/// Representation of types generic over a lifetime. +/// +/// The type must be covariant over the generic lifetime, i.e. the lifetime parameter +/// can be soundly shortened. +/// +/// The lifetime involved must be covariant. +/// +/// # Macro +/// +/// It is not recommended to implement this trait directly. `ForLt!` macro is provided to obtain a +/// type that implements this trait. +/// +/// The full syntax is +/// +/// ``` +/// # use kernel::types::ForLt; +/// # fn expect_lt<F: ForLt>() {} +/// # struct TypeThatUse<'a>(&'a ()); +/// # expect_lt::< +/// ForLt!(for<'a> TypeThatUse<'a>) +/// # >(); +/// ``` +/// +/// which gives a type so that `<ForLt!(for<'a> TypeThatUse<'a>) as ForLt>::Of<'b>` +/// is `TypeThatUse<'b>`. +/// +/// You may also use a short-hand syntax which works similar to lifetime elision. +/// The macro also accepts types that do not involve a lifetime at all. +/// +/// ``` +/// # use kernel::types::ForLt; +/// # fn expect_lt<F: ForLt>() {} +/// # struct TypeThatUse<'a>(&'a ()); +/// # expect_lt::< +/// ForLt!(TypeThatUse<'_>) // Equivalent to `ForLt!(for<'a> TypeThatUse<'a>)`. +/// # >(); +/// # expect_lt::< +/// ForLt!(&u32) // Equivalent to `ForLt!(for<'a> &'a u32)`. +/// # >(); +/// # expect_lt::< +/// ForLt!(u32) // Equivalent to `ForLt!(for<'a> u32)`. +/// # >(); +/// ``` +/// +/// The macro will attempt to prove that the type is indeed covariant over the lifetime supplied. +/// When it cannot be syntactically proven, it will emit checks to ask the Rust compiler to prove +/// it. +/// +/// ```ignore,compile_fail +/// # use kernel::types::ForLt; +/// # fn expect_lt<F: ForLt>() {} +/// # expect_lt::< +/// ForLt!(fn(&u32)) // Contravariant, will fail compilation. +/// # >(); +/// ``` +/// +/// There is a limitation if the type refers to generic parameters; if the macro cannot prove the +/// covariance syntactically, the emitted checks will fail the compilation as it needs to refer to +/// the generic parameter but is in a separate item. +/// +/// ``` +/// # use kernel::types::ForLt; +/// fn expect_lt<F: ForLt>() {} +/// # #[allow(clippy::unnecessary_safety_comment, reason = "false positive")] +/// fn generic_fn<T: 'static>() { +/// // Syntactically proven by the macro +/// expect_lt::<ForLt!(&T)>(); +/// // Syntactically proven by the macro +/// expect_lt::<ForLt!(&KBox<T>)>(); +/// // Cannot be syntactically proven, need to check covariance of `KBox` +/// // expect_lt::<ForLt!(&KBox<&T>)>(); +/// } +/// ``` +/// +/// # Safety +/// +/// `Self::Of<'a>` must be covariant over the lifetime `'a`. +pub unsafe trait ForLt { + /// The type parameterized by the lifetime. + type Of<'a>: 'a; + + /// Cast a reference to a shorter lifetime. + #[inline(always)] + fn cast_ref<'r, 'short: 'r, 'long: 'short>(long: &'r Self::Of<'long>) -> &'r Self::Of<'short> { + // SAFETY: This is sound as this trait guarantees covariance. + unsafe { core::mem::transmute(long) } + } +} +pub use macros::ForLt; + +/// This is intended to be an "unsafe-to-refer-to" type. +/// +/// Must only be used by the `ForLt!` macro. +/// +/// `T` is the magic `dyn for<'a> WithLt<'a, TypeThatUse<'a>>` generated by macro. +/// +/// `WF` is a type that the macro can use to assert some specific type is well-formed. +/// +/// `N` is to provide the macro a place to emit arbitrary items, in case it needs to prove +/// additional properties. +#[doc(hidden)] +pub struct UnsafeForLtImpl<T: ?Sized, WF, const N: usize>(PhantomData<(WF, T)>); + +// This is a helper trait for implementation `ForLt` to be able to use HRTB. +#[doc(hidden)] +pub trait WithLt<'a> { + type Of: 'a; +} + +// SAFETY: In `ForLt!` macro, a covariance proof is generated when naming `UnsafeForLtImpl` +// and it will fail to evaluate if the type is not covariant. +unsafe impl<T: ?Sized + for<'a> WithLt<'a>, WF> ForLt for UnsafeForLtImpl<T, WF, 0> { + type Of<'a> = <T as WithLt<'a>>::Of; +} diff --git a/rust/kernel/usb.rs b/rust/kernel/usb.rs index 9c17a672cd27..7aff0c82d0af 100644 --- a/rust/kernel/usb.rs +++ b/rust/kernel/usb.rs @@ -36,18 +36,18 @@ pub struct Adapter<T: Driver>(T); // SAFETY: // - `bindings::usb_driver` is a C type declared as `repr(C)`. -// - `T` is the type of the driver's device private data. +// - `T::Data` is the type of the driver's device private data. // - `struct usb_driver` embeds a `struct device_driver`. // - `DEVICE_DRIVER_OFFSET` is the correct byte offset to the embedded `struct device_driver`. -unsafe impl<T: Driver + 'static> driver::DriverLayout for Adapter<T> { +unsafe impl<T: Driver> driver::DriverLayout for Adapter<T> { type DriverType = bindings::usb_driver; - type DriverData = T; + type DriverData<'bound> = T::Data<'bound>; const DEVICE_DRIVER_OFFSET: usize = core::mem::offset_of!(Self::DriverType, driver); } // SAFETY: A call to `unregister` for a given instance of `DriverType` is guaranteed to be valid if // a preceding call to `register` has been successful. -unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { +unsafe impl<T: Driver> driver::RegistrationOps for Adapter<T> { unsafe fn register( udrv: &Opaque<Self::DriverType>, name: &'static CStr, @@ -73,7 +73,7 @@ unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { } } -impl<T: Driver + 'static> Adapter<T> { +impl<T: Driver> Adapter<T> { extern "C" fn probe_callback( intf: *mut bindings::usb_interface, id: *const bindings::usb_device_id, @@ -82,7 +82,7 @@ impl<T: Driver + 'static> Adapter<T> { // `struct usb_interface` and `struct usb_device_id`. // // INVARIANT: `intf` is valid for the duration of `probe_callback()`. - let intf = unsafe { &*intf.cast::<Interface<device::CoreInternal>>() }; + let intf = unsafe { &*intf.cast::<Interface<device::CoreInternal<'_>>>() }; from_result(|| { // SAFETY: `DeviceId` is a `#[repr(transparent)]` wrapper of `struct usb_device_id` and @@ -92,7 +92,7 @@ impl<T: Driver + 'static> Adapter<T> { let info = T::ID_TABLE.info(id.index()); let data = T::probe(intf, id, info); - let dev: &device::Device<device::CoreInternal> = intf.as_ref(); + let dev: &device::Device<device::CoreInternal<'_>> = intf.as_ref(); dev.set_drvdata(data)?; Ok(0) }) @@ -103,14 +103,14 @@ impl<T: Driver + 'static> Adapter<T> { // `struct usb_interface`. // // INVARIANT: `intf` is valid for the duration of `disconnect_callback()`. - let intf = unsafe { &*intf.cast::<Interface<device::CoreInternal>>() }; + let intf = unsafe { &*intf.cast::<Interface<device::CoreInternal<'_>>>() }; - let dev: &device::Device<device::CoreInternal> = intf.as_ref(); + let dev: &device::Device<device::CoreInternal<'_>> = intf.as_ref(); // SAFETY: `disconnect_callback` is only ever called after a successful call to // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called - // and stored a `Pin<KBox<T>>`. - let data = unsafe { dev.drvdata_borrow::<T>() }; + // and stored a `Pin<KBox<T::Data<'_>>>`. + let data = unsafe { dev.drvdata_borrow::<T::Data<'_>>() }; T::disconnect(intf, data); } @@ -287,23 +287,31 @@ macro_rules! usb_device_table { /// /// impl usb::Driver for MyDriver { /// type IdInfo = (); +/// type Data<'bound> = Self; /// const ID_TABLE: usb::IdTable<Self::IdInfo> = &USB_TABLE; /// -/// fn probe( -/// _interface: &usb::Interface<Core>, +/// fn probe<'bound>( +/// _interface: &'bound usb::Interface<Core<'_>>, /// _id: &usb::DeviceId, -/// _info: &Self::IdInfo, -/// ) -> impl PinInit<Self, Error> { +/// _info: &'bound Self::IdInfo, +/// ) -> impl PinInit<Self::Data<'bound>, Error> + 'bound { /// Err(ENODEV) /// } /// -/// fn disconnect(_interface: &usb::Interface<Core>, _data: Pin<&Self>) {} +/// fn disconnect<'bound>( +/// _interface: &'bound usb::Interface<Core<'_>>, +/// _data: Pin<&Self::Data<'bound>>, +/// ) { +/// } /// } ///``` pub trait Driver { /// The type holding information about each one of the device ids supported by the driver. type IdInfo: 'static; + /// The type of the driver's bus device private data. + type Data<'bound>: Send + 'bound; + /// The table of device ids supported by the driver. const ID_TABLE: IdTable<Self::IdInfo>; @@ -311,16 +319,19 @@ pub trait Driver { /// /// Called when a new USB interface is bound to this driver. /// Implementers should attempt to initialize the interface here. - fn probe( - interface: &Interface<device::Core>, + fn probe<'bound>( + interface: &'bound Interface<device::Core<'_>>, id: &DeviceId, - id_info: &Self::IdInfo, - ) -> impl PinInit<Self, Error>; + id_info: &'bound Self::IdInfo, + ) -> impl PinInit<Self::Data<'bound>, Error> + 'bound; /// USB driver disconnect. /// /// Called when the USB interface is about to be unbound from this driver. - fn disconnect(interface: &Interface<device::Core>, data: Pin<&Self>); + fn disconnect<'bound>( + interface: &'bound Interface<device::Core<'_>>, + data: Pin<&Self::Data<'bound>>, + ); } /// A USB interface. @@ -464,6 +475,10 @@ unsafe impl Send for Device {} // allow any mutation through a shared reference. unsafe impl Sync for Device {} +// SAFETY: Same as `Device<Normal>` -- the underlying `struct usb_device` is the same; +// `Bound` is a zero-sized type-state marker that does not affect thread safety. +unsafe impl Sync for Device<device::Bound> {} + /// Declares a kernel module that exposes a single USB driver. /// /// # Examples diff --git a/rust/macros/for_lt.rs b/rust/macros/for_lt.rs new file mode 100644 index 000000000000..364d4113cd10 --- /dev/null +++ b/rust/macros/for_lt.rs @@ -0,0 +1,248 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +use proc_macro2::{ + Span, + TokenStream, // +}; +use quote::{ + format_ident, + quote, // +}; +use syn::{ + parse::{ + Parse, + ParseStream, // + }, + visit::Visit, + visit_mut::VisitMut, + Lifetime, + Result, + Token, + Type, // +}; + +pub(crate) enum HigherRankedType { + Explicit { + _for_token: Token![for], + _lt_token: Token![<], + lifetime: Lifetime, + _gt_token: Token![>], + ty: Type, + }, + Implicit { + ty: Type, + }, +} + +impl Parse for HigherRankedType { + fn parse(input: ParseStream<'_>) -> Result<Self> { + if input.peek(Token![for]) { + Ok(Self::Explicit { + _for_token: input.parse()?, + _lt_token: input.parse()?, + lifetime: input.parse()?, + _gt_token: input.parse()?, + ty: input.parse()?, + }) + } else { + Ok(Self::Implicit { ty: input.parse()? }) + } + } +} + +trait TypeExt { + fn expand_elided_lifetime(&self, explicit_lt: &Lifetime) -> Type; + fn replace_lifetime(&self, src: &Lifetime, dst: &Lifetime) -> Type; + fn has_lifetime(&self, lt: &Lifetime) -> bool; +} + +impl TypeExt for Type { + fn expand_elided_lifetime(&self, explicit_lt: &Lifetime) -> Type { + struct ElidedLifetimeExpander<'a>(&'a Lifetime); + + impl VisitMut for ElidedLifetimeExpander<'_> { + fn visit_lifetime_mut(&mut self, lifetime: &mut Lifetime) { + // Expand explicit `'_` + if lifetime.ident == "_" { + *lifetime = self.0.clone(); + } + } + + fn visit_type_reference_mut(&mut self, reference: &mut syn::TypeReference) { + syn::visit_mut::visit_type_reference_mut(self, reference); + + if reference.lifetime.is_none() { + reference.lifetime = Some(self.0.clone()); + } + } + } + + let mut ret = self.clone(); + ElidedLifetimeExpander(explicit_lt).visit_type_mut(&mut ret); + ret + } + + fn replace_lifetime(&self, src: &Lifetime, dst: &Lifetime) -> Type { + struct LifetimeReplacer<'a>(&'a Lifetime, &'a Lifetime); + + impl VisitMut for LifetimeReplacer<'_> { + fn visit_lifetime_mut(&mut self, lifetime: &mut Lifetime) { + if lifetime.ident == self.0.ident { + *lifetime = self.1.clone(); + } + } + } + + let mut ret = self.clone(); + LifetimeReplacer(src, dst).visit_type_mut(&mut ret); + ret + } + + fn has_lifetime(&self, lt: &Lifetime) -> bool { + struct HasLifetime<'a>(&'a Lifetime, bool); + + impl Visit<'_> for HasLifetime<'_> { + fn visit_lifetime(&mut self, lifetime: &Lifetime) { + if lifetime.ident == self.0.ident { + self.1 = true; + } + } + + // Macro invocations are opaque; conservatively assume they may + // reference the lifetime. + fn visit_macro(&mut self, _: &syn::Macro) { + self.1 = true; + } + } + + let mut visitor = HasLifetime(lt, false); + visitor.visit_type(self); + visitor.1 + } +} + +struct Prover<'a>(&'a Lifetime, Vec<&'a Type>); + +impl<'a> Prover<'a> { + /// Prove that `ty` is covariant over `'lt`. + /// + /// This also needs to prove that it'll be wellformed for any instance of `'lt`. + /// It can be assumed that `ty` will be wellformed if `'lt` is substituted to `'static`. + fn prove(&mut self, ty: &'a Type) { + match ty { + Type::Paren(ty) => self.prove(&ty.elem), + Type::Group(ty) => self.prove(&ty.elem), + + // No lifetime involved + Type::Never(_) => {} + + // `[T; N]` and `[T]` is covariant over `T`. + Type::Array(ty) => self.prove(&ty.elem), + Type::Slice(ty) => self.prove(&ty.elem), + + Type::Tuple(ty) => { + for elem in &ty.elems { + self.prove(elem); + } + } + + // `*const T` is covariant over `T` + Type::Ptr(ty) if ty.const_token.is_some() => self.prove(&ty.elem), + + // `&T` is covariant over `T` and lifetime. + // + // Note that if we encounter `&'other_lt T`, then we still need to make sure the type + // is wellformed if `T` involves `&'lt`, so we defer to the compiler. + // + // This is to block cases like `ForLt!(for<'a> &'static &'a u32)`, as the presence of + // the type implies `'a: 'static` but this is unsound. + Type::Reference(ty) + if ty.mutability.is_none() && ty.lifetime.as_ref() == Some(self.0) => + { + self.prove(&ty.elem) + } + + // `&[mut] T` is covariant over lifetime. + // In case we have `&[mut] NoLifetime`, we don't need to do additional checks. + Type::Reference(ty) if !ty.elem.has_lifetime(self.0) => (), + + // No mention of lifetime at all, no need to perform compiler check. + ty if !ty.has_lifetime(self.0) => (), + + // Otherwise, we need to emit checks so that compiler can determine if the types are + // actually covariant. + ty => self.1.push(ty), + } + } +} + +pub(crate) fn for_lt(input: HigherRankedType) -> TokenStream { + let (ty, lifetime) = match input { + HigherRankedType::Explicit { lifetime, ty, .. } => (ty, lifetime), + HigherRankedType::Implicit { ty } => { + // If there's no explicit `for<'a>` binder, inject a synthetic `'__elided` lifetime + // and expand elided sites. + let lifetime = Lifetime { + apostrophe: Span::mixed_site(), + ident: format_ident!("__elided", span = Span::mixed_site()), + }; + (ty.expand_elided_lifetime(&lifetime), lifetime) + } + }; + + let mut prover = Prover(&lifetime, Vec::new()); + prover.prove(&ty); + + let mut proof = Vec::new(); + + // Emit proofs for every type that requires additional compiler help in proving covariance. + for (idx, required_proof) in prover.1.into_iter().enumerate() { + // Insert a proof that the type is well-formed. + // + // This is intended to workaround a Rust compiler soundness bug related to HRTB. + // https://github.com/rust-lang/rust/issues/152489 + // + // This needs to be a struct instead of fn to avoid the implied WF bounds. + let wf_proof_name = format_ident!("ProveWf{idx}"); + proof.push(quote!( + struct #wf_proof_name<#lifetime>( + ::core::marker::PhantomData<&#lifetime ()>, #required_proof + ); + )); + + // Insert a proof that the type is covariant. + let cov_proof_name = format_ident!("prove_covariant_{idx}"); + proof.push(quote!( + fn #cov_proof_name<'__short, '__long: '__short>( + long: #wf_proof_name<'__long> + ) -> #wf_proof_name<'__short> { + long + } + )); + } + + // Make sure that the type is wellformed when substituting lifetime with `'static`. + // + // Currently the Rust compiler doesn't check this, see the above `ProveWf` documentation. + // + // We prefer to use this way of proving WF-ness as it can work when generics are involved. + let ty_static = ty.replace_lifetime( + &lifetime, + &Lifetime { + apostrophe: Span::mixed_site(), + ident: format_ident!("static"), + }, + ); + + quote!( + ::kernel::types::for_lt::UnsafeForLtImpl::< + dyn for<#lifetime> ::kernel::types::for_lt::WithLt<#lifetime, Of = #ty>, + #ty_static, + { + #(#proof)* + + 0 + } + > + ) +} diff --git a/rust/macros/lib.rs b/rust/macros/lib.rs index 2cfd59e0f9e7..4a48fabbc268 100644 --- a/rust/macros/lib.rs +++ b/rust/macros/lib.rs @@ -17,6 +17,7 @@ mod concat_idents; mod export; mod fmt; +mod for_lt; mod helpers; mod kunit; mod module; @@ -489,3 +490,15 @@ pub fn kunit_tests(attr: TokenStream, input: TokenStream) -> TokenStream { .unwrap_or_else(|e| e.into_compile_error()) .into() } + +/// Obtain a type that implements [`ForLt`] for the given higher-ranked type. +/// +/// Please refer to the documentation of the [`ForLt`] trait. +/// +/// [`ForLt`]: trait.ForLt.html +#[proc_macro] +// The macro shares the name with the trait. +#[allow(non_snake_case)] +pub fn ForLt(input: TokenStream) -> TokenStream { + for_lt::for_lt(parse_macro_input!(input)).into() +} |
