1import functools 

2 

3import llvmlite.ir as ir 

4import llvmlite.llvmpy.core as lc 

5import sympy as sp 

6from sympy import Indexed, S 

7from sympy.printing.printer import Printer 

8 

9from pystencils.assignment import Assignment 

10from pystencils.data_types import ( 

11 collate_types, create_composite_type_from_string, create_type, get_type_of_expression, 

12 to_llvm_type) 

13from pystencils.llvm.control_flow import Loop 

14 

15 

16# From Numba 

17def set_cuda_kernel(lfunc): 

18 from llvmlite.llvmpy.core import MetaData, MetaDataString, Constant, Type 

19 

20 m = lfunc.module 

21 

22 ops = lfunc, MetaDataString.get(m, "kernel"), Constant.int(Type.int(), 1) 

23 md = MetaData.get(m, ops) 

24 

25 nmd = m.get_or_insert_named_metadata('nvvm.annotations') 

26 nmd.add(md) 

27 

28 # set nvvm ir version 

29 i32 = ir.IntType(32) 

30 md_ver = m.add_metadata([i32(1), i32(2), i32(2), i32(0)]) 

31 m.add_named_metadata('nvvmir.version', md_ver) 

32 

33 

34# From Numba 

35def _call_sreg(builder, name): 

36 module = builder.module 

37 fnty = lc.Type.function(lc.Type.int(), ()) 

38 fn = module.get_or_insert_function(fnty, name=name) 

39 return builder.call(fn, ()) 

40 

41 

42def generate_llvm(ast_node, module=None, builder=None, target='cpu'): 

43 """Prints the ast as llvm code.""" 

44 if module is None: 

45 module = lc.Module() 

46 if builder is None: 

47 builder = ir.IRBuilder() 

48 printer = LLVMPrinter(module, builder, target=target) 

49 return printer._print(ast_node) 

50 

51 

52# noinspection PyPep8Naming 

53class LLVMPrinter(Printer): 

54 """Convert expressions to LLVM IR""" 

55 

56 def __init__(self, module, builder, fn=None, target='cpu', *args, **kwargs): 

57 self.func_arg_map = kwargs.pop("func_arg_map", {}) 

58 super(LLVMPrinter, self).__init__(*args, **kwargs) 

59 self.fp_type = ir.DoubleType() 

60 self.fp_pointer = self.fp_type.as_pointer() 

61 self.integer = ir.IntType(64) 

62 self.integer_pointer = self.integer.as_pointer() 

63 self.void = ir.VoidType() 

64 self.module = module 

65 self.builder = builder 

66 self.fn = fn 

67 self.ext_fn = {} # keep track of wrappers to external functions 

68 self.tmp_var = {} 

69 self.target = target 

70 

71 def _add_tmp_var(self, name, value): 

72 self.tmp_var[name] = value 

73 

74 def _remove_tmp_var(self, name): 

75 del self.tmp_var[name] 

76 

77 def _print_Number(self, n): 

78 if get_type_of_expression(n) == create_type("int"): 

79 return ir.Constant(self.integer, int(n)) 

80 elif get_type_of_expression(n) == create_type("double"): 

81 return ir.Constant(self.fp_type, float(n)) 

82 else: 

83 raise NotImplementedError("Numbers can only have int and double", n) 

84 

85 def _print_Float(self, expr): 

86 return ir.Constant(self.fp_type, float(expr)) 

87 

88 def _print_Integer(self, expr): 

89 return ir.Constant(self.integer, int(expr)) 

90 

91 def _print_int(self, i): 

92 return ir.Constant(self.integer, i) 

93 

94 def _print_Symbol(self, s): 

95 val = self.tmp_var.get(s) 

96 if not val: 

97 # look up parameter with name s 

98 val = self.func_arg_map.get(s.name) 

99 if not val: 

100 raise LookupError(f"Symbol not found: {s}") 

101 return val 

102 

103 def _print_Pow(self, expr): 

104 base0 = self._print(expr.base) 

105 if expr.exp == S.NegativeOne: 

106 return self.builder.fdiv(ir.Constant(self.fp_type, 1.0), base0) 

107 if expr.exp == S.Half: 

