Fix carry output of adder/subtracter
authorMichael Nolan <mtnolan2640@gmail.com>
Mon, 10 Feb 2020 16:23:11 +0000 (11:23 -0500)
committerMichael Nolan <mtnolan2640@gmail.com>
Mon, 10 Feb 2020 16:23:11 +0000 (11:23 -0500)
src/ieee754/part/test/test_partsig.py
src/ieee754/part_mul_add/adder.py

index 53647082033ed646fad84d4067ef07d60fdf233c..b209aec18a5ab582f802ffb138fbbd057a045d41 100644 (file)
@@ -12,6 +12,7 @@ from ieee754.part_mux.part_mux import PMux
 from random import randint
 import unittest
 import itertools
+import math
 
 
 def perms(k):
@@ -71,7 +72,7 @@ class TestAddMod(Elaboratable):
         sub_out, sub_carry = self.a.sub_op(self.a, self.b,
                                            self.carry_in)
         comb += self.sub_output.eq(sub_out)
-        comb += self.sub_carry_out.eq(add_carry)
+        comb += self.sub_carry_out.eq(sub_carry)
         comb += self.neg_output.eq(-self.a)
         ppts = self.partpoints
         comb += self.mux_out.eq(PMux(m, ppts, self.mux_sel, self.a, self.b))
@@ -97,14 +98,21 @@ class TestPartitionPoints(unittest.TestCase):
 
             def test_add_fn(carry_in, a, b, mask):
                 lsb = mask & ~(mask-1) if carry_in else 0
-                return mask & ((a & mask) + (b & mask) + lsb)
+                sum = (a & mask) + (b & mask) + lsb
+                result = mask & sum
+                carry = (sum & mask) != sum
+                print(a, b, sum, mask)
+                return result, carry
 
             def test_sub_fn(carry_in, a, b, mask):
                 lsb = mask & ~(mask-1) if carry_in else 0
-                return mask & ((a & mask) + (~b & mask) + lsb)
+                sum = (a & mask) + (~b & mask) + lsb
+                result = mask & sum
+                carry = (sum & mask) != sum
+                return result, carry
 
             def test_neg_fn(carry_in, a, b, mask):
-                return mask & ((a & mask) + (~0 & mask))
+                return test_add_fn(0, a, ~0, mask)
 
             def test_op(msg_prefix, carry, test_fn, mod_attr, *mask_list):
                 rand_data = []
@@ -124,14 +132,25 @@ class TestPartitionPoints(unittest.TestCase):
                     yield module.carry_in.eq(carry_sig)
                     yield Delay(0.1e-6)
                     y = 0
+                    carry_result = 0
                     for i, mask in enumerate(mask_list):
-                        y |= test_fn(carry, a, b, mask)
+                        res, c = test_fn(carry, a, b, mask)
+                        y |= res
+                        lsb = mask & ~(mask - 1)
+                        bit_set = int(math.log2(lsb))
+                        carry_result |= c << int(bit_set/4)
                     outval = (yield getattr(module, "%s_output" % mod_attr))
                     # TODO: get (and test) carry output as well
                     print(a, b, outval, carry)
                     msg = f"{msg_prefix}: 0x{a:X} + 0x{b:X}" + \
-                        f" => 0x{y:X} != 0x{outval:X}"
+                        f" => 0x{y:X} != 0x{outval:X} ({mod_attr})"
                     self.assertEqual(y, outval, msg)
