1from typing import Any, Dict, Optional, Union 

2 

3import sympy as sp 

4 

5from pystencils.astnodes import KernelFunction 

6from pystencils.kernel_wrapper import KernelWrapper 

7 

8 

9def to_dot(expr: sp.Expr, graph_style: Optional[Dict[str, Any]] = None, short=True): 

10 """Show a sympy or pystencils AST as dot graph""" 

11 from pystencils.astnodes import Node 

12 import graphviz 

13 graph_style = {} if graph_style is None else graph_style 

14 

15 if isinstance(expr, Node): 

16 from pystencils.backends.dot import print_dot 

17 return graphviz.Source(print_dot(expr, short=short, graph_attr=graph_style)) 

18 else: 

19 from sympy.printing.dot import dotprint 

20 return graphviz.Source(dotprint(expr, graph_attr=graph_style)) 

21 

22 

23def highlight_cpp(code: str): 

24 """Highlight the given C/C++ source code with pygments.""" 

25 from IPython.display import HTML, display 

26 from pygments import highlight 

27 # noinspection PyUnresolvedReferences 

28 from pygments.formatters import HtmlFormatter 

29 # noinspection PyUnresolvedReferences 

30 from pygments.lexers import CppLexer 

31 

32 css = HtmlFormatter().get_style_defs('.highlight') 

33 css_tag = f"<style>{css}</style>" 

34 display(HTML(css_tag)) 

35 return HTML(highlight(code, CppLexer(), HtmlFormatter())) 

36 

37 

38def get_code_obj(ast: Union[KernelFunction, KernelWrapper], custom_backend=None): 

39 """Returns an object to display generated code (C/C++ or CUDA) 

40 

41 Can either be displayed as HTML in Jupyter notebooks or printed as normal string. 

42 """ 

43 from pystencils.backends.cbackend import generate_c 

44 

45 if isinstance(ast, KernelWrapper): 45 ↛ 46line 45 didn't jump to line 46, because the condition on line 45 was never true

46 ast = ast.ast 

47 

48 if ast.backend == 'gpucuda': 48 ↛ 49line 48 didn't jump to line 49, because the condition on line 48 was never true

49 dialect = 'cuda' 

50 elif ast.backend == 'opencl': 50 ↛ 51line 50 didn't jump to line 51, because the condition on line 50 was never true

51 dialect = 'opencl' 

52 else: 

53 dialect = 'c' 

54 

55 class CodeDisplay: 

56 def __init__(self, ast_input): 

57 self.ast = ast_input 

58 

59 def _repr_html_(self): 

60 return highlight_cpp(generate_c(self.ast, dialect=dialect, custom_backend=custom_backend)).__html__() 

61 

62 def __str__(self): 

63 return generate_c(self.ast, dialect=dialect, custom_backend=custom_backend) 

64 

65 def __repr__(self): 

66 return generate_c(self.ast, dialect=dialect, custom_backend=custom_backend) 

67 return CodeDisplay(ast) 

68 

69 

70def get_code_str(ast, custom_backend=None): 

71 return str(get_code_obj(ast, custom_backend)) 

72 

73 

74def _isnotebook(): 

75 try: 

76 shell = get_ipython().__class__.__name__ 

77 if shell == 'ZMQInteractiveShell': 

78 return True # Jupyter notebook or qtconsole 

79 elif shell == 'TerminalInteractiveShell': 

80 return False # Terminal running IPython 

81 else: 

82 return False # Other type (?) 

83 except NameError: 

84 return False 

85 

86 

87def show_code(ast: Union[KernelFunction, KernelWrapper], custom_backend=None): 

88 code = get_code_obj(ast, custom_backend) 

89 

90 if _isnotebook(): 

91 from IPython.display import display 

92 display(code) 

93 else: 

94 try: 

95 import rich.syntax 

96 import rich.console 

97 syntax = rich.syntax.Syntax(str(code), "c++", theme="monokai", line_numbers=True) 

98 console = rich.console.Console() 

99 console.print(syntax) 

100 except ImportError: 

101 print(code)