Coverage for bim2sim/elements/graphs/hvac_graph.py: 54%

426 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-12 17:09 +0000

1""" This module represents the elements of a HVAC system in form of a network 

2graph where each node represents a hvac-component. 

3""" 

4from __future__ import annotations 

5 

6import itertools 

7import logging 

8import os 

9import shutil 

10from pathlib import Path 

11from typing import Set, Iterable, Type, List, Union 

12import json 

13 

14import networkx as nx 

15from networkx import json_graph 

16 

17from bim2sim.elements.base_elements import ProductBased, ElementEncoder 

18 

19logger = logging.getLogger(__name__) 

20 

21 

22class HvacGraph(nx.Graph): 

23 """HVAC related graph manipulations based on ports.""" 

24 # TODO 246 HvacGraph init should only be called one based on IFC as it works 

25 # with port.connection and therefore is not reliable after changes are made 

26 # to the graph 

27 def __init__(self, elements=None, **attr): 

28 super().__init__(incoming_graph_data=None, **attr) 

29 if elements: 

30 self._update_from_elements(elements) 

31 

32 def _update_from_elements(self, elements): 

33 """ 

34 Update graph based on ports of elements. 

35 """ 

36 

37 nodes = [port for instance in elements for port in instance.ports 

38 if port.connection] 

39 inner_edges = [connection for instance in elements 

40 for connection in instance.inner_connections] 

41 edges = [(port, port.connection) for port in nodes if port.connection] 

42 

43 self.update(nodes=nodes, edges=edges + inner_edges) 

44 

45 @staticmethod 

46 def _contract_ports_into_elements(graph, port_nodes): 

47 """ 

48 Contract the port nodes into the belonging instance nodes for better 

49 handling, the information about the ports is still accessible via the 

50 get_contractions function. 

51 :return: 

52 """ 

53 new_graph = graph.copy() 

54 logger.info("Contracting ports into elements ...") 

55 for port in port_nodes: 

56 new_graph = nx.contracted_nodes(new_graph, port.parent, port) 

57 logger.info("Contracted the ports into node elements, this" 

58 " leads to %d nodes.", 

59 new_graph.number_of_nodes()) 

60 return graph 

61 

62 @property 

63 def element_graph(self) -> nx.Graph: 

64 """View of graph with elements instead of ports""" 

65 graph = nx.Graph() 

66 nodes = {ele.parent for ele in self.nodes if ele} 

67 edges = {(con[0].parent, con[1].parent) for con in self.edges 

68 if not con[0].parent is con[1].parent} 

69 graph.update(nodes=nodes, edges=edges) 

70 return graph 

71 

72 @property 

73 def elements(self): 

74 """List of elements present in graph""" 

75 nodes = {ele.parent for ele in self.nodes if ele} 

76 return list(nodes) 

77 

78 @staticmethod 

79 def get_not_contracted_neighbors(graph, node): 

80 neighbors = list( 

81 set(nx.all_neighbors(graph, node)) - 

82 set(graph.node[node]['contracted_nodes']) - 

83 {node} 

84 ) 

85 return neighbors 

86 

87 def get_contractions(self, node): 

88 """ 

89 Returns a list of contracted nodes for the passed node. 

90 :param node: node in whose connections you are interested 

91 :return: 

92 """ 

93 node = self.nodes[node] 

94 inner_nodes = [] 

95 if 'contraction' in node: 

96 for contraction in node['contraction'].keys(): 

97 inner_nodes.append(contraction) 

98 return inner_nodes 

99 

100 def get_cycles(self): 

101 """ 

102 Find cycles in the graph. 

103 :return cycles: 

104 """ 

105 logger.info("Searching for cycles in hvac network ...") 

106 base_cycles = nx.cycle_basis(self) 

107 # for cycle in base_cycles: 

108 # x = {port.parent for port in cycle} 

109 cycles = [cycle for cycle in base_cycles if len( 

110 {port.parent for port in cycle}) > 1] 

111 logger.info("Found %d cycles", len(cycles)) 

112 return cycles 

113 

114 # TODO #246 delete because not needed anymore 

115 @staticmethod 