+                    if hasattr(module, "%s_carry_out" % mod_attr):
+                        c_outval = (yield getattr(module,
+                                                  "%s_carry_out" % mod_attr))
+                        msg = f"{msg_prefix}: 0x{a:X} + 0x{b:X}" + \
+                            f" => 0x{carry_result:X} != 0x{c_outval:X} ({mod_attr})"
+                        self.assertEqual(carry_result, c_outval, msg)
 
             for (test_fn, mod_attr) in ((test_add_fn, "add"),
                                         (test_sub_fn, "sub"),
index 2e504368c1c664a3b51b553338ca4231d0431d3b..ebcbe75e12594078082e1ae48580b9fa5977115f 100644 (file)
@@ -6,16 +6,10 @@ See:
 * https://libre-riscv.org/3d_gpu/architecture/dynamic_simd/add/
 """
 
-from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
-from nmigen.hdl.ast import Assign
-from abc import ABCMeta, abstractmethod
-from nmigen.cli import main
-from functools import reduce
-from operator import or_
-from ieee754.pipeline import PipelineSpec
-from nmutil.pipemodbase import PipeModBase
+from nmigen import Signal, Module, Elaboratable, Cat
 
 from ieee754.part_mul_add.partpoints import PartitionPoints
+from ieee754.part_cmp.ripple import MoveMSBDown
 
 
 class FullAdder(Elaboratable):
@@ -176,13 +170,17 @@ class PartitionedAdder(Elaboratable):
         """Elaborate this module."""
         m = Module()
         comb = m.d.comb
+
+        carry_tmp = Signal(self.carry_out.width)
+        m.submodules.ripple = ripple = MoveMSBDown(self.carry_out.width)
+
         expanded_a = Signal(self._expanded_width, reset_less=True)
         expanded_b = Signal(self._expanded_width, reset_less=True)
         expanded_o = Signal(self._expanded_width, reset_less=True)
 
         expanded_index = 0
         # store bits in a list, use Cat later.  graphviz is much cleaner
-        al, bl, ol, cl, ea, eb, eo, co = [],[],[],[],[],[],[],[]
+        al, bl, ol, cl, ea, eb, eo, co = [], [], [], [], [], [], [], []
 
         # partition points are "breaks" (extra zeros or 1s) in what would
         # otherwise be a massive long add.  when the "break" points are 0,
@@ -201,20 +199,22 @@ class PartitionedAdder(Elaboratable):
         expanded_index += 1
 
         for i in range(self.width):
-            pi = i/self.pmul # double the range of the partition point test
+            pi = i/self.pmul  # double the range of the partition point test
             if pi.is_integer() and pi in self.part_pts:
-                # add extra bit set to 0 + 0 for enabled partition points
+                # add extra bit set to carry + carry for enabled
+                # partition points
                 a_bit = Signal(name="a_bit_%d" % i, reset_less=True)
-                carry_in = self.carry_in[carry_bit] # convenience
+                carry_in = self.carry_in[carry_bit]  # convenience
                 m.d.comb += a_bit.eq(self.part_pts[pi].implies(carry_in))
+
                 # and 1 + 0 for disabled partition points
                 ea.append(expanded_a[expanded_index])
-                al.append(a_bit) # add extra bit in a
+                al.append(a_bit)  # add extra bit in a
                 eb.append(expanded_b[expanded_index])
-                bl.append(carry_in & self.part_pts[pi]) # yes, add a zero
+                bl.append(carry_in & self.part_pts[pi])  # carry bit
                 co.append(expanded_o[expanded_index])
-                cl.append(self.carry_out[carry_bit-1])
-                expanded_index += 1 # skip the extra point.  NOT in the output
+                cl.append(carry_tmp[carry_bit-1])
+                expanded_index += 1  # skip the extra point.  NOT in the output
                 carry_bit += 1
             ea.append(expanded_a[expanded_index])
             eb.append(expanded_b[expanded_index])
@@ -225,8 +225,8 @@ class PartitionedAdder(Elaboratable):
             expanded_index += 1
         al.append(0)
         bl.append(0)
-        co.append(expanded_o[expanded_index])
-        cl.append(self.carry_out[carry_bit-1])
+        co.append(expanded_o[-1])
+        cl.append(carry_tmp[carry_bit-1])
 
         # combine above using Cat
         comb += Cat(*ea).eq(Cat(*al))
@@ -238,6 +238,8 @@ class PartitionedAdder(Elaboratable):
         # special hardware on FPGAs
         comb += expanded_o.eq(expanded_a + expanded_b)
 
-        return m
-
+        comb += ripple.results_in.eq(carry_tmp)
+        comb += ripple.gates.eq(self.part_pts.as_sig())
+        comb += self.carry_out.eq(ripple.output)
 
+        return m