vfio: Use get_user_pages_longterm correctly
authorJason Gunthorpe <jgg@mellanox.com>
Fri, 29 Jun 2018 17:31:50 +0000 (11:31 -0600)
committerAlex Williamson <alex.williamson@redhat.com>
Sat, 30 Jun 2018 19:58:09 +0000 (13:58 -0600)
The patch noted in the fixes below converted get_user_pages_fast() to
get_user_pages_longterm(), however the two calls differ in a few ways.

First _fast() is documented to not require the mmap_sem, while _longterm()
is documented to need it. Hold the mmap sem as required.

Second, _fast accepts an 'int write' while _longterm uses 'unsigned int
gup_flags', so the expression '!!(prot & IOMMU_WRITE)' is only working by
luck as FOLL_WRITE is currently == 0x1. Use the expected FOLL_WRITE
constant instead.

Fixes: 94db151dc892 ("vfio: disable filesystem-dax page pinning")
Cc: <stable@vger.kernel.org>
Signed-off-by: Jason Gunthorpe <jgg@mellanox.com>
Acked-by: Dan Williams <dan.j.williams@intel.com>
Signed-off-by: Alex Williamson <alex.williamson@redhat.com>
drivers/vfio/vfio_iommu_type1.c

index 2c75b33db4ac19768ea77415685b4fac700dc4d3..3e5b17710a4f1fa47eb4f1333c17c96e1eae2cdd 100644 (file)
@@ -343,18 +343,16 @@ static int vaddr_get_pfn(struct mm_struct *mm, unsigned long vaddr,
        struct page *page[1];
        struct vm_area_struct *vma;
        struct vm_area_struct *vmas[1];
+       unsigned int flags = 0;
        int ret;
 
+       if (prot & IOMMU_WRITE)
+               flags |= FOLL_WRITE;
+
+       down_read(&mm->mmap_sem);
        if (mm == current->mm) {
-               ret = get_user_pages_longterm(vaddr, 1, !!(prot & IOMMU_WRITE),
-                                             page, vmas);
+               ret = get_user_pages_longterm(vaddr, 1, flags, page, vmas);
        } else {
-               unsigned int flags = 0;
-
-               if (prot & IOMMU_WRITE)
-                       flags |= FOLL_WRITE;
-
-               down_read(&mm->mmap_sem);
                ret = get_user_pages_remote(NULL, mm, vaddr, 1, flags, page,
                                            vmas, NULL);
                /*
@@ -368,8 +366,8 @@ static int vaddr_get_pfn(struct mm_struct *mm, unsigned long vaddr,
                        ret = -EOPNOTSUPP;
                        put_page(page[0]);
                }
-               up_read(&mm->mmap_sem);
        }
+       up_read(&mm->mmap_sem);
 
        if (ret == 1) {
                *pfn = page_to_pfn(page[0]);