1""" 

2This module extends the pyplot module with functions to show scalar and vector fields in the usual 

3simulation coordinate system (y-axis goes up), instead of the "image coordinate system" (y axis goes down) that 

4matplotlib normally uses. 

5""" 

6import warnings 

7from itertools import cycle 

8 

9from matplotlib.pyplot import * 

10 

11 

12def vector_field(array, step=2, **kwargs): 

13 """Plots given vector field as quiver (arrow) plot. 

14 

15 Args: 

16 array: numpy array with 3 dimensions, first two are spatial x,y coordinate, the last 

17 coordinate should have shape 2 and stores the 2 velocity components 

18 step: plots only every steps's cell, increase the step for high resolution arrays 

19 kwargs: keyword arguments passed to :func:`matplotlib.pyplot.quiver` 

20 

21 Returns: 

22 quiver plot object 

23 """ 

24 assert len(array.shape) == 3, "Wrong shape of array - did you forget to slice your 3D domain first?" 

25 assert array.shape[2] == 2, "Last array dimension is expected to store 2D vectors" 

26 vel_n = array.swapaxes(0, 1) 

27 res = quiver(vel_n[::step, ::step, 0], vel_n[::step, ::step, 1], **kwargs) 

28 axis('equal') 

29 return res 

30 

31 

32def vector_field_magnitude(array, **kwargs): 

33 """Plots the magnitude of a vector field as colormap. 

34 

35 Args: 

36 array: numpy array with 3 dimensions, first two are spatial x,y coordinate, the last 

37 coordinate should have shape 2 and stores the 2 velocity components 

38 kwargs: keyword arguments passed to :func:`matplotlib.pyplot.imshow` 

39 

40 Returns: 

41 imshow object 

42 """ 

43 assert len(array.shape) == 3, "Wrong shape of array - did you forget to slice your 3D domain first?" 

44 assert array.shape[2] in (2, 3), "Wrong size of the last coordinate. Has to be a 2D or 3D vector field." 

45 from numpy.linalg import norm 

46 norm = norm(array, axis=2, ord=2) 

47 if hasattr(array, 'mask'): 

48 norm = np.ma.masked_array(norm, mask=array.mask[:, :, 0]) 

49 return scalar_field(norm, **kwargs) 

50 

51 

52def scalar_field(array, **kwargs): 

53 """Plots field values as colormap. 

54 

55 Works just as imshow, but uses coordinate system where second coordinate (y) points upwards. 

56 

57 Args: 

58 array: two dimensional numpy array 

59 kwargs: keyword arguments passed to :func:`matplotlib.pyplot.imshow` 

60 

61 Returns: 

62 imshow object 

63 """ 

64 import numpy 

65 array = numpy.swapaxes(array, 0, 1) 

66 res = imshow(array, origin='lower', **kwargs) 

67 axis('equal') 

68 return res 

69 

70 

71def scalar_field_surface(array, **kwargs): 

72 """Plots scalar field as 3D surface 

73 

74 Args: 

75 array: the two dimensional numpy array to plot 

76 kwargs: keyword arguments passed to :func:`mpl_toolkits.mplot3d.Axes3D.plot_surface` 

77 """ 

78 from mpl_toolkits.mplot3d import Axes3D 

79 from matplotlib import cm 

80 

81 fig = gcf() 

82 ax = fig.add_subplot(111, projection='3d') 

83 x, y = np.meshgrid(np.arange(array.shape[0]), np.arange(array.shape[1]), indexing='ij') 

84 kwargs.setdefault('rstride', 2) 

85 kwargs.setdefault('cstride', 2) 

86 kwargs.setdefault('color', 'b') 

87 kwargs.setdefault('cmap', cm.coolwarm) 

88 return ax.plot_surface(x, y, array, **kwargs) 

89 

90 

91def scalar_field_alpha_value(array, color, clip=False, **kwargs): 

92 """Plots an image with same color everywhere, using the array values as transparency. 

93 

94 Array is supposed to have values between 0 and 1 (if this is not the case it is normalized). 

95 An image is plotted that has the same color everywhere, the passed array determines the transparency. 

96 Regions where the array is 1 are fully opaque, areas with 0 are fully transparent. 

97 

98 Args: 

99 array: 2D array with alpha values 

100 color: fill color 

101 clip: if True, all values in the array larger than 1 are set to 1, all values smaller than 0 are set to zero 

102 if False, the array is linearly scaled to the [0, 1] interval 

103 **kwargs: arguments passed to imshow 

104 

105 Returns: 

106 imshow object 

107 """ 

108 import numpy 

109 import matplotlib 

110 assert len(array.shape) == 2, "Wrong shape of array - did you forget to slice your 3D domain first?" 

111 array = numpy.swapaxes(array, 0, 1) 

112 

113 if clip: 

114 normalized_field = array.copy() 

115 normalized_field[normalized_field < 0] = 0 

116 normalized_field[normalized_field > 1] = 1 

117 else: 

118 minimum, maximum = numpy.min(array), numpy.max(array) 

