MemMMap: finish implementing brk_syscall
authorJacob Lifshay <programmerjake@gmail.com>
Sun, 3 Dec 2023 03:23:22 +0000 (19:23 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Sun, 3 Dec 2023 09:31:40 +0000 (01:31 -0800)
src/openpower/decoder/isa/mem.py

index bea3c7f135905a65fa6d1a5233114acce212ce72..14a5f4df5fc09fbc25bef6308b646a60347ed46b 100644 (file)
@@ -310,7 +310,8 @@ class MMapPageFlags(enum.IntFlag):
     """this memory block will grow when the address one page before the
     beginning is accessed"""
 
-    RWX = R | W | X
+    RW = R | W
+    RWX = RW | X
     NONE = 0
 
 
@@ -388,12 +389,15 @@ class MMapEmuBlock:
         self.page_indexes  # check that addresses can be mapped to pages
 
     def intersects(self, other):
-        # type: (MMapEmuBlock) -> bool
-        return (other.addrs.start < self.addrs.stop
-                and self.addrs.start < other.addrs.stop)
+        # type: (MMapEmuBlock | range) -> bool
+        if isinstance(other, MMapEmuBlock):
+            other = other.addrs
+        if len_(other) == 0:
+            return False
+        return other.start < self.addrs.stop and self.addrs.start < other.stop
 
     @property
-    def is_private_mem(self):
+    def is_private_anon(self):
         return self.file is None and not self.flags & MMapPageFlags.S
 
     @property
@@ -470,8 +474,7 @@ def len_(r):
 
 class MemMMap(MemCommon):
     def __init__(self, row_bytes=8, initial_mem=None, misaligned_ok=False,
-                 block_addrs=DEFAULT_BLOCK_ADDRS, emulating_mmap=False,
-                 mmap_emu_data_block=None):
+                 block_addrs=DEFAULT_BLOCK_ADDRS, emulating_mmap=False):
         # we can't allocate the entire 2 ** 47 byte address space, so split
         # it into smaller blocks
         self.mem_blocks = {
@@ -480,7 +483,7 @@ class MemMMap(MemCommon):
             "misaligned block address not supported"
         self.__page_flags = {}
         self.modified_pages = set()
-        self.mmap_emu_data_block = mmap_emu_data_block
+        self.__heap_range = None
         self.__mmap_emu_alloc_blocks = set()  # type: set[MMapEmuBlock] | None
 
         for addr, block in self.mem_blocks.items():
@@ -495,22 +498,8 @@ class MemMMap(MemCommon):
             range(a, a + len(b)) for a, b in self.mem_blocks.items()]
         self.__mmap_emu_unbacked_blocks = tuple(self.__gaps_in(addr_ranges))
 
-        if emulating_mmap:
-            if mmap_emu_data_block is not None:
-                if not isinstance(mmap_emu_data_block, MMapEmuBlock):
-                    raise TypeError(
-                        "mmap_emu_data_block must be a MMapEmuBlock")
-                if mmap_emu_data_block.file is not None:
-                    raise ValueError(
-                        "mmap_emu_data_block must be an anonymous mapping")
-                if not self.__mmap_emu_map_fixed(block=mmap_emu_data_block,
-                                                 replace=False, dry_run=False):
-                    raise ValueError("invalid mmap_emu_data_block")
-        else:
+        if not emulating_mmap:
             self.__mmap_emu_alloc_blocks = None
-            if mmap_emu_data_block is not None:
-                raise ValueError("can't set mmap_emu_data_block "
-                                 "without emulating_mmap=True")
             # mark blocks as readable/writable
             for addr, block in self.mem_blocks.items():
                 start_page = self.addr_to_mmap_page_idx(addr)
@@ -520,6 +509,30 @@ class MemMMap(MemCommon):
 
         super().__init__(row_bytes, initial_mem, misaligned_ok)
 
+    @property
+    def heap_range(self):
+        # type: () -> range | None
+        return self.__heap_range
+
+    @heap_range.setter
+    def heap_range(self, value):
+        # type: (range | None) -> None
+        if value is None:
+            self.__heap_range = value
+            return
+        if not self.emulating_mmap:
+            raise ValueError(
+                "can't set heap_range without emulating_mmap=True")
+        if not isinstance(value, range):
+            raise TypeError("heap_range must be a range or None")
+        if value.step != 1 or value.start > value.stop:
+            raise ValueError("heap_range is not a suitable range")
+        if value.start % MMAP_PAGE_SIZE != 0:
+            raise ValueError("heap_range.start must be aligned")
+        if value.stop % MMAP_PAGE_SIZE != 0:
+            raise ValueError("heap_range.stop must be aligned")
+        self.__heap_range = value
+
     @staticmethod
     def __gaps_in(sorted_ranges, start=0, stop=2 ** 64):
         # type: (list[range] | tuple[range], int, int) -> list[range]
