7ea9b0da54da304cdc505c718d8c0a97489b5662
[soc.git] / src / soc / fu / shift_rot / rotator.py
1 # Manual translation and adaptation of rotator.vhdl from microwatt into nmigen
2 #
3 from nmigen.compat.sim import run_simulation
4
5 from nmigen import (Elaboratable, Signal, Module, Const, Cat, Repl,
6 unsigned, signed)
7 from soc.fu.shift_rot.rotl import ROTL
8 from nmutil.extend import exts
9 from nmigen.back.pysim import Settle
10
11
12 # note BE bit numbering
13 def right_mask(m, mask_begin):
14 ret = Signal(64, name="right_mask", reset_less=True)
15 with m.If(mask_begin <= 64):
16 m.d.comb += ret.eq((1<<(64-mask_begin)) - 1)
17 return ret
18
19 def left_mask(m, mask_end):
20 ret = Signal(64, name="left_mask", reset_less=True)
21 m.d.comb += ret.eq(~((1<<(63-mask_end)) - 1))
22 return ret
23
24
25 class Rotator(Elaboratable):
26 """Rotator: covers multiple POWER9 rotate functions
27
28 supported modes:
29
30 * sl[wd]
31 * rlw*, rldic, rldicr, rldimi
32 * rldicl, sr[wd]
33 * sra[wd][i]
34
35 use as follows:
36
37 * shift = RB[0:7]
38 * arith = 1 when is_signed
39 * right_shift = 1 when insn_type is OP_SHR
40 * clear_left = 1 when insn_type is OP_RLC or OP_RLCL
41 * clear_right = 1 when insn_type is OP_RLC or OP_RLCR
42 """
43 def __init__(self):
44 # input
45 self.me = Signal(5, reset_less=True) # ME field
46 self.mb = Signal(5, reset_less=True) # MB field
47 self.mb_extra = Signal(1, reset_less=True) # extra bit of mb in MD-form
48 self.ra = Signal(64, reset_less=True) # RA
49 self.rs = Signal(64, reset_less=True) # RS
50 self.shift = Signal(7, reset_less=True) # RB[0:7]
51 self.is_32bit = Signal(reset_less=True)
52 self.right_shift = Signal(reset_less=True)
53 self.arith = Signal(reset_less=True)
54 self.clear_left = Signal(reset_less=True)
55 self.clear_right = Signal(reset_less=True)
56 self.sign_ext_rs = Signal(reset_less=True)
57 # output
58 self.result_o = Signal(64, reset_less=True)
59 self.carry_out_o = Signal(reset_less=True)
60
61 def elaborate(self, platform):
62 m = Module()
63 comb = m.d.comb
64 ra, rs = self.ra, self.rs
65
66 # temporaries
67 rot_count = Signal(6, reset_less=True)
68 rot = Signal(64, reset_less=True)
69 sh = Signal(7, reset_less=True)
70 mb = Signal(7, reset_less=True)
71 me = Signal(7, reset_less=True)
72 mr = Signal(64, reset_less=True)
73 ml = Signal(64, reset_less=True)
74 output_mode = Signal(2, reset_less=True)
75 hi32 = Signal(32, reset_less=True)
76 repl32 = Signal(64, reset_less=True)
77
78 # First replicate bottom 32 bits to both halves if 32-bit
79 with m.If(self.is_32bit):
80 comb += hi32.eq(rs[0:32])
81 with m.Elif(self.sign_ext_rs):
82 # sign-extend bottom 32 bits
83 comb += hi32.eq(Repl(rs[31], 32))
84 with m.Else():
85 comb += hi32.eq(rs[32:64])
86 comb += repl32.eq(Cat(rs[0:32], hi32))
87
88 shift_signed = Signal(signed(6))
89 comb += shift_signed.eq(self.shift[0:6])
90
91 # Negate shift count for right shifts
92 with m.If(self.right_shift):
93 comb += rot_count.eq(-shift_signed)
94 with m.Else():
95 comb += rot_count.eq(self.shift[0:6])
96
97 # ROTL submodule
98 m.submodules.rotl = rotl = ROTL(64)
99 comb += rotl.a.eq(repl32)
100 comb += rotl.b.eq(rot_count)
101 comb += rot.eq(rotl.o)
102
103 # Trim shift count to 6 bits for 32-bit shifts
104 comb += sh.eq(Cat(self.shift[0:6], self.shift[6] & ~self.is_32bit))
105
106 # XXX errr... we should already have these, in Fields? oh well
107 # Work out mask begin/end indexes (caution, big-endian bit numbering)
108
109 # mask-begin (mb)
110 with m.If(self.clear_left):
111 comb += mb.eq(self.mb)
112 with m.If(self.is_32bit):
113 comb += mb[5:7].eq(Const(0b01, 2))
114 with m.Else():
115 comb += mb[5:7].eq(Cat(self.mb_extra, Const(0b0, 1)))
116 with m.Elif(self.right_shift):
117 # this is basically mb = sh + (is_32bit? 32: 0);
118 comb += mb.eq(sh)
119 with m.If(self.is_32bit):
120 comb += mb[5:7].eq(Cat(~sh[5], sh[5]))
121 with m.Else():
122 comb += mb.eq(Cat(Const(0b0, 5), self.is_32bit, Const(0b0, 1)))
123
124 # mask-end (me)
125 with m.If(self.clear_right & self.is_32bit):
126 # TODO: track down where this is. have to use fields.
127 comb += me.eq(Cat(self.me, Const(0b01, 2)))
128 with m.Elif(self.clear_right & ~self.clear_left):
129 # this is me, have to use fields
130 comb += me.eq(Cat(self.mb, self.mb_extra, Const(0b0, 1)))
131 with m.Else():
132 # effectively, 63 - sh
133 comb += me.eq(Cat(~sh[0:6], sh[6]))
134
135 # Calculate left and right masks
136 comb += mr.eq(right_mask(m, mb))
137 comb += ml.eq(left_mask(m, me))
138
139 # Work out output mode
140 # 00 for sl[wd]
141 # 0w for rlw*, rldic, rldicr, rldimi, where w = 1 iff mb > me
142 # 10 for rldicl, sr[wd]
143 # 1z for sra[wd][i], z = 1 if rs is negative
144 with m.If((self.clear_left & ~self.clear_right) | self.right_shift):
145 comb += output_mode.eq(Cat(self.arith & repl32[63], Const(1, 1)))
146 with m.Else():
147 mbgt = self.clear_right & (mb[0:6] > me[0:6])
148 comb += output_mode.eq(Cat(mbgt, Const(0, 1)))
149
150 # Generate output from rotated input and masks
151 with m.Switch(output_mode):
152 with m.Case(0b00):
153 comb += self.result_o.eq((rot & (mr & ml)) | (ra & ~(mr & ml)))
154 with m.Case(0b01):
155 comb += self.result_o.eq((rot & (mr | ml)) | (ra & ~(mr | ml)))
156 with m.Case(0b10):
157 comb += self.result_o.eq(rot & mr)
158 with m.Case(0b11):
159 comb += self.result_o.eq(rot | ~mr)
160 # Generate carry output for arithmetic shift right of -ve value
161 comb += self.carry_out_o.eq((rs & ~ml).bool())
162
163 return m
164
165 if __name__ == '__main__':
166
167 m = Module()
168 comb = m.d.comb
169 mr = Signal(64)
170 mb = Signal(6)
171 comb += mr.eq(left_mask(m, mb))
172
173 def loop():
174 for i in range(64):
175 yield mb.eq(63-i)
176 yield Settle()
177 res = yield mr
178 print (i, hex(res))
179
180 run_simulation(m, [loop()],
181 vcd_name="test_mask.vcd")
182