1import collections.abc 

2import itertools 

3import uuid 

4from typing import Any, List, Optional, Sequence, Set, Union 

5 

6import sympy as sp 

7 

8import pystencils 

9from pystencils.data_types import TypedImaginaryUnit, TypedSymbol, cast_func, create_type 

10from pystencils.field import Field 

11from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol 

12from pystencils.sympyextensions import fast_subs 

13 

14NodeOrExpr = Union['Node', sp.Expr] 

15 

16 

17class Node: 

18 """Base class for all AST nodes.""" 

19 

20 def __init__(self, parent: Optional['Node'] = None): 

21 self.parent = parent 

22 

23 @property 

24 def args(self) -> List[NodeOrExpr]: 

25 """Returns all arguments/children of this node.""" 

26 raise NotImplementedError() 

27 

28 @property 

29 def symbols_defined(self) -> Set[sp.Symbol]: 

30 """Set of symbols which are defined by this node.""" 

31 raise NotImplementedError() 

32 

33 @property 

34 def undefined_symbols(self) -> Set[sp.Symbol]: 

35 """Symbols which are used but are not defined inside this node.""" 

36 raise NotImplementedError() 

37 

38 def subs(self, subs_dict) -> None: 

39 """Inplace! Substitute, similar to sympy's but modifies the AST inplace.""" 

40 for i, a in enumerate(self.args): 

41 result = a.subs(subs_dict) 

42 if isinstance(a, sp.Expr): # sympy expressions' subs is out-of-place 

43 self.args[i] = result 

44 else: # all other should be in-place 

45 assert result is None 

46 

47 @property 

48 def func(self): 

49 return self.__class__ 

50 

51 def atoms(self, arg_type) -> Set[Any]: 

52 """Returns a set of all descendants recursively, which are an instance of the given type.""" 

53 result = set() 

54 for arg in self.args: 

55 if isinstance(arg, arg_type): 

56 result.add(arg) 

57 result.update(arg.atoms(arg_type)) 

58 return result 

59 

60 

61class Conditional(Node): 

62 """Conditional that maps to a 'if' statement in C/C++. 

63 

64 Try to avoid using this node inside of loops, since currently this construction can not be vectorized. 

65 Consider using assignments with sympy.Piecewise in this case. 

66 

67 Args: 

68 condition_expr: sympy relational expression 

69 true_block: block which is run if conditional is true 

70 false_block: optional block which is run if conditional is false 

71 """ 

72 

73 def __init__(self, condition_expr: sp.Basic, true_block: Union['Block', 'SympyAssignment'], 

74 false_block: Optional['Block'] = None) -> None: 

75 super(Conditional, self).__init__(parent=None) 

76 

77 self.condition_expr = condition_expr 

78 

79 def handle_child(c): 

80 if c is None: 

81 return None 

82 if not isinstance(c, Block): 82 ↛ 83line 82 didn't jump to line 83, because the condition on line 82 was never true

83 c = Block([c]) 

84 c.parent = self 

85 return c 

86 

87 self.true_block = handle_child(true_block) 

88 self.false_block = handle_child(false_block) 

89 

90 def subs(self, subs_dict): 

91 self.true_block.subs(subs_dict) 

92 if self.false_block: 92 ↛ 93line 92 didn't jump to line 93, because the condition on line 92 was never true

93 self.false_block.subs(subs_dict) 

94 self.condition_expr = self.condition_expr.subs(subs_dict) 

95 

96 @property 

97 def args(self): 

98 result = [self.condition_expr, self.true_block] 

99 if self.false_block: 99 ↛ 100line 99 didn't jump to line 100, because the condition on line 99 was never true

100 result.append(self.false_block) 

101 return result 

102 

103 @property 

104 def symbols_defined(self): 

105 return set() 

106 

107 @property 

108 def undefined_symbols(self): 

109 result = self.true_block.undefined_symbols 

110 if self.false_block: 110 ↛ 111line 110 didn't jump to line 111, because the condition on line 110 was never true

111 result.update(self.false_block.undefined_symbols) 

112 if hasattr(self.condition_expr, 'atoms'): 112 ↛ 114line 112 didn't jump to line 114, because the condition on line 112 was never false

