enowars5-service-stldoctor

STL-Analyzing A/D Service for ENOWARS5 in 2021
git clone https://git.sinitax.com/sinitax/enowars5-service-stldoctor
Log | Files | Refs | README | LICENSE | sfeed.txt

commit ecf4de6db67ce19d90a0b55ad8c1544087398a4c
parent 7501a1d6c20581312eb37de883ab52e83e27c8fa
Author: Louis Burda <quent.burda@gmail.com>
Date:   Fri,  2 Jul 2021 00:08:54 +0200

refactored code to be style compliant and mostly statically typed with mypy

Diffstat:
Achecker/.flake8 | 4++++
Achecker/.mypy.ini | 26++++++++++++++++++++++++++
Achecker/Makefile | 17+++++++++++++++++
Achecker/dev-requirements.txt | 7+++++++
Dchecker/enoreq.py | 116-------------------------------------------------------------------------------
Mchecker/local.sh | 1+
Mchecker/src/checker.py | 804++++++++++++++++++++++++++++++++++++++++++++++++++-----------------------------
Dchecker/test.sh | 115-------------------------------------------------------------------------------
8 files changed, 567 insertions(+), 523 deletions(-)

diff --git a/checker/.flake8 b/checker/.flake8 @@ -0,0 +1,4 @@ +[flake8] +select = F +per-file-ignores = __init__.py:F401 + diff --git a/checker/.mypy.ini b/checker/.mypy.ini @@ -0,0 +1,26 @@ +[mypy] +warn_redundant_casts = True +warn_unused_ignores = True +warn_no_return = True +warn_unreachable = True + +# Untyped Definitions and Calls +disallow_untyped_calls = True +disallow_untyped_defs = True +disallow_incomplete_defs = True +check_untyped_defs = True +disallow_untyped_decorators = True + +[mypy-tests.*] +disallow_untyped_calls = False +disallow_untyped_defs = False +disallow_incomplete_defs = False +disallow_untyped_decorators = False + +[mypy-enochecker3] +ignore_missing_imports = True + +[mypy-faker] +ignore_missing_imports = True + + diff --git a/checker/Makefile b/checker/Makefile @@ -0,0 +1,17 @@ +.PHONY: all lint diff format test + +all: format test + +lint: + python3 -m black --check src + python3 -m flake8 src + python3 -m mypy src + +diff: + python3 -m isort --diff src + python3 -m black --diff src + +format: + python3 -m isort src + python3 -m black src + diff --git a/checker/dev-requirements.txt b/checker/dev-requirements.txt @@ -0,0 +1,7 @@ +mypy==0.910 +black==21.6b0 +isort==5.9.1 +flake8==3.9.2 +coverage==5.5 +pytest==6.2.4 +pytest-asyncio==0.15.1 diff --git a/checker/enoreq.py b/checker/enoreq.py @@ -1,116 +0,0 @@ -import argparse -import hashlib -import sys - -import jsons -import requests -from enochecker_core import CheckerMethod, CheckerResultMessage, CheckerTaskMessage - -TASK_TYPES = [str(i) for i in CheckerMethod] - - -def add_arguments(parser: argparse.ArgumentParser) -> None: - _add_arguments(parser, hide_checker_address=True) - - -def _add_arguments(parser: argparse.ArgumentParser, hide_checker_address=False) -> None: - parser.add_argument("method", choices=TASK_TYPES, help="One of {} ".format(TASK_TYPES)) - if not hide_checker_address: - parser.add_argument("-A", "--checker_address", type=str, default="http://localhost", help="The URL of the checker") - parser.add_argument("-i", "--task_id", type=int, default=1, help="An id for this task. Must be unique in a CTF.") - parser.add_argument("-a", "--address", type=str, default="localhost", help="The ip or address of the remote team to check") - parser.add_argument("-j", "--json", type=bool, default=False, help="Raw JSON output") - parser.add_argument("-T", "--team_id", type=int, default=1, help="The Team_id belonging to the specified Team") - parser.add_argument("-t", "--team_name", type=str, default="team1", help="The name of the target team to check") - parser.add_argument("-r", "--current_round_id", type=int, default=1, help="The round we are in right now") - parser.add_argument( - "-R", - "--related_round_id", - type=int, - default=1, - help="The round in which the flag or noise was stored when method is getflag/getnoise. Equal to current_round_id otherwise.", - ) - parser.add_argument("-f", "--flag", type=str, default="ENOFLAGENOFLAG=", help="The flag for putflag/getflag or the flag to find in exploit mode") - parser.add_argument("-v", "--variant_id", type=int, default=0, help="The variantId for the method being called") - parser.add_argument( - "-x", "--timeout", type=int, default=30000, help="The maximum amount of time the script has to execute in milliseconds (default 30 000)" - ) - parser.add_argument("-l", "--round_length", type=int, default=300000, help="The round length in milliseconds (default 300 000)") - parser.add_argument( - "-I", - "--task_chain_id", - type=str, - default=None, - help="A unique Id which must be identical for all related putflag/getflag calls and putnoise/getnoise calls", - ) - parser.add_argument("--flag_regex", type=str, default=None, help="A regular expression matched by the flag, used only when running the exploit method") - parser.add_argument( - "--attack_info", type=str, default=None, help="The attack info returned by the corresponding putflag, used only when running the exploit method" - ) - - -def task_message_from_namespace(ns: argparse.Namespace) -> CheckerTaskMessage: - task_chain_id = ns.task_chain_id - method = CheckerMethod(ns.method) - if not task_chain_id: - option = None - if method in (CheckerMethod.PUTFLAG, CheckerMethod.GETFLAG): - option = "flag" - elif method in (CheckerMethod.PUTNOISE, CheckerMethod.GETNOISE): - option = "noise" - elif method == CheckerMethod.HAVOC: - option = "havoc" - elif method == CheckerMethod.EXPLOIT: - option = "exploit" - else: - raise ValueError(f"Unexpected CheckerMethod: {method}") - task_chain_id = f"{option}_s0_r{ns.related_round_id}_t{ns.team_id}_i{ns.variant_id}" - - flag_hash = None - if method == CheckerMethod.EXPLOIT: - flag_hash = hashlib.sha256(ns.flag.encode()).hexdigest() - - msg = CheckerTaskMessage( - task_id=ns.task_id, - method=method, - address=ns.address, - team_id=ns.team_id, - team_name=ns.team_name, - current_round_id=ns.current_round_id, - related_round_id=ns.related_round_id, - flag=ns.flag if method != CheckerMethod.EXPLOIT else None, - variant_id=ns.variant_id, - timeout=ns.timeout, - round_length=ns.round_length, - task_chain_id=task_chain_id, - flag_regex=ns.flag_regex, - flag_hash=flag_hash, - attack_info=ns.attack_info, - ) - - return msg - - -def json_task_message_from_namespace(ns: argparse.Namespace) -> str: - return jsons.dumps(task_message_from_namespace(ns), use_enum_name=False, key_transformer=jsons.KEY_TRANSFORMER_CAMELCASE, strict=True) - - -def main() -> None: - parser = argparse.ArgumentParser(description="Your friendly checker script") - _add_arguments(parser) - ns = parser.parse_args(sys.argv[1:]) - msg = json_task_message_from_namespace(ns) - - result = requests.post(ns.checker_address, data=msg, - headers={"content-type": "application/json"},) - if ns.json: - print(result.text) - else: - if result.ok: - result_msg = jsons.loads(result.content, CheckerResultMessage) - print(result_msg.result) - else: - print(result.status_code) - print(result.text) - -main() diff --git a/checker/local.sh b/checker/local.sh @@ -1,6 +1,7 @@ #!/bin/sh if [ -z "$(docker ps | grep stldoctor-mongo)" ]; then + docker-compose down -v docker-compose up -d stldoctor-mongo fi diff --git a/checker/src/checker.py b/checker/src/checker.py @@ -1,33 +1,47 @@ #!/usr/bin/env python3 -import logging, math, os, random, re, socket, string, struct, subprocess, selectors, time +import logging +import math +import os +import random +import re +import struct +import subprocess + import numpy as np logging.getLogger("faker").setLevel(logging.WARNING) logging.getLogger("_curses").setLevel(logging.CRITICAL) -from enochecker3 import * -from enochecker3.utils import * -from faker import Faker +from asyncio import StreamReader, StreamWriter from io import BytesIO -from stl import mesh - -from typing import ( - Any, - Optional, - Tuple, - Union -) - from logging import LoggerAdapter - -from asyncio import StreamReader, StreamWriter +from typing import Any, Optional, Union, cast + +from enochecker3 import ( + AsyncSocket, + ChainDB, + DependencyInjector, + Enochecker, + GetflagCheckerTaskMessage, + GetnoiseCheckerTaskMessage, + HavocCheckerTaskMessage, + InternalErrorException, + MumbleException, + PutflagCheckerTaskMessage, + PutnoiseCheckerTaskMessage, +) +from enochecker3.utils import FlagSearcher, assert_in +from faker import Faker +from stl import mesh rand = random.SystemRandom() -generic_alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmopqrstuvwxyz0123456789-+.!" +generic_alphabet = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmopqrstuvwxyz0123456789-+.!" script_path = os.path.dirname(os.path.realpath(__file__)) models_path = f"{script_path}/models" -extra_models = [f"{models_path}/{path}" for path in os.listdir(models_path) if path.endswith(".stl")] +extra_models = [ + f"{models_path}/{path}" for path in os.listdir(models_path) if path.endswith(".stl") +] prompt = b"\r$ " search_truncation_payload = b""" @@ -45,41 +59,33 @@ endsolid test\xff checker = Enochecker("STLDoctor", 9090) app = lambda: checker.app + class Session: def __init__(self, socket: AsyncSocket) -> None: - self.reader: StreamReader = socket[0] - self.writer: StreamWriter = socket[1] - - self.writer._write = self.writer.write - self.writer.write = Session.write.__get__(self.writer) - - self.reader._readuntil = self.reader.readuntil - self.reader.readuntil = Session.readuntil.__get__(self.reader) - - def write(self: StreamWriter, data: Union[str, bytes]) -> None: - self._write(ensure_bytes(data)) - - def readuntil(self: StreamReader, data: Union[str, bytes]) -> bytes: - return self._readuntil(ensure_bytes(data)) + socket_tuple = cast(tuple[StreamReader, StreamWriter], socket) + self.reader = socket_tuple[0] + self.writer = socket_tuple[1] async def __atexit__(self) -> None: await self.close() async def prepare(self) -> None: - await self.reader.readuntil(prompt) + await self.reader.readuntil(prompt) # skip welcome banner async def close(self) -> None: - self.writer.write("exit\n") + self.writer.write(b"exit\n") await self.writer.drain() - await self.reader.readuntil("bye!") # ensure clean exit + await self.reader.readuntil(b"bye!") # ensure clean exit self.writer.close() await self.writer.wait_closed() + @checker.register_dependency def _get_session(socket: AsyncSocket) -> Session: return Session(socket) -def ensure_bytes(v: Union[str,bytes]) -> bytes: + +def ensure_bytes(v: Union[str, bytes]) -> bytes: if type(v) == bytes: return v elif type(v) == str: @@ -87,68 +93,80 @@ def ensure_bytes(v: Union[str,bytes]) -> bytes: else: raise InternalErrorException("Tried to pass non str/bytes to bytes arg") -def includes_all(resp: bytes, targets: Tuple[bytes, ...]) -> bool: + +def includes_all(resp: bytes, targets: list[bytes]) -> bool: for m in targets: if ensure_bytes(m) not in resp: return False return True -def includes_any(resp: bytes, targets: Tuple[bytes, ...]) -> bool: + +def includes_any(resp: bytes, targets: list[bytes]) -> bool: for m in targets: if ensure_bytes(m) in resp: return True return False -def fakeid(havoc = False) -> bytes: + +def fakeid(havoc: bool = False) -> bytes: if havoc: idlen = rand.randint(10, 40) return bytes([rand.randint(32, 127) for i in range(idlen)]) else: fake = Faker(["en_US"]) - idstr = bytes([ord(c) for c in fake.name().replace(" ","") if c in generic_alphabet][:12]).ljust(10, b".") - idstr += bytes([ord(rand.choice(generic_alphabet)) for i in range(8)]) + idstr = bytes( + [c for c in fake.name().replace(" ", "").encode() if c in generic_alphabet][ + :12 + ] + ).ljust(10, b".") + idstr += bytes([rand.choice(generic_alphabet) for i in range(8)]) return idstr -def fakeids(n: int, **kwargs) -> Tuple[bytes, ...]: - return [fakeid(**kwargs) for i in range(n)] + +def fakeids(n: int, havoc: bool = False) -> list[bytes]: + return [fakeid(havoc) for i in range(n)] + def approx_equal(f1: float, f2: float, precision: int = 2) -> bool: return round(f1, precision) == round(f2, precision) -def reverse_hash(hashstr: Union[str, bytes]): - if type(hashstr) is bytes: - hashstr = hashstr.decode() + +def reverse_hash(hashstr: str) -> bytes: data = subprocess.check_output([f"{script_path}/revhash/revhash", hashstr])[:-1] if data == b"": - raise InternalErrorException(f"Failed to find hash preimage of {hashstr}") + raise InternalErrorException(f"Failed to find hash preimage of {hashstr!r}") return data + def parse_int(intstr: Union[str, bytes]) -> Optional[int]: try: return int(intstr) except: return None + def parse_float(floatstr: Union[str, bytes]) -> Optional[float]: try: return float(floatstr) except: return None + def has_alph(data: Union[str, bytes], alph: Union[str, bytes]) -> bool: return len([v for v in data if v not in alph]) == 0 -def assert_match(data: bytes, pattern: bytes, exception: Exception) -> bytes: + +def assert_match(data: bytes, pattern: bytes, raiser: Any) -> bytes: rem = re.search(pattern, data) if rem is None: - raise exception(f"Expected pattern {pattern} to match {data}") + raise raiser(f"Expected pattern {pattern!r} to match {data!r}") if len(rem.groups()) > 0: return rem.group(1) return rem.group(0) -def genfile_ascii(solidname: str, malformed: bool = None) -> bytes: + +def genfile_ascii(solidname: bytes, malformed: bool = None) -> bytes: indent = bytes([rand.choice(b"\t ") for i in range(rand.randint(1, 4))]) - solidname = ensure_bytes(solidname) facet_count = rand.randint(4, 30) if len(solidname) != 0: @@ -164,7 +182,7 @@ def genfile_ascii(solidname: str, malformed: bool = None) -> bytes: content += indent * 1 + b"facet normal " vs = [[rand.random() for i in range(3)] for k in range(3)] - norm = np.cross(np.subtract(vs[1], vs[0]), np.subtract(vs[2],vs[0])) + norm = np.cross(np.subtract(vs[1], vs[0]), np.subtract(vs[2], vs[0])) norm = norm / np.linalg.norm(norm) content += " ".join([f"{v:.2f}" for v in norm]).encode() + b"\n" @@ -176,7 +194,12 @@ def genfile_ascii(solidname: str, malformed: bool = None) -> bytes: content += indent * 2 + b"outer loop\n" for i in range(3): - content += indent * 3 + b"vertex " + " ".join([f"{v:.2f}" for v in vs[i]]).encode() + b"\n" + content += ( + indent * 3 + + b"vertex " + + " ".join([f"{v:.2f}" for v in vs[i]]).encode() + + b"\n" + ) content += indent * 2 + b"endloop\n" content += indent + b"endfacet\n" @@ -190,12 +213,14 @@ def genfile_ascii(solidname: str, malformed: bool = None) -> bytes: return content + def genfile_bin(solidname: bytes, malformed: bool = None) -> bytes: - solidname = ensure_bytes(solidname) facet_count = rand.randint(4, 30) if len(solidname) > 78: - raise InternalErrorException("Solidname to embed in header is larger than header itself") + raise InternalErrorException( + "Solidname to embed in header is larger than header itself" + ) if solidname != "": content = b"#" + solidname.ljust(78, b"\x00") + b"\x00" else: @@ -209,12 +234,12 @@ def genfile_bin(solidname: bytes, malformed: bool = None) -> bytes: for fi in range(facet_count): vs = [[rand.random() for i in range(3)] for k in range(3)] - norm = np.cross(np.subtract(vs[1], vs[0]), np.subtract(vs[2],vs[0])) + norm = np.cross(np.subtract(vs[1], vs[0]), np.subtract(vs[2], vs[0])) # MALFORM 2: invalid float for norm / vec if malformed == 2: - norm[rand.randint(0,2)] = math.nan - vs[rand.randint(0,2)][rand.randint(0,2)] = math.inf + norm[rand.randint(0, 2)] = math.nan + vs[rand.randint(0, 2)][rand.randint(0, 2)] = math.inf for i in range(3): content += struct.pack("<f", norm[i]) for k in range(3): @@ -228,25 +253,29 @@ def genfile_bin(solidname: bytes, malformed: bool = None) -> bytes: return content -def genfile(solidname: bytes, filetype: str, malformed: Optional[bool] = None) -> bytes: + +def genfile(solidname: bytes, filetype: str, malformed: Optional[Any] = None) -> bytes: if filetype == "ascii": - return genfile_ascii(solidname, malformed = malformed) + return genfile_ascii(solidname, malformed=malformed) elif filetype == "bin": - return genfile_bin(solidname, malformed = malformed) + return genfile_bin(solidname, malformed=malformed) elif filetype == "garbage-tiny": - return bytes([ord(rand.choice(generic_alphabet)) for i in range(rand.randint(3, 8))]) + return bytes([rand.choice(generic_alphabet) for i in range(rand.randint(3, 8))]) elif filetype == "garbage": - return bytes([ord(rand.choice(generic_alphabet)) for i in range(rand.randint(100, 300))]) + return bytes( + [rand.choice(generic_alphabet) for i in range(rand.randint(100, 300))] + ) else: raise InternalErrorException("Invalid file type supplied") + def parse_stlinfo(stlfile: bytes) -> Any: fakefile = BytesIO() fakefile.write(stlfile) fakefile.seek(0) try: name, data = mesh.Mesh.load(fakefile) - meshinfo = mesh.Mesh(data, True, name=name, speedups=True) + meshinfo = mesh.Mesh(data, True, name=name, speedups=True) # type: ignore except Exception as e: raise InternalErrorException(f"Unable to parse generated STL file: {e}") bmin = [math.inf for i in range(3)] @@ -256,29 +285,35 @@ def parse_stlinfo(stlfile: bytes) -> Any: for p in meshinfo.points: for k in range(3): for i in range(3): - bmin[k] = min(bmin[k], float(p[3*i+k])) - bmax[k] = max(bmax[k], float(p[3*i+k])) + bmin[k] = min(bmin[k], float(p[3 * i + k])) + bmax[k] = max(bmax[k], float(p[3 * i + k])) info = { "points": meshinfo.points, "bb_origin": bmin, "bb_size": [bmax[i] - bmin[i] for i in range(3)], "size": len(stlfile), - "triangle_count": len(meshinfo.points) + "triangle_count": len(meshinfo.points), } return info -async def getdb(db: ChainDB, key: str) -> Tuple[Any, ...]: + +async def getdb(db: ChainDB, key: str) -> tuple[Any, ...]: try: return await db.get(key) except KeyError: - raise MumbleException("Could not retrieve necessary info for service interaction") + raise MumbleException( + "Could not retrieve necessary info for service interaction" + ) + # SERVICE FUNCTIONS # -async def do_auth(session: Session, logger: LoggerAdapter, authstr: bytes, check: bool = True) -> Optional[bytes]: - authstr = ensure_bytes(authstr) - logger.debug(f"Logging in with {authstr}") - session.writer.write("auth\n") + +async def do_auth( + session: Session, logger: LoggerAdapter, authstr: bytes, check: bool = True +) -> Optional[bool]: + logger.debug(f"Logging in with {authstr!r}") + session.writer.write(b"auth\n") session.writer.write(authstr + b"\n") await session.writer.drain() @@ -286,40 +321,48 @@ async def do_auth(session: Session, logger: LoggerAdapter, authstr: bytes, check resp = await session.reader.readline() if b"ERR:" in resp: if check: - logger.critical(f"Failed to login with {authstr}:\n{resp}") + logger.critical(f"Failed to login with {authstr!r}:\n{resp!r}") raise MumbleException("Authentication not working properly") return None # Also check success message resp += await session.reader.readuntil(prompt) if b"Success!" not in resp: - logger.critical(f"Login with pass {authstr} failed") + logger.critical(f"Login with pass {authstr!r} failed") raise MumbleException("Authentication not working properly") return b"Welcome back" in resp -async def do_list(session: Session, logger: LoggerAdapter, check: bool = True) -> Optional[bytes]: - session.writer.write("list\n") + +async def do_list( + session: Session, logger: LoggerAdapter, check: bool = True +) -> Optional[bytes]: + session.writer.write(b"list\n") await session.writer.drain() resp = await session.reader.readuntil(prompt) # Check for errors if b"ERR:" in resp and b">> " not in resp: if check: - logger.critical(f"Failed to list private files:\n{resp}") + logger.critical(f"Failed to list private files:\n{resp!r}") raise MumbleException("File listing not working properly") return None return resp -async def do_upload(session: Session, logger: LoggerAdapter, modelname: str, stlfile: str, check: bool = True) -> Optional[bytes]: - modelname = ensure_bytes(modelname) +async def do_upload( + session: Session, + logger: LoggerAdapter, + modelname: bytes, + stlfile: bytes, + check: bool = True, +) -> Optional[bytes]: # Upload file - logger.debug(f"Uploading model with name {modelname}") - session.writer.write("upload\n") + logger.debug(f"Uploading model with name {modelname!r}") + session.writer.write(b"upload\n") session.writer.write(modelname + b"\n") - session.writer.write(f"{len(stlfile)}\n") + session.writer.write(f"{len(stlfile)}\n".encode()) session.writer.write(stlfile) await session.writer.drain() @@ -329,7 +372,7 @@ async def do_upload(session: Session, logger: LoggerAdapter, modelname: str, stl resp += await session.reader.readline() if b"ERR:" in resp: if check: - logger.critical(f"Failed to upload model {modelname}:\n{resp}") + logger.critical(f"Failed to upload model {modelname!r}:\n{resp!r}") raise MumbleException("File upload not working properly") await session.reader.readuntil(prompt) return None @@ -337,30 +380,38 @@ async def do_upload(session: Session, logger: LoggerAdapter, modelname: str, stl # Parse ID try: modelid = resp.rsplit(b"!", 1)[0].split(b"with ID ", 1)[1] - if modelid == b"": raise Exception + if modelid == b"": + raise Exception except: - logger.critical(f"Invalid response during upload of {modelname}:\n{resp}") + logger.critical(f"Invalid response during upload of {modelname!r}:\n{resp!r}") raise MumbleException("File upload not working properly") await session.reader.readuntil(prompt) return modelid -async def do_search(session, logger, modelname, download = False, check = True) -> Optional[Tuple[bytes, bytes]]: + +async def do_search( + session: Session, + logger: LoggerAdapter, + modelname: bytes, + download: bool = False, + check: bool = True, +) -> Optional[tuple[bytes, bytes]]: modelname = ensure_bytes(modelname) # Initiate download - logger.debug(f"Retrieving model with name {modelname}") + logger.debug(f"Retrieving model with name {modelname!r}") session.writer.write(b"search " + modelname + b"\n") - session.writer.write("0\n") # first result - session.writer.write("y\n" if download else "n\n") - session.writer.write("q\n") # quit + session.writer.write(b"0\n") # first result + session.writer.write(b"y\n" if download else b"n\n") + session.writer.write(b"q\n") # quit await session.writer.drain() # Check if an error occured line = await session.reader.readline() if b"ERR:" in line: if check: - logger.critical(f"Failed to retrieve model {modelname}:\n{line}") + logger.critical(f"Failed to retrieve model {modelname!r}:\n{line!r}") raise MumbleException("File search not working properly") if b"Couldn't find a matching scan result" in line: # collect all the invalid commands sent after (hacky) @@ -372,16 +423,18 @@ async def do_search(session, logger, modelname, download = False, check = True) return None # read until end of info box - fileinfo = line + await session.reader.readuntil("================== \n") + fileinfo = line + await session.reader.readuntil(b"================== \n") stlfile = b"" - if download: # Parse file contents + if download: # Parse file contents await session.reader.readuntil(b"Here you go.. (") resp = await session.reader.readuntil(b"B)\n") resp = resp[:-3] size = parse_int(resp) if size is None: - raise MumbleException(f"Received invalid download size, response:\n{resp}") + raise MumbleException( + f"Received invalid download size, response:\n{resp!r}" + ) logger.debug(f"Download size: {size}") stlfile = await session.reader.readexactly(size) @@ -392,112 +445,179 @@ async def do_search(session, logger, modelname, download = False, check = True) # CHECK WRAPPERS # -async def check_line(session: Session, logger: LoggerAdapter, context: str): - line = session.reader.readline() + +async def check_line(session: Session, logger: LoggerAdapter, context: str) -> bytes: + line = await session.reader.readline() if b"ERR:" in line: logger.critical(f"{context}: Unexpected error message\n") raise MumbleException("Service returned error during valid interaction") return line -async def check_listed(session: Session, logger: LoggerAdapter, includes: Tuple[bytes, ...]) -> bytes: - resp = await do_list(session, logger, check = True) + +async def check_listed( + session: Session, logger: LoggerAdapter, includes: list[bytes] +) -> bytes: + resp = await do_list(session, logger, check=True) + assert resp is not None if not includes_all(resp, includes): - logger.critical(f"Failed to find {includes} in listing:\n{resp}") + logger.critical(f"Failed to find {includes} in listing:\n{resp!r}") raise MumbleException("File listing not working properly") return resp -async def check_not_listed(session: Session, logger: LoggerAdapter, excludes: Tuple[bytes, ...], fail: bool = False) -> bytes: - resp = await do_list(session, logger, check = False) - if fail and resp: - logger.critical(f"Expected list to fail, but returned:\n{resp}") - raise MumbleException("File listing not working properly") - if not fail and not resp: - logger.critical(f"List failed unexpectedly:\n{resp}") - raise MumbleException("File listing not working properly") - if resp and includes_any(resp, excludes): - logger.critical(f"Unexpectedly found one of {excludes} in listing:\n{resp}") + +async def check_not_listed( + session: Session, + logger: LoggerAdapter, + excludes: list[bytes], + fail: bool = False, +) -> Optional[bytes]: + resp = await do_list(session, logger, check=False) + if resp is not None: + if fail: + logger.critical(f"Expected list to fail, but returned:\n{resp!r}") + raise MumbleException("File listing not working properly") + if includes_any(resp, excludes): + logger.critical( + f"Unexpectedly found one of {excludes} in listing:\n{resp!r}" + ) + raise MumbleException("File listing not working properly") + elif not fail: + logger.critical(f"list failed unexpectedly:\n{resp!r}") raise MumbleException("File listing not working properly") return resp -async def check_in_search(session: Session, logger: LoggerAdapter, modelname: bytes, includes: Tuple[bytes], download: bool = False) -> Tuple[bytes, bytes]: - info, stlfile = await do_search(session, logger, modelname, download, check = True) - if not includes_all(info + stlfile, includes): - logger.critical(f"Retrieved info for {modelname} is missing {includes}: {resp}") + +async def check_in_search( + session: Session, + logger: LoggerAdapter, + modelname: bytes, + includes: list[bytes], + download: bool = False, +) -> tuple[bytes, bytes]: + resp = await do_search(session, logger, modelname, download, check=True) + assert resp is not None + if not includes_all(resp[0] + resp[1], includes): + logger.critical( + f"Retrieved info for {modelname!r} is missing {includes}: {resp[0]+resp[1]!r}" + ) raise MumbleException("File search not working properly") - return info, stlfile + return resp -async def check_not_in_search(session: Session, logger: LoggerAdapter, modelname: bytes, excludes: Tuple[bytes], download: bool = False, fail: bool = False) -> Tuple[bytes, bytes]: - resp = await do_search(session, logger, modelname, download, check = False) - if resp: + +async def check_not_in_search( + session: Session, + logger: LoggerAdapter, + modelname: bytes, + excludes: list[bytes], + download: bool = False, + fail: bool = False, +) -> Optional[tuple[bytes, bytes]]: + resp = await do_search(session, logger, modelname, download, check=False) + if resp is not None: combined = resp[0] + resp[1] - if fail and resp: - logger.critical("Search for {modelname} succeeded unexpectedly:\n{combined}") - raise MumbleException("File search not working properly") - if not fail and not resp: - logger.critical(f"Search for {modelname} failed unexpectedly") - raise MumbleException("File search not working properly") - if resp and includes_any(combined, excludes): - logger.critical(f"Unexpectedly {modelname} info contains one of {includes}: {combined}") + if fail: + logger.critical( + "Search for {modelname!r} succeeded unexpectedly:\n{combined!r}" + ) + raise MumbleException("File search not working properly") + if includes_any(combined, excludes): + logger.critical( + f"Unexpectedly {modelname!r} info contains one of {excludes}: {combined!r}" + ) + raise MumbleException("File search not working properly") + elif not fail: + logger.critical(f"Search for {modelname!r} failed unexpectedly") raise MumbleException("File search not working properly") return resp + def check_hash(hashstr: bytes) -> None: if not has_alph(hashstr, b"0123456789abcdef"): raise MumbleException("Invalid model hash format returned") -def check_stlinfo(logger: LoggerAdapter, resp: bytes, ref_info: Any, ref_modelid: Optional[bytes] = None, - ref_modelname: Optional[bytes] = None, ref_solidname: Optional[bytes] = None) -> None: - def logthrow(msg): + +def check_stlinfo( + logger: LoggerAdapter, + resp: bytes, + ref_info: Any, + ref_modelid: Optional[bytes] = None, + ref_modelname: Optional[bytes] = None, + ref_solidname: Optional[bytes] = None, +) -> None: + def logthrow(msg: str) -> None: logger.critical(msg) raise MumbleException("STL parsing not working properly") size = parse_int(assert_match(resp, b"File Size: (.*)\n", MumbleException)) if not size or size != ref_info["size"]: - logthrow(f"STL info returned no / invalid file size: {size} != {ref_info['size']}") + logthrow( + f"STL info returned no / invalid file size: {size} != {ref_info['size']}" + ) - triangle_count = parse_int(assert_match(resp, b"Triangle Count: (.*)\n", MumbleException)) + triangle_count = parse_int( + assert_match(resp, b"Triangle Count: (.*)\n", MumbleException) + ) if not triangle_count or triangle_count != ref_info["triangle_count"]: - logthrow(f"STL info returned no / invalid triangle count: {triangle_count} != {ref_info['triangle_count']}") + logthrow( + f"STL info returned no / invalid triangle count: {triangle_count} != {ref_info['triangle_count']}" + ) bb_size_str = assert_match(resp, b"Bounding Box Size: (.*)\n", MumbleException) bb_size = [parse_float(v) for v in bb_size_str.split(b" x ")] - if None in bb_size: - logthrow(f"STL info returned invalid bounding box size: {bb_size_str}") - if False in [approx_equal(bb_size[i], ref_info["bb_size"][i]) for i in range(3)]: - logthrow(f"Bounding box size doesnt match: (REF) {ref_info['bb_size']} {bb_size}") + for i in range(3): + val = bb_size[i] + if val is None: + logthrow(f"STL info returned invalid bounding box size: {bb_size_str!r}") + elif not approx_equal(val, ref_info["bb_size"][i]): + logthrow( + f"Bounding box size doesnt match: (REF) {ref_info['bb_size']} {bb_size}" + ) bb_origin_str = assert_match(resp, b"Bounding Box Origin: (.*)\n", MumbleException) bb_origin = [parse_float(v) for v in bb_origin_str.split(b" x ")] - if None in bb_origin: - logthrow(f"STL info returned invalid bounding box origin: {bb_origin_str}") - if False in [approx_equal(bb_origin[i], ref_info["bb_origin"][i]) for i in range(3)]: - logthrow(f"Bounding box origin doesnt match: (REF) {ref_info['bb_origin']} {bb_origin}") - - triangle_count = parse_float(assert_match(resp, b"Triangle Count: (.*)\n", MumbleException)) + for i in range(3): + val = bb_origin[i] + if val is None: + logthrow( + f"STL info returned invalid bounding box origin: {bb_origin_str!r}" + ) + elif not approx_equal(val, ref_info["bb_origin"][i]): + logthrow( + f"Bounding box origin doesnt match: (REF) {ref_info['bb_origin']} {bb_origin}" + ) + + triangle_count = parse_int( + assert_match(resp, b"Triangle Count: (.*)\n", MumbleException) + ) if triangle_count is None or triangle_count != ref_info["triangle_count"]: - logthrow(f"Triangle count {triangle_count} doesnt match expected: {ref_info['triangle_count']}") + logthrow( + f"Triangle count {triangle_count} doesnt match expected: {ref_info['triangle_count']}" + ) if ref_modelname: modelname = assert_match(resp, b"Model Name: (.*)\n", MumbleException) if modelname != ref_modelname: - logthrow(f"Got modelname {modelname}, expected {ref_modelname}") + logthrow(f"Got modelname {modelname!r}, expected {ref_modelname!r}") if ref_modelid: modelid = assert_match(resp, b"Model ID: (.*)\n", MumbleException) if modelid != ref_modelid: - logthrow(f"Got modelid {modelid}, expected {ref_modelid}") + logthrow(f"Got modelid {modelid!r}, expected {ref_modelid!r}") if ref_solidname: solidname = assert_match(resp, b"Solid Name: (.*)\n", MumbleException) if solidname != ref_solidname: - logthrow(f"Got solidname {solidname}, expected {ref_solidname}") + logthrow(f"Got solidname {solidname!r}, expected {ref_solidname!r}") # TEST METHODS # -async def test_good_upload(di: DependencyInjector, filetype: str, register: str) -> None: - solidname = fakeid(havoc = (filetype == "bin")) # ascii stl cant handle havoc - modelname, authstr = fakeids(2, havoc = True) + +async def test_good_upload( + di: DependencyInjector, filetype: str, register: bool +) -> None: + solidname = fakeid(havoc=(filetype == "bin")) # ascii stl cant handle havoc + modelname, authstr = fakeids(2, havoc=True) stlfile = genfile(solidname, filetype) ref_info = parse_stlinfo(stlfile) logger = await di.get(LoggerAdapter) @@ -506,240 +626,326 @@ async def test_good_upload(di: DependencyInjector, filetype: str, register: str) session = await di.get(Session) await session.prepare() if register: - await do_auth(session, logger, authstr) - modelid = await do_upload(session, logger, modelname, stlfile) + await do_auth(session, logger, authstr, check=True) + modelid = await do_upload(session, logger, modelname, stlfile, check=True) + assert modelid is not None check_hash(modelid) expected = [modelname, solidname, stlfile, modelid] - info, stlfile = await check_in_search(session, logger, modelname, expected, download = True) - check_stlinfo(logger, info, ref_info, ref_modelname = modelname, - ref_modelid = modelid, ref_solidname = solidname) + info, stlfile = await check_in_search( + session, logger, modelname, expected, download=True + ) + check_stlinfo( + logger, + info, + ref_info, + ref_modelname=modelname, + ref_modelid=modelid, + ref_solidname=solidname, + ) if register: - resp = await check_listed(session, logger, [modelname, modelid + b"-"]) + await check_listed(session, logger, [modelname, modelid + b"-"]) await session.close() # Try getting file from a new session session = await di.get(Session) await session.prepare() if register: - await check_not_in_search(session, logger, modelname, expected, download = True, fail = True) - await do_auth(session, logger, authstr) - info, stlfile = await check_in_search(session, logger, modelname, expected, download = True) - check_stlinfo(logger, info, ref_info, ref_modelname = modelname, - ref_modelid = modelid, ref_solidname = solidname) + await check_not_in_search( + session, logger, modelname, expected, download=True, fail=True + ) + await do_auth(session, logger, authstr, check=True) + info, stlfile = await check_in_search( + session, logger, modelname, expected, download=True + ) + check_stlinfo( + logger, + info, + ref_info, + ref_modelname=modelname, + ref_modelid=modelid, + ref_solidname=solidname, + ) await check_listed(session, logger, [modelname, modelid + b"-"]) else: - info, stlfile = await check_in_search(session, logger, modelname, expected, download = True) - check_stlinfo(logger, info, ref_info, ref_modelname = modelname, - ref_modelid = modelid, ref_solidname = solidname) + info, stlfile = await check_in_search( + session, logger, modelname, expected, download=True + ) + check_stlinfo( + logger, + info, + ref_info, + ref_modelname=modelname, + ref_modelid=modelid, + ref_solidname=solidname, + ) await session.close() + async def test_bad_upload(di: DependencyInjector, filetype: str, variant: int) -> None: modelname, solidname = fakeids(2) - stlfile = genfile(solidname, filetype, malformed = variant) + stlfile = genfile(solidname, filetype, malformed=variant) logger = await di.get(LoggerAdapter) # Ensure a malformed file causes an error session = await di.get(Session) await session.prepare() - if await do_upload(session, logger, modelname, stlfile, check = False): - logger.critical(f"Able to upload malformed file:\n{stlfile}") + if await do_upload(session, logger, modelname, stlfile, check=False): + logger.critical(f"Able to upload malformed file:\n{stlfile!r}") raise MumbleException("Upload validation not working properly") await session.close() -async def test_search(di: DependencyInjector, registered = False) -> None: + +async def test_search(di: DependencyInjector, registered: bool = False) -> None: solidname, modelname, authstr = fakeids(3) - stlfile = genfile(solidname, "bin") logger = await di.get(LoggerAdapter) # Ensure searching for a file that doesnt exist causes an error session = await di.get(Session) await session.prepare() if registered: - await do_auth(session, logger, authstr) - if await do_search(session, logger, modelname, download = False, check = None): - logger.critical(f"Search for file that shouldn't exist returned a file:\n{resp}") + await do_auth(session, logger, authstr, check=True) + if resp := await do_search(session, logger, modelname, download=False, check=False): + assert resp is not None + logger.critical( + f"Search for file that shouldn't exist succeeded:\n{resp[0]+resp[1]!r}" + ) raise MumbleException("File search not working properly") await session.close() -async def test_list(di: DependencyInjector, registered = False) -> None: - solidname, modelname, authstr, authstr2 = fakeids(4) - stlfile = genfile(solidname, "bin") - logger = await di.get(LoggerAdapter) - - if registered: - # Create a session and upload a file - session = await di.get(Session) - await session.prepare() - await do_auth(session, logger, authstr) - modelid = await do_upload(session, logger, modelname, stlfile) - await check_listed(session, logger, [modelname, modelid + b"-"]) - await session.close() - - # Ensure that list for another user does not return first users files - session = await di.get(Session) - await session.prepare() - if await do_auth(session, logger, authstr2): - logger.critical("New authstr {authstr2} already has a user dir! Hash collision?!") - raise MumbleException("User authentication not working properly") - await check_not_listed(session, logger, [modelid, modelname]) - await session.close() - else: - # Ensure that list does not work for unregistered users - session = await di.get(Session) - await session.prepare() - if await do_list(session, logger, check = False): - logger.critical("Unregistered user can run list without ERR!") - raise MumbleException("User authentication not working properly") - await session.close() # CHECKER METHODS # + @checker.putflag(0) -async def putflag_guest(task: PutflagCheckerTaskMessage, di: DependencyInjector) -> None: - modelname: bytes = fakeid() - logger: LoggerAdapter = await di.get(LoggerAdapter) - db: ChainDB = await di.get(ChainDB) +async def putflag_guest( + task: PutflagCheckerTaskMessage, di: DependencyInjector +) -> None: + modelname = fakeid() + logger = await di.get(LoggerAdapter) + db = await di.get(ChainDB) - session: Session = await di.get(Session) + session = await di.get(Session) await session.prepare() - stlfile: bytes = genfile(task.flag, "ascii") - modelid: bytes = await do_upload(session, logger, modelname, stlfile) + stlfile = genfile(task.flag.encode(), "ascii") + modelid = await do_upload(session, logger, modelname, stlfile, check=True) + assert modelid is not None await session.close() await db.set("flag-0-info", (modelname, modelid)) + @checker.putflag(1) -async def putflag_private(task: PutflagCheckerTaskMessage, di: DependencyInjector) -> None: +async def putflag_private( + task: PutflagCheckerTaskMessage, di: DependencyInjector +) -> None: modelname, authstr = fakeids(2) - stlfile: bytes = genfile(task.flag, "bin") - logger: LoggerAdapter = await di.get(LoggerAdapter) - db: ChainDB = await di.get(ChainDB) + logger = await di.get(LoggerAdapter) + stlfile = genfile(task.flag.encode(), "bin") + db = await di.get(ChainDB) - session: Session = await di.get(Session) + session = await di.get(Session) await session.prepare() - await do_auth(session, logger, authstr) - modelid: bytes = await do_upload(session, logger, modelname, stlfile) + await do_auth(session, logger, authstr, check=True) + modelid = await do_upload(session, logger, modelname, stlfile, check=True) + assert modelid is not None await session.close() await db.set("flag-1-info", (modelname, modelid, authstr)) + @checker.getflag(0) -async def getflag_guest(task: GetflagCheckerTaskMessage, di: DependencyInjector) -> None: - db: ChainDB = await di.get(ChainDB) +async def getflag_guest( + task: GetflagCheckerTaskMessage, di: DependencyInjector +) -> None: + db = await di.get(ChainDB) modelname, modelid = await getdb(db, "flag-0-info") - logger: LoggerAdapter = await di.get(LoggerAdapter) + logger = await di.get(LoggerAdapter) - session: Session = await di.get(Session) + session = await di.get(Session) await session.prepare() - stlinfo, stlfile = await do_search(session, logger, modelname, download = True) - assert_in(task.flag.encode(), stlinfo, "Flag is missing from stl info") - assert_in(task.flag.encode(), stlfile, "Flag is missing from stl file") + resp = await do_search(session, logger, modelname, download=True, check=True) + assert resp is not None + assert_in(task.flag.encode(), resp[0], "Flag is missing from stl info") + assert_in(task.flag.encode(), resp[1], "Flag is missing from stl file") await session.close() + @checker.getflag(1) -async def getflag_private(task: GetflagCheckerTaskMessage, di: DependencyInjector) -> None: - db: ChainDB = await di.get(ChainDB) +async def getflag_private( + task: GetflagCheckerTaskMessage, di: DependencyInjector +) -> None: + db = await di.get(ChainDB) modelname, modelid, authstr = await getdb(db, "flag-1-info") logger = await di.get(LoggerAdapter) session = await di.get(Session) await session.prepare() - await do_auth(session, logger, authstr) - stlinfo, stlfile = await do_search(session, logger, modelname, download = True) - assert_in(task.flag.encode(), stlinfo, "Flag is missing from stl info") - assert_in(task.flag.encode(), stlfile, "Flag is missing from stl file") - resp = await do_list(session, logger) - assert_in(task.flag.encode(), resp, "Flag is missing from list") + await do_auth(session, logger, authstr, check=True) + search_resp = await do_search(session, logger, modelname, download=True, check=True) + assert search_resp is not None + assert_in(task.flag.encode(), search_resp[0], "Flag is missing from stl info") + assert_in(task.flag.encode(), search_resp[1], "Flag is missing from stl file") + list_resp = await do_list(session, logger, check=True) + assert list_resp is not None + assert_in(task.flag.encode(), list_resp, "Flag is missing from list") await session.close() + @checker.putnoise(0, 1) -async def putnoise_guest_ascii(task: PutnoiseCheckerTaskMessage, di: DependencyInjector) -> None: +async def putnoise_guest_ascii( + task: PutnoiseCheckerTaskMessage, di: DependencyInjector +) -> None: modelname, solidname = fakeids(2) - logger: LoggerAdapter = await di.get(LoggerAdapter) - db: ChainDB = await di.get(ChainDB) + logger = await di.get(LoggerAdapter) + db = await di.get(ChainDB) - session: Session = await di.get(Session) + session = await di.get(Session) await session.prepare() stlfile = genfile(solidname, "ascii" if task.variant_id == 0 else "bin") - modelid = await do_upload(session, logger, modelname, stlfile) + modelid = await do_upload(session, logger, modelname, stlfile, check=True) await session.close() - await db.set(f"noise-{task.variant_id}-info", (modelid, modelname, solidname, stlfile)) + await db.set( + f"noise-{task.variant_id}-info", (modelid, modelname, solidname, stlfile) + ) + @checker.putnoise(2, 3) -async def putnoise_priv_ascii(task: PutnoiseCheckerTaskMessage, di: DependencyInjector) -> None: +async def putnoise_priv_ascii( + task: PutnoiseCheckerTaskMessage, di: DependencyInjector +) -> None: modelname, solidname, authstr = fakeids(3) - logger: LoggerAdapter = await di.get(LoggerAdapter) - db: ChainDB = await di.get(ChainDB) + logger = await di.get(LoggerAdapter) + db = await di.get(ChainDB) - session: Session = await di.get(Session) + session = await di.get(Session) await session.prepare() stlfile = genfile(solidname, "ascii" if task.variant_id == 0 else "bin") - await do_auth(session, logger, authstr) - modelid = await do_upload(session, logger, modelname, stlfile) + await do_auth(session, logger, authstr, check=True) + modelid = await do_upload(session, logger, modelname, stlfile, check=True) await session.close() - await db.set(f"noise-{task.variant_id}-info", (modelid, modelname, solidname, stlfile, authstr)) + await db.set( + f"noise-{task.variant_id}-info", + (modelid, modelname, solidname, stlfile, authstr), + ) + @checker.getnoise(0, 1) -async def getnoise_guest_ascii(task: GetnoiseCheckerTaskMessage, di: DependencyInjector) -> None: - db: ChainDB = await di.get(ChainDB) - modelid, modelname, solidname, stlfile = await getdb(db, f"noise-{task.variant_id}-info") - logger: LoggerAdapter = await di.get(LoggerAdapter) +async def getnoise_guest_ascii( + task: GetnoiseCheckerTaskMessage, di: DependencyInjector +) -> None: + db = await di.get(ChainDB) + modelid, modelname, solidname, stlfile = await getdb( + db, f"noise-{task.variant_id}-info" + ) + logger = await di.get(LoggerAdapter) - session: Session = await di.get(Session) + session = await di.get(Session) await session.prepare() - await check_in_search(session, logger, modelname, [modelname, solidname, stlfile, modelid], download = True) + await check_in_search( + session, + logger, + modelname, + [modelname, solidname, stlfile, modelid], + download=True, + ) await session.close() + @checker.getnoise(2, 3) -async def getnoise_priv_ascii(task: GetnoiseCheckerTaskMessage, di: DependencyInjector) -> None: - db: ChainDB = await di.get(ChainDB) - modelid, modelname, solidname, stlfile, authstr = await getdb(db, f"noise-{task.variant_id}-info") - logger: LoggerAdapter = await di.get(LoggerAdapter) +async def getnoise_priv_ascii( + task: GetnoiseCheckerTaskMessage, di: DependencyInjector +) -> None: + db = await di.get(ChainDB) + modelid, modelname, solidname, stlfile, authstr = await getdb( + db, f"noise-{task.variant_id}-info" + ) + logger = await di.get(LoggerAdapter) - session: Session = await di.get(Session) + session = await di.get(Session) await session.prepare() - await do_auth(session, logger, authstr) - await check_in_search(session, logger, modelname, [modelname, solidname, stlfile, modelid], download = True) + await do_auth(session, logger, authstr, check=True) + await check_in_search( + session, + logger, + modelname, + [modelname, solidname, stlfile, modelid], + download=True, + ) await session.close() + @checker.havoc(*range(0, 4)) -async def havoc_good_upload(task: HavocCheckerTaskMessage, di: DependencyInjector) -> None: +async def havoc_good_upload( + task: HavocCheckerTaskMessage, di: DependencyInjector +) -> None: filetype = ["ascii", "bin", "ascii", "bin"] registered = [False, False, True, True] await test_good_upload(di, filetype[task.variant_id], registered[task.variant_id]) + @checker.havoc(*range(4, 12)) -async def havoc_bad_upload(task: HavocCheckerTaskMessage, di: DependencyInjector) -> None: - filetype = ["ascii", "ascii", "ascii", "bin", "bin", "bin", "garbage", "garbage-tiny"] +async def havoc_bad_upload( + task: HavocCheckerTaskMessage, di: DependencyInjector +) -> None: + filetype = [ + "ascii", + "ascii", + "ascii", + "bin", + "bin", + "bin", + "garbage", + "garbage-tiny", + ] upload_variant = [1, 2, 3, 1, 2, 3, 1, 1] - await test_bad_upload(di, filetype[task.variant_id - 4], upload_variant[task.variant_id - 4]) + index = task.variant_id - 4 + await test_bad_upload(di, filetype[index], upload_variant[index]) + @checker.havoc(12, 13) -async def havoc_test_search(task: HavocCheckerTaskMessage, di: DependencyInjector) -> None: +async def havoc_test_search( + task: HavocCheckerTaskMessage, di: DependencyInjector +) -> None: await test_search(di, task.variant_id == 12) -@checker.havoc(14, 15) -async def havoc_test_list(task: HavocCheckerTaskMessage, di: DependencyInjector) -> None: - await test_list(di, task.variant_id == 14) -@checker.havoc(16) -async def havoc_fluff_upload(task: HavocCheckerTaskMessage, di: DependencyInjector) -> None: - if len(extra_models) == 0: return +@checker.havoc(14) +async def havoc_test_list_unregistered(di: DependencyInjector) -> None: + logger = await di.get(LoggerAdapter) + + # Ensure that list does not work for unregistered users + session = await di.get(Session) + await session.prepare() + resp = await do_list(session, logger, check=False) + if resp is not None: + logger.critical("Unregistered user can run list without ERR!") + raise MumbleException("User authentication not working properly") + await session.close() + + +@checker.havoc(15) +async def havoc_fluff_upload(di: DependencyInjector) -> None: + if len(extra_models) == 0: + return model = rand.choice(extra_models) - modelname = os.path.basename(model).split("-")[0] - modelname += "".join([rand.choice("0123456789") for i in range(5)]) + modelname = os.path.basename(model).split("-")[0].encode() + modelname += bytes([rand.choice(b"0123456789") for i in range(5)]) stlfile = open(model, "rb").read() logger = await di.get(LoggerAdapter) session = await di.get(Session) await session.prepare() - modelid = await do_upload(session, logger, modelname, stlfile) - await check_in_search(session, logger, modelname, [modelname, modelid, stlfile], download = True) + modelid = await do_upload(session, logger, modelname, stlfile, check=True) + assert modelid is not None + await check_in_search( + session, logger, modelname, [modelname, modelid, stlfile], download=True + ) + @checker.exploit(0) -async def exploit_prefix_truncation(di: DependencyInjector): +async def exploit_prefix_truncation(di: DependencyInjector) -> bytes: modelname = fakeid() logger = await di.get(LoggerAdapter) searcher = await di.get(FlagSearcher) @@ -747,35 +953,46 @@ async def exploit_prefix_truncation(di: DependencyInjector): # Upload evil file for parse via search for hash truncation session = await di.get(Session) await session.prepare() - logger.debug(f"Uploading evil file for hash truncation") - modelid = await do_upload(session, logger, modelname, stlfile = search_truncation_payload) - info, contents = await do_search(session, logger, modelname, download = False) - session.writer.write("search last\n") + logger.debug("Uploading evil file for hash truncation") + await do_upload( + session, logger, modelname, stlfile=search_truncation_payload, check=True + ) + search_resp = await do_search( + session, logger, modelname, download=False, check=True + ) + assert search_resp is not None + info, contents = search_resp + session.writer.write(b"search last\n") await session.writer.drain() - resp = await session.reader.readuntil("? ") - filelist = [l.strip().split(b" : ")[1] for l in resp.split(b"\n") if b" : " in l] + filelist_resp = await session.reader.readuntil(b"? ") + filelist = [ + l.strip().split(b" : ")[1] for l in filelist_resp.split(b"\n") if b" : " in l + ] if len(filelist) == 0: raise MumbleException("Failed to list files through search") # Use it to enumerate other files and grab contents - logger.debug("Targets:\n" + "\n".join([" - " + l.decode("latin1") for l in filelist])) + logger.debug( + "Targets:\n" + "\n".join([" - " + l.decode("latin1") for l in filelist]) + ) for i, fhash in enumerate(filelist): logger.debug(f"Retrieving file {fhash} at index {i}") - session.writer.write(f"{i}\nn\n") + session.writer.write(f"{i}\nn\n".encode()) await session.writer.drain() - resp = await session.reader.readuntil("==================") - resp += await session.reader.readuntil("? ") - if flag := searcher.search_flag(resp.decode("latin1")): + filelist_resp = await session.reader.readuntil(b"==================") + filelist_resp += await session.reader.readuntil(b"? ") + if flag := searcher.search_flag(filelist_resp.decode("latin1")): return flag # Done! - session.writer.write("q\n") + session.writer.write(b"q\n") await session.writer.drain() await session.reader.readuntil(prompt) await session.close() raise MumbleException("Exploit for flagstore 1 failed") + @checker.exploit(1) async def exploit_hash_overflow(di: DependencyInjector) -> None: logger = await di.get(LoggerAdapter) @@ -787,7 +1004,7 @@ async def exploit_hash_overflow(di: DependencyInjector) -> None: session.writer.write(b"search \xff\xff\xff\xff\xff0000000000000000\n") await session.writer.drain() await session.reader.readuntil(prompt) - session.writer.write("auth\n") + session.writer.write(b"auth\n") await session.writer.drain() resp = await session.reader.readuntil(prompt) if b"Already logged in!" not in resp: @@ -797,20 +1014,20 @@ async def exploit_hash_overflow(di: DependencyInjector) -> None: resp = await do_list(session, logger, check=False) if not resp: raise MumbleException("") - logger.debug("List response: " + str(resp)) + logger.debug("list response: " + str(resp)) users = [l.split(b" .")[1] for l in resp.split(b"\n") if b">> ." in l] await session.close() # Login as each private user for userhash in users: # Find preimage of user hash - logger.debug(f"Logging in as user with id {userhash}") - authstr = reverse_hash(userhash) + logger.debug(f"Logging in as user with id {userhash!r}") + authstr = reverse_hash(userhash.decode()) # Authenticate and check if the user is new session = await di.get(Session) await session.prepare() - if not await do_auth(session, logger, authstr): + if not await do_auth(session, logger, authstr, check=True): await session.close() # We dont raise an exception, because it could be that user dir was cleaned # up just before we logged in, not necessarily because of an invalid prehash. @@ -818,16 +1035,19 @@ async def exploit_hash_overflow(di: DependencyInjector) -> None: # an exception will be raised later anways... continue - # List all private files of user - resp = await do_list(session, logger) + # list all private files of user + resp = await do_list(session, logger, check=True) await session.close() # Search for flag in solid names - solidnames = b"\n".join([l.split(b": ", 1)[1] for l in resp.split(b"\n") if b"Solid Name: " in l]) + solidnames = b"\n".join( + [l.split(b": ", 1)[1] for l in resp.split(b"\n") if b"Solid Name: " in l] + ) if flag := searcher.search_flag(solidnames.decode("latin1")): return flag raise MumbleException("Exploit for flagstore 2 failed") + if __name__ == "__main__": - checker.run(port = 9091) + checker.run(port=9091) diff --git a/checker/test.sh b/checker/test.sh @@ -1,115 +0,0 @@ -#!/bin/bash - -ipstr="$1" - -SCRIPTPATH="$(dirname $(readlink -f "$0"))" -cd "$SCRIPTPATH" -export REVHASH_PATH="$SCRIPTPATH/src/revhash/revhash" - -nop() { :; } - -splitmsg() { - python3 -c " -import json,sys - -try: - instr = sys.stdin.read().strip() - jres = json.loads(instr) - print(jres['result']) - print(jres['message']) -except: - print('INVALID') - print('INVALID') - print('FAIL:', instr, file=sys.stderr) - " || nop -} - -taskid="" -try() { - cmd="$1" - pid=$BASHPID - tmpfile="/tmp/checker-log-$pid" - [ -e "$tmpfile" ] && rm "$tmpfile" - if [ $# -lt 2 ]; then - variant=0 - else - variant=$2 - fi - taskid="$pid" - if [ ! -z "$REMOTE" ]; then - python3 enoreq.py -j True -A http://localhost:9091 -a $REMOTE \ - --flag ENOTESTFLAG123= --flag_regex 'ENO.*=' -i $taskid \ - -v "$variant" -x 4000 ${@:3} "$cmd" > "$tmpfile" - res="$(cat $tmpfile | splitmsg | head -n1)" - else - python3 src/checker.py -j run -v "$variant" -x 4000 \ - --flag ENOTESTFLAG123= --flag_regex 'ENO.*=' -i $taskid \ - ${@:3} "$cmd" > "$tmpfile" - res="$(cat $tmpfile | grep -a -o 'Result: .*' | tail -n1 | cut -d' ' -f2)" - fi - if [ "$res" != "OK" ]; then - newfile="fails/err-$pid" - ( - echo "METHOD $@" - echo "RESULT $res" - echo "TASK $taskid" - cat "$tmpfile" - if [ ! -z "$REMOTE" -a -e "$ENOLOGMESSAGE_PARSER" ]; then - docker-compose logs --tail=2000 | grep '"taskId": '$taskid | $ENOLOGMESSAGE_PARSER - fi - ) > "$newfile" - echo -ne "Executing $cmd with variant $variant.. $res (TASK: $taskid)\n" - return 1 - else - echo -ne "Executing $cmd with variant $variant.. $res (TASK: $taskid)\n" - return 0 - fi -} - -try-all() { - set -e - - try putflag 0 - try getflag 0 - - try putflag 1 - try getflag 1 - - try putnoise 0 - try getnoise 0 - - try putflag 1 - try getflag 1 - - for i in $(seq 0 15); do - try havoc $i - done - - try exploit 0 - try exploit 1 -} - -one-of() { - for arg in ${@:2}; do - [ "$1" == "$arg" ] && return 0 - done - return 1 -} - -if one-of "$1" putflag getflag putnoise getnoise havoc exploit; then - try $@ - [ $? -ne 0 -a -e "$EDITOR" ] && "$EDITOR" "/tmp/checker-log-$taskid" -elif [ "$1" == "test-exploits" ]; then - try exploit 0 - try exploit 1 -elif [ "$1" == "stress-test" ]; then - mkdir -p fails - count=${2:-100} - for i in $(seq $count); do - try ${3:-exploit} ${4:-0} & - done -else - try-all -fi - -exit 0