サムネがコーヒーの記事は書きかけです。

無向グラフの最短経路探索テンプレート NetworkX

エクセルで隣接行列を定義するだけで、各ノード間の最短経路を探索できるプログラムを書いたので、いつでも使えるようにテンプレートとして記録しておきます。

import networkx as nx
import matplotlib.pyplot as plt
import numpy as np 


class Pn:
    def __init__(self,name,next) -> None:
        self.name = name
        self.next = next
    def __repr__(self) -> str:
        return "{}".format(self.next)
    def get_name(self):
        return self.name
    def get_next(self):
        return self.next

class Data:
    def __init__(self,file) -> None:  
        fp = open(file)   
        self.lines = [fp.readlines() for i in range(len([None for i in open(file,"rb")]))]
        self.n = len(self.lines)-1
        self.arr = np.array([[int(0) if i == "" else int(i) for i in self.lines[0][j].split(",")[1:]] for j in range(1,self.n)])
        self.arr = np.append(self.arr, np.array([0 for i in range(self.n)]).reshape(1, self.n), axis=0)
        self.Pn = [i.replace("\n","") for i in self.lines[0][0].split(",")[1:]]
        self.arr += np.transpose(self.arr)
        pass 

    def get_array(self):
        return self.arr


class Graph(Data):
    def __init__(self, file) -> None:
        self.fig = plt.figure(figsize=[9,8])
        self.G = nx.Graph()
        super().__init__(file)
        self.G.add_nodes_from(self.Pn)
        print(self.arr)
        self.lim = 8
        self.P = [Pn(self.Pn[i],[f"P{j+1}" for j in range(self.n) if 0<self.arr[i][j]<self.lim ])for i in range(self.n)]
    def draw(self) -> None:
        nx.draw(self.G, with_labels=True,node_color = "red", edge_color = "black", node_size = 300, width = 2,font_size = 8)
    
    def show(self) -> None:
        plt.show()
    
    def save(self,name,dpi = 500) -> None:
        self.fig.savefig(name,dpi = dpi)

    def set_edges(self) -> None:
        self.G.add_edges_from([(i,j) for i in self.Pn for j in self.P[int(i[1:])-1].get_next()])
        
    @staticmethod
    def flat(l):
        if l == []:
            return []
        else:
            if isinstance(l[0],int) or isinstance(l[0],str):
                return [l[0]] + Graph.flat(l[1:])
            else:
                return Graph.flat(l[0]) + Graph.flat(l[1:])


    def get_shortest_paths(self,src,target):
        fig_shortest_paths = plt.figure()
        G = nx.Graph()
        paths = [p for p in nx.all_shortest_paths(self.G, source=src, target=target)]
        print(paths)
        G.add_nodes_from(set(self.flat(paths)))
        G.add_edges_from([(i[j],i[j+1]) for i in paths for j in range(len(i)-1)])
        nx.draw(G, with_labels=True,node_color = "red", edge_color = "black", node_size = 300, width = 2,font_size = 8)
        fig_shortest_paths.savefig("shortest_paths.png",dpi = 500)



graph = Graph("Book1.csv")
graph.set_edges()
graph.draw()

graph.get_shortest_paths("P21","P22")
graph.save("result.png")

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です