1from collections import defaultdict, namedtuple

3import sympy as sp

5from pystencils.field import Field

6from pystencils.sympyextensions import normalize_product, prod

9def _default_diff_sort_key(d):

10 return str(d.superscript), str(d.target)

13class Diff(sp.Expr):

14 """Sympy Node representing a derivative.

16 The difference to sympy's built in differential is:

17 - shortened latex representation

18 - all simplifications have to be done manually

19 - optional marker displayed as superscript

20 """

21 is_number = False

22 is_Rational = False

23 _diff_wrt = True

25 def __new__(cls, argument, target=-1, superscript=-1):

26 if argument == 0:

27 return sp.Rational(0, 1)

28 if isinstance(argument, Field):

29 argument = argument.center

30 return sp.Expr.__new__(cls, argument.expand(), sp.sympify(target), sp.sympify(superscript))

32 @property

33 def is_commutative(self):

34 any_non_commutative = any(not s.is_commutative for s in self.atoms(sp.Symbol))

35 if any_non_commutative:

36 return False

37 else:

38 return True

40 def get_arg_recursive(self):

41 """Returns the argument the derivative acts on, for nested derivatives the inner argument is returned"""

42 if not isinstance(self.arg, Diff):

43 return self.arg

44 else:

45 return self.arg.get_arg_recursive()

47 def change_arg_recursive(self, new_arg):

48 """Returns a Diff node with the given 'new_arg' instead of the current argument. For nested derivatives

49 a new nested derivative is returned where the inner Diff has the 'new_arg'"""

50 if not isinstance(self.arg, Diff):

51 return Diff(new_arg, self.target, self.superscript)

52 else:

53 return Diff(self.arg.change_arg_recursive(new_arg), self.target, self.superscript)

55 def split_linear(self, functions):

56 """

57 Applies linearity property of Diff: i.e. 'Diff(c*a+b)' is transformed to 'c * Diff(a) + Diff(b)'

58 The parameter functions is a list of all symbols that are considered functions, not constants.

59 For the example above: functions=[a, b]

60 """

61 constant, variable = 1, 1

63 if self.arg.func != sp.Mul:

64 constant, variable = 1, self.arg

65 else:

66 for factor in normalize_product(self.arg):

67 if factor in functions or isinstance(factor, Diff):

68 variable *= factor

69 else:

70 constant *= factor

72 if isinstance(variable, sp.Symbol) and variable not in functions:

73 return 0

75 if isinstance(variable, int) or variable.is_number:

76 return 0

77 else:

78 return constant * Diff(variable, target=self.target, superscript=self.superscript)

80 @property

81 def arg(self):

82 """Expression the derivative acts on"""

83 return self.args[0]

85 @property

86 def target(self):

87 """Subscript, usually the variable the Diff is w.r.t. """

88 return self.args[1]

90 @property

91 def superscript(self):

92 """Superscript, for example used as the Chapman-Enskog order index"""

93 return self.args[2]

95 def _latex(self, printer, *_):

96 result = r"{\partial"

97 if self.superscript >= 0:

98 result += "^{(%s)}" % (self.superscript,)

99 if self.target != -1:

100 result += "_{%s}" % (printer.doprint(self.target),)

102 contents = printer.doprint(self.arg)

103 if isinstance(self.arg, int) or isinstance(self.arg, sp.Symbol) or self.arg.is_number or self.arg.func == Diff:

104 result += " " + contents

105 else:

106 result += " (" + contents + ") "

108 result += "}"

109 return result

111 def __str__(self):

112 return f"D({self.arg})"

114 def interpolated_access(self, offset, **kwargs):

115 """Represents an interpolated access on a spatially differentiated field

117 Args:

118 offset (Tuple[sympy.Expr]): Absolute position to determine the value of the spatial derivative

119 """

120 from pystencils.interpolation_astnodes import DiffInterpolatorAccess

121 assert isinstance(self.arg.field, Field), "Must be field to enable interpolated accesses"

122 return DiffInterpolatorAccess(self.arg.field.interpolated_access(offset, **kwargs).symbol, self.target, *offset)

