RDMA/umem: Use umem->owning_mm inside ODP
authorJason Gunthorpe <jgg@mellanox.com>
Sun, 16 Sep 2018 17:48:08 +0000 (20:48 +0300)
committerDoug Ledford <dledford@redhat.com>
Fri, 21 Sep 2018 15:58:36 +0000 (11:58 -0400)
Since ODP had a single struct mmu_notifier located in the ucontext it
could only handle a single MM at a time, and this prevented it from using
the new owning_mm system.

With the prior rework it is now simple to let ODP track multiple MMs per
ucontext, finish the job so that the per_mm is allocated on a mm by mm
basis, and freed when the last umem is dropped from the ucontext.

As a side effect the new saner locking removes the lockdep splat about
nesting the umem_rwsem between mmu_notifier_unregister and
ib_umem_odp_release.

It also makes ODP work with multiple processes, across, fork, etc.

Signed-off-by: Jason Gunthorpe <jgg@mellanox.com>
Signed-off-by: Leon Romanovsky <leonro@mellanox.com>
Signed-off-by: Doug Ledford <dledford@redhat.com>
drivers/infiniband/core/umem_odp.c
drivers/infiniband/core/uverbs_cmd.c
drivers/infiniband/hw/mlx5/main.c
drivers/infiniband/hw/mlx5/odp.c
include/rdma/ib_umem_odp.h
include/rdma/ib_verbs.h

index 6bf3fc0c12a104581b28dcc9ad7e706100ce5639..0577f9ff600f6b57536f8beeefac72e8a304824b 100644 (file)
@@ -278,10 +278,135 @@ static const struct mmu_notifier_ops ib_umem_notifiers = {
        .invalidate_range_end       = ib_umem_notifier_invalidate_range_end,
 };
 
-struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context,
-                                     unsigned long addr, size_t size)
+static void add_umem_to_per_mm(struct ib_umem_odp *umem_odp)
+{
+       struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
+       struct ib_umem *umem = &umem_odp->umem;
+
+       down_write(&per_mm->umem_rwsem);
+       if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
+               rbt_ib_umem_insert(&umem_odp->interval_tree,
+                                  &per_mm->umem_tree);
+
+       if (likely(!atomic_read(&per_mm->notifier_count)))
+               umem_odp->mn_counters_active = true;
+       else
+               list_add(&umem_odp->no_private_counters,
+                        &per_mm->no_private_counters);
+       up_write(&per_mm->umem_rwsem);
+}
+
+static void remove_umem_from_per_mm(struct ib_umem_odp *umem_odp)
+{
+       struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
+       struct ib_umem *umem = &umem_odp->umem;
+
+       down_write(&per_mm->umem_rwsem);
+       if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
+               rbt_ib_umem_remove(&umem_odp->interval_tree,
+                                  &per_mm->umem_tree);
+       if (!umem_odp->mn_counters_active) {
+               list_del(&umem_odp->no_private_counters);
+               complete_all(&umem_odp->notifier_completion);
+       }
+
+       up_write(&per_mm->umem_rwsem);
+}
+
+static struct ib_ucontext_per_mm *alloc_per_mm(struct ib_ucontext *ctx,
+                                              struct mm_struct *mm)
 {
        struct ib_ucontext_per_mm *per_mm;
+       int ret;
+
+       per_mm = kzalloc(sizeof(*per_mm), GFP_KERNEL);
+       if (!per_mm)
+               return ERR_PTR(-ENOMEM);
+
+       per_mm->context = ctx;
+       per_mm->mm = mm;
+       per_mm->umem_tree = RB_ROOT_CACHED;
+       init_rwsem(&per_mm->umem_rwsem);
+       INIT_LIST_HEAD(&per_mm->no_private_counters);
+
+       rcu_read_lock();
+       per_mm->tgid = get_task_pid(current->group_leader, PIDTYPE_PID);
+       rcu_read_unlock();
+
+       WARN_ON(mm != current->mm);
+
+       per_mm->mn.ops = &ib_umem_notifiers;
+       ret = mmu_notifier_register(&per_mm->mn, per_mm->mm);
+       if (ret) {
+               dev_err(&ctx->device->dev,
+                       "Failed to register mmu_notifier %d\n", ret);
+               goto out_pid;
+       }
+
+       list_add(&per_mm->ucontext_list, &ctx->per_mm_list);
+       return per_mm;
+
+out_pid:
+       put_pid(per_mm->tgid);
+       kfree(per_mm);
+       return ERR_PTR(ret);
+}
+
+static int get_per_mm(struct ib_umem_odp *umem_odp)
+{
+       struct ib_ucontext *ctx = umem_odp->umem.context;
+       struct ib_ucontext_per_mm *per_mm;
+
+       /*
+        * Generally speaking we expect only one or two per_mm in this list,
+        * so no reason to optimize this search today.
+        */
+       mutex_lock(&ctx->per_mm_list_lock);
+       list_for_each_entry(per_mm, &ctx->per_mm_list, ucontext_list) {
+               if (per_mm->mm == umem_odp->umem.owning_mm)
+                       goto found;
+       }
+
+       per_mm = alloc_per_mm(ctx, umem_odp->umem.owning_mm);
+       if (IS_ERR(per_mm)) {
+               mutex_unlock(&ctx->per_mm_list_lock);
+               return PTR_ERR(per_mm);
+       }
+
+found:
+       umem_odp->per_mm = per_mm;
+       per_mm->odp_mrs_count++;
+       mutex_unlock(&ctx->per_mm_list_lock);
+
+       return 0;
+}
+
+void put_per_mm(struct ib_umem_odp *umem_odp)
+{
+       struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
+       struct ib_ucontext *ctx = umem_odp->umem.context;
+       bool need_free;
+
+       mutex_lock(&ctx->per_mm_list_lock);
+       umem_odp->per_mm = NULL;
+       per_mm->odp_mrs_count--;
+       need_free = per_mm->odp_mrs_count == 0;
+       if (need_free)
+               list_del(&per_mm->ucontext_list);
+       mutex_unlock(&ctx->per_mm_list_lock);
+
+       if (!need_free)
+               return;
+
+       mmu_notifier_unregister(&per_mm->mn, per_mm->mm);
+       put_pid(per_mm->tgid);
+       kfree(per_mm);
+}
+
+struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext_per_mm *per_mm,
+                                     unsigned long addr, size_t size)
+{
+       struct ib_ucontext *ctx = per_mm->context;
        struct ib_umem_odp *odp_data;
        struct ib_umem *umem;
        int pages = size >> PAGE_SHIFT;
@@ -291,13 +416,13 @@ struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context,
        if (!odp_data)
                return ERR_PTR(-ENOMEM);
        umem = &odp_data->umem;
-       umem->context    = context;
+       umem->context    = ctx;
        umem->length     = size;
        umem->address    = addr;
        umem->page_shift = PAGE_SHIFT;
        umem->writable   = 1;
        umem->is_odp = 1;
-       odp_data->per_mm = per_mm = &context->per_mm;
+       odp_data->per_mm = per_mm;
 
        mutex_init(&odp_data->umem_mutex);
        init_completion(&odp_data->notifier_completion);
@@ -316,15 +441,14 @@ struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context,
                goto out_page_list;
        }
 
-       down_write(&per_mm->umem_rwsem);
+       /*
+        * Caller must ensure that the umem_odp that the per_mm came from
+        * cannot be freed during the call to ib_alloc_odp_umem.
+        */
+       mutex_lock(&ctx->per_mm_list_lock);
        per_mm->odp_mrs_count++;
-       rbt_ib_umem_insert(&odp_data->interval_tree, &per_mm->umem_tree);
-       if (likely(!atomic_read(&per_mm->notifier_count)))
-               odp_data->mn_counters_active = true;
-       else
-               list_add(&odp_data->no_private_counters,
-                        &per_mm->no_private_counters);
-       up_write(&per_mm->umem_rwsem);
+       mutex_unlock(&ctx->per_mm_list_lock);
+       add_umem_to_per_mm(odp_data);
 
        return odp_data;
 
