1from typing import List, Union 

2 

3import sympy as sp 

4 

5from pystencils.astnodes import Node 

6from pystencils.simp import AssignmentCollection 

7 

8 

9# noinspection PyPep8Naming 

10class fast_division(sp.Function): 

11 nargs = (2,) 

12 

13 

14# noinspection PyPep8Naming 

15class fast_sqrt(sp.Function): 

16 nargs = (1, ) 

17 

18 

19# noinspection PyPep8Naming 

20class fast_inv_sqrt(sp.Function): 

21 nargs = (1, ) 

22 

23 

24def _run(term, visitor): 

25 if isinstance(term, AssignmentCollection): 25 ↛ 26line 25 didn't jump to line 26, because the condition on line 25 was never true

26 new_main_assignments = _run(term.main_assignments, visitor) 

27 new_subexpressions = _run(term.subexpressions, visitor) 

28 return term.copy(new_main_assignments, new_subexpressions) 

29 elif isinstance(term, list): 29 ↛ 30line 29 didn't jump to line 30, because the condition on line 29 was never true

30 return [_run(e, visitor) for e in term] 

31 else: 

32 return visitor(term) 

33 

34 

35def insert_fast_sqrts(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection]): 

36 def visit(expr): 

37 if isinstance(expr, Node): 37 ↛ 38line 37 didn't jump to line 38, because the condition on line 37 was never true

38 return expr 

39 if expr.func == sp.Pow and isinstance(expr.exp, sp.Rational) and expr.exp.q == 2: 

40 power = expr.exp.p 

41 if power < 0: 

42 return fast_inv_sqrt(expr.args[0]) ** (-power) 

43 else: 

44 return fast_sqrt(expr.args[0]) ** power 

45 else: 

46 new_args = [visit(a) for a in expr.args] 

47 return expr.func(*new_args) if new_args else expr 

48 return _run(term, visit) 

49 

50 

51def insert_fast_divisions(term: Union[sp.Expr, List[sp.Expr], AssignmentCollection]): 

52 

53 def visit(expr): 

54 if isinstance(expr, Node): 54 ↛ 55line 54 didn't jump to line 55, because the condition on line 54 was never true

55 return expr 

56 if expr.func == sp.Mul: 

57 div_args = [] 

58 other_args = [] 

59 for a in expr.args: 

60 if a.func == sp.Pow and a.exp.is_integer and a.exp < 0: 

61 div_args.append(visit(a.base) ** (-a.exp)) 

62 else: 

63 other_args.append(visit(a)) 

64 if div_args: 64 ↛ 67line 64 didn't jump to line 67, because the condition on line 64 was never false

65 return fast_division(sp.Mul(*other_args), sp.Mul(*div_args)) 

66 else: 

67 return sp.Mul(*other_args) 

68 elif expr.func == sp.Pow and expr.exp.is_integer and expr.exp < 0: 68 ↛ 69line 68 didn't jump to line 69, because the condition on line 68 was never true

69 return fast_division(1, visit(expr.base) ** (-expr.exp)) 

70 else: 

71 new_args = [visit(a) for a in expr.args] 

72 return expr.func(*new_args) if new_args else expr 

73 

74 return _run(term, visit)