1from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment 

2from pystencils.data_types import BasicType, StructType, TypedSymbol 

3from pystencils.field import Field, FieldType 

4from pystencils.gpucuda.cudajit import make_python_function 

5from pystencils.gpucuda.indexing import BlockIndexing 

6from pystencils.transformations import ( 

7 add_types, get_base_buffer_index, get_common_shape, implement_interpolations, 

8 parse_base_pointer_info, resolve_buffer_accesses, resolve_field_accesses, unify_shape_symbols) 

9 

10 

11def create_cuda_kernel(assignments, 

12 function_name="kernel", 

13 type_info=None, 

14 indexing_creator=BlockIndexing, 

15 iteration_slice=None, 

16 ghost_layers=None, 

17 skip_independence_check=False, 

18 use_textures_for_interpolation=True): 

19 assert assignments, "Assignments must not be empty!" 

20 fields_read, fields_written, assignments = add_types(assignments, type_info, not skip_independence_check) 

21 all_fields = fields_read.union(fields_written) 

22 read_only_fields = set([f.name for f in fields_read - fields_written]) 

23 

24 buffers = set([f for f in all_fields if FieldType.is_buffer(f) or FieldType.is_custom(f)]) 

25 fields_without_buffers = all_fields - buffers 

26 

27 field_accesses = set() 

28 num_buffer_accesses = 0 

29 for eq in assignments: 

30 field_accesses.update(eq.atoms(Field.Access)) 

31 field_accesses = {e for e in field_accesses if not e.is_absolute_access} 

32 num_buffer_accesses += sum(1 for access in eq.atoms(Field.Access) if FieldType.is_buffer(access.field)) 

33 

34 common_shape = get_common_shape(fields_without_buffers) 

35 

36 if iteration_slice is None: 

37 # determine iteration slice from ghost layers 

38 if ghost_layers is None: 

39 # determine required number of ghost layers from field access 

40 required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses]) 

41 ghost_layers = [(required_ghost_layers, required_ghost_layers)] * len(common_shape) 

42 iteration_slice = [] 

43 if isinstance(ghost_layers, int): 

44 for i in range(len(common_shape)): 

45 iteration_slice.append(slice(ghost_layers, -ghost_layers if ghost_layers > 0 else None)) 

46 ghost_layers = [(ghost_layers, ghost_layers)] * len(common_shape) 

47 else: 

48 for i in range(len(common_shape)): 

49 iteration_slice.append(slice(ghost_layers[i][0], 

50 -ghost_layers[i][1] if ghost_layers[i][1] > 0 else None)) 

51 

52 indexing = indexing_creator(field=list(fields_without_buffers)[0], iteration_slice=iteration_slice) 

53 coord_mapping = indexing.coordinates 

54 

55 cell_idx_assignments = [SympyAssignment(LoopOverCoordinate.get_loop_counter_symbol(i), value) 

56 for i, value in enumerate(coord_mapping)] 

57 cell_idx_symbols = [LoopOverCoordinate.get_loop_counter_symbol(i) for i, _ in enumerate(coord_mapping)] 

58 assignments = cell_idx_assignments + assignments 

59 

60 block = Block(assignments) 

61 

62 block = indexing.guard(block, common_shape) 

63 unify_shape_symbols(block, common_shape=common_shape, fields=fields_without_buffers) 

64 

65 ast = KernelFunction(block, 

66 'gpu', 

67 'gpucuda', 

68 make_python_function, 

69 ghost_layers, 

70 function_name, 

71 assignments=assignments) 

72 ast.global_variables.update(indexing.index_variables) 

73 

74 implement_interpolations(ast, implement_by_texture_accesses=use_textures_for_interpolation) 

75 

76 base_pointer_spec = [['spatialInner0']] 

77 base_pointer_info = {f.name: parse_base_pointer_info(base_pointer_spec, [2, 1, 0], 

78 f.spatial_dimensions, f.index_dimensions) 

79 for f in all_fields} 

80 

81 coord_mapping = {f.name: cell_idx_symbols for f in all_fields} 

82 

83 loop_strides = list(fields_without_buffers)[0].shape 

84 

85 if any(FieldType.is_buffer(f) for f in all_fields): 

86 resolve_buffer_accesses(ast, get_base_buffer_index(ast, indexing.coordinates, loop_strides), read_only_fields) 