113 result.update(self.condition_expr.atoms(sp.Symbol)) 

114 return result 

115 

116 def __str__(self): 

117 return self.__repr__() 

118 

119 def __repr__(self): 

120 result = f'if:({self.condition_expr!r}) ' 

121 if self.true_block: 

122 result += f'\n\t{self.true_block}) ' 

123 if self.false_block: 

124 result = 'else: ' 

125 result += f'\n\t{self.false_block} ' 

126 

127 return result 

128 

129 def replace_by_true_block(self): 

130 """Replaces the conditional by its True block""" 

131 self.parent.replace(self, [self.true_block]) 

132 

133 def replace_by_false_block(self): 

134 """Replaces the conditional by its False block""" 

135 self.parent.replace(self, [self.false_block] if self.false_block else []) 

136 

137 

138class KernelFunction(Node): 

139 

140 class Parameter: 

141 """Function parameter. 

142 

143 Each undefined symbol in a `KernelFunction` node becomes a parameter to the function. 

144 Parameters are either symbols introduced by the user that never occur on the left hand side of an 

145 Assignment, or are related to fields/arrays passed to the function. 

146 

147 A parameter consists of the typed symbol (symbol property). For field related parameters this is a symbol 

148 defined in pystencils.kernelparameters. 

149 If the parameter is related to one or multiple fields, these fields are referenced in the fields property. 

150 """ 

151 

152 def __init__(self, symbol, fields): 

153 self.symbol = symbol # type: TypedSymbol 

154 self.fields = fields # type: Sequence[Field] 

155 

156 def __repr__(self): 

157 return repr(self.symbol) 

158 

159 @property 

160 def is_field_stride(self): 

161 return isinstance(self.symbol, FieldStrideSymbol) 

162 

163 @property 

164 def is_field_shape(self): 

165 return isinstance(self.symbol, FieldShapeSymbol) 

166 

167 @property 

168 def is_field_pointer(self): 

169 return isinstance(self.symbol, FieldPointerSymbol) 

170 

171 @property 

172 def is_field_parameter(self): 

173 return self.is_field_pointer or self.is_field_shape or self.is_field_stride 

174 

175 @property 

176 def field_name(self): 

177 return self.fields[0].name 

178 

179 def __init__(self, body, target, backend, compile_function, ghost_layers, function_name="kernel", assignments=None): 

180 super(KernelFunction, self).__init__() 

181 self._body = body 

182 body.parent = self 

183 self.function_name = function_name 

184 self._body.parent = self 

185 self.ghost_layers = ghost_layers 

186 self._target = target 

187 self._backend = backend 

188 # these variables are assumed to be global, so no automatic parameter is generated for them 

189 self.global_variables = set() 

190 self.instruction_set = None # used in `vectorize` function to tell the backend which i.s. (SSE,AVX) to use 

191 # function that compiles the node to a Python callable, is set by the backends 

192 self._compile_function = compile_function 

193 self.assignments = assignments 

194 

195 @property 

196 def target(self): 

197 """Currently either 'cpu' or 'gpu' """ 

198 return self._target 

199 

200 @property 

201 def backend(self): 

202 """Backend for generating the code e.g. 'llvm', 'c', 'cuda' """ 

203 return self._backend 

204 

205 @property 

206 def symbols_defined(self): 

207 return set() 

208 

209 @property 

210 def undefined_symbols(self): 

211 return set() 

212 

213 @property 

214 def body(self): 

215 return self._body 

216 

217 @body.setter 

218 def body(self, value): 

219 self._body = value 

220 self._body.parent = self 

221 

222 @property 

223 def args(self): 

224 return self._body, 

225 

226 @property 

227 def fields_accessed(self) -> Set[Field]: 

228 """Set of Field instances: fields which are accessed inside this kernel function""" 

229 from pystencils.interpolation_astnodes import InterpolatorAccess 

230 return set(o.field for o in itertools.chain(self.atoms(ResolvedFieldAccess), self.atoms(InterpolatorAccess))) 

231 

232 @property 

233 def fields_written(self) -> Set[Field]: 

