add some debug output to Visitor2 (commented out)
[nmutil.git] / src / nmutil / nmoperator.py
1 """ nmigen operator functions / utils
2
3 This work is funded through NLnet under Grant 2019-02-012
4
5 License: LGPLv3+
6
7
8 eq:
9 --
10
11 a strategically very important function that is identical in function
12 to nmigen's Signal.eq function, except it may take objects, or a list
13 of objects, or a tuple of objects, and where objects may also be
14 Records.
15 """
16
17 from nmigen import Signal, Cat, Value
18 from nmigen.hdl.ast import ArrayProxy
19 from nmigen.hdl.rec import Record, Layout
20
21 from abc import ABCMeta, abstractmethod
22 from collections.abc import Sequence, Iterable
23 import inspect
24
25
26 class Visitor2:
27 """ a helper class for iterating twin-argument compound data structures.
28
29 Record is a special (unusual, recursive) case, where the input may be
30 specified as a dictionary (which may contain further dictionaries,
31 recursively), where the field names of the dictionary must match
32 the Record's field spec. Alternatively, an object with the same
33 member names as the Record may be assigned: it does not have to
34 *be* a Record.
35
36 ArrayProxy is also special-cased, it's a bit messy: whilst ArrayProxy
37 has an eq function, the object being assigned to it (e.g. a python
38 object) might not. despite the *input* having an eq function,
39 that doesn't help us, because it's the *ArrayProxy* that's being
40 assigned to. so.... we cheat. use the ports() function of the
41 python object, enumerate them, find out the list of Signals that way,
42 and assign them.
43 """
44 def iterator2(self, o, i):
45 if isinstance(o, dict):
46 yield from self.dict_iter2(o, i)
47
48 if not isinstance(o, Sequence):
49 o, i = [o], [i]
50 for (ao, ai) in zip(o, i):
51 # print ("visit", ao, ai)
52 # print (" isinstance Record(ao)", isinstance(ao, Record))
53 # print (" isinstance ArrayProxy(ao)",
54 # isinstance(ao, ArrayProxy))
55 # print (" isinstance Value(ai)",
56 # isinstance(ai, Value))
57 if isinstance(ao, Record):
58 yield from self.record_iter2(ao, ai)
59 elif isinstance(ao, ArrayProxy) and not isinstance(ai, Value):
60 yield from self.arrayproxy_iter2(ao, ai)
61 elif isinstance(ai, ArrayProxy) and not isinstance(ao, Value):
62 assert False, "whoops, input ArrayProxy not supported yet"
63 yield from self.arrayproxy_iter3(ao, ai)
64 else:
65 yield (ao, ai)
66
67 def dict_iter2(self, o, i):
68 for (k, v) in o.items():
69 # print ("d-iter", v, i[k])
70 yield (v, i[k])
71 return res
72
73 def _not_quite_working_with_all_unit_tests_record_iter2(self, ao, ai):
74 # print ("record_iter2", ao, ai, type(ao), type(ai))
75 if isinstance(ai, Value):
76 if isinstance(ao, Sequence):
77 ao, ai = [ao], [ai]
78 for o, i in zip(ao, ai):
79 yield (o, i)
80 return
81 for idx, (field_name, field_shape, _) in enumerate(ao.layout):
82 if isinstance(field_shape, Layout):
83 val = ai.fields
84 else:
85 val = ai
86 if hasattr(val, field_name): # check for attribute
87 val = getattr(val, field_name)
88 else:
89 val = val[field_name] # dictionary-style specification
90 yield from self.iterator2(ao.fields[field_name], val)
91
92 def record_iter2(self, ao, ai):
93 for idx, (field_name, field_shape, _) in enumerate(ao.layout):
94 if isinstance(field_shape, Layout):
95 val = ai.fields
96 else:
97 val = ai
98 if hasattr(val, field_name): # check for attribute
99 val = getattr(val, field_name)
100 else:
101 val = val[field_name] # dictionary-style specification
102 yield from self.iterator2(ao.fields[field_name], val)
103
104 def arrayproxy_iter2(self, ao, ai):
105 # print ("arrayproxy_iter2", ai.ports(), ai, ao)
106 for p in ai.ports():
107 # print ("arrayproxy - p", p, p.name, ao)
108 op = getattr(ao, p.name)
109 yield from self.iterator2(op, p)
110
111 def arrayproxy_iter3(self, ao, ai):
112 # print ("arrayproxy_iter3", ao.ports(), ai, ao)
113 for p in ao.ports():
114 # print ("arrayproxy - p", p, p.name, ao)
115 op = getattr(ao, p.name)
116 yield from self.iterator2(op, p)
117
118
119 class Visitor:
120 """ a helper class for iterating single-argument compound data structures.
121 similar to Visitor2.
122 """
123 def iterate(self, i):
124 """ iterate a compound structure recursively using yield
125 """
126 if not isinstance(i, Sequence):
127 i = [i]
128 for ai in i:
129 #print ("iterate", ai)
130 if isinstance(ai, Record):
131 #print ("record", list(ai.layout))
132 yield from self.record_iter(ai)
133 elif isinstance(ai, ArrayProxy) and not isinstance(ai, Value):
134 yield from self.array_iter(ai)
135 else:
136 yield ai
137
138 def record_iter(self, ai):
139 for idx, (field_name, field_shape, _) in enumerate(ai.layout):
140 if isinstance(field_shape, Layout):
141 val = ai.fields
142 else:
143 val = ai
144 if hasattr(val, field_name): # check for attribute
145 val = getattr(val, field_name)
146 else:
147 val = val[field_name] # dictionary-style specification
148 #print ("recidx", idx, field_name, field_shape, val)
149 yield from self.iterate(val)
150
151 def array_iter(self, ai):
152 for p in ai.ports():
153 yield from self.iterate(p)
154
155
156 def eq(o, i):
157 """ makes signals equal: a helper routine which identifies if it is being
158 passed a list (or tuple) of objects, or signals, or Records, and calls
159 the objects' eq function.
160 """
161 res = []
162 for (ao, ai) in Visitor2().iterator2(o, i):
163 rres = ao.eq(ai)
164 if not isinstance(rres, Sequence):
165 rres = [rres]
166 res += rres
167 return res
168
169
170 def shape(i):
171 #print ("shape", i)
172 r = 0
173 for part in list(i):
174 #print ("shape?", part)
175 s, _ = part.shape()
176 r += s
177 return r, False
178
179
180 def cat(i):
181 """ flattens a compound structure recursively using Cat
182 """
183 from nmigen._utils import flatten
184 #res = list(flatten(i)) # works (as of nmigen commit f22106e5) HOWEVER...
185 res = list(Visitor().iterate(i)) # needed because input may be a sequence
186 return Cat(*res)
187
188