1import ctypes 

2from collections import defaultdict 

3from functools import partial 

4from typing import Tuple 

5 

6import numpy as np 

7import sympy as sp 

8import sympy.codegen.ast 

9from sympy.core.cache import cacheit 

10from sympy.logic.boolalg import Boolean, BooleanFunction 

11 

12import pystencils 

13from pystencils.cache import memorycache, memorycache_if_hashable 

14from pystencils.utils import all_equal 

15 

16try: 

17 import llvmlite.ir as ir 

18except ImportError as e: 

19 ir = None 

20 _ir_importerror = e 

21 

22 

23def typed_symbols(names, dtype, *args): 

24 symbols = sp.symbols(names, *args) 

25 if isinstance(symbols, Tuple): 

26 return tuple(TypedSymbol(str(s), dtype) for s in symbols) 

27 else: 

28 return TypedSymbol(str(symbols), dtype) 

29 

30 

31def type_all_numbers(expr, dtype): 

32 substitutions = {a: cast_func(a, dtype) for a in expr.atoms(sp.Number)} 

33 return expr.subs(substitutions) 

34 

35 

36def matrix_symbols(names, dtype, rows, cols): 

37 if isinstance(names, str): 

38 names = names.replace(' ', '').split(',') 

39 

40 matrices = [] 

41 for n in names: 

42 symbols = typed_symbols("%s:%i" % (n, rows * cols), dtype) 

43 matrices.append(sp.Matrix(rows, cols, lambda i, j: symbols[i * cols + j])) 

44 

45 return tuple(matrices) 

46 

47 

48def assumptions_from_dtype(dtype): 

49 """Derives SymPy assumptions from :class:`BasicType` or a Numpy dtype 

50 

51 Args: 

52 dtype (BasicType, np.dtype): a Numpy data type 

53 Returns: 

54 A dict of SymPy assumptions 

55 """ 

56 if hasattr(dtype, 'numpy_dtype'): 

57 dtype = dtype.numpy_dtype 

58 

59 assumptions = dict() 

60 

61 try: 

62 if np.issubdtype(dtype, np.integer): 

63 assumptions.update({'integer': True}) 

64 

65 if np.issubdtype(dtype, np.unsignedinteger): 

66 assumptions.update({'negative': False}) 

67 

68 if np.issubdtype(dtype, np.integer) or \ 

69 np.issubdtype(dtype, np.floating): 

70 assumptions.update({'real': True}) 

71 except Exception: 

72 pass 

73 

74 return assumptions 

75 

76 

77# noinspection PyPep8Naming 

78class address_of(sp.Function): 

79 is_Atom = True 

80 

81 def __new__(cls, arg): 

82 obj = sp.Function.__new__(cls, arg) 

83 return obj 

84 

85 @property 

86 def canonical(self): 

87 if hasattr(self.args[0], 'canonical'): 

88 return self.args[0].canonical 

89 else: 

90 raise NotImplementedError() 

91 

92 @property 

93 def is_commutative(self): 

94 return self.args[0].is_commutative 

95 

96 @property 

97 def dtype(self): 

98 if hasattr(self.args[0], 'dtype'): 

99 return PointerType(self.args[0].dtype, restrict=True) 

100 else: 

101 return PointerType('void', restrict=True) 

102 

103 

104# noinspection PyPep8Naming 

105class cast_func(sp.Function): 

106 is_Atom = True 

107 

108 def __new__(cls, *args, **kwargs): 

109 if len(args) != 2: 

110 pass 

111 expr, dtype, *other_args = args 

112 if not isinstance(dtype, Type): 112 ↛ 113line 112 didn't jump to line 113, because the condition on line 112 was never true

113 dtype = create_type(dtype) 

114 # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well 

115 # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads 

116 # to problems when for example comparing cast_func's for equality 

117 # 

118 # lhs = bitwise_and(a, cast_func(1, 'int')) 

