1import itertools 

2import operator 

3import warnings 

4from collections import Counter, defaultdict 

5from functools import partial, reduce 

6from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union 

7 

8import sympy as sp 

9from sympy.functions import Abs 

10from sympy.core.numbers import Zero 

11 

12from pystencils.assignment import Assignment 

13from pystencils.data_types import cast_func, get_type_of_expression, PointerType, VectorType 

14from pystencils.kernelparameters import FieldPointerSymbol 

15 

16T = TypeVar('T') 

17 

18 

19def prod(seq: Iterable[T]) -> T: 

20 """Takes a sequence and returns the product of all elements""" 

21 return reduce(operator.mul, seq, 1) 

22 

23 

24def remove_small_floats(expr, threshold): 

25 """Removes all sp.Float objects whose absolute value is smaller than threshold 

26 

27 >>> expr = sp.sympify("x + 1e-15 * y") 

28 >>> remove_small_floats(expr, 1e-14) 

29 x 

30 """ 

31 if isinstance(expr, sp.Float) and sp.Abs(expr) < threshold: 

32 return 0 

33 else: 

34 new_args = [remove_small_floats(c, threshold) for c in expr.args] 

35 return expr.func(*new_args) if new_args else expr 

36 

37 

38def is_integer_sequence(sequence: Iterable) -> bool: 

39 """Checks if all elements of the passed sequence can be cast to integers""" 

40 try: 

41 for i in sequence: 

42 int(i) 

43 return True 

44 except TypeError: 

45 return False 

46 

47 

48def scalar_product(a: Iterable[T], b: Iterable[T]) -> T: 

49 """Scalar product between two sequences.""" 

50 return sum(a_i * b_i for a_i, b_i in zip(a, b)) 

51 

52 

53def kronecker_delta(*args): 

54 """Kronecker delta for variable number of arguments, 1 if all args are equal, otherwise 0""" 

55 for a in args: 

56 if a != args[0]: 

57 return 0 

58 return 1 

59 

60 

61def tanh_step_function_approximation(x, step_location, kind='right', steepness=0.0001): 

62 """Approximation of step function by a tanh function 

63 

64 >>> tanh_step_function_approximation(1.2, step_location=1.0, kind='right') 

65 1.00000000000000 

66 >>> tanh_step_function_approximation(0.9, step_location=1.0, kind='right') 

67 0 

68 >>> tanh_step_function_approximation(1.1, step_location=1.0, kind='left') 

69 0 

70 >>> tanh_step_function_approximation(0.9, step_location=1.0, kind='left') 

71 1.00000000000000 

72 >>> tanh_step_function_approximation(0.5, step_location=(0, 1), kind='middle') 

73 1 

74 """ 

75 if kind == 'left': 

76 return (1 - sp.tanh((x - step_location) / steepness)) / 2 

77 elif kind == 'right': 

78 return (1 + sp.tanh((x - step_location) / steepness)) / 2 

79 elif kind == 'middle': 

80 x1, x2 = step_location 

81 return 1 - (tanh_step_function_approximation(x, x1, 'left', steepness) 

82 + tanh_step_function_approximation(x, x2, 'right', steepness)) 

83 

84 

85def multidimensional_sum(i, dim): 

86 """Multidimensional summation 

87 

88 Example: 

89 >>> list(multidimensional_sum(2, dim=3)) 

90 [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)] 

91 """ 

92 prod_args = [range(dim)] * i 

93 return itertools.product(*prod_args) 

94 

95 

96def normalize_product(product: sp.Expr) -> List[sp.Expr]: 

97 """Expects a sympy expression that can be interpreted as a product and returns a list of all factors. 

98 

99 Removes sp.Pow nodes that have integer exponent by representing them as single factors in list. 

100 

101 Returns: 

102 * for a Mul node list of factors ('args') 

103 * for a Pow node with positive integer exponent a list of factors 

104 * for other node types [product] is returned 

105 """ 

106 def handle_pow(power): 

107 if power.exp.is_integer and power.exp.is_number and power.exp > 0: 

108 return [power.base] * power.exp 

109 else: 

110 return [power] 

111 

112 if isinstance(product, sp.Pow): 

113 return handle_pow(product) 

114 elif isinstance(product, sp.Mul): 

115 result = [] 

