1import warnings 

2from typing import Container, Union 

3 

4import numpy as np 

5import sympy as sp 

6from sympy.logic.boolalg import BooleanFunction 

7 

8import pystencils.astnodes as ast 

9from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets, get_vector_instruction_set 

10from pystencils.data_types import ( 

11 PointerType, TypedSymbol, VectorType, cast_func, collate_types, get_type_of_expression, vector_memory_access) 

12from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt 

13from pystencils.field import Field 

14from pystencils.integer_functions import modulo_ceil, modulo_floor 

15from pystencils.sympyextensions import fast_subs 

16from pystencils.transformations import cut_loop, filtered_tree_iteration, replace_inner_stride_with_one 

17 

18 

19# noinspection PyPep8Naming 

20class vec_any(sp.Function): 

21 nargs = (1,) 

22 

23 

24# noinspection PyPep8Naming 

25class vec_all(sp.Function): 

26 nargs = (1,) 

27 

28 

29def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best', 

30 assume_aligned: bool = False, nontemporal: Union[bool, Container[Union[str, Field]]] = False, 

31 assume_inner_stride_one: bool = False, assume_sufficient_line_padding: bool = True): 

32 """Explicit vectorization using SIMD vectorization via intrinsics. 

33 

34 Args: 

35 kernel_ast: abstract syntax tree (KernelFunction node) 

36 instruction_set: one of the supported vector instruction sets, currently ('sse', 'avx' and 'avx512') 

37 assume_aligned: assume that the first inner cell of each line is aligned. If false, only unaligned-loads are 

38 used. If true, some of the loads are assumed to be from aligned memory addresses. 

39 For example if x is the fastest coordinate, the access to center can be fetched via an 

40 aligned-load instruction, for the west or east accesses potentially slower unaligend-load 

41 instructions have to be used. 

42 nontemporal: a container of fields or field names for which nontemporal (streaming) stores are used. 

43 If true, nontemporal access instructions are used for all fields. 

44 assume_inner_stride_one: kernels with non-constant inner loop bound and strides can not be vectorized since 

45 the inner loop stride is a runtime variable and thus might not be always 1. 

46 If this parameter is set to true, the inner stride is assumed to be always one. 

47 This has to be ensured at runtime! 

48 assume_sufficient_line_padding: if True and assume_inner_stride_one, no tail loop is created but loop is 

49 extended by at most (vector_width-1) elements 

50 assumes that at the end of each line there is enough padding with dummy data 

51 depending on the access pattern there might be additional padding 

52 required at the end of the array 

53 """ 

54 if instruction_set == 'best': 54 ↛ 55line 54 didn't jump to line 55, because the condition on line 54 was never true

55 if get_supported_instruction_sets(): 

56 instruction_set = get_supported_instruction_sets()[-1] 

57 else: 

58 instruction_set = 'avx' 

59 if instruction_set is None: 59 ↛ 60line 59 didn't jump to line 60, because the condition on line 59 was never true

60 return 

61 

62 all_fields = kernel_ast.fields_accessed 

63 if nontemporal is None or nontemporal is False: 

64 nontemporal = {} 

65 elif nontemporal is True: 65 ↛ 68line 65 didn't jump to line 68, because the condition on line 65 was never false

66 nontemporal = all_fields 

67 

68 if assume_inner_stride_one: 

69 replace_inner_stride_with_one(kernel_ast) 

70 

71 field_float_dtypes = set(f.dtype for f in all_fields if f.dtype.is_float()) 

72 if len(field_float_dtypes) != 1: 

73 raise NotImplementedError("Cannot vectorize kernels that contain accesses " 

74 "to differently typed floating point fields") 

75 float_size = field_float_dtypes.pop().numpy_dtype.itemsize 

76 assert float_size in (8, 4) 

77 vector_is = get_vector_instruction_set('double' if float_size == 8 else 'float', 

78 instruction_set=instruction_set) 

79 vector_width = vector_is['width'] 

80 kernel_ast.instruction_set = vector_is 

81 

82 vectorize_rng(kernel_ast, vector_width) 

83 scattergather = 'scatter' in vector_is and 'gather' in vector_is 

