1from collections import defaultdict, namedtuple 

2 

3import sympy as sp 

4 

5from pystencils.field import Field 

6from pystencils.sympyextensions import normalize_product, prod 

7 

8 

9def _default_diff_sort_key(d): 

10 return str(d.superscript), str(d.target) 

11 

12 

13class Diff(sp.Expr): 

14 """Sympy Node representing a derivative. 

15 

16 The difference to sympy's built in differential is: 

17 - shortened latex representation 

18 - all simplifications have to be done manually 

19 - optional marker displayed as superscript 

20 """ 

21 is_number = False 

22 is_Rational = False 

23 _diff_wrt = True 

24 

25 def __new__(cls, argument, target=-1, superscript=-1): 

26 if argument == 0: 

27 return sp.Rational(0, 1) 

28 if isinstance(argument, Field): 

29 argument = argument.center 

30 return sp.Expr.__new__(cls, argument.expand(), sp.sympify(target), sp.sympify(superscript)) 

31 

32 @property 

33 def is_commutative(self): 

34 any_non_commutative = any(not s.is_commutative for s in self.atoms(sp.Symbol)) 

35 if any_non_commutative: 

36 return False 

37 else: 

38 return True 

39 

40 def get_arg_recursive(self): 

41 """Returns the argument the derivative acts on, for nested derivatives the inner argument is returned""" 

42 if not isinstance(self.arg, Diff): 

43 return self.arg 

44 else: 

45 return self.arg.get_arg_recursive() 

46 

47 def change_arg_recursive(self, new_arg): 

48 """Returns a Diff node with the given 'new_arg' instead of the current argument. For nested derivatives 

49 a new nested derivative is returned where the inner Diff has the 'new_arg'""" 

50 if not isinstance(self.arg, Diff): 

51 return Diff(new_arg, self.target, self.superscript) 

52 else: 

53 return Diff(self.arg.change_arg_recursive(new_arg), self.target, self.superscript) 

54 

55 def split_linear(self, functions): 

56 """ 

57 Applies linearity property of Diff: i.e. 'Diff(c*a+b)' is transformed to 'c * Diff(a) + Diff(b)' 

58 The parameter functions is a list of all symbols that are considered functions, not constants. 

59 For the example above: functions=[a, b] 

60 """ 

61 constant, variable = 1, 1 

62 

63 if self.arg.func != sp.Mul: 

64 constant, variable = 1, self.arg 

65 else: 

66 for factor in normalize_product(self.arg): 

67 if factor in functions or isinstance(factor, Diff): 

68 variable *= factor 

69 else: 

70 constant *= factor 

71 

72 if isinstance(variable, sp.Symbol) and variable not in functions: 

73 return 0 

74 

75 if isinstance(variable, int) or variable.is_number: 

76 return 0 

77 else: 

78 return constant * Diff(variable, target=self.target, superscript=self.superscript) 

79 

80 @property 

81 def arg(self): 

82 """Expression the derivative acts on""" 

83 return self.args[0] 

84 

85 @property 

86 def target(self): 

87 """Subscript, usually the variable the Diff is w.r.t. """ 

88 return self.args[1] 

89 

90 @property 

91 def superscript(self): 

92 """Superscript, for example used as the Chapman-Enskog order index""" 

93 return self.args[2] 

94 

95 def _latex(self, printer, *_): 

96 result = r"{\partial" 

97 if self.superscript >= 0: 

98 result += "^{(%s)}" % (self.superscript,) 

99 if self.target != -1: 

100 result += "_{%s}" % (printer.doprint(self.target),) 

101 

102 contents = printer.doprint(self.arg) 

103 if isinstance(self.arg, int) or isinstance(self.arg, sp.Symbol) or self.arg.is_number or self.arg.func == Diff: 

104 result += " " + contents 

105 else: 

106 result += " (" + contents + ") " 

107 

108 result += "}" 

109 return result 

110 

111 def __str__(self): 

112 return f"D({self.arg})" 

113 

114 def interpolated_access(self, offset, **kwargs): 

115 """Represents an interpolated access on a spatially differentiated field 

116 

117 Args: 

118 offset (Tuple[sympy.Expr]): Absolute position to determine the value of the spatial derivative 

119 """ 

120 from pystencils.interpolation_astnodes import DiffInterpolatorAccess 

121 assert isinstance(self.arg.field, Field), "Must be field to enable interpolated accesses" 

122 return DiffInterpolatorAccess(self.arg.field.interpolated_access(offset, **kwargs).symbol, self.target, *offset) 

123 

124 

125class DiffOperator(sp.Expr): 

