Specify names for TAP signals.
[c4m-jtag.git] / c4m / nmigen / jtag / tap.py
index 725e7e6321df9c5b5f511c74519de746f790916d..3bc33e78aac43b3df54b3b4d5d2166adab55a6c0 100755 (executable)
 #!/bin/env python3
-import os
+import os, textwrap
 
 from nmigen import *
 from nmigen.build import *
 from nmigen.lib.io import *
+from nmigen.hdl.rec import Direction
 from nmigen.tracer import get_var_name
 
-from c4m_repo.nmigen.lib import Wishbone
+from nmigen_soc.wishbone import Interface as WishboneInterface
+
+from .bus import Interface
 
 __all__ = [
-    "TAP",
+    "TAP", "ShiftReg",
 ]
 
 
-class ShiftReg(Elaboratable):
-    def __init__(self, ircodes, length, domain):
-        # The sr record will be returned to user code
-        self.sr = Record([("i", length), ("o", length), ("oe", len(ircodes)), ("ack", 1)])
-        # The next attributes are for JTAG class usage only
-        self.ir = None # made None as width is not known yet
-        self.tdi = Signal()
-        self.tdo = Signal()
-        self.tdo_en = Signal()
-        self.capture = Signal()
-        self.shift = Signal()
-        self.update = Signal()
-        self.jtag_cd = None # The JTAG clock domain
-
-        ##
-
-        self._ircodes = ircodes
-        self._domain = domain
-
-    def elaborate(self, platform):
-        length = len(self.sr.o)
-        domain = self._domain
-
-        m = Module()
-
-        m.domains.jtag = self.jtag_cd
-
-        sr_jtag = Signal(length)
-
-        assert isinstance(self.ir, Signal)
-        isir = Signal(len(self._ircodes))
-        capture = Signal()
-        shift = Signal()
-        update = Signal()
-        m.d.comb += [
-            isir.eq(Cat(self.ir == ircode for ircode in self._ircodes)),
-            capture.eq((isir != 0) & self.capture),
-            shift.eq((isir != 0) & self.shift),
-            update.eq((isir != 0) & self.update),
-        ]
-
-        # On update set o, oe and wait for ack
-        # update signal is on JTAG clockdomain, latch it
-        update_core = Signal()
-        m.d[domain] += update_core.eq(update) # This is CDC from JTAG domain to given domain
-        with m.FSM(domain=domain):
-            with m.State("IDLE"):
-                m.d.comb += self.sr.oe.eq(0)
-                with m.If(update_core):
-                    # Latch sr_jtag cross domain but it should be stable due to latching of update_core
-                    m.d[domain] += self.sr.o.eq(sr_jtag)
-                    # Wait one cycle to raise oe so sr.o has one more cycle to stabilize
-                    m.next = "WAIT4ACK"
-            with m.State("WAIT4ACK"):
-                m.d.comb += self.sr.oe.eq(isir)
-                with m.If(self.sr.ack):
-                    m.next = "WAIT4END"
-            with m.State("WAIT4END"):
-                m.d.comb += self.sr.oe.eq(0)
-                with m.If(~update_core):
-                    m.next = "IDLE"
-
-        m.d.comb += [
-            self.tdo.eq(sr_jtag[0]),
-            self.tdo_en.eq(shift),
-        ]
-
-        with m.If(shift):
-            m.d.jtag += sr_jtag.eq(Cat(sr_jtag[1:], self.tdi))
-        with m.If(capture):
-            m.d.jtag += sr_jtag.eq(self.sr.i)
-
-        return m
-
-class JTAGWishbone(Elaboratable):
-    def __init__(self, sr_addr, sr_data, wb, domain):
-        self._sr_addr = sr_addr
-        self._sr_data = sr_data
-        self._wb = wb
-        self._domain = domain
-
-        # To be set by JTAG
-        self._ir = None
-
-    def elaborate(self, platform):
-        sr_addr = self._sr_addr
-        sr_data = self._sr_data
-        wb = self._wb
-        domain = self._domain
-        ir = self._ir
-
-        m = Module()
-
-        if hasattr(wb, "sel"):
-            # Always selected
-            m.d.comb += [s.eq(1) for s in wb.sel]
-
-        # Immediately ack oe
-        m.d[domain] += [
-            sr_addr.ack.eq(sr_addr.oe),
-            sr_data.ack.eq(sr_data.oe != 0),
+class ShiftReg(Record):
+    """Object with interface for extra shift registers on a TAP.
+
+    Parameters
+    ----------
+    sr_length : int
+    cmds : int, default=1
+        The number of corresponding JTAG instructions
+
+    This object is normally only allocated and returned from ``TAP.add_shiftreg``
+    It is a Record subclass.
+
+    Attributes
+    ----------
+    i: length=sr_length, FANIN
+        The input data sampled during capture state of the TAP
+    ie: length=cmds, FANOUT
+        Indicates that data is to be sampled by the JTAG TAP and
+        should be held stable. The bit indicates the corresponding
+        instruction for which data is asked.
+        This signal is kept high for a whole JTAG TAP clock cycle
+        and may thus be kept higher for more than one clock cycle
+        on the domain where ShiftReg is used.
+        The JTAG protocol does not allow insertion of wait states
+        so data need to be provided before ie goes down. The speed
+        of the response will determine the max. frequency for the
+        JTAG interface.
+    o: length=sr_length, FANOUT
+        The value of the shift register.
+    oe: length=cmds, FANOUT
+        Indicates that output needs to be sampled downstream because
+        JTAG TAP in in the Update state. The bit indicated the corresponding
+        instruction. The bit is only kept high for one clock cycle.
+    """
+    def __init__(self, *, sr_length, cmds=1, name=None, src_loc_at=0):
+        layout = [
+            ("i", sr_length, Direction.FANIN),
+            ("ie", cmds, Direction.FANOUT),
+            ("o", sr_length, Direction.FANOUT),
+            ("oe", cmds, Direction.FANOUT),
         ]
-
-        with m.FSM(domain=domain) as fsm:
-            with m.State("IDLE"):
-                m.d.comb += [
-                    wb.cyc.eq(0),
-                    wb.stb.eq(0),
-                    wb.we.eq(0),
-                ]
-                with m.If(sr_addr.oe): # WBADDR code
-                    m.d[domain] += wb.addr.eq(sr_addr.o)
-                    m.next = "READ"
-                with m.If(sr_data.oe[0]): # WBREAD code
-                    m.d[domain] += wb.addr.eq(wb.addr + 1)
-                    m.next = "READ"
-                with m.If(sr_data.oe[1]): # WBWRITE code
-                    m.d[domain] += wb.dat_w.eq(sr_data.o)
-                    m.next = "WRITEREAD"
-            with m.State("READ"):
-                m.d.comb += [
-                    wb.cyc.eq(1),
-                    wb.stb.eq(1),
-                    wb.we.eq(0),
-                ]
-                with m.If(~wb.stall):
-                    m.next = "READACK"
-            with m.State("READACK"):
-                m.d.comb += [
-                    wb.cyc.eq(1),
-                    wb.stb.eq(0),
-                    wb.we.eq(0),
-                ]
-                with m.If(wb.ack):
-                    m.d[domain] += sr_data.i.eq(wb.dat_r)
-                    m.next = "IDLE"
-            with m.State("WRITEREAD"):
-                m.d.comb += [
-                    wb.cyc.eq(1),
-                    wb.stb.eq(1),
-                    wb.we.eq(1),
-                ]
-                with m.If(~wb.stall):
-                    m.next = "WRITEREADACK"
-            with m.State("WRITEREADACK"):
-                m.d.comb += [
-                    wb.cyc.eq(1),
-                    wb.stb.eq(0),
-                    wb.we.eq(0),
-                ]
-                with m.If(wb.ack):
-                    m.d[domain] += wb.addr.eq(wb.addr + 1)
-                    m.next = "READ"
-
-        return m
+        super().__init__(layout, name=name, src_loc_at=src_loc_at+1)
 
 
 class TAP(Elaboratable):
@@ -190,23 +80,115 @@ class TAP(Elaboratable):
             platform.add_file(prefix + fname, f)
             f.close()
 
-
-    def __init__(self, io_count, *, ir_width=None, manufacturer_id=Const(0b10001111111, 11),
-                 part_number=Const(1, 16), version=Const(0, 4)
+    _controller_templ = textwrap.dedent(r"""
+    library ieee;
+    use ieee.std_logic_1164.ALL;
+    
+    use work.c4m_jtag.ALL;
+    
+    entity {name} is
+      port (
+        -- The TAP signals
+        TCK:        in std_logic;
+        TMS:        in std_logic;
+        TDI:        in std_logic;
+        TDO:        out std_logic;
+        TRST_N:     in std_logic;
+    
+        -- The FSM state indicators
+        RESET:      out std_logic;
+        CAPTURE:    out std_logic;
+        SHIFT:      out std_logic;
+        UPDATE:     out std_logic;
+    
+        -- The Instruction Register
+        IR:         out std_logic_vector({ir_width}-1 downto 0);
+    
+        -- The I/O access ports
+        CORE_IN:    out std_logic_vector({ios}-1 downto 0);
+        CORE_EN:    in std_logic_vector({ios}-1 downto 0);
+        CORE_OUT:   in std_logic_vector({ios}-1 downto 0);
+    
+        -- The pad connections
+        PAD_IN:     in std_logic_vector({ios}-1 downto 0);
+        PAD_EN:     out std_logic_vector({ios}-1 downto 0);
+        PAD_OUT:    out std_logic_vector({ios}-1 downto 0)
+      );
+    end {name};
+    
+    architecture rtl of {name} is
+    begin
+      jtag : c4m_jtag_tap_controller
+        generic map(
+          DEBUG => FALSE,
+          IR_WIDTH => {ir_width},
+          IOS => {ios},
+          MANUFACTURER => "{manufacturer:011b}",
+          PART_NUMBER => "{part:016b}",
+          VERSION => "{version:04b}"
+        )
+        port map(
+          TCK => TCK,
+          TMS => TMS,
+          TDI => TDI,
+          TDO => TDO,
+          TRST_N => TRST_N,
+          RESET => RESET,
+          CAPTURE => CAPTURE,
+          SHIFT => SHIFT,
+          UPDATE => UPDATE,
+          IR => IR,
+          CORE_IN => CORE_IN,
+          CORE_EN => CORE_EN,
+          CORE_OUT => CORE_OUT,
+          PAD_IN => PAD_IN,
+          PAD_EN => PAD_EN,
+          PAD_OUT => PAD_OUT
+        );
+    end architecture rtl;
+    """)
+    _cell_inst = 0
+    @classmethod
+    def _add_instance(cls, platform, prefix, *, ir_width, ios, manufacturer, part, version):
+        name = "jtag_controller_i{}".format(cls._cell_inst)
+        cls._cell_inst += 1
+
+        platform.add_file(
+            "{}{}.vhdl".format(prefix, name),
+            cls._controller_templ.format(
+                name=name, ir_width=ir_width, ios=ios,
+                manufacturer=manufacturer, part=part, version=version,
+            )
+        )
+
+        return name
+
+
+    def __init__(
+        self, io_count, *, with_reset=False, ir_width=None,
+        manufacturer_id=Const(0b10001111111, 11), part_number=Const(1, 16),
+        version=Const(0, 4),
+        name=None, src_loc_at=0
     ):
         assert(isinstance(io_count, int) and io_count > 0)
         assert((ir_width is None) or (isinstance(ir_width, int) and ir_width >= 2))
         assert(len(version) == 4)
 
-        # TODO: Handle IOs with different directions
-        self.tck  = Signal()
-        self.tms  = Signal()
-        self.tdo  = Signal()
-        self.tdi  = Signal()
-        self.core = Array(Pin(1, "io") for _ in range(io_count)) # Signals to use for core
-        self.pad  = Array(Pin(1, "io") for _ in range(io_count)) # Signals going to IO pads
+        if name is None:
+            name = get_var_name(depth=src_loc_at+2, default="TAP")
+        self.name = name
+        self.bus = Interface(with_reset=with_reset, name=self.name+"_bus",
+                             src_loc_at=src_loc_at+1)
 
-        self.jtag_cd = ClockDomain(name="jtag", local=True) # Own clock domain using TCK as clock signal
+        # TODO: Handle IOs with different directions
+        self.core = Array(
+            Pin(1, "io", name=name+"_coreio"+str(i), src_loc_at=src_loc_at+1)
+            for i in range(io_count)
+        ) # Signals to use for core
+        self.pad  = Array(
+            Pin(1, "io", name=name+"_padio"+str(i), src_loc_at=src_loc_at+1)
+            for i in range(io_count)
+        ) # Signals going to IO pads
 
         ##
 
@@ -221,24 +203,37 @@ class TAP(Elaboratable):
 
         self._wbs = []
 
+
     def elaborate(self, platform):
-        TAP._add_files(platform, "jtag" + os.path.sep)
+        self.__class__._add_files(platform, "jtag" + os.path.sep)
 
         m = Module()
 
-        tdo_jtag = Signal()
-        reset = Signal()
-        capture = Signal()
-        shift = Signal()
-        update = Signal()
-
-
+        # Determine ir_width if not fixed.
         ir_max = max(self._ircodes) + 1 # One extra code needed with all ones
         ir_width = len("{:b}".format(ir_max))
         if self._ir_width is not None:
             assert self._ir_width >= ir_width, "Specified JTAG IR width not big enough for allocated shiift registers"
             ir_width = self._ir_width
-        ir = Signal(ir_width)
+
+        cell = self.__class__._add_instance(
+            platform, "jtag" + os.path.sep, ir_width=ir_width, ios=self._io_count,
+            manufacturer=self._manufacturer_id.value, part=self._part_number.value,
+            version=self._version.value,
+        )
+
+        sigs = Record([
+            ("capture", 1),
+            ("shift", 1),
+            ("update", 1),
+            ("ir", ir_width),
+            ("tdo_jtag", 1),
+        ])
+
+        reset = Signal()
+
+        trst_n = Signal()
+        m.d.comb += trst_n.eq(~self.bus.trst if hasattr(self.bus, "trst") else Const(1))
 
         core_i = Cat(pin.i for pin in self.core)
         core_o = Cat(pin.o for pin in self.core)
@@ -247,69 +242,39 @@ class TAP(Elaboratable):
         pad_o = Cat(pin.o for pin in self.pad)
         pad_oe = Cat(pin.oe for pin in self.pad)
 
-        params = {
-            "p_IOS": self._io_count,
-            "p_IR_WIDTH": ir_width,
-            "p_MANUFACTURER": self._manufacturer_id,
-            "p_PART_NUMBER": self._part_number,
-            "p_VERSION": self._version,
-            "i_TCK": self.tck,
-            "i_TMS": self.tms,
-            "i_TDI": self.tdi,
-            "o_TDO": tdo_jtag,
-            "i_TRST_N": Const(1),
-            "o_RESET": reset,
-            "o_DRCAPTURE": capture,
-            "o_DRSHIFT": shift,
-            "o_DRUPDATE": update,
-            "o_IR": ir,
-            "o_CORE_IN": core_i,
-            "i_CORE_OUT": core_o,
-            "i_CORE_EN": core_oe,
-            "i_PAD_IN": pad_i,
-            "o_PAD_OUT": pad_o,
-            "o_PAD_EN": pad_oe,
-        }
-        m.submodules.tap = Instance("c4m_jtag_tap_controller", **params)
-
+        m.submodules.tap = Instance(cell,
+            i_TCK=self.bus.tck,
+            i_TMS=self.bus.tms,
+            i_TDI=self.bus.tdi,
+            o_TDO=sigs.tdo_jtag,
+            i_TRST_N=trst_n,
+            o_RESET=reset,
+            o_CAPTURE=sigs.capture,
+            o_SHIFT=sigs.shift,
+            o_UPDATE=sigs.update,
+            o_IR=sigs.ir,
+            o_CORE_IN=core_i,
+            i_CORE_OUT=core_o,
+            i_CORE_EN=core_oe,
+            i_PAD_IN=pad_i,
+            o_PAD_OUT=pad_o,
+            o_PAD_EN=pad_oe,
+        )
+
+        # Own clock domain using TCK as clock signal
+        m.domains.jtag = jtag_cd = ClockDomain(name="jtag", local=True)
         m.d.comb += [
-            self.jtag_cd.clk.eq(self.tck),
-            self.jtag_cd.rst.eq(reset),
+            jtag_cd.clk.eq(self.bus.tck),
+            jtag_cd.rst.eq(reset),
         ]
 
-        for i, sr in enumerate(self._srs):
-            m.submodules["sr{}".format(i)] = sr
-            sr.ir = ir
-            m.d.comb += [
-                sr.tdi.eq(self.tdi),
-                sr.capture.eq(capture),
-                sr.shift.eq(shift),
-                sr.update.eq(update),
-            ]
-
-        if len(self._srs) > 0:
-            first = True
-            for sr in self._srs:
-                if first:
-                    first = False
-                    with m.If(sr.tdo_en):
-                        m.d.comb += self.tdo.eq(sr.tdo)
-                else:
-                    with m.Elif(sr.tdo_en):
-                        m.d.comb += self.tdo.eq(sr.tdo)
-            with m.Else():
-                m.d.comb += self.tdo.eq(tdo_jtag)
-        else:
-            m.d.comb += self.tdo.eq(tdo_jtag)
-
-        for i, wb in enumerate(self._wbs):
-            m.submodules["wb{}".format(i)] = wb
-            wb._ir = ir
+        self._elaborate_shiftregs(m, sigs)
+        self._elaborate_wishbones(m)
 
         return m
 
 
-    def add_shiftreg(self, ircode, length, domain="sync"):
+    def add_shiftreg(self, ircode, length, domain="sync", name=None, src_loc_at=0):
         """Add a shift register to the JTAG interface
 
         Parameters:
@@ -324,34 +289,145 @@ class TAP(Elaboratable):
         except TypeError:
             ir_it = ircodes = (ircode,)
         for _ircode in ir_it:
-            assert(isinstance(_ircode, int) and _ircode > 0 and _ircode not in self._ircodes)
+            if not isinstance(_ircode, int) or _ircode <= 0:
+                raise ValueError("IR code '{}' is not an int greater than 0".format(_ircode))
+            if _ircode in self._ircodes:
+                raise ValueError("IR code '{}' already taken".format(_ircode))
 
-        sr = ShiftReg(ircodes, length, domain)
-        sr.jtag_cd = self.jtag_cd
         self._ircodes.extend(ircodes)
-        self._srs.append(sr)
 
-        return sr.sr
+        if name is None:
+            name = self.name + "_sr{}".format(len(self._srs))
+        sr = ShiftReg(sr_length=length, cmds=len(ircodes), name=name, src_loc_at=src_loc_at+1)
+        self._srs.append((ircodes, domain, sr))
 
+        return sr
 
-    def add_wishbone(self, ircodes, address_width, data_width, sel_width=None, domain="sync"):
+    def _elaborate_shiftregs(self, m, sigs):
+        # tdos is tuple of (tdo, tdo_en) for each shiftreg
+        tdos = []
+        for ircodes, domain, sr in self._srs:
+            reg = Signal(len(sr.o), name=sr.name+"_reg")
+            m.d.comb += sr.o.eq(reg)
+
+            isir = Signal(len(ircodes), name=sr.name+"_isir")
+            capture = Signal(name=sr.name+"_capture")
+            shift = Signal(name=sr.name+"_shift")
+            update = Signal(name=sr.name+"_update")
+            m.d.comb += [
+                isir.eq(Cat(sigs.ir == ircode for ircode in ircodes)),
+                capture.eq((isir != 0) & sigs.capture),
+                shift.eq((isir != 0) & sigs.shift),
+                update.eq((isir != 0) & sigs.update),
+            ]
+
+            # update signal is on the JTAG clockdomain, sr.oe is on `domain` clockdomain
+            # latch update in `domain` clockdomain and see when it has falling edge.
+            # At that edge put isir in sr.oe for one `domain` clockdomain
+            update_core = Signal(name=sr.name+"_update_core")
+            update_core_prev = Signal(name=sr.name+"_update_core_prev")
+            m.d[domain] += [
+                update_core.eq(update), # This is CDC from JTAG domain to given domain
+                update_core_prev.eq(update_core)
+            ]
+            with m.If(update_core_prev & ~update_core == 0):
+                # Falling edge of update
+                m.d[domain] += sr.oe.eq(isir)
+            with m.Else():
+                m.d[domain] += sr.oe.eq(0)
+
+            with m.If(shift):
+                m.d.jtag += reg.eq(Cat(reg[1:], self.bus.tdi))
+            with m.If(capture):
+                m.d.jtag += reg.eq(sr.i)
+
+            # tdo = reg[0], tdo_en = shift
+            tdos.append((reg[0], shift))
+
+        for i, (tdo, tdo_en) in enumerate(tdos):
+            if i == 0:
+                with m.If(shift):
+                    m.d.comb += self.bus.tdo.eq(tdo)
+            else:
+                with m.Elif(shift):
+                    m.d.comb += self.bus.tdo.eq(tdo)
+
+        if len(tdos) > 0:
+            with m.Else():
+                m.d.comb += self.bus.tdo.eq(sigs.tdo_jtag)
+        else:
+            # Always connect tdo_jtag to 
+            m.d.comb += self.bus.tdo.eq(sigs.tdo_jtag)
+
+
+    def add_wishbone(self, *, ircodes, address_width, data_width, granularity=None, domain="sync"):
         """Add a wishbone interface
 
+        In order to allow high JTAG clock speed, data will be cached. This means that if data is
+        output the value of the next address will be read automatically.
+
         Parameters:
-        - ircodes: sequence of three integer for the JTAG IR codes;
+        -----------
+        ircodes: sequence of three integer for the JTAG IR codes;
           they represent resp. WBADDR, WBREAD and WBREADWRITE. First code
           has a shift register of length 'address_width', the two other codes
           share a shift register of length data_width.
-        address_width: width of the address
-        - data_width: width of the data"""
+        address_width: width of the address
+        data_width: width of the data
 
-        assert len(ircodes) == 3
+        Returns:
+        wb: nmigen_soc.wishbone.bus.Interface
+            The Wishbone interface, is pipelined and has stall field.
+        """
+        if len(ircodes) != 3:
+            raise ValueError("3 IR Codes have to be provided")
 
         sr_addr = self.add_shiftreg(ircodes[0], address_width, domain=domain)
         sr_data = self.add_shiftreg(ircodes[1:], data_width, domain=domain)
 
-        wb = Wishbone(data_width=data_width, address_width=address_width, sel_width=sel_width, master=True)
+        wb = WishboneInterface(data_width=data_width, addr_width=address_width,
+                               granularity=granularity, features={"stall", "lock", "err", "rty"})
 
-        self._wbs.append(JTAGWishbone(sr_addr, sr_data, wb, domain))
+        self._wbs.append((sr_addr, sr_data, wb, domain))
 
         return wb
+
+    def _elaborate_wishbones(self, m):
+        for sr_addr, sr_data, wb, domain in self._wbs:
+            if hasattr(wb, "sel"):
+                # Always selected
+                m.d.comb += [s.eq(1) for s in wb.sel]
+
+            with m.FSM(domain=domain) as fsm:
+                with m.State("IDLE"):
+                    with m.If(sr_addr.oe): # WBADDR code
+                        m.d[domain] += wb.adr.eq(sr_addr.o)
+                        m.next = "READ"
+                    with m.Elif(sr_data.oe[0]): # WBREAD code
+                        # If data is
+                        m.d[domain] += wb.adr.eq(wb.adr + 1)
+                        m.next = "READ"
+                    with m.Elif(sr_data.oe[1]): # WBWRITE code
+                        m.d[domain] += wb.dat_w.eq(sr_data.o)
+                        m.next = "WRITEREAD"
+                with m.State("READ"):
+                    with m.If(~wb.stall):
+                        m.next = "READACK"
+                with m.State("READACK"):
+                    with m.If(wb.ack):
+                        # Store read data in sr_data.i and keep it there til next read
+                        m.d[domain] += sr_data.i.eq(wb.dat_r)
+                        m.next = "IDLE"
+                with m.State("WRITEREAD"):
+                    with m.If(~wb.stall):
+                        m.next = "WRITEREADACK"
+                with m.State("WRITEREADACK"):
+                    with m.If(wb.ack):
+                        m.d[domain] += wb.adr.eq(wb.adr + 1)
+                        m.next = "READ"
+
+                m.d.comb += [
+                    wb.cyc.eq(~fsm.ongoing("IDLE")),
+                    wb.stb.eq(fsm.ongoing("READ") | fsm.ongoing("WRITEREAD")),
+                    wb.we.eq(fsm.ongoing("WRITEREAD")),
+                ]