add minerva source from https://github.com/lambdaconcept/minerva
[soc.git] / src / soc / minerva / cache.py
1 from nmigen import *
2 from nmigen.asserts import *
3 from nmigen.lib.coding import Encoder
4 from nmigen.utils import log2_int
5
6
7 __all__ = ["L1Cache"]
8
9
10 class L1Cache(Elaboratable):
11 def __init__(self, nways, nlines, nwords, base, limit):
12 if not nlines or nlines & nlines-1:
13 raise ValueError("nlines must be a power of 2, not {!r}".format(nlines))
14 if nwords not in {4, 8, 16}:
15 raise ValueError("nwords must be 4, 8 or 16, not {!r}".format(nwords))
16 if nways not in {1, 2}:
17 raise ValueError("nways must be 1 or 2, not {!r}".format(nways))
18
19 self.nways = nways
20 self.nlines = nlines
21 self.nwords = nwords
22 self.base = base
23 self.limit = limit
24
25 offsetbits = log2_int(nwords)
26 linebits = log2_int(nlines)
27 tagbits = log2_int(limit-base) - log2_int(nlines) - log2_int(nwords) - 2
28
29 self.s1_addr = Record([("offset", offsetbits), ("line", linebits), ("tag", tagbits)])
30 self.s1_flush = Signal()
31 self.s1_stall = Signal()
32 self.s1_valid = Signal()
33 self.s2_addr = Record.like(self.s1_addr)
34 self.s2_re = Signal()
35 self.s2_evict = Signal()
36 self.s2_valid = Signal()
37 self.bus_valid = Signal()
38 self.bus_error = Signal()
39 self.bus_rdata = Signal(32)
40
41 self.s2_miss = Signal()
42 self.s2_rdata = Signal(32)
43 self.bus_re = Signal()
44 self.bus_addr = Record.like(self.s1_addr)
45 self.bus_last = Signal()
46
47 def elaborate(self, platform):
48 m = Module()
49
50 ways = Array(Record([("data", self.nwords * 32),
51 ("tag", self.s2_addr.tag.shape()),
52 ("valid", 1),
53 ("bus_re", 1)])
54 for _ in range(self.nways))
55
56 if self.nways == 1:
57 way_lru = Const(0)
58 elif self.nways == 2:
59 way_lru = Signal()
60 with m.If(self.bus_re & self.bus_valid & self.bus_last & ~self.bus_error):
61 m.d.sync += way_lru.eq(~way_lru)
62
63 m.d.comb += ways[way_lru].bus_re.eq(self.bus_re)
64
65 way_hit = m.submodules.way_hit = Encoder(self.nways)
66 for j, way in enumerate(ways):
67 m.d.comb += way_hit.i[j].eq((way.tag == self.s2_addr.tag) & way.valid)
68
69 m.d.comb += [
70 self.s2_miss.eq(way_hit.n),
71 self.s2_rdata.eq(ways[way_hit.o].data.word_select(self.s2_addr.offset, 32))
72 ]
73
74 with m.FSM() as fsm:
75 last_offset = Signal.like(self.s2_addr.offset)
76
77 with m.State("CHECK"):
78 with m.If(self.s2_re & self.s2_miss & self.s2_valid):
79 m.d.sync += [
80 self.bus_addr.eq(self.s2_addr),
81 self.bus_re.eq(1),
82 last_offset.eq(self.s2_addr.offset - 1)
83 ]
84 m.next = "REFILL"
85
86 with m.State("REFILL"):
87 m.d.comb += self.bus_last.eq(self.bus_addr.offset == last_offset)
88 with m.If(self.bus_valid):
89 m.d.sync += self.bus_addr.offset.eq(self.bus_addr.offset + 1)
90 with m.If(self.bus_valid & self.bus_last | self.bus_error):
91 m.d.sync += self.bus_re.eq(0)
92 with m.If(~self.bus_re & ~self.s1_stall):
93 m.next = "CHECK"
94
95 if platform == "formal":
96 with m.If(Initial()):
97 m.d.comb += Assume(fsm.ongoing("CHECK"))
98
99 for way in ways:
100 valid_lines = Signal(self.nlines)
101
102 with m.If(self.s1_flush & self.s1_valid):
103 m.d.sync += valid_lines.eq(0)
104 with m.Elif(way.bus_re & self.bus_error):
105 m.d.sync += valid_lines.bit_select(self.bus_addr.line, 1).eq(0)
106 with m.Elif(way.bus_re & self.bus_valid & self.bus_last):
107 m.d.sync += valid_lines.bit_select(self.bus_addr.line, 1).eq(1)
108 with m.Elif(self.s2_evict & self.s2_valid & (way.tag == self.s2_addr.tag)):
109 m.d.sync += valid_lines.bit_select(self.s2_addr.line, 1).eq(0)
110
111 tag_mem = Memory(width=len(way.tag), depth=self.nlines)
112 tag_rp = tag_mem.read_port()
113 tag_wp = tag_mem.write_port()
114 m.submodules += tag_rp, tag_wp
115
116 data_mem = Memory(width=len(way.data), depth=self.nlines)
117 data_rp = data_mem.read_port()
118 data_wp = data_mem.write_port(granularity=32)
119 m.submodules += data_rp, data_wp
120
121 m.d.comb += [
122 tag_rp.addr.eq(Mux(self.s1_stall, self.s2_addr.line, self.s1_addr.line)),
123 data_rp.addr.eq(Mux(self.s1_stall, self.s2_addr.line, self.s1_addr.line)),
124
125 tag_wp.addr.eq(self.bus_addr.line),
126 tag_wp.en.eq(way.bus_re & self.bus_valid & self.bus_last),
127 tag_wp.data.eq(self.bus_addr.tag),
128
129 data_wp.addr.eq(self.bus_addr.line),
130 data_wp.en.bit_select(self.bus_addr.offset, 1).eq(way.bus_re & self.bus_valid),
131 data_wp.data.eq(self.bus_rdata << self.bus_addr.offset*32),
132
133 way.valid.eq(valid_lines.bit_select(self.s2_addr.line, 1)),
134 way.tag.eq(tag_rp.data),
135 way.data.eq(data_rp.data)
136 ]
137
138 if platform == "formal":
139 with m.If(Initial()):
140 m.d.comb += Assume(~valid_lines.bool())
141
142 return m