116 for a in product.args: 

117 if a.func == sp.Pow: 

118 result += handle_pow(a) 

119 else: 

120 result.append(a) 

121 return result 

122 else: 

123 return [product] 

124 

125 

126def symmetric_product(*args, with_diagonal: bool = True) -> Iterable: 

127 """Similar to itertools.product but yields only values where the index is ascending i.e. values below/up to diagonal 

128 

129 Examples: 

130 >>> list(symmetric_product([1, 2, 3], ['a', 'b', 'c'])) 

131 [(1, 'a'), (1, 'b'), (1, 'c'), (2, 'b'), (2, 'c'), (3, 'c')] 

132 >>> list(symmetric_product([1, 2, 3], ['a', 'b', 'c'], with_diagonal=False)) 

133 [(1, 'b'), (1, 'c'), (2, 'c')] 

134 """ 

135 ranges = [range(len(a)) for a in args] 

136 for idx in itertools.product(*ranges): 

137 valid_index = True 

138 for t in range(1, len(idx)): 

139 if (with_diagonal and idx[t - 1] > idx[t]) or (not with_diagonal and idx[t - 1] >= idx[t]): 

140 valid_index = False 

141 break 

142 if valid_index: 

143 yield tuple(a[i] for a, i in zip(args, idx)) 

144 

145 

146def fast_subs(expression: T, substitutions: Dict, 

147 skip: Optional[Callable[[sp.Expr], bool]] = None) -> T: 

148 """Similar to sympy subs function. 

149 

150 Args: 

151 expression: expression where parts should be substituted 

152 substitutions: dict defining substitutions by mapping from old to new terms 

153 skip: function that marks expressions to be skipped (if True is returned) - that means that in these skipped 

154 expressions no substitutions are done 

155 

156 This version is much faster for big substitution dictionaries than sympy version 

157 """ 

158 if type(expression) is sp.Matrix: 158 ↛ 159line 158 didn't jump to line 159, because the condition on line 158 was never true

159 return expression.copy().applyfunc(partial(fast_subs, substitutions=substitutions)) 

160 

161 def visit(expr): 

162 if skip and skip(expr): 

163 return expr 

164 if hasattr(expr, "fast_subs"): 

165 return expr.fast_subs(substitutions, skip) 

166 if expr in substitutions: 

167 return substitutions[expr] 

168 if not hasattr(expr, 'args'): 168 ↛ 169line 168 didn't jump to line 169, because the condition on line 168 was never true

169 return expr 

170 param_list = [visit(a) for a in expr.args] 

171 return expr if not param_list else expr.func(*param_list) 

172 

173 if len(substitutions) == 0: 

174 return expression 

175 else: 

176 return visit(expression) 

177 

178 

179def is_constant(expr): 

180 """Simple version of checking if a sympy expression is constant. 

181 Works also for piecewise defined functions - sympy's is_constant() has a problem there, see: 

182 https://github.com/sympy/sympy/issues/16662 

183 """ 

184 return len(expr.free_symbols) == 0 

185 

186 

187def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr, 

188 required_match_replacement: Optional[Union[int, float]] = 0.5, 

189 required_match_original: Optional[Union[int, float]] = None) -> sp.Expr: 

190 """Transformation for replacing a given subexpression inside a sum. 

191 

192 Examples: 

193 The next example demonstrates the advantage of replace_additive compared to sympy.subs: 

194 >>> x, y, z, k = sp.symbols("x y z k") 

195 >>> subs_additive(3*x + 3*y, replacement=k, subexpression=x + y) 

196 3*k 

197 

198 Terms that don't match completely can be substituted at the cost of additional terms. 

199 This trade-off is managed using the required_match parameters. 

200 >>> subs_additive(3*x + 3*y + z, replacement=k, subexpression=x+y+z, required_match_original=1.0) 

201 3*x + 3*y + z 

202 >>> subs_additive(3*x + 3*y + z, replacement=k, subexpression=x+y+z, required_match_original=0.5) 

203 3*k - 2*z 

204 >>> subs_additive(3*x + 3*y + z, replacement=k, subexpression=x+y+z, required_match_original=2) 

205 3*k - 2*z 

206 

207 Args: 

208 expr: input expression 

209 replacement: expression that is inserted for subexpression (if found) 

210 subexpression: expression to replace 

211 required_match_replacement: 

212 * if float: the percentage of terms of the subexpression that has to be matched in order to replace 

213 * if integer: the total number of terms that has to be matched in order to replace 

214 * None: is equal to integer 1 

215 * if both match parameters are given, both restrictions have to be fulfilled (i.e. logical AND) 

216 required_match_original: 

217 * if float: the percentage of terms of the original addition expression that has to be matched 

218 * if integer: the total number of terms that has to be matched in order to replace 

219 * None: is equal to integer 1 

220 

221 Returns: 

222 new expression with replacement 

223 """ 

