1import itertools 

2from copy import copy 

3from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union 

4 

5import sympy as sp 

6 

7import pystencils 

8from pystencils.assignment import Assignment 

9from pystencils.simp.simplifications import (sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs) 

10from pystencils.sympyextensions import count_operations, fast_subs 

11 

12 

13class AssignmentCollection: 

14 """ 

15 A collection of equations with subexpression definitions, also represented as assignments, 

16 that are used in the main equations. AssignmentCollection can be passed to simplification methods. 

17 These simplification methods can change the subexpressions, but the number and 

18 left hand side of the main equations themselves is not altered. 

19 Additionally a dictionary of simplification hints is stored, which are set by the functions that create 

20 assignment collections to transport information to the simplification system. 

21 

22 Attributes: 

23 main_assignments: list of assignments 

24 subexpressions: list of assignments defining subexpressions used in main equations 

25 simplification_hints: dict that is used to annotate the assignment collection with hints that are 

26 used by the simplification system. See documentation of the simplification rules for 

27 potentially required hints and their meaning. 

28 subexpression_symbol_generator: generator for new symbols that are used when new subexpressions are added 

29 used to get new symbols that are unique for this AssignmentCollection 

30 

31 """ 

32 

33 # ------------------------------- Creation & Inplace Manipulation -------------------------------------------------- 

34 

35 def __init__(self, main_assignments: Union[List[Assignment], Dict[sp.Expr, sp.Expr]], 

36 subexpressions: Union[List[Assignment], Dict[sp.Expr, sp.Expr]] = {}, 

37 simplification_hints: Optional[Dict[str, Any]] = None, 

38 subexpression_symbol_generator: Iterator[sp.Symbol] = None) -> None: 

39 if isinstance(main_assignments, Dict): 39 ↛ 40line 39 didn't jump to line 40, because the condition on line 39 was never true

40 main_assignments = [Assignment(k, v) 

41 for k, v in main_assignments.items()] 

42 if isinstance(subexpressions, Dict): 42 ↛ 46line 42 didn't jump to line 46, because the condition on line 42 was never false

43 subexpressions = [Assignment(k, v) 

44 for k, v in subexpressions.items()] 

45 

46 main_assignments = list(itertools.chain.from_iterable( 

47 [(a if isinstance(a, Iterable) else [a]) for a in main_assignments])) 

48 subexpressions = list(itertools.chain.from_iterable( 

49 [(a if isinstance(a, Iterable) else [a]) for a in subexpressions])) 

50 

51 self.main_assignments = main_assignments 

52 self.subexpressions = subexpressions 

53 

54 if simplification_hints is None: 54 ↛ 57line 54 didn't jump to line 57, because the condition on line 54 was never false

55 simplification_hints = {} 

56 

57 self.simplification_hints = simplification_hints 

58 

59 if subexpression_symbol_generator is None: 59 ↛ 62line 59 didn't jump to line 62, because the condition on line 59 was never false

60 self.subexpression_symbol_generator = SymbolGen() 

61 else: 

62 self.subexpression_symbol_generator = subexpression_symbol_generator 

63 

64 def add_simplification_hint(self, key: str, value: Any) -> None: 

65 """Adds an entry to the simplification_hints dictionary and checks that is does not exist yet.""" 

66 assert key not in self.simplification_hints, "This hint already exists" 

67 self.simplification_hints[key] = value 

68 

69 def add_subexpression(self, rhs: sp.Expr, lhs: Optional[sp.Symbol] = None, topological_sort=True) -> sp.Symbol: 

70 """Adds a subexpression to current collection. 

71 

72 Args: 

73 rhs: right hand side of new subexpression 

74 lhs: optional left hand side of new subexpression. If None a new unique symbol is generated. 

75 topological_sort: sort the subexpressions topologically after insertion, to make sure that 

76 definition of a symbol comes before its usage. If False, subexpression is appended. 

77 

78 Returns: 

79 left hand side symbol (which could have been generated) 

80 """ 

81 if lhs is None: 

82 lhs = next(self.subexpression_symbol_generator) 

83 eq = Assignment(lhs, rhs) 

84 self.subexpressions.append(eq) 

85 if topological_sort: 

