From 4dbfa766440c6dfe3d10f077cde966a7d11b58f1 Mon Sep 17 00:00:00 2001 From: Vlad Buslov Date: Mon, 11 Feb 2019 10:55:39 +0200 Subject: [PATCH] net: sched: introduce reference counting for tcf_proto In order to remove dependency on rtnl lock and allow concurrent tcf_proto modification, extend tcf_proto with reference counter. Implement helper get/put functions for tcf proto and use them to modify cls API to always take reference to tcf_proto while using it. Only release reference to parent chain after releasing last reference to tp. Signed-off-by: Vlad Buslov Acked-by: Jiri Pirko Signed-off-by: David S. Miller --- include/net/sch_generic.h | 1 + net/sched/cls_api.c | 53 +++++++++++++++++++++++++++++++-------- 2 files changed, 44 insertions(+), 10 deletions(-) diff --git a/include/net/sch_generic.h b/include/net/sch_generic.h index 85993d7efee6..4372c08fc4d9 100644 --- a/include/net/sch_generic.h +++ b/include/net/sch_generic.h @@ -322,6 +322,7 @@ struct tcf_proto { void *data; const struct tcf_proto_ops *ops; struct tcf_chain *chain; + refcount_t refcnt; struct rcu_head rcu; }; diff --git a/net/sched/cls_api.c b/net/sched/cls_api.c index 3fce30ae9a9b..37c05b96898f 100644 --- a/net/sched/cls_api.c +++ b/net/sched/cls_api.c @@ -180,6 +180,7 @@ static struct tcf_proto *tcf_proto_create(const char *kind, u32 protocol, tp->protocol = protocol; tp->prio = prio; tp->chain = chain; + refcount_set(&tp->refcnt, 1); err = tp->ops->init(tp); if (err) { @@ -193,14 +194,29 @@ errout: return ERR_PTR(err); } +static void tcf_proto_get(struct tcf_proto *tp) +{ + refcount_inc(&tp->refcnt); +} + +static void tcf_chain_put(struct tcf_chain *chain); + static void tcf_proto_destroy(struct tcf_proto *tp, struct netlink_ext_ack *extack) { tp->ops->destroy(tp, extack); + tcf_chain_put(tp->chain); module_put(tp->ops->owner); kfree_rcu(tp, rcu); } +static void tcf_proto_put(struct tcf_proto *tp, + struct netlink_ext_ack *extack) +{ + if (refcount_dec_and_test(&tp->refcnt)) + tcf_proto_destroy(tp, extack); +} + #define ASSERT_BLOCK_LOCKED(block) \ lockdep_assert_held(&(block)->lock) @@ -445,18 +461,18 @@ static void tcf_chain_put_explicitly_created(struct tcf_chain *chain) static void tcf_chain_flush(struct tcf_chain *chain) { - struct tcf_proto *tp; + struct tcf_proto *tp, *tp_next; mutex_lock(&chain->filter_chain_lock); tp = tcf_chain_dereference(chain->filter_chain, chain); + RCU_INIT_POINTER(chain->filter_chain, NULL); tcf_chain0_head_change(chain, NULL); mutex_unlock(&chain->filter_chain_lock); while (tp) { - RCU_INIT_POINTER(chain->filter_chain, tp->next); - tcf_proto_destroy(tp, NULL); - tp = rtnl_dereference(chain->filter_chain); - tcf_chain_put(chain); + tp_next = rcu_dereference_protected(tp->next, 1); + tcf_proto_put(tp, NULL); + tp = tp_next; } } @@ -1500,9 +1516,9 @@ static void tcf_chain_tp_insert(struct tcf_chain *chain, { if (*chain_info->pprev == chain->filter_chain) tcf_chain0_head_change(chain, tp); + tcf_proto_get(tp); RCU_INIT_POINTER(tp->next, tcf_chain_tp_prev(chain, chain_info)); rcu_assign_pointer(*chain_info->pprev, tp); - tcf_chain_hold(chain); } static void tcf_chain_tp_remove(struct tcf_chain *chain, @@ -1514,7 +1530,6 @@ static void tcf_chain_tp_remove(struct tcf_chain *chain, if (tp == chain->filter_chain) tcf_chain0_head_change(chain, next); RCU_INIT_POINTER(*chain_info->pprev, next); - tcf_chain_put(chain); } static struct tcf_proto *tcf_chain_tp_find(struct tcf_chain *chain, @@ -1541,7 +1556,12 @@ static struct tcf_proto *tcf_chain_tp_find(struct tcf_chain *chain, } } chain_info->pprev = pprev; - chain_info->next = tp ? tp->next : NULL; + if (tp) { + chain_info->next = tp->next; + tcf_proto_get(tp); + } else { + chain_info->next = NULL; + } return tp; } @@ -1699,6 +1719,7 @@ replay: prio = TC_H_MAJ(t->tcm_info); prio_allocate = false; parent = t->tcm_parent; + tp = NULL; cl = 0; if (prio == 0) { @@ -1816,6 +1837,12 @@ replay: errout: if (chain) tcf_chain_put(chain); + if (chain) { + if (tp && !IS_ERR(tp)) + tcf_proto_put(tp, NULL); + if (!tp_created) + tcf_chain_put(chain); + } tcf_block_release(q, block); if (err == -EAGAIN) /* Replay the request. */ @@ -1946,8 +1973,11 @@ static int tc_del_tfilter(struct sk_buff *skb, struct nlmsghdr *n, } errout: - if (chain) + if (chain) { + if (tp && !IS_ERR(tp)) + tcf_proto_put(tp, NULL); tcf_chain_put(chain); + } tcf_block_release(q, block); return err; @@ -2038,8 +2068,11 @@ static int tc_get_tfilter(struct sk_buff *skb, struct nlmsghdr *n, } errout: - if (chain) + if (chain) { + if (tp && !IS_ERR(tp)) + tcf_proto_put(tp, NULL); tcf_chain_put(chain); + } tcf_block_release(q, block); return err; } -- 2.30.2