125class DiffOperator(sp.Expr):

126 """Un-applied differential, i.e. differential operator

128 Args:

129 target: the differential is w.r.t to this variable.

130 This target is mainly for display purposes (its the subscript) and to distinguish DiffOperators

131 If the target is '-1' no subscript is displayed

132 superscript: optional marker displayed as superscript

133 is not displayed if set to '-1'

135 The DiffOperator behaves much like a variable with special name. Its main use is to be applied later, using the

136 DiffOperator.apply(expr, arg) which transforms 'DiffOperator's to applied 'Diff's

137 """

138 is_commutative = True

139 is_number = False

140 is_Rational = False

142 def __new__(cls, target=-1, superscript=-1):

143 return sp.Expr.__new__(cls, sp.sympify(target), sp.sympify(superscript))

145 @property

146 def target(self):

147 return self.args[0]

149 @property

150 def superscript(self):

151 return self.args[1]

153 def _latex(self, *_):

154 result = r"{\partial"

155 if self.superscript >= 0:

156 result += "^{(%s)}" % (self.superscript,)

157 if self.target != -1:

158 result += "_{%s}" % (self.target,)

159 result += "}"

160 return result

162 @staticmethod

163 def apply(expr, argument, apply_to_constants=True):

164 """

165 Returns a new expression where each 'DiffOperator' is replaced by a 'Diff' node.

166 Multiplications of 'DiffOperator's are interpreted as nested application of differentiation:

167 i.e. DiffOperator('x')*DiffOperator('x') is a second derivative replaced by Diff(Diff(arg, x), t)

168 """

170 def handle_mul(mul):

171 args = normalize_product(mul)

172 diffs = [a for a in args if isinstance(a, DiffOperator)]

173 if len(diffs) == 0:

174 return mul * argument if apply_to_constants else mul

175 rest = [a for a in args if not isinstance(a, DiffOperator)]

176 diffs.sort(key=_default_diff_sort_key)

177 result = argument

178 for d in reversed(diffs):

179 result = Diff(result, target=d.target, superscript=d.superscript)

180 return prod(rest) * result

182 expr = expr.expand()

183 if expr.func == sp.Mul or expr.func == sp.Pow:

184 return handle_mul(expr)

186 return expr.func(*[handle_mul(a) for a in expr.args])

187 else:

188 return expr * argument if apply_to_constants else expr

191# ----------------------------------------------------------------------------------------------------------------------

194def diff(expr, *args):

195 """Shortcut function to create nested derivatives

197 >>> f = sp.Symbol("f")

198 >>> diff(f, 0, 0, 1) == Diff(Diff( Diff(f, 1), 0), 0)

199 True

200 """

201 if len(args) == 0:

202 return expr

203 result = expr

204 for index in reversed(args):

205 result = Diff(result, index)

206 return result

209def diff_args(expr):

210 """Extracts the indices and argument of possibly nested derivative - inverse of diff function

212 >>> args = (sp.Symbol("x"), 0, 1, 2, 5, 1)

213 >>> e = diff(*args)

214 >>> assert diff_args(e) == args

215 """

216 if not isinstance(expr, Diff):

217 return expr,

218 else:

219 inner_res = diff_args(expr.args[0])

220 return (inner_res[0], expr.args[1], *inner_res[1:])

223def diff_terms(expr):

224 """Returns set of all derivatives in an expression.

226 This function yields different results than 'expr.atoms(Diff)' when nested derivatives are in the expression,

227 since this function only returns the outer derivatives

229 Example:

230 >>> x, y = sp.symbols("x, y")

231 >>> diff_terms( diff(x, 0, 0) )

232 {Diff(Diff(x, 0, -1), 0, -1)}

233 >>> diff_terms( diff(x, 0, 0) + y )

234 {Diff(Diff(x, 0, -1), 0, -1)}

235 """

236 result = set()

238 def visit(e):

239 if isinstance(e, Diff):

241 else:

242 for a in e.args:

243 visit(a)

245 visit(expr)

246 return result

249def collect_diffs(expr):

250 """Rewrites expression into a sum of distinct derivatives with pre-factors"""

