add cont_tb and rewrite cont
authorFlorent Kermarrec <florent@enjoy-digital.fr>
Fri, 19 Dec 2014 10:15:01 +0000 (11:15 +0100)
committerFlorent Kermarrec <florent@enjoy-digital.fr>
Fri, 19 Dec 2014 10:15:01 +0000 (11:15 +0100)
lib/sata/link/cont.py
lib/sata/test/Makefile
lib/sata/test/common.py
lib/sata/test/cont_tb.py [new file with mode: 0644]

index fca64c70f77e69262de049725e2c59c84d2ad80b..9553bb40d0ad9dc76dee7c539b7102697f546400 100644 (file)
@@ -10,73 +10,66 @@ class SATACONTInserter(Module):
 
                ###
 
-               # Detect consecutive primitives
-               # tn insert CONT
                counter = Counter(max=4)
                self.submodules += counter
 
-               is_primitive = Signal()
-               last_was_primitive = Signal()
-               last_primitive = Signal(32)
+               is_data = Signal()
+               was_data = Signal()
                change = Signal()
+               self.comb += is_data.eq(sink.charisk == 0)
 
-               cont_insert = Signal()
-               scrambler_insert = Signal()
-               last_primitive_insert = Signal()
-               last_primitive_insert_d = Signal()
-
-               self.comb += [
-                       is_primitive.eq(sink.charisk != 0),
-                       change.eq((sink.data != last_primitive) | ~is_primitive),
-                       cont_insert.eq(~change & (counter.value == 1)),
-                       scrambler_insert.eq(~change & (counter.value == 2)),
-                       last_primitive_insert.eq((counter.value == 2) & (
-                               (~is_primitive & last_was_primitive) |
-                               (is_primitive & (last_primitive == primitives["HOLD"]) & (last_primitive != sink.data))))
-               ]
-
-               self.sync += \
+               last_data = Signal(32)
+               last_primitive = Signal(32)
+               last_charisk = Signal(4)
+               self.sync += [
                        If(sink.stb & source.ack,
-                               last_primitive_insert_d.eq(last_primitive_insert),
-                               If(is_primitive,
+                               last_data.eq(sink.data),
+                               last_charisk.eq(sink.charisk),
+                               If(~is_data,
                                        last_primitive.eq(sink.data),
-                                       last_was_primitive.eq(1)
-                               ).Else(
-                                       last_was_primitive.eq(0)
-                               )
-                       )
-               self.comb += \
-                       If(sink.stb & source.ack,
-                               If(change | last_primitive_insert_d,
-                                       counter.reset.eq(1)
-                               ).Else(
-                                       counter.ce.eq(~scrambler_insert)
-                               )
+                               ),
+                               was_data.eq(is_data)
                        )
+               ]
+               was_hold = last_primitive == primitives["HOLD"]
+
+               self.comb += change.eq(
+                       (sink.data != last_data) |
+                       (sink.charisk != last_charisk) |
+                       is_data
+               )
 
-               # scrambler (between CONT and next primitive)
+               # scrambler
                scrambler = InsertReset(Scrambler())
                self.submodules += scrambler
