Coverage for bim2sim/tasks/base.py: 73%

101 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-12 17:09 +0000

1"""Module containing the ITask base class an Playground to execute ITasks. 

2 

3All Tasks should inherit from ITask 

4""" 

5from __future__ import annotations 

6 

7import inspect 

8import logging 

9from typing import Generator, Tuple, List, Type, TYPE_CHECKING 

10 

11from bim2sim.kernel import log 

12from bim2sim.kernel.decision import DecisionBunch 

13 

14if TYPE_CHECKING: 

15 from bim2sim import Project 

16 

17 

18class TaskFailed(Exception): 

19 pass 

20 

21 

22class ITask: 

23 """Baseclass for interactive Tasks. 

24 

25 Args: 

26 reads: names of the arguments the run() method requires. The arguments 

27 are outputs from previous tasks 

28 touches: names that are assigned to the return value tuple of method 

29 run() 

30 final: flag that indicates termination of project run after this tasks 

31 single_user: flag that indicates if this tasks can be run multiple times 

32 in same Playground 

33 

34 

35 """ 

36 

37 reads: Tuple[str] = tuple() 

38 touches: Tuple[str] = tuple() 

39 final = False 

40 single_use = True 

41 

42 def __init__(self, playground): 

43 self.name = self.__class__.__name__ 

44 self.logger = log.get_user_logger("%s.%s" % (__name__, self.name)) 

45 self.paths = None 

46 self.prj_name = None 

47 self.playground = playground 

48 

49 def run(self, **kwargs): 

50 """Run tasks.""" 

51 raise NotImplementedError 

52 

53 @classmethod 

54 def requirements_met(cls, state, history) -> bool: 

55 """Check if all requirements for this tasks are met. 

56 

57 Args: 

58 state: state of playground 

59 history: history of playground 

60 """ 

61 if cls.single_use: 

62 for task in history: 

63 if task.__class__ is cls: 

64 return False 

65 # uses_ok = cls not in history if cls.single_use else True 

66 return all((r in state for r in cls.reads)) 

67 

68 def __repr__(self): 

69 return "<Task (%s)>" % self.name 

70 

71 

72class Playground: 

73 """Playground for executing ITasks""" 

74 

75 def __init__(self, project: Project): 

76 self.project = project 

77 self.sim_settings = project.plugin_cls.sim_settings() 

78 self.sim_settings.update_from_config(config=project.config) 

79 self.state = {} 

80 self.history = [] 

81 self.elements = {} 

82 self.elements_updated = False 

83 self.graph = None 

84 self.graph_updated = False 

85 self.logger = logging.getLogger("bim2sim.Playground") 

86 

87 @staticmethod 

88 def all_tasks() -> List[Type[ITask]]: 

89 """Returns list of all tasks""" 

90 return [task for task in ITask.__subclasses__()] # TODO: from workflow? 

91 

92 def available_tasks(self) -> List[Type[ITask]]: 

93 """Returns list of available tasks""" 

94 return [task for task in self.all_tasks() if task.requirements_met(self.state, self.history)] 

95 

96 def run_task(self, task: ITask) -> Generator[DecisionBunch, None, None]: 

97 """Generator executing tasks with arguments specified in tasks.reads.""" 

98 if not task.requirements_met(self.state, self.history): 

99 raise AssertionError("%s requirements not met." % task) 

100 

101 self.logger.info("Starting Task '%s'", task) 

102 read_state = {k: self.state[k] for k in task.reads} 

103 try: 

104 task.paths = self.project.paths 

105 task.prj_name = self.project.name 

106 if inspect.isgeneratorfunction(task.run): 

107 result = yield from task.run(**read_state) 

108 else: 

109 # no decisions 

110 result = task.run(**read_state) 

111 except Exception as ex: 

112 self.logger.exception("Task '%s' failed!", task) 

113 raise TaskFailed(str(task)) 

114 else: 

115 self.logger.info("Successfully finished Task '%s'", task) 

116 

117 # update elements in playground based on tasks results 

118 if 'elements' in task.touches: 

119 indices = [i for i in range(len(task.touches)) if 

120 'element' in task.touches[i]] 

121 if len(indices) > 1: 

122 self.logger.info("Found more than one element entry in touches" 

123 ", using the last one to update elements") 

124 index = indices[-1] 

125 else: 

126 index = indices[0] 

127 self.elements = result[index] 

128 self.elements_updated = True 

129 self.logger.info("Updated elements based on tasks results.") 

130 

131 if 'graph' in task.touches: 

132 indices = [i for i in range(len(task.touches)) if 

133 'graph' in task.touches[i]] 

134 if len(indices) > 1: 

135 self.logger.info("Found more than one graph entry in touches" 

136 ", using the last one to update elements") 

137 index = indices[-1] 

138 else: 

139 index = indices[0] 

140 self.graph = result[index] 

141 self.graph_updated = True 

142 self.logger.info("Updated graph based on tasks results.") 

143 

144 if task.touches == '__reset__': 

145 # special case 

146 self.state.clear() 

147 self.history.clear() 

148 else: 

149 # normal case 

150 n_res = len(result) if result is not None else 0 

151 if len(task.touches) != n_res: 

152 raise TaskFailed("Mismatch in '%s' result. Required items: %d (%s). Please make sure that required" 

153 " inputs (reads) are created in previous tasks." % (task, n_res, task.touches)) 

154 

155 # assign results to state 

156 if n_res: 

157 for key, sub_state in zip(task.touches, result): 

158 self.state[key] = sub_state 

159 

160 self.history.append(task) 

161 self.logger.info("%s done", task) 

162 

163 def update_elements(self, elements): 

164 """Updates the elements of the current run. 

165 

166 This only has to be done if you want to update elements manually, 

167 if a tasks touches elements, they will be updated automatically after 

168 the tasks is finished. 

169 """ 

170 self.elements = elements 

171 self.elements_updated = True 

172 self.logger.info("Updated elements based on tasks results.") 

173 

174 def update_graph(self, graph): 

175 """Updates the graph of the current run. 

176 

177 This only has to be done if you want to update graph manually, 

178 if a tasks touches graph, they will be updated automatically after 

179 the tasks is finished. 

180 """ 

181 self.graph = graph 

182 self.graph_updated = True 

183 self.logger.info("Updated graph based on tasks results.")