diff --git a/net/ipv6/ip6_output.c b/net/ipv6/ip6_output.c index 7fde1f265c90..c21777565c58 100644 --- a/net/ipv6/ip6_output.c +++ b/net/ipv6/ip6_output.c @@ -886,6 +886,38 @@ static int ip6_dst_lookup_tail(struct sock *sk, #endif int err; + /* The correct way to handle this would be to do + * ip6_route_get_saddr, and then ip6_route_output; however, + * the route-specific preferred source forces the + * ip6_route_output call _before_ ip6_route_get_saddr. + * + * In source specific routing (no src=any default route), + * ip6_route_output will fail given src=any saddr, though, so + * that's why we try it again later. + */ + if (ipv6_addr_any(&fl6->saddr) && (!*dst || !(*dst)->error)) { + struct rt6_info *rt; + bool had_dst = *dst != NULL; + + if (!had_dst) + *dst = ip6_route_output(net, sk, fl6); + rt = (*dst)->error ? NULL : (struct rt6_info *)*dst; + err = ip6_route_get_saddr(net, rt, &fl6->daddr, + sk ? inet6_sk(sk)->srcprefs : 0, + &fl6->saddr); + if (err) + goto out_err_release; + + /* If we had an erroneous initial result, pretend it + * never existed and let the SA-enabled version take + * over. + */ + if (!had_dst && (*dst)->error) { + dst_release(*dst); + *dst = NULL; + } + } + if (!*dst) *dst = ip6_route_output(net, sk, fl6); @@ -893,15 +925,6 @@ static int ip6_dst_lookup_tail(struct sock *sk, if (err) goto out_err_release; - if (ipv6_addr_any(&fl6->saddr)) { - struct rt6_info *rt = (struct rt6_info *) *dst; - err = ip6_route_get_saddr(net, rt, &fl6->daddr, - sk ? inet6_sk(sk)->srcprefs : 0, - &fl6->saddr); - if (err) - goto out_err_release; - } - #ifdef CONFIG_IPV6_OPTIMISTIC_DAD /* * Here if the dst entry we've looked up diff --git a/net/ipv6/route.c b/net/ipv6/route.c index 5c48293ff062..d3588885f097 100644 --- a/net/ipv6/route.c +++ b/net/ipv6/route.c @@ -2245,9 +2245,10 @@ int ip6_route_get_saddr(struct net *net, unsigned int prefs, struct in6_addr *saddr) { - struct inet6_dev *idev = ip6_dst_idev((struct dst_entry *)rt); + struct inet6_dev *idev = + rt ? ip6_dst_idev((struct dst_entry *)rt) : NULL; int err = 0; - if (rt->rt6i_prefsrc.plen) + if (rt && rt->rt6i_prefsrc.plen) *saddr = rt->rt6i_prefsrc.addr; else err = ipv6_dev_get_saddr(net, idev ? idev->dev : NULL,