84 vectorize_inner_loops_and_adapt_load_stores(kernel_ast, vector_width, assume_aligned, nontemporal, 

85 scattergather, assume_sufficient_line_padding) 

86 insert_vector_casts(kernel_ast) 

87 

88 

89def vectorize_rng(kernel_ast, vector_width): 

90 """Replace scalar result symbols on RNG nodes with vectorial ones""" 

91 from pystencils.rng import RNGBase 

92 subst = {} 

93 

94 def visit_node(node): 

95 for arg in node.args: 

96 if isinstance(arg, RNGBase): 

97 new_result_symbols = [TypedSymbol(s.name, VectorType(s.dtype, width=vector_width)) 

98 for s in arg.result_symbols] 

99 subst.update({s[0]: s[1] for s in zip(arg.result_symbols, new_result_symbols)}) 

100 arg._symbols_defined = set(new_result_symbols) 

101 else: 

102 visit_node(arg) 

103 visit_node(kernel_ast) 

104 fast_subs(kernel_ast.body, subst, skip=lambda e: isinstance(e, RNGBase)) 

105 

106 

107def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_aligned, nontemporal_fields, 

108 scattergather, assume_sufficient_line_padding): 

109 """Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type.""" 

110 all_loops = filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment) 

111 inner_loops = [n for n in all_loops if n.is_innermost_loop] 

112 zero_loop_counters = {l.loop_counter_symbol: 0 for l in all_loops} 

113 

114 for loop_node in inner_loops: 

115 loop_range = loop_node.stop - loop_node.start 

116 

117 # cut off loop tail, that is not a multiple of four 

118 if assume_aligned and assume_sufficient_line_padding: 

119 loop_range = loop_node.stop - loop_node.start 

120 new_stop = loop_node.start + modulo_ceil(loop_range, vector_width) 

121 loop_node.stop = new_stop 

122 else: 

123 cutting_point = modulo_floor(loop_range, vector_width) + loop_node.start 

124 loop_nodes = [l for l in cut_loop(loop_node, [cutting_point]).args if isinstance(l, ast.LoopOverCoordinate)] 

125 assert len(loop_nodes) in (0, 1, 2) # 2 for main and tail loop, 1 if loop range divisible by vector width 

126 if len(loop_nodes) == 0: 

127 continue 

128 loop_node = loop_nodes[0] 

129 

130 # Find all array accesses (indexed) that depend on the loop counter as offset 

131 loop_counter_symbol = ast.LoopOverCoordinate.get_loop_counter_symbol(loop_node.coordinate_to_loop_over) 

132 substitutions = {} 

133 successful = True 

134 for indexed in loop_node.atoms(sp.Indexed): 

135 base, index = indexed.args 

136 if loop_counter_symbol in index.atoms(sp.Symbol): 136 ↛ 134line 136 didn't jump to line 134, because the condition on line 136 was never false

137 loop_counter_is_offset = loop_counter_symbol not in (index - loop_counter_symbol).atoms() 

138 aligned_access = (index - loop_counter_symbol).subs(zero_loop_counters) == 0 

139 stride = sp.simplify(index.subs({loop_counter_symbol: loop_counter_symbol + 1}) - index) 

140 if not loop_counter_is_offset and (not scattergather or loop_counter_symbol in stride.atoms()): 

141 successful = False 

142 break 

143 typed_symbol = base.label 

144 assert type(typed_symbol.dtype) is PointerType, \ 

145 f"Type of access is {typed_symbol.dtype}, {indexed}" 

146 

147 vec_type = VectorType(typed_symbol.dtype.base_type, vector_width) 

148 use_aligned_access = aligned_access and assume_aligned 

149 nontemporal = False 

150 if hasattr(indexed, 'field'): 150 ↛ 152line 150 didn't jump to line 152, because the condition on line 150 was never false

151 nontemporal = (indexed.field in nontemporal_fields) or (indexed.field.name in nontemporal_fields) 

152 substitutions[indexed] = vector_memory_access(indexed, vec_type, use_aligned_access, nontemporal, True, 

153 stride if scattergather else 1) 

154 if nontemporal: 

155 # insert NontemporalFence after the outermost loop 

156 parent = loop_node.parent 

