unix_bpf.c (4721B)
1// SPDX-License-Identifier: GPL-2.0 2/* Copyright (c) 2021 Cong Wang <cong.wang@bytedance.com> */ 3 4#include <linux/skmsg.h> 5#include <linux/bpf.h> 6#include <net/sock.h> 7#include <net/af_unix.h> 8 9#define unix_sk_has_data(__sk, __psock) \ 10 ({ !skb_queue_empty(&__sk->sk_receive_queue) || \ 11 !skb_queue_empty(&__psock->ingress_skb) || \ 12 !list_empty(&__psock->ingress_msg); \ 13 }) 14 15static int unix_msg_wait_data(struct sock *sk, struct sk_psock *psock, 16 long timeo) 17{ 18 DEFINE_WAIT_FUNC(wait, woken_wake_function); 19 struct unix_sock *u = unix_sk(sk); 20 int ret = 0; 21 22 if (sk->sk_shutdown & RCV_SHUTDOWN) 23 return 1; 24 25 if (!timeo) 26 return ret; 27 28 add_wait_queue(sk_sleep(sk), &wait); 29 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); 30 if (!unix_sk_has_data(sk, psock)) { 31 mutex_unlock(&u->iolock); 32 wait_woken(&wait, TASK_INTERRUPTIBLE, timeo); 33 mutex_lock(&u->iolock); 34 ret = unix_sk_has_data(sk, psock); 35 } 36 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); 37 remove_wait_queue(sk_sleep(sk), &wait); 38 return ret; 39} 40 41static int __unix_recvmsg(struct sock *sk, struct msghdr *msg, 42 size_t len, int flags) 43{ 44 if (sk->sk_type == SOCK_DGRAM) 45 return __unix_dgram_recvmsg(sk, msg, len, flags); 46 else 47 return __unix_stream_recvmsg(sk, msg, len, flags); 48} 49 50static int unix_bpf_recvmsg(struct sock *sk, struct msghdr *msg, 51 size_t len, int flags, int *addr_len) 52{ 53 struct unix_sock *u = unix_sk(sk); 54 struct sk_psock *psock; 55 int copied; 56 57 psock = sk_psock_get(sk); 58 if (unlikely(!psock)) 59 return __unix_recvmsg(sk, msg, len, flags); 60 61 mutex_lock(&u->iolock); 62 if (!skb_queue_empty(&sk->sk_receive_queue) && 63 sk_psock_queue_empty(psock)) { 64 mutex_unlock(&u->iolock); 65 sk_psock_put(sk, psock); 66 return __unix_recvmsg(sk, msg, len, flags); 67 } 68 69msg_bytes_ready: 70 copied = sk_msg_recvmsg(sk, psock, msg, len, flags); 71 if (!copied) { 72 long timeo; 73 int data; 74 75 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); 76 data = unix_msg_wait_data(sk, psock, timeo); 77 if (data) { 78 if (!sk_psock_queue_empty(psock)) 79 goto msg_bytes_ready; 80 mutex_unlock(&u->iolock); 81 sk_psock_put(sk, psock); 82 return __unix_recvmsg(sk, msg, len, flags); 83 } 84 copied = -EAGAIN; 85 } 86 mutex_unlock(&u->iolock); 87 sk_psock_put(sk, psock); 88 return copied; 89} 90 91static struct proto *unix_dgram_prot_saved __read_mostly; 92static DEFINE_SPINLOCK(unix_dgram_prot_lock); 93static struct proto unix_dgram_bpf_prot; 94 95static struct proto *unix_stream_prot_saved __read_mostly; 96static DEFINE_SPINLOCK(unix_stream_prot_lock); 97static struct proto unix_stream_bpf_prot; 98 99static void unix_dgram_bpf_rebuild_protos(struct proto *prot, const struct proto *base) 100{ 101 *prot = *base; 102 prot->close = sock_map_close; 103 prot->recvmsg = unix_bpf_recvmsg; 104 prot->sock_is_readable = sk_msg_is_readable; 105} 106 107static void unix_stream_bpf_rebuild_protos(struct proto *prot, 108 const struct proto *base) 109{ 110 *prot = *base; 111 prot->close = sock_map_close; 112 prot->recvmsg = unix_bpf_recvmsg; 113 prot->sock_is_readable = sk_msg_is_readable; 114 prot->unhash = sock_map_unhash; 115} 116 117static void unix_dgram_bpf_check_needs_rebuild(struct proto *ops) 118{ 119 if (unlikely(ops != smp_load_acquire(&unix_dgram_prot_saved))) { 120 spin_lock_bh(&unix_dgram_prot_lock); 121 if (likely(ops != unix_dgram_prot_saved)) { 122 unix_dgram_bpf_rebuild_protos(&unix_dgram_bpf_prot, ops); 123 smp_store_release(&unix_dgram_prot_saved, ops); 124 } 125 spin_unlock_bh(&unix_dgram_prot_lock); 126 } 127} 128 129static void unix_stream_bpf_check_needs_rebuild(struct proto *ops) 130{ 131 if (unlikely(ops != smp_load_acquire(&unix_stream_prot_saved))) { 132 spin_lock_bh(&unix_stream_prot_lock); 133 if (likely(ops != unix_stream_prot_saved)) { 134 unix_stream_bpf_rebuild_protos(&unix_stream_bpf_prot, ops); 135 smp_store_release(&unix_stream_prot_saved, ops); 136 } 137 spin_unlock_bh(&unix_stream_prot_lock); 138 } 139} 140 141int unix_dgram_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) 142{ 143 if (sk->sk_type != SOCK_DGRAM) 144 return -EOPNOTSUPP; 145 146 if (restore) { 147 sk->sk_write_space = psock->saved_write_space; 148 WRITE_ONCE(sk->sk_prot, psock->sk_proto); 149 return 0; 150 } 151 152 unix_dgram_bpf_check_needs_rebuild(psock->sk_proto); 153 WRITE_ONCE(sk->sk_prot, &unix_dgram_bpf_prot); 154 return 0; 155} 156 157int unix_stream_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) 158{ 159 if (restore) { 160 sk->sk_write_space = psock->saved_write_space; 161 WRITE_ONCE(sk->sk_prot, psock->sk_proto); 162 return 0; 163 } 164 165 unix_stream_bpf_check_needs_rebuild(psock->sk_proto); 166 WRITE_ONCE(sk->sk_prot, &unix_stream_bpf_prot); 167 return 0; 168} 169 170void __init unix_bpf_build_proto(void) 171{ 172 unix_dgram_bpf_rebuild_protos(&unix_dgram_bpf_prot, &unix_dgram_proto); 173 unix_stream_bpf_rebuild_protos(&unix_stream_bpf_prot, &unix_stream_proto); 174 175}