86 self.topological_sort(sort_subexpressions=True, 

87 sort_main_assignments=False) 

88 return lhs 

89 

90 def topological_sort(self, sort_subexpressions: bool = True, sort_main_assignments: bool = True) -> None: 

91 """Sorts subexpressions and/or main_equations topologically to make sure symbol usage comes after definition.""" 

92 if sort_subexpressions: 

93 self.subexpressions = sort_assignments_topologically(self.subexpressions) 

94 if sort_main_assignments: 

95 self.main_assignments = sort_assignments_topologically(self.main_assignments) 

96 

97 # ---------------------------------------------- Properties ------------------------------------------------------- 

98 

99 @property 

100 def all_assignments(self) -> List[Assignment]: 

101 """Subexpression and main equations as a single list.""" 

102 return self.subexpressions + self.main_assignments 

103 

104 @property 

105 def free_symbols(self) -> Set[sp.Symbol]: 

106 """All symbols used in the assignment collection, which do not occur as left hand sides in any assignment.""" 

107 free_symbols = set() 

108 for eq in self.all_assignments: 

109 if isinstance(eq, Assignment): 

110 free_symbols.update(eq.rhs.atoms(sp.Symbol)) 

111 elif isinstance(eq, pystencils.astnodes.Node): 

112 free_symbols.update(eq.undefined_symbols) 

113 

114 return free_symbols - self.bound_symbols 

115 

116 @property 

117 def bound_symbols(self) -> Set[sp.Symbol]: 

118 """All symbols which occur on the left hand side of a main assignment or a subexpression.""" 

119 bound_symbols_set = set( 

120 [assignment.lhs for assignment in self.all_assignments if isinstance(assignment, Assignment)] 

121 ) 

122 

123 assert len(bound_symbols_set) == len(list(a for a in self.all_assignments if isinstance(a, Assignment))), \ 

124 "Not in SSA form - same symbol assigned multiple times" 

125 

126 bound_symbols_set = bound_symbols_set.union(*[ 

127 assignment.symbols_defined for assignment in self.all_assignments 

128 if isinstance(assignment, pystencils.astnodes.Node) 

129 ] 

130 ) 

131 

132 return bound_symbols_set 

133 

134 @property 

135 def free_fields(self): 

136 """All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment.""" 

137 return {s.field for s in self.free_symbols if hasattr(s, 'field')} 

138 

139 @property 

140 def bound_fields(self): 

141 """All field accessed on the left hand side of a main assignment or a subexpression.""" 

142 return {s.field for s in self.bound_symbols if hasattr(s, 'field')} 

143 

144 @property 

145 def defined_symbols(self) -> Set[sp.Symbol]: 

146 """All symbols which occur as left-hand-sides of one of the main equations""" 

147 return (set( 

148 [assignment.lhs for assignment in self.main_assignments if isinstance(assignment, Assignment)] 

149 ).union(*[assignment.symbols_defined for assignment in self.main_assignments if isinstance( 

150 assignment, pystencils.astnodes.Node)] 

151 )) 

152 

153 @property 

154 def operation_count(self): 

155 """See :func:`count_operations` """ 

156 return count_operations(self.all_assignments, only_type=None) 

157 

158 def atoms(self, *args): 

159 return set().union(*[a.atoms(*args) for a in self.all_assignments]) 

160 

161 def dependent_symbols(self, symbols: Iterable[sp.Symbol]) -> Set[sp.Symbol]: 

162 """Returns all symbols that depend on one of the passed symbols. 

163 

164 A symbol 'a' depends on a symbol 'b', if there is an assignment 'a <- some_expression(b)' i.e. when 

165 'b' is required to compute 'a'. 

166 """ 

167 

168 queue = list(symbols) 

169 

170 def add_symbols_from_expr(expr): 

171 dependent_symbols = expr.atoms(sp.Symbol) 

172 for ds in dependent_symbols: 

173 queue.append(ds) 

174 

175 handled_symbols = set() 

176 assignment_dict = {e.lhs: e.rhs for e in self.all_assignments} 

177 

178 while len(queue) > 0: 

179 e = queue.pop(0) 

180 if e in handled_symbols: 

181 continue 

182 if e in assignment_dict: 

183 add_symbols_from_expr(assignment_dict[e]) 

