1import hashlib 

2import pickle 

3import warnings 

4from collections import OrderedDict, defaultdict, namedtuple 

5from copy import deepcopy 

6from types import MappingProxyType 

7 

8import numpy as np 

9import sympy as sp 

10from sympy.core.numbers import ImaginaryUnit 

11from sympy.logic.boolalg import Boolean, BooleanFunction 

12 

13import pystencils.astnodes as ast 

14import pystencils.integer_functions 

15from pystencils.assignment import Assignment 

16from pystencils.data_types import ( 

17 PointerType, StructType, TypedImaginaryUnit, TypedSymbol, cast_func, collate_types, create_type, 

18 get_base_type, get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func) 

19from pystencils.field import AbstractField, Field, FieldType 

20from pystencils.kernelparameters import FieldPointerSymbol 

21from pystencils.simp.assignment_collection import AssignmentCollection 

22from pystencils.slicing import normalize_slice 

23 

24 

25class NestedScopes: 

26 """Symbol visibility model using nested scopes 

27 

28 - every accessed symbol that was not defined before, is added as a "free parameter" 

29 - free parameters are global, i.e. they are not in scopes 

30 - push/pop adds or removes a scope 

31 

32 >>> s = NestedScopes() 

33 >>> s.access_symbol("a") 

34 >>> s.is_defined("a") 

35 False 

36 >>> s.free_parameters 

37 {'a'} 

38 >>> s.define_symbol("b") 

39 >>> s.is_defined("b") 

40 True 

41 >>> s.push() 

42 >>> s.is_defined_locally("b") 

43 False 

44 >>> s.define_symbol("c") 

45 >>> s.pop() 

46 >>> s.is_defined("c") 

47 False 

48 """ 

49 

50 def __init__(self): 

51 self.free_parameters = set() 

52 self._defined = [set()] 

53 

54 def access_symbol(self, symbol): 

55 if not self.is_defined(symbol): 

56 self.free_parameters.add(symbol) 

57 

58 def define_symbol(self, symbol): 

59 self._defined[-1].add(symbol) 

60 

61 def is_defined(self, symbol): 

62 return any(symbol in scopes for scopes in self._defined) 

63 

64 def is_defined_locally(self, symbol): 

65 return symbol in self._defined[-1] 

66 

67 def push(self): 

68 self._defined.append(set()) 

69 

70 def pop(self): 

71 self._defined.pop() 

72 assert self.depth >= 1 

73 

74 @property 

75 def depth(self): 

76 return len(self._defined) 

77 

78 

79def filtered_tree_iteration(node, node_type, stop_type=None): 

80 for arg in node.args: 

81 if isinstance(arg, node_type): 

82 yield arg 

83 elif stop_type and isinstance(node, stop_type): 83 ↛ 84line 83 didn't jump to line 84, because the condition on line 83 was never true

84 continue 

85 

86 yield from filtered_tree_iteration(arg, node_type) 

87 

88 

89def generic_visit(term, visitor): 

90 if isinstance(term, AssignmentCollection): 

91 new_main_assignments = generic_visit(term.main_assignments, visitor) 

92 new_subexpressions = generic_visit(term.subexpressions, visitor) 

93 return term.copy(new_main_assignments, new_subexpressions) 

94 elif isinstance(term, list): 

95 return [generic_visit(e, visitor) for e in term] 

96 elif isinstance(term, Assignment): 

97 return Assignment(term.lhs, generic_visit(term.rhs, visitor)) 

98 elif isinstance(term, sp.Matrix): 

99 return term.applyfunc(lambda e: generic_visit(e, visitor)) 

100 else: 

101 return visitor(term) 

102 

103 

104def unify_shape_symbols(body, common_shape, fields): 

105 """Replaces symbols for array sizes to ensure they are represented by the same unique symbol. 

106 

107 When creating a kernel with variable array sizes, all passed arrays must have the same size. 

108 This is ensured when the kernel is called. Inside the kernel this means that only on symbol has to be used instead 

109 of one for each field. For example shape_arr1[0] and shape_arr2[0] must be equal, so they should also be 

110 represented by the same symbol. 

111 

112 Args: 

113 body: ast node, for the kernel part where substitutions is made, is modified in-place 

114 common_shape: shape of the field that was chosen 

115 fields: all fields whose shapes should be replaced by common_shape 

116 """ 

117 substitutions = {} 

118 for field in fields: 

119 assert len(field.spatial_shape) == len(common_shape) 

120 if not field.has_fixed_shape: 

121 for common_shape_component, shape_component in zip(common_shape, field.spatial_shape): 

122 if shape_component != common_shape_component: 

123 substitutions[shape_component] = common_shape_component 

124 if substitutions: 

125 body.subs(substitutions) 

126 

127 

128def get_common_shape(field_set): 

129 """Takes a set of pystencils Fields and returns their common spatial shape if it exists. Otherwise 

130 ValueError is raised""" 

131 nr_of_fixed_shaped_fields = 0 

132 for f in field_set: 

133 if f.has_fixed_shape: 

134 nr_of_fixed_shaped_fields += 1 

135 

136 if nr_of_fixed_shaped_fields > 0 and nr_of_fixed_shaped_fields != len(field_set): 136 ↛ 137line 136 didn't jump to line 137, because the condition on line 136 was never true

137 fixed_field_names = ",".join([f.name for f in field_set if f.has_fixed_shape]) 

138 var_field_names = ",".join([f.name for f in field_set if not f.has_fixed_shape]) 

139 msg = "Mixing fixed-shaped and variable-shape fields in a single kernel is not possible\n" 

140 msg += f"Variable shaped: {var_field_names} \nFixed shaped: {fixed_field_names}" 

141 raise ValueError(msg) 

142 

143 shape_set = set([f.spatial_shape for f in field_set]) 

144 if nr_of_fixed_shaped_fields == len(field_set): 

145 if len(shape_set) != 1: 145 ↛ 146line 145 didn't jump to line 146, because the condition on line 145 was never true

146 raise ValueError("Differently sized field accesses in loop body: " + str(shape_set)) 

147 

148 shape = list(sorted(shape_set, key=lambda e: str(e[0])))[0] 

149 return shape 

150 

151 

152def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_order=None): 

153 """Uses :class:`pystencils.field.Field.Access` to create (multiple) loops around given AST. 

154 

155 Args: 

156 body: Block object with inner loop contents 

157 iteration_slice: if not None, iteration is done only over this slice of the field 

158 ghost_layers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers 

159 if None, the number of ghost layers is determined automatically and assumed to be equal for a 

160 all dimensions 

161 loop_order: loop ordering from outer to inner loop (optimal ordering is same as layout) 

162 

163 Returns: 

164 tuple of loop-node, ghost_layer_info 

165 """ 

166 # find correct ordering by inspecting participating FieldAccesses 

167 field_accesses = body.atoms(AbstractField.AbstractAccess) 

168 field_accesses = {e for e in field_accesses if not e.is_absolute_access} 

169 

170 # exclude accesses to buffers from field_list, because buffers are treated separately 

171 field_list = [e.field for e in field_accesses if not (FieldType.is_buffer(e.field) or FieldType.is_custom(e.field))] 

172 if len(field_list) == 0: # when kernel contains only custom fields 172 ↛ 173line 172 didn't jump to line 173, because the condition on line 172 was never true

173 field_list = [e.field for e in field_accesses if not (FieldType.is_buffer(e.field))] 

174 

175 fields = set(field_list) 

176 

177 if loop_order is None: 177 ↛ 178line 177 didn't jump to line 178, because the condition on line 177 was never true

178 loop_order = get_optimal_loop_ordering(fields) 

179 

180 shape = get_common_shape(fields) 

181 unify_shape_symbols(body, common_shape=shape, fields=fields) 

182 

183 if iteration_slice is not None: 183 ↛ 184line 183 didn't jump to line 184, because the condition on line 183 was never true

184 iteration_slice = normalize_slice(iteration_slice, shape) 

185 

186 if ghost_layers is None: 

187 required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses]) 

