1import itertools 

2from collections import defaultdict 

3 

4import numpy as np 

5import sympy as sp 

6 

7from pystencils.field import Field 

8from pystencils.stencil import direction_string_to_offset 

9from pystencils.sympyextensions import multidimensional_sum, prod 

10from pystencils.utils import LinearEquationSystem, fully_contains 

11 

12 

13class FiniteDifferenceStencilDerivation: 

14 """Derives finite difference stencils. 

15 

16 Can derive standard finite difference stencils, as well as isotropic versions 

17 (see Isotropic Finite Differences by A. Kumar) 

18 

19 Args: 

20 derivative_coordinates: tuple indicating which derivative should be approximated, 

21 (1, ) stands for first derivative in second direction (y), 

22 (0, 1) would be a mixed second derivative in x and y 

23 (0, 0, 0) would be a third derivative in x direction 

24 stencil: list of offset tuples, defining the stencil 

25 dx: spacing between grid points, one for all directions, i.e. dx=dy=dz 

26 

27 Examples: 

28 Central differences 

29 >>> fd_1d = FiniteDifferenceStencilDerivation((0,), stencil=[(-1,), (0,), (1,)]) 

30 >>> result = fd_1d.get_stencil() 

31 >>> result 

32 Finite difference stencil of accuracy 2, isotropic error: False 

33 >>> result.weights 

34 [-1/2, 0, 1/2] 

35 

36 Forward differences 

37 >>> fd_1d = FiniteDifferenceStencilDerivation((0,), stencil=[(0,), (1,)]) 

38 >>> result = fd_1d.get_stencil() 

39 >>> result 

40 Finite difference stencil of accuracy 1, isotropic error: False 

41 >>> result.weights 

42 [-1, 1] 

43 """ 

44 

45 def __init__(self, derivative_coordinates, stencil, dx=1): 

46 self.dim = len(stencil[0]) 

47 self.field = Field.create_generic('f', spatial_dimensions=self.dim) 

48 self._derivative = tuple(sorted(derivative_coordinates)) 

49 self._stencil = stencil 

50 self._dx = dx 

51 self.weights = {tuple(d): self.symbolic_weight(*d) for d in self._stencil} 

52 

53 def assume_symmetric(self, dim, anti_symmetric=False): 

54 """Adds restriction that weight in opposite directions of a dimension are equal (symmetric) or 

55 the negative of each other (anti symmetric) 

56 

57 For example: dim=1, assumes that w(1, 1) == w(1, -1), if anti_symmetric=False or 

58 w(1, 1) == -w(1, -1) if anti_symmetric=True 

59 """ 

60 update = {} 

61 for direction, value in self.weights.items(): 

62 inv_direction = tuple(-offset if i == dim else offset for i, offset in enumerate(direction)) 

63 if direction[dim] < 0: 

64 inv_weight = self.weights[inv_direction] 

65 update[direction] = -inv_weight if anti_symmetric else inv_weight 

66 self.weights.update(update) 

67 

68 def set_weight(self, offset, value): 

69 assert offset in self.weights 

70 self.weights[offset] = value 

71 

72 def get_stencil(self, isotropic=False) -> 'FiniteDifferenceStencilDerivation.Result': 

73 weights = [self.weights[d] for d in self._stencil] 

74 system = LinearEquationSystem(sp.Matrix(weights).atoms(sp.Symbol)) 

75 

76 order = 0 

77 

78 while True: 

79 new_system = system.copy() 

80 eq = self.error_term_equations(order) 

81 new_system.add_equations(eq) 

82 sol_structure = new_system.solution_structure() 

83 if sol_structure == 'single': 

84 system = new_system 

85 elif sol_structure == 'multiple': 

86 system = new_system 

87 elif sol_structure == 'none': 

88 break 

89 else: 

90 assert False 

91 order += 1 

92 

93 accuracy = order - len(self._derivative) 

94 error_is_isotropic = False 

95 if isotropic: 

96 new_system = system.copy() 

97 new_system.add_equations(self.isotropy_equations(order)) 

98 sol_structure = new_system.solution_structure() 

99 error_is_isotropic = sol_structure != 'none' 

100 if error_is_isotropic: 

101 system = new_system 

102 

103 solve_res = system.solution() 

104 weight_list = [self.weights[d].subs(solve_res) for d in self._stencil] 

105 return self.Result(self._stencil, weight_list, accuracy, error_is_isotropic) 

106 

107 @staticmethod 

108 def symbolic_weight(*args): 

109 str_args = [str(e) for e in args] 

110 return sp.Symbol(f"w_({','.join(str_args)})") 

111 

112 def error_term_dict(self, order): 

113 error_terms = defaultdict(lambda: 0) 

114 for direction in self._stencil: 

115 weight = self.weights[tuple(direction)] 

