1"""This submodule offers functions to work with stencils in expression an offset-list form.""" 

2from collections import defaultdict 

3from typing import Sequence 

4 

5import numpy as np 

6import sympy as sp 

7 

8 

9def inverse_direction(direction): 

10 """Returns inverse i.e. negative of given direction tuple 

11 

12 Example: 

13 >>> inverse_direction((1, -1, 0)) 

14 (-1, 1, 0) 

15 """ 

16 return tuple([-i for i in direction]) 

17 

18 

19def inverse_direction_string(direction): 

20 """Returns inverse of given direction string""" 

21 return offset_to_direction_string(inverse_direction(direction_string_to_offset(direction))) 

22 

23 

24def is_valid(stencil, max_neighborhood=None): 

25 """ 

26 Tests if a nested sequence is a valid stencil i.e. all the inner sequences have the same length. 

27 If max_neighborhood is specified, it is also verified that the stencil does not contain any direction components 

28 with absolute value greater than the maximal neighborhood. 

29 

30 Examples: 

31 >>> is_valid([(1, 0), (1, 0, 0)]) # stencil entries have different length 

32 False 

33 >>> is_valid([(2, 0), (1, 0)]) 

34 True 

35 >>> is_valid([(2, 0), (1, 0)], max_neighborhood=1) 

36 False 

37 >>> is_valid([(2, 0), (1, 0)], max_neighborhood=2) 

38 True 

39 """ 

40 expected_dim = len(stencil[0]) 

41 for d in stencil: 

42 if len(d) != expected_dim: 

43 return False 

44 if max_neighborhood is not None: 

45 for d_i in d: 

46 if abs(d_i) > max_neighborhood: 

47 return False 

48 return True 

49 

50 

51def is_symmetric(stencil): 

52 """Tests for every direction d, that -d is also in the stencil 

53 

54 Examples: 

55 >>> is_symmetric([(1, 0), (0, 1)]) 

56 False 

57 >>> is_symmetric([(1, 0), (-1, 0)]) 

58 True 

59 """ 

60 for d in stencil: 

61 if inverse_direction(d) not in stencil: 

62 return False 

63 return True 

64 

65 

66def have_same_entries(s1, s2): 

67 """Checks if two stencils are the same 

68 

69 Examples: 

70 >>> stencil1 = [(1, 0), (-1, 0), (0, 1), (0, -1)] 

71 >>> stencil2 = [(-1, 0), (0, -1), (1, 0), (0, 1)] 

72 >>> stencil3 = [(-1, 0), (0, -1), (1, 0)] 

73 >>> have_same_entries(stencil1, stencil2) 

74 True 

75 >>> have_same_entries(stencil1, stencil3) 

76 False 

77 """ 

78 if len(s1) != len(s2): 

79 return False 

80 return len(set(s1) - set(s2)) == 0 

81 

82 

83# -------------------------------------Expression - Coefficient Form Conversion ---------------------------------------- 

84 

85 

86def coefficient_dict(expr): 

87 """Extracts coefficients in front of field accesses in a expression. 

88 

89 Expression may only access a single field at a single index. 

90 

91 Returns: 

92 center, coefficient dict, nonlinear part 

93 where center is the single field that is accessed in expression accessed at center 

94 and coefficient dict maps offsets to coefficients. The nonlinear part is everything that is not in the form of 

95 coefficient times field access. 

96 

97 Examples: 

98 >>> import pystencils as ps 

99 >>> f = ps.fields("f(3) : double[2D]") 

100 >>> field, coeffs, nonlinear_part = coefficient_dict(2 * f[0, 1](1) + 3 * f[-1, 0](1) + 123) 

101 >>> assert nonlinear_part == 123 and field == f(1) 

102 >>> sorted(coeffs.items()) 

103 [((-1, 0), 3), ((0, 1), 2)] 

104 """ 

105 from pystencils.field import Field 

106 expr = expr.expand() 

107 field_accesses = expr.atoms(Field.Access) 

108 fields = set(fa.field for fa in field_accesses) 

109 accessed_indices = set(fa.index for fa in field_accesses) 

110 