234 assignments = self.atoms(SympyAssignment) 

235 return {a.lhs.field for a in assignments if isinstance(a.lhs, ResolvedFieldAccess)} 

236 

237 @property 

238 def fields_read(self) -> Set[Field]: 

239 assignments = self.atoms(SympyAssignment) 

240 return set().union(itertools.chain.from_iterable([f.field for f in a.rhs.free_symbols if hasattr(f, 'field')] 

241 for a in assignments)) 

242 

243 def get_parameters(self) -> Sequence['KernelFunction.Parameter']: 

244 """Returns list of parameters for this function. 

245 

246 This function is expensive, cache the result where possible! 

247 """ 

248 field_map = {f.name: f for f in self.fields_accessed} 

249 

250 def get_fields(symbol): 

251 if hasattr(symbol, 'field_name'): 

252 return field_map[symbol.field_name], 

253 elif hasattr(symbol, 'field_names'): 

254 return tuple(field_map[fn] for fn in symbol.field_names) 

255 return () 

256 

257 argument_symbols = self._body.undefined_symbols - self.global_variables 

258 parameters = [self.Parameter(symbol, get_fields(symbol)) for symbol in argument_symbols] 

259 if hasattr(self, 'indexing'): 259 ↛ 260line 259 didn't jump to line 260, because the condition on line 259 was never true

260 parameters += [self.Parameter(s, []) for s in self.indexing.symbolic_parameters()] 

261 parameters.sort(key=lambda p: p.symbol.name) 

262 return parameters 

263 

264 def __str__(self): 

265 params = [p.symbol for p in self.get_parameters()] 

266 return '{0} {1}({2})\n{3}'.format(type(self).__name__, self.function_name, params, 

267 ("\t" + "\t".join(str(self.body).splitlines(True)))) 

268 

269 def __repr__(self): 

270 params = [p.symbol for p in self.get_parameters()] 

271 return f'{type(self).__name__} {self.function_name}({params})' 

272 

273 def compile(self, *args, **kwargs): 

274 if self._compile_function is None: 274 ↛ 275line 274 didn't jump to line 275, because the condition on line 274 was never true

275 raise ValueError("No compile-function provided for this KernelFunction node") 

276 return self._compile_function(self, *args, **kwargs) 

277 

278 

279class SkipIteration(Node): 

280 @property 

281 def args(self): 

282 return [] 

283 

284 @property 

285 def symbols_defined(self): 

286 return set() 

287 

288 @property 

289 def undefined_symbols(self): 

290 return set() 

291 

292 

293class Block(Node): 

294 def __init__(self, nodes: List[Node]): 

295 super(Block, self).__init__() 

296 self._nodes = nodes 

297 self.parent = None 

298 for n in self._nodes: 

299 try: 

300 n.parent = self 

301 except AttributeError: 

302 pass 

303 

304 @property 

305 def args(self): 

306 return self._nodes 

307 

308 def subs(self, subs_dict) -> None: 

309 for a in self.args: 

310 a.subs(subs_dict) 

311 

312 def fast_subs(self, subs_dict, skip=None): 

313 self._nodes = [fast_subs(a, subs_dict, skip) for a in self._nodes] 

314 return self 

315 

316 def insert_front(self, node, if_not_exists=False): 

317 if if_not_exists and len(self._nodes) > 0 and self._nodes[0] == node: 

318 return 

319 if isinstance(node, collections.abc.Iterable): 319 ↛ 320line 319 didn't jump to line 320, because the condition on line 319 was never true

320 node = list(node) 

321 for n in node: 

322 n.parent = self 

323 

324 self._nodes = node + self._nodes 

325 else: 

326 node.parent = self 

327 self._nodes.insert(0, node) 

328 

329 def insert_before(self, new_node, insert_before, if_not_exists=False): 

330 new_node.parent = self 

331 assert self._nodes.count(insert_before) == 1 

332 idx = self._nodes.index(insert_before) 

333 

334 # move all assignment (definitions to the top) 

335 if isinstance(new_node, SympyAssignment) and new_node.is_declaration: 335 ↛ 342line 335 didn't jump to line 342, because the condition on line 335 was never false

