1import re 

2from collections import namedtuple 

3import hashlib 

4from typing import Set 

5 

6import numpy as np 

7import sympy as sp 

8from sympy.core import S 

9from sympy.logic.boolalg import BooleanFalse, BooleanTrue 

10 

11from pystencils.astnodes import KernelFunction, Node, CachelineSize 

12from pystencils.cpu.vectorization import vec_all, vec_any 

13from pystencils.data_types import ( 

14 PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression, 

15 reinterpret_cast_func, vector_memory_access, BasicType, TypedSymbol) 

16from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt 

17from pystencils.integer_functions import ( 

18 bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, 

19 int_div, int_power_of_2, modulo_ceil) 

20 

21try: 

22 from sympy.printing.c import C99CodePrinter as CCodePrinter # for sympy versions > 1.6 

23except ImportError: 

24 from sympy.printing.ccode import C99CodePrinter as CCodePrinter 

25 

26__all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter'] 

27 

28 

29HEADER_REGEX = re.compile(r'^[<"].*[">]$') 

30 

31KERNCRAFT_NO_TERNARY_MODE = False 

32 

33 

34def generate_c(ast_node: Node, 

35 signature_only: bool = False, 

36 dialect='c', 

37 custom_backend=None, 

38 with_globals=True) -> str: 

39 """Prints an abstract syntax tree node as C or CUDA code. 

40 

41 This function does not need to distinguish for most AST nodes between C, C++ or CUDA code, it just prints 'C-like' 

42 code as encoded in the abstract syntax tree (AST). The AST is built differently for C or CUDA by calling different 

43 create_kernel functions. 

44 

45 Args: 

46 ast_node: ast representation of kernel 

47 signature_only: generate signature without function body 

48 dialect: 'c', 'cuda' or opencl 

49 custom_backend: use own custom printer for code generation 

50 with_globals: enable usage of global variables 

51 Returns: 

52 C-like code for the ast node and its descendants 

53 """ 

54 global_declarations = get_global_declarations(ast_node) 

55 for d in global_declarations: 55 ↛ 56line 55 didn't jump to line 56, because the loop on line 55 never started

56 if hasattr(ast_node, "global_variables"): 

57 ast_node.global_variables.update(d.symbols_defined) 

58 else: 

59 ast_node.global_variables = d.symbols_defined 

60 if custom_backend: 60 ↛ 61line 60 didn't jump to line 61, because the condition on line 60 was never true

61 printer = custom_backend 

62 elif dialect == 'c': 62 ↛ 69line 62 didn't jump to line 69, because the condition on line 62 was never false

63 try: 

64 instruction_set = ast_node.instruction_set 

65 except Exception: 

66 instruction_set = None 

67 printer = CBackend(signature_only=signature_only, 

68 vector_instruction_set=instruction_set) 

69 elif dialect == 'cuda': 

70 from pystencils.backends.cuda_backend import CudaBackend 

71 printer = CudaBackend(signature_only=signature_only) 

72 elif dialect == 'opencl': 

73 from pystencils.backends.opencl_backend import OpenClBackend 

74 printer = OpenClBackend(signature_only=signature_only) 

75 else: 

76 raise ValueError("Unknown dialect: " + str(dialect)) 

77 code = printer(ast_node) 

78 if not signature_only and isinstance(ast_node, KernelFunction): 78 ↛ 84line 78 didn't jump to line 84, because the condition on line 78 was never false

79 if with_globals and global_declarations: 79 ↛ 80line 79 didn't jump to line 80, because the condition on line 79 was never true

80 code = "\n" + code 

81 for declaration in global_declarations: 

82 code = printer(declaration) + "\n" + code 

83 

84 return code 

85 

86 

87def get_global_declarations(ast): 

88 global_declarations = [] 

89 

90 def visit_node(sub_ast): 

91 nonlocal global_declarations 

92 if hasattr(sub_ast, "required_global_declarations"): 92 ↛ 93line 92 didn't jump to line 93, because the condition on line 92 was never true

93 global_declarations += sub_ast.required_global_declarations 

94 

95 if hasattr(sub_ast, "args"): 95 ↛ exitline 95 didn't return from function 'visit_node', because the condition on line 95 was never false

96 for node in sub_ast.args: 

97 visit_node(node) 

98 

99 visit_node(ast) 

100 

101 return sorted(set(global_declarations), key=str) 

102 

103 

104def get_headers(ast_node: Node) -> Set[str]: 

105 """Return a set of header files, necessary to compile the printed C-like code.""" 

106 headers = set() 

107 

108 if isinstance(ast_node, KernelFunction) and ast_node.instruction_set: 

109 headers.update(ast_node.instruction_set['headers']) 

110 

111 if hasattr(ast_node, 'headers'): 

112 headers.update(ast_node.headers) 

113 for a in ast_node.args: 

114 if isinstance(a, (sp.Expr, Node)): 

115 headers.update(get_headers(a)) 

116 

117 for g in get_global_declarations(ast_node): 117 ↛ 118line 117 didn't jump to line 118, because the loop on line 117 never started

118 if isinstance(g, Node): 

119 headers.update(get_headers(g)) 

