diff --git a/include/net/udp.h b/include/net/udp.h index 496f89d45c8b..98755ebaf163 100644 --- a/include/net/udp.h +++ b/include/net/udp.h @@ -119,16 +119,9 @@ static inline void udp_lib_close(struct sock *sk, long timeout) } -struct udp_get_port_ops { - int (*saddr_cmp)(const struct sock *sk1, const struct sock *sk2); - int (*saddr_any)(const struct sock *sk); - unsigned int (*hash_port_and_rcv_saddr)(__u16 port, - const struct sock *sk); -}; - /* net/ipv4/udp.c */ extern int udp_get_port(struct sock *sk, unsigned short snum, - const struct udp_get_port_ops *ops); + int (*saddr_cmp)(const struct sock *, const struct sock *)); extern void udp_err(struct sk_buff *, u32); extern int udp_sendmsg(struct kiocb *iocb, struct sock *sk, diff --git a/include/net/udplite.h b/include/net/udplite.h index 50b4b424d1ca..635b0eafca95 100644 --- a/include/net/udplite.h +++ b/include/net/udplite.h @@ -120,5 +120,5 @@ static inline __wsum udplite_csum_outgoing(struct sock *sk, struct sk_buff *skb) extern void udplite4_register(void); extern int udplite_get_port(struct sock *sk, unsigned short snum, - const struct udp_get_port_ops *ops); + int (*scmp)(const struct sock *, const struct sock *)); #endif /* _UDPLITE_H */ diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index 5da703e699da..facb7e29304e 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -114,36 +114,14 @@ DEFINE_RWLOCK(udp_hash_lock); static int udp_port_rover; -/* - * Note about this hash function : - * Typical use is probably daddr = 0, only dport is going to vary hash - */ -static inline unsigned int udp_hash_port(__u16 port) -{ - return port; -} - -static inline int __udp_lib_port_inuse(unsigned int hash, int port, - const struct sock *this_sk, - struct hlist_head udptable[], - const struct udp_get_port_ops *ops) +static inline int __udp_lib_lport_inuse(__u16 num, struct hlist_head udptable[]) { struct sock *sk; struct hlist_node *node; - struct inet_sock *inet; - sk_for_each(sk, node, &udptable[hash & (UDP_HTABLE_SIZE - 1)]) { - if (sk->sk_hash != hash) - continue; - inet = inet_sk(sk); - if (inet->num != port) - continue; - if (this_sk) { - if (ops->saddr_cmp(sk, this_sk)) - return 1; - } else if (ops->saddr_any(sk)) + sk_for_each(sk, node, &udptable[num & (UDP_HTABLE_SIZE - 1)]) + if (sk->sk_hash == num) return 1; - } return 0; } @@ -154,16 +132,16 @@ static inline int __udp_lib_port_inuse(unsigned int hash, int port, * @snum: port number to look up * @udptable: hash list table, must be of UDP_HTABLE_SIZE * @port_rover: pointer to record of last unallocated port - * @ops: AF-dependent address operations + * @saddr_comp: AF-dependent comparison of bound local IP addresses */ int __udp_lib_get_port(struct sock *sk, unsigned short snum, struct hlist_head udptable[], int *port_rover, - const struct udp_get_port_ops *ops) + int (*saddr_comp)(const struct sock *sk1, + const struct sock *sk2 ) ) { struct hlist_node *node; struct hlist_head *head; struct sock *sk2; - unsigned int hash; int error = 1; write_lock_bh(&udp_hash_lock); @@ -178,8 +156,7 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum, for (i = 0; i < UDP_HTABLE_SIZE; i++, result++) { int size; - hash = ops->hash_port_and_rcv_saddr(result, sk); - head = &udptable[hash & (UDP_HTABLE_SIZE - 1)]; + head = &udptable[result & (UDP_HTABLE_SIZE - 1)]; if (hlist_empty(head)) { if (result > sysctl_local_port_range[1]) result = sysctl_local_port_range[0] + @@ -204,16 +181,7 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum, result = sysctl_local_port_range[0] + ((result - sysctl_local_port_range[0]) & (UDP_HTABLE_SIZE - 1)); - hash = udp_hash_port(result); - if (__udp_lib_port_inuse(hash, result, - NULL, udptable, ops)) - continue; - if (ops->saddr_any(sk)) - break; - - hash = ops->hash_port_and_rcv_saddr(result, sk); - if (! __udp_lib_port_inuse(hash, result, - sk, udptable, ops)) + if (! __udp_lib_lport_inuse(result, udptable)) break; } if (i >= (1 << 16) / UDP_HTABLE_SIZE) @@ -221,40 +189,21 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum, gotit: *port_rover = snum = result; } else { - hash = udp_hash_port(snum); - head = &udptable[hash & (UDP_HTABLE_SIZE - 1)]; + head = &udptable[snum & (UDP_HTABLE_SIZE - 1)]; sk_for_each(sk2, node, head) - if (sk2->sk_hash == hash && - sk2 != sk && - inet_sk(sk2)->num == snum && - (!sk2->sk_reuse || !sk->sk_reuse) && - (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if || - sk2->sk_bound_dev_if == sk->sk_bound_dev_if) && - ops->saddr_cmp(sk, sk2)) + if (sk2->sk_hash == snum && + sk2 != sk && + (!sk2->sk_reuse || !sk->sk_reuse) && + (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if + || sk2->sk_bound_dev_if == sk->sk_bound_dev_if) && + (*saddr_comp)(sk, sk2) ) goto fail; - - if (!ops->saddr_any(sk)) { - hash = ops->hash_port_and_rcv_saddr(snum, sk); - head = &udptable[hash & (UDP_HTABLE_SIZE - 1)]; - - sk_for_each(sk2, node, head) - if (sk2->sk_hash == hash && - sk2 != sk && - inet_sk(sk2)->num == snum && - (!sk2->sk_reuse || !sk->sk_reuse) && - (!sk2->sk_bound_dev_if || - !sk->sk_bound_dev_if || - sk2->sk_bound_dev_if == - sk->sk_bound_dev_if) && - ops->saddr_cmp(sk, sk2)) - goto fail; - } } inet_sk(sk)->num = snum; - sk->sk_hash = hash; + sk->sk_hash = snum; if (sk_unhashed(sk)) { - head = &udptable[hash & (UDP_HTABLE_SIZE - 1)]; + head = &udptable[snum & (UDP_HTABLE_SIZE - 1)]; sk_add_node(sk, head); sock_prot_inc_use(sk->sk_prot); } @@ -265,12 +214,12 @@ fail: } int udp_get_port(struct sock *sk, unsigned short snum, - const struct udp_get_port_ops *ops) + int (*scmp)(const struct sock *, const struct sock *)) { - return __udp_lib_get_port(sk, snum, udp_hash, &udp_port_rover, ops); + return __udp_lib_get_port(sk, snum, udp_hash, &udp_port_rover, scmp); } -static int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2) +int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2) { struct inet_sock *inet1 = inet_sk(sk1), *inet2 = inet_sk(sk2); @@ -279,33 +228,9 @@ static int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2) inet1->rcv_saddr == inet2->rcv_saddr )); } -static int ipv4_rcv_saddr_any(const struct sock *sk) -{ - return !inet_sk(sk)->rcv_saddr; -} - -static inline unsigned int ipv4_hash_port_and_addr(__u16 port, __be32 addr) -{ - addr ^= addr >> 16; - addr ^= addr >> 8; - return port ^ addr; -} - -static unsigned int ipv4_hash_port_and_rcv_saddr(__u16 port, - const struct sock *sk) -{ - return ipv4_hash_port_and_addr(port, inet_sk(sk)->rcv_saddr); -} - -const struct udp_get_port_ops udp_ipv4_ops = { - .saddr_cmp = ipv4_rcv_saddr_equal, - .saddr_any = ipv4_rcv_saddr_any, - .hash_port_and_rcv_saddr = ipv4_hash_port_and_rcv_saddr, -}; - static inline int udp_v4_get_port(struct sock *sk, unsigned short snum) { - return udp_get_port(sk, snum, &udp_ipv4_ops); + return udp_get_port(sk, snum, ipv4_rcv_saddr_equal); } /* UDP is nearly always wildcards out the wazoo, it makes no sense to try @@ -317,77 +242,63 @@ static struct sock *__udp4_lib_lookup(__be32 saddr, __be16 sport, { struct sock *sk, *result = NULL; struct hlist_node *node; - unsigned int hash, hashwild; - int score, best = -1, hport = ntohs(dport); - - hash = ipv4_hash_port_and_addr(hport, daddr); - hashwild = udp_hash_port(hport); + unsigned short hnum = ntohs(dport); + int badness = -1; read_lock(&udp_hash_lock); - -lookup: - - sk_for_each(sk, node, &udptable[hash & (UDP_HTABLE_SIZE - 1)]) { + sk_for_each(sk, node, &udptable[hnum & (UDP_HTABLE_SIZE - 1)]) { struct inet_sock *inet = inet_sk(sk); - if (sk->sk_hash != hash || ipv6_only_sock(sk) || - inet->num != hport) - continue; - - score = (sk->sk_family == PF_INET ? 1 : 0); - if (inet->rcv_saddr) { - if (inet->rcv_saddr != daddr) - continue; - score+=2; - } - if (inet->daddr) { - if (inet->daddr != saddr) - continue; - score+=2; - } - if (inet->dport) { - if (inet->dport != sport) - continue; - score+=2; - } - if (sk->sk_bound_dev_if) { - if (sk->sk_bound_dev_if != dif) - continue; - score+=2; - } - if (score == 9) { - result = sk; - goto found; - } else if (score > best) { - result = sk; - best = score; + if (sk->sk_hash == hnum && !ipv6_only_sock(sk)) { + int score = (sk->sk_family == PF_INET ? 1 : 0); + if (inet->rcv_saddr) { + if (inet->rcv_saddr != daddr) + continue; + score+=2; + } + if (inet->daddr) { + if (inet->daddr != saddr) + continue; + score+=2; + } + if (inet->dport) { + if (inet->dport != sport) + continue; + score+=2; + } + if (sk->sk_bound_dev_if) { + if (sk->sk_bound_dev_if != dif) + continue; + score+=2; + } + if (score == 9) { + result = sk; + break; + } else if (score > badness) { + result = sk; + badness = score; + } } } - - if (hash != hashwild) { - hash = hashwild; - goto lookup; - } -found: if (result) sock_hold(result); read_unlock(&udp_hash_lock); return result; } -static inline struct sock *udp_v4_mcast_next(struct sock *sk, unsigned int hnum, - int hport, __be32 loc_addr, +static inline struct sock *udp_v4_mcast_next(struct sock *sk, + __be16 loc_port, __be32 loc_addr, __be16 rmt_port, __be32 rmt_addr, int dif) { struct hlist_node *node; struct sock *s = sk; + unsigned short hnum = ntohs(loc_port); sk_for_each_from(s, node) { struct inet_sock *inet = inet_sk(s); if (s->sk_hash != hnum || - inet->num != hport || (inet->daddr && inet->daddr != rmt_addr) || (inet->dport != rmt_port && inet->dport) || (inet->rcv_saddr && inet->rcv_saddr != loc_addr) || @@ -1221,45 +1132,29 @@ static int __udp4_lib_mcast_deliver(struct sk_buff *skb, __be32 saddr, __be32 daddr, struct hlist_head udptable[]) { - struct sock *sk, *skw, *sknext; + struct sock *sk; int dif; - int hport = ntohs(uh->dest); - unsigned int hash = ipv4_hash_port_and_addr(hport, daddr); - unsigned int hashwild = udp_hash_port(hport); - - dif = skb->dev->ifindex; read_lock(&udp_hash_lock); - - sk = sk_head(&udptable[hash & (UDP_HTABLE_SIZE - 1)]); - skw = sk_head(&udptable[hashwild & (UDP_HTABLE_SIZE - 1)]); - - sk = udp_v4_mcast_next(sk, hash, hport, daddr, uh->source, saddr, dif); - if (!sk) { - hash = hashwild; - sk = udp_v4_mcast_next(skw, hash, hport, daddr, uh->source, - saddr, dif); - } + sk = sk_head(&udptable[ntohs(uh->dest) & (UDP_HTABLE_SIZE - 1)]); + dif = skb->dev->ifindex; + sk = udp_v4_mcast_next(sk, uh->dest, daddr, uh->source, saddr, dif); if (sk) { + struct sock *sknext = NULL; + do { struct sk_buff *skb1 = skb; - sknext = udp_v4_mcast_next(sk_next(sk), hash, hport, - daddr, uh->source, saddr, dif); - if (!sknext && hash != hashwild) { - hash = hashwild; - sknext = udp_v4_mcast_next(skw, hash, hport, - daddr, uh->source, saddr, dif); - } + + sknext = udp_v4_mcast_next(sk_next(sk), uh->dest, daddr, + uh->source, saddr, dif); if (sknext) skb1 = skb_clone(skb, GFP_ATOMIC); if (skb1) { int ret = udp_queue_rcv_skb(sk, skb1); if (ret > 0) - /* - * we should probably re-process - * instead of dropping packets here. - */ + /* we should probably re-process instead + * of dropping packets here. */ kfree_skb(skb1); } sk = sknext; @@ -1346,7 +1241,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct hlist_head udptable[], return __udp4_lib_mcast_deliver(skb, uh, saddr, daddr, udptable); sk = __udp4_lib_lookup(saddr, uh->source, daddr, uh->dest, - skb->dev->ifindex, udptable); + skb->dev->ifindex, udptable ); if (sk != NULL) { int ret = udp_queue_rcv_skb(sk, skb); diff --git a/net/ipv4/udp_impl.h b/net/ipv4/udp_impl.h index 06d94195e644..820a477cfaa6 100644 --- a/net/ipv4/udp_impl.h +++ b/net/ipv4/udp_impl.h @@ -5,14 +5,14 @@ #include #include -extern const struct udp_get_port_ops udp_ipv4_ops; - extern int __udp4_lib_rcv(struct sk_buff *, struct hlist_head [], int ); extern void __udp4_lib_err(struct sk_buff *, u32, struct hlist_head []); extern int __udp_lib_get_port(struct sock *sk, unsigned short snum, struct hlist_head udptable[], int *port_rover, - const struct udp_get_port_ops *ops); + int (*)(const struct sock*,const struct sock*)); +extern int ipv4_rcv_saddr_equal(const struct sock *, const struct sock *); + extern int udp_setsockopt(struct sock *sk, int level, int optname, char __user *optval, int optlen); diff --git a/net/ipv4/udplite.c b/net/ipv4/udplite.c index 3653b32dce2d..f34fd686a8f1 100644 --- a/net/ipv4/udplite.c +++ b/net/ipv4/udplite.c @@ -19,15 +19,14 @@ struct hlist_head udplite_hash[UDP_HTABLE_SIZE]; static int udplite_port_rover; int udplite_get_port(struct sock *sk, unsigned short p, - const struct udp_get_port_ops *ops) + int (*c)(const struct sock *, const struct sock *)) { - return __udp_lib_get_port(sk, p, udplite_hash, - &udplite_port_rover, ops); + return __udp_lib_get_port(sk, p, udplite_hash, &udplite_port_rover, c); } static int udplite_v4_get_port(struct sock *sk, unsigned short snum) { - return udplite_get_port(sk, snum, &udp_ipv4_ops); + return udplite_get_port(sk, snum, ipv4_rcv_saddr_equal); } static int udplite_rcv(struct sk_buff *skb) diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c index d1fbddd172e7..4210951edb6e 100644 --- a/net/ipv6/udp.c +++ b/net/ipv6/udp.c @@ -52,28 +52,9 @@ DEFINE_SNMP_STAT(struct udp_mib, udp_stats_in6) __read_mostly; -static int ipv6_rcv_saddr_any(const struct sock *sk) -{ - struct ipv6_pinfo *np = inet6_sk(sk); - - return ipv6_addr_any(&np->rcv_saddr); -} - -static unsigned int ipv6_hash_port_and_rcv_saddr(__u16 port, - const struct sock *sk) -{ - return port; -} - -const struct udp_get_port_ops udp_ipv6_ops = { - .saddr_cmp = ipv6_rcv_saddr_equal, - .saddr_any = ipv6_rcv_saddr_any, - .hash_port_and_rcv_saddr = ipv6_hash_port_and_rcv_saddr, -}; - static inline int udp_v6_get_port(struct sock *sk, unsigned short snum) { - return udp_get_port(sk, snum, &udp_ipv6_ops); + return udp_get_port(sk, snum, ipv6_rcv_saddr_equal); } static struct sock *__udp6_lib_lookup(struct in6_addr *saddr, __be16 sport, diff --git a/net/ipv6/udp_impl.h b/net/ipv6/udp_impl.h index 36b0c11a28a3..6e252f318f7c 100644 --- a/net/ipv6/udp_impl.h +++ b/net/ipv6/udp_impl.h @@ -6,8 +6,6 @@ #include #include -extern const struct udp_get_port_ops udp_ipv6_ops; - extern int __udp6_lib_rcv(struct sk_buff **, struct hlist_head [], int ); extern void __udp6_lib_err(struct sk_buff *, struct inet6_skb_parm *, int , int , int , __be32 , struct hlist_head []); diff --git a/net/ipv6/udplite.c b/net/ipv6/udplite.c index c40a51362f89..f54016a55004 100644 --- a/net/ipv6/udplite.c +++ b/net/ipv6/udplite.c @@ -37,7 +37,7 @@ static struct inet6_protocol udplitev6_protocol = { static int udplite_v6_get_port(struct sock *sk, unsigned short snum) { - return udplite_get_port(sk, snum, &udp_ipv6_ops); + return udplite_get_port(sk, snum, ipv6_rcv_saddr_equal); } struct proto udplitev6_prot = {