1# -*- coding: utf-8 -*- 

2# 

3# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de> 

4# 

5# Distributed under terms of the GPLv3 license. 

6 

7""" 

8 

9""" 

10 

11import hashlib 

12import itertools 

13from enum import Enum 

14from typing import Set 

15 

16import sympy as sp 

17from sympy.core.cache import cacheit 

18 

19import pystencils 

20from pystencils.astnodes import Node 

21from pystencils.data_types import TypedSymbol, cast_func, create_type 

22 

23try: 

24 import pycuda.driver 

25except Exception: 

26 pass 

27 

28_hash = hashlib.md5 

29 

30 

31class InterpolationMode(str, Enum): 

32 NEAREST_NEIGHBOR = "nearest_neighbour" 

33 NN = NEAREST_NEIGHBOR 

34 LINEAR = "linear" 

35 CUBIC_SPLINE = "cubic_spline" 

36 

37 

38class _InterpolationSymbol(TypedSymbol): 

39 

40 def __new__(cls, name, field, interpolator): 

41 obj = cls.__xnew_cached_(cls, name, field, interpolator) 

42 return obj 

43 

44 def __new_stage2__(cls, name, field, interpolator): 

45 obj = super().__xnew__(cls, name, 'dummy_symbol_carrying_field' + field.name) 

46 obj.field = field 

47 obj.interpolator = interpolator 

48 return obj 

49 

50 def __getnewargs__(self): 

51 return self.name, self.field, self.interpolator 

52 

53 def __getnewargs_ex__(self): 

54 return (self.name, self.field, self.interpolator), {} 

55 

56 # noinspection SpellCheckingInspection 

57 __xnew__ = staticmethod(__new_stage2__) 

58 # noinspection SpellCheckingInspection 

59 __xnew_cached_ = staticmethod(cacheit(__new_stage2__)) 

60 

61 

62class Interpolator(object): 

63 """ 

64 Implements non-integer accesses on fields using linear interpolation. 

65 

66 On GPU, this interpolator can be implemented by a :class:`.TextureCachedField` for hardware acceleration. 

67 

68 Address modes are different boundary handlings possible choices are like for CUDA textures 

69 

70 **CLAMP** 

71 

72 The signal c[k] is continued outside k=0,...,M-1 so that c[k] = c[0] for k < 0, and c[k] = c[M-1] for k >= M. 

73 

74 **BORDER** 

75 

76 The signal c[k] is continued outside k=0,...,M-1 so that c[k] = 0 for k < 0and for k >= M. 

77 

78 Now, to describe the last two address modes, we are forced to consider normalized coordinates, 

79 so that the 1D input signal samples are assumed to be c[k / M], with k=0,...,M-1. 

80 

81 **WRAP** 

82 

83 The signal c[k / M] is continued outside k=0,...,M-1 so that it is periodic with period equal to M. 

84 In other words, c[(k + p * M) / M] = c[k / M] for any (positive, negative or vanishing) integer p. 

85 

86 **MIRROR** 

87 

88 The signal c[k / M] is continued outside k=0,...,M-1 so that it is periodic with period equal to 2 * M - 2. 

89 In other words, c[l / M] = c[k / M] for any l and k such that (l + k)mod(2 * M - 2) = 0. 

90 

91 Explanations from https://stackoverflow.com/questions/19020963/the-different-addressing-modes-of-cuda-textures 

92 """ 

93 

94 required_global_declarations = [] 

95 

96 def __init__(self, 

97 parent_field, 

98 interpolation_mode: InterpolationMode, 

99 address_mode='BORDER', 

100 use_normalized_coordinates=False, 

101 allow_textures=True): 

102 super().__init__() 

103 

104 self.field = parent_field 

105 self.field.field_type = pystencils.field.FieldType.CUSTOM 

106 self.address_mode = address_mode 

107 self.use_normalized_coordinates = use_normalized_coordinates 

108 self.interpolation_mode = interpolation_mode 

109 self.hash_str = hashlib.md5( 

110 f'{self.field}_{address_mode}_{self.field.dtype}_{interpolation_mode}'.encode()).hexdigest() 

111 self.symbol = _InterpolationSymbol(str(self), parent_field, self) 

112 self.allow_textures = allow_textures 

113 

114 @property 

115 def ndim(self): 

