1"""This submodule offers functions to work with stencils in expression an offset-list form."""

2from collections import defaultdict

3from typing import Sequence

5import numpy as np

6import sympy as sp

9def inverse_direction(direction):

10 """Returns inverse i.e. negative of given direction tuple

12 Example:

13 >>> inverse_direction((1, -1, 0))

14 (-1, 1, 0)

15 """

16 return tuple([-i for i in direction])

19def inverse_direction_string(direction):

20 """Returns inverse of given direction string"""

21 return offset_to_direction_string(inverse_direction(direction_string_to_offset(direction)))

24def is_valid(stencil, max_neighborhood=None):

25 """

26 Tests if a nested sequence is a valid stencil i.e. all the inner sequences have the same length.

27 If max_neighborhood is specified, it is also verified that the stencil does not contain any direction components

28 with absolute value greater than the maximal neighborhood.

30 Examples:

31 >>> is_valid([(1, 0), (1, 0, 0)]) # stencil entries have different length

32 False

33 >>> is_valid([(2, 0), (1, 0)])

34 True

35 >>> is_valid([(2, 0), (1, 0)], max_neighborhood=1)

36 False

37 >>> is_valid([(2, 0), (1, 0)], max_neighborhood=2)

38 True

39 """

40 expected_dim = len(stencil[0])

41 for d in stencil:

42 if len(d) != expected_dim:

43 return False

44 if max_neighborhood is not None:

45 for d_i in d:

46 if abs(d_i) > max_neighborhood:

47 return False

48 return True

51def is_symmetric(stencil):

52 """Tests for every direction d, that -d is also in the stencil

54 Examples:

55 >>> is_symmetric([(1, 0), (0, 1)])

56 False

57 >>> is_symmetric([(1, 0), (-1, 0)])

58 True

59 """

60 for d in stencil:

61 if inverse_direction(d) not in stencil:

62 return False

63 return True

66def have_same_entries(s1, s2):

67 """Checks if two stencils are the same

69 Examples:

70 >>> stencil1 = [(1, 0), (-1, 0), (0, 1), (0, -1)]

71 >>> stencil2 = [(-1, 0), (0, -1), (1, 0), (0, 1)]

72 >>> stencil3 = [(-1, 0), (0, -1), (1, 0)]

73 >>> have_same_entries(stencil1, stencil2)

74 True

75 >>> have_same_entries(stencil1, stencil3)

76 False

77 """

78 if len(s1) != len(s2):

79 return False

80 return len(set(s1) - set(s2)) == 0

83# -------------------------------------Expression - Coefficient Form Conversion ----------------------------------------

86def coefficient_dict(expr):

87 """Extracts coefficients in front of field accesses in a expression.

89 Expression may only access a single field at a single index.

91 Returns:

92 center, coefficient dict, nonlinear part

93 where center is the single field that is accessed in expression accessed at center

94 and coefficient dict maps offsets to coefficients. The nonlinear part is everything that is not in the form of

95 coefficient times field access.

97 Examples:

98 >>> import pystencils as ps

99 >>> f = ps.fields("f(3) : double[2D]")

100 >>> field, coeffs, nonlinear_part = coefficient_dict(2 * f[0, 1](1) + 3 * f[-1, 0](1) + 123)

101 >>> assert nonlinear_part == 123 and field == f(1)

102 >>> sorted(coeffs.items())

103 [((-1, 0), 3), ((0, 1), 2)]

104 """

105 from pystencils.field import Field

106 expr = expr.expand()

107 field_accesses = expr.atoms(Field.Access)

108 fields = set(fa.field for fa in field_accesses)

109 accessed_indices = set(fa.index for fa in field_accesses)

111 if len(fields) != 1:

112 raise ValueError("Could not extract stencil coefficients. "

113 "Expression has to be a linear function of exactly one field.")

114 if len(accessed_indices) != 1:

115 raise ValueError("Could not extract stencil coefficients. Field is accessed at multiple indices")

117 field = fields.pop()

118 idx = accessed_indices.pop()

120 coeffs = defaultdict(lambda: 0)

121 coeffs.update({fa.offsets: expr.coeff(fa) for fa in field_accesses})

123 linear_part = sum(c * field[off](*idx) for off, c in coeffs.items())

124 nonlinear_part = expr - linear_part

125 return field(*idx), coeffs, nonlinear_part

128def coefficients(expr):

129 """Returns two lists - one with accessed offsets and one with their coefficients.

131 Same restrictions as `coefficient_dict` apply. Expression must not have any nonlinear part

133 >>> import pystencils as ps

134 >>> f = ps.fields("f(3) : double[2D]")

135 >>> coff = coefficients(2 * f[0, 1](1) + 3 * f[-1, 0](1))

136 """