116 x = tuple(self._dx * d_i for d_i in direction) 

117 for offset in multidimensional_sum(order, dim=self.field.spatial_dimensions): 

118 fac = sp.factorial(order) 

119 error_terms[tuple(sorted(offset))] += weight / fac * prod(x[off] for off in offset) 

120 if self._derivative in error_terms: 

121 error_terms[self._derivative] -= 1 

122 return error_terms 

123 

124 def error_term_equations(self, order): 

125 return list(self.error_term_dict(order).values()) 

126 

127 def isotropy_equations(self, order): 

128 def cycle_int_sequence(sequence, modulus): 

129 result = [] 

130 arr = np.array(sequence, dtype=int) 

131 while True: 

132 if tuple(arr) in result: 

133 break 

134 result.append(tuple(arr)) 

135 arr = (arr + 1) % modulus 

136 return tuple(set(tuple(sorted(t)) for t in result)) 

137 

138 error_dict = self.error_term_dict(order) 

139 eqs = [] 

140 for derivative_tuple in list(error_dict.keys()): 

141 if fully_contains(self._derivative, derivative_tuple): 

142 remaining = list(derivative_tuple) 

143 for e in self._derivative: 

144 del remaining[remaining.index(e)] 

145 permutations = cycle_int_sequence(remaining, self.dim) 

146 if len(permutations) == 1: 

147 eqs.append(error_dict[derivative_tuple]) 

148 else: 

149 for i in range(1, len(permutations)): 

150 new_eq = (error_dict[tuple(sorted(permutations[i] + self._derivative))] 

151 - error_dict[tuple(sorted(permutations[i - 1] + self._derivative))]) 

152 if new_eq: 

153 eqs.append(new_eq) 

154 else: 

155 eqs.append(error_dict[derivative_tuple]) 

156 return eqs 

157 

158 class Result: 

159 def __init__(self, stencil, weights, accuracy, is_isotropic): 

160 self.stencil = stencil 

161 self.weights = weights 

162 self.accuracy = accuracy 

163 self.is_isotropic = is_isotropic 

164 

165 def visualize(self): 

166 from pystencils.stencil import plot 

167 plot(self.stencil, data=self.weights) 

168 

169 def apply(self, field_access: Field.Access): 

170 f = field_access 

171 return sum(f.get_shifted(*offset) * weight for offset, weight in zip(self.stencil, self.weights)) 

172 

173 def __array__(self): 

174 return np.array(self.as_array().tolist()) 

175 

176 def as_array(self): 

177 dim = len(self.stencil[0]) 

178 assert (dim == 2 or dim == 3), "Only 2D or 3D matrix representations are available" 

179 max_offset = max(max(abs(e) for e in direction) for direction in self.stencil) 

180 shape_list = [] 

181 for i in range(dim): 

182 shape_list.append(2 * max_offset + 1) 

183 

184 number_of_elements = np.prod(shape_list) 

185 shape = tuple(shape_list) 

186 result = sp.MutableDenseNDimArray([0] * number_of_elements, shape) 

187 

188 if dim == 2: 

189 for direction, weight in zip(self.stencil, self.weights): 

190 result[max_offset - direction[1], max_offset + direction[0]] = weight 

191 if dim == 3: 

192 for direction, weight in zip(self.stencil, self.weights): 

193 result[max_offset - direction[1], max_offset + direction[0], max_offset + direction[2]] = weight 

194 

195 return result 

196 

197 def rotate_weights_and_apply(self, field_access: Field.Access, axes): 

198 """derive gradient weights of other direction with already calculated weights of one direction 

199 via rotation and apply them to a field.""" 

200 dim = len(self.stencil[0]) 

201 assert (dim == 2 or dim == 3), "This function is only for 2D or 3D stencils available" 

202 rotated_weights = np.rot90(np.array(self.__array__()), 1, axes) 

203 

204 result = [] 

205 max_offset = max(max(abs(e) for e in direction) for direction in self.stencil) 

206 if dim == 2: 

207 for direction in self.stencil: 

208 result.append(rotated_weights[max_offset - direction[1], 

209 max_offset + direction[0]]) 

210 if dim == 3: 

211 for direction in self.stencil: 

212 result.append(rotated_weights[max_offset - direction[1], 

213 max_offset + direction[0], 

214 max_offset + direction[2]]) 

215 

216 f = field_access 

217 return sum(f.get_shifted(*offset) * weight for offset, weight in zip(self.stencil, result)) 

218 

219 def __repr__(self): 

220 return "Finite difference stencil of accuracy {}, isotropic error: {}".format(self.accuracy, 

221 self.is_isotropic) 

222 

223 

224class FiniteDifferenceStaggeredStencilDerivation: 

