1from itertools import chain

2from typing import Callable, List, Sequence, Union

3from collections import defaultdict

5import sympy as sp

7from pystencils.assignment import Assignment

8from pystencils.astnodes import Node

9from pystencils.field import AbstractField, Field

10from pystencils.sympyextensions import subs_additive, is_constant, recursive_collect

13def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:

14 """Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs"""

15 edges = []

16 for c1, e1 in enumerate(assignments):

17 if hasattr(e1, 'lhs') and hasattr(e1, 'rhs'):

18 symbols = [e1.lhs]

19 elif isinstance(e1, Node):

20 symbols = e1.symbols_defined

21 else:

22 raise NotImplementedError(f"Cannot sort topologically. Object of type {type(e1)} cannot be handled.")

24 for lhs in symbols:

25 for c2, e2 in enumerate(assignments):

26 if isinstance(e2, Assignment) and lhs in e2.rhs.free_symbols:

27 edges.append((c1, c2))

28 elif isinstance(e2, Node) and lhs in e2.undefined_symbols:

29 edges.append((c1, c2))

30 return [assignments[i] for i in sp.topological_sort((range(len(assignments)), edges))]

33def sympy_cse(ac, **kwargs):

34 """Searches for common subexpressions inside the assignment collection.

36 Searches is done in both the existing subexpressions as well as the assignments themselves.

37 It uses the sympy subexpression detection to do this. Return a new assignment collection

38 with the additional subexpressions found

39 """

40 symbol_gen = ac.subexpression_symbol_generator

42 all_assignments = [e for e in chain(ac.subexpressions, ac.main_assignments) if isinstance(e, Assignment)]

43 other_objects = [e for e in chain(ac.subexpressions, ac.main_assignments) if not isinstance(e, Assignment)]

44 replacements, new_eq = sp.cse(all_assignments, symbols=symbol_gen, **kwargs)

46 replacement_eqs = [Assignment(*r) for r in replacements]

48 modified_subexpressions = new_eq[:len(ac.subexpressions)]

49 modified_update_equations = new_eq[len(ac.subexpressions):]

51 new_subexpressions = sort_assignments_topologically(other_objects + replacement_eqs + modified_subexpressions)

52 return ac.copy(modified_update_equations, new_subexpressions)

55def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]:

56 """Extracts common subexpressions from a list of assignments."""

57 from pystencils.simp.assignment_collection import AssignmentCollection

58 ec = AssignmentCollection([], assignments)

59 return sympy_cse(ec).all_assignments

62def subexpression_substitution_in_existing_subexpressions(ac):

63 """Goes through the subexpressions list and replaces the term in the following subexpressions."""

64 result = []

65 for outer_ctr, s in enumerate(ac.subexpressions):

66 new_rhs = s.rhs

67 for inner_ctr in range(outer_ctr):

68 sub_expr = ac.subexpressions[inner_ctr]

69 new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0)

70 new_rhs = new_rhs.subs(sub_expr.rhs, sub_expr.lhs)

71 result.append(Assignment(s.lhs, new_rhs))

73 return ac.copy(ac.main_assignments, result)

76def subexpression_substitution_in_main_assignments(ac):

77 """Replaces already existing subexpressions in the equations of the assignment_collection."""

78 result = []

79 for s in ac.main_assignments:

80 new_rhs = s.rhs

81 for sub_expr in ac.subexpressions:

82 new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0)

83 result.append(Assignment(s.lhs, new_rhs))

84 return ac.copy(result)

88 """Extracts constant factors to subexpressions in the given assignment collection.

90 SymPy will exclude common factors from a sum only if they are symbols. This simplification

91 can be applied to exclude common numeric constants from multiple terms of a sum. As a consequence,

92 the number of multiplications is reduced and in some cases, more common subexpressions can be found.

93 """

94 constants_to_subexp_dict = defaultdict(lambda: next(ac.subexpression_symbol_generator))

96 def visit(expr):

97 args = list(expr.args)

98 if len(args) == 0:

99 return expr

100 if isinstance(expr, sp.Add) or isinstance(expr, sp.Mul):

101 for i, arg in enumerate(args):

102 if is_constant(arg) and abs(arg) != 1:

103 if arg < 0:

