From 453f43543348a61fe7c1752e88e71a9ddbbe49a0 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 4 Aug 2022 22:53:08 -0700 Subject: [PATCH] add tree_reduction and pop_count based off of dead-code-elimination of prefix_sum_ops --- src/nmutil/prefix_sum.py | 43 +++++++++ src/nmutil/test/test_prefix_sum.py | 134 ++++++++++++++++++++++++++++- 2 files changed, 176 insertions(+), 1 deletion(-) diff --git a/src/nmutil/prefix_sum.py b/src/nmutil/prefix_sum.py index 4aa89b2..3908d1b 100644 --- a/src/nmutil/prefix_sum.py +++ b/src/nmutil/prefix_sum.py @@ -7,6 +7,7 @@ from collections import defaultdict from dataclasses import dataclass import operator +from nmigen.hdl.ast import Value, Const @dataclass(order=True, unsafe_hash=True, frozen=True) @@ -219,6 +220,48 @@ def render_prefix_sum_diagram(item_count, *, work_efficient=False, return "\n".join(map(str.rstrip, lines)) +def tree_reduction_ops(item_count): + assert isinstance(item_count, int) and item_count >= 1 + ops = list(prefix_sum_ops(item_count=item_count)) + items_live_flags = [False] * item_count + items_live_flags[-1] = True + ops_live_flags = [False] * len(ops) + for i in reversed(range(len(ops))): + op = ops[i] + out_live = items_live_flags[op.out] + items_live_flags[op.out] = False + items_live_flags[op.lhs] |= out_live + items_live_flags[op.rhs] |= out_live + ops_live_flags[i] = out_live + for op, live_flag in zip(ops, ops_live_flags): + if live_flag: + yield op + + +def tree_reduction(items, fn=operator.add): + items = list(items) + for op in tree_reduction_ops(len(items)): + items[op.out] = fn(items[op.lhs], items[op.rhs]) + return items[-1] + + +def pop_count(v, *, width=None, process_temporary=lambda v: v): + if isinstance(v, Value): + if width is None: + width = len(v) + assert width == len(v) + bits = [v[i] for i in range(width)] + if len(bits) == 0: + return Const(0) + else: + assert isinstance(width, int) and width >= 0 + assert isinstance(v, int) + bits = [(v & (1 << i)) != 0 for i in range(width)] + if len(bits) == 0: + return 0 + return tree_reduction(bits, fn=lambda a, b: process_temporary(a + b)) + + if __name__ == "__main__": print("the non-work-efficient algorithm, matches the diagram in wikipedia:" "\n" diff --git a/src/nmutil/test/test_prefix_sum.py b/src/nmutil/test/test_prefix_sum.py index 63aa68e..7f1c45a 100644 --- a/src/nmutil/test/test_prefix_sum.py +++ b/src/nmutil/test/test_prefix_sum.py @@ -4,11 +4,17 @@ # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part # of Horizon 2020 EU Programme 957073. +from functools import reduce from nmutil.formaltest import FHDLTestCase +from nmutil.sim_util import write_il from itertools import accumulate import operator -from nmutil.prefix_sum import prefix_sum, render_prefix_sum_diagram +from nmutil.prefix_sum import (Op, pop_count, prefix_sum, + render_prefix_sum_diagram, + tree_reduction, tree_reduction_ops) import unittest +from nmigen.hdl.ast import Signal, AnyConst, Assert +from nmigen.hdl.dsl import Module def reference_prefix_sum(items, fn): @@ -28,6 +34,132 @@ class TestPrefixSum(FHDLTestCase): work_efficient = prefix_sum(input_items, work_efficient=True) self.assertEqual(expected, work_efficient) + def test_tree_reduction_str(self): + input_items = ("a", "b", "c", "d", "e", "f", "g", "h", "i") + expected = reduce(operator.add, input_items) + with self.subTest(expected=repr(expected)): + work_efficient = tree_reduction(input_items) + self.assertEqual(expected, work_efficient) + + def test_tree_reduction_ops_9(self): + ops = list(tree_reduction_ops(9)) + self.assertEqual(ops, [ + Op(out=8, lhs=7, rhs=8, row=0), + Op(out=6, lhs=5, rhs=6, row=0), + Op(out=4, lhs=3, rhs=4, row=0), + Op(out=2, lhs=1, rhs=2, row=0), + Op(out=8, lhs=6, rhs=8, row=1), + Op(out=4, lhs=2, rhs=4, row=1), + Op(out=8, lhs=4, rhs=8, row=2), + Op(out=8, lhs=0, rhs=8, row=3), + ]) + + def test_tree_reduction_ops_8(self): + ops = list(tree_reduction_ops(8)) + self.assertEqual(ops, [ + Op(out=7, lhs=6, rhs=7, row=0), + Op(out=5, lhs=4, rhs=5, row=0), + Op(out=3, lhs=2, rhs=3, row=0), + Op(out=1, lhs=0, rhs=1, row=0), + Op(out=7, lhs=5, rhs=7, row=1), + Op(out=3, lhs=1, rhs=3, row=1), + Op(out=7, lhs=3, rhs=7, row=2), + ]) + + def tst_pop_count_int(self, width): + assert isinstance(width, int) + for v in range(1 << width): + expected = f"{v:b}".count("1") + with self.subTest(v=v, expected=expected): + self.assertEqual(expected, pop_count(v, width=width)) + + def test_pop_count_int_0(self): + self.tst_pop_count_int(0) + + def test_pop_count_int_1(self): + self.tst_pop_count_int(1) + + def test_pop_count_int_2(self): + self.tst_pop_count_int(2) + + def test_pop_count_int_3(self): + self.tst_pop_count_int(3) + + def test_pop_count_int_4(self): + self.tst_pop_count_int(4) + + def test_pop_count_int_5(self): + self.tst_pop_count_int(5) + + def test_pop_count_int_6(self): + self.tst_pop_count_int(6) + + def test_pop_count_int_7(self): + self.tst_pop_count_int(7) + + def test_pop_count_int_8(self): + self.tst_pop_count_int(8) + + def test_pop_count_int_9(self): + self.tst_pop_count_int(9) + + def test_pop_count_int_10(self): + self.tst_pop_count_int(10) + + def tst_pop_count_formal(self, width): + assert isinstance(width, int) + m = Module() + v = Signal(width) + out = Signal(16) + + def process_temporary(v): + sig = Signal.like(v) + m.d.comb += sig.eq(v) + return sig + + m.d.comb += out.eq(pop_count(v, process_temporary=process_temporary)) + write_il(self, m, [v, out]) + m.d.comb += v.eq(AnyConst(width)) + expected = Signal(16) + m.d.comb += expected.eq(reduce(operator.add, + (v[i] for i in range(width)), + 0)) + m.d.comb += Assert(out == expected) + self.assertFormal(m) + + def test_pop_count_formal_0(self): + self.tst_pop_count_formal(0) + + def test_pop_count_formal_1(self): + self.tst_pop_count_formal(1) + + def test_pop_count_formal_2(self): + self.tst_pop_count_formal(2) + + def test_pop_count_formal_3(self): + self.tst_pop_count_formal(3) + + def test_pop_count_formal_4(self): + self.tst_pop_count_formal(4) + + def test_pop_count_formal_5(self): + self.tst_pop_count_formal(5) + + def test_pop_count_formal_6(self): + self.tst_pop_count_formal(6) + + def test_pop_count_formal_7(self): + self.tst_pop_count_formal(7) + + def test_pop_count_formal_8(self): + self.tst_pop_count_formal(8) + + def test_pop_count_formal_9(self): + self.tst_pop_count_formal(9) + + def test_pop_count_formal_10(self): + self.tst_pop_count_formal(10) + def test_render_work_efficient(self): text = render_prefix_sum_diagram(16, work_efficient=True, plus="@") expected = r""" -- 2.30.2