035afa47832357746df258bc2b606b72b621cf88
[soc.git] / src / soc / shift_rot / rotator.py
1 # Manual translation and adaptation of rotator.vhdl from microwatt into nmigen
2 #
3
4 from nmigen import (Elaboratable, Signal, Module, Const, Cat,
5 unsigned, signed)
6 from soc.shift_rot.rotl import ROTL
7
8 # note BE bit numbering
9 def right_mask(m, mask_begin):
10 ret = Signal(64, name="right_mask", reset_less=True)
11 with m.If(mask_begin <= 64):
12 m.d.comb += ret.eq((1<<(64-mask_begin)) - 1)
13 return ret
14
15 def left_mask(m, mask_end):
16 ret = Signal(64, name="left_mask", reset_less=True)
17 m.d.comb += ret.eq(~((1<<(63-mask_end)) - 1))
18 return ret
19
20
21 class Rotator(Elaboratable):
22 """Rotator: covers multiple POWER9 rotate functions
23
24 supported modes:
25
26 * sl[wd]
27 * rlw*, rldic, rldicr, rldimi
28 * rldicl, sr[wd]
29 * sra[wd][i]
30
31 use as follows:
32
33 * shift = RB[0:7]
34 * arith = 1 when is_signed
35 * right_shift = 1 when insn_type is OP_SHR
36 * clear_left = 1 when insn_type is OP_RLC or OP_RLCL
37 * clear_right = 1 when insn_type is OP_RLC or OP_RLCR
38 """
39 def __init__(self):
40 # input
41 self.me = Signal(5, reset_less=True) # ME field
42 self.mb = Signal(5, reset_less=True) # MB field
43 self.XO = Signal(1, reset_less=True) # XO field
44 self.ra = Signal(64, reset_less=True) # RA
45 self.rs = Signal(64, reset_less=True) # RS
46 self.ra = Signal(64, reset_less=True) # RA
47 self.shift = Signal(7, reset_less=True) # RB[0:7]
48 self.is_32bit = Signal(reset_less=True)
49 self.right_shift = Signal(reset_less=True)
50 self.arith = Signal(reset_less=True)
51 self.clear_left = Signal(reset_less=True)
52 self.clear_right = Signal(reset_less=True)
53 # output
54 self.result_o = Signal(64, reset_less=True)
55 self.carry_out_o = Signal(reset_less=True)
56
57 def elaborate(self, platform):
58 m = Module()
59 comb = m.d.comb
60 ra, rs = self.ra, self.rs
61
62 # temporaries
63 rot_in = Signal(64, reset_less=True)
64 rot_count = Signal(6, reset_less=True)
65 rot = Signal(64, reset_less=True)
66 sh = Signal(7, reset_less=True)
67 mb = Signal(7, reset_less=True)
68 me = Signal(7, reset_less=True)
69 mr = Signal(64, reset_less=True)
70 ml = Signal(64, reset_less=True)
71 output_mode = Signal(2, reset_less=True)
72
73 # First replicate bottom 32 bits to both halves if 32-bit
74 comb += rot_in[0:32].eq(rs[0:32])
75 with m.If(self.is_32bit):
76 comb += rot_in[32:64].eq(rs[0:32])
77 with m.Else():
78 comb += rot_in[32:64].eq(rs[32:64])
79
80 shift_signed = Signal(signed(6))
81 comb += shift_signed.eq(self.shift[0:6])
82
83 # Negate shift count for right shifts
84 with m.If(self.right_shift):
85 comb += rot_count.eq(-shift_signed)
86 with m.Else():
87 comb += rot_count.eq(self.shift[0:6])
88
89 # ROTL submodule
90 m.submodules.rotl = rotl = ROTL(64)
91 comb += rotl.a.eq(rot_in)
92 comb += rotl.b.eq(rot_count)
93 comb += rot.eq(rotl.o)
94
95 # Trim shift count to 6 bits for 32-bit shifts
96 comb += sh.eq(Cat(self.shift[0:6], self.shift[6] & ~self.is_32bit))
97
98 # XXX errr... we should already have these, in Fields? oh well
99 # Work out mask begin/end indexes (caution, big-endian bit numbering)
100
101 # mask-begin (mb)
102 with m.If(self.clear_left):
103 with m.If(self.is_32bit):
104 comb += mb.eq(Cat(self.mb, Const(0b01, 2)))
105 with m.Else():
106 comb += mb.eq(Cat(self.mb, self.XO, Const(0b0, 1)))
107 with m.Elif(self.right_shift):
108 # this is basically mb = sh + (is_32bit? 32: 0);
109 with m.If(self.is_32bit):
110 comb += mb.eq(Cat(sh[0:5], ~sh[5], sh[5]))
111 with m.Else():
112 comb += mb.eq(sh)
113 with m.Else():
114 comb += mb.eq(Cat(Const(0b0, 5), self.is_32bit, Const(0b0, 1)))
115
116 # mask-end (me)
117 with m.If(self.clear_right & self.is_32bit):
118 # TODO: track down where this is. have to use fields.
119 comb += me.eq(Cat(self.me, Const(0b01, 2)))
120 with m.Elif(self.clear_right & ~self.clear_left):
121 # this is me, have to use fields
122 comb += me.eq(Cat(self.mb, self.XO, Const(0b0, 1)))
123 with m.Else():
124 # effectively, 63 - sh
125 comb += me.eq(Cat(~self.shift[0:6], self.shift[6]))
126
127 # Calculate left and right masks
128 comb += mr.eq(right_mask(m, mb))
129 comb += ml.eq(left_mask(m, me))
130
131 # Work out output mode
132 # 00 for sl[wd]
133 # 0w for rlw*, rldic, rldicr, rldimi, where w = 1 iff mb > me
134 # 10 for rldicl, sr[wd]
135 # 1z for sra[wd][i], z = 1 if rs is negative
136 with m.If((self.clear_left & ~self.clear_right) | self.right_shift):
137 comb += output_mode.eq(Cat(self.arith & rot_in[63], Const(1, 1)))
138 with m.Else():
139 mbgt = self.clear_right & (mb[0:6] > me[0:6])
140 comb += output_mode.eq(Cat(mbgt, Const(0, 1)))
141
142 # Generate output from rotated input and masks
143 with m.Switch(output_mode):
144 with m.Case(0b00):
145 comb += self.result_o.eq((rot & (mr & ml)) | (ra & ~(mr & ml)))
146 with m.Case(0b01):
147 comb += self.result_o.eq((rot & (mr | ml)) | (ra & ~(mr | ml)))
148 with m.Case(0b10):
149 comb += self.result_o.eq(rot & mr)
150 with m.Case(0b11):
151 comb += self.result_o.eq(rot | ~mr)
152 # Generate carry output for arithmetic shift right of -ve value
153 comb += self.carry_out_o.eq(rs & ~ml)
154
155 return m
156