bug 1151: got python version of curve25519_mul working
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Mon, 26 Feb 2024 15:02:31 +0000 (15:02 +0000)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Mon, 26 Feb 2024 15:02:31 +0000 (15:02 +0000)
src/openpower/decoder/isa/ed25519/curve25519_mul.py
src/openpower/decoder/isa/ed25519/ed25519util.py

index 9da698f1ed772e464fdf7394f1402969feb673f2..4d0e483c4af42f18012977df16b1badb72d59649 100644 (file)
@@ -34,18 +34,22 @@ r1 +=   c;
 """
 
 import random
-from ed25519util import add128_64, lo128, shr128, reduce_mask_51, MASK64
+from ed25519util import (add128_64, lo128, shr128, 
+                         reduce_mask_51, MASK64, MASK128)
+from copy import deepcopy
 
 def curve25519_mul(r, s):
 
-    t = [0] * 5
+    t = [0] * 5 # all 128-bit
+    r = deepcopy(r)
+    s = deepcopy(s)
 
     for i in range(5):
         print("t%d += " % i, end='')
         for j in range(i+1):
             sidx = i-j
             print("r%d*s%d + " % (j, sidx), end='')
-            t[i] += (r[j] * s[sidx]) & MASK64
+            t[i] += (r[j] * s[sidx]) & MASK128
         print()
 
     for i in range(1,5):
@@ -58,7 +62,7 @@ def curve25519_mul(r, s):
         for j in range(i):
             jidx, sidx = 4-j, 5-(i-j)
             print("r%d*s%d + " % (jidx, sidx), end='')
-            t[tidx] += (r[jidx] * s[sidx]) & MASK64
+            t[tidx] += (r[jidx] * s[sidx]) & MASK128
         print()
 
     # this is the one where i *think* it possible to do some sort
@@ -66,9 +70,11 @@ def curve25519_mul(r, s):
 
     c = 0
     for i in range(5):
+        print("carry %d" % i, hex(c), hex(t[i]), end='')
         t[i] = add128_64(t[i], c)
         r[i] = lo128(t[i]) & reduce_mask_51
-        c = shr128(t[i], 51);
+        c = shr128(t[i], 51)
+        print()
 
     r[0] +=   c * 19; c = r[0] >> 51; r[0] = r[0] & reduce_mask_51;
     r[1] +=   c;
@@ -90,17 +96,7 @@ def expand(a): # put bignum into an array
         a >>= 51
     return res
 
-
-if __name__ == '__main__':
-    random.seed(2) # set the same seed (consistent test)
-    r, s = [0]*5, [0]*5
-    # dummy/obvious test
-    r[0] = (2<<53)-1
-    s[0] = 2<<60
-    for j in range(5):
-        #r[j] = random.randint(0, 1<<50)
-        #s[j] = random.randint(0, 1<<50)
-        pass
+def check(r, s):
     rb, sb = contract(r), contract(s)
     print ("r", list(map(hex,r)), hex(rb))
     print ("s", list(map(hex,s)), hex(sb))
@@ -109,5 +105,22 @@ if __name__ == '__main__':
     print ("t", list(map(hex,t)))
     print ("     ", hex(tb))
 
-    check = rb * sb % ((1<<255)-19)
+    check = (rb * sb) % ((1<<255)-19)
     print ("check", hex(check))
+    assert check == tb
+
+if __name__ == '__main__':
+    random.seed(2) # set the same seed (consistent test)
+    # dummy/obvious test
+    r, s = [0]*5, [0]*5
+    r[0] = (1<<51)-1
+    s[0] = 1<<51
+    s[1] = 1<<51
+    r[1] = 1<<51
+    check(r, s)
+    for i in range(100000):
+        r, s = [0]*5, [0]*5
+        for j in range(5):
+            r[j] = random.randint(0, 1<<50)
+            s[j] = random.randint(0, 1<<50)
+        check(r, s)
index d328758d7450bddf1830070de9165847b84f7487..3863f4f022ade2a70590aa16132ab55e34a0490a 100644 (file)
@@ -17,12 +17,12 @@ MASK128 = (1<<128)-1
 reduce_mask_51 = (1<<51)-1
 reduce_mask_40 = (1<<40)-1
 reduce_mask_56 = (1<<56)-1
-def mul64x64_128(a,b): return (a * b) & MASK128
+def mul64x64_128(a,b)       : return (a * b) & MASK128
 def shr128_pair(hi,lo,shift): return shr128((hi<<64)|lo, shift)
 def shl128_pair(hi,lo,shift): return shl128((hi<<64)|lo, shift)
-def shr128(a,shift): return lo128(a>>shift)
-def shl128(a,shift): return lo128((a<<shift)>>64)
-def add128(a,b): return (a + b) & MASK128
-def add128_64(a,b): return a + lo128(b)
-def lo128(a): return a & MASK64
-def hi128(a): return lo128(a>>64)
+def shr128(a,shift)         : return lo128(a>>shift)
+def shl128(a,shift)         : return lo128((a<<shift)>>64)
+def add128(a,b)             : return (a + b) & MASK128
+def add128_64(a,b)          : return (a + lo128(b)) & MASK128
+def lo128(a)                : return a & MASK64
+def hi128(a)                : return lo128(a>>64)