turn visitor into a class
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Thu, 11 Apr 2019 03:37:53 +0000 (04:37 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Thu, 11 Apr 2019 03:37:53 +0000 (04:37 +0100)
src/add/singlepipe.py

index ff945b1768cd9d5848807bcbca36244d86a060dc..3d4eefb7c2404ba3a254fb0e10c08538097b4a0c 100644 (file)
@@ -275,7 +275,7 @@ class NextControl:
                ]
 
 
-def visitor(o, i, fn):
+class Visitor:
     """ a helper routine which identifies if it is being passed a list
         (or tuple) of objects, or signals, or Records, and calls
         a visitor function.
@@ -297,51 +297,59 @@ def visitor(o, i, fn):
         python object, enumerate them, find out the list of Signals that way,
         and assign them.
     """
-    res = []
-    if isinstance(o, dict):
-        for (k, v) in o.items():
-            print ("d-eq", v, i[k])
-            res.append(fn(v, i[k]))
+    def visit(self, o, i, fn):
+        res = []
+        if isinstance(o, dict):
+            for (k, v) in o.items():
+                print ("d-eq", v, i[k])
+                res.append(fn(v, i[k]))
+            return res
+
+        if not isinstance(o, Sequence):
+            o, i = [o], [i]
+        for (ao, ai) in zip(o, i):
+            #print ("visit", fn, ao, ai)
+            if isinstance(ao, Record):
+                rres = []
+                for idx, (field_name, field_shape, _) in enumerate(ao.layout):
+                    if isinstance(field_shape, Layout):
+                        val = ai.fields
+                    else:
+                        val = ai
+                    if hasattr(val, field_name): # check for attribute
+                        val = getattr(val, field_name)
+                    else:
+                        val = val[field_name] # dictionary-style specification
+                    rres += self.visit(ao.fields[field_name], val, fn)
+            elif isinstance(ao, ArrayProxy) and not isinstance(ai, Value):
+                rres = []
+                for p in ai.ports():
+                    op = getattr(ao, p.name)
+                    #print (op, p, p.name)
+                    rres.append(fn(op, p))
+            else:
+                rres = fn(ao, ai)
+            if not isinstance(rres, Sequence):
+                rres = [rres]
+            res += rres
         return res
 
-    if not isinstance(o, Sequence):
-        o, i = [o], [i]
-    for (ao, ai) in zip(o, i):
-        #print ("visit", fn, ao, ai)
-        if isinstance(ao, Record):
-            rres = []
-            for idx, (field_name, field_shape, _) in enumerate(ao.layout):
-                if isinstance(field_shape, Layout):
-                    val = ai.fields
-                else:
-                    val = ai
-                if hasattr(val, field_name): # check for attribute
-                    val = getattr(val, field_name)
-                else:
-                    val = val[field_name] # dictionary-style specification
-                rres += visitor(ao.fields[field_name], val, fn)
-        elif isinstance(ao, ArrayProxy) and not isinstance(ai, Value):
-            rres = []
-            for p in ai.ports():
-                op = getattr(ao, p.name)
-                #print (op, p, p.name)
-                rres.append(fn(op, p))
-        else:
-            rres = fn(ao, ai)
-        if not isinstance(rres, Sequence):
-            rres = [rres]
-        res += rres
-    return res
 
-def _eq_fn(o, i):
-    return o.eq(i)
+class Eq(Visitor):
+    def __init__(self):
+        self.res = []
+    def __call__(self, o, i):
+        def _eq_fn(o, i):
+            return o.eq(i)
+        res = self.visit(o, i, _eq_fn)
+        return res
 
 def eq(o, i):
     """ makes signals equal: a helper routine which identifies if it is being
         passed a list (or tuple) of objects, or signals, or Records, and calls
         the objects' eq function.
     """
-    return visitor(o, i, _eq_fn)
+    return Eq()(o, i)
 
 
 class StageCls(metaclass=ABCMeta):