188 ghost_layers = [(required_ghost_layers, required_ghost_layers)] * len(loop_order) 

189 if isinstance(ghost_layers, int): 

190 ghost_layers = [(ghost_layers, ghost_layers)] * len(loop_order) 

191 

192 current_body = body 

193 for i, loop_coordinate in enumerate(reversed(loop_order)): 

194 if iteration_slice is None: 194 ↛ 200line 194 didn't jump to line 200, because the condition on line 194 was never false

195 begin = ghost_layers[loop_coordinate][0] 

196 end = shape[loop_coordinate] - ghost_layers[loop_coordinate][1] 

197 new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1) 

198 current_body = ast.Block([new_loop]) 

199 else: 

200 slice_component = iteration_slice[loop_coordinate] 

201 if type(slice_component) is slice: 

202 sc = slice_component 

203 new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, sc.start, sc.stop, sc.step) 

204 current_body = ast.Block([new_loop]) 

205 else: 

206 assignment = ast.SympyAssignment(ast.LoopOverCoordinate.get_loop_counter_symbol(loop_coordinate), 

207 sp.sympify(slice_component)) 

208 current_body.insert_front(assignment) 

209 

210 return current_body, ghost_layers 

211 

212 

213def create_intermediate_base_pointer(field_access, coordinates, previous_ptr): 

214 r""" 

215 Addressing elements in structured arrays is done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]` 

216 where :math:`c_i` is the coordinate value and :math:`s_i` the stride of a coordinate. 

217 The sum can be split up into multiple parts, such that parts of it can be pulled before loops. 

218 This function creates such an access for coordinates :math:`i \in \mbox{coordinates}`. 

219 Returns a new typed symbol, where the name encodes which coordinates have been resolved. 

220 

221 Args: 

222 field_access: instance of :class:`pystencils.field.Field.Access` which provides strides and offsets 

223 coordinates: mapping of coordinate ids to its value, where stride*value is calculated 

224 previous_ptr: the pointer which is de-referenced 

225 

226 Returns 

227 tuple with the new pointer symbol and the calculated offset 

228 

229 Examples: 

230 >>> field = Field.create_generic('myfield', spatial_dimensions=2, index_dimensions=1) 

231 >>> x, y = sp.symbols("x y") 

232 >>> prev_pointer = TypedSymbol("ptr", "double") 

233 >>> create_intermediate_base_pointer(field[1,-2](5), {0: x}, prev_pointer) 

234 (ptr_01, _stride_myfield_0*x + _stride_myfield_0) 

235 >>> create_intermediate_base_pointer(field[1,-2](5), {0: x, 1 : y }, prev_pointer) 

236 (ptr_01_1m2, _stride_myfield_0*x + _stride_myfield_0 + _stride_myfield_1*y - 2*_stride_myfield_1) 

237 """ 

238 field = field_access.field 

239 offset = 0 

240 name = "" 

241 list_to_hash = [] 

242 for coordinate_id, coordinate_value in coordinates.items(): 

243 offset += field.strides[coordinate_id] * coordinate_value 

244 

245 if coordinate_id < field.spatial_dimensions: 

246 offset += field.strides[coordinate_id] * field_access.offsets[coordinate_id] 

247 if field_access.offsets[coordinate_id].is_Integer: 247 ↛ 250line 247 didn't jump to line 250, because the condition on line 247 was never false

248 name += "_%d%d" % (coordinate_id, field_access.offsets[coordinate_id]) 

249 else: 

250 list_to_hash.append(field_access.offsets[coordinate_id]) 

251 else: 

252 if type(coordinate_value) is int: 252 ↛ 255line 252 didn't jump to line 255, because the condition on line 252 was never false

253 name += "_%d%d" % (coordinate_id, coordinate_value) 

254 else: 

255 list_to_hash.append(coordinate_value) 

256 

257 if len(list_to_hash) > 0: 257 ↛ 258line 257 didn't jump to line 258, because the condition on line 257 was never true

258 name += hashlib.md5(pickle.dumps(list_to_hash)).hexdigest()[:16] 

259 

260 name = name.replace("-", 'm') 

261 new_ptr = TypedSymbol(previous_ptr.name + name, previous_ptr.dtype) 

262 return new_ptr, offset 

263 

264 

265def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dimensions, index_dimensions): 

266 """ 

267 Creates base pointer specification for :func:`resolve_field_accesses` function. 

268 

269 Specification of how many and which intermediate pointers are created for a field access. 

270 For example [ (0), (2,3,)] creates on base pointer for coordinates 2 and 3 and writes the offset for coordinate 

271 zero directly in the field access. These specifications are defined dependent on the loop ordering. 

272 This function translates more readable version into the specification above. 

273 

274 Allowed specifications: 

275 - "spatialInner<int>" spatialInner0 is the innermost loop coordinate, 

276 spatialInner1 the loop enclosing the innermost 

277 - "spatialOuter<int>" spatialOuter0 is the outermost loop 

278 - "index<int>": index coordinate 

279 - "<int>": specifying directly the coordinate 

280 

281 Args: 

282 base_pointer_specification: nested list with above specifications 

283 loop_order: list with ordering of loops from outer to inner 

284 spatial_dimensions: number of spatial dimensions 

285 index_dimensions: number of index dimensions 

286 

287 Returns: 

288 list of tuples that can be passed to :func:`resolve_field_accesses` 

289 

290 Examples: 

291 >>> parse_base_pointer_info([['spatialOuter0'], ['index0']], loop_order=[2,1,0], 

292 ... spatial_dimensions=3, index_dimensions=1) 

293 [[0], [3], [1, 2]] 

294 """ 

295 result = [] 

296 specified_coordinates = set() 

297 loop_order = list(reversed(loop_order)) 

298 for spec_group in base_pointer_specification: 

299 new_group = [] 

300 

301 def add_new_element(elem): 

302 if elem >= spatial_dimensions + index_dimensions: 302 ↛ 303line 302 didn't jump to line 303, because the condition on line 302 was never true

303 raise ValueError("Coordinate %d does not exist" % (elem,)) 

304 new_group.append(elem) 

305 if elem in specified_coordinates: 305 ↛ 306line 305 didn't jump to line 306, because the condition on line 305 was never true

306 raise ValueError("Coordinate %d specified two times" % (elem,)) 

307 specified_coordinates.add(elem) 

308 

309 for element in spec_group: 

310 if type(element) is int: 310 ↛ 311line 310 didn't jump to line 311, because the condition on line 310 was never true

311 add_new_element(element) 

312 elif element.startswith("spatial"): 312 ↛ 325line 312 didn't jump to line 325, because the condition on line 312 was never false

313 element = element[len("spatial"):] 

314 if element.startswith("Inner"): 314 ↛ 317line 314 didn't jump to line 317, because the condition on line 314 was never false

315 index = int(element[len("Inner"):]) 

316 add_new_element(loop_order[index]) 

317 elif element.startswith("Outer"): 

318 index = int(element[len("Outer"):]) 

319 add_new_element(loop_order[-index]) 

320 elif element == "all": 

321 for i in range(spatial_dimensions): 

322 add_new_element(i) 

323 else: 

324 raise ValueError("Could not parse " + element) 

325 elif element.startswith("index"): 

326 index = int(element[len("index"):]) 

327 add_new_element(spatial_dimensions + index) 

328 else: 

329 raise ValueError(f"Unknown specification {element}") 

330 

331 result.append(new_group) 

332 

333 all_coordinates = set(range(spatial_dimensions + index_dimensions)) 

334 rest = all_coordinates - specified_coordinates 

335 if rest: 

336 result.append(list(rest)) 

337 

338 return result 

339 

340 

341def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None): 

342 """Used for buffer fields to determine the linearized index of the buffer dependent on loop counter symbols. 

343 

344 Args: 

345 ast_node: ast before any field accesses are resolved 

346 loop_counters: for CPU kernels: leave to default 'None' (can be determined from loop nodes) 

347 for GPU kernels: list of 'loop counters' from inner to outer loop 

348 loop_iterations: number of iterations of each loop from inner to outer, for CPU kernels leave to default 

349 

350 Returns: 

351 base buffer index - required by 'resolve_buffer_accesses' function 

352 """ 

