hdl.{ast,dsl},back.rtlil: track source locations for switch cases.
authorwhitequark <cz@m-labs.hk>
Tue, 9 Jul 2019 19:18:02 +0000 (19:18 +0000)
committerwhitequark <cz@m-labs.hk>
Tue, 9 Jul 2019 19:26:47 +0000 (19:26 +0000)
This is a very new Yosys feature, and will require a Yosys build
newer than YosysHQ/yosys@93bc5aff.

nmigen/back/rtlil.py
nmigen/hdl/ast.py
nmigen/hdl/dsl.py
nmigen/hdl/xfrm.py

index 6e3c87fb846cd8c30e78c6f681ecf1065bfe70c9..1a83675bb95a3f94321be8f52a56f7992f79ce63 100644 (file)
@@ -196,7 +196,8 @@ class _SwitchBuilder(_ProxiedBuilder, _AttrBuilder):
     def __exit__(self, *args):
         self._append("{}end\n", "  " * self.indent)
 
-    def case(self, *values):
+    def case(self, *values, attrs={}, src=""):
+        self._attributes(attrs, src=src, indent=self.indent + 1)
         if values == ():
             self._append("{}case\n", "  " * (self.indent + 1))
         else:
@@ -602,10 +603,10 @@ class _StatementCompiler(xfrm.StatementVisitor):
         self._has_rhs     = False
 
     @contextmanager
-    def case(self, switch, values):
+    def case(self, switch, values, src=""):
         try:
             old_case = self._case
-            with switch.case(*values) as self._case:
+            with switch.case(*values, src=src) as self._case:
                 yield
         finally:
             self._case = old_case
@@ -658,7 +659,11 @@ class _StatementCompiler(xfrm.StatementVisitor):
 
         with self._case.switch(test_sigspec, src=src(stmt.src_loc)) as switch:
             for values, stmts in stmt.cases.items():
-                with self.case(switch, values):
+                if values in stmt.case_src_locs:
+                    case_src = src(stmt.case_src_locs[values])
+                else:
+                    case_src = ""
+                with self.case(switch, values, src=case_src):
                     self.on_statements(stmts)
 
     def on_statement(self, stmt):
index 949a82516ad77b449e9040946c357adeb4557d3a..85310c1e00322e8709cd3a762dcaa85364359e74 100644 (file)
@@ -1033,18 +1033,22 @@ class Assume(Property):
 
 # @final
 class Switch(Statement):
-    def __init__(self, test, cases, *, src_loc=None, src_loc_at=0):
+    def __init__(self, test, cases, *, src_loc=None, src_loc_at=0, case_src_locs={}):
         if src_loc is None:
             super().__init__(src_loc_at=src_loc_at)
         else:
             # Switch is a bit special in terms of location tracking because it is usually created
             # long after the control has left the statement that directly caused its creation.
             self.src_loc = src_loc
+        # Switch is also a bit special in that its parts also have location information. It can't
+        # be automatically traced, so whatever constructs a Switch may optionally provide it.
+        self.case_src_locs = {}
 
         self.test  = Value.wrap(test)
         self.cases = OrderedDict()
-        for keys, stmts in cases.items():
+        for orig_keys, stmts in cases.items():
             # Map: None -> (); key -> (key,); (key...) -> (key...)
+            keys = orig_keys
             if keys is None:
                 keys = ()
             if not isinstance(keys, tuple):
@@ -1064,6 +1068,8 @@ class Switch(Statement):
             if not isinstance(stmts, Iterable):
                 stmts = [stmts]
             self.cases[new_keys] = Statement.wrap(stmts)
+            if orig_keys in case_src_locs:
+                self.case_src_locs[new_keys] = case_src_locs[orig_keys]
 
     def _lhs_signals(self):
         signals = union((s._lhs_signals() for ss in self.cases.values() for s in ss),
index f154ed3305e2f9f2ff7cdef9e6137e330a43eb2c..58db1cf0fb4222bddba30a39aa7a587a020e518e 100644 (file)
@@ -161,10 +161,12 @@ class Module(_ModuleBuilderRoot, Elaboratable):
     @contextmanager
     def If(self, cond):
         self._check_context("If", context=None)
+        src_loc = tracer.get_src_loc(src_loc_at=1)
         if_data = self._set_ctrl("If", {
-            "tests":   [],
-            "bodies":  [],
-            "src_loc": tracer.get_src_loc(src_loc_at=1),
+            "tests":    [],
+            "bodies":   [],
+            "src_loc":  src_loc,
+            "src_locs": [],
         })
         try:
             _outer_case, self._statements = self._statements, []
@@ -173,6 +175,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
             self._flush_ctrl()
             if_data["tests"].append(cond)
             if_data["bodies"].append(self._statements)
+            if_data["src_locs"].append(src_loc)
         finally:
             self.domain._depth -= 1
             self._statements = _outer_case
@@ -180,6 +183,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
     @contextmanager
     def Elif(self, cond):
         self._check_context("Elif", context=None)
+        src_loc = tracer.get_src_loc(src_loc_at=1)
         if_data = self._get_ctrl("If")
         if if_data is None:
             raise SyntaxError("Elif without preceding If")
@@ -190,6 +194,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
             self._flush_ctrl()
             if_data["tests"].append(cond)
             if_data["bodies"].append(self._statements)
+            if_data["src_locs"].append(src_loc)
         finally:
             self.domain._depth -= 1
             self._statements = _outer_case
@@ -197,6 +202,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
     @contextmanager
     def Else(self):
         self._check_context("Else", context=None)
+        src_loc = tracer.get_src_loc(src_loc_at=1)
         if_data = self._get_ctrl("If")
         if if_data is None:
             raise SyntaxError("Else without preceding If/Elif")
@@ -206,6 +212,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
             yield
             self._flush_ctrl()
             if_data["bodies"].append(self._statements)
+            if_data["src_locs"].append(src_loc)
         finally:
             self.domain._depth -= 1
             self._statements = _outer_case
@@ -218,6 +225,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
             "test":    Value.wrap(test),
             "cases":   OrderedDict(),
             "src_loc": tracer.get_src_loc(src_loc_at=1),
+            "case_src_locs": {},
         })
         try:
             self._ctrl_context = "Switch"