119 # rhs = cast_func(0, 'int') 

120 # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans 

121 # -> thus a separate class boolean_cast_func is introduced 

122 if isinstance(expr, Boolean) and (not isinstance(expr, TypedSymbol) or expr.dtype == BasicType(bool)): 

123 cls = boolean_cast_func 

124 

125 return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs) 

126 

127 @property 

128 def canonical(self): 

129 if hasattr(self.args[0], 'canonical'): 

130 return self.args[0].canonical 

131 else: 

132 raise NotImplementedError() 

133 

134 @property 

135 def is_commutative(self): 

136 return self.args[0].is_commutative 

137 

138 def _eval_evalf(self, *args, **kwargs): 

139 return self.args[0].evalf() 

140 

141 @property 

142 def dtype(self): 

143 return self.args[1] 

144 

145 @property 

146 def is_integer(self): 

147 """ 

148 Uses Numpy type hierarchy to determine :func:`sympy.Expr.is_integer` predicate 

149 

150 For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html 

151 """ 

152 if hasattr(self.dtype, 'numpy_dtype'): 152 ↛ 153line 152 didn't jump to line 153, because the condition on line 152 was never true

153 return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer 

154 else: 

155 return super().is_integer 

156 

157 @property 

158 def is_negative(self): 

159 """ 

160 See :func:`.TypedSymbol.is_integer` 

161 """ 

162 if hasattr(self.dtype, 'numpy_dtype'): 162 ↛ 163line 162 didn't jump to line 163, because the condition on line 162 was never true

163 if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger): 

164 return False 

165 

166 return super().is_negative 

167 

168 @property 

169 def is_nonnegative(self): 

170 """ 

171 See :func:`.TypedSymbol.is_integer` 

172 """ 

173 if self.is_negative is False: 173 ↛ 174line 173 didn't jump to line 174, because the condition on line 173 was never true

174 return True 

175 else: 

176 return super().is_nonnegative 

177 

178 @property 

179 def is_real(self): 

180 """ 

181 See :func:`.TypedSymbol.is_integer` 

182 """ 

183 if hasattr(self.dtype, 'numpy_dtype'): 

184 return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \ 

185 np.issubdtype(self.dtype.numpy_dtype, np.floating) or \ 

186 super().is_real 

187 else: 

188 return super().is_real 

189 

190 

191# noinspection PyPep8Naming 

192class boolean_cast_func(cast_func, Boolean): 

193 pass 

194 

195 

196# noinspection PyPep8Naming 

197class vector_memory_access(cast_func): 

198 # Arguments are: read/write expression, type, aligned, nontemporal, mask (or none), stride 

199 nargs = (6,) 

200 

201 

202# noinspection PyPep8Naming 

203class reinterpret_cast_func(cast_func): 

204 pass 

205 

206 

207# noinspection PyPep8Naming 

208class pointer_arithmetic_func(sp.Function, Boolean): 

209 @property 

210 def canonical(self): 

211 if hasattr(self.args[0], 'canonical'): 

212 return self.args[0].canonical 

213 else: 

214 raise NotImplementedError() 

215 

216 

217class TypedSymbol(sp.Symbol): 

218 def __new__(cls, *args, **kwds): 

219 obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds) 

220 return obj 

221 

222 def __new_stage2__(cls, name, dtype, **kwargs): 

223 assumptions = assumptions_from_dtype(dtype) 

224 assumptions.update(kwargs) 

225 obj = super(TypedSymbol, cls).__xnew__(cls, name, **assumptions) 

226 try: 

227 obj._dtype = create_type(dtype) 

228 except (TypeError, ValueError): 

229 # on error keep the string 

230 obj._dtype = dtype 

231 return obj 

232 

233 __xnew__ = staticmethod(__new_stage2__) 

234 __xnew_cached_ = staticmethod(cacheit(__new_stage2__)) 

235 

236 @property 

