b72940070122db877bb389a6908db9c6a1e9fbbb
[nmigen.git] / nmigen / hdl / rec.py
1 from enum import Enum
2 from collections import OrderedDict
3 from functools import reduce, wraps
4
5 from .. import tracer
6 from .._utils import union, deprecated
7 from .ast import *
8
9
10 __all__ = ["Direction", "DIR_NONE", "DIR_FANOUT", "DIR_FANIN", "Layout", "Record"]
11
12
13 Direction = Enum('Direction', ('NONE', 'FANOUT', 'FANIN'))
14
15 DIR_NONE = Direction.NONE
16 DIR_FANOUT = Direction.FANOUT
17 DIR_FANIN = Direction.FANIN
18
19
20 class Layout:
21 @staticmethod
22 def cast(obj, *, src_loc_at=0):
23 if isinstance(obj, Layout):
24 return obj
25 return Layout(obj, src_loc_at=1 + src_loc_at)
26
27 def __init__(self, fields, *, src_loc_at=0):
28 self.fields = OrderedDict()
29 for field in fields:
30 if not isinstance(field, tuple) or len(field) not in (2, 3):
31 raise TypeError("Field {!r} has invalid layout: should be either "
32 "(name, shape) or (name, shape, direction)"
33 .format(field))
34 if len(field) == 2:
35 name, shape = field
36 direction = DIR_NONE
37 if isinstance(shape, list):
38 shape = Layout.cast(shape)
39 else:
40 name, shape, direction = field
41 if not isinstance(direction, Direction):
42 raise TypeError("Field {!r} has invalid direction: should be a Direction "
43 "instance like DIR_FANIN"
44 .format(field))
45 if not isinstance(name, str):
46 raise TypeError("Field {!r} has invalid name: should be a string"
47 .format(field))
48 if not isinstance(shape, Layout):
49 try:
50 # Check provided shape by calling Shape.cast and checking for exception
51 Shape.cast(shape, src_loc_at=1 + src_loc_at)
52 except Exception as error:
53 raise TypeError("Field {!r} has invalid shape: should be castable to Shape "
54 "or a list of fields of a nested record"
55 .format(field))
56 if name in self.fields:
57 raise NameError("Field {!r} has a name that is already present in the layout"
58 .format(field))
59 self.fields[name] = (shape, direction)
60
61 def __getitem__(self, item):
62 if isinstance(item, tuple):
63 return Layout([
64 (name, shape, dir)
65 for (name, (shape, dir)) in self.fields.items()
66 if name in item
67 ])
68
69 return self.fields[item]
70
71 def __iter__(self):
72 for name, (shape, dir) in self.fields.items():
73 yield (name, shape, dir)
74
75 def __eq__(self, other):
76 return self.fields == other.fields
77
78 def __repr__(self):
79 field_reprs = []
80 for name, shape, dir in self:
81 if dir == DIR_NONE:
82 field_reprs.append("({!r}, {!r})".format(name, shape))
83 else:
84 field_reprs.append("({!r}, {!r}, Direction.{})".format(name, shape, dir.name))
85 return "Layout([{}])".format(", ".join(field_reprs))
86
87
88 class Record(ValueCastable):
89 @staticmethod
90 def like(other, *, name=None, name_suffix=None, src_loc_at=0):
91 if name is not None:
92 new_name = str(name)
93 elif name_suffix is not None:
94 new_name = other.name + str(name_suffix)
95 else:
96 new_name = tracer.get_var_name(depth=2 + src_loc_at, default=None)
97
98 def concat(a, b):
99 if a is None:
100 return b
101 return "{}__{}".format(a, b)
102
103 fields = {}
104 for field_name in other.fields:
105 field = other[field_name]
106 if isinstance(field, Record):
107 fields[field_name] = Record.like(field, name=concat(new_name, field_name),
108 src_loc_at=1 + src_loc_at)
109 else:
110 fields[field_name] = Signal.like(field, name=concat(new_name, field_name),
111 src_loc_at=1 + src_loc_at)
112
113 return Record(other.layout, name=new_name, fields=fields, src_loc_at=1)
114
115 def __init__(self, layout, *, name=None, fields=None, src_loc_at=0):
116 if name is None:
117 name = tracer.get_var_name(depth=2 + src_loc_at, default=None)
118
119 self.name = name
120 self.src_loc = tracer.get_src_loc(src_loc_at)
121
122 def concat(a, b):
123 if a is None:
124 return b
125 return "{}__{}".format(a, b)
126
127 self.layout = Layout.cast(layout, src_loc_at=1 + src_loc_at)
128 self.fields = OrderedDict()
129 for field_name, field_shape, field_dir in self.layout:
130 if fields is not None and field_name in fields:
131 field = fields[field_name]
132 if isinstance(field_shape, Layout):
133 assert isinstance(field, Record) and field_shape == field.layout
134 else:
135 assert isinstance(field, Signal) and Shape.cast(field_shape) == field.shape()
136 self.fields[field_name] = field
137 else:
138 if isinstance(field_shape, Layout):
139 self.fields[field_name] = Record(field_shape, name=concat(name, field_name),
140 src_loc_at=1 + src_loc_at)
141 else:
142 self.fields[field_name] = Signal(field_shape, name=concat(name, field_name),
143 src_loc_at=1 + src_loc_at)
144
145 def __getattr__(self, name):
146 return self[name]
147
148 def __getitem__(self, item):
149 if isinstance(item, str):
150 try:
151 return self.fields[item]
152 except KeyError:
153 if self.name is None:
154 reference = "Unnamed record"
155 else:
156 reference = "Record '{}'".format(self.name)
157 raise AttributeError("{} does not have a field '{}'. Did you mean one of: {}?"
158 .format(reference, item, ", ".join(self.fields))) from None
159 elif isinstance(item, tuple):
160 return Record(self.layout[item], fields={
161 field_name: field_value
162 for field_name, field_value in self.fields.items()
163 if field_name in item
164 })
165 else:
166 try:
167 return Value.__getitem__(self, item)
168 except KeyError:
169 if self.name is None:
170 reference = "Unnamed record"
171 else:
172 reference = "Record '{}'".format(self.name)
173 raise AttributeError("{} does not have a field '{}'. Did you mean one of: {}?"
174 .format(reference, item, ", ".join(self.fields))) from None
175
176 @ValueCastable.lowermethod
177 def as_value(self):
178 return Cat(self.fields.values())
179
180 def __len__(self):
181 return len(self.as_value())
182
183 def _lhs_signals(self):
184 return union((f._lhs_signals() for f in self.fields.values()), start=SignalSet())
185
186 def _rhs_signals(self):
187 return union((f._rhs_signals() for f in self.fields.values()), start=SignalSet())
188
189 def __repr__(self):
190 fields = []
191 for field_name, field in self.fields.items():
192 if isinstance(field, Signal):
193 fields.append(field_name)
194 else:
195 fields.append(repr(field))
196 name = self.name
197 if name is None:
198 name = "<unnamed>"
199 return "(rec {} {})".format(name, " ".join(fields))
200
201 def connect(self, *subordinates, include=None, exclude=None):
202 def rec_name(record):
203 if record.name is None:
204 return "unnamed record"
205 else:
206 return "record '{}'".format(record.name)
207
208 for field in include or {}:
209 if field not in self.fields:
210 raise AttributeError("Cannot include field '{}' because it is not present in {}"
211 .format(field, rec_name(self)))
212 for field in exclude or {}:
213 if field not in self.fields:
214 raise AttributeError("Cannot exclude field '{}' because it is not present in {}"
215 .format(field, rec_name(self)))
216
217 stmts = []
218 for field in self.fields:
219 if include is not None and field not in include:
220 continue
221 if exclude is not None and field in exclude:
222 continue
223
224 shape, direction = self.layout[field]
225 if not isinstance(shape, Layout) and direction == DIR_NONE:
226 raise TypeError("Cannot connect field '{}' of {} because it does not have "
227 "a direction"
228 .format(field, rec_name(self)))
229
230 item = self.fields[field]
231 subord_items = []
232 for subord in subordinates:
233 if field not in subord.fields:
234 raise AttributeError("Cannot connect field '{}' of {} to subordinate {} "
235 "because the subordinate record does not have this field"
236 .format(field, rec_name(self), rec_name(subord)))
237 subord_items.append(subord.fields[field])
238
239 if isinstance(shape, Layout):
240 sub_include = include[field] if include and field in include else None
241 sub_exclude = exclude[field] if exclude and field in exclude else None
242 stmts += item.connect(*subord_items, include=sub_include, exclude=sub_exclude)
243 else:
244 if direction == DIR_FANOUT:
245 stmts += [sub_item.eq(item) for sub_item in subord_items]
246 if direction == DIR_FANIN:
247 stmts += [item.eq(reduce(lambda a, b: a | b, subord_items))]
248
249 return stmts
250
251 def _valueproxy(name):
252 value_func = getattr(Value, name)
253 @wraps(value_func)
254 def _wrapper(self, *args, **kwargs):
255 return value_func(Value.cast(self), *args, **kwargs)
256 return _wrapper
257
258 for name in [
259 "__bool__",
260 "__invert__", "__neg__",
261 "__add__", "__radd__", "__sub__", "__rsub__",
262 "__mul__", "__rmul__",
263 "__mod__", "__rmod__", "__floordiv__", "__rfloordiv__",
264 "__lshift__", "__rlshift__", "__rshift__", "__rrshift__",
265 "__and__", "__rand__", "__xor__", "__rxor__", "__or__", "__ror__",
266 "__eq__", "__ne__", "__lt__", "__le__", "__gt__", "__ge__",
267 "__abs__", "__len__",
268 "as_unsigned", "as_signed", "bool", "any", "all", "xor", "implies",
269 "bit_select", "word_select", "matches",
270 "shift_left", "shift_right", "rotate_left", "rotate_right", "eq"
271 ]:
272 setattr(Record, name, _valueproxy(name))
273
274 del _valueproxy
275 del name