Allow the formal engine to perform a same-cycle result in the ALU
[soc.git] / src / soc / fu / logical / countzero.py
1 # https://github.com/antonblanchard/microwatt/blob/master/countzero.vhdl
2 from nmigen import Memory, Module, Signal, Cat, Elaboratable
3 from nmigen.hdl.rec import Record, Layout
4 from nmigen.cli import main
5
6
7 def or4(a, b, c, d):
8 return Cat(a.any(), b.any(), c.any(), d.any())
9
10
11 class IntermediateResult(Record):
12 def __init__(self, name=None):
13 layout = (('v16', 15),
14 ('sel_hi', 2),
15 ('is_32bit', 1),
16 ('count_right', 1))
17 Record.__init__(self, Layout(layout), name=name)
18
19
20 class ZeroCounter(Elaboratable):
21 def __init__(self):
22 self.rs_i = Signal(64, reset_less=True)
23 self.count_right_i = Signal(1, reset_less=True)
24 self.is_32bit_i = Signal(1, reset_less=True)
25 self.result_o = Signal(64, reset_less=True)
26
27 def ports(self):
28 return [self.rs_i, self.count_right_i, self.is_32bit_i, self.result_o]
29
30 def elaborate(self, platform):
31 m = Module()
32
33 # TODO: replace this with m.submodule.pe1 = PriorityEncoder(4)
34 # m.submodule.pe2 = PriorityEncoder(4)
35 # m.submodule.pe3 = PriorityEncoder(4)
36 # etc.
37 # and where right will assign input to v and !right will assign v[::-1]
38 # so as to reverse the order of the input bits.
39
40 def encoder(v, right):
41 """
42 Return the index of the leftmost or rightmost 1 in a set of 4 bits.
43 Assumes v is not "0000"; if it is, return (right ? "11" : "00").
44 """
45 ret = Signal(2, reset_less=True)
46 with m.If(right):
47 with m.If(v[0]):
48 m.d.comb += ret.eq(0)
49 with m.Elif(v[1]):
50 m.d.comb += ret.eq(1)
51 with m.Elif(v[2]):
52 m.d.comb += ret.eq(2)
53 with m.Else():
54 m.d.comb += ret.eq(3)
55 with m.Else():
56 with m.If(v[3]):
57 m.d.comb += ret.eq(3)
58 with m.Elif(v[2]):
59 m.d.comb += ret.eq(2)
60 with m.Elif(v[1]):
61 m.d.comb += ret.eq(1)
62 with m.Else():
63 m.d.comb += ret.eq(0)
64 return ret
65
66 r = IntermediateResult()
67 r_in = IntermediateResult()
68
69 m.d.comb += r.eq(r_in) # make the module entirely combinatorial for now
70
71 v = IntermediateResult()
72 y = Signal(4, reset_less=True)
73 z = Signal(4, reset_less=True)
74 sel = Signal(6, reset_less=True)
75 v4 = Signal(4, reset_less=True)
76
77 # Test 4 groups of 16 bits each.
78 # The top 2 groups are considered to be zero in 32-bit mode.
79 m.d.comb += z.eq(or4(self.rs_i[0:16], self.rs_i[16:32],
80 self.rs_i[32:48], self.rs_i[48:64]))
81 with m.If(self.is_32bit_i):
82 m.d.comb += v.sel_hi[1].eq(0)
83 with m.If(self.count_right_i):
84 m.d.comb += v.sel_hi[0].eq(~z[0])
85 with m.Else():
86 m.d.comb += v.sel_hi[0].eq(z[1])
87 with m.Else():
88 m.d.comb += v.sel_hi.eq(encoder(z, self.count_right_i))
89
90 # Select the leftmost/rightmost non-zero group of 16 bits
91 with m.Switch(v.sel_hi):
92 with m.Case(0):
93 m.d.comb += v.v16.eq(self.rs_i[0:16])
94 with m.Case(1):
95 m.d.comb += v.v16.eq(self.rs_i[16:32])
96 with m.Case(2):
97 m.d.comb += v.v16.eq(self.rs_i[32:48])
98 with m.Case(3):
99 m.d.comb += v.v16.eq(self.rs_i[48:64])
100
101 # Latch this and do the rest in the next cycle, for the sake of timing
102 m.d.comb += v.is_32bit.eq(self.is_32bit_i)
103 m.d.comb += v.count_right.eq(self.count_right_i)
104 m.d.comb += r_in.eq(v)
105 m.d.comb += sel[4:6].eq(r.sel_hi)
106
107 # Test 4 groups of 4 bits
108 m.d.comb += y.eq(or4(r.v16[0:4], r.v16[4:8],
109 r.v16[8:12], r.v16[12:16]))
110 m.d.comb += sel[2:4].eq(encoder(y, r.count_right))
111
112 # Select the leftmost/rightmost non-zero group of 4 bits
113 with m.Switch(sel[2:4]):
114 with m.Case(0):
115 m.d.comb += v4.eq(r.v16[0:4])
116 with m.Case(1):
117 m.d.comb += v4.eq(r.v16[4:8])
118 with m.Case(2):
119 m.d.comb += v4.eq(r.v16[8:12])
120 with m.Case(3):
121 m.d.comb += v4.eq(r.v16[12:16])
122
123 m.d.comb += sel[0:2].eq(encoder(v4, r.count_right))
124
125 # sel is now the index of the leftmost/rightmost 1 bit in rs
126 o = self.result_o
127 with m.If(v4 == 0):
128 # operand is zero, return 32 for 32-bit, else 64
129 m.d.comb += o[5:7].eq(Cat(r.is_32bit, ~r.is_32bit))
130 with m.Elif(r.count_right):
131 # return (63 - sel), trimmed to 5 bits in 32-bit mode
132 m.d.comb += o.eq(Cat(~sel[0:5], ~(sel[5] | r.is_32bit)))
133 with m.Else():
134 m.d.comb += o.eq(sel)
135
136 return m