116 def get_type_chains( 

117 element_graph: nx.Graph, 

118 types: Iterable[Type[ProductBased]], 

119 include_singles: bool = False): 

120 """Get lists of consecutive elements of the given types. Elements are 

121 ordered in the same way as the are connected. 

122 

123 Args: 

124 element_graph: Graph object with elements as nodes. 

125 types: Items the chains are built of. 

126 include_singles: 

127 

128 Returns: 

129 chain_lists: Lists of consecutive elements. 

130 """ 

131 

132 undirected_graph = element_graph 

133 nodes_degree2 = [v for v, d in undirected_graph.degree() if 1 <= d <= 2 

134 and type(v) in types] 

135 subgraph_aggregations = nx.subgraph(undirected_graph, nodes_degree2) 

136 

137 chain_lists = [] 

138 # order elements as connected 

139 

140 for component in nx.connected_components(subgraph_aggregations): 

141 subgraph = subgraph_aggregations.subgraph(component).copy() 

142 end_nodes = [v for v, d in subgraph.degree() if d == 1] 

143 

144 if len(end_nodes) != 2: 

145 if include_singles: 

146 chain_lists.append(list(subgraph.nodes)) 

147 continue 

148 # TODO more efficient 

149 elements = nx.shortest_path(subgraph, *end_nodes) 

150 chain_lists.append(elements) 

151 

152 return chain_lists 

153 

154 def merge(self, mapping: dict, inner_connections: list, 

155 add_connections=None): 

156 """Merge port nodes in graph 

157 

158 according to mapping dict port nodes are removed {port: None} 

159 or replaced {port: new_port} ceeping connections. 

160 adds also inner connections to graph and if passed additional 

161 connections. 

162 

163 WARNING: connections from removed port nodes are also removed 

164 

165 :param add_connections: additional connections to add 

166 :param mapping: replacement dict. ports as keys and replacement ports 

167 or None as values 

168 :param inner_connections: connections to add""" 

169 

170 replace = {k: v for k, v in mapping.items() if not v is None} 

171 remove = [k for k, v in mapping.items() if v is None] 

172 

173 nx.relabel_nodes(self, replace, copy=False) 

174 self.remove_nodes_from(remove) 

175 self.add_edges_from(inner_connections) 

176 if add_connections: 

177 self.add_edges_from(add_connections) 

178 

179 def get_connections(self): 

180 """Returns connections between different parent elements""" 

181 return [edge for edge in self.edges 

182 if not edge[0].parent is edge[1].parent] 

183 

184 # def get_nodes(self): 

185 # """Returns list of nodes represented by graph""" 

186 # return list(self.nodes) 

187 

188 def plot(self, path: Path = None, ports: bool = False, dpi: int = 400, 

189 use_pyvis=False): 

190 """Plot graph and either display or save as pdf file. 

191 

192 Args: 

193 path: If provided, the graph is saved there as pdf file or html 

194 if use_pyvis=True. 

195 ports: If True, the port graph is plotted, else the element graph. 

196 dpi: dots per inch, increase for higher quality (takes longer to 

197 render) 

198 use_pyvis: exports graph to interactive html 

199 """ 

200 # importing matplotlib is slow and plotting is optional 

201 import matplotlib.pyplot as plt 

202 from pyvis.network import Network 

203 

204 # https://plot.ly/python/network-graphs/ 

205 edge_colors_flow_side = { 

206 1: dict(edge_color='red'), 

207 -1: dict(edge_color='blue'), 

208 0: dict(edge_color='grey'), 

209 None: dict(edge_color='grey'), 

210 } 

211 node_colors_flow_direction = { 

212 1: dict(node_color='white', edgecolors='blue'), 

213 -1: dict(node_color='blue', edgecolors='black'), 

214 0: dict(node_color='grey', edgecolors='black'), 

215 None: dict(node_color='grey', edgecolors='black'), 

216 } 

217 

218 kwargs = {} 

219 if ports: 

220 # set port (nodes) colors based on flow direction 

221 graph = self 

222 kwargs['node_color'] = [ 

223 node_colors_flow_direction[ 

224 port.flow_direction]['node_color'] for port in self] 

