1from typing import Tuple 

2 

3import sympy as sp 

4 

5from pystencils.astnodes import LoopOverCoordinate 

6from pystencils.cache import memorycache 

7from pystencils.fd import Diff 

8from pystencils.field import Field 

9from pystencils.transformations import generic_visit 

10 

11from .derivation import FiniteDifferenceStencilDerivation 

12from .derivative import diff_args 

13 

14 

15def fd_stencils_standard(indices, dx, fa): 

16 order = len(indices) 

17 assert all(i >= 0 for i in indices), "Can only discretize objects with (integer) subscripts" 

18 if order == 1: 

19 idx = indices[0] 

20 return (fa.neighbor(idx, 1) - fa.neighbor(idx, -1)) / (2 * dx) 

21 elif order == 2: 

22 if indices[0] == indices[1]: 

23 return (-2 * fa + fa.neighbor(indices[0], -1) + fa.neighbor(indices[0], +1)) / (dx ** 2) 

24 else: 

25 offsets = [(1, 1), [-1, 1], [1, -1], [-1, -1]] 

26 return sum(o1 * o2 * fa.neighbor(indices[0], o1).neighbor(indices[1], o2) 

27 for o1, o2 in offsets) / (4 * dx ** 2) 

28 raise NotImplementedError("Supports only derivatives up to order 2") 

29 

30 

31def fd_stencils_isotropic(indices, dx, fa): 

32 dim = fa.field.spatial_dimensions 

33 if dim == 1: 

34 return fd_stencils_standard(indices, dx, fa) 

35 elif dim == 2: 

36 order = len(indices) 

37 

38 if order == 1: 

39 idx = indices[0] 

40 assert 0 <= idx < 2 

41 other_idx = 1 if indices[0] == 0 else 0 

42 weights = {-1: sp.Rational(1, 12) / dx, 

43 0: sp.Rational(1, 3) / dx, 

44 1: sp.Rational(1, 12) / dx} 

45 upper_terms = sum(fa.neighbor(idx, +1).neighbor(other_idx, off) * w for off, w in weights.items()) 

46 lower_terms = sum(fa.neighbor(idx, -1).neighbor(other_idx, off) * w for off, w in weights.items()) 

47 return upper_terms - lower_terms 

48 elif order == 2: 

49 if indices[0] == indices[1]: 

50 idx = indices[0] 

51 other_idx = 1 if idx == 0 else 0 

52 diagonals = sp.Rational(1, 12) * sum(fa.neighbor(0, i).neighbor(1, j) for i in (-1, 1) for j in (-1, 1)) 

53 div_direction = sp.Rational(5, 6) * sum(fa.neighbor(idx, i) for i in (-1, 1)) 

54 other_direction = - sp.Rational(1, 6) * sum(fa.neighbor(other_idx, i) for i in (-1, 1)) 

55 center = - sp.Rational(5, 3) * fa 

56 return (diagonals + div_direction + other_direction + center) / (dx ** 2) 

57 else: 

58 return fd_stencils_standard(indices, dx, fa) 

59 raise NotImplementedError("Supports only derivatives up to order 2 for 1D and 2D setups") 

60 

61 

62def fd_stencils_forth_order_isotropic(indices, dx, fa): 

63 order = len(indices) 

64 if order != 1: 

65 raise NotImplementedError("Forth order finite difference discretization is " 

66 "currently only supported for first derivatives") 

67 dim = indices[0] 

68 if dim not in (0, 1): 

69 raise NotImplementedError("Forth order finite difference discretization is only implemented for 2D") 

70 

71 stencils = forth_order_2d_derivation() 

72 return stencils[dim].apply(fa) / dx 

73 

74 

75def discretize_spatial(expr, dx, stencil=fd_stencils_standard): 

76 if isinstance(stencil, str): 

77 if stencil == 'standard': 

78 stencil = fd_stencils_standard 

79 elif stencil == 'isotropic': 

80 stencil = fd_stencils_isotropic 

81 else: 

