From 3ba53bdab4714334a6138d2cc24b1a66f82eedbf Mon Sep 17 00:00:00 2001 From: whitequark Date: Sun, 16 Dec 2018 13:30:20 +0000 Subject: [PATCH] back.rtlil: reorganize value compiler into LHS/RHS. This also implements Cat on LHS. --- nmigen/back/rtlil.py | 258 +++++++++++++++++++++++-------------------- 1 file changed, 140 insertions(+), 118 deletions(-) diff --git a/nmigen/back/rtlil.py b/nmigen/back/rtlil.py index 9049066..5641322 100644 --- a/nmigen/back/rtlil.py +++ b/nmigen/back/rtlil.py @@ -195,37 +195,12 @@ def src(src_loc): return "{}:{}".format(file, line) -class _ValueTransformer(xfrm.AbstractValueTransformer): - operator_map = { - (1, "~"): "$not", - (1, "-"): "$neg", - (1, "b"): "$reduce_bool", - (2, "+"): "$add", - (2, "-"): "$sub", - (2, "*"): "$mul", - (2, "/"): "$div", - (2, "%"): "$mod", - (2, "**"): "$pow", - (2, "<<"): "$sshl", - (2, ">>"): "$sshr", - (2, "&"): "$and", - (2, "^"): "$xor", - (2, "|"): "$or", - (2, "=="): "$eq", - (2, "!="): "$ne", - (2, "<"): "$lt", - (2, "<="): "$le", - (2, ">"): "$gt", - (2, ">="): "$ge", - (3, "m"): "$mux", - } - +class _ValueCompilerState: def __init__(self, rtlil): - self.rtlil = rtlil - self.wires = ast.ValueDict() - self.driven = ast.ValueDict() - self.ports = ast.ValueDict() - self.is_lhs = False + self.rtlil = rtlil + self.wires = ast.ValueDict() + self.driven = ast.ValueDict() + self.ports = ast.ValueDict() self.sub_name = None def add_driven(self, signal, sync): @@ -241,13 +216,36 @@ class _ValueTransformer(xfrm.AbstractValueTransformer): kind = "inout" self.ports[signal] = (len(self.ports), kind) - @contextmanager - def lhs(self): - try: - self.is_lhs = True - yield - finally: - self.is_lhs = False + def resolve(self, signal): + if signal in self.wires: + return self.wires[signal] + + if signal in self.ports: + port_id, port_kind = self.ports[signal] + else: + port_id = port_kind = None + if self.sub_name: + wire_name = "{}_{}".format(self.sub_name, signal.name) + else: + wire_name = signal.name + + for attr_name, attr_signal in signal.attrs.items(): + self.rtlil.attribute(attr_name, attr_signal) + wire_curr = self.rtlil.wire(width=signal.nbits, name=wire_name, + port_id=port_id, port_kind=port_kind, + src=src(signal.src_loc)) + if signal in self.driven: + wire_next = self.rtlil.wire(width=signal.nbits, name=wire_curr + "$next", + src=src(signal.src_loc)) + else: + wire_next = None + self.wires[signal] = (wire_curr, wire_next) + + return wire_curr, wire_next + + def resolve_curr(self, signal): + wire_curr, wire_next = self.resolve(signal) + return wire_curr @contextmanager def hierarchy(self, sub_name): @@ -257,12 +255,58 @@ class _ValueTransformer(xfrm.AbstractValueTransformer): finally: self.sub_name = None + +class _ValueCompiler(xfrm.AbstractValueTransformer): + def __init__(self, state): + self.s = state + def on_unknown(self, value): if value is None: return None else: super().on_unknown(value) + def on_ClockSignal(self, value): + raise NotImplementedError # :nocov: + + def on_ResetSignal(self, value): + raise NotImplementedError # :nocov: + + def on_Slice(self, value): + if value.end == value.start + 1: + return "{} [{}]".format(self(value.value), value.start) + else: + return "{} [{}:{}]".format(self(value.value), value.end - 1, value.start) + + def on_Cat(self, value): + return "{{ {} }}".format(" ".join(reversed([self(o) for o in value.operands]))) + + +class _RHSValueCompiler(_ValueCompiler): + operator_map = { + (1, "~"): "$not", + (1, "-"): "$neg", + (1, "b"): "$reduce_bool", + (2, "+"): "$add", + (2, "-"): "$sub", + (2, "*"): "$mul", + (2, "/"): "$div", + (2, "%"): "$mod", + (2, "**"): "$pow", + (2, "<<"): "$sshl", + (2, ">>"): "$sshr", + (2, "&"): "$and", + (2, "^"): "$xor", + (2, "|"): "$or", + (2, "=="): "$eq", + (2, "!="): "$ne", + (2, "<"): "$lt", + (2, "<="): "$le", + (2, ">"): "$gt", + (2, ">="): "$ge", + (3, "m"): "$mux", + } + def on_Const(self, value): if isinstance(value.value, str): return "{}'{}".format(value.nbits, value.value) @@ -270,48 +314,15 @@ class _ValueTransformer(xfrm.AbstractValueTransformer): return "{}'{:b}".format(value.nbits, value.value) def on_Signal(self, value): - if value in self.wires: - wire_curr, wire_next = self.wires[value] - else: - if value in self.ports: - port_id, port_kind = self.ports[value] - else: - port_id = port_kind = None - if self.sub_name: - wire_name = "{}_{}".format(self.sub_name, value.name) - else: - wire_name = value.name - for attr_name, attr_value in value.attrs.items(): - self.rtlil.attribute(attr_name, attr_value) - wire_curr = self.rtlil.wire(width=value.nbits, name=wire_name, - port_id=port_id, port_kind=port_kind, - src=src(value.src_loc)) - if value in self.driven: - wire_next = self.rtlil.wire(width=value.nbits, name=wire_curr + "$next", - src=src(value.src_loc)) - else: - wire_next = None - self.wires[value] = (wire_curr, wire_next) - - if self.is_lhs: - if wire_next is None: - raise ValueError("Cannot return lhs for non-driven signal {}".format(repr(value))) - return wire_next - else: - return wire_curr - - def on_ClockSignal(self, value): - raise NotImplementedError # :nocov: - - def on_ResetSignal(self, value): - raise NotImplementedError # :nocov: + wire_curr, wire_next = self.s.resolve(value) + return wire_curr def on_Operator_unary(self, value): arg, = value.operands arg_bits, arg_sign = arg.shape() res_bits, res_sign = value.shape() - res = self.rtlil.wire(width=res_bits) - self.rtlil.cell(self.operator_map[(1, value.op)], ports={ + res = self.s.rtlil.wire(width=res_bits) + self.s.rtlil.cell(self.operator_map[(1, value.op)], ports={ "\\A": self(arg), "\\Y": res, }, params={ @@ -327,8 +338,8 @@ class _ValueTransformer(xfrm.AbstractValueTransformer): value_bits, value_sign = value.shape() if new_bits > value_bits: - res = self.rtlil.wire(width=new_bits) - self.rtlil.cell("$pos", ports={ + res = self.s.rtlil.wire(width=new_bits) + self.s.rtlil.cell("$pos", ports={ "\\A": self(value), "\\Y": res, }, params={ @@ -353,8 +364,8 @@ class _ValueTransformer(xfrm.AbstractValueTransformer): lhs_wire = self.match_shape(lhs, lhs_bits, lhs_sign) rhs_wire = self.match_shape(rhs, rhs_bits, rhs_sign) res_bits, res_sign = value.shape() - res = self.rtlil.wire(width=res_bits) - self.rtlil.cell(self.operator_map[(2, value.op)], ports={ + res = self.s.rtlil.wire(width=res_bits) + self.s.rtlil.cell(self.operator_map[(2, value.op)], ports={ "\\A": lhs_wire, "\\B": rhs_wire, "\\Y": res, @@ -375,8 +386,8 @@ class _ValueTransformer(xfrm.AbstractValueTransformer): lhs_bits = rhs_bits = res_bits = max(lhs_bits, rhs_bits, res_bits) lhs_wire = self.match_shape(lhs, lhs_bits, lhs_sign) rhs_wire = self.match_shape(rhs, rhs_bits, rhs_sign) - res = self.rtlil.wire(width=res_bits) - self.rtlil.cell("$mux", ports={ + res = self.s.rtlil.wire(width=res_bits) + self.s.rtlil.cell("$mux", ports={ "\\A": lhs_wire, "\\B": rhs_wire, "\\S": self(sel), @@ -395,20 +406,11 @@ class _ValueTransformer(xfrm.AbstractValueTransformer): assert value.op == "m" return self.on_Operator_mux(value) else: - raise TypeError - - def on_Slice(self, value): - if value.end == value.start + 1: - return "{} [{}]".format(self(value.value), value.start) - else: - return "{} [{}:{}]".format(self(value.value), value.end - 1, value.start) + raise TypeError # :nocov: def on_Part(self, value): raise NotImplementedError - def on_Cat(self, value): - return "{{ {} }}".format(" ".join(reversed([self(o) for o in value.operands]))) - def on_Repl(self, value): return "{{ {} }}".format(" ".join(self(value.value) for _ in range(value.count))) @@ -416,27 +418,52 @@ class _ValueTransformer(xfrm.AbstractValueTransformer): raise NotImplementedError +class _LHSValueCompiler(_ValueCompiler): + def on_Const(self, value): + raise TypeError # :nocov: + + def on_Operator(self, value): + raise TypeError # :nocov: + + def on_Signal(self, value): + wire_curr, wire_next = self.s.resolve(value) + if wire_next is None: + raise ValueError("Cannot return lhs for non-driven signal {}".format(repr(value))) + return wire_next + + def on_Part(self, value): + raise NotImplementedError + + def on_Repl(self, value): + raise TypeError # :nocov: + + def on_ArrayProxy(self, value): + raise NotImplementedError + + def convert_fragment(builder, fragment, name, top): with builder.module(name or "anonymous", attrs={"top": 1} if top else {}) as module: - xformer = _ValueTransformer(module) + compiler_state = _ValueCompilerState(module) + rhs_compiler = _RHSValueCompiler(compiler_state) + lhs_compiler = _LHSValueCompiler(compiler_state) # Register all signals driven in the current fragment. This must be done first, as it # affects further codegen; e.g. whether sig$next signals will be generated and used. for domain, signal in fragment.iter_drivers(): - xformer.add_driven(signal, sync=domain is not None) + compiler_state.add_driven(signal, sync=domain is not None) # Transform all signals used as ports in the current fragment eagerly and outside of # any hierarchy, to make sure they get sensible (non-prefixed) names. for signal in fragment.ports: - xformer.add_port(signal, fragment.ports[signal]) - xformer(signal) + compiler_state.add_port(signal, fragment.ports[signal]) + rhs_compiler(signal) # Transform all clocks clocks and resets eagerly and outside of any hierarchy, to make # sure they get sensible (non-prefixed) names. This does not affect semantics. for domain, _ in fragment.iter_sync(): cd = fragment.domains[domain] - xformer(cd.clk) - xformer(cd.rst) + rhs_compiler(cd.clk) + rhs_compiler(cd.rst) # Transform all subfragments to their respective cells. Transforming signals connected # to their ports into wires eagerly makes sure they get sensible (prefixed with submodule @@ -444,9 +471,9 @@ def convert_fragment(builder, fragment, name, top): for subfragment, sub_name in fragment.subfragments: sub_name, sub_port_map = \ convert_fragment(builder, subfragment, top=False, name=sub_name) - with xformer.hierarchy(sub_name): + with compiler_state.hierarchy(sub_name): module.cell(sub_name, name=sub_name, ports={ - p: xformer(s) for p, s in sub_port_map.items() + p: rhs_compiler(s) for p, s in sub_port_map.items() }) with module.process() as process: @@ -455,11 +482,10 @@ def convert_fragment(builder, fragment, name, top): # For every signal in sync domains, assign \sig$next to the current value (\sig). for domain, signal in fragment.iter_drivers(): if domain is None: - prev_value = xformer(ast.Const(signal.reset, signal.nbits)) + prev_value = ast.Const(signal.reset, signal.nbits) else: - prev_value = xformer(signal) - with xformer.lhs(): - case.assign(xformer(signal), prev_value) + prev_value = signal + case.assign(lhs_compiler(signal), rhs_compiler(prev_value)) # Convert statements into decision trees. def _convert_stmts(case, stmts): @@ -468,17 +494,15 @@ def convert_fragment(builder, fragment, name, top): lhs_bits, lhs_sign = stmt.lhs.shape() rhs_bits, rhs_sign = stmt.rhs.shape() if lhs_bits == rhs_bits: - rhs_sigspec = xformer(stmt.rhs) + rhs_sigspec = rhs_compiler(stmt.rhs) else: # In RTLIL, LHS and RHS of assignment must have exactly same width. - rhs_sigspec = xformer.match_shape( + rhs_sigspec = rhs_compiler.match_shape( stmt.rhs, lhs_bits, rhs_sign) - with xformer.lhs(): - lhs_sigspec = xformer(stmt.lhs) - case.assign(lhs_sigspec, rhs_sigspec) + case.assign(lhs_compiler(stmt.lhs), rhs_sigspec) elif isinstance(stmt, ast.Switch): - with case.switch(xformer(stmt.test)) as switch: + with case.switch(rhs_compiler(stmt.test)) as switch: for value, nested_stmts in stmt.cases.items(): with switch.case(value) as nested_case: _convert_stmts(nested_case, nested_stmts) @@ -489,12 +513,11 @@ def convert_fragment(builder, fragment, name, top): _convert_stmts(case, fragment.statements) # For every signal in the sync domain, assign \sig's initial value (which will end up - # as the \init reg attribute) to the reset value. Note that this assigns \sig, - # not \sig$next. + # as the \init reg attribute) to the reset value. with process.sync("init") as sync: for domain, signal in fragment.iter_sync(): - sync.update(xformer(signal), - xformer(ast.Const(signal.reset, signal.nbits))) + wire_curr, wire_next = compiler_state.resolve(signal) + sync.update(wire_curr, rhs_compiler(ast.Const(signal.reset, signal.nbits))) # For every signal in every domain, assign \sig to \sig$next. The sensitivity list, # however, differs between domains: for comb domains, it is `always`, for sync domains @@ -506,23 +529,22 @@ def convert_fragment(builder, fragment, name, top): triggers.append(("always",)) else: cd = fragment.domains[domain] - triggers.append(("posedge", xformer(cd.clk))) + triggers.append(("posedge", compiler_state.resolve_curr(cd.clk))) if cd.async_reset: - triggers.append(("posedge", xformer(cd.rst))) + triggers.append(("posedge", compiler_state.resolve_curr(cd.rst))) for trigger in triggers: with process.sync(*trigger) as sync: for signal in signals: - lhs_sigspec = xformer(signal) - with xformer.lhs(): - sync.update(lhs_sigspec, xformer(signal)) + wire_curr, wire_next = compiler_state.resolve(signal) + sync.update(wire_curr, wire_next) # Finally, collect the names we've given to our ports in RTLIL, and correlate these with # the signals represented by these ports. If we are a submodule, this will be necessary # to create a cell for us in the parent module. port_map = OrderedDict() for signal in fragment.ports: - port_map[xformer(signal)] = signal + port_map[compiler_state.resolve_curr(signal)] = signal return module.name, port_map -- 2.30.2