225 kwargs['edgecolors'] = [ 

226 node_colors_flow_direction[ 

227 port.flow_direction]['edgecolors'] for port in self] 

228 kwargs['edge_color'] = 'grey' 

229 else: 

230 kwargs['node_color'] = 'blue' 

231 kwargs['edgecolors'] = 'black' 

232 # set connection colors (edges) based on flow side 

233 graph = self.element_graph 

234 edge_color_map = [] 

235 for edge in graph.edges: 

236 sides0 = {port.flow_side for port in edge[0].ports} 

237 sides1 = {port.flow_side for port in edge[1].ports} 

238 side = None 

239 # element with multiple sides is usually a consumer / generator 

240 # (or result of conflicts) hence side of definite element is 

241 # used 

242 if len(sides0) == 1: 

243 side = sides0.pop() 

244 elif len(sides1) == 1: 

245 side = sides1.pop() 

246 edge_color_map.append(edge_colors_flow_side[side]['edge_color']) 

247 kwargs['edge_color'] = edge_color_map 

248 if use_pyvis: 

249 # convert all edges to strings to use dynamic plotting via pyvis 

250 graph_cp = graph.copy() 

251 nodes = graph_cp.nodes() 

252 replace = {} 

253 for node in nodes.keys(): 

254 # use guid because str must be unique to prevent overrides 

255 replace[node] = str(node) + ' ' + str(node.guid) 

256 

257 # Todo Remove temp code. This is for Abschlussbericht Plotting only! 

258 # start of temp plotting code 

259 bypass_nodes_guids = [] 

260 small_pump_guids = [] 

261 parallel_pump_guids = [] 

262 for node in nodes: 

263 try: 

264 if node.length.m == 34: 

265 bypass_nodes_guids.append(node.guid) 

266 except AttributeError: 

267 pass 

268 try: 

269 if node.rated_power.m == 0.6: 

270 small_pump_guids.append(node.guid) 

271 except AttributeError: 

272 pass 

273 try: 

274 if node.rated_power.m == 1: 

275 parallel_pump_guids.append(node.guid) 

276 except AttributeError: 

277 pass 

278 # end of temp plotting code 

279 

280 nx.relabel_nodes(graph_cp, replace, copy=False) 

281 net = Network(height='1000', width='1000', notebook=False, 

282 bgcolor='white', font_color='black', layout=False) 

283 net.barnes_hut(gravity=-17000, spring_length=55) 

284 # net.show_buttons() 

285 pyvis_json = Path(__file__).parent.parent.parent / \ 

286 'assets/configs/pyvis/pyvis_options.json' 

287 f = open(pyvis_json) 

288 net.options = json.load(f) 

289 

290 net.from_nx(graph_cp, default_node_size=50) 

291 for node in net.nodes: 

292 try: 

293 node['label'] = node['label'].split('<')[1] 

294 except: 

295 pass 

296 node['label'] = node['label'].split('(ports')[0] 

297 if 'agg' in node['label'].lower(): 

298 node['label'] = node['label'].split('Agg0')[0] 

299 if 'storage' in node['label'].lower(): 

300 node['color'] = 'purple' 

301 if 'distributor' in node['label'].lower(): 

302 node['color'] = 'gray' 

303 if 'pump' in node['label'].lower(): 

304 node['color'] = 'blue' 

305 if 'spaceheater' in node['label'].lower(): 

306 node['color'] = 'purple' 

307 if 'pipestrand' in node['label'].lower(): 

308 node['color'] = 'blue' 

309 if any([ele in node['label'].lower() for ele in [ 

310 'parallelpump', 

311 'boiler', 

312 'generatoronefluid', 

313 'heatpump', 

314 ]]): 

315 node['color'] = 'yellow' 

316 # bypass color for parallelpump test 

317 if node['id'].split('> ')[-1] in bypass_nodes_guids: 

318 node['color'] = 'green' 

319 if node['id'].split('> ')[-1] in small_pump_guids: 

320 node['color'] = 'purple' 

321 if node['id'].split('> ')[-1] in parallel_pump_guids: 