82 raise ValueError("Unknown stencil. Supported 'standard' and 'isotropic'") 

83 

84 def visitor(e): 

85 if isinstance(e, Diff): 

86 arg, *indices = diff_args(e) 

87 if not isinstance(arg, Field.Access): 

88 raise ValueError("Only derivatives with field or field accesses as arguments can be discretized") 

89 return stencil(indices, dx, arg) 

90 else: 

91 new_args = [discretize_spatial(a, dx, stencil) for a in e.args] 

92 return e.func(*new_args) if new_args else e 

93 

94 return generic_visit(expr, visitor) 

95 

96 

97def discretize_spatial_staggered(expr, dx, stencil=fd_stencils_standard): 

98 def staggered_visitor(e, coordinate, sign): 

99 if isinstance(e, Diff): 

100 arg, *indices = diff_args(e) 

101 if len(indices) != 1: 

102 raise ValueError("Function supports only up to second derivatives") 

103 if not isinstance(arg, Field.Access): 

104 raise ValueError("Argument of inner derivative has to be field access") 

105 target = indices[0] 

106 if target == coordinate: 

107 assert sign in (-1, 1) 

108 return (arg.neighbor(coordinate, sign) - arg) / dx * sign 

109 else: 

110 return (stencil(indices, dx, arg.neighbor(coordinate, sign)) 

111 + stencil(indices, dx, arg)) / 2 

112 elif isinstance(e, Field.Access): 

113 return (e.neighbor(coordinate, sign) + e) / 2 

114 elif isinstance(e, sp.Symbol): 

115 loop_idx = LoopOverCoordinate.is_loop_counter_symbol(e) 

116 return e + sign / 2 if loop_idx == coordinate else e 

117 else: 

118 new_args = [staggered_visitor(a, coordinate, sign) for a in e.args] 

119 return e.func(*new_args) if new_args else e 

120 

121 def visitor(e): 

122 if isinstance(e, Diff): 

123 arg, *indices = diff_args(e) 

124 if isinstance(arg, Field.Access): 

125 return stencil(indices, dx, arg) 

126 else: 

127 if not len(indices) == 1: 

128 raise ValueError("This term is not support by the staggered discretization strategy") 

129 target = indices[0] 

130 return (staggered_visitor(arg, target, 1) - staggered_visitor(arg, target, -1)) / dx 

131 else: 

132 new_args = [visitor(a) for a in e.args] 

133 return e.func(*new_args) if new_args else e 

134 

135 return generic_visit(expr, visitor) 

136 

137 

138# -------------------------------------- special stencils -------------------------------------------------------------- 

139@memorycache(maxsize=1) 

140def forth_order_2d_derivation() -> Tuple[FiniteDifferenceStencilDerivation.Result, ...]: 

141 # Symmetry, isotropy and 4th order conditions are not enough to fully specify the stencil 

142 # one weight has to be specifically set to a somewhat arbitrary value 

143 second_neighbor_weight = sp.Rational(1, 10) 

144 second_neighbor_stencil = [(i, j) 

145 for i in (-2, -1, 0, 1, 2) 

146 for j in (-2, -1, 0, 1, 2) 

147 ] 

148 x_diff = FiniteDifferenceStencilDerivation((0,), second_neighbor_stencil) 

149 x_diff.set_weight((2, 0), second_neighbor_weight) 

150 x_diff.assume_symmetric(0, anti_symmetric=True) 

151 x_diff.assume_symmetric(1) 

152 x_diff_stencil = x_diff.get_stencil(isotropic=True) 

153 

154 y_diff = FiniteDifferenceStencilDerivation((1,), second_neighbor_stencil) 

155 y_diff.set_weight((0, 2), second_neighbor_weight) 

156 y_diff.assume_symmetric(1, anti_symmetric=True) 

157 y_diff.assume_symmetric(0) 

158 y_diff_stencil = y_diff.get_stencil(isotropic=True) 

159 

160 return x_diff_stencil, y_diff_stencil