add LGPLv3+ notice and add copyright holders
[soc.git] / src / soc / fu / alu / main_stage.py
1 # This stage is intended to do most of the work of executing the Arithmetic
2 # instructions. This would be like the additions, compares, and sign-extension
3 # as well as carry and overflow generation. This module
4 # however should not gate the carry or overflow, that's up to the
5 # output stage
6
7 # License: LGPLv3+
8 # Copyright (C) 2020 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
9 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
10 # (michael: note that there are multiple copyright holders)
11
12 from nmigen import (Module, Signal, Cat, Repl, Mux, Const)
13 from nmutil.pipemodbase import PipeModBase
14 from nmutil.extend import exts, extz
15 from soc.fu.alu.pipe_data import ALUInputData, ALUOutputData
16 from ieee754.part.partsig import PartitionedSignal
17 from soc.decoder.power_enums import MicrOp
18
19 from soc.decoder.power_fields import DecodeFields
20 from soc.decoder.power_fieldsn import SignalBitRange
21
22
23 # microwatt calc_ov function.
24 def calc_ov(msb_a, msb_b, ca, msb_r):
25 return (ca ^ msb_r) & ~(msb_a ^ msb_b)
26
27
28 class ALUMainStage(PipeModBase):
29 def __init__(self, pspec):
30 super().__init__(pspec, "main")
31 self.fields = DecodeFields(SignalBitRange, [self.i.ctx.op.insn])
32 self.fields.create_specs()
33
34 def ispec(self):
35 return ALUInputData(self.pspec) # defines pipeline stage input format
36
37 def ospec(self):
38 return ALUOutputData(self.pspec) # defines pipeline stage output format
39
40 def elaborate(self, platform):
41 m = Module()
42 comb = m.d.comb
43
44 # convenience variables
45 cry_o, o, cr0 = self.o.xer_ca, self.o.o, self.o.cr0
46 xer_so_i, ov_o = self.i.xer_so, self.o.xer_ov
47 a, b, cry_i, op = self.i.a, self.i.b, self.i.xer_ca, self.i.ctx.op
48
49 # get L-field for OP_CMP
50 x_fields = self.fields.FormX
51 L = x_fields.L[0]
52
53 # check if op is 32-bit, and get sign bit from operand a
54 is_32bit = Signal(reset_less=True)
55
56 with m.If(op.insn_type == MicrOp.OP_CMP):
57 comb += is_32bit.eq(~L)
58
59 # little trick: do the add using only one add (not 2)
60 # LSB: carry-in [0]. op/result: [1:-1]. MSB: carry-out [-1]
61 add_a = Signal(a.width + 2, reset_less=True)
62 add_b = Signal(a.width + 2, reset_less=True)
63 add_o = Signal(a.width + 2, reset_less=True)
64
65 a_i = Signal.like(a)
66 b_i = Signal.like(b)
67 with m.If(op.insn_type == MicrOp.OP_CMP): # another temporary hack
68 comb += a_i.eq(a) # reaaaally need to move CMP
69 comb += b_i.eq(b) # into trap pipeline
70 with m.Elif(is_32bit):
71 with m.If(op.is_signed):
72 comb += a_i.eq(exts(a, 32, 64))
73 comb += b_i.eq(exts(b, 32, 64))
74 with m.Else():
75 comb += a_i.eq(extz(a, 32, 64))
76 comb += b_i.eq(extz(b, 32, 64))
77 with m.Else():
78 comb += a_i.eq(a)
79 comb += b_i.eq(b)
80
81 with m.If((op.insn_type == MicrOp.OP_ADD) |
82 (op.insn_type == MicrOp.OP_CMP)):
83 # in bit 0, 1+carry_in creates carry into bit 1 and above
84 comb += add_a.eq(Cat(cry_i[0], a_i, Const(0, 1)))
85 comb += add_b.eq(Cat(Const(1, 1), b_i, Const(0, 1)))
86 comb += add_o.eq(add_a + add_b)
87
88 ##########################
89 # main switch-statement for handling arithmetic operations
90
91 with m.Switch(op.insn_type):
92
93 ###################
94 #### CMP, CMPL v3.0B p85-86
95
96 with m.Case(MicrOp.OP_CMP):
97 a_n = Signal(64) # temporary - inverted a
98 tval = Signal(5)
99 a_lt = Signal()
100 carry_32 = Signal()
101 carry_64 = Signal()
102 zerolo = Signal()
103 zerohi = Signal()
104 msb_a = Signal()
105 msb_b = Signal()
106 newcrf = Signal(4)
107
108 # this is supposed to be inverted (b-a, not a-b)
109 comb += a_n.eq(~a) # sigh a gets inverted
110 comb += carry_32.eq(add_o[33] ^ a[32] ^ b[32])
111 comb += carry_64.eq(add_o[65])
112
113 comb += zerolo.eq(~((a_n[0:32] ^ b[0:32]).bool()))
114 comb += zerohi.eq(~((a_n[32:64] ^ b[32:64]).bool()))
115
116 with m.If(zerolo & (is_32bit | zerohi)):
117 # values are equal
118 comb += tval[2].eq(1)
119 with m.Else():
120 comb += msb_a.eq(Mux(is_32bit, a_n[31], a_n[63]))
121 comb += msb_b.eq(Mux(is_32bit, b[31], b[63]))
122 C0 = Const(0, 1)
123 with m.If(msb_a != msb_b):
124 # Subtraction might overflow, but
125 # comparison is clear from MSB difference.
126 # for signed, 0 is greater; for unsigned, 1 is greater
127 comb += tval.eq(Cat(msb_a, msb_b, C0, msb_b, msb_a))
128 with m.Else():
129 # Subtraction cannot overflow since MSBs are equal.
130 # carry = 1 indicates RA is smaller (signed or unsigned)
131 comb += a_lt.eq(Mux(is_32bit, carry_32, carry_64))
132 comb += tval.eq(Cat(~a_lt, a_lt, C0, ~a_lt, a_lt))
133 comb += cr0.data[0:2].eq(Cat(xer_so_i[0], tval[2]))
134 with m.If(op.is_signed):
135 comb += cr0.data[2:4].eq(tval[3:5])
136 with m.Else():
137 comb += cr0.data[2:4].eq(tval[0:2])
138 comb += cr0.ok.eq(1)
139
140 ###################
141 #### add v3.0B p67, p69-72
142
143 with m.Case(MicrOp.OP_ADD):
144 # bit 0 is not part of the result, top bit is the carry-out
145 comb += o.data.eq(add_o[1:-1])
146 comb += o.ok.eq(1) # output register
147
148 # see microwatt OP_ADD code
149 # https://bugs.libre-soc.org/show_bug.cgi?id=319#c5
150 ca = Signal(2, reset_less=True)
151 comb += ca[0].eq(add_o[-1]) # XER.CA
152 comb += ca[1].eq(add_o[33] ^ (a_i[32] ^ b_i[32])) # XER.CA32
153 comb += cry_o.data.eq(ca)
154 comb += cry_o.ok.eq(1)
155 # 32-bit (ov[1]) and 64-bit (ov[0]) overflow
156 ov = Signal(2, reset_less=True)
157 comb += ov[0].eq(calc_ov(a_i[-1], b_i[-1], ca[0], add_o[-2]))
158 comb += ov[1].eq(calc_ov(a_i[31], b_i[31], ca[1], add_o[32]))
159 comb += ov_o.data.eq(ov)
160 comb += ov_o.ok.eq(1)
161
162 ###################
163 #### exts (sign-extend) v3.0B p96, p99
164
165 with m.Case(MicrOp.OP_EXTS):
166 with m.If(op.data_len == 1):
167 comb += o.data.eq(exts(a, 8, 64))
168 with m.If(op.data_len == 2):
169 comb += o.data.eq(exts(a, 16, 64))
170 with m.If(op.data_len == 4):
171 comb += o.data.eq(exts(a, 32, 64))
172 comb += o.ok.eq(1) # output register
173
174 ###################
175 #### cmpeqb v3.0B p88
176
177 with m.Case(MicrOp.OP_CMPEQB):
178 eqs = Signal(8, reset_less=True)
179 src1 = Signal(8, reset_less=True)
180 comb += src1.eq(a[0:8])
181 for i in range(8):
182 comb += eqs[i].eq(src1 == b[8*i:8*(i+1)])
183 comb += o.data[0].eq(eqs.any())
184 comb += o.ok.eq(0) # use o.data but do *not* actually output
185 comb += cr0.data.eq(Cat(Const(0, 2), eqs.any(), Const(0, 1)))
186 comb += cr0.ok.eq(1)
187
188 ###### sticky overflow and context, both pass-through #####
189
190 comb += self.o.xer_so.data.eq(xer_so_i)
191 comb += self.o.ctx.eq(self.i.ctx)
192
193 return m