111 if len(fields) != 1: 

112 raise ValueError("Could not extract stencil coefficients. " 

113 "Expression has to be a linear function of exactly one field.") 

114 if len(accessed_indices) != 1: 

115 raise ValueError("Could not extract stencil coefficients. Field is accessed at multiple indices") 

116 

117 field = fields.pop() 

118 idx = accessed_indices.pop() 

119 

120 coeffs = defaultdict(lambda: 0) 

121 coeffs.update({fa.offsets: expr.coeff(fa) for fa in field_accesses}) 

122 

123 linear_part = sum(c * field[off](*idx) for off, c in coeffs.items()) 

124 nonlinear_part = expr - linear_part 

125 return field(*idx), coeffs, nonlinear_part 

126 

127 

128def coefficients(expr): 

129 """Returns two lists - one with accessed offsets and one with their coefficients. 

130 

131 Same restrictions as `coefficient_dict` apply. Expression must not have any nonlinear part 

132 

133 >>> import pystencils as ps 

134 >>> f = ps.fields("f(3) : double[2D]") 

135 >>> coff = coefficients(2 * f[0, 1](1) + 3 * f[-1, 0](1)) 

136 """ 

137 field_center, coeffs, nonlinear_part = coefficient_dict(expr) 

138 assert nonlinear_part == 0 

139 stencil = list(coeffs.keys()) 

140 entries = [coeffs[c] for c in stencil] 

141 return stencil, entries 

142 

143 

144def coefficient_list(expr, matrix_form=False): 

145 """Returns stencil coefficients in the form of nested lists 

146 

147 Same restrictions as `coefficient_dict` apply. Expression must not have any nonlinear part 

148 

149 Examples: 

150 >>> import pystencils as ps 

151 >>> f = ps.fields("f: double[2D]") 

152 >>> coefficient_list(2 * f[0, 1] + 3 * f[-1, 0]) 

153 [[0, 0, 0], [3, 0, 0], [0, 2, 0]] 

154 >>> coefficient_list(2 * f[0, 1] + 3 * f[-1, 0], matrix_form=True) 

155 Matrix([ 

156 [0, 2, 0], 

157 [3, 0, 0], 

158 [0, 0, 0]]) 

159 """ 

160 field_center, coeffs, nonlinear_part = coefficient_dict(expr) 

161 assert nonlinear_part == 0 

162 field = field_center.field 

163 

164 dim = field.spatial_dimensions 

165 max_offsets = defaultdict(lambda: 0) 

166 for offset in coeffs.keys(): 

167 for d, off in enumerate(offset): 

168 max_offsets[d] = max(max_offsets[d], abs(off)) 

169 

170 if dim == 1: 

171 result = [coeffs[(i,)] for i in range(-max_offsets[0], max_offsets[0] + 1)] 

172 return sp.Matrix(result) if matrix_form else result 

173 else: 

174 y_range = list(range(-max_offsets[1], max_offsets[1] + 1)) 

175 if matrix_form: 

176 y_range.reverse() 

177 if dim == 2: 

178 result = [[coeffs[(i, j)] 

179 for i in range(-max_offsets[0], max_offsets[0] + 1)] 

180 for j in y_range] 

181 return sp.Matrix(result) if matrix_form else result 

182 elif dim == 3: 

183 result = [[[coeffs[(i, j, k)] 

184 for i in range(-max_offsets[0], max_offsets[0] + 1)] 

185 for j in y_range] 

186 for k in range(-max_offsets[2], max_offsets[2] + 1)] 

187 return [sp.Matrix(l) for l in result] if matrix_form else result 

188 else: 

189 raise ValueError("Can only handle fields with 1,2 or 3 spatial dimensions") 

190 

191 

192# ------------------------------------- Point-on-compass notation ------------------------------------------------------ 

193 

194 

195def offset_component_to_direction_string(coordinate_id: int, value: int) -> str: 

