implement gfb* instructions' pseudo-code
authorJacob Lifshay <programmerjake@gmail.com>
Wed, 16 Mar 2022 04:14:29 +0000 (21:14 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Wed, 16 Mar 2022 04:14:29 +0000 (21:14 -0700)
cldivrem.py
decode_reducing_polynomial.py [new file with mode: 0644]
gfbinv.py [new file with mode: 0644]
gfbmadd.py [new file with mode: 0644]
gfbmul.py [new file with mode: 0644]
gfbredpoly.py [new file with mode: 0644]
log2.py [new file with mode: 0644]
state.py [new file with mode: 0644]
test_cl_gfb_gfp.py

index 8b660a52d4b19b5ead66d362f819152a658d5fa0..33a528da422e69243edfaac62f21411c9273ef0b 100644 (file)
@@ -1,3 +1,6 @@
+from .log2 import floor_log2
+
+
 def cldivrem(n, d, width):
     """ Carry-less Division and Remainder.
         `n` and `d` are integers, `width` is the number of bits needed to hold
@@ -22,9 +25,6 @@ def cldivrem(n, d, width):
 def degree(v):
     """the degree of the GF(2) polynomial `v`. `v` is a non-negative integer.
     """
-    assert v >= 0
-    retval = -1
-    while v != 0:
-        retval += 1
-        v >>= 1
-    return retval
+    if v == 0:
+        return -1
+    return floor_log2(v)
diff --git a/decode_reducing_polynomial.py b/decode_reducing_polynomial.py
new file mode 100644 (file)
index 0000000..ac6623b
--- /dev/null
@@ -0,0 +1,19 @@
+from .state import ST
+
+
+def decode_reducing_polynomial():
+    """returns the decoded reducing polynomial as an integer.
+        Note: the returned integer is `XLEN + 1` bits wide.
+    """
+    v = ST.GFBREDPOLY & ((1 << ST.XLEN) - 1)  # mask to XLEN bits
+    if v == 0 or v == 2:  # GF(2)
+        return 0b10  # degree = 1, poly = x
+    if (v & 1) == 0:
+        # all reducing polynomials of degree > 1 must have the LSB set,
+        # because they must be irreducible polynomials (meaning they
+        # can't be factored), if the LSB was clear, then they would
+        # have `x` as a factor. Therefore, we can reuse the LSB clear
+        # to instead mean the polynomial has degree XLEN.
+        v |= 1 << ST.XLEN
+        v |= 1  # LSB must be set
+    return v
diff --git a/gfbinv.py b/gfbinv.py
new file mode 100644 (file)
index 0000000..648fac6
--- /dev/null
+++ b/gfbinv.py
@@ -0,0 +1,40 @@
+from .decode_reducing_polynomial import decode_reducing_polynomial
+from .cldivrem import degree
+
+
+def gfbinv(a):
+    """compute the GF(2^m) inverse of `a`."""
+    # Derived from Algorithm 3, from [7] in:
+    # https://scholar.archive.org/work/ktlygagf6jhslhx42gpuwwzc44/access/wayback/http://acsel-lab.com/arithmetic/arith18/papers/ARITH18_Kobayashi.pdf
+
+    s = decode_reducing_polynomial()
+    m = degree(s)
+    assert a >> m == 0, "`a` is out-of-range"
+    r = a
+    v = 0
+    u = 1
+    delta = 0
+
+    for _ in range(2 * m):
+        # could use count-leading-zeros here to skip ahead
+        if r >> m == 0:  # if the MSB of `r` is zero
+            r <<= 1
+            u <<= 1
+            delta += 1
+        else:
+            if s >> m != 0:  # if the MSB of `s` isn't zero
+                s ^= r
+                v ^= u
+            s <<= 1
+            if delta == 0:
+                r, s = s, r  # swap r and s
+                u, v = v << 1, u  # shift v and swap
+                delta = 1
+            else:
+                u >>= 1
+                delta -= 1
+    if a == 0:
+        # we specifically choose 0 as the result of inverting 0, rather than an
+        # error or undefined, since that's what AES needs.
+        return 0
+    return u
diff --git a/gfbmadd.py b/gfbmadd.py
new file mode 100644 (file)
index 0000000..df826c4
--- /dev/null
@@ -0,0 +1,11 @@
+from .state import ST
+from .decode_reducing_polynomial import decode_reducing_polynomial
+from .clmul import clmul
+from .cldivrem import cldivrem
+
+
+def gfbmadd(a, b, c):
+    v = clmul(a, b) ^ c
+    red_poly = decode_reducing_polynomial()
+    q, r = cldivrem(v, red_poly, width=ST.XLEN + 1)
+    return r
diff --git a/gfbmul.py b/gfbmul.py
new file mode 100644 (file)
index 0000000..6295a29
--- /dev/null
+++ b/gfbmul.py
@@ -0,0 +1,11 @@
+from .state import ST
+from .decode_reducing_polynomial import decode_reducing_polynomial
+from .clmul import clmul
+from .cldivrem import cldivrem
+
+
+def gfbmul(a, b):
+    product = clmul(a, b)
+    red_poly = decode_reducing_polynomial()
+    q, r = cldivrem(product, red_poly, width=ST.XLEN + 1)
+    return r
diff --git a/gfbredpoly.py b/gfbredpoly.py
new file mode 100644 (file)
index 0000000..1812967
--- /dev/null
@@ -0,0 +1,6 @@
+from .state import ST
+
+
+def gfbredpoly(immed):
+    # TODO: figure out how `immed` should be encoded
+    ST.GFBREDPOLY = immed
diff --git a/log2.py b/log2.py
new file mode 100644 (file)
index 0000000..0defcc8
--- /dev/null
+++ b/log2.py
@@ -0,0 +1,12 @@
+def floor_log2(v):
+    """return floor(log2(v))."""
+    assert isinstance(v, int)
+    assert v > 0
+    return v.bit_length() - 1
+
+
+def ceil_log2(v):
+    """return ceil(log2(v))."""
+    assert isinstance(v, int)
+    assert v > 0
+    return (v - 1).bit_length()
diff --git a/state.py b/state.py
new file mode 100644 (file)
index 0000000..91cb542
--- /dev/null
+++ b/state.py
@@ -0,0 +1,19 @@
+from .log2 import floor_log2
+from threading import local
+
+
+class State(local):
+    # thread local so unit tests can be run in parallel without breaking
+    def __init__(self, *, XLEN=64, GFBREDPOLY=0, GFPRIME=31):
+        assert isinstance(XLEN, int) and 2 ** floor_log2(XLEN) == XLEN
+        assert isinstance(GFBREDPOLY, int) and 0 <= GFBREDPOLY < 2 ** 64
+        assert isinstance(GFPRIME, int) and 0 <= GFPRIME < 2 ** 64
+        self.XLEN = XLEN
+        self.GFBREDPOLY = GFBREDPOLY
+        self.GFPRIME = GFPRIME
+
+    def reinit(self, *, XLEN=64, GFBREDPOLY=0, GFPRIME=31):
+        self.__init__(XLEN=XLEN, GFBREDPOLY=GFBREDPOLY, GFPRIME=GFPRIME)
+
+
+ST = State()
index e954eb8c76a399c242cd12d8c866b23f810a304d..787deab2a54ce9579e52791d3bcd446088dad3cd 100644 (file)
@@ -1,5 +1,9 @@
+from .state import ST
 from .cldivrem import cldivrem
 from .clmul import clmul
+from .gfbmul import gfbmul
+from .gfbmadd import gfbmadd
+from .gfbinv import gfbinv
 from .pack_poly import pack_poly, unpack_poly
 import unittest
 
@@ -125,6 +129,24 @@ class GF2Poly:
         q, r = divmod(self, divisor)
         return r
 
+    def __pow__(self, exponent, modulus=None):
+        assert isinstance(exponent, int) and exponent >= 0
+        assert modulus is None or isinstance(modulus, GF2Poly)
+        retval = GF2Poly([1])
+        pow2 = GF2Poly(self)
+        while exponent != 0:
+            if exponent & 1:
+                retval *= pow2
+                if modulus is not None:
+                    retval %= modulus
+                exponent &= ~1
+            else:
+                pow2 *= pow2
+                if modulus is not None:
+                    pow2 %= modulus
+                exponent >>= 1
+        return retval
+
     def __eq__(self, rhs):
         if isinstance(rhs, GF2Poly):
             return self.coefficients == rhs.coefficients
@@ -206,6 +228,152 @@ class TestGF2Poly(unittest.TestCase):
         r = a % b
         self.assertEqual(r, GF2Poly([1, 0, 1, 0, 0, 1, 0, 1]))
 
+    def test_pow(self):
+        b = GF2Poly([0, 1])
+        for e in range(8):
+            expected = GF2Poly([0] * e + [1])
+            with self.subTest(b=str(b), e=e, expected=str(expected)):
+                v = b ** e
+                self.assertEqual(b, GF2Poly([0, 1]))
+                self.assertEqual(v, expected)
+
+        # AES's finite field reducing polynomial
+        m = GF2Poly([1, 1, 0, 1, 1, 0, 0, 0, 1])
+        period = 2 ** m.degree - 1
+        b = GF2Poly([1, 1, 0, 0, 1, 0, 1])
+        e = period - 1
+        expected = GF2Poly([0, 1, 0, 1, 0, 0, 1, 1])
+        v = pow(b, e, m)
+        self.assertEqual(m, GF2Poly([1, 1, 0, 1, 1, 0, 0, 0, 1]))
+        self.assertEqual(b, GF2Poly([1, 1, 0, 0, 1, 0, 1]))
+        self.assertEqual(v, expected)
+
+        # test that pow doesn't take inordinately long when given a modulus.
+        # adding a multiple of `period` should leave results unchanged.
+        e += period * 10 ** 15
+        v = pow(b, e, m)
+        self.assertEqual(m, GF2Poly([1, 1, 0, 1, 1, 0, 0, 0, 1]))
+        self.assertEqual(b, GF2Poly([1, 1, 0, 0, 1, 0, 1]))
+        self.assertEqual(v, expected)
+
+
+class GFB:
+    def __init__(self, value, red_poly=None):
+        if isinstance(value, GFB):
+            # copy value
+            assert red_poly is None
+            self.red_poly = GF2Poly(value.red_poly)
+            self.value = GF2Poly(value.value)
+            return
+        assert isinstance(value, GF2Poly)
+        assert isinstance(red_poly, GF2Poly)
+        assert red_poly.degree > 0
+        self.value = value % red_poly
+        self.red_poly = red_poly
+
+    def __repr__(self):
+        return f"GFB({self.value}, {self.red_poly})"
+
+    def __add__(self, rhs):
+        assert isinstance(rhs, GFB)
+        assert self.red_poly == rhs.red_poly
+        return GFB((self.value + rhs.value) % self.red_poly, self.red_poly)
+
+    def __sub__(self, rhs):
+        return self.__add__(rhs)
+
+    def __eq__(self, rhs):
+        if isinstance(rhs, GFB):
+            return self.value == rhs.value and self.red_poly == rhs.red_poly
+        return NotImplemented
+
+    def __mul__(self, rhs):
+        assert isinstance(rhs, GFB)
+        assert self.red_poly == rhs.red_poly
+        return GFB((self.value * rhs.value) % self.red_poly, self.red_poly)
+
+    def __div__(self, rhs):
+        assert isinstance(rhs, GFB)
+        assert self.red_poly == rhs.red_poly
+        return self * rhs ** -1
+
+    @property
+    def __pow_period(self):
+        period = (1 << self.red_poly.degree) - 1
+        assert period > 0, "internal logic error"
+        return period
+
+    def __pow__(self, exponent):
+        assert isinstance(exponent, int)
+        if len(self.value) == 0:
+            if exponent < 0:
+                raise ZeroDivisionError
+            else:
+                return GFB(self)
+        exponent %= self.__pow_period
+        return GFB(pow(self.value, exponent, self.red_poly), self.red_poly)
+
+
+class TestGFBClass(unittest.TestCase):
+    def test_add(self):
+        # AES's finite field reducing polynomial
+        red_poly = GF2Poly([1, 1, 0, 1, 1, 0, 0, 0, 1])
+        a = GFB(GF2Poly([0, 1, 0, 1]), red_poly)
+        b = GFB(GF2Poly([0, 0, 0, 0, 0, 0, 1, 1]), red_poly)
+        expected = GFB(GF2Poly([0, 1, 0, 1, 0, 0, 1, 1]), red_poly)
+        c = a + b
+        self.assertEqual(red_poly, GF2Poly([1, 1, 0, 1, 1, 0, 0, 0, 1]))
+        self.assertEqual(a, GFB(GF2Poly([0, 1, 0, 1]), red_poly))
+        self.assertEqual(b, GFB(GF2Poly([0, 0, 0, 0, 0, 0, 1, 1]), red_poly))
+        self.assertEqual(c, expected)
+        c = a - b
+        self.assertEqual(red_poly, GF2Poly([1, 1, 0, 1, 1, 0, 0, 0, 1]))
+        self.assertEqual(a, GFB(GF2Poly([0, 1, 0, 1]), red_poly))
+        self.assertEqual(b, GFB(GF2Poly([0, 0, 0, 0, 0, 0, 1, 1]), red_poly))
+        self.assertEqual(c, expected)
+        c = b - a
+        self.assertEqual(red_poly, GF2Poly([1, 1, 0, 1, 1, 0, 0, 0, 1]))
+        self.assertEqual(a, GFB(GF2Poly([0, 1, 0, 1]), red_poly))
+        self.assertEqual(b, GFB(GF2Poly([0, 0, 0, 0, 0, 0, 1, 1]), red_poly))
+        self.assertEqual(c, expected)
+
+    def test_mul(self):
+        # AES's finite field reducing polynomial
+        red_poly = GF2Poly([1, 1, 0, 1, 1, 0, 0, 0, 1])
+        a = GFB(GF2Poly([0, 1, 0, 1, 0, 0, 1, 1]), red_poly)
+        b = GFB(GF2Poly([1, 1, 0, 0, 1, 0, 1]), red_poly)
+        expected = GFB(GF2Poly([1]), red_poly)
+        c = a * b
+        self.assertEqual(red_poly, GF2Poly([1, 1, 0, 1, 1, 0, 0, 0, 1]))
+        self.assertEqual(a, GFB(GF2Poly([0, 1, 0, 1, 0, 0, 1, 1]), red_poly))
+        self.assertEqual(b, GFB(GF2Poly([1, 1, 0, 0, 1, 0, 1]), red_poly))
+        self.assertEqual(c, expected)
+
+    def test_pow(self):
+        # AES's finite field reducing polynomial
+        red_poly = GF2Poly([1, 1, 0, 1, 1, 0, 0, 0, 1])
+        period = 2 ** red_poly.degree - 1
+        b = GFB(GF2Poly([1, 1, 0, 0, 1, 0, 1]), red_poly)
+        e = period - 1
+        expected = GFB(GF2Poly([0, 1, 0, 1, 0, 0, 1, 1]), red_poly)
+        v = b ** e
+        self.assertEqual(red_poly, GF2Poly([1, 1, 0, 1, 1, 0, 0, 0, 1]))
+        self.assertEqual(b, GFB(GF2Poly([1, 1, 0, 0, 1, 0, 1]), red_poly))
+        self.assertEqual(v, expected)
+        e = -1
+        v = b ** e
+        self.assertEqual(red_poly, GF2Poly([1, 1, 0, 1, 1, 0, 0, 0, 1]))
+        self.assertEqual(b, GFB(GF2Poly([1, 1, 0, 0, 1, 0, 1]), red_poly))
+        self.assertEqual(v, expected)
+
+        # test that pow doesn't take inordinately long when given a modulus.
+        # adding a multiple of `period` should leave results unchanged.
+        e += period * 10 ** 15
+        v = b ** e
+        self.assertEqual(red_poly, GF2Poly([1, 1, 0, 1, 1, 0, 0, 0, 1]))
+        self.assertEqual(b, GFB(GF2Poly([1, 1, 0, 0, 1, 0, 1]), red_poly))
+        self.assertEqual(v, expected)
+
 
 class TestCL(unittest.TestCase):
     def test_cldivrem(self):
@@ -238,5 +406,62 @@ class TestCL(unittest.TestCase):
                     self.assertEqual(product, expected)
 
 
+class TestGFBInstructions(unittest.TestCase):
+    @staticmethod
+    def init_aes_red_poly():
+        # AES's finite field reducing polynomial
+        red_poly = GF2Poly([1, 1, 0, 1, 1, 0, 0, 0, 1])
+        ST.reinit(GFBREDPOLY=pack_poly(red_poly.coefficients))
+        return red_poly
+
+    def test_gfbmul(self):
+        # AES's finite field reducing polynomial
+        red_poly = self.init_aes_red_poly()
+        a_width = 8
+        b_width = 4
+        for av in range(2 ** a_width):
+            a = GFB(GF2Poly(unpack_poly(av)), red_poly)
+            for bv in range(2 ** b_width):
+                b = GFB(GF2Poly(unpack_poly(bv)), red_poly)
+                expected = a * b
+                with self.subTest(a=str(a), av=av, b=str(b), bv=bv, expected=str(expected)):
+                    product = gfbmul(av, bv)
+                    expectedv = pack_poly(expected.value.coefficients)
+                    self.assertEqual(product, expectedv)
+
+    def test_gfbmadd(self):
+        # AES's finite field reducing polynomial
+        red_poly = self.init_aes_red_poly()
+        a_width = 5
+        b_width = 4
+        c_width = 4
+        for av in range(2 ** a_width):
+            a = GFB(GF2Poly(unpack_poly(av)), red_poly)
+            for bv in range(2 ** b_width):
+                b = GFB(GF2Poly(unpack_poly(bv)), red_poly)
+                for cv in range(2 ** c_width):
+                    c = GFB(GF2Poly(unpack_poly(cv)), red_poly)
+                    expected = a * b + c
+                    with self.subTest(a=str(a), av=av,
+                                      b=str(b), bv=bv,
+                                      c=str(c), cv=cv,
+                                      expected=str(expected)):
+                        result = gfbmadd(av, bv, cv)
+                        expectedv = pack_poly(expected.value.coefficients)
+                        self.assertEqual(result, expectedv)
+
+    def test_gfbinv(self):
+        # AES's finite field reducing polynomial
+        red_poly = self.init_aes_red_poly()
+        width = 8
+        for av in range(2 ** width):
+            a = GFB(GF2Poly(unpack_poly(av)), red_poly)
+            expected = a ** -1 if av != 0 else GFB(GF2Poly(), red_poly)
+            with self.subTest(a=str(a), av=av, expected=str(expected)):
+                result = gfbinv(av)
+                expectedv = pack_poly(expected.value.coefficients)
+                self.assertEqual(result, expectedv)
+
+
 if __name__ == "__main__":
     unittest.main()