322 node['color'] = 'red' 

323 

324 else: 

325 plt.figure(dpi=dpi) 

326 nx.draw(graph, node_size=10, font_size=5, linewidths=0.5, alpha=0.7, 

327 with_labels=True, **kwargs) 

328 plt.draw() 

329 if path: 

330 if use_pyvis: 

331 name = "%s_graph_pyvis.html" % ("port" if ports else "element") 

332 try: 

333 net.save_graph(name) 

334 shutil.move(name, path) 

335 except Exception as ex: 

336 logger.error("Unable to save plot of graph (%s)", ex) 

337 else: 

338 name = "%s_graph.pdf" % ("port" if ports else "element") 

339 try: 

340 plt.savefig( 

341 os.path.join(path, name), 

342 bbox_inches='tight') 

343 except IOError as ex: 

344 logger.error("Unable to save plot of graph (%s)", ex) 

345 else: 

346 if use_pyvis: 

347 name = "graph.html" 

348 try: 

349 net.show(name) 

350 except Exception as ex: 

351 logger.error("Unable to show plot of graph (%s)", ex) 

352 else: 

353 plt.show() 

354 plt.clf() 

355 

356 def dump_to_cytoscape_json(self, path: Path, ports: bool = True): 

357 """Dumps the current state of the graph to a json file in cytoscape 

358 format. 

359 

360 Args: 

361 path: Pathlib path to where to dump the JSON file. 

362 ports: if True the ports graph will be serialized, else the 

363 element_graph. 

364 """ 

365 if ports: 

366 export_graph = self 

367 name = 'port_graph_cytoscape.json' 

368 else: 

369 export_graph = self.element_graph 

370 name = 'element_graph_cytoscape.json' 

371 with open(path / name, 'w') as fp: 

372 json.dump(json_graph.cytoscape_data(export_graph), fp, 

373 cls=ElementEncoder) 

374 

375 def to_serializable(self): 

376 """Returns a json serializable object""" 

377 return json_graph.adjacency_data(self) 

378 

379 @classmethod 

380 def from_serialized(cls, data): 

381 """Sets grapg from serialized data""" 

382 return cls(json_graph.adjacency_graph(data)) 

383 

384 @staticmethod 

385 def remove_not_wanted_nodes( 

386 graph: element_graph, 

387 wanted: Set[Type[ProductBased]], 

388 inert: Set[Type[ProductBased]] = None): 

389 """Removes not wanted and not inert nodes from the given graph. 

390 

391 Args: 

392 graph: element_graph 

393 wanted: set of all elements that are wanted and should persist in 

394 graph 

395 inert: set all inert elements. Are treated the same as wanted. 

396 """ 

397 if inert is None: 

398 inert = set() 

399 if not all(map( 

400 lambda item: issubclass(item, ProductBased), wanted | inert)): 

401 raise AssertionError("Invalid type") 

402 _graph = graph.copy() 

403 # remove blocking nodes 

404 remove = {node for node in _graph.nodes 

405 if type(node) not in wanted | inert} 

406 _graph.remove_nodes_from(remove) 

407 return _graph 

408 

409 @staticmethod 

410 def find_bypasses_in_cycle(graph: nx.Graph, cycle, wanted): 

411 """ Detects bypasses in the given cycle of the given graph. 

412 

413 Bypasses are any direct connections between edge elements which don't 

414 hold wanted elements. 

415 

416 Args: 

417 graph: The graph in which the cycle belongs. 

418 cycle: A list of nodes representing a cycle in the graph. 

419 wanted: A list of classes of the desired node type. 

420 

421 Returns: 

422 List: A list of bypasses, where each bypass is a list of elements in 

423 the bypass. 

424 

425 Raises: 

426 None 

427 """ 

428 bypasses = [] 

429 # get wanted guids in the cycle 

430 wanted_guids_cycle = {node.guid for node in 

431 cycle if type(node) in wanted} 

432 

433 # check that it's not a parallel connection of wanted elements 

434 if len(wanted_guids_cycle) < 2: 

435 # get edge_elements 

436 edge_elements = [ 

437 node for node in cycle if len(node.ports) > 2] 