353 if loop_counters is None or loop_iterations is None: 

354 loops = [l for l in filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, ast.SympyAssignment)] 

355 loops.reverse() 

356 parents_of_innermost_loop = list(parents_of_type(loops[0], ast.LoopOverCoordinate, include_current=True)) 

357 assert len(loops) == len(parents_of_innermost_loop) 

358 assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop)) 

359 

360 loop_iterations = [(l.stop - l.start) / l.step for l in loops] 

361 loop_counters = [l.loop_counter_symbol for l in loops] 

362 

363 field_accesses = ast_node.atoms(AbstractField.AbstractAccess) 

364 buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)} 

365 loop_counters = [v * len(buffer_accesses) for v in loop_counters] 

366 

367 base_buffer_index = loop_counters[0] 

368 stride = 1 

369 for idx, var in enumerate(loop_counters[1:]): 

370 cur_stride = loop_iterations[idx] 

371 stride *= int(cur_stride) if isinstance(cur_stride, float) else cur_stride 

372 base_buffer_index += var * stride 

373 return base_buffer_index 

374 

375 

376def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()): 

377 

378 def visit_sympy_expr(expr, enclosing_block, sympy_assignment): 

379 if isinstance(expr, AbstractField.AbstractAccess): 

380 field_access = expr 

381 

382 # Do not apply transformation if field is not a buffer 

383 if not FieldType.is_buffer(field_access.field): 

384 return expr 

385 

386 buffer = field_access.field 

387 field_ptr = FieldPointerSymbol(buffer.name, buffer.dtype, const=buffer.name in read_only_field_names) 

388 

389 buffer_index = base_buffer_index 

390 if len(field_access.index) > 1: 

391 raise RuntimeError('Only indexing dimensions up to 1 are currently supported in buffers!') 

392 

393 if len(field_access.index) > 0: 

394 cell_index = field_access.index[0] 

395 buffer_index += cell_index 

396 

397 result = ast.ResolvedFieldAccess(field_ptr, buffer_index, field_access.field, field_access.offsets, 

398 field_access.index) 

399 

400 return visit_sympy_expr(result, enclosing_block, sympy_assignment) 

401 else: 

402 if isinstance(expr, ast.ResolvedFieldAccess): 

403 return expr 

404 

405 new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args] 

406 kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {} 

407 return expr.func(*new_args, **kwargs) if new_args else expr 

408 

409 def visit_node(sub_ast): 

410 if isinstance(sub_ast, ast.SympyAssignment): 

411 enclosing_block = sub_ast.parent 

412 assert type(enclosing_block) is ast.Block 

413 sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast) 

414 sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast) 

415 else: 

416 for i, a in enumerate(sub_ast.args): 

417 visit_node(a) 

418 

419 return visit_node(ast_node) 

420 

421 

422def resolve_field_accesses(ast_node, read_only_field_names=set(), 

423 field_to_base_pointer_info=MappingProxyType({}), 

424 field_to_fixed_coordinates=MappingProxyType({})): 

425 """ 

426 Substitutes :class:`pystencils.field.Field.Access` nodes by array indexing 

427 

428 Args: 

429 ast_node: the AST root 

430 read_only_field_names: set of field names which are considered read-only 

431 field_to_base_pointer_info: a list of tuples indicating which intermediate base pointers should be created 

432 for details see :func:`parse_base_pointer_info` 

433 field_to_fixed_coordinates: map of field name to a tuple of coordinate symbols. Instead of using the loop 

434 counters to index the field these symbols are used as coordinates 

435 

436 Returns 

437 transformed AST 

438 """ 

439 field_to_base_pointer_info = OrderedDict(sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0])) 

440 field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0])) 440 ↛ exitline 440 didn't run the lambda on line 440

441 

442 def visit_sympy_expr(expr, enclosing_block, sympy_assignment): 

443 if isinstance(expr, AbstractField.AbstractAccess): 

444 field_access = expr 

445 field = field_access.field 

446 

447 if field_access.indirect_addressing_fields: 447 ↛ 448line 447 didn't jump to line 448, because the condition on line 447 was never true

448 new_offsets = tuple(visit_sympy_expr(off, enclosing_block, sympy_assignment) 

449 for off in field_access.offsets) 

450 new_indices = tuple(visit_sympy_expr(ind, enclosing_block, sympy_assignment) 

451 if isinstance(ind, sp.Basic) else ind 

452 for ind in field_access.index) 

453 field_access = Field.Access(field_access.field, new_offsets, 

454 new_indices, field_access.is_absolute_access) 

455 

456 if field.name in field_to_base_pointer_info: 456 ↛ 459line 456 didn't jump to line 459

457 base_pointer_info = field_to_base_pointer_info[field.name] 

458 else: 

459 base_pointer_info = [ 

460 list( 

461 range(field.index_dimensions + field.spatial_dimensions)) 

462 ] 

463 

464 field_ptr = FieldPointerSymbol( 

465 field.name, 

466 field.dtype, 

467 const=field.name in read_only_field_names) 

468 

469 def create_coordinate_dict(group_param): 

470 coordinates = {} 

471 for e in group_param: 

472 if e < field.spatial_dimensions: 

473 if field.name in field_to_fixed_coordinates: 473 ↛ 474line 473 didn't jump to line 474, because the condition on line 473 was never true

474 if not field_access.is_absolute_access: 

475 coordinates[e] = field_to_fixed_coordinates[field.name][e] 

476 else: 

477 coordinates[e] = 0 

478 else: 

479 if not field_access.is_absolute_access: 479 ↛ 482line 479 didn't jump to line 482, because the condition on line 479 was never false

480 coordinates[e] = ast.LoopOverCoordinate.get_loop_counter_symbol(e) 

481 else: 

482 coordinates[e] = 0 

483 coordinates[e] *= field.dtype.item_size 

484 else: 

485 if isinstance(field.dtype, StructType): 485 ↛ 486line 485 didn't jump to line 486, because the condition on line 485 was never true

486 assert field.index_dimensions == 1 

487 accessed_field_name = field_access.index[0] 

488 if isinstance(accessed_field_name, sp.Symbol): 

489 accessed_field_name = accessed_field_name.name 

490 assert isinstance(accessed_field_name, str) 

491 coordinates[e] = field.dtype.get_element_offset(accessed_field_name) 

492 else: 

493 coordinates[e] = field_access.index[e - field.spatial_dimensions] 

494 

495 return coordinates 

496 

497 last_pointer = field_ptr 

498 

499 for group in reversed(base_pointer_info[1:]): 

500 coord_dict = create_coordinate_dict(group) 

501 new_ptr, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer) 

502 if new_ptr not in enclosing_block.symbols_defined: 

503 new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False) 

504 enclosing_block.insert_before(new_assignment, sympy_assignment) 

505 last_pointer = new_ptr 

506 

507 coord_dict = create_coordinate_dict(base_pointer_info[0]) 

508 _, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer) 

509 result = ast.ResolvedFieldAccess(last_pointer, offset, field_access.field, 

510 field_access.offsets, field_access.index) 

511 

512 if isinstance(get_base_type(field_access.field.dtype), StructType): 512 ↛ 513line 512 didn't jump to line 513, because the condition on line 512 was never true

513 accessed_field_name = field_access.index[0] 

514 if isinstance(accessed_field_name, sp.Symbol): 

515 accessed_field_name = accessed_field_name.name 

516 new_type = field_access.field.dtype.get_element_type(accessed_field_name) 

517 result = reinterpret_cast_func(result, new_type) 

518 

519 return visit_sympy_expr(result, enclosing_block, sympy_assignment) 

520 else: 

521 if isinstance(expr, ast.ResolvedFieldAccess): 

522 return expr 

523 

524 if hasattr(expr, 'args'): 524 ↛ 527line 524 didn't jump to line 527, because the condition on line 524 was never false

525 new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args] 

526 else: 

527 new_args = [] 