184 handled_symbols.add(e) 

185 

186 return handled_symbols 

187 

188 def lambdify(self, symbols: Sequence[sp.Symbol], fixed_symbols: Optional[Dict[sp.Symbol, Any]] = None, module=None): 

189 """Returns a python function to evaluate this equation collection. 

190 

191 Args: 

192 symbols: symbol(s) which are the parameter for the created function 

193 fixed_symbols: dictionary with substitutions, that are applied before sympy's lambdify 

194 module: same as sympy.lambdify parameter. Defines which module to use e.g. 'numpy' 

195 

196 Examples: 

197 >>> a, b, c, d = sp.symbols("a b c d") 

198 >>> ac = AssignmentCollection([Assignment(c, a + b), Assignment(d, a**2 + b)], 

199 ... subexpressions=[Assignment(b, a + b / 2)]) 

200 >>> python_function = ac.lambdify([a], fixed_symbols={b: 2}) 

201 >>> python_function(4) 

202 {c: 6, d: 18} 

203 """ 

204 assignments = self.new_with_substitutions(fixed_symbols, substitute_on_lhs=False) if fixed_symbols else self 

205 assignments = assignments.new_without_subexpressions().main_assignments 

206 lambdas = {assignment.lhs: sp.lambdify(symbols, assignment.rhs, module) for assignment in assignments} 

207 

208 def f(*args, **kwargs): 

209 return {s: func(*args, **kwargs) for s, func in lambdas.items()} 

210 

211 return f 

212 # ---------------------------- Creating new modified collections --------------------------------------------------- 

213 

214 def copy(self, 

215 main_assignments: Optional[List[Assignment]] = None, 

216 subexpressions: Optional[List[Assignment]] = None) -> 'AssignmentCollection': 

217 """Returns a copy with optionally replaced main_assignments and/or subexpressions.""" 

218 

219 res = copy(self) 

220 res.simplification_hints = self.simplification_hints.copy() 

221 res.subexpression_symbol_generator = copy(self.subexpression_symbol_generator) 

222 

223 if main_assignments is not None: 

224 res.main_assignments = main_assignments 

225 else: 

226 res.main_assignments = self.main_assignments.copy() 

227 

228 if subexpressions is not None: 

229 res.subexpressions = subexpressions 

230 else: 

231 res.subexpressions = self.subexpressions.copy() 

232 

233 return res 

234 

235 def new_with_substitutions(self, substitutions: Dict, add_substitutions_as_subexpressions: bool = False, 

236 substitute_on_lhs: bool = True, 

237 sort_topologically: bool = True) -> 'AssignmentCollection': 

238 """Returns new object, where terms are substituted according to the passed substitution dict. 

239 

240 Args: 

241 substitutions: dict that is passed to sympy subs, substitutions are done main assignments and subexpressions 

242 add_substitutions_as_subexpressions: if True, the substitutions are added as assignments to subexpressions 

243 substitute_on_lhs: if False, the substitutions are done only on the right hand side of assignments 

244 sort_topologically: if subexpressions are added as substitutions and this parameters is true, 

245 the subexpressions are sorted topologically after insertion 

246 Returns: 

247 New AssignmentCollection where substitutions have been applied, self is not altered. 

248 """ 

249 transform = transform_lhs_and_rhs if substitute_on_lhs else transform_rhs 

250 transformed_subexpressions = transform(self.subexpressions, fast_subs, substitutions) 

251 transformed_assignments = transform(self.main_assignments, fast_subs, substitutions) 

252 

253 if add_substitutions_as_subexpressions: 

254 transformed_subexpressions = [Assignment(b, a) for a, b in 

255 substitutions.items()] + transformed_subexpressions 

256 if sort_topologically: 

257 transformed_subexpressions = sort_assignments_topologically(transformed_subexpressions) 

258 return self.copy(transformed_assignments, transformed_subexpressions) 

259 

260 def new_merged(self, other: 'AssignmentCollection') -> 'AssignmentCollection': 

261 """Returns a new collection which contains self and other. Subexpressions are renamed if they clash.""" 

262 own_definitions = set([e.lhs for e in self.main_assignments]) 

263 other_definitions = set([e.lhs for e in other.main_assignments]) 

