From: Jacob Lifshay Date: Sun, 3 Apr 2022 20:20:30 +0000 (-0700) Subject: Move files from libreriscv.git/openpower/sv/bitmanip/ X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=5e7b23a9237aac9e9ced9ae37fdef7f255fc081b;p=nmigen-gf.git Move files from libreriscv.git/openpower/sv/bitmanip/ https://git.libre-soc.org/?p=libreriscv.git;a=tree;f=openpower/sv/bitmanip;hb=633d57457d98b8c6ab7803fe72050f8918bba87f --- diff --git a/gf_reference/.git-keep b/gf_reference/.git-keep deleted file mode 100644 index e69de29..0000000 diff --git a/gf_reference/__init__.py b/gf_reference/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gf_reference/cldivrem.py b/gf_reference/cldivrem.py new file mode 100644 index 0000000..33a528d --- /dev/null +++ b/gf_reference/cldivrem.py @@ -0,0 +1,30 @@ +from .log2 import floor_log2 + + +def cldivrem(n, d, width): + """ Carry-less Division and Remainder. + `n` and `d` are integers, `width` is the number of bits needed to hold + each input/output. + Returns a tuple `q, r` of the quotient and remainder. + """ + assert d != 0, "TODO: decide what happens on division by zero" + assert 0 <= n < 1 << width, f"bad n (doesn't fit in {width}-bit uint)" + assert 0 <= d < 1 << width, f"bad d (doesn't fit in {width}-bit uint)" + r = n + q = 0 + d <<= width + for _ in range(width): + d >>= 1 + q <<= 1 + if degree(d) == degree(r): + r ^= d + q |= 1 + return q, r + + +def degree(v): + """the degree of the GF(2) polynomial `v`. `v` is a non-negative integer. + """ + if v == 0: + return -1 + return floor_log2(v) diff --git a/gf_reference/clmul.py b/gf_reference/clmul.py new file mode 100644 index 0000000..7451c3c --- /dev/null +++ b/gf_reference/clmul.py @@ -0,0 +1,8 @@ +def clmul(a, b): + x = 0 + i = 0 + while b >> i != 0: + if (b >> i) & 1: + x ^= a << i + i += 1 + return x diff --git a/gf_reference/clmulh.py b/gf_reference/clmulh.py new file mode 100644 index 0000000..b170aca --- /dev/null +++ b/gf_reference/clmulh.py @@ -0,0 +1,5 @@ +from .clmul import clmul + + +def clmulh(a, b, XLEN): + return clmul(a, b) >> XLEN diff --git a/gf_reference/clmulr.py b/gf_reference/clmulr.py new file mode 100644 index 0000000..5b155b5 --- /dev/null +++ b/gf_reference/clmulr.py @@ -0,0 +1,5 @@ +from .clmul import clmul + + +def clmulh(a, b, XLEN): + return clmul(a, b) >> (XLEN - 1) diff --git a/gf_reference/decode_reducing_polynomial.py b/gf_reference/decode_reducing_polynomial.py new file mode 100644 index 0000000..ac6623b --- /dev/null +++ b/gf_reference/decode_reducing_polynomial.py @@ -0,0 +1,19 @@ +from .state import ST + + +def decode_reducing_polynomial(): + """returns the decoded reducing polynomial as an integer. + Note: the returned integer is `XLEN + 1` bits wide. + """ + v = ST.GFBREDPOLY & ((1 << ST.XLEN) - 1) # mask to XLEN bits + if v == 0 or v == 2: # GF(2) + return 0b10 # degree = 1, poly = x + if (v & 1) == 0: + # all reducing polynomials of degree > 1 must have the LSB set, + # because they must be irreducible polynomials (meaning they + # can't be factored), if the LSB was clear, then they would + # have `x` as a factor. Therefore, we can reuse the LSB clear + # to instead mean the polynomial has degree XLEN. + v |= 1 << ST.XLEN + v |= 1 # LSB must be set + return v diff --git a/gf_reference/gfbinv.py b/gf_reference/gfbinv.py new file mode 100644 index 0000000..57aa5a8 --- /dev/null +++ b/gf_reference/gfbinv.py @@ -0,0 +1,40 @@ +from .decode_reducing_polynomial import decode_reducing_polynomial +from .cldivrem import degree + + +def gfbinv(a): + """compute the GF(2^m) inverse of `a`.""" + # Derived from Algorithm 3, from [7] in: + # https://ftp.libre-soc.org/ARITH18_Kobayashi.pdf + + s = decode_reducing_polynomial() + m = degree(s) + assert a >> m == 0, "`a` is out-of-range" + r = a + v = 0 + u = 1 + delta = 0 + + for _ in range(2 * m): + # could use count-leading-zeros here to skip ahead + if r >> m == 0: # if the MSB of `r` is zero + r <<= 1 + u <<= 1 + delta += 1 + else: + if s >> m != 0: # if the MSB of `s` isn't zero + s ^= r + v ^= u + s <<= 1 + if delta == 0: + r, s = s, r # swap r and s + u, v = v << 1, u # shift v and swap + delta = 1 + else: + u >>= 1 + delta -= 1 + if a == 0: + # we specifically choose 0 as the result of inverting 0, rather than an + # error or undefined, since that's what Rijndael needs. + return 0 + return u diff --git a/gf_reference/gfbmadd.py b/gf_reference/gfbmadd.py new file mode 100644 index 0000000..df826c4 --- /dev/null +++ b/gf_reference/gfbmadd.py @@ -0,0 +1,11 @@ +from .state import ST +from .decode_reducing_polynomial import decode_reducing_polynomial +from .clmul import clmul +from .cldivrem import cldivrem + + +def gfbmadd(a, b, c): + v = clmul(a, b) ^ c + red_poly = decode_reducing_polynomial() + q, r = cldivrem(v, red_poly, width=ST.XLEN + 1) + return r diff --git a/gf_reference/gfbmul.py b/gf_reference/gfbmul.py new file mode 100644 index 0000000..6295a29 --- /dev/null +++ b/gf_reference/gfbmul.py @@ -0,0 +1,11 @@ +from .state import ST +from .decode_reducing_polynomial import decode_reducing_polynomial +from .clmul import clmul +from .cldivrem import cldivrem + + +def gfbmul(a, b): + product = clmul(a, b) + red_poly = decode_reducing_polynomial() + q, r = cldivrem(product, red_poly, width=ST.XLEN + 1) + return r diff --git a/gf_reference/gfbredpoly.py b/gf_reference/gfbredpoly.py new file mode 100644 index 0000000..1812967 --- /dev/null +++ b/gf_reference/gfbredpoly.py @@ -0,0 +1,6 @@ +from .state import ST + + +def gfbredpoly(immed): + # TODO: figure out how `immed` should be encoded + ST.GFBREDPOLY = immed diff --git a/gf_reference/gfpadd.py b/gf_reference/gfpadd.py new file mode 100644 index 0000000..f00bf8c --- /dev/null +++ b/gf_reference/gfpadd.py @@ -0,0 +1,5 @@ +from .state import ST + + +def gfpadd(a, b): + return (a + b) % ST.GFPRIME diff --git a/gf_reference/gfpinv.py b/gf_reference/gfpinv.py new file mode 100644 index 0000000..45b6dbb --- /dev/null +++ b/gf_reference/gfpinv.py @@ -0,0 +1,51 @@ +from .state import ST + + +def gfpinv(a): + # based on Algorithm ExtEucdInv from: + # https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.5233&rep=rep1&type=pdf + p = ST.GFPRIME + assert p >= 2, "GFPRIME isn't a prime" + assert a != 0, "TODO: decide what happens for division by zero" + assert isinstance(a, int) and 0 < a < p, "a out of range" + if p == 2: + return 1 # the only value possible + + u = p + v = a + r = 0 + s = 1 + while v > 0: + # implementations could use count-zeros on + # both u and r to save cycles + if u & 1 == 0: + u >>= 1 + if r & 1 == 0: + r >>= 1 + else: + r = (r + p) >> 1 + # implementations could use count-zeros on + # both v and s to save cycles + elif v & 1 == 0: + v >>= 1 + if s & 1 == 0: + s >>= 1 + else: + s = (s + p) >> 1 + else: + x = u - v + if x > 0: + u = x + r -= s + if r < 0: + r += p + else: + v = -x + s -= r + if s < 0: + s += p + if r > p: + r -= p + if r < 0: + r += p + return r diff --git a/gf_reference/gfpmadd.py b/gf_reference/gfpmadd.py new file mode 100644 index 0000000..e7159a2 --- /dev/null +++ b/gf_reference/gfpmadd.py @@ -0,0 +1,5 @@ +from .state import ST + + +def gfpmadd(a, b, c): + return (a * b + c) % ST.GFPRIME diff --git a/gf_reference/gfpmsub.py b/gf_reference/gfpmsub.py new file mode 100644 index 0000000..fd09124 --- /dev/null +++ b/gf_reference/gfpmsub.py @@ -0,0 +1,5 @@ +from .state import ST + + +def gfpmsub(a, b, c): + return (a * b - c) % ST.GFPRIME diff --git a/gf_reference/gfpmsubr.py b/gf_reference/gfpmsubr.py new file mode 100644 index 0000000..2d349f8 --- /dev/null +++ b/gf_reference/gfpmsubr.py @@ -0,0 +1,5 @@ +from .state import ST + + +def gfpmsubr(a, b, c): + return (c - a * b) % ST.GFPRIME diff --git a/gf_reference/gfpmul.py b/gf_reference/gfpmul.py new file mode 100644 index 0000000..42c43a2 --- /dev/null +++ b/gf_reference/gfpmul.py @@ -0,0 +1,5 @@ +from .state import ST + + +def gfpmul(a, b): + return (a * b) % ST.GFPRIME diff --git a/gf_reference/gfpsub.py b/gf_reference/gfpsub.py new file mode 100644 index 0000000..44e5798 --- /dev/null +++ b/gf_reference/gfpsub.py @@ -0,0 +1,5 @@ +from .state import ST + + +def gfpsub(a, b): + return (a - b) % ST.GFPRIME diff --git a/gf_reference/log2.py b/gf_reference/log2.py new file mode 100644 index 0000000..0defcc8 --- /dev/null +++ b/gf_reference/log2.py @@ -0,0 +1,12 @@ +def floor_log2(v): + """return floor(log2(v)).""" + assert isinstance(v, int) + assert v > 0 + return v.bit_length() - 1 + + +def ceil_log2(v): + """return ceil(log2(v)).""" + assert isinstance(v, int) + assert v > 0 + return (v - 1).bit_length() diff --git a/gf_reference/pack_poly.py b/gf_reference/pack_poly.py new file mode 100644 index 0000000..f5ab644 --- /dev/null +++ b/gf_reference/pack_poly.py @@ -0,0 +1,19 @@ +"""Polynomials with GF(2) coefficients.""" + + +def pack_poly(poly): + """`poly` is a list where `poly[i]` is the coefficient for `x ** i`""" + retval = 0 + for i, v in enumerate(poly): + retval |= v << i + return retval + + +def unpack_poly(v): + """returns a list `poly`, where `poly[i]` is the coefficient for `x ** i`. + """ + poly = [] + while v != 0: + poly.append(v & 1) + v >>= 1 + return poly diff --git a/gf_reference/state.py b/gf_reference/state.py new file mode 100644 index 0000000..91cb542 --- /dev/null +++ b/gf_reference/state.py @@ -0,0 +1,19 @@ +from .log2 import floor_log2 +from threading import local + + +class State(local): + # thread local so unit tests can be run in parallel without breaking + def __init__(self, *, XLEN=64, GFBREDPOLY=0, GFPRIME=31): + assert isinstance(XLEN, int) and 2 ** floor_log2(XLEN) == XLEN + assert isinstance(GFBREDPOLY, int) and 0 <= GFBREDPOLY < 2 ** 64 + assert isinstance(GFPRIME, int) and 0 <= GFPRIME < 2 ** 64 + self.XLEN = XLEN + self.GFBREDPOLY = GFBREDPOLY + self.GFPRIME = GFPRIME + + def reinit(self, *, XLEN=64, GFBREDPOLY=0, GFPRIME=31): + self.__init__(XLEN=XLEN, GFBREDPOLY=GFBREDPOLY, GFPRIME=GFPRIME) + + +ST = State() diff --git a/gf_reference/test_cl_gfb_gfp.py b/gf_reference/test_cl_gfb_gfp.py new file mode 100644 index 0000000..34133cb --- /dev/null +++ b/gf_reference/test_cl_gfb_gfp.py @@ -0,0 +1,701 @@ +from .state import ST +from .cldivrem import cldivrem +from .clmul import clmul +from .gfbmul import gfbmul +from .gfbmadd import gfbmadd +from .gfbinv import gfbinv +from .gfpadd import gfpadd +from .gfpsub import gfpsub +from .gfpmul import gfpmul +from .gfpinv import gfpinv +from .gfpmadd import gfpmadd +from .gfpmsub import gfpmsub +from .gfpmsubr import gfpmsubr +from .pack_poly import pack_poly, unpack_poly +import unittest + + +class GF2Poly: + """Polynomial with GF(2) coefficients. + + `self.coefficients`: a list where `coefficients[-1] != 0`. + `coefficients[i]` is the coefficient for `x ** i`. + """ + + def __init__(self, coefficients=None): + self.coefficients = [] + if coefficients is not None: + if not isinstance(coefficients, (tuple, list)): + coefficients = list(coefficients) + # reversed to resize self.coefficients once + for i in reversed(range(len(coefficients))): + self[i] = coefficients[i] + + def __len__(self): + return len(self.coefficients) + + @property + def degree(self): + return len(self) - 1 + + @property + def lc(self): + """leading coefficient.""" + return 0 if len(self) == 0 else self.coefficients[-1] + + def __getitem__(self, key): + assert key >= 0 + if key < len(self): + return self.coefficients[key] + return 0 + + def __setitem__(self, key, value): + assert key >= 0 + assert value == 0 or value == 1 + if key < len(self): + self.coefficients[key] = value + while len(self) and self.coefficients[-1] == 0: + self.coefficients.pop() + elif value != 0: + self.coefficients += [0] * (key + 1 - len(self)) + self.coefficients[key] = value + + def __repr__(self): + return f"GF2Poly({self.coefficients})" + + def __iadd__(self, rhs): + for i in range(max(len(self), len(rhs))): + self[i] ^= rhs[i] + return self + + def __add__(self, rhs): + return GF2Poly(self).__iadd__(rhs) + + def __isub__(self, rhs): + return self.__iadd__(rhs) + + def __sub__(self, rhs): + return self.__add__(rhs) + + def __iter__(self): + return iter(self.coefficients) + + def __mul__(self, rhs): + retval = GF2Poly() + # reversed to resize retval.coefficients once + for i in reversed(range(len(self))): + if self[i]: + for j in reversed(range(len(rhs))): + retval[i + j] ^= rhs[j] + return retval + + def __ilshift__(self, amount): + """multiplies `self` by the polynomial `x**amount`""" + if len(self) != 0: + self.coefficients[:0] = [0] * amount + return self + + def __lshift__(self, amount): + """returns the polynomial `self * x**amount`""" + return GF2Poly(self).__ilshift__(amount) + + def __irshift__(self, amount): + """divides `self` by the polynomial `x**amount`, discarding the + remainder. + """ + if amount < len(self): + del self.coefficients[:amount] + else: + del self.coefficients[:] + return self + + def __rshift__(self, amount): + """divides `self` by the polynomial `x**amount`, discarding the + remainder. + """ + return GF2Poly(self).__irshift__(amount) + + def __divmod__(self, divisor): + # based on https://en.wikipedia.org/wiki/Polynomial_greatest_common_divisor#Euclidean_division + assert isinstance(divisor, GF2Poly) + if len(divisor) == 0: + raise ZeroDivisionError + q = GF2Poly() + r = GF2Poly(self) + while r.degree >= divisor.degree: + shift = r.degree - divisor.degree + q[shift] ^= 1 + r -= divisor << shift + return q, r + + def __floordiv__(self, divisor): + q, r = divmod(self, divisor) + return q + + def __mod__(self, divisor): + q, r = divmod(self, divisor) + return r + + def __pow__(self, exponent, modulus=None): + assert isinstance(exponent, int) and exponent >= 0 + assert modulus is None or isinstance(modulus, GF2Poly) + retval = GF2Poly([1]) + pow2 = GF2Poly(self) + while exponent != 0: + if exponent & 1: + retval *= pow2 + if modulus is not None: + retval %= modulus + exponent &= ~1 + else: + pow2 *= pow2 + if modulus is not None: + pow2 %= modulus + exponent >>= 1 + return retval + + def __eq__(self, rhs): + if isinstance(rhs, GF2Poly): + return self.coefficients == rhs.coefficients + return NotImplemented + + +class TestGF2Poly(unittest.TestCase): + def test_add(self): + a = GF2Poly([1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1]) + b = GF2Poly([0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0]) + c = a + b + self.assertEqual(a, GF2Poly([1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1])) + self.assertEqual(b, GF2Poly([0, 0, 1, 0, 1, 1, 1, 1, 1])) + self.assertEqual(c, GF2Poly([1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1])) + c = b + a + self.assertEqual(a, GF2Poly([1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1])) + self.assertEqual(b, GF2Poly([0, 0, 1, 0, 1, 1, 1, 1, 1])) + self.assertEqual(c, GF2Poly([1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1])) + a = GF2Poly([1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1]) + b = GF2Poly([0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1]) + c = a + b + self.assertEqual(a, GF2Poly([1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1])) + self.assertEqual(b, GF2Poly([0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1])) + self.assertEqual(c, GF2Poly([1, 0, 1, 1, 1, 0, 0, 1])) + c = a - b + self.assertEqual(c, GF2Poly([1, 0, 1, 1, 1, 0, 0, 1])) + c = b - a + self.assertEqual(c, GF2Poly([1, 0, 1, 1, 1, 0, 0, 1])) + + def test_shift(self): + a = GF2Poly([1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1]) + c = a << 0 + self.assertEqual(a, GF2Poly([1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1])) + self.assertEqual(c, GF2Poly([1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1])) + c = a << 5 + self.assertEqual(a, GF2Poly([1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1])) + self.assertEqual(c, GF2Poly( + [0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1])) + c = a << 10 + self.assertEqual(a, GF2Poly([1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1])) + self.assertEqual(c, GF2Poly( + [0] * 10 + [1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1])) + c = a >> 0 + self.assertEqual(a, GF2Poly([1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1])) + self.assertEqual(c, GF2Poly([1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1])) + c = a >> 5 + self.assertEqual(a, GF2Poly([1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1])) + self.assertEqual(c, GF2Poly([1, 1, 0, 1, 0, 1])) + c = a >> 10 + self.assertEqual(a, GF2Poly([1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1])) + self.assertEqual(c, GF2Poly([1])) + c = a >> 11 + self.assertEqual(a, GF2Poly([1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1])) + self.assertEqual(c, GF2Poly([])) + c = a >> 100 + self.assertEqual(a, GF2Poly([1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1])) + self.assertEqual(c, GF2Poly([])) + + def test_mul(self): + a = GF2Poly([1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1]) + b = GF2Poly([0, 0, 1, 0, 1, 1, 1, 1, 1]) + c = a * b + expected = GF2Poly([0, 0, 1, 0, 1, 0, 1, 1, 1, 0, + 0, 1, 0, 1, 1, 0, 0, 1, 1]) + self.assertEqual(a, GF2Poly([1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1])) + self.assertEqual(b, GF2Poly([0, 0, 1, 0, 1, 1, 1, 1, 1])) + self.assertEqual(c, expected) + + def test_divmod(self): + a = GF2Poly([1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1]) + b = GF2Poly([0, 0, 1, 0, 1, 1, 1, 1, 1]) + q, r = divmod(a, b) + self.assertEqual(a, GF2Poly([1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1])) + self.assertEqual(b, GF2Poly([0, 0, 1, 0, 1, 1, 1, 1, 1])) + self.assertEqual(q, GF2Poly([1, 1, 1])) + self.assertEqual(r, GF2Poly([1, 0, 1, 0, 0, 1, 0, 1])) + q = a // b + self.assertEqual(q, GF2Poly([1, 1, 1])) + r = a % b + self.assertEqual(r, GF2Poly([1, 0, 1, 0, 0, 1, 0, 1])) + + def test_pow(self): + b = GF2Poly([0, 1]) + for e in range(8): + expected = GF2Poly([0] * e + [1]) + with self.subTest(b=str(b), e=e, expected=str(expected)): + v = b ** e + self.assertEqual(b, GF2Poly([0, 1])) + self.assertEqual(v, expected) + + # AES's finite field reducing polynomial + m = GF2Poly([1, 1, 0, 1, 1, 0, 0, 0, 1]) + period = 2 ** m.degree - 1 + b = GF2Poly([1, 1, 0, 0, 1, 0, 1]) + e = period - 1 + expected = GF2Poly([0, 1, 0, 1, 0, 0, 1, 1]) + v = pow(b, e, m) + self.assertEqual(m, GF2Poly([1, 1, 0, 1, 1, 0, 0, 0, 1])) + self.assertEqual(b, GF2Poly([1, 1, 0, 0, 1, 0, 1])) + self.assertEqual(v, expected) + + # test that pow doesn't take inordinately long when given a modulus. + # adding a multiple of `period` should leave results unchanged. + e += period * 10 ** 15 + v = pow(b, e, m) + self.assertEqual(m, GF2Poly([1, 1, 0, 1, 1, 0, 0, 0, 1])) + self.assertEqual(b, GF2Poly([1, 1, 0, 0, 1, 0, 1])) + self.assertEqual(v, expected) + + +class GFB: + def __init__(self, value, red_poly=None): + if isinstance(value, GFB): + # copy value + assert red_poly is None + self.red_poly = GF2Poly(value.red_poly) + self.value = GF2Poly(value.value) + return + assert isinstance(value, GF2Poly) + assert isinstance(red_poly, GF2Poly) + assert red_poly.degree > 0 + self.value = value % red_poly + self.red_poly = red_poly + + def __repr__(self): + return f"GFB({self.value}, {self.red_poly})" + + def __add__(self, rhs): + assert isinstance(rhs, GFB) + assert self.red_poly == rhs.red_poly + return GFB((self.value + rhs.value) % self.red_poly, self.red_poly) + + def __sub__(self, rhs): + return self.__add__(rhs) + + def __eq__(self, rhs): + if isinstance(rhs, GFB): + return self.value == rhs.value and self.red_poly == rhs.red_poly + return NotImplemented + + def __mul__(self, rhs): + assert isinstance(rhs, GFB) + assert self.red_poly == rhs.red_poly + return GFB((self.value * rhs.value) % self.red_poly, self.red_poly) + + def __div__(self, rhs): + assert isinstance(rhs, GFB) + assert self.red_poly == rhs.red_poly + return self * rhs ** -1 + + @property + def __pow_period(self): + period = (1 << self.red_poly.degree) - 1 + assert period > 0, "internal logic error" + return period + + def __pow__(self, exponent): + assert isinstance(exponent, int) + if len(self.value) == 0: + if exponent < 0: + raise ZeroDivisionError + else: + return GFB(self) + exponent %= self.__pow_period + return GFB(pow(self.value, exponent, self.red_poly), self.red_poly) + + +class TestGFBClass(unittest.TestCase): + def test_add(self): + # AES's finite field reducing polynomial + red_poly = GF2Poly([1, 1, 0, 1, 1, 0, 0, 0, 1]) + a = GFB(GF2Poly([0, 1, 0, 1]), red_poly) + b = GFB(GF2Poly([0, 0, 0, 0, 0, 0, 1, 1]), red_poly) + expected = GFB(GF2Poly([0, 1, 0, 1, 0, 0, 1, 1]), red_poly) + c = a + b + self.assertEqual(red_poly, GF2Poly([1, 1, 0, 1, 1, 0, 0, 0, 1])) + self.assertEqual(a, GFB(GF2Poly([0, 1, 0, 1]), red_poly)) + self.assertEqual(b, GFB(GF2Poly([0, 0, 0, 0, 0, 0, 1, 1]), red_poly)) + self.assertEqual(c, expected) + c = a - b + self.assertEqual(red_poly, GF2Poly([1, 1, 0, 1, 1, 0, 0, 0, 1])) + self.assertEqual(a, GFB(GF2Poly([0, 1, 0, 1]), red_poly)) + self.assertEqual(b, GFB(GF2Poly([0, 0, 0, 0, 0, 0, 1, 1]), red_poly)) + self.assertEqual(c, expected) + c = b - a + self.assertEqual(red_poly, GF2Poly([1, 1, 0, 1, 1, 0, 0, 0, 1])) + self.assertEqual(a, GFB(GF2Poly([0, 1, 0, 1]), red_poly)) + self.assertEqual(b, GFB(GF2Poly([0, 0, 0, 0, 0, 0, 1, 1]), red_poly)) + self.assertEqual(c, expected) + + def test_mul(self): + # AES's finite field reducing polynomial + red_poly = GF2Poly([1, 1, 0, 1, 1, 0, 0, 0, 1]) + a = GFB(GF2Poly([0, 1, 0, 1, 0, 0, 1, 1]), red_poly) + b = GFB(GF2Poly([1, 1, 0, 0, 1, 0, 1]), red_poly) + expected = GFB(GF2Poly([1]), red_poly) + c = a * b + self.assertEqual(red_poly, GF2Poly([1, 1, 0, 1, 1, 0, 0, 0, 1])) + self.assertEqual(a, GFB(GF2Poly([0, 1, 0, 1, 0, 0, 1, 1]), red_poly)) + self.assertEqual(b, GFB(GF2Poly([1, 1, 0, 0, 1, 0, 1]), red_poly)) + self.assertEqual(c, expected) + + def test_pow(self): + # AES's finite field reducing polynomial + red_poly = GF2Poly([1, 1, 0, 1, 1, 0, 0, 0, 1]) + period = 2 ** red_poly.degree - 1 + b = GFB(GF2Poly([1, 1, 0, 0, 1, 0, 1]), red_poly) + e = period - 1 + expected = GFB(GF2Poly([0, 1, 0, 1, 0, 0, 1, 1]), red_poly) + v = b ** e + self.assertEqual(red_poly, GF2Poly([1, 1, 0, 1, 1, 0, 0, 0, 1])) + self.assertEqual(b, GFB(GF2Poly([1, 1, 0, 0, 1, 0, 1]), red_poly)) + self.assertEqual(v, expected) + e = -1 + v = b ** e + self.assertEqual(red_poly, GF2Poly([1, 1, 0, 1, 1, 0, 0, 0, 1])) + self.assertEqual(b, GFB(GF2Poly([1, 1, 0, 0, 1, 0, 1]), red_poly)) + self.assertEqual(v, expected) + + # test that pow doesn't take inordinately long when given a modulus. + # adding a multiple of `period` should leave results unchanged. + e += period * 10 ** 15 + v = b ** e + self.assertEqual(red_poly, GF2Poly([1, 1, 0, 1, 1, 0, 0, 0, 1])) + self.assertEqual(b, GFB(GF2Poly([1, 1, 0, 0, 1, 0, 1]), red_poly)) + self.assertEqual(v, expected) + + +class GFP: + def __init__(self, value, size): + assert isinstance(value, int) + assert isinstance(size, int) and size >= 2, "size is not a prime" + self.value = value % size + self.size = size + + def __repr__(self): + return f"GFP({self.value}, {self.size})" + + def __eq__(self, rhs): + if isinstance(rhs, GFP): + return self.value == rhs.value and self.size == rhs.size + return NotImplemented + + def __add__(self, rhs): + assert isinstance(rhs, GFP) + assert self.size == rhs.size + return GFP((self.value + rhs.value) % self.size, self.size) + + def __sub__(self, rhs): + assert isinstance(rhs, GFP) + assert self.size == rhs.size + return GFP((self.value - rhs.value) % self.size, self.size) + + def __mul__(self, rhs): + assert isinstance(rhs, GFP) + assert self.size == rhs.size + return GFP((self.value * rhs.value) % self.size, self.size) + + def __div__(self, rhs): + assert isinstance(rhs, GFP) + assert self.size == rhs.size + return self * rhs ** -1 + + @property + def __pow_period(self): + period = self.size - 1 + assert period > 0, "internal logic error" + return period + + def __pow__(self, exponent): + assert isinstance(exponent, int) + if self.value == 0: + if exponent < 0: + raise ZeroDivisionError + else: + return GFP(self.value, self.size) + exponent %= self.__pow_period + return GFP(pow(self.value, exponent, self.size), self.size) + + +PRIMES = 2, 3, 5, 7, 11, 13, 17, 19 +"""handy list of small primes for testing""" + + +class TestGFPClass(unittest.TestCase): + def test_add_sub(self): + for prime in PRIMES: + for av in range(prime): + for bv in range(prime): + with self.subTest(av=av, bv=bv, prime=prime): + a = GFP(av, prime) + b = GFP(bv, prime) + expected = GFP((av + bv) % prime, prime) + c = a + b + self.assertEqual(a, GFP(av, prime)) + self.assertEqual(b, GFP(bv, prime)) + self.assertEqual(c, expected) + c = b + a + self.assertEqual(a, GFP(av, prime)) + self.assertEqual(b, GFP(bv, prime)) + self.assertEqual(c, expected) + expected = GFP((av - bv) % prime, prime) + c = a - b + self.assertEqual(a, GFP(av, prime)) + self.assertEqual(b, GFP(bv, prime)) + self.assertEqual(c, expected) + expected = GFP((bv - av) % prime, prime) + c = b - a + self.assertEqual(a, GFP(av, prime)) + self.assertEqual(b, GFP(bv, prime)) + self.assertEqual(c, expected) + + def test_mul(self): + for prime in PRIMES: + for av in range(prime): + for bv in range(prime): + with self.subTest(av=av, bv=bv, prime=prime): + a = GFP(av, prime) + b = GFP(bv, prime) + expected = GFP((av * bv) % prime, prime) + c = a * b + self.assertEqual(a, GFP(av, prime)) + self.assertEqual(b, GFP(bv, prime)) + self.assertEqual(c, expected) + c = b * a + self.assertEqual(a, GFP(av, prime)) + self.assertEqual(b, GFP(bv, prime)) + self.assertEqual(c, expected) + + def test_pow(self): + for prime in PRIMES: + for bv in range(prime): + with self.subTest(bv=bv, prime=prime): + b = GFP(bv, prime) + period = prime - 1 + e = period - 1 + expected = GFP(pow(bv, e, prime) if bv != 0 else 0, prime) + v = b ** e + self.assertEqual(b, GFP(bv, prime)) + self.assertEqual(v, expected) + e = -1 + if bv != 0: + v = b ** e + self.assertEqual(b, GFP(bv, prime)) + self.assertEqual(v, expected) + + # test that pow doesn't take inordinately long when given + # a modulus. adding a multiple of `period` should leave + # results unchanged. + e += period * 10 ** 15 + v = b ** e + self.assertEqual(b, GFP(bv, prime)) + self.assertEqual(v, expected) + + +class TestCL(unittest.TestCase): + def test_cldivrem(self): + n_width = 8 + d_width = 4 + width = max(n_width, d_width) + for nv in range(2 ** n_width): + n = GF2Poly(unpack_poly(nv)) + for dv in range(1, 2 ** d_width): + d = GF2Poly(unpack_poly(dv)) + with self.subTest(n=str(n), nv=nv, d=str(d), dv=dv): + q_expected, r_expected = divmod(n, d) + self.assertEqual(q_expected * d + r_expected, n) + q, r = cldivrem(nv, dv, width) + q_expected = pack_poly(q_expected.coefficients) + r_expected = pack_poly(r_expected.coefficients) + self.assertEqual((q, r), (q_expected, r_expected)) + + def test_clmul(self): + a_width = 5 + b_width = 5 + for av in range(2 ** a_width): + a = GF2Poly(unpack_poly(av)) + for bv in range(2 ** b_width): + b = GF2Poly(unpack_poly(bv)) + with self.subTest(a=str(a), av=av, b=str(b), bv=bv): + expected = a * b + product = clmul(av, bv) + expected = pack_poly(expected.coefficients) + self.assertEqual(product, expected) + + +class TestGFBInstructions(unittest.TestCase): + @staticmethod + def init_aes_red_poly(): + # AES's finite field reducing polynomial + red_poly = GF2Poly([1, 1, 0, 1, 1, 0, 0, 0, 1]) + ST.reinit(GFBREDPOLY=pack_poly(red_poly.coefficients)) + return red_poly + + def test_gfbmul(self): + # AES's finite field reducing polynomial + red_poly = self.init_aes_red_poly() + a_width = 8 + b_width = 4 + for av in range(2 ** a_width): + a = GFB(GF2Poly(unpack_poly(av)), red_poly) + for bv in range(2 ** b_width): + b = GFB(GF2Poly(unpack_poly(bv)), red_poly) + expected = a * b + with self.subTest(a=str(a), av=av, b=str(b), bv=bv, expected=str(expected)): + product = gfbmul(av, bv) + expectedv = pack_poly(expected.value.coefficients) + self.assertEqual(product, expectedv) + + def test_gfbmadd(self): + # AES's finite field reducing polynomial + red_poly = self.init_aes_red_poly() + a_width = 5 + b_width = 4 + c_width = 4 + for av in range(2 ** a_width): + a = GFB(GF2Poly(unpack_poly(av)), red_poly) + for bv in range(2 ** b_width): + b = GFB(GF2Poly(unpack_poly(bv)), red_poly) + for cv in range(2 ** c_width): + c = GFB(GF2Poly(unpack_poly(cv)), red_poly) + expected = a * b + c + with self.subTest(a=str(a), av=av, + b=str(b), bv=bv, + c=str(c), cv=cv, + expected=str(expected)): + result = gfbmadd(av, bv, cv) + expectedv = pack_poly(expected.value.coefficients) + self.assertEqual(result, expectedv) + + def test_gfbinv(self): + # AES's finite field reducing polynomial + red_poly = self.init_aes_red_poly() + width = 8 + for av in range(2 ** width): + a = GFB(GF2Poly(unpack_poly(av)), red_poly) + expected = a ** -1 if av != 0 else GFB(GF2Poly(), red_poly) + with self.subTest(a=str(a), av=av, expected=str(expected)): + result = gfbinv(av) + expectedv = pack_poly(expected.value.coefficients) + self.assertEqual(result, expectedv) + + +class TestGFPInstructions(unittest.TestCase): + def test_gfpadd(self): + for prime in PRIMES: + for av in range(prime): + for bv in range(prime): + a = GFP(av, prime) + b = GFP(bv, prime) + expected = a + b + with self.subTest(a=str(a), b=str(b), + expected=str(expected)): + ST.reinit(GFPRIME=prime) + v = gfpadd(av, bv) + self.assertEqual(v, expected.value) + + def test_gfpsub(self): + for prime in PRIMES: + for av in range(prime): + for bv in range(prime): + a = GFP(av, prime) + b = GFP(bv, prime) + expected = a - b + with self.subTest(a=str(a), b=str(b), + expected=str(expected)): + ST.reinit(GFPRIME=prime) + v = gfpsub(av, bv) + self.assertEqual(v, expected.value) + + def test_gfpmul(self): + for prime in PRIMES: + for av in range(prime): + for bv in range(prime): + a = GFP(av, prime) + b = GFP(bv, prime) + expected = a * b + with self.subTest(a=str(a), b=str(b), + expected=str(expected)): + ST.reinit(GFPRIME=prime) + v = gfpmul(av, bv) + self.assertEqual(v, expected.value) + + def test_gfpinv(self): + for prime in PRIMES: + for av in range(prime): + a = GFP(av, prime) + if av == 0: + # TODO: determine what's expected for division by zero + continue + else: + expected = a ** -1 + with self.subTest(a=str(a), expected=str(expected)): + ST.reinit(GFPRIME=prime) + v = gfpinv(av) + self.assertEqual(v, expected.value) + + def test_gfpmadd(self): + for prime in PRIMES: + for av in range(prime): + for bv in range(prime): + for cv in range(prime): + a = GFP(av, prime) + b = GFP(bv, prime) + c = GFP(cv, prime) + expected = a * b + c + with self.subTest(a=str(a), b=str(b), c=str(c), + expected=str(expected)): + ST.reinit(GFPRIME=prime) + v = gfpmadd(av, bv, cv) + self.assertEqual(v, expected.value) + + def test_gfpmsub(self): + for prime in PRIMES: + for av in range(prime): + for bv in range(prime): + for cv in range(prime): + a = GFP(av, prime) + b = GFP(bv, prime) + c = GFP(cv, prime) + expected = a * b - c + with self.subTest(a=str(a), b=str(b), c=str(c), + expected=str(expected)): + ST.reinit(GFPRIME=prime) + v = gfpmsub(av, bv, cv) + self.assertEqual(v, expected.value) + + def test_gfpmsubr(self): + for prime in PRIMES: + for av in range(prime): + for bv in range(prime): + for cv in range(prime): + a = GFP(av, prime) + b = GFP(bv, prime) + c = GFP(cv, prime) + expected = c - a * b + with self.subTest(a=str(a), b=str(b), c=str(c), + expected=str(expected)): + ST.reinit(GFPRIME=prime) + v = gfpmsubr(av, bv, cv) + self.assertEqual(v, expected.value) + + +if __name__ == "__main__": + unittest.main()