summaryrefslogtreecommitdiffstats
path: root/meta/solve.py
blob: ef5842a564f6e20edca78c6262561e5eab5e11eb (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from pwn import *

import random
import subprocess
import sys
import time

# MT19937 constants
W, N, M, R = 32, 624, 397, 31
A = 0x9908B0DF

w_upper = (1 << W) - (1 << R)
w_lower = (1 << R) - (1 << 0)
w_full = (1 << W) - (1 << 0)

def _mask_lower(n):
    return (1 << n) - (1 << 0)

def mask_lower(bits, n, shl):
    mask = _mask_lower(n)
    return (bits & mask) << shl

def _mask_upper(n):
    return (1 << W) - (1 << (W - n))

def mask_upper(bits, n, shr):
    mask = _mask_upper(n)
    return (bits & mask) >> shr

def undo_selfxor(bits, mask, shr, shl):
    dirty = (mask << shl) >> shr
    clean = w_full ^ dirty
    assert(dirty == (dirty & w_full))
    rec = bits & clean
    while dirty != 0:
        pre = clean & ((dirty << shr) >> shl)
        post = ((pre << shl) >> shr) & w_full
        assert(pre != 0) # we can recover new bits
        rec |= (((rec & pre) << shl) >> shr) ^ (bits & post)
        clean |= post
        dirty &= w_full ^ clean
    return rec

def harden(bits):
    bits ^= mask_upper(bits, W - 11, 11)
    bits ^= mask_lower(bits, W -  7,  7) & 0x9d2c5680
    bits ^= mask_lower(bits, W - 15, 15) & 0xefc60000
    bits ^= mask_upper(bits, W - 18, 18)
    return bits

def unharden(bits):
    bits = undo_selfxor(bits, _mask_upper(W - 18), 18, 0)
    bits = undo_selfxor(bits, _mask_lower(W - 15) & (0xefc60000 >> 15), 0, 15)
    bits = undo_selfxor(bits, _mask_lower(W - 7) & (0x9d2c5680 >> 7), 0, 7)
    bits = undo_selfxor(bits, _mask_upper(W - 11), 11, 0)
    return bits

val = random.getrandbits(32)
assert(unharden(harden(val)) == val)

# for initial state population from seed
def mul_a(x):
    return (x >> 1) ^ (A * (x & 1))

def gen_next(states):
    si = len(states)
    x = (states[si - N] & w_upper) | (states[si - N + 1] & w_lower)
    return states[si - N + M] ^ mul_a(x)

def main(host="localhost", port="9051"):
    io = remote(host, int(port))

    retries = 100
    good = 9

    values = []
    for n in range(retries):
        assert(io.readline() == b"Hints:\n")
        for i in range(good):
            values.append(unharden(int(io.readline())))
        assert(io.readline() == b"Guess:\n")
        if n == retries - 1:
            break
        values.append(None)
        io.sendline(b"0")

    predict = gen_next(values)
    io.sendline(str(harden(int(predict))).encode())

    print(io.readline().decode())

if __name__ == "__main__":
    main(*sys.argv[1:])