1import itertools

2from copy import copy

3from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union

5import sympy as sp

7import pystencils

8from pystencils.assignment import Assignment

9from pystencils.simp.simplifications import (sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs)

10from pystencils.sympyextensions import count_operations, fast_subs

13class AssignmentCollection:

14 """

15 A collection of equations with subexpression definitions, also represented as assignments,

16 that are used in the main equations. AssignmentCollection can be passed to simplification methods.

17 These simplification methods can change the subexpressions, but the number and

18 left hand side of the main equations themselves is not altered.

19 Additionally a dictionary of simplification hints is stored, which are set by the functions that create

20 assignment collections to transport information to the simplification system.

22 Attributes:

23 main_assignments: list of assignments

24 subexpressions: list of assignments defining subexpressions used in main equations

25 simplification_hints: dict that is used to annotate the assignment collection with hints that are

26 used by the simplification system. See documentation of the simplification rules for

27 potentially required hints and their meaning.

28 subexpression_symbol_generator: generator for new symbols that are used when new subexpressions are added

29 used to get new symbols that are unique for this AssignmentCollection

31 """

33 # ------------------------------- Creation & Inplace Manipulation --------------------------------------------------

35 def __init__(self, main_assignments: Union[List[Assignment], Dict[sp.Expr, sp.Expr]],

36 subexpressions: Union[List[Assignment], Dict[sp.Expr, sp.Expr]] = {},

37 simplification_hints: Optional[Dict[str, Any]] = None,

38 subexpression_symbol_generator: Iterator[sp.Symbol] = None) -> None:

39 if isinstance(main_assignments, Dict): 39 ↛ 40line 39 didn't jump to line 40, because the condition on line 39 was never true

40 main_assignments = [Assignment(k, v)

41 for k, v in main_assignments.items()]

42 if isinstance(subexpressions, Dict): 42 ↛ 46line 42 didn't jump to line 46, because the condition on line 42 was never false

43 subexpressions = [Assignment(k, v)

44 for k, v in subexpressions.items()]

46 main_assignments = list(itertools.chain.from_iterable(

47 [(a if isinstance(a, Iterable) else [a]) for a in main_assignments]))

48 subexpressions = list(itertools.chain.from_iterable(

49 [(a if isinstance(a, Iterable) else [a]) for a in subexpressions]))

51 self.main_assignments = main_assignments

52 self.subexpressions = subexpressions

54 if simplification_hints is None: 54 ↛ 57line 54 didn't jump to line 57, because the condition on line 54 was never false

55 simplification_hints = {}

57 self.simplification_hints = simplification_hints

59 if subexpression_symbol_generator is None: 59 ↛ 62line 59 didn't jump to line 62, because the condition on line 59 was never false

60 self.subexpression_symbol_generator = SymbolGen()

61 else:

62 self.subexpression_symbol_generator = subexpression_symbol_generator

64 def add_simplification_hint(self, key: str, value: Any) -> None:

65 """Adds an entry to the simplification_hints dictionary and checks that is does not exist yet."""

66 assert key not in self.simplification_hints, "This hint already exists"

67 self.simplification_hints[key] = value

69 def add_subexpression(self, rhs: sp.Expr, lhs: Optional[sp.Symbol] = None, topological_sort=True) -> sp.Symbol:

70 """Adds a subexpression to current collection.

72 Args:

73 rhs: right hand side of new subexpression

74 lhs: optional left hand side of new subexpression. If None a new unique symbol is generated.

75 topological_sort: sort the subexpressions topologically after insertion, to make sure that

76 definition of a symbol comes before its usage. If False, subexpression is appended.

78 Returns:

79 left hand side symbol (which could have been generated)

80 """

81 if lhs is None:

82 lhs = next(self.subexpression_symbol_generator)

83 eq = Assignment(lhs, rhs)

84 self.subexpressions.append(eq)

85 if topological_sort:

86 self.topological_sort(sort_subexpressions=True,

87 sort_main_assignments=False)

88 return lhs

90 def topological_sort(self, sort_subexpressions: bool = True, sort_main_assignments: bool = True) -> None:

91 """Sorts subexpressions and/or main_equations topologically to make sure symbol usage comes after definition."""

92 if sort_subexpressions:

93 self.subexpressions = sort_assignments_topologically(self.subexpressions)

94 if sort_main_assignments:

95 self.main_assignments = sort_assignments_topologically(self.main_assignments)

97 # ---------------------------------------------- Properties -------------------------------------------------------

99 @property

100 def all_assignments(self) -> List[Assignment]:

101 """Subexpression and main equations as a single list."""

102 return self.subexpressions + self.main_assignments

104 @property

105 def free_symbols(self) -> Set[sp.Symbol]:

106 """All symbols used in the assignment collection, which do not occur as left hand sides in any assignment."""

