rhashtable: move dereference inside rht_ptr()
authorNeilBrown <neilb@suse.com>
Fri, 12 Apr 2019 01:52:08 +0000 (11:52 +1000)
committerDavid S. Miller <davem@davemloft.net>
Sat, 13 Apr 2019 00:34:45 +0000 (17:34 -0700)
Rather than dereferencing a pointer to a bucket and then passing the
result to rht_ptr(), we now pass in the pointer and do the dereference
in rht_ptr().

This requires that we pass in the tbl and hash as well to support RCU
checks, and means that the various rht_for_each functions can expect a
pointer that can be dereferenced without further care.

There are two places where we dereference a bucket pointer
where there is no testable protection - in each case we know
that we much have exclusive access without having taken a lock.
The previous code used rht_dereference() to pretend that holding
the mutex provided protects, but holding the mutex never provides
protection for accessing buckets.

So instead introduce rht_ptr_exclusive() that can be used when
there is known to be exclusive access without holding any locks.

Signed-off-by: NeilBrown <neilb@suse.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
include/linux/rhashtable.h
lib/rhashtable.c
lib/test_rhashtable.c

index c504cd820736f11d4bae859fbe6a06346224f0e9..b54e6436547e281c65d0c5bd0e84df08cd6dfaad 100644 (file)
@@ -344,12 +344,28 @@ static inline void rht_unlock(struct bucket_table *tbl,
 }
 
 /*
- * If 'p' is a bucket head and might be locked:
- *   rht_ptr() returns the address without the lock bit.
- *   rht_ptr_locked() returns the address WITH the lock bit.
+ * Where 'bkt' is a bucket and might be locked:
+ *   rht_ptr() dereferences that pointer and clears the lock bit.
+ *   rht_ptr_exclusive() dereferences in a context where exclusive
+ *            access is guaranteed, such as when destroying the table.
  */
-static inline struct rhash_head __rcu *rht_ptr(const struct rhash_lock_head *p)
+static inline struct rhash_head *rht_ptr(
+       struct rhash_lock_head __rcu * const *bkt,
+       struct bucket_table *tbl,
+       unsigned int hash)
 {
+       const struct rhash_lock_head *p =
+               rht_dereference_bucket_rcu(*bkt, tbl, hash);
+
+       return (void *)(((unsigned long)p) & ~BIT(1));
+}
+
+static inline struct rhash_head *rht_ptr_exclusive(
+       struct rhash_lock_head __rcu * const *bkt)
+{
+       const struct rhash_lock_head *p =
+               rcu_dereference_protected(*bkt, 1);
+
        return (void *)(((unsigned long)p) & ~BIT(1));
 }
 