438 

439 # get direct connections between the edges 

440 subgraph = graph.subgraph(cycle) 

441 dir_connections = HvacGraph.get_dir_paths_between( 

442 subgraph, edge_elements) 

443 

444 # remove strands without wanted items 

445 for dir_connection in dir_connections: 

446 if not any(type(node) == want for want in wanted 

447 for node in dir_connection): 

448 bypasses.append(dir_connection) 

449 return bypasses 

450 

451 @staticmethod 

452 def get_all_cycles_with_wanted(graph, wanted): 

453 """Returns a list of cycles with wanted element in it.""" 

454 # todo how to handle cascaded boilers 

455 

456 directed = graph.to_directed() 

457 simple_cycles = list(nx.simple_cycles(directed)) 

458 # filter cycles: 

459 cycles = [cycle for cycle in simple_cycles for node in cycle if 

460 type(node) in wanted and len(cycle) > 2] 

461 

462 # remove duplicate cycles with only different orientation 

463 cycles_sorted = cycles.copy() 

464 # sort copy by guid 

465 for i, my_list in enumerate(cycles_sorted): 

466 cycles_sorted[i] = sorted(my_list, key=lambda x: x.guid, 

467 reverse=True) 

468 # remove duplicates 

469 unique_cycles = [list(x) for x in set(tuple(x) for x in cycles_sorted)] 

470 

471 # group cycles by wanted elements 

472 wanted_elements = [node for node in graph.nodes if type(node) in wanted] 

473 cycles_dict = {} 

474 for wanted_element in wanted_elements: 

475 cycles_dict[wanted_element] = [] 

476 for cycle in unique_cycles: 

477 if wanted_element in cycle: 

478 cycles_dict[wanted_element].append(cycle) 

479 

480 return cycles_dict 

481 

482 @staticmethod 

483 def detect_bypasses_to_wanted(graph, wanted, inert, blockers): 

484 """ 

485 Returns a list of nodes which build a bypass to the wanted elements 

486 and blockers. E.g. used to find bypasses between generator and 

487 distributor. 

488 :returns: list of nodes 

489 """ 

490 # todo currently not working, this might be reused later 

491 raise NotImplementedError 

492 pot_edge_elements = inert - blockers - wanted 

493 

494 cycles = HvacGraph.get_all_cycles_with_wanted(graph, wanted) 

495 

496 # get cycle with blocker (can't hold bypass if has wanted and blocker) 

497 blocker_cycles = [cycle for cycle in cycles 

498 if any(type(node) == block for block in 

499 blockers for node in cycle)] 

500 for blocker_cycle in blocker_cycles: 

501 cycles.remove(blocker_cycle) 

502 

503 pot_bypass_nodes = [] 

504 for cycle in cycles: 

505 # get edge_elements 

506 edge_elements = [node for node in cycle if 

507 len(list(nx.all_neighbors(graph, node))) > 2 and 

508 type(node) in pot_edge_elements] 

509 # get direct connections between edge_elements 

510 dir_connections = HvacGraph.get_dir_paths_between( 

511 graph, edge_elements) 

512 # filter connections, that has no wanted nodes 

513 for dir_connection in dir_connections: 

514 if not any(type(node) == want for want in wanted for node in 

515 dir_connection): 

516 pot_bypass_nodes.extend(dir_connection) 

517 # filter the potential bypass nodes for the once not in blocker cycles 

518 bypass_nodes = [pot_bypass_node 

519 for pot_bypass_node in pot_bypass_nodes 

520 for blocker_cycle in blocker_cycles 

521 if pot_bypass_node not in blocker_cycle] 

522 # remove duplicates 

523 bypass_nodes = list(set(bypass_nodes)) 

524 

525 return bypass_nodes 

526 

527 @staticmethod 

528 def get_parallels( 

529 graph, 

530 wanted: Set[Type[ProductBased]], 

531 inert: Set[Type[ProductBased]] = None, 

532 grouping=None, grp_threshold=None): 