120 

121 for h in headers: 

122 assert HEADER_REGEX.match(h), f'header /{h}/ does not follow the pattern /"..."/ or /<...>/' 

123 

124 return sorted(headers) 

125 

126 

127# --------------------------------------- Backend Specific Nodes ------------------------------------------------------- 

128 

129 

130class CustomCodeNode(Node): 

131 def __init__(self, code, symbols_read, symbols_defined, parent=None): 

132 super(CustomCodeNode, self).__init__(parent=parent) 

133 self._code = "\n" + code 

134 self._symbols_read = set(symbols_read) 

135 self._symbols_defined = set(symbols_defined) 

136 self.headers = [] 

137 

138 def get_code(self, dialect, vector_instruction_set, print_arg): 

139 return self._code 

140 

141 @property 

142 def args(self): 

143 return [] 

144 

145 @property 

146 def symbols_defined(self): 

147 return self._symbols_defined 

148 

149 @property 

150 def undefined_symbols(self): 

151 return self._symbols_read - self._symbols_defined 

152 

153 def __eq___(self, other): 

154 return self._code == other._code 

155 

156 def __hash__(self): 

157 return hash(self._code) 

158 

159 

160class PrintNode(CustomCodeNode): 

161 # noinspection SpellCheckingInspection 

162 def __init__(self, symbol_to_print): 

163 code = f'\nstd::cout << "{symbol_to_print.name} = " << {symbol_to_print.name} << std::endl; \n' 

164 super(PrintNode, self).__init__(code, symbols_read=[symbol_to_print], symbols_defined=set()) 

165 self.headers.append("<iostream>") 

166 

167 

168# ------------------------------------------- Printer ------------------------------------------------------------------ 

169 

170 

171# noinspection PyPep8Naming 

172class CBackend: 

173 

174 def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect='c'): 

175 if sympy_printer is None: 175 ↛ 181line 175 didn't jump to line 181, because the condition on line 175 was never false

176 if vector_instruction_set is not None: 

177 self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set) 

178 else: 

179 self.sympy_printer = CustomSympyPrinter() 

180 else: 

181 self.sympy_printer = sympy_printer 

182 

183 self._vector_instruction_set = vector_instruction_set 

184 self._indent = " " 

185 self._dialect = dialect 

186 self._signatureOnly = signature_only 

187 

188 def __call__(self, node): 

189 prev_is = VectorType.instruction_set 

190 VectorType.instruction_set = self._vector_instruction_set 

191 result = str(self._print(node)) 

192 VectorType.instruction_set = prev_is 

193 return result 

194 

195 def _print(self, node): 

196 if isinstance(node, str): 196 ↛ 197line 196 didn't jump to line 197, because the condition on line 196 was never true

197 return node 

198 for cls in type(node).__mro__: 

199 method_name = "_print_" + cls.__name__ 

200 if hasattr(self, method_name): 

201 return getattr(self, method_name)(node) 

202 raise NotImplementedError(self.__class__.__name__ + " does not support node of type " + node.__class__.__name__) 

203 

204 def _print_Type(self, node): 

205 return str(node) 

206 

207 def _print_KernelFunction(self, node): 

208 function_arguments = [f"{self._print(s.symbol.dtype)} {s.symbol.name}" for s in node.get_parameters()] 

209 launch_bounds = "" 

210 if self._dialect == 'cuda': 210 ↛ 211line 210 didn't jump to line 211, because the condition on line 210 was never true

211 max_threads = node.indexing.max_threads_per_block() 

212 if max_threads: 

213 launch_bounds = f"__launch_bounds__({max_threads}) " 

214 func_declaration = "FUNC_PREFIX %svoid %s(%s)" % (launch_bounds, node.function_name, 

215 ", ".join(function_arguments)) 

216 if self._signatureOnly: 216 ↛ 217line 216 didn't jump to line 217, because the condition on line 216 was never true

217 return func_declaration 

218 

219 body = self._print(node.body) 

220 return func_declaration + "\n" + body 

221 

222 def _print_Block(self, node): 

223 block_contents = "\n".join([self._print(child) for child in node.args]) 

224 return "{\n%s\n}" % (self._indent + self._indent.join(block_contents.splitlines(True))) 

225 

226 def _print_PragmaBlock(self, node): 

227 return f"{node.pragma_line}\n{self._print_Block(node)}" 

228 

229 def _print_LoopOverCoordinate(self, node): 

230 counter_symbol = node.loop_counter_name 

231 start = f"int64_t {counter_symbol} = {self.sympy_printer.doprint(node.start)}" 

232 condition = f"{counter_symbol} < {self.sympy_printer.doprint(node.stop)}" 

233 update = f"{counter_symbol} += {self.sympy_printer.doprint(node.step)}" 

234 loop_str = f"for ({start}; {condition}; {update})" 

235 

236 prefix = "\n".join(node.prefix_lines) 

237 if prefix: 

238 prefix += "\n" 

239 return f"{prefix}{loop_str}\n{self._print(node.body)}" 

240 

241 def _print_SympyAssignment(self, node): 

242 if node.is_declaration: 

