speed up ==, hash, <, >, <=, and >= for plain_data
[nmutil.git] / src / nmutil / util.py
index ceb9a710df8ac7a0f42f319bae351dcb155fe0cd..5864d5edb92b8db47c752280c1fb3ae83145c698 100644 (file)
@@ -1,3 +1,4 @@
+# SPDX-License-Identifier: LGPL-3-or-later
 """
     This work is funded through NLnet under Grant 2019-02-012
 
@@ -8,6 +9,7 @@
 from collections.abc import Iterable
 from nmigen import Mux, Signal, Cat
 
+
 # XXX this already exists in nmigen._utils
 # see https://bugs.libre-soc.org/show_bug.cgi?id=297
 def flatten(v):
@@ -17,25 +19,36 @@ def flatten(v):
     else:
         yield v
 
+
 # tree reduction function.  operates recursively.
-def treereduce(tree, op, fn):
-    """treereduce: apply a map-reduce to a list.
+def treereduce(tree, op, fn=None):
+    """treereduce: apply a map-reduce to a list, reducing to a single item
+
+    this is *not* the same as "x = Signal(64) reduce(x, operator.add)",
+    which is a bit-wise reduction down to a single bit
+
+    it is "l = [Signal(w), ..., Signal(w)] reduce(l, operator.add)"
+    i.e. l[0] + l[1] ...
+
     examples: OR-reduction of one member of a list of Records down to a
-              single data point:
-              treereduce(tree, operator.or_, lambda x: getattr(x, "data_o"))
+              single value:
+              treereduce(tree, operator.or_, lambda x: getattr(x, "o_data"))
     """
-    #print ("treereduce", tree)
+    if fn is None:
+        def fn(x): return x
     if not isinstance(tree, list):
         return tree
     if len(tree) == 1:
         return fn(tree[0])
     if len(tree) == 2:
         return op(fn(tree[0]), fn(tree[1]))
-    s = len(tree) // 2 # splitpoint
+    s = len(tree) // 2  # splitpoint
     return op(treereduce(tree[:s], op, fn),
               treereduce(tree[s:], op, fn))
 
 # chooses assignment of 32 bit or full 64 bit depending on is_32bit
+
+
 def eq32(is_32bit, dest, src):
     return [dest[0:32].eq(src[0:32]),
             dest[32:64].eq(Mux(is_32bit, 0, src[32:64]))]
@@ -61,8 +74,8 @@ def rising_edge(m, sig):
     rising = Signal.like(sig)
     delay.name = "%s_dly" % sig.name
     rising.name = "%s_rise" % sig.name
-    m.d.sync += delay.eq(sig) # 1 clock delay
-    m.d.comb += rising.eq(sig & ~delay) # sig is hi but delay-sig is lo
+    m.d.sync += delay.eq(sig)  # 1 clock delay
+    m.d.comb += rising.eq(sig & ~delay)  # sig is hi but delay-sig is lo
     return rising