cachepc-linux

Fork of AMDESE/linux with modifications for CachePC side-channel attack
git clone https://git.sinitax.com/sinitax/cachepc-linux
Log | Files | Refs | README | LICENSE | sfeed.txt

tnum.c (5225B)


      1// SPDX-License-Identifier: GPL-2.0-only
      2/* tnum: tracked (or tristate) numbers
      3 *
      4 * A tnum tracks knowledge about the bits of a value.  Each bit can be either
      5 * known (0 or 1), or unknown (x).  Arithmetic operations on tnums will
      6 * propagate the unknown bits such that the tnum result represents all the
      7 * possible results for possible values of the operands.
      8 */
      9#include <linux/kernel.h>
     10#include <linux/tnum.h>
     11
     12#define TNUM(_v, _m)	(struct tnum){.value = _v, .mask = _m}
     13/* A completely unknown value */
     14const struct tnum tnum_unknown = { .value = 0, .mask = -1 };
     15
     16struct tnum tnum_const(u64 value)
     17{
     18	return TNUM(value, 0);
     19}
     20
     21struct tnum tnum_range(u64 min, u64 max)
     22{
     23	u64 chi = min ^ max, delta;
     24	u8 bits = fls64(chi);
     25
     26	/* special case, needed because 1ULL << 64 is undefined */
     27	if (bits > 63)
     28		return tnum_unknown;
     29	/* e.g. if chi = 4, bits = 3, delta = (1<<3) - 1 = 7.
     30	 * if chi = 0, bits = 0, delta = (1<<0) - 1 = 0, so we return
     31	 *  constant min (since min == max).
     32	 */
     33	delta = (1ULL << bits) - 1;
     34	return TNUM(min & ~delta, delta);
     35}
     36
     37struct tnum tnum_lshift(struct tnum a, u8 shift)
     38{
     39	return TNUM(a.value << shift, a.mask << shift);
     40}
     41
     42struct tnum tnum_rshift(struct tnum a, u8 shift)
     43{
     44	return TNUM(a.value >> shift, a.mask >> shift);
     45}
     46
     47struct tnum tnum_arshift(struct tnum a, u8 min_shift, u8 insn_bitness)
     48{
     49	/* if a.value is negative, arithmetic shifting by minimum shift
     50	 * will have larger negative offset compared to more shifting.
     51	 * If a.value is nonnegative, arithmetic shifting by minimum shift
     52	 * will have larger positive offset compare to more shifting.
     53	 */
     54	if (insn_bitness == 32)
     55		return TNUM((u32)(((s32)a.value) >> min_shift),
     56			    (u32)(((s32)a.mask)  >> min_shift));
     57	else
     58		return TNUM((s64)a.value >> min_shift,
     59			    (s64)a.mask  >> min_shift);
     60}
     61
     62struct tnum tnum_add(struct tnum a, struct tnum b)
     63{
     64	u64 sm, sv, sigma, chi, mu;
     65
     66	sm = a.mask + b.mask;
     67	sv = a.value + b.value;
     68	sigma = sm + sv;
     69	chi = sigma ^ sv;
     70	mu = chi | a.mask | b.mask;
     71	return TNUM(sv & ~mu, mu);
     72}
     73
     74struct tnum tnum_sub(struct tnum a, struct tnum b)
     75{
     76	u64 dv, alpha, beta, chi, mu;
     77
     78	dv = a.value - b.value;
     79	alpha = dv + a.mask;
     80	beta = dv - b.mask;
     81	chi = alpha ^ beta;
     82	mu = chi | a.mask | b.mask;
     83	return TNUM(dv & ~mu, mu);
     84}
     85
     86struct tnum tnum_and(struct tnum a, struct tnum b)
     87{
     88	u64 alpha, beta, v;
     89
     90	alpha = a.value | a.mask;
     91	beta = b.value | b.mask;
     92	v = a.value & b.value;
     93	return TNUM(v, alpha & beta & ~v);
     94}
     95
     96struct tnum tnum_or(struct tnum a, struct tnum b)
     97{
     98	u64 v, mu;
     99
    100	v = a.value | b.value;
    101	mu = a.mask | b.mask;
    102	return TNUM(v, mu & ~v);
    103}
    104
    105struct tnum tnum_xor(struct tnum a, struct tnum b)
    106{
    107	u64 v, mu;
    108
    109	v = a.value ^ b.value;
    110	mu = a.mask | b.mask;
    111	return TNUM(v & ~mu, mu);
    112}
    113
    114/* Generate partial products by multiplying each bit in the multiplier (tnum a)
    115 * with the multiplicand (tnum b), and add the partial products after
    116 * appropriately bit-shifting them. Instead of directly performing tnum addition
    117 * on the generated partial products, equivalenty, decompose each partial
    118 * product into two tnums, consisting of the value-sum (acc_v) and the
    119 * mask-sum (acc_m) and then perform tnum addition on them. The following paper
    120 * explains the algorithm in more detail: https://arxiv.org/abs/2105.05398.
    121 */
    122struct tnum tnum_mul(struct tnum a, struct tnum b)
    123{
    124	u64 acc_v = a.value * b.value;
    125	struct tnum acc_m = TNUM(0, 0);
    126
    127	while (a.value || a.mask) {
    128		/* LSB of tnum a is a certain 1 */
    129		if (a.value & 1)
    130			acc_m = tnum_add(acc_m, TNUM(0, b.mask));
    131		/* LSB of tnum a is uncertain */
    132		else if (a.mask & 1)
    133			acc_m = tnum_add(acc_m, TNUM(0, b.value | b.mask));
    134		/* Note: no case for LSB is certain 0 */
    135		a = tnum_rshift(a, 1);
    136		b = tnum_lshift(b, 1);
    137	}
    138	return tnum_add(TNUM(acc_v, 0), acc_m);
    139}
    140
    141/* Note that if a and b disagree - i.e. one has a 'known 1' where the other has
    142 * a 'known 0' - this will return a 'known 1' for that bit.
    143 */
    144struct tnum tnum_intersect(struct tnum a, struct tnum b)
    145{
    146	u64 v, mu;
    147
    148	v = a.value | b.value;
    149	mu = a.mask & b.mask;
    150	return TNUM(v & ~mu, mu);
    151}
    152
    153struct tnum tnum_cast(struct tnum a, u8 size)
    154{
    155	a.value &= (1ULL << (size * 8)) - 1;
    156	a.mask &= (1ULL << (size * 8)) - 1;
    157	return a;
    158}
    159
    160bool tnum_is_aligned(struct tnum a, u64 size)
    161{
    162	if (!size)
    163		return true;
    164	return !((a.value | a.mask) & (size - 1));
    165}
    166
    167bool tnum_in(struct tnum a, struct tnum b)
    168{
    169	if (b.mask & ~a.mask)
    170		return false;
    171	b.value &= ~a.mask;
    172	return a.value == b.value;
    173}
    174
    175int tnum_strn(char *str, size_t size, struct tnum a)
    176{
    177	return snprintf(str, size, "(%#llx; %#llx)", a.value, a.mask);
    178}
    179EXPORT_SYMBOL_GPL(tnum_strn);
    180
    181int tnum_sbin(char *str, size_t size, struct tnum a)
    182{
    183	size_t n;
    184
    185	for (n = 64; n; n--) {
    186		if (n < size) {
    187			if (a.mask & 1)
    188				str[n - 1] = 'x';
    189			else if (a.value & 1)
    190				str[n - 1] = '1';
    191			else
    192				str[n - 1] = '0';
    193		}
    194		a.mask >>= 1;
    195		a.value >>= 1;
    196	}
    197	str[min(size - 1, (size_t)64)] = 0;
    198	return 64;
    199}
    200
    201struct tnum tnum_subreg(struct tnum a)
    202{
    203	return tnum_cast(a, 4);
    204}
    205
    206struct tnum tnum_clear_subreg(struct tnum a)
    207{
    208	return tnum_lshift(tnum_rshift(a, 32), 32);
    209}
    210
    211struct tnum tnum_const_subreg(struct tnum a, u32 value)
    212{
    213	return tnum_or(tnum_clear_subreg(a), tnum_const(value));
    214}