cachepc-qemu

Fork of AMDESE/qemu with changes for cachepc side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-qemu
Log | Files | Refs | Submodules | LICENSE | sfeed.txt

minimize_qtest_trace.py (11613B)


      1#!/usr/bin/env python3
      2# -*- coding: utf-8 -*-
      3
      4"""
      5This takes a crashing qtest trace and tries to remove superflous operations
      6"""
      7
      8import sys
      9import os
     10import subprocess
     11import time
     12import struct
     13
     14QEMU_ARGS = None
     15QEMU_PATH = None
     16TIMEOUT = 5
     17CRASH_TOKEN = None
     18
     19# Minimization levels
     20M1 = False # try removing IO commands iteratively
     21M2 = False # try setting bits in operand of write/out to zero
     22
     23write_suffix_lookup = {"b": (1, "B"),
     24                       "w": (2, "H"),
     25                       "l": (4, "L"),
     26                       "q": (8, "Q")}
     27
     28def usage():
     29    sys.exit("""\
     30Usage:
     31
     32QEMU_PATH="/path/to/qemu" QEMU_ARGS="args" {} [Options] input_trace output_trace
     33
     34By default, will try to use the second-to-last line in the output to identify
     35whether the crash occred. Optionally, manually set a string that idenitifes the
     36crash by setting CRASH_TOKEN=
     37
     38Options:
     39
     40-M1: enable a loop around the remove minimizer, which may help decrease some
     41     timing dependant instructions. Off by default.
     42-M2: try setting bits in operand of write/out to zero. Off by default.
     43
     44""".format((sys.argv[0])))
     45
     46deduplication_note = """\n\
     47Note: While trimming the input, sometimes the mutated trace triggers a different
     48type crash but indicates the same bug. Under this situation, our minimizer is
     49incapable of recognizing and stopped from removing it. In the future, we may
     50use a more sophisticated crash case deduplication method.
     51\n"""
     52
     53def check_if_trace_crashes(trace, path):
     54    with open(path, "w") as tracefile:
     55        tracefile.write("".join(trace))
     56
     57    rc = subprocess.Popen("timeout -s 9 {timeout}s {qemu_path} {qemu_args} 2>&1\
     58    < {trace_path}".format(timeout=TIMEOUT,
     59                           qemu_path=QEMU_PATH,
     60                           qemu_args=QEMU_ARGS,
     61                           trace_path=path),
     62                          shell=True,
     63                          stdin=subprocess.PIPE,
     64                          stdout=subprocess.PIPE,
     65                          encoding="utf-8")
     66    global CRASH_TOKEN
     67    if CRASH_TOKEN is None:
     68        try:
     69            outs, _ = rc.communicate(timeout=5)
     70            CRASH_TOKEN = " ".join(outs.splitlines()[-2].split()[0:3])
     71        except subprocess.TimeoutExpired:
     72            print("subprocess.TimeoutExpired")
     73            return False
     74        print("Identifying Crashes by this string: {}".format(CRASH_TOKEN))
     75        global deduplication_note
     76        print(deduplication_note)
     77        return True
     78
     79    for line in iter(rc.stdout.readline, ""):
     80        if "CLOSED" in line:
     81            return False
     82        if CRASH_TOKEN in line:
     83            return True
     84
     85    print("\nWarning:")
     86    print("  There is no 'CLOSED'or CRASH_TOKEN in the stdout of subprocess.")
     87    print("  Usually this indicates a different type of crash.\n")
     88    return False
     89
     90
     91# If previous write commands write the same length of data at the same
     92# interval, we view it as a hint.
     93def split_write_hint(newtrace, i):
     94    HINT_LEN = 3 # > 2
     95    if i <=(HINT_LEN-1):
     96        return None
     97
     98    #find previous continuous write traces
     99    k = 0
    100    l = i-1
    101    writes = []
    102    while (k != HINT_LEN and l >= 0):
    103        if newtrace[l].startswith("write "):
    104            writes.append(newtrace[l])
    105            k += 1
    106            l -= 1
    107        elif newtrace[l] == "":
    108            l -= 1
    109        else:
    110            return None
    111    if k != HINT_LEN:
    112        return None
    113
    114    length = int(writes[0].split()[2], 16)
    115    for j in range(1, HINT_LEN):
    116        if length != int(writes[j].split()[2], 16):
    117            return None
    118
    119    step = int(writes[0].split()[1], 16) - int(writes[1].split()[1], 16)
    120    for j in range(1, HINT_LEN-1):
    121        if step != int(writes[j].split()[1], 16) - \
    122            int(writes[j+1].split()[1], 16):
    123            return None
    124
    125    return (int(writes[0].split()[1], 16)+step, length)
    126
    127
    128def remove_lines(newtrace, outpath):
    129    remove_step = 1
    130    i = 0
    131    while i < len(newtrace):
    132        # 1.) Try to remove lines completely and reproduce the crash.
    133        # If it works, we're done.
    134        if (i+remove_step) >= len(newtrace):
    135            remove_step = 1
    136        prior = newtrace[i:i+remove_step]
    137        for j in range(i, i+remove_step):
    138            newtrace[j] = ""
    139        print("Removing {lines} ...\n".format(lines=prior))
    140        if check_if_trace_crashes(newtrace, outpath):
    141            i += remove_step
    142            # Double the number of lines to remove for next round
    143            remove_step *= 2
    144            continue
    145        # Failed to remove multiple IOs, fast recovery
    146        if remove_step > 1:
    147            for j in range(i, i+remove_step):
    148                newtrace[j] = prior[j-i]
    149            remove_step = 1
    150            continue
    151        newtrace[i] = prior[0] # remove_step = 1
    152
    153        # 2.) Try to replace write{bwlq} commands with a write addr, len
    154        # command. Since this can require swapping endianness, try both LE and
    155        # BE options. We do this, so we can "trim" the writes in (3)
    156
    157        if (newtrace[i].startswith("write") and not
    158            newtrace[i].startswith("write ")):
    159            suffix = newtrace[i].split()[0][-1]
    160            assert(suffix in write_suffix_lookup)
    161            addr = int(newtrace[i].split()[1], 16)
    162            value = int(newtrace[i].split()[2], 16)
    163            for endianness in ['<', '>']:
    164                data = struct.pack("{end}{size}".format(end=endianness,
    165                                   size=write_suffix_lookup[suffix][1]),
    166                                   value)
    167                newtrace[i] = "write {addr} {size} 0x{data}\n".format(
    168                    addr=hex(addr),
    169                    size=hex(write_suffix_lookup[suffix][0]),
    170                    data=data.hex())
    171                if(check_if_trace_crashes(newtrace, outpath)):
    172                    break
    173            else:
    174                newtrace[i] = prior[0]
    175
    176        # 3.) If it is a qtest write command: write addr len data, try to split
    177        # it into two separate write commands. If splitting the data operand
    178        # from length/2^n bytes to the left does not work, try to move the pivot
    179        # to the right side, then add one to n, until length/2^n == 0. The idea
    180        # is to prune unneccessary bytes from long writes, while accommodating
    181        # arbitrary MemoryRegion access sizes and alignments.
    182
    183        # This algorithm will fail under some rare situations.
    184        # e.g., xxxxxxxxxuxxxxxx (u is the unnecessary byte)
    185
    186        if newtrace[i].startswith("write "):
    187            addr = int(newtrace[i].split()[1], 16)
    188            length = int(newtrace[i].split()[2], 16)
    189            data = newtrace[i].split()[3][2:]
    190            if length > 1:
    191
    192                # Can we get a hint from previous writes?
    193                hint = split_write_hint(newtrace, i)
    194                if hint is not None:
    195                    hint_addr = hint[0]
    196                    hint_len = hint[1]
    197                    if hint_addr >= addr and hint_addr+hint_len <= addr+length:
    198                        newtrace[i] = "write {addr} {size} 0x{data}\n".format(
    199                            addr=hex(hint_addr),
    200                            size=hex(hint_len),
    201                            data=data[(hint_addr-addr)*2:\
    202                                (hint_addr-addr)*2+hint_len*2])
    203                        if check_if_trace_crashes(newtrace, outpath):
    204                            # next round
    205                            i += 1
    206                            continue
    207                        newtrace[i] = prior[0]
    208
    209                # Try splitting it using a binary approach
    210                leftlength = int(length/2)
    211                rightlength = length - leftlength
    212                newtrace.insert(i+1, "")
    213                power = 1
    214                while leftlength > 0:
    215                    newtrace[i] = "write {addr} {size} 0x{data}\n".format(
    216                            addr=hex(addr),
    217                            size=hex(leftlength),
    218                            data=data[:leftlength*2])
    219                    newtrace[i+1] = "write {addr} {size} 0x{data}\n".format(
    220                            addr=hex(addr+leftlength),
    221                            size=hex(rightlength),
    222                            data=data[leftlength*2:])
    223                    if check_if_trace_crashes(newtrace, outpath):
    224                        break
    225                    # move the pivot to right side
    226                    if leftlength < rightlength:
    227                        rightlength, leftlength = leftlength, rightlength
    228                        continue
    229                    power += 1
    230                    leftlength = int(length/pow(2, power))
    231                    rightlength = length - leftlength
    232                if check_if_trace_crashes(newtrace, outpath):
    233                    i -= 1
    234                else:
    235                    newtrace[i] = prior[0]
    236                    del newtrace[i+1]
    237        i += 1
    238
    239
    240def clear_bits(newtrace, outpath):
    241    # try setting bits in operands of out/write to zero
    242    i = 0
    243    while i < len(newtrace):
    244        if (not newtrace[i].startswith("write ") and not
    245           newtrace[i].startswith("out")):
    246           i += 1
    247           continue
    248        # write ADDR SIZE DATA
    249        # outx ADDR VALUE
    250        print("\nzero setting bits: {}".format(newtrace[i]))
    251
    252        prefix = " ".join(newtrace[i].split()[:-1])
    253        data = newtrace[i].split()[-1]
    254        data_bin = bin(int(data, 16))
    255        data_bin_list = list(data_bin)
    256
    257        for j in range(2, len(data_bin_list)):
    258            prior = newtrace[i]
    259            if (data_bin_list[j] == '1'):
    260                data_bin_list[j] = '0'
    261                data_try = hex(int("".join(data_bin_list), 2))
    262                # It seems qtest only accepts padded hex-values.
    263                if len(data_try) % 2 == 1:
    264                    data_try = data_try[:2] + "0" + data_try[2:]
    265
    266                newtrace[i] = "{prefix} {data_try}\n".format(
    267                        prefix=prefix,
    268                        data_try=data_try)
    269
    270                if not check_if_trace_crashes(newtrace, outpath):
    271                    data_bin_list[j] = '1'
    272                    newtrace[i] = prior
    273        i += 1
    274
    275
    276def minimize_trace(inpath, outpath):
    277    global TIMEOUT
    278    with open(inpath) as f:
    279        trace = f.readlines()
    280    start = time.time()
    281    if not check_if_trace_crashes(trace, outpath):
    282        sys.exit("The input qtest trace didn't cause a crash...")
    283    end = time.time()
    284    print("Crashed in {} seconds".format(end-start))
    285    TIMEOUT = (end-start)*5
    286    print("Setting the timeout for {} seconds".format(TIMEOUT))
    287
    288    newtrace = trace[:]
    289    global M1, M2
    290
    291    # remove lines
    292    old_len = len(newtrace) + 1
    293    while(old_len > len(newtrace)):
    294        old_len = len(newtrace)
    295        print("trace lenth = ", old_len)
    296        remove_lines(newtrace, outpath)
    297        if not M1 and not M2:
    298            break
    299        newtrace = list(filter(lambda s: s != "", newtrace))
    300    assert(check_if_trace_crashes(newtrace, outpath))
    301
    302    # set bits to zero
    303    if M2:
    304        clear_bits(newtrace, outpath)
    305    assert(check_if_trace_crashes(newtrace, outpath))
    306
    307
    308if __name__ == '__main__':
    309    if len(sys.argv) < 3:
    310        usage()
    311    if "-M1" in sys.argv:
    312        M1 = True
    313    if "-M2" in sys.argv:
    314        M2 = True
    315    QEMU_PATH = os.getenv("QEMU_PATH")
    316    QEMU_ARGS = os.getenv("QEMU_ARGS")
    317    if QEMU_PATH is None or QEMU_ARGS is None:
    318        usage()
    319    # if "accel" not in QEMU_ARGS:
    320    #     QEMU_ARGS += " -accel qtest"
    321    CRASH_TOKEN = os.getenv("CRASH_TOKEN")
    322    QEMU_ARGS += " -qtest stdio -monitor none -serial none "
    323    minimize_trace(sys.argv[-2], sys.argv[-1])