@@ -231,6 +239,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
     @contextmanager
     def Case(self, *values):
         self._check_context("Case", context="Switch")
+        src_loc = tracer.get_src_loc(src_loc_at=1)
         switch_data = self._get_ctrl("Switch")
         new_values = ()
         for value in values:
@@ -254,6 +263,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
             # which means the branch will always match.
             if not (values and not new_values):
                 switch_data["cases"][new_values] = self._statements
+                switch_data["case_src_locs"][new_values] = src_loc
         finally:
             self._ctrl_context = "Switch"
             self._statements = _outer_case
@@ -272,6 +282,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
             "decoding": OrderedDict(),
             "states":   OrderedDict(),
             "src_loc":  tracer.get_src_loc(src_loc_at=1),
+            "state_src_locs": {},
         })
         self._generated[name] = fsm = \
             FSM(fsm_data["signal"], fsm_data["encoding"], fsm_data["decoding"])
@@ -287,6 +298,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
     @contextmanager
     def State(self, name):
         self._check_context("FSM State", context="FSM")
+        src_loc = tracer.get_src_loc(src_loc_at=1)
         fsm_data = self._get_ctrl("FSM")
         if name in fsm_data["states"]:
             raise SyntaxError("FSM state '{}' is already defined".format(name))
@@ -298,6 +310,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
             yield
             self._flush_ctrl()
             fsm_data["states"][name] = self._statements
+            fsm_data["state_src_locs"][name] = src_loc
         finally:
             self._ctrl_context = "FSM"
             self._statements = _outer_case
@@ -327,6 +340,7 @@ class Module(_ModuleBuilderRoot, Elaboratable):
 
         if name == "If":
             if_tests, if_bodies = data["tests"], data["bodies"]
+            if_src_locs = data["src_locs"]
 
             tests, cases = [], OrderedDict()
             for if_test, if_case in zip(if_tests + [None], if_bodies):
@@ -342,16 +356,20 @@ class Module(_ModuleBuilderRoot, Elaboratable):
                     match = None
                 cases[match] = if_case
 
-            self._statements.append(Switch(Cat(tests), cases, src_loc=src_loc))
+            self._statements.append(Switch(Cat(tests), cases,
+                src_loc=src_loc, case_src_locs=dict(zip(cases, if_src_locs))))
 
         if name == "Switch":
             switch_test, switch_cases = data["test"], data["cases"]
+            switch_case_src_locs = data["case_src_locs"]
 
-            self._statements.append(Switch(switch_test, switch_cases, src_loc=src_loc))
+            self._statements.append(Switch(switch_test, switch_cases,
+                src_loc=src_loc, case_src_locs=switch_case_src_locs))
 
         if name == "FSM":
             fsm_signal, fsm_reset, fsm_encoding, fsm_decoding, fsm_states = \
                 data["signal"], data["reset"], data["encoding"], data["decoding"], data["states"]
+            fsm_state_src_locs = data["state_src_locs"]
             if not fsm_states:
                 return
             fsm_signal.nbits = bits_for(len(fsm_encoding) - 1)
@@ -364,7 +382,8 @@ class Module(_ModuleBuilderRoot, Elaboratable):
             fsm_signal.decoder = lambda n: "{}/{}".format(fsm_decoding[n], n)
             self._statements.append(Switch(fsm_signal,
                 OrderedDict((fsm_encoding[name], stmts) for name, stmts in fsm_states.items()),
-                src_loc=src_loc))
+                src_loc=src_loc, case_src_locs={fsm_encoding[name]: fsm_state_src_locs[name]
+                                                for name in fsm_states}))
 
     def _add_statement(self, assigns, domain, depth, compat_mode=False):
         def domain_name(domain):
index 434067665a52f04d9c96bfa7b6692b229502b37b..334e1e41c665939675cc242553c63fd75f3f7c83 100644 (file)
@@ -216,6 +216,8 @@ class StatementVisitor(metaclass=ABCMeta):
             new_stmt = self.on_unknown_statement(stmt)
         if isinstance(new_stmt, Statement) and self.replace_statement_src_loc(stmt, new_stmt):
             new_stmt.src_loc = stmt.src_loc
+            if isinstance(new_stmt, Switch) and isinstance(stmt, Switch):
+                new_stmt.case_src_locs = stmt.case_src_locs
         return new_stmt
 
     def __call__(self, stmt):