528 kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {} 

529 return expr.func(*new_args, **kwargs) if new_args else expr 

530 

531 def visit_node(sub_ast): 

532 if isinstance(sub_ast, ast.SympyAssignment): 

533 enclosing_block = sub_ast.parent 

534 assert type(enclosing_block) is ast.Block 

535 sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast) 

536 sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast) 

537 elif isinstance(sub_ast, ast.Conditional): 

538 enclosing_block = sub_ast.parent 

539 assert type(enclosing_block) is ast.Block 

540 sub_ast.condition_expr = visit_sympy_expr(sub_ast.condition_expr, enclosing_block, sub_ast) 

541 visit_node(sub_ast.true_block) 

542 if sub_ast.false_block: 542 ↛ 543line 542 didn't jump to line 543, because the condition on line 542 was never true

543 visit_node(sub_ast.false_block) 

544 else: 

545 if isinstance(sub_ast, (bool, int, float)): 545 ↛ 546line 545 didn't jump to line 546, because the condition on line 545 was never true

546 return 

547 for a in sub_ast.args: 

548 visit_node(a) 

549 

550 return visit_node(ast_node) 

551 

552 

553def move_constants_before_loop(ast_node): 

554 """Moves :class:`pystencils.ast.SympyAssignment` nodes out of loop body if they are iteration independent. 

555 

556 Call this after creating the loop structure with :func:`make_loop_over_domain` 

557 """ 

558 def find_block_to_move_to(node): 

559 """ 

560 Traverses parents of node as long as the symbols are independent and returns a (parent) block 

561 the assignment can be safely moved to 

562 :param node: SympyAssignment inside a Block 

563 :return blockToInsertTo, childOfBlockToInsertBefore 

564 """ 

565 assert isinstance(node.parent, ast.Block) 

566 

567 last_block = node.parent 

568 last_block_child = node 

569 element = node.parent 

570 prev_element = node 

571 while element: 

572 if isinstance(element, ast.Block): 

573 last_block = element 

574 last_block_child = prev_element 

575 

576 if isinstance(element, ast.Conditional): 

577 break 

578 else: 

579 critical_symbols = set([s.name for s in element.symbols_defined]) 

580 if set([s.name for s in node.undefined_symbols]).intersection(critical_symbols): 

581 break 

582 prev_element = element 

583 element = element.parent 

584 return last_block, last_block_child 

585 

586 def check_if_assignment_already_in_block(assignment, target_block, rhs_or_lhs=True): 

587 for arg in target_block.args: 

588 if type(arg) is not ast.SympyAssignment: 

589 continue 

590 if (rhs_or_lhs and arg.rhs == assignment.rhs) or (not rhs_or_lhs and arg.lhs == assignment.lhs): 590 ↛ 591line 590 didn't jump to line 591, because the condition on line 590 was never true

591 return arg 

592 return None 

593 

594 def get_blocks(node, result_list): 

595 if isinstance(node, ast.Block): 

596 result_list.append(node) 

597 if isinstance(node, ast.Node): 

598 for a in node.args: 

599 get_blocks(a, result_list) 

600 

601 all_blocks = [] 

602 get_blocks(ast_node, all_blocks) 

603 for block in all_blocks: 

604 children = block.take_child_nodes() 

605 # Every time a symbol can be replaced in the current block because the assignment 

606 # was found in a parent block, but with a different lhs symbol (same rhs) 

607 # the outer symbol is inserted here as key. 

608 substitute_variables = {} 

609 for child in children: 

610 # Before traversing the next child, all symbols are substituted first. 

611 child.subs(substitute_variables) 

612 

613 if not isinstance(child, ast.SympyAssignment): # only move SympyAssignments 

614 block.append(child) 

615 continue 

616 

617 target, child_to_insert_before = find_block_to_move_to(child) 

618 if target == block: # movement not possible 

619 target.append(child) 

620 else: 

621 if isinstance(child, ast.SympyAssignment): 621 ↛ 624line 621 didn't jump to line 624, because the condition on line 621 was never false

622 exists_already = check_if_assignment_already_in_block(child, target, False) 

623 else: 

624 exists_already = False 

625 

626 if not exists_already: 626 ↛ 635line 626 didn't jump to line 635, because the condition on line 626 was never false

627 rhs_identical = check_if_assignment_already_in_block(child, target, True) 

628 if rhs_identical: 628 ↛ 632line 628 didn't jump to line 632, because the condition on line 628 was never true

629 # there is already an assignment out there with the same rhs 

630 # -> replace all lhs symbols in this block with the lhs of the outer assignment 

631 # -> remove the local assignment (do not re-append child to the former block) 

632 substitute_variables[child.lhs] = rhs_identical.lhs 

633 else: 

634 target.insert_before(child, child_to_insert_before) 

635 elif exists_already and exists_already.rhs == child.rhs: 

636 if target.args.index(exists_already) > target.args.index(child_to_insert_before): 

637 assert target.args.count(exists_already) == 1 

638 assert target.args.count(child_to_insert_before) == 1 

639 target.args.remove(exists_already) 

640 target.insert_before(exists_already, child_to_insert_before) 

641 else: 

642 # this variable already exists in outer block, but with different rhs 

643 # -> symbol has to be renamed 

644 assert isinstance(child.lhs, TypedSymbol) 

645 new_symbol = TypedSymbol(sp.Dummy().name, child.lhs.dtype) 

646 target.insert_before(ast.SympyAssignment(new_symbol, child.rhs, is_const=child.is_const), 

647 child_to_insert_before) 

648 substitute_variables[child.lhs] = new_symbol 

649 

650 

651def split_inner_loop(ast_node: ast.Node, symbol_groups): 

652 """ 

653 Splits inner loop into multiple loops to minimize the amount of simultaneous load/store streams 

654 

655 Args: 

656 ast_node: AST root 

657 symbol_groups: sequence of symbol sequences: for each symbol sequence a new inner loop is created which 

658 updates these symbols and their dependent symbols. Symbols which are in none of the symbolGroups 

659 and which no symbol in a symbol group depends on, are not updated! 

660 """ 

661 all_loops = ast_node.atoms(ast.LoopOverCoordinate) 

662 inner_loop = [l for l in all_loops if l.is_innermost_loop] 

663 assert len(inner_loop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?" 

664 inner_loop = inner_loop[0] 

665 assert type(inner_loop.body) is ast.Block 

666 outer_loop = [l for l in all_loops if l.is_outermost_loop] 

667 assert len(outer_loop) == 1, "Error in AST, multiple outermost loops." 

668 outer_loop = outer_loop[0] 

669 

670 symbols_with_temporary_array = OrderedDict() 

671 assignment_map = OrderedDict((a.lhs, a) for a in inner_loop.body.args if hasattr(a, 'lhs')) 

672 

673 assignment_groups = [] 

674 for symbol_group in symbol_groups: 

675 # get all dependent symbols 

676 symbols_to_process = list(symbol_group) 

677 symbols_resolved = set() 

678 while symbols_to_process: 

679 s = symbols_to_process.pop() 

680 if s in symbols_resolved: 

681 continue 

682 

683 if s in assignment_map: # if there is no assignment inside the loop body it is independent already 

684 for new_symbol in assignment_map[s].rhs.atoms(sp.Symbol): 

685 if not isinstance(new_symbol, AbstractField.AbstractAccess) and \ 

686 new_symbol not in symbols_with_temporary_array: 

687 symbols_to_process.append(new_symbol) 

688 symbols_resolved.add(s) 

689 

690 for symbol in symbol_group: 

691 if not isinstance(symbol, AbstractField.AbstractAccess): 

692 assert type(symbol) is TypedSymbol 

693 new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype)) 

694 symbols_with_temporary_array[symbol] = sp.IndexedBase( 

695 new_ts, shape=(1, ))[inner_loop.loop_counter_symbol] 

696 

697 assignment_group = [] 

698 for assignment in inner_loop.body.args: 

699 if assignment.lhs in symbols_resolved: 

700 new_rhs = assignment.rhs.subs( 

701 symbols_with_temporary_array.items()) 

