From 8268a10d4827a095f7c776d3424e7b5abc581056 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 9 Mar 2023 00:10:47 -0800 Subject: [PATCH] add experiment for seeing if changing ternlogi to have 7-bit immediate could even work turns out, it probably can. I wrote a script that looks for possible expand_encoded_imm() functions and filled in one I found. --- .../ternlogi_simplification_experiment.mdwn | 56 ++ .../ternlogi_simplification_experiment.py | 576 ++++++++++++++++++ 2 files changed, 632 insertions(+) create mode 100644 openpower/sv/bitmanip/ternlogi_simplification_experiment.mdwn create mode 100755 openpower/sv/bitmanip/ternlogi_simplification_experiment.py diff --git a/openpower/sv/bitmanip/ternlogi_simplification_experiment.mdwn b/openpower/sv/bitmanip/ternlogi_simplification_experiment.mdwn new file mode 100644 index 000000000..140b24eca --- /dev/null +++ b/openpower/sv/bitmanip/ternlogi_simplification_experiment.mdwn @@ -0,0 +1,56 @@ +# experiment to see if changing ternlogi to have a 7-bit immediate would even work + +I created `openpower/sv/bitmanip/ternlogi_simplification_experiment.py` to try +to find a expand_encoded_imm function that works, it should be possible since +we only need to encode 68 distinct cases into 7 bits. + +`ternlogi` would be: + +| 0.5|6.10|11.15|16.20| 21..27|28.30|31| +| -- | -- | --- | --- | ----- | --- |--| +| NN | RT | RA | RB | imm7 | 00 |Rc| + +``` +def lut3(imm8, a, b, c): + idx = c << 2 | b << 1 | a + return imm8[idx] # idx by LSB0 order + +# TODO: independently verify the expand_encoded_imm function I found works, +# ternlogi_simplification_experiment.py could be buggy. + +# found using: python3 ternlogi_simplification_experiment.py generate 24798 +def expand_encoded_imm(imm7): + imm7 &= 0x7F + v0 = imm7 + v1 = v0 >> 0x2 + v3 = v1 ^ v0 + v7 = v1 << 0x4 + v8 = v3 ^ v7 + return v8 & 0xFF + +imm8 <- expand_encoded_imm(imm7) + +for i in range(64): + RT[i] = lut3(imm8, RB[i], RA[i], RT[i]) +``` + +This definition of `expand_encoded_imm` allows covering all possible ternary +logic functions with the following imho reasonable exceptions: +* you may need to permute RA, RB, and RT +* if the function is already implemented by any single existing instruction, +you may need to use the existing instruction instead. +(turns out you probably don't even need to do this for the +`expand_encoded_imm` function I found.) + + I checked against (see `needed_equivalence_classes()`): + * li + * mv + * not + * and + * andc + * nand + * or + * orc + * nor + * xor + * eqv diff --git a/openpower/sv/bitmanip/ternlogi_simplification_experiment.py b/openpower/sv/bitmanip/ternlogi_simplification_experiment.py new file mode 100755 index 000000000..d1a3d610c --- /dev/null +++ b/openpower/sv/bitmanip/ternlogi_simplification_experiment.py @@ -0,0 +1,576 @@ +#!/usr/bin/env python3 +from collections import defaultdict +from dataclasses import dataclass +from typing import Callable, Dict, Sequence, Iterable, NoReturn, Set +from itertools import permutations +import pprint +import enum +import random +import sys +import unittest + + +def unreachable(v): + # type: (NoReturn) -> NoReturn + raise TypeError("v was expected to be a never type") + + +def lut3(imm8, a, b, c): + # type: (int, int, int, int) -> int + a &= 1 # ensure a, b, c are in-range + b &= 1 + c &= 1 + idx = c << 2 | b << 1 | a + return (imm8 >> idx) & 1 # idx by LSB0 order + + +def ternlogi_8(RT, RA, RB, imm8): + # type: (int, int, int, int) -> int + for i in range(64): + RT |= lut3(imm8, (RB >> i) & 1, (RA >> i) & 1, (RT >> i) & 1) << i + return RT + + +# found using: python3 ternlogi_simplification_experiment.py generate 24798 +def expand_encoded_imm(imm7): + imm7 &= 0x7F + v0 = imm7 + v1 = v0 >> 0x2 + v3 = v1 ^ v0 + v7 = v1 << 0x4 + v8 = v3 ^ v7 + return v8 & 0xFF + + +def ternlogi_7(RT, RA, RB, imm7): + # type: (int, int, int, int) -> int + imm8 = expand_encoded_imm(imm7) + for i in range(64): + RT |= lut3(imm8, (RB >> i) & 1, (RA >> i) & 1, (RT >> i) & 1) << i + return RT + + +@dataclass(unsafe_hash=True, order=True, frozen=True) +class TruthTable: + truth_table: "tuple[bool, ...]" + num_vars: int + + def __post_init__(self): + assert len(self.truth_table) == 1 << self.num_vars + + @staticmethod + def all_truth_tables(num_vars): + # type: (int) -> Iterable[TruthTable] + for bits in range(1 << (1 << num_vars)): + yield TruthTable.from_bits(num_vars, bits) + + @property + def bits(self): + # type: () -> int + retval = 0 + for i, v in enumerate(self.truth_table): + if v: + retval |= 1 << i + return retval + + @staticmethod + def from_bits(num_vars, bits): + # type: (int, int) -> TruthTable + return TruthTable.from_fn_with_index( + num_vars, lambda index: bool((bits >> index) & 1)) + + def vars_to_index(self, variables): + # type: (Sequence[bool]) -> int + index = 0 + for i in range(self.num_vars): + if variables[i]: + index |= 1 << i + return index + + @staticmethod + def index_to_vars(num_vars, index): + # type: (int, int) -> tuple[bool, ...] + return tuple(bool((index >> i) & 1) for i in range(num_vars)) + + def __getitem__(self, variables): + # type: (Sequence[bool]) -> bool + return self.truth_table[self.vars_to_index(variables)] + + @staticmethod + def from_fn_with_index(num_vars, fn): + # type: (int, Callable[[int], bool]) -> TruthTable + truth_table = [] # type: list[bool] + for i in range(1 << num_vars): + truth_table.append(fn(i)) + return TruthTable(truth_table=tuple(truth_table), num_vars=num_vars) + + @staticmethod + def from_fn(num_vars, fn): + # type: (int, Callable[[tuple[bool, ...]], bool]) -> TruthTable + return TruthTable.from_fn_with_index( + num_vars, lambda i: fn(TruthTable.index_to_vars(num_vars, i))) + + @staticmethod + def for_const(num_vars, v): + # type: (int, bool) -> TruthTable + return TruthTable.from_fn_with_index(num_vars, lambda _: v) + + @staticmethod + def for_var(num_vars, var_index): + # type: (int, int) -> TruthTable + return TruthTable.from_fn( + num_vars, lambda variables: variables[var_index]) + + def __invert__(self): + # type: () -> TruthTable + return TruthTable(tuple(not v for v in self.truth_table), + num_vars=self.num_vars) + + def __and__(self, rhs): + # type: (TruthTable) -> TruthTable + assert self.num_vars == rhs.num_vars, "mismatched num_vars" + return TruthTable(tuple( + l & r for l, r in zip(self.truth_table, rhs.truth_table)), + num_vars=self.num_vars) + + __rand__ = __and__ + + def __or__(self, rhs): + # type: (TruthTable) -> TruthTable + assert self.num_vars == rhs.num_vars, "mismatched num_vars" + return TruthTable(tuple( + l | r for l, r in zip(self.truth_table, rhs.truth_table)), + num_vars=self.num_vars) + + __ror__ = __or__ + + def __xor__(self, rhs): + # type: (TruthTable) -> TruthTable + assert self.num_vars == rhs.num_vars, "mismatched num_vars" + return TruthTable(tuple( + l ^ r for l, r in zip(self.truth_table, rhs.truth_table)), + num_vars=self.num_vars) + + __rxor__ = __xor__ + + +class TestTruthTable(unittest.TestCase): + def test_all_truth_tables(self): + self.assertEqual(list(TruthTable.all_truth_tables(num_vars=2)), [ + TruthTable(truth_table=(False, False, False, False), num_vars=2), + TruthTable(truth_table=(True, False, False, False), num_vars=2), + TruthTable(truth_table=(False, True, False, False), num_vars=2), + TruthTable(truth_table=(True, True, False, False), num_vars=2), + TruthTable(truth_table=(False, False, True, False), num_vars=2), + TruthTable(truth_table=(True, False, True, False), num_vars=2), + TruthTable(truth_table=(False, True, True, False), num_vars=2), + TruthTable(truth_table=(True, True, True, False), num_vars=2), + TruthTable(truth_table=(False, False, False, True), num_vars=2), + TruthTable(truth_table=(True, False, False, True), num_vars=2), + TruthTable(truth_table=(False, True, False, True), num_vars=2), + TruthTable(truth_table=(True, True, False, True), num_vars=2), + TruthTable(truth_table=(False, False, True, True), num_vars=2), + TruthTable(truth_table=(True, False, True, True), num_vars=2), + TruthTable(truth_table=(False, True, True, True), num_vars=2), + TruthTable(truth_table=(True, True, True, True), num_vars=2), + ]) + + def test_bits(self): + bits = [i.bits for i in TruthTable.all_truth_tables(num_vars=2)] + expected = list(range(16)) + self.assertEqual(bits, expected) + + def test_const(self): + self.assertEqual(TruthTable.for_const(3, False), TruthTable( + truth_table=(False,) * 8, num_vars=3)) + self.assertEqual(TruthTable.for_const(3, True), TruthTable( + truth_table=(True,) * 8, num_vars=3)) + + def test_var(self): + self.assertEqual(TruthTable.for_var(3, 0), TruthTable( + truth_table=(False, True) * 4, num_vars=3)) + self.assertEqual(TruthTable.for_var(3, 1), TruthTable( + truth_table=(False, False, True, True) * 2, num_vars=3)) + self.assertEqual(TruthTable.for_var(3, 2), TruthTable( + truth_table=(False,) * 4 + (True,) * 4, num_vars=3)) + + def test_invert(self): + self.assertEqual(~TruthTable.for_var(3, 0), TruthTable( + truth_table=(True, False) * 4, num_vars=3)) + + def test_and(self): + self.assertEqual( + TruthTable.for_var(3, 0) & TruthTable.for_var(3, 1), + TruthTable(truth_table=(False, False, False, True) * 2, + num_vars=3)) + + def test_or(self): + self.assertEqual( + TruthTable.for_var(3, 0) | TruthTable.for_var(3, 1), + TruthTable(truth_table=(False, True, True, True) * 2, + num_vars=3)) + + def test_xor(self): + self.assertEqual( + TruthTable.for_var(3, 0) ^ TruthTable.for_var(3, 1), + TruthTable(truth_table=(False, True, True, False) * 2, + num_vars=3)) + + +_VARS_PERMS_SET = Set["tuple[int, ...]"] +_VARS_PERMS = Dict[TruthTable, _VARS_PERMS_SET] + + +@dataclass(unsafe_hash=True, frozen=True) +class EquivalenceClass: + """ an equivalence class of all truth tables where variables can be + arbitrarily permuted. + """ + truth_tables: "frozenset[TruthTable]" + + @staticmethod + def get_with_vars_perms(truth_table): + # type: (TruthTable) -> tuple[EquivalenceClass, _VARS_PERMS] + truth_tables = set() # type: set[TruthTable] + vars_perms = defaultdict(set) # type: _VARS_PERMS + iter_ = EquivalenceClass.permuted_truth_tables(truth_table) + for tt, vars_perm in iter_: + truth_tables.add(tt) + vars_perms[tt].add(vars_perm) + return EquivalenceClass(frozenset(truth_tables)), vars_perms + + @staticmethod + def get(truth_table): + # type: (TruthTable) -> EquivalenceClass + retval, _ = EquivalenceClass.get_with_vars_perms(truth_table) + return retval + + @staticmethod + def permuted_truth_tables(truth_table): + # type: (TruthTable) -> Iterable[tuple[TruthTable, tuple[int, ...]]] + for vars_perm in permutations(range(truth_table.num_vars)): + def fn(variables): + # type: (Sequence[bool]) -> bool + return truth_table[[variables[i] for i in vars_perm]] + perm_tt = TruthTable.from_fn(truth_table.num_vars, fn) + yield (perm_tt, vars_perm) + + @property + def representative_truth_table(self): + # type: () -> TruthTable + return min(self.truth_tables) + + def __repr__(self): + # type: () -> str + return f"EquivalenceClass({self.representative_truth_table!r})" + + def __lt__(self, rhs): + # type: (EquivalenceClass) -> bool + return self.representative_truth_table < rhs.representative_truth_table + + def __gt__(self, rhs): + # type: (EquivalenceClass) -> bool + return self.representative_truth_table > rhs.representative_truth_table + + def __le__(self, rhs): + # type: (EquivalenceClass) -> bool + return self.representative_truth_table <= rhs.representative_truth_table + + def __ge__(self, rhs): + # type: (EquivalenceClass) -> bool + return self.representative_truth_table >= rhs.representative_truth_table + + +class TestEquivalenceClass(unittest.TestCase): + maxDiff = None + + def test_get_with_vars_perms(self): + truth_table = TruthTable.from_bits(2, 0x4) + eqv_cls, vars_perms = EquivalenceClass.get_with_vars_perms(truth_table) + self.assertEqual(eqv_cls.truth_tables, frozenset(vars_perms.keys())) + self.assertEqual(dict(vars_perms), { + TruthTable(truth_table=(False, False, True, False), + num_vars=2): {(0, 1)}, + TruthTable(truth_table=(False, True, False, False), + num_vars=2): {(1, 0)}, + }) + for tt in TruthTable.all_truth_tables(3): + with self.subTest(tt=repr(tt)): + eqv_cls, vars_perms = EquivalenceClass.get_with_vars_perms(tt) + self.assertEqual(eqv_cls.truth_tables, + frozenset(vars_perms.keys())) + expected_vars_perms = defaultdict(set) # type: _VARS_PERMS + expected_vars_perms[tt].add((0, 1, 2)) + perm = TruthTable.from_fn(3, lambda v: tt[(v[0], v[2], v[1])]) + expected_vars_perms[perm].add((0, 2, 1)) + perm = TruthTable.from_fn(3, lambda v: tt[(v[1], v[0], v[2])]) + expected_vars_perms[perm].add((1, 0, 2)) + perm = TruthTable.from_fn(3, lambda v: tt[(v[1], v[2], v[0])]) + expected_vars_perms[perm].add((1, 2, 0)) + perm = TruthTable.from_fn(3, lambda v: tt[(v[2], v[0], v[1])]) + expected_vars_perms[perm].add((2, 0, 1)) + perm = TruthTable.from_fn(3, lambda v: tt[(v[2], v[1], v[0])]) + expected_vars_perms[perm].add((2, 1, 0)) + l = dict(vars_perms) + r = dict(expected_vars_perms) + if l != r: + l = pprint.pformat(l) + r = pprint.pformat(r) + self.fail(f"vars_perms not as expected:\nvars_perms:\n{l}" + f"\nexpected_vars_perms:\n{r}") + + +NUM_VARS = 3 + + +def needed_equivalence_classes(): + # type: () -> set[EquivalenceClass] + ZERO = TruthTable.for_const(NUM_VARS, False) + ONE = TruthTable.for_const(NUM_VARS, True) + A = TruthTable.for_var(NUM_VARS, 0) + B = TruthTable.for_var(NUM_VARS, 1) + equivalence_classes = set() # type: set[EquivalenceClass] + for truth_table in TruthTable.all_truth_tables(NUM_VARS): + equivalence_classes.add(EquivalenceClass.get(truth_table)) + + # remove equivalence classes that are already covered by + # existing PowerISA instructions: + + # li + equivalence_classes.remove(EquivalenceClass.get(ZERO)) + equivalence_classes.remove(EquivalenceClass.get(ONE)) + + # mv + equivalence_classes.remove(EquivalenceClass.get(A)) + + # not + equivalence_classes.remove(EquivalenceClass.get(~A)) + + # and + equivalence_classes.remove(EquivalenceClass.get(A & B)) + + # andc + equivalence_classes.remove(EquivalenceClass.get(A & ~B)) + + # nand + equivalence_classes.remove(EquivalenceClass.get(~(A & B))) + + # or + equivalence_classes.remove(EquivalenceClass.get(A | B)) + + # orc + equivalence_classes.remove(EquivalenceClass.get(A | ~B)) + + # nor + equivalence_classes.remove(EquivalenceClass.get(~(A | B))) + + # xor + equivalence_classes.remove(EquivalenceClass.get(A ^ B)) + + # eqv + equivalence_classes.remove(EquivalenceClass.get(~(A ^ B))) + + return equivalence_classes + + +NEEDED_EQUIVALENCE_CLASSES = frozenset(needed_equivalence_classes()) + + +class MissingNeededEquivalenceClasses(Exception): + pass + + +@dataclass(frozen=True, unsafe_hash=True) +class Imm7AndVarsPerm: + imm7: int + vars_perm: "tuple[int, ...]" + + def __repr__(self): + # type: () -> str + return (f"Imm7AndVarsPerm(imm7={self.imm7:#04x}, " + f"vars_perm={self.vars_perm})") + + +def verify_covers_needed_equivalence_classes(expand_encoded_imm): + # type: (Callable[[int], int]) -> str + equivalence_classes = set(NEEDED_EQUIVALENCE_CLASSES) + imm8_imm7_map = defaultdict(set) # type: dict[int, set[Imm7AndVarsPerm]] + log_lines = [] # type: list[str] + for imm7 in range(1 << 7): + imm8 = expand_encoded_imm(imm7) + truth_table = TruthTable.from_bits(num_vars=NUM_VARS, bits=imm8) + log_lines.append( + f"imm7={hex(imm7)} imm8={hex(imm8)} truth_table={truth_table}") + eqv_cls, vars_perms = EquivalenceClass.get_with_vars_perms(truth_table) + if eqv_cls in equivalence_classes: + log_lines.append(f"new equivalence class: {eqv_cls}") + equivalence_classes.remove(eqv_cls) + for truth_table, vars_perms_set in vars_perms.items(): + imm8 = truth_table.bits + for vars_perm in vars_perms_set: + imm8_imm7_map[imm8].add(Imm7AndVarsPerm( + imm7=imm7, vars_perm=vars_perm)) + + for imm8 in range(1 << 8): + imm8_imm7_map[imm8] # fill in empty set defaults + + log_lines.append("imm8_imm7_map:") + log_lines.append(pprint.pformat(dict(imm8_imm7_map))) + log_lines.append("equivalence classes we failed to cover:") + log_lines.append(pprint.pformat(sorted(equivalence_classes))) + if len(equivalence_classes) != 0: + log_lines.append("failed to cover all equivalence classes") + raise MissingNeededEquivalenceClasses("\n".join(log_lines)) + return "\n".join(log_lines) + + +@enum.unique +class OpKind(enum.Enum): + Shl = "<<" + Shr = ">>" + And = "&" + Or = "|" + Xor = "^" + Input = "imm7" + + @property + def is_shift(self): + return self is OpKind.Shl or self is OpKind.Shr + + +@dataclass(unsafe_hash=True, frozen=True) +class Operation: + ssa_reg: int + kind: OpKind + lhs: "Operation | int" + rhs: "Operation | int" + + @staticmethod + def make_random(op_count, rand): + # type: (int, random.Random) -> Operation + inp = Operation(ssa_reg=0, kind=OpKind.Input, lhs=0, rhs=0) + values = [None, inp] # type: list[Operation | None] + kinds = tuple(OpKind) + + def make_arg(): + # type: () -> Operation | int + arg = rand.choice(values) + if arg is None: + return rand.randrange(0, 1 << 8) + return arg + for ssa_reg in range(1, op_count): + kind = rand.choice(kinds) + lhs = make_arg() + if kind.is_shift: + rhs = rand.randint(1, 7) + else: + rhs = make_arg() + values.append(Operation(ssa_reg=ssa_reg, kind=kind, + lhs=lhs, rhs=rhs)) + retval = values[-1] + assert retval is not None + return retval + + def to_python(self): + # type: () -> str + lines = [ + "def expand_encoded_imm(imm7):", + " imm7 &= 0x7F", + ] + ssa_regs = set() # type: set[int] + + def visit(op): + # type: (Operation | int) -> str + if isinstance(op, int): + return hex(op) + reg = f"v{op.ssa_reg}" + if op.ssa_reg in ssa_regs: + return reg + ssa_regs.add(op.ssa_reg) + lhs = visit(op.lhs) + rhs = visit(op.rhs) + if op.kind is OpKind.Shl: + lines.append(f" {reg} = {lhs} << {rhs}") + elif op.kind is OpKind.Shr: + lines.append(f" {reg} = {lhs} >> {rhs}") + elif op.kind is OpKind.And: + lines.append(f" {reg} = {lhs} & {rhs}") + elif op.kind is OpKind.Or: + lines.append(f" {reg} = {lhs} | {rhs}") + elif op.kind is OpKind.Xor: + lines.append(f" {reg} = {lhs} ^ {rhs}") + elif op.kind is OpKind.Input: + lines.append(f" {reg} = imm7") + else: + unreachable(op.kind) + return reg + final_reg = visit(self) + lines.append(f" return {final_reg} & 0xFF") + return "\n".join(lines) + + def __repr__(self): + python_str = self.to_python() + return f"Operation{{\n{python_str}\n}}" + + +@enum.unique +class Cmd(enum.Enum): + Generate = "generate" + Check = "check" + UnitTest = "unittest" + Help = "--help" + + +def parse_command(): + cmd = "" + if len(sys.argv) > 1: + cmd = sys.argv[1] + try: + retval = Cmd(cmd) + if retval is not Cmd.Help: + return retval + except ValueError: + pass + cmds = "|".join(i.value for i in Cmd) + print(f"usage: {sys.argv[0]} {cmds}", file=sys.stderr) + exit(1) + + +def main(): + cmd = parse_command() + if cmd is Cmd.UnitTest: + del sys.argv[1] + unittest.main() + return + print("equivalence classes we need to cover:") + pprint.pprint(sorted(NEEDED_EQUIVALENCE_CLASSES)) + print(flush=True) + if cmd is Cmd.Check: + log_str = verify_covers_needed_equivalence_classes( + expand_encoded_imm=expand_encoded_imm) + print(log_str) + elif cmd is Cmd.Generate: + rand = random.Random() + seed = sys.argv[-1] + rand.seed(seed, version=2) # seed for reproducibility + for i in range(100000): + if i % 1000 == 0: + print(f"try #{i} seed={seed!r}", flush=True) + op_count = rand.randint(1, 10) + python_str = Operation.make_random(op_count, rand).to_python() + try: + globals_ = {} + exec(python_str, globals_) + fn = globals_["expand_encoded_imm"] + verify_covers_needed_equivalence_classes(expand_encoded_imm=fn) + except MissingNeededEquivalenceClasses: + continue + success_str = f"# found working function:\nseed={seed!r}\n\n{python_str}" + print(success_str) + with open("found_working_function.txt", "xt") as f: + print(success_str, file=f) + return + + +if __name__ == "__main__": + main() -- 2.30.2