224 def normalize_match_parameter(match_parameter, expression_length): 

225 if match_parameter is None: 

226 return 1 

227 elif isinstance(match_parameter, float): 

228 assert 0 <= match_parameter <= 1 

229 res = int(match_parameter * expression_length) 

230 return max(res, 1) 

231 elif isinstance(match_parameter, int): 

232 assert match_parameter > 0 

233 return match_parameter 

234 raise ValueError("Invalid parameter") 

235 

236 normalized_replacement_match = normalize_match_parameter(required_match_replacement, len(subexpression.args)) 

237 

238 def visit(current_expr): 

239 if current_expr.is_Add: 

240 expr_max_length = max(len(current_expr.args), len(subexpression.args)) 

241 normalized_current_expr_match = normalize_match_parameter(required_match_original, expr_max_length) 

242 expr_coefficients = current_expr.as_coefficients_dict() 

243 subexpression_coefficient_dict = subexpression.as_coefficients_dict() 

244 intersection = set(subexpression_coefficient_dict.keys()).intersection(set(expr_coefficients)) 

245 if len(intersection) >= max(normalized_replacement_match, normalized_current_expr_match): 

246 # find common factor 

247 factors = defaultdict(int) 

248 skips = 0 

249 for common_symbol in subexpression_coefficient_dict.keys(): 

250 if common_symbol not in expr_coefficients: 

251 skips += 1 

252 continue 

253 factor = expr_coefficients[common_symbol] / subexpression_coefficient_dict[common_symbol] 

254 factors[sp.simplify(factor)] += 1 

255 

256 common_factor = max(factors.items(), key=operator.itemgetter(1))[0] 

257 if factors[common_factor] >= max(normalized_current_expr_match, normalized_replacement_match): 

258 return current_expr - common_factor * subexpression + common_factor * replacement 

259 

260 # if no subexpression was found 

261 param_list = [visit(a) for a in current_expr.args] 

262 if not param_list: 

263 return current_expr 

264 else: 

265 if current_expr.func == sp.Mul and Zero() in param_list: 

266 return Zero() 

267 else: 

268 return current_expr.func(*param_list, evaluate=False) 

269 

270 return visit(expr) 

271 

272 

273def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Symbol], 

274 positive: Optional[bool] = None, 

275 replace_mixed: Optional[List[Assignment]] = None) -> sp.Expr: 

276 """Replaces second order mixed terms like 4*x*y by 2*( (x+y)**2 - x**2 - y**2 ). 

277 

278 This makes the term longer - simplify usually is undoing these - however this 

279 transformation can be done to find more common sub-expressions 

280 

281 Args: 

282 expr: input expression 

283 search_symbols: symbols that are searched for 

284 for example, given [x,y,z] terms like x*y, x*z, z*y are replaced 

285 positive: there are two ways to do this substitution, either with term 

286 (x+y)**2 or (x-y)**2 . if positive=True the first version is done, 

287 if positive=False the second version is done, if positive=None the 

288 sign is determined by the sign of the mixed term that is replaced 

289 replace_mixed: if a list is passed here, the expr x+y or x-y is replaced by a special new symbol 

290 and the replacement equation is added to the list 

291 """ 

292 mixed_symbols_replaced = set([e.lhs for e in replace_mixed]) if replace_mixed is not None else set() 

293 

294 if expr.is_Mul: 

295 distinct_search_symbols = set() 

296 nr_of_search_terms = 0 

297 other_factors = sp.Integer(1) 

298 for t in expr.args: 

299 if t in search_symbols: 

300 nr_of_search_terms += 1 

301 distinct_search_symbols.add(t) 

302 else: 

