1from itertools import chain 

2from typing import Callable, List, Sequence, Union 

3from collections import defaultdict 

4 

5import sympy as sp 

6 

7from pystencils.assignment import Assignment 

8from pystencils.astnodes import Node 

9from pystencils.field import AbstractField, Field 

10from pystencils.sympyextensions import subs_additive, is_constant, recursive_collect 

11 

12 

13def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]: 

14 """Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs""" 

15 edges = [] 

16 for c1, e1 in enumerate(assignments): 

17 if hasattr(e1, 'lhs') and hasattr(e1, 'rhs'): 

18 symbols = [e1.lhs] 

19 elif isinstance(e1, Node): 

20 symbols = e1.symbols_defined 

21 else: 

22 raise NotImplementedError(f"Cannot sort topologically. Object of type {type(e1)} cannot be handled.") 

23 

24 for lhs in symbols: 

25 for c2, e2 in enumerate(assignments): 

26 if isinstance(e2, Assignment) and lhs in e2.rhs.free_symbols: 

27 edges.append((c1, c2)) 

28 elif isinstance(e2, Node) and lhs in e2.undefined_symbols: 

29 edges.append((c1, c2)) 

30 return [assignments[i] for i in sp.topological_sort((range(len(assignments)), edges))] 

31 

32 

33def sympy_cse(ac, **kwargs): 

34 """Searches for common subexpressions inside the assignment collection. 

35 

36 Searches is done in both the existing subexpressions as well as the assignments themselves. 

37 It uses the sympy subexpression detection to do this. Return a new assignment collection 

38 with the additional subexpressions found 

39 """ 

40 symbol_gen = ac.subexpression_symbol_generator 

41 

42 all_assignments = [e for e in chain(ac.subexpressions, ac.main_assignments) if isinstance(e, Assignment)] 

43 other_objects = [e for e in chain(ac.subexpressions, ac.main_assignments) if not isinstance(e, Assignment)] 

44 replacements, new_eq = sp.cse(all_assignments, symbols=symbol_gen, **kwargs) 

45 

46 replacement_eqs = [Assignment(*r) for r in replacements] 

47 

48 modified_subexpressions = new_eq[:len(ac.subexpressions)] 

49 modified_update_equations = new_eq[len(ac.subexpressions):] 

50 

51 new_subexpressions = sort_assignments_topologically(other_objects + replacement_eqs + modified_subexpressions) 

52 return ac.copy(modified_update_equations, new_subexpressions) 

53 

54 

55def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]: 

56 """Extracts common subexpressions from a list of assignments.""" 

57 from pystencils.simp.assignment_collection import AssignmentCollection 

58 ec = AssignmentCollection([], assignments) 

59 return sympy_cse(ec).all_assignments 

60 

61 

62def subexpression_substitution_in_existing_subexpressions(ac): 

63 """Goes through the subexpressions list and replaces the term in the following subexpressions.""" 

64 result = [] 

65 for outer_ctr, s in enumerate(ac.subexpressions): 

66 new_rhs = s.rhs 

67 for inner_ctr in range(outer_ctr): 

68 sub_expr = ac.subexpressions[inner_ctr] 

69 new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0) 

70 new_rhs = new_rhs.subs(sub_expr.rhs, sub_expr.lhs) 

71 result.append(Assignment(s.lhs, new_rhs)) 

72 

73 return ac.copy(ac.main_assignments, result) 

74 

75 

76def subexpression_substitution_in_main_assignments(ac): 

77 """Replaces already existing subexpressions in the equations of the assignment_collection.""" 

78 result = [] 

79 for s in ac.main_assignments: 

80 new_rhs = s.rhs 

81 for sub_expr in ac.subexpressions: 

82 new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0) 

83 result.append(Assignment(s.lhs, new_rhs)) 

84 return ac.copy(result) 

85 

86 

87def add_subexpressions_for_constants(ac): 

88 """Extracts constant factors to subexpressions in the given assignment collection. 

89 

90 SymPy will exclude common factors from a sum only if they are symbols. This simplification 

91 can be applied to exclude common numeric constants from multiple terms of a sum. As a consequence, 

92 the number of multiplications is reduced and in some cases, more common subexpressions can be found. 

93 """ 

94 constants_to_subexp_dict = defaultdict(lambda: next(ac.subexpression_symbol_generator)) 

95 

96 def visit(expr): 

97 args = list(expr.args) 

98 if len(args) == 0: 

99 return expr 

100 if isinstance(expr, sp.Add) or isinstance(expr, sp.Mul): 

101 for i, arg in enumerate(args): 

102 if is_constant(arg) and abs(arg) != 1: 

103 if arg < 0: 

104 args[i] = - constants_to_subexp_dict[- arg] 

105 else: 

106 args[i] = constants_to_subexp_dict[arg] 

107 return expr.func(*(visit(a) for a in args)) 

108 main_assignments = [Assignment(a.lhs, visit(a.rhs)) for a in ac.main_assignments] 