251 return expr.collect(diff_terms(expr))

254def zero_diffs(expr, label):

255 """Replaces all differentials with the given target by 0

257 Example:

258 >>> x, y, f = sp.symbols("x y f")

259 >>> expression = Diff(f, x) + Diff(f, y) + Diff(Diff(f, y), x) + 7

260 >>> zero_diffs(expression, x)

261 Diff(f, y, -1) + 7

262 """

264 def visit(e):

265 if isinstance(e, Diff):

266 if e.target == label:

267 return 0

268 new_args = [visit(arg) for arg in e.args]

269 return e.func(*new_args) if new_args else e

271 return visit(expr)

274def evaluate_diffs(expr, var=None):

275 """Replaces pystencils diff objects by sympy diff objects and evaluates them.

277 Replaces Diff nodes by sp.diff , the free variable is either the target (if var=None) otherwise

278 the specified var

279 """

280 if isinstance(expr, Diff):

281 if var is None:

282 var = expr.target

283 return sp.diff(evaluate_diffs(expr.arg, var), var)

284 else:

285 new_args = [evaluate_diffs(arg, var) for arg in expr.args]

286 return expr.func(*new_args) if new_args else expr

289def normalize_diff_order(expression, functions=None, constants=None, sort_key=_default_diff_sort_key):

290 """Assumes order of differentiation can be exchanged. Changes the order of nested Diffs to a standard order defined

291 by the sorting key 'sort_key' such that the derivative terms can be further simplified """

293 def visit(expr):

294 if isinstance(expr, Diff):

295 nodes = [expr]

296 while isinstance(nodes[-1].arg, Diff):

297 nodes.append(nodes[-1].arg)

299 processed_arg = visit(nodes[-1].arg)

300 nodes.sort(key=sort_key)

302 result = processed_arg

303 for d in reversed(nodes):

304 result = Diff(result, target=d.target, superscript=d.superscript)

305 return result

306 else:

307 new_args = [visit(e) for e in expr.args]

308 return expr.func(*new_args) if new_args else expr

310 expression = expand_diff_linear(expression.expand(), functions, constants).expand()

311 return visit(expression)

314def expand_diff_full(expr, functions=None, constants=None):

315 if functions is None:

316 functions = expr.atoms(sp.Symbol)

317 if constants is not None:

318 functions.difference_update(constants)

320 def visit(e):

321 if not isinstance(e, sp.Tuple):

322 e = e.expand()

324 if e.func == Diff:

325 result = 0

326 diff_args = {'target': e.target, 'superscript': e.superscript}

327 diff_inner = e.args[0]

328 diff_inner = visit(diff_inner)

329 if diff_inner.func not in (sp.Add, sp.Mul):

330 return e

331 for term in diff_inner.args if diff_inner.func == sp.Add else [diff_inner]:

332 independent_terms = 1

333 dependent_terms = []

334 for factor in normalize_product(term):

335 if factor in functions or isinstance(factor, Diff):

336 dependent_terms.append(factor)

337 else:

338 independent_terms *= factor

339 for i in range(len(dependent_terms)):

340 dependent_term = dependent_terms[i]

341 other_dependent_terms = dependent_terms[:i] + dependent_terms[i + 1:]

342 processed_diff = normalize_diff_order(Diff(dependent_term, **diff_args))

343 result += independent_terms * prod(other_dependent_terms) * processed_diff

344 return result

345 elif isinstance(e, sp.Piecewise):

346 return sp.Piecewise(*((expand_diff_full(a, functions, constants), b) for a, b in e.args))

347 elif isinstance(expr, sp.Tuple):

348 new_args = [visit(arg) for arg in e.args]

349 return sp.Tuple(*new_args)

350 else:

351 new_args = [visit(arg) for arg in e.args]

352 return e.func(*new_args) if new_args else e

354 if isinstance(expr, sp.Matrix):

355 return expr.applyfunc(visit)

356 else:

357 return visit(expr)

360def expand_diff_linear(expr, functions=None, constants=None):