336 while idx > 0: 

337 pn = self._nodes[idx - 1] 

338 if isinstance(pn, LoopOverCoordinate) or isinstance(pn, Conditional): 338 ↛ 339,   338 ↛ 3412 missed branches: 1) line 338 didn't jump to line 339, because the condition on line 338 was never true, 2) line 338 didn't jump to line 341, because the condition on line 338 was never false

339 idx -= 1 

340 else: 

341 break 

342 if not if_not_exists or self._nodes[idx] != new_node: 342 ↛ exitline 342 didn't return from function 'insert_before', because the condition on line 342 was never false

343 self._nodes.insert(idx, new_node) 

344 

345 def insert_after(self, new_node, insert_after, if_not_exists=False): 

346 new_node.parent = self 

347 assert self._nodes.count(insert_after) == 1 

348 idx = self._nodes.index(insert_after) + 1 

349 

350 # move all assignment (definitions to the top) 

351 if isinstance(new_node, SympyAssignment) and new_node.is_declaration: 351 ↛ 352line 351 didn't jump to line 352, because the condition on line 351 was never true

352 while idx > 0: 

353 pn = self._nodes[idx - 1] 

354 if isinstance(pn, LoopOverCoordinate) or isinstance(pn, Conditional): 

355 idx -= 1 

356 else: 

357 break 

358 if not if_not_exists or not (self._nodes[idx - 1] == new_node 

359 or (idx < len(self._nodes) and self._nodes[idx] == new_node)): 

360 self._nodes.insert(idx, new_node) 

361 

362 def append(self, node): 

363 if isinstance(node, list) or isinstance(node, tuple): 363 ↛ 364line 363 didn't jump to line 364, because the condition on line 363 was never true

364 for n in node: 

365 n.parent = self 

366 self._nodes.append(n) 

367 else: 

368 node.parent = self 

369 self._nodes.append(node) 

370 

371 def take_child_nodes(self): 

372 tmp = self._nodes 

373 self._nodes = [] 

374 return tmp 

375 

376 def replace(self, child, replacements): 

377 assert self._nodes.count(child) == 1 

378 idx = self._nodes.index(child) 

379 del self._nodes[idx] 

380 if type(replacements) is list: 

381 for e in replacements: 

382 e.parent = self 

383 self._nodes = self._nodes[:idx] + replacements + self._nodes[idx:] 

384 else: 

385 replacements.parent = self 

386 self._nodes.insert(idx, replacements) 

387 

388 @property 

389 def symbols_defined(self): 

390 result = set() 

391 for a in self.args: 

392 if isinstance(a, pystencils.Assignment): 392 ↛ 393line 392 didn't jump to line 393, because the condition on line 392 was never true

393 result.update(a.free_symbols) 

394 else: 

395 result.update(a.symbols_defined) 

396 return result 

397 

398 @property 

399 def undefined_symbols(self): 

400 result = set() 

401 defined_symbols = set() 

402 for a in self.args: 

403 if isinstance(a, pystencils.Assignment): 403 ↛ 404line 403 didn't jump to line 404, because the condition on line 403 was never true

404 result.update(a.free_symbols) 

405 defined_symbols.update({a.lhs}) 

406 else: 

407 result.update(a.undefined_symbols) 

408 defined_symbols.update(a.symbols_defined) 

409 return result - defined_symbols 

410 

411 def __str__(self): 

412 return "Block " + ''.join('{!s}\n'.format(node) for node in self._nodes) 

413 

414 def __repr__(self): 

415 return "Block" 

416 

417 

418class PragmaBlock(Block): 

419 def __init__(self, pragma_line, nodes): 

420 super(PragmaBlock, self).__init__(nodes) 

421 self.pragma_line = pragma_line 

422 for n in nodes: 

423 n.parent = self 

424 

425 def __repr__(self): 

426 return self.pragma_line 

427 

428 

429class LoopOverCoordinate(Node): 

430 LOOP_COUNTER_NAME_PREFIX = "ctr" 

431 BLOCK_LOOP_COUNTER_NAME_PREFIX = "_blockctr" 

432 