157 while type(parent.parent.parent) is not ast.KernelFunction: 

158 parent = parent.parent 

159 parent.parent.insert_after(ast.NontemporalFence(), parent, if_not_exists=True) 

160 # insert CachelineSize at the beginning of the kernel 

161 parent.parent.insert_front(ast.CachelineSize(), if_not_exists=True) 

162 if not successful: 

163 warnings.warn("Could not vectorize loop because of non-consecutive memory access") 

164 continue 

165 

166 loop_node.step = vector_width 

167 loop_node.subs(substitutions) 

168 vector_int_width = ast_node.instruction_set['intwidth'] 

169 vector_loop_counter = cast_func(loop_counter_symbol, VectorType(loop_counter_symbol.dtype, vector_int_width)) \ 

170 + cast_func(tuple(range(vector_int_width)), VectorType(loop_counter_symbol.dtype, vector_int_width)) 

171 

172 fast_subs(loop_node, {loop_counter_symbol: vector_loop_counter}, 

173 skip=lambda e: isinstance(e, ast.ResolvedFieldAccess) or isinstance(e, vector_memory_access)) 

174 

175 mask_conditionals(loop_node) 

176 

177 

178def mask_conditionals(loop_body): 

179 def visit_node(node, mask): 

180 if isinstance(node, ast.Conditional): 

181 cond = node.condition_expr 

182 skip = (loop_body.loop_counter_symbol not in cond.atoms(sp.Symbol)) or cond.func in (vec_all, vec_any) 

183 cond = True if skip else cond 

184 

185 true_mask = sp.And(cond, mask) 

186 visit_node(node.true_block, true_mask) 

187 if node.false_block: 187 ↛ 188line 187 didn't jump to line 188, because the condition on line 187 was never true

188 false_mask = sp.And(sp.Not(node.condition_expr), mask) 

189 visit_node(node, false_mask) 

190 if not skip: 

191 node.condition_expr = vec_any(node.condition_expr) 

192 elif isinstance(node, ast.SympyAssignment): 

193 if mask is not True: 

194 s = {ma: vector_memory_access(*ma.args[0:4], sp.And(mask, ma.args[4]), *ma.args[5:]) 

195 for ma in node.atoms(vector_memory_access)} 

196 node.subs(s) 

197 else: 

198 for arg in node.args: 

199 visit_node(arg, mask) 

200 

201 visit_node(loop_body, mask=True) 

202 

203 

204def insert_vector_casts(ast_node): 

205 """Inserts necessary casts from scalar values to vector values.""" 

206 

207 handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt, vec_any, vec_all) 

208 

209 def visit_expr(expr): 

210 if isinstance(expr, vector_memory_access): 

211 return vector_memory_access(*expr.args[0:4], visit_expr(expr.args[4]), *expr.args[5:]) 

212 elif isinstance(expr, cast_func): 

213 return expr 

214 elif expr.func is sp.Abs and 'abs' not in ast_node.instruction_set: 214 ↛ 215line 214 didn't jump to line 215, because the condition on line 214 was never true

215 new_arg = visit_expr(expr.args[0]) 

216 base_type = get_type_of_expression(expr.args[0]).base_type if type(expr.args[0]) is vector_memory_access \ 

217 else get_type_of_expression(expr.args[0]) 

218 pw = sp.Piecewise((-new_arg, new_arg < base_type.numpy_dtype.type(0)), 

219 (new_arg, True)) 

220 return visit_expr(pw) 

221 elif expr.func in handled_functions or isinstance(expr, sp.Rel) or isinstance(expr, BooleanFunction): 

222 default_type = 'double' 

223 if expr.func is sp.Mul and expr.args[0] == -1: 

224 # special treatment for the unary minus: make sure that the -1 has the same type as the argument 

225 dtype = int 

226 for arg in expr.atoms(vector_memory_access): 226 ↛ 227line 226 didn't jump to line 227, because the loop on line 226 never started

227 if arg.dtype.base_type.is_float(): 

228 dtype = arg.dtype.base_type.numpy_dtype.type 

229 for arg in expr.atoms(TypedSymbol): 

