1import socket 

2import time 

3from typing import Dict, Iterator, Sequence 

4 

5import blitzdb 

6 

7from pystencils.cpu.cpujit import get_compiler_config 

8 

9 

10class Database: 

11 """NoSQL database for storing simulation results. 

12 

13 Two backends are supported: 

14 * `blitzdb`: simple file-based solution similar to sqlite for SQL databases, stores json files 

15 no server setup required, but slow for larger collections 

16 * `mongodb`: mongodb backend via `pymongo` 

17 

18 A simulation result is stored as an object consisting of 

19 * parameters: dict with simulation parameters 

20 * results: dict with results 

21 * environment: information about the machine, compiler configuration and time 

22 

23 Args: 

24 file: database identifier, for blitzdb pass a directory name here. Database folder is created if it doesn't 

25 exist yet. For larger collections use mongodb. In this case pass a pymongo connection string 

26 e.g. "mongo://server:9131" 

27 

28 Example: 

29 >>> from tempfile import TemporaryDirectory 

30 >>> with TemporaryDirectory() as tmp_dir: 

31 ... db = Database(tmp_dir) # create database in temporary folder 

32 ... params = {'method': 'finite_diff', 'dx': 1.5} # some hypothetical simulation parameters 

33 ... db.save(params, result={'error': 1e-6}) # save simulation parameters together with hypothetical results 

34 ... assert db.was_already_simulated(params) # search for parameters in database 

35 ... assert next(db.filter_params(params))['params'] == params # get data set, keys are 'params', 'results' 

36 ... # and 'env' 

37 ... # get a pandas object with all results matching a query 

38 ... df = db.to_pandas({'dx': 1.5}, remove_prefix=True) 

39 ... # order columns alphabetically (just for doctest output) 

40 ... df.reindex(sorted(df.columns), axis=1) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE 

41 dx error method 

42 pk 

43 ... 1.5 0.000001 finite_diff 

44 """ 

45 

46 class SimulationResult(blitzdb.Document): 

47 pass 

48 

49 def __init__(self, file: str) -> None: 

50 if file.startswith("mongo://"): 

51 from pymongo import MongoClient 

52 db_name = file[len("mongo://"):] 

53 c = MongoClient() 

54 self.backend = blitzdb.MongoBackend(c[db_name]) 

55 else: 

56 self.backend = blitzdb.FileBackend(file) 

57 

58 self.backend.autocommit = True 

59 

60 def save(self, params: Dict, result: Dict, env: Dict = None, **kwargs) -> None: 

61 """Stores a simulation result in the database. 

62 

63 Args: 

64 params: dict of simulation parameters 

65 result: dict of simulation results 

66 env: optional environment - if None a default environment with compiler configuration, machine info and time 

67 is used 

68 **kwargs: the final object is updated with the keyword arguments 

69 

70 """ 

71 document_dict = { 

72 'params': params, 

73 'result': result, 

74 'env': env if env else self.get_environment(), 

75 } 

76 document_dict.update(kwargs) 

77 document = Database.SimulationResult(document_dict, backend=self.backend) 

78 document.save() 

79 self.backend.commit() 

80 

81 def filter_params(self, parameter_query: Dict, *args, **kwargs) -> Iterator['SimulationResult']: 

82 """Query using simulation parameters. 

83 

84 See blitzdb documentation for filter 

85 

86 Args: 

87 parameter_query: blitzdb filter dict using only simulation parameters 

88 *args: arguments passed to blitzdb filter 

89 **kwargs: arguments passed to blitzdb filter 

90 

91 Returns: 

92 generator of SimulationResult, which is a dict-like object with keys 'params', 'result' and 'env' 

93 """ 

94 query = {'params.' + k: v for k, v in parameter_query.items()} 

95 return self.filter(query, *args, **kwargs) 

96 

97 def filter(self, *args, **kwargs): 

98 """blitzdb filter on SimulationResult, not only simulation parameters. 

99 

100 Can be used to filter for results or environment options. 

101 The filter dictionary has to have prefixes "params." , "env." or "result." 

102 """ 

103 return self.backend.filter(Database.SimulationResult, *args, **kwargs) 

104 

105 def was_already_simulated(self, parameters): 

106 """Checks if there is at least one simulation result matching the passed parameters.""" 

107 return len(self.filter({'params': parameters})) > 0 

108 

109 # Columns with these prefixes are not included in pandas result 

110 pandas_columns_to_ignore = ['changedParams.', 'env.'] 

111 

112 def to_pandas(self, parameter_query, remove_prefix=True, drop_constant_columns=False): 

113 """Queries for simulations with given parameters and returns them in a pandas data frame. 

114 

115 Args: 

116 parameter_query: see filter method 

117 remove_prefix: if True the name of the pandas columns are not prefixed with "params." or "results." 

118 drop_constant_columns: if True, all columns are dropped that have the same value is all rows 

119 

120 Returns: 

121 pandas data frame 

122 """ 

123 from pandas import json_normalize 

124 

125 query_result = self.filter_params(parameter_query) 

126 attributes = [e.attributes for e in query_result] 

127 if not attributes: 

128 return 

129 df = json_normalize(attributes) 

130 df.set_index('pk', inplace=True) 

131 

132 if self.pandas_columns_to_ignore: 

133 remove_columns_by_prefix(df, self.pandas_columns_to_ignore, inplace=True) 

134 if remove_prefix: 

135 remove_prefix_in_column_name(df, inplace=True) 

136 if drop_constant_columns: 

137 df, _ = remove_constant_columns(df) 

138 

139 return df 

140 

141 @staticmethod 

142 def get_environment(): 

143 result = { 

144 'timestamp': time.mktime(time.gmtime()), 

145 'hostname': socket.gethostname(), 

146 'cpuCompilerConfig': get_compiler_config(), 

147 } 

148 try: 

149 from git import Repo, InvalidGitRepositoryError 

150 repo = Repo(search_parent_directories=True) 

151 result['git_hash'] = str(repo.head.commit) 

152 except (ImportError, InvalidGitRepositoryError): 

153 pass 

154 

155 return result 

156 

157# ----------------------------------------- Helper Functions ----------------------------------------------------------- 

158 

159 

160def remove_constant_columns(df): 

161 """Removes all columns of a pandas data frame that have the same value in all rows.""" 

162 import pandas as pd 

163 remaining_df = df.loc[:, df.apply(pd.Series.nunique) > 1] 

164 constants = df.loc[:, df.apply(pd.Series.nunique) <= 1].iloc[0] 

165 return remaining_df, constants 

166 

167 

168def remove_columns_by_prefix(df, prefixes: Sequence[str], inplace: bool = False): 

169 """Remove all columns from a pandas data frame whose name starts with one of the given prefixes.""" 

170 if not inplace: 

171 df = df.copy() 

172 

173 for column_name in df.columns: 

174 for prefix in prefixes: 

175 if column_name.startswith(prefix): 

176 del df[column_name] 

177 return df 

178 

179 

180def remove_prefix_in_column_name(df, inplace: bool = False): 

181 """Removes dotted prefixes from pandas column names. 

182 

183 A column named 'result.finite_diff.dx' is renamed to 'finite_diff.dx', everything before the first dot is removed. 

184 If the column name does not contain a dot, the column name is not changed. 

185 """ 

186 if not inplace: 

187 df = df.copy() 

188 

189 new_column_names = [] 

190 for column_name in df.columns: 

191 if '.' in column_name: 

192 new_column_names.append(column_name[column_name.index('.') + 1:]) 

193 else: 

194 new_column_names.append(column_name) 

195 df.columns = new_column_names 

196 return df