@@ -559,10 +572,6 @@ class MemMMap(MemCommon):
         for b in intersecting_blocks:
             if not replace:
                 return False
-            if self.mmap_emu_data_block == b:
-                # FIXME: what does linux do here?
-                raise NotImplementedError(
-                    "mmap overlapping the data block isn't implemented")
             if not dry_run:
                 self.__mmap_emu_alloc_blocks.remove(b)
                 for replacement in b.difference(block):
@@ -573,6 +582,52 @@ class MemMMap(MemCommon):
                 self.__page_flags[page_idx] = block.flags
         return True
 
+    def __mmap_emu_unmap(self, block):
+        # type: (MMapEmuBlock) -> int
+        """unmap `block`, return 0 if no error, otherwise return -errno"""
+        assert block in self.__mmap_emu_alloc_blocks, \
+            "can't unmap already unmapped block"
+
+        # replace mapping with zeros
+        retval = self.__mmap_emu_zero_block(block)
+        if retval < 0:
+            return retval
+        # remove block
+        self.__mmap_emu_alloc_blocks.remove(block)
+        # mark pages as empty
+        for page_idx in block.page_indexes:
+            self.__page_flags.pop(page_idx)
+            self.modified_pages.remove(page_idx)
+        return retval
+
+    def __mmap_emu_zero_block(self, block):
+        # type: (MMapEmuBlock) -> int
+        """ mmap zeros over block, return 0 if no error,
+        otherwise return -errno
+        """
+        mblock = self.mem_blocks[block.underlying_block_key]
+        offsets = block.underlying_block_offsets
+        buf = (ctypes.c_ubyte * len(offsets)).from_buffer(mblock, offsets[0])
+        buf_addr = ctypes.addressof(buf)
+        libc = ctypes.CDLL(None)
+        syscall = libc.syscall
+        syscall.restype = ctypes.c_long
+        syscall.argtypes = (ctypes.c_long,) * 6
+        call_no = ctypes.c_long(ppc_flags.host_defines['SYS_mmap'])
+        host_prot = ppc_flags.host_defines['PROT_READ']
+        host_prot |= ppc_flags.host_defines['PROT_WRITE']
+        host_flags = ppc_flags.host_defines['MAP_ANONYMOUS']
+        host_flags |= ppc_flags.host_defines['MAP_FIXED']
+        host_flags |= ppc_flags.host_defines['MAP_PRIVATE']
+        # map a block of zeros over it
+        if -1 == int(syscall(
+                call_no, ctypes.c_long(buf_addr),
+                ctypes.c_long(len(offsets)),
+                ctypes.c_long(host_prot), ctypes.c_long(host_flags),
+                ctypes.c_long(-1), ctypes.c_long(0))):
+            return -ctypes.get_errno()
+        return 0
+
     def __mmap_emu_resize_map_fixed(self, block, new_size):
         # type: (MMapEmuBlock, int) -> MMapEmuBlock | None
         assert block in self.__mmap_emu_alloc_blocks, \
@@ -588,28 +643,23 @@ class MemMMap(MemCommon):
                 return None
         finally:
             self.__mmap_emu_alloc_blocks.add(block)
-        if not block.is_private_mem:
+        if not block.is_private_anon:
             # FIXME: implement resizing underlying mapping
             raise NotImplementedError
         else:
             # clear newly mapped bytes
             clear_addrs = range(block.addrs.stop, new_block.addrs.stop)
-            if len(clear_addrs):
+            if len_(clear_addrs):
                 clear_block = MMapEmuBlock(clear_addrs)
-                mem_block = self.mem_blocks[clear_block.underlying_block_key]
-                assert mem_block is not None
-                clear_size = len(clear_addrs)
-                arr = (ctypes.c_ubyte * clear_size).from_buffer(
-                    mem_block, clear_block.underlying_block_offsets.start)
-                ctypes.memset(arr, 0, clear_size)
-        if self.mmap_emu_data_block == block:
-            self.mmap_emu_data_block = new_block
-        self.__mmap_emu_alloc_blocks.remove(block)
-        self.__mmap_emu_alloc_blocks.add(new_block)
+                if self.__mmap_emu_zero_block(clear_block) < 0:
+                    return None
 
         if new_size < len(block.addrs):
             # shrinking -- unmap pages at end
             r = range(new_block.page_indexes.stop, block.page_indexes.stop)
