1import copy 

2import numpy as np 

3import sympy as sp 

4 

5from pystencils.data_types import TypedSymbol, cast_func 

6from pystencils.astnodes import LoopOverCoordinate 

7from pystencils.backends.cbackend import CustomCodeNode 

8from pystencils.sympyextensions import fast_subs 

9 

10 

11class RNGBase(CustomCodeNode): 

12 

13 id = 0 

14 

15 def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), offsets=None, keys=None): 

16 if keys is None: 

17 keys = (0,) * self._num_keys 

18 if offsets is None: 

19 offsets = (0,) * dim 

20 if len(keys) != self._num_keys: 20 ↛ 21line 20 didn't jump to line 21, because the condition on line 20 was never true

21 raise ValueError(f"Provided {len(keys)} keys but need {self._num_keys}") 

22 if len(offsets) != dim: 22 ↛ 23line 22 didn't jump to line 23, because the condition on line 22 was never true

23 raise ValueError(f"Provided {len(offsets)} offsets but need {dim}") 

24 coordinates = [LoopOverCoordinate.get_loop_counter_symbol(i) + offsets[i] for i in range(dim)] 

25 if dim < 3: 25 ↛ 28line 25 didn't jump to line 28, because the condition on line 25 was never false

26 coordinates.append(0) 

27 

28 self._args = sp.sympify([time_step, *coordinates, *keys]) 

29 self.result_symbols = tuple(TypedSymbol(f'random_{self.id}_{i}', self._data_type) 

30 for i in range(self._num_vars)) 

31 symbols_read = set.union(*[s.atoms(sp.Symbol) for s in self.args]) 

32 super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols) 

33 

34 self.headers = [f'"{self._name.split("_")[0]}_rand.h"'] 

35 

36 RNGBase.id += 1 

37 

38 @property 

39 def args(self): 

40 return self._args 

41 

42 def fast_subs(self, subs_dict, skip): 

43 rng = copy.deepcopy(self) 

44 rng._args = [fast_subs(a, subs_dict, skip) for a in rng._args] 

45 return rng 

46 

47 def get_code(self, dialect, vector_instruction_set, print_arg): 

48 code = "\n" 

49 for r in self.result_symbols: 

50 if vector_instruction_set and not self.args[1].atoms(cast_func): 50 ↛ 52line 50 didn't jump to line 52, because the condition on line 50 was never true

51 # this vector RNG has become scalar through substitution 

52 code += f"{r.dtype} {r.name};\n" 

53 else: 

54 code += f"{vector_instruction_set[r.dtype.base_name] if vector_instruction_set else r.dtype} " + \ 

55 f"{r.name};\n" 

56 code += (self._name + "(" + ", ".join([print_arg(a) for a in self.args] 

57 + [r.name for r in self.result_symbols]) + ");\n") 

58 return code 

59 

60 def __repr__(self): 

61 return ", ".join([str(s) for s in self.result_symbols]) + " \\leftarrow " + \ 

62 self._name.capitalize() + "_RNG(" + ", ".join([str(a) for a in self.args]) + ")" 

63 

64 

65class PhiloxTwoDoubles(RNGBase): 

66 _name = "philox_double2" 

67 _data_type = np.float64 

68 _num_vars = 2 

69 _num_keys = 2 

70 

71 

72class PhiloxFourFloats(RNGBase): 

73 _name = "philox_float4" 

74 _data_type = np.float32 

75 _num_vars = 4 

76 _num_keys = 2 

77 

78 

79class AESNITwoDoubles(RNGBase): 

80 _name = "aesni_double2" 

81 _data_type = np.float64 

82 _num_vars = 2 

83 _num_keys = 4 

84 

85 

86class AESNIFourFloats(RNGBase): 

87 _name = "aesni_float4" 

88 _data_type = np.float32 

89 _num_vars = 4 

90 _num_keys = 4 

91 

92 

93def random_symbol(assignment_list, dim, seed=TypedSymbol("seed", np.uint32), rng_node=PhiloxTwoDoubles, 

94 time_step=TypedSymbol("time_step", np.uint32), offsets=None): 

95 """Return a symbol generator for random numbers 

96  

97 Args: 

98 assignment_list: the subexpressions member of an AssignmentCollection, into which helper variables assignments 

99 will be inserted 

100 dim: 2 or 3 for two or three spatial dimensions 

101 seed: an integer or TypedSymbol(..., np.uint32) to seed the random number generator. If you create multiple 

102 symbol generators, please pass them different seeds so you don't get the same stream of random numbers! 

103 rng_node: which random number generator to use (PhiloxTwoDoubles, PhiloxFourFloats, AESNITwoDoubles, 

104 AESNIFourFloats). 

105 time_step: TypedSymbol(..., np.uint32) that indicates the number of the current time step 

106 offsets: tuple of offsets (constant integers or TypedSymbol(..., np.uint32)) that give the global coordinates 

107 of the local origin 

108 """ 

109 counter = 0 

110 while True: 

111 keys = (counter, seed) + (0,) * (rng_node._num_keys - 2) 

112 node = rng_node(dim, keys=keys, time_step=time_step, offsets=offsets) 

113 inserted = False 

114 for symbol in node.result_symbols: 

115 if not inserted: 

116 assignment_list.insert(0, node) 

117 inserted = True 

118 yield symbol 

119 counter += 1