126 """Un-applied differential, i.e. differential operator 

127 

128 Args: 

129 target: the differential is w.r.t to this variable. 

130 This target is mainly for display purposes (its the subscript) and to distinguish DiffOperators 

131 If the target is '-1' no subscript is displayed 

132 superscript: optional marker displayed as superscript 

133 is not displayed if set to '-1' 

134 

135 The DiffOperator behaves much like a variable with special name. Its main use is to be applied later, using the 

136 DiffOperator.apply(expr, arg) which transforms 'DiffOperator's to applied 'Diff's 

137 """ 

138 is_commutative = True 

139 is_number = False 

140 is_Rational = False 

141 

142 def __new__(cls, target=-1, superscript=-1): 

143 return sp.Expr.__new__(cls, sp.sympify(target), sp.sympify(superscript)) 

144 

145 @property 

146 def target(self): 

147 return self.args[0] 

148 

149 @property 

150 def superscript(self): 

151 return self.args[1] 

152 

153 def _latex(self, *_): 

154 result = r"{\partial" 

155 if self.superscript >= 0: 

156 result += "^{(%s)}" % (self.superscript,) 

157 if self.target != -1: 

158 result += "_{%s}" % (self.target,) 

159 result += "}" 

160 return result 

161 

162 @staticmethod 

163 def apply(expr, argument, apply_to_constants=True): 

164 """ 

165 Returns a new expression where each 'DiffOperator' is replaced by a 'Diff' node. 

166 Multiplications of 'DiffOperator's are interpreted as nested application of differentiation: 

167 i.e. DiffOperator('x')*DiffOperator('x') is a second derivative replaced by Diff(Diff(arg, x), t) 

168 """ 

169 

170 def handle_mul(mul): 

171 args = normalize_product(mul) 

172 diffs = [a for a in args if isinstance(a, DiffOperator)] 

173 if len(diffs) == 0: 

174 return mul * argument if apply_to_constants else mul 

175 rest = [a for a in args if not isinstance(a, DiffOperator)] 

176 diffs.sort(key=_default_diff_sort_key) 

177 result = argument 

178 for d in reversed(diffs): 

179 result = Diff(result, target=d.target, superscript=d.superscript) 

180 return prod(rest) * result 

181 

182 expr = expr.expand() 

183 if expr.func == sp.Mul or expr.func == sp.Pow: 

184 return handle_mul(expr) 

185 elif expr.func == sp.Add: 

186 return expr.func(*[handle_mul(a) for a in expr.args]) 

187 else: 

188 return expr * argument if apply_to_constants else expr 

189 

190 

191# ---------------------------------------------------------------------------------------------------------------------- 

192 

193 

194def diff(expr, *args): 

195 """Shortcut function to create nested derivatives 

196 

197 >>> f = sp.Symbol("f") 

198 >>> diff(f, 0, 0, 1) == Diff(Diff( Diff(f, 1), 0), 0) 

199 True 

200 """ 

201 if len(args) == 0: 

202 return expr 

203 result = expr 

204 for index in reversed(args): 

205 result = Diff(result, index) 

206 return result 

207 

208 

209def diff_args(expr): 

210 """Extracts the indices and argument of possibly nested derivative - inverse of diff function 

211 

212 >>> args = (sp.Symbol("x"), 0, 1, 2, 5, 1) 

213 >>> e = diff(*args) 

214 >>> assert diff_args(e) == args 

215 """ 

216 if not isinstance(expr, Diff): 

217 return expr, 

218 else: 

219 inner_res = diff_args(expr.args[0]) 

220 return (inner_res[0], expr.args[1], *inner_res[1:]) 

221 

222 

223def diff_terms(expr): 

224 """Returns set of all derivatives in an expression. 

225 

226 This function yields different results than 'expr.atoms(Diff)' when nested derivatives are in the expression, 

227 since this function only returns the outer derivatives 

228 

229 Example: 

230 >>> x, y = sp.symbols("x, y") 

231 >>> diff_terms( diff(x, 0, 0) ) 

232 {Diff(Diff(x, 0, -1), 0, -1)} 

233 >>> diff_terms( diff(x, 0, 0) + y ) 

234 {Diff(Diff(x, 0, -1), 0, -1)} 

235 """ 

236 result = set() 

237 

238 def visit(e): 

239 if isinstance(e, Diff): 

240 result.add(e) 

241 else: 

242 for a in e.args: 

243 visit(a) 

244 

245 visit(expr) 

246 return result 

247 

248 

249def collect_diffs(expr): 

250 """Rewrites expression into a sum of distinct derivatives with pre-factors""" 