@@ -380,8 +396,8 @@ static inline void rht_assign_unlock(struct bucket_table *tbl,
  * @hash:      the hash value / bucket index
  */
 #define rht_for_each_from(pos, head, tbl, hash) \
-       for (pos = rht_dereference_bucket(head, tbl, hash); \
-            !rht_is_a_nulls(pos); \
+       for (pos = head;                        \
+            !rht_is_a_nulls(pos);              \
             pos = rht_dereference_bucket((pos)->next, tbl, hash))
 
 /**
@@ -391,7 +407,8 @@ static inline void rht_assign_unlock(struct bucket_table *tbl,
  * @hash:      the hash value / bucket index
  */
 #define rht_for_each(pos, tbl, hash) \
-       rht_for_each_from(pos, rht_ptr(*rht_bucket(tbl, hash)), tbl, hash)
+       rht_for_each_from(pos, rht_ptr(rht_bucket(tbl, hash), tbl, hash),  \
+                         tbl, hash)
 
 /**
  * rht_for_each_entry_from - iterate over hash chain from given head
@@ -403,7 +420,7 @@ static inline void rht_assign_unlock(struct bucket_table *tbl,
  * @member:    name of the &struct rhash_head within the hashable struct.
  */
 #define rht_for_each_entry_from(tpos, pos, head, tbl, hash, member)    \
-       for (pos = rht_dereference_bucket(head, tbl, hash);             \
+       for (pos = head;                                                \
             (!rht_is_a_nulls(pos)) && rht_entry(tpos, pos, member);    \
             pos = rht_dereference_bucket((pos)->next, tbl, hash))
 
@@ -416,8 +433,9 @@ static inline void rht_assign_unlock(struct bucket_table *tbl,
  * @member:    name of the &struct rhash_head within the hashable struct.
  */
 #define rht_for_each_entry(tpos, pos, tbl, hash, member)               \
-       rht_for_each_entry_from(tpos, pos, rht_ptr(*rht_bucket(tbl, hash)), \
-                                   tbl, hash, member)
+       rht_for_each_entry_from(tpos, pos,                              \
+                               rht_ptr(rht_bucket(tbl, hash), tbl, hash), \
+                               tbl, hash, member)
 
 /**
  * rht_for_each_entry_safe - safely iterate over hash chain of given type
@@ -432,8 +450,7 @@ static inline void rht_assign_unlock(struct bucket_table *tbl,
  * remove the loop cursor from the list.
  */
 #define rht_for_each_entry_safe(tpos, pos, next, tbl, hash, member)          \
-       for (pos = rht_dereference_bucket(rht_ptr(*rht_bucket(tbl, hash)),    \
-                                         tbl, hash),                         \
+       for (pos = rht_ptr(rht_bucket(tbl, hash), tbl, hash),                 \
             next = !rht_is_a_nulls(pos) ?                                    \
                       rht_dereference_bucket(pos->next, tbl, hash) : NULL;   \
             (!rht_is_a_nulls(pos)) && rht_entry(tpos, pos, member);          \
@@ -454,7 +471,7 @@ static inline void rht_assign_unlock(struct bucket_table *tbl,
  */
 #define rht_for_each_rcu_from(pos, head, tbl, hash)                    \
        for (({barrier(); }),                                           \
-            pos = rht_dereference_bucket_rcu(head, tbl, hash);         \
+            pos = head;                                                \
             !rht_is_a_nulls(pos);                                      \
             pos = rcu_dereference_raw(pos->next))
 
@@ -469,10 +486,9 @@ static inline void rht_assign_unlock(struct bucket_table *tbl,
  * traversal is guarded by rcu_read_lock().
  */
 #define rht_for_each_rcu(pos, tbl, hash)                       \
-       for (({barrier(); }),                                           \
-            pos = rht_ptr(rht_dereference_bucket_rcu(                  \
-                                  *rht_bucket(tbl, hash), tbl, hash)); \
-            !rht_is_a_nulls(pos);                                      \
+       for (({barrier(); }),                                   \
+            pos = rht_ptr(rht_bucket(tbl, hash), tbl, hash);   \
+            !rht_is_a_nulls(pos);                              \
             pos = rcu_dereference_raw(pos->next))
 
 /**
@@ -490,7 +506,7 @@ static inline void rht_assign_unlock(struct bucket_table *tbl,
  */
 #define rht_for_each_entry_rcu_from(tpos, pos, head, tbl, hash, member) \
        for (({barrier(); }),                                               \
-            pos = rht_dereference_bucket_rcu(head, tbl, hash);             \
+            pos = head;                                                    \
             (!rht_is_a_nulls(pos)) && rht_entry(tpos, pos, member);        \
             pos = rht_dereference_bucket_rcu(pos->next, tbl, hash))
 
@@ -508,8 +524,9 @@ static inline void rht_assign_unlock(struct bucket_table *tbl,
  */
 #define rht_for_each_entry_rcu(tpos, pos, tbl, hash, member)              \
        rht_for_each_entry_rcu_from(tpos, pos,                             \
-                                       rht_ptr(*rht_bucket(tbl, hash)),   \
-                                       tbl, hash, member)
+                                   rht_ptr(rht_bucket(tbl, hash),         \
+                                           tbl, hash),                    \
+                                   tbl, hash, member)
 
 /**
  * rhl_for_each_rcu - iterate over rcu hash table list
@@ -556,7 +573,6 @@ static inline struct rhash_head *__rhashtable_lookup(
        };
        struct rhash_lock_head __rcu * const *bkt;
        struct bucket_table *tbl;
-       struct rhash_head __rcu *head;
        struct rhash_head *he;
        unsigned int hash;
 
@@ -565,8 +581,7 @@ restart:
        hash = rht_key_hashfn(ht, tbl, key, params);
        bkt = rht_bucket(tbl, hash);
        do {
-               head = rht_ptr(rht_dereference_bucket_rcu(*bkt, tbl, hash));
-               rht_for_each_rcu_from(he, head, tbl, hash) {
+               rht_for_each_rcu_from(he, rht_ptr(bkt, tbl, hash), tbl, hash) {
                        if (params.obj_cmpfn ?
                            params.obj_cmpfn(&arg, rht_obj(ht, he)) :
                            rhashtable_compare(&arg, rht_obj(ht, he)))
@@ -699,7 +714,7 @@ slow_path:
                return rhashtable_insert_slow(ht, key, obj);
        }
 
-       rht_for_each_from(head, rht_ptr(*bkt), tbl, hash) {
+       rht_for_each_from(head, rht_ptr(bkt, tbl, hash), tbl, hash) {
                struct rhlist_head *plist;
                struct rhlist_head *list;
 
@@ -744,7 +759,7 @@ slow_path:
                goto slow_path;
 
        /* Inserting at head of list makes unlocking free. */
-       head = rht_ptr(rht_dereference_bucket(*bkt, tbl, hash));
+       head = rht_ptr(bkt, tbl, hash);
 
        RCU_INIT_POINTER(obj->next, head);
        if (rhlist) {
@@ -971,7 +986,7 @@ static inline int __rhashtable_remove_fast_one(
        pprev = NULL;
        rht_lock(tbl, bkt);
 
-       rht_for_each_from(he, rht_ptr(*bkt), tbl, hash) {
+       rht_for_each_from(he, rht_ptr(bkt, tbl, hash), tbl, hash) {
                struct rhlist_head *list;
 
                list = container_of(he, struct rhlist_head, rhead);
@@ -1130,7 +1145,7 @@ static inline int __rhashtable_replace_fast(
        pprev = NULL;
        rht_lock(tbl, bkt);
 
-       rht_for_each_from(he, rht_ptr(*bkt), tbl, hash) {
+       rht_for_each_from(he, rht_ptr(bkt, tbl, hash), tbl, hash) {
                if (he != obj_old) {
                        pprev = &he->next;
                        continue;
index e387ceb00e865327dd18265dc82aaa9122e28177..237368ea98c5a18fe6065e61e1b65ed8cd8d20b3 100644 (file)
@@ -231,7 +231,8 @@ static int rhashtable_rehash_one(struct rhashtable *ht,
 
        err = -ENOENT;
 
-       rht_for_each_from(entry, rht_ptr(*bkt), old_tbl, old_hash) {
+       rht_for_each_from(entry, rht_ptr(bkt, old_tbl, old_hash),
+                         old_tbl, old_hash) {
                err = 0;
                next = rht_dereference_bucket(entry->next, old_tbl, old_hash);
 
@@ -248,8 +249,7 @@ static int rhashtable_rehash_one(struct rhashtable *ht,
 
        rht_lock_nested(new_tbl, &new_tbl->buckets[new_hash], SINGLE_DEPTH_NESTING);
 
-       head = rht_ptr(rht_dereference_bucket(new_tbl->buckets[new_hash],
-                                             new_tbl, new_hash));
+       head = rht_ptr(new_tbl->buckets + new_hash, new_tbl, new_hash);
 
        RCU_INIT_POINTER(entry->next, head);
 
@@ -491,7 +491,7 @@ static void *rhashtable_lookup_one(struct rhashtable *ht,
        int elasticity;
 
        elasticity = RHT_ELASTICITY;
-       rht_for_each_from(head, rht_ptr(*bkt), tbl, hash) {
+       rht_for_each_from(head, rht_ptr(bkt, tbl, hash), tbl, hash) {
                struct rhlist_head *list;
                struct rhlist_head *plist;
 
@@ -557,7 +557,7 @@ static struct bucket_table *rhashtable_insert_one(struct rhashtable *ht,
        if (unlikely(rht_grow_above_100(ht, tbl)))
                return ERR_PTR(-EAGAIN);
 
-       head = rht_ptr(rht_dereference_bucket(*bkt, tbl, hash));
+       head = rht_ptr(bkt, tbl, hash);
 
        RCU_INIT_POINTER(obj->next, head);
        if (ht->rhlist) {
@@ -1139,7 +1139,7 @@ restart:
                        struct rhash_head *pos, *next;
 
                        cond_resched();
-                       for (pos = rht_ptr(rht_dereference(*rht_bucket(tbl, i), ht)),
+                       for (pos = rht_ptr_exclusive(rht_bucket(tbl, i)),
                             next = !rht_is_a_nulls(pos) ?
                                        rht_dereference(pos->next, ht) : NULL;
                             !rht_is_a_nulls(pos);
index 02592c2a249c65796f6082315973110d98ba6b15..084fe5a6ac57cd9bed1101f0e8acf16a4d521ce2 100644 (file)
@@ -500,7 +500,7 @@ static unsigned int __init print_ht(struct rhltable *rhlt)
                struct rhash_head *pos, *next;
                struct test_obj_rhl *p;
 
-               pos = rht_ptr(rht_dereference(tbl->buckets[i], ht));
+               pos = rht_ptr_exclusive(tbl->buckets + i);
                next = !rht_is_a_nulls(pos) ? rht_dereference(pos->next, ht) : NULL;
 
                if (!rht_is_a_nulls(pos)) {