From 05dce556bd8d47e1ea1e430b59d234d2dc01bd66 Mon Sep 17 00:00:00 2001 From: Michael Nolan Date: Tue, 28 Jan 2020 16:58:58 -0500 Subject: [PATCH] Use anyconst for the inputs of the dut in FMax formal proof --- src/ieee754/fpmax/formal/proof_fmax_mod.py | 75 +++++++++++++--------- 1 file changed, 45 insertions(+), 30 deletions(-) diff --git a/src/ieee754/fpmax/formal/proof_fmax_mod.py b/src/ieee754/fpmax/formal/proof_fmax_mod.py index abaa2d86..24554bc5 100644 --- a/src/ieee754/fpmax/formal/proof_fmax_mod.py +++ b/src/ieee754/fpmax/formal/proof_fmax_mod.py @@ -2,7 +2,7 @@ # Copyright (C) 2020 Michael Nolan from nmigen import Module, Signal, Elaboratable, Cat, Mux -from nmigen.asserts import Assert, Assume +from nmigen.asserts import Assert, Assume, AnyConst from nmigen.cli import rtlil from ieee754.fpcommon.fpbase import FPNumDecode, FPNumBaseRecord @@ -18,69 +18,84 @@ class FPMAXDriver(Elaboratable): def __init__(self, pspec): # inputs and outputs self.pspec = pspec - self.a = Signal(pspec.width) - self.b = Signal(pspec.width) - self.z = Signal(pspec.width) - self.opc = Signal(pspec.op_wid) - self.muxid = Signal(pspec.id_wid) def elaborate(self, platform): m = Module() - width = self.pspec.width + + # setup the inputs and outputs of the DUT as anyconst + a = Signal(width) + b = Signal(width) + z = Signal(width) + opc = Signal(self.pspec.op_wid) + muxid = Signal(self.pspec.id_wid) + m.d.comb += [a.eq(AnyConst(width)), + b.eq(AnyConst(width)), + opc.eq(AnyConst(self.pspec.op_wid)), + muxid.eq(AnyConst(self.pspec.id_wid))] + m.submodules.dut = dut = FPMAXPipeMod(self.pspec) - a1 = FPNumBaseRecord(self.pspec.width, False) - b1 = FPNumBaseRecord(self.pspec.width, False) - z1 = FPNumBaseRecord(self.pspec.width, False) + # Decode the inputs and outputs so they're easier to work with + a1 = FPNumBaseRecord(width, False) + b1 = FPNumBaseRecord(width, False) + z1 = FPNumBaseRecord(width, False) m.submodules.sc_decode_a = a1 = FPNumDecode(None, a1) m.submodules.sc_decode_b = b1 = FPNumDecode(None, b1) m.submodules.sc_decode_z = z1 = FPNumDecode(None, z1) + m.d.comb += [a1.v.eq(a), + b1.v.eq(b), + z1.v.eq(z)] - m.d.comb += [a1.v.eq(self.a), - b1.v.eq(self.b), - z1.v.eq(self.z)] - - m.d.comb += Assert((z1.v == a1.v) | \ - (z1.v == b1.v) | \ + # Since this calculates the min/max of two values, the value + # it returns should either be one of the two values, or NaN + m.d.comb += Assert((z1.v == a1.v) | (z1.v == b1.v) | (z1.v == a1.fp.nan2(0))) + # If both the operands are NaN, max/min should return NaN with m.If(a1.is_nan & b1.is_nan): m.d.comb += Assert(z1.is_nan) + # If only one of the operands is NaN, fmax and fmin should + # return the other operand with m.Elif(a1.is_nan & ~b1.is_nan): m.d.comb += Assert(z1.v == b1.v) with m.Elif(b1.is_nan & ~a1.is_nan): m.d.comb += Assert(z1.v == a1.v) + # If none of the operands are NaN, then compare the values and + # determine the largest or smallest with m.Else(): + # Selects whether the result should be the left hand side + # (a) or right hand side (b) isrhs = Signal() # if a1 is negative and b1 isn't, then we should return b1 with m.If(a1.s != b1.s): m.d.comb += isrhs.eq(a1.s > b1.s) with m.Else(): - # if they both have the same sign + # if they both have the same sign, compare the + # exponent/mantissa as an integer gt = Signal() - m.d.comb += gt.eq(self.a[0:width-1] < self.b[0:width-1]) + m.d.comb += gt.eq(a[0:width-1] < b[0:width-1]) + # Invert the result we got if both sign bits are set + # (A bigger exponent/mantissa with a set sign bit + # means a smaller value) m.d.comb += isrhs.eq(gt ^ a1.s) - with m.If(self.opc == 0): + with m.If(opc == 0): m.d.comb += Assert(z1.v == - Mux(self.opc[0] ^ isrhs, + Mux(opc[0] ^ isrhs, b1.v, a1.v)) - # connect up the inputs and outputs. I think these could - # theoretically be $anyconst/$anysync but I'm not sure nmigen - # has support for that - m.d.comb += dut.i.a.eq(self.a) - m.d.comb += dut.i.b.eq(self.b) - m.d.comb += dut.i.ctx.op.eq(self.opc) - m.d.comb += dut.i.muxid.eq(self.muxid) - m.d.comb += self.z.eq(dut.o.z) - + # connect up the inputs and outputs. + m.d.comb += dut.i.a.eq(a) + m.d.comb += dut.i.b.eq(b) + m.d.comb += dut.i.ctx.op.eq(opc) + m.d.comb += dut.i.muxid.eq(muxid) + m.d.comb += z.eq(dut.o.z) return m def ports(self): - return [self.a, self.b, self.z, self.opc, self.muxid] + return [] def run_test(bits=32): -- 2.30.2