533 """ Detect parallel occurrences of wanted items. 

534 All graph nodes not in inert or wanted are counted as blocking. 

535 Grouping can hold additional arguments like only same size. 

536 

537 :grouping: dict with parameter to be grouped and condition. e.g. ( 

538 rated_power: equal) 

539 :grp_threshold: float for minimum group size 

540 :returns: list of none overlapping subgraphs 

541 """ 

542 if inert is None: 

543 inert = set() 

544 if grouping is None: 

545 grouping = {} 

546 _graph = HvacGraph.remove_not_wanted_nodes(graph, wanted, inert) 

547 

548 # detect simple cycles 

549 basis_cycles = nx.cycle_basis(_graph) 

550 graph_changed = False 

551 # remove bypasses which prevent correct finding of parallel pump cycles 

552 for basis_cycle in basis_cycles: 

553 bypasses = HvacGraph.find_bypasses_in_cycle( 

554 _graph, basis_cycle, wanted) 

555 if bypasses: 

556 graph_changed = True 

557 for bypass in bypasses: 

558 _graph.remove_nodes_from([node for node in bypass]) 

559 

560 if graph_changed: 

561 # update graph after removing bypasses 

562 basis_cycles = nx.cycle_basis(_graph) 

563 

564 basis_cycle_sets = [frozenset((node.guid for node in basis_cycle)) 

565 for basis_cycle in basis_cycles] # hashable 

566 wanted_guids = {node.guid 

567 for node in _graph.nodes if type(node) in wanted} 

568 

569 occurrence_cycles = {} 

570 cycle_occurrences = {} 

571 for cycle in basis_cycle_sets: 

572 wanteds = frozenset( 

573 guid_node for guid_node in cycle if guid_node in wanted_guids) 

574 if len(wanteds) > 1: 

575 cycle_occurrences[cycle] = wanteds 

576 for item in wanteds: 

577 occurrence_cycles.setdefault(item, []).append(cycle) 

578 

579 # detect connected cycles 

580 def related_cycles(item, known): 

581 sub_cycles = occurrence_cycles[item] 

582 for cycle in sub_cycles: 

583 if cycle not in known: 

584 known.append(cycle) 

585 sub_items = cycle_occurrences[cycle] 

586 for sub_item in sub_items: 

587 related_cycles(sub_item, known) 

588 

589 cycle_sets = [] 

590 known_items = set() 

591 for item in occurrence_cycles: 

592 if item not in known_items: 

593 known = [] 

594 related_cycles(item, known) 

595 cycle_sets.append(known) 

596 known_items = known_items | { 

597 oc for k in known for oc in cycle_occurrences[k]} 

598 

599 def group_parallels(graph, group_attr, cond, threshold=None): 

600 """ group a graph of parallel items by conditions. Currently only 

601 equal grouping is implemented, which will return only parallel 

602 items with equal group_attr. If a threshold is given, only groups 

603 with number of elements > this threshold value will be included in 

604 result. 

605 """ 

606 if cond != 'equal': 

607 raise NotImplementedError() 

608 

609 graphs = [] 

610 nodes = [node for node in graph.nodes if type(node) in 

611 wanted] 

612 

613 # group elements by group_attr 

614 grouped = {} 

615 for node in nodes: 

616 grouped.setdefault(getattr(node, group_attr), []).append(node) 

617 

618 # check if more than one grouped element exist 

619 if len(grouped.keys()) == 1: 

620 graphs.append(graph) 

621 return graphs 

622 

623 for parallel_eles in grouped.values(): 

624 # only groups > threshold will be included in result 

625 if len(parallel_eles) <= threshold: 

626 continue 

627 else: 

628 subgraph_nodes = [] 

629 

630 for parallel_ele in parallel_eles: 

631 # get strands with the wanted items 

632 strand = HvacGraph.get_path_without_junctions( 

633 graph, parallel_ele, True) 

634 for node in strand: 

635 subgraph_nodes.append(node) 

636 graphs.append(graph.subgraph(subgraph_nodes)) 

637 return graphs 

638 

639 # merge cycles to get multi parallel items 

640 node_dict = {node.guid: node for node in _graph.nodes} 

641 graphs = [] 

642 for cycle_set in cycle_sets: 

