dewildcardify unitsg
[soc.git] / src / soc / minerva / units / decoder.py
1 from functools import reduce
2 from itertools import starmap
3 from operator import or_
4
5 from nmigen import Elaboratable, Module, Signal, Cat
6
7 from ..isa import Opcode, Funct3, Funct7, Funct12
8
9
10 __all__ = ["InstructionDecoder"]
11
12
13 class Type:
14 R = 0
15 I = 1
16 S = 2
17 B = 3
18 U = 4
19 J = 5
20
21
22 class InstructionDecoder(Elaboratable):
23 def __init__(self, with_muldiv):
24 self.with_muldiv = with_muldiv
25
26 self.instruction = Signal(32)
27
28 self.rd = Signal(5)
29 self.rd_we = Signal()
30 self.rs1 = Signal(5)
31 self.rs1_re = Signal()
32 self.rs2 = Signal(5)
33 self.rs2_re = Signal()
34 self.immediate = Signal((32, True))
35 self.bypass_x = Signal()
36 self.bypass_m = Signal()
37 self.load = Signal()
38 self.store = Signal()
39 self.fence_i = Signal()
40 self.adder = Signal()
41 self.adder_sub = Signal()
42 self.logic = Signal()
43 self.multiply = Signal()
44 self.divide = Signal()
45 self.shift = Signal()
46 self.direction = Signal()
47 self.sext = Signal()
48 self.lui = Signal()
49 self.auipc = Signal()
50 self.jump = Signal()
51 self.branch = Signal()
52 self.compare = Signal()
53 self.csr = Signal()
54 self.csr_we = Signal()
55 self.privileged = Signal()
56 self.ecall = Signal()
57 self.ebreak = Signal()
58 self.mret = Signal()
59 self.funct3 = Signal(3)
60 self.illegal = Signal()
61
62 def elaborate(self, platform):
63 m = Module()
64
65 opcode = Signal(5)
66 funct3 = Signal(3)
67 funct7 = Signal(7)
68 funct12 = Signal(12)
69
70 iimm12 = Signal((12, True))
71 simm12 = Signal((12, True))
72 bimm12 = Signal((13, True))
73 uimm20 = Signal(20)
74 jimm20 = Signal((21, True))
75
76 insn = self.instruction
77 fmt = Signal(range(Type.J + 1))
78
79 m.d.comb += [
80 opcode.eq(insn[2:7]),
81 funct3.eq(insn[12:15]),
82 funct7.eq(insn[25:32]),
83 funct12.eq(insn[20:32]),
84
85 iimm12.eq(insn[20:32]),
86 simm12.eq(Cat(insn[7:12], insn[25:32])),
87 bimm12.eq(Cat(0, insn[8:12], insn[25:31], insn[7], insn[31])),
88 uimm20.eq(insn[12:32]),
89 jimm20.eq(Cat(0, insn[21:31], insn[20], insn[12:20], insn[31])),
90 ]
91
92 with m.Switch(opcode):
93 with m.Case(Opcode.LUI):
94 m.d.comb += fmt.eq(Type.U)
95 with m.Case(Opcode.AUIPC):
96 m.d.comb += fmt.eq(Type.U)
97 with m.Case(Opcode.JAL):
98 m.d.comb += fmt.eq(Type.J)
99 with m.Case(Opcode.JALR):
100 m.d.comb += fmt.eq(Type.I)
101 with m.Case(Opcode.BRANCH):
102 m.d.comb += fmt.eq(Type.B)
103 with m.Case(Opcode.LOAD):
104 m.d.comb += fmt.eq(Type.I)
105 with m.Case(Opcode.STORE):
106 m.d.comb += fmt.eq(Type.S)
107 with m.Case(Opcode.OP_IMM_32):
108 m.d.comb += fmt.eq(Type.I)
109 with m.Case(Opcode.OP_32):
110 m.d.comb += fmt.eq(Type.R)
111 with m.Case(Opcode.MISC_MEM):
112 m.d.comb += fmt.eq(Type.I)
113 with m.Case(Opcode.SYSTEM):
114 m.d.comb += fmt.eq(Type.I)
115
116 with m.Switch(fmt):
117 with m.Case(Type.I):
118 m.d.comb += self.immediate.eq(iimm12)
119 with m.Case(Type.S):
120 m.d.comb += self.immediate.eq(simm12)
121 with m.Case(Type.B):
122 m.d.comb += self.immediate.eq(bimm12)
123 with m.Case(Type.U):
124 m.d.comb += self.immediate.eq(uimm20 << 12)
125 with m.Case(Type.J):
126 m.d.comb += self.immediate.eq(jimm20)
127
128 m.d.comb += [
129 self.rd.eq(insn[7:12]),
130 self.rs1.eq(insn[15:20]),
131 self.rs2.eq(insn[20:25]),
132
133 self.rd_we.eq(reduce(or_, (fmt == T for T in (Type.R, Type.I, Type.U, Type.J)))),
134 self.rs1_re.eq(reduce(or_, (fmt == T for T in (Type.R, Type.I, Type.S, Type.B)))),
135 self.rs2_re.eq(reduce(or_, (fmt == T for T in (Type.R, Type.S, Type.B)))),
136
137 self.funct3.eq(funct3)
138 ]
139
140 def matcher(encodings):
141 return reduce(or_, starmap(
142 lambda opc, f3=None, f7=None, f12=None:
143 (opcode == opc if opc is not None else 1) \
144 & (funct3 == f3 if f3 is not None else 1) \
145 & (funct7 == f7 if f7 is not None else 1) \
146 & (funct12 == f12 if f12 is not None else 1),
147 encodings))
148
149 m.d.comb += [
150 self.compare.eq(matcher([
151 (Opcode.OP_IMM_32, Funct3.SLT, None), # slti
152 (Opcode.OP_IMM_32, Funct3.SLTU, None), # sltiu
153 (Opcode.OP_32, Funct3.SLT, 0), # slt
154 (Opcode.OP_32, Funct3.SLTU, 0) # sltu
155 ])),
156 self.branch.eq(matcher([
157 (Opcode.BRANCH, Funct3.BEQ, None), # beq
158 (Opcode.BRANCH, Funct3.BNE, None), # bne
159 (Opcode.BRANCH, Funct3.BLT, None), # blt
160 (Opcode.BRANCH, Funct3.BGE, None), # bge
161 (Opcode.BRANCH, Funct3.BLTU, None), # bltu
162 (Opcode.BRANCH, Funct3.BGEU, None) # bgeu
163 ])),
164
165 self.adder.eq(matcher([
166 (Opcode.OP_IMM_32, Funct3.ADD, None), # addi
167 (Opcode.OP_32, Funct3.ADD, Funct7.ADD), # add
168 (Opcode.OP_32, Funct3.ADD, Funct7.SUB) # sub
169 ])),
170 self.adder_sub.eq(self.rs2_re & (funct7 == Funct7.SUB)),
171
172 self.logic.eq(matcher([
173 (Opcode.OP_IMM_32, Funct3.XOR, None), # xori
174 (Opcode.OP_IMM_32, Funct3.OR, None), # ori
175 (Opcode.OP_IMM_32, Funct3.AND, None), # andi
176 (Opcode.OP_32, Funct3.XOR, 0), # xor
177 (Opcode.OP_32, Funct3.OR, 0), # or
178 (Opcode.OP_32, Funct3.AND, 0) # and
179 ])),
180 ]
181
182 if self.with_muldiv:
183 m.d.comb += [
184 self.multiply.eq(matcher([
185 (Opcode.OP_32, Funct3.MUL, Funct7.MULDIV), # mul
186 (Opcode.OP_32, Funct3.MULH, Funct7.MULDIV), # mulh
187 (Opcode.OP_32, Funct3.MULHSU, Funct7.MULDIV), # mulhsu
188 (Opcode.OP_32, Funct3.MULHU, Funct7.MULDIV), # mulhu
189 ])),
190
191 self.divide.eq(matcher([
192 (Opcode.OP_32, Funct3.DIV, Funct7.MULDIV), # div
193 (Opcode.OP_32, Funct3.DIVU, Funct7.MULDIV), # divu
194 (Opcode.OP_32, Funct3.REM, Funct7.MULDIV), # rem
195 (Opcode.OP_32, Funct3.REMU, Funct7.MULDIV) # remu
196 ])),
197 ]
198
199 m.d.comb += [
200 self.shift.eq(matcher([
201 (Opcode.OP_IMM_32, Funct3.SLL, 0), # slli
202 (Opcode.OP_IMM_32, Funct3.SR, Funct7.SRL), # srli
203 (Opcode.OP_IMM_32, Funct3.SR, Funct7.SRA), # srai
204 (Opcode.OP_32, Funct3.SLL, 0), # sll
205 (Opcode.OP_32, Funct3.SR, Funct7.SRL), # srl
206 (Opcode.OP_32, Funct3.SR, Funct7.SRA) # sra
207 ])),
208 self.direction.eq(funct3 == Funct3.SR),
209 self.sext.eq(funct7 == Funct7.SRA),
210
211 self.lui.eq(opcode == Opcode.LUI),
212 self.auipc.eq(opcode == Opcode.AUIPC),
213
214 self.jump.eq(matcher([
215 (Opcode.JAL, None), # jal
216 (Opcode.JALR, 0) # jalr
217 ])),
218
219 self.load.eq(matcher([
220 (Opcode.LOAD, Funct3.B), # lb
221 (Opcode.LOAD, Funct3.BU), # lbu
222 (Opcode.LOAD, Funct3.H), # lh
223 (Opcode.LOAD, Funct3.HU), # lhu
224 (Opcode.LOAD, Funct3.W) # lw
225 ])),
226 self.store.eq(matcher([
227 (Opcode.STORE, Funct3.B), # sb
228 (Opcode.STORE, Funct3.H), # sh
229 (Opcode.STORE, Funct3.W) # sw
230 ])),
231
232 self.fence_i.eq(matcher([
233 (Opcode.MISC_MEM, Funct3.FENCEI) # fence.i
234 ])),
235
236 self.csr.eq(matcher([
237 (Opcode.SYSTEM, Funct3.CSRRW), # csrrw
238 (Opcode.SYSTEM, Funct3.CSRRS), # csrrs
239 (Opcode.SYSTEM, Funct3.CSRRC), # csrrc
240 (Opcode.SYSTEM, Funct3.CSRRWI), # csrrwi
241 (Opcode.SYSTEM, Funct3.CSRRSI), # csrrsi
242 (Opcode.SYSTEM, Funct3.CSRRCI) # csrrci
243 ])),
244 self.csr_we.eq(~funct3[1] | (self.rs1 != 0)),
245
246 self.privileged.eq((opcode == Opcode.SYSTEM) & (funct3 == Funct3.PRIV)),
247 self.ecall.eq(self.privileged & (funct12 == Funct12.ECALL)),
248 self.ebreak.eq(self.privileged & (funct12 == Funct12.EBREAK)),
249 self.mret.eq(self.privileged & (funct12 == Funct12.MRET)),
250
251 self.bypass_x.eq(self.adder | self.logic | self.lui | self.auipc | self.csr),
252 self.bypass_m.eq(self.compare | self.divide | self.shift),
253
254 self.illegal.eq((self.instruction[:2] != 0b11) | ~reduce(or_, (
255 self.compare, self.branch, self.adder, self.logic, self.multiply, self.divide, self.shift,
256 self.lui, self.auipc, self.jump, self.load, self.store,
257 self.csr, self.ecall, self.ebreak, self.mret
258 )))
259 ]
260
261 return m