303 other_factors *= t 

304 if len(distinct_search_symbols) == 2 and nr_of_search_terms == 2: 

305 u, v = sorted(list(distinct_search_symbols), key=lambda symbol: symbol.name) 

306 if positive is None: 

307 other_factors_without_symbols = other_factors 

308 for s in other_factors.atoms(sp.Symbol): 

309 other_factors_without_symbols = other_factors_without_symbols.subs(s, 1) 

310 positive = other_factors_without_symbols.is_positive 

311 assert positive is not None 

312 sign = 1 if positive else -1 

313 if replace_mixed is not None: 

314 new_symbol_str = 'P' if positive else 'M' 

315 mixed_symbol_name = u.name + new_symbol_str + v.name 

316 mixed_symbol = sp.Symbol(mixed_symbol_name.replace("_", "")) 

317 if mixed_symbol not in mixed_symbols_replaced: 

318 mixed_symbols_replaced.add(mixed_symbol) 

319 replace_mixed.append(Assignment(mixed_symbol, u + sign * v)) 

320 else: 

321 mixed_symbol = u + sign * v 

322 return sp.Rational(1, 2) * sign * other_factors * (mixed_symbol ** 2 - u ** 2 - v ** 2) 

323 

324 param_list = [replace_second_order_products(a, search_symbols, positive, replace_mixed) for a in expr.args] 

325 result = expr.func(*param_list, evaluate=False) if param_list else expr 

326 return result 

327 

328 

329def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order: int = 3) -> sp.Expr: 

330 """Removes all terms that contain more than 'order' factors of given 'symbols' 

331 

332 Example: 

333 >>> x, y = sp.symbols("x y") 

334 >>> term = x**2 * y + y**2 * x + y**3 + x + y ** 2 

335 >>> remove_higher_order_terms(term, order=2, symbols=[x, y]) 

336 x + y**2 

337 """ 

338 from sympy.core.power import Pow 

339 from sympy.core.add import Add, Mul 

340 

341 result = 0 

342 expr = expr.expand() 

343 

344 def velocity_factors_in_product(product): 

345 factor_count = 0 

346 if type(product) is Mul: 

347 for factor in product.args: 

348 if type(factor) == Pow: 

349 if factor.args[0] in symbols: 

350 factor_count += factor.args[1] 

351 if factor in symbols: 

352 factor_count += 1 

353 elif type(product) is Pow: 

354 if product.args[0] in symbols: 

355 factor_count += product.args[1] 

356 return factor_count 

357 

358 if type(expr) == Mul or type(expr) == Pow: 

359 if velocity_factors_in_product(expr) <= order: 

360 return expr 

361 else: 

362 return sp.Rational(0, 1) 

363 

364 if type(expr) != Add: 

365 return expr 

366 

367 for sum_term in expr.args: 

368 if velocity_factors_in_product(sum_term) <= order: 

369 result += sum_term 

370 return result 

371 

372 

373def complete_the_square(expr: sp.Expr, symbol_to_complete: sp.Symbol, 

374 new_variable: sp.Symbol) -> Tuple[sp.Expr, Optional[Tuple[sp.Symbol, sp.Expr]]]: 

375 """Transforms second order polynomial into only squared part. 

376 

377 Examples: 

378 >>> a, b, c, s, n = sp.symbols("a b c s n") 

379 >>> expr = a * s**2 + b * s + c 

380 >>> completed_expr, substitution = complete_the_square(expr, symbol_to_complete=s, new_variable=n) 

381 >>> completed_expr 

382 a*n**2 + c - b**2/(4*a) 

383 >>> substitution 

384 (n, s + b/(2*a)) 

385 

386 Returns: 

387 (replaced_expr, tuple to pass to subs, such that old expr comes out again) 

388 """ 

389 p = sp.Poly(expr, symbol_to_complete) 

390 coefficients = p.all_coeffs() 

391 if len(coefficients) != 3: 

392 return expr, None 

393 a, b, _ = coefficients 

394 expr = expr.subs(symbol_to_complete, new_variable - b / (2 * a)) 

395 return sp.simplify(expr), (new_variable, symbol_to_complete + b / (2 * a)) 

396 

397 

398def complete_the_squares_in_exp(expr: sp.Expr, symbols_to_complete: Sequence[sp.Symbol]): 