104 args[i] = - constants_to_subexp_dict[- arg]

105 else:

106 args[i] = constants_to_subexp_dict[arg]

107 return expr.func(*(visit(a) for a in args))

108 main_assignments = [Assignment(a.lhs, visit(a.rhs)) for a in ac.main_assignments]

109 subexpressions = [Assignment(a.lhs, visit(a.rhs)) for a in ac.subexpressions]

111 symbols_to_collect = set(constants_to_subexp_dict.values())

113 main_assignments = [Assignment(a.lhs, recursive_collect(a.rhs, symbols_to_collect, True)) for a in main_assignments]

114 subexpressions = [Assignment(a.lhs, recursive_collect(a.rhs, symbols_to_collect, True)) for a in subexpressions]

116 subexpressions = [Assignment(symb, c) for c, symb in constants_to_subexp_dict.items()] + subexpressions

117 return ac.copy(main_assignments=main_assignments, subexpressions=subexpressions)

121 r"""Introduces subexpressions for all divisions which have no constant in the denominator.

123 For example :math:\frac{1}{x} is replaced while :math:\frac{1}{3} is not replaced.

124 """

125 divisors = set()

127 def search_divisors(term):

128 if term.func == sp.Pow:

129 if term.exp.is_integer and term.exp.is_number and term.exp < 0:

131 else:

132 for a in term.args:

133 search_divisors(a)

135 for eq in ac.all_assignments:

136 search_divisors(eq.rhs)

138 divisors = sorted(list(divisors), key=lambda x: str(x))

139 new_symbol_gen = ac.subexpression_symbol_generator

140 substitutions = {divisor: new_symbol for new_symbol, divisor in zip(new_symbol_gen, divisors)}

145 r"""Introduces subexpressions for all sums - i.e. splits addends into subexpressions."""

148 def contains_sum(term):

150 return True

151 if term.is_Atom:

152 return False

153 return any([contains_sum(a) for a in term.args])

157 if all([not contains_sum(a) for a in term.args]):

159 for a in term.args:

162 for eq in ac.all_assignments:

165 addends = [a for a in addends if not isinstance(a, sp.Symbol) or isinstance(a, AbstractField.AbstractAccess)]

166 new_symbol_gen = ac.subexpression_symbol_generator

168 return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False)

172 r"""Substitutes field accesses on rhs of assignments with subexpressions

174 Can change semantics of the update rule (which is the goal of this transformation)

175 This is useful if a field should be update in place - all values are loaded before into subexpression variables,

176 then the new values are computed and written to the same field in-place.

177 """

179 to_iterate = []

180 if subexpressions:

181 to_iterate = chain(to_iterate, ac.subexpressions)

182 if main_assignments:

183 to_iterate = chain(to_iterate, ac.main_assignments)

185 for assignment in to_iterate:

186 if hasattr(assignment, 'lhs') and hasattr(assignment, 'rhs'):

188 substitutions = {fa: next(ac.subexpression_symbol_generator) for fa in field_reads}

190 substitute_on_lhs=False, sort_topologically=False)

193def transform_rhs(assignment_list, transformation, *args, **kwargs):

194 """Applies a transformation function on the rhs of each element of the passed assignment list

195 If the list also contains other object, like AST nodes, these are ignored.

196 Additional parameters are passed to the transformation function"""

197 return [Assignment(a.lhs, transformation(a.rhs, *args, **kwargs)) if hasattr(a, 'lhs') and hasattr(a, 'rhs') else a

198 for a in assignment_list]

201def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs):

202 return [Assignment(transformation(a.lhs, *args, **kwargs),

203 transformation(a.rhs, *args, **kwargs))

204 if hasattr(a, 'lhs') and hasattr(a, 'rhs') else a

205 for a in assignment_list]

208def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]):

209 """Applies a given operation to all equations in collection."""

211 def f(ac):

212 return ac.copy(transform_rhs(ac.main_assignments, operation))

214 f.__name__ = operation.__name__

215 return f

218def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]):

219 """Applies the given operation on all subexpressions of the AC."""

221 def f(ac):

222 return ac.copy(ac.main_assignments, transform_rhs(ac.subexpressions, operation))

224 f.__name__ = operation.__name__

225 return f