@@ -338,15 +462,13 @@ EXPORT_SYMBOL(ib_alloc_odp_umem);
 
 int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
 {
-       struct ib_ucontext *context = umem_odp->umem.context;
        struct ib_umem *umem = &umem_odp->umem;
-       struct ib_ucontext_per_mm *per_mm;
+       /*
+        * NOTE: This must called in a process context where umem->owning_mm
+        * == current->mm
+        */
+       struct mm_struct *mm = umem->owning_mm;
        int ret_val;
-       struct pid *our_pid;
-       struct mm_struct *mm = get_task_mm(current);
-
-       if (!mm)
-               return -EINVAL;
 
        if (access & IB_ACCESS_HUGETLB) {
                struct vm_area_struct *vma;
@@ -366,16 +488,6 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
                umem->hugetlb = 0;
        }
 
-       /* Prevent creating ODP MRs in child processes */
-       rcu_read_lock();
-       our_pid = get_task_pid(current->group_leader, PIDTYPE_PID);
-       rcu_read_unlock();
-       put_pid(our_pid);
-       if (context->tgid != our_pid) {
-               ret_val = -EINVAL;
-               goto out_mm;
-       }
-
        mutex_init(&umem_odp->umem_mutex);
 
        init_completion(&umem_odp->notifier_completion);
@@ -384,10 +496,8 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
                umem_odp->page_list =
                        vzalloc(array_size(sizeof(*umem_odp->page_list),
                                           ib_umem_num_pages(umem)));
-               if (!umem_odp->page_list) {
-                       ret_val = -ENOMEM;
-                       goto out_mm;
-               }
+               if (!umem_odp->page_list)
+                       return -ENOMEM;
 
                umem_odp->dma_list =
                        vzalloc(array_size(sizeof(*umem_odp->dma_list),
@@ -398,67 +508,23 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
                }
        }
 
-       /*
-        * When using MMU notifiers, we will get a
-        * notification before the "current" task (and MM) is
-        * destroyed. We use the umem_rwsem semaphore to synchronize.
-        */
-       umem_odp->per_mm = per_mm = &context->per_mm;
-
-       down_write(&per_mm->umem_rwsem);
-       per_mm->odp_mrs_count++;
-       if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
-               rbt_ib_umem_insert(&umem_odp->interval_tree,
-                                  &per_mm->umem_tree);
-       if (likely(!atomic_read(&per_mm->notifier_count)) ||
-           per_mm->odp_mrs_count == 1)
-               umem_odp->mn_counters_active = true;
-       else
-               list_add(&umem_odp->no_private_counters,
-                        &per_mm->no_private_counters);
-       downgrade_write(&per_mm->umem_rwsem);
+       ret_val = get_per_mm(umem_odp);
+       if (ret_val)
+               goto out_dma_list;
+       add_umem_to_per_mm(umem_odp);
 
-       if (per_mm->odp_mrs_count == 1) {
-               /*
-                * Note that at this point, no MMU notifier is running
-                * for this per_mm!
-                */
-               atomic_set(&per_mm->notifier_count, 0);
-               INIT_HLIST_NODE(&per_mm->mn.hlist);
-               per_mm->mn.ops = &ib_umem_notifiers;
-               ret_val = mmu_notifier_register(&per_mm->mn, mm);
-               if (ret_val) {
-                       pr_err("Failed to register mmu_notifier %d\n", ret_val);
-                       ret_val = -EBUSY;
-                       goto out_mutex;
-               }
-       }
-
-       up_read(&per_mm->umem_rwsem);
-
-       /*
-        * Note that doing an mmput can cause a notifier for the relevant mm.
-        * If the notifier is called while we hold the umem_rwsem, this will
-        * cause a deadlock. Therefore, we release the reference only after we
-        * released the semaphore.
-        */
-       mmput(mm);
        return 0;
 
-out_mutex:
-       up_read(&per_mm->umem_rwsem);
+out_dma_list:
        vfree(umem_odp->dma_list);
 out_page_list:
        vfree(umem_odp->page_list);
-out_mm:
-       mmput(mm);
        return ret_val;
 }
 
 void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
 {
        struct ib_umem *umem = &umem_odp->umem;
-       struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
 
        /*
         * Ensure that no more pages are mapped in the umem.
@@ -469,54 +535,8 @@ void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
        ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem),
                                    ib_umem_end(umem));
 
-       down_write(&per_mm->umem_rwsem);
-       if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
-               rbt_ib_umem_remove(&umem_odp->interval_tree,
-                                  &per_mm->umem_tree);
-       per_mm->odp_mrs_count--;
-       if (!umem_odp->mn_counters_active) {
-               list_del(&umem_odp->no_private_counters);
-               complete_all(&umem_odp->notifier_completion);
-       }
-
-       /*
-        * Downgrade the lock to a read lock. This ensures that the notifiers
-        * (who lock the mutex for reading) will be able to finish, and we
-        * will be able to enventually obtain the mmu notifiers SRCU. Note
-        * that since we are doing it atomically, no other user could register
-        * and unregister while we do the check.
-        */
-       downgrade_write(&per_mm->umem_rwsem);
-       if (!per_mm->odp_mrs_count) {
-               struct task_struct *owning_process = NULL;
-               struct mm_struct *owning_mm        = NULL;
-
-               owning_process =
-                       get_pid_task(umem_odp->umem.context->tgid, PIDTYPE_PID);
-               if (owning_process == NULL)
-                       /*
-                        * The process is already dead, notifier were removed
-                        * already.
-                        */
-                       goto out;
-
-               owning_mm = get_task_mm(owning_process);
-               if (owning_mm == NULL)
-                       /*
-                        * The process' mm is already dead, notifier were
-                        * removed already.
-                        */
-                       goto out_put_task;
-               mmu_notifier_unregister(&per_mm->mn, owning_mm);
-
-               mmput(owning_mm);
-
-out_put_task:
-               put_task_struct(owning_process);
-       }
-out:
-       up_read(&per_mm->umem_rwsem);
-
+       remove_umem_from_per_mm(umem_odp);
+       put_per_mm(umem_odp);
        vfree(umem_odp->dma_list);
        vfree(umem_odp->page_list);
 }