237 def dtype(self): 

238 return self._dtype 

239 

240 def _hashable_content(self): 

241 return super()._hashable_content(), hash(self._dtype) 

242 

243 def __getnewargs__(self): 

244 return self.name, self.dtype 

245 

246 def __getnewargs_ex__(self): 

247 return (self.name, self.dtype), self.assumptions0 

248 

249 @property 

250 def canonical(self): 

251 return self 

252 

253 @property 

254 def reversed(self): 

255 return self 

256 

257 @property 

258 def headers(self): 

259 headers = [] 

260 try: 

261 if np.issubdtype(self.dtype.numpy_dtype, np.complexfloating): 261 ↛ 262line 261 didn't jump to line 262, because the condition on line 261 was never true

262 headers.append('"cuda_complex.hpp"') 

263 except Exception: 

264 pass 

265 try: 

266 if np.issubdtype(self.dtype.base_type.numpy_dtype, np.complexfloating): 266 ↛ 267line 266 didn't jump to line 267, because the condition on line 266 was never true

267 headers.append('"cuda_complex.hpp"') 

268 except Exception: 

269 pass 

270 

271 return headers 

272 

273 

274def create_type(specification): 

275 """Creates a subclass of Type according to a string or an object of subclass Type. 

276 

277 Args: 

278 specification: Type object, or a string 

279 

280 Returns: 

281 Type object, or a new Type object parsed from the string 

282 """ 

283 if isinstance(specification, Type): 

284 return specification 

285 else: 

286 numpy_dtype = np.dtype(specification) 

287 if numpy_dtype.fields is None: 287 ↛ 290line 287 didn't jump to line 290, because the condition on line 287 was never false

288 return BasicType(numpy_dtype, const=False) 

289 else: 

290 return StructType(numpy_dtype, const=False) 

291 

292 

293@memorycache(maxsize=64) 

294def create_composite_type_from_string(specification): 

295 """Creates a new Type object from a c-like string specification. 

296 

297 Args: 

298 specification: Specification string 

299 

300 Returns: 

301 Type object 

302 """ 

303 specification = specification.lower().split() 

304 parts = [] 

305 current = [] 

306 for s in specification: 

307 if s == '*': 307 ↛ 308line 307 didn't jump to line 308, because the condition on line 307 was never true

308 parts.append(current) 

309 current = [s] 

310 else: 

311 current.append(s) 

312 if len(current) > 0: 312 ↛ 315line 312 didn't jump to line 315, because the condition on line 312 was never false

313 parts.append(current) 

314 # Parse native part 

315 base_part = parts.pop(0) 

316 const = False 

317 if 'const' in base_part: 317 ↛ 320line 317 didn't jump to line 320, because the condition on line 317 was never false

318 const = True 

319 base_part.remove('const') 

320 assert len(base_part) == 1 

321 if base_part[0][-1] == "*": 321 ↛ 322line 321 didn't jump to line 322, because the condition on line 321 was never true

322 base_part[0] = base_part[0][:-1] 

323 parts.append('*') 

324 current_type = BasicType(np.dtype(base_part[0]), const) 

325 # Parse pointer parts 

326 for part in parts: 326 ↛ 327line 326 didn't jump to line 327, because the loop on line 326 never started

327 restrict = False 

328 const = False 

329 if 'restrict' in part: 

330 restrict = True 

331 part.remove('restrict') 

332 if 'const' in part: 

333 const = True 

334 part.remove("const") 

335 assert len(part) == 1 and part[0] == '*' 

336 current_type = PointerType(current_type, const, restrict) 

337 return current_type 

338 

339 

340def get_base_type(data_type): 

341 while data_type.base_type is not None: 

342 data_type = data_type.base_type 

343 return data_type 

344 

345 

346def to_ctypes(data_type): 

347 """ 

348 Transforms a given Type into ctypes 

349 :param data_type: Subclass of Type 

350 :return: ctypes type object 

351 """ 

