/* SPDX-License-Identifier: GPL-2.0-or-later */ /* * Copyright (c) 2025 Stefan Metzmacher */ #ifndef __FS_SMB_COMMON_SMBDIRECT_SMBDIRECT_SOCKET_H__ #define __FS_SMB_COMMON_SMBDIRECT_SMBDIRECT_SOCKET_H__ #include #include #include #include #include #include #include #include enum smbdirect_socket_status { SMBDIRECT_SOCKET_CREATED, SMBDIRECT_SOCKET_LISTENING, SMBDIRECT_SOCKET_RESOLVE_ADDR_NEEDED, SMBDIRECT_SOCKET_RESOLVE_ADDR_RUNNING, SMBDIRECT_SOCKET_RESOLVE_ADDR_FAILED, SMBDIRECT_SOCKET_RESOLVE_ROUTE_NEEDED, SMBDIRECT_SOCKET_RESOLVE_ROUTE_RUNNING, SMBDIRECT_SOCKET_RESOLVE_ROUTE_FAILED, SMBDIRECT_SOCKET_RDMA_CONNECT_NEEDED, SMBDIRECT_SOCKET_RDMA_CONNECT_RUNNING, SMBDIRECT_SOCKET_RDMA_CONNECT_FAILED, SMBDIRECT_SOCKET_NEGOTIATE_NEEDED, SMBDIRECT_SOCKET_NEGOTIATE_RUNNING, SMBDIRECT_SOCKET_NEGOTIATE_FAILED, SMBDIRECT_SOCKET_CONNECTED, SMBDIRECT_SOCKET_ERROR, SMBDIRECT_SOCKET_DISCONNECTING, SMBDIRECT_SOCKET_DISCONNECTED, SMBDIRECT_SOCKET_DESTROYED }; static __always_inline const char *smbdirect_socket_status_string(enum smbdirect_socket_status status) { switch (status) { case SMBDIRECT_SOCKET_CREATED: return "CREATED"; case SMBDIRECT_SOCKET_LISTENING: return "LISTENING"; case SMBDIRECT_SOCKET_RESOLVE_ADDR_NEEDED: return "RESOLVE_ADDR_NEEDED"; case SMBDIRECT_SOCKET_RESOLVE_ADDR_RUNNING: return "RESOLVE_ADDR_RUNNING"; case SMBDIRECT_SOCKET_RESOLVE_ADDR_FAILED: return "RESOLVE_ADDR_FAILED"; case SMBDIRECT_SOCKET_RESOLVE_ROUTE_NEEDED: return "RESOLVE_ROUTE_NEEDED"; case SMBDIRECT_SOCKET_RESOLVE_ROUTE_RUNNING: return "RESOLVE_ROUTE_RUNNING"; case SMBDIRECT_SOCKET_RESOLVE_ROUTE_FAILED: return "RESOLVE_ROUTE_FAILED"; case SMBDIRECT_SOCKET_RDMA_CONNECT_NEEDED: return "RDMA_CONNECT_NEEDED"; case SMBDIRECT_SOCKET_RDMA_CONNECT_RUNNING: return "RDMA_CONNECT_RUNNING"; case SMBDIRECT_SOCKET_RDMA_CONNECT_FAILED: return "RDMA_CONNECT_FAILED"; case SMBDIRECT_SOCKET_NEGOTIATE_NEEDED: return "NEGOTIATE_NEEDED"; case SMBDIRECT_SOCKET_NEGOTIATE_RUNNING: return "NEGOTIATE_RUNNING"; case SMBDIRECT_SOCKET_NEGOTIATE_FAILED: return "NEGOTIATE_FAILED"; case SMBDIRECT_SOCKET_CONNECTED: return "CONNECTED"; case SMBDIRECT_SOCKET_ERROR: return "ERROR"; case SMBDIRECT_SOCKET_DISCONNECTING: return "DISCONNECTING"; case SMBDIRECT_SOCKET_DISCONNECTED: return "DISCONNECTED"; case SMBDIRECT_SOCKET_DESTROYED: return "DESTROYED"; } return ""; } /* * This can be used with %1pe to print errors as strings or '0' * And it avoids warnings like: warn: passing zero to 'ERR_PTR' * from smatch -p=kernel --pedantic */ static __always_inline const void * __must_check SMBDIRECT_DEBUG_ERR_PTR(long error) { if (error == 0) return NULL; return ERR_PTR(error); } enum smbdirect_keepalive_status { SMBDIRECT_KEEPALIVE_NONE, SMBDIRECT_KEEPALIVE_PENDING, SMBDIRECT_KEEPALIVE_SENT }; struct smbdirect_socket { enum smbdirect_socket_status status; wait_queue_head_t status_wait; int first_error; /* * This points to the workqueues to * be used for this socket. */ struct { struct workqueue_struct *accept; struct workqueue_struct *connect; struct workqueue_struct *idle; struct workqueue_struct *refill; struct workqueue_struct *immediate; struct workqueue_struct *cleanup; } workqueues; struct work_struct disconnect_work; /* * The reference counts. */ struct { /* * This holds the references by the * frontend, typically the smb layer. * * It is typically 1 and a disconnect * will happen if it reaches 0. */ struct kref disconnect; /* * This holds the reference by the * backend, the code that manages * the lifetime of the whole * struct smbdirect_socket, * if this reaches 0 it can will * be freed. * * Can be REFCOUNT_MAX is part * of another structure. * * This is equal or higher than * the disconnect refcount. */ struct kref destroy; } refs; /* RDMA related */ struct { struct rdma_cm_id *cm_id; /* * The expected event in our current * cm_id->event_handler, all other events * are treated as an error. */ enum rdma_cm_event_type expected_event; /* * This is for iWarp MPA v1 */ bool legacy_iwarp; } rdma; /* IB verbs related */ struct { struct ib_pd *pd; enum ib_poll_context poll_ctx; struct ib_cq *send_cq; struct ib_cq *recv_cq; /* * shortcuts for rdma.cm_id->{qp,device}; */ struct ib_qp *qp; struct ib_device *dev; } ib; struct smbdirect_socket_parameters parameters; /* * The state for connect/negotiation */ struct { spinlock_t lock; struct work_struct work; } connect; /* * The state for keepalive and timeout handling */ struct { enum smbdirect_keepalive_status keepalive; struct work_struct immediate_work; struct delayed_work timer_work; } idle; /* * The state for listen sockets */ struct { spinlock_t lock; struct list_head pending; struct list_head ready; wait_queue_head_t wait_queue; /* * This starts as -1 and a value != -1 * means this socket was in LISTENING state * before. Note the valid backlog can * only be > 0. */ int backlog; } listen; /* * The state for sockets waiting * for accept, either still waiting * for the negotiation to finish * or already ready with a usable * connection. */ struct { struct smbdirect_socket *listener; struct list_head list; } accept; /* * The state for posted send buffers */ struct { /* * Memory pools for preallocating * smbdirect_send_io buffers */ struct { struct kmem_cache *cache; mempool_t *pool; gfp_t gfp_mask; } mem; /* * This is a coordination for smbdirect_send_batch. * * There's only one possible credit, which means * only one instance is running at a time. */ struct { atomic_t count; wait_queue_head_t wait_queue; } bcredits; /* * The local credit state for ib_post_send() */ struct { atomic_t count; wait_queue_head_t wait_queue; } lcredits; /* * The remote credit state for the send side */ struct { atomic_t count; wait_queue_head_t wait_queue; } credits; /* * The state about posted/pending sends */ struct { atomic_t count; /* * woken when count reached zero */ wait_queue_head_t zero_wait_queue; } pending; } send_io; /* * The state for posted receive buffers */ struct { /* * The type of PDU we are expecting */ enum { SMBDIRECT_EXPECT_NEGOTIATE_REQ = 1, SMBDIRECT_EXPECT_NEGOTIATE_REP = 2, SMBDIRECT_EXPECT_DATA_TRANSFER = 3, } expected; /* * Memory pools for preallocating * smbdirect_recv_io buffers */ struct { struct kmem_cache *cache; mempool_t *pool; gfp_t gfp_mask; } mem; /* * The list of free smbdirect_recv_io * structures */ struct { struct list_head list; spinlock_t lock; } free; /* * The state for posted recv_io messages * and the refill work struct. */ struct { atomic_t count; struct work_struct refill_work; } posted; /* * The credit state for the recv side */ struct { u16 target; atomic_t available; atomic_t count; } credits; /* * The list of arrived non-empty smbdirect_recv_io * structures * * This represents the reassembly queue. */ struct { struct list_head list; spinlock_t lock; wait_queue_head_t wait_queue; /* total data length of reassembly queue */ int data_length; int queue_length; /* the offset to first buffer in reassembly queue */ int first_entry_offset; /* * Indicate if we have received a full packet on the * connection This is used to identify the first SMBD * packet of a assembled payload (SMB packet) in * reassembly queue so we can return a RFC1002 length to * upper layer to indicate the length of the SMB packet * received */ bool full_packet_received; } reassembly; } recv_io; /* * The state for Memory registrations on the client */ struct { enum ib_mr_type type; /* * The list of free smbdirect_mr_io * structures */ struct { struct list_head list; spinlock_t lock; } all; /* * The number of available MRs ready for memory registration */ struct { atomic_t count; wait_queue_head_t wait_queue; } ready; /* * The number of used MRs */ struct { atomic_t count; } used; } mr_io; /* * The state for RDMA read/write requests on the server */ struct { /* * Memory hints for * smbdirect_rw_io structs */ struct { gfp_t gfp_mask; } mem; /* * The credit state for the send side */ struct { /* * The maximum number of rw credits */ size_t max; /* * The number of pages per credit */ size_t num_pages; atomic_t count; wait_queue_head_t wait_queue; } credits; } rw_io; /* * For debug purposes */ struct { u64 get_receive_buffer; u64 put_receive_buffer; u64 enqueue_reassembly_queue; u64 dequeue_reassembly_queue; u64 send_empty; } statistics; struct { void *private_ptr; bool (*needed)(struct smbdirect_socket *sc, void *private_ptr, unsigned int lvl, unsigned int cls); void (*vaprintf)(struct smbdirect_socket *sc, const char *func, unsigned int line, void *private_ptr, unsigned int lvl, unsigned int cls, struct va_format *vaf); } logging; }; static void __smbdirect_socket_disabled_work(struct work_struct *work) { /* * Should never be called as disable_[delayed_]work_sync() was used. */ WARN_ON_ONCE(1); } static bool __smbdirect_log_needed(struct smbdirect_socket *sc, void *private_ptr, unsigned int lvl, unsigned int cls) { /* * Should never be called, the caller should * set it's own functions. */ WARN_ON_ONCE(1); return false; } static void __smbdirect_log_vaprintf(struct smbdirect_socket *sc, const char *func, unsigned int line, void *private_ptr, unsigned int lvl, unsigned int cls, struct va_format *vaf) { /* * Should never be called, the caller should * set it's own functions. */ WARN_ON_ONCE(1); } __printf(6, 7) static void __smbdirect_log_printf(struct smbdirect_socket *sc, const char *func, unsigned int line, unsigned int lvl, unsigned int cls, const char *fmt, ...); __maybe_unused static void __smbdirect_log_printf(struct smbdirect_socket *sc, const char *func, unsigned int line, unsigned int lvl, unsigned int cls, const char *fmt, ...) { struct va_format vaf; va_list args; va_start(args, fmt); vaf.fmt = fmt; vaf.va = &args; sc->logging.vaprintf(sc, func, line, sc->logging.private_ptr, lvl, cls, &vaf); va_end(args); } #define ___smbdirect_log_generic(sc, func, line, lvl, cls, fmt, args...) do { \ if (sc->logging.needed(sc, sc->logging.private_ptr, lvl, cls)) { \ __smbdirect_log_printf(sc, func, line, lvl, cls, fmt, ##args); \ } \ } while (0) #define __smbdirect_log_generic(sc, lvl, cls, fmt, args...) \ ___smbdirect_log_generic(sc, __func__, __LINE__, lvl, cls, fmt, ##args) #define smbdirect_log_outgoing(sc, lvl, fmt, args...) \ __smbdirect_log_generic(sc, lvl, SMBDIRECT_LOG_OUTGOING, fmt, ##args) #define smbdirect_log_incoming(sc, lvl, fmt, args...) \ __smbdirect_log_generic(sc, lvl, SMBDIRECT_LOG_INCOMING, fmt, ##args) #define smbdirect_log_read(sc, lvl, fmt, args...) \ __smbdirect_log_generic(sc, lvl, SMBDIRECT_LOG_READ, fmt, ##args) #define smbdirect_log_write(sc, lvl, fmt, args...) \ __smbdirect_log_generic(sc, lvl, SMBDIRECT_LOG_WRITE, fmt, ##args) #define smbdirect_log_rdma_send(sc, lvl, fmt, args...) \ __smbdirect_log_generic(sc, lvl, SMBDIRECT_LOG_RDMA_SEND, fmt, ##args) #define smbdirect_log_rdma_recv(sc, lvl, fmt, args...) \ __smbdirect_log_generic(sc, lvl, SMBDIRECT_LOG_RDMA_RECV, fmt, ##args) #define smbdirect_log_keep_alive(sc, lvl, fmt, args...) \ __smbdirect_log_generic(sc, lvl, SMBDIRECT_LOG_KEEP_ALIVE, fmt, ##args) #define smbdirect_log_rdma_event(sc, lvl, fmt, args...) \ __smbdirect_log_generic(sc, lvl, SMBDIRECT_LOG_RDMA_EVENT, fmt, ##args) #define smbdirect_log_rdma_mr(sc, lvl, fmt, args...) \ __smbdirect_log_generic(sc, lvl, SMBDIRECT_LOG_RDMA_MR, fmt, ##args) #define smbdirect_log_rdma_rw(sc, lvl, fmt, args...) \ __smbdirect_log_generic(sc, lvl, SMBDIRECT_LOG_RDMA_RW, fmt, ##args) #define smbdirect_log_negotiate(sc, lvl, fmt, args...) \ __smbdirect_log_generic(sc, lvl, SMBDIRECT_LOG_NEGOTIATE, fmt, ##args) static __always_inline void smbdirect_socket_init(struct smbdirect_socket *sc) { /* * This also sets status = SMBDIRECT_SOCKET_CREATED */ BUILD_BUG_ON(SMBDIRECT_SOCKET_CREATED != 0); memset(sc, 0, sizeof(*sc)); init_waitqueue_head(&sc->status_wait); sc->workqueues.accept = smbdirect_globals.workqueues.accept; sc->workqueues.connect = smbdirect_globals.workqueues.connect; sc->workqueues.idle = smbdirect_globals.workqueues.idle; sc->workqueues.refill = smbdirect_globals.workqueues.refill; sc->workqueues.immediate = smbdirect_globals.workqueues.immediate; sc->workqueues.cleanup = smbdirect_globals.workqueues.cleanup; INIT_WORK(&sc->disconnect_work, __smbdirect_socket_disabled_work); disable_work_sync(&sc->disconnect_work); kref_init(&sc->refs.disconnect); sc->refs.destroy = (struct kref) KREF_INIT(REFCOUNT_MAX); sc->rdma.expected_event = RDMA_CM_EVENT_INTERNAL; sc->ib.poll_ctx = IB_POLL_UNBOUND_WORKQUEUE; spin_lock_init(&sc->connect.lock); INIT_WORK(&sc->connect.work, __smbdirect_socket_disabled_work); disable_work_sync(&sc->connect.work); INIT_WORK(&sc->idle.immediate_work, __smbdirect_socket_disabled_work); disable_work_sync(&sc->idle.immediate_work); INIT_DELAYED_WORK(&sc->idle.timer_work, __smbdirect_socket_disabled_work); disable_delayed_work_sync(&sc->idle.timer_work); spin_lock_init(&sc->listen.lock); INIT_LIST_HEAD(&sc->listen.pending); INIT_LIST_HEAD(&sc->listen.ready); sc->listen.backlog = -1; /* not a listener */ init_waitqueue_head(&sc->listen.wait_queue); INIT_LIST_HEAD(&sc->accept.list); sc->send_io.mem.gfp_mask = GFP_KERNEL; atomic_set(&sc->send_io.bcredits.count, 0); init_waitqueue_head(&sc->send_io.bcredits.wait_queue); atomic_set(&sc->send_io.lcredits.count, 0); init_waitqueue_head(&sc->send_io.lcredits.wait_queue); atomic_set(&sc->send_io.credits.count, 0); init_waitqueue_head(&sc->send_io.credits.wait_queue); atomic_set(&sc->send_io.pending.count, 0); init_waitqueue_head(&sc->send_io.pending.zero_wait_queue); sc->recv_io.mem.gfp_mask = GFP_KERNEL; INIT_LIST_HEAD(&sc->recv_io.free.list); spin_lock_init(&sc->recv_io.free.lock); atomic_set(&sc->recv_io.posted.count, 0); INIT_WORK(&sc->recv_io.posted.refill_work, __smbdirect_socket_disabled_work); disable_work_sync(&sc->recv_io.posted.refill_work); atomic_set(&sc->recv_io.credits.available, 0); atomic_set(&sc->recv_io.credits.count, 0); INIT_LIST_HEAD(&sc->recv_io.reassembly.list); spin_lock_init(&sc->recv_io.reassembly.lock); init_waitqueue_head(&sc->recv_io.reassembly.wait_queue); sc->rw_io.mem.gfp_mask = GFP_KERNEL; atomic_set(&sc->rw_io.credits.count, 0); init_waitqueue_head(&sc->rw_io.credits.wait_queue); spin_lock_init(&sc->mr_io.all.lock); INIT_LIST_HEAD(&sc->mr_io.all.list); atomic_set(&sc->mr_io.ready.count, 0); init_waitqueue_head(&sc->mr_io.ready.wait_queue); atomic_set(&sc->mr_io.used.count, 0); sc->logging.private_ptr = NULL; sc->logging.needed = __smbdirect_log_needed; sc->logging.vaprintf = __smbdirect_log_vaprintf; } #define __SMBDIRECT_CHECK_STATUS_FAILED(__sc, __expected_status, __error_cmd, __unexpected_cmd) ({ \ bool __failed = false; \ if (unlikely((__sc)->first_error)) { \ __failed = true; \ __error_cmd \ } else if (unlikely((__sc)->status != (__expected_status))) { \ __failed = true; \ __unexpected_cmd \ } \ __failed; \ }) #define __SMBDIRECT_CHECK_STATUS_WARN(__sc, __expected_status, __unexpected_cmd) \ __SMBDIRECT_CHECK_STATUS_FAILED(__sc, __expected_status, \ { \ const struct sockaddr_storage *__src = NULL; \ const struct sockaddr_storage *__dst = NULL; \ if ((__sc)->rdma.cm_id) { \ __src = &(__sc)->rdma.cm_id->route.addr.src_addr; \ __dst = &(__sc)->rdma.cm_id->route.addr.dst_addr; \ } \ smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO, \ "expected[%s] != %s first_error=%1pe local=%pISpsfc remote=%pISpsfc\n", \ smbdirect_socket_status_string(__expected_status), \ smbdirect_socket_status_string((__sc)->status), \ SMBDIRECT_DEBUG_ERR_PTR((__sc)->first_error), \ __src, __dst); \ }, \ { \ const struct sockaddr_storage *__src = NULL; \ const struct sockaddr_storage *__dst = NULL; \ if ((__sc)->rdma.cm_id) { \ __src = &(__sc)->rdma.cm_id->route.addr.src_addr; \ __dst = &(__sc)->rdma.cm_id->route.addr.dst_addr; \ } \ smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_ERR, \ "expected[%s] != %s first_error=%1pe local=%pISpsfc remote=%pISpsfc\n", \ smbdirect_socket_status_string(__expected_status), \ smbdirect_socket_status_string((__sc)->status), \ SMBDIRECT_DEBUG_ERR_PTR((__sc)->first_error), \ __src, __dst); \ WARN_ONCE(1, \ "expected[%s] != %s first_error=%1pe local=%pISpsfc remote=%pISpsfc\n", \ smbdirect_socket_status_string(__expected_status), \ smbdirect_socket_status_string((__sc)->status), \ SMBDIRECT_DEBUG_ERR_PTR((__sc)->first_error), \ __src, __dst); \ __unexpected_cmd \ }) #define SMBDIRECT_CHECK_STATUS_WARN(__sc, __expected_status) \ __SMBDIRECT_CHECK_STATUS_WARN(__sc, __expected_status, /* nothing */) #ifndef __SMBDIRECT_SOCKET_DISCONNECT #define __SMBDIRECT_SOCKET_DISCONNECT(__sc) \ smbdirect_socket_schedule_cleanup(__sc, -ECONNABORTED) #endif /* ! __SMBDIRECT_SOCKET_DISCONNECT */ #define SMBDIRECT_CHECK_STATUS_DISCONNECT(__sc, __expected_status) \ __SMBDIRECT_CHECK_STATUS_WARN(__sc, __expected_status, \ __SMBDIRECT_SOCKET_DISCONNECT(__sc);) struct smbdirect_send_io { struct smbdirect_socket *socket; struct ib_cqe cqe; /* * The SGE entries for this work request * * The first points to the packet header */ #define SMBDIRECT_SEND_IO_MAX_SGE 6 size_t num_sge; struct ib_sge sge[SMBDIRECT_SEND_IO_MAX_SGE]; /* * Link to the list of sibling smbdirect_send_io * messages. */ struct list_head sibling_list; struct ib_send_wr wr; /* SMBD packet header follows this structure */ u8 packet[]; }; struct smbdirect_send_batch { /* * List of smbdirect_send_io messages */ struct list_head msg_list; /* * Number of list entries */ size_t wr_cnt; /* * Possible remote key invalidation state */ bool need_invalidate_rkey; u32 remote_key; int credit; }; struct smbdirect_recv_io { struct smbdirect_socket *socket; struct ib_cqe cqe; /* * For now we only use a single SGE * as we have just one large buffer * per posted recv. */ #define SMBDIRECT_RECV_IO_MAX_SGE 1 struct ib_sge sge; /* Link to free or reassembly list */ struct list_head list; /* Indicate if this is the 1st packet of a payload */ bool first_segment; /* SMBD packet header and payload follows this structure */ u8 packet[]; }; enum smbdirect_mr_state { SMBDIRECT_MR_READY, SMBDIRECT_MR_REGISTERED, SMBDIRECT_MR_INVALIDATED, SMBDIRECT_MR_ERROR, SMBDIRECT_MR_DISABLED }; struct smbdirect_mr_io { struct smbdirect_socket *socket; struct ib_cqe cqe; /* * We can have up to two references: * 1. by the connection * 2. by the registration */ struct kref kref; struct mutex mutex; struct list_head list; enum smbdirect_mr_state state; struct ib_mr *mr; struct sg_table sgt; enum dma_data_direction dir; union { struct ib_reg_wr wr; struct ib_send_wr inv_wr; }; bool need_invalidate; struct completion invalidate_done; }; struct smbdirect_rw_io { struct smbdirect_socket *socket; struct ib_cqe cqe; struct list_head list; int error; struct completion *completion; struct rdma_rw_ctx rdma_ctx; struct sg_table sgt; struct scatterlist sg_list[]; }; static inline size_t smbdirect_get_buf_page_count(const void *buf, size_t size) { return DIV_ROUND_UP((uintptr_t)buf + size, PAGE_SIZE) - (uintptr_t)buf / PAGE_SIZE; } /* * Maximum number of retries on data transfer operations */ #define SMBDIRECT_RDMA_CM_RETRY 6 /* * No need to retry on Receiver Not Ready since SMB_DIRECT manages credits */ #define SMBDIRECT_RDMA_CM_RNR_RETRY 0 #endif /* __FS_SMB_COMMON_SMBDIRECT_SMBDIRECT_SOCKET_H__ */