702 if not isinstance(assignment.lhs, AbstractField.AbstractAccess) and assignment.lhs in symbol_group: 

703 assert type(assignment.lhs) is TypedSymbol 

704 new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype)) 

705 new_lhs = sp.IndexedBase(new_ts, shape=(1, ))[inner_loop.loop_counter_symbol] 

706 else: 

707 new_lhs = assignment.lhs 

708 assignment_group.append(ast.SympyAssignment(new_lhs, new_rhs)) 

709 assignment_groups.append(assignment_group) 

710 

711 new_loops = [ 

712 inner_loop.new_loop_with_different_body(ast.Block(group)) 

713 for group in assignment_groups 

714 ] 

715 inner_loop.parent.replace(inner_loop, ast.Block(new_loops)) 

716 

717 for tmp_array in symbols_with_temporary_array: 

718 tmp_array_pointer = TypedSymbol(tmp_array.name, PointerType(tmp_array.dtype)) 

719 alloc_node = ast.TemporaryMemoryAllocation(tmp_array_pointer, inner_loop.stop, inner_loop.start) 

720 free_node = ast.TemporaryMemoryFree(alloc_node) 

721 outer_loop.parent.insert_front(alloc_node) 

722 outer_loop.parent.append(free_node) 

723 

724 

725def cut_loop(loop_node, cutting_points): 

726 """Cuts loop at given cutting points. 

727 

728 One loop is transformed into len(cuttingPoints)+1 new loops that range from 

729 old_begin to cutting_points[1], ..., cutting_points[-1] to old_end 

730 

731 Modifies the ast in place 

732 

733 Returns: 

734 list of new loop nodes 

735 """ 

736 if loop_node.step != 1: 

737 raise NotImplementedError("Can only split loops that have a step of 1") 

738 new_loops = ast.Block([]) 

739 new_start = loop_node.start 

740 cutting_points = list(cutting_points) + [loop_node.stop] 

741 for new_end in cutting_points: 

742 if new_end - new_start == 1: 

743 new_body = deepcopy(loop_node.body) 

744 new_body.subs({loop_node.loop_counter_symbol: new_start}) 

745 new_loops.append(new_body) 

746 elif new_end - new_start == 0: 

747 pass 

748 else: 

749 new_loop = ast.LoopOverCoordinate( 

750 deepcopy(loop_node.body), loop_node.coordinate_to_loop_over, 

751 new_start, new_end, loop_node.step) 

752 new_loops.append(new_loop) 

753 new_start = new_end 

754 loop_node.parent.replace(loop_node, new_loops) 

755 return new_loops 

756 

757 

758def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool = False) -> None: 

759 """Removes conditionals that are always true/false. 

760 

761 Args: 

762 node: ast node, all descendants of this node are simplified 

763 loop_counter_simplification: if enabled, tries to detect if a conditional is always true/false 

764 depending on the surrounding loop. For example if the surrounding loop goes from 

765 x=0 to 10 and the condition is x < 0, it is removed. 

766 This analysis needs the integer set library (ISL) islpy, so it is not done by 

767 default. 

768 """ 

769 for conditional in node.atoms(ast.Conditional): 

770 conditional.condition_expr = sp.simplify(conditional.condition_expr) 

771 if conditional.condition_expr == sp.true: 771 ↛ 772line 771 didn't jump to line 772, because the condition on line 771 was never true

772 conditional.parent.replace(conditional, [conditional.true_block]) 

773 elif conditional.condition_expr == sp.false: 

774 conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else []) 

775 elif loop_counter_simplification: 775 ↛ 769line 775 didn't jump to line 769, because the condition on line 775 was never false

776 try: 

777 # noinspection PyUnresolvedReferences 

778 from pystencils.integer_set_analysis import simplify_loop_counter_dependent_conditional 

779 simplify_loop_counter_dependent_conditional(conditional) 

780 except ImportError: 

781 warnings.warn("Integer simplifications in conditionals skipped, because ISLpy package not installed") 

782 

783 

784def cleanup_blocks(node: ast.Node) -> None: 

785 """Curly Brace Removal: Removes empty blocks, and replaces blocks with a single child by its child """ 

786 if isinstance(node, ast.SympyAssignment): 

787 return 

788 elif isinstance(node, ast.Block): 

789 for a in list(node.args): 

790 cleanup_blocks(a) 

791 if len(node.args) <= 1 and isinstance(node.parent, ast.Block): 

792 node.parent.replace(node, node.args) 

793 return 

794 else: 

795 for a in node.args: 

796 cleanup_blocks(a) 

797 

798 

799class KernelConstraintsCheck: 

800 """Checks if the input to create_kernel is valid. 

801 

802 Test the following conditions: 

803 

804 - SSA Form for pure symbols: 

805 - Every pure symbol may occur only once as left-hand-side of an assignment 

806 - Every pure symbol that is read, may not be written to later 

807 - Independence / Parallelization condition: 

808 - a field that is written may only be read at exact the same spatial position 

809 

810 (Pure symbols are symbols that are not Field.Accesses) 

811 """ 

812 FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index']) 

813 

814 def __init__(self, type_for_symbol, check_independence_condition, check_double_write_condition=True): 

815 self._type_for_symbol = type_for_symbol 

816 

817 self.scopes = NestedScopes() 

818 self._field_writes = defaultdict(set) 

819 self.fields_read = set() 

820 self.check_independence_condition = check_independence_condition 

821 self.check_double_write_condition = check_double_write_condition 

822 

823 def process_assignment(self, assignment): 

824 # for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1 

825 new_rhs = self.process_expression(assignment.rhs) 

826 new_lhs = self._process_lhs(assignment.lhs) 

827 return ast.SympyAssignment(new_lhs, new_rhs) 

828 

829 def process_expression(self, rhs, type_constants=True): 

830 from pystencils.interpolation_astnodes import InterpolatorAccess 

831 

832 self._update_accesses_rhs(rhs) 

833 if isinstance(rhs, AbstractField.AbstractAccess): 

834 self.fields_read.add(rhs.field) 

835 self.fields_read.update(rhs.indirect_addressing_fields) 

836 return rhs 

837 elif isinstance(rhs, InterpolatorAccess): 837 ↛ 838line 837 didn't jump to line 838, because the condition on line 837 was never true

838 new_args = [self.process_expression(arg, type_constants) for arg in rhs.offsets] 

839 if new_args: 

840 rhs.offsets = new_args 

841 return rhs 

842 elif isinstance(rhs, ImaginaryUnit): 842 ↛ 843line 842 didn't jump to line 843, because the condition on line 842 was never true

843 return TypedImaginaryUnit(create_type(self._type_for_symbol['_complex_type'])) 

844 elif isinstance(rhs, TypedSymbol): 

845 return rhs 

846 elif isinstance(rhs, sp.Symbol): 

847 return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name]) 

848 elif type_constants and isinstance(rhs, np.generic): 848 ↛ 849line 848 didn't jump to line 849, because the condition on line 848 was never true

849 return cast_func(rhs, create_type(rhs.dtype)) 

850 elif type_constants and isinstance(rhs, sp.Number): 

851 return cast_func(rhs, create_type(self._type_for_symbol['_constant'])) 

852 # Very important that this clause comes before BooleanFunction 

853 elif isinstance(rhs, sp.Equality): 

854 if isinstance(rhs.args[1], sp.Number): 854 ↛ 859line 854 didn't jump to line 859, because the condition on line 854 was never false

855 return sp.Equality( 

856 self.process_expression(rhs.args[0], type_constants), 

857 rhs.args[1]) 

858 else: 

859 return sp.Equality( 

860 self.process_expression(rhs.args[0], type_constants), 

861 self.process_expression(rhs.args[1], type_constants)) 

862 elif isinstance(rhs, cast_func): 862 ↛ 863line 862 didn't jump to line 863, because the condition on line 862 was never true

863 return cast_func( 

864 self.process_expression(rhs.args[0], type_constants=False), 

865 rhs.dtype) 