352 if isinstance(data_type, PointerType): 

353 return ctypes.POINTER(to_ctypes(data_type.base_type)) 

354 elif isinstance(data_type, StructType): 

355 return ctypes.POINTER(ctypes.c_uint8) 

356 else: 

357 return to_ctypes.map[data_type.numpy_dtype] 

358 

359 

360to_ctypes.map = { 

361 np.dtype(np.int8): ctypes.c_int8, 

362 np.dtype(np.int16): ctypes.c_int16, 

363 np.dtype(np.int32): ctypes.c_int32, 

364 np.dtype(np.int64): ctypes.c_int64, 

365 

366 np.dtype(np.uint8): ctypes.c_uint8, 

367 np.dtype(np.uint16): ctypes.c_uint16, 

368 np.dtype(np.uint32): ctypes.c_uint32, 

369 np.dtype(np.uint64): ctypes.c_uint64, 

370 

371 np.dtype(np.float32): ctypes.c_float, 

372 np.dtype(np.float64): ctypes.c_double, 

373} 

374 

375 

376def ctypes_from_llvm(data_type): 

377 if not ir: 

378 raise _ir_importerror 

379 if isinstance(data_type, ir.PointerType): 

380 ctype = ctypes_from_llvm(data_type.pointee) 

381 if ctype is None: 

382 return ctypes.c_void_p 

383 else: 

384 return ctypes.POINTER(ctype) 

385 elif isinstance(data_type, ir.IntType): 

386 if data_type.width == 8: 

387 return ctypes.c_int8 

388 elif data_type.width == 16: 

389 return ctypes.c_int16 

390 elif data_type.width == 32: 

391 return ctypes.c_int32 

392 elif data_type.width == 64: 

393 return ctypes.c_int64 

394 else: 

395 raise ValueError("Int width %d is not supported" % data_type.width) 

396 elif isinstance(data_type, ir.FloatType): 

397 return ctypes.c_float 

398 elif isinstance(data_type, ir.DoubleType): 

399 return ctypes.c_double 

400 elif isinstance(data_type, ir.VoidType): 

401 return None # Void type is not supported by ctypes 

402 else: 

403 raise NotImplementedError(f'Data type {type(data_type)} of {data_type} is not supported yet') 

404 

405 

406def to_llvm_type(data_type, nvvm_target=False): 

407 """ 

408 Transforms a given type into ctypes 

409 :param data_type: Subclass of Type 

410 :return: llvmlite type object 

411 """ 

412 if not ir: 

413 raise _ir_importerror 

414 if isinstance(data_type, PointerType): 

415 return to_llvm_type(data_type.base_type).as_pointer(1 if nvvm_target else 0) 

416 else: 

417 return to_llvm_type.map[data_type.numpy_dtype] 

418 

419 

420if ir: 420 ↛ 421line 420 didn't jump to line 421

421 to_llvm_type.map = { 

422 np.dtype(np.int8): ir.IntType(8), 

423 np.dtype(np.int16): ir.IntType(16), 

424 np.dtype(np.int32): ir.IntType(32), 

425 np.dtype(np.int64): ir.IntType(64), 

426 

427 np.dtype(np.uint8): ir.IntType(8), 

428 np.dtype(np.uint16): ir.IntType(16), 

429 np.dtype(np.uint32): ir.IntType(32), 

430 np.dtype(np.uint64): ir.IntType(64), 

431 

432 np.dtype(np.float32): ir.FloatType(), 

433 np.dtype(np.float64): ir.DoubleType(), 

434 } 

435 

436 

437def peel_off_type(dtype, type_to_peel_off): 

438 while type(dtype) is type_to_peel_off: 

439 dtype = dtype.base_type 

440 return dtype 

441 

442 

443def collate_types(types, 

444 forbid_collation_to_complex=False, 

445 forbid_collation_to_float=False, 

446 default_float_type='float64', 

447 default_int_type='int64'): 