107 free_symbols = set()

108 for eq in self.all_assignments:

109 if isinstance(eq, Assignment):

110 free_symbols.update(eq.rhs.atoms(sp.Symbol))

111 elif isinstance(eq, pystencils.astnodes.Node):

112 free_symbols.update(eq.undefined_symbols)

114 return free_symbols - self.bound_symbols

116 @property

117 def bound_symbols(self) -> Set[sp.Symbol]:

118 """All symbols which occur on the left hand side of a main assignment or a subexpression."""

119 bound_symbols_set = set(

120 [assignment.lhs for assignment in self.all_assignments if isinstance(assignment, Assignment)]

121 )

123 assert len(bound_symbols_set) == len(list(a for a in self.all_assignments if isinstance(a, Assignment))), \

124 "Not in SSA form - same symbol assigned multiple times"

126 bound_symbols_set = bound_symbols_set.union(*[

127 assignment.symbols_defined for assignment in self.all_assignments

128 if isinstance(assignment, pystencils.astnodes.Node)

129 ]

130 )

132 return bound_symbols_set

134 @property

135 def free_fields(self):

136 """All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment."""

137 return {s.field for s in self.free_symbols if hasattr(s, 'field')}

139 @property

140 def bound_fields(self):

141 """All field accessed on the left hand side of a main assignment or a subexpression."""

142 return {s.field for s in self.bound_symbols if hasattr(s, 'field')}

144 @property

145 def defined_symbols(self) -> Set[sp.Symbol]:

146 """All symbols which occur as left-hand-sides of one of the main equations"""

147 return (set(

148 [assignment.lhs for assignment in self.main_assignments if isinstance(assignment, Assignment)]

149 ).union(*[assignment.symbols_defined for assignment in self.main_assignments if isinstance(

150 assignment, pystencils.astnodes.Node)]

151 ))

153 @property

154 def operation_count(self):

155 """See :func:`count_operations` """

156 return count_operations(self.all_assignments, only_type=None)

158 def atoms(self, *args):

159 return set().union(*[a.atoms(*args) for a in self.all_assignments])

161 def dependent_symbols(self, symbols: Iterable[sp.Symbol]) -> Set[sp.Symbol]:

162 """Returns all symbols that depend on one of the passed symbols.

164 A symbol 'a' depends on a symbol 'b', if there is an assignment 'a <- some_expression(b)' i.e. when

165 'b' is required to compute 'a'.

166 """

168 queue = list(symbols)

171 dependent_symbols = expr.atoms(sp.Symbol)

172 for ds in dependent_symbols:

173 queue.append(ds)

175 handled_symbols = set()

176 assignment_dict = {e.lhs: e.rhs for e in self.all_assignments}

178 while len(queue) > 0:

179 e = queue.pop(0)

180 if e in handled_symbols:

181 continue

182 if e in assignment_dict:

186 return handled_symbols

188 def lambdify(self, symbols: Sequence[sp.Symbol], fixed_symbols: Optional[Dict[sp.Symbol, Any]] = None, module=None):

189 """Returns a python function to evaluate this equation collection.

191 Args:

192 symbols: symbol(s) which are the parameter for the created function

193 fixed_symbols: dictionary with substitutions, that are applied before sympy's lambdify

194 module: same as sympy.lambdify parameter. Defines which module to use e.g. 'numpy'

196 Examples:

197 >>> a, b, c, d = sp.symbols("a b c d")

198 >>> ac = AssignmentCollection([Assignment(c, a + b), Assignment(d, a**2 + b)],

199 ... subexpressions=[Assignment(b, a + b / 2)])

200 >>> python_function = ac.lambdify([a], fixed_symbols={b: 2})

201 >>> python_function(4)

202 {c: 6, d: 18}

203 """

204 assignments = self.new_with_substitutions(fixed_symbols, substitute_on_lhs=False) if fixed_symbols else self

205 assignments = assignments.new_without_subexpressions().main_assignments

206 lambdas = {assignment.lhs: sp.lambdify(symbols, assignment.rhs, module) for assignment in assignments}

208 def f(*args, **kwargs):

209 return {s: func(*args, **kwargs) for s, func in lambdas.items()}

211 return f

212 # ---------------------------- Creating new modified collections ---------------------------------------------------

214 def copy(self,

215 main_assignments: Optional[List[Assignment]] = None,

216 subexpressions: Optional[List[Assignment]] = None) -> 'AssignmentCollection':

217 """Returns a copy with optionally replaced main_assignments and/or subexpressions."""

219 res = copy(self)

220 res.simplification_hints = self.simplification_hints.copy()

221 res.subexpression_symbol_generator = copy(self.subexpression_symbol_generator)

223 if main_assignments is not None:

224 res.main_assignments = main_assignments

225 else:

226 res.main_assignments = self.main_assignments.copy()

