Convert a few more tests to be able to use cxxsim
[soc.git] / src / soc / fu / div / fsm.py
1 import enum
2 from nmigen import Elaboratable, Module, Signal, Shape, unsigned, Cat, Mux
3 from soc.fu.div.pipe_data import CoreInputData, CoreOutputData, DivPipeSpec
4 from nmutil.iocontrol import PrevControl, NextControl
5 from nmutil.singlepipe import ControlBase
6 from ieee754.div_rem_sqrt_rsqrt.core import DivPipeCoreOperation
7
8
9 class FSMDivCoreConfig:
10 n_stages = 1
11 bit_width = 64
12 fract_width = 64
13
14
15 class FSMDivCoreInputData:
16 def __init__(self, core_config, reset_less=True):
17 self.core_config = core_config
18 self.dividend = Signal(128, reset_less=reset_less)
19 self.divisor_radicand = Signal(64, reset_less=reset_less)
20 self.operation = DivPipeCoreOperation.create_signal(
21 reset_less=reset_less)
22
23 def __iter__(self):
24 """ Get member signals. """
25 yield self.dividend
26 yield self.divisor_radicand
27 yield self.operation
28
29 def eq(self, rhs):
30 """ Assign member signals. """
31 return [self.dividend.eq(rhs.dividend),
32 self.divisor_radicand.eq(rhs.divisor_radicand),
33 self.operation.eq(rhs.operation),
34 ]
35
36
37 class FSMDivCoreOutputData:
38 def __init__(self, core_config, reset_less=True):
39 self.core_config = core_config
40 self.quotient_root = Signal(64, reset_less=reset_less)
41 self.remainder = Signal(3 * 64, reset_less=reset_less)
42
43 def __iter__(self):
44 """ Get member signals. """
45 yield self.quotient_root
46 yield self.remainder
47 return
48
49 def eq(self, rhs):
50 """ Assign member signals. """
51 return [self.quotient_root.eq(rhs.quotient_root),
52 self.remainder.eq(rhs.remainder)]
53
54
55 class DivStateNext(Elaboratable):
56 def __init__(self, quotient_width):
57 self.quotient_width = quotient_width
58 self.i = DivState(quotient_width=quotient_width, name="i")
59 self.divisor = Signal(quotient_width)
60 self.o = DivState(quotient_width=quotient_width, name="o")
61
62 def elaborate(self, platform):
63 m = Module()
64 difference = Signal(self.i.quotient_width * 2)
65 m.d.comb += difference.eq(self.i.dividend_quotient
66 - (self.divisor
67 << (self.quotient_width - 1)))
68 next_quotient_bit = Signal()
69 m.d.comb += next_quotient_bit.eq(
70 ~difference[self.quotient_width * 2 - 1])
71 value = Signal(self.i.quotient_width * 2)
72 with m.If(next_quotient_bit):
73 m.d.comb += value.eq(difference)
74 with m.Else():
75 m.d.comb += value.eq(self.i.dividend_quotient)
76
77 with m.If(self.i.done):
78 m.d.comb += self.o.eq(self.i)
79 with m.Else():
80 m.d.comb += [
81 self.o.q_bits_known.eq(self.i.q_bits_known + 1),
82 self.o.dividend_quotient.eq(Cat(next_quotient_bit, value))]
83 return m
84
85
86 class DivStateInit(Elaboratable):
87 def __init__(self, quotient_width):
88 self.quotient_width = quotient_width
89 self.dividend = Signal(quotient_width * 2)
90 self.o = DivState(quotient_width=quotient_width, name="o")
91
92 def elaborate(self, platform):
93 m = Module()
94 m.d.comb += self.o.q_bits_known.eq(0)
95 m.d.comb += self.o.dividend_quotient.eq(self.dividend)
96 return m
97
98
99 class DivState:
100 def __init__(self, quotient_width, name):
101 self.quotient_width = quotient_width
102 self.q_bits_known = Signal(range(1 + quotient_width),
103 name=name + "_q_bits_known")
104 self.dividend_quotient = Signal(unsigned(2 * quotient_width),
105 name=name + "_dividend_quotient")
106
107 @property
108 def done(self):
109 return self.q_bits_known == self.quotient_width
110
111 @property
112 def quotient(self):
113 """ get the quotient -- requires self.done is True """
114 return self.dividend_quotient[0:self.quotient_width]
115
116 @property
117 def remainder(self):
118 """ get the remainder -- requires self.done is True """
119 return self.dividend_quotient[self.quotient_width:self.quotient_width*2]
120
121 def eq(self, rhs):
122 return [self.q_bits_known.eq(rhs.q_bits_known),
123 self.dividend_quotient.eq(rhs.dividend_quotient)]
124
125
126 class FSMDivCoreStage(ControlBase):
127 def __init__(self, pspec):
128 super().__init__()
129 self.pspec = pspec
130 self.p.data_i = CoreInputData(pspec)
131 self.n.data_o = CoreOutputData(pspec)
132 self.saved_input_data = CoreInputData(pspec)
133 self.empty = Signal(reset=1)
134 self.saved_state = DivState(64, name="saved_state")
135 self.div_state_next = DivStateNext(64)
136 self.div_state_init = DivStateInit(64)
137 self.divisor = Signal(unsigned(64))
138
139 def elaborate(self, platform):
140 m = super().elaborate(platform)
141 m.submodules.div_state_next = self.div_state_next
142 m.submodules.div_state_init = self.div_state_init
143 data_i = self.p.data_i
144 data_o = self.n.data_o
145 core_i = data_i.core
146 core_o = data_o.core
147
148 core_saved_i = self.saved_input_data.core
149
150 # TODO: handle cancellation
151
152 m.d.comb += self.div_state_init.dividend.eq(core_i.dividend)
153
154 m.d.comb += data_o.eq_without_core(self.saved_input_data)
155 m.d.comb += core_o.quotient_root.eq(self.div_state_next.o.quotient)
156 # fract width of `DivPipeCoreOutputData.remainder`
157 remainder_fract_width = 64 * 3
158 # fract width of `DivPipeCoreInputData.dividend`
159 dividend_fract_width = 64 * 2
160 rem_start = remainder_fract_width - dividend_fract_width
161 m.d.comb += core_o.remainder.eq(self.div_state_next.o.remainder
162 << rem_start)
163 m.d.comb += self.n.valid_o.eq(~self.empty & self.div_state_next.o.done)
164 m.d.comb += self.p.ready_o.eq(self.empty)
165 m.d.sync += self.saved_state.eq(self.div_state_next.o)
166
167 with m.If(self.empty):
168 m.d.comb += self.div_state_next.i.eq(self.div_state_init.o)
169 m.d.comb += self.div_state_next.divisor.eq(core_i.divisor_radicand)
170 with m.If(self.p.valid_i):
171 m.d.sync += self.empty.eq(0)
172 m.d.sync += self.saved_input_data.eq(data_i)
173 with m.Else():
174 m.d.comb += [
175 self.div_state_next.i.eq(self.saved_state),
176 self.div_state_next.divisor.eq(core_saved_i.divisor_radicand)]
177 with m.If(self.n.ready_i & self.n.valid_o):
178 m.d.sync += self.empty.eq(1)
179
180 return m
181
182 def __iter__(self):
183 yield from self.p
184 yield from self.n
185
186 def ports(self):
187 return list(self)