433 def __init__(self, body, coordinate_to_loop_over, start, stop, step=1, is_block_loop=False): 

434 super(LoopOverCoordinate, self).__init__(parent=None) 

435 self.body = body 

436 body.parent = self 

437 self.coordinate_to_loop_over = coordinate_to_loop_over 

438 self.start = start 

439 self.stop = stop 

440 self.step = step 

441 self.body.parent = self 

442 self.prefix_lines = [] 

443 self.is_block_loop = is_block_loop 

444 

445 def new_loop_with_different_body(self, new_body): 

446 result = LoopOverCoordinate(new_body, self.coordinate_to_loop_over, self.start, self.stop, 

447 self.step, self.is_block_loop) 

448 result.prefix_lines = [l for l in self.prefix_lines] 

449 return result 

450 

451 def subs(self, subs_dict): 

452 self.body.subs(subs_dict) 

453 if hasattr(self.start, "subs"): 453 ↛ 454line 453 didn't jump to line 454, because the condition on line 453 was never true

454 self.start = self.start.subs(subs_dict) 

455 if hasattr(self.stop, "subs"): 

456 self.stop = self.stop.subs(subs_dict) 

457 if hasattr(self.step, "subs"): 457 ↛ 458line 457 didn't jump to line 458, because the condition on line 457 was never true

458 self.step = self.step.subs(subs_dict) 

459 

460 def fast_subs(self, subs_dict, skip=None): 

461 self.body = fast_subs(self.body, subs_dict, skip) 

462 if isinstance(self.start, sp.Basic): 462 ↛ 463line 462 didn't jump to line 463, because the condition on line 462 was never true

463 self.start = fast_subs(self.start, subs_dict, skip) 

464 if isinstance(self.stop, sp.Basic): 

465 self.stop = fast_subs(self.stop, subs_dict, skip) 

466 if isinstance(self.step, sp.Basic): 466 ↛ 467line 466 didn't jump to line 467, because the condition on line 466 was never true

467 self.step = fast_subs(self.step, subs_dict, skip) 

468 return self 

469 

470 @property 

471 def args(self): 

472 result = [self.body] 

473 for e in [self.start, self.stop, self.step]: 

474 if hasattr(e, "args"): 

475 result.append(e) 

476 return result 

477 

478 def replace(self, child, replacement): 

479 if child == self.body: 

480 self.body = replacement 

481 elif child == self.start: 

482 self.start = replacement 

483 elif child == self.step: 

484 self.step = replacement 

485 elif child == self.stop: 

486 self.stop = replacement 

487 

488 @property 

489 def symbols_defined(self): 

490 return {self.loop_counter_symbol} 

491 

492 @property 

493 def undefined_symbols(self): 

494 result = self.body.undefined_symbols 

495 for possible_symbol in [self.start, self.stop, self.step]: 

496 if isinstance(possible_symbol, Node) or isinstance(possible_symbol, sp.Basic): 

497 result.update(possible_symbol.atoms(sp.Symbol)) 

498 return result - {self.loop_counter_symbol} 

499 

500 @staticmethod 

501 def get_loop_counter_name(coordinate_to_loop_over): 

502 return f"{LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX}_{coordinate_to_loop_over}" 

503 

504 @staticmethod 

505 def get_block_loop_counter_name(coordinate_to_loop_over): 

506 return f"{LoopOverCoordinate.BLOCK_LOOP_COUNTER_NAME_PREFIX}_{coordinate_to_loop_over}" 

507 

508 @property 

509 def loop_counter_name(self): 

510 if self.is_block_loop: 510 ↛ 511line 510 didn't jump to line 511, because the condition on line 510 was never true

511 return LoopOverCoordinate.get_block_loop_counter_name(self.coordinate_to_loop_over) 

512 else: 

513 return LoopOverCoordinate.get_loop_counter_name(self.coordinate_to_loop_over) 

514 

515 @staticmethod 

516 def is_loop_counter_symbol(symbol): 

517 prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX 

518 if not symbol.name.startswith(prefix): 

519 return None 

520 if symbol.dtype != create_type('int'): 

521 return None 