228 if subexpressions is not None:

229 res.subexpressions = subexpressions

230 else:

231 res.subexpressions = self.subexpressions.copy()

233 return res

235 def new_with_substitutions(self, substitutions: Dict, add_substitutions_as_subexpressions: bool = False,

236 substitute_on_lhs: bool = True,

237 sort_topologically: bool = True) -> 'AssignmentCollection':

238 """Returns new object, where terms are substituted according to the passed substitution dict.

240 Args:

241 substitutions: dict that is passed to sympy subs, substitutions are done main assignments and subexpressions

242 add_substitutions_as_subexpressions: if True, the substitutions are added as assignments to subexpressions

243 substitute_on_lhs: if False, the substitutions are done only on the right hand side of assignments

244 sort_topologically: if subexpressions are added as substitutions and this parameters is true,

245 the subexpressions are sorted topologically after insertion

246 Returns:

247 New AssignmentCollection where substitutions have been applied, self is not altered.

248 """

249 transform = transform_lhs_and_rhs if substitute_on_lhs else transform_rhs

250 transformed_subexpressions = transform(self.subexpressions, fast_subs, substitutions)

251 transformed_assignments = transform(self.main_assignments, fast_subs, substitutions)

254 transformed_subexpressions = [Assignment(b, a) for a, b in

255 substitutions.items()] + transformed_subexpressions

256 if sort_topologically:

257 transformed_subexpressions = sort_assignments_topologically(transformed_subexpressions)

258 return self.copy(transformed_assignments, transformed_subexpressions)

260 def new_merged(self, other: 'AssignmentCollection') -> 'AssignmentCollection':

261 """Returns a new collection which contains self and other. Subexpressions are renamed if they clash."""

262 own_definitions = set([e.lhs for e in self.main_assignments])

263 other_definitions = set([e.lhs for e in other.main_assignments])

264 assert len(own_definitions.intersection(other_definitions)) == 0, \

265 "Cannot merge collections, since both define the same symbols"

267 own_subexpression_symbols = {e.lhs: e.rhs for e in self.subexpressions}

268 substitution_dict = {}

270 processed_other_subexpression_equations = []

271 for other_subexpression_eq in other.subexpressions:

272 if other_subexpression_eq.lhs in own_subexpression_symbols:

273 if other_subexpression_eq.rhs == own_subexpression_symbols[other_subexpression_eq.lhs]:

274 continue # exact the same subexpression equation exists already

275 else:

276 # different definition - a new name has to be introduced

277 new_lhs = next(self.subexpression_symbol_generator)

278 new_eq = Assignment(new_lhs, fast_subs(other_subexpression_eq.rhs, substitution_dict))

279 processed_other_subexpression_equations.append(new_eq)

280 substitution_dict[other_subexpression_eq.lhs] = new_lhs

281 else:

282 processed_other_subexpression_equations.append(fast_subs(other_subexpression_eq, substitution_dict))

284 processed_other_main_assignments = [fast_subs(eq, substitution_dict) for eq in other.main_assignments]

285 return self.copy(self.main_assignments + processed_other_main_assignments,

286 self.subexpressions + processed_other_subexpression_equations)

288 def new_filtered(self, symbols_to_extract: Iterable[sp.Symbol]) -> 'AssignmentCollection':

289 """Extracts equations that have symbols_to_extract as left hand side, together with necessary subexpressions.

291 Returns:

292 new AssignmentCollection, self is not altered

293 """

294 symbols_to_extract = set(symbols_to_extract)

295 dependent_symbols = self.dependent_symbols(symbols_to_extract)

296 new_assignments = []

297 for eq in self.all_assignments:

298 if eq.lhs in symbols_to_extract:

299 new_assignments.append(eq)

301 new_sub_expr = [eq for eq in self.subexpressions

302 if eq.lhs in dependent_symbols and eq.lhs not in symbols_to_extract]

303 return self.copy(new_assignments, new_sub_expr)

305 def new_without_unused_subexpressions(self) -> 'AssignmentCollection':

306 """Returns new collection that only contains subexpressions required to compute the main assignments."""

307 all_lhs = [eq.lhs for eq in self.main_assignments]

308 return self.new_filtered(all_lhs)

310 def new_with_inserted_subexpression(self, symbol: sp.Symbol) -> 'AssignmentCollection':

311 """Eliminates the subexpression with the given symbol on its left hand side, by substituting it everywhere."""

312 new_subexpressions = []

313 subs_dict = None

314 for se in self.subexpressions:

315 if se.lhs == symbol:

316 subs_dict = {se.lhs: se.rhs}

317 else:

318 new_subexpressions.append(se)

319 if subs_dict is None:

320 return self

322 new_subexpressions = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in new_subexpressions]

323 new_eqs = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in self.main_assignments]

324 return self.copy(new_eqs, new_subexpressions)