119 normalized_field = (array - minimum) / (maximum - minimum) 

120 

121 color = matplotlib.colors.to_rgba(color) 

122 field_to_plot = numpy.empty(array.shape + (4,)) 

123 # set the complete array to the color 

124 for i in range(3): 

125 field_to_plot[:, :, i] = color[i] 

126 # only the alpha channel varies using the array values 

127 field_to_plot[:, :, 3] = normalized_field 

128 

129 res = imshow(field_to_plot, origin='lower', **kwargs) 

130 axis('equal') 

131 return res 

132 

133 

134def scalar_field_contour(array, **kwargs): 

135 """Small wrapper around contour to transform the coordinate system. 

136 

137 For details see :func:`matplotlib.pyplot.imshow` 

138 """ 

139 array = np.swapaxes(array, 0, 1) 

140 res = contour(array, **kwargs) 

141 axis('equal') 

142 return res 

143 

144 

145def multiple_scalar_fields(array, **kwargs): 

146 """Plots a 3D array by slicing the last dimension and creates on plot for each entry of the last dimension. 

147 

148 Args: 

149 array: 3D array to plot. 

150 **kwargs: passed along to imshow 

151 """ 

152 assert len(array.shape) == 3 

153 sub_plots = array.shape[-1] 

154 for i in range(sub_plots): 

155 subplot(1, sub_plots, i + 1) 

156 title(str(i)) 

157 scalar_field(array[..., i], **kwargs) 

158 colorbar() 

159 

160 

161def phase_plot(phase_field: np.ndarray, linewidth=1.0, clip=True) -> None: 

162 """Plots a phase field array using the phase variables as alpha channel. 

163 

164 Args: 

165 phase_field: array with len(shape) == 3, first two dimensions are spatial, the last one indexes the phase 

166 components. 

167 linewidth: line width of the 0.5 contour lines that are drawn over the alpha blended phase images 

168 clip: see scalar_field_alpha_value function 

169 """ 

170 color_cycle = cycle(['#fe0002', '#00fe00', '#0000ff', '#ffa800', '#f600ff']) 

171 

172 assert len(phase_field.shape) == 3 

173 

174 with warnings.catch_warnings(): 

175 warnings.simplefilter("ignore") 

176 for i in range(phase_field.shape[-1]): 

177 scalar_field_alpha_value(phase_field[..., i], next(color_cycle), clip=clip, interpolation='bilinear') 

178 if linewidth: 

179 for i in range(phase_field.shape[-1]): 

180 scalar_field_contour(phase_field[..., i], levels=[0.5], colors='k', linewidths=[linewidth]) 

181 

182 

183def sympy_function(expr, x_values=None, **kwargs): 

184 """Plots the graph of a sympy term that depends on one symbol only. 

185 

186 Args: 

187 expr: sympy term that depends on one symbol only, which is plotted on the x axis 

188 x_values: describes sampling of x axis. Possible values are: 

189 * tuple of (start, stop) or (start, stop, nr_of_steps) 

190 * None, then start=0, stop=1, nr_of_steps=100 

191 * 1D numpy array with x values 

192 **kwargs: passed on to :func:`matplotlib.pyplot.plot` 

193 

194 Returns: 

195 plot object 

196 """ 

197 import sympy as sp 

198 if x_values is None: 

199 x_arr = np.linspace(0, 1, 100) 

200 elif type(x_values) is tuple: 

201 x_arr = np.linspace(*x_values) 

202 elif isinstance(x_values, np.ndarray): 

203 assert len(x_values.shape) == 1 

204 x_arr = x_values 

205 else: 

206 raise ValueError("Invalid value for parameter x_values") 

207 symbols = expr.atoms(sp.Symbol) 

208 assert len(symbols) == 1, "Sympy expression may only depend on one variable only. Depends on " + str(symbols) 

209 y_arr = sp.lambdify(symbols.pop(), expr)(x_arr) 

210 return plot(x_arr, y_arr, **kwargs) 

211 

212 

213# ------------------------------------------- Animations --------------------------------------------------------------- 

214 

215def __scale_array(arr): 

216 from numpy.linalg import norm 

217 norm_arr = norm(arr, axis=2, ord=2) 

218 if isinstance(arr, np.ma.MaskedArray): 

219 norm_arr = np.ma.masked_array(norm_arr, arr.mask[..., 0]) 

220 return arr / norm_arr.max() 

221 

222 

223def vector_field_animation(run_function, step=2, rescale=True, plot_setup_function=lambda *_: None, 

224 plot_update_function=lambda *_: None, interval=200, frames=180, **kwargs): 

225 """Creates a matplotlib animation of a vector field using a quiver plot. 

226 

227 Args: 

228 run_function: callable without arguments, returning a 2D vector field i.e. numpy array with len(shape)==3 

229 step: see documentation of vector_field function 

230 rescale: if True, the length of the arrows is rescaled in every time step 

231 plot_setup_function: optional callable with the quiver object as argument, 

232 that can be used to set up the plot (title, legend,..) 

233 plot_update_function: optional callable with the quiver object as argument 

234 that is called of the quiver object was updated 

235 interval: delay between frames in milliseconds (see matplotlib.FuncAnimation) 

236 frames: how many frames should be generated, see matplotlib.FuncAnimation 

237 **kwargs: passed to quiver plot 

238 

239 Returns: 

240 matplotlib animation object 

241 """ 

