1import ctypes as ct 

2import subprocess 

3from functools import partial 

4from itertools import chain 

5from os.path import exists, join 

6 

7import llvmlite.binding as llvm 

8import llvmlite.ir as ir 

9import numpy as np 

10 

11from pystencils.data_types import create_composite_type_from_string 

12from pystencils.field import FieldType 

13 

14from ..data_types import StructType, ctypes_from_llvm, to_ctypes 

15from .llvm import generate_llvm 

16 

17 

18def build_ctypes_argument_list(parameter_specification, argument_dict): 

19 argument_dict = {k: v for k, v in argument_dict.items()} 

20 ct_arguments = [] 

21 array_shapes = set() 

22 index_arr_shapes = set() 

23 

24 for param in parameter_specification: 

25 if param.is_field_parameter: 

26 try: 

27 field_arr = argument_dict[param.field_name] 

28 except KeyError: 

29 raise KeyError("Missing field parameter for kernel call " + param.field_name) 

30 

31 symbolic_field = param.fields[0] 

32 if param.is_field_pointer: 

33 ct_arguments.append(field_arr.ctypes.data_as(to_ctypes(param.symbol.dtype))) 

34 if symbolic_field.has_fixed_shape: 

35 symbolic_field_shape = tuple(int(i) for i in symbolic_field.shape) 

36 if isinstance(symbolic_field.dtype, StructType): 

37 symbolic_field_shape = symbolic_field_shape[:-1] 

38 if symbolic_field_shape != field_arr.shape: 

39 raise ValueError("Passed array '%s' has shape %s which does not match expected shape %s" % 

40 (param.field_name, str(field_arr.shape), str(symbolic_field.shape))) 

41 if symbolic_field.has_fixed_shape: 

42 symbolic_field_strides = tuple(int(i) * field_arr.itemsize for i in symbolic_field.strides) 

43 if isinstance(symbolic_field.dtype, StructType): 

44 symbolic_field_strides = symbolic_field_strides[:-1] 

45 if symbolic_field_strides != field_arr.strides: 

46 raise ValueError("Passed array '%s' has strides %s which does not match expected strides %s" % 

47 (param.field_name, str(field_arr.strides), str(symbolic_field_strides))) 

48 

49 if FieldType.is_indexed(symbolic_field): 

50 index_arr_shapes.add(field_arr.shape[:symbolic_field.spatial_dimensions]) 

51 elif FieldType.is_generic(symbolic_field): 

52 array_shapes.add(field_arr.shape[:symbolic_field.spatial_dimensions]) 

53 

54 elif param.is_field_shape: 

55 data_type = to_ctypes(param.symbol.dtype) 

56 ct_arguments.append(data_type(field_arr.shape[param.symbol.coordinate])) 

57 elif param.is_field_stride: 

58 data_type = to_ctypes(param.symbol.dtype) 

59 assert field_arr.strides[param.symbol.coordinate] % field_arr.itemsize == 0 

60 item_stride = field_arr.strides[param.symbol.coordinate] // field_arr.itemsize 

61 ct_arguments.append(data_type(item_stride)) 

62 else: 

63 assert False 

64 else: 

65 try: 

66 value = argument_dict[param.symbol.name] 

67 except KeyError: 

68 raise KeyError("Missing parameter for kernel call " + param.symbol.name) 

69 expected_type = to_ctypes(param.symbol.dtype) 

70 ct_arguments.append(expected_type(value)) 

71 

72 if len(array_shapes) > 1: 

73 raise ValueError("All passed arrays have to have the same size " + str(array_shapes)) 

74 if len(index_arr_shapes) > 1: 

75 raise ValueError("All passed index arrays have to have the same size " + str(array_shapes)) 

76 

77 return ct_arguments 

78 

79 

80def make_python_function_incomplete_params(kernel_function_node, argument_dict, func): 

81 parameters = kernel_function_node.get_parameters() 

82 

83 cache = {} 

84 cache_values = [] 

85 

86 def wrapper(**kwargs): 

87 key = hash(tuple((k, v.ctypes.data, v.strides, v.shape) if isinstance(v, np.ndarray) else (k, id(v)) 

88 for k, v in kwargs.items())) 

89 try: 

90 args = cache[key] 

91 func(*args) 

92 except KeyError: 

93 full_arguments = argument_dict.copy() 

94 full_arguments.update(kwargs) 

95 args = build_ctypes_argument_list(parameters, full_arguments) 

96 cache[key] = args 

97 cache_values.append(kwargs) # keep objects alive such that ids remain unique 

98 func(*args) 

99 wrapper.ast = kernel_function_node 

100 wrapper.parameters = kernel_function_node.get_parameters() 

