diff --git a/include/net/pkt_cls.h b/include/net/pkt_cls.h index cb8be396a11f..38bee7dd21d1 100644 --- a/include/net/pkt_cls.h +++ b/include/net/pkt_cls.h @@ -44,6 +44,8 @@ bool tcf_queue_work(struct rcu_work *rwork, work_func_t func); struct tcf_chain *tcf_chain_get_by_act(struct tcf_block *block, u32 chain_index); void tcf_chain_put_by_act(struct tcf_chain *chain); +struct tcf_chain *tcf_get_next_chain(struct tcf_block *block, + struct tcf_chain *chain); void tcf_block_netif_keep_dst(struct tcf_block *block); int tcf_block_get(struct tcf_block **p_block, struct tcf_proto __rcu **p_filter_chain, struct Qdisc *q, diff --git a/net/sched/cls_api.c b/net/sched/cls_api.c index 869ae44d7631..8e2ac785f6fd 100644 --- a/net/sched/cls_api.c +++ b/net/sched/cls_api.c @@ -883,28 +883,62 @@ static struct tcf_block *tcf_block_refcnt_get(struct net *net, u32 block_index) return block; } +static struct tcf_chain * +__tcf_get_next_chain(struct tcf_block *block, struct tcf_chain *chain) +{ + mutex_lock(&block->lock); + if (chain) + chain = list_is_last(&chain->list, &block->chain_list) ? + NULL : list_next_entry(chain, list); + else + chain = list_first_entry_or_null(&block->chain_list, + struct tcf_chain, list); + + /* skip all action-only chains */ + while (chain && tcf_chain_held_by_acts_only(chain)) + chain = list_is_last(&chain->list, &block->chain_list) ? + NULL : list_next_entry(chain, list); + + if (chain) + tcf_chain_hold(chain); + mutex_unlock(&block->lock); + + return chain; +} + +/* Function to be used by all clients that want to iterate over all chains on + * block. It properly obtains block->lock and takes reference to chain before + * returning it. Users of this function must be tolerant to concurrent chain + * insertion/deletion or ensure that no concurrent chain modification is + * possible. Note that all netlink dump callbacks cannot guarantee to provide + * consistent dump because rtnl lock is released each time skb is filled with + * data and sent to user-space. + */ + +struct tcf_chain * +tcf_get_next_chain(struct tcf_block *block, struct tcf_chain *chain) +{ + struct tcf_chain *chain_next = __tcf_get_next_chain(block, chain); + + if (chain) + tcf_chain_put(chain); + + return chain_next; +} +EXPORT_SYMBOL(tcf_get_next_chain); + static void tcf_block_flush_all_chains(struct tcf_block *block) { struct tcf_chain *chain; - /* Hold a refcnt for all chains, so that they don't disappear - * while we are iterating. + /* Last reference to block. At this point chains cannot be added or + * removed concurrently. */ - list_for_each_entry(chain, &block->chain_list, list) - tcf_chain_hold(chain); - - list_for_each_entry(chain, &block->chain_list, list) - tcf_chain_flush(chain); -} - -static void tcf_block_put_all_chains(struct tcf_block *block) -{ - struct tcf_chain *chain, *tmp; - - /* At this point, all the chains should have refcnt >= 1. */ - list_for_each_entry_safe(chain, tmp, &block->chain_list, list) { + for (chain = tcf_get_next_chain(block, NULL); + chain; + chain = tcf_get_next_chain(block, chain)) { tcf_chain_put_explicitly_created(chain); - tcf_chain_put(chain); + tcf_chain_flush(chain); } } @@ -923,8 +957,6 @@ static void __tcf_block_put(struct tcf_block *block, struct Qdisc *q, mutex_unlock(&block->lock); if (tcf_block_shared(block)) tcf_block_remove(block, block->net); - if (!free_block) - tcf_block_flush_all_chains(block); if (q) tcf_block_offload_unbind(block, q, ei); @@ -932,7 +964,7 @@ static void __tcf_block_put(struct tcf_block *block, struct Qdisc *q, if (free_block) tcf_block_destroy(block); else - tcf_block_put_all_chains(block); + tcf_block_flush_all_chains(block); } else if (q) { tcf_block_offload_unbind(block, q, ei); } @@ -1266,11 +1298,15 @@ tcf_block_playback_offloads(struct tcf_block *block, tc_setup_cb_t *cb, void *cb_priv, bool add, bool offload_in_use, struct netlink_ext_ack *extack) { - struct tcf_chain *chain; + struct tcf_chain *chain, *chain_prev; struct tcf_proto *tp; int err; - list_for_each_entry(chain, &block->chain_list, list) { + for (chain = __tcf_get_next_chain(block, NULL); + chain; + chain_prev = chain, + chain = __tcf_get_next_chain(block, chain), + tcf_chain_put(chain_prev)) { for (tp = rtnl_dereference(chain->filter_chain); tp; tp = rtnl_dereference(tp->next)) { if (tp->ops->reoffload) { @@ -1289,6 +1325,7 @@ tcf_block_playback_offloads(struct tcf_block *block, tc_setup_cb_t *cb, return 0; err_playback_remove: + tcf_chain_put(chain); tcf_block_playback_offloads(block, cb, cb_priv, false, offload_in_use, extack); return err; @@ -2023,11 +2060,11 @@ static bool tcf_chain_dump(struct tcf_chain *chain, struct Qdisc *q, u32 parent, /* called with RTNL */ static int tc_dump_tfilter(struct sk_buff *skb, struct netlink_callback *cb) { + struct tcf_chain *chain, *chain_prev; struct net *net = sock_net(skb->sk); struct nlattr *tca[TCA_MAX + 1]; struct Qdisc *q = NULL; struct tcf_block *block; - struct tcf_chain *chain; struct tcmsg *tcm = nlmsg_data(cb->nlh); long index_start; long index; @@ -2091,12 +2128,17 @@ static int tc_dump_tfilter(struct sk_buff *skb, struct netlink_callback *cb) index_start = cb->args[0]; index = 0; - list_for_each_entry(chain, &block->chain_list, list) { + for (chain = __tcf_get_next_chain(block, NULL); + chain; + chain_prev = chain, + chain = __tcf_get_next_chain(block, chain), + tcf_chain_put(chain_prev)) { if (tca[TCA_CHAIN] && nla_get_u32(tca[TCA_CHAIN]) != chain->index) continue; if (!tcf_chain_dump(chain, q, parent, skb, cb, index_start, &index)) { + tcf_chain_put(chain); err = -EMSGSIZE; break; } @@ -2364,11 +2406,11 @@ errout_block_locked: /* called with RTNL */ static int tc_dump_chain(struct sk_buff *skb, struct netlink_callback *cb) { + struct tcf_chain *chain, *chain_prev; struct net *net = sock_net(skb->sk); struct nlattr *tca[TCA_MAX + 1]; struct Qdisc *q = NULL; struct tcf_block *block; - struct tcf_chain *chain; struct tcmsg *tcm = nlmsg_data(cb->nlh); long index_start; long index; @@ -2432,7 +2474,11 @@ static int tc_dump_chain(struct sk_buff *skb, struct netlink_callback *cb) index_start = cb->args[0]; index = 0; - list_for_each_entry(chain, &block->chain_list, list) { + for (chain = __tcf_get_next_chain(block, NULL); + chain; + chain_prev = chain, + chain = __tcf_get_next_chain(block, chain), + tcf_chain_put(chain_prev)) { if ((tca[TCA_CHAIN] && nla_get_u32(tca[TCA_CHAIN]) != chain->index)) continue; @@ -2440,14 +2486,14 @@ static int tc_dump_chain(struct sk_buff *skb, struct netlink_callback *cb) index++; continue; } - if (tcf_chain_held_by_acts_only(chain)) - continue; err = tc_chain_fill_node(chain, net, skb, block, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, NLM_F_MULTI, RTM_NEWCHAIN); - if (err <= 0) + if (err <= 0) { + tcf_chain_put(chain); break; + } index++; } diff --git a/net/sched/sch_api.c b/net/sched/sch_api.c index 03e26e8d0ec9..80058abc729f 100644 --- a/net/sched/sch_api.c +++ b/net/sched/sch_api.c @@ -1909,7 +1909,9 @@ static void tc_bind_tclass(struct Qdisc *q, u32 portid, u32 clid, block = cops->tcf_block(q, cl, NULL); if (!block) return; - list_for_each_entry(chain, &block->chain_list, list) { + for (chain = tcf_get_next_chain(block, NULL); + chain; + chain = tcf_get_next_chain(block, chain)) { struct tcf_proto *tp; for (tp = rtnl_dereference(chain->filter_chain);