e48bb7fb56222d5de82fdf79b77278aa6bb4f1ca
[nmutil.git] / src / nmutil / nmoperator.py
1 """ nmigen operator functions / utils
2
3 This work is funded through NLnet under Grant 2019-02-012
4
5 License: LGPLv3+
6
7
8 eq:
9 --
10
11 a strategically very important function that is identical in function
12 to nmigen's Signal.eq function, except it may take objects, or a list
13 of objects, or a tuple of objects, and where objects may also be
14 Records.
15 """
16
17 from nmigen import Signal, Cat, Value
18 from nmigen.hdl.ast import ArrayProxy
19 from nmigen.hdl.rec import Record, Layout
20
21 from abc import ABCMeta, abstractmethod
22 from collections.abc import Sequence, Iterable
23 import inspect
24
25
26 class Visitor2:
27 """ a helper class for iterating twin-argument compound data structures.
28
29 Record is a special (unusual, recursive) case, where the input may be
30 specified as a dictionary (which may contain further dictionaries,
31 recursively), where the field names of the dictionary must match
32 the Record's field spec. Alternatively, an object with the same
33 member names as the Record may be assigned: it does not have to
34 *be* a Record.
35
36 ArrayProxy is also special-cased, it's a bit messy: whilst ArrayProxy
37 has an eq function, the object being assigned to it (e.g. a python
38 object) might not. despite the *input* having an eq function,
39 that doesn't help us, because it's the *ArrayProxy* that's being
40 assigned to. so.... we cheat. use the ports() function of the
41 python object, enumerate them, find out the list of Signals that way,
42 and assign them.
43 """
44
45 def iterator2(self, o, i):
46 if isinstance(o, dict):
47 yield from self.dict_iter2(o, i)
48
49 if not isinstance(o, Sequence):
50 o, i = [o], [i]
51 for (ao, ai) in zip(o, i):
52 # print ("visit", ao, ai)
53 # print (" isinstance Record(ao)", isinstance(ao, Record))
54 # print (" isinstance ArrayProxy(ao)",
55 # isinstance(ao, ArrayProxy))
56 # print (" isinstance Value(ai)",
57 # isinstance(ai, Value))
58 if isinstance(ao, Record):
59 yield from self.record_iter2(ao, ai)
60 elif isinstance(ao, ArrayProxy) and not isinstance(ai, Value):
61 yield from self.arrayproxy_iter2(ao, ai)
62 elif isinstance(ai, ArrayProxy) and not isinstance(ao, Value):
63 assert False, "whoops, input ArrayProxy not supported yet"
64 yield from self.arrayproxy_iter3(ao, ai)
65 else:
66 yield (ao, ai)
67
68 def dict_iter2(self, o, i):
69 for (k, v) in o.items():
70 # print ("d-iter", v, i[k])
71 yield (v, i[k])
72 return res
73
74 def _not_quite_working_with_all_unit_tests_record_iter2(self, ao, ai):
75 # print ("record_iter2", ao, ai, type(ao), type(ai))
76 if isinstance(ai, Value):
77 if isinstance(ao, Sequence):
78 ao, ai = [ao], [ai]
79 for o, i in zip(ao, ai):
80 yield (o, i)
81 return
82 for idx, (field_name, field_shape, _) in enumerate(ao.layout):
83 if isinstance(field_shape, Layout):
84 val = ai.fields
85 else:
86 val = ai
87 if hasattr(val, field_name): # check for attribute
88 val = getattr(val, field_name)
89 else:
90 val = val[field_name] # dictionary-style specification
91 yield from self.iterator2(ao.fields[field_name], val)
92
93 def record_iter2(self, ao, ai):
94 for idx, (field_name, field_shape, _) in enumerate(ao.layout):
95 if isinstance(field_shape, Layout):
96 val = ai.fields
97 else:
98 val = ai
99 if hasattr(val, field_name): # check for attribute
100 val = getattr(val, field_name)
101 else:
102 val = val[field_name] # dictionary-style specification
103 yield from self.iterator2(ao.fields[field_name], val)
104
105 def arrayproxy_iter2(self, ao, ai):
106 # print ("arrayproxy_iter2", ai.ports(), ai, ao)
107 for p in ai.ports():
108 # print ("arrayproxy - p", p, p.name, ao)
109 op = getattr(ao, p.name)
110 yield from self.iterator2(op, p)
111
112 def arrayproxy_iter3(self, ao, ai):
113 # print ("arrayproxy_iter3", ao.ports(), ai, ao)
114 for p in ao.ports():
115 # print ("arrayproxy - p", p, p.name, ao)
116 op = getattr(ao, p.name)
117 yield from self.iterator2(op, p)
118
119
120 class Visitor:
121 """ a helper class for iterating single-argument compound data structures.
122 similar to Visitor2.
123 """
124
125 def iterate(self, i):
126 """ iterate a compound structure recursively using yield
127 """
128 if not isinstance(i, Sequence):
129 i = [i]
130 for ai in i:
131 #print ("iterate", ai)
132 if isinstance(ai, Record):
133 #print ("record", list(ai.layout))
134 yield from self.record_iter(ai)
135 elif isinstance(ai, ArrayProxy) and not isinstance(ai, Value):
136 yield from self.array_iter(ai)
137 else:
138 yield ai
139
140 def record_iter(self, ai):
141 for idx, (field_name, field_shape, _) in enumerate(ai.layout):
142 if isinstance(field_shape, Layout):
143 val = ai.fields
144 else:
145 val = ai
146 if hasattr(val, field_name): # check for attribute
147 val = getattr(val, field_name)
148 else:
149 val = val[field_name] # dictionary-style specification
150 #print ("recidx", idx, field_name, field_shape, val)
151 yield from self.iterate(val)
152
153 def array_iter(self, ai):
154 for p in ai.ports():
155 yield from self.iterate(p)
156
157
158 def eq(o, i):
159 """ makes signals equal: a helper routine which identifies if it is being
160 passed a list (or tuple) of objects, or signals, or Records, and calls
161 the objects' eq function.
162 """
163 res = []
164 for (ao, ai) in Visitor2().iterator2(o, i):
165 rres = ao.eq(ai)
166 if not isinstance(rres, Sequence):
167 rres = [rres]
168 res += rres
169 return res
170
171
172 def shape(i):
173 #print ("shape", i)
174 r = 0
175 for part in list(i):
176 #print ("shape?", part)
177 s, _ = part.shape()
178 r += s
179 return r, False
180
181
182 def cat(i):
183 """ flattens a compound structure recursively using Cat
184 """
185 from nmigen._utils import flatten
186 # res = list(flatten(i)) # works (as of nmigen commit f22106e5) HOWEVER...
187 res = list(Visitor().iterate(i)) # needed because input may be a sequence
188 return Cat(*res)