From d4904def3c67d3b54b56ed4054cceb84ca484054 Mon Sep 17 00:00:00 2001 From: Michael Nolan Date: Wed, 12 Feb 2020 11:40:23 -0500 Subject: [PATCH] Add formal proof for dynamic shifter --- .../formal/proof_shift_dynamic.py | 113 ++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 src/ieee754/part_shift_scalar/formal/proof_shift_dynamic.py diff --git a/src/ieee754/part_shift_scalar/formal/proof_shift_dynamic.py b/src/ieee754/part_shift_scalar/formal/proof_shift_dynamic.py new file mode 100644 index 00000000..ea2cc745 --- /dev/null +++ b/src/ieee754/part_shift_scalar/formal/proof_shift_dynamic.py @@ -0,0 +1,113 @@ +# Proof of correctness for partitioned dynamic shifter +# Copyright (C) 2020 Michael Nolan + +from nmigen import Module, Signal, Elaboratable, Mux, Cat +from nmigen.asserts import Assert, AnyConst, Assume +from nmigen.test.utils import FHDLTestCase +from nmigen.cli import rtlil + +from ieee754.part_mul_add.partpoints import PartitionPoints +from ieee754.part_shift_scalar.part_shift_dynamic import \ + PartitionedDynamicShift +import unittest + + +# This defines a module to drive the device under test and assert +# properties about its outputs +class ShifterDriver(Elaboratable): + def __init__(self): + # inputs and outputs + pass + + def get_intervals(self, signal, points): + start = 0 + interval = [] + keys = list(points.keys()) + [signal.width] + for key in keys: + end = key + interval.append(signal[start:end]) + start = end + return interval + + def elaborate(self, platform): + m = Module() + comb = m.d.comb + width = 24 + mwidth = 3 + + # setup the inputs and outputs of the DUT as anyconst + a = Signal(width) + b = Signal(width) + out = Signal(width) + points = PartitionPoints() + gates = Signal(mwidth-1) + step = int(width/mwidth) + for i in range(mwidth-1): + points[(i+1)*step] = gates[i] + print(points) + + comb += [a.eq(AnyConst(width)), + b.eq(AnyConst(width)), + gates.eq(AnyConst(mwidth-1))] + + m.submodules.dut = dut = PartitionedDynamicShift(width, points) + + a_intervals = self.get_intervals(a, points) + b_intervals = self.get_intervals(b, points) + out_intervals = self.get_intervals(out, points) + + comb += [dut.a.eq(a), + dut.b.eq(b), + out.eq(dut.output)] + + + with m.Switch(points.as_sig()): + with m.Case(0b00): + comb += Assume(b < 24) + comb += Assert(out == (a<