diff --git a/include/net/ip.h b/include/net/ip.h index c3fffaa92d6e..acec504c469a 100644 --- a/include/net/ip.h +++ b/include/net/ip.h @@ -76,6 +76,7 @@ struct ipcm_cookie { __be32 addr; int oif; struct ip_options_rcu *opt; + __u8 protocol; __u8 ttl; __s16 tos; char priority; @@ -96,6 +97,7 @@ static inline void ipcm_init_sk(struct ipcm_cookie *ipcm, ipcm->sockc.tsflags = inet->sk.sk_tsflags; ipcm->oif = READ_ONCE(inet->sk.sk_bound_dev_if); ipcm->addr = inet->inet_saddr; + ipcm->protocol = inet->inet_num; } #define IPCB(skb) ((struct inet_skb_parm*)((skb)->cb)) diff --git a/include/uapi/linux/in.h b/include/uapi/linux/in.h index 4b7f2df66b99..e682ab628dfa 100644 --- a/include/uapi/linux/in.h +++ b/include/uapi/linux/in.h @@ -163,6 +163,7 @@ struct in_addr { #define IP_MULTICAST_ALL 49 #define IP_UNICAST_IF 50 #define IP_LOCAL_PORT_RANGE 51 +#define IP_PROTOCOL 52 #define MCAST_EXCLUDE 0 #define MCAST_INCLUDE 1 diff --git a/net/ipv4/ip_sockglue.c b/net/ipv4/ip_sockglue.c index b511ff0adc0a..8e97d8d4cc9d 100644 --- a/net/ipv4/ip_sockglue.c +++ b/net/ipv4/ip_sockglue.c @@ -317,7 +317,14 @@ int ip_cmsg_send(struct sock *sk, struct msghdr *msg, struct ipcm_cookie *ipc, ipc->tos = val; ipc->priority = rt_tos2priority(ipc->tos); break; - + case IP_PROTOCOL: + if (cmsg->cmsg_len != CMSG_LEN(sizeof(int))) + return -EINVAL; + val = *(int *)CMSG_DATA(cmsg); + if (val < 1 || val > 255) + return -EINVAL; + ipc->protocol = val; + break; default: return -EINVAL; } @@ -1761,6 +1768,9 @@ int do_ip_getsockopt(struct sock *sk, int level, int optname, case IP_LOCAL_PORT_RANGE: val = inet->local_port_range.hi << 16 | inet->local_port_range.lo; break; + case IP_PROTOCOL: + val = inet_sk(sk)->inet_num; + break; default: sockopt_release_sock(sk); return -ENOPROTOOPT; diff --git a/net/ipv4/raw.c b/net/ipv4/raw.c index ff712bf2a98d..eadf1c9ef7e4 100644 --- a/net/ipv4/raw.c +++ b/net/ipv4/raw.c @@ -532,6 +532,9 @@ static int raw_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) } ipcm_init_sk(&ipc, inet); + /* Keep backward compat */ + if (hdrincl) + ipc.protocol = IPPROTO_RAW; if (msg->msg_controllen) { err = ip_cmsg_send(sk, msg, &ipc, false); @@ -599,7 +602,7 @@ static int raw_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) flowi4_init_output(&fl4, ipc.oif, ipc.sockc.mark, tos, RT_SCOPE_UNIVERSE, - hdrincl ? IPPROTO_RAW : sk->sk_protocol, + hdrincl ? ipc.protocol : sk->sk_protocol, inet_sk_flowi_flags(sk) | (hdrincl ? FLOWI_FLAG_KNOWN_NH : 0), daddr, saddr, 0, 0, sk->sk_uid); diff --git a/net/ipv6/raw.c b/net/ipv6/raw.c index 7d0adb612bdd..44ee7a2e72ac 100644 --- a/net/ipv6/raw.c +++ b/net/ipv6/raw.c @@ -793,7 +793,8 @@ static int rawv6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) if (!proto) proto = inet->inet_num; - else if (proto != inet->inet_num) + else if (proto != inet->inet_num && + inet->inet_num != IPPROTO_RAW) return -EINVAL; if (proto > 255)