cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

sharedbuffer_configuration.py (12658B)


      1#!/usr/bin/env python
      2# SPDX-License-Identifier: GPL-2.0
      3
      4import subprocess
      5import json as j
      6import random
      7
      8
      9class SkipTest(Exception):
     10    pass
     11
     12
     13class RandomValuePicker:
     14    """
     15    Class for storing shared buffer configuration. Can handle 3 different
     16    objects, pool, tcbind and portpool. Provide an interface to get random
     17    values for a specific object type as the follow:
     18      1. Pool:
     19         - random size
     20
     21      2. TcBind:
     22         - random pool number
     23         - random threshold
     24
     25      3. PortPool:
     26         - random threshold
     27    """
     28    def __init__(self, pools):
     29        self._pools = []
     30        for pool in pools:
     31            self._pools.append(pool)
     32
     33    def _cell_size(self):
     34        return self._pools[0]["cell_size"]
     35
     36    def _get_static_size(self, th):
     37        # For threshold of 16, this works out to be about 12MB on Spectrum-1,
     38        # and about 17MB on Spectrum-2.
     39        return th * 8000 * self._cell_size()
     40
     41    def _get_size(self):
     42        return self._get_static_size(16)
     43
     44    def _get_thtype(self):
     45        return "static"
     46
     47    def _get_th(self, pool):
     48        # Threshold value could be any integer between 3 to 16
     49        th = random.randint(3, 16)
     50        if pool["thtype"] == "dynamic":
     51            return th
     52        else:
     53            return self._get_static_size(th)
     54
     55    def _get_pool(self, direction):
     56        ing_pools = []
     57        egr_pools = []
     58        for pool in self._pools:
     59            if pool["type"] == "ingress":
     60                ing_pools.append(pool)
     61            else:
     62                egr_pools.append(pool)
     63        if direction == "ingress":
     64            arr = ing_pools
     65        else:
     66            arr = egr_pools
     67        return arr[random.randint(0, len(arr) - 1)]
     68
     69    def get_value(self, objid):
     70        if isinstance(objid, Pool):
     71            if objid["pool"] in [4, 8, 9, 10]:
     72                # The threshold type of pools 4, 8, 9 and 10 cannot be changed
     73                raise SkipTest()
     74            else:
     75                return (self._get_size(), self._get_thtype())
     76        if isinstance(objid, TcBind):
     77            if objid["tc"] >= 8:
     78                # Multicast TCs cannot be changed
     79                raise SkipTest()
     80            else:
     81                pool = self._get_pool(objid["type"])
     82                th = self._get_th(pool)
     83                pool_n = pool["pool"]
     84                return (pool_n, th)
     85        if isinstance(objid, PortPool):
     86            pool_n = objid["pool"]
     87            pool = self._pools[pool_n]
     88            assert pool["pool"] == pool_n
     89            th = self._get_th(pool)
     90            return (th,)
     91
     92
     93class RecordValuePickerException(Exception):
     94    pass
     95
     96
     97class RecordValuePicker:
     98    """
     99    Class for storing shared buffer configuration. Can handle 2 different
    100    objects, pool and tcbind. Provide an interface to get the stored values per
    101    object type.
    102    """
    103    def __init__(self, objlist):
    104        self._recs = []
    105        for item in objlist:
    106            self._recs.append({"objid": item, "value": item.var_tuple()})
    107
    108    def get_value(self, objid):
    109        if isinstance(objid, Pool) and objid["pool"] in [4, 8, 9, 10]:
    110            # The threshold type of pools 4, 8, 9 and 10 cannot be changed
    111            raise SkipTest()
    112        if isinstance(objid, TcBind) and objid["tc"] >= 8:
    113            # Multicast TCs cannot be changed
    114            raise SkipTest()
    115        for rec in self._recs:
    116            if rec["objid"].weak_eq(objid):
    117                return rec["value"]
    118        raise RecordValuePickerException()
    119
    120
    121def run_cmd(cmd, json=False):
    122    out = subprocess.check_output(cmd, shell=True)
    123    if json:
    124        return j.loads(out)
    125    return out
    126
    127
    128def run_json_cmd(cmd):
    129    return run_cmd(cmd, json=True)
    130
    131
    132def log_test(test_name, err_msg=None):
    133    if err_msg:
    134        print("\t%s" % err_msg)
    135        print("TEST: %-80s  [FAIL]" % test_name)
    136    else:
    137        print("TEST: %-80s  [ OK ]" % test_name)
    138
    139
    140class CommonItem(dict):
    141    varitems = []
    142
    143    def var_tuple(self):
    144        ret = []
    145        self.varitems.sort()
    146        for key in self.varitems:
    147            ret.append(self[key])
    148        return tuple(ret)
    149
    150    def weak_eq(self, other):
    151        for key in self:
    152            if key in self.varitems:
    153                continue
    154            if self[key] != other[key]:
    155                return False
    156        return True
    157
    158
    159class CommonList(list):
    160    def get_by(self, by_obj):
    161        for item in self:
    162            if item.weak_eq(by_obj):
    163                return item
    164        return None
    165
    166    def del_by(self, by_obj):
    167        for item in self:
    168            if item.weak_eq(by_obj):
    169                self.remove(item)
    170
    171
    172class Pool(CommonItem):
    173    varitems = ["size", "thtype"]
    174
    175    def dl_set(self, dlname, size, thtype):
    176        run_cmd("devlink sb pool set {} sb {} pool {} size {} thtype {}".format(dlname, self["sb"],
    177                                                                                self["pool"],
    178                                                                                size, thtype))
    179
    180
    181class PoolList(CommonList):
    182    pass
    183
    184
    185def get_pools(dlname, direction=None):
    186    d = run_json_cmd("devlink sb pool show -j")
    187    pools = PoolList()
    188    for pooldict in d["pool"][dlname]:
    189        if not direction or direction == pooldict["type"]:
    190            pools.append(Pool(pooldict))
    191    return pools
    192
    193
    194def do_check_pools(dlname, pools, vp):
    195    for pool in pools:
    196        pre_pools = get_pools(dlname)
    197        try:
    198            (size, thtype) = vp.get_value(pool)
    199        except SkipTest:
    200            continue
    201        pool.dl_set(dlname, size, thtype)
    202        post_pools = get_pools(dlname)
    203        pool = post_pools.get_by(pool)
    204
    205        err_msg = None
    206        if pool["size"] != size:
    207            err_msg = "Incorrect pool size (got {}, expected {})".format(pool["size"], size)
    208        if pool["thtype"] != thtype:
    209            err_msg = "Incorrect pool threshold type (got {}, expected {})".format(pool["thtype"], thtype)
    210
    211        pre_pools.del_by(pool)
    212        post_pools.del_by(pool)
    213        if pre_pools != post_pools:
    214            err_msg = "Other pool setup changed as well"
    215        log_test("pool {} of sb {} set verification".format(pool["pool"],
    216                                                            pool["sb"]), err_msg)
    217
    218
    219def check_pools(dlname, pools):
    220    # Save defaults
    221    record_vp = RecordValuePicker(pools)
    222
    223    # For each pool, set random size and static threshold type
    224    do_check_pools(dlname, pools, RandomValuePicker(pools))
    225
    226    # Restore defaults
    227    do_check_pools(dlname, pools, record_vp)
    228
    229
    230class TcBind(CommonItem):
    231    varitems = ["pool", "threshold"]
    232
    233    def __init__(self, port, d):
    234        super(TcBind, self).__init__(d)
    235        self["dlportname"] = port.name
    236
    237    def dl_set(self, pool, th):
    238        run_cmd("devlink sb tc bind set {} sb {} tc {} type {} pool {} th {}".format(self["dlportname"],
    239                                                                                     self["sb"],
    240                                                                                     self["tc"],
    241                                                                                     self["type"],
    242                                                                                     pool, th))
    243
    244
    245class TcBindList(CommonList):
    246    pass
    247
    248
    249def get_tcbinds(ports, verify_existence=False):
    250    d = run_json_cmd("devlink sb tc bind show -j -n")
    251    tcbinds = TcBindList()
    252    for port in ports:
    253        err_msg = None
    254        if port.name not in d["tc_bind"] or len(d["tc_bind"][port.name]) == 0:
    255            err_msg = "No tc bind for port"
    256        else:
    257            for tcbinddict in d["tc_bind"][port.name]:
    258                tcbinds.append(TcBind(port, tcbinddict))
    259        if verify_existence:
    260            log_test("tc bind existence for port {} verification".format(port.name), err_msg)
    261    return tcbinds
    262
    263
    264def do_check_tcbind(ports, tcbinds, vp):
    265    for tcbind in tcbinds:
    266        pre_tcbinds = get_tcbinds(ports)
    267        try:
    268            (pool, th) = vp.get_value(tcbind)
    269        except SkipTest:
    270            continue
    271        tcbind.dl_set(pool, th)
    272        post_tcbinds = get_tcbinds(ports)
    273        tcbind = post_tcbinds.get_by(tcbind)
    274
    275        err_msg = None
    276        if tcbind["pool"] != pool:
    277            err_msg = "Incorrect pool (got {}, expected {})".format(tcbind["pool"], pool)
    278        if tcbind["threshold"] != th:
    279            err_msg = "Incorrect threshold (got {}, expected {})".format(tcbind["threshold"], th)
    280
    281        pre_tcbinds.del_by(tcbind)
    282        post_tcbinds.del_by(tcbind)
    283        if pre_tcbinds != post_tcbinds:
    284            err_msg = "Other tc bind setup changed as well"
    285        log_test("tc bind {}-{} of sb {} set verification".format(tcbind["dlportname"],
    286                                                                  tcbind["tc"],
    287                                                                  tcbind["sb"]), err_msg)
    288
    289
    290def check_tcbind(dlname, ports, pools):
    291    tcbinds = get_tcbinds(ports, verify_existence=True)
    292
    293    # Save defaults
    294    record_vp = RecordValuePicker(tcbinds)
    295
    296    # Bind each port and unicast TC (TCs < 8) to a random pool and a random
    297    # threshold
    298    do_check_tcbind(ports, tcbinds, RandomValuePicker(pools))
    299
    300    # Restore defaults
    301    do_check_tcbind(ports, tcbinds, record_vp)
    302
    303
    304class PortPool(CommonItem):
    305    varitems = ["threshold"]
    306
    307    def __init__(self, port, d):
    308        super(PortPool, self).__init__(d)
    309        self["dlportname"] = port.name
    310
    311    def dl_set(self, th):
    312        run_cmd("devlink sb port pool set {} sb {} pool {} th {}".format(self["dlportname"],
    313                                                                         self["sb"],
    314                                                                         self["pool"], th))
    315
    316
    317class PortPoolList(CommonList):
    318    pass
    319
    320
    321def get_portpools(ports, verify_existence=False):
    322    d = run_json_cmd("devlink sb port pool -j -n")
    323    portpools = PortPoolList()
    324    for port in ports:
    325        err_msg = None
    326        if port.name not in d["port_pool"] or len(d["port_pool"][port.name]) == 0:
    327            err_msg = "No port pool for port"
    328        else:
    329            for portpooldict in d["port_pool"][port.name]:
    330                portpools.append(PortPool(port, portpooldict))
    331        if verify_existence:
    332            log_test("port pool existence for port {} verification".format(port.name), err_msg)
    333    return portpools
    334
    335
    336def do_check_portpool(ports, portpools, vp):
    337    for portpool in portpools:
    338        pre_portpools = get_portpools(ports)
    339        (th,) = vp.get_value(portpool)
    340        portpool.dl_set(th)
    341        post_portpools = get_portpools(ports)
    342        portpool = post_portpools.get_by(portpool)
    343
    344        err_msg = None
    345        if portpool["threshold"] != th:
    346            err_msg = "Incorrect threshold (got {}, expected {})".format(portpool["threshold"], th)
    347
    348        pre_portpools.del_by(portpool)
    349        post_portpools.del_by(portpool)
    350        if pre_portpools != post_portpools:
    351            err_msg = "Other port pool setup changed as well"
    352        log_test("port pool {}-{} of sb {} set verification".format(portpool["dlportname"],
    353                                                                    portpool["pool"],
    354                                                                    portpool["sb"]), err_msg)
    355
    356
    357def check_portpool(dlname, ports, pools):
    358    portpools = get_portpools(ports, verify_existence=True)
    359
    360    # Save defaults
    361    record_vp = RecordValuePicker(portpools)
    362
    363    # For each port pool, set a random threshold
    364    do_check_portpool(ports, portpools, RandomValuePicker(pools))
    365
    366    # Restore defaults
    367    do_check_portpool(ports, portpools, record_vp)
    368
    369
    370class Port:
    371    def __init__(self, name):
    372        self.name = name
    373
    374
    375class PortList(list):
    376    pass
    377
    378
    379def get_ports(dlname):
    380    d = run_json_cmd("devlink port show -j")
    381    ports = PortList()
    382    for name in d["port"]:
    383        if name.find(dlname) == 0 and d["port"][name]["flavour"] == "physical":
    384            ports.append(Port(name))
    385    return ports
    386
    387
    388def get_device():
    389    devices_info = run_json_cmd("devlink -j dev info")["info"]
    390    for d in devices_info:
    391        if "mlxsw_spectrum" in devices_info[d]["driver"]:
    392            return d
    393    return None
    394
    395
    396class UnavailableDevlinkNameException(Exception):
    397    pass
    398
    399
    400def test_sb_configuration():
    401    # Use static seed
    402    random.seed(0)
    403
    404    dlname = get_device()
    405    if not dlname:
    406        raise UnavailableDevlinkNameException()
    407
    408    ports = get_ports(dlname)
    409    pools = get_pools(dlname)
    410
    411    check_pools(dlname, pools)
    412    check_tcbind(dlname, ports, pools)
    413    check_portpool(dlname, ports, pools)
    414
    415
    416test_sb_configuration()