tests: Accept SourceFilters as sources for GTest.
authorGabe Black <gabeblack@google.com>
Sun, 3 Dec 2017 09:20:12 +0000 (01:20 -0800)
committerGabe Black <gabeblack@google.com>
Thu, 7 Dec 2017 10:07:53 +0000 (10:07 +0000)
This change introduces the idea of a SourceFilter which is an object
that can filter a SourceList and which can be composed with other
SourceFilters using | and & operators. This means a filter can be
constructed ahead of time, possibly before all sources have been
discovered, and then later applied to any SourceList necessary.

This change also modifies GTest so that it accepts SourceFilters in
addition to normal source files. These filters will be applied to the
final list of all sources, and the result included in the build for
that test.

By default, gtests will build in all sources tagged with 'gtest lib'.
This change also introduces the keyword argument "skip_lib" which will
exclude those files. They can then be left out entirely, or they can be
re-included as part of a more elaborate filter. That would be useful if
someone wanted to write a unit test for, for instance, the warn, etc.
macros which rely on the gtest logging support. Those classes could
be replaced by something under the control of the unit test, while
still including the rest of the gtest library.

Change-Id: I13a846dc884b86b9fdcaf809edefd57bb4168b8e
Reviewed-on: https://gem5-review.googlesource.com/6262
Reviewed-by: Andreas Sandberg <andreas.sandberg@arm.com>
Maintainer: Andreas Sandberg <andreas.sandberg@arm.com>

src/SConscript

index cd42c27c50365e74ce6e26da2837c59a578f3cf2..7cd711693800239d026130c2de9007bb0deadb5a 100755 (executable)
@@ -30,6 +30,7 @@
 
 import array
 import bisect
+import functools
 import imp
 import marshal
 import os
@@ -62,32 +63,68 @@ from m5.util import code_formatter, compareVersions
 # When specifying a source file of some type, a set of tags can be
 # specified for that file.
 
+class SourceFilter(object):
+    def __init__(self, predicate):
+        self.predicate = predicate
+
+    def __or__(self, other):
+        return SourceFilter(lambda tags: self.predicate(tags) or
+                                         other.predicate(tags))
+
+    def __and__(self, other):
+        return SourceFilter(lambda tags: self.predicate(tags) and
+                                         other.predicate(tags))
+
+def with_tags_that(predicate):
+    '''Return a list of sources with tags that satisfy a predicate.'''
+    return SourceFilter(predicate)
+
+def with_any_tags(*tags):
+    '''Return a list of sources with any of the supplied tags.'''
+    return SourceFilter(lambda stags: len(set(tags) & stags) > 0)
+
+def with_all_tags(*tags):
+    '''Return a list of sources with all of the supplied tags.'''
+    return SourceFilter(lambda stags: set(tags) <= stags)
+
+def with_tag(tag):
+    '''Return a list of sources with the supplied tag.'''
+    return SourceFilter(lambda stags: tag in stags)
+
+def without_tags(*tags):
+    '''Return a list of sources without any of the supplied tags.'''
+    return SourceFilter(lambda stags: len(set(tags) & stags) == 0)
+
+def without_tag(tag):
+    '''Return a list of sources with the supplied tag.'''
+    return SourceFilter(lambda stags: tag not in stags)
+
+source_filter_factories = {
+    'with_tags_that': with_tags_that,
+    'with_any_tags': with_any_tags,
+    'with_all_tags': with_all_tags,
+    'with_tag': with_tag,
+    'without_tags': without_tags,
+    'without_tag': without_tag,
+}
+
+Export(source_filter_factories)
+
 class SourceList(list):
-    def with_tags_that(self, predicate):
-        '''Return a list of sources with tags that satisfy a predicate.'''
+    def apply_filter(self, f):
         def match(source):
-            return predicate(source.tags)
+            return f.predicate(source.tags)
         return SourceList(filter(match, self))
 
-    def with_any_tags(self, *tags):
-        '''Return a list of sources with any of the supplied tags.'''
-        return self.with_tags_that(lambda stags: len(set(tags) & stags) > 0)
-
-    def with_all_tags(self, *tags):
-        '''Return a list of sources with all of the supplied tags.'''
-        return self.with_tags_that(lambda stags: set(tags) <= stags)
+    def __getattr__(self, name):
+        func = source_filter_factories.get(name, None)
+        if not func:
+            raise AttributeError
 
-    def with_tag(self, tag):
-        '''Return a list of sources with the supplied tag.'''
-        return self.with_tags_that(lambda stags: tag in stags)
-
-    def without_tags(self, *tags):
-        '''Return a list of sources without any of the supplied tags.'''
-        return self.with_tags_that(lambda stags: len(set(tags) & stags) == 0)
-
-    def without_tag(self, tag):
-        '''Return a list of sources with the supplied tag.'''
-        return self.with_tags_that(lambda stags: tag not in stags)
+        @functools.wraps(func)
+        def wrapper(*args, **kwargs):
+            return self.apply_filter(func(*args, **kwargs))
+        return wrapper
 
 class SourceMeta(type):
     '''Meta class for source files that keeps track of all files of a
@@ -294,11 +331,14 @@ class UnitTest(object):
 
 class GTest(UnitTest):
     '''Create a unit test based on the google test framework.'''
-
     all = []
     def __init__(self, *args, **kwargs):
+        isFilter = lambda arg: isinstance(arg, SourceFilter)
+        self.filters = filter(isFilter, args)
+        args = filter(lambda a: not isFilter(a), args)
         super(GTest, self).__init__(*args, **kwargs)
         self.dir = Dir('.')
+        self.skip_lib = kwargs.pop('skip_lib', False)
 
 # Children should have access
 Export('Source')
@@ -1049,9 +1089,14 @@ def makeEnv(env, label, objsfx, strip=False, disable_partial=False, **kwargs):
     gtest_env = new_env.Clone()
     gtest_env.Append(LIBS=gtest_env['GTEST_LIBS'])
     gtest_env.Append(CPPFLAGS=gtest_env['GTEST_CPPFLAGS'])
+    gtestlib_sources = Source.all.with_tag('gtest lib')
     gtests = []
     for test in GTest.all:
-        test_sources = Source.all.with_tag(str(test.target))
+        test_sources = test.sources
+        if not test.skip_lib:
+            test_sources += gtestlib_sources
+        for f in test.filters:
+            test_sources += Source.all.apply_filter(f)
         test_objs = [ s.static(gtest_env) for s in test_sources ]
         gtests.append(gtest_env.Program(
             test.dir.File('%s.%s' % (test.target, label)), test_objs))