448 """ 

449 Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double 

450 Uses the collation rules from numpy. 

451 """ 

452 if forbid_collation_to_complex: 

453 types = [ 

454 t for t in types 

455 if not np.issubdtype(t.numpy_dtype, np.complexfloating) 

456 ] 

457 if not types: 457 ↛ 458line 457 didn't jump to line 458, because the condition on line 457 was never true

458 return create_type(default_float_type) 

459 

460 if forbid_collation_to_float: 

461 types = [ 

462 t for t in types if not np.issubdtype(t.numpy_dtype, np.floating) 

463 ] 

464 if not types: 464 ↛ 465line 464 didn't jump to line 465, because the condition on line 464 was never true

465 return create_type(default_int_type) 

466 

467 # Pointer arithmetic case i.e. pointer + integer is allowed 

468 if any(type(t) is PointerType for t in types): 

469 pointer_type = None 

470 for t in types: 

471 if type(t) is PointerType: 

472 if pointer_type is not None: 472 ↛ 473line 472 didn't jump to line 473, because the condition on line 472 was never true

473 raise ValueError("Cannot collate the combination of two pointer types") 

474 pointer_type = t 

475 elif type(t) is BasicType: 475 ↛ 479line 475 didn't jump to line 479, because the condition on line 475 was never false

476 if not (t.is_int() or t.is_uint()): 476 ↛ 477line 476 didn't jump to line 477, because the condition on line 476 was never true

477 raise ValueError("Invalid pointer arithmetic") 

478 else: 

479 raise ValueError("Invalid pointer arithmetic") 

480 return pointer_type 

481 

482 # peel of vector types, if at least one vector type occurred the result will also be the vector type 

483 vector_type = [t for t in types if type(t) is VectorType] 

484 if not all_equal(t.width for t in vector_type): 484 ↛ 485line 484 didn't jump to line 485, because the condition on line 484 was never true

485 raise ValueError("Collation failed because of vector types with different width") 

486 types = [peel_off_type(t, VectorType) for t in types] 

487 

488 # now we should have a list of basic types - struct types are not yet supported 

489 assert all(type(t) is BasicType for t in types) 

490 

491 if any(t.is_float() for t in types): 

492 types = tuple(t for t in types if t.is_float()) 

493 # use numpy collation -> create type from numpy type -> and, put vector type around if necessary 

494 result_numpy_type = np.result_type(*(t.numpy_dtype for t in types)) 

495 result = BasicType(result_numpy_type) 

496 if vector_type: 

497 result = VectorType(result, vector_type[0].width) 

498 return result 

499 

500 

501@memorycache_if_hashable(maxsize=2048) 

502def get_type_of_expression(expr, 

503 default_float_type='double', 

504 default_int_type='int', 

505 symbol_type_dict=None): 

506 from pystencils.astnodes import ResolvedFieldAccess 

507 from pystencils.cpu.vectorization import vec_all, vec_any 

508 

509 if default_float_type == 'float': 509 ↛ 510line 509 didn't jump to line 510, because the condition on line 509 was never true

510 default_float_type = 'float32' 

511 

512 if not symbol_type_dict: 

513 symbol_type_dict = defaultdict(lambda: create_type('double')) 513 ↛ exitline 513 didn't run the lambda on line 513

514 

515 get_type = partial(get_type_of_expression, 

516 default_float_type=default_float_type, 

517 default_int_type=default_int_type, 

518 symbol_type_dict=symbol_type_dict) 

519 

520 expr = sp.sympify(expr) 

521 if isinstance(expr, sp.Integer): 

522 return create_type(default_int_type) 

523 elif expr.is_real is False: 523 ↛ 524line 523 didn't jump to line 524, because the condition on line 523 was never true

524 return create_type((np.zeros((1,), default_float_type) * 1j).dtype) 

525 elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float): 

526 return create_type(default_float_type) 