108 fn = self.ext_fn.get("sqrt") 

109 if not fn: 

110 fn_type = ir.FunctionType(self.fp_type, [self.fp_type]) 

111 fn = ir.Function(self.module, fn_type, "sqrt") 

112 self.ext_fn["sqrt"] = fn 

113 return self.builder.call(fn, [base0], "sqrt") 

114 if expr.exp == 2: 

115 return self.builder.fmul(base0, base0) 

116 elif expr.exp == 3: 

117 return self.builder.fmul(self.builder.fmul(base0, base0), base0) 

118 

119 exp0 = self._print(expr.exp) 

120 fn = self.ext_fn.get("pow") 

121 if not fn: 

122 fn_type = ir.FunctionType(self.fp_type, [self.fp_type, self.fp_type]) 

123 fn = ir.Function(self.module, fn_type, "pow") 

124 self.ext_fn["pow"] = fn 

125 return self.builder.call(fn, [base0, exp0], "pow") 

126 

127 def _print_Mul(self, expr): 

128 nodes = [self._print(a) for a in expr.args] 

129 e = nodes[0] 

130 if get_type_of_expression(expr) == create_type('double'): 

131 mul = self.builder.fmul 

132 else: # int TODO unsigned/signed 

133 mul = self.builder.mul 

134 for node in nodes[1:]: 

135 e = mul(e, node) 

136 return e 

137 

138 def _print_Add(self, expr): 

139 nodes = [self._print(a) for a in expr.args] 

140 e = nodes[0] 

141 if get_type_of_expression(expr) == create_type('double'): 

142 add = self.builder.fadd 

143 else: # int TODO unsigned/signed 

144 add = self.builder.add 

145 for node in nodes[1:]: 

146 e = add(e, node) 

147 return e 

148 

149 def _print_Or(self, expr): 

150 nodes = [self._print(a) for a in expr.args] 

151 e = nodes[0] 

152 for node in nodes[1:]: 

153 e = self.builder.or_(e, node) 

154 return e 

155 

156 def _print_And(self, expr): 

157 nodes = [self._print(a) for a in expr.args] 

158 e = nodes[0] 

159 for node in nodes[1:]: 

160 e = self.builder.and_(e, node) 

161 return e 

162 

163 def _print_StrictLessThan(self, expr): 

164 return self._comparison('<', expr) 

165 

166 def _print_LessThan(self, expr): 

167 return self._comparison('<=', expr) 

168 

169 def _print_StrictGreaterThan(self, expr): 

170 return self._comparison('>', expr) 

171 

172 def _print_GreaterThan(self, expr): 

173 return self._comparison('>=', expr) 

174 

175 def _print_Unequality(self, expr): 

176 return self._comparison('!=', expr) 

177 

178 def _print_Equality(self, expr): 

179 return self._comparison('==', expr) 

180 

181 def _comparison(self, cmpop, expr): 

182 if collate_types([get_type_of_expression(arg) for arg in expr.args]) == create_type('double'): 

183 comparison = self.builder.fcmp_unordered 

184 else: 

185 comparison = self.builder.icmp_signed 

186 return comparison(cmpop, self._print(expr.lhs), self._print(expr.rhs)) 

187 

188 def _print_KernelFunction(self, func): 

189 # KernelFunction does not posses a return type 

190 return_type = self.void 

191 parameter_type = [] 

192 parameters = func.get_parameters() 

193 for parameter in parameters: 

194 parameter_type.append(to_llvm_type(parameter.symbol.dtype, nvvm_target=self.target == 'gpu')) 

195 func_type = ir.FunctionType(return_type, tuple(parameter_type)) 

196 name = func.function_name 

197 fn = ir.Function(self.module, func_type, name) 

198 self.ext_fn[name] = fn 

199 

200 # set proper names to arguments 

201 for i, arg in enumerate(fn.args): 

202 arg.name = parameters[i].symbol.name 

203 self.func_arg_map[parameters[i].symbol.name] = arg 

204 

205 # func.attributes.add("inlinehint") 

206 # func.attributes.add("argmemonly") 

207 block = fn.append_basic_block(name="entry") 

