add twin partial address mapper class
[soc.git] / src / soc / scoreboard / addr_split.py
1 # LDST Address Splitter. For misaligned address crossing cache line boundary
2
3 from nmigen import Elaboratable, Module, Signal, Record, Array, Const
4 from nmutil.latch import SRLatch, latchregister
5 from nmigen.back.pysim import Simulator, Delay
6 from nmigen.cli import verilog, rtlil
7
8 from soc.scoreboard.addr_match import LenExpand
9 #from nmutil.queue import Queue
10
11 class LDData(Record):
12 def __init__(self, dwidth, name=None):
13 Record.__init__(self, (('err', 1), ('data', dwidth)), name=name)
14
15
16 class LDLatch(Elaboratable):
17
18 def __init__(self, dwidth, awidth, mlen):
19 self.addr_i = Signal(awidth, reset_less=True)
20 self.mask_i = Signal(mlen, reset_less=True)
21 self.valid_i = Signal(reset_less=True)
22 self.ld_i = LDData(dwidth, "ld_i")
23 self.ld_o = LDData(dwidth, "ld_o")
24 self.valid_o = Signal(reset_less=True)
25
26 def elaborate(self, platform):
27 m = Module()
28 comb = m.d.comb
29 m.submodules.in_l = in_l = SRLatch(sync=False, name="in_l")
30
31 comb += in_l.s.eq(self.valid_i)
32 comb += self.valid_o.eq(in_l.q & self.valid_i)
33 latchregister(m, self.ld_i, self.ld_o, in_l.q & self.valid_o, "ld_i_r")
34
35 return m
36
37 def __iter__(self):
38 yield self.addr_i
39 yield self.mask_i
40 yield self.ld_i.err
41 yield self.ld_i.data
42 yield self.ld_o.err
43 yield self.ld_o.data
44 yield self.valid_i
45 yield self.valid_o
46
47 def ports(self):
48 return list(self)
49
50 class LDSTSplitter(Elaboratable):
51
52 def __init__(self, dwidth, awidth, dlen):
53 self.dwidth, self.awidth, self.dlen = dwidth, awidth, dlen
54 self.addr_i = Signal(awidth, reset_less=True)
55 self.len_i = Signal(dlen, reset_less=True)
56 self.is_ld_i = Signal(reset_less=True)
57 self.ld_data_o = LDData(dwidth, "ld_data_o")
58 self.ld_valid_i = Signal(reset_less=True)
59 self.valid_o = Signal(reset_less=True)
60 self.sld_valid_o = Signal(2, reset_less=True)
61 self.sld_valid_i = Signal(2, reset_less=True)
62 self.sld_data_i = Array((LDData(dwidth, "ld_data_i1"),
63 LDData(dwidth, "ld_data_i2")))
64
65 #self.is_st_i = Signal(reset_less=True)
66 #self.st_data_i = Signal(dwidth, reset_less=True)
67
68 def elaborate(self, platform):
69 m = Module()
70 comb = m.d.comb
71 dlen = self.dlen
72 mlen = 1 << dlen
73 m.submodules.ld1 = ld1 = LDLatch(self.dwidth, self.awidth-dlen, mlen)
74 m.submodules.ld2 = ld2 = LDLatch(self.dwidth, self.awidth-dlen, mlen)
75 m.submodules.lenexp = lenexp = LenExpand(self.dlen)
76
77 # set up len-expander, len to mask. ld1 gets first bit, ld2 gets rest
78 comb += lenexp.addr_i.eq(self.addr_i)
79 comb += lenexp.len_i.eq(self.len_i)
80 mask1 = Signal(mlen, reset_less=True)
81 mask2 = Signal(mlen, reset_less=True)
82 comb += mask1.eq(lenexp.lexp_o[0:mlen]) # Lo bits of expanded len-mask
83 comb += mask2.eq(lenexp.lexp_o[mlen:]) # Hi bits of expanded len-mask
84
85 # set up new address records: addr1 is "as-is", addr2 is +1
86 comb += ld1.addr_i.eq(self.addr_i[dlen:])
87 comb += ld2.addr_i.eq(self.addr_i[dlen:] + 1) # TODO exception if rolls
88
89 # set up connections to LD-split. note: not active if mask is zero
90 mzero = Const(0, mlen)
91 for i, (ld, mask) in enumerate(((ld1, mask1),
92 (ld2, mask2))):
93 ld_valid = Signal(name="ldvalid_i%d" % i, reset_less=True)
94 comb += ld_valid.eq(self.ld_valid_i & self.sld_valid_i[i])
95 comb += ld.valid_i.eq(ld_valid & (mask != mzero))
96 comb += ld.ld_i.eq(self.sld_data_i[i])
97 comb += self.sld_valid_o[i].eq(ld.valid_o)
98
99 # sort out valid: mask2 zero we ignore 2nd LD
100 with m.If(mask2 == mzero):
101 comb += self.valid_o.eq(self.sld_valid_o[0])
102 with m.Else():
103 comb += self.valid_o.eq(self.sld_valid_o.all())
104
105 # all bits valid (including when a data error occurs!) decode ld1/ld2
106 with m.If(self.valid_o):
107 # errors cause error condition
108 comb += self.ld_data_o.err.eq(ld1.ld_o.err | ld2.ld_o.err)
109 # data needs recombining via shifting.
110 ashift1 = Signal(self.dlen)
111 ashift2 = Signal(self.dlen)
112 comb += ashift1.eq(self.addr_i[:self.dlen])
113 comb += ashift2.eq((1<<dlen)-ashift1)
114 # note that data from LD1 will be in *cache-line* byte position
115 # likewise from LD2 but we *know* it is at the start of the line
116 comb += self.ld_data_o.data.eq((ld1.ld_o.data >> ashift1) |
117 (ld2.ld_o.data << ashift2))
118
119 return m
120
121 def __iter__(self):
122 yield self.addr_i
123 yield self.len_i
124 yield self.is_ld_i
125 yield self.ld_data_o.err
126 yield self.ld_data_o.data
127 yield self.ld_valid_i
128 yield self.valid_o
129 yield self.sld_valid_i
130 for i in range(2):
131 yield self.sld_data_i[i].err
132 yield self.sld_data_i[i].data
133
134 def ports(self):
135 return list(self)
136
137 def sim(dut):
138
139 sim = Simulator(dut)
140 sim.add_clock(1e-6)
141 data = 0b11010011
142 dlen = 4 # 4 bits
143 addr = 0b1100
144 ld_len = 8
145 ldm = ((1<<ld_len)-1)
146 dlm = ((1<<dlen)-1)
147 data = data & ldm # truncate data to be tested, mask to within ld len
148 print ("ldm", ldm, bin(data&ldm))
149 print ("dlm", dlm, bin(addr&dlm))
150 dmask = ldm << (addr & dlm)
151 print ("dmask", bin(dmask))
152 dmask1 = dmask >> (1<<dlen)
153 print ("dmask1", bin(dmask1))
154 dmask = dmask & ((1<<(1<<dlen))-1)
155 print ("dmask", bin(dmask))
156
157 def send_in():
158 print ("send_in")
159 yield dut.len_i.eq(ld_len)
160 yield dut.addr_i.eq(addr)
161 yield dut.ld_valid_i.eq(1)
162 print ("waiting")
163 while True:
164 valid_o = yield dut.valid_o
165 if valid_o:
166 break
167 yield
168 ld_data_o = yield dut.ld_data_o.data
169 yield
170
171 print (bin(ld_data_o), bin(data))
172 assert ld_data_o == data
173
174 def lds():
175 print ("lds")
176 while True:
177 valid_i = yield dut.ld_valid_i
178 if valid_i:
179 break
180 yield
181
182 shf = addr & dlm
183 shfdata = (data << shf)
184 data1 = shfdata & dmask
185 print ("ld data1", bin(data), bin(data1), shf, bin(dmask))
186
187 data2 = (shfdata >> 16) & dmask1
188 print ("ld data2", 1<<dlen, bin(data >> (1<<dlen)), bin(data2))
189 yield dut.sld_data_i[0].data.eq(data1)
190 yield dut.sld_valid_i[0].eq(1)
191 yield
192 yield dut.sld_data_i[1].data.eq(data2)
193 yield dut.sld_valid_i[1].eq(1)
194 yield
195
196 sim.add_sync_process(lds)
197 sim.add_sync_process(send_in)
198
199 prefix = "ldst_splitter"
200 with sim.write_vcd("%s.vcd" % prefix, traces=dut.ports()):
201 sim.run()
202
203
204 if __name__ == '__main__':
205 dut = LDSTSplitter(32, 48, 4)
206 vl = rtlil.convert(dut, ports=dut.ports())
207 with open("ldst_splitter.il", "w") as f:
208 f.write(vl)
209
210 sim(dut)