196 """Translates numerical offset to string notation. 

197 

198 x offsets are labeled with east 'E' and 'W', 

199 y offsets with north 'N' and 'S' and 

200 z offsets with top 'T' and bottom 'B' 

201 If the absolute value of the offset is bigger than 1, this number is prefixed. 

202 

203 Args: 

204 coordinate_id: integer 0, 1 or 2 standing for x,y and z 

205 value: integer offset 

206 

207 Examples: 

208 >>> offset_component_to_direction_string(0, 1) 

209 'E' 

210 >>> offset_component_to_direction_string(1, 2) 

211 '2N' 

212 """ 

213 assert 0 <= coordinate_id < 3, "Works only for at most 3D arrays" 

214 name_components = (('W', 'E'), # west, east 

215 ('S', 'N'), # south, north 

216 ('B', 'T')) # bottom, top 

217 if value == 0: 

218 result = "" 

219 elif value < 0: 

220 result = name_components[coordinate_id][0] 

221 else: 

222 result = name_components[coordinate_id][1] 

223 if abs(value) > 1: 223 ↛ 224line 223 didn't jump to line 224, because the condition on line 223 was never true

224 result = "%d%s" % (abs(value), result) 

225 return result 

226 

227 

228def offset_to_direction_string(offsets: Sequence[int]) -> str: 

229 """ 

230 Translates numerical offset to string notation. 

231 For details see :func:`offset_component_to_direction_string` 

232 Args: 

233 offsets: 3-tuple with x,y,z offset 

234 

235 Examples: 

236 >>> offset_to_direction_string([1, -1, 0]) 

237 'SE' 

238 >>> offset_to_direction_string(([-3, 0, -2])) 

239 '2B3W' 

240 """ 

241 if len(offsets) > 3: 241 ↛ 242line 241 didn't jump to line 242, because the condition on line 241 was never true

242 return str(offsets) 

243 names = ["", "", ""] 

244 for i in range(len(offsets)): 

245 names[i] = offset_component_to_direction_string(i, offsets[i]) 

246 name = "".join(reversed(names)) 

247 if name == "": 

248 name = "C" 

249 return name 

250 

251 

252def direction_string_to_offset(direction: str, dim: int = 3): 

253 """ 

254 Reverse mapping of :func:`offset_to_direction_string` 

255 

256 Args: 

257 direction: string representation of offset 

258 dim: dimension of offset, i.e the length of the returned list 

259 

260 Examples: 

261 >>> direction_string_to_offset('NW', dim=3) 

262 array([-1, 1, 0]) 

263 >>> direction_string_to_offset('NW', dim=2) 

264 array([-1, 1]) 

265 >>> direction_string_to_offset(offset_to_direction_string((3,-2,1))) 

266 array([ 3, -2, 1]) 

267 """ 

268 offset_dict = { 

269 'C': np.array([0, 0, 0]), 

270 

271 'W': np.array([-1, 0, 0]), 

272 'E': np.array([1, 0, 0]), 

273 

274 'S': np.array([0, -1, 0]), 

275 'N': np.array([0, 1, 0]), 

276 

277 'B': np.array([0, 0, -1]), 

278 'T': np.array([0, 0, 1]), 

279 } 

280 offset = np.array([0, 0, 0]) 

281 

282 while len(direction) > 0: 

283 factor = 1 

284 first_non_digit = 0 

285 while direction[first_non_digit].isdigit(): 285 ↛ 286line 285 didn't jump to line 286, because the condition on line 285 was never true

286 first_non_digit += 1 

287 if first_non_digit > 0: 287 ↛ 288line 287 didn't jump to line 288, because the condition on line 287 was never true

288 factor = int(direction[:first_non_digit]) 

289 direction = direction[first_non_digit:] 

290 cur_offset = offset_dict[direction[0]] 

291 offset += factor * cur_offset 

292 direction = direction[1:] 

293 return offset[:dim] 

294 

295 

296# -------------------------------------- Visualization ----------------------------------------------------------------- 

297 

298 

299def plot(stencil, **kwargs): 

300 dim = len(stencil[0]) 

301 if dim == 2: 

302 plot_2d(stencil, **kwargs) 

303 else: 

304 slicing = False 

305 if 'slice' in kwargs: 

306 slicing = kwargs['slice'] 

307 del kwargs['slice'] 