643 nodes = [node_dict[guid] for guids in cycle_set for guid in guids] 

644 _graph = graph.subgraph(nodes) 

645 # apply filter if used 

646 if grouping: 

647 for group_attr, cond in grouping.items(): 

648 _graphs = group_parallels(_graph, group_attr, cond, 

649 grp_threshold) 

650 # filtering might return multiple graphs 

651 for _graph in _graphs: 

652 graphs.append(_graph) 

653 else: 

654 graphs.append(_graph) 

655 return graphs 

656 

657 def recurse_set_side(self, port, side, known: dict = None, 

658 raise_error=True): 

659 """Recursive set flow_side to connected ports""" 

660 if known is None: 

661 known = {} 

662 

663 # set side suggestion 

664 is_known = port in known 

665 current_side = known.get(port, port.flow_side) 

666 if not is_known: 

667 known[port] = side 

668 elif is_known and current_side == side: 

669 return known 

670 else: 

671 # conflict 

672 if raise_error: 

673 raise AssertionError("Conflicting flow_side in %r" % port) 

674 else: 

675 logger.error("Conflicting flow_side in %r", port) 

676 known[port] = None 

677 return known 

678 

679 # call neighbours 

680 for neigh in self.neighbors(port): 

681 if (neigh.parent.is_consumer() or neigh.parent.is_generator()) \ 

682 and port.parent is neigh.parent: 

683 # switch flag over consumers / generators 

684 self.recurse_set_side(neigh, -side, known, raise_error) 

685 else: 

686 self.recurse_set_side(neigh, side, known, raise_error) 

687 

688 return known 

689 

690 def recurse_set_unknown_sides(self, port, visited: list = None, 

691 masters: list = None): 

692 """Recursive checks neighbours flow_side. 

693 :returns tuple of 

694 common flow_side (None if conflict) 

695 list of checked ports 

696 list of ports on which flow_side s are determined""" 

697 

698 if visited is None: 

699 visited = [] 

700 if masters is None: 

701 masters = [] 

702 

703 # mark as visited to prevent deadloops 

704 visited.append(port) 

705 

706 if port.flow_side in (-1, 1): 

707 # use port with known flow_side as master 

708 masters.append(port) 

709 return port.flow_side, visited, masters 

710 

711 # call neighbours 

712 neighbour_sides = {} 

713 for neigh in self.neighbors(port): 

714 if neigh not in visited: 

715 if (neigh.parent.is_consumer() or neigh.parent.is_generator()) \ 

716 and port.parent is neigh.parent: 

717 # switch flag over consumers / generators 

718 side, _, _ = self.recurse_set_unknown_sides( 

719 neigh, visited, masters) 

720 side = -side 

721 else: 

722 side, _, _ = self.recurse_set_unknown_sides( 

723 neigh, visited, masters) 

724 neighbour_sides[neigh] = side 

725 

726 sides = set(neighbour_sides.values()) 

727 if not sides: 

728 return port.flow_side, visited, masters 

729 elif len(sides) == 1: 

730 # all neighbours have same site 

731 side = sides.pop() 

732 return side, visited, masters 

733 elif len(sides) == 2 and 0 in sides: 

734 side = (sides - {0}).pop() 

735 return side, visited, masters 

736 else: 

737 # conflict 

738 return None, visited, masters 

739 

740 @staticmethod 

741 def get_dir_paths_between(graph, nodes, include_edges=False): 

742 """ get direct connection between list of nodes in a graph.""" 

743 dir_connections = [] 

744 

745 for node1, node2 in itertools.combinations(nodes, 2): 

746 all_paths = list(nx.all_simple_paths(graph, node1, node2)) 

747 

748 for path in all_paths: 

749 if not any(len(ele.ports) > 2 for ele in path[1:-1]): 

750 

751 if len(path) > 2: 

752 # remove edge items if not wanted 

753 if not include_edges: 

754 path.pop(0) 

755 path.pop(-1) 

756 dir_connections.append(path) 

757 elif len(path) == 2 and include_edges: 

758 dir_connections.append(path) 

759 return dir_connections 

760 

761 @staticmethod 