399 """Completes squares in arguments of exponential which makes them simpler to integrate. 

400 

401 Very useful for integrating Maxwell-Boltzmann equilibria and its moment generating function 

402 """ 

403 dummies = [sp.Dummy() for _ in symbols_to_complete] 

404 

405 def visit(term): 

406 if term.func == sp.exp: 

407 exp_arg = term.args[0] 

408 for symbol_to_complete, dummy in zip(symbols_to_complete, dummies): 

409 exp_arg, substitution = complete_the_square(exp_arg, symbol_to_complete, dummy) 

410 return sp.exp(sp.expand(exp_arg)) 

411 else: 

412 param_list = [visit(a) for a in term.args] 

413 if not param_list: 

414 return term 

415 else: 

416 return term.func(*param_list) 

417 

418 result = visit(expr) 

419 for s, d in zip(symbols_to_complete, dummies): 

420 result = result.subs(d, s) 

421 return result 

422 

423 

424def extract_most_common_factor(term): 

425 """Processes a sum of fractions: determines the most common factor and splits term in common factor and rest""" 

426 coefficient_dict = term.as_coefficients_dict() 

427 counter = Counter([Abs(v) for v in coefficient_dict.values()]) 

428 common_factor, occurrences = max(counter.items(), key=operator.itemgetter(1)) 

429 if occurrences == 1 and (1 in counter): 

430 common_factor = 1 

431 return common_factor, term / common_factor 

432 

433 

434def recursive_collect(expr, symbols, order_by_occurences=False): 

435 """Applies sympy.collect recursively for a list of symbols, collecting symbol 2 in the coefficients of symbol 1,  

436 and so on. 

437 

438 Args: 

439 expr: A sympy expression 

440 symbols: A sequence of symbols 

441 order_by_occurences: If True, during recursive descent, always collect the symbol occuring  

442 most often in the expression. 

443 """ 

444 if order_by_occurences: 

445 symbols = list(expr.atoms(sp.Symbol) & set(symbols)) 

446 symbols = sorted(symbols, key=expr.count, reverse=True) 

447 if len(symbols) == 0: 

448 return expr 

449 symbol = symbols[0] 

450 collected_poly = sp.Poly(expr.collect(symbol), symbol) 

451 coeffs = collected_poly.all_coeffs()[::-1] 

452 rec_sum = sum(symbol**i * recursive_collect(c, symbols[1:], order_by_occurences) for i, c in enumerate(coeffs)) 

453 return rec_sum 

454 

455 

456def count_operations(term: Union[sp.Expr, List[sp.Expr]], 

457 only_type: Optional[str] = 'real') -> Dict[str, int]: 

458 """Counts the number of additions, multiplications and division. 

459 

460 Args: 

461 term: a sympy expression (term, assignment) or sequence of sympy objects 

462 only_type: 'real' or 'int' to count only operations on these types, or None for all 

463 

464 Returns: 

465 dict with 'adds', 'muls' and 'divs' keys 

466 """ 

467 from pystencils.fast_approximation import fast_sqrt, fast_inv_sqrt, fast_division 

468 

469 result = {'adds': 0, 'muls': 0, 'divs': 0, 'sqrts': 0, 

470 'fast_sqrts': 0, 'fast_inv_sqrts': 0, 'fast_div': 0} 

471 if isinstance(term, Sequence): 

472 for element in term: 

473 r = count_operations(element, only_type) 

474 for operation_name in result.keys(): 

475 result[operation_name] += r[operation_name] 

476 return result 

477 elif isinstance(term, Assignment): 

478 term = term.rhs 

479 

480 def check_type(e): 

481 if only_type is None: 

482 return True 

483 if isinstance(e, FieldPointerSymbol) and only_type == "real": 

484 return only_type == "int" 

485 

486 try: 

487 base_type = get_type_of_expression(e) 

488 except ValueError: 

489 return False 

490 if isinstance(base_type, VectorType): 

491 return False 

492 if isinstance(base_type, PointerType): 

493 return only_type == 'int' 

494 if only_type == 'int' and (base_type.is_int() or base_type.is_uint()): 

495 return True 

496 if only_type == 'real' and (base_type.is_float()): 

497 return True 

498 else: 

