1# -*- coding: utf-8 -*- 

2# 

3# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de> 

4# 

5# Distributed under terms of the GPLv3 license. 

6 

7""" 

8 

9""" 

10from typing import Union 

11 

12import numpy as np 

13 

14try: 

15 import pycuda.driver as cuda 

16 from pycuda import gpuarray 

17 import pycuda 

18except Exception: 

19 pass 

20 

21 

22def ndarray_to_tex(tex_ref, # type: Union[cuda.TextureReference, cuda.SurfaceReference] 

23 ndarray, 

24 address_mode=None, 

25 filter_mode=None, 

26 use_normalized_coordinates=False, 

27 read_as_integer=False): 

28 

29 if isinstance(address_mode, str): 

30 address_mode = getattr(pycuda.driver.address_mode, address_mode.upper()) 

31 if address_mode is None: 

32 address_mode = cuda.address_mode.BORDER 

33 if filter_mode is None: 

34 filter_mode = cuda.filter_mode.LINEAR 

35 

36 if isinstance(ndarray, np.ndarray): 

37 cu_array = cuda.np_to_array(ndarray, 'C') 

38 elif isinstance(ndarray, gpuarray.GPUArray): 

39 cu_array = cuda.gpuarray_to_array(ndarray, 'C') 

40 else: 

41 raise TypeError( 

42 'ndarray must be numpy.ndarray or pycuda.gpuarray.GPUArray') 

43 

44 tex_ref.set_array(cu_array) 

45 

46 tex_ref.set_address_mode(0, address_mode) 

47 if ndarray.ndim >= 2: 

48 tex_ref.set_address_mode(1, address_mode) 

49 if ndarray.ndim >= 3: 

50 tex_ref.set_address_mode(2, address_mode) 

51 tex_ref.set_filter_mode(filter_mode) 

52 

53 if not use_normalized_coordinates: 

54 tex_ref.set_flags(tex_ref.get_flags() & ~cuda.TRSF_NORMALIZED_COORDINATES) 

55 

56 if not read_as_integer: 

57 tex_ref.set_flags(tex_ref.get_flags() & ~cuda.TRSF_READ_AS_INTEGER)