251 return expr.collect(diff_terms(expr)) 

252 

253 

254def zero_diffs(expr, label): 

255 """Replaces all differentials with the given target by 0 

256 

257 Example: 

258 >>> x, y, f = sp.symbols("x y f") 

259 >>> expression = Diff(f, x) + Diff(f, y) + Diff(Diff(f, y), x) + 7 

260 >>> zero_diffs(expression, x) 

261 Diff(f, y, -1) + 7 

262 """ 

263 

264 def visit(e): 

265 if isinstance(e, Diff): 

266 if e.target == label: 

267 return 0 

268 new_args = [visit(arg) for arg in e.args] 

269 return e.func(*new_args) if new_args else e 

270 

271 return visit(expr) 

272 

273 

274def evaluate_diffs(expr, var=None): 

275 """Replaces pystencils diff objects by sympy diff objects and evaluates them. 

276 

277 Replaces Diff nodes by sp.diff , the free variable is either the target (if var=None) otherwise 

278 the specified var 

279 """ 

280 if isinstance(expr, Diff): 

281 if var is None: 

282 var = expr.target 

283 return sp.diff(evaluate_diffs(expr.arg, var), var) 

284 else: 

285 new_args = [evaluate_diffs(arg, var) for arg in expr.args] 

286 return expr.func(*new_args) if new_args else expr 

287 

288 

289def normalize_diff_order(expression, functions=None, constants=None, sort_key=_default_diff_sort_key): 

290 """Assumes order of differentiation can be exchanged. Changes the order of nested Diffs to a standard order defined 

291 by the sorting key 'sort_key' such that the derivative terms can be further simplified """ 

292 

293 def visit(expr): 

294 if isinstance(expr, Diff): 

295 nodes = [expr] 

296 while isinstance(nodes[-1].arg, Diff): 

297 nodes.append(nodes[-1].arg) 

298 

299 processed_arg = visit(nodes[-1].arg) 

300 nodes.sort(key=sort_key) 

301 

302 result = processed_arg 

303 for d in reversed(nodes): 

304 result = Diff(result, target=d.target, superscript=d.superscript) 

305 return result 

306 else: 

307 new_args = [visit(e) for e in expr.args] 

308 return expr.func(*new_args) if new_args else expr 

309 

310 expression = expand_diff_linear(expression.expand(), functions, constants).expand() 

311 return visit(expression) 

312 

313 

314def expand_diff_full(expr, functions=None, constants=None): 

315 if functions is None: 

316 functions = expr.atoms(sp.Symbol) 

317 if constants is not None: 

318 functions.difference_update(constants) 

319 

320 def visit(e): 

321 if not isinstance(e, sp.Tuple): 

322 e = e.expand() 

323 

324 if e.func == Diff: 

325 result = 0 

326 diff_args = {'target': e.target, 'superscript': e.superscript} 

327 diff_inner = e.args[0] 

328 diff_inner = visit(diff_inner) 

329 if diff_inner.func not in (sp.Add, sp.Mul): 

330 return e 

331 for term in diff_inner.args if diff_inner.func == sp.Add else [diff_inner]: 

332 independent_terms = 1 

333 dependent_terms = [] 

334 for factor in normalize_product(term): 

335 if factor in functions or isinstance(factor, Diff): 

336 dependent_terms.append(factor) 

337 else: 

338 independent_terms *= factor 

339 for i in range(len(dependent_terms)): 

340 dependent_term = dependent_terms[i] 

341 other_dependent_terms = dependent_terms[:i] + dependent_terms[i + 1:] 

342 processed_diff = normalize_diff_order(Diff(dependent_term, **diff_args)) 

343 result += independent_terms * prod(other_dependent_terms) * processed_diff 

344 return result 

345 elif isinstance(e, sp.Piecewise): 

346 return sp.Piecewise(*((expand_diff_full(a, functions, constants), b) for a, b in e.args)) 

347 elif isinstance(expr, sp.Tuple): 

348 new_args = [visit(arg) for arg in e.args] 

349 return sp.Tuple(*new_args) 

350 else: 

351 new_args = [visit(arg) for arg in e.args] 

352 return e.func(*new_args) if new_args else e 

353 

354 if isinstance(expr, sp.Matrix): 

355 return expr.applyfunc(visit) 

356 else: 

357 return visit(expr) 

358 

359 

360def expand_diff_linear(expr, functions=None, constants=None): 

