From: Michael Nolan Date: Wed, 26 Feb 2020 16:56:24 +0000 (-0500) Subject: Add shift right to test_partsig and partsig X-Git-Tag: ls180-24jan2020~117 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=e42917b58f03e32aefb38dbea1deea4f359f5511;p=ieee754fpu.git Add shift right to test_partsig and partsig --- diff --git a/src/ieee754/part/partsig.py b/src/ieee754/part/partsig.py index 33cc5fea..1fc7d4b0 100644 --- a/src/ieee754/part/partsig.py +++ b/src/ieee754/part/partsig.py @@ -93,7 +93,7 @@ class PartitionedSignal: # TODO: detect if the 2nd operand is a Const, a Signal or a # PartitionedSignal. if it's a Const or a Signal, a global shift # can occur. if it's a PartitionedSignal, that's much more interesting. - def ls_op(self, op1, op2, carry): + def ls_op(self, op1, op2, carry, shr_flag=0): op1 = getsig(op1) if isinstance(op2, Const) or isinstance(op2, Signal): scalar = True @@ -109,9 +109,11 @@ class PartitionedSignal: if scalar: comb += pa.data.eq(op1) comb += pa.shifter.eq(op2) + comb += pa.shift_right.eq(shr_flag) else: comb += pa.a.eq(op1) comb += pa.b.eq(op2) + comb += pa.shift_right.eq(shr_flag) # XXX TODO: carry-in, carry-out #comb += pa.carry_in.eq(carry) return (pa.output, 0) @@ -126,8 +128,9 @@ class PartitionedSignal: return Operator("<<", [other, self]) def __rshift__(self, other): - raise NotImplementedError - return Operator(">>", [self, other]) + z = Const(0, len(self.partpoints)+1) + result, _ = self.ls_op(self, other, carry=z, shr_flag=1) # TODO, carry + return result def __rrshift__(self, other): raise NotImplementedError diff --git a/src/ieee754/part/test/test_partsig.py b/src/ieee754/part/test/test_partsig.py index bfb88468..1c980bad 100644 --- a/src/ieee754/part/test/test_partsig.py +++ b/src/ieee754/part/test/test_partsig.py @@ -57,6 +57,8 @@ class TestAddMod2(Elaboratable): self.add_output = Signal(width) self.ls_output = Signal(width) # left shift self.ls_scal_output = Signal(width) # left shift + self.rs_output = Signal(width) # left shift + self.rs_scal_output = Signal(width) # left shift self.sub_output = Signal(width) self.eq_output = Signal(len(partpoints)+1) self.gt_output = Signal(len(partpoints)+1) @@ -98,11 +100,13 @@ class TestAddMod2(Elaboratable): sync += self.neg_output.eq(-self.a) # left shift sync += self.ls_output.eq(self.a << self.b) + sync += self.rs_output.eq(self.a >> self.b) ppts = self.partpoints sync += self.mux_out.eq(PMux(m, ppts, self.mux_sel, self.a, self.b)) # scalar left shift comb += self.bsig.eq(self.b.sig) sync += self.ls_scal_output.eq(self.a << self.bsig) + sync += self.rs_scal_output.eq(self.a >> self.bsig) return m @@ -116,6 +120,8 @@ class TestAddMod(Elaboratable): self.add_output = Signal(width) self.ls_output = Signal(width) # left shift self.ls_scal_output = Signal(width) # left shift + self.rs_output = Signal(width) # left shift + self.rs_scal_output = Signal(width) # left shift self.sub_output = Signal(width) self.eq_output = Signal(len(partpoints)+1) self.gt_output = Signal(len(partpoints)+1) @@ -157,11 +163,13 @@ class TestAddMod(Elaboratable): comb += self.neg_output.eq(-self.a) # left shift comb += self.ls_output.eq(self.a << self.b) + comb += self.rs_output.eq(self.a >> self.b) ppts = self.partpoints comb += self.mux_out.eq(PMux(m, ppts, self.mux_sel, self.a, self.b)) # scalar left shift comb += self.bsig.eq(self.b.sig) comb += self.ls_scal_output.eq(self.a << self.bsig) + comb += self.rs_scal_output.eq(self.a >> self.bsig) return m @@ -199,6 +207,23 @@ class TestPartitionPoints(unittest.TestCase): print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result)) return result, carry + def test_rs_scal_fn(carry_in, a, b, mask): + # reduce range of b + bits = count_bits(mask) + newb = b & ((bits-1)) + print ("%x %x %x bits %d trunc %x" % \ + (a, b, mask, bits, newb)) + b = newb + # TODO: carry + carry_in = 0 + lsb = mask & ~(mask-1) if carry_in else 0 + sum = ((a & mask) >> b) + result = mask & sum + carry = (sum & mask) != sum + carry = 0 + print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result)) + return result, carry + def test_ls_fn(carry_in, a, b, mask): # reduce range of b bits = count_bits(mask) @@ -219,6 +244,26 @@ class TestPartitionPoints(unittest.TestCase): print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result)) return result, carry + def test_rs_fn(carry_in, a, b, mask): + # reduce range of b + bits = count_bits(mask) + fz = first_zero(mask) + newb = b & ((bits-1)<>fz + sum = ((a & mask) >> b) + result = mask & sum + carry = (sum & mask) != sum + carry = 0 + print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result)) + return result, carry + def test_add_fn(carry_in, a, b, mask): lsb = mask & ~(mask-1) if carry_in else 0 sum = (a & mask) + (b & mask) + lsb @@ -279,6 +324,8 @@ class TestPartitionPoints(unittest.TestCase): for (test_fn, mod_attr) in ( (test_ls_scal_fn, "ls_scal"), (test_ls_fn, "ls"), + (test_rs_scal_fn, "rs_scal"), + (test_rs_fn, "rs"), (test_add_fn, "add"), (test_sub_fn, "sub"), (test_neg_fn, "neg"),