def group_by_targets(sl):
groups = []
- for statement_order, statement in enumerate(flat_iteration(sl)):
- targets = list_targets(statement)
-
- chk_groups = [(targets.isdisjoint(g[0]), g) for g in groups]
- merge_groups = [g for dj, g in chk_groups if not dj]
- groups = [g for dj, g in chk_groups if dj]
-
- new_group = (set(targets), [(statement_order, statement)])
-
- for g in merge_groups:
- new_group[0].update(g[0])
- new_group[1].extend(g[1])
-
- groups.append(new_group)
-
- return [(target, _resort_statements(stmts))
- for target, stmts in groups]
+ seen = set()
+ for order, stmt in enumerate(flat_iteration(sl)):
+ targets = set(list_targets(stmt))
+ group = [(order, stmt)]
+ disjoint = targets.isdisjoint(seen)
+ seen |= targets
+ if not disjoint:
+ groups, old_groups = [], groups
+ for old_targets, old_group in old_groups:
+ if targets.isdisjoint(old_targets):
+ groups.append((old_targets, old_group))
+ else:
+ targets |= old_targets
+ group += old_group
+ groups.append((targets, group))
+ return [(targets, _resort_statements(stmts))
+ for targets, stmts in groups]
def list_special_ios(f, ins, outs, inouts):
r = set()