137 field_center, coeffs, nonlinear_part = coefficient_dict(expr)

138 assert nonlinear_part == 0

139 stencil = list(coeffs.keys())

140 entries = [coeffs[c] for c in stencil]

141 return stencil, entries

144def coefficient_list(expr, matrix_form=False):

145 """Returns stencil coefficients in the form of nested lists

147 Same restrictions as `coefficient_dict` apply. Expression must not have any nonlinear part

149 Examples:

150 >>> import pystencils as ps

151 >>> f = ps.fields("f: double[2D]")

152 >>> coefficient_list(2 * f[0, 1] + 3 * f[-1, 0])

153 [[0, 0, 0], [3, 0, 0], [0, 2, 0]]

154 >>> coefficient_list(2 * f[0, 1] + 3 * f[-1, 0], matrix_form=True)

155 Matrix([

156 [0, 2, 0],

157 [3, 0, 0],

158 [0, 0, 0]])

159 """

160 field_center, coeffs, nonlinear_part = coefficient_dict(expr)

161 assert nonlinear_part == 0

162 field = field_center.field

164 dim = field.spatial_dimensions

165 max_offsets = defaultdict(lambda: 0)

166 for offset in coeffs.keys():

167 for d, off in enumerate(offset):

168 max_offsets[d] = max(max_offsets[d], abs(off))

170 if dim == 1:

171 result = [coeffs[(i,)] for i in range(-max_offsets[0], max_offsets[0] + 1)]

172 return sp.Matrix(result) if matrix_form else result

173 else:

174 y_range = list(range(-max_offsets[1], max_offsets[1] + 1))

175 if matrix_form:

176 y_range.reverse()

177 if dim == 2:

178 result = [[coeffs[(i, j)]

179 for i in range(-max_offsets[0], max_offsets[0] + 1)]

180 for j in y_range]

181 return sp.Matrix(result) if matrix_form else result

182 elif dim == 3:

183 result = [[[coeffs[(i, j, k)]

184 for i in range(-max_offsets[0], max_offsets[0] + 1)]

185 for j in y_range]

186 for k in range(-max_offsets[2], max_offsets[2] + 1)]

187 return [sp.Matrix(l) for l in result] if matrix_form else result

188 else:

189 raise ValueError("Can only handle fields with 1,2 or 3 spatial dimensions")

192# ------------------------------------- Point-on-compass notation ------------------------------------------------------

195def offset_component_to_direction_string(coordinate_id: int, value: int) -> str:

196 """Translates numerical offset to string notation.

198 x offsets are labeled with east 'E' and 'W',

199 y offsets with north 'N' and 'S' and

200 z offsets with top 'T' and bottom 'B'

201 If the absolute value of the offset is bigger than 1, this number is prefixed.

203 Args:

204 coordinate_id: integer 0, 1 or 2 standing for x,y and z

205 value: integer offset

207 Examples:

208 >>> offset_component_to_direction_string(0, 1)

209 'E'

210 >>> offset_component_to_direction_string(1, 2)

211 '2N'

212 """

213 assert 0 <= coordinate_id < 3, "Works only for at most 3D arrays"

214 name_components = (('W', 'E'), # west, east

215 ('S', 'N'), # south, north

216 ('B', 'T')) # bottom, top

217 if value == 0:

218 result = ""

219 elif value < 0:

220 result = name_components[coordinate_id][0]

221 else:

222 result = name_components[coordinate_id][1]

223 if abs(value) > 1: 223 ↛ 224line 223 didn't jump to line 224, because the condition on line 223 was never true

224 result = "%d%s" % (abs(value), result)

225 return result

228def offset_to_direction_string(offsets: Sequence[int]) -> str:

229 """

230 Translates numerical offset to string notation.

231 For details see :func:`offset_component_to_direction_string`

232 Args:

233 offsets: 3-tuple with x,y,z offset

235 Examples:

236 >>> offset_to_direction_string([1, -1, 0])

237 'SE'

238 >>> offset_to_direction_string(([-3, 0, -2]))

239 '2B3W'

240 """

241 if len(offsets) > 3: 241 ↛ 242line 241 didn't jump to line 242, because the condition on line 241 was never true

242 return str(offsets)

243 names = ["", "", ""]

244 for i in range(len(offsets)):

245 names[i] = offset_component_to_direction_string(i, offsets[i])

246 name = "".join(reversed(names))

247 if name == "":

