diff options
Diffstat (limited to 'mm/memory.c')
-rw-r--r-- | mm/memory.c | 106 |
1 files changed, 77 insertions, 29 deletions
diff --git a/mm/memory.c b/mm/memory.c index e48945ab362b..ce22a250926f 100644 --- a/mm/memory.c +++ b/mm/memory.c @@ -1410,6 +1410,13 @@ no_page_table: return page; } +static inline int stack_guard_page(struct vm_area_struct *vma, unsigned long addr) +{ + return (vma->vm_flags & VM_GROWSDOWN) && + (vma->vm_start == addr) && + !vma_stack_continue(vma->vm_prev, addr); +} + /** * __get_user_pages() - pin user pages in memory * @tsk: task_struct of target task @@ -1486,9 +1493,8 @@ int __get_user_pages(struct task_struct *tsk, struct mm_struct *mm, struct vm_area_struct *vma; vma = find_extend_vma(mm, start); - if (!vma && in_gate_area(tsk, start)) { + if (!vma && in_gate_area(mm, start)) { unsigned long pg = start & PAGE_MASK; - struct vm_area_struct *gate_vma = get_gate_vma(tsk); pgd_t *pgd; pud_t *pud; pmd_t *pmd; @@ -1513,10 +1519,11 @@ int __get_user_pages(struct task_struct *tsk, struct mm_struct *mm, pte_unmap(pte); return i ? : -EFAULT; } + vma = get_gate_vma(mm); if (pages) { struct page *page; - page = vm_normal_page(gate_vma, start, *pte); + page = vm_normal_page(vma, start, *pte); if (!page) { if (!(gup_flags & FOLL_DUMP) && is_zero_pfn(pte_pfn(*pte))) @@ -1530,12 +1537,7 @@ int __get_user_pages(struct task_struct *tsk, struct mm_struct *mm, get_page(page); } pte_unmap(pte); - if (vmas) - vmas[i] = gate_vma; - i++; - start += PAGE_SIZE; - nr_pages--; - continue; + goto next_page; } if (!vma || @@ -1549,6 +1551,13 @@ int __get_user_pages(struct task_struct *tsk, struct mm_struct *mm, continue; } + /* + * If we don't actually want the page itself, + * and it's the stack guard page, just skip it. + */ + if (!pages && stack_guard_page(vma, start)) + goto next_page; + do { struct page *page; unsigned int foll_flags = gup_flags; @@ -1569,6 +1578,8 @@ int __get_user_pages(struct task_struct *tsk, struct mm_struct *mm, fault_flags |= FAULT_FLAG_WRITE; if (nonblocking) fault_flags |= FAULT_FLAG_ALLOW_RETRY; + if (foll_flags & FOLL_NOWAIT) + fault_flags |= (FAULT_FLAG_ALLOW_RETRY | FAULT_FLAG_RETRY_NOWAIT); ret = handle_mm_fault(mm, vma, start, fault_flags); @@ -1589,13 +1600,17 @@ int __get_user_pages(struct task_struct *tsk, struct mm_struct *mm, return i ? i : -EFAULT; BUG(); } - if (ret & VM_FAULT_MAJOR) - tsk->maj_flt++; - else - tsk->min_flt++; + + if (tsk) { + if (ret & VM_FAULT_MAJOR) + tsk->maj_flt++; + else + tsk->min_flt++; + } if (ret & VM_FAULT_RETRY) { - *nonblocking = 0; + if (nonblocking) + *nonblocking = 0; return i; } @@ -1625,6 +1640,7 @@ int __get_user_pages(struct task_struct *tsk, struct mm_struct *mm, flush_anon_page(vma, page, start); flush_dcache_page(page); } +next_page: if (vmas) vmas[i] = vma; i++; @@ -1638,7 +1654,8 @@ EXPORT_SYMBOL(__get_user_pages); /** * get_user_pages() - pin user pages in memory - * @tsk: task_struct of target task + * @tsk: the task_struct to use for page fault accounting, or + * NULL if faults are not to be recorded. * @mm: mm_struct of target mm * @start: starting user address * @nr_pages: number of pages from start to pin @@ -2764,7 +2781,7 @@ static int do_swap_page(struct mm_struct *mm, struct vm_area_struct *vma, swp_entry_t entry; pte_t pte; int locked; - struct mem_cgroup *ptr = NULL; + struct mem_cgroup *ptr; int exclusive = 0; int ret = 0; @@ -3496,7 +3513,7 @@ static int __init gate_vma_init(void) __initcall(gate_vma_init); #endif -struct vm_area_struct *get_gate_vma(struct task_struct *tsk) +struct vm_area_struct *get_gate_vma(struct mm_struct *mm) { #ifdef AT_SYSINFO_EHDR return &gate_vma; @@ -3505,7 +3522,7 @@ struct vm_area_struct *get_gate_vma(struct task_struct *tsk) #endif } -int in_gate_area_no_task(unsigned long addr) +int in_gate_area_no_mm(unsigned long addr) { #ifdef AT_SYSINFO_EHDR if ((addr >= FIXADDR_USER_START) && (addr < FIXADDR_USER_END)) @@ -3646,20 +3663,15 @@ int generic_access_phys(struct vm_area_struct *vma, unsigned long addr, #endif /* - * Access another process' address space. - * Source/target buffer must be kernel space, - * Do not walk the page table directly, use get_user_pages + * Access another process' address space as given in mm. If non-NULL, use the + * given task for page fault accounting. */ -int access_process_vm(struct task_struct *tsk, unsigned long addr, void *buf, int len, int write) +static int __access_remote_vm(struct task_struct *tsk, struct mm_struct *mm, + unsigned long addr, void *buf, int len, int write) { - struct mm_struct *mm; struct vm_area_struct *vma; void *old_buf = buf; - mm = get_task_mm(tsk); - if (!mm) - return 0; - down_read(&mm->mmap_sem); /* ignore errors, just check how much was successfully transferred */ while (len) { @@ -3676,7 +3688,7 @@ int access_process_vm(struct task_struct *tsk, unsigned long addr, void *buf, in */ #ifdef CONFIG_HAVE_IOREMAP_PROT vma = find_vma(mm, addr); - if (!vma) + if (!vma || vma->vm_start > addr) break; if (vma->vm_ops && vma->vm_ops->access) ret = vma->vm_ops->access(vma, addr, buf, @@ -3708,11 +3720,47 @@ int access_process_vm(struct task_struct *tsk, unsigned long addr, void *buf, in addr += bytes; } up_read(&mm->mmap_sem); - mmput(mm); return buf - old_buf; } +/** + * access_remote_vm - access another process' address space + * @mm: the mm_struct of the target address space + * @addr: start address to access + * @buf: source or destination buffer + * @len: number of bytes to transfer + * @write: whether the access is a write + * + * The caller must hold a reference on @mm. + */ +int access_remote_vm(struct mm_struct *mm, unsigned long addr, + void *buf, int len, int write) +{ + return __access_remote_vm(NULL, mm, addr, buf, len, write); +} + +/* + * Access another process' address space. + * Source/target buffer must be kernel space, + * Do not walk the page table directly, use get_user_pages + */ +int access_process_vm(struct task_struct *tsk, unsigned long addr, + void *buf, int len, int write) +{ + struct mm_struct *mm; + int ret; + + mm = get_task_mm(tsk); + if (!mm) + return 0; + + ret = __access_remote_vm(tsk, mm, addr, buf, len, write); + mmput(mm); + + return ret; +} + /* * Print the name of a VMA. */ |