866 elif isinstance(rhs, BooleanFunction) or \ 

867 type(rhs) in pystencils.integer_functions.__dict__.values(): 

868 new_args = [self.process_expression(a, type_constants) for a in rhs.args] 

869 types_of_expressions = [get_type_of_expression(a) for a in new_args] 

870 arg_type = collate_types(types_of_expressions, forbid_collation_to_float=True) 

871 new_args = [a if not hasattr(a, 'dtype') or a.dtype == arg_type 

872 else cast_func(a, arg_type) 

873 for a in new_args] 

874 return rhs.func(*new_args) 

875 elif isinstance(rhs, sp.Mul): 

876 new_args = [ 

877 self.process_expression(arg, type_constants) 

878 if arg not in (-1, 1) else arg for arg in rhs.args 

879 ] 

880 return rhs.func(*new_args) if new_args else rhs 

881 elif isinstance(rhs, sp.Indexed): 881 ↛ 882line 881 didn't jump to line 882, because the condition on line 881 was never true

882 return rhs 

883 else: 

884 if isinstance(rhs, sp.Pow): 

885 # don't process exponents -> they should remain integers 

886 return sp.Pow( 

887 self.process_expression(rhs.args[0], type_constants), 

888 rhs.args[1]) 

889 else: 

890 new_args = [ 

891 self.process_expression(arg, type_constants) 

892 for arg in rhs.args 

893 ] 

894 return rhs.func(*new_args) if new_args else rhs 

895 

896 @property 

897 def fields_written(self): 

898 return set(k.field for k, v in self._field_writes.items() if len(v)) 

899 

900 def _process_lhs(self, lhs): 

901 assert isinstance(lhs, sp.Symbol) 

902 self._update_accesses_lhs(lhs) 

903 if not isinstance(lhs, (AbstractField.AbstractAccess, TypedSymbol)): 

904 return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name]) 

905 else: 

906 return lhs 

907 

908 def _update_accesses_lhs(self, lhs): 

909 if isinstance(lhs, AbstractField.AbstractAccess): 

910 fai = self.FieldAndIndex(lhs.field, lhs.index) 

911 self._field_writes[fai].add(lhs.offsets) 

912 if self.check_double_write_condition and len(self._field_writes[fai]) > 1: 912 ↛ 913line 912 didn't jump to line 913, because the condition on line 912 was never true

913 raise ValueError( 

914 f"Field {lhs.field.name} is written at two different locations") 

915 elif isinstance(lhs, sp.Symbol): 915 ↛ exitline 915 didn't return from function '_update_accesses_lhs', because the condition on line 915 was never false

916 if self.scopes.is_defined_locally(lhs): 916 ↛ 917line 916 didn't jump to line 917, because the condition on line 916 was never true

917 raise ValueError(f"Assignments not in SSA form, multiple assignments to {lhs.name}") 

918 if lhs in self.scopes.free_parameters: 918 ↛ 919line 918 didn't jump to line 919, because the condition on line 918 was never true

919 raise ValueError(f"Symbol {lhs.name} is written, after it has been read") 

920 self.scopes.define_symbol(lhs) 

921 

922 def _update_accesses_rhs(self, rhs): 

923 if isinstance(rhs, AbstractField.AbstractAccess) and self.check_independence_condition: 

924 writes = self._field_writes[self.FieldAndIndex( 

925 rhs.field, rhs.index)] 

926 for write_offset in writes: 926 ↛ 927line 926 didn't jump to line 927, because the loop on line 926 never started

927 assert len(writes) == 1 

928 if write_offset != rhs.offsets: 

929 raise ValueError("Violation of loop independence condition. Field " 

930 "{} is read at {} and written at {}".format(rhs.field, rhs.offsets, write_offset)) 

931 self.fields_read.add(rhs.field) 

932 elif isinstance(rhs, sp.Symbol): 

933 self.scopes.access_symbol(rhs) 

934 

935 

936def add_types(eqs, type_for_symbol, check_independence_condition): 

937 """Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`. 

938 

939 Additionally returns sets of all fields which are read/written 

940 

941 Args: 

942 eqs: list of equations 

943 type_for_symbol: dict mapping symbol names to types. Types are strings of C types like 'int' or 'double' 

944 check_independence_condition: check that loop iterations are independent - this has to be skipped for indexed 

945 kernels 

946 

947 Returns: 

948 ``fields_read, fields_written, typed_equations`` set of read fields, set of written fields, 

949 list of equations where symbols have been replaced by typed symbols 

950 """ 

951 if isinstance(type_for_symbol, (str, type)) or not hasattr(type_for_symbol, '__getitem__'): 951 ↛ 954line 951 didn't jump to line 954, because the condition on line 951 was never false

952 type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol) 

953 

954 check = KernelConstraintsCheck(type_for_symbol, check_independence_condition) 

955 

956 def visit(obj): 

957 if isinstance(obj, (list, tuple)): 

958 return [visit(e) for e in obj] 

959 if isinstance(obj, (sp.Eq, ast.SympyAssignment, Assignment)): 

960 return check.process_assignment(obj) 

961 elif isinstance(obj, ast.Conditional): 

962 check.scopes.push() 

963 # Disable double write check inside conditionals 

964 # would be triggered by e.g. in-kernel boundaries 

965 old_double_write = check.check_double_write_condition 

966 check.check_double_write_condition = False 

967 false_block = None if obj.false_block is None else visit( 

968 obj.false_block) 

969 result = ast.Conditional(check.process_expression( 

970 obj.condition_expr, type_constants=False), 

971 true_block=visit(obj.true_block), 

972 false_block=false_block) 

973 check.check_double_write_condition = old_double_write 

974 check.scopes.pop() 

975 return result 

976 elif isinstance(obj, ast.Block): 

977 check.scopes.push() 

978 result = ast.Block([visit(e) for e in obj.args]) 

979 check.scopes.pop() 

980 return result 

981 elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate): 981 ↛ 984line 981 didn't jump to line 984, because the condition on line 981 was never false

982 return obj 

983 else: 

984 raise ValueError("Invalid object in kernel " + str(type(obj))) 

985 

986 typed_equations = visit(eqs) 

987 

988 return check.fields_read, check.fields_written, typed_equations 

989 

990 

991def insert_casts(node): 

992 """Checks the types and inserts casts and pointer arithmetic where necessary. 

993 

994 Args: 

995 node: the head node of the ast 

996 

997 Returns: 

998 modified AST 

999 """ 

1000 def cast(zipped_args_types, target_dtype): 

1001 """ 

1002 Adds casts to the arguments if their type differs from the target type 

1003 :param zipped_args_types: a zipped list of args and types 

1004 :param target_dtype: The target data type 

1005 :return: args with possible casts 

1006 """ 

1007 casted_args = [] 

1008 for argument, data_type in zipped_args_types: 

1009 if data_type.numpy_dtype != target_dtype.numpy_dtype: # ignoring const 

1010 casted_args.append(cast_func(argument, target_dtype)) 

1011 else: 

1012 casted_args.append(argument) 

1013 return casted_args 

1014 

1015 def pointer_arithmetic(expr_args): 

1016 """ 

1017 Creates a valid pointer arithmetic function 

1018 :param expr_args: Arguments of the add expression 

1019 :return: pointer_arithmetic_func 

1020 """ 

1021 pointer = None 

1022 new_args = [] 

1023 for arg, data_type in expr_args: 

1024 if data_type.func is PointerType: 

1025 assert pointer is None 

1026 pointer = arg 

1027 for arg, data_type in expr_args: 

1028 if arg != pointer: 

1029 assert data_type.is_int() or data_type.is_uint() 

1030 new_args.append(arg) 

1031 new_args = sp.Add(*new_args) if len(new_args) > 0 else new_args 

1032 return pointer_arithmetic_func(pointer, new_args) 

1033 

1034 if isinstance(node, sp.AtomicExpr) or isinstance(node, cast_func): 

1035 return node 

1036 args = [] 

1037 for arg in node.args: 

1038 args.append(insert_casts(arg)) 