116 return self.field.ndim 

117 

118 @property 

119 def _hashable_contents(self): 

120 return (str(self.address_mode), 

121 str(type(self)), 

122 self.hash_str, 

123 self.use_normalized_coordinates) 

124 

125 def at(self, offset): 

126 return InterpolatorAccess(self.symbol, *[sp.S(o) for o in offset]) 

127 

128 def __getitem__(self, offset): 

129 return InterpolatorAccess(self.symbol, *[sp.S(o) for o in offset]) 

130 

131 def __str__(self): 

132 return f'{self.field.name}_interpolator_{self.reproducible_hash}' 

133 

134 def __repr__(self): 

135 return self.__str__() 

136 

137 def __hash__(self): 

138 return hash(self._hashable_contents) 

139 

140 def __eq__(self, other): 

141 return hash(self) == hash(other) 

142 

143 @property 

144 def reproducible_hash(self): 

145 return _hash(str(self._hashable_contents).encode()).hexdigest() 

146 

147 

148class LinearInterpolator(Interpolator): 

149 

150 def __init__(self, 

151 parent_field: pystencils.Field, 

152 address_mode='BORDER', 

153 use_normalized_coordinates=False): 

154 super().__init__(parent_field, 

155 InterpolationMode.LINEAR, 

156 address_mode, 

157 use_normalized_coordinates) 

158 

159 

160class NearestNeightborInterpolator(Interpolator): 

161 

162 def __init__(self, 

163 parent_field: pystencils.Field, 

164 address_mode='BORDER', 

165 use_normalized_coordinates=False): 

166 super().__init__(parent_field, 

167 InterpolationMode.NN, 

168 address_mode, 

169 use_normalized_coordinates) 

170 

171 

172class InterpolatorAccess(TypedSymbol): 

173 def __new__(cls, field, *offsets): 

174 obj = InterpolatorAccess.__xnew_cached_(cls, field, *offsets) 

175 return obj 

176 

177 def __new_stage2__(cls, symbol, *offsets): 

178 assert offsets is not None 

179 obj = super().__xnew__(cls, '%s_interpolator_%s' % 

180 (symbol.field.name, _hash(str(tuple(offsets)).encode()).hexdigest()), 

181 symbol.field.dtype) 

182 obj.offsets = offsets 

183 obj.symbol = symbol 

184 obj.field = symbol.field 

185 obj.interpolator = symbol.interpolator 

186 return obj 

187 

188 def _hashable_contents(self): 

189 return super()._hashable_content() + ((self.symbol, self.field, tuple(self.offsets), self.symbol.interpolator)) 

190 

191 def __str__(self): 

192 return f"{self.field.name}_interpolator({', '.join(str(o) for o in self.offsets)})" 

193 

194 def __repr__(self): 

195 return self.__str__() 

196 

197 def _latex(self, printer, *_): 

198 n = self.field.latex_name if self.field.latex_name else self.field.name 

199 foo = ", ".join(str(printer.doprint(o)) for o in self.offsets) 

200 return f'{n}_{{interpolator}}\\left({foo}\\right)' 

201 

202 @property 

203 def ndim(self): 

204 return len(self.offsets) 

205 

206 @property 

207 def is_texture(self): 

208 return isinstance(self.interpolator, TextureCachedField) 

209 

210 def atoms(self, *types): 

211 if self.offsets: 

212 offsets = set(o for o in self.offsets if isinstance(o, types)) 

213 if isinstance(self, *types): 

214 offsets.update([self]) 

215 for o in self.offsets: 

216 if hasattr(o, 'atoms'): 

217 offsets.update(set(o.atoms(*types))) 

218 return offsets 

219 else: 

220 return set() 

221 

222 def neighbor(self, coord_id, offset): 

223 offset_list = list(self.offsets) 

224 offset_list[coord_id] += offset 

225 return self.interpolator.at(tuple(offset_list)) 

226 

227 @property 

228 def free_symbols(self): 

229 symbols = set() 

230 if self.offsets is not None: 

231 for o in self.offsets: 

232 if hasattr(o, 'free_symbols'): 

233 symbols.update(set(o.free_symbols)) 

234 # if hasattr(o, 'atoms'): 

235 # symbols.update(set(o.atoms(sp.Symbol))) 

236 