326 def new_without_subexpressions(self, subexpressions_to_keep: Set[sp.Symbol] = set()) -> 'AssignmentCollection':

327 """Returns a new collection where all subexpressions have been inserted."""

328 if len(self.subexpressions) == 0:

329 return self.copy()

331 subexpressions_to_keep = set(subexpressions_to_keep)

333 kept_subexpressions = []

334 if self.subexpressions.lhs in subexpressions_to_keep:

335 substitution_dict = {}

336 kept_subexpressions.append(self.subexpressions)

337 else:

338 substitution_dict = {self.subexpressions.lhs: self.subexpressions.rhs}

340 subexpression = [e for e in self.subexpressions]

341 for i in range(1, len(subexpression)):

342 subexpression[i] = fast_subs(subexpression[i], substitution_dict)

343 if subexpression[i].lhs in subexpressions_to_keep:

344 kept_subexpressions.append(subexpression[i])

345 else:

346 substitution_dict[subexpression[i].lhs] = subexpression[i].rhs

348 new_assignment = [fast_subs(eq, substitution_dict) for eq in self.main_assignments]

349 return self.copy(new_assignment, kept_subexpressions)

351 # ----------------------------------------- Display and Printing -------------------------------------------------

353 def _repr_html_(self):

354 """Interface to Jupyter notebook, to display as a nicely formatted HTML table"""

355 def make_html_equation_table(equations):

356 no_border = 'style="border:none"'

357 html_table = '<table style="border:none; width: 100%; ">'

358 line = '<tr {nb}> <td {nb}>\$\${eq}\$\$</td> </tr> '

359 for eq in equations:

360 format_dict = {'eq': sp.latex(eq),

361 'nb': no_border, }

362 html_table += line.format(**format_dict)

363 html_table += "</table>"

364 return html_table

366 result = ""

367 if len(self.subexpressions) > 0:

368 result += "<div>Subexpressions:</div>"

369 result += make_html_equation_table(self.subexpressions)

370 result += "<div>Main Assignments:</div>"

371 result += make_html_equation_table(self.main_assignments)

372 return result

374 def __repr__(self):

375 return f"AssignmentCollection: {str(tuple(self.defined_symbols))[1:-1]} <- f{tuple(self.free_symbols)}"

377 def __str__(self):

378 result = "Subexpressions:\n"

379 for eq in self.subexpressions:

380 result += f"\t{eq}\n"

381 result += "Main Assignments:\n"

382 for eq in self.main_assignments:

383 result += f"\t{eq}\n"

384 return result

386 def __iter__(self):

387 return self.all_assignments.__iter__()

389 @property

390 def main_assignments_dict(self):

391 return {a.lhs: a.rhs for a in self.main_assignments}

393 @property

394 def subexpressions_dict(self):

395 return {a.lhs: a.rhs for a in self.subexpressions}

397 def set_main_assignments_from_dict(self, main_assignments_dict):

398 self.main_assignments = [Assignment(k, v)

399 for k, v in main_assignments_dict.items()]

401 def set_sub_expressions_from_dict(self, sub_expressions_dict):

402 self.subexpressions = [Assignment(k, v)

403 for k, v in sub_expressions_dict.items()]

405 def find(self, *args, **kwargs):

406 return set.union(

407 *[a.find(*args, **kwargs) for a in self.all_assignments]

408 )

410 def match(self, *args, **kwargs):

411 rtn = {}

412 for a in self.all_assignments:

413 partial_result = a.match(*args, **kwargs)

414 if partial_result:

415 rtn.update(partial_result)

416 return rtn

418 def subs(self, *args, **kwargs):

419 return AssignmentCollection(

420 main_assignments=[a.subs(*args, **kwargs) for a in self.main_assignments],

421 subexpressions=[a.subs(*args, **kwargs) for a in self.subexpressions]

422 )

424 def replace(self, *args, **kwargs):

425 return AssignmentCollection(

426 main_assignments=[a.replace(*args, **kwargs) for a in self.main_assignments],

427 subexpressions=[a.replace(*args, **kwargs) for a in self.subexpressions]

428 )

430 def __eq__(self, other):

431 return set(self.all_assignments) == set(other.all_assignments)

433 def __bool__(self):

434 return bool(self.all_assignments)

437class SymbolGen:

438 """Default symbol generator producing number symbols ζ_0, ζ_1, ..."""

440 def __init__(self, symbol="xi", dtype=None):

441 self._ctr = 0

442 self._symbol = symbol

443 self._dtype = dtype

445 def __iter__(self):

446 return self

448 def __next__(self):

449 name = f"{self._symbol}_{self._ctr}"

450 self._ctr += 1

451 if self._dtype is not None:

452 return pystencils.TypedSymbol(name, self._dtype)

453 return sp.Symbol(name)