308 

309 if slicing: 

310 plot_3d_slicing(stencil, **kwargs) 

311 else: 

312 plot_3d(stencil, **kwargs) 

313 

314 

315def plot_2d(stencil, axes=None, figure=None, data=None, textsize='12', **kwargs): 

316 """ 

317 Creates a matplotlib 2D plot of the stencil 

318 

319 Args: 

320 stencil: sequence of directions 

321 axes: optional matplotlib axes 

322 figure: optional matplotlib figure 

323 data: data to annotate the directions with, if none given, the indices are used 

324 textsize: size of annotation text 

325 """ 

326 from matplotlib.patches import BoxStyle 

327 import matplotlib.pyplot as plt 

328 

329 if axes is None: 

330 if figure is None: 

331 figure = plt.gcf() 

332 axes = figure.gca() 

333 

334 text_box_style = BoxStyle("Round", pad=0.3) 

335 head_length = 0.1 

336 max_offsets = [max(abs(int(d[c])) for d in stencil) for c in (0, 1)] 

337 

338 if data is None: 

339 data = list(range(len(stencil))) 

340 

341 for direction, annotation in zip(stencil, data): 

342 assert len(direction) == 2, "Works only for 2D stencils" 

343 direction = tuple(int(i) for i in direction) 

344 if not(direction[0] == 0 and direction[1] == 0): 

345 axes.arrow(0, 0, direction[0], direction[1], head_width=0.08, head_length=head_length, color='k') 

346 

347 if isinstance(annotation, sp.Basic): 

348 annotation = "$" + sp.latex(annotation) + "$" 

349 else: 

350 annotation = str(annotation) 

351 

352 def position_correction(d, magnitude=0.18): 

353 if d < 0: 

354 return -magnitude 

355 elif d > 0: 

356 return +magnitude 

357 else: 

358 return 0 

359 text_position = [direction[c] + position_correction(direction[c]) for c in (0, 1)] 

360 axes.text(x=text_position[0], y=text_position[1], s=annotation, verticalalignment='center', 

361 zorder=30, horizontalalignment='center', size=textsize, 

362 bbox=dict(boxstyle=text_box_style, facecolor='#00b6eb', alpha=0.85, linewidth=0)) 

363 

364 axes.set_axis_off() 

365 axes.set_aspect('equal') 

366 max_offsets = [m if m > 0 else 0.1 for m in max_offsets] 

367 border = 0.1 

368 axes.set_xlim([-border - max_offsets[0], border + max_offsets[0]]) 

369 axes.set_ylim([-border - max_offsets[1], border + max_offsets[1]]) 

370 

371 

372def plot_3d_slicing(stencil, slice_axis=2, figure=None, data=None, **kwargs): 

373 """Visualizes a 3D, first-neighborhood stencil by plotting 3 slices along a given axis. 

374 

375 Args: 

376 stencil: stencil as sequence of directions 

377 slice_axis: 0, 1, or 2 indicating the axis to slice through 

378 figure: optional matplotlib figure 

379 data: optional data to print as text besides the arrows 

380 """ 

381 import matplotlib.pyplot as plt 

382 

383 for d in stencil: 

384 for element in d: 

385 assert element == -1 or element == 0 or element == 1, "This function can only first neighborhood stencils" 

386 

387 if figure is None: 

388 figure = plt.gcf() 

389 

390 axes = [figure.add_subplot(1, 3, i + 1) for i in range(3)] 

391 splitted_directions = [[], [], []] 

392 splitted_data = [[], [], []] 

393 axes_names = ['x', 'y', 'z'] 

394 

395 for i, d in enumerate(stencil): 

396 split_idx = d[slice_axis] + 1 

397 reduced_dir = tuple([element for j, element in enumerate(d) if j != slice_axis]) 

398 splitted_directions[split_idx].append(reduced_dir) 

399 splitted_data[split_idx].append(i if data is None else data[i]) 

400 

401 for i in range(3): 

402 plot_2d(splitted_directions[i], axes=axes[i], data=splitted_data[i], **kwargs) 

403 for i in [-1, 0, 1]: 