243 if node.use_auto: 243 ↛ 244line 243 didn't jump to line 244, because the condition on line 243 was never true

244 data_type = 'auto ' 

245 else: 

246 if node.is_const: 

247 prefix = 'const ' 

248 else: 

249 prefix = '' 

250 data_type = prefix + self._print(node.lhs.dtype).replace(' const', '') + " " 

251 

252 return "%s%s = %s;" % (data_type, 

253 self.sympy_printer.doprint(node.lhs), 

254 self.sympy_printer.doprint(node.rhs)) 

255 else: 

256 lhs_type = get_type_of_expression(node.lhs) 

257 printed_mask = "" 

258 if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func): 

259 arg, data_type, aligned, nontemporal, mask, stride = node.lhs.args 

260 instr = 'storeU' 

261 if aligned: 

262 instr = 'stream' if nontemporal and 'stream' in self._vector_instruction_set else 'storeA' 

263 if mask != True: # NOQA 

264 instr = 'maskStoreA' if aligned else 'maskStoreU' 

265 if instr not in self._vector_instruction_set: 265 ↛ 269line 265 didn't jump to line 269, because the condition on line 265 was never false

266 self._vector_instruction_set[instr] = self._vector_instruction_set['store' + instr[-1]].format( 

267 '{0}', self._vector_instruction_set['blendv'].format( 

268 self._vector_instruction_set['load' + instr[-1]].format('{0}'), '{1}', '{2}')) 

269 printed_mask = self.sympy_printer.doprint(mask) 

270 if data_type.base_type.base_name == 'double': 

271 if self._vector_instruction_set['double'] == '__m256d': 271 ↛ 272line 271 didn't jump to line 272, because the condition on line 271 was never true

272 printed_mask = f"_mm256_castpd_si256({printed_mask})" 

273 elif self._vector_instruction_set['double'] == '__m128d': 273 ↛ 274line 273 didn't jump to line 274, because the condition on line 273 was never true

274 printed_mask = f"_mm_castpd_si128({printed_mask})" 

275 elif data_type.base_type.base_name == 'float': 275 ↛ 281line 275 didn't jump to line 281, because the condition on line 275 was never false

276 if self._vector_instruction_set['float'] == '__m256': 276 ↛ 277line 276 didn't jump to line 277, because the condition on line 276 was never true

277 printed_mask = f"_mm256_castps_si256({printed_mask})" 

278 elif self._vector_instruction_set['float'] == '__m128': 278 ↛ 279line 278 didn't jump to line 279, because the condition on line 278 was never true

279 printed_mask = f"_mm_castps_si128({printed_mask})" 

280 

281 rhs_type = get_type_of_expression(node.rhs) 

282 if type(rhs_type) is not VectorType: 

283 rhs = cast_func(node.rhs, VectorType(rhs_type)) 

284 else: 

285 rhs = node.rhs 

286 

287 ptr = "&" + self.sympy_printer.doprint(node.lhs.args[0]) 

288 

289 if stride != 1: 289 ↛ 290line 289 didn't jump to line 290, because the condition on line 289 was never true

290 instr = 'maskScatter' if mask != True else 'scatter' # NOQA 

291 return self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs), 

292 stride, printed_mask) + ';' 

293 

294 pre_code = '' 

295 if nontemporal and 'cachelineZero' in self._vector_instruction_set: 

296 pre_code = f"if (((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) == 0) " + "{\n\t" + \ 

297 self._vector_instruction_set['cachelineZero'].format(ptr) + ';\n}\n' 

298 

299 code = self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs), 

300 printed_mask) + ';' 

301 flushcond = f"((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) == {CachelineSize.last_symbol}" 

302 if nontemporal and 'flushCacheline' in self._vector_instruction_set: 302 ↛ 303line 302 didn't jump to line 303, because the condition on line 302 was never true

303 code2 = self._vector_instruction_set['flushCacheline'].format( 

304 ptr, self.sympy_printer.doprint(rhs)) + ';' 

305 code = f"{code}\nif ({flushcond}) {{\n\t{code2}\n}}" 

306 elif nontemporal and 'storeAAndFlushCacheline' in self._vector_instruction_set: 

307 tmpvar = '_tmp_' + hashlib.sha1(self.sympy_printer.doprint(rhs).encode('ascii')).hexdigest()[:8] 

308 code = 'const ' + self._print(node.lhs.dtype).replace(' const', '') + ' ' + tmpvar + ' = ' \ 

309 + self.sympy_printer.doprint(rhs) + ';' 

310 code1 = self._vector_instruction_set[instr].format(ptr, tmpvar, printed_mask) + ';' 

311 code2 = self._vector_instruction_set['storeAAndFlushCacheline'].format(ptr, tmpvar, printed_mask) \ 

312 + ';' 

313 code += f"\nif ({flushcond}) {{\n\t{code2}\n}} else {{\n\t{code1}\n}}" 

314 return pre_code + code 

315 else: 

316 return f"{self.sympy_printer.doprint(node.lhs)} = {self.sympy_printer.doprint(node.rhs)};" 

317 

318 def _print_NontemporalFence(self, _): 

319 if 'streamFence' in self._vector_instruction_set: 319 ↛ 320line 319 didn't jump to line 320, because the condition on line 319 was never true