762 def get_path_without_junctions(graph, root, include_edges=False): 

763 """Get not orientated list of nodes for paths that includes the 

764 defined root element. The edges areany junction elements. 

765 These edges are not included by default. 

766 Return all nodes in thisde path. 

767 # todo make this correct! 

768 :graph = 

769 :root = element which must be in path 

770 :include_edges = include edges of path or not""" 

771 

772 # def create_subgraph(graph, sub_G, start_node): 

773 # sub_G.add_node(start_node) 

774 # for n in graph.neighbors_iter(start_node): 

775 # if n not in sub_G.neighbors(start_node): 

776 # 

777 # sub_G.add_path([start_node, n]) 

778 # create_subgraph(G, sub_G, n) 

779 

780 nodes = [root] 

781 # get direct neighbors 

782 neighbors_root = nx.all_neighbors(graph, root) 

783 if not neighbors_root: 

784 return nodes 

785 # loop through neighbors until next junction 

786 for neighbor in neighbors_root: 

787 while True: 

788 neighbors = [neighbor for neighbor in 

789 nx.all_neighbors(graph, neighbor) if not 

790 neighbor in nodes] 

791 if not neighbors: 

792 break 

793 if len(neighbors) > 1: 

794 if include_edges: 

795 nodes.append(neighbor) 

796 break 

797 else: 

798 nodes.append(neighbor) 

799 neighbor = neighbors[0] 

800 return nodes 

801 

802 @staticmethod 

803 def get_connections_between( 

804 graph, 

805 wanted: Set[Type[ProductBased]], 

806 inert: Set[Type[ProductBased]] = set()): 

807 """Detect simple connections between wanted items. 

808 All graph nodes not in inert or wanted are counted as blocking 

809 :returns: list of none overlapping subgraphs 

810 """ 

811 if not all(map(lambda item: issubclass( 

812 item, ProductBased), wanted | inert)): 

813 raise AssertionError("Invalid type") 

814 _graph = HvacGraph.remove_not_wanted_nodes(graph, wanted, inert) 

815 

816 # get connections between the wanted items 

817 wanted_nodes = {node for node in _graph.nodes 

818 if type(node) in wanted} 

819 

820 cons = HvacGraph.get_dir_paths_between(_graph, wanted_nodes, True) 

821 graphs = [] 

822 for con in cons: 

823 subgraph = nx.subgraph(_graph, con) 

824 graphs.append(subgraph) 

825 return graphs 

826 

827 def subgraph_from_elements(self, elements: list): 

828 """ Returns a subgraph of the current graph containing only the ports 

829 associated with the provided elements. 

830 

831 Args: 

832 elements: A list of elements to include in the subgraph. 

833 

834 Returns: 

835 A subgraph of the current graph that contains only the ports 

836 associated with the provided elements. 

837 

838 Raises: 

839 AssertionError: If the provided elements are not part of the graph. 

840 

841 """ 

842 if not set(elements).issubset(set(self.elements)): 

843 raise AssertionError('The elements %s are not part of this graph.', 

844 elements) 

845 return self.subgraph((port for ele in elements for port in ele.ports)) 

846 

847 @staticmethod 

848 def remove_classes_from(graph: nx.Graph, 

849 classes_to_remove: Set[Type[ProductBased]] 

850 ) -> Union[nx.Graph, HvacGraph]: 

851 """ Removes nodes from a given graph based on their class. 

852 

853 Args: 

854 graph: The graph to remove nodes from. 

855 classes_to_remove: A set of classes to remove from the graph. 

856 

857 Returns: 

858 The modified graph as a new instance. 

859 """ 

860 _graph = graph.copy() 

861 if not isinstance(_graph, HvacGraph): 

862 nodes_to_remove = {node for node in _graph.nodes if 

863 node.__class__ in classes_to_remove} 

864 else: 

865 elements_to_remove = {ele for ele in _graph.elements if 

866 ele.__class__ in classes_to_remove} 

867 nodes_to_remove = [port for ele in elements_to_remove 

868 for port in ele.ports] 

869 _graph.remove_nodes_from(nodes_to_remove) 

870 return _graph