1from os.path import dirname, join 

2 

3from pystencils.astnodes import Node 

4from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, generate_c 

5from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt 

6from pystencils.interpolation_astnodes import DiffInterpolatorAccess, InterpolationMode 

7 

8with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f: 

9 lines = f.readlines() 

10 CUDA_KNOWN_FUNCTIONS = {l.strip(): l.strip() for l in lines if l} 

11 

12 

13def generate_cuda(ast_node: Node, signature_only: bool = False, custom_backend=None, with_globals=True) -> str: 

14 """Prints an abstract syntax tree node as CUDA code. 

15 

16 Args: 

17 ast_node: ast representation of kernel 

18 signature_only: generate signature without function body 

19 custom_backend: use own custom printer for code generation 

20 with_globals: enable usage of global variables 

21 

22 Returns: 

23 CUDA code for the ast node and its descendants 

24 """ 

25 return generate_c(ast_node, signature_only, dialect='cuda', 

26 custom_backend=custom_backend, with_globals=with_globals) 

27 

28 

29class CudaBackend(CBackend): 

30 

31 def __init__(self, sympy_printer=None, 

32 signature_only=False): 

33 if not sympy_printer: 

34 sympy_printer = CudaSympyPrinter() 

35 

36 super().__init__(sympy_printer, signature_only, dialect='cuda') 

37 

38 def _print_SharedMemoryAllocation(self, node): 

39 dtype = node.symbol.dtype 

40 name = self.sympy_printer.doprint(node.symbol.name) 

41 num_elements = '*'.join([str(s) for s in node.shared_mem.shape]) 

42 code = f"__shared__ {dtype} {name}[{num_elements}];" 

43 return code 

44 

45 @staticmethod 

46 def _print_ThreadBlockSynchronization(node): 

47 code = "__synchtreads();" 

48 return code 

49 

50 def _print_TextureDeclaration(self, node): 

51 

52 # TODO: use fStrings here 

53 if node.texture.field.dtype.numpy_dtype.itemsize > 4: 

54 code = "texture<fp_tex_%s, cudaTextureType%iD, cudaReadModeElementType> %s;" % ( 

55 str(node.texture.field.dtype), 

56 node.texture.field.spatial_dimensions, 

57 node.texture 

58 ) 

59 else: 

60 code = "texture<%s, cudaTextureType%iD, cudaReadModeElementType> %s;" % ( 

61 str(node.texture.field.dtype), 

62 node.texture.field.spatial_dimensions, 

63 node.texture 

64 ) 

65 return code 

66 

67 def _print_SkipIteration(self, _): 

68 return "return;" 

69 

70 

71class CudaSympyPrinter(CustomSympyPrinter): 

72 language = "CUDA" 

73 

74 def __init__(self): 

75 super(CudaSympyPrinter, self).__init__() 

76 self.known_functions.update(CUDA_KNOWN_FUNCTIONS) 

77 

78 def _print_InterpolatorAccess(self, node): 

79 dtype = node.interpolator.field.dtype.numpy_dtype 

80 

81 if type(node) == DiffInterpolatorAccess: 

82 # cubicTex3D_1st_derivative_x(texture tex, float3 coord) 

83 template = f"cubicTex%iD_1st_derivative_{list(reversed('xyz'[:node.ndim]))[node.diff_coordinate_idx]}(%s, %s)" # noqa 

84 elif node.interpolator.interpolation_mode == InterpolationMode.CUBIC_SPLINE: 

85 template = "cubicTex%iDSimple(%s, %s)" 

86 else: 

87 if dtype.itemsize > 4: 

88 # Use PyCuda hack! 

89 # https://github.com/inducer/pycuda/blob/master/pycuda/cuda/pycuda-helpers.hpp 

90 template = "fp_tex%iD(%s, %s)" 

91 else: 

92 template = "tex%iD(%s, %s)" 

93 

94 code = template % ( 

95 node.interpolator.field.spatial_dimensions, 

96 str(node.interpolator), 

97 # + 0.5 comes from Nvidia's staggered indexing 

98 ', '.join(self._print(o + 0.5) for o in reversed(node.offsets)) 

99 ) 

100 return code 

101 

102 def _print_Function(self, expr): 

103 if isinstance(expr, fast_division): 

104 assert len(expr.args) == 2, f"__fdividef has two arguments, but {len(expr.args)} where given" 

105 return f"__fdividef({self._print(expr.args[0])}, {self._print(expr.args[1])})" 

106 elif isinstance(expr, fast_sqrt): 

107 assert len(expr.args) == 1, f"__fsqrt_rn has one argument, but {len(expr.args)} where given" 

108 return f"__fsqrt_rn({self._print(expr.args[0])})" 

109 elif isinstance(expr, fast_inv_sqrt): 

110 assert len(expr.args) == 1, f"__frsqrt_rn has one argument, but {len(expr.args)} where given" 

111 return f"__frsqrt_rn({self._print(expr.args[0])})" 

112 return super()._print_Function(expr)