248 name = "C"

249 return name

252def direction_string_to_offset(direction: str, dim: int = 3):

253 """

254 Reverse mapping of :func:`offset_to_direction_string`

256 Args:

257 direction: string representation of offset

258 dim: dimension of offset, i.e the length of the returned list

260 Examples:

261 >>> direction_string_to_offset('NW', dim=3)

262 array([-1, 1, 0])

263 >>> direction_string_to_offset('NW', dim=2)

264 array([-1, 1])

265 >>> direction_string_to_offset(offset_to_direction_string((3,-2,1)))

266 array([ 3, -2, 1])

267 """

268 offset_dict = {

269 'C': np.array([0, 0, 0]),

271 'W': np.array([-1, 0, 0]),

272 'E': np.array([1, 0, 0]),

274 'S': np.array([0, -1, 0]),

275 'N': np.array([0, 1, 0]),

277 'B': np.array([0, 0, -1]),

278 'T': np.array([0, 0, 1]),

279 }

280 offset = np.array([0, 0, 0])

282 while len(direction) > 0:

283 factor = 1

284 first_non_digit = 0

285 while direction[first_non_digit].isdigit(): 285 ↛ 286line 285 didn't jump to line 286, because the condition on line 285 was never true

286 first_non_digit += 1

287 if first_non_digit > 0: 287 ↛ 288line 287 didn't jump to line 288, because the condition on line 287 was never true

288 factor = int(direction[:first_non_digit])

289 direction = direction[first_non_digit:]

290 cur_offset = offset_dict[direction[0]]

291 offset += factor * cur_offset

292 direction = direction[1:]

293 return offset[:dim]

296# -------------------------------------- Visualization -----------------------------------------------------------------

299def plot(stencil, **kwargs):

300 dim = len(stencil[0])

301 if dim == 2:

302 plot_2d(stencil, **kwargs)

303 else:

304 slicing = False

305 if 'slice' in kwargs:

306 slicing = kwargs['slice']

307 del kwargs['slice']

309 if slicing:

310 plot_3d_slicing(stencil, **kwargs)

311 else:

312 plot_3d(stencil, **kwargs)

315def plot_2d(stencil, axes=None, figure=None, data=None, textsize='12', **kwargs):

316 """

317 Creates a matplotlib 2D plot of the stencil

319 Args:

320 stencil: sequence of directions

321 axes: optional matplotlib axes

322 figure: optional matplotlib figure

323 data: data to annotate the directions with, if none given, the indices are used

324 textsize: size of annotation text

325 """

326 from matplotlib.patches import BoxStyle

327 import matplotlib.pyplot as plt

329 if axes is None:

330 if figure is None:

331 figure = plt.gcf()

332 axes = figure.gca()

336 max_offsets = [max(abs(int(d[c])) for d in stencil) for c in (0, 1)]

338 if data is None:

339 data = list(range(len(stencil)))

341 for direction, annotation in zip(stencil, data):

342 assert len(direction) == 2, "Works only for 2D stencils"

343 direction = tuple(int(i) for i in direction)

344 if not(direction[0] == 0 and direction[1] == 0):

347 if isinstance(annotation, sp.Basic):

348 annotation = "\$" + sp.latex(annotation) + "\$"

349 else:

350 annotation = str(annotation)

352 def position_correction(d, magnitude=0.18):

353 if d < 0:

354 return -magnitude

355 elif d > 0:

356 return +magnitude

357 else:

358 return 0

359 text_position = [direction[c] + position_correction(direction[c]) for c in (0, 1)]

360 axes.text(x=text_position[0], y=text_position[1], s=annotation, verticalalignment='center',

361 zorder=30, horizontalalignment='center', size=textsize,

362 bbox=dict(boxstyle=text_box_style, facecolor='#00b6eb', alpha=0.85, linewidth=0))

364 axes.set_axis_off()

365 axes.set_aspect('equal')

366 max_offsets = [m if m > 0 else 0.1 for m in max_offsets]

367 border = 0.1

368 axes.set_xlim([-border - max_offsets[0], border + max_offsets[0]])

369 axes.set_ylim([-border - max_offsets[1], border + max_offsets[1]])

372def plot_3d_slicing(stencil, slice_axis=2, figure=None, data=None, **kwargs):

373 """Visualizes a 3D, first-neighborhood stencil by plotting 3 slices along a given axis.

375 Args:

376 stencil: stencil as sequence of directions

377 slice_axis: 0, 1, or 2 indicating the axis to slice through

378 figure: optional matplotlib figure

379 data: optional data to print as text besides the arrows

380 """