237 return symbols 

238 

239 @property 

240 def required_global_declarations(self): 

241 required_global_declarations = self.symbol.interpolator.required_global_declarations 

242 if required_global_declarations: 

243 required_global_declarations[0]._symbols_defined.add(self) 

244 return required_global_declarations 

245 

246 @property 

247 def args(self): 

248 return [self.symbol, *self.offsets] 

249 

250 @property 

251 def symbols_defined(self) -> Set[sp.Symbol]: 

252 return {self} 

253 

254 @property 

255 def interpolation_mode(self): 

256 return self.interpolator.interpolation_mode 

257 

258 @property 

259 def _diff_interpolation_vec(self): 

260 return sp.Matrix([DiffInterpolatorAccess(self.symbol, i, *self.offsets) 

261 for i in range(len(self.offsets))]) 

262 

263 def diff(self, *symbols, **kwargs): 

264 if symbols == (self,): 

265 return 1 

266 rtn = self._diff_interpolation_vec.T * sp.Matrix(self.offsets).diff(*symbols, **kwargs) 

267 if rtn.shape == (1, 1): 

268 rtn = rtn[0, 0] 

269 return rtn 

270 

271 def implementation_with_stencils(self): 

272 field = self.field 

273 

274 default_int_type = create_type('int64') 

275 use_textures = isinstance(self.interpolator, TextureCachedField) 

276 if use_textures: 

277 def absolute_access(x, _): 

278 return self.symbol.interpolator.at((o for o in x)) 

279 else: 

280 absolute_access = field.absolute_access 

281 

282 sum = [0, ] * (field.shape[0] if field.index_dimensions else 1) 

283 

284 offsets = self.offsets 

285 rounding_functions = (sp.floor, lambda x: sp.floor(x) + 1) 

286 

287 for channel_idx in range(field.shape[0] if field.index_dimensions else 1): 

288 if self.interpolation_mode == InterpolationMode.NN: 

289 if use_textures: 

290 sum[channel_idx] = self 

291 else: 

292 sum[channel_idx] = absolute_access([sp.floor(i + 0.5) for i in offsets], channel_idx) 

293 

294 elif self.interpolation_mode == InterpolationMode.LINEAR: 

295 # TODO optimization: implement via lerp: https://devblogs.nvidia.com/lerp-faster-cuda/ 

296 for c in itertools.product(rounding_functions, repeat=field.spatial_dimensions): 

297 weight = sp.Mul(*[1 - sp.Abs(f(offset) - offset) for (f, offset) in zip(c, offsets)]) 

298 index = [f(offset) for (f, offset) in zip(c, offsets)] 

299 # Hardware boundary handling on GPU 

300 if use_textures: 

301 weight = sp.Mul(*[1 - sp.Abs(f(offset) - offset) for (f, offset) in zip(c, offsets)]) 

302 sum[channel_idx] += \ 

303 weight * absolute_access(index, channel_idx if field.index_dimensions else ()) 

304 # else boundary handling using software 

305 elif str(self.interpolator.address_mode).lower() == 'border': 

306 is_inside_field = sp.And( 

307 *itertools.chain([i >= 0 for i in index], 

308 [idx < field.shape[dim] for (dim, idx) in enumerate(index)])) 

309 index = [cast_func(i, default_int_type) for i in index] 

310 sum[channel_idx] += sp.Piecewise( 

311 (weight * absolute_access(index, channel_idx if field.index_dimensions else ()), 

312 is_inside_field), 

313 (sp.simplify(0), True) 

314 ) 

315 elif str(self.interpolator.address_mode).lower() == 'clamp': 

316 index = [sp.Min(sp.Max(0, cast_func(i, default_int_type)), field.spatial_shape[dim] - 1) 

317 for (dim, i) in enumerate(index)] 

318 sum[channel_idx] += weight * \ 

319 absolute_access(index, channel_idx if field.index_dimensions else ()) 

320 elif str(self.interpolator.address_mode).lower() == 'wrap': 

321 index = [sp.Mod(cast_func(i, default_int_type), field.shape[dim] - 1) 

322 for (dim, i) in enumerate(index)] 

323 index = [cast_func(sp.Piecewise((i, i > 0), 

324 (sp.Abs(cast_func(field.shape[dim] - 1 + i, default_int_type)), 

325 True)), default_int_type) 

326 for (dim, i) in enumerate(index)] 

