from enum import Enum
from collections import OrderedDict
+from functools import reduce
from .. import tracer
from ..tools import union
if name is None:
name = "<unnamed>"
return "(rec {} {})".format(name, " ".join(fields))
+
+ def connect(self, *subordinates, include=None, exclude=None):
+ def rec_name(record):
+ if record.name is None:
+ return "unnamed record"
+ else:
+ return "record '{}'".format(record.name)
+
+ for field in include or {}:
+ if field not in self.fields:
+ raise AttributeError("Cannot include field '{}' because it is not present in {}"
+ .format(field, rec_name(self)))
+ for field in exclude or {}:
+ if field not in self.fields:
+ raise AttributeError("Cannot exclude field '{}' because it is not present in {}"
+ .format(field, rec_name(self)))
+
+ stmts = []
+ for field in self.fields:
+ if include is not None and field not in include:
+ continue
+ if exclude is not None and field in exclude:
+ continue
+
+ shape, direction = self.layout[field]
+ if not isinstance(shape, Layout) and direction == DIR_NONE:
+ raise TypeError("Cannot connect field '{}' of {} because it does not have "
+ "a direction"
+ .format(field, rec_name(self)))
+
+ item = self.fields[field]
+ subord_items = []
+ for subord in subordinates:
+ if field not in subord.fields:
+ raise AttributeError("Cannot connect field '{}' of {} to subordinate {} "
+ "because the subordinate record does not have this field"
+ .format(field, rec_name(self), rec_name(subord)))
+ subord_items.append(subord.fields[field])
+
+ if isinstance(shape, Layout):
+ sub_include = include[field] if include and field in include else None
+ sub_exclude = exclude[field] if exclude and field in exclude else None
+ stmts += item.connect(*subord_items, include=sub_include, exclude=sub_exclude)
+ else:
+ if direction == DIR_FANOUT:
+ stmts += [sub_item.eq(item) for sub_item in subord_items]
+ if direction == DIR_FANIN:
+ stmts += [item.eq(reduce(lambda a, b: a | b, subord_items))]
+
+ return stmts
with self.assertRaises(NameError,
msg="Unnamed record does not have a field 'en'. Did you mean one of: stb, ack?"):
r.en
+
+
+class ConnectTestCase(FHDLTestCase):
+ def setUp_flat(self):
+ self.core_layout = [
+ ("addr", 32, DIR_FANOUT),
+ ("data_r", 32, DIR_FANIN),
+ ("data_w", 32, DIR_FANIN),
+ ]
+ self.periph_layout = [
+ ("addr", 32, DIR_FANOUT),
+ ("data_r", 32, DIR_FANIN),
+ ("data_w", 32, DIR_FANIN),
+ ]
+
+ def setUp_nested(self):
+ self.core_layout = [
+ ("addr", 32, DIR_FANOUT),
+ ("data", [
+ ("r", 32, DIR_FANIN),
+ ("w", 32, DIR_FANIN),
+ ]),
+ ]
+ self.periph_layout = [
+ ("addr", 32, DIR_FANOUT),
+ ("data", [
+ ("r", 32, DIR_FANIN),
+ ("w", 32, DIR_FANIN),
+ ]),
+ ]
+
+ def test_flat(self):
+ self.setUp_flat()
+
+ core = Record(self.core_layout)
+ periph1 = Record(self.periph_layout)
+ periph2 = Record(self.periph_layout)
+
+ stmts = core.connect(periph1, periph2)
+ self.assertRepr(stmts, """(
+ (eq (sig periph1__addr) (sig core__addr))
+ (eq (sig periph2__addr) (sig core__addr))
+ (eq (sig core__data_r) (| (sig periph1__data_r) (sig periph2__data_r)))
+ (eq (sig core__data_w) (| (sig periph1__data_w) (sig periph2__data_w)))
+ )""")
+
+ def test_flat_include(self):
+ self.setUp_flat()
+
+ core = Record(self.core_layout)
+ periph1 = Record(self.periph_layout)
+ periph2 = Record(self.periph_layout)
+
+ stmts = core.connect(periph1, periph2, include={"addr": True})
+ self.assertRepr(stmts, """(
+ (eq (sig periph1__addr) (sig core__addr))
+ (eq (sig periph2__addr) (sig core__addr))
+ )""")
+
+ def test_flat_exclude(self):
+ self.setUp_flat()
+
+ core = Record(self.core_layout)
+ periph1 = Record(self.periph_layout)
+ periph2 = Record(self.periph_layout)
+
+ stmts = core.connect(periph1, periph2, exclude={"addr": True})
+ self.assertRepr(stmts, """(
+ (eq (sig core__data_r) (| (sig periph1__data_r) (sig periph2__data_r)))
+ (eq (sig core__data_w) (| (sig periph1__data_w) (sig periph2__data_w)))
+ )""")
+
+ def test_nested(self):
+ self.setUp_nested()
+
+ core = Record(self.core_layout)
+ periph1 = Record(self.periph_layout)
+ periph2 = Record(self.periph_layout)
+
+ stmts = core.connect(periph1, periph2)
+ self.maxDiff = None
+ self.assertRepr(stmts, """(
+ (eq (sig periph1__addr) (sig core__addr))
+ (eq (sig periph2__addr) (sig core__addr))
+ (eq (sig core__data__r) (| (sig periph1__data__r) (sig periph2__data__r)))
+ (eq (sig core__data__w) (| (sig periph1__data__w) (sig periph2__data__w)))
+ )""")
+
+ def test_wrong_include_exclude(self):
+ self.setUp_flat()
+
+ core = Record(self.core_layout)
+ periph = Record(self.periph_layout)
+
+ with self.assertRaises(AttributeError,
+ msg="Cannot include field 'foo' because it is not present in record 'core'"):
+ core.connect(periph, include={"foo": True})
+
+ with self.assertRaises(AttributeError,
+ msg="Cannot exclude field 'foo' because it is not present in record 'core'"):
+ core.connect(periph, exclude={"foo": True})
+
+ def test_wrong_direction(self):
+ recs = [Record([("x", 1)]) for _ in range(2)]
+
+ with self.assertRaises(TypeError,
+ msg="Cannot connect field 'x' of unnamed record because it does not have "
+ "a direction"):
+ recs[0].connect(recs[1])
+
+ def test_wrong_missing_field(self):
+ core = Record([("addr", 32, DIR_FANOUT)])
+ periph = Record([])
+
+ with self.assertRaises(AttributeError,
+ msg="Cannot connect field 'addr' of record 'core' to subordinate record 'periph' "
+ "because the subordinate record does not have this field"):
+ core.connect(periph)