1from typing import List

3import sympy as sp

5from pystencils.assignment import Assignment

6from pystencils.astnodes import Node

7from pystencils.sympyextensions import is_constant

8from pystencils.transformations import generic_visit

11class PlaceholderFunction:

12 pass

15def to_placeholder_function(expr, name):

16 """Replaces an expression by a sympy function.

18 - replacing an expression with just a symbol would lead to problem when calculating derivatives

19 - placeholder functions get rid of this problem

21 Examples:

22 >>> x, t = sp.symbols("x, t")

23 >>> temperature = x**2 + t**4 # some 'complicated' dependency

24 >>> temperature_placeholder = to_placeholder_function(temperature, 'T')

25 >>> diffusivity = temperature_placeholder + 42 * t

26 >>> sp.diff(diffusivity, t) # returns a symbol instead of the computed derivative

27 _dT_dt + 42

28 >>> result, subexpr = remove_placeholder_functions(diffusivity)

29 >>> result

30 T + 42*t

31 >>> subexpr

32 [Assignment(T, t**4 + x**2), Assignment(_dT_dt, 4*t**3), Assignment(_dT_dx, 2*x)]

34 """

35 symbols = list(expr.atoms(sp.Symbol))

36 symbols.sort(key=lambda e: e.name)

37 derivative_symbols = [sp.Symbol(f"_d{name}_d{s.name}") for s in symbols]

38 derivatives = [sp.diff(expr, s) for s in symbols]

40 assignments = [Assignment(sp.Symbol(name), expr)]

41 assignments += [Assignment(symbol, derivative)

42 for symbol, derivative in zip(derivative_symbols, derivatives)

43 if not is_constant(derivative)]

45 def fdiff(_, index):

46 result = derivatives[index - 1]

47 return result if is_constant(result) else derivative_symbols[index - 1]

49 func = type(name, (sp.Function, PlaceholderFunction),

50 {'fdiff': fdiff,

51 'value': sp.Symbol(name),

52 'subexpressions': assignments,

53 'nargs': len(symbols)})

54 return func(*symbols)

57def remove_placeholder_functions(expr):

58 subexpressions = []

60 def visit(e):

61 if isinstance(e, Node):

62 return e

63 elif isinstance(e, PlaceholderFunction):

64 for se in e.subexpressions:

65 if se.lhs not in {a.lhs for a in subexpressions}:

66 subexpressions.append(se)

67 return e.value

68 else:

69 new_args = [visit(a) for a in e.args]

70 return e.func(*new_args) if new_args else e

72 return generic_visit(expr, visit), subexpressions

75def prepend_placeholder_functions(assignments: List[Assignment]):

76 result, subexpressions = remove_placeholder_functions(assignments)

77 return subexpressions + result