1"""Transformations using integer sets based on ISL library""" 

2 

3import islpy as isl 

4import sympy as sp 

5 

6import pystencils.astnodes as ast 

7from pystencils.transformations import parents_of_type 

8 

9 

10def remove_brackets(s): 

11 return s.replace('[', '').replace(']', '') 

12 

13 

14def _degrees_of_freedom_as_string(expr): 

15 expr = sp.sympify(expr) 

16 indexed = expr.atoms(sp.Indexed) 

17 symbols = expr.atoms(sp.Symbol) 

18 symbols_without_indexed_base = symbols - {ind.base.args[0] for ind in indexed} 

19 symbols_without_indexed_base.update(indexed) 

20 return {remove_brackets(str(s)) for s in symbols_without_indexed_base} 

21 

22 

23def isl_iteration_set(node: ast.Node): 

24 """Builds up an ISL set describing the iteration space by analysing the enclosing loops of the given node. """ 

25 conditions = [] 

26 degrees_of_freedom = set() 

27 

28 for loop in parents_of_type(node, ast.LoopOverCoordinate): 

29 if loop.step != 1: 

30 raise NotImplementedError("Loops with strides != 1 are not yet supported.") 

31 

32 degrees_of_freedom.update(_degrees_of_freedom_as_string(loop.loop_counter_symbol)) 

33 degrees_of_freedom.update(_degrees_of_freedom_as_string(loop.start)) 

34 degrees_of_freedom.update(_degrees_of_freedom_as_string(loop.stop)) 

35 

36 loop_start_str = remove_brackets(str(loop.start)) 

37 loop_stop_str = remove_brackets(str(loop.stop)) 

38 ctr_name = loop.loop_counter_name 

39 set_string_description = f"{ctr_name} >= {loop_start_str} and {ctr_name} < {loop_stop_str}" 

40 conditions.append(remove_brackets(set_string_description)) 

41 

42 symbol_names = ','.join(degrees_of_freedom) 

43 condition_str = ' and '.join(conditions) 

44 set_description = f"{{ [{symbol_names}] : {condition_str} }}" 

45 return degrees_of_freedom, isl.BasicSet(set_description) 

46 

47 

48def simplify_loop_counter_dependent_conditional(conditional): 

49 """Removes conditionals that depend on the loop counter or iteration limits if they are always true/false.""" 

50 dofs_in_condition = _degrees_of_freedom_as_string(conditional.condition_expr) 

51 dofs_in_loops, iteration_set = isl_iteration_set(conditional) 

52 if dofs_in_condition.issubset(dofs_in_loops): 

53 symbol_names = ','.join(dofs_in_loops) 

54 condition_str = remove_brackets(str(conditional.condition_expr)) 

55 condition_set = isl.BasicSet(f"{{ [{symbol_names}] : {condition_str} }}") 

56 

57 if condition_set.is_empty(): 

58 conditional.replace_by_false_block() 

59 

60 intersection = iteration_set.intersect(condition_set) 

61 if intersection.is_empty(): 

62 conditional.replace_by_false_block() 

63 elif intersection == iteration_set: 

64 conditional.replace_by_true_block()