522 coordinate = int(symbol.name[len(prefix) + 1:]) 

523 return coordinate 

524 

525 @staticmethod 

526 def get_loop_counter_symbol(coordinate_to_loop_over): 

527 return TypedSymbol(LoopOverCoordinate.get_loop_counter_name(coordinate_to_loop_over), 'int', nonnegative=True) 

528 

529 @staticmethod 

530 def get_block_loop_counter_symbol(coordinate_to_loop_over): 

531 return TypedSymbol(LoopOverCoordinate.get_block_loop_counter_name(coordinate_to_loop_over), 

532 'int', 

533 nonnegative=True) 

534 

535 @property 

536 def loop_counter_symbol(self): 

537 if self.is_block_loop: 537 ↛ 538line 537 didn't jump to line 538, because the condition on line 537 was never true

538 return self.get_block_loop_counter_symbol(self.coordinate_to_loop_over) 

539 else: 

540 return self.get_loop_counter_symbol(self.coordinate_to_loop_over) 

541 

542 @property 

543 def is_outermost_loop(self): 

544 from pystencils.transformations import get_next_parent_of_type 

545 return get_next_parent_of_type(self, LoopOverCoordinate) is None 

546 

547 @property 

548 def is_innermost_loop(self): 

549 return len(self.atoms(LoopOverCoordinate)) == 0 

550 

551 def __str__(self): 

552 return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loop_counter_name, self.start, 

553 self.loop_counter_name, self.stop, 

554 self.loop_counter_name, self.step, 

555 ("\t" + "\t".join(str(self.body).splitlines(True)))) 

556 

557 def __repr__(self): 

558 return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loop_counter_name, self.start, 

559 self.loop_counter_name, self.stop, 

560 self.loop_counter_name, self.step) 

561 

562 

563class SympyAssignment(Node): 

564 def __init__(self, lhs_symbol, rhs_expr, is_const=True, use_auto=False): 

565 super(SympyAssignment, self).__init__(parent=None) 

566 self._lhs_symbol = sp.sympify(lhs_symbol) 

567 self.rhs = sp.sympify(rhs_expr) 

568 self._is_const = is_const 

569 self._is_declaration = self.__is_declaration() 

570 self.use_auto = use_auto 

571 

572 def __is_declaration(self): 

573 if isinstance(self._lhs_symbol, cast_func): 

574 return False 

575 if any(isinstance(self._lhs_symbol, c) for c in (Field.Access, sp.Indexed, TemporaryMemoryAllocation)): 

576 return False 

577 return True 

578 

579 @property 

580 def lhs(self): 

581 return self._lhs_symbol 

582 

583 @lhs.setter 

584 def lhs(self, new_value): 

585 self._lhs_symbol = new_value 

586 self._is_declaration = self.__is_declaration() 

587 

588 def subs(self, subs_dict): 

589 self.lhs = fast_subs(self.lhs, subs_dict) 

590 self.rhs = fast_subs(self.rhs, subs_dict) 

591 

592 def optimize(self, optimizations): 

593 try: 

594 from sympy.codegen.rewriting import optimize 

595 self.rhs = optimize(self.rhs, optimizations) 

596 except Exception: 

597 pass 

598 

599 @property 

600 def args(self): 

601 return [self._lhs_symbol, self.rhs, sp.sympify(self._is_const)] 

602 

603 @property 

604 def symbols_defined(self): 

605 if not self._is_declaration: 

606 return set() 

607 return {self._lhs_symbol} 

608 

609 @property 

610 def undefined_symbols(self): 

611 result = {s for s in self.rhs.free_symbols if not isinstance(s, sp.Indexed)} 

612 # Add loop counters if there a field accesses 

613 loop_counters = set() 

614 for symbol in result: 

615 if isinstance(symbol, Field.Access): 615 ↛ 616line 615 didn't jump to line 616, because the condition on line 615 was never true

616 for i in range(len(symbol.offsets)): 

617 loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i)) 

618 result = {r for r in result if not isinstance(r, TypedImaginaryUnit)} 

619 result.update(loop_counters) 

620 result.update(self._lhs_symbol.atoms(sp.Symbol)) 

