Coverage for bim2sim/elements/graphs/hvac_graph.py: 54%
426 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-12 17:09 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-12 17:09 +0000
1""" This module represents the elements of a HVAC system in form of a network
2graph where each node represents a hvac-component.
3"""
4from __future__ import annotations
6import itertools
7import logging
8import os
9import shutil
10from pathlib import Path
11from typing import Set, Iterable, Type, List, Union
12import json
14import networkx as nx
15from networkx import json_graph
17from bim2sim.elements.base_elements import ProductBased, ElementEncoder
19logger = logging.getLogger(__name__)
22class HvacGraph(nx.Graph):
23 """HVAC related graph manipulations based on ports."""
24 # TODO 246 HvacGraph init should only be called one based on IFC as it works
25 # with port.connection and therefore is not reliable after changes are made
26 # to the graph
27 def __init__(self, elements=None, **attr):
28 super().__init__(incoming_graph_data=None, **attr)
29 if elements:
30 self._update_from_elements(elements)
32 def _update_from_elements(self, elements):
33 """
34 Update graph based on ports of elements.
35 """
37 nodes = [port for instance in elements for port in instance.ports
38 if port.connection]
39 inner_edges = [connection for instance in elements
40 for connection in instance.inner_connections]
41 edges = [(port, port.connection) for port in nodes if port.connection]
43 self.update(nodes=nodes, edges=edges + inner_edges)
45 @staticmethod
46 def _contract_ports_into_elements(graph, port_nodes):
47 """
48 Contract the port nodes into the belonging instance nodes for better
49 handling, the information about the ports is still accessible via the
50 get_contractions function.
51 :return:
52 """
53 new_graph = graph.copy()
54 logger.info("Contracting ports into elements ...")
55 for port in port_nodes:
56 new_graph = nx.contracted_nodes(new_graph, port.parent, port)
57 logger.info("Contracted the ports into node elements, this"
58 " leads to %d nodes.",
59 new_graph.number_of_nodes())
60 return graph
62 @property
63 def element_graph(self) -> nx.Graph:
64 """View of graph with elements instead of ports"""
65 graph = nx.Graph()
66 nodes = {ele.parent for ele in self.nodes if ele}
67 edges = {(con[0].parent, con[1].parent) for con in self.edges
68 if not con[0].parent is con[1].parent}
69 graph.update(nodes=nodes, edges=edges)
70 return graph
72 @property
73 def elements(self):
74 """List of elements present in graph"""
75 nodes = {ele.parent for ele in self.nodes if ele}
76 return list(nodes)
78 @staticmethod
79 def get_not_contracted_neighbors(graph, node):
80 neighbors = list(
81 set(nx.all_neighbors(graph, node)) -
82 set(graph.node[node]['contracted_nodes']) -
83 {node}
84 )
85 return neighbors
87 def get_contractions(self, node):
88 """
89 Returns a list of contracted nodes for the passed node.
90 :param node: node in whose connections you are interested
91 :return:
92 """
93 node = self.nodes[node]
94 inner_nodes = []
95 if 'contraction' in node:
96 for contraction in node['contraction'].keys():
97 inner_nodes.append(contraction)
98 return inner_nodes
100 def get_cycles(self):
101 """
102 Find cycles in the graph.
103 :return cycles:
104 """
105 logger.info("Searching for cycles in hvac network ...")
106 base_cycles = nx.cycle_basis(self)
107 # for cycle in base_cycles:
108 # x = {port.parent for port in cycle}
109 cycles = [cycle for cycle in base_cycles if len(
110 {port.parent for port in cycle}) > 1]
111 logger.info("Found %d cycles", len(cycles))
112 return cycles
114 # TODO #246 delete because not needed anymore
115 @staticmethod
116 def get_type_chains(
117 element_graph: nx.Graph,
118 types: Iterable[Type[ProductBased]],
119 include_singles: bool = False):
120 """Get lists of consecutive elements of the given types. Elements are
121 ordered in the same way as the are connected.
123 Args:
124 element_graph: Graph object with elements as nodes.
125 types: Items the chains are built of.
126 include_singles:
128 Returns:
129 chain_lists: Lists of consecutive elements.
130 """
132 undirected_graph = element_graph
133 nodes_degree2 = [v for v, d in undirected_graph.degree() if 1 <= d <= 2
134 and type(v) in types]
135 subgraph_aggregations = nx.subgraph(undirected_graph, nodes_degree2)
137 chain_lists = []
138 # order elements as connected
140 for component in nx.connected_components(subgraph_aggregations):
141 subgraph = subgraph_aggregations.subgraph(component).copy()
142 end_nodes = [v for v, d in subgraph.degree() if d == 1]
144 if len(end_nodes) != 2:
145 if include_singles:
146 chain_lists.append(list(subgraph.nodes))
147 continue
148 # TODO more efficient
149 elements = nx.shortest_path(subgraph, *end_nodes)
150 chain_lists.append(elements)
152 return chain_lists
154 def merge(self, mapping: dict, inner_connections: list,
155 add_connections=None):
156 """Merge port nodes in graph
158 according to mapping dict port nodes are removed {port: None}
159 or replaced {port: new_port} ceeping connections.
160 adds also inner connections to graph and if passed additional
161 connections.
163 WARNING: connections from removed port nodes are also removed
165 :param add_connections: additional connections to add
166 :param mapping: replacement dict. ports as keys and replacement ports
167 or None as values
168 :param inner_connections: connections to add"""
170 replace = {k: v for k, v in mapping.items() if not v is None}
171 remove = [k for k, v in mapping.items() if v is None]
173 nx.relabel_nodes(self, replace, copy=False)
174 self.remove_nodes_from(remove)
175 self.add_edges_from(inner_connections)
176 if add_connections:
177 self.add_edges_from(add_connections)
179 def get_connections(self):
180 """Returns connections between different parent elements"""
181 return [edge for edge in self.edges
182 if not edge[0].parent is edge[1].parent]
184 # def get_nodes(self):
185 # """Returns list of nodes represented by graph"""
186 # return list(self.nodes)
188 def plot(self, path: Path = None, ports: bool = False, dpi: int = 400,
189 use_pyvis=False):
190 """Plot graph and either display or save as pdf file.
192 Args:
193 path: If provided, the graph is saved there as pdf file or html
194 if use_pyvis=True.
195 ports: If True, the port graph is plotted, else the element graph.
196 dpi: dots per inch, increase for higher quality (takes longer to
197 render)
198 use_pyvis: exports graph to interactive html
199 """
200 # importing matplotlib is slow and plotting is optional
201 import matplotlib.pyplot as plt
202 from pyvis.network import Network
204 # https://plot.ly/python/network-graphs/
205 edge_colors_flow_side = {
206 1: dict(edge_color='red'),
207 -1: dict(edge_color='blue'),
208 0: dict(edge_color='grey'),
209 None: dict(edge_color='grey'),
210 }
211 node_colors_flow_direction = {
212 1: dict(node_color='white', edgecolors='blue'),
213 -1: dict(node_color='blue', edgecolors='black'),
214 0: dict(node_color='grey', edgecolors='black'),
215 None: dict(node_color='grey', edgecolors='black'),
216 }
218 kwargs = {}
219 if ports:
220 # set port (nodes) colors based on flow direction
221 graph = self
222 kwargs['node_color'] = [
223 node_colors_flow_direction[
224 port.flow_direction]['node_color'] for port in self]
225 kwargs['edgecolors'] = [
226 node_colors_flow_direction[
227 port.flow_direction]['edgecolors'] for port in self]
228 kwargs['edge_color'] = 'grey'
229 else:
230 kwargs['node_color'] = 'blue'
231 kwargs['edgecolors'] = 'black'
232 # set connection colors (edges) based on flow side
233 graph = self.element_graph
234 edge_color_map = []
235 for edge in graph.edges:
236 sides0 = {port.flow_side for port in edge[0].ports}
237 sides1 = {port.flow_side for port in edge[1].ports}
238 side = None
239 # element with multiple sides is usually a consumer / generator
240 # (or result of conflicts) hence side of definite element is
241 # used
242 if len(sides0) == 1:
243 side = sides0.pop()
244 elif len(sides1) == 1:
245 side = sides1.pop()
246 edge_color_map.append(edge_colors_flow_side[side]['edge_color'])
247 kwargs['edge_color'] = edge_color_map
248 if use_pyvis:
249 # convert all edges to strings to use dynamic plotting via pyvis
250 graph_cp = graph.copy()
251 nodes = graph_cp.nodes()
252 replace = {}
253 for node in nodes.keys():
254 # use guid because str must be unique to prevent overrides
255 replace[node] = str(node) + ' ' + str(node.guid)
257 # Todo Remove temp code. This is for Abschlussbericht Plotting only!
258 # start of temp plotting code
259 bypass_nodes_guids = []
260 small_pump_guids = []
261 parallel_pump_guids = []
262 for node in nodes:
263 try:
264 if node.length.m == 34:
265 bypass_nodes_guids.append(node.guid)
266 except AttributeError:
267 pass
268 try:
269 if node.rated_power.m == 0.6:
270 small_pump_guids.append(node.guid)
271 except AttributeError:
272 pass
273 try:
274 if node.rated_power.m == 1:
275 parallel_pump_guids.append(node.guid)
276 except AttributeError:
277 pass
278 # end of temp plotting code
280 nx.relabel_nodes(graph_cp, replace, copy=False)
281 net = Network(height='1000', width='1000', notebook=False,
282 bgcolor='white', font_color='black', layout=False)
283 net.barnes_hut(gravity=-17000, spring_length=55)
284 # net.show_buttons()
285 pyvis_json = Path(__file__).parent.parent.parent / \
286 'assets/configs/pyvis/pyvis_options.json'
287 f = open(pyvis_json)
288 net.options = json.load(f)
290 net.from_nx(graph_cp, default_node_size=50)
291 for node in net.nodes:
292 try:
293 node['label'] = node['label'].split('<')[1]
294 except:
295 pass
296 node['label'] = node['label'].split('(ports')[0]
297 if 'agg' in node['label'].lower():
298 node['label'] = node['label'].split('Agg0')[0]
299 if 'storage' in node['label'].lower():
300 node['color'] = 'purple'
301 if 'distributor' in node['label'].lower():
302 node['color'] = 'gray'
303 if 'pump' in node['label'].lower():
304 node['color'] = 'blue'
305 if 'spaceheater' in node['label'].lower():
306 node['color'] = 'purple'
307 if 'pipestrand' in node['label'].lower():
308 node['color'] = 'blue'
309 if any([ele in node['label'].lower() for ele in [
310 'parallelpump',
311 'boiler',
312 'generatoronefluid',
313 'heatpump',
314 ]]):
315 node['color'] = 'yellow'
316 # bypass color for parallelpump test
317 if node['id'].split('> ')[-1] in bypass_nodes_guids:
318 node['color'] = 'green'
319 if node['id'].split('> ')[-1] in small_pump_guids:
320 node['color'] = 'purple'
321 if node['id'].split('> ')[-1] in parallel_pump_guids:
322 node['color'] = 'red'
324 else:
325 plt.figure(dpi=dpi)
326 nx.draw(graph, node_size=10, font_size=5, linewidths=0.5, alpha=0.7,
327 with_labels=True, **kwargs)
328 plt.draw()
329 if path:
330 if use_pyvis:
331 name = "%s_graph_pyvis.html" % ("port" if ports else "element")
332 try:
333 net.save_graph(name)
334 shutil.move(name, path)
335 except Exception as ex:
336 logger.error("Unable to save plot of graph (%s)", ex)
337 else:
338 name = "%s_graph.pdf" % ("port" if ports else "element")
339 try:
340 plt.savefig(
341 os.path.join(path, name),
342 bbox_inches='tight')
343 except IOError as ex:
344 logger.error("Unable to save plot of graph (%s)", ex)
345 else:
346 if use_pyvis:
347 name = "graph.html"
348 try:
349 net.show(name)
350 except Exception as ex:
351 logger.error("Unable to show plot of graph (%s)", ex)
352 else:
353 plt.show()
354 plt.clf()
356 def dump_to_cytoscape_json(self, path: Path, ports: bool = True):
357 """Dumps the current state of the graph to a json file in cytoscape
358 format.
360 Args:
361 path: Pathlib path to where to dump the JSON file.
362 ports: if True the ports graph will be serialized, else the
363 element_graph.
364 """
365 if ports:
366 export_graph = self
367 name = 'port_graph_cytoscape.json'
368 else:
369 export_graph = self.element_graph
370 name = 'element_graph_cytoscape.json'
371 with open(path / name, 'w') as fp:
372 json.dump(json_graph.cytoscape_data(export_graph), fp,
373 cls=ElementEncoder)
375 def to_serializable(self):
376 """Returns a json serializable object"""
377 return json_graph.adjacency_data(self)
379 @classmethod
380 def from_serialized(cls, data):
381 """Sets grapg from serialized data"""
382 return cls(json_graph.adjacency_graph(data))
384 @staticmethod
385 def remove_not_wanted_nodes(
386 graph: element_graph,
387 wanted: Set[Type[ProductBased]],
388 inert: Set[Type[ProductBased]] = None):
389 """Removes not wanted and not inert nodes from the given graph.
391 Args:
392 graph: element_graph
393 wanted: set of all elements that are wanted and should persist in
394 graph
395 inert: set all inert elements. Are treated the same as wanted.
396 """
397 if inert is None:
398 inert = set()
399 if not all(map(
400 lambda item: issubclass(item, ProductBased), wanted | inert)):
401 raise AssertionError("Invalid type")
402 _graph = graph.copy()
403 # remove blocking nodes
404 remove = {node for node in _graph.nodes
405 if type(node) not in wanted | inert}
406 _graph.remove_nodes_from(remove)
407 return _graph
409 @staticmethod
410 def find_bypasses_in_cycle(graph: nx.Graph, cycle, wanted):
411 """ Detects bypasses in the given cycle of the given graph.
413 Bypasses are any direct connections between edge elements which don't
414 hold wanted elements.
416 Args:
417 graph: The graph in which the cycle belongs.
418 cycle: A list of nodes representing a cycle in the graph.
419 wanted: A list of classes of the desired node type.
421 Returns:
422 List: A list of bypasses, where each bypass is a list of elements in
423 the bypass.
425 Raises:
426 None
427 """
428 bypasses = []
429 # get wanted guids in the cycle
430 wanted_guids_cycle = {node.guid for node in
431 cycle if type(node) in wanted}
433 # check that it's not a parallel connection of wanted elements
434 if len(wanted_guids_cycle) < 2:
435 # get edge_elements
436 edge_elements = [
437 node for node in cycle if len(node.ports) > 2]
439 # get direct connections between the edges
440 subgraph = graph.subgraph(cycle)
441 dir_connections = HvacGraph.get_dir_paths_between(
442 subgraph, edge_elements)
444 # remove strands without wanted items
445 for dir_connection in dir_connections:
446 if not any(type(node) == want for want in wanted
447 for node in dir_connection):
448 bypasses.append(dir_connection)
449 return bypasses
451 @staticmethod
452 def get_all_cycles_with_wanted(graph, wanted):
453 """Returns a list of cycles with wanted element in it."""
454 # todo how to handle cascaded boilers
456 directed = graph.to_directed()
457 simple_cycles = list(nx.simple_cycles(directed))
458 # filter cycles:
459 cycles = [cycle for cycle in simple_cycles for node in cycle if
460 type(node) in wanted and len(cycle) > 2]
462 # remove duplicate cycles with only different orientation
463 cycles_sorted = cycles.copy()
464 # sort copy by guid
465 for i, my_list in enumerate(cycles_sorted):
466 cycles_sorted[i] = sorted(my_list, key=lambda x: x.guid,
467 reverse=True)
468 # remove duplicates
469 unique_cycles = [list(x) for x in set(tuple(x) for x in cycles_sorted)]
471 # group cycles by wanted elements
472 wanted_elements = [node for node in graph.nodes if type(node) in wanted]
473 cycles_dict = {}
474 for wanted_element in wanted_elements:
475 cycles_dict[wanted_element] = []
476 for cycle in unique_cycles:
477 if wanted_element in cycle:
478 cycles_dict[wanted_element].append(cycle)
480 return cycles_dict
482 @staticmethod
483 def detect_bypasses_to_wanted(graph, wanted, inert, blockers):
484 """
485 Returns a list of nodes which build a bypass to the wanted elements
486 and blockers. E.g. used to find bypasses between generator and
487 distributor.
488 :returns: list of nodes
489 """
490 # todo currently not working, this might be reused later
491 raise NotImplementedError
492 pot_edge_elements = inert - blockers - wanted
494 cycles = HvacGraph.get_all_cycles_with_wanted(graph, wanted)
496 # get cycle with blocker (can't hold bypass if has wanted and blocker)
497 blocker_cycles = [cycle for cycle in cycles
498 if any(type(node) == block for block in
499 blockers for node in cycle)]
500 for blocker_cycle in blocker_cycles:
501 cycles.remove(blocker_cycle)
503 pot_bypass_nodes = []
504 for cycle in cycles:
505 # get edge_elements
506 edge_elements = [node for node in cycle if
507 len(list(nx.all_neighbors(graph, node))) > 2 and
508 type(node) in pot_edge_elements]
509 # get direct connections between edge_elements
510 dir_connections = HvacGraph.get_dir_paths_between(
511 graph, edge_elements)
512 # filter connections, that has no wanted nodes
513 for dir_connection in dir_connections:
514 if not any(type(node) == want for want in wanted for node in
515 dir_connection):
516 pot_bypass_nodes.extend(dir_connection)
517 # filter the potential bypass nodes for the once not in blocker cycles
518 bypass_nodes = [pot_bypass_node
519 for pot_bypass_node in pot_bypass_nodes
520 for blocker_cycle in blocker_cycles
521 if pot_bypass_node not in blocker_cycle]
522 # remove duplicates
523 bypass_nodes = list(set(bypass_nodes))
525 return bypass_nodes
527 @staticmethod
528 def get_parallels(
529 graph,
530 wanted: Set[Type[ProductBased]],
531 inert: Set[Type[ProductBased]] = None,
532 grouping=None, grp_threshold=None):
533 """ Detect parallel occurrences of wanted items.
534 All graph nodes not in inert or wanted are counted as blocking.
535 Grouping can hold additional arguments like only same size.
537 :grouping: dict with parameter to be grouped and condition. e.g. (
538 rated_power: equal)
539 :grp_threshold: float for minimum group size
540 :returns: list of none overlapping subgraphs
541 """
542 if inert is None:
543 inert = set()
544 if grouping is None:
545 grouping = {}
546 _graph = HvacGraph.remove_not_wanted_nodes(graph, wanted, inert)
548 # detect simple cycles
549 basis_cycles = nx.cycle_basis(_graph)
550 graph_changed = False
551 # remove bypasses which prevent correct finding of parallel pump cycles
552 for basis_cycle in basis_cycles:
553 bypasses = HvacGraph.find_bypasses_in_cycle(
554 _graph, basis_cycle, wanted)
555 if bypasses:
556 graph_changed = True
557 for bypass in bypasses:
558 _graph.remove_nodes_from([node for node in bypass])
560 if graph_changed:
561 # update graph after removing bypasses
562 basis_cycles = nx.cycle_basis(_graph)
564 basis_cycle_sets = [frozenset((node.guid for node in basis_cycle))
565 for basis_cycle in basis_cycles] # hashable
566 wanted_guids = {node.guid
567 for node in _graph.nodes if type(node) in wanted}
569 occurrence_cycles = {}
570 cycle_occurrences = {}
571 for cycle in basis_cycle_sets:
572 wanteds = frozenset(
573 guid_node for guid_node in cycle if guid_node in wanted_guids)
574 if len(wanteds) > 1:
575 cycle_occurrences[cycle] = wanteds
576 for item in wanteds:
577 occurrence_cycles.setdefault(item, []).append(cycle)
579 # detect connected cycles
580 def related_cycles(item, known):
581 sub_cycles = occurrence_cycles[item]
582 for cycle in sub_cycles:
583 if cycle not in known:
584 known.append(cycle)
585 sub_items = cycle_occurrences[cycle]
586 for sub_item in sub_items:
587 related_cycles(sub_item, known)
589 cycle_sets = []
590 known_items = set()
591 for item in occurrence_cycles:
592 if item not in known_items:
593 known = []
594 related_cycles(item, known)
595 cycle_sets.append(known)
596 known_items = known_items | {
597 oc for k in known for oc in cycle_occurrences[k]}
599 def group_parallels(graph, group_attr, cond, threshold=None):
600 """ group a graph of parallel items by conditions. Currently only
601 equal grouping is implemented, which will return only parallel
602 items with equal group_attr. If a threshold is given, only groups
603 with number of elements > this threshold value will be included in
604 result.
605 """
606 if cond != 'equal':
607 raise NotImplementedError()
609 graphs = []
610 nodes = [node for node in graph.nodes if type(node) in
611 wanted]
613 # group elements by group_attr
614 grouped = {}
615 for node in nodes:
616 grouped.setdefault(getattr(node, group_attr), []).append(node)
618 # check if more than one grouped element exist
619 if len(grouped.keys()) == 1:
620 graphs.append(graph)
621 return graphs
623 for parallel_eles in grouped.values():
624 # only groups > threshold will be included in result
625 if len(parallel_eles) <= threshold:
626 continue
627 else:
628 subgraph_nodes = []
630 for parallel_ele in parallel_eles:
631 # get strands with the wanted items
632 strand = HvacGraph.get_path_without_junctions(
633 graph, parallel_ele, True)
634 for node in strand:
635 subgraph_nodes.append(node)
636 graphs.append(graph.subgraph(subgraph_nodes))
637 return graphs
639 # merge cycles to get multi parallel items
640 node_dict = {node.guid: node for node in _graph.nodes}
641 graphs = []
642 for cycle_set in cycle_sets:
643 nodes = [node_dict[guid] for guids in cycle_set for guid in guids]
644 _graph = graph.subgraph(nodes)
645 # apply filter if used
646 if grouping:
647 for group_attr, cond in grouping.items():
648 _graphs = group_parallels(_graph, group_attr, cond,
649 grp_threshold)
650 # filtering might return multiple graphs
651 for _graph in _graphs:
652 graphs.append(_graph)
653 else:
654 graphs.append(_graph)
655 return graphs
657 def recurse_set_side(self, port, side, known: dict = None,
658 raise_error=True):
659 """Recursive set flow_side to connected ports"""
660 if known is None:
661 known = {}
663 # set side suggestion
664 is_known = port in known
665 current_side = known.get(port, port.flow_side)
666 if not is_known:
667 known[port] = side
668 elif is_known and current_side == side:
669 return known
670 else:
671 # conflict
672 if raise_error:
673 raise AssertionError("Conflicting flow_side in %r" % port)
674 else:
675 logger.error("Conflicting flow_side in %r", port)
676 known[port] = None
677 return known
679 # call neighbours
680 for neigh in self.neighbors(port):
681 if (neigh.parent.is_consumer() or neigh.parent.is_generator()) \
682 and port.parent is neigh.parent:
683 # switch flag over consumers / generators
684 self.recurse_set_side(neigh, -side, known, raise_error)
685 else:
686 self.recurse_set_side(neigh, side, known, raise_error)
688 return known
690 def recurse_set_unknown_sides(self, port, visited: list = None,
691 masters: list = None):
692 """Recursive checks neighbours flow_side.
693 :returns tuple of
694 common flow_side (None if conflict)
695 list of checked ports
696 list of ports on which flow_side s are determined"""
698 if visited is None:
699 visited = []
700 if masters is None:
701 masters = []
703 # mark as visited to prevent deadloops
704 visited.append(port)
706 if port.flow_side in (-1, 1):
707 # use port with known flow_side as master
708 masters.append(port)
709 return port.flow_side, visited, masters
711 # call neighbours
712 neighbour_sides = {}
713 for neigh in self.neighbors(port):
714 if neigh not in visited:
715 if (neigh.parent.is_consumer() or neigh.parent.is_generator()) \
716 and port.parent is neigh.parent:
717 # switch flag over consumers / generators
718 side, _, _ = self.recurse_set_unknown_sides(
719 neigh, visited, masters)
720 side = -side
721 else:
722 side, _, _ = self.recurse_set_unknown_sides(
723 neigh, visited, masters)
724 neighbour_sides[neigh] = side
726 sides = set(neighbour_sides.values())
727 if not sides:
728 return port.flow_side, visited, masters
729 elif len(sides) == 1:
730 # all neighbours have same site
731 side = sides.pop()
732 return side, visited, masters
733 elif len(sides) == 2 and 0 in sides:
734 side = (sides - {0}).pop()
735 return side, visited, masters
736 else:
737 # conflict
738 return None, visited, masters
740 @staticmethod
741 def get_dir_paths_between(graph, nodes, include_edges=False):
742 """ get direct connection between list of nodes in a graph."""
743 dir_connections = []
745 for node1, node2 in itertools.combinations(nodes, 2):
746 all_paths = list(nx.all_simple_paths(graph, node1, node2))
748 for path in all_paths:
749 if not any(len(ele.ports) > 2 for ele in path[1:-1]):
751 if len(path) > 2:
752 # remove edge items if not wanted
753 if not include_edges:
754 path.pop(0)
755 path.pop(-1)
756 dir_connections.append(path)
757 elif len(path) == 2 and include_edges:
758 dir_connections.append(path)
759 return dir_connections
761 @staticmethod
762 def get_path_without_junctions(graph, root, include_edges=False):
763 """Get not orientated list of nodes for paths that includes the
764 defined root element. The edges areany junction elements.
765 These edges are not included by default.
766 Return all nodes in thisde path.
767 # todo make this correct!
768 :graph =
769 :root = element which must be in path
770 :include_edges = include edges of path or not"""
772 # def create_subgraph(graph, sub_G, start_node):
773 # sub_G.add_node(start_node)
774 # for n in graph.neighbors_iter(start_node):
775 # if n not in sub_G.neighbors(start_node):
776 #
777 # sub_G.add_path([start_node, n])
778 # create_subgraph(G, sub_G, n)
780 nodes = [root]
781 # get direct neighbors
782 neighbors_root = nx.all_neighbors(graph, root)
783 if not neighbors_root:
784 return nodes
785 # loop through neighbors until next junction
786 for neighbor in neighbors_root:
787 while True:
788 neighbors = [neighbor for neighbor in
789 nx.all_neighbors(graph, neighbor) if not
790 neighbor in nodes]
791 if not neighbors:
792 break
793 if len(neighbors) > 1:
794 if include_edges:
795 nodes.append(neighbor)
796 break
797 else:
798 nodes.append(neighbor)
799 neighbor = neighbors[0]
800 return nodes
802 @staticmethod
803 def get_connections_between(
804 graph,
805 wanted: Set[Type[ProductBased]],
806 inert: Set[Type[ProductBased]] = set()):
807 """Detect simple connections between wanted items.
808 All graph nodes not in inert or wanted are counted as blocking
809 :returns: list of none overlapping subgraphs
810 """
811 if not all(map(lambda item: issubclass(
812 item, ProductBased), wanted | inert)):
813 raise AssertionError("Invalid type")
814 _graph = HvacGraph.remove_not_wanted_nodes(graph, wanted, inert)
816 # get connections between the wanted items
817 wanted_nodes = {node for node in _graph.nodes
818 if type(node) in wanted}
820 cons = HvacGraph.get_dir_paths_between(_graph, wanted_nodes, True)
821 graphs = []
822 for con in cons:
823 subgraph = nx.subgraph(_graph, con)
824 graphs.append(subgraph)
825 return graphs
827 def subgraph_from_elements(self, elements: list):
828 """ Returns a subgraph of the current graph containing only the ports
829 associated with the provided elements.
831 Args:
832 elements: A list of elements to include in the subgraph.
834 Returns:
835 A subgraph of the current graph that contains only the ports
836 associated with the provided elements.
838 Raises:
839 AssertionError: If the provided elements are not part of the graph.
841 """
842 if not set(elements).issubset(set(self.elements)):
843 raise AssertionError('The elements %s are not part of this graph.',
844 elements)
845 return self.subgraph((port for ele in elements for port in ele.ports))
847 @staticmethod
848 def remove_classes_from(graph: nx.Graph,
849 classes_to_remove: Set[Type[ProductBased]]
850 ) -> Union[nx.Graph, HvacGraph]:
851 """ Removes nodes from a given graph based on their class.
853 Args:
854 graph: The graph to remove nodes from.
855 classes_to_remove: A set of classes to remove from the graph.
857 Returns:
858 The modified graph as a new instance.
859 """
860 _graph = graph.copy()
861 if not isinstance(_graph, HvacGraph):
862 nodes_to_remove = {node for node in _graph.nodes if
863 node.__class__ in classes_to_remove}
864 else:
865 elements_to_remove = {ele for ele in _graph.elements if
866 ele.__class__ in classes_to_remove}
867 nodes_to_remove = [port for ele in elements_to_remove
868 for port in ele.ports]
869 _graph.remove_nodes_from(nodes_to_remove)
870 return _graph