From 7aa760d82803909ddef24d203541d54fcf583cbc Mon Sep 17 00:00:00 2001 From: Dmitry Selyutin Date: Fri, 9 Jun 2023 22:20:38 +0300 Subject: [PATCH] insndb/core: support class walking --- src/openpower/insndb/core.py | 52 +++++++++++++++++++++++++++--------- 1 file changed, 40 insertions(+), 12 deletions(-) diff --git a/src/openpower/insndb/core.py b/src/openpower/insndb/core.py index 6a2054be..05042d2d 100644 --- a/src/openpower/insndb/core.py +++ b/src/openpower/insndb/core.py @@ -4,6 +4,7 @@ import csv as _csv import dataclasses as _dataclasses import enum as _enum import functools as _functools +import inspect as _inspect import os as _os import operator as _operator import pathlib as _pathlib @@ -56,8 +57,21 @@ from openpower.decoder.power_fields import ( from openpower.decoder.pseudo.pagereader import ISA as _ISA +class walkmethod: + def __init__(self, walk): + self.__walk = walk + return super().__init__() + + def __get__(self, instance, owner): + entity = instance + if instance is None: + entity = owner + return _functools.partial(self.__walk, entity) + + class Node: - def walk(self, match=None): + @walkmethod + def walk(clsself, match=None): return () @@ -76,14 +90,20 @@ class DataclassMeta(type): class Dataclass(metaclass=DataclassMeta): - def walk(self, match=None): + @walkmethod + def walk(clsself, match=None): if match is None: match = lambda subnode: True - def subnode(field): - return getattr(self, field.name) + def field_type(field): + return field.type - yield from filter(match, map(subnode, _dataclasses.fields(self))) + def field_value(field): + return getattr(clsself, field.name) + + field = (field_type if isinstance(clsself, type) else field_value) + + yield from filter(match, map(field, _dataclasses.fields(clsself))) class Visitor: @@ -3724,13 +3744,17 @@ class Records(tuple): def __new__(cls, records): return super().__new__(cls, sorted(records)) - def walk(self, match=None): + @walkmethod + def walk(clsself, match=None): if match is None: match = lambda subnode: True - for record in self: - if match(record): - yield record + if isinstance(clsself, type): + yield Record + else: + for record in clsself: + if match(record): + yield record class Database(Node): @@ -3767,12 +3791,16 @@ class Database(Node): return super().__init__() - def walk(self, match=None): + @walkmethod + def walk(clsself, match=None): if match is None: match = lambda subnode: True - if match(self.__db): - yield self.__db + if isinstance(clsself, type): + yield Records + else: + if match(clsself.__db): + yield clsself.__db def __repr__(self): return repr(self.__db) -- 2.30.2