242 import matplotlib.animation as animation 

243 

244 fig = gcf() 

245 im = None 

246 field = run_function() 

247 if rescale: 

248 field = __scale_array(field) 

249 kwargs.setdefault('scale', 0.6) 

250 kwargs.setdefault('angles', 'xy') 

251 kwargs.setdefault('scale_units', 'xy') 

252 

253 quiver_plot = vector_field(field, step=step, **kwargs) 

254 plot_setup_function(quiver_plot) 

255 

256 def update_figure(*_): 

257 f = run_function() 

258 f = np.swapaxes(f, 0, 1) 

259 if rescale: 

260 f = __scale_array(f) 

261 u, v = f[::step, ::step, 0], f[::step, ::step, 1] 

262 quiver_plot.set_UVC(u, v) 

263 plot_update_function(quiver_plot) 

264 return im, 

265 

266 return animation.FuncAnimation(fig, update_figure, interval=interval, frames=frames) 

267 

268 

269def vector_field_magnitude_animation(run_function, plot_setup_function=lambda *_: None, rescale=False, 

270 plot_update_function=lambda *_: None, interval=30, frames=180, **kwargs): 

271 """Animation of a vector field, showing the magnitude as colormap. 

272 

273 For arguments, see vector_field_animation 

274 """ 

275 import matplotlib.animation as animation 

276 from numpy.linalg import norm 

277 

278 fig = gcf() 

279 im = None 

280 field = run_function() 

281 if rescale: 

282 field = __scale_array(field) 

283 im = vector_field_magnitude(field, **kwargs) 

284 plot_setup_function(im) 

285 

286 def update_figure(*_): 

287 f = run_function() 

288 if rescale: 

289 f = __scale_array(f) 

290 normed = norm(f, axis=2, ord=2) 

291 if hasattr(f, 'mask'): 

292 normed = np.ma.masked_array(normed, mask=f.mask[:, :, 0]) 

293 normed = np.swapaxes(normed, 0, 1) 

294 im.set_array(normed) 

295 plot_update_function(im) 

296 return im, 

297 

298 return animation.FuncAnimation(fig, update_figure, interval=interval, frames=frames) 

299 

300 

301def scalar_field_animation(run_function, plot_setup_function=lambda *_: None, rescale=True, 

302 plot_update_function=lambda *_: None, interval=30, frames=180, **kwargs): 

303 """Animation of scalar field as colored image, see `scalar_field`.""" 

304 import matplotlib.animation as animation 

305 

306 fig = gcf() 

307 im = None 

308 field = run_function() 

309 if rescale: 

310 f_min, f_max = np.min(field), np.max(field) 

311 field = (field - f_min) / (f_max - f_min) 

312 im = scalar_field(field, vmin=0.0, vmax=1.0, **kwargs) 

313 else: 

314 im = scalar_field(field, **kwargs) 

315 plot_setup_function(im) 

316 

317 def update_figure(*_): 

318 f = run_function() 

319 if rescale: 

320 f_min, f_max = np.min(f), np.max(f) 

321 f = (f - f_min) / (f_max - f_min) 

322 if hasattr(f, 'mask'): 

323 f = np.ma.masked_array(f, mask=f.mask[:, :]) 

324 f = np.swapaxes(f, 0, 1) 

325 im.set_array(f) 

326 plot_update_function(im) 

327 return im, 

328 

329 return animation.FuncAnimation(fig, update_figure, interval=interval, frames=frames) 

330 

331 

332def surface_plot_animation(run_function, frames=90, interval=30, zlim=None, **kwargs): 

333 """Animation of scalar field as 3D plot.""" 

334 from mpl_toolkits.mplot3d import Axes3D 

335 import matplotlib.animation as animation 

336 from matplotlib import cm 

337 fig = gcf() 

338 ax = fig.add_subplot(111, projection='3d') 

339 data = run_function() 

340 x, y = np.meshgrid(np.arange(data.shape[0]), np.arange(data.shape[1]), indexing='ij') 

341 kwargs.setdefault('rstride', 2) 

342 kwargs.setdefault('cstride', 2) 

343 kwargs.setdefault('color', 'b') 

344 kwargs.setdefault('cmap', cm.coolwarm) 

345 ax.plot_surface(x, y, data, **kwargs) 

346 if zlim is not None: 

347 ax.set_zlim(*zlim) 

348 

349 def update_figure(*_): 

350 d = run_function() 

351 ax.clear() 

352 plot = ax.plot_surface(x, y, d, **kwargs) 

353 if zlim is not None: 

354 ax.set_zlim(*zlim) 

355 return plot, 

356 

357 return animation.FuncAnimation(fig, update_figure, interval=interval, frames=frames, blit=False)