225 """Derives a finite difference stencil for application at a staggered position 

226 

227 Args: 

228 neighbor: the neighbor direction string or vector at whose staggered position to calculate the derivative 

229 dim: how many dimensions (2 or 3) 

230 derivative: a tuple of directions over which to perform derivatives 

231 """ 

232 

233 def __init__(self, neighbor, dim, derivative=tuple()): 

234 if type(neighbor) is str: 

235 neighbor = direction_string_to_offset(neighbor) 

236 if dim == 2: 

237 assert neighbor[dim:] == 0 

238 assert derivative is tuple() or max(derivative) < dim 

239 neighbor = sp.Matrix(neighbor[:dim]) 

240 pos = neighbor / 2 

241 

242 def unitvec(i): 

243 """return the `i`-th unit vector in three dimensions""" 

244 a = np.zeros(dim, dtype=int) 

245 a[i] = 1 

246 return a 

247 

248 def flipped(a, i): 

249 """return `a` with its `i`-th element's sign flipped""" 

250 a = a.copy() 

251 a[i] *= -1 

252 return a 

253 

254 # determine the points to use, coordinates are relative to position 

255 points = [] 

256 if np.linalg.norm(neighbor, 1) == 1: 

257 main_points = [neighbor / 2, neighbor / -2] 

258 elif np.linalg.norm(neighbor, 1) == 2: 

259 nonzero_indices = [i for i, v in enumerate(neighbor) if v != 0 and i < dim] 

260 main_points = [neighbor / 2, neighbor / -2, flipped(neighbor / 2, nonzero_indices[0]), 

261 flipped(neighbor / -2, nonzero_indices[0])] 

262 else: 

263 main_points = [neighbor.multiply_elementwise(sp.Matrix(c) / 2) 

264 for c in itertools.product([-1, 1], repeat=3)] 

265 points += main_points 

266 zero_indices = [i for i, v in enumerate(neighbor) if v == 0 and i < dim] 

267 for i in zero_indices: 

268 points += [point + sp.Matrix(unitvec(i)) for point in main_points] 

269 points += [point - sp.Matrix(unitvec(i)) for point in main_points] 

270 points_tuple = tuple([tuple(p) for p in points]) 

271 self._stencil = points_tuple 

272 

273 # determine the stencil weights 

274 if len(derivative) == 0: 

275 weights = None 

276 else: 

277 derivation = FiniteDifferenceStencilDerivation(derivative, points_tuple).get_stencil() 

278 if not derivation.accuracy: 

279 raise Exception('the requested derivative cannot be performed with the available neighbors') 

280 weights = derivation.weights 

281 

282 # if the weights are underdefined, we can choose the free symbols to find the sparsest stencil 

283 free_weights = set(itertools.chain(*[w.free_symbols for w in weights])) 

284 if len(free_weights) > 0: 

285 zero_counts = defaultdict(list) 

286 for values in itertools.product([-1, -sp.Rational(1, 2), 0, 1, sp.Rational(1, 2)], 

287 repeat=len(free_weights)): 

288 subs = {free_weight: value for free_weight, value in zip(free_weights, values)} 

289 weights = [w.subs(subs) for w in derivation.weights] 

290 if not all(a == 0 for a in weights): 

291 zero_count = sum([1 for w in weights if w == 0]) 

292 zero_counts[zero_count].append(weights) 

293 best = zero_counts[max(zero_counts.keys())] 

294 if len(best) > 1: # if there are multiple, pick the one that contains a nonzero center weight 

295 center = [tuple(p + pos) for p in points].index((0, 0, 0)[:dim]) 

296 best = [b for b in best if b[center] != 0] 

297 if len(best) > 1: # if there are still multiple, they are equivalent, so we average 

298 weights = [sum([b[i] for b in best]) / len(best) for i in range(len(weights))] 

299 else: 

300 weights = best[0] 

301 assert weights 

302 

303 points_tuple = tuple([tuple(p + pos) for p in points]) 

304 self._points = points_tuple 

305 self._weights = weights 

306 

307 @property 

308 def points(self): 

309 """return the points of the stencil""" 

310 return self._points 

311 

312 @property 

313 def stencil(self): 

314 """return the points of the stencil relative to the staggered position specified by neighbor""" 

315 return self._stencil 

316 

317 @property 

318 def weights(self): 

319 """return the weights of the stencil""" 

320 assert self._weights is not None 

321 return self._weights 

322 

323 def visualize(self): 

324 if self._weights is None: 

325 ws = None 

326 else: 

327 ws = np.array([w for w in self.weights if w != 0], dtype=float) 

328 pts = np.array([p for i, p in enumerate(self.points) if self.weights[i] != 0], dtype=int) 

329 from pystencils.stencil import plot 

330 plot(pts, data=ws) 

331 

332 def apply(self, access: Field.Access): 

333 return sum([access.get_shifted(*point) * weight for point, weight in zip(self.points, self.weights)])