1import graphviz 

2from graphviz import Digraph, lang 

3from sympy.printing.printer import Printer 

4 

5 

6# noinspection PyPep8Naming 

7class DotPrinter(Printer): 

8 """ 

9 A printer which converts ast to DOT (graph description language). 

10 """ 

11 def __init__(self, node_to_str_function, **kwargs): 

12 super(DotPrinter, self).__init__() 

13 self._node_to_str_function = node_to_str_function 

14 self.dot = Digraph(**kwargs) 

15 self.dot.quote_edge = lang.quote 

16 

17 def _print_KernelFunction(self, func): 

18 self.dot.node(str(id(func)), style='filled', fillcolor='#a056db', label=self._node_to_str_function(func)) 

19 self._print(func.body) 

20 self.dot.edge(str(id(func)), str(id(func.body))) 

21 

22 def _print_LoopOverCoordinate(self, loop): 

23 self.dot.node(str(id(loop)), style='filled', fillcolor='#3498db', label=self._node_to_str_function(loop)) 

24 self._print(loop.body) 

25 self.dot.edge(str(id(loop)), str(id(loop.body))) 

26 

27 def _print_Block(self, block): 

28 for node in block.args: 

29 self._print(node) 

30 

31 self.dot.node(str(id(block)), style='filled', fillcolor='#dbc256', label=repr(block)) 

32 for node in block.args: 

33 self.dot.edge(str(id(block)), str(id(node))) 

34 

35 def _print_SympyAssignment(self, assignment): 

36 self.dot.node(str(id(assignment)), style='filled', fillcolor='#56db7f', 

37 label=self._node_to_str_function(assignment)) 

38 

39 def _print_Conditional(self, expr): 

40 self.dot.node(str(id(expr)), style='filled', fillcolor='#56bd7f', label=self._node_to_str_function(expr)) 

41 self._print(expr.true_block) 

42 self.dot.edge(str(id(expr)), str(id(expr.true_block))) 

43 if expr.false_block: 

44 self._print(expr.false_block) 

45 self.dot.edge(str(id(expr)), str(id(expr.false_block))) 

46 

47 def doprint(self, expr): 

48 self._print(expr) 

49 return self.dot.source 

50 

51 

52def __shortened(node): 

53 from pystencils.astnodes import LoopOverCoordinate, KernelFunction, SympyAssignment, Conditional 

54 if isinstance(node, LoopOverCoordinate): 

55 return "Loop over dim %d" % (node.coordinate_to_loop_over,) 

56 elif isinstance(node, KernelFunction): 

57 params = node.get_parameters() 

58 param_names = [p.field_name for p in params if p.is_field_pointer] 

59 param_names += [p.symbol.name for p in params if not p.is_field_parameter] 

60 return f"Func: {node.function_name} ({','.join(param_names)})" 

61 elif isinstance(node, SympyAssignment): 

62 return repr(node.lhs) 

63 elif isinstance(node, Conditional): 

64 return repr(node) 

65 else: 

66 raise NotImplementedError(f"Cannot handle node type {type(node)}") 

67 

68 

69def print_dot(node, view=False, short=False, **kwargs): 

70 """ 

71 Returns a string which can be used to generate a DOT-graph 

72 :param node: The ast which should be generated 

73 :param view: Boolean, if rendering of the image directly should occur. 

74 :param short: Uses the __shortened output 

75 :param kwargs: is directly passed to the DotPrinter class: http://graphviz.readthedocs.io/en/latest/api.html#digraph 

76 :return: string in DOT format 

77 """ 

78 node_to_str_function = repr 

79 if short: 

80 node_to_str_function = __shortened 

81 printer = DotPrinter(node_to_str_function, **kwargs) 

82 dot = printer.doprint(node) 

83 if view: 

84 return graphviz.Source(dot) 

85 return dot