320 return self._vector_instruction_set['streamFence'] + ';' 

321 else: 

322 return '' 

323 

324 def _print_CachelineSize(self, node): 

325 if 'cachelineSize' in self._vector_instruction_set: 325 ↛ 332line 325 didn't jump to line 332, because the condition on line 325 was never false

326 code = f'const size_t {node.symbol} = {self._vector_instruction_set["cachelineSize"]};\n' 

327 code += f'const size_t {node.mask_symbol} = {node.symbol} - 1;\n' 

328 vectorsize = self._vector_instruction_set['bytes'] 

329 code += f'const size_t {node.last_symbol} = {node.symbol} - {vectorsize};\n' 

330 return code 

331 else: 

332 return '' 

333 

334 def _print_TemporaryMemoryAllocation(self, node): 

335 if self._vector_instruction_set: 

336 align = self._vector_instruction_set['bytes'] 

337 else: 

338 align = node.symbol.dtype.base_type.numpy_dtype.itemsize 

339 

340 np_dtype = node.symbol.dtype.base_type.numpy_dtype 

341 required_size = np_dtype.itemsize * node.size + align 

342 size = modulo_ceil(required_size, align) 

343 code = "#if defined(_MSC_VER)\n" 

344 code += "{dtype} {name}=({dtype})_aligned_malloc({size}, {align}) + {offset};\n" 

345 code += "#elif __cplusplus >= 201703L || __STDC_VERSION__ >= 201112L\n" 

346 code += "{dtype} {name}=({dtype})aligned_alloc({align}, {size}) + {offset};\n" 

347 code += "#else\n" 

348 code += "{dtype} {name};\n" 

349 code += "posix_memalign((void**) &{name}, {align}, {size});\n" 

350 code += "{name} += {offset};\n" 

351 code += "#endif" 

352 return code.format(dtype=node.symbol.dtype, 

353 name=self.sympy_printer.doprint(node.symbol.name), 

354 size=self.sympy_printer.doprint(size), 

355 offset=int(node.offset(align)), 

356 align=align) 

357 

358 def _print_TemporaryMemoryFree(self, node): 

359 if self._vector_instruction_set: 

360 align = self._vector_instruction_set['bytes'] 

361 else: 

362 align = node.symbol.dtype.base_type.numpy_dtype.itemsize 

363 

364 code = "#if defined(_MSC_VER)\n" 

365 code += "_aligned_free(%s - %d);\n" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align)) 

366 code += "#else\n" 

367 code += "free(%s - %d);\n" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align)) 

368 code += "#endif" 

369 return code 

370 

371 def _print_SkipIteration(self, _): 

372 return "continue;" 

373 

374 def _print_CustomCodeNode(self, node): 

375 return node.get_code(self._dialect, self._vector_instruction_set, print_arg=self.sympy_printer._print) 

376 

377 def _print_SourceCodeComment(self, node): 

378 return f"/* {node.text } */" 

379 

380 def _print_EmptyLine(self, node): 

381 return "" 

382 

383 def _print_Conditional(self, node): 

384 if type(node.condition_expr) is BooleanTrue: 384 ↛ 385line 384 didn't jump to line 385, because the condition on line 384 was never true

385 return self._print_Block(node.true_block) 

386 elif type(node.condition_expr) is BooleanFalse: 386 ↛ 387line 386 didn't jump to line 387, because the condition on line 386 was never true

387 return self._print_Block(node.false_block) 

388 cond_type = get_type_of_expression(node.condition_expr) 

389 if isinstance(cond_type, VectorType): 389 ↛ 390line 389 didn't jump to line 390, because the condition on line 389 was never true

390 raise ValueError("Problem with Conditional inside vectorized loop - use vec_any or vec_all") 

391 condition_expr = self.sympy_printer.doprint(node.condition_expr) 

392 true_block = self._print_Block(node.true_block) 

393 result = f"if ({condition_expr})\n{true_block} " 

394 if node.false_block: 394 ↛ 395line 394 didn't jump to line 395, because the condition on line 394 was never true

395 false_block = self._print_Block(node.false_block) 

396 result += f"else {false_block}" 

397 return result 

398 

399 

400# ------------------------------------------ Helper function & classes ------------------------------------------------- 

401 

402 

403# noinspection PyPep8Naming 

404class CustomSympyPrinter(CCodePrinter): 

405 

406 def __init__(self): 

407 super(CustomSympyPrinter, self).__init__() 

408 self._float_type = create_type("float32") 

409 

410 def _print_Pow(self, expr): 

411 """Don't use std::pow function, for small integer exponents, write as multiplication""" 

412 if not expr.free_symbols: 

413 return self._typed_number(expr.evalf(), get_type_of_expression(expr)) 

414 

415 if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: 

416 return f"({self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False))})" 

417 elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0: 

418 return f"1 / ({self._print(sp.Mul(*([expr.base] * -expr.exp), evaluate=False))})" 

419 else: 

420 return super(CustomSympyPrinter, self)._print_Pow(expr) 

421 

422 def _print_Rational(self, expr): 

423 """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0""" 