264 assert len(own_definitions.intersection(other_definitions)) == 0, \ 

265 "Cannot merge collections, since both define the same symbols" 

266 

267 own_subexpression_symbols = {e.lhs: e.rhs for e in self.subexpressions} 

268 substitution_dict = {} 

269 

270 processed_other_subexpression_equations = [] 

271 for other_subexpression_eq in other.subexpressions: 

272 if other_subexpression_eq.lhs in own_subexpression_symbols: 

273 if other_subexpression_eq.rhs == own_subexpression_symbols[other_subexpression_eq.lhs]: 

274 continue # exact the same subexpression equation exists already 

275 else: 

276 # different definition - a new name has to be introduced 

277 new_lhs = next(self.subexpression_symbol_generator) 

278 new_eq = Assignment(new_lhs, fast_subs(other_subexpression_eq.rhs, substitution_dict)) 

279 processed_other_subexpression_equations.append(new_eq) 

280 substitution_dict[other_subexpression_eq.lhs] = new_lhs 

281 else: 

282 processed_other_subexpression_equations.append(fast_subs(other_subexpression_eq, substitution_dict)) 

283 

284 processed_other_main_assignments = [fast_subs(eq, substitution_dict) for eq in other.main_assignments] 

285 return self.copy(self.main_assignments + processed_other_main_assignments, 

286 self.subexpressions + processed_other_subexpression_equations) 

287 

288 def new_filtered(self, symbols_to_extract: Iterable[sp.Symbol]) -> 'AssignmentCollection': 

289 """Extracts equations that have symbols_to_extract as left hand side, together with necessary subexpressions. 

290 

291 Returns: 

292 new AssignmentCollection, self is not altered 

293 """ 

294 symbols_to_extract = set(symbols_to_extract) 

295 dependent_symbols = self.dependent_symbols(symbols_to_extract) 

296 new_assignments = [] 

297 for eq in self.all_assignments: 

298 if eq.lhs in symbols_to_extract: 

299 new_assignments.append(eq) 

300 

301 new_sub_expr = [eq for eq in self.subexpressions 

302 if eq.lhs in dependent_symbols and eq.lhs not in symbols_to_extract] 

303 return self.copy(new_assignments, new_sub_expr) 

304 

305 def new_without_unused_subexpressions(self) -> 'AssignmentCollection': 

306 """Returns new collection that only contains subexpressions required to compute the main assignments.""" 

307 all_lhs = [eq.lhs for eq in self.main_assignments] 

308 return self.new_filtered(all_lhs) 

309 

310 def new_with_inserted_subexpression(self, symbol: sp.Symbol) -> 'AssignmentCollection': 

311 """Eliminates the subexpression with the given symbol on its left hand side, by substituting it everywhere.""" 

312 new_subexpressions = [] 

313 subs_dict = None 

314 for se in self.subexpressions: 

315 if se.lhs == symbol: 

316 subs_dict = {se.lhs: se.rhs} 

317 else: 

318 new_subexpressions.append(se) 

319 if subs_dict is None: 

320 return self 

321 

322 new_subexpressions = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in new_subexpressions] 

323 new_eqs = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in self.main_assignments] 

324 return self.copy(new_eqs, new_subexpressions) 

325 

326 def new_without_subexpressions(self, subexpressions_to_keep: Set[sp.Symbol] = set()) -> 'AssignmentCollection': 

327 """Returns a new collection where all subexpressions have been inserted.""" 

328 if len(self.subexpressions) == 0: 

329 return self.copy() 

330 

331 subexpressions_to_keep = set(subexpressions_to_keep) 

332 

333 kept_subexpressions = [] 

334 if self.subexpressions[0].lhs in subexpressions_to_keep: 

335 substitution_dict = {} 

336 kept_subexpressions.append(self.subexpressions[0]) 

337 else: 

338 substitution_dict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs} 

339 

340 subexpression = [e for e in self.subexpressions] 

341 for i in range(1, len(subexpression)): 

342 subexpression[i] = fast_subs(subexpression[i], substitution_dict) 

343 if subexpression[i].lhs in subexpressions_to_keep: 

344 kept_subexpressions.append(subexpression[i]) 

345 else: 