1039 # TODO indexed, LoopOverCoordinate 

1040 if node.func in (sp.Add, sp.Mul, sp.Or, sp.And, sp.Pow, sp.Eq, sp.Ne, sp.Lt, sp.Le, sp.Gt, sp.Ge): 

1041 # TODO optimize pow, don't cast integer on double 

1042 types = [get_type_of_expression(arg) for arg in args] 

1043 assert len(types) > 0 

1044 # Never ever, ever collate to float type for boolean functions! 

1045 target = collate_types(types, forbid_collation_to_float=isinstance(node.func, BooleanFunction)) 

1046 zipped = list(zip(args, types)) 

1047 if target.func is PointerType: 

1048 assert node.func is sp.Add 

1049 return pointer_arithmetic(zipped) 

1050 else: 

1051 return node.func(*cast(zipped, target)) 

1052 elif node.func is ast.SympyAssignment: 

1053 lhs = args[0] 

1054 rhs = args[1] 

1055 target = get_type_of_expression(lhs) 

1056 if target.func is PointerType: 

1057 return node.func(*args) # TODO fix, not complete 

1058 else: 

1059 return node.func(lhs, *cast([(rhs, get_type_of_expression(rhs))], target)) 

1060 elif node.func is ast.ResolvedFieldAccess: 

1061 return node 

1062 elif node.func is ast.Block: 

1063 for old_arg, new_arg in zip(node.args, args): 

1064 node.replace(old_arg, new_arg) 

1065 return node 

1066 elif node.func is ast.LoopOverCoordinate: 

1067 for old_arg, new_arg in zip(node.args, args): 

1068 node.replace(old_arg, new_arg) 

1069 return node 

1070 elif node.func is sp.Piecewise: 

1071 expressions = [expr for (expr, _) in args] 

1072 types = [get_type_of_expression(expr) for expr in expressions] 

1073 target = collate_types(types) 

1074 zipped = list(zip(expressions, types)) 

1075 casted_expressions = cast(zipped, target) 

1076 args = [ 

1077 arg.func(*[expr, arg.cond]) 

1078 for (arg, expr) in zip(args, casted_expressions) 

1079 ] 

1080 

1081 return node.func(*args) 

1082 

1083 

1084def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction, include_first=True) -> None: 

1085 """Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last or 

1086 first and last element""" 

1087 

1088 all_inner_loops = [l for l in function_node.atoms(ast.LoopOverCoordinate) if l.is_innermost_loop] 

1089 assert len(all_inner_loops) == 1, "Transformation works only on kernels with exactly one inner loop" 

1090 inner_loop = all_inner_loops.pop() 

1091 

1092 for loop in parents_of_type(inner_loop, ast.LoopOverCoordinate, include_current=True): 

1093 if include_first: 1093 ↛ 1094line 1093 didn't jump to line 1094, because the condition on line 1093 was never true

1094 cut_loop(loop, [loop.start + 1, loop.stop - 1]) 

1095 else: 

1096 cut_loop(loop, [loop.stop - 1]) 

1097 

1098 simplify_conditionals(function_node.body, loop_counter_simplification=True) 

1099 cleanup_blocks(function_node.body) 

1100 

1101 move_constants_before_loop(function_node.body) 

1102 cleanup_blocks(function_node.body) 

1103 

1104 

1105# --------------------------------------- Helper Functions ------------------------------------------------------------- 

1106 

1107 

1108def typing_from_sympy_inspection(eqs, default_type="double", default_int_type='int64'): 

1109 """ 

1110 Creates a default symbol name to type mapping. 

1111 If a sympy Boolean is assigned to a symbol it is assumed to be 'bool' otherwise the default type, usually ('double') 

1112 

1113 Args: 

1114 eqs: list of equations 

1115 default_type: the type for non-boolean symbols 

1116 Returns: 

1117 dictionary, mapping symbol name to type 

1118 """ 

1119 result = defaultdict(lambda: default_type) 

1120 if hasattr(default_type, 'numpy_dtype'): 1120 ↛ 1121line 1120 didn't jump to line 1121, because the condition on line 1120 was never true

1121 result['_complex_type'] = (np.zeros((1,), default_type.numpy_dtype) * 1j).dtype 

1122 else: 

1123 result['_complex_type'] = (np.zeros((1,), default_type) * 1j).dtype 

1124 for eq in eqs: 

1125 if isinstance(eq, ast.Conditional): 

1126 result.update(typing_from_sympy_inspection(eq.true_block.args)) 

1127 if eq.false_block: 1127 ↛ 1128line 1127 didn't jump to line 1128, because the condition on line 1127 was never true

1128 result.update(typing_from_sympy_inspection( 

1129 eq.false_block.args)) 

1130 elif isinstance(eq, ast.Node) and not isinstance(eq, ast.SympyAssignment): 

1131 continue 

1132 else: 

1133 from pystencils.cpu.vectorization import vec_all, vec_any 

1134 if isinstance(eq.rhs, (vec_all, vec_any)): 

1135 result[eq.lhs.name] = "bool" 

1136 # problematic case here is when rhs is a symbol: then it is impossible to decide here without 

1137 # further information what type the left hand side is - default fallback is the dict value then 

1138 if isinstance(eq.rhs, Boolean) and not isinstance(eq.rhs, sp.Symbol): 

1139 result[eq.lhs.name] = "bool" 

1140 try: 

1141 result[eq.lhs.name] = get_type_of_expression(eq.rhs, 

1142 default_float_type=default_type, 

1143 default_int_type=default_int_type, 

1144 symbol_type_dict=result) 

1145 except Exception: 

1146 pass # gracefully fail in case get_type_of_expression cannot determine type 

1147 return result 

1148 

1149 

1150def get_next_parent_of_type(node, parent_type): 

1151 """Returns the next parent node of given type or None, if root is reached. 

1152 

1153 Traverses the AST nodes parents until a parent of given type was found. 

1154 If no such parent is found, None is returned 

1155 """ 

1156 parent = node.parent 

1157 while parent is not None: 

1158 if isinstance(parent, parent_type): 

1159 return parent 

1160 parent = parent.parent 

1161 return None 

1162 

1163 

1164def parents_of_type(node, parent_type, include_current=False): 

1165 """Generator for all parent nodes of given type""" 

1166 parent = node if include_current else node.parent 

1167 while parent is not None: 

1168 if isinstance(parent, parent_type): 

1169 yield parent 

1170 parent = parent.parent 

1171 

1172 

1173def get_optimal_loop_ordering(fields): 

1174 """ 

1175 Determines the optimal loop order for a given set of fields. 

1176 If the fields have different memory layout or different sizes an exception is thrown. 

1177 

1178 Args: 

1179 fields: sequence of fields 

1180 

1181 Returns: 

1182 list of coordinate ids, where the first list entry should be the outermost loop 

1183 """ 

1184 assert len(fields) > 0 

1185 ref_field = next(iter(fields)) 

1186 for field in fields: 

1187 if field.spatial_dimensions != ref_field.spatial_dimensions: 1187 ↛ 1188line 1187 didn't jump to line 1188, because the condition on line 1187 was never true

1188 raise ValueError( 

1189 "All fields have to have the same number of spatial dimensions. Spatial field dimensions: " 

1190 + str({f.name: f.spatial_shape 

1191 for f in fields})) 

1192 

1193 layouts = set([field.layout for field in fields]) 

1194 if len(layouts) > 1: 1194 ↛ 1195line 1194 didn't jump to line 1195, because the condition on line 1194 was never true

1195 raise ValueError( 

1196 "Due to different layout of the fields no optimal loop ordering exists " 

1197 + str({f.name: f.layout 

1198 for f in fields})) 

1199 layout = list(layouts)[0] 

1200 return list(layout) 

1201 

1202 

1203def get_loop_hierarchy(ast_node): 

1204 """Determines the loop structure around a given AST node, i.e. the node has to be inside the loops. 

1205 

1206 Returns: 

1207 sequence of LoopOverCoordinate nodes, starting from outer loop to innermost loop 

1208 """ 

1209 result = [] 