424 res = str(expr.evalf().num) 

425 return res 

426 

427 def _print_Equality(self, expr): 

428 """Equality operator is not printable in default printer""" 

429 return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))' 

430 

431 def _print_Piecewise(self, expr): 

432 """Print piecewise in one line (remove newlines)""" 

433 result = super(CustomSympyPrinter, self)._print_Piecewise(expr) 

434 return result.replace("\n", "") 

435 

436 def _print_Abs(self, expr): 

437 if expr.args[0].is_integer: 437 ↛ 438line 437 didn't jump to line 438, because the condition on line 437 was never true

438 return f'abs({self._print(expr.args[0])})' 

439 else: 

440 return f'fabs({self._print(expr.args[0])})' 

441 

442 def _print_Type(self, node): 

443 return str(node) 

444 

445 def _print_Function(self, expr): 

446 infix_functions = { 

447 bitwise_xor: '^', 

448 bit_shift_right: '>>', 

449 bit_shift_left: '<<', 

450 bitwise_or: '|', 

451 bitwise_and: '&', 

452 } 

453 if hasattr(expr, 'to_c'): 

454 return expr.to_c(self._print) 

455 if isinstance(expr, reinterpret_cast_func): 455 ↛ 456line 455 didn't jump to line 456, because the condition on line 455 was never true

456 arg, data_type = expr.args 

457 return f"*(({self._print(PointerType(data_type, restrict=False))})(& {self._print(arg)}))" 

458 elif isinstance(expr, address_of): 458 ↛ 459line 458 didn't jump to line 459, because the condition on line 458 was never true

459 assert len(expr.args) == 1, "address_of must only have one argument" 

460 return f"&({self._print(expr.args[0])})" 

461 elif isinstance(expr, cast_func): 461 ↛ 467line 461 didn't jump to line 467, because the condition on line 461 was never false

462 arg, data_type = expr.args 

463 if isinstance(arg, sp.Number) and arg.is_finite: 463 ↛ 466line 463 didn't jump to line 466, because the condition on line 463 was never false

464 return self._typed_number(arg, data_type) 

465 else: 

466 return f"(({data_type})({self._print(arg)}))" 

467 elif isinstance(expr, fast_division): 

468 return f"({self._print(expr.args[0] / expr.args[1])})" 

469 elif isinstance(expr, fast_sqrt): 

470 return f"({self._print(sp.sqrt(expr.args[0]))})" 

471 elif isinstance(expr, vec_any) or isinstance(expr, vec_all): 

472 return self._print(expr.args[0]) 

473 elif isinstance(expr, fast_inv_sqrt): 

474 return f"({self._print(1 / sp.sqrt(expr.args[0]))})" 

475 elif isinstance(expr, sp.Abs): 

476 return f"abs({self._print(expr.args[0])})" 

477 elif isinstance(expr, sp.Max): 

478 return self._print(expr) 

479 elif isinstance(expr, sp.Mod): 

480 if expr.args[0].is_integer and expr.args[1].is_integer: 

481 return f"({self._print(expr.args[0])} % {self._print(expr.args[1])})" 

482 else: 

483 return f"fmod({self._print(expr.args[0])}, {self._print(expr.args[1])})" 

484 elif expr.func in infix_functions: 

485 return f"({self._print(expr.args[0])} {infix_functions[expr.func]} {self._print(expr.args[1])})" 

486 elif expr.func == int_power_of_2: 

487 return f"(1 << ({self._print(expr.args[0])}))" 

488 elif expr.func == int_div: 

489 return f"(({self._print(expr.args[0])}) / ({self._print(expr.args[1])}))" 

490 else: 

491 name = expr.name if hasattr(expr, 'name') else expr.__class__.__name__ 

492 arg_str = ', '.join(self._print(a) for a in expr.args) 

493 return f'{name}({arg_str})' 

494 

495 def _typed_number(self, number, dtype): 

496 res = self._print(number) 

497 if dtype.numpy_dtype == np.float32: 497 ↛ 498line 497 didn't jump to line 498, because the condition on line 497 was never true

498 return res + '.0f' if '.' not in res else res + 'f' 

499 elif dtype.numpy_dtype == np.float64: 499 ↛ 502line 499 didn't jump to line 502, because the condition on line 499 was never false

500 return res + '.0' if '.' not in res else res 

501 else: 

502 return res 

503 

504 def _print_Sum(self, expr): 

505 template = """[&]() {{ 

506 {dtype} sum = ({dtype}) 0; 

507 for ( {iterator_dtype} {var} = {start}; {condition}; {var} += {increment} ) {{ 

508 sum += {expr}; 

509 }} 

510 return sum; 

511}}()""" 

512 var = expr.limits[0][0] 

513 start = expr.limits[0][1] 

514 end = expr.limits[0][2] 

515 code = template.format( 

516 dtype=get_type_of_expression(expr.args[0]), 

517 iterator_dtype='int', 

518 var=self._print(var), 

519 start=self._print(start), 

520 end=self._print(end), 

521 expr=self._print(expr.function), 

522 increment=str(1), 

523 condition=self._print(var) + ' <= ' + self._print(end) # if start < end else '>=' 

524 ) 