527 elif isinstance(expr, ResolvedFieldAccess): 

528 return expr.field.dtype 

529 elif isinstance(expr, pystencils.field.Field.AbstractAccess): 

530 return expr.field.dtype 

531 elif isinstance(expr, TypedSymbol): 

532 return expr.dtype 

533 elif isinstance(expr, sp.Symbol): 

534 if symbol_type_dict: 534 ↛ 537line 534 didn't jump to line 537, because the condition on line 534 was never false

535 return symbol_type_dict[expr.name] 

536 else: 

537 raise ValueError("All symbols inside this expression have to be typed! ", str(expr)) 

538 elif isinstance(expr, cast_func): 

539 return expr.args[1] 

540 elif isinstance(expr, (vec_any, vec_all)): 

541 return create_type("bool") 

542 elif hasattr(expr, 'func') and expr.func == sp.Piecewise: 

543 collated_result_type = collate_types(tuple(get_type(a[0]) for a in expr.args)) 

544 collated_condition_type = collate_types(tuple(get_type(a[1]) for a in expr.args)) 

545 if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType: 

546 collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width) 

547 return collated_result_type 

548 elif isinstance(expr, sp.Indexed): 548 ↛ 549line 548 didn't jump to line 549, because the condition on line 548 was never true

549 typed_symbol = expr.base.label 

550 return typed_symbol.dtype.base_type 

551 elif isinstance(expr, (Boolean, BooleanFunction)): 

552 # if any arg is of vector type return a vector boolean, else return a normal scalar boolean 

553 result = create_type("bool") 

554 vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)] 

555 if vec_args: 

556 result = VectorType(result, width=vec_args[0].width) 

557 return result 

558 elif isinstance(expr, sp.Pow): 

559 base_type = get_type(expr.args[0]) 

560 if expr.exp.is_integer: 

561 return base_type 

562 else: 

563 return collate_types([create_type(default_float_type), base_type]) 

564 elif isinstance(expr, (sp.Sum, sp.Product)): 564 ↛ 565line 564 didn't jump to line 565, because the condition on line 564 was never true

565 return get_type(expr.args[0]) 

566 elif isinstance(expr, sp.Expr): 

567 expr: sp.Expr 

568 if expr.args: 568 ↛ 577line 568 didn't jump to line 577, because the condition on line 568 was never false

569 types = tuple(get_type(a) for a in expr.args) 

570 return collate_types( 

571 types, 

572 forbid_collation_to_complex=expr.is_real is True, 

573 forbid_collation_to_float=expr.is_integer is True, 

574 default_float_type=default_float_type, 

575 default_int_type=default_int_type) 

576 else: 

577 if expr.is_integer: 

578 return create_type(default_int_type) 

579 else: 

580 return create_type(default_float_type) 

581 

582 raise NotImplementedError("Could not determine type for", expr, type(expr)) 

583 

584 

585sympy_version = sp.__version__.split('.') 

586if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109: 586 ↛ 588line 586 didn't jump to line 588, because the condition on line 586 was never true

587 # __setstate__ would bypass the contructor, so we remove it 

588 sp.Number.__getstate__ = sp.Basic.__getstate__ 

589 del sp.Basic.__getstate__ 

590 

591 

592class Type(sp.Atom): 

593 def __new__(cls, *args, **kwargs): 

594 return sp.Basic.__new__(cls) 

595 

596 def _sympystr(self, *args, **kwargs): 

597 return str(self) 

598 

599 

600class BasicType(Type): 

601 @staticmethod 

602 def numpy_name_to_c(name): 

603 if name == 'float64': 

604 return 'double' 

605 elif name == 'float32': 

606 return 'float' 

607 elif name == 'complex64': 607 ↛ 608line 607 didn't jump to line 608, because the condition on line 607 was never true

608 return 'ComplexFloat' 