+            clear_block = MMapEmuBlock(r)
+            if self.__mmap_emu_zero_block(clear_block) < 0:
+                return None
             for page_idx in r:
                 self.__page_flags.pop(page_idx)
                 self.modified_pages.remove(page_idx)
@@ -619,6 +669,8 @@ class MemMMap(MemCommon):
             for page_idx in r:
                 self.__page_flags[page_idx] = block.flags
                 self.modified_pages.remove(page_idx)  # cleared page
+        self.__mmap_emu_alloc_blocks.remove(block)
+        self.__mmap_emu_alloc_blocks.add(new_block)
         return new_block
 
     def __mmap_emu_find_free_addr(self, block):
@@ -654,14 +706,71 @@ class MemMMap(MemCommon):
 
     def brk_syscall(self, addr):
         assert self.emulating_mmap, "brk syscall requires emulating_mmap=True"
-        assert self.mmap_emu_data_block is not None, \
-            "brk syscall requires a data block/segment"
+        assert self.heap_range is not None, "brk syscall requires a heap"
+
+        if addr < self.heap_range.start:
+            # can't shrink heap to negative size
+            return self.heap_range.stop  # don't change heap
 
         # round addr up to the nearest page
         addr_div_page_size = -(-addr // MMAP_PAGE_SIZE)  # ceil(addr / size)
         addr = addr_div_page_size * MMAP_PAGE_SIZE
 
-        raise NotImplementedError  # FIXME: finish
+        # something else could be mmap-ped in the middle of the heap,
+        # be careful...
+
+        block = None
+        if len_(self.heap_range) != 0:
+            for b in self.__mmap_emu_alloc_blocks:
+                # we check for the end matching so we get the last heap block
+                # if the heap was split.
+                # the heap must not be a file mapping.
+                # the heap must not be shared, and must be RW
+                if b.addrs.stop == self.heap_range.stop and b.file is None \
+                        and b.flags == MMapPageFlags.RW:
+                    block = b
+                    break
+
+        if block is not None and addr < block.addrs.start:
+            # heap was split by something, we can't shrink beyond
+            # the start of the last heap block
+            return self.heap_range.stop  # don't change heap
+
+        if block is not None and addr == block.addrs.start:
+            # unmap heap block
+            if self.__mmap_emu_unmap(block) < 0:
+                block = None  # can't unmap heap block
+        elif addr > self.heap_range.stop and block is None:
+            # map new heap block
+            try:
+                addrs = range(self.heap_range.stop, addr)
+                block = MMapEmuBlock(addrs, flags=MMapPageFlags.RW)
+                if not self.__mmap_emu_map_fixed(block,
+                                                 replace=False, dry_run=True):
+                    block = None
+                elif 0 != self.__mmap_emu_zero_block(block):
+                    block = None
+                else:
+                    self.__mmap_emu_map_fixed(block,
+                                              replace=False, dry_run=False)
+            except (MemException, ValueError):
+                # caller could pass in invalid size, catch that
+                block = None
+        elif block is not None:  # resize block
+            try:
+                block = self.__mmap_emu_resize_map_fixed(
+                    block, addr - block.addrs.start)
+            except (MemException, ValueError):
+                # caller could pass in invalid size, catch that
+                block = None
+
+        if block is None and addr != self.heap_range.start:
+            # can't resize heap block
+            return self.heap_range.stop  # don't change heap
+
+        # success! assign new heap_range
+        self.heap_range = range(self.heap_range.start, addr)
+        return self.heap_range.stop  # return new brk address
 
     def mmap_syscall(self, addr, length, prot, flags, fd, offset, is_mmap2):
         assert self.emulating_mmap, "mmap syscall requires emulating_mmap=True"
@@ -753,8 +862,6 @@ class MemMMap(MemCommon):
         buf_addr = ctypes.addressof(buf)
         libc = ctypes.CDLL(None)
         syscall = libc.syscall
-        restype = syscall.restype
-        argtypes = syscall.argtypes
         syscall.restype = ctypes.c_long
         syscall.argtypes = (ctypes.c_long,) * 6
         call_no = ctypes.c_long(ppc_flags.host_defines['SYS_mmap'])
@@ -768,6 +875,11 @@ class MemMMap(MemCommon):
         extra_zeros_start = 0
         if file is None:
             host_flags |= ppc_flags.host_defines['MAP_ANONYMOUS']
+            # don't remove check, since we'll eventually have shared memory
+            if host_flags & ppc_flags.host_defines['MAP_PRIVATE']:
+                # always map private memory read/write,
+                # so we can clear it if needed
+                host_prot |= ppc_flags.host_defines['PROT_WRITE']
         else:
             file_sz = os.fstat(fd).st_size
             # host-page-align file_sz, rounding up