525 return code 

526 

527 def _print_Product(self, expr): 

528 template = """[&]() {{ 

529 {dtype} product = ({dtype}) 1; 

530 for ( {iterator_dtype} {var} = {start}; {condition}; {var} += {increment} ) {{ 

531 product *= {expr}; 

532 }} 

533 return product; 

534}}()""" 

535 var = expr.limits[0][0] 

536 start = expr.limits[0][1] 

537 end = expr.limits[0][2] 

538 code = template.format( 

539 dtype=get_type_of_expression(expr.args[0]), 

540 iterator_dtype='int', 

541 var=self._print(var), 

542 start=self._print(start), 

543 end=self._print(end), 

544 expr=self._print(expr.function), 

545 increment=str(1), 

546 condition=self._print(var) + ' <= ' + self._print(end) # if start < end else '>=' 

547 ) 

548 return code 

549 

550 def _print_ConditionalFieldAccess(self, node): 

551 return self._print(sp.Piecewise((node.outofbounds_value, node.outofbounds_condition), (node.access, True))) 

552 

553 def _print_Max(self, expr): 

554 def inner_print_max(args): 

555 if len(args) == 1: 

556 return self._print(args[0]) 

557 half = len(args) // 2 

558 a = inner_print_max(args[:half]) 

559 b = inner_print_max(args[half:]) 

560 return f"(({a} > {b}) ? {a} : {b})" 

561 return inner_print_max(expr.args) 

562 

563 def _print_Min(self, expr): 

564 def inner_print_min(args): 

565 if len(args) == 1: 

566 return self._print(args[0]) 

567 half = len(args) // 2 

568 a = inner_print_min(args[:half]) 

569 b = inner_print_min(args[half:]) 

570 return f"(({a} < {b}) ? {a} : {b})" 

571 return inner_print_min(expr.args) 

572 

573 def _print_re(self, expr): 

574 return f"real({self._print(expr.args[0])})" 

575 

576 def _print_im(self, expr): 

577 return f"imag({self._print(expr.args[0])})" 

578 

579 def _print_ImaginaryUnit(self, expr): 

580 return "complex<double>{0,1}" 

581 

582 def _print_TypedImaginaryUnit(self, expr): 

583 if expr.dtype.numpy_dtype == np.complex64: 

584 return "complex<float>{0,1}" 

585 elif expr.dtype.numpy_dtype == np.complex128: 

586 return "complex<double>{0,1}" 

587 else: 

588 raise NotImplementedError( 

589 "only complex64 and complex128 supported") 

590 

591 def _print_Complex(self, expr): 

592 return self._typed_number(expr, np.complex64) 

593 

594 

595# noinspection PyPep8Naming 

596class VectorizedCustomSympyPrinter(CustomSympyPrinter): 

597 SummandInfo = namedtuple("SummandInfo", ['sign', 'term']) 

598 

599 def __init__(self, instruction_set): 

600 super(VectorizedCustomSympyPrinter, self).__init__() 

601 self.instruction_set = instruction_set 

602 

603 def _scalarFallback(self, func_name, expr, *args, **kwargs): 

604 expr_type = get_type_of_expression(expr) 

605 if type(expr_type) is not VectorType: 

606 return getattr(super(VectorizedCustomSympyPrinter, self), func_name)(expr, *args, **kwargs) 

607 else: 

608 assert self.instruction_set['width'] == expr_type.width 

609 return None 

610 

611 def _print_Abs(self, expr): 

612 if 'abs' in self.instruction_set and isinstance(expr.args[0], vector_memory_access): 

613 return self.instruction_set['abs'].format(self._print(expr.args[0])) 

614 return super()._print_Abs(expr) 

615 

616 def _print_Function(self, expr): 

617 if isinstance(expr, vector_memory_access): 

618 arg, data_type, aligned, _, mask, stride = expr.args 

619 if stride != 1: 619 ↛ 620line 619 didn't jump to line 620, because the condition on line 619 was never true

620 return self.instruction_set['gather'].format("& " + self._print(arg), stride) 

621 instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU'] 

622 return instruction.format("& " + self._print(arg)) 

623 elif isinstance(expr, cast_func): 

624 arg, data_type = expr.args 

625 if type(data_type) is VectorType: 

626 # vector_memory_access is a cast_func itself so it should't be directly inside a cast_func 

627 assert not isinstance(arg, vector_memory_access) 

628 if isinstance(arg, sp.Tuple): 

629 is_boolean = get_type_of_expression(arg[0]) == create_type("bool") 

630 is_integer = get_type_of_expression(arg[0]) == create_type("int") 

631 printed_args = [self._print(a) for a in arg] 

632 instruction = 'makeVecBool' if is_boolean else 'makeVecInt' if is_integer else 'makeVec' 

633 if instruction == 'makeVecInt' and 'makeVecIndex' in self.instruction_set: 633 ↛ 634line 633 didn't jump to line 634, because the condition on line 633 was never true

634 increments = np.array(arg)[1:] - np.array(arg)[:-1] 

635 if len(set(increments)) == 1: 

636 return self.instruction_set['makeVecIndex'].format(printed_args[0], increments[0]) 