208 self.builder = ir.IRBuilder(block) # TODO use goto_block instead 

209 self._print(func.body) 

210 self.builder.ret_void() 

211 self.fn = fn 

212 if self.target == 'gpu': 

213 set_cuda_kernel(fn) 

214 

215 return fn 

216 

217 def _print_Block(self, block): 

218 for node in block.args: 

219 self._print(node) 

220 

221 def _print_LoopOverCoordinate(self, loop): 

222 with Loop(self.builder, self._print(loop.start), self._print(loop.stop), self._print(loop.step), 

223 loop.loop_counter_name, loop.loop_counter_symbol.name) as i: 

224 self._add_tmp_var(loop.loop_counter_symbol, i) 

225 self._print(loop.body) 

226 self._remove_tmp_var(loop.loop_counter_symbol) 

227 

228 def _print_SympyAssignment(self, assignment): 

229 expr = self._print(assignment.rhs) 

230 lhs = assignment.lhs 

231 if isinstance(lhs, Indexed): 

232 ptr = self._print(lhs.base.label) 

233 index = self._print(lhs.args[1]) 

234 gep = self.builder.gep(ptr, [index]) 

235 return self.builder.store(expr, gep) 

236 self.func_arg_map[assignment.lhs.name] = expr 

237 return expr 

238 

239 def _print_boolean_cast_func(self, conversion): 

240 return self._print_cast_func(conversion) 

241 

242 def _print_cast_func(self, conversion): 

243 node = self._print(conversion.args[0]) 

244 to_dtype = get_type_of_expression(conversion) 

245 from_dtype = get_type_of_expression(conversion.args[0]) 

246 if from_dtype == to_dtype: 

247 return self._print(conversion.args[0]) 

248 

249 # (From, to) 

250 decision = { 

251 (create_composite_type_from_string("int32"), 

252 create_composite_type_from_string("int64")): functools.partial(self.builder.zext, node, self.integer), 

253 (create_composite_type_from_string("int16"), 

254 create_composite_type_from_string("int64")): functools.partial(self.builder.zext, node, self.integer), 

255 (create_composite_type_from_string("int"), 

256 create_composite_type_from_string("double")): functools.partial(self.builder.sitofp, node, self.fp_type), 

257 (create_composite_type_from_string("int16"), 

258 create_composite_type_from_string("double")): functools.partial(self.builder.sitofp, node, self.fp_type), 

259 (create_composite_type_from_string("double"), 

260 create_composite_type_from_string("int")): functools.partial(self.builder.fptosi, node, self.integer), 

261 (create_composite_type_from_string("double *"), 

262 create_composite_type_from_string("int")): functools.partial(self.builder.ptrtoint, node, self.integer), 

263 (create_composite_type_from_string("int"), 

264 create_composite_type_from_string("double *")): functools.partial(self.builder.inttoptr, 

265 node, self.fp_pointer), 

266 (create_composite_type_from_string("double * restrict"), 

267 create_composite_type_from_string("int")): functools.partial(self.builder.ptrtoint, node, self.integer), 

268 (create_composite_type_from_string("int"), 

269 create_composite_type_from_string("double * restrict")): functools.partial(self.builder.inttoptr, node, 

270 self.fp_pointer), 

271 (create_composite_type_from_string("double * restrict const"), 

272 create_composite_type_from_string("int")): functools.partial(self.builder.ptrtoint, node, 

273 self.integer), 

274 (create_composite_type_from_string("int"), 

275 create_composite_type_from_string("double * restrict const")): functools.partial(self.builder.inttoptr, 

276 node, self.fp_pointer), 

277 } 

278 # TODO float, TEST: const, restrict 

279 # TODO bitcast, addrspacecast 

280 # TODO unsigned/signed fills 

281 # print([x for x in decision.keys()]) 

282 # print("Types:") 

283 # print([(type(x), type(y)) for (x, y) in decision.keys()]) 

284 # print("Cast:") 

285 # print((from_dtype, to_dtype)) 

286 return decision[(from_dtype, to_dtype)]() 

287 

288 def _print_pointer_arithmetic_func(self, pointer): 

289 ptr = self._print(pointer.args[0]) 

