root/kernel/bpf/tnum.c

/* [<][>][^][v][top][bottom][index][help] */

DEFINITIONS

This source file includes following definitions.
  1. tnum_const
  2. tnum_range
  3. tnum_lshift
  4. tnum_rshift
  5. tnum_arshift
  6. tnum_add
  7. tnum_sub
  8. tnum_and
  9. tnum_or
  10. tnum_xor
  11. hma
  12. tnum_mul
  13. tnum_intersect
  14. tnum_cast
  15. tnum_is_aligned
  16. tnum_in
  17. tnum_strn
  18. tnum_sbin

   1 // SPDX-License-Identifier: GPL-2.0-only
   2 /* tnum: tracked (or tristate) numbers
   3  *
   4  * A tnum tracks knowledge about the bits of a value.  Each bit can be either
   5  * known (0 or 1), or unknown (x).  Arithmetic operations on tnums will
   6  * propagate the unknown bits such that the tnum result represents all the
   7  * possible results for possible values of the operands.
   8  */
   9 #include <linux/kernel.h>
  10 #include <linux/tnum.h>
  11 
  12 #define TNUM(_v, _m)    (struct tnum){.value = _v, .mask = _m}
  13 /* A completely unknown value */
  14 const struct tnum tnum_unknown = { .value = 0, .mask = -1 };
  15 
  16 struct tnum tnum_const(u64 value)
  17 {
  18         return TNUM(value, 0);
  19 }
  20 
  21 struct tnum tnum_range(u64 min, u64 max)
  22 {
  23         u64 chi = min ^ max, delta;
  24         u8 bits = fls64(chi);
  25 
  26         /* special case, needed because 1ULL << 64 is undefined */
  27         if (bits > 63)
  28                 return tnum_unknown;
  29         /* e.g. if chi = 4, bits = 3, delta = (1<<3) - 1 = 7.
  30          * if chi = 0, bits = 0, delta = (1<<0) - 1 = 0, so we return
  31          *  constant min (since min == max).
  32          */
  33         delta = (1ULL << bits) - 1;
  34         return TNUM(min & ~delta, delta);
  35 }
  36 
  37 struct tnum tnum_lshift(struct tnum a, u8 shift)
  38 {
  39         return TNUM(a.value << shift, a.mask << shift);
  40 }
  41 
  42 struct tnum tnum_rshift(struct tnum a, u8 shift)
  43 {
  44         return TNUM(a.value >> shift, a.mask >> shift);
  45 }
  46 
  47 struct tnum tnum_arshift(struct tnum a, u8 min_shift, u8 insn_bitness)
  48 {
  49         /* if a.value is negative, arithmetic shifting by minimum shift
  50          * will have larger negative offset compared to more shifting.
  51          * If a.value is nonnegative, arithmetic shifting by minimum shift
  52          * will have larger positive offset compare to more shifting.
  53          */
  54         if (insn_bitness == 32)
  55                 return TNUM((u32)(((s32)a.value) >> min_shift),
  56                             (u32)(((s32)a.mask)  >> min_shift));
  57         else
  58                 return TNUM((s64)a.value >> min_shift,
  59                             (s64)a.mask  >> min_shift);
  60 }
  61 
  62 struct tnum tnum_add(struct tnum a, struct tnum b)
  63 {
  64         u64 sm, sv, sigma, chi, mu;
  65 
  66         sm = a.mask + b.mask;
  67         sv = a.value + b.value;
  68         sigma = sm + sv;
  69         chi = sigma ^ sv;
  70         mu = chi | a.mask | b.mask;
  71         return TNUM(sv & ~mu, mu);
  72 }
  73 
  74 struct tnum tnum_sub(struct tnum a, struct tnum b)
  75 {
  76         u64 dv, alpha, beta, chi, mu;
  77 
  78         dv = a.value - b.value;
  79         alpha = dv + a.mask;
  80         beta = dv - b.mask;
  81         chi = alpha ^ beta;
  82         mu = chi | a.mask | b.mask;
  83         return TNUM(dv & ~mu, mu);
  84 }
  85 
  86 struct tnum tnum_and(struct tnum a, struct tnum b)
  87 {
  88         u64 alpha, beta, v;
  89 
  90         alpha = a.value | a.mask;
  91         beta = b.value | b.mask;
  92         v = a.value & b.value;
  93         return TNUM(v, alpha & beta & ~v);
  94 }
  95 
  96 struct tnum tnum_or(struct tnum a, struct tnum b)
  97 {
  98         u64 v, mu;
  99 
 100         v = a.value | b.value;
 101         mu = a.mask | b.mask;
 102         return TNUM(v, mu & ~v);
 103 }
 104 
 105 struct tnum tnum_xor(struct tnum a, struct tnum b)
 106 {
 107         u64 v, mu;
 108 
 109         v = a.value ^ b.value;
 110         mu = a.mask | b.mask;
 111         return TNUM(v & ~mu, mu);
 112 }
 113 
 114 /* half-multiply add: acc += (unknown * mask * value).
 115  * An intermediate step in the multiply algorithm.
 116  */
 117 static struct tnum hma(struct tnum acc, u64 value, u64 mask)
 118 {
 119         while (mask) {
 120                 if (mask & 1)
 121                         acc = tnum_add(acc, TNUM(0, value));
 122                 mask >>= 1;
 123                 value <<= 1;
 124         }
 125         return acc;
 126 }
 127 
 128 struct tnum tnum_mul(struct tnum a, struct tnum b)
 129 {
 130         struct tnum acc;
 131         u64 pi;
 132 
 133         pi = a.value * b.value;
 134         acc = hma(TNUM(pi, 0), a.mask, b.mask | b.value);
 135         return hma(acc, b.mask, a.value);
 136 }
 137 
 138 /* Note that if a and b disagree - i.e. one has a 'known 1' where the other has
 139  * a 'known 0' - this will return a 'known 1' for that bit.
 140  */
 141 struct tnum tnum_intersect(struct tnum a, struct tnum b)
 142 {
 143         u64 v, mu;
 144 
 145         v = a.value | b.value;
 146         mu = a.mask & b.mask;
 147         return TNUM(v & ~mu, mu);
 148 }
 149 
 150 struct tnum tnum_cast(struct tnum a, u8 size)
 151 {
 152         a.value &= (1ULL << (size * 8)) - 1;
 153         a.mask &= (1ULL << (size * 8)) - 1;
 154         return a;
 155 }
 156 
 157 bool tnum_is_aligned(struct tnum a, u64 size)
 158 {
 159         if (!size)
 160                 return true;
 161         return !((a.value | a.mask) & (size - 1));
 162 }
 163 
 164 bool tnum_in(struct tnum a, struct tnum b)
 165 {
 166         if (b.mask & ~a.mask)
 167                 return false;
 168         b.value &= ~a.mask;
 169         return a.value == b.value;
 170 }
 171 
 172 int tnum_strn(char *str, size_t size, struct tnum a)
 173 {
 174         return snprintf(str, size, "(%#llx; %#llx)", a.value, a.mask);
 175 }
 176 EXPORT_SYMBOL_GPL(tnum_strn);
 177 
 178 int tnum_sbin(char *str, size_t size, struct tnum a)
 179 {
 180         size_t n;
 181 
 182         for (n = 64; n; n--) {
 183                 if (n < size) {
 184                         if (a.mask & 1)
 185                                 str[n - 1] = 'x';
 186                         else if (a.value & 1)
 187                                 str[n - 1] = '1';
 188                         else
 189                                 str[n - 1] = '0';
 190                 }
 191                 a.mask >>= 1;
 192                 a.value >>= 1;
 193         }
 194         str[min(size - 1, (size_t)64)] = 0;
 195         return 64;
 196 }

/* [<][>][^][v][top][bottom][index][help] */