637 return self.instruction_set[instruction].format(*printed_args) 

638 else: 

639 is_boolean = get_type_of_expression(arg) == create_type("bool") 

640 is_integer = get_type_of_expression(arg) == create_type("int") or \ 

641 (isinstance(arg, TypedSymbol) and arg.dtype.is_int()) 

642 instruction = 'makeVecConstBool' if is_boolean else \ 

643 'makeVecConstInt' if is_integer else 'makeVecConst' 

644 return self.instruction_set[instruction].format(self._print(arg)) 

645 elif expr.func == fast_division: 

646 result = self._scalarFallback('_print_Function', expr) 

647 if not result: 647 ↛ 649line 647 didn't jump to line 649, because the condition on line 647 was never false

648 result = self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1])) 

649 return result 

650 elif expr.func == fast_sqrt: 

651 return f"({self._print(sp.sqrt(expr.args[0]))})" 

652 elif expr.func == fast_inv_sqrt: 

653 result = self._scalarFallback('_print_Function', expr) 

654 if not result: 654 ↛ 671line 654 didn't jump to line 671, because the condition on line 654 was never false

655 if 'rsqrt' in self.instruction_set: 655 ↛ 658line 655 didn't jump to line 658, because the condition on line 655 was never false

656 return self.instruction_set['rsqrt'].format(self._print(expr.args[0])) 

657 else: 

658 return f"({self._print(1 / sp.sqrt(expr.args[0]))})" 

659 elif isinstance(expr, vec_any) or isinstance(expr, vec_all): 

660 instr = 'any' if isinstance(expr, vec_any) else 'all' 

661 expr_type = get_type_of_expression(expr.args[0]) 

662 if type(expr_type) is not VectorType: 662 ↛ 663line 662 didn't jump to line 663, because the condition on line 662 was never true

663 return self._print(expr.args[0]) 

664 else: 

665 if isinstance(expr.args[0], sp.Rel): 665 ↛ 669line 665 didn't jump to line 669, because the condition on line 665 was never false

666 op = expr.args[0].rel_op 

667 if (instr, op) in self.instruction_set: 667 ↛ 669line 667 didn't jump to line 669, because the condition on line 667 was never false

668 return self.instruction_set[(instr, op)].format(*[self._print(a) for a in expr.args[0].args]) 

669 return self.instruction_set[instr].format(self._print(expr.args[0])) 

670 

671 return super(VectorizedCustomSympyPrinter, self)._print_Function(expr) 

672 

673 def _print_And(self, expr): 

674 result = self._scalarFallback('_print_And', expr) 

675 if result: 675 ↛ 676line 675 didn't jump to line 676, because the condition on line 675 was never true

676 return result 

677 

678 arg_strings = [self._print(a) for a in expr.args] 

679 assert len(arg_strings) > 0 

680 result = arg_strings[0] 

681 for item in arg_strings[1:]: 

682 result = self.instruction_set['&'].format(result, item) 

683 return result 

684 

685 def _print_Or(self, expr): 

686 result = self._scalarFallback('_print_Or', expr) 

687 if result: 687 ↛ 688line 687 didn't jump to line 688, because the condition on line 687 was never true

688 return result 

689 

690 arg_strings = [self._print(a) for a in expr.args] 

691 assert len(arg_strings) > 0 

692 result = arg_strings[0] 

693 for item in arg_strings[1:]: 

694 result = self.instruction_set['|'].format(result, item) 

695 return result 

696 

697 def _print_Add(self, expr, order=None): 

698 try: 

699 result = self._scalarFallback('_print_Add', expr) 

700 except Exception: 

701 result = None 

702 if result: 

703 return result 

704 args = expr.args 

705 

706 # special treatment for all-integer args, for loop index arithmetic until we have proper int vectorization 

707 suffix = "" 

708 if all([(type(e) is cast_func and str(e.dtype) == self.instruction_set['int']) or isinstance(e, sp.Integer) 

709 or (type(e) is TypedSymbol and isinstance(e.dtype, BasicType) and e.dtype.is_int()) for e in args]): 

710 dtype = set([e.dtype for e in args if type(e) is cast_func]) 

711 assert len(dtype) == 1 

712 dtype = dtype.pop() 

713 args = [cast_func(e, dtype) if (isinstance(e, sp.Integer) or isinstance(e, TypedSymbol)) else e 

714 for e in args] 

715 suffix = "int" 

716 

717 summands = [] 

718 for term in args: 

719 if term.func == sp.Mul: 

720 sign, t = self._print_Mul(term, inside_add=True) 

721 else: 

722 t = self._print(term) 

723 sign = 1 

724 summands.append(self.SummandInfo(sign, t)) 

725 # Use positive terms first 

726 summands.sort(key=lambda e: e.sign, reverse=True) 

727 # if no positive term exists, prepend a zero 

728 if summands[0].sign == -1: 728 ↛ 729line 728 didn't jump to line 729, because the condition on line 728 was never true

729 summands.insert(0, self.SummandInfo(1, "0")) 

730 

731 assert len(summands) >= 2 

732 processed = summands[0].term 

733 for summand in summands[1:]: 