621 return result 

622 

623 @property 

624 def is_declaration(self): 

625 return self._is_declaration 

626 

627 @property 

628 def is_const(self): 

629 return self._is_const 

630 

631 def replace(self, child, replacement): 

632 if child == self.lhs: 

633 replacement.parent = self 

634 self.lhs = replacement 

635 elif child == self.rhs: 

636 replacement.parent = self 

637 self.rhs = replacement 

638 else: 

639 raise ValueError(f'{replacement} is not in args of {self.__class__}') 

640 

641 def __repr__(self): 

642 return repr(self.lhs) + " ← " + repr(self.rhs) 

643 

644 def _repr_html_(self): 

645 printed_lhs = sp.latex(self.lhs) 

646 printed_rhs = sp.latex(self.rhs) 

647 return f"${printed_lhs} \\leftarrow {printed_rhs}$" 

648 

649 def __hash__(self): 

650 return hash((self.lhs, self.rhs)) 

651 

652 def __eq__(self, other): 

653 return type(self) == type(other) and (self.lhs, self.rhs) == (other.lhs, other.rhs) 

654 

655 

656class ResolvedFieldAccess(sp.Indexed): 

657 def __new__(cls, base, linearized_index, field, offsets, idx_coordinate_values): 

658 if not isinstance(base, sp.IndexedBase): 

659 assert isinstance(base, TypedSymbol) 

660 base = sp.IndexedBase(base, shape=(1,)) 

661 assert isinstance(base.label, TypedSymbol) 

662 obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearized_index) 

663 obj.field = field 

664 obj.offsets = offsets 

665 obj.idx_coordinate_values = idx_coordinate_values 

666 return obj 

667 

668 def _eval_subs(self, old, new): 

669 return ResolvedFieldAccess(self.args[0], 

670 self.args[1].subs(old, new), 

671 self.field, self.offsets, self.idx_coordinate_values) 

672 

673 def fast_subs(self, substitutions, skip=None): 

674 if self in substitutions: 

675 return substitutions[self] 

676 return ResolvedFieldAccess(self.args[0].subs(substitutions), 

677 self.args[1].subs(substitutions), 

678 self.field, self.offsets, self.idx_coordinate_values) 

679 

680 def _hashable_content(self): 

681 super_class_contents = super(ResolvedFieldAccess, self)._hashable_content() 

682 return super_class_contents + tuple(self.offsets) + (repr(self.idx_coordinate_values), hash(self.field)) 

683 

684 @property 

685 def typed_symbol(self): 

686 return self.base.label 

687 

688 def __str__(self): 

689 top = super(ResolvedFieldAccess, self).__str__() 

690 return f"{top} ({self.typed_symbol.dtype})" 

691 

692 def __getnewargs__(self): 

693 return self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values 

694 

695 def __getnewargs_ex__(self): 

696 return (self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values), {} 

697 

698 

699class TemporaryMemoryAllocation(Node): 

700 """Node for temporary memory buffer allocation. 

701 

702 Always allocates aligned memory. 

703 

704 Args: 

705 typed_symbol: symbol used as pointer (has to be typed) 

706 size: number of elements to allocate 

707 align_offset: the align_offset's element is aligned 

708 """ 

709 

710 def __init__(self, typed_symbol: TypedSymbol, size, align_offset): 

711 super(TemporaryMemoryAllocation, self).__init__(parent=None) 

712 self.symbol = typed_symbol 

713 self.size = size 

714 self.headers = ['<stdlib.h>'] 

715 self._align_offset = align_offset 

716 

717 @property 

718 def symbols_defined(self): 

719 return {self.symbol} 

720 

721 @property 

722 def undefined_symbols(self): 

723 if isinstance(self.size, sp.Basic): 

724 return self.size.atoms(sp.Symbol) 

725 else: 

726 return set() 

727 

728 @property 

729 def args(self): 

730 return [self.symbol] 

731 

732 def offset(self, byte_alignment): 

733 """Number of ELEMENTS to skip for a pointer that is aligned to byte_alignment.""" 

734 np_dtype = self.symbol.dtype.base_type.numpy_dtype 