346 substitution_dict[subexpression[i].lhs] = subexpression[i].rhs 

347 

348 new_assignment = [fast_subs(eq, substitution_dict) for eq in self.main_assignments] 

349 return self.copy(new_assignment, kept_subexpressions) 

350 

351 # ----------------------------------------- Display and Printing ------------------------------------------------- 

352 

353 def _repr_html_(self): 

354 """Interface to Jupyter notebook, to display as a nicely formatted HTML table""" 

355 def make_html_equation_table(equations): 

356 no_border = 'style="border:none"' 

357 html_table = '<table style="border:none; width: 100%; ">' 

358 line = '<tr {nb}> <td {nb}>$${eq}$$</td> </tr> ' 

359 for eq in equations: 

360 format_dict = {'eq': sp.latex(eq), 

361 'nb': no_border, } 

362 html_table += line.format(**format_dict) 

363 html_table += "</table>" 

364 return html_table 

365 

366 result = "" 

367 if len(self.subexpressions) > 0: 

368 result += "<div>Subexpressions:</div>" 

369 result += make_html_equation_table(self.subexpressions) 

370 result += "<div>Main Assignments:</div>" 

371 result += make_html_equation_table(self.main_assignments) 

372 return result 

373 

374 def __repr__(self): 

375 return f"AssignmentCollection: {str(tuple(self.defined_symbols))[1:-1]} <- f{tuple(self.free_symbols)}" 

376 

377 def __str__(self): 

378 result = "Subexpressions:\n" 

379 for eq in self.subexpressions: 

380 result += f"\t{eq}\n" 

381 result += "Main Assignments:\n" 

382 for eq in self.main_assignments: 

383 result += f"\t{eq}\n" 

384 return result 

385 

386 def __iter__(self): 

387 return self.all_assignments.__iter__() 

388 

389 @property 

390 def main_assignments_dict(self): 

391 return {a.lhs: a.rhs for a in self.main_assignments} 

392 

393 @property 

394 def subexpressions_dict(self): 

395 return {a.lhs: a.rhs for a in self.subexpressions} 

396 

397 def set_main_assignments_from_dict(self, main_assignments_dict): 

398 self.main_assignments = [Assignment(k, v) 

399 for k, v in main_assignments_dict.items()] 

400 

401 def set_sub_expressions_from_dict(self, sub_expressions_dict): 

402 self.subexpressions = [Assignment(k, v) 

403 for k, v in sub_expressions_dict.items()] 

404 

405 def find(self, *args, **kwargs): 

406 return set.union( 

407 *[a.find(*args, **kwargs) for a in self.all_assignments] 

408 ) 

409 

410 def match(self, *args, **kwargs): 

411 rtn = {} 

412 for a in self.all_assignments: 

413 partial_result = a.match(*args, **kwargs) 

414 if partial_result: 

415 rtn.update(partial_result) 

416 return rtn 

417 

418 def subs(self, *args, **kwargs): 

419 return AssignmentCollection( 

420 main_assignments=[a.subs(*args, **kwargs) for a in self.main_assignments], 

421 subexpressions=[a.subs(*args, **kwargs) for a in self.subexpressions] 

422 ) 

423 

424 def replace(self, *args, **kwargs): 

425 return AssignmentCollection( 

426 main_assignments=[a.replace(*args, **kwargs) for a in self.main_assignments], 

427 subexpressions=[a.replace(*args, **kwargs) for a in self.subexpressions] 

428 ) 

429 

430 def __eq__(self, other): 

431 return set(self.all_assignments) == set(other.all_assignments) 

432 

433 def __bool__(self): 

434 return bool(self.all_assignments) 

435 

436 

437class SymbolGen: 

438 """Default symbol generator producing number symbols ζ_0, ζ_1, ...""" 

439 

440 def __init__(self, symbol="xi", dtype=None): 

441 self._ctr = 0 

442 self._symbol = symbol 

443 self._dtype = dtype 

444 

445 def __iter__(self): 

446 return self 

447 

448 def __next__(self): 

449 name = f"{self._symbol}_{self._ctr}" 

450 self._ctr += 1 

451 if self._dtype is not None: 

452 return pystencils.TypedSymbol(name, self._dtype) 

453 return sp.Symbol(name)