1210 node = ast_node 

1211 while node is not None: 

1212 node = get_next_parent_of_type(node, ast.LoopOverCoordinate) 

1213 if node: 

1214 result.append(node.coordinate_to_loop_over) 

1215 return reversed(result) 

1216 

1217 

1218def get_loop_counter_symbol_hierarchy(ast_node): 

1219 """Determines the loop counter symbols around a given AST node. 

1220 :param ast_node: the AST node 

1221 :return: list of loop counter symbols, where the first list entry is the symbol of the innermost loop 

1222 """ 

1223 result = [] 

1224 node = ast_node 

1225 while node is not None: 

1226 node = get_next_parent_of_type(node, ast.LoopOverCoordinate) 

1227 if node: 

1228 result.append(node.loop_counter_symbol) 

1229 return result 

1230 

1231 

1232def replace_inner_stride_with_one(ast_node: ast.KernelFunction) -> None: 

1233 """Replaces the stride of the innermost loop of a variable sized kernel with 1 (assumes optimal loop ordering). 

1234 

1235 Variable sized kernels can handle arbitrary field sizes and field shapes. However, the kernel is most efficient 

1236 if the innermost loop accesses the fields with stride 1. The inner loop can also only be vectorized if the inner 

1237 stride is 1. This transformation hard codes this inner stride to one to enable e.g. vectorization. 

1238 

1239 Warning: the assumption is not checked at runtime! 

1240 """ 

1241 inner_loops = [] 

1242 inner_loop_counters = set() 

1243 for loop in filtered_tree_iteration(ast_node, 

1244 ast.LoopOverCoordinate, 

1245 stop_type=ast.SympyAssignment): 

1246 if loop.is_innermost_loop: 

1247 inner_loops.append(loop) 

1248 inner_loop_counters.add(loop.coordinate_to_loop_over) 

1249 

1250 if len(inner_loop_counters) != 1: 1250 ↛ 1251line 1250 didn't jump to line 1251, because the condition on line 1250 was never true

1251 raise ValueError("Inner loops iterate over different coordinates") 

1252 

1253 inner_loop_counter = inner_loop_counters.pop() 

1254 

1255 parameters = ast_node.get_parameters() 

1256 stride_params = [ 

1257 p.symbol for p in parameters 

1258 if p.is_field_stride and p.symbol.coordinate == inner_loop_counter 

1259 ] 

1260 subs_dict = {stride_param: 1 for stride_param in stride_params} 

1261 if subs_dict: 

1262 ast_node.subs(subs_dict) 

1263 

1264 

1265def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int: 

1266 """Blocking of loops to enhance cache locality. Modifies the ast node in-place. 

1267 

1268 Args: 

1269 ast_node: kernel function node before vectorization transformation has been applied 

1270 block_size: sequence defining block size in x, y, (z) direction. 

1271 If chosen as zero the direction will not be used for blocking. 

1272 

1273 Returns: 

1274 number of dimensions blocked 

1275 """ 

1276 loops = [ 

1277 l for l in filtered_tree_iteration( 

1278 ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment) 

1279 ] 

1280 body = ast_node.body 

1281 

1282 coordinates = [] 

1283 coordinates_taken_into_account = 0 

1284 loop_starts = {} 

1285 loop_stops = {} 

1286 

1287 for loop in loops: 

1288 coord = loop.coordinate_to_loop_over 

1289 if coord not in coordinates: 

1290 coordinates.append(coord) 

1291 loop_starts[coord] = loop.start 

1292 loop_stops[coord] = loop.stop 

1293 else: 

1294 assert loop.start == loop_starts[coord] and loop.stop == loop_stops[coord], \ 

1295 f"Multiple loops over coordinate {coord} with different loop bounds" 

1296 

1297 # Create the outer loops that iterate over the blocks 

1298 outer_loop = None 

1299 for coord in reversed(coordinates): 

1300 if block_size[coord] == 0: 

1301 continue 

1302 coordinates_taken_into_account += 1 

1303 body = ast.Block([outer_loop]) if outer_loop else body 

1304 outer_loop = ast.LoopOverCoordinate(body, 

1305 coord, 

1306 loop_starts[coord], 

1307 loop_stops[coord], 

1308 step=block_size[coord], 

1309 is_block_loop=True) 

1310 

1311 ast_node.body = ast.Block([outer_loop]) 

1312 

1313 # modify the existing loops to only iterate within one block 

1314 for inner_loop in loops: 

1315 coord = inner_loop.coordinate_to_loop_over 

1316 if block_size[coord] == 0: 

1317 continue 

1318 block_ctr = ast.LoopOverCoordinate.get_block_loop_counter_symbol(coord) 

1319 loop_range = inner_loop.stop - inner_loop.start 

1320 if sp.sympify( 

1321 loop_range).is_number and loop_range % block_size[coord] == 0: 

1322 stop = block_ctr + block_size[coord] 

1323 else: 

1324 stop = sp.Min(inner_loop.stop, block_ctr + block_size[coord]) 

1325 inner_loop.start = block_ctr 

1326 inner_loop.stop = stop 

1327 return coordinates_taken_into_account 

1328 

1329 

1330def implement_interpolations(ast_node: ast.Node, 

1331 implement_by_texture_accesses: bool = False, 

1332 vectorize: bool = False, 

1333 use_hardware_interpolation_for_f32=True): 

1334 from pystencils.interpolation_astnodes import (InterpolatorAccess, TextureCachedField) 

1335 # TODO: perform this function on assignments, when unify_shape_symbols allows differently sized fields 

1336 

1337 assert not(implement_by_texture_accesses and vectorize), \ 

1338 "can only implement interpolations either by texture accesses or CPU vectorization" 

1339 FLOAT32_T = create_type('float32') 

1340 

1341 interpolation_accesses = ast_node.atoms(InterpolatorAccess) 

1342 if not interpolation_accesses: 1342 ↛ 1345line 1342 didn't jump to line 1345, because the condition on line 1342 was never false

1343 return ast_node 

1344 

1345 def can_use_hw_interpolation(i): 

1346 return (use_hardware_interpolation_for_f32 

1347 and implement_by_texture_accesses 

1348 and i.dtype == FLOAT32_T 

1349 and isinstance(i.symbol.interpolator, TextureCachedField)) 

1350 

1351 if implement_by_texture_accesses: 

1352 

1353 for i in interpolation_accesses: 

1354 from pystencils.interpolation_astnodes import _InterpolationSymbol 

1355 

1356 try: 

1357 import pycuda.driver as cuda 

1358 texture = TextureCachedField.from_interpolator(i.interpolator) 

1359 if can_use_hw_interpolation(i): 

1360 texture.filter_mode = cuda.filter_mode.LINEAR 

1361 else: 

1362 texture.filter_mode = cuda.filter_mode.POINT 

1363 texture.read_as_integer = True 

1364 except Exception as e: 

1365 raise e 

1366 i.symbol = _InterpolationSymbol(str(texture), i.symbol.field, texture) 

1367 

1368 # from pystencils.math_optimizations import ReplaceOptim, optimize_ast 

1369 

1370 # ImplementInterpolationByStencils = ReplaceOptim(lambda e: isinstance(e, InterpolatorAccess) 

1371 # and not can_use_hw_interpolation(i), 

1372 # lambda e: e.implementation_with_stencils() 

1373 # ) 

1374 

1375 # RemoveConjugate = ReplaceOptim(lambda e: isinstance(e, sp.conjugate), 

1376 # lambda e: e.args[0] 

1377 # ) 

1378 if vectorize: 

1379 # TODO can be done in _interpolator_access_to_stencils field.absolute_access == simd_gather 

1380 raise NotImplementedError() 

1381 else: 

1382 substitutions = {i: i.implementation_with_stencils() 

1383 for i in interpolation_accesses if not can_use_hw_interpolation(i)} 

1384 if isinstance(ast_node, AssignmentCollection): 

1385 ast_node = ast_node.subs(substitutions) 

1386 else: 

1387 ast_node.subs(substitutions) 

1388 

1389 return ast_node