Fix carry output of adder/subtracter
[ieee754fpu.git] / src / ieee754 / part_mul_add / adder.py
index 2e504368c1c664a3b51b553338ca4231d0431d3b..ebcbe75e12594078082e1ae48580b9fa5977115f 100644 (file)
@@ -6,16 +6,10 @@ See:
 * https://libre-riscv.org/3d_gpu/architecture/dynamic_simd/add/
 """
 
-from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
-from nmigen.hdl.ast import Assign
-from abc import ABCMeta, abstractmethod
-from nmigen.cli import main
-from functools import reduce
-from operator import or_
-from ieee754.pipeline import PipelineSpec
-from nmutil.pipemodbase import PipeModBase
+from nmigen import Signal, Module, Elaboratable, Cat
 
 from ieee754.part_mul_add.partpoints import PartitionPoints
+from ieee754.part_cmp.ripple import MoveMSBDown
 
 
 class FullAdder(Elaboratable):
@@ -176,13 +170,17 @@ class PartitionedAdder(Elaboratable):
         """Elaborate this module."""
         m = Module()
         comb = m.d.comb
+
+        carry_tmp = Signal(self.carry_out.width)
+        m.submodules.ripple = ripple = MoveMSBDown(self.carry_out.width)
+
         expanded_a = Signal(self._expanded_width, reset_less=True)
         expanded_b = Signal(self._expanded_width, reset_less=True)
         expanded_o = Signal(self._expanded_width, reset_less=True)
 
         expanded_index = 0
         # store bits in a list, use Cat later.  graphviz is much cleaner
-        al, bl, ol, cl, ea, eb, eo, co = [],[],[],[],[],[],[],[]
+        al, bl, ol, cl, ea, eb, eo, co = [], [], [], [], [], [], [], []
 
         # partition points are "breaks" (extra zeros or 1s) in what would
         # otherwise be a massive long add.  when the "break" points are 0,
@@ -201,20 +199,22 @@ class PartitionedAdder(Elaboratable):
         expanded_index += 1
 
         for i in range(self.width):
-            pi = i/self.pmul # double the range of the partition point test
+            pi = i/self.pmul  # double the range of the partition point test
             if pi.is_integer() and pi in self.part_pts:
-                # add extra bit set to 0 + 0 for enabled partition points
+                # add extra bit set to carry + carry for enabled
+                # partition points
                 a_bit = Signal(name="a_bit_%d" % i, reset_less=True)
-                carry_in = self.carry_in[carry_bit] # convenience
+                carry_in = self.carry_in[carry_bit]  # convenience
                 m.d.comb += a_bit.eq(self.part_pts[pi].implies(carry_in))
+
                 # and 1 + 0 for disabled partition points
                 ea.append(expanded_a[expanded_index])
-                al.append(a_bit) # add extra bit in a
+                al.append(a_bit)  # add extra bit in a
                 eb.append(expanded_b[expanded_index])
-                bl.append(carry_in & self.part_pts[pi]) # yes, add a zero
+                bl.append(carry_in & self.part_pts[pi])  # carry bit
                 co.append(expanded_o[expanded_index])
-                cl.append(self.carry_out[carry_bit-1])
-                expanded_index += 1 # skip the extra point.  NOT in the output
+                cl.append(carry_tmp[carry_bit-1])
+                expanded_index += 1  # skip the extra point.  NOT in the output
                 carry_bit += 1
             ea.append(expanded_a[expanded_index])
             eb.append(expanded_b[expanded_index])
@@ -225,8 +225,8 @@ class PartitionedAdder(Elaboratable):
             expanded_index += 1
         al.append(0)
         bl.append(0)
-        co.append(expanded_o[expanded_index])
-        cl.append(self.carry_out[carry_bit-1])
+        co.append(expanded_o[-1])
+        cl.append(carry_tmp[carry_bit-1])
 
         # combine above using Cat
         comb += Cat(*ea).eq(Cat(*al))
@@ -238,6 +238,8 @@ class PartitionedAdder(Elaboratable):
         # special hardware on FPGAs
         comb += expanded_o.eq(expanded_a + expanded_b)
 
-        return m
-
+        comb += ripple.results_in.eq(carry_tmp)
+        comb += ripple.gates.eq(self.part_pts.as_sig())
+        comb += self.carry_out.eq(ripple.output)
 
+        return m