290 index = self._print(pointer.args[1]) 

291 return self.builder.gep(ptr, [index]) 

292 

293 def _print_Indexed(self, indexed): 

294 ptr = self._print(indexed.base.label) 

295 index = self._print(indexed.args[1]) 

296 gep = self.builder.gep(ptr, [index]) 

297 return self.builder.load(gep, name=indexed.base.label.name) 

298 

299 def _print_Piecewise(self, piece): 

300 if not piece.args[-1].cond: 

301 # We need the last conditional to be a True, otherwise the resulting 

302 # function may not return a result. 

303 raise ValueError("All Piecewise expressions must contain an " 

304 "(expr, True) statement to be used as a default " 

305 "condition. Without one, the generated " 

306 "expression may not evaluate to anything under " 

307 "some condition.") 

308 if piece.has(Assignment): 

309 raise NotImplementedError('The llvm-backend does not support assignments' 

310 'in the Piecewise function. It is questionable' 

311 'whether to implement it. So far there is no' 

312 'use-case to test it.') 

313 else: 

314 phi_data = [] 

315 after_block = self.builder.append_basic_block() 

316 for (expr, condition) in piece.args: 

317 if condition == sp.sympify(True): # Don't use 'is' use '=='! 

318 phi_data.append((self._print(expr), self.builder.block)) 

319 self.builder.branch(after_block) 

320 self.builder.position_at_end(after_block) 

321 else: 

322 cond = self._print(condition) 

323 true_block = self.builder.append_basic_block() 

324 false_block = self.builder.append_basic_block() 

325 self.builder.cbranch(cond, true_block, false_block) 

326 self.builder.position_at_end(true_block) 

327 phi_data.append((self._print(expr), true_block)) 

328 self.builder.branch(after_block) 

329 self.builder.position_at_end(false_block) 

330 

331 phi = self.builder.phi(to_llvm_type(get_type_of_expression(piece), nvvm_target=self.target == 'gpu')) 

332 for (val, block) in phi_data: 

333 phi.add_incoming(val, block) 

334 return phi 

335 

336 def _print_Conditional(self, node): 

337 cond = self._print(node.condition_expr) 

338 with self.builder.if_else(cond) as (then, otherwise): 

339 with then: 

340 self._print(node.true_block) # emit instructions for when the predicate is true 

341 with otherwise: 

342 self._print(node.false_block) # emit instructions for when the predicate is true 

343 

344 # No return! 

345 

346 def _print_Function(self, expr): 

347 name = expr.func.__name__ 

348 e0 = self._print(expr.args[0]) 

349 fn = self.ext_fn.get(name) 

350 if not fn: 

351 fn_type = ir.FunctionType(self.fp_type, [self.fp_type]) 

352 fn = ir.Function(self.module, fn_type, name) 

353 self.ext_fn[name] = fn 

354 return self.builder.call(fn, [e0], name) 

355 

356 def empty_printer(self, expr): 

357 try: 

358 import inspect 

359 mro = inspect.getmro(expr) 

360 except AttributeError: 

361 mro = "None" 

362 raise TypeError("Unsupported type for LLVM JIT conversion: Expression:\"%s\", Type:\"%s\", MRO:%s" 

363 % (expr, type(expr), mro)) 

364 

365 # from: https://llvm.org/docs/NVPTXUsage.html#nvptx-intrinsics 

366 INDEXING_FUNCTION_MAPPING = { 

367 'blockIdx': 'llvm.nvvm.read.ptx.sreg.ctaid', 

368 'threadIdx': 'llvm.nvvm.read.ptx.sreg.tid', 

369 'blockDim': 'llvm.nvvm.read.ptx.sreg.ntid', 

370 'gridDim': 'llvm.nvvm.read.ptx.sreg.nctaid' 

371 } 

372 

373 def _print_ThreadIndexingSymbol(self, node): 

374 symbol_name: str = node.name 

375 function_name, dimension = tuple(symbol_name.split(".")) 

376 function_name = self.INDEXING_FUNCTION_MAPPING[function_name] 

377 name = f"{function_name}.{dimension}" 

378 

379 return self.builder.zext(_call_sreg(self.builder, name), self.integer)