// SPDX-License-Identifier: GPL-2.0-or-later /* * Copyright (C) 2017, Microsoft Corporation. * Copyright (c) 2025, Stefan Metzmacher */ #include "internal.h" bool smbdirect_frwr_is_supported(const struct ib_device_attr *attrs) { /* * Test if FRWR (Fast Registration Work Requests) is supported on the * device This implementation requires FRWR on RDMA read/write return * value: true if it is supported */ if (!(attrs->device_cap_flags & IB_DEVICE_MEM_MGT_EXTENSIONS)) return false; if (attrs->max_fast_reg_page_list_len == 0) return false; return true; } __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_frwr_is_supported); static void smbdirect_socket_cleanup_work(struct work_struct *work); static int smbdirect_socket_rdma_event_handler(struct rdma_cm_id *id, struct rdma_cm_event *event) { struct smbdirect_socket *sc = id->context; int ret = -ESTALE; /* * This should be replaced before any real work * starts! So it should never be called! */ if (event->event == RDMA_CM_EVENT_DEVICE_REMOVAL) ret = -ENETDOWN; if (IS_ERR(SMBDIRECT_DEBUG_ERR_PTR(event->status))) ret = event->status; pr_err("%s (first_error=%1pe, expected=%s) => event=%s status=%d => ret=%1pe\n", smbdirect_socket_status_string(sc->status), SMBDIRECT_DEBUG_ERR_PTR(sc->first_error), rdma_event_msg(sc->rdma.expected_event), rdma_event_msg(event->event), event->status, SMBDIRECT_DEBUG_ERR_PTR(ret)); WARN_ONCE(1, "%s should not be called!\n", __func__); sc->rdma.cm_id = NULL; return -ESTALE; } int smbdirect_socket_init_new(struct net *net, struct smbdirect_socket *sc) { struct rdma_cm_id *id; int ret; smbdirect_socket_init(sc); id = rdma_create_id(net, smbdirect_socket_rdma_event_handler, sc, RDMA_PS_TCP, IB_QPT_RC); if (IS_ERR(id)) { pr_err("%s: rdma_create_id() failed %1pe\n", __func__, id); return PTR_ERR(id); } ret = rdma_set_afonly(id, 1); if (ret) { rdma_destroy_id(id); pr_err("%s: rdma_set_afonly() failed %1pe\n", __func__, SMBDIRECT_DEBUG_ERR_PTR(ret)); return ret; } sc->rdma.cm_id = id; INIT_WORK(&sc->disconnect_work, smbdirect_socket_cleanup_work); return 0; } int smbdirect_socket_create_kern(struct net *net, struct smbdirect_socket **_sc) { struct smbdirect_socket *sc; int ret; ret = -ENOMEM; sc = kzalloc_obj(*sc); if (!sc) goto alloc_failed; ret = smbdirect_socket_init_new(net, sc); if (ret) goto init_failed; kref_init(&sc->refs.destroy); *_sc = sc; return 0; init_failed: kfree(sc); alloc_failed: return ret; } __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_socket_create_kern); int smbdirect_socket_init_accepting(struct rdma_cm_id *id, struct smbdirect_socket *sc) { smbdirect_socket_init(sc); sc->rdma.cm_id = id; sc->rdma.cm_id->context = sc; sc->rdma.cm_id->event_handler = smbdirect_socket_rdma_event_handler; sc->ib.dev = sc->rdma.cm_id->device; INIT_WORK(&sc->disconnect_work, smbdirect_socket_cleanup_work); return 0; } int smbdirect_socket_create_accepting(struct rdma_cm_id *id, struct smbdirect_socket **_sc) { struct smbdirect_socket *sc; int ret; ret = -ENOMEM; sc = kzalloc_obj(*sc); if (!sc) goto alloc_failed; ret = smbdirect_socket_init_accepting(id, sc); if (ret) goto init_failed; kref_init(&sc->refs.destroy); *_sc = sc; return 0; init_failed: kfree(sc); alloc_failed: return ret; } __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_socket_create_accepting); int smbdirect_socket_set_initial_parameters(struct smbdirect_socket *sc, const struct smbdirect_socket_parameters *sp) { /* * This is only allowed before connect or accept */ WARN_ONCE(sc->status != SMBDIRECT_SOCKET_CREATED, "status=%s first_error=%1pe", smbdirect_socket_status_string(sc->status), SMBDIRECT_DEBUG_ERR_PTR(sc->first_error)); if (sc->status != SMBDIRECT_SOCKET_CREATED) return -EINVAL; if (sp->flags & ~SMBDIRECT_FLAG_PORT_RANGE_MASK) return -EINVAL; if (sp->initiator_depth > U8_MAX) return -EINVAL; if (sp->responder_resources > U8_MAX) return -EINVAL; if (sp->flags & SMBDIRECT_FLAG_PORT_RANGE_ONLY_IB && sp->flags & SMBDIRECT_FLAG_PORT_RANGE_ONLY_IW) return -EINVAL; else if (sp->flags & SMBDIRECT_FLAG_PORT_RANGE_ONLY_IB) rdma_restrict_node_type(sc->rdma.cm_id, RDMA_NODE_IB_CA); else if (sp->flags & SMBDIRECT_FLAG_PORT_RANGE_ONLY_IW) rdma_restrict_node_type(sc->rdma.cm_id, RDMA_NODE_RNIC); /* * Make a copy of the callers parameters * from here we only work on the copy * * TODO: do we want consistency checking? */ sc->parameters = *sp; return 0; } __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_socket_set_initial_parameters); const struct smbdirect_socket_parameters * smbdirect_socket_get_current_parameters(struct smbdirect_socket *sc) { return &sc->parameters; } __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_socket_get_current_parameters); int smbdirect_socket_set_kernel_settings(struct smbdirect_socket *sc, enum ib_poll_context poll_ctx, gfp_t gfp_mask) { /* * This is only allowed before connect or accept */ WARN_ONCE(sc->status != SMBDIRECT_SOCKET_CREATED, "status=%s first_error=%1pe", smbdirect_socket_status_string(sc->status), SMBDIRECT_DEBUG_ERR_PTR(sc->first_error)); if (sc->status != SMBDIRECT_SOCKET_CREATED) return -EINVAL; sc->ib.poll_ctx = poll_ctx; sc->send_io.mem.gfp_mask = gfp_mask; sc->recv_io.mem.gfp_mask = gfp_mask; sc->rw_io.mem.gfp_mask = gfp_mask; return 0; } __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_socket_set_kernel_settings); void smbdirect_socket_set_logging(struct smbdirect_socket *sc, void *private_ptr, bool (*needed)(struct smbdirect_socket *sc, void *private_ptr, unsigned int lvl, unsigned int cls), void (*vaprintf)(struct smbdirect_socket *sc, const char *func, unsigned int line, void *private_ptr, unsigned int lvl, unsigned int cls, struct va_format *vaf)) { sc->logging.private_ptr = private_ptr; sc->logging.needed = needed; sc->logging.vaprintf = vaprintf; } __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_socket_set_logging); static void smbdirect_socket_wake_up_all(struct smbdirect_socket *sc) { /* * Wake up all waiters in all wait queues * in order to notice the broken connection. */ wake_up_all(&sc->status_wait); wake_up_all(&sc->listen.wait_queue); wake_up_all(&sc->send_io.bcredits.wait_queue); wake_up_all(&sc->send_io.lcredits.wait_queue); wake_up_all(&sc->send_io.credits.wait_queue); wake_up_all(&sc->send_io.pending.zero_wait_queue); wake_up_all(&sc->recv_io.reassembly.wait_queue); wake_up_all(&sc->rw_io.credits.wait_queue); wake_up_all(&sc->mr_io.ready.wait_queue); } void __smbdirect_socket_schedule_cleanup(struct smbdirect_socket *sc, const char *macro_name, unsigned int lvl, const char *func, unsigned int line, int error, enum smbdirect_socket_status *force_status) { struct smbdirect_socket *psc, *tsc; unsigned long flags; bool was_first = false; if (!sc->first_error) { ___smbdirect_log_generic(sc, func, line, lvl, SMBDIRECT_LOG_RDMA_EVENT, "%s(%1pe%s%s) called from %s in line=%u status=%s\n", macro_name, SMBDIRECT_DEBUG_ERR_PTR(error), force_status ? ", " : "", force_status ? smbdirect_socket_status_string(*force_status) : "", func, line, smbdirect_socket_status_string(sc->status)); if (error) sc->first_error = error; else sc->first_error = -ECONNABORTED; was_first = true; } /* * make sure other work (than disconnect_work) * is not queued again but here we don't block and avoid * disable[_delayed]_work_sync() */ disable_work(&sc->connect.work); disable_work(&sc->recv_io.posted.refill_work); disable_work(&sc->idle.immediate_work); sc->idle.keepalive = SMBDIRECT_KEEPALIVE_NONE; disable_delayed_work(&sc->idle.timer_work); /* * In case we were a listener we need to * disconnect all pending and ready sockets * * First we move ready sockets to pending again. */ spin_lock_irqsave(&sc->listen.lock, flags); list_splice_init(&sc->listen.ready, &sc->listen.pending); list_for_each_entry_safe(psc, tsc, &sc->listen.pending, accept.list) smbdirect_socket_schedule_cleanup(psc, sc->first_error); spin_unlock_irqrestore(&sc->listen.lock, flags); switch (sc->status) { case SMBDIRECT_SOCKET_RESOLVE_ADDR_FAILED: case SMBDIRECT_SOCKET_RESOLVE_ROUTE_FAILED: case SMBDIRECT_SOCKET_RDMA_CONNECT_FAILED: case SMBDIRECT_SOCKET_NEGOTIATE_FAILED: case SMBDIRECT_SOCKET_ERROR: case SMBDIRECT_SOCKET_DISCONNECTING: case SMBDIRECT_SOCKET_DISCONNECTED: case SMBDIRECT_SOCKET_DESTROYED: /* * Keep the current error status */ break; case SMBDIRECT_SOCKET_RESOLVE_ADDR_NEEDED: case SMBDIRECT_SOCKET_RESOLVE_ADDR_RUNNING: sc->status = SMBDIRECT_SOCKET_RESOLVE_ADDR_FAILED; break; case SMBDIRECT_SOCKET_RESOLVE_ROUTE_NEEDED: case SMBDIRECT_SOCKET_RESOLVE_ROUTE_RUNNING: sc->status = SMBDIRECT_SOCKET_RESOLVE_ROUTE_FAILED; break; case SMBDIRECT_SOCKET_RDMA_CONNECT_NEEDED: case SMBDIRECT_SOCKET_RDMA_CONNECT_RUNNING: sc->status = SMBDIRECT_SOCKET_RDMA_CONNECT_FAILED; break; case SMBDIRECT_SOCKET_NEGOTIATE_NEEDED: case SMBDIRECT_SOCKET_NEGOTIATE_RUNNING: sc->status = SMBDIRECT_SOCKET_NEGOTIATE_FAILED; break; case SMBDIRECT_SOCKET_CREATED: case SMBDIRECT_SOCKET_LISTENING: sc->status = SMBDIRECT_SOCKET_DISCONNECTED; break; case SMBDIRECT_SOCKET_CONNECTED: sc->status = SMBDIRECT_SOCKET_ERROR; break; } if (force_status && (was_first || *force_status > sc->status)) sc->status = *force_status; /* * Wake up all waiters in all wait queues * in order to notice the broken connection. */ smbdirect_socket_wake_up_all(sc); queue_work(sc->workqueues.cleanup, &sc->disconnect_work); } static void smbdirect_socket_cleanup_work(struct work_struct *work) { struct smbdirect_socket *sc = container_of(work, struct smbdirect_socket, disconnect_work); struct smbdirect_socket *psc, *tsc; unsigned long flags; /* * This should not never be called in an interrupt! */ WARN_ON_ONCE(in_interrupt()); if (!sc->first_error) { smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_ERR, "%s called with first_error==0\n", smbdirect_socket_status_string(sc->status)); sc->first_error = -ECONNABORTED; } /* * make sure this and other work is not queued again * but here we don't block and avoid * disable[_delayed]_work_sync() */ disable_work(&sc->disconnect_work); disable_work(&sc->connect.work); disable_work(&sc->recv_io.posted.refill_work); disable_work(&sc->idle.immediate_work); sc->idle.keepalive = SMBDIRECT_KEEPALIVE_NONE; disable_delayed_work(&sc->idle.timer_work); /* * In case we were a listener we need to * disconnect all pending and ready sockets * * First we move ready sockets to pending again. */ spin_lock_irqsave(&sc->listen.lock, flags); list_splice_init(&sc->listen.ready, &sc->listen.pending); list_for_each_entry_safe(psc, tsc, &sc->listen.pending, accept.list) smbdirect_socket_schedule_cleanup(psc, sc->first_error); spin_unlock_irqrestore(&sc->listen.lock, flags); switch (sc->status) { case SMBDIRECT_SOCKET_NEGOTIATE_NEEDED: case SMBDIRECT_SOCKET_NEGOTIATE_RUNNING: case SMBDIRECT_SOCKET_NEGOTIATE_FAILED: case SMBDIRECT_SOCKET_CONNECTED: case SMBDIRECT_SOCKET_ERROR: sc->status = SMBDIRECT_SOCKET_DISCONNECTING; /* * Make sure we hold the callback lock * im order to coordinate with the * rdma_event handlers, typically * smbdirect_connection_rdma_event_handler(), * and smbdirect_socket_destroy(). * * So that the order of ib_drain_qp() * and rdma_disconnect() is controlled * by the mutex. */ rdma_lock_handler(sc->rdma.cm_id); rdma_disconnect(sc->rdma.cm_id); rdma_unlock_handler(sc->rdma.cm_id); break; case SMBDIRECT_SOCKET_CREATED: case SMBDIRECT_SOCKET_LISTENING: case SMBDIRECT_SOCKET_RESOLVE_ADDR_NEEDED: case SMBDIRECT_SOCKET_RESOLVE_ADDR_RUNNING: case SMBDIRECT_SOCKET_RESOLVE_ADDR_FAILED: case SMBDIRECT_SOCKET_RESOLVE_ROUTE_NEEDED: case SMBDIRECT_SOCKET_RESOLVE_ROUTE_RUNNING: case SMBDIRECT_SOCKET_RESOLVE_ROUTE_FAILED: case SMBDIRECT_SOCKET_RDMA_CONNECT_NEEDED: case SMBDIRECT_SOCKET_RDMA_CONNECT_RUNNING: case SMBDIRECT_SOCKET_RDMA_CONNECT_FAILED: /* * rdma_{accept,connect}() never reached * RDMA_CM_EVENT_ESTABLISHED */ sc->status = SMBDIRECT_SOCKET_DISCONNECTED; break; case SMBDIRECT_SOCKET_DISCONNECTING: case SMBDIRECT_SOCKET_DISCONNECTED: case SMBDIRECT_SOCKET_DESTROYED: break; } /* * Wake up all waiters in all wait queues * in order to notice the broken connection. */ smbdirect_socket_wake_up_all(sc); } static void smbdirect_socket_destroy(struct smbdirect_socket *sc) { struct smbdirect_socket *psc, *tsc; size_t psockets; struct smbdirect_recv_io *recv_io; struct smbdirect_recv_io *recv_tmp; LIST_HEAD(all_list); unsigned long flags; smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO, "status=%s first_error=%1pe", smbdirect_socket_status_string(sc->status), SMBDIRECT_DEBUG_ERR_PTR(sc->first_error)); /* * This should not never be called in an interrupt! */ WARN_ON_ONCE(in_interrupt()); if (sc->status == SMBDIRECT_SOCKET_DESTROYED) return; WARN_ONCE(sc->status != SMBDIRECT_SOCKET_DISCONNECTED, "status=%s first_error=%1pe", smbdirect_socket_status_string(sc->status), SMBDIRECT_DEBUG_ERR_PTR(sc->first_error)); /* * The listener should clear this before we reach this */ WARN_ONCE(sc->accept.listener, "status=%s first_error=%1pe", smbdirect_socket_status_string(sc->status), SMBDIRECT_DEBUG_ERR_PTR(sc->first_error)); /* * Wake up all waiters in all wait queues * in order to notice the broken connection. * * Most likely this was already called via * smbdirect_socket_cleanup_work(), but call it again... */ smbdirect_socket_wake_up_all(sc); disable_work_sync(&sc->disconnect_work); disable_work_sync(&sc->connect.work); disable_work_sync(&sc->recv_io.posted.refill_work); disable_work_sync(&sc->idle.immediate_work); disable_delayed_work_sync(&sc->idle.timer_work); if (sc->rdma.cm_id) rdma_lock_handler(sc->rdma.cm_id); if (sc->ib.qp) { smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO, "drain qp\n"); ib_drain_qp(sc->ib.qp); } /* * In case we were a listener we need to * disconnect all pending and ready sockets * * We move ready sockets to pending again. */ spin_lock_irqsave(&sc->listen.lock, flags); list_splice_tail_init(&sc->listen.ready, &all_list); list_splice_tail_init(&sc->listen.pending, &all_list); spin_unlock_irqrestore(&sc->listen.lock, flags); psockets = list_count_nodes(&all_list); if (sc->listen.backlog != -1) /* was a listener */ smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO, "release %zu pending sockets\n", psockets); list_for_each_entry_safe(psc, tsc, &all_list, accept.list) { list_del_init(&psc->accept.list); psc->accept.listener = NULL; smbdirect_socket_release(psc); } if (sc->listen.backlog != -1) /* was a listener */ smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO, "released %zu pending sockets\n", psockets); INIT_LIST_HEAD(&all_list); /* It's not possible for upper layer to get to reassembly */ if (sc->listen.backlog == -1) /* was not a listener */ smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO, "drain the reassembly queue\n"); spin_lock_irqsave(&sc->recv_io.reassembly.lock, flags); list_splice_tail_init(&sc->recv_io.reassembly.list, &all_list); spin_unlock_irqrestore(&sc->recv_io.reassembly.lock, flags); list_for_each_entry_safe(recv_io, recv_tmp, &all_list, list) smbdirect_connection_put_recv_io(recv_io); sc->recv_io.reassembly.data_length = 0; if (sc->listen.backlog == -1) /* was not a listener */ smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO, "freeing mr list\n"); smbdirect_connection_destroy_mr_list(sc); if (sc->listen.backlog == -1) /* was not a listener */ smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO, "destroying qp\n"); smbdirect_connection_destroy_qp(sc); if (sc->rdma.cm_id) { rdma_unlock_handler(sc->rdma.cm_id); smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO, "destroying cm_id\n"); rdma_destroy_id(sc->rdma.cm_id); sc->rdma.cm_id = NULL; } if (sc->listen.backlog == -1) /* was not a listener */ smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO, "destroying mem pools\n"); smbdirect_connection_destroy_mem_pools(sc); sc->status = SMBDIRECT_SOCKET_DESTROYED; smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO, "rdma session destroyed\n"); } void smbdirect_socket_destroy_sync(struct smbdirect_socket *sc) { smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO, "status=%s first_error=%1pe", smbdirect_socket_status_string(sc->status), SMBDIRECT_DEBUG_ERR_PTR(sc->first_error)); /* * This should not never be called in an interrupt! */ WARN_ON_ONCE(in_interrupt()); /* * First we try to disable the work * without disable_work_sync() in a * non blocking way, if it's already * running it will be handles by * disable_work_sync() below. * * Here we just want to make sure queue_work() in * smbdirect_socket_schedule_cleanup_lvl() * is a no-op. */ disable_work(&sc->disconnect_work); if (!sc->first_error) /* * SMBDIRECT_LOG_INFO is enough here * as this is the typical case where * we terminate the connection ourself. */ smbdirect_socket_schedule_cleanup_lvl(sc, SMBDIRECT_LOG_INFO, -ESHUTDOWN); smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO, "cancelling and disable disconnect_work\n"); disable_work_sync(&sc->disconnect_work); smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO, "destroying rdma session\n"); if (sc->status < SMBDIRECT_SOCKET_DISCONNECTING) smbdirect_socket_cleanup_work(&sc->disconnect_work); if (sc->status < SMBDIRECT_SOCKET_DISCONNECTED) { smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO, "wait for transport being disconnected\n"); wait_event(sc->status_wait, sc->status == SMBDIRECT_SOCKET_DISCONNECTED); smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO, "waited for transport being disconnected\n"); } /* * Once we reached SMBDIRECT_SOCKET_DISCONNECTED, * we should call smbdirect_socket_destroy() */ smbdirect_socket_destroy(sc); smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO, "status=%s first_error=%1pe", smbdirect_socket_status_string(sc->status), SMBDIRECT_DEBUG_ERR_PTR(sc->first_error)); } int smbdirect_socket_bind(struct smbdirect_socket *sc, struct sockaddr *addr) { int ret; if (sc->status != SMBDIRECT_SOCKET_CREATED) return -EINVAL; ret = rdma_bind_addr(sc->rdma.cm_id, addr); if (ret) return ret; return 0; } __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_socket_bind); void smbdirect_socket_shutdown(struct smbdirect_socket *sc) { smbdirect_socket_schedule_cleanup_lvl(sc, SMBDIRECT_LOG_INFO, -ESHUTDOWN); } __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_socket_shutdown); static void smbdirect_socket_release_disconnect(struct kref *kref) { struct smbdirect_socket *sc = container_of(kref, struct smbdirect_socket, refs.disconnect); /* * For now do a sync disconnect/destroy */ smbdirect_socket_destroy_sync(sc); } static void smbdirect_socket_release_destroy(struct kref *kref) { struct smbdirect_socket *sc = container_of(kref, struct smbdirect_socket, refs.destroy); /* * Do a sync disconnect/destroy... * hopefully a no-op, as it should be already * in DESTROYED state, before we free the memory. */ smbdirect_socket_destroy_sync(sc); kfree(sc); } void smbdirect_socket_release(struct smbdirect_socket *sc) { /* * We expect only 1 disconnect reference * and if it is already 0, it's a use after free! */ WARN_ON_ONCE(kref_read(&sc->refs.disconnect) != 1); WARN_ON(!kref_put(&sc->refs.disconnect, smbdirect_socket_release_disconnect)); /* * This may not trigger smbdirect_socket_release_destroy(), * if struct smbdirect_socket is embedded in another structure * indicated by REFCOUNT_MAX. */ kref_put(&sc->refs.destroy, smbdirect_socket_release_destroy); } __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_socket_release); int smbdirect_socket_wait_for_credits(struct smbdirect_socket *sc, enum smbdirect_socket_status expected_status, int unexpected_errno, wait_queue_head_t *waitq, atomic_t *total_credits, int needed) { int ret; if (WARN_ON_ONCE(needed < 0)) return -EINVAL; do { if (atomic_sub_return(needed, total_credits) >= 0) return 0; atomic_add(needed, total_credits); ret = wait_event_interruptible(*waitq, atomic_read(total_credits) >= needed || sc->status != expected_status); if (sc->status != expected_status) return unexpected_errno; else if (ret < 0) return ret; } while (true); }