735 assert byte_alignment % np_dtype.itemsize == 0 

736 return -self._align_offset % (byte_alignment / np_dtype.itemsize) 

737 

738 

739class TemporaryMemoryFree(Node): 

740 def __init__(self, alloc_node): 

741 super(TemporaryMemoryFree, self).__init__(parent=None) 

742 self.alloc_node = alloc_node 

743 

744 @property 

745 def symbol(self): 

746 return self.alloc_node.symbol 

747 

748 def offset(self, byte_alignment): 

749 return self.alloc_node.offset(byte_alignment) 

750 

751 @property 

752 def symbols_defined(self): 

753 return set() 

754 

755 @property 

756 def undefined_symbols(self): 

757 return set() 

758 

759 @property 

760 def args(self): 

761 return [] 

762 

763 

764def early_out(condition): 

765 from pystencils.cpu.vectorization import vec_all 

766 return Conditional(vec_all(condition), Block([SkipIteration()])) 

767 

768 

769def get_dummy_symbol(dtype='bool'): 

770 return TypedSymbol(f'dummy{uuid.uuid4().hex}', create_type(dtype)) 

771 

772 

773class SourceCodeComment(Node): 

774 def __init__(self, text): 

775 self.text = text 

776 

777 @property 

778 def args(self): 

779 return [] 

780 

781 @property 

782 def symbols_defined(self): 

783 return set() 

784 

785 @property 

786 def undefined_symbols(self): 

787 return set() 

788 

789 def __str__(self): 

790 return "/* " + self.text + " */" 

791 

792 def __repr__(self): 

793 return self.__str__() 

794 

795 

796class EmptyLine(Node): 

797 def __init__(self): 

798 pass 

799 

800 @property 

801 def args(self): 

802 return [] 

803 

804 @property 

805 def symbols_defined(self): 

806 return set() 

807 

808 @property 

809 def undefined_symbols(self): 

810 return set() 

811 

812 def __str__(self): 

813 return "" 

814 

815 def __repr__(self): 

816 return self.__str__() 

817 

818 

819class ConditionalFieldAccess(sp.Function): 

820 """ 

821 :class:`pystencils.Field.Access` that is only executed if a certain condition is met. 

822 Can be used, for instance, for out-of-bound checks. 

823 """ 

824 

825 def __new__(cls, field_access, outofbounds_condition, outofbounds_value=0): 

826 return sp.Function.__new__(cls, field_access, outofbounds_condition, sp.S(outofbounds_value)) 

827 

828 @property 

829 def access(self): 

830 return self.args[0] 

831 

832 @property 

833 def outofbounds_condition(self): 

834 return self.args[1] 

835 

836 @property 

837 def outofbounds_value(self): 

838 return self.args[2] 

839 

840 def __getnewargs__(self): 

841 return self.access, self.outofbounds_condition, self.outofbounds_value 

842 

843 def __getnewargs_ex__(self): 

844 return (self.access, self.outofbounds_condition, self.outofbounds_value), {} 

845 

846 

847class NontemporalFence(Node): 

848 def __init__(self): 

849 super(NontemporalFence, self).__init__(parent=None) 

850 

851 @property 

852 def symbols_defined(self): 

853 return set() 

854 

855 @property 

856 def undefined_symbols(self): 

857 return set() 

858 

859 @property 

860 def args(self): 

861 return [] 

862 

863 def __eq__(self, other): 

864 return isinstance(other, NontemporalFence) 

865 

866 

867class CachelineSize(Node): 

868 symbol = sp.Symbol("_clsize") 

869 mask_symbol = sp.Symbol("_clsize_mask") 

870 last_symbol = sp.Symbol("_cl_lastvec") 

871 

872 def __init__(self): 

873 super(CachelineSize, self).__init__(parent=None) 

874 

875 @property 

876 def symbols_defined(self): 

877 return set([self.symbol, self.mask_symbol, self.last_symbol]) 

878 

879 @property 

880 def undefined_symbols(self): 

881 return set() 

882 

883 @property 

884 def args(self): 

885 return [] 

886 

887 def __eq__(self, other): 

888 return isinstance(other, CachelineSize)