499 return base_type == only_type 

500 

501 def visit(t): 

502 visit_children = True 

503 if t.func is sp.Add: 

504 if check_type(t): 

505 result['adds'] += len(t.args) - 1 

506 elif t.func in [sp.Or, sp.And]: 

507 pass 

508 elif t.func is sp.Mul: 

509 if check_type(t): 

510 result['muls'] += len(t.args) - 1 

511 for a in t.args: 

512 if a == 1 or a == -1: 

513 result['muls'] -= 1 

514 elif isinstance(t, sp.Float) or isinstance(t, sp.Rational): 

515 pass 

516 elif isinstance(t, sp.Symbol): 

517 visit_children = False 

518 elif isinstance(t, sp.Indexed): 

519 visit_children = False 

520 elif t.is_integer: 

521 pass 

522 elif isinstance(t, cast_func): 

523 visit_children = False 

524 visit(t.args[0]) 

525 elif t.func is fast_sqrt: 

526 result['fast_sqrts'] += 1 

527 elif t.func is fast_inv_sqrt: 

528 result['fast_inv_sqrts'] += 1 

529 elif t.func is fast_division: 

530 result['fast_div'] += 1 

531 elif t.func is sp.Pow: 

532 if check_type(t.args[0]): 

533 visit_children = True 

534 if t.exp.is_integer and t.exp.is_number: 

535 if t.exp >= 0: 

536 result['muls'] += int(t.exp) - 1 

537 else: 

538 if result['muls'] > 0: 

539 result['muls'] -= 1 

540 result['divs'] += 1 

541 result['muls'] += (-int(t.exp)) - 1 

542 elif sp.nsimplify(t.exp) == sp.Rational(1, 2): 

543 result['sqrts'] += 1 

544 elif sp.nsimplify(t.exp) == -sp.Rational(1, 2): 

545 result["sqrts"] += 1 

546 result["divs"] += 1 

547 else: 

548 warnings.warn(f"Cannot handle exponent {t.exp} of sp.Pow node") 

549 else: 

550 warnings.warn("Counting operations: only integer exponents are supported in Pow, " 

551 "counting will be inaccurate") 

552 elif t.func is sp.Piecewise: 

553 for child_term, condition in t.args: 

554 visit(child_term) 

555 visit_children = False 

556 elif isinstance(t, sp.Rel): 

557 pass 

558 else: 

559 warnings.warn(f"Unknown sympy node of type {str(t.func)} counting will be inaccurate") 

560 

561 if visit_children: 

562 for a in t.args: 

563 visit(a) 

564 

565 visit(term) 

566 return result 

567 

568 

569def count_operations_in_ast(ast) -> Dict[str, int]: 

570 """Counts number of operations in an abstract syntax tree, see also :func:`count_operations`""" 

571 from pystencils.astnodes import SympyAssignment 

572 result = defaultdict(int) 

573 

574 def visit(node): 

575 if isinstance(node, SympyAssignment): 

576 r = count_operations(node.rhs) 

577 for k, v in r.items(): 

578 result[k] += v 

579 else: 

580 for arg in node.args: 

581 visit(arg) 

582 visit(ast) 

583 return result 

584 

585 

586def common_denominator(expr: sp.Expr) -> sp.Expr: 

587 """Finds least common multiple of all denominators occurring in an expression""" 

588 denominators = [r.q for r in expr.atoms(sp.Rational)] 

589 return sp.lcm(denominators) 

590 

591 

592def get_symmetric_part(expr: sp.Expr, symbols: Iterable[sp.Symbol]) -> sp.Expr: 

593 """ 

594 Returns the symmetric part of a sympy expressions. 

595 

596 Args: 

597 expr: sympy expression, labeled here as :math:`f` 

598 symbols: sequence of symbols which are considered as degrees of freedom, labeled here as :math:`x_0, x_1,...` 

599 

600 Returns: 

601 :math:`\frac{1}{2} [ f(x_0, x_1, ..) + f(-x_0, -x_1) ]` 

602 """ 

603 substitution_dict = {e: -e for e in symbols} 

604 return sp.Rational(1, 2) * (expr + expr.subs(substitution_dict)) 

605 

606 

607class SymbolCreator: 

608 def __getattribute__(self, name): 

609 return sp.Symbol(name)