-               self.comb += [
-                       scrambler.reset.eq(ResetSignal()), #XXX: should be reseted on COMINIT / COMRESET
-                       scrambler.ce.eq(scrambler_insert & source.stb & source.ack)
-               ]
 
                # Datapath
                self.comb += [
                        Record.connect(sink, source),
                        If(sink.stb,
-                               If(cont_insert,
-                                       source.charisk.eq(0b0001),
-                                       source.data.eq(primitives["CONT"])
-                               ).Elif(scrambler_insert,
-                                       source.charisk.eq(0b0000),
-                                       source.data.eq(scrambler.value)
-                               ).Elif(last_primitive_insert,
-                                       source.stb.eq(1),
-                                       sink.ack.eq(0),
-                                       source.charisk.eq(0b0001),
-                                       source.data.eq(last_primitive)
+                               If(~change,
+                                       counter.ce.eq(sink.ack & (counter.value !=2)),
+                                       # insert CONT
+                                       If(counter.value == 1,
+                                               source.charisk.eq(0b0001),
+                                               source.data.eq(primitives["CONT"])
+                                       # insert scrambled data for EMI
+                                       ).Elif(counter.value == 2,
+                                               scrambler.ce.eq(sink.ack),
+                                               source.charisk.eq(0b0000),
+                                               source.data.eq(scrambler.value)
+                                       )
+                               ).Else(
+                                       counter.reset.eq(source.ack),
+                                       If(counter.value == 2,
+                                               # Reinsert last primitive
+                                               If(is_data | (~is_data & was_hold),
+                                                       source.stb.eq(1),
+                                                       sink.ack.eq(0),
+                                                       source.charisk.eq(0b0001),
+                                                       source.data.eq(last_primitive)
+                                               )
+                                       )
                                )
                        )
                ]
@@ -88,29 +81,32 @@ class SATACONTRemover(Module):
 
                ###
 
-               # Detect CONT
-               is_primitive = Signal()
+               is_data = Signal()
                is_cont = Signal()
                in_cont = Signal()
                cont_ongoing = Signal()
 
                self.comb += [
-                       is_primitive.eq(sink.charisk != 0),
-                       is_cont.eq(is_primitive & (sink.data == primitives["CONT"]))
+                       is_data.eq(sink.charisk == 0),
+                       is_cont.eq(~is_data & (sink.data == primitives["CONT"]))
                ]
                self.sync += \
-                       If(is_cont,
-                               in_cont.eq(1)
-                       ).Elif(is_primitive,
-                               in_cont.eq(0)
+                       If(sink.stb & sink.ack,
+                               If(is_cont,
+                                       in_cont.eq(1)
+                               ).Elif(~is_data,
+                                       in_cont.eq(0)
+                               )
                        )
-               self.comb += cont_ongoing.eq(is_cont | (in_cont & ~is_primitive))
+               self.comb += cont_ongoing.eq(is_cont | (in_cont & is_data))
 
                # Datapath
-               last_primitive = Signal()
+               last_primitive = Signal(32)
                self.sync += [
-                       If(is_primitive & ~is_cont,
-                               last_primitive.eq(sink.data)
+                       If(sink.stb & sink.ack,
+                               If(~is_data & ~is_cont,
+                                       last_primitive.eq(sink.data)
+                               )
                        )
                ]
                self.comb += [
index ef772ae9665839697049c0696632cdfd0fbd3151..cfdd7214abde957eee4b2f91368d1c923724546d 100644 (file)
@@ -14,6 +14,9 @@ scrambler_tb:
        $(CC) $(CFLAGS) $(INC) -o scrambler scrambler.c
        $(CMD) scrambler_tb.py
 
+cont_tb:
+       $(CMD) cont_tb.py
+
 link_tb:
        $(CMD) link_tb.py
 
@@ -23,7 +26,5 @@ command_tb:
 bist_tb:
        $(CMD) bist_tb.py
 
-all: crc_tb scrambler_tb link_tb command_tb
-
 clean:
        rm crc scrambler *.vcd
index b79442a63b5265cc8a2131bc7636fed5b38209da..1469e4b53bb1d69c528f91dd0adf8cc468d2f2d2 100644 (file)
@@ -42,6 +42,8 @@ class PacketStreamer(Module):
                self.packet = packet_class()
                self.packet.done = 1
 
+               self.source_data = 0
+
        def send(self, packet, blocking=True):
                packet = copy.deepcopy(packet)
                self.packets.append(packet)
@@ -54,22 +56,26 @@ class PacketStreamer(Module):
                        self.packet = self.packets.pop(0)
                if not self.packet.ongoing and not self.packet.done:
                        selfp.source.stb = 1
-                       selfp.source.sop = 1
+                       if self.source.description.packetized:
+                               selfp.source.sop = 1
+                       self.source_data = self.packet.pop(0)
                        if len(self.packet) > 0:
                                if hasattr(selfp.source, "data"):
-                                       selfp.source.data = self.packet.pop(0)
+                                       selfp.source.data = self.source_data
                                else:
-                                       selfp.source.d = self.packet.pop(0)
+                                       selfp.source.d = self.source_data
                        self.packet.ongoing = True
                elif selfp.source.stb == 1 and selfp.source.ack == 1:
-                       selfp.source.sop = 0
-                       selfp.source.eop = (len(self.packet) == 1)
+                       if self.source.description.packetized:
+                               selfp.source.sop = 0
+                               selfp.source.eop = (len(self.packet) == 1)
                        if len(self.packet) > 0:
                                selfp.source.stb = 1
+                               self.source_data = self.packet.pop(0)
                                if hasattr(selfp.source, "data"):
-                                       selfp.source.data = self.packet.pop(0)
+                                       selfp.source.data = self.source_data
                                else:
-                                       selfp.source.d = self.packet.pop(0)
+                                       selfp.source.d = self.source_data
                        else:
                                self.packet.done = 1
                                selfp.source.stb = 0
@@ -81,22 +87,28 @@ class PacketLogger(Module):
                self.packet_class = packet_class
                self.packet = packet_class()
 
-       def receive(self):
+       def receive(self, length=None):
                self.packet.done = 0
-               while self.packet.done == 0:
-                       yield
+               if length is None:
+                       while self.packet.done == 0:
+                               yield
+               else:
+                       while length > len(self.packet):
+                               yield
 
        def do_simulation(self, selfp):
                selfp.sink.ack = 1
-               if selfp.sink.stb == 1 and selfp.sink.sop == 1:
-                       self.packet = self.packet_class()
+               if self.sink.description.packetized:
+                       if selfp.sink.stb == 1 and selfp.sink.sop == 1:
+                               self.packet = self.packet_class()
                if selfp.sink.stb:
                        if hasattr(selfp.sink, "data"):
                                self.packet.append(selfp.sink.data)
                        else:
                                self.packet.append(selfp.sink.d)
-               if selfp.sink.stb == 1 and selfp.sink.eop == 1:
-                       self.packet.done = True
+               if self.sink.description.packetized:
+                       if selfp.sink.stb == 1 and selfp.sink.eop == 1:
+                               self.packet.done = True
 
 class Randomizer(Module):
        def __init__(self, description, level=0):
diff --git a/lib/sata/test/cont_tb.py b/lib/sata/test/cont_tb.py
new file mode 100644 (file)
index 0000000..61c800d
--- /dev/null
@@ -0,0 +1,95 @@
+from lib.sata.common import *
+from lib.sata.link.cont import SATACONTInserter, SATACONTRemover
+
+from lib.sata.test.common import *
+
+class ContPacket(list):
+       def __init__(self, data=[]):
+               self.ongoing = False
+               self.done = False
+               for d in data:
+                       self.append(d)
+
+class ContStreamer(PacketStreamer):
+       def __init__(self):
+               PacketStreamer.__init__(self, phy_description(32), ContPacket)
+
+       def do_simulation(self, selfp):
+               PacketStreamer.do_simulation(self, selfp)
+               selfp.source.charisk = 0
+               # Note: for simplicity we generate charisk by detecting
+               # primitives in data
+               for k, v in primitives.items():
+                       try:
+                               if self.source_data == v:
+                                       selfp.source.charisk = 0b0001
+                       except:
+                               pass
+
+class ContLogger(PacketLogger):
+       def __init__(self):
+               PacketLogger.__init__(self, phy_description(32), ContPacket)
+
+class TB(Module):
+       def __init__(self):
+               self.streamer = ContStreamer()
+               self.streamer_randomizer = Randomizer(phy_description(32), level=50)
+               self.inserter = SATACONTInserter(phy_description(32))
+               self.remover = SATACONTRemover(phy_description(32))
+               self.logger_randomizer = Randomizer(phy_description(32), level=50)
+               self.logger = ContLogger()
+
+               self.pipeline = Pipeline(
+                       self.streamer,
+                       self.streamer_randomizer,
+                       self.inserter,
+                       self.remover,
+                       self.logger_randomizer,
+                       self.logger
+               )
+
+       def gen_simulation(self, selfp):
+               test_packet = ContPacket([
+                       primitives["SYNC"],
+                       primitives["SYNC"],
+                       primitives["SYNC"],
+                       primitives["SYNC"],
+                       primitives["SYNC"],
+                       primitives["SYNC"],
+                       primitives["ALIGN"],
+                       primitives["ALIGN"],
+                       primitives["SYNC"],
+                       primitives["SYNC"],
+                       #primitives["SYNC"],
+                       0x00000000,
+                       0x00000001,
+                       0x00000002,
+                       0x00000003,
+                       0x00000004,
+                       0x00000005,
+                       0x00000006,
+                       0x00000007,
+                       primitives["SYNC"],
+                       primitives["SYNC"],
+                       primitives["SYNC"],
+                       primitives["SYNC"],
+                       primitives["ALIGN"],
+                       primitives["ALIGN"],
+                       primitives["SYNC"],
+                       primitives["SYNC"],
+                       primitives["SYNC"],
+                       primitives["SYNC"]]*4
+                       )
+               streamer_packet = ContPacket(test_packet)
+               yield from self.streamer.send(streamer_packet)
+               yield from self.logger.receive(len(test_packet))
+               #for d in self.logger.packet:
+               #       print("%08x" %d)
+
+               # check results
+               s, l, e = check(streamer_packet, self.logger.packet)
+               print("shift "+ str(s) + " / length " + str(l) + " / errors " + str(e))
+
+
+if __name__ == "__main__":
+       run_simulation(TB(), ncycles=1024, vcd_name="my.vcd", keep_files=True)