361 """Expands all derivative nodes by applying Diff.split_linear 

362 

363 Args: 

364 expr: expression containing derivatives 

365 functions: sequence of symbols that are considered functions and can not be pulled before the derivative. 

366 if None, all symbols are viewed as functions 

367 constants: sequence of symbols which are considered constants and can be pulled before the derivative 

368 """ 

369 if functions is None: 

370 functions = expr.atoms(sp.Symbol) 

371 if constants is not None: 

372 functions.difference_update(constants) 

373 

374 if isinstance(expr, Diff): 

375 arg = expand_diff_linear(expr.arg, functions) 

376 if hasattr(arg, 'func') and arg.func == sp.Add: 

377 result = 0 

378 for a in arg.args: 

379 result += Diff(a, target=expr.target, superscript=expr.superscript).split_linear(functions) 

380 return result 

381 else: 

382 diff = Diff(arg, target=expr.target, superscript=expr.superscript) 

383 if diff == 0: 

384 return 0 

385 else: 

386 return diff.split_linear(functions) 

387 elif isinstance(expr, sp.Piecewise): 

388 return sp.Piecewise(*((expand_diff_linear(a, functions, constants), b) for a, b in expr.args)) 

389 elif isinstance(expr, sp.Tuple): 

390 new_args = [expand_diff_linear(e, functions) for e in expr.args] 

391 return sp.Tuple(*new_args) 

392 else: 

393 new_args = [expand_diff_linear(e, functions) for e in expr.args] 

394 result = sp.expand(expr.func(*new_args) if new_args else expr) 

395 return result 

396 

397 

398def expand_diff_products(expr): 

399 """Fully expands all derivatives by applying product rule""" 

400 if isinstance(expr, Diff): 

401 arg = expand_diff_products(expr.args[0]) 

402 if arg.func == sp.Add: 

403 new_args = [Diff(e, target=expr.target, superscript=expr.superscript) 

404 for e in arg.args] 

405 return sp.Add(*new_args) 

406 if arg.func not in (sp.Mul, sp.Pow): 

407 return Diff(arg, target=expr.target, superscript=expr.superscript) 

408 else: 

409 prod_list = normalize_product(arg) 

410 result = 0 

411 for i in range(len(prod_list)): 

412 pre_factor = prod(prod_list[j] for j in range(len(prod_list)) if i != j) 

413 result += pre_factor * Diff(prod_list[i], target=expr.target, superscript=expr.superscript) 

414 return result 

415 else: 

416 new_args = [expand_diff_products(e) for e in expr.args] 

417 return expr.func(*new_args) if new_args else expr 

418 

419 

420def combine_diff_products(expr): 

421 """Inverse product rule""" 

422 

423 def expr_to_diff_decomposition(expression): 

424 """Decomposes a sp.Add node containing CeDiffs into: 

425 diff_dict: maps (target, superscript) -> [ (pre_factor, argument), ... ] 

426 i.e. a partial(b) ( a is pre-factor, b is argument) 

427 in case of partial(a) partial(b) two entries are created (0.5 partial(a), b), (0.5 partial(b), a) 

428 """ 

429 DiffInfo = namedtuple("DiffInfo", ["target", "superscript"]) 

430 

431 class DiffSplit: 

432 def __init__(self, fac, argument): 

433 self.pre_factor = fac 

434 self.argument = argument 

435 

436 def __repr__(self): 

437 return str((self.pre_factor, self.argument)) 

438 

439 assert isinstance(expression, sp.Add) 

440 diff_dict = defaultdict(list) 

441 rest = 0 

442 for term in expression.args: 

443 if isinstance(term, Diff): 

444 diff_dict[DiffInfo(term.target, term.superscript)].append(DiffSplit(1, term.arg)) 

445 else: 

446 mul_args = normalize_product(term) 

447 diffs = [d for d in mul_args if isinstance(d, Diff)] 

448 factor = prod(d for d in mul_args if not isinstance(d, Diff)) 

449 if len(diffs) == 0: 

450 rest += factor 

451 else: 

452 for i, diff in enumerate(diffs): 

453 all_but_current = [d for j, d in enumerate(diffs) if i != j] 

454 pre_factor = factor * prod(all_but_current) * sp.Rational(1, len(diffs)) 

455 diff_dict[DiffInfo(diff.target, diff.superscript)].append(DiffSplit(pre_factor, diff.arg)) 

456 

457 return diff_dict, rest 

458 

459 def match_diff_splits(own, other): 

460 own_fac = own.pre_factor / other.argument 

461 other_fac = other.pre_factor / own.argument 

462 count = sp.count_ops 

463 if count(own_fac) > count(own.pre_factor) or count(other_fac) > count(other.pre_factor): 

464 return None 

465 

