1from typing import List 

2 

3import sympy as sp 

4 

5from pystencils.assignment import Assignment 

6from pystencils.astnodes import Node 

7from pystencils.sympyextensions import is_constant 

8from pystencils.transformations import generic_visit 

9 

10 

11class PlaceholderFunction: 

12 pass 

13 

14 

15def to_placeholder_function(expr, name): 

16 """Replaces an expression by a sympy function. 

17 

18 - replacing an expression with just a symbol would lead to problem when calculating derivatives 

19 - placeholder functions get rid of this problem 

20 

21 Examples: 

22 >>> x, t = sp.symbols("x, t") 

23 >>> temperature = x**2 + t**4 # some 'complicated' dependency 

24 >>> temperature_placeholder = to_placeholder_function(temperature, 'T') 

25 >>> diffusivity = temperature_placeholder + 42 * t 

26 >>> sp.diff(diffusivity, t) # returns a symbol instead of the computed derivative 

27 _dT_dt + 42 

28 >>> result, subexpr = remove_placeholder_functions(diffusivity) 

29 >>> result 

30 T + 42*t 

31 >>> subexpr 

32 [Assignment(T, t**4 + x**2), Assignment(_dT_dt, 4*t**3), Assignment(_dT_dx, 2*x)] 

33 

34 """ 

35 symbols = list(expr.atoms(sp.Symbol)) 

36 symbols.sort(key=lambda e: e.name) 

37 derivative_symbols = [sp.Symbol(f"_d{name}_d{s.name}") for s in symbols] 

38 derivatives = [sp.diff(expr, s) for s in symbols] 

39 

40 assignments = [Assignment(sp.Symbol(name), expr)] 

41 assignments += [Assignment(symbol, derivative) 

42 for symbol, derivative in zip(derivative_symbols, derivatives) 

43 if not is_constant(derivative)] 

44 

45 def fdiff(_, index): 

46 result = derivatives[index - 1] 

47 return result if is_constant(result) else derivative_symbols[index - 1] 

48 

49 func = type(name, (sp.Function, PlaceholderFunction), 

50 {'fdiff': fdiff, 

51 'value': sp.Symbol(name), 

52 'subexpressions': assignments, 

53 'nargs': len(symbols)}) 

54 return func(*symbols) 

55 

56 

57def remove_placeholder_functions(expr): 

58 subexpressions = [] 

59 

60 def visit(e): 

61 if isinstance(e, Node): 

62 return e 

63 elif isinstance(e, PlaceholderFunction): 

64 for se in e.subexpressions: 

65 if se.lhs not in {a.lhs for a in subexpressions}: 

66 subexpressions.append(se) 

67 return e.value 

68 else: 

69 new_args = [visit(a) for a in e.args] 

70 return e.func(*new_args) if new_args else e 

71 

72 return generic_visit(expr, visit), subexpressions 

73 

74 

75def prepend_placeholder_functions(assignments: List[Assignment]): 

76 result, subexpressions = remove_placeholder_functions(assignments) 

77 return subexpressions + result