327 sum[channel_idx] += weight * \ 

328 absolute_access(index, channel_idx if field.index_dimensions else ()) 

329 # sum[channel_idx] = 0 

330 elif str(self.interpolator.address_mode).lower() == 'mirror': 

331 def triangle_fun(x, half_period): 

332 saw_tooth = cast_func(sp.Abs(cast_func(x, 'int32')), 'int32') % ( 

333 cast_func(2 * half_period, create_type('int32'))) 

334 return sp.Piecewise((saw_tooth, saw_tooth < half_period), 

335 (2 * half_period - 1 - saw_tooth, True)) 

336 index = [cast_func(triangle_fun(i, field.shape[dim]), 

337 default_int_type) for (dim, i) in enumerate(index)] 

338 sum[channel_idx] += weight * \ 

339 absolute_access(index, channel_idx if field.index_dimensions else ()) 

340 else: 

341 raise NotImplementedError() 

342 elif self.interpolation_mode == InterpolationMode.CUBIC_SPLINE: 

343 raise NotImplementedError("only works with HW interpolation for float32") 

344 

345 sum = [sp.factor(s) for s in sum] 

346 

347 if field.index_dimensions: 

348 return sp.Matrix(sum) 

349 else: 

350 return sum[0] 

351 

352 # noinspection SpellCheckingInspection 

353 __xnew__ = staticmethod(__new_stage2__) 

354 # noinspection SpellCheckingInspection 

355 __xnew_cached_ = staticmethod(cacheit(__new_stage2__)) 

356 

357 def __getnewargs__(self): 

358 return (self.symbol, *self.offsets) 

359 

360 def __getnewargs_ex__(self): 

361 return (self.symbol, *self.offsets), {} 

362 

363 

364class DiffInterpolatorAccess(InterpolatorAccess): 

365 def __new__(cls, symbol, diff_coordinate_idx, *offsets): 

366 if symbol.interpolator.interpolation_mode == InterpolationMode.LINEAR: 

367 from pystencils.fd import Diff, Discretization2ndOrder 

368 return Discretization2ndOrder(1)(Diff(symbol.interpolator.at(offsets), diff_coordinate_idx)) 

369 obj = DiffInterpolatorAccess.__xnew_cached_(cls, symbol, diff_coordinate_idx, *offsets) 

370 return obj 

371 

372 def __new_stage2__(self, symbol: sp.Symbol, diff_coordinate_idx, *offsets): 

373 assert offsets is not None 

374 obj = super().__xnew__(self, symbol, *offsets) 

375 obj.diff_coordinate_idx = diff_coordinate_idx 

376 return obj 

377 

378 def __hash__(self): 

379 return hash((self.symbol, self.field, self.diff_coordinate_idx, tuple(self.offsets), self.interpolator)) 

380 

381 def __str__(self): 

382 return '%s_diff%i_interpolator(%s)' % (self.field.name, self.diff_coordinate_idx, 

383 ', '.join(str(o) for o in self.offsets)) 

384 

385 def __repr__(self): 

386 return str(self) 

387 

388 @property 

389 def args(self): 

390 return [self.symbol, self.diff_coordinate_idx, *self.offsets] 

391 

392 @property 

393 def symbols_defined(self) -> Set[sp.Symbol]: 

394 return {self} 

395 

396 @property 

397 def interpolation_mode(self): 

398 return self.interpolator.interpolation_mode 

399 

400 # noinspection SpellCheckingInspection 

401 __xnew__ = staticmethod(__new_stage2__) 

402 # noinspection SpellCheckingInspection 

403 __xnew_cached_ = staticmethod(cacheit(__new_stage2__)) 

404 

405 def __getnewargs__(self): 

406 return (self.symbol, self.diff_coordinate_idx, *self.offsets) 

407 

408 def __getnewargs_ex__(self): 

409 return (self.symbol, self.diff_coordinate_idx, *self.offsets), {} 

410 

411 

412########################################################################################## 

413# GPU-specific fast specializations (for precision GPUs can also use above nodes/symbols # 

414########################################################################################## 

415 

416 

417class TextureCachedField(Interpolator): 

418 

