Coverage for bim2sim/tasks/base.py: 73%
101 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-12 17:09 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-12 17:09 +0000
1"""Module containing the ITask base class an Playground to execute ITasks.
3All Tasks should inherit from ITask
4"""
5from __future__ import annotations
7import inspect
8import logging
9from typing import Generator, Tuple, List, Type, TYPE_CHECKING
11from bim2sim.kernel import log
12from bim2sim.kernel.decision import DecisionBunch
14if TYPE_CHECKING:
15 from bim2sim import Project
18class TaskFailed(Exception):
19 pass
22class ITask:
23 """Baseclass for interactive Tasks.
25 Args:
26 reads: names of the arguments the run() method requires. The arguments
27 are outputs from previous tasks
28 touches: names that are assigned to the return value tuple of method
29 run()
30 final: flag that indicates termination of project run after this tasks
31 single_user: flag that indicates if this tasks can be run multiple times
32 in same Playground
35 """
37 reads: Tuple[str] = tuple()
38 touches: Tuple[str] = tuple()
39 final = False
40 single_use = True
42 def __init__(self, playground):
43 self.name = self.__class__.__name__
44 self.logger = log.get_user_logger("%s.%s" % (__name__, self.name))
45 self.paths = None
46 self.prj_name = None
47 self.playground = playground
49 def run(self, **kwargs):
50 """Run tasks."""
51 raise NotImplementedError
53 @classmethod
54 def requirements_met(cls, state, history) -> bool:
55 """Check if all requirements for this tasks are met.
57 Args:
58 state: state of playground
59 history: history of playground
60 """
61 if cls.single_use:
62 for task in history:
63 if task.__class__ is cls:
64 return False
65 # uses_ok = cls not in history if cls.single_use else True
66 return all((r in state for r in cls.reads))
68 def __repr__(self):
69 return "<Task (%s)>" % self.name
72class Playground:
73 """Playground for executing ITasks"""
75 def __init__(self, project: Project):
76 self.project = project
77 self.sim_settings = project.plugin_cls.sim_settings()
78 self.sim_settings.update_from_config(config=project.config)
79 self.state = {}
80 self.history = []
81 self.elements = {}
82 self.elements_updated = False
83 self.graph = None
84 self.graph_updated = False
85 self.logger = logging.getLogger("bim2sim.Playground")
87 @staticmethod
88 def all_tasks() -> List[Type[ITask]]:
89 """Returns list of all tasks"""
90 return [task for task in ITask.__subclasses__()] # TODO: from workflow?
92 def available_tasks(self) -> List[Type[ITask]]:
93 """Returns list of available tasks"""
94 return [task for task in self.all_tasks() if task.requirements_met(self.state, self.history)]
96 def run_task(self, task: ITask) -> Generator[DecisionBunch, None, None]:
97 """Generator executing tasks with arguments specified in tasks.reads."""
98 if not task.requirements_met(self.state, self.history):
99 raise AssertionError("%s requirements not met." % task)
101 self.logger.info("Starting Task '%s'", task)
102 read_state = {k: self.state[k] for k in task.reads}
103 try:
104 task.paths = self.project.paths
105 task.prj_name = self.project.name
106 if inspect.isgeneratorfunction(task.run):
107 result = yield from task.run(**read_state)
108 else:
109 # no decisions
110 result = task.run(**read_state)
111 except Exception as ex:
112 self.logger.exception("Task '%s' failed!", task)
113 raise TaskFailed(str(task))
114 else:
115 self.logger.info("Successfully finished Task '%s'", task)
117 # update elements in playground based on tasks results
118 if 'elements' in task.touches:
119 indices = [i for i in range(len(task.touches)) if
120 'element' in task.touches[i]]
121 if len(indices) > 1:
122 self.logger.info("Found more than one element entry in touches"
123 ", using the last one to update elements")
124 index = indices[-1]
125 else:
126 index = indices[0]
127 self.elements = result[index]
128 self.elements_updated = True
129 self.logger.info("Updated elements based on tasks results.")
131 if 'graph' in task.touches:
132 indices = [i for i in range(len(task.touches)) if
133 'graph' in task.touches[i]]
134 if len(indices) > 1:
135 self.logger.info("Found more than one graph entry in touches"
136 ", using the last one to update elements")
137 index = indices[-1]
138 else:
139 index = indices[0]
140 self.graph = result[index]
141 self.graph_updated = True
142 self.logger.info("Updated graph based on tasks results.")
144 if task.touches == '__reset__':
145 # special case
146 self.state.clear()
147 self.history.clear()
148 else:
149 # normal case
150 n_res = len(result) if result is not None else 0
151 if len(task.touches) != n_res:
152 raise TaskFailed("Mismatch in '%s' result. Required items: %d (%s). Please make sure that required"
153 " inputs (reads) are created in previous tasks." % (task, n_res, task.touches))
155 # assign results to state
156 if n_res:
157 for key, sub_state in zip(task.touches, result):
158 self.state[key] = sub_state
160 self.history.append(task)
161 self.logger.info("%s done", task)
163 def update_elements(self, elements):
164 """Updates the elements of the current run.
166 This only has to be done if you want to update elements manually,
167 if a tasks touches elements, they will be updated automatically after
168 the tasks is finished.
169 """
170 self.elements = elements
171 self.elements_updated = True
172 self.logger.info("Updated elements based on tasks results.")
174 def update_graph(self, graph):
175 """Updates the graph of the current run.
177 This only has to be done if you want to update graph manually,
178 if a tasks touches graph, they will be updated automatically after
179 the tasks is finished.
180 """
181 self.graph = graph
182 self.graph_updated = True
183 self.logger.info("Updated graph based on tasks results.")