1// See LICENSE for license details.
2
3#include "mmap.h"
4#include "atomic.h"
5#include "pk.h"
6#include "boot.h"
7#include "bits.h"
8#include "mtrap.h"
9#include <stdint.h>
10#include <errno.h>
11
12typedef struct {
13  uintptr_t addr;
14  size_t length;
15  file_t* file;
16  size_t offset;
17  unsigned refcnt;
18  int prot;
19} vmr_t;
20
21#define MAX_VMR (RISCV_PGSIZE / sizeof(vmr_t))
22static spinlock_t vm_lock = SPINLOCK_INIT;
23static vmr_t* vmrs;
24
25uintptr_t first_free_paddr;
26static uintptr_t first_free_page;
27static size_t next_free_page;
28static size_t free_pages;
29
30int demand_paging = 1; // unless -p flag is given
31
32static uintptr_t __page_alloc()
33{
34  kassert(next_free_page != free_pages);
35  uintptr_t addr = first_free_page + RISCV_PGSIZE * next_free_page++;
36  memset((void*)addr, 0, RISCV_PGSIZE);
37  return addr;
38}
39
40static vmr_t* __vmr_alloc(uintptr_t addr, size_t length, file_t* file,
41                          size_t offset, unsigned refcnt, int prot)
42{
43  if (!vmrs) {
44    spinlock_lock(&vm_lock);
45      if (!vmrs) {
46        vmr_t* page = (vmr_t*)__page_alloc();
47        mb();
48        vmrs = page;
49      }
50    spinlock_unlock(&vm_lock);
51  }
52  mb();
53
54  for (vmr_t* v = vmrs; v < vmrs + MAX_VMR; v++) {
55    if (v->refcnt == 0) {
56      if (file)
57        file_incref(file);
58      v->addr = addr;
59      v->length = length;
60      v->file = file;
61      v->offset = offset;
62      v->refcnt = refcnt;
63      v->prot = prot;
64      return v;
65    }
66  }
67  return NULL;
68}
69
70static void __vmr_decref(vmr_t* v, unsigned dec)
71{
72  if ((v->refcnt -= dec) == 0)
73  {
74    if (v->file)
75      file_decref(v->file);
76  }
77}
78
79static size_t pte_ppn(pte_t pte)
80{
81  return pte >> PTE_PPN_SHIFT;
82}
83
84static uintptr_t ppn(uintptr_t addr)
85{
86  return addr >> RISCV_PGSHIFT;
87}
88
89static size_t pt_idx(uintptr_t addr, int level)
90{
91  size_t idx = addr >> (RISCV_PGLEVEL_BITS*level + RISCV_PGSHIFT);
92  return idx & ((1 << RISCV_PGLEVEL_BITS) - 1);
93}
94
95static pte_t* __walk_create(uintptr_t addr);
96
97static pte_t* __attribute__((noinline)) __continue_walk_create(uintptr_t addr, pte_t* pte)
98{
99  *pte = ptd_create(ppn(__page_alloc()));
100  return __walk_create(addr);
101}
102
103static pte_t* __walk_internal(uintptr_t addr, int create)
104{
105  pte_t* t = root_page_table;
106  for (int i = (VA_BITS - RISCV_PGSHIFT) / RISCV_PGLEVEL_BITS - 1; i > 0; i--) {
107    size_t idx = pt_idx(addr, i);
108    if (unlikely(!(t[idx] & PTE_V)))
109      return create ? __continue_walk_create(addr, &t[idx]) : 0;
110    t = (pte_t*)(pte_ppn(t[idx]) << RISCV_PGSHIFT);
111  }
112  return &t[pt_idx(addr, 0)];
113}
114
115static pte_t* __walk(uintptr_t addr)
116{
117  return __walk_internal(addr, 0);
118}
119
120static pte_t* __walk_create(uintptr_t addr)
121{
122  return __walk_internal(addr, 1);
123}
124
125static int __va_avail(uintptr_t vaddr)
126{
127  pte_t* pte = __walk(vaddr);
128  return pte == 0 || *pte == 0;
129}
130
131static uintptr_t __vm_alloc(size_t npage)
132{
133  uintptr_t start = current.brk, end = current.mmap_max - npage*RISCV_PGSIZE;
134  for (uintptr_t a = start; a <= end; a += RISCV_PGSIZE)
135  {
136    if (!__va_avail(a))
137      continue;
138    uintptr_t first = a, last = a + (npage-1) * RISCV_PGSIZE;
139    for (a = last; a > first && __va_avail(a); a -= RISCV_PGSIZE)
140      ;
141    if (a > first)
142      continue;
143    return a;
144  }
145  return 0;
146}
147
148static inline pte_t prot_to_type(int prot, int user)
149{
150  pte_t pte = 0;
151  if (prot & PROT_READ) pte |= PTE_R | PTE_A;
152  if (prot & PROT_WRITE) pte |= PTE_W | PTE_D;
153  if (prot & PROT_EXEC) pte |= PTE_X | PTE_A;
154  if (pte == 0) pte = PTE_R;
155  if (user) pte |= PTE_U;
156  return pte;
157}
158
159int __valid_user_range(uintptr_t vaddr, size_t len)
160{
161  if (vaddr + len < vaddr)
162    return 0;
163  return vaddr + len <= current.mmap_max;
164}
165
166static int __handle_page_fault(uintptr_t vaddr, int prot)
167{
168  uintptr_t vpn = vaddr >> RISCV_PGSHIFT;
169  vaddr = vpn << RISCV_PGSHIFT;
170
171  pte_t* pte = __walk(vaddr);
172
173  if (pte == 0 || *pte == 0 || !__valid_user_range(vaddr, 1))
174    return -1;
175  else if (!(*pte & PTE_V))
176  {
177    uintptr_t ppn = vpn + (first_free_paddr / RISCV_PGSIZE);
178
179    vmr_t* v = (vmr_t*)*pte;
180    *pte = pte_create(ppn, prot_to_type(PROT_READ|PROT_WRITE, 0));
181    flush_tlb();
182    if (v->file)
183    {
184      size_t flen = MIN(RISCV_PGSIZE, v->length - (vaddr - v->addr));
185      ssize_t ret = file_pread(v->file, (void*)vaddr, flen, vaddr - v->addr + v->offset);
186      kassert(ret > 0);
187      if (ret < RISCV_PGSIZE)
188        memset((void*)vaddr + ret, 0, RISCV_PGSIZE - ret);
189    }
190    else
191      memset((void*)vaddr, 0, RISCV_PGSIZE);
192    __vmr_decref(v, 1);
193    *pte = pte_create(ppn, prot_to_type(v->prot, 1));
194  }
195
196  pte_t perms = pte_create(0, prot_to_type(prot, 1));
197  if ((*pte & perms) != perms)
198    return -1;
199
200  flush_tlb();
201  return 0;
202}
203
204int handle_page_fault(uintptr_t vaddr, int prot)
205{
206  spinlock_lock(&vm_lock);
207    int ret = __handle_page_fault(vaddr, prot);
208  spinlock_unlock(&vm_lock);
209  return ret;
210}
211
212static void __do_munmap(uintptr_t addr, size_t len)
213{
214  for (uintptr_t a = addr; a < addr + len; a += RISCV_PGSIZE)
215  {
216    pte_t* pte = __walk(a);
217    if (pte == 0 || *pte == 0)
218      continue;
219
220    if (!(*pte & PTE_V))
221      __vmr_decref((vmr_t*)*pte, 1);
222
223    *pte = 0;
224  }
225  flush_tlb(); // TODO: shootdown
226}
227
228uintptr_t __do_mmap(uintptr_t addr, size_t length, int prot, int flags, file_t* f, off_t offset)
229{
230  size_t npage = (length-1)/RISCV_PGSIZE+1;
231  if (flags & MAP_FIXED)
232  {
233    if ((addr & (RISCV_PGSIZE-1)) || !__valid_user_range(addr, length))
234      return (uintptr_t)-1;
235  }
236  else if ((addr = __vm_alloc(npage)) == 0)
237    return (uintptr_t)-1;
238
239  vmr_t* v = __vmr_alloc(addr, length, f, offset, npage, prot);
240  if (!v)
241    return (uintptr_t)-1;
242
243  for (uintptr_t a = addr; a < addr + length; a += RISCV_PGSIZE)
244  {
245    pte_t* pte = __walk_create(a);
246    kassert(pte);
247
248    if (*pte)
249      __do_munmap(a, RISCV_PGSIZE);
250
251    *pte = (pte_t)v;
252  }
253
254  if (!demand_paging || (flags & MAP_POPULATE))
255    for (uintptr_t a = addr; a < addr + length; a += RISCV_PGSIZE)
256      kassert(__handle_page_fault(a, prot) == 0);
257
258  return addr;
259}
260
261int do_munmap(uintptr_t addr, size_t length)
262{
263  if ((addr & (RISCV_PGSIZE-1)) || !__valid_user_range(addr, length))
264    return -EINVAL;
265
266  spinlock_lock(&vm_lock);
267    __do_munmap(addr, length);
268  spinlock_unlock(&vm_lock);
269
270  return 0;
271}
272
273uintptr_t do_mmap(uintptr_t addr, size_t length, int prot, int flags, int fd, off_t offset)
274{
275  if (!(flags & MAP_PRIVATE) || length == 0 || (offset & (RISCV_PGSIZE-1)))
276    return -EINVAL;
277
278  file_t* f = NULL;
279  if (!(flags & MAP_ANONYMOUS) && (f = file_get(fd)) == NULL)
280    return -EBADF;
281
282  spinlock_lock(&vm_lock);
283    addr = __do_mmap(addr, length, prot, flags, f, offset);
284
285    if (addr < current.brk_max)
286      current.brk_max = addr;
287  spinlock_unlock(&vm_lock);
288
289  if (f) file_decref(f);
290  return addr;
291}
292
293uintptr_t __do_brk(size_t addr)
294{
295  uintptr_t newbrk = addr;
296  if (addr < current.brk_min)
297    newbrk = current.brk_min;
298  else if (addr > current.brk_max)
299    newbrk = current.brk_max;
300
301  if (current.brk == 0)
302    current.brk = ROUNDUP(current.brk_min, RISCV_PGSIZE);
303
304  uintptr_t newbrk_page = ROUNDUP(newbrk, RISCV_PGSIZE);
305  if (current.brk > newbrk_page)
306    __do_munmap(newbrk_page, current.brk - newbrk_page);
307  else if (current.brk < newbrk_page)
308    kassert(__do_mmap(current.brk, newbrk_page - current.brk, -1, MAP_FIXED|MAP_PRIVATE|MAP_ANONYMOUS, 0, 0) == current.brk);
309  current.brk = newbrk_page;
310
311  return newbrk;
312}
313
314uintptr_t do_brk(size_t addr)
315{
316  spinlock_lock(&vm_lock);
317    addr = __do_brk(addr);
318  spinlock_unlock(&vm_lock);
319
320  return addr;
321}
322
323uintptr_t do_mremap(uintptr_t addr, size_t old_size, size_t new_size, int flags)
324{
325  return -ENOSYS;
326}
327
328uintptr_t do_mprotect(uintptr_t addr, size_t length, int prot)
329{
330  uintptr_t res = 0;
331  if ((addr) & (RISCV_PGSIZE-1))
332    return -EINVAL;
333
334  spinlock_lock(&vm_lock);
335    for (uintptr_t a = addr; a < addr + length; a += RISCV_PGSIZE)
336    {
337      pte_t* pte = __walk(a);
338      if (pte == 0 || *pte == 0) {
339        res = -ENOMEM;
340        break;
341      }
342
343      if (!(*pte & PTE_V)) {
344        vmr_t* v = (vmr_t*)*pte;
345        if((v->prot ^ prot) & ~v->prot){
346          //TODO:look at file to find perms
347          res = -EACCES;
348          break;
349        }
350        v->prot = prot;
351      } else {
352        if (!(*pte & PTE_U) ||
353            ((prot & PROT_READ) && !(*pte & PTE_R)) ||
354            ((prot & PROT_WRITE) && !(*pte & PTE_W)) ||
355            ((prot & PROT_EXEC) && !(*pte & PTE_X))) {
356          //TODO:look at file to find perms
357          res = -EACCES;
358          break;
359        }
360        *pte = pte_create(pte_ppn(*pte), prot_to_type(prot, 1));
361      }
362    }
363  spinlock_unlock(&vm_lock);
364
365  flush_tlb();
366  return res;
367}
368
369void __map_kernel_range(uintptr_t vaddr, uintptr_t paddr, size_t len, int prot)
370{
371  uintptr_t n = ROUNDUP(len, RISCV_PGSIZE) / RISCV_PGSIZE;
372  uintptr_t offset = paddr - vaddr;
373  for (uintptr_t a = vaddr, i = 0; i < n; i++, a += RISCV_PGSIZE)
374  {
375    pte_t* pte = __walk_create(a);
376    kassert(pte);
377    *pte = pte_create((a + offset) >> RISCV_PGSHIFT, prot_to_type(prot, 0));
378  }
379}
380
381void populate_mapping(const void* start, size_t size, int prot)
382{
383  uintptr_t a0 = ROUNDDOWN((uintptr_t)start, RISCV_PGSIZE);
384  for (uintptr_t a = a0; a < (uintptr_t)start+size; a += RISCV_PGSIZE)
385  {
386    if (prot & PROT_WRITE)
387      atomic_add((int*)a, 0);
388    else
389      atomic_read((int*)a);
390  }
391}
392
393uintptr_t pk_vm_init()
394{
395  // HTIF address signedness and va2pa macro both cap memory size to 2 GiB
396  mem_size = MIN(mem_size, 1U << 31);
397  size_t mem_pages = mem_size >> RISCV_PGSHIFT;
398  free_pages = MAX(8, mem_pages >> (RISCV_PGLEVEL_BITS-1));
399
400  extern char _end;
401  first_free_page = ROUNDUP((uintptr_t)&_end, RISCV_PGSIZE);
402  first_free_paddr = first_free_page + free_pages * RISCV_PGSIZE;
403
404  root_page_table = (void*)__page_alloc();
405  __map_kernel_range(DRAM_BASE, DRAM_BASE, first_free_paddr - DRAM_BASE, PROT_READ|PROT_WRITE|PROT_EXEC);
406
407  current.mmap_max = current.brk_max =
408    MIN(DRAM_BASE, mem_size - (first_free_paddr - DRAM_BASE));
409
410  size_t stack_size = MIN(mem_pages >> 5, 2048) * RISCV_PGSIZE;
411  size_t stack_bottom = __do_mmap(current.mmap_max - stack_size, stack_size, PROT_READ|PROT_WRITE|PROT_EXEC, MAP_PRIVATE|MAP_ANONYMOUS|MAP_FIXED, 0, 0);
412  kassert(stack_bottom != (uintptr_t)-1);
413  current.stack_top = stack_bottom + stack_size;
414
415  flush_tlb();
416  write_csr(sptbr, ((uintptr_t)root_page_table >> RISCV_PGSHIFT) | SATP_MODE_CHOICE);
417
418  uintptr_t kernel_stack_top = __page_alloc() + RISCV_PGSIZE;
419  return kernel_stack_top;
420}
421