101 return wrapper 

102 

103 

104def generate_and_jit(ast): 

105 target = 'gpu' if ast._backend == 'llvm_gpu' else 'cpu' 

106 gen = generate_llvm(ast, target=target) 

107 if isinstance(gen, ir.Module): 

108 return compile_llvm(gen, target, ast) 

109 else: 

110 return compile_llvm(gen.module, target, ast) 

111 

112 

113def make_python_function(ast, argument_dict={}, func=None): 

114 if func is None: 

115 jit = generate_and_jit(ast) 

116 func = jit.get_function_ptr(ast.function_name) 

117 try: 

118 args = build_ctypes_argument_list(ast.get_parameters(), argument_dict) 

119 except KeyError: 

120 # not all parameters specified yet 

121 return make_python_function_incomplete_params(ast, argument_dict, func) 

122 return lambda: func(*args) 

123 

124 

125def compile_llvm(module, target='cpu', ast=None): 

126 jit = CudaJit(ast) if target == "gpu" else Jit() 

127 jit.parse(module) 

128 jit.optimize() 

129 jit.compile() 

130 return jit 

131 

132 

133class Jit(object): 

134 def __init__(self): 

135 llvm.initialize() 

136 llvm.initialize_all_targets() 

137 llvm.initialize_native_target() 

138 llvm.initialize_native_asmprinter() 

139 

140 self.module = None 

141 self._llvmmod = llvm.parse_assembly("") 

142 self.target = llvm.Target.from_default_triple() 

143 self.cpu = llvm.get_host_cpu_name() 

144 try: 

145 self.cpu_features = llvm.get_host_cpu_features() 

146 self.target_machine = self.target.create_target_machine(cpu=self.cpu, features=self.cpu_features.flatten(), 

147 opt=2) 

148 except RuntimeError: 

149 self.target_machine = self.target.create_target_machine(cpu=self.cpu, opt=2) 

150 llvm.check_jit_execution() 

151 self.ee = llvm.create_mcjit_compiler(self.llvmmod, self.target_machine) 

152 self.ee.finalize_object() 

153 self.fptr = None 

154 

155 @property 

156 def llvmmod(self): 

157 return self._llvmmod 

158 

159 @llvmmod.setter 

160 def llvmmod(self, mod): 

161 self.ee.remove_module(self.llvmmod) 

162 self.ee.add_module(mod) 

163 self.ee.finalize_object() 

164 self.compile() 

165 self._llvmmod = mod 

166 

167 def parse(self, module): 

168 self.module = module 

169 llvmmod = llvm.parse_assembly(str(module)) 

170 llvmmod.verify() 

171 llvmmod.triple = self.target.triple 

172 llvmmod.name = 'module' 

173 self.llvmmod = llvmmod 

174 

175 def write_ll(self, file): 

176 with open(file, 'w') as f: 

177 f.write(str(self.llvmmod)) 

178 

179 def write_assembly(self, file): 

180 with open(file, 'w') as f: 

181 f.write(self.target_machine.emit_assembly(self.llvmmod)) 

182 

183 def write_object_file(self, file): 

184 with open(file, 'wb') as f: 

185 f.write(self.target_machine.emit_object(self.llvmmod)) 

186 

187 def optimize(self): 

188 pmb = llvm.create_pass_manager_builder() 

189 pmb.opt_level = 2 

190 pmb.disable_unit_at_a_time = False 

191 pmb.loop_vectorize = True 

192 pmb.slp_vectorize = True 

193 # TODO possible to pass for functions 

194 pm = llvm.create_module_pass_manager() 

195 pm.add_instruction_combining_pass() 

196 pm.add_function_attrs_pass() 

197 pm.add_constant_merge_pass() 

198 pm.add_licm_pass() 

199 pmb.populate(pm) 

200 pm.run(self.llvmmod) 

201 

202 def compile(self): 

203 fptr = {} 

204 for func in self.module.functions: 

205 if not func.is_declaration: 

206 return_type = None 

207 if func.ftype.return_type != ir.VoidType(): 

208 return_type = to_ctypes(create_composite_type_from_string(str(func.ftype.return_type))) 

209 args = [ctypes_from_llvm(arg) for arg in func.ftype.args] 

210 function_address = self.ee.get_function_address(func.name) 

211 fptr[func.name] = ct.CFUNCTYPE(return_type, *args)(function_address) 

212 self.fptr = fptr 

213 

214 def __call__(self, func, *args, **kwargs): 

215 target_function = next(f for f in self.module.functions if f.name == func) 

216 arg_types = [ctypes_from_llvm(arg.type) for arg in target_function.args] 

217 

218 transformed_args = [] 

219 for i, arg in enumerate(args): 

