Allow the formal engine to perform a same-cycle result in the ALU
[soc.git] / src / soc / minerva / cache.py
1 from nmigen import (Elaboratable, Module, Const, Signal, Record, Array,
2 Mux, Memory)
3 from nmigen.asserts import Assume, Initial
4 from nmigen.lib.coding import Encoder
5 from nmigen.utils import log2_int
6
7
8 __all__ = ["L1Cache"]
9
10
11 class L1Cache(Elaboratable):
12 def __init__(self, nways, nlines, nwords, base, limit):
13 if not nlines or nlines & nlines-1:
14 raise ValueError("nlines must be a power "\
15 "of 2, not {!r}".format(nlines))
16 if nwords not in {4, 8, 16}:
17 raise ValueError("nwords must be 4, 8 or 16, "\
18 "not {!r}".format(nwords))
19 if nways not in {1, 2}:
20 raise ValueError("nways must be 1 or 2, not {!r}".format(nways))
21
22 self.nways = nways
23 self.nlines = nlines
24 self.nwords = nwords
25 self.base = base
26 self.limit = limit
27
28 offsetbits = log2_int(nwords)
29 linebits = log2_int(nlines)
30 tagbits = log2_int(limit-base) - log2_int(nlines) - log2_int(nwords) - 2
31
32 # stage 1: address checking (is it in the cache?)
33 self.s1_addr = Record([("offset", offsetbits),
34 ("line", linebits),
35 ("tag", tagbits)])
36 self.s1_flush = Signal()
37 self.s1_stall = Signal()
38 self.s1_valid = Signal()
39
40 # stage 2: if not, what now? (XXX: what is it?? no explanation, at all)
41 self.s2_addr = Record.like(self.s1_addr)
42 self.s2_re = Signal() # read-enable?
43 self.s2_evict = Signal()
44 self.s2_valid = Signal()
45 self.bus_valid = Signal()
46 self.bus_error = Signal()
47 self.bus_rdata = Signal(32) # read data?
48
49 self.s2_miss = Signal()
50 self.s2_rdata = Signal(32) # write data?
51 self.bus_re = Signal() # read-enable?
52 self.bus_addr = Record.like(self.s1_addr)
53 self.bus_last = Signal()
54
55 def elaborate(self, platform):
56 m = Module()
57
58 ways = Array(Record([("data", self.nwords * 32),
59 ("tag", self.s2_addr.tag.shape()),
60 ("valid", 1),
61 ("bus_re", 1)])
62 for _ in range(self.nways))
63
64 if self.nways == 1:
65 way_lru = Const(0)
66 elif self.nways == 2:
67 way_lru = Signal()
68 with m.If(self.bus_re & self.bus_valid & self.bus_last &
69 ~self.bus_error):
70 m.d.sync += way_lru.eq(~way_lru)
71
72 m.d.comb += ways[way_lru].bus_re.eq(self.bus_re)
73
74 way_hit = m.submodules.way_hit = Encoder(self.nways)
75 for j, way in enumerate(ways):
76 hit = (way.tag == self.s2_addr.tag) & way.valid
77 m.d.comb += way_hit.i[j].eq(hit)
78
79 rdata = ways[way_hit.o].data.word_select(self.s2_addr.offset, 32)
80 m.d.comb += [
81 self.s2_miss.eq(way_hit.n),
82 self.s2_rdata.eq(rdata)
83 ]
84
85 with m.FSM() as fsm:
86 last_offs = Signal.like(self.s2_addr.offset)
87
88 with m.State("CHECK"):
89 with m.If(self.s2_re & self.s2_miss & self.s2_valid):
90 m.d.sync += [
91 self.bus_addr.eq(self.s2_addr),
92 self.bus_re.eq(1),
93 last_offs.eq(self.s2_addr.offset - 1)
94 ]
95 m.next = "REFILL"
96
97 with m.State("REFILL"):
98 m.d.comb += self.bus_last.eq(self.bus_addr.offset == last_offs)
99 with m.If(self.bus_valid):
100 m.d.sync += self.bus_addr.offset.eq(self.bus_addr.offset+1)
101 with m.If(self.bus_valid & self.bus_last | self.bus_error):
102 m.d.sync += self.bus_re.eq(0)
103 with m.If(~self.bus_re & ~self.s1_stall):
104 m.next = "CHECK"
105
106 if platform == "formal":
107 with m.If(Initial()):
108 m.d.comb += Assume(fsm.ongoing("CHECK"))
109
110 for way in ways:
111 valid_lines = Signal(self.nlines)
112
113 with m.If(self.s1_flush & self.s1_valid):
114 m.d.sync += valid_lines.eq(0)
115 with m.Elif(way.bus_re & self.bus_error):
116 m.d.sync += valid_lines.bit_select(self.bus_addr.line, 1).eq(0)
117 with m.Elif(way.bus_re & self.bus_valid & self.bus_last):
118 m.d.sync += valid_lines.bit_select(self.bus_addr.line, 1).eq(1)
119 with m.Elif(self.s2_evict & self.s2_valid &
120 (way.tag == self.s2_addr.tag)):
121 m.d.sync += valid_lines.bit_select(self.s2_addr.line, 1).eq(0)
122
123 tag_mem = Memory(width=len(way.tag), depth=self.nlines)
124 tag_rp = tag_mem.read_port()
125 tag_wp = tag_mem.write_port()
126 m.submodules += tag_rp, tag_wp
127
128 data_mem = Memory(width=len(way.data), depth=self.nlines)
129 data_rp = data_mem.read_port()
130 data_wp = data_mem.write_port(granularity=32)
131 m.submodules += data_rp, data_wp
132
133 taddr = Mux(self.s1_stall, self.s2_addr.line, self.s1_addr.line)
134 daddr = Mux(self.s1_stall, self.s2_addr.line, self.s1_addr.line)
135 den = way.bus_re & self.bus_valid
136 m.d.comb += [
137 tag_rp.addr.eq(taddr),
138 data_rp.addr.eq(daddr),
139
140 tag_wp.addr.eq(self.bus_addr.line),
141 tag_wp.en.eq(way.bus_re & self.bus_valid & self.bus_last),
142 tag_wp.data.eq(self.bus_addr.tag),
143
144 data_wp.addr.eq(self.bus_addr.line),
145 data_wp.en.bit_select(self.bus_addr.offset, 1).eq(den),
146 data_wp.data.eq(self.bus_rdata << self.bus_addr.offset*32),
147
148 way.valid.eq(valid_lines.bit_select(self.s2_addr.line, 1)),
149 way.tag.eq(tag_rp.data),
150 way.data.eq(data_rp.data)
151 ]
152
153 if platform == "formal":
154 with m.If(Initial()):
155 m.d.comb += Assume(~valid_lines.bool())
156
157 return m