diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h index 4d3d75d63066..2be51b7a5800 100644 --- a/include/linux/skmsg.h +++ b/include/linux/skmsg.h @@ -347,11 +347,23 @@ static inline void sk_psock_update_proto(struct sock *sk, struct sk_psock *psock, struct proto *ops) { - psock->saved_unhash = sk->sk_prot->unhash; - psock->saved_close = sk->sk_prot->close; - psock->saved_write_space = sk->sk_write_space; + /* Initialize saved callbacks and original proto only once, since this + * function may be called multiple times for a psock, e.g. when + * psock->progs.msg_parser is updated. + * + * Since we've not installed the new proto, psock is not yet in use and + * we can initialize it without synchronization. + */ + if (!psock->sk_proto) { + struct proto *orig = READ_ONCE(sk->sk_prot); + + psock->saved_unhash = orig->unhash; + psock->saved_close = orig->close; + psock->saved_write_space = sk->sk_write_space; + + psock->sk_proto = orig; + } - psock->sk_proto = sk->sk_prot; /* Pairs with lockless read in sk_clone_lock() */ WRITE_ONCE(sk->sk_prot, ops); } diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c index 7d6e1b75d4d4..3327afa05c3d 100644 --- a/net/ipv4/tcp_bpf.c +++ b/net/ipv4/tcp_bpf.c @@ -637,20 +637,6 @@ static void tcp_bpf_update_sk_prot(struct sock *sk, struct sk_psock *psock) sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config]); } -static void tcp_bpf_reinit_sk_prot(struct sock *sk, struct sk_psock *psock) -{ - int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; - int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; - - /* Reinit occurs when program types change e.g. TCP_BPF_TX is removed - * or added requiring sk_prot hook updates. We keep original saved - * hooks in this case. - * - * Pairs with lockless read in sk_clone_lock(). - */ - WRITE_ONCE(sk->sk_prot, &tcp_bpf_prots[family][config]); -} - static int tcp_bpf_assert_proto_ops(struct proto *ops) { /* In order to avoid retpoline, we make assumptions when we call @@ -670,7 +656,7 @@ void tcp_bpf_reinit(struct sock *sk) rcu_read_lock(); psock = sk_psock(sk); - tcp_bpf_reinit_sk_prot(sk, psock); + tcp_bpf_update_sk_prot(sk, psock); rcu_read_unlock(); }