361 """Expands all derivative nodes by applying Diff.split_linear

363 Args:

364 expr: expression containing derivatives

365 functions: sequence of symbols that are considered functions and can not be pulled before the derivative.

366 if None, all symbols are viewed as functions

367 constants: sequence of symbols which are considered constants and can be pulled before the derivative

368 """

369 if functions is None:

370 functions = expr.atoms(sp.Symbol)

371 if constants is not None:

372 functions.difference_update(constants)

374 if isinstance(expr, Diff):

375 arg = expand_diff_linear(expr.arg, functions)

376 if hasattr(arg, 'func') and arg.func == sp.Add:

377 result = 0

378 for a in arg.args:

379 result += Diff(a, target=expr.target, superscript=expr.superscript).split_linear(functions)

380 return result

381 else:

382 diff = Diff(arg, target=expr.target, superscript=expr.superscript)

383 if diff == 0:

384 return 0

385 else:

386 return diff.split_linear(functions)

387 elif isinstance(expr, sp.Piecewise):

388 return sp.Piecewise(*((expand_diff_linear(a, functions, constants), b) for a, b in expr.args))

389 elif isinstance(expr, sp.Tuple):

390 new_args = [expand_diff_linear(e, functions) for e in expr.args]

391 return sp.Tuple(*new_args)

392 else:

393 new_args = [expand_diff_linear(e, functions) for e in expr.args]

394 result = sp.expand(expr.func(*new_args) if new_args else expr)

395 return result

398def expand_diff_products(expr):

399 """Fully expands all derivatives by applying product rule"""

400 if isinstance(expr, Diff):

401 arg = expand_diff_products(expr.args[0])

403 new_args = [Diff(e, target=expr.target, superscript=expr.superscript)

404 for e in arg.args]

406 if arg.func not in (sp.Mul, sp.Pow):

407 return Diff(arg, target=expr.target, superscript=expr.superscript)

408 else:

409 prod_list = normalize_product(arg)

410 result = 0

411 for i in range(len(prod_list)):

412 pre_factor = prod(prod_list[j] for j in range(len(prod_list)) if i != j)

413 result += pre_factor * Diff(prod_list[i], target=expr.target, superscript=expr.superscript)

414 return result

415 else:

416 new_args = [expand_diff_products(e) for e in expr.args]

417 return expr.func(*new_args) if new_args else expr

420def combine_diff_products(expr):

421 """Inverse product rule"""

423 def expr_to_diff_decomposition(expression):

424 """Decomposes a sp.Add node containing CeDiffs into:

425 diff_dict: maps (target, superscript) -> [ (pre_factor, argument), ... ]

426 i.e. a partial(b) ( a is pre-factor, b is argument)

427 in case of partial(a) partial(b) two entries are created (0.5 partial(a), b), (0.5 partial(b), a)

428 """

429 DiffInfo = namedtuple("DiffInfo", ["target", "superscript"])

431 class DiffSplit:

432 def __init__(self, fac, argument):

433 self.pre_factor = fac

434 self.argument = argument

436 def __repr__(self):

437 return str((self.pre_factor, self.argument))

440 diff_dict = defaultdict(list)

441 rest = 0

442 for term in expression.args:

443 if isinstance(term, Diff):

444 diff_dict[DiffInfo(term.target, term.superscript)].append(DiffSplit(1, term.arg))

445 else:

446 mul_args = normalize_product(term)

447 diffs = [d for d in mul_args if isinstance(d, Diff)]

448 factor = prod(d for d in mul_args if not isinstance(d, Diff))

449 if len(diffs) == 0:

450 rest += factor

451 else:

452 for i, diff in enumerate(diffs):

453 all_but_current = [d for j, d in enumerate(diffs) if i != j]

454 pre_factor = factor * prod(all_but_current) * sp.Rational(1, len(diffs))

455 diff_dict[DiffInfo(diff.target, diff.superscript)].append(DiffSplit(pre_factor, diff.arg))

457 return diff_dict, rest

459 def match_diff_splits(own, other):

460 own_fac = own.pre_factor / other.argument

461 other_fac = other.pre_factor / own.argument

462 count = sp.count_ops

