1import numpy as np 

2import sympy as sp 

3from sympy.codegen.ast import Assignment 

4from sympy.printing.latex import LatexPrinter 

5 

6__all__ = ['Assignment', 'assignment_from_stencil'] 

7 

8 

9def print_assignment_latex(printer, expr): 

10 """sympy cannot print Assignments as Latex. Thus, this function is added to the sympy Latex printer""" 

11 printed_lhs = printer.doprint(expr.lhs) 

12 printed_rhs = printer.doprint(expr.rhs) 

13 return r"{printed_lhs} \leftarrow {printed_rhs}".format(printed_lhs=printed_lhs, printed_rhs=printed_rhs) 

14 

15 

16def assignment_str(assignment): 

17 return r"{lhs} ← {rhs}".format(lhs=assignment.lhs, rhs=assignment.rhs) 

18 

19 

20_old_new = sp.codegen.ast.Assignment.__new__ 

21 

22 

23def _Assignment__new__(cls, lhs, rhs, *args, **kwargs): 

24 if isinstance(lhs, (list, tuple, sp.Matrix)) and isinstance(rhs, (list, tuple, sp.Matrix)): 24 ↛ 25line 24 didn't jump to line 25, because the condition on line 24 was never true

25 assert len(lhs) == len(rhs), f'{lhs} and {rhs} must have same length when performing vector assignment!' 

26 return tuple(_old_new(cls, a, b, *args, **kwargs) for a, b in zip(lhs, rhs)) 

27 return _old_new(cls, lhs, rhs, *args, **kwargs) 

28 

29 

30Assignment.__str__ = assignment_str 

31Assignment.__new__ = _Assignment__new__ 

32LatexPrinter._print_Assignment = print_assignment_latex 

33 

34sp.MutableDenseMatrix.__hash__ = lambda self: hash(tuple(self)) 34 ↛ exitline 34 didn't run the lambda on line 34

35 

36 

37# Apparently, in SymPy 1.4 Assignment.__hash__ is not implemented. This has been fixed in current master 

38try: 

39 sympy_version = sp.__version__.split('.') 

40 

41 if int(sympy_version[0]) <= 1 and int(sympy_version[1]) <= 4: 41 ↛ 42line 41 didn't jump to line 42, because the condition on line 41 was never true

42 def hash_fun(self): 

43 return hash((self.lhs, self.rhs)) 

44 

45 Assignment.__hash__ = hash_fun 

46except Exception: 

47 pass 

48 

49 

50def assignment_from_stencil(stencil_array, input_field, output_field, 

51 normalization_factor=None, order='visual') -> Assignment: 

52 """Creates an assignment 

53 

54 Args: 

55 stencil_array: nested list of numpy array defining the stencil weights 

56 input_field: field or field access, defining where the stencil should be applied to 

57 output_field: field or field access where the result is written to 

58 normalization_factor: optional normalization factor for the stencil 

59 order: defines how the stencil_array is interpreted. Possible values are 'visual' and 'numpy'. 

60 For details see examples 

61 

62 Returns: 

63 Assignment that can be used to create a kernel 

64 

65 Examples: 

66 >>> import pystencils as ps 

67 >>> f, g = ps.fields("f, g: [2D]") 

68 >>> stencil = [[0, 2, 0], 

69 ... [3, 4, 5], 

70 ... [0, 6, 0]] 

71 

72 By default 'visual ordering is used - i.e. the stencil is applied as the nested lists are written down 

73 >>> expected_output = Assignment(g[0, 0], 3*f[-1, 0] + 6*f[0, -1] + 4*f[0, 0] + 2*f[0, 1] + 5*f[1, 0]) 

74 >>> assignment_from_stencil(stencil, f, g, order='visual') == expected_output 

75 True 

76 

77 'numpy' ordering uses the first coordinate of the stencil array for x offset, second for y offset etc. 

78 >>> expected_output = Assignment(g[0, 0], 2*f[-1, 0] + 3*f[0, -1] + 4*f[0, 0] + 5*f[0, 1] + 6*f[1, 0]) 

79 >>> assignment_from_stencil(stencil, f, g, order='numpy') == expected_output 

80 True 

81 

82 You can also pass field accesses to apply the stencil at an already shifted position: 

83 >>> expected_output = Assignment(g[2, 0], 3*f[0, 0] + 6*f[1, -1] + 4*f[1, 0] + 2*f[1, 1] + 5*f[2, 0]) 

84 >>> assignment_from_stencil(stencil, f[1, 0], g[2, 0]) == expected_output 

85 True 

86 """ 

87 from pystencils.field import Field 

88 

89 stencil_array = np.array(stencil_array) 

90 if order == 'visual': 

91 stencil_array = np.swapaxes(stencil_array, 0, 1) 

92 stencil_array = np.flip(stencil_array, axis=1) 

93 elif order == 'numpy': 

94 pass 

95 else: 

96 raise ValueError("'order' has to be either 'visual' or 'numpy'") 

97 

98 if isinstance(input_field, Field): 

99 input_field = input_field.center 

100 if isinstance(output_field, Field): 

101 output_field = output_field.center 

102 

103 rhs = 0 

104 offset = tuple(s // 2 for s in stencil_array.shape) 

105 

106 for index, factor in np.ndenumerate(stencil_array): 

107 shift = tuple(i - o for i, o in zip(index, offset)) 

108 rhs += factor * input_field.get_shifted(*shift) 

109 

110 if normalization_factor: 

111 rhs *= normalization_factor 

112 

113 return Assignment(output_field, rhs)