381 import matplotlib.pyplot as plt

383 for d in stencil:

384 for element in d:

385 assert element == -1 or element == 0 or element == 1, "This function can only first neighborhood stencils"

387 if figure is None:

388 figure = plt.gcf()

390 axes = [figure.add_subplot(1, 3, i + 1) for i in range(3)]

391 splitted_directions = [[], [], []]

392 splitted_data = [[], [], []]

393 axes_names = ['x', 'y', 'z']

395 for i, d in enumerate(stencil):

396 split_idx = d[slice_axis] + 1

397 reduced_dir = tuple([element for j, element in enumerate(d) if j != slice_axis])

398 splitted_directions[split_idx].append(reduced_dir)

399 splitted_data[split_idx].append(i if data is None else data[i])

401 for i in range(3):

402 plot_2d(splitted_directions[i], axes=axes[i], data=splitted_data[i], **kwargs)

403 for i in [-1, 0, 1]:

404 axes[i + 1].set_title("Cut at %s=%d" % (axes_names[slice_axis], i), y=1.08)

407def plot_3d(stencil, figure=None, axes=None, data=None, textsize='8'):

408 """

409 Draws 3D stencil into a 3D coordinate system, parameters are similar to :func:`visualize_stencil_2d`

410 If data is None, no labels are drawn. To draw the labels as in the 2D case, use ``data=list(range(len(stencil)))``

411 """

412 from matplotlib.patches import FancyArrowPatch

413 from mpl_toolkits.mplot3d import proj3d

414 import matplotlib.pyplot as plt

415 from matplotlib.patches import BoxStyle

416 from itertools import product, combinations

417 import numpy as np

419 class Arrow3D(FancyArrowPatch):

420 def __init__(self, xs, ys, zs, *args, **kwargs):

421 FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs)

422 self._verts3d = xs, ys, zs

424 def draw(self, renderer):

425 xs3d, ys3d, zs3d = self._verts3d

426 xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M)

427 self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))

428 FancyArrowPatch.draw(self, renderer)

430 if axes is None:

431 if figure is None:

432 figure = plt.figure()

433 axes = figure.gca(projection='3d')

434 try:

435 axes.set_aspect("equal")

436 except NotImplementedError:

437 pass

439 if data is None:

440 data = [None] * len(stencil)

442 text_offset = 1.25

445 # Draw cell (cube)

446 r = [-1, 1]

447 for s, e in combinations(np.array(list(product(r, r, r))), 2):

448 if np.sum(np.abs(s - e)) == r[1] - r[0]:

449 axes.plot3D(*zip(s, e), color="k", alpha=0.5)

451 for d, annotation in zip(stencil, data):

452 assert len(d) == 3, "Works only for 3D stencils"

453 d = tuple(int(i) for i in d)

454 if not (d[0] == 0 and d[1] == 0 and d[2] == 0):

455 if d[0] == 0:

456 color = '#348abd'

457 elif d[1] == 0:

458 color = '#fac364'

459 elif sum([abs(d) for d in d]) == 2:

460 color = '#95bd50'

461 else:

462 color = '#808080'

464 a = Arrow3D([0, d[0]], [0, d[1]], [0, d[2]], mutation_scale=20, lw=2, arrowstyle="-|>", color=color)

467 if annotation:

468 if isinstance(annotation, sp.Basic):

469 annotation = "\$" + sp.latex(annotation) + "\$"

470 else:

471 annotation = str(annotation)

473 axes.text(x=d[0] * text_offset, y=d[1] * text_offset, z=d[2] * text_offset,

474 s=annotation, verticalalignment='center', zorder=30,

475 size=textsize, bbox=dict(boxstyle=text_box_style, facecolor='#777777', alpha=0.6, linewidth=0))

477 axes.set_xlim([-text_offset * 1.1, text_offset * 1.1])

478 axes.set_ylim([-text_offset * 1.1, text_offset * 1.1])

479 axes.set_zlim([-text_offset * 1.1, text_offset * 1.1])

480 axes.set_axis_off()

483def plot_expression(expr, **kwargs):

484 """Displays coefficients of a linear update expression of a single field as matplotlib arrow drawing."""

485 stencil, coeffs = coefficients(expr)

486 dim = len(stencil[0])

487 assert 0 < dim <= 3

488 if dim == 1:

489 return coefficient_list(expr, matrix_form=True)

490 elif dim == 2:

491 return plot_2d(stencil, data=coeffs, **kwargs)

492 elif dim == 3:

493 return plot_3d_slicing(stencil, data=coeffs, **kwargs)