463 if count(own_fac) > count(own.pre_factor) or count(other_fac) > count(other.pre_factor):

464 return None

466 new_other_factor = own_fac - other_fac

467 return new_other_factor

469 def process_diff_list(diff_list, label, superscript):

470 if len(diff_list) == 0:

471 return 0

472 elif len(diff_list) == 1:

473 return diff_list[0].pre_factor * Diff(diff_list[0].argument, label, superscript)

475 result = 0

476 matches = []

477 for i in range(1, len(diff_list)):

478 match_result = match_diff_splits(diff_list[i], diff_list[0])

479 if match_result is not None:

480 matches.append((i, match_result))

482 if len(matches) == 0:

483 result += diff_list[0].pre_factor * Diff(diff_list[0].argument, label, superscript)

484 else:

485 other_idx, match_result = sorted(matches, key=lambda e: sp.count_ops(e[1]))[0]

486 new_argument = diff_list[0].argument * diff_list[other_idx].argument

487 result += (diff_list[0].pre_factor / diff_list[other_idx].argument) * Diff(new_argument, label, superscript)

488 if match_result == 0:

489 del diff_list[other_idx]

490 else:

491 diff_list[other_idx].pre_factor = match_result * diff_list[0].argument

492 result += process_diff_list(diff_list[1:], label, superscript)

493 return result

495 def combine(expression):

496 expression = expression.expand()

498 diff_dict, rest = expr_to_diff_decomposition(expression)

499 for (label, superscript), diff_list in diff_dict.items():

500 rest += process_diff_list(diff_list, label, superscript)

501 return rest

502 else:

503 new_args = [combine_diff_products(e) for e in expression.args]

504 return expression.func(*new_args) if new_args else expression

506 return combine(expr)

509def replace_generic_laplacian(expr, dim=None):

510 """Laplacian can be written as Diff(Diff(term)) without explicitly giving the dimensions.

512 This function replaces these constructs by diff(term, 0, 0) + diff(term, 1, 1) + ...

513 For this to work, the arguments of the derivative have to be field or field accesses such that the spatial

514 dimension can be determined.

516 >>> l = Diff(Diff(sp.symbols('x')))

517 >>> replace_generic_laplacian(l, 3)

518 Diff(Diff(x, 0, -1), 0, -1) + Diff(Diff(x, 1, -1), 1, -1) + Diff(Diff(x, 2, -1), 2, -1)

520 """

521 if isinstance(expr, Diff):

522 arg, *indices = diff_args(expr)

523 if isinstance(arg, Field.Access):

524 dim = arg.field.spatial_dimensions

525 assert dim is not None

526 if len(indices) == 2 and all(i == -1 for i in indices):

527 return sum(diff(arg, i, i) for i in range(dim))

528 else:

529 return expr

530 else:

531 new_args = [replace_generic_laplacian(a, dim) for a in expr.args]

532 return expr.func(*new_args) if new_args else expr

535def functional_derivative(functional, v):

536 r"""Computes functional derivative of functional with respect to v using Euler-Lagrange equation

538 .. math ::

540 \frac{\delta F}{\delta v} =

541 \frac{\partial F}{\partial v} - \nabla \cdot \frac{\partial F}{\partial \nabla v}

543 - assumes that gradients are represented by Diff() node

544 - Diff(Diff(r)) represents the divergence of r

545 - the constants parameter is a list with symbols not affected by the derivative. This is used for simplification

546 of the derivative terms.

547 """

548 diffs = functional.atoms(Diff)

549 bulk_substitutions = {d: sp.Dummy() for d in diffs}

550 bulk_substitutions_inverse = {v: k for k, v in bulk_substitutions.items()}

551 non_diff_part = functional.subs(bulk_substitutions)

552 partial_f_partial_v = sp.diff(non_diff_part, v).subs(bulk_substitutions_inverse)

555 for diff_obj in diffs:

556 if diff_obj.args[0] != v:

557 continue

558 dummy = sp.Dummy()

559 partial_f_partial_grad_v = functional.subs(diff_obj, dummy).diff(dummy).subs(dummy, diff_obj)