220 if isinstance(arg, np.ndarray): 

221 transformed_args.append(arg.ctypes.data_as(arg_types[i])) 

222 else: 

223 transformed_args.append(arg) 

224 

225 self.fptr[func](*transformed_args) 

226 

227 def print_functions(self): 

228 for f in self.module.functions: 

229 print(f.ftype.return_type, f.name, f.args) 

230 

231 def get_function_ptr(self, name): 

232 fptr = self.fptr[name] 

233 fptr.jit = self 

234 return fptr 

235 

236 

237# Following code more or less from numba 

238class CudaJit(Jit): 

239 

240 CUDA_TRIPLE = {32: 'nvptx-nvidia-cuda', 

241 64: 'nvptx64-nvidia-cuda'} 

242 MACHINE_BITS = tuple.__itemsize__ * 8 

243 data_layout = { 

244 32: ('e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-' 

245 'f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64'), 

246 64: ('e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-' 

247 'f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64')} 

248 

249 default_data_layout = data_layout[MACHINE_BITS] 

250 

251 def __init__(self, ast): 

252 # super().__init__() 

253 

254 # self.target = llvm.Target.from_triple(self.CUDA_TRIPLE[self.MACHINE_BITS]) 

255 self._data_layout = self.default_data_layout[self.MACHINE_BITS] 

256 # self._target_data = llvm.create_target_data(self._data_layout) 

257 self.indexing = ast.indexing 

258 

259 def optimize(self): 

260 pmb = llvm.create_pass_manager_builder() 

261 pmb.opt_level = 2 

262 pmb.disable_unit_at_a_time = False 

263 pmb.loop_vectorize = False 

264 pmb.slp_vectorize = False 

265 # TODO possible to pass for functions 

266 pm = llvm.create_module_pass_manager() 

267 pm.add_instruction_combining_pass() 

268 pm.add_function_attrs_pass() 

269 pm.add_constant_merge_pass() 

270 pm.add_licm_pass() 

271 pmb.populate(pm) 

272 pm.run(self.llvmmod) 

273 pm.run(self.llvmmod) 

274 

275 def write_ll(self, file): 

276 with open(file, 'w') as f: 

277 f.write(str(self.llvmmod)) 

278 

279 def parse(self, module): 

280 

281 llvmmod = module 

282 llvmmod.triple = self.CUDA_TRIPLE[self.MACHINE_BITS] 

283 llvmmod.data_layout = self.default_data_layout 

284 llvmmod.verify() 

285 llvmmod.name = 'module' 

286 

287 self._llvmmod = llvm.parse_assembly(str(llvmmod)) 

288 

289 def compile(self): 

290 from pystencils.cpu.cpujit import get_cache_config, get_compiler_config, get_llc_command 

291 import hashlib 

292 compiler_cache = get_cache_config()['object_cache'] 

293 ir_file = join(compiler_cache, hashlib.md5(str(self._llvmmod).encode()).hexdigest() + '.ll') 

294 ptx_file = ir_file.replace('.ll', '.ptx') 

295 try: 

296 from pycuda.driver import Context 

297 arch = "sm_%d%d" % Context.get_device().compute_capability() 

298 except Exception: 

299 arch = "sm_35" 

300 

301 if not exists(ptx_file): 

302 self.write_ll(ir_file) 

303 if 'llc' in get_compiler_config(): 

304 llc_command = get_compiler_config()['llc'] 

305 else: 

306 llc_command = get_llc_command() or 'llc' 

307 

308 subprocess.check_call([llc_command, '-mcpu=' + arch, ir_file, '-o', ptx_file]) 

309 

310 # cubin_file = ir_file.replace('.ll', '.cubin') 

311 # if not exists(cubin_file): 

312 # subprocess.check_call(['ptxas', '--gpu-name', arch, ptx_file, '-o', cubin_file]) 

313 import pycuda.driver 

314 

315 cuda_module = pycuda.driver.module_from_file(ptx_file) # also works: cubin_file 

316 self.cuda_module = cuda_module 

317 

318 def __call__(self, func, *args, **kwargs): 

319 shape = [a.shape for a in chain(args, kwargs.values()) if hasattr(a, 'shape')][0] 

320 block_and_thread_numbers = self.indexing.call_parameters(shape) 

321 block_and_thread_numbers['block'] = tuple(int(i) for i in block_and_thread_numbers['block']) 

322 block_and_thread_numbers['grid'] = tuple(int(i) for i in block_and_thread_numbers['grid']) 

323 self.cuda_module.get_function(func)(*args, **kwargs, **block_and_thread_numbers) 

324 

325 def get_function_ptr(self, name): 

326 return partial(self._call__, name)