609 elif name == 'complex128': 609 ↛ 610line 609 didn't jump to line 610, because the condition on line 609 was never true

610 return 'ComplexDouble' 

611 elif name.startswith('int'): 

612 width = int(name[len("int"):]) 

613 return "int%d_t" % (width,) 

614 elif name.startswith('uint'): 

615 width = int(name[len("uint"):]) 

616 return "uint%d_t" % (width,) 

617 elif name == 'bool': 

618 return 'bool' 

619 else: 

620 raise NotImplementedError(f"Can map numpy to C name for {name}") 

621 

622 def __init__(self, dtype, const=False): 

623 self.const = const 

624 if isinstance(dtype, Type): 624 ↛ 625line 624 didn't jump to line 625, because the condition on line 624 was never true

625 self._dtype = dtype.numpy_dtype 

626 else: 

627 self._dtype = np.dtype(dtype) 

628 assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type" 

629 assert self._dtype.hasobject is False 

630 assert self._dtype.subdtype is None 

631 

632 def __getnewargs__(self): 

633 return self.numpy_dtype, self.const 

634 

635 def __getnewargs_ex__(self): 

636 return (self.numpy_dtype, self.const), {} 

637 

638 @property 

639 def base_type(self): 

640 return None 

641 

642 @property 

643 def numpy_dtype(self): 

644 return self._dtype 

645 

646 @property 

647 def sympy_dtype(self): 

648 return getattr(sympy.codegen.ast, str(self.numpy_dtype)) 

649 

650 @property 

651 def item_size(self): 

652 return 1 

653 

654 def is_int(self): 

655 return self.numpy_dtype in np.sctypes['int'] or self.numpy_dtype in np.sctypes['uint'] 

656 

657 def is_float(self): 

658 return self.numpy_dtype in np.sctypes['float'] 

659 

660 def is_uint(self): 

661 return self.numpy_dtype in np.sctypes['uint'] 

662 

663 def is_complex(self): 

664 return self.numpy_dtype in np.sctypes['complex'] 

665 

666 def is_other(self): 

667 return self.numpy_dtype in np.sctypes['others'] 

668 

669 @property 

670 def base_name(self): 

671 return BasicType.numpy_name_to_c(str(self._dtype)) 

672 

673 def __str__(self): 

674 result = BasicType.numpy_name_to_c(str(self._dtype)) 

675 if self.const: 

676 result += " const" 

677 return result 

678 

679 def __repr__(self): 

680 return str(self) 

681 

682 def __eq__(self, other): 

683 if not isinstance(other, BasicType): 

684 return False 

685 else: 

686 return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const) 

687 

688 def __hash__(self): 

689 return hash(str(self)) 

690 

691 

692class VectorType(Type): 

693 instruction_set = None 

694 

695 def __init__(self, base_type, width=4): 

696 self._base_type = base_type 

697 self.width = width 

698 

699 @property 

700 def base_type(self): 

701 return self._base_type 

702 

703 @property 

704 def item_size(self): 

705 return self.width * self.base_type.item_size 

706 

707 def __eq__(self, other): 

708 if not isinstance(other, VectorType): 708 ↛ 709line 708 didn't jump to line 709, because the condition on line 708 was never true

709 return False 

710 else: 

711 return (self.base_type, self.width) == (other.base_type, other.width) 

712 

713 def __str__(self): 

714 if self.instruction_set is None: 

715 return "%s[%d]" % (self.base_type, self.width) 

716 else: 

717 if self.base_type == create_type("int64") or self.base_type == create_type("int32"): 

718 return self.instruction_set['int'] 

719 elif self.base_type == create_type("float64"): 

720 return self.instruction_set['double'] 

721 elif self.base_type == create_type("float32"): 

722 return self.instruction_set['float'] 

723 elif self.base_type == create_type("bool"): 

724 return self.instruction_set['bool'] 

725 else: 

726 raise NotImplementedError() 

727 

728 def __hash__(self): 