466 new_other_factor = own_fac - other_fac 

467 return new_other_factor 

468 

469 def process_diff_list(diff_list, label, superscript): 

470 if len(diff_list) == 0: 

471 return 0 

472 elif len(diff_list) == 1: 

473 return diff_list[0].pre_factor * Diff(diff_list[0].argument, label, superscript) 

474 

475 result = 0 

476 matches = [] 

477 for i in range(1, len(diff_list)): 

478 match_result = match_diff_splits(diff_list[i], diff_list[0]) 

479 if match_result is not None: 

480 matches.append((i, match_result)) 

481 

482 if len(matches) == 0: 

483 result += diff_list[0].pre_factor * Diff(diff_list[0].argument, label, superscript) 

484 else: 

485 other_idx, match_result = sorted(matches, key=lambda e: sp.count_ops(e[1]))[0] 

486 new_argument = diff_list[0].argument * diff_list[other_idx].argument 

487 result += (diff_list[0].pre_factor / diff_list[other_idx].argument) * Diff(new_argument, label, superscript) 

488 if match_result == 0: 

489 del diff_list[other_idx] 

490 else: 

491 diff_list[other_idx].pre_factor = match_result * diff_list[0].argument 

492 result += process_diff_list(diff_list[1:], label, superscript) 

493 return result 

494 

495 def combine(expression): 

496 expression = expression.expand() 

497 if isinstance(expression, sp.Add): 

498 diff_dict, rest = expr_to_diff_decomposition(expression) 

499 for (label, superscript), diff_list in diff_dict.items(): 

500 rest += process_diff_list(diff_list, label, superscript) 

501 return rest 

502 else: 

503 new_args = [combine_diff_products(e) for e in expression.args] 

504 return expression.func(*new_args) if new_args else expression 

505 

506 return combine(expr) 

507 

508 

509def replace_generic_laplacian(expr, dim=None): 

510 """Laplacian can be written as Diff(Diff(term)) without explicitly giving the dimensions. 

511 

512 This function replaces these constructs by diff(term, 0, 0) + diff(term, 1, 1) + ... 

513 For this to work, the arguments of the derivative have to be field or field accesses such that the spatial 

514 dimension can be determined. 

515 

516 >>> l = Diff(Diff(sp.symbols('x'))) 

517 >>> replace_generic_laplacian(l, 3) 

518 Diff(Diff(x, 0, -1), 0, -1) + Diff(Diff(x, 1, -1), 1, -1) + Diff(Diff(x, 2, -1), 2, -1) 

519 

520 """ 

521 if isinstance(expr, Diff): 

522 arg, *indices = diff_args(expr) 

523 if isinstance(arg, Field.Access): 

524 dim = arg.field.spatial_dimensions 

525 assert dim is not None 

526 if len(indices) == 2 and all(i == -1 for i in indices): 

527 return sum(diff(arg, i, i) for i in range(dim)) 

528 else: 

529 return expr 

530 else: 

531 new_args = [replace_generic_laplacian(a, dim) for a in expr.args] 

532 return expr.func(*new_args) if new_args else expr 

533 

534 

535def functional_derivative(functional, v): 

536 r"""Computes functional derivative of functional with respect to v using Euler-Lagrange equation 

537 

538 .. math :: 

539 

540 \frac{\delta F}{\delta v} = 

541 \frac{\partial F}{\partial v} - \nabla \cdot \frac{\partial F}{\partial \nabla v} 

542 

543 - assumes that gradients are represented by Diff() node 

544 - Diff(Diff(r)) represents the divergence of r 

545 - the constants parameter is a list with symbols not affected by the derivative. This is used for simplification 

546 of the derivative terms. 

547 """ 

548 diffs = functional.atoms(Diff) 

549 bulk_substitutions = {d: sp.Dummy() for d in diffs} 

550 bulk_substitutions_inverse = {v: k for k, v in bulk_substitutions.items()} 

551 non_diff_part = functional.subs(bulk_substitutions) 

552 partial_f_partial_v = sp.diff(non_diff_part, v).subs(bulk_substitutions_inverse) 

553 

554 gradient_part = 0 

555 for diff_obj in diffs: 

556 if diff_obj.args[0] != v: 

557 continue 

558 dummy = sp.Dummy() 

559 partial_f_partial_grad_v = functional.subs(diff_obj, dummy).diff(dummy).subs(dummy, diff_obj) 

560 gradient_part += Diff(partial_f_partial_grad_v, target=diff_obj.target, superscript=diff_obj.superscript) 

561 

562 result = partial_f_partial_v - gradient_part 

563 return result