734 func = self.instruction_set['-' + suffix] if summand.sign == -1 else self.instruction_set['+' + suffix] 

735 processed = func.format(processed, summand.term) 

736 return processed 

737 

738 def _print_Pow(self, expr): 

739 result = self._scalarFallback('_print_Pow', expr) 

740 if result: 740 ↛ 741line 740 didn't jump to line 741, because the condition on line 740 was never true

741 return result 

742 

743 one = self.instruction_set['makeVecConst'].format(1.0) 

744 

745 if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: 

746 return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")" 

747 elif expr.exp == -1: 

748 one = self.instruction_set['makeVecConst'].format(1.0) 

749 return self.instruction_set['/'].format(one, self._print(expr.base)) 

750 elif expr.exp == 0.5: 

751 return self.instruction_set['sqrt'].format(self._print(expr.base)) 

752 elif expr.exp == -0.5: 

753 root = self.instruction_set['sqrt'].format(self._print(expr.base)) 

754 return self.instruction_set['/'].format(one, root) 

755 elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0: 755 ↛ 759line 755 didn't jump to line 759, because the condition on line 755 was never false

756 return self.instruction_set['/'].format(one, 

757 self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False))) 

758 else: 

759 raise ValueError("Generic exponential not supported: " + str(expr)) 

760 

761 def _print_Mul(self, expr, inside_add=False): 

762 # noinspection PyProtectedMember 

763 from sympy.core.mul import _keep_coeff 

764 

765 result = self._scalarFallback('_print_Mul', expr) 

766 if result: 

767 return result 

768 

769 c, e = expr.as_coeff_Mul() 

770 if c < 0: 770 ↛ 771line 770 didn't jump to line 771, because the condition on line 770 was never true

771 expr = _keep_coeff(-c, e) 

772 sign = -1 

773 else: 

774 sign = 1 

775 

776 a = [] # items in the numerator 

777 b = [] # items that are in the denominator (if any) 

778 

779 # Gather args for numerator/denominator 

780 for item in expr.as_ordered_factors(): 

781 if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative: 

782 if item.exp != -1: 782 ↛ 783line 782 didn't jump to line 783, because the condition on line 782 was never true

783 b.append(sp.Pow(item.base, -item.exp, evaluate=False)) 

784 else: 

785 b.append(sp.Pow(item.base, -item.exp)) 

786 else: 

787 a.append(item) 

788 

789 a = a or [S.One] 

790 

791 a_str = [self._print(x) for x in a] 

792 b_str = [self._print(x) for x in b] 

793 

794 result = a_str[0] 

795 for item in a_str[1:]: 

796 result = self.instruction_set['*'].format(result, item) 

797 

798 if len(b) > 0: 

799 denominator_str = b_str[0] 

800 for item in b_str[1:]: 800 ↛ 801line 800 didn't jump to line 801, because the loop on line 800 never started

801 denominator_str = self.instruction_set['*'].format(denominator_str, item) 

802 result = self.instruction_set['/'].format(result, denominator_str) 

803 

804 if inside_add: 

805 return sign, result 

806 else: 

807 if sign < 0: 807 ↛ 808line 807 didn't jump to line 808, because the condition on line 807 was never true

808 return self.instruction_set['*'].format(self._print(S.NegativeOne), result) 

809 else: 

810 return result 

811 

812 def _print_Relational(self, expr): 

813 result = self._scalarFallback('_print_Relational', expr) 

814 if result: 

815 return result 

816 return self.instruction_set[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs)) 

817 

818 def _print_Equality(self, expr): 

819 result = self._scalarFallback('_print_Equality', expr) 

820 if result: 820 ↛ 821line 820 didn't jump to line 821, because the condition on line 820 was never true

821 return result 

822 return self.instruction_set['=='].format(self._print(expr.lhs), self._print(expr.rhs)) 

823 

824 def _print_Piecewise(self, expr): 

825 result = self._scalarFallback('_print_Piecewise', expr) 

826 if result: 826 ↛ 827line 826 didn't jump to line 827, because the condition on line 826 was never true

827 return result 

828 

829 if expr.args[-1].cond.args[0] is not sp.sympify(True): 829 ↛ 832line 829 didn't jump to line 832, because the condition on line 829 was never true

830 # We need the last conditional to be a True, otherwise the resulting 

831 # function may not return a result. 

832 raise ValueError("All Piecewise expressions must contain an " 

833 "(expr, True) statement to be used as a default " 

834 "condition. Without one, the generated " 

835 "expression may not evaluate to anything under " 

836 "some condition.") 

837 

838 result = self._print(expr.args[-1][0]) 

839 for true_expr, condition in reversed(expr.args[:-1]): 

840 if isinstance(condition, cast_func) and get_type_of_expression(condition.args[0]) == create_type("bool"): 

841 if not KERNCRAFT_NO_TERNARY_MODE: 841 ↛ 845line 841 didn't jump to line 845, because the condition on line 841 was never false

842 result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr), 

843 result) 

844 else: 

845 print("Warning - skipping ternary op") 

846 else: 

847 # noinspection SpellCheckingInspection 

848 result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition)) 

849 return result