87 

88 resolve_field_accesses(ast, read_only_fields, field_to_base_pointer_info=base_pointer_info, 

89 field_to_fixed_coordinates=coord_mapping) 

90 

91 # add the function which determines #blocks and #threads as additional member to KernelFunction node 

92 # this is used by the jit 

93 

94 # If loop counter symbols have been explicitly used in the update equations (e.g. for built in periodicity), 

95 # they are defined here 

96 undefined_loop_counters = {LoopOverCoordinate.is_loop_counter_symbol(s): s for s in ast.body.undefined_symbols 

97 if LoopOverCoordinate.is_loop_counter_symbol(s) is not None} 

98 for i, loop_counter in undefined_loop_counters.items(): 

99 ast.body.insert_front(SympyAssignment(loop_counter, indexing.coordinates[i])) 

100 

101 ast.indexing = indexing 

102 return ast 

103 

104 

105def created_indexed_cuda_kernel(assignments, 

106 index_fields, 

107 function_name="kernel", 

108 type_info=None, 

109 coordinate_names=('x', 'y', 'z'), 

110 indexing_creator=BlockIndexing, 

111 use_textures_for_interpolation=True): 

112 fields_read, fields_written, assignments = add_types(assignments, type_info, check_independence_condition=False) 

113 all_fields = fields_read.union(fields_written) 

114 read_only_fields = set([f.name for f in fields_read - fields_written]) 

115 

116 for index_field in index_fields: 

117 index_field.field_type = FieldType.INDEXED 

118 assert FieldType.is_indexed(index_field) 

119 assert index_field.spatial_dimensions == 1, "Index fields have to be 1D" 

120 

121 non_index_fields = [f for f in all_fields if f not in index_fields] 

122 spatial_coordinates = {f.spatial_dimensions for f in non_index_fields} 

123 assert len(spatial_coordinates) == 1, "Non-index fields do not have the same number of spatial coordinates" 

124 spatial_coordinates = list(spatial_coordinates)[0] 

125 

126 def get_coordinate_symbol_assignment(name): 

127 for ind_f in index_fields: 

128 assert isinstance(ind_f.dtype, StructType), "Index fields have to have a struct data type" 

129 data_type = ind_f.dtype 

130 if data_type.has_element(name): 

131 rhs = ind_f[0](name) 

132 lhs = TypedSymbol(name, BasicType(data_type.get_element_type(name))) 

133 return SympyAssignment(lhs, rhs) 

134 raise ValueError(f"Index {name} not found in any of the passed index fields") 

135 

136 coordinate_symbol_assignments = [get_coordinate_symbol_assignment(n) 

137 for n in coordinate_names[:spatial_coordinates]] 

138 coordinate_typed_symbols = [eq.lhs for eq in coordinate_symbol_assignments] 

139 

140 idx_field = list(index_fields)[0] 

141 indexing = indexing_creator(field=idx_field, 

142 iteration_slice=[slice(None, None, None)] * len(idx_field.spatial_shape)) 

143 

144 function_body = Block(coordinate_symbol_assignments + assignments) 

145 function_body = indexing.guard(function_body, get_common_shape(index_fields)) 

146 ast = KernelFunction(function_body, 'gpu', 'gpucuda', make_python_function, 

147 None, function_name, assignments=assignments) 

148 ast.global_variables.update(indexing.index_variables) 

149 

150 implement_interpolations(ast, implement_by_texture_accesses=use_textures_for_interpolation) 

151 

152 coord_mapping = indexing.coordinates 

153 base_pointer_spec = [['spatialInner0']] 

154 base_pointer_info = {f.name: parse_base_pointer_info(base_pointer_spec, [2, 1, 0], 

155 f.spatial_dimensions, f.index_dimensions) 

156 for f in all_fields} 

157 

158 coord_mapping = {f.name: coord_mapping for f in index_fields} 

159 coord_mapping.update({f.name: coordinate_typed_symbols for f in non_index_fields}) 

160 resolve_field_accesses(ast, read_only_fields, field_to_fixed_coordinates=coord_mapping, 

161 field_to_base_pointer_info=base_pointer_info) 

162 

163 # add the function which determines #blocks and #threads as additional member to KernelFunction node 

164 # this is used by the jit 

165 ast.indexing = indexing 

166 return ast