729 return hash((self.base_type, self.width)) 

730 

731 def __getnewargs__(self): 

732 return self._base_type, self.width 

733 

734 def __getnewargs_ex__(self): 

735 return (self._base_type, self.width), {} 

736 

737 

738class PointerType(Type): 

739 def __init__(self, base_type, const=False, restrict=True): 

740 self._base_type = base_type 

741 self.const = const 

742 self.restrict = restrict 

743 

744 def __getnewargs__(self): 

745 return self.base_type, self.const, self.restrict 

746 

747 def __getnewargs_ex__(self): 

748 return (self.base_type, self.const, self.restrict), {} 

749 

750 @property 

751 def alias(self): 

752 return not self.restrict 

753 

754 @property 

755 def base_type(self): 

756 return self._base_type 

757 

758 @property 

759 def item_size(self): 

760 return self.base_type.item_size 

761 

762 def __eq__(self, other): 

763 if not isinstance(other, PointerType): 763 ↛ 764line 763 didn't jump to line 764, because the condition on line 763 was never true

764 return False 

765 else: 

766 return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict) 

767 

768 def __str__(self): 

769 components = [str(self.base_type), '*'] 

770 if self.restrict: 770 ↛ 772line 770 didn't jump to line 772, because the condition on line 770 was never false

771 components.append('RESTRICT') 

772 if self.const: 

773 components.append("const") 

774 return " ".join(components) 

775 

776 def __repr__(self): 

777 return str(self) 

778 

779 def __hash__(self): 

780 return hash((self._base_type, self.const, self.restrict)) 

781 

782 

783class StructType: 

784 def __init__(self, numpy_type, const=False): 

785 self.const = const 

786 self._dtype = np.dtype(numpy_type) 

787 

788 def __getnewargs__(self): 

789 return self.numpy_dtype, self.const 

790 

791 def __getnewargs_ex__(self): 

792 return (self.numpy_dtype, self.const), {} 

793 

794 @property 

795 def base_type(self): 

796 return None 

797 

798 @property 

799 def numpy_dtype(self): 

800 return self._dtype 

801 

802 @property 

803 def item_size(self): 

804 return self.numpy_dtype.itemsize 

805 

806 def get_element_offset(self, element_name): 

807 return self.numpy_dtype.fields[element_name][1] 

808 

809 def get_element_type(self, element_name): 

810 np_element_type = self.numpy_dtype.fields[element_name][0] 

811 return BasicType(np_element_type, self.const) 

812 

813 def has_element(self, element_name): 

814 return element_name in self.numpy_dtype.fields 

815 

816 def __eq__(self, other): 

817 if not isinstance(other, StructType): 

818 return False 

819 else: 

820 return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const) 

821 

822 def __str__(self): 

823 # structs are handled byte-wise 

824 result = "uint8_t" 

825 if self.const: 

826 result += " const" 

827 return result 

828 

829 def __repr__(self): 

830 return str(self) 

831 

832 def __hash__(self): 

833 return hash((self.numpy_dtype, self.const)) 

834 

835 

836class TypedImaginaryUnit(TypedSymbol): 

837 def __new__(cls, *args, **kwds): 

838 obj = TypedImaginaryUnit.__xnew_cached_(cls, *args, **kwds) 

839 return obj 

840 

841 def __new_stage2__(cls, dtype): 

842 obj = super(TypedImaginaryUnit, cls).__xnew__(cls, 

843 "_i", 

844 dtype, 

845 imaginary=True) 

846 return obj 

847 

848 headers = ['"cuda_complex.hpp"'] 

849 

850 __xnew__ = staticmethod(__new_stage2__) 

851 __xnew_cached_ = staticmethod(cacheit(__new_stage2__)) 

852 

853 def __getnewargs__(self): 

854 return (self.dtype,) 

855 

856 def __getnewargs_ex__(self): 

857 return (self.dtype,), {}