Add shift right to test_partsig and partsig
authorMichael Nolan <mtnolan2640@gmail.com>
Wed, 26 Feb 2020 16:56:24 +0000 (11:56 -0500)
committerMichael Nolan <mtnolan2640@gmail.com>
Wed, 26 Feb 2020 16:56:24 +0000 (11:56 -0500)
src/ieee754/part/partsig.py
src/ieee754/part/test/test_partsig.py

index 33cc5feac86dda735d36623de2e815c45f5854b9..1fc7d4b052aa85becb9113b891563106450bcd80 100644 (file)
@@ -93,7 +93,7 @@ class PartitionedSignal:
     # TODO: detect if the 2nd operand is a Const, a Signal or a
     # PartitionedSignal.  if it's a Const or a Signal, a global shift
     # can occur.  if it's a PartitionedSignal, that's much more interesting.
-    def ls_op(self, op1, op2, carry):
+    def ls_op(self, op1, op2, carry, shr_flag=0):
         op1 = getsig(op1)
         if isinstance(op2, Const) or isinstance(op2, Signal):
             scalar = True
@@ -109,9 +109,11 @@ class PartitionedSignal:
         if scalar:
             comb += pa.data.eq(op1)
             comb += pa.shifter.eq(op2)
+            comb += pa.shift_right.eq(shr_flag)
         else:
             comb += pa.a.eq(op1)
             comb += pa.b.eq(op2)
+            comb += pa.shift_right.eq(shr_flag)
         # XXX TODO: carry-in, carry-out
         #comb += pa.carry_in.eq(carry)
         return (pa.output, 0)
@@ -126,8 +128,9 @@ class PartitionedSignal:
         return Operator("<<", [other, self])
 
     def __rshift__(self, other):
-        raise NotImplementedError
-        return Operator(">>", [self, other])
+        z = Const(0, len(self.partpoints)+1)
+        result, _ = self.ls_op(self, other, carry=z, shr_flag=1) # TODO, carry
+        return result
 
     def __rrshift__(self, other):
         raise NotImplementedError
index bfb8846828aba3ba227c53422bfac58d1231c484..1c980bad35ec8ae5d1411768fd82c6a64416ef8e 100644 (file)
@@ -57,6 +57,8 @@ class TestAddMod2(Elaboratable):
         self.add_output = Signal(width)
         self.ls_output = Signal(width) # left shift
         self.ls_scal_output = Signal(width) # left shift
+        self.rs_output = Signal(width) # left shift
+        self.rs_scal_output = Signal(width) # left shift
         self.sub_output = Signal(width)
         self.eq_output = Signal(len(partpoints)+1)
         self.gt_output = Signal(len(partpoints)+1)
@@ -98,11 +100,13 @@ class TestAddMod2(Elaboratable):
         sync += self.neg_output.eq(-self.a)
         # left shift
         sync += self.ls_output.eq(self.a << self.b)
+        sync += self.rs_output.eq(self.a >> self.b)
         ppts = self.partpoints
         sync += self.mux_out.eq(PMux(m, ppts, self.mux_sel, self.a, self.b))
         # scalar left shift
         comb += self.bsig.eq(self.b.sig)
         sync += self.ls_scal_output.eq(self.a << self.bsig)
+        sync += self.rs_scal_output.eq(self.a >> self.bsig)
 
         return m
 
@@ -116,6 +120,8 @@ class TestAddMod(Elaboratable):
         self.add_output = Signal(width)
         self.ls_output = Signal(width) # left shift
         self.ls_scal_output = Signal(width) # left shift
+        self.rs_output = Signal(width) # left shift
+        self.rs_scal_output = Signal(width) # left shift
         self.sub_output = Signal(width)
         self.eq_output = Signal(len(partpoints)+1)
         self.gt_output = Signal(len(partpoints)+1)
@@ -157,11 +163,13 @@ class TestAddMod(Elaboratable):
         comb += self.neg_output.eq(-self.a)
         # left shift
         comb += self.ls_output.eq(self.a << self.b)
+        comb += self.rs_output.eq(self.a >> self.b)
         ppts = self.partpoints
         comb += self.mux_out.eq(PMux(m, ppts, self.mux_sel, self.a, self.b))
         # scalar left shift
         comb += self.bsig.eq(self.b.sig)
         comb += self.ls_scal_output.eq(self.a << self.bsig)
+        comb += self.rs_scal_output.eq(self.a >> self.bsig)
 
         return m
 
@@ -199,6 +207,23 @@ class TestPartitionPoints(unittest.TestCase):
                 print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
                 return result, carry
 
+            def test_rs_scal_fn(carry_in, a, b, mask):
+                # reduce range of b
+                bits = count_bits(mask)
+                newb = b & ((bits-1))
+                print ("%x %x %x bits %d trunc %x" % \
+                        (a, b, mask, bits, newb))
+                b = newb
+                # TODO: carry
+                carry_in = 0
+                lsb = mask & ~(mask-1) if carry_in else 0
+                sum = ((a & mask) >> b)
+                result = mask & sum
+                carry = (sum & mask) != sum
+                carry = 0
+                print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
+                return result, carry
+
             def test_ls_fn(carry_in, a, b, mask):
                 # reduce range of b
                 bits = count_bits(mask)
@@ -219,6 +244,26 @@ class TestPartitionPoints(unittest.TestCase):
                 print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
                 return result, carry
 
+            def test_rs_fn(carry_in, a, b, mask):
+                # reduce range of b
+                bits = count_bits(mask)
+                fz = first_zero(mask)
+                newb = b & ((bits-1)<<fz)
+                print ("%x %x %x bits %d zero %d trunc %x" % \
+                        (a, b, mask, bits, fz, newb))
+                b = newb
+                # TODO: carry
+                carry_in = 0
+                lsb = mask & ~(mask-1) if carry_in else 0
+                b = (b & mask)
+                b = b >>fz
+                sum = ((a & mask) >> b)
+                result = mask & sum
+                carry = (sum & mask) != sum
+                carry = 0
+                print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
+                return result, carry
+
             def test_add_fn(carry_in, a, b, mask):
                 lsb = mask & ~(mask-1) if carry_in else 0
                 sum = (a & mask) + (b & mask) + lsb
@@ -279,6 +324,8 @@ class TestPartitionPoints(unittest.TestCase):
             for (test_fn, mod_attr) in (
                                         (test_ls_scal_fn, "ls_scal"),
                                         (test_ls_fn, "ls"),
+                                        (test_rs_scal_fn, "rs_scal"),
+                                        (test_rs_fn, "rs"),
                                         (test_add_fn, "add"),
                                         (test_sub_fn, "sub"),
                                         (test_neg_fn, "neg"),