From 8b9e4bb7bdc717dd5cafee843f2e08ba7d027636 Mon Sep 17 00:00:00 2001 From: Michael Nolan Date: Wed, 5 Feb 2020 10:37:14 -0500 Subject: [PATCH] Add module to handle partitioned eq, gt, and ge comparisons --- src/ieee754/part_cmp/eq_gt_ge.py | 89 +++++++++++++ .../part_cmp/experiments/gt_combiner.py | 12 ++ src/ieee754/part_cmp/formal/proof_eq_gt_ge.py | 121 ++++++++++++++++++ 3 files changed, 222 insertions(+) create mode 100644 src/ieee754/part_cmp/eq_gt_ge.py create mode 100644 src/ieee754/part_cmp/formal/proof_eq_gt_ge.py diff --git a/src/ieee754/part_cmp/eq_gt_ge.py b/src/ieee754/part_cmp/eq_gt_ge.py new file mode 100644 index 00000000..c1d223d4 --- /dev/null +++ b/src/ieee754/part_cmp/eq_gt_ge.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: LGPL-2.1-or-later +# See Notices.txt for copyright information + +""" +Copyright (C) 2020 Luke Kenneth Casson Leighton + +dynamically-partitionable "comparison" class, directly equivalent +to Signal.__eq__ except SIMD-partitionable + +See: + +* http://libre-riscv.org/3d_gpu/architecture/dynamic_simd/eq +* http://bugs.libre-riscv.org/show_bug.cgi?id=132 +""" + +from nmigen import Signal, Module, Elaboratable, Cat, C, Mux, Repl +from nmigen.cli import main + +from ieee754.part_mul_add.partpoints import PartitionPoints +from ieee754.part_cmp.experiments.gt_combiner import GTCombiner + + +class PartitionedEqGtGe(Elaboratable): + + # Expansion of the partitioned equals module to handle Greater + # Than and Greater than or Equal to. The function being evaluated + # is selected by the opcode signal, where: + # opcode 0x00 - EQ + # opcode 0x01 - GT + # opcode 0x02 - GE + def __init__(self, width, partition_points): + """Create a ``PartitionedEq`` operator + """ + self.width = width + self.a = Signal(width, reset_less=True) + self.b = Signal(width, reset_less=True) + self.opcode = Signal(2) + self.partition_points = PartitionPoints(partition_points) + self.mwidth = len(self.partition_points)+1 + self.output = Signal(self.mwidth, reset_less=True) + if not self.partition_points.fits_in_width(width): + raise ValueError("partition_points doesn't fit in width") + + def elaborate(self, platform): + m = Module() + comb = m.d.comb + m.submodules.gtc = gtc = GTCombiner(self.mwidth) + + # make a series of "eqs" and "gts", splitting a and b into partition chunks + eqs = Signal(self.mwidth, reset_less=True) + eql = [] + gts = Signal(self.mwidth, reset_less=True) + gtl = [] + + keys = list(self.partition_points.keys()) + [self.width] + start = 0 + for i in range(len(keys)): + end = keys[i] + eql.append(self.a[start:end] == self.b[start:end]) + gtl.append(self.a[start:end] > self.b[start:end]) + start = end # for next time round loop + comb += eqs.eq(Cat(*eql)) + comb += gts.eq(Cat(*gtl)) + + # Signal to control the constant injected into the partition next to a closed gate + aux_input = Signal() + # Signal to enable or disable the gt input for the gt partition combiner + gt_en = Signal() + + with m.Switch(self.opcode): + with m.Case(0b00): # equals + comb += aux_input.eq(1) + comb += gt_en.eq(0) + with m.Case(0b01): # greater than + comb += aux_input.eq(0) + comb += gt_en.eq(1) + with m.Case(0b10): # greater than or equal to + comb += aux_input.eq(1) + comb += gt_en.eq(1) + + comb += gtc.gates.eq(self.partition_points.as_sig()) + comb += gtc.eqs.eq(eqs) + comb += gtc.gts.eq(gts) + comb += gtc.aux_input.eq(aux_input) + comb += gtc.gt_en.eq(gt_en) + comb += self.output.eq(gtc.outputs) + + + return m diff --git a/src/ieee754/part_cmp/experiments/gt_combiner.py b/src/ieee754/part_cmp/experiments/gt_combiner.py index 421a39be..a8c9c85e 100644 --- a/src/ieee754/part_cmp/experiments/gt_combiner.py +++ b/src/ieee754/part_cmp/experiments/gt_combiner.py @@ -32,8 +32,20 @@ class GTCombiner(Elaboratable): def __init__(self, width): self.width = width + + # These two signals allow this module to do more than just a + # partitioned greater than comparison. + # - If aux_input is set to 0 and gt_en is set to 1, then this + # module performs a partitioned greater than comparision + # - If aux_input is set to 1 and gt_en is set to 0, then this + # module is functionally equivalent to the eq_combiner + # module. + # - If aux_input is set to 1 and gt_en is set to 1, then this + # module performs a partitioned greater than or equals + # comparison self.aux_input = Signal(reset_less=True) # right hand side mux input self.gt_en = Signal(reset_less=True) # enable or disable the gt signal + self.eqs = Signal(width, reset_less=True) # the flags for EQ self.gts = Signal(width, reset_less=True) # the flags for GT self.gates = Signal(width-1, reset_less=True) diff --git a/src/ieee754/part_cmp/formal/proof_eq_gt_ge.py b/src/ieee754/part_cmp/formal/proof_eq_gt_ge.py new file mode 100644 index 00000000..497e234e --- /dev/null +++ b/src/ieee754/part_cmp/formal/proof_eq_gt_ge.py @@ -0,0 +1,121 @@ +# Proof of correctness for partitioned equals module +# Copyright (C) 2020 Michael Nolan + +from nmigen import Module, Signal, Elaboratable, Mux, Cat +from nmigen.asserts import Assert, AnyConst, Assume +from nmigen.test.utils import FHDLTestCase +from nmigen.cli import rtlil + +from ieee754.part_mul_add.partpoints import PartitionPoints +from ieee754.part_cmp.eq_gt_ge import PartitionedEqGtGe +import unittest + + +# This defines a module to drive the device under test and assert +# properties about its outputs +class EqualsDriver(Elaboratable): + def __init__(self): + # inputs and outputs + pass + + def get_intervals(self, signal, points): + start = 0 + interval = [] + keys = list(points.keys()) + [signal.width] + for key in keys: + end = key + interval.append(signal[start:end]) + start = end + return interval + + def elaborate(self, platform): + m = Module() + comb = m.d.comb + width = 24 + mwidth = 3 + + # setup the inputs and outputs of the DUT as anyconst + a = Signal(width) + b = Signal(width) + points = PartitionPoints() + gates = Signal(mwidth-1) + opcode = Signal(2) + for i in range(mwidth-1): + points[i*8+8] = gates[i] + out = Signal(mwidth) + + + comb += [a.eq(AnyConst(width)), + b.eq(AnyConst(width)), + opcode.eq(AnyConst(opcode.width)), + gates.eq(AnyConst(mwidth-1))] + + m.submodules.dut = dut = PartitionedEqGtGe(width, points) + + a_intervals = self.get_intervals(a, points) + b_intervals = self.get_intervals(b, points) + + with m.If(opcode == 0b00): + with m.Switch(gates): + with m.Case(0b00): + comb += Assert(out[-1] == (a == b)) + with m.Case(0b01): + comb += Assert(out[2] == ((a_intervals[1] == b_intervals[1]) & + (a_intervals[2] == b_intervals[2]))) + comb += Assert(out[0] == (a_intervals[0] == b_intervals[0])) + with m.Case(0b10): + comb += Assert(out[1] == ((a_intervals[0] == b_intervals[0]) & + (a_intervals[1] == b_intervals[1]))) + comb += Assert(out[2] == (a_intervals[2] == b_intervals[2])) + with m.Case(0b11): + for i in range(mwidth-1): + comb += Assert(out[i] == (a_intervals[i] == b_intervals[i])) + with m.If(opcode == 0b01): + with m.Switch(gates): + with m.Case(0b00): + comb += Assert(out[-1] == (a > b)) + with m.Case(0b01): + comb += Assert(out[0] == (a_intervals[0] > b_intervals[0])) + + comb += Assert(out[1] == 0) + comb += Assert(out[2] == (Cat(*a_intervals[1:3]) > Cat(*b_intervals[1:3]))) + with m.Case(0b10): + comb += Assert(out[0] == 0) + comb += Assert(out[1] == (Cat(*a_intervals[0:2]) > Cat(*b_intervals[0:2]))) + comb += Assert(out[2] == (a_intervals[2] > b_intervals[2])) + with m.Case(0b11): + for i in range(mwidth-1): + comb += Assert(out[i] == (a_intervals[i] > b_intervals[i])) + with m.If(opcode == 0b10): + with m.Switch(gates): + with m.Case(0b00): + comb += Assert(out[-1] == (a >= b)) + with m.Case(0b01): + comb += Assert(out[0] == (a_intervals[0] >= b_intervals[0])) + + comb += Assert(out[1] == 0) + comb += Assert(out[2] == (Cat(*a_intervals[1:3]) >= Cat(*b_intervals[1:3]))) + with m.Case(0b10): + comb += Assert(out[0] == 0) + comb += Assert(out[1] == (Cat(*a_intervals[0:2]) >= Cat(*b_intervals[0:2]))) + comb += Assert(out[2] == (a_intervals[2] >= b_intervals[2])) + with m.Case(0b11): + for i in range(mwidth-1): + comb += Assert(out[i] == (a_intervals[i] >= b_intervals[i])) + + + + comb += [dut.a.eq(a), + dut.b.eq(b), + dut.opcode.eq(opcode), + out.eq(dut.output)] + return m + +class PartitionedEqTestCase(FHDLTestCase): + def test_eq(self): + module = EqualsDriver() + self.assertFormal(module, mode="bmc", depth=4) + +if __name__ == "__main__": + unittest.main() + -- 2.30.2