419 def __init__(self, parent_field, 

420 address_mode=None, 

421 filter_mode=None, 

422 interpolation_mode: InterpolationMode = InterpolationMode.LINEAR, 

423 use_normalized_coordinates=False, 

424 read_as_integer=False 

425 ): 

426 super().__init__(parent_field, interpolation_mode, address_mode, use_normalized_coordinates) 

427 

428 if address_mode is None: 

429 address_mode = 'border' 

430 if filter_mode is None: 

431 filter_mode = pycuda.driver.filter_mode.LINEAR 

432 

433 self.read_as_integer = read_as_integer 

434 self.required_global_declarations = [TextureDeclaration(self)] 

435 

436 @property 

437 def ndim(self): 

438 return self.field.ndim 

439 

440 @classmethod 

441 def from_interpolator(cls, interpolator: LinearInterpolator): 

442 if (isinstance(interpolator, cls) 

443 or (hasattr(interpolator, 'allow_textures') and not interpolator.allow_textures)): 

444 return interpolator 

445 obj = cls(interpolator.field, interpolator.address_mode, interpolation_mode=interpolator.interpolation_mode) 

446 return obj 

447 

448 def __str__(self): 

449 return f'{self.field.name}_texture_{self.reproducible_hash}' 

450 

451 def __repr__(self): 

452 return self.__str__() 

453 

454 @property 

455 def reproducible_hash(self): 

456 return _hash(str(self._hashable_contents).encode()).hexdigest() 

457 

458 

459class TextureDeclaration(Node): 

460 """ 

461 A global declaration of a texture. Visible both for device and host code. 

462 

463 .. code:: cpp 

464 

465 // This Node represents the following global declaration 

466 texture<float, cudaTextureType2D, cudaReadModeElementType> x_texture_5acc9fced7b0dc3e; 

467 

468 __device__ kernel(...) { 

469 // kernel acceses x_texture_5acc9fced7b0dc3e with tex2d(...) 

470 } 

471 

472 __host__ launch_kernel(...) { 

473 // Host needs to bind the texture 

474 cudaBindTexture(0, x_texture_5acc9fced7b0dc3e, buffer, N*sizeof(float)); 

475 } 

476 

477 This has been deprecated by CUDA in favor of :class:`.TextureObject`. 

478 But texture objects are not yet supported by PyCUDA (https://github.com/inducer/pycuda/pull/174) 

479 """ 

480 

481 def __init__(self, parent_texture): 

482 self.texture = parent_texture 

483 self._symbols_defined = {self.texture.symbol} 

484 

485 @property 

486 def symbols_defined(self) -> Set[sp.Symbol]: 

487 return self._symbols_defined 

488 

489 @property 

490 def args(self) -> Set[sp.Symbol]: 

491 return set() 

492 

493 @property 

494 def headers(self): 

495 headers = ['"pycuda-helpers.hpp"'] 

496 if self.texture.interpolation_mode == InterpolationMode.CUBIC_SPLINE: 

497 headers.append('"cubicTex%iD.cu"' % self.texture.ndim) 

498 return headers 

499 

500 def __str__(self): 

501 from pystencils.backends.cuda_backend import CudaBackend 

502 return CudaBackend()(self) 

503 

504 def __repr__(self): 

505 return str(self) 

506 

507 

508class TextureObject(TextureDeclaration): 

509 """ 

510 A CUDA texture object. Opposed to :class:`.TextureDeclaration` it is not declared globally but 

511 used as a function argument for the kernel call. 

512 

513 Like :class:`.TextureDeclaration` it defines :class:`.TextureAccess` symbols. 

514 Just the printing representation is a bit different. 

515 """ 

516 pass 

517 

518 

519def dtype_supports_textures(dtype): 

520 """ 

521 Returns whether CUDA natively supports texture fetches with this numpy dtype. 

522 

523 The maximum word size for a texture fetch is four bytes. 

524 

525 With this trick also larger dtypes can be fetched: 

526 https://github.com/inducer/pycuda/blob/master/pycuda/cuda/pycuda-helpers.hpp 

527 

528 """ 

529 if hasattr(dtype, 'numpy_dtype'): 

530 dtype = dtype.numpy_dtype 

531 

532 if isinstance(dtype, type): 

533 return dtype().itemsize <= 4 

534 

535 return dtype.itemsize <= 4