cachepc-qemu

Fork of AMDESE/qemu with changes for cachepc side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-qemu
Log | Files | Refs | Submodules | LICENSE | sfeed.txt

protocol.py (18086B)


      1import asyncio
      2from contextlib import contextmanager
      3import os
      4import socket
      5from tempfile import TemporaryDirectory
      6
      7import avocado
      8
      9from qemu.aqmp import ConnectError, Runstate
     10from qemu.aqmp.protocol import AsyncProtocol, StateError
     11from qemu.aqmp.util import asyncio_run, create_task
     12
     13
     14class NullProtocol(AsyncProtocol[None]):
     15    """
     16    NullProtocol is a test mockup of an AsyncProtocol implementation.
     17
     18    It adds a fake_session instance variable that enables a code path
     19    that bypasses the actual connection logic, but still allows the
     20    reader/writers to start.
     21
     22    Because the message type is defined as None, an asyncio.Event named
     23    'trigger_input' is created that prohibits the reader from
     24    incessantly being able to yield None; this event can be poked to
     25    simulate an incoming message.
     26
     27    For testing symmetry with do_recv, an interface is added to "send" a
     28    Null message.
     29
     30    For testing purposes, a "simulate_disconnection" method is also
     31    added which allows us to trigger a bottom half disconnect without
     32    injecting any real errors into the reader/writer loops; in essence
     33    it performs exactly half of what disconnect() normally does.
     34    """
     35    def __init__(self, name=None):
     36        self.fake_session = False
     37        self.trigger_input: asyncio.Event
     38        super().__init__(name)
     39
     40    async def _establish_session(self):
     41        self.trigger_input = asyncio.Event()
     42        await super()._establish_session()
     43
     44    async def _do_accept(self, address, ssl=None):
     45        if not self.fake_session:
     46            await super()._do_accept(address, ssl)
     47
     48    async def _do_connect(self, address, ssl=None):
     49        if not self.fake_session:
     50            await super()._do_connect(address, ssl)
     51
     52    async def _do_recv(self) -> None:
     53        await self.trigger_input.wait()
     54        self.trigger_input.clear()
     55
     56    def _do_send(self, msg: None) -> None:
     57        pass
     58
     59    async def send_msg(self) -> None:
     60        await self._outgoing.put(None)
     61
     62    async def simulate_disconnect(self) -> None:
     63        """
     64        Simulates a bottom-half disconnect.
     65
     66        This method schedules a disconnection but does not wait for it
     67        to complete. This is used to put the loop into the DISCONNECTING
     68        state without fully quiescing it back to IDLE. This is normally
     69        something you cannot coax AsyncProtocol to do on purpose, but it
     70        will be similar to what happens with an unhandled Exception in
     71        the reader/writer.
     72
     73        Under normal circumstances, the library design requires you to
     74        await on disconnect(), which awaits the disconnect task and
     75        returns bottom half errors as a pre-condition to allowing the
     76        loop to return back to IDLE.
     77        """
     78        self._schedule_disconnect()
     79
     80
     81class LineProtocol(AsyncProtocol[str]):
     82    def __init__(self, name=None):
     83        super().__init__(name)
     84        self.rx_history = []
     85
     86    async def _do_recv(self) -> str:
     87        raw = await self._readline()
     88        msg = raw.decode()
     89        self.rx_history.append(msg)
     90        return msg
     91
     92    def _do_send(self, msg: str) -> None:
     93        assert self._writer is not None
     94        self._writer.write(msg.encode() + b'\n')
     95
     96    async def send_msg(self, msg: str) -> None:
     97        await self._outgoing.put(msg)
     98
     99
    100def run_as_task(coro, allow_cancellation=False):
    101    """
    102    Run a given coroutine as a task.
    103
    104    Optionally, wrap it in a try..except block that allows this
    105    coroutine to be canceled gracefully.
    106    """
    107    async def _runner():
    108        try:
    109            await coro
    110        except asyncio.CancelledError:
    111            if allow_cancellation:
    112                return
    113            raise
    114    return create_task(_runner())
    115
    116
    117@contextmanager
    118def jammed_socket():
    119    """
    120    Opens up a random unused TCP port on localhost, then jams it.
    121    """
    122    socks = []
    123
    124    try:
    125        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    126        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    127        sock.bind(('127.0.0.1', 0))
    128        sock.listen(1)
    129        address = sock.getsockname()
    130
    131        socks.append(sock)
    132
    133        # I don't *fully* understand why, but it takes *two* un-accepted
    134        # connections to start jamming the socket.
    135        for _ in range(2):
    136            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    137            sock.connect(address)
    138            socks.append(sock)
    139
    140        yield address
    141
    142    finally:
    143        for sock in socks:
    144            sock.close()
    145
    146
    147class Smoke(avocado.Test):
    148
    149    def setUp(self):
    150        self.proto = NullProtocol()
    151
    152    def test__repr__(self):
    153        self.assertEqual(
    154            repr(self.proto),
    155            "<NullProtocol runstate=IDLE>"
    156        )
    157
    158    def testRunstate(self):
    159        self.assertEqual(
    160            self.proto.runstate,
    161            Runstate.IDLE
    162        )
    163
    164    def testDefaultName(self):
    165        self.assertEqual(
    166            self.proto.name,
    167            None
    168        )
    169
    170    def testLogger(self):
    171        self.assertEqual(
    172            self.proto.logger.name,
    173            'qemu.aqmp.protocol'
    174        )
    175
    176    def testName(self):
    177        self.proto = NullProtocol('Steve')
    178
    179        self.assertEqual(
    180            self.proto.name,
    181            'Steve'
    182        )
    183
    184        self.assertEqual(
    185            self.proto.logger.name,
    186            'qemu.aqmp.protocol.Steve'
    187        )
    188
    189        self.assertEqual(
    190            repr(self.proto),
    191            "<NullProtocol name='Steve' runstate=IDLE>"
    192        )
    193
    194
    195class TestBase(avocado.Test):
    196
    197    def setUp(self):
    198        self.proto = NullProtocol(type(self).__name__)
    199        self.assertEqual(self.proto.runstate, Runstate.IDLE)
    200        self.runstate_watcher = None
    201
    202    def tearDown(self):
    203        self.assertEqual(self.proto.runstate, Runstate.IDLE)
    204
    205    async def _asyncSetUp(self):
    206        pass
    207
    208    async def _asyncTearDown(self):
    209        if self.runstate_watcher:
    210            await self.runstate_watcher
    211
    212    @staticmethod
    213    def async_test(async_test_method):
    214        """
    215        Decorator; adds SetUp and TearDown to async tests.
    216        """
    217        async def _wrapper(self, *args, **kwargs):
    218            loop = asyncio.get_event_loop()
    219            loop.set_debug(True)
    220
    221            await self._asyncSetUp()
    222            await async_test_method(self, *args, **kwargs)
    223            await self._asyncTearDown()
    224
    225        return _wrapper
    226
    227    # Definitions
    228
    229    # The states we expect a "bad" connect/accept attempt to transition through
    230    BAD_CONNECTION_STATES = (
    231        Runstate.CONNECTING,
    232        Runstate.DISCONNECTING,
    233        Runstate.IDLE,
    234    )
    235
    236    # The states we expect a "good" session to transition through
    237    GOOD_CONNECTION_STATES = (
    238        Runstate.CONNECTING,
    239        Runstate.RUNNING,
    240        Runstate.DISCONNECTING,
    241        Runstate.IDLE,
    242    )
    243
    244    # Helpers
    245
    246    async def _watch_runstates(self, *states):
    247        """
    248        This launches a task alongside (most) tests below to confirm that
    249        the sequence of runstate changes that occur is exactly as
    250        anticipated.
    251        """
    252        async def _watcher():
    253            for state in states:
    254                new_state = await self.proto.runstate_changed()
    255                self.assertEqual(
    256                    new_state,
    257                    state,
    258                    msg=f"Expected state '{state.name}'",
    259                )
    260
    261        self.runstate_watcher = create_task(_watcher())
    262        # Kick the loop and force the task to block on the event.
    263        await asyncio.sleep(0)
    264
    265
    266class State(TestBase):
    267
    268    @TestBase.async_test
    269    async def testSuperfluousDisconnect(self):
    270        """
    271        Test calling disconnect() while already disconnected.
    272        """
    273        await self._watch_runstates(
    274            Runstate.DISCONNECTING,
    275            Runstate.IDLE,
    276        )
    277        await self.proto.disconnect()
    278
    279
    280class Connect(TestBase):
    281    """
    282    Tests primarily related to calling Connect().
    283    """
    284    async def _bad_connection(self, family: str):
    285        assert family in ('INET', 'UNIX')
    286
    287        if family == 'INET':
    288            await self.proto.connect(('127.0.0.1', 0))
    289        elif family == 'UNIX':
    290            await self.proto.connect('/dev/null')
    291
    292    async def _hanging_connection(self):
    293        with jammed_socket() as addr:
    294            await self.proto.connect(addr)
    295
    296    async def _bad_connection_test(self, family: str):
    297        await self._watch_runstates(*self.BAD_CONNECTION_STATES)
    298
    299        with self.assertRaises(ConnectError) as context:
    300            await self._bad_connection(family)
    301
    302        self.assertIsInstance(context.exception.exc, OSError)
    303        self.assertEqual(
    304            context.exception.error_message,
    305            "Failed to establish connection"
    306        )
    307
    308    @TestBase.async_test
    309    async def testBadINET(self):
    310        """
    311        Test an immediately rejected call to an IP target.
    312        """
    313        await self._bad_connection_test('INET')
    314
    315    @TestBase.async_test
    316    async def testBadUNIX(self):
    317        """
    318        Test an immediately rejected call to a UNIX socket target.
    319        """
    320        await self._bad_connection_test('UNIX')
    321
    322    @TestBase.async_test
    323    async def testCancellation(self):
    324        """
    325        Test what happens when a connection attempt is aborted.
    326        """
    327        # Note that accept() cannot be cancelled outright, as it isn't a task.
    328        # However, we can wrap it in a task and cancel *that*.
    329        await self._watch_runstates(*self.BAD_CONNECTION_STATES)
    330        task = run_as_task(self._hanging_connection(), allow_cancellation=True)
    331
    332        state = await self.proto.runstate_changed()
    333        self.assertEqual(state, Runstate.CONNECTING)
    334
    335        # This is insider baseball, but the connection attempt has
    336        # yielded *just* before the actual connection attempt, so kick
    337        # the loop to make sure it's truly wedged.
    338        await asyncio.sleep(0)
    339
    340        task.cancel()
    341        await task
    342
    343    @TestBase.async_test
    344    async def testTimeout(self):
    345        """
    346        Test what happens when a connection attempt times out.
    347        """
    348        await self._watch_runstates(*self.BAD_CONNECTION_STATES)
    349        task = run_as_task(self._hanging_connection())
    350
    351        # More insider baseball: to improve the speed of this test while
    352        # guaranteeing that the connection even gets a chance to start,
    353        # verify that the connection hangs *first*, then await the
    354        # result of the task with a nearly-zero timeout.
    355
    356        state = await self.proto.runstate_changed()
    357        self.assertEqual(state, Runstate.CONNECTING)
    358        await asyncio.sleep(0)
    359
    360        with self.assertRaises(asyncio.TimeoutError):
    361            await asyncio.wait_for(task, timeout=0)
    362
    363    @TestBase.async_test
    364    async def testRequire(self):
    365        """
    366        Test what happens when a connection attempt is made while CONNECTING.
    367        """
    368        await self._watch_runstates(*self.BAD_CONNECTION_STATES)
    369        task = run_as_task(self._hanging_connection(), allow_cancellation=True)
    370
    371        state = await self.proto.runstate_changed()
    372        self.assertEqual(state, Runstate.CONNECTING)
    373
    374        with self.assertRaises(StateError) as context:
    375            await self._bad_connection('UNIX')
    376
    377        self.assertEqual(
    378            context.exception.error_message,
    379            "NullProtocol is currently connecting."
    380        )
    381        self.assertEqual(context.exception.state, Runstate.CONNECTING)
    382        self.assertEqual(context.exception.required, Runstate.IDLE)
    383
    384        task.cancel()
    385        await task
    386
    387    @TestBase.async_test
    388    async def testImplicitRunstateInit(self):
    389        """
    390        Test what happens if we do not wait on the runstate event until
    391        AFTER a connection is made, i.e., connect()/accept() themselves
    392        initialize the runstate event. All of the above tests force the
    393        initialization by waiting on the runstate *first*.
    394        """
    395        task = run_as_task(self._hanging_connection(), allow_cancellation=True)
    396
    397        # Kick the loop to coerce the state change
    398        await asyncio.sleep(0)
    399        assert self.proto.runstate == Runstate.CONNECTING
    400
    401        # We already missed the transition to CONNECTING
    402        await self._watch_runstates(Runstate.DISCONNECTING, Runstate.IDLE)
    403
    404        task.cancel()
    405        await task
    406
    407
    408class Accept(Connect):
    409    """
    410    All of the same tests as Connect, but using the accept() interface.
    411    """
    412    async def _bad_connection(self, family: str):
    413        assert family in ('INET', 'UNIX')
    414
    415        if family == 'INET':
    416            await self.proto.accept(('example.com', 1))
    417        elif family == 'UNIX':
    418            await self.proto.accept('/dev/null')
    419
    420    async def _hanging_connection(self):
    421        with TemporaryDirectory(suffix='.aqmp') as tmpdir:
    422            sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
    423            await self.proto.accept(sock)
    424
    425
    426class FakeSession(TestBase):
    427
    428    def setUp(self):
    429        super().setUp()
    430        self.proto.fake_session = True
    431
    432    async def _asyncSetUp(self):
    433        await super()._asyncSetUp()
    434        await self._watch_runstates(*self.GOOD_CONNECTION_STATES)
    435
    436    async def _asyncTearDown(self):
    437        await self.proto.disconnect()
    438        await super()._asyncTearDown()
    439
    440    ####
    441
    442    @TestBase.async_test
    443    async def testFakeConnect(self):
    444
    445        """Test the full state lifecycle (via connect) with a no-op session."""
    446        await self.proto.connect('/not/a/real/path')
    447        self.assertEqual(self.proto.runstate, Runstate.RUNNING)
    448
    449    @TestBase.async_test
    450    async def testFakeAccept(self):
    451        """Test the full state lifecycle (via accept) with a no-op session."""
    452        await self.proto.accept('/not/a/real/path')
    453        self.assertEqual(self.proto.runstate, Runstate.RUNNING)
    454
    455    @TestBase.async_test
    456    async def testFakeRecv(self):
    457        """Test receiving a fake/null message."""
    458        await self.proto.accept('/not/a/real/path')
    459
    460        logname = self.proto.logger.name
    461        with self.assertLogs(logname, level='DEBUG') as context:
    462            self.proto.trigger_input.set()
    463            self.proto.trigger_input.clear()
    464            await asyncio.sleep(0)  # Kick reader.
    465
    466        self.assertEqual(
    467            context.output,
    468            [f"DEBUG:{logname}:<-- None"],
    469        )
    470
    471    @TestBase.async_test
    472    async def testFakeSend(self):
    473        """Test sending a fake/null message."""
    474        await self.proto.accept('/not/a/real/path')
    475
    476        logname = self.proto.logger.name
    477        with self.assertLogs(logname, level='DEBUG') as context:
    478            # Cheat: Send a Null message to nobody.
    479            await self.proto.send_msg()
    480            # Kick writer; awaiting on a queue.put isn't sufficient to yield.
    481            await asyncio.sleep(0)
    482
    483        self.assertEqual(
    484            context.output,
    485            [f"DEBUG:{logname}:--> None"],
    486        )
    487
    488    async def _prod_session_api(
    489            self,
    490            current_state: Runstate,
    491            error_message: str,
    492            accept: bool = True
    493    ):
    494        with self.assertRaises(StateError) as context:
    495            if accept:
    496                await self.proto.accept('/not/a/real/path')
    497            else:
    498                await self.proto.connect('/not/a/real/path')
    499
    500        self.assertEqual(context.exception.error_message, error_message)
    501        self.assertEqual(context.exception.state, current_state)
    502        self.assertEqual(context.exception.required, Runstate.IDLE)
    503
    504    @TestBase.async_test
    505    async def testAcceptRequireRunning(self):
    506        """Test that accept() cannot be called when Runstate=RUNNING"""
    507        await self.proto.accept('/not/a/real/path')
    508
    509        await self._prod_session_api(
    510            Runstate.RUNNING,
    511            "NullProtocol is already connected and running.",
    512            accept=True,
    513        )
    514
    515    @TestBase.async_test
    516    async def testConnectRequireRunning(self):
    517        """Test that connect() cannot be called when Runstate=RUNNING"""
    518        await self.proto.accept('/not/a/real/path')
    519
    520        await self._prod_session_api(
    521            Runstate.RUNNING,
    522            "NullProtocol is already connected and running.",
    523            accept=False,
    524        )
    525
    526    @TestBase.async_test
    527    async def testAcceptRequireDisconnecting(self):
    528        """Test that accept() cannot be called when Runstate=DISCONNECTING"""
    529        await self.proto.accept('/not/a/real/path')
    530
    531        # Cheat: force a disconnect.
    532        await self.proto.simulate_disconnect()
    533
    534        await self._prod_session_api(
    535            Runstate.DISCONNECTING,
    536            ("NullProtocol is disconnecting."
    537             " Call disconnect() to return to IDLE state."),
    538            accept=True,
    539        )
    540
    541    @TestBase.async_test
    542    async def testConnectRequireDisconnecting(self):
    543        """Test that connect() cannot be called when Runstate=DISCONNECTING"""
    544        await self.proto.accept('/not/a/real/path')
    545
    546        # Cheat: force a disconnect.
    547        await self.proto.simulate_disconnect()
    548
    549        await self._prod_session_api(
    550            Runstate.DISCONNECTING,
    551            ("NullProtocol is disconnecting."
    552             " Call disconnect() to return to IDLE state."),
    553            accept=False,
    554        )
    555
    556
    557class SimpleSession(TestBase):
    558
    559    def setUp(self):
    560        super().setUp()
    561        self.server = LineProtocol(type(self).__name__ + '-server')
    562
    563    async def _asyncSetUp(self):
    564        await super()._asyncSetUp()
    565        await self._watch_runstates(*self.GOOD_CONNECTION_STATES)
    566
    567    async def _asyncTearDown(self):
    568        await self.proto.disconnect()
    569        try:
    570            await self.server.disconnect()
    571        except EOFError:
    572            pass
    573        await super()._asyncTearDown()
    574
    575    @TestBase.async_test
    576    async def testSmoke(self):
    577        with TemporaryDirectory(suffix='.aqmp') as tmpdir:
    578            sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
    579            server_task = create_task(self.server.accept(sock))
    580
    581            # give the server a chance to start listening [...]
    582            await asyncio.sleep(0)
    583            await self.proto.connect(sock)