404 axes[i + 1].set_title("Cut at %s=%d" % (axes_names[slice_axis], i), y=1.08) 

405 

406 

407def plot_3d(stencil, figure=None, axes=None, data=None, textsize='8'): 

408 """ 

409 Draws 3D stencil into a 3D coordinate system, parameters are similar to :func:`visualize_stencil_2d` 

410 If data is None, no labels are drawn. To draw the labels as in the 2D case, use ``data=list(range(len(stencil)))`` 

411 """ 

412 from matplotlib.patches import FancyArrowPatch 

413 from mpl_toolkits.mplot3d import proj3d 

414 import matplotlib.pyplot as plt 

415 from matplotlib.patches import BoxStyle 

416 from itertools import product, combinations 

417 import numpy as np 

418 

419 class Arrow3D(FancyArrowPatch): 

420 def __init__(self, xs, ys, zs, *args, **kwargs): 

421 FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs) 

422 self._verts3d = xs, ys, zs 

423 

424 def draw(self, renderer): 

425 xs3d, ys3d, zs3d = self._verts3d 

426 xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M) 

427 self.set_positions((xs[0], ys[0]), (xs[1], ys[1])) 

428 FancyArrowPatch.draw(self, renderer) 

429 

430 if axes is None: 

431 if figure is None: 

432 figure = plt.figure() 

433 axes = figure.gca(projection='3d') 

434 try: 

435 axes.set_aspect("equal") 

436 except NotImplementedError: 

437 pass 

438 

439 if data is None: 

440 data = [None] * len(stencil) 

441 

442 text_offset = 1.25 

443 text_box_style = BoxStyle("Round", pad=0.3) 

444 

445 # Draw cell (cube) 

446 r = [-1, 1] 

447 for s, e in combinations(np.array(list(product(r, r, r))), 2): 

448 if np.sum(np.abs(s - e)) == r[1] - r[0]: 

449 axes.plot3D(*zip(s, e), color="k", alpha=0.5) 

450 

451 for d, annotation in zip(stencil, data): 

452 assert len(d) == 3, "Works only for 3D stencils" 

453 d = tuple(int(i) for i in d) 

454 if not (d[0] == 0 and d[1] == 0 and d[2] == 0): 

455 if d[0] == 0: 

456 color = '#348abd' 

457 elif d[1] == 0: 

458 color = '#fac364' 

459 elif sum([abs(d) for d in d]) == 2: 

460 color = '#95bd50' 

461 else: 

462 color = '#808080' 

463 

464 a = Arrow3D([0, d[0]], [0, d[1]], [0, d[2]], mutation_scale=20, lw=2, arrowstyle="-|>", color=color) 

465 axes.add_artist(a) 

466 

467 if annotation: 

468 if isinstance(annotation, sp.Basic): 

469 annotation = "$" + sp.latex(annotation) + "$" 

470 else: 

471 annotation = str(annotation) 

472 

473 axes.text(x=d[0] * text_offset, y=d[1] * text_offset, z=d[2] * text_offset, 

474 s=annotation, verticalalignment='center', zorder=30, 

475 size=textsize, bbox=dict(boxstyle=text_box_style, facecolor='#777777', alpha=0.6, linewidth=0)) 

476 

477 axes.set_xlim([-text_offset * 1.1, text_offset * 1.1]) 

478 axes.set_ylim([-text_offset * 1.1, text_offset * 1.1]) 

479 axes.set_zlim([-text_offset * 1.1, text_offset * 1.1]) 

480 axes.set_axis_off() 

481 

482 

483def plot_expression(expr, **kwargs): 

484 """Displays coefficients of a linear update expression of a single field as matplotlib arrow drawing.""" 

485 stencil, coeffs = coefficients(expr) 

486 dim = len(stencil[0]) 

487 assert 0 < dim <= 3 

488 if dim == 1: 

489 return coefficient_list(expr, matrix_form=True) 

490 elif dim == 2: 

491 return plot_2d(stencil, data=coeffs, **kwargs) 

492 elif dim == 3: 

493 return plot_3d_slicing(stencil, data=coeffs, **kwargs)