convert shift_rot pipeline to XLEN=32/64
[soc.git] / src / soc / fu / shift_rot / rotator.py
1 # Manual translation and adaptation of rotator.vhdl from microwatt into nmigen
2 #
3 from nmigen.compat.sim import run_simulation
4
5 from nmigen import (Elaboratable, Signal, Module, Const, Cat, Repl,
6 unsigned, signed)
7 from soc.fu.shift_rot.rotl import ROTL
8 from nmigen.back.pysim import Settle
9 from nmutil.extend import exts
10 from nmutil.mask import Mask
11
12
13 # note BE bit numbering
14 def right_mask(m, mask_begin, width):
15 ret = Signal(width, name="right_mask", reset_less=True)
16 with m.If(mask_begin <= width):
17 m.d.comb += ret.eq((1 << (width-mask_begin)) - 1)
18 with m.Else():
19 m.d.comb += ret.eq(0)
20 return ret
21
22
23 def left_mask(m, mask_end, width):
24 ret = Signal(width, name="left_mask", reset_less=True)
25 m.d.comb += ret.eq(~((1 << (width-1-mask_end)) - 1))
26 return ret
27
28
29 class Rotator(Elaboratable):
30 """Rotator: covers multiple POWER9 rotate functions
31
32 supported modes:
33
34 * sl[wd]
35 * rlw*, rldic, rldicr, rldimi
36 * rldicl, sr[wd]
37 * sra[wd][i]
38
39 use as follows:
40
41 * shift = RB[0:7]
42 * arith = 1 when is_signed
43 * right_shift = 1 when insn_type is OP_SHR
44 * clear_left = 1 when insn_type is OP_RLC or OP_RLCL
45 * clear_right = 1 when insn_type is OP_RLC or OP_RLCR
46 """
47
48 def __init__(self, width):
49 self.width = width
50 # input
51 self.me = Signal(5, reset_less=True) # ME field
52 self.mb = Signal(5, reset_less=True) # MB field
53 # extra bit of mb in MD-form
54 self.mb_extra = Signal(1, reset_less=True)
55 self.ra = Signal(width, reset_less=True) # RA
56 self.rs = Signal(width, reset_less=True) # RS
57 self.shift = Signal(7, reset_less=True) # RB[0:7]
58 self.is_32bit = Signal(reset_less=True)
59 self.right_shift = Signal(reset_less=True)
60 self.arith = Signal(reset_less=True)
61 self.clear_left = Signal(reset_less=True)
62 self.clear_right = Signal(reset_less=True)
63 self.sign_ext_rs = Signal(reset_less=True)
64 # output
65 self.result_o = Signal(width, reset_less=True)
66 self.carry_out_o = Signal(reset_less=True)
67
68 def elaborate(self, platform):
69 width = self.width
70 m = Module()
71 comb = m.d.comb
72 ra, rs = self.ra, self.rs
73
74 # temporaries
75 rot_count = Signal(6, reset_less=True)
76 rot = Signal(64, reset_less=True)
77 sh = Signal(7, reset_less=True)
78 mb = Signal(7, reset_less=True)
79 me = Signal(7, reset_less=True)
80 mr = Signal(width, reset_less=True)
81 ml = Signal(width, reset_less=True)
82 output_mode = Signal(2, reset_less=True)
83 hi32 = Signal(32, reset_less=True)
84 repl32 = Signal(width, reset_less=True)
85
86 # First replicate bottom 32 bits to both halves if 32-bit
87 with m.If(self.is_32bit):
88 comb += hi32.eq(rs[0:32])
89 with m.Elif(self.sign_ext_rs):
90 # sign-extend bottom 32 bits
91 comb += hi32.eq(Repl(rs[31], 32))
92 with m.Else():
93 if width == 64:
94 comb += hi32.eq(rs[32:64])
95 comb += repl32.eq(Cat(rs[0:32], hi32))
96
97 shift_signed = Signal(signed(6))
98 comb += shift_signed.eq(self.shift[0:6])
99
100 # Negate shift count for right shifts
101 with m.If(self.right_shift):
102 comb += rot_count.eq(-shift_signed)
103 with m.Else():
104 comb += rot_count.eq(self.shift[0:6])
105
106 # ROTL submodule
107 m.submodules.rotl = rotl = ROTL(width)
108 comb += rotl.a.eq(repl32)
109 comb += rotl.b.eq(rot_count)
110 comb += rot.eq(rotl.o)
111
112 # Trim shift count to 6 bits for 32-bit shifts
113 comb += sh.eq(Cat(self.shift[0:6], self.shift[6] & ~self.is_32bit))
114
115 # XXX errr... we should already have these, in Fields? oh well
116 # Work out mask begin/end indexes (caution, big-endian bit numbering)
117
118 # mask-begin (mb)
119 with m.If(self.clear_left):
120 comb += mb.eq(self.mb)
121 with m.If(self.is_32bit):
122 comb += mb[5:7].eq(Const(0b01, 2))
123 with m.Else():
124 comb += mb[5:7].eq(Cat(self.mb_extra, Const(0b0, 1)))
125 with m.Elif(self.right_shift):
126 # this is basically mb = sh + (is_32bit? 32: 0);
127 comb += mb.eq(sh)
128 with m.If(self.is_32bit):
129 comb += mb[5:7].eq(Cat(~sh[5], sh[5]))
130 with m.Else():
131 comb += mb.eq(Cat(Const(0b0, 5), self.is_32bit, Const(0b0, 1)))
132
133 # mask-end (me)
134 with m.If(self.clear_right & self.is_32bit):
135 # TODO: track down where this is. have to use fields.
136 comb += me.eq(Cat(self.me, Const(0b01, 2)))
137 with m.Elif(self.clear_right & ~self.clear_left):
138 # this is me, have to use fields
139 comb += me.eq(Cat(self.mb, self.mb_extra, Const(0b0, 1)))
140 with m.Else():
141 # effectively, 63 - sh
142 comb += me.eq(Cat(~sh[0:6], sh[6]))
143
144 # Calculate left and right masks
145 m.submodules.right_mask = right_mask = Mask(width)
146 with m.If(mb <= width):
147 comb += right_mask.shift.eq(width-mb)
148 comb += mr.eq(right_mask.mask)
149 with m.Else():
150 comb += mr.eq(0)
151 #comb += mr.eq(right_mask(m, mb))
152
153 m.submodules.left_mask = left_mask = Mask(width)
154 comb += left_mask.shift.eq(width-1-me)
155 comb += ml.eq(~left_mask.mask)
156 #comb += ml.eq(left_mask(m, me))
157
158
159 # Work out output mode
160 # 00 for sl[wd]
161 # 0w for rlw*, rldic, rldicr, rldimi, where w = 1 iff mb > me
162 # 10 for rldicl, sr[wd]
163 # 1z for sra[wd][i], z = 1 if rs is negative
164 with m.If((self.clear_left & ~self.clear_right) | self.right_shift):
165 comb += output_mode.eq(Cat(self.arith &
166 repl32[width-1], Const(1, 1)))
167 with m.Else():
168 mbgt = self.clear_right & (mb[0:6] > me[0:6])
169 comb += output_mode.eq(Cat(mbgt, Const(0, 1)))
170
171 # Generate output from rotated input and masks
172 with m.Switch(output_mode):
173 with m.Case(0b00):
174 comb += self.result_o.eq((rot & (mr & ml)) | (ra & ~(mr & ml)))
175 with m.Case(0b01):
176 comb += self.result_o.eq((rot & (mr | ml)) | (ra & ~(mr | ml)))
177 with m.Case(0b10):
178 comb += self.result_o.eq(rot & mr)
179 with m.Case(0b11):
180 comb += self.result_o.eq(rot | ~mr)
181 # Generate carry output for arithmetic shift right of -ve value
182 comb += self.carry_out_o.eq((rs & ~ml).bool())
183
184 return m
185
186
187 if __name__ == '__main__':
188
189 m = Module()
190 comb = m.d.comb
191 mr = Signal(64)
192 mb = Signal(6)
193 comb += mr.eq(left_mask(m, mb, 64))
194
195 def loop():
196 for i in range(64):
197 yield mb.eq(63-i)
198 yield Settle()
199 res = yield mr
200 print(i, hex(res))
201
202 run_simulation(m, [loop()],
203 vcd_name="test_mask.vcd")