cachepc-qemu

Fork of AMDESE/qemu with changes for cachepc side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-qemu
Log | Files | Refs | Submodules | LICENSE | sfeed.txt

decodetree.py (41623B)


      1#!/usr/bin/env python3
      2# Copyright (c) 2018 Linaro Limited
      3#
      4# This library is free software; you can redistribute it and/or
      5# modify it under the terms of the GNU Lesser General Public
      6# License as published by the Free Software Foundation; either
      7# version 2.1 of the License, or (at your option) any later version.
      8#
      9# This library is distributed in the hope that it will be useful,
     10# but WITHOUT ANY WARRANTY; without even the implied warranty of
     11# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
     12# Lesser General Public License for more details.
     13#
     14# You should have received a copy of the GNU Lesser General Public
     15# License along with this library; if not, see <http://www.gnu.org/licenses/>.
     16#
     17
     18#
     19# Generate a decoding tree from a specification file.
     20# See the syntax and semantics in docs/devel/decodetree.rst.
     21#
     22
     23import io
     24import os
     25import re
     26import sys
     27import getopt
     28
     29insnwidth = 32
     30bitop_width = 32
     31insnmask = 0xffffffff
     32variablewidth = False
     33fields = {}
     34arguments = {}
     35formats = {}
     36allpatterns = []
     37anyextern = False
     38
     39translate_prefix = 'trans'
     40translate_scope = 'static '
     41input_file = ''
     42output_file = None
     43output_fd = None
     44insntype = 'uint32_t'
     45decode_function = 'decode'
     46
     47# An identifier for C.
     48re_C_ident = '[a-zA-Z][a-zA-Z0-9_]*'
     49
     50# Identifiers for Arguments, Fields, Formats and Patterns.
     51re_arg_ident = '&[a-zA-Z0-9_]*'
     52re_fld_ident = '%[a-zA-Z0-9_]*'
     53re_fmt_ident = '@[a-zA-Z0-9_]*'
     54re_pat_ident = '[a-zA-Z0-9_]*'
     55
     56def error_with_file(file, lineno, *args):
     57    """Print an error message from file:line and args and exit."""
     58    global output_file
     59    global output_fd
     60
     61    prefix = ''
     62    if file:
     63        prefix += f'{file}:'
     64    if lineno:
     65        prefix += f'{lineno}:'
     66    if prefix:
     67        prefix += ' '
     68    print(prefix, end='error: ', file=sys.stderr)
     69    print(*args, file=sys.stderr)
     70
     71    if output_file and output_fd:
     72        output_fd.close()
     73        os.remove(output_file)
     74    exit(1)
     75# end error_with_file
     76
     77
     78def error(lineno, *args):
     79    error_with_file(input_file, lineno, *args)
     80# end error
     81
     82
     83def output(*args):
     84    global output_fd
     85    for a in args:
     86        output_fd.write(a)
     87
     88
     89def output_autogen():
     90    output('/* This file is autogenerated by scripts/decodetree.py.  */\n\n')
     91
     92
     93def str_indent(c):
     94    """Return a string with C spaces"""
     95    return ' ' * c
     96
     97
     98def str_fields(fields):
     99    """Return a string uniquely identifying FIELDS"""
    100    r = ''
    101    for n in sorted(fields.keys()):
    102        r += '_' + n
    103    return r[1:]
    104
    105
    106def whex(val):
    107    """Return a hex string for val padded for insnwidth"""
    108    global insnwidth
    109    return f'0x{val:0{insnwidth // 4}x}'
    110
    111
    112def whexC(val):
    113    """Return a hex string for val padded for insnwidth,
    114       and with the proper suffix for a C constant."""
    115    suffix = ''
    116    if val >= 0x100000000:
    117        suffix = 'ull'
    118    elif val >= 0x80000000:
    119        suffix = 'u'
    120    return whex(val) + suffix
    121
    122
    123def str_match_bits(bits, mask):
    124    """Return a string pretty-printing BITS/MASK"""
    125    global insnwidth
    126
    127    i = 1 << (insnwidth - 1)
    128    space = 0x01010100
    129    r = ''
    130    while i != 0:
    131        if i & mask:
    132            if i & bits:
    133                r += '1'
    134            else:
    135                r += '0'
    136        else:
    137            r += '.'
    138        if i & space:
    139            r += ' '
    140        i >>= 1
    141    return r
    142
    143
    144def is_pow2(x):
    145    """Return true iff X is equal to a power of 2."""
    146    return (x & (x - 1)) == 0
    147
    148
    149def ctz(x):
    150    """Return the number of times 2 factors into X."""
    151    assert x != 0
    152    r = 0
    153    while ((x >> r) & 1) == 0:
    154        r += 1
    155    return r
    156
    157
    158def is_contiguous(bits):
    159    if bits == 0:
    160        return -1
    161    shift = ctz(bits)
    162    if is_pow2((bits >> shift) + 1):
    163        return shift
    164    else:
    165        return -1
    166
    167
    168def eq_fields_for_args(flds_a, arg):
    169    if len(flds_a) != len(arg.fields):
    170        return False
    171    # Only allow inference on default types
    172    for t in arg.types:
    173        if t != 'int':
    174            return False
    175    for k, a in flds_a.items():
    176        if k not in arg.fields:
    177            return False
    178    return True
    179
    180
    181def eq_fields_for_fmts(flds_a, flds_b):
    182    if len(flds_a) != len(flds_b):
    183        return False
    184    for k, a in flds_a.items():
    185        if k not in flds_b:
    186            return False
    187        b = flds_b[k]
    188        if a.__class__ != b.__class__ or a != b:
    189            return False
    190    return True
    191
    192
    193class Field:
    194    """Class representing a simple instruction field"""
    195    def __init__(self, sign, pos, len):
    196        self.sign = sign
    197        self.pos = pos
    198        self.len = len
    199        self.mask = ((1 << len) - 1) << pos
    200
    201    def __str__(self):
    202        if self.sign:
    203            s = 's'
    204        else:
    205            s = ''
    206        return str(self.pos) + ':' + s + str(self.len)
    207
    208    def str_extract(self):
    209        global bitop_width
    210        s = 's' if self.sign else ''
    211        return f'{s}extract{bitop_width}(insn, {self.pos}, {self.len})'
    212
    213    def __eq__(self, other):
    214        return self.sign == other.sign and self.mask == other.mask
    215
    216    def __ne__(self, other):
    217        return not self.__eq__(other)
    218# end Field
    219
    220
    221class MultiField:
    222    """Class representing a compound instruction field"""
    223    def __init__(self, subs, mask):
    224        self.subs = subs
    225        self.sign = subs[0].sign
    226        self.mask = mask
    227
    228    def __str__(self):
    229        return str(self.subs)
    230
    231    def str_extract(self):
    232        global bitop_width
    233        ret = '0'
    234        pos = 0
    235        for f in reversed(self.subs):
    236            ext = f.str_extract()
    237            if pos == 0:
    238                ret = ext
    239            else:
    240                ret = f'deposit{bitop_width}({ret}, {pos}, {bitop_width - pos}, {ext})'
    241            pos += f.len
    242        return ret
    243
    244    def __ne__(self, other):
    245        if len(self.subs) != len(other.subs):
    246            return True
    247        for a, b in zip(self.subs, other.subs):
    248            if a.__class__ != b.__class__ or a != b:
    249                return True
    250        return False
    251
    252    def __eq__(self, other):
    253        return not self.__ne__(other)
    254# end MultiField
    255
    256
    257class ConstField:
    258    """Class representing an argument field with constant value"""
    259    def __init__(self, value):
    260        self.value = value
    261        self.mask = 0
    262        self.sign = value < 0
    263
    264    def __str__(self):
    265        return str(self.value)
    266
    267    def str_extract(self):
    268        return str(self.value)
    269
    270    def __cmp__(self, other):
    271        return self.value - other.value
    272# end ConstField
    273
    274
    275class FunctionField:
    276    """Class representing a field passed through a function"""
    277    def __init__(self, func, base):
    278        self.mask = base.mask
    279        self.sign = base.sign
    280        self.base = base
    281        self.func = func
    282
    283    def __str__(self):
    284        return self.func + '(' + str(self.base) + ')'
    285
    286    def str_extract(self):
    287        return self.func + '(ctx, ' + self.base.str_extract() + ')'
    288
    289    def __eq__(self, other):
    290        return self.func == other.func and self.base == other.base
    291
    292    def __ne__(self, other):
    293        return not self.__eq__(other)
    294# end FunctionField
    295
    296
    297class ParameterField:
    298    """Class representing a pseudo-field read from a function"""
    299    def __init__(self, func):
    300        self.mask = 0
    301        self.sign = 0
    302        self.func = func
    303
    304    def __str__(self):
    305        return self.func
    306
    307    def str_extract(self):
    308        return self.func + '(ctx)'
    309
    310    def __eq__(self, other):
    311        return self.func == other.func
    312
    313    def __ne__(self, other):
    314        return not self.__eq__(other)
    315# end ParameterField
    316
    317
    318class Arguments:
    319    """Class representing the extracted fields of a format"""
    320    def __init__(self, nm, flds, types, extern):
    321        self.name = nm
    322        self.extern = extern
    323        self.fields = flds
    324        self.types = types
    325
    326    def __str__(self):
    327        return self.name + ' ' + str(self.fields)
    328
    329    def struct_name(self):
    330        return 'arg_' + self.name
    331
    332    def output_def(self):
    333        if not self.extern:
    334            output('typedef struct {\n')
    335            for (n, t) in zip(self.fields, self.types):
    336                output(f'    {t} {n};\n')
    337            output('} ', self.struct_name(), ';\n\n')
    338# end Arguments
    339
    340
    341class General:
    342    """Common code between instruction formats and instruction patterns"""
    343    def __init__(self, name, lineno, base, fixb, fixm, udfm, fldm, flds, w):
    344        self.name = name
    345        self.file = input_file
    346        self.lineno = lineno
    347        self.base = base
    348        self.fixedbits = fixb
    349        self.fixedmask = fixm
    350        self.undefmask = udfm
    351        self.fieldmask = fldm
    352        self.fields = flds
    353        self.width = w
    354
    355    def __str__(self):
    356        return self.name + ' ' + str_match_bits(self.fixedbits, self.fixedmask)
    357
    358    def str1(self, i):
    359        return str_indent(i) + self.__str__()
    360# end General
    361
    362
    363class Format(General):
    364    """Class representing an instruction format"""
    365
    366    def extract_name(self):
    367        global decode_function
    368        return decode_function + '_extract_' + self.name
    369
    370    def output_extract(self):
    371        output('static void ', self.extract_name(), '(DisasContext *ctx, ',
    372               self.base.struct_name(), ' *a, ', insntype, ' insn)\n{\n')
    373        for n, f in self.fields.items():
    374            output('    a->', n, ' = ', f.str_extract(), ';\n')
    375        output('}\n\n')
    376# end Format
    377
    378
    379class Pattern(General):
    380    """Class representing an instruction pattern"""
    381
    382    def output_decl(self):
    383        global translate_scope
    384        global translate_prefix
    385        output('typedef ', self.base.base.struct_name(),
    386               ' arg_', self.name, ';\n')
    387        output(translate_scope, 'bool ', translate_prefix, '_', self.name,
    388               '(DisasContext *ctx, arg_', self.name, ' *a);\n')
    389
    390    def output_code(self, i, extracted, outerbits, outermask):
    391        global translate_prefix
    392        ind = str_indent(i)
    393        arg = self.base.base.name
    394        output(ind, '/* ', self.file, ':', str(self.lineno), ' */\n')
    395        if not extracted:
    396            output(ind, self.base.extract_name(),
    397                   '(ctx, &u.f_', arg, ', insn);\n')
    398        for n, f in self.fields.items():
    399            output(ind, 'u.f_', arg, '.', n, ' = ', f.str_extract(), ';\n')
    400        output(ind, 'if (', translate_prefix, '_', self.name,
    401               '(ctx, &u.f_', arg, ')) return true;\n')
    402
    403    # Normal patterns do not have children.
    404    def build_tree(self):
    405        return
    406    def prop_masks(self):
    407        return
    408    def prop_format(self):
    409        return
    410    def prop_width(self):
    411        return
    412
    413# end Pattern
    414
    415
    416class MultiPattern(General):
    417    """Class representing a set of instruction patterns"""
    418
    419    def __init__(self, lineno):
    420        self.file = input_file
    421        self.lineno = lineno
    422        self.pats = []
    423        self.base = None
    424        self.fixedbits = 0
    425        self.fixedmask = 0
    426        self.undefmask = 0
    427        self.width = None
    428
    429    def __str__(self):
    430        r = 'group'
    431        if self.fixedbits is not None:
    432            r += ' ' + str_match_bits(self.fixedbits, self.fixedmask)
    433        return r
    434
    435    def output_decl(self):
    436        for p in self.pats:
    437            p.output_decl()
    438
    439    def prop_masks(self):
    440        global insnmask
    441
    442        fixedmask = insnmask
    443        undefmask = insnmask
    444
    445        # Collect fixedmask/undefmask for all of the children.
    446        for p in self.pats:
    447            p.prop_masks()
    448            fixedmask &= p.fixedmask
    449            undefmask &= p.undefmask
    450
    451        # Widen fixedmask until all fixedbits match
    452        repeat = True
    453        fixedbits = 0
    454        while repeat and fixedmask != 0:
    455            fixedbits = None
    456            for p in self.pats:
    457                thisbits = p.fixedbits & fixedmask
    458                if fixedbits is None:
    459                    fixedbits = thisbits
    460                elif fixedbits != thisbits:
    461                    fixedmask &= ~(fixedbits ^ thisbits)
    462                    break
    463            else:
    464                repeat = False
    465
    466        self.fixedbits = fixedbits
    467        self.fixedmask = fixedmask
    468        self.undefmask = undefmask
    469
    470    def build_tree(self):
    471        for p in self.pats:
    472            p.build_tree()
    473
    474    def prop_format(self):
    475        for p in self.pats:
    476            p.build_tree()
    477
    478    def prop_width(self):
    479        width = None
    480        for p in self.pats:
    481            p.prop_width()
    482            if width is None:
    483                width = p.width
    484            elif width != p.width:
    485                error_with_file(self.file, self.lineno,
    486                                'width mismatch in patterns within braces')
    487        self.width = width
    488
    489# end MultiPattern
    490
    491
    492class IncMultiPattern(MultiPattern):
    493    """Class representing an overlapping set of instruction patterns"""
    494
    495    def output_code(self, i, extracted, outerbits, outermask):
    496        global translate_prefix
    497        ind = str_indent(i)
    498        for p in self.pats:
    499            if outermask != p.fixedmask:
    500                innermask = p.fixedmask & ~outermask
    501                innerbits = p.fixedbits & ~outermask
    502                output(ind, f'if ((insn & {whexC(innermask)}) == {whexC(innerbits)}) {{\n')
    503                output(ind, f'    /* {str_match_bits(p.fixedbits, p.fixedmask)} */\n')
    504                p.output_code(i + 4, extracted, p.fixedbits, p.fixedmask)
    505                output(ind, '}\n')
    506            else:
    507                p.output_code(i, extracted, p.fixedbits, p.fixedmask)
    508#end IncMultiPattern
    509
    510
    511class Tree:
    512    """Class representing a node in a decode tree"""
    513
    514    def __init__(self, fm, tm):
    515        self.fixedmask = fm
    516        self.thismask = tm
    517        self.subs = []
    518        self.base = None
    519
    520    def str1(self, i):
    521        ind = str_indent(i)
    522        r = ind + whex(self.fixedmask)
    523        if self.format:
    524            r += ' ' + self.format.name
    525        r += ' [\n'
    526        for (b, s) in self.subs:
    527            r += ind + f'  {whex(b)}:\n'
    528            r += s.str1(i + 4) + '\n'
    529        r += ind + ']'
    530        return r
    531
    532    def __str__(self):
    533        return self.str1(0)
    534
    535    def output_code(self, i, extracted, outerbits, outermask):
    536        ind = str_indent(i)
    537
    538        # If we identified all nodes below have the same format,
    539        # extract the fields now.
    540        if not extracted and self.base:
    541            output(ind, self.base.extract_name(),
    542                   '(ctx, &u.f_', self.base.base.name, ', insn);\n')
    543            extracted = True
    544
    545        # Attempt to aid the compiler in producing compact switch statements.
    546        # If the bits in the mask are contiguous, extract them.
    547        sh = is_contiguous(self.thismask)
    548        if sh > 0:
    549            # Propagate SH down into the local functions.
    550            def str_switch(b, sh=sh):
    551                return f'(insn >> {sh}) & {b >> sh:#x}'
    552
    553            def str_case(b, sh=sh):
    554                return hex(b >> sh)
    555        else:
    556            def str_switch(b):
    557                return f'insn & {whexC(b)}'
    558
    559            def str_case(b):
    560                return whexC(b)
    561
    562        output(ind, 'switch (', str_switch(self.thismask), ') {\n')
    563        for b, s in sorted(self.subs):
    564            assert (self.thismask & ~s.fixedmask) == 0
    565            innermask = outermask | self.thismask
    566            innerbits = outerbits | b
    567            output(ind, 'case ', str_case(b), ':\n')
    568            output(ind, '    /* ',
    569                   str_match_bits(innerbits, innermask), ' */\n')
    570            s.output_code(i + 4, extracted, innerbits, innermask)
    571            output(ind, '    break;\n')
    572        output(ind, '}\n')
    573# end Tree
    574
    575
    576class ExcMultiPattern(MultiPattern):
    577    """Class representing a non-overlapping set of instruction patterns"""
    578
    579    def output_code(self, i, extracted, outerbits, outermask):
    580        # Defer everything to our decomposed Tree node
    581        self.tree.output_code(i, extracted, outerbits, outermask)
    582
    583    @staticmethod
    584    def __build_tree(pats, outerbits, outermask):
    585        # Find the intersection of all remaining fixedmask.
    586        innermask = ~outermask & insnmask
    587        for i in pats:
    588            innermask &= i.fixedmask
    589
    590        if innermask == 0:
    591            # Edge condition: One pattern covers the entire insnmask
    592            if len(pats) == 1:
    593                t = Tree(outermask, innermask)
    594                t.subs.append((0, pats[0]))
    595                return t
    596
    597            text = 'overlapping patterns:'
    598            for p in pats:
    599                text += '\n' + p.file + ':' + str(p.lineno) + ': ' + str(p)
    600            error_with_file(pats[0].file, pats[0].lineno, text)
    601
    602        fullmask = outermask | innermask
    603
    604        # Sort each element of pats into the bin selected by the mask.
    605        bins = {}
    606        for i in pats:
    607            fb = i.fixedbits & innermask
    608            if fb in bins:
    609                bins[fb].append(i)
    610            else:
    611                bins[fb] = [i]
    612
    613        # We must recurse if any bin has more than one element or if
    614        # the single element in the bin has not been fully matched.
    615        t = Tree(fullmask, innermask)
    616
    617        for b, l in bins.items():
    618            s = l[0]
    619            if len(l) > 1 or s.fixedmask & ~fullmask != 0:
    620                s = ExcMultiPattern.__build_tree(l, b | outerbits, fullmask)
    621            t.subs.append((b, s))
    622
    623        return t
    624
    625    def build_tree(self):
    626        super().prop_format()
    627        self.tree = self.__build_tree(self.pats, self.fixedbits,
    628                                      self.fixedmask)
    629
    630    @staticmethod
    631    def __prop_format(tree):
    632        """Propagate Format objects into the decode tree"""
    633
    634        # Depth first search.
    635        for (b, s) in tree.subs:
    636            if isinstance(s, Tree):
    637                ExcMultiPattern.__prop_format(s)
    638
    639        # If all entries in SUBS have the same format, then
    640        # propagate that into the tree.
    641        f = None
    642        for (b, s) in tree.subs:
    643            if f is None:
    644                f = s.base
    645                if f is None:
    646                    return
    647            if f is not s.base:
    648                return
    649        tree.base = f
    650
    651    def prop_format(self):
    652        super().prop_format()
    653        self.__prop_format(self.tree)
    654
    655# end ExcMultiPattern
    656
    657
    658def parse_field(lineno, name, toks):
    659    """Parse one instruction field from TOKS at LINENO"""
    660    global fields
    661    global insnwidth
    662
    663    # A "simple" field will have only one entry;
    664    # a "multifield" will have several.
    665    subs = []
    666    width = 0
    667    func = None
    668    for t in toks:
    669        if re.match('^!function=', t):
    670            if func:
    671                error(lineno, 'duplicate function')
    672            func = t.split('=')
    673            func = func[1]
    674            continue
    675
    676        if re.fullmatch('[0-9]+:s[0-9]+', t):
    677            # Signed field extract
    678            subtoks = t.split(':s')
    679            sign = True
    680        elif re.fullmatch('[0-9]+:[0-9]+', t):
    681            # Unsigned field extract
    682            subtoks = t.split(':')
    683            sign = False
    684        else:
    685            error(lineno, f'invalid field token "{t}"')
    686        po = int(subtoks[0])
    687        le = int(subtoks[1])
    688        if po + le > insnwidth:
    689            error(lineno, f'field {t} too large')
    690        f = Field(sign, po, le)
    691        subs.append(f)
    692        width += le
    693
    694    if width > insnwidth:
    695        error(lineno, 'field too large')
    696    if len(subs) == 0:
    697        if func:
    698            f = ParameterField(func)
    699        else:
    700            error(lineno, 'field with no value')
    701    else:
    702        if len(subs) == 1:
    703            f = subs[0]
    704        else:
    705            mask = 0
    706            for s in subs:
    707                if mask & s.mask:
    708                    error(lineno, 'field components overlap')
    709                mask |= s.mask
    710            f = MultiField(subs, mask)
    711        if func:
    712            f = FunctionField(func, f)
    713
    714    if name in fields:
    715        error(lineno, 'duplicate field', name)
    716    fields[name] = f
    717# end parse_field
    718
    719
    720def parse_arguments(lineno, name, toks):
    721    """Parse one argument set from TOKS at LINENO"""
    722    global arguments
    723    global re_C_ident
    724    global anyextern
    725
    726    flds = []
    727    types = []
    728    extern = False
    729    for n in toks:
    730        if re.fullmatch('!extern', n):
    731            extern = True
    732            anyextern = True
    733            continue
    734        if re.fullmatch(re_C_ident + ':' + re_C_ident, n):
    735            (n, t) = n.split(':')
    736        elif re.fullmatch(re_C_ident, n):
    737            t = 'int'
    738        else:
    739            error(lineno, f'invalid argument set token "{n}"')
    740        if n in flds:
    741            error(lineno, f'duplicate argument "{n}"')
    742        flds.append(n)
    743        types.append(t)
    744
    745    if name in arguments:
    746        error(lineno, 'duplicate argument set', name)
    747    arguments[name] = Arguments(name, flds, types, extern)
    748# end parse_arguments
    749
    750
    751def lookup_field(lineno, name):
    752    global fields
    753    if name in fields:
    754        return fields[name]
    755    error(lineno, 'undefined field', name)
    756
    757
    758def add_field(lineno, flds, new_name, f):
    759    if new_name in flds:
    760        error(lineno, 'duplicate field', new_name)
    761    flds[new_name] = f
    762    return flds
    763
    764
    765def add_field_byname(lineno, flds, new_name, old_name):
    766    return add_field(lineno, flds, new_name, lookup_field(lineno, old_name))
    767
    768
    769def infer_argument_set(flds):
    770    global arguments
    771    global decode_function
    772
    773    for arg in arguments.values():
    774        if eq_fields_for_args(flds, arg):
    775            return arg
    776
    777    name = decode_function + str(len(arguments))
    778    arg = Arguments(name, flds.keys(), ['int'] * len(flds), False)
    779    arguments[name] = arg
    780    return arg
    781
    782
    783def infer_format(arg, fieldmask, flds, width):
    784    global arguments
    785    global formats
    786    global decode_function
    787
    788    const_flds = {}
    789    var_flds = {}
    790    for n, c in flds.items():
    791        if c is ConstField:
    792            const_flds[n] = c
    793        else:
    794            var_flds[n] = c
    795
    796    # Look for an existing format with the same argument set and fields
    797    for fmt in formats.values():
    798        if arg and fmt.base != arg:
    799            continue
    800        if fieldmask != fmt.fieldmask:
    801            continue
    802        if width != fmt.width:
    803            continue
    804        if not eq_fields_for_fmts(flds, fmt.fields):
    805            continue
    806        return (fmt, const_flds)
    807
    808    name = decode_function + '_Fmt_' + str(len(formats))
    809    if not arg:
    810        arg = infer_argument_set(flds)
    811
    812    fmt = Format(name, 0, arg, 0, 0, 0, fieldmask, var_flds, width)
    813    formats[name] = fmt
    814
    815    return (fmt, const_flds)
    816# end infer_format
    817
    818
    819def parse_generic(lineno, parent_pat, name, toks):
    820    """Parse one instruction format from TOKS at LINENO"""
    821    global fields
    822    global arguments
    823    global formats
    824    global allpatterns
    825    global re_arg_ident
    826    global re_fld_ident
    827    global re_fmt_ident
    828    global re_C_ident
    829    global insnwidth
    830    global insnmask
    831    global variablewidth
    832
    833    is_format = parent_pat is None
    834
    835    fixedmask = 0
    836    fixedbits = 0
    837    undefmask = 0
    838    width = 0
    839    flds = {}
    840    arg = None
    841    fmt = None
    842    for t in toks:
    843        # '&Foo' gives a format an explicit argument set.
    844        if re.fullmatch(re_arg_ident, t):
    845            tt = t[1:]
    846            if arg:
    847                error(lineno, 'multiple argument sets')
    848            if tt in arguments:
    849                arg = arguments[tt]
    850            else:
    851                error(lineno, 'undefined argument set', t)
    852            continue
    853
    854        # '@Foo' gives a pattern an explicit format.
    855        if re.fullmatch(re_fmt_ident, t):
    856            tt = t[1:]
    857            if fmt:
    858                error(lineno, 'multiple formats')
    859            if tt in formats:
    860                fmt = formats[tt]
    861            else:
    862                error(lineno, 'undefined format', t)
    863            continue
    864
    865        # '%Foo' imports a field.
    866        if re.fullmatch(re_fld_ident, t):
    867            tt = t[1:]
    868            flds = add_field_byname(lineno, flds, tt, tt)
    869            continue
    870
    871        # 'Foo=%Bar' imports a field with a different name.
    872        if re.fullmatch(re_C_ident + '=' + re_fld_ident, t):
    873            (fname, iname) = t.split('=%')
    874            flds = add_field_byname(lineno, flds, fname, iname)
    875            continue
    876
    877        # 'Foo=number' sets an argument field to a constant value
    878        if re.fullmatch(re_C_ident + '=[+-]?[0-9]+', t):
    879            (fname, value) = t.split('=')
    880            value = int(value)
    881            flds = add_field(lineno, flds, fname, ConstField(value))
    882            continue
    883
    884        # Pattern of 0s, 1s, dots and dashes indicate required zeros,
    885        # required ones, or dont-cares.
    886        if re.fullmatch('[01.-]+', t):
    887            shift = len(t)
    888            fms = t.replace('0', '1')
    889            fms = fms.replace('.', '0')
    890            fms = fms.replace('-', '0')
    891            fbs = t.replace('.', '0')
    892            fbs = fbs.replace('-', '0')
    893            ubm = t.replace('1', '0')
    894            ubm = ubm.replace('.', '0')
    895            ubm = ubm.replace('-', '1')
    896            fms = int(fms, 2)
    897            fbs = int(fbs, 2)
    898            ubm = int(ubm, 2)
    899            fixedbits = (fixedbits << shift) | fbs
    900            fixedmask = (fixedmask << shift) | fms
    901            undefmask = (undefmask << shift) | ubm
    902        # Otherwise, fieldname:fieldwidth
    903        elif re.fullmatch(re_C_ident + ':s?[0-9]+', t):
    904            (fname, flen) = t.split(':')
    905            sign = False
    906            if flen[0] == 's':
    907                sign = True
    908                flen = flen[1:]
    909            shift = int(flen, 10)
    910            if shift + width > insnwidth:
    911                error(lineno, f'field {fname} exceeds insnwidth')
    912            f = Field(sign, insnwidth - width - shift, shift)
    913            flds = add_field(lineno, flds, fname, f)
    914            fixedbits <<= shift
    915            fixedmask <<= shift
    916            undefmask <<= shift
    917        else:
    918            error(lineno, f'invalid token "{t}"')
    919        width += shift
    920
    921    if variablewidth and width < insnwidth and width % 8 == 0:
    922        shift = insnwidth - width
    923        fixedbits <<= shift
    924        fixedmask <<= shift
    925        undefmask <<= shift
    926        undefmask |= (1 << shift) - 1
    927
    928    # We should have filled in all of the bits of the instruction.
    929    elif not (is_format and width == 0) and width != insnwidth:
    930        error(lineno, f'definition has {width} bits')
    931
    932    # Do not check for fields overlapping fields; one valid usage
    933    # is to be able to duplicate fields via import.
    934    fieldmask = 0
    935    for f in flds.values():
    936        fieldmask |= f.mask
    937
    938    # Fix up what we've parsed to match either a format or a pattern.
    939    if is_format:
    940        # Formats cannot reference formats.
    941        if fmt:
    942            error(lineno, 'format referencing format')
    943        # If an argument set is given, then there should be no fields
    944        # without a place to store it.
    945        if arg:
    946            for f in flds.keys():
    947                if f not in arg.fields:
    948                    error(lineno, f'field {f} not in argument set {arg.name}')
    949        else:
    950            arg = infer_argument_set(flds)
    951        if name in formats:
    952            error(lineno, 'duplicate format name', name)
    953        fmt = Format(name, lineno, arg, fixedbits, fixedmask,
    954                     undefmask, fieldmask, flds, width)
    955        formats[name] = fmt
    956    else:
    957        # Patterns can reference a format ...
    958        if fmt:
    959            # ... but not an argument simultaneously
    960            if arg:
    961                error(lineno, 'pattern specifies both format and argument set')
    962            if fixedmask & fmt.fixedmask:
    963                error(lineno, 'pattern fixed bits overlap format fixed bits')
    964            if width != fmt.width:
    965                error(lineno, 'pattern uses format of different width')
    966            fieldmask |= fmt.fieldmask
    967            fixedbits |= fmt.fixedbits
    968            fixedmask |= fmt.fixedmask
    969            undefmask |= fmt.undefmask
    970        else:
    971            (fmt, flds) = infer_format(arg, fieldmask, flds, width)
    972        arg = fmt.base
    973        for f in flds.keys():
    974            if f not in arg.fields:
    975                error(lineno, f'field {f} not in argument set {arg.name}')
    976            if f in fmt.fields.keys():
    977                error(lineno, f'field {f} set by format and pattern')
    978        for f in arg.fields:
    979            if f not in flds.keys() and f not in fmt.fields.keys():
    980                error(lineno, f'field {f} not initialized')
    981        pat = Pattern(name, lineno, fmt, fixedbits, fixedmask,
    982                      undefmask, fieldmask, flds, width)
    983        parent_pat.pats.append(pat)
    984        allpatterns.append(pat)
    985
    986    # Validate the masks that we have assembled.
    987    if fieldmask & fixedmask:
    988        error(lineno, 'fieldmask overlaps fixedmask ',
    989              f'({whex(fieldmask)} & {whex(fixedmask)})')
    990    if fieldmask & undefmask:
    991        error(lineno, 'fieldmask overlaps undefmask ',
    992              f'({whex(fieldmask)} & {whex(undefmask)})')
    993    if fixedmask & undefmask:
    994        error(lineno, 'fixedmask overlaps undefmask ',
    995              f'({whex(fixedmask)} & {whex(undefmask)})')
    996    if not is_format:
    997        allbits = fieldmask | fixedmask | undefmask
    998        if allbits != insnmask:
    999            error(lineno, 'bits left unspecified ',
   1000                  f'({whex(allbits ^ insnmask)})')
   1001# end parse_general
   1002
   1003
   1004def parse_file(f, parent_pat):
   1005    """Parse all of the patterns within a file"""
   1006    global re_arg_ident
   1007    global re_fld_ident
   1008    global re_fmt_ident
   1009    global re_pat_ident
   1010
   1011    # Read all of the lines of the file.  Concatenate lines
   1012    # ending in backslash; discard empty lines and comments.
   1013    toks = []
   1014    lineno = 0
   1015    nesting = 0
   1016    nesting_pats = []
   1017
   1018    for line in f:
   1019        lineno += 1
   1020
   1021        # Expand and strip spaces, to find indent.
   1022        line = line.rstrip()
   1023        line = line.expandtabs()
   1024        len1 = len(line)
   1025        line = line.lstrip()
   1026        len2 = len(line)
   1027
   1028        # Discard comments
   1029        end = line.find('#')
   1030        if end >= 0:
   1031            line = line[:end]
   1032
   1033        t = line.split()
   1034        if len(toks) != 0:
   1035            # Next line after continuation
   1036            toks.extend(t)
   1037        else:
   1038            # Allow completely blank lines.
   1039            if len1 == 0:
   1040                continue
   1041            indent = len1 - len2
   1042            # Empty line due to comment.
   1043            if len(t) == 0:
   1044                # Indentation must be correct, even for comment lines.
   1045                if indent != nesting:
   1046                    error(lineno, 'indentation ', indent, ' != ', nesting)
   1047                continue
   1048            start_lineno = lineno
   1049            toks = t
   1050
   1051        # Continuation?
   1052        if toks[-1] == '\\':
   1053            toks.pop()
   1054            continue
   1055
   1056        name = toks[0]
   1057        del toks[0]
   1058
   1059        # End nesting?
   1060        if name == '}' or name == ']':
   1061            if len(toks) != 0:
   1062                error(start_lineno, 'extra tokens after close brace')
   1063
   1064            # Make sure { } and [ ] nest properly.
   1065            if (name == '}') != isinstance(parent_pat, IncMultiPattern):
   1066                error(lineno, 'mismatched close brace')
   1067
   1068            try:
   1069                parent_pat = nesting_pats.pop()
   1070            except:
   1071                error(lineno, 'extra close brace')
   1072
   1073            nesting -= 2
   1074            if indent != nesting:
   1075                error(lineno, 'indentation ', indent, ' != ', nesting)
   1076
   1077            toks = []
   1078            continue
   1079
   1080        # Everything else should have current indentation.
   1081        if indent != nesting:
   1082            error(start_lineno, 'indentation ', indent, ' != ', nesting)
   1083
   1084        # Start nesting?
   1085        if name == '{' or name == '[':
   1086            if len(toks) != 0:
   1087                error(start_lineno, 'extra tokens after open brace')
   1088
   1089            if name == '{':
   1090                nested_pat = IncMultiPattern(start_lineno)
   1091            else:
   1092                nested_pat = ExcMultiPattern(start_lineno)
   1093            parent_pat.pats.append(nested_pat)
   1094            nesting_pats.append(parent_pat)
   1095            parent_pat = nested_pat
   1096
   1097            nesting += 2
   1098            toks = []
   1099            continue
   1100
   1101        # Determine the type of object needing to be parsed.
   1102        if re.fullmatch(re_fld_ident, name):
   1103            parse_field(start_lineno, name[1:], toks)
   1104        elif re.fullmatch(re_arg_ident, name):
   1105            parse_arguments(start_lineno, name[1:], toks)
   1106        elif re.fullmatch(re_fmt_ident, name):
   1107            parse_generic(start_lineno, None, name[1:], toks)
   1108        elif re.fullmatch(re_pat_ident, name):
   1109            parse_generic(start_lineno, parent_pat, name, toks)
   1110        else:
   1111            error(lineno, f'invalid token "{name}"')
   1112        toks = []
   1113
   1114    if nesting != 0:
   1115        error(lineno, 'missing close brace')
   1116# end parse_file
   1117
   1118
   1119class SizeTree:
   1120    """Class representing a node in a size decode tree"""
   1121
   1122    def __init__(self, m, w):
   1123        self.mask = m
   1124        self.subs = []
   1125        self.base = None
   1126        self.width = w
   1127
   1128    def str1(self, i):
   1129        ind = str_indent(i)
   1130        r = ind + whex(self.mask) + ' [\n'
   1131        for (b, s) in self.subs:
   1132            r += ind + f'  {whex(b)}:\n'
   1133            r += s.str1(i + 4) + '\n'
   1134        r += ind + ']'
   1135        return r
   1136
   1137    def __str__(self):
   1138        return self.str1(0)
   1139
   1140    def output_code(self, i, extracted, outerbits, outermask):
   1141        ind = str_indent(i)
   1142
   1143        # If we need to load more bytes to test, do so now.
   1144        if extracted < self.width:
   1145            output(ind, f'insn = {decode_function}_load_bytes',
   1146                   f'(ctx, insn, {extracted // 8}, {self.width // 8});\n')
   1147            extracted = self.width
   1148
   1149        # Attempt to aid the compiler in producing compact switch statements.
   1150        # If the bits in the mask are contiguous, extract them.
   1151        sh = is_contiguous(self.mask)
   1152        if sh > 0:
   1153            # Propagate SH down into the local functions.
   1154            def str_switch(b, sh=sh):
   1155                return f'(insn >> {sh}) & {b >> sh:#x}'
   1156
   1157            def str_case(b, sh=sh):
   1158                return hex(b >> sh)
   1159        else:
   1160            def str_switch(b):
   1161                return f'insn & {whexC(b)}'
   1162
   1163            def str_case(b):
   1164                return whexC(b)
   1165
   1166        output(ind, 'switch (', str_switch(self.mask), ') {\n')
   1167        for b, s in sorted(self.subs):
   1168            innermask = outermask | self.mask
   1169            innerbits = outerbits | b
   1170            output(ind, 'case ', str_case(b), ':\n')
   1171            output(ind, '    /* ',
   1172                   str_match_bits(innerbits, innermask), ' */\n')
   1173            s.output_code(i + 4, extracted, innerbits, innermask)
   1174        output(ind, '}\n')
   1175        output(ind, 'return insn;\n')
   1176# end SizeTree
   1177
   1178class SizeLeaf:
   1179    """Class representing a leaf node in a size decode tree"""
   1180
   1181    def __init__(self, m, w):
   1182        self.mask = m
   1183        self.width = w
   1184
   1185    def str1(self, i):
   1186        return str_indent(i) + whex(self.mask)
   1187
   1188    def __str__(self):
   1189        return self.str1(0)
   1190
   1191    def output_code(self, i, extracted, outerbits, outermask):
   1192        global decode_function
   1193        ind = str_indent(i)
   1194
   1195        # If we need to load more bytes, do so now.
   1196        if extracted < self.width:
   1197            output(ind, f'insn = {decode_function}_load_bytes',
   1198                   f'(ctx, insn, {extracted // 8}, {self.width // 8});\n')
   1199            extracted = self.width
   1200        output(ind, 'return insn;\n')
   1201# end SizeLeaf
   1202
   1203
   1204def build_size_tree(pats, width, outerbits, outermask):
   1205    global insnwidth
   1206
   1207    # Collect the mask of bits that are fixed in this width
   1208    innermask = 0xff << (insnwidth - width)
   1209    innermask &= ~outermask
   1210    minwidth = None
   1211    onewidth = True
   1212    for i in pats:
   1213        innermask &= i.fixedmask
   1214        if minwidth is None:
   1215            minwidth = i.width
   1216        elif minwidth != i.width:
   1217            onewidth = False;
   1218            if minwidth < i.width:
   1219                minwidth = i.width
   1220
   1221    if onewidth:
   1222        return SizeLeaf(innermask, minwidth)
   1223
   1224    if innermask == 0:
   1225        if width < minwidth:
   1226            return build_size_tree(pats, width + 8, outerbits, outermask)
   1227
   1228        pnames = []
   1229        for p in pats:
   1230            pnames.append(p.name + ':' + p.file + ':' + str(p.lineno))
   1231        error_with_file(pats[0].file, pats[0].lineno,
   1232                        f'overlapping patterns size {width}:', pnames)
   1233
   1234    bins = {}
   1235    for i in pats:
   1236        fb = i.fixedbits & innermask
   1237        if fb in bins:
   1238            bins[fb].append(i)
   1239        else:
   1240            bins[fb] = [i]
   1241
   1242    fullmask = outermask | innermask
   1243    lens = sorted(bins.keys())
   1244    if len(lens) == 1:
   1245        b = lens[0]
   1246        return build_size_tree(bins[b], width + 8, b | outerbits, fullmask)
   1247
   1248    r = SizeTree(innermask, width)
   1249    for b, l in bins.items():
   1250        s = build_size_tree(l, width, b | outerbits, fullmask)
   1251        r.subs.append((b, s))
   1252    return r
   1253# end build_size_tree
   1254
   1255
   1256def prop_size(tree):
   1257    """Propagate minimum widths up the decode size tree"""
   1258
   1259    if isinstance(tree, SizeTree):
   1260        min = None
   1261        for (b, s) in tree.subs:
   1262            width = prop_size(s)
   1263            if min is None or min > width:
   1264                min = width
   1265        assert min >= tree.width
   1266        tree.width = min
   1267    else:
   1268        min = tree.width
   1269    return min
   1270# end prop_size
   1271
   1272
   1273def main():
   1274    global arguments
   1275    global formats
   1276    global allpatterns
   1277    global translate_scope
   1278    global translate_prefix
   1279    global output_fd
   1280    global output_file
   1281    global input_file
   1282    global insnwidth
   1283    global insntype
   1284    global insnmask
   1285    global decode_function
   1286    global bitop_width
   1287    global variablewidth
   1288    global anyextern
   1289
   1290    decode_scope = 'static '
   1291
   1292    long_opts = ['decode=', 'translate=', 'output=', 'insnwidth=',
   1293                 'static-decode=', 'varinsnwidth=']
   1294    try:
   1295        (opts, args) = getopt.gnu_getopt(sys.argv[1:], 'o:vw:', long_opts)
   1296    except getopt.GetoptError as err:
   1297        error(0, err)
   1298    for o, a in opts:
   1299        if o in ('-o', '--output'):
   1300            output_file = a
   1301        elif o == '--decode':
   1302            decode_function = a
   1303            decode_scope = ''
   1304        elif o == '--static-decode':
   1305            decode_function = a
   1306        elif o == '--translate':
   1307            translate_prefix = a
   1308            translate_scope = ''
   1309        elif o in ('-w', '--insnwidth', '--varinsnwidth'):
   1310            if o == '--varinsnwidth':
   1311                variablewidth = True
   1312            insnwidth = int(a)
   1313            if insnwidth == 16:
   1314                insntype = 'uint16_t'
   1315                insnmask = 0xffff
   1316            elif insnwidth == 64:
   1317                insntype = 'uint64_t'
   1318                insnmask = 0xffffffffffffffff
   1319                bitop_width = 64
   1320            elif insnwidth != 32:
   1321                error(0, 'cannot handle insns of width', insnwidth)
   1322        else:
   1323            assert False, 'unhandled option'
   1324
   1325    if len(args) < 1:
   1326        error(0, 'missing input file')
   1327
   1328    toppat = ExcMultiPattern(0)
   1329
   1330    for filename in args:
   1331        input_file = filename
   1332        f = open(filename, 'rt', encoding='utf-8')
   1333        parse_file(f, toppat)
   1334        f.close()
   1335
   1336    # We do not want to compute masks for toppat, because those masks
   1337    # are used as a starting point for build_tree.  For toppat, we must
   1338    # insist that decode begins from naught.
   1339    for i in toppat.pats:
   1340        i.prop_masks()
   1341
   1342    toppat.build_tree()
   1343    toppat.prop_format()
   1344
   1345    if variablewidth:
   1346        for i in toppat.pats:
   1347            i.prop_width()
   1348        stree = build_size_tree(toppat.pats, 8, 0, 0)
   1349        prop_size(stree)
   1350
   1351    if output_file:
   1352        output_fd = open(output_file, 'wt', encoding='utf-8')
   1353    else:
   1354        output_fd = io.TextIOWrapper(sys.stdout.buffer,
   1355                                     encoding=sys.stdout.encoding,
   1356                                     errors="ignore")
   1357
   1358    output_autogen()
   1359    for n in sorted(arguments.keys()):
   1360        f = arguments[n]
   1361        f.output_def()
   1362
   1363    # A single translate function can be invoked for different patterns.
   1364    # Make sure that the argument sets are the same, and declare the
   1365    # function only once.
   1366    #
   1367    # If we're sharing formats, we're likely also sharing trans_* functions,
   1368    # but we can't tell which ones.  Prevent issues from the compiler by
   1369    # suppressing redundant declaration warnings.
   1370    if anyextern:
   1371        output("#pragma GCC diagnostic push\n",
   1372               "#pragma GCC diagnostic ignored \"-Wredundant-decls\"\n",
   1373               "#ifdef __clang__\n"
   1374               "#  pragma GCC diagnostic ignored \"-Wtypedef-redefinition\"\n",
   1375               "#endif\n\n")
   1376
   1377    out_pats = {}
   1378    for i in allpatterns:
   1379        if i.name in out_pats:
   1380            p = out_pats[i.name]
   1381            if i.base.base != p.base.base:
   1382                error(0, i.name, ' has conflicting argument sets')
   1383        else:
   1384            i.output_decl()
   1385            out_pats[i.name] = i
   1386    output('\n')
   1387
   1388    if anyextern:
   1389        output("#pragma GCC diagnostic pop\n\n")
   1390
   1391    for n in sorted(formats.keys()):
   1392        f = formats[n]
   1393        f.output_extract()
   1394
   1395    output(decode_scope, 'bool ', decode_function,
   1396           '(DisasContext *ctx, ', insntype, ' insn)\n{\n')
   1397
   1398    i4 = str_indent(4)
   1399
   1400    if len(allpatterns) != 0:
   1401        output(i4, 'union {\n')
   1402        for n in sorted(arguments.keys()):
   1403            f = arguments[n]
   1404            output(i4, i4, f.struct_name(), ' f_', f.name, ';\n')
   1405        output(i4, '} u;\n\n')
   1406        toppat.output_code(4, False, 0, 0)
   1407
   1408    output(i4, 'return false;\n')
   1409    output('}\n')
   1410
   1411    if variablewidth:
   1412        output('\n', decode_scope, insntype, ' ', decode_function,
   1413               '_load(DisasContext *ctx)\n{\n',
   1414               '    ', insntype, ' insn = 0;\n\n')
   1415        stree.output_code(4, 0, 0, 0)
   1416        output('}\n')
   1417
   1418    if output_file:
   1419        output_fd.close()
   1420# end main
   1421
   1422
   1423if __name__ == '__main__':
   1424    main()