@@ -634,7 +654,7 @@ int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt,
 {
        struct ib_umem *umem = &umem_odp->umem;
        struct task_struct *owning_process  = NULL;
-       struct mm_struct   *owning_mm       = NULL;
+       struct mm_struct *owning_mm = umem_odp->umem.owning_mm;
        struct page       **local_page_list = NULL;
        u64 page_mask, off;
        int j, k, ret = 0, start_idx, npages = 0, page_shift;
@@ -658,15 +678,14 @@ int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt,
        user_virt = user_virt & page_mask;
        bcnt += off; /* Charge for the first page offset as well. */
 
-       owning_process = get_pid_task(umem->context->tgid, PIDTYPE_PID);
-       if (owning_process == NULL) {
+       /*
+        * owning_process is allowed to be NULL, this means somehow the mm is
+        * existing beyond the lifetime of the originating process.. Presumably
+        * mmget_not_zero will fail in this case.
+        */
+       owning_process = get_pid_task(umem_odp->per_mm->tgid, PIDTYPE_PID);
+       if (WARN_ON(!mmget_not_zero(umem_odp->umem.owning_mm))) {
                ret = -EINVAL;
-               goto out_no_task;
-       }
-
-       owning_mm = get_task_mm(owning_process);
-       if (owning_mm == NULL) {
-               ret = -ENOENT;
                goto out_put_task;
        }
 
@@ -738,8 +757,8 @@ int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt,
 
        mmput(owning_mm);
 out_put_task:
-       put_task_struct(owning_process);
-out_no_task:
+       if (owning_process)
+               put_task_struct(owning_process);
        free_page((unsigned long)local_page_list);
        return ret;
 }
index ce678e1008a45422e244e1c531324613f87ed39d..d77b0b9793c776ac215b7366e44f90cbf793b141 100644 (file)
@@ -124,12 +124,8 @@ ssize_t ib_uverbs_get_context(struct ib_uverbs_file *file,
        ucontext->cleanup_retryable = false;
 
 #ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
-       ucontext->per_mm.umem_tree = RB_ROOT_CACHED;
-       init_rwsem(&ucontext->per_mm.umem_rwsem);
-       ucontext->per_mm.odp_mrs_count = 0;
-       INIT_LIST_HEAD(&ucontext->per_mm.no_private_counters);
-       ucontext->per_mm.context = ucontext;
-
+       mutex_init(&ucontext->per_mm_list_lock);
+       INIT_LIST_HEAD(&ucontext->per_mm_list);
        if (!(ib_dev->attrs.device_cap_flags & IB_DEVICE_ON_DEMAND_PAGING))
                ucontext->invalidate_range = NULL;
 
index aeb328100986d5aa2e8bc7fd9189be991d7c4ded..1348a08261a949806ddb9399675bfaf326bb1b97 100644 (file)
@@ -1861,6 +1861,13 @@ static int mlx5_ib_dealloc_ucontext(struct ib_ucontext *ibcontext)
        struct mlx5_ib_dev *dev = to_mdev(ibcontext->device);
        struct mlx5_bfreg_info *bfregi;
 
+#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
+       /* All umem's must be destroyed before destroying the ucontext. */
+       mutex_lock(&ibcontext->per_mm_list_lock);
+       WARN_ON(!list_empty(&ibcontext->per_mm_list));
+       mutex_unlock(&ibcontext->per_mm_list_lock);
+#endif
+
        if (context->devx_uid)
                mlx5_ib_devx_destroy(dev, context);
 
index 9982b5f4e59838916b9c8d88022c0cfd2165cf47..b04eb67753261c71c036bc4269578d350dd40de4 100644 (file)
@@ -393,7 +393,7 @@ next_mr:
                if (nentries)
                        nentries++;
        } else {
-               odp = ib_alloc_odp_umem(odp_mr->umem.context, addr,
+               odp = ib_alloc_odp_umem(odp_mr->per_mm, addr,
                                        MLX5_IMR_MTT_SIZE);
                if (IS_ERR(odp)) {
                        mutex_unlock(&odp_mr->umem_mutex);
index 394ea6b68db705209e2d0a5adb33c5d6f8f702f0..259eb08dfc9e5193d942ba1f358594b721403458 100644 (file)
@@ -91,8 +91,26 @@ static inline struct ib_umem_odp *to_ib_umem_odp(struct ib_umem *umem)
 
 #ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
 
+struct ib_ucontext_per_mm {
+       struct ib_ucontext *context;
+       struct mm_struct *mm;
+       struct pid *tgid;
+
+       struct rb_root_cached umem_tree;
+       /* Protects umem_tree */
+       struct rw_semaphore umem_rwsem;
+       atomic_t notifier_count;
+
+       struct mmu_notifier mn;
+       /* A list of umems that don't have private mmu notifier counters yet. */
+       struct list_head no_private_counters;
+       unsigned int odp_mrs_count;
+
+       struct list_head ucontext_list;
+};
+
 int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access);
-struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context,
+struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext_per_mm *per_mm,
                                      unsigned long addr, size_t size);
 void ib_umem_odp_release(struct ib_umem_odp *umem_odp);
 
index 2cf2cee5a753529f705fda9ee9bb004ec787658e..6437e6af758d84afa12900ccdd1445f612782f16 100644 (file)
@@ -1488,25 +1488,6 @@ struct ib_rdmacg_object {
 #endif
 };
 
-#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
-struct ib_ucontext_per_mm {
-       struct ib_ucontext *context;
-
-       struct rb_root_cached umem_tree;
-       /*
-        * Protects .umem_rbroot and tree, as well as odp_mrs_count and
-        * mmu notifiers registration.
-        */
-       struct rw_semaphore umem_rwsem;
-
-       struct mmu_notifier mn;
-       atomic_t notifier_count;
-       /* A list of umems that don't have private mmu notifier counters yet. */
-       struct list_head no_private_counters;
-       unsigned int odp_mrs_count;
-};
-#endif
-
 struct ib_ucontext {
        struct ib_device       *device;
        struct ib_uverbs_file  *ufile;
@@ -1523,7 +1504,8 @@ struct ib_ucontext {
 #ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
        void (*invalidate_range)(struct ib_umem_odp *umem_odp,
                                 unsigned long start, unsigned long end);
-       struct ib_ucontext_per_mm per_mm;
+       struct mutex per_mm_list_lock;
+       struct list_head per_mm_list;
 #endif
 
        struct ib_rdmacg_object cg_obj;