1from os.path import dirname, join 

2 

3import pystencils.data_types 

4from pystencils.astnodes import Node 

5from pystencils.backends.cbackend import CustomSympyPrinter, generate_c 

6from pystencils.backends.cuda_backend import CudaBackend, CudaSympyPrinter 

7from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt 

8 

9with open(join(dirname(__file__), 'opencl1.1_known_functions.txt')) as f: 

10 lines = f.readlines() 

11 OPENCL_KNOWN_FUNCTIONS = {l.strip(): l.strip() for l in lines if l} 

12 

13 

14def generate_opencl(ast_node: Node, signature_only: bool = False, custom_backend=None, with_globals=True) -> str: 

15 """Prints an abstract syntax tree node (made for target 'gpu') as OpenCL code. 

16 

17 Args: 

18 ast_node: ast representation of kernel 

19 signature_only: generate signature without function body 

20 custom_backend: use own custom printer for code generation 

21 with_globals: enable usage of global variables 

22 

23 Returns: 

24 OpenCL code for the ast node and its descendants 

25 """ 

26 return generate_c(ast_node, signature_only, dialect='opencl', 

27 custom_backend=custom_backend, with_globals=with_globals) 

28 

29 

30class OpenClBackend(CudaBackend): 

31 

32 def __init__(self, 

33 sympy_printer=None, 

34 signature_only=False): 

35 if not sympy_printer: 

36 sympy_printer = OpenClSympyPrinter() 

37 

38 super().__init__(sympy_printer, signature_only) 

39 self._dialect = 'opencl' 

40 

41 def _print_Type(self, node): 

42 code = super()._print_Type(node) 

43 if isinstance(node, pystencils.data_types.PointerType): 

44 return "__global " + code 

45 else: 

46 return code 

47 

48 def _print_ThreadBlockSynchronization(self, node): 

49 raise NotImplementedError() 

50 

51 def _print_TextureDeclaration(self, node): 

52 raise NotImplementedError() 

53 

54 

55class OpenClSympyPrinter(CudaSympyPrinter): 

56 language = "OpenCL" 

57 

58 DIMENSION_MAPPING = { 

59 'x': '0', 

60 'y': '1', 

61 'z': '2' 

62 } 

63 INDEXING_FUNCTION_MAPPING = { 

64 'blockIdx': 'get_group_id', 

65 'threadIdx': 'get_local_id', 

66 'blockDim': 'get_local_size', 

67 'gridDim': 'get_global_size' 

68 } 

69 

70 def __init__(self): 

71 CustomSympyPrinter.__init__(self) 

72 self.known_functions = OPENCL_KNOWN_FUNCTIONS 

73 

74 def _print_Type(self, node): 

75 code = super()._print_Type(node) 

76 if isinstance(node, pystencils.data_types.PointerType): 

77 return "__global " + code 

78 else: 

79 return code 

80 

81 def _print_ThreadIndexingSymbol(self, node): 

82 symbol_name: str = node.name 

83 function_name, dimension = tuple(symbol_name.split(".")) 

84 dimension = self.DIMENSION_MAPPING[dimension] 

85 function_name = self.INDEXING_FUNCTION_MAPPING[function_name] 

86 return f"(int) {function_name}({dimension})" 

87 

88 def _print_TextureAccess(self, node): 

89 raise NotImplementedError() 

90 

91 # For math functions, OpenCL is more similar to the C++ printer CustomSympyPrinter 

92 # since built-in math functions are generic. 

93 # In CUDA, you have to differentiate between `sin` and `sinf` 

94 try: 

95 _print_math_func = CustomSympyPrinter._print_math_func 

96 except AttributeError: 

97 pass 

98 _print_Pow = CustomSympyPrinter._print_Pow 

99 

100 def _print_Function(self, expr): 

101 if isinstance(expr, fast_division): 

102 return "native_divide(%s, %s)" % tuple(self._print(a) for a in expr.args) 

103 elif isinstance(expr, fast_sqrt): 

104 return f"native_sqrt({tuple(self._print(a) for a in expr.args)})" 

105 elif isinstance(expr, fast_inv_sqrt): 

106 return f"native_rsqrt({tuple(self._print(a) for a in expr.args)})" 

107 return CustomSympyPrinter._print_Function(self, expr)