230 if type(arg.dtype) is VectorType and arg.dtype.base_type.is_float(): 230 ↛ 231line 230 didn't jump to line 231, because the condition on line 230 was never true

231 dtype = arg.dtype.base_type.numpy_dtype.type 

232 if dtype is not int: 232 ↛ 233line 232 didn't jump to line 233, because the condition on line 232 was never true

233 if dtype is np.float32: 

234 default_type = 'float' 

235 expr = sp.Mul(dtype(expr.args[0]), *expr.args[1:]) 

236 new_args = [visit_expr(a) for a in expr.args] 

237 arg_types = [get_type_of_expression(a, default_float_type=default_type) for a in new_args] 

238 if not any(type(t) is VectorType for t in arg_types): 

239 return expr 

240 else: 

241 target_type = collate_types(arg_types) 

242 casted_args = [ 

243 cast_func(a, target_type) if t != target_type and not isinstance(a, vector_memory_access) else a 

244 for a, t in zip(new_args, arg_types)] 

245 return expr.func(*casted_args) 

246 elif expr.func is sp.Pow: 

247 new_arg = visit_expr(expr.args[0]) 

248 return expr.func(new_arg, expr.args[1]) 

249 elif expr.func == sp.Piecewise: 

250 new_results = [visit_expr(a[0]) for a in expr.args] 

251 new_conditions = [visit_expr(a[1]) for a in expr.args] 

252 types_of_results = [get_type_of_expression(a) for a in new_results] 

253 types_of_conditions = [get_type_of_expression(a) for a in new_conditions] 

254 

255 result_target_type = get_type_of_expression(expr) 

256 condition_target_type = collate_types(types_of_conditions) 

257 if type(condition_target_type) is VectorType and type(result_target_type) is not VectorType: 257 ↛ 258line 257 didn't jump to line 258, because the condition on line 257 was never true

258 result_target_type = VectorType(result_target_type, width=condition_target_type.width) 

259 if type(condition_target_type) is not VectorType and type(result_target_type) is VectorType: 

260 condition_target_type = VectorType(condition_target_type, width=result_target_type.width) 

261 

262 casted_results = [cast_func(a, result_target_type) if t != result_target_type else a 

263 for a, t in zip(new_results, types_of_results)] 

264 

265 casted_conditions = [cast_func(a, condition_target_type) 

266 if t != condition_target_type and a is not True else a 

267 for a, t in zip(new_conditions, types_of_conditions)] 

268 

269 return sp.Piecewise(*[(r, c) for r, c in zip(casted_results, casted_conditions)]) 

270 else: 

271 return expr 

272 

273 def visit_node(node, substitution_dict): 

274 substitution_dict = substitution_dict.copy() 

275 for arg in node.args: 

276 if isinstance(arg, ast.SympyAssignment): 

277 assignment = arg 

278 subs_expr = fast_subs(assignment.rhs, substitution_dict, 

279 skip=lambda e: isinstance(e, ast.ResolvedFieldAccess)) 

280 assignment.rhs = visit_expr(subs_expr) 

281 rhs_type = get_type_of_expression(assignment.rhs) 

282 if isinstance(assignment.lhs, TypedSymbol): 

283 lhs_type = assignment.lhs.dtype 

284 if type(rhs_type) is VectorType and type(lhs_type) is not VectorType: 

285 new_lhs_type = VectorType(lhs_type, rhs_type.width) 

286 new_lhs = TypedSymbol(assignment.lhs.name, new_lhs_type) 

287 substitution_dict[assignment.lhs] = new_lhs 

288 assignment.lhs = new_lhs 

289 elif isinstance(assignment.lhs, vector_memory_access): 

290 assignment.lhs = visit_expr(assignment.lhs) 

291 elif isinstance(arg, ast.Conditional): 

292 arg.condition_expr = fast_subs(arg.condition_expr, substitution_dict, 292 ↛ exitline 292 didn't jump to the function exit

293 skip=lambda e: isinstance(e, ast.ResolvedFieldAccess)) 

294 arg.condition_expr = visit_expr(arg.condition_expr) 

295 visit_node(arg, substitution_dict) 

296 else: 

297 visit_node(arg, substitution_dict) 

298 

299 visit_node(ast_node, {})