109 subexpressions = [Assignment(a.lhs, visit(a.rhs)) for a in ac.subexpressions] 

110 

111 symbols_to_collect = set(constants_to_subexp_dict.values()) 

112 

113 main_assignments = [Assignment(a.lhs, recursive_collect(a.rhs, symbols_to_collect, True)) for a in main_assignments] 

114 subexpressions = [Assignment(a.lhs, recursive_collect(a.rhs, symbols_to_collect, True)) for a in subexpressions] 

115 

116 subexpressions = [Assignment(symb, c) for c, symb in constants_to_subexp_dict.items()] + subexpressions 

117 return ac.copy(main_assignments=main_assignments, subexpressions=subexpressions) 

118 

119 

120def add_subexpressions_for_divisions(ac): 

121 r"""Introduces subexpressions for all divisions which have no constant in the denominator. 

122 

123 For example :math:`\frac{1}{x}` is replaced while :math:`\frac{1}{3}` is not replaced. 

124 """ 

125 divisors = set() 

126 

127 def search_divisors(term): 

128 if term.func == sp.Pow: 

129 if term.exp.is_integer and term.exp.is_number and term.exp < 0: 

130 divisors.add(term) 

131 else: 

132 for a in term.args: 

133 search_divisors(a) 

134 

135 for eq in ac.all_assignments: 

136 search_divisors(eq.rhs) 

137 

138 divisors = sorted(list(divisors), key=lambda x: str(x)) 

139 new_symbol_gen = ac.subexpression_symbol_generator 

140 substitutions = {divisor: new_symbol for new_symbol, divisor in zip(new_symbol_gen, divisors)} 

141 return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, substitute_on_lhs=False) 

142 

143 

144def add_subexpressions_for_sums(ac): 

145 r"""Introduces subexpressions for all sums - i.e. splits addends into subexpressions.""" 

146 addends = [] 

147 

148 def contains_sum(term): 

149 if term.func == sp.Add: 

150 return True 

151 if term.is_Atom: 

152 return False 

153 return any([contains_sum(a) for a in term.args]) 

154 

155 def search_addends(term): 

156 if term.func == sp.Add: 

157 if all([not contains_sum(a) for a in term.args]): 

158 addends.extend(term.args) 

159 for a in term.args: 

160 search_addends(a) 

161 

162 for eq in ac.all_assignments: 

163 search_addends(eq.rhs) 

164 

165 addends = [a for a in addends if not isinstance(a, sp.Symbol) or isinstance(a, AbstractField.AbstractAccess)] 

166 new_symbol_gen = ac.subexpression_symbol_generator 

167 substitutions = {addend: new_symbol for new_symbol, addend in zip(new_symbol_gen, addends)} 

168 return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False) 

169 

170 

171def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True): 

172 r"""Substitutes field accesses on rhs of assignments with subexpressions 

173 

174 Can change semantics of the update rule (which is the goal of this transformation) 

175 This is useful if a field should be update in place - all values are loaded before into subexpression variables, 

176 then the new values are computed and written to the same field in-place. 

177 """ 

178 field_reads = set() 

179 to_iterate = [] 

180 if subexpressions: 

181 to_iterate = chain(to_iterate, ac.subexpressions) 

182 if main_assignments: 

183 to_iterate = chain(to_iterate, ac.main_assignments) 

184 

185 for assignment in to_iterate: 

186 if hasattr(assignment, 'lhs') and hasattr(assignment, 'rhs'): 

187 field_reads.update(assignment.rhs.atoms(Field.Access)) 

188 substitutions = {fa: next(ac.subexpression_symbol_generator) for fa in field_reads} 

189 return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, 

190 substitute_on_lhs=False, sort_topologically=False) 

191 

192 

193def transform_rhs(assignment_list, transformation, *args, **kwargs): 

194 """Applies a transformation function on the rhs of each element of the passed assignment list 

195 If the list also contains other object, like AST nodes, these are ignored. 

196 Additional parameters are passed to the transformation function""" 

197 return [Assignment(a.lhs, transformation(a.rhs, *args, **kwargs)) if hasattr(a, 'lhs') and hasattr(a, 'rhs') else a 

198 for a in assignment_list] 

199 

200 

201def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs): 

202 return [Assignment(transformation(a.lhs, *args, **kwargs), 

203 transformation(a.rhs, *args, **kwargs)) 

204 if hasattr(a, 'lhs') and hasattr(a, 'rhs') else a 

205 for a in assignment_list] 

206 

207 

208def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]): 

209 """Applies a given operation to all equations in collection.""" 

210 

211 def f(ac): 

212 return ac.copy(transform_rhs(ac.main_assignments, operation)) 

213 

214 f.__name__ = operation.__name__ 

215 return f 

216 

217 

218def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]): 

219 """Applies the given operation on all subexpressions of the AC.""" 

220 

221 def f(ac): 

222 return ac.copy(ac.main_assignments, transform_rhs(ac.subexpressions, operation)) 

223 

224 f.__name__ = operation.__name__ 

225 return f