1from typing import Tuple

3import sympy as sp

5from pystencils.astnodes import LoopOverCoordinate

6from pystencils.cache import memorycache

7from pystencils.fd import Diff

8from pystencils.field import Field

9from pystencils.transformations import generic_visit

11from .derivation import FiniteDifferenceStencilDerivation

12from .derivative import diff_args

15def fd_stencils_standard(indices, dx, fa):

16 order = len(indices)

17 assert all(i >= 0 for i in indices), "Can only discretize objects with (integer) subscripts"

18 if order == 1:

19 idx = indices[0]

20 return (fa.neighbor(idx, 1) - fa.neighbor(idx, -1)) / (2 * dx)

21 elif order == 2:

22 if indices[0] == indices[1]:

23 return (-2 * fa + fa.neighbor(indices[0], -1) + fa.neighbor(indices[0], +1)) / (dx ** 2)

24 else:

25 offsets = [(1, 1), [-1, 1], [1, -1], [-1, -1]]

26 return sum(o1 * o2 * fa.neighbor(indices[0], o1).neighbor(indices[1], o2)

27 for o1, o2 in offsets) / (4 * dx ** 2)

28 raise NotImplementedError("Supports only derivatives up to order 2")

31def fd_stencils_isotropic(indices, dx, fa):

32 dim = fa.field.spatial_dimensions

33 if dim == 1:

34 return fd_stencils_standard(indices, dx, fa)

35 elif dim == 2:

36 order = len(indices)

38 if order == 1:

39 idx = indices[0]

40 assert 0 <= idx < 2

41 other_idx = 1 if indices[0] == 0 else 0

42 weights = {-1: sp.Rational(1, 12) / dx,

43 0: sp.Rational(1, 3) / dx,

44 1: sp.Rational(1, 12) / dx}

45 upper_terms = sum(fa.neighbor(idx, +1).neighbor(other_idx, off) * w for off, w in weights.items())

46 lower_terms = sum(fa.neighbor(idx, -1).neighbor(other_idx, off) * w for off, w in weights.items())

47 return upper_terms - lower_terms

48 elif order == 2:

49 if indices[0] == indices[1]:

50 idx = indices[0]

51 other_idx = 1 if idx == 0 else 0

52 diagonals = sp.Rational(1, 12) * sum(fa.neighbor(0, i).neighbor(1, j) for i in (-1, 1) for j in (-1, 1))

53 div_direction = sp.Rational(5, 6) * sum(fa.neighbor(idx, i) for i in (-1, 1))

54 other_direction = - sp.Rational(1, 6) * sum(fa.neighbor(other_idx, i) for i in (-1, 1))

55 center = - sp.Rational(5, 3) * fa

56 return (diagonals + div_direction + other_direction + center) / (dx ** 2)

57 else:

58 return fd_stencils_standard(indices, dx, fa)

59 raise NotImplementedError("Supports only derivatives up to order 2 for 1D and 2D setups")

62def fd_stencils_forth_order_isotropic(indices, dx, fa):

63 order = len(indices)

64 if order != 1:

65 raise NotImplementedError("Forth order finite difference discretization is "

66 "currently only supported for first derivatives")

67 dim = indices[0]

68 if dim not in (0, 1):

69 raise NotImplementedError("Forth order finite difference discretization is only implemented for 2D")

71 stencils = forth_order_2d_derivation()

72 return stencils[dim].apply(fa) / dx

75def discretize_spatial(expr, dx, stencil=fd_stencils_standard):

76 if isinstance(stencil, str):

77 if stencil == 'standard':

78 stencil = fd_stencils_standard

79 elif stencil == 'isotropic':

80 stencil = fd_stencils_isotropic

81 else:

82 raise ValueError("Unknown stencil. Supported 'standard' and 'isotropic'")

84 def visitor(e):

85 if isinstance(e, Diff):

86 arg, *indices = diff_args(e)

87 if not isinstance(arg, Field.Access):

88 raise ValueError("Only derivatives with field or field accesses as arguments can be discretized")

89 return stencil(indices, dx, arg)

90 else:

91 new_args = [discretize_spatial(a, dx, stencil) for a in e.args]

92 return e.func(*new_args) if new_args else e

94 return generic_visit(expr, visitor)

97def discretize_spatial_staggered(expr, dx, stencil=fd_stencils_standard):

98 def staggered_visitor(e, coordinate, sign):

99 if isinstance(e, Diff):

100 arg, *indices = diff_args(e)

101 if len(indices) != 1:

102 raise ValueError("Function supports only up to second derivatives")

103 if not isinstance(arg, Field.Access):

104 raise ValueError("Argument of inner derivative has to be field access")

105 target = indices[0]

106 if target == coordinate:

108 return (arg.neighbor(coordinate, sign) - arg) / dx * sign

109 else:

110 return (stencil(indices, dx, arg.neighbor(coordinate, sign))

111 + stencil(indices, dx, arg)) / 2

112 elif isinstance(e, Field.Access):

113 return (e.neighbor(coordinate, sign) + e) / 2

114 elif isinstance(e, sp.Symbol):

115 loop_idx = LoopOverCoordinate.is_loop_counter_symbol(e)

116 return e + sign / 2 if loop_idx == coordinate else e

117 else:

118 new_args = [staggered_visitor(a, coordinate, sign) for a in e.args]

119 return e.func(*new_args) if new_args else e

121 def visitor(e):

122 if isinstance(e, Diff):

123 arg, *indices = diff_args(e)

124 if isinstance(arg, Field.Access):

125 return stencil(indices, dx, arg)

126 else:

127 if not len(indices) == 1:

128 raise ValueError("This term is not support by the staggered discretization strategy")

129 target = indices[0]

130 return (staggered_visitor(arg, target, 1) - staggered_visitor(arg, target, -1)) / dx

131 else:

132 new_args = [visitor(a) for a in e.args]

133 return e.func(*new_args) if new_args else e

135 return generic_visit(expr, visitor)

138# -------------------------------------- special stencils --------------------------------------------------------------

139@memorycache(maxsize=1)

140def forth_order_2d_derivation() -> Tuple[FiniteDifferenceStencilDerivation.Result, ...]:

141 # Symmetry, isotropy and 4th order conditions are not enough to fully specify the stencil

142 # one weight has to be specifically set to a somewhat arbitrary value

143 second_neighbor_weight = sp.Rational(1, 10)

144 second_neighbor_stencil = [(i, j)

145 for i in (-2, -1, 0, 1, 2)

146 for j in (-2, -1, 0, 1, 2)

147 ]

148 x_diff = FiniteDifferenceStencilDerivation((0,), second_neighbor_stencil)

149 x_diff.set_weight((2, 0), second_neighbor_weight)

150 x_diff.assume_symmetric(0, anti_symmetric=True)

151 x_diff.assume_symmetric(1)

152 x_diff_stencil = x_diff.get_stencil(isotropic=True)

154 y_diff = FiniteDifferenceStencilDerivation((1,), second_neighbor_stencil)

155 y_diff.set_weight((0, 2), second_neighbor_weight)

156 y_diff.assume_symmetric(1, anti_symmetric=True)

157 y_diff.assume_symmetric(0)

158 y_diff_stencil = y_diff.get_stencil(isotropic=True)

160 return x_diff_stencil, y_diff_stencil