+++ /dev/null
-#!/usr/bin/env python3
-from collections import defaultdict
-from dataclasses import dataclass
-from typing import Callable, Dict, Sequence, Iterable, NoReturn, Set
-from itertools import permutations
-import pprint
-import enum
-import random
-import sys
-import unittest
-
-
-def unreachable(v):
- # type: (NoReturn) -> NoReturn
- raise TypeError("v was expected to be a never type")
-
-
-def lut3(imm8, a, b, c):
- # type: (int, int, int, int) -> int
- a &= 1 # ensure a, b, c are in-range
- b &= 1
- c &= 1
- idx = c << 2 | b << 1 | a
- return (imm8 >> idx) & 1 # idx by LSB0 order
-
-
-def ternlogi_8(RT, RA, RB, imm8):
- # type: (int, int, int, int) -> int
- for i in range(64):
- RT |= lut3(imm8, (RB >> i) & 1, (RA >> i) & 1, (RT >> i) & 1) << i
- return RT
-
-
-# found using: python3 ternlogi_simplification_experiment.py generate 24798
-def expand_encoded_imm(imm7):
- imm7 &= 0x7F
- v0 = imm7
- v1 = v0 >> 0x2
- v3 = v1 ^ v0
- v7 = v1 << 0x4
- v8 = v3 ^ v7
- return v8 & 0xFF
-
-
-def ternlogi_7(RT, RA, RB, imm7):
- # type: (int, int, int, int) -> int
- imm8 = expand_encoded_imm(imm7)
- for i in range(64):
- RT |= lut3(imm8, (RB >> i) & 1, (RA >> i) & 1, (RT >> i) & 1) << i
- return RT
-
-
-@dataclass(unsafe_hash=True, order=True, frozen=True)
-class TruthTable:
- truth_table: "tuple[bool, ...]"
- num_vars: int
-
- def __post_init__(self):
- assert len(self.truth_table) == 1 << self.num_vars
-
- @staticmethod
- def all_truth_tables(num_vars):
- # type: (int) -> Iterable[TruthTable]
- for bits in range(1 << (1 << num_vars)):
- yield TruthTable.from_bits(num_vars, bits)
-
- @property
- def bits(self):
- # type: () -> int
- retval = 0
- for i, v in enumerate(self.truth_table):
- if v:
- retval |= 1 << i
- return retval
-
- @staticmethod
- def from_bits(num_vars, bits):
- # type: (int, int) -> TruthTable
- return TruthTable.from_fn_with_index(
- num_vars, lambda index: bool((bits >> index) & 1))
-
- def vars_to_index(self, variables):
- # type: (Sequence[bool]) -> int
- index = 0
- for i in range(self.num_vars):
- if variables[i]:
- index |= 1 << i
- return index
-
- @staticmethod
- def index_to_vars(num_vars, index):
- # type: (int, int) -> tuple[bool, ...]
- return tuple(bool((index >> i) & 1) for i in range(num_vars))
-
- def __getitem__(self, variables):
- # type: (Sequence[bool]) -> bool
- return self.truth_table[self.vars_to_index(variables)]
-
- @staticmethod
- def from_fn_with_index(num_vars, fn):
- # type: (int, Callable[[int], bool]) -> TruthTable
- truth_table = [] # type: list[bool]
- for i in range(1 << num_vars):
- truth_table.append(fn(i))
- return TruthTable(truth_table=tuple(truth_table), num_vars=num_vars)
-
- @staticmethod
- def from_fn(num_vars, fn):
- # type: (int, Callable[[tuple[bool, ...]], bool]) -> TruthTable
- return TruthTable.from_fn_with_index(
- num_vars, lambda i: fn(TruthTable.index_to_vars(num_vars, i)))
-
- @staticmethod
- def for_const(num_vars, v):
- # type: (int, bool) -> TruthTable
- return TruthTable.from_fn_with_index(num_vars, lambda _: v)
-
- @staticmethod
- def for_var(num_vars, var_index):
- # type: (int, int) -> TruthTable
- return TruthTable.from_fn(
- num_vars, lambda variables: variables[var_index])
-
- def __invert__(self):
- # type: () -> TruthTable
- return TruthTable(tuple(not v for v in self.truth_table),
- num_vars=self.num_vars)
-
- def __and__(self, rhs):
- # type: (TruthTable) -> TruthTable
- assert self.num_vars == rhs.num_vars, "mismatched num_vars"
- return TruthTable(tuple(
- l & r for l, r in zip(self.truth_table, rhs.truth_table)),
- num_vars=self.num_vars)
-
- __rand__ = __and__
-
- def __or__(self, rhs):
- # type: (TruthTable) -> TruthTable
- assert self.num_vars == rhs.num_vars, "mismatched num_vars"
- return TruthTable(tuple(
- l | r for l, r in zip(self.truth_table, rhs.truth_table)),
- num_vars=self.num_vars)
-
- __ror__ = __or__
-
- def __xor__(self, rhs):
- # type: (TruthTable) -> TruthTable
- assert self.num_vars == rhs.num_vars, "mismatched num_vars"
- return TruthTable(tuple(
- l ^ r for l, r in zip(self.truth_table, rhs.truth_table)),
- num_vars=self.num_vars)
-
- __rxor__ = __xor__
-
-
-class TestTruthTable(unittest.TestCase):
- def test_all_truth_tables(self):
- self.assertEqual(list(TruthTable.all_truth_tables(num_vars=2)), [
- TruthTable(truth_table=(False, False, False, False), num_vars=2),
- TruthTable(truth_table=(True, False, False, False), num_vars=2),
- TruthTable(truth_table=(False, True, False, False), num_vars=2),
- TruthTable(truth_table=(True, True, False, False), num_vars=2),
- TruthTable(truth_table=(False, False, True, False), num_vars=2),
- TruthTable(truth_table=(True, False, True, False), num_vars=2),
- TruthTable(truth_table=(False, True, True, False), num_vars=2),
- TruthTable(truth_table=(True, True, True, False), num_vars=2),
- TruthTable(truth_table=(False, False, False, True), num_vars=2),
- TruthTable(truth_table=(True, False, False, True), num_vars=2),
- TruthTable(truth_table=(False, True, False, True), num_vars=2),
- TruthTable(truth_table=(True, True, False, True), num_vars=2),
- TruthTable(truth_table=(False, False, True, True), num_vars=2),
- TruthTable(truth_table=(True, False, True, True), num_vars=2),
- TruthTable(truth_table=(False, True, True, True), num_vars=2),
- TruthTable(truth_table=(True, True, True, True), num_vars=2),
- ])
-
- def test_bits(self):
- bits = [i.bits for i in TruthTable.all_truth_tables(num_vars=2)]
- expected = list(range(16))
- self.assertEqual(bits, expected)
-
- def test_const(self):
- self.assertEqual(TruthTable.for_const(3, False), TruthTable(
- truth_table=(False,) * 8, num_vars=3))
- self.assertEqual(TruthTable.for_const(3, True), TruthTable(
- truth_table=(True,) * 8, num_vars=3))
-
- def test_var(self):
- self.assertEqual(TruthTable.for_var(3, 0), TruthTable(
- truth_table=(False, True) * 4, num_vars=3))
- self.assertEqual(TruthTable.for_var(3, 1), TruthTable(
- truth_table=(False, False, True, True) * 2, num_vars=3))
- self.assertEqual(TruthTable.for_var(3, 2), TruthTable(
- truth_table=(False,) * 4 + (True,) * 4, num_vars=3))
-
- def test_invert(self):
- self.assertEqual(~TruthTable.for_var(3, 0), TruthTable(
- truth_table=(True, False) * 4, num_vars=3))
-
- def test_and(self):
- self.assertEqual(
- TruthTable.for_var(3, 0) & TruthTable.for_var(3, 1),
- TruthTable(truth_table=(False, False, False, True) * 2,
- num_vars=3))
-
- def test_or(self):
- self.assertEqual(
- TruthTable.for_var(3, 0) | TruthTable.for_var(3, 1),
- TruthTable(truth_table=(False, True, True, True) * 2,
- num_vars=3))
-
- def test_xor(self):
- self.assertEqual(
- TruthTable.for_var(3, 0) ^ TruthTable.for_var(3, 1),
- TruthTable(truth_table=(False, True, True, False) * 2,
- num_vars=3))
-
-
-_VARS_PERMS_SET = Set["tuple[int, ...]"]
-_VARS_PERMS = Dict[TruthTable, _VARS_PERMS_SET]
-
-
-@dataclass(unsafe_hash=True, frozen=True)
-class EquivalenceClass:
- """ an equivalence class of all truth tables where variables can be
- arbitrarily permuted.
- """
- truth_tables: "frozenset[TruthTable]"
-
- @staticmethod
- def get_with_vars_perms(truth_table):
- # type: (TruthTable) -> tuple[EquivalenceClass, _VARS_PERMS]
- truth_tables = set() # type: set[TruthTable]
- vars_perms = defaultdict(set) # type: _VARS_PERMS
- iter_ = EquivalenceClass.permuted_truth_tables(truth_table)
- for tt, vars_perm in iter_:
- truth_tables.add(tt)
- vars_perms[tt].add(vars_perm)
- return EquivalenceClass(frozenset(truth_tables)), vars_perms
-
- @staticmethod
- def get(truth_table):
- # type: (TruthTable) -> EquivalenceClass
- retval, _ = EquivalenceClass.get_with_vars_perms(truth_table)
- return retval
-
- @staticmethod
- def permuted_truth_tables(truth_table):
- # type: (TruthTable) -> Iterable[tuple[TruthTable, tuple[int, ...]]]
- for vars_perm in permutations(range(truth_table.num_vars)):
- def fn(variables):
- # type: (Sequence[bool]) -> bool
- return truth_table[[variables[i] for i in vars_perm]]
- perm_tt = TruthTable.from_fn(truth_table.num_vars, fn)
- yield (perm_tt, vars_perm)
-
- @property
- def representative_truth_table(self):
- # type: () -> TruthTable
- return min(self.truth_tables)
-
- def __repr__(self):
- # type: () -> str
- return f"EquivalenceClass({self.representative_truth_table!r})"
-
- def __lt__(self, rhs):
- # type: (EquivalenceClass) -> bool
- return self.representative_truth_table < rhs.representative_truth_table
-
- def __gt__(self, rhs):
- # type: (EquivalenceClass) -> bool
- return self.representative_truth_table > rhs.representative_truth_table
-
- def __le__(self, rhs):
- # type: (EquivalenceClass) -> bool
- return self.representative_truth_table <= rhs.representative_truth_table
-
- def __ge__(self, rhs):
- # type: (EquivalenceClass) -> bool
- return self.representative_truth_table >= rhs.representative_truth_table
-
-
-class TestEquivalenceClass(unittest.TestCase):
- maxDiff = None
-
- def test_get_with_vars_perms(self):
- truth_table = TruthTable.from_bits(2, 0x4)
- eqv_cls, vars_perms = EquivalenceClass.get_with_vars_perms(truth_table)
- self.assertEqual(eqv_cls.truth_tables, frozenset(vars_perms.keys()))
- self.assertEqual(dict(vars_perms), {
- TruthTable(truth_table=(False, False, True, False),
- num_vars=2): {(0, 1)},
- TruthTable(truth_table=(False, True, False, False),
- num_vars=2): {(1, 0)},
- })
- for tt in TruthTable.all_truth_tables(3):
- with self.subTest(tt=repr(tt)):
- eqv_cls, vars_perms = EquivalenceClass.get_with_vars_perms(tt)
- self.assertEqual(eqv_cls.truth_tables,
- frozenset(vars_perms.keys()))
- expected_vars_perms = defaultdict(set) # type: _VARS_PERMS
- expected_vars_perms[tt].add((0, 1, 2))
- perm = TruthTable.from_fn(3, lambda v: tt[(v[0], v[2], v[1])])
- expected_vars_perms[perm].add((0, 2, 1))
- perm = TruthTable.from_fn(3, lambda v: tt[(v[1], v[0], v[2])])
- expected_vars_perms[perm].add((1, 0, 2))
- perm = TruthTable.from_fn(3, lambda v: tt[(v[1], v[2], v[0])])
- expected_vars_perms[perm].add((1, 2, 0))
- perm = TruthTable.from_fn(3, lambda v: tt[(v[2], v[0], v[1])])
- expected_vars_perms[perm].add((2, 0, 1))
- perm = TruthTable.from_fn(3, lambda v: tt[(v[2], v[1], v[0])])
- expected_vars_perms[perm].add((2, 1, 0))
- l = dict(vars_perms)
- r = dict(expected_vars_perms)
- if l != r:
- l = pprint.pformat(l)
- r = pprint.pformat(r)
- self.fail(f"vars_perms not as expected:\nvars_perms:\n{l}"
- f"\nexpected_vars_perms:\n{r}")
-
-
-NUM_VARS = 3
-
-
-def needed_equivalence_classes():
- # type: () -> set[EquivalenceClass]
- ZERO = TruthTable.for_const(NUM_VARS, False)
- ONE = TruthTable.for_const(NUM_VARS, True)
- A = TruthTable.for_var(NUM_VARS, 0)
- B = TruthTable.for_var(NUM_VARS, 1)
- equivalence_classes = set() # type: set[EquivalenceClass]
- for truth_table in TruthTable.all_truth_tables(NUM_VARS):
- equivalence_classes.add(EquivalenceClass.get(truth_table))
-
- # remove equivalence classes that are already covered by
- # existing PowerISA instructions:
-
- # li
- equivalence_classes.remove(EquivalenceClass.get(ZERO))
- equivalence_classes.remove(EquivalenceClass.get(ONE))
-
- # mv
- equivalence_classes.remove(EquivalenceClass.get(A))
-
- # not
- equivalence_classes.remove(EquivalenceClass.get(~A))
-
- # and
- equivalence_classes.remove(EquivalenceClass.get(A & B))
-
- # andc
- equivalence_classes.remove(EquivalenceClass.get(A & ~B))
-
- # nand
- equivalence_classes.remove(EquivalenceClass.get(~(A & B)))
-
- # or
- equivalence_classes.remove(EquivalenceClass.get(A | B))
-
- # orc
- equivalence_classes.remove(EquivalenceClass.get(A | ~B))
-
- # nor
- equivalence_classes.remove(EquivalenceClass.get(~(A | B)))
-
- # xor
- equivalence_classes.remove(EquivalenceClass.get(A ^ B))
-
- # eqv
- equivalence_classes.remove(EquivalenceClass.get(~(A ^ B)))
-
- return equivalence_classes
-
-
-NEEDED_EQUIVALENCE_CLASSES = frozenset(needed_equivalence_classes())
-
-
-class MissingNeededEquivalenceClasses(Exception):
- pass
-
-
-@dataclass(frozen=True, unsafe_hash=True)
-class Imm7AndVarsPerm:
- imm7: int
- vars_perm: "tuple[int, ...]"
-
- def __repr__(self):
- # type: () -> str
- return (f"Imm7AndVarsPerm(imm7={self.imm7:#04x}, "
- f"vars_perm={self.vars_perm})")
-
-
-def verify_covers_needed_equivalence_classes(expand_encoded_imm):
- # type: (Callable[[int], int]) -> str
- equivalence_classes = set(NEEDED_EQUIVALENCE_CLASSES)
- imm8_imm7_map = defaultdict(set) # type: dict[int, set[Imm7AndVarsPerm]]
- log_lines = [] # type: list[str]
- for imm7 in range(1 << 7):
- imm8 = expand_encoded_imm(imm7)
- truth_table = TruthTable.from_bits(num_vars=NUM_VARS, bits=imm8)
- log_lines.append(
- f"imm7={hex(imm7)} imm8={hex(imm8)} truth_table={truth_table}")
- eqv_cls, vars_perms = EquivalenceClass.get_with_vars_perms(truth_table)
- if eqv_cls in equivalence_classes:
- log_lines.append(f"new equivalence class: {eqv_cls}")
- equivalence_classes.remove(eqv_cls)
- for truth_table, vars_perms_set in vars_perms.items():
- imm8 = truth_table.bits
- for vars_perm in vars_perms_set:
- imm8_imm7_map[imm8].add(Imm7AndVarsPerm(
- imm7=imm7, vars_perm=vars_perm))
-
- for imm8 in range(1 << 8):
- imm8_imm7_map[imm8] # fill in empty set defaults
-
- log_lines.append("imm8_imm7_map:")
- log_lines.append(pprint.pformat(dict(imm8_imm7_map)))
- log_lines.append("equivalence classes we failed to cover:")
- log_lines.append(pprint.pformat(sorted(equivalence_classes)))
- if len(equivalence_classes) != 0:
- log_lines.append("failed to cover all equivalence classes")
- raise MissingNeededEquivalenceClasses("\n".join(log_lines))
- return "\n".join(log_lines)
-
-
-@enum.unique
-class OpKind(enum.Enum):
- Shl = "<<"
- Shr = ">>"
- And = "&"
- Or = "|"
- Xor = "^"
- Input = "imm7"
-
- @property
- def is_shift(self):
- return self is OpKind.Shl or self is OpKind.Shr
-
-
-@dataclass(unsafe_hash=True, frozen=True)
-class Operation:
- ssa_reg: int
- kind: OpKind
- lhs: "Operation | int"
- rhs: "Operation | int"
-
- @staticmethod
- def make_random(op_count, rand):
- # type: (int, random.Random) -> Operation
- inp = Operation(ssa_reg=0, kind=OpKind.Input, lhs=0, rhs=0)
- values = [None, inp] # type: list[Operation | None]
- kinds = tuple(OpKind)
-
- def make_arg():
- # type: () -> Operation | int
- arg = rand.choice(values)
- if arg is None:
- return rand.randrange(0, 1 << 8)
- return arg
- for ssa_reg in range(1, op_count):
- kind = rand.choice(kinds)
- lhs = make_arg()
- if kind.is_shift:
- rhs = rand.randint(1, 7)
- else:
- rhs = make_arg()
- values.append(Operation(ssa_reg=ssa_reg, kind=kind,
- lhs=lhs, rhs=rhs))
- retval = values[-1]
- assert retval is not None
- return retval
-
- def to_python(self):
- # type: () -> str
- lines = [
- "def expand_encoded_imm(imm7):",
- " imm7 &= 0x7F",
- ]
- ssa_regs = set() # type: set[int]
-
- def visit(op):
- # type: (Operation | int) -> str
- if isinstance(op, int):
- return hex(op)
- reg = f"v{op.ssa_reg}"
- if op.ssa_reg in ssa_regs:
- return reg
- ssa_regs.add(op.ssa_reg)
- lhs = visit(op.lhs)
- rhs = visit(op.rhs)
- if op.kind is OpKind.Shl:
- lines.append(f" {reg} = {lhs} << {rhs}")
- elif op.kind is OpKind.Shr:
- lines.append(f" {reg} = {lhs} >> {rhs}")
- elif op.kind is OpKind.And:
- lines.append(f" {reg} = {lhs} & {rhs}")
- elif op.kind is OpKind.Or:
- lines.append(f" {reg} = {lhs} | {rhs}")
- elif op.kind is OpKind.Xor:
- lines.append(f" {reg} = {lhs} ^ {rhs}")
- elif op.kind is OpKind.Input:
- lines.append(f" {reg} = imm7")
- else:
- unreachable(op.kind)
- return reg
- final_reg = visit(self)
- lines.append(f" return {final_reg} & 0xFF")
- return "\n".join(lines)
-
- def __repr__(self):
- python_str = self.to_python()
- return f"Operation{{\n{python_str}\n}}"
-
-
-@enum.unique
-class Cmd(enum.Enum):
- Generate = "generate"
- Check = "check"
- UnitTest = "unittest"
- Help = "--help"
-
-
-def parse_command():
- cmd = ""
- if len(sys.argv) > 1:
- cmd = sys.argv[1]
- try:
- retval = Cmd(cmd)
- if retval is not Cmd.Help:
- return retval
- except ValueError:
- pass
- cmds = "|".join(i.value for i in Cmd)
- print(f"usage: {sys.argv[0]} {cmds}", file=sys.stderr)
- exit(1)
-
-
-def main():
- cmd = parse_command()
- if cmd is Cmd.UnitTest:
- del sys.argv[1]
- unittest.main()
- return
- print("equivalence classes we need to cover:")
- pprint.pprint(sorted(NEEDED_EQUIVALENCE_CLASSES))
- print(flush=True)
- if cmd is Cmd.Check:
- log_str = verify_covers_needed_equivalence_classes(
- expand_encoded_imm=expand_encoded_imm)
- print(log_str)
- elif cmd is Cmd.Generate:
- rand = random.Random()
- seed = sys.argv[-1]
- rand.seed(seed, version=2) # seed for reproducibility
- for i in range(100000):
- if i % 1000 == 0:
- print(f"try #{i} seed={seed!r}", flush=True)
- op_count = rand.randint(1, 10)
- python_str = Operation.make_random(op_count, rand).to_python()
- try:
- globals_ = {}
- exec(python_str, globals_)
- fn = globals_["expand_encoded_imm"]
- verify_covers_needed_equivalence_classes(expand_encoded_imm=fn)
- except MissingNeededEquivalenceClasses:
- continue
- success_str = f"# found working function:\nseed={seed!r}\n\n{python_str}"
- print(success_str)
- with open("found_working_function.txt", "xt") as f:
- print(success_str, file=f)
- return
-
-
-if __name__ == "__main__":
- main()