1from collections import namedtuple 

2from typing import Any, Callable, Optional, Sequence 

3 

4import sympy as sp 

5 

6from pystencils.simp.assignment_collection import AssignmentCollection 

7 

8 

9class SimplificationStrategy: 

10 """A simplification strategy is an ordered collection of simplification rules. 

11 

12 Each simplification is a function taking an assignment collection, and returning a new simplified 

13 assignment collection. The strategy can nicely print intermediate simplification stages and results 

14 to Jupyter notebooks. 

15 """ 

16 

17 def __init__(self): 

18 self._rules = [] 

19 

20 def add(self, rule: Callable[[AssignmentCollection], AssignmentCollection]) -> None: 

21 """Adds the given simplification rule to the end of the collection. 

22 

23 Args: 

24 rule: function that rewrites/simplifies an assignment collection 

25 """ 

26 self._rules.append(rule) 

27 

28 @property 

29 def rules(self): 

30 return self._rules 

31 

32 def apply(self, assignment_collection: AssignmentCollection) -> AssignmentCollection: 

33 """Runs all rules on the given assignment collection.""" 

34 for t in self._rules: 

35 assignment_collection = t(assignment_collection) 

36 return assignment_collection 

37 

38 def __call__(self, assignment_collection: AssignmentCollection) -> AssignmentCollection: 

39 """Same as apply""" 

40 return self.apply(assignment_collection) 

41 

42 def create_simplification_report(self, assignment_collection: AssignmentCollection) -> Any: 

43 """Creates a report to be displayed as HTML in a Jupyter notebook. 

44 

45 The simplification report contains the number of operations at each simplification stage together 

46 with the run-time the simplification took. 

47 """ 

48 

49 ReportElement = namedtuple('ReportElement', ['simplificationName', 'runtime', 'adds', 'muls', 'divs', 'total']) 

50 

51 class Report: 

52 def __init__(self): 

53 self.elements = [] 

54 

55 def add(self, element): 

56 self.elements.append(element) 

57 

58 def __str__(self): 

59 try: 

60 import tabulate 

61 return tabulate(self.elements, headers=['Name', 'Runtime', 'Adds', 'Muls', 'Divs', 'Total']) 

62 except ImportError: 

63 result = "Name, Adds, Muls, Divs, Runtime\n" 

64 for e in self.elements: 

65 result += ",".join([str(tuple_item) for tuple_item in e]) + "\n" 

66 return result 

67 

68 def _repr_html_(self): 

69 html_table = '<table style="border:none">' 

70 html_table += "<tr><th>Name</th>" \ 

71 "<th>Runtime</th>" \ 

72 "<th>Adds</th>" \ 

73 "<th>Muls</th>" \ 

74 "<th>Divs</th>" \ 

75 "<th>Total</th></tr>" 

76 line = "<tr><td>{simplificationName}</td>" \ 

77 "<td>{runtime}</td> <td>{adds}</td> <td>{muls}</td> <td>{divs}</td> <td>{total}</td> </tr>" 

78 

79 for e in self.elements: 

80 # noinspection PyProtectedMember 

81 html_table += line.format(**e._asdict()) 

82 html_table += "</table>" 

83 return html_table 

84 

85 import timeit 

86 report = Report() 

87 op = assignment_collection.operation_count 

88 total = op['adds'] + op['muls'] + op['divs'] 

89 report.add(ReportElement("OriginalTerm", '-', op['adds'], op['muls'], op['divs'], total)) 

90 for t in self._rules: 

91 start_time = timeit.default_timer() 

92 assignment_collection = t(assignment_collection) 

93 end_time = timeit.default_timer() 

94 op = assignment_collection.operation_count 

95 time_str = f"{(end_time - start_time) * 1000:.2f} ms" 

96 total = op['adds'] + op['muls'] + op['divs'] 

97 report.add(ReportElement(t.__name__, time_str, op['adds'], op['muls'], op['divs'], total)) 

98 return report 

99 

100 def show_intermediate_results(self, assignment_collection: AssignmentCollection, 

101 symbols: Optional[Sequence[sp.Symbol]] = None) -> Any: 

102 """Shows the assignment collection after the application of each rule as HTML report for Jupyter notebook. 

103 

104 Args: 

105 assignment_collection: the collection to apply the rules to 

106 symbols: if not None, only the assignments are shown that have one of these symbols as left hand side 

107 """ 

108 class IntermediateResults: 

109 def __init__(self, strategy, collection, restrict_symbols): 

110 self.strategy = strategy 

111 self.assignment_collection = collection 

112 self.restrict_symbols = restrict_symbols 

113 

114 def __str__(self): 

115 def print_assignment_collection(title, c): 

116 text = title 

117 if self.restrict_symbols: 

118 text += "\n".join([str(e) for e in c.new_filtered(self.restrict_symbols).main_assignments]) 

119 else: 

120 text += (" " * 3 + (" " * 3).join(str(c).splitlines(True))) 

121 return text 

122 

123 result = print_assignment_collection("Initial Version", self.assignment_collection) 

124 collection = self.assignment_collection 

125 for rule in self.strategy.rules: 

126 collection = rule(collection) 

127 result += print_assignment_collection(rule.__name__, collection) 

128 return result 

129 

130 def _repr_html_(self): 

131 def print_assignment_collection(title, c): 

132 text = f'<h5 style="padding-bottom:10px">{title}</h5> <div style="padding-left:20px;">' 

133 if self.restrict_symbols: 

134 text += "\n".join(["$$" + sp.latex(e) + '$$' 

135 for e in c.new_filtered(self.restrict_symbols).main_assignments]) 

136 else: 

137 # noinspection PyProtectedMember 

138 text += c._repr_html_() 

139 text += "</div>" 

140 return text 

141 

142 result = print_assignment_collection("Initial Version", self.assignment_collection) 

143 collection = self.assignment_collection 

144 for rule in self.strategy.rules: 

145 collection = rule(collection) 

146 result += print_assignment_collection(rule.__name__, collection) 

147 return result 

148 

149 return IntermediateResults(self, assignment_collection, symbols) 

150 

151 def __repr__(self): 

152 result = "Simplification Strategy:\n" 

153 for t in self._rules: 

154 result += f" - {t.__name__}\n" 

155 return result