vectornet代碼復(fù)現(xiàn)

數(shù)據(jù):
先來(lái)看下丟到模型里面的x棍丐,下面是直接將x當(dāng)作散點(diǎn)圖可視化误辑,每個(gè)polyline用不同的顏色表示,紅線是需要預(yù)測(cè)的agent的歷史軌跡


x可視化

下面是官方的api可視化


image.png

模型結(jié)構(gòu):

class HGNN(nn.Module):
    def forward(self, data):
        time_step_len = int(data[0].time_step_len[0]) #83
        valid_lens = data[0].valid_len # 78
        sub_graph_out = self.subgraph(data)
        x = sub_graph_out.x.view(-1, time_step_len, self.polyline_vec_shape) 
        out = self.self_atten_layer(x, valid_lens)
        pred = self.traj_pred_mlp(out[:, [0]].squeeze(1))
        return pred

核心代碼就四行:

1. sub_graph_out = self.subgraph(data)

2. x = sub_graph_out.x.view(-1, time_step_len, self.polyline_vec_shape)

3. out = self.self_atten_layer(x, valid_lens)

4. pred = self.traj_pred_mlp(out[:, [0]].squeeze(1))

首先看1
subGraph的forward如下

class SubGraph(nn.Module):
    """
    Subgraph that computes all vectors in a polyline, and get a polyline-level feature
    """
    def __init__(self, in_channels, num_subgraph_layres=3, hidden_unit=64):
        super(SubGraph, self).__init__()
        self.num_subgraph_layres = num_subgraph_layres
        self.layer_seq = nn.Sequential()
        for i in range(num_subgraph_layres):
            self.layer_seq.add_module(
                f'glp_{i}', GraphLayerProp(in_channels, hidden_unit))
            in_channels *= 2

    def forward(self, sub_data):
        x, edge_index = sub_data.x, sub_data.edge_index # x 8310,8 edge_index 2,66852
        for name, layer in self.layer_seq.named_modules():
            if isinstance(layer, GraphLayerProp):
                x = layer(x, edge_index)
        sub_data.x = x # 8310歌逢,64
        out_data = max_pool(sub_data.cluster, sub_data) # 1162巾钉,64
        assert out_data.x.shape[0] % int(sub_data.time_step_len[0]) == 0
        out_data.x = out_data.x / out_data.x.norm(dim=0)
        return out_data

subgraph的核心代碼有三步

1.1

 for name, layer in self.layer_seq.named_modules():
            if isinstance(layer, GraphLayerProp):
                x = layer(x, edge_index)

1.2 out_data = max_pool(sub_data.cluster, sub_data)

1.3 out_data.x = out_data.x / out_data.x.norm(dim=0)

先來(lái)看1.1
subgraph的forward中首先過(guò)了三層GraphLayerProp

for name, layer in self.layer_seq.named_modules():
            if isinstance(layer, GraphLayerProp):
                x = layer(x, edge_index)

self.layer_seq.named_modules()如下:

(glp_0): GraphLayerProp(
    (mlp): Sequential(
      (0): Linear(in_features=8, out_features=64, bias=True)
      (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (2): ReLU()
      (3): Linear(in_features=64, out_features=8, bias=True)
    )
  )
  (glp_1): GraphLayerProp(
    (mlp): Sequential(
      (0): Linear(in_features=16, out_features=64, bias=True)
      (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (2): ReLU()
      (3): Linear(in_features=64, out_features=16, bias=True)
    )
  )
  (glp_2): GraphLayerProp(
    (mlp): Sequential(
      (0): Linear(in_features=32, out_features=64, bias=True)
      (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (2): ReLU()
      (3): Linear(in_features=64, out_features=32, bias=True)
    )
  )

但是我們發(fā)現(xiàn)(3)linear的out_features 不等于下一層的in_features
因?yàn)?3)linear后面還有個(gè)contact的操作(具體看GraphLayerProp里面的update),讓out_features翻倍了秘案,實(shí)際上應(yīng)該是:
(8310,8)-> (8310,16)
(8310,16)-> (8310,32)
(8310,32)-> (8310,64)
現(xiàn)在咱們來(lái)具體看下GraphLayerProp

class GraphLayerProp(MessagePassing):
    """
    Message Passing mechanism for infomation aggregation
    """
    def __init__(self, in_channels, hidden_unit=64, verbose=False):
        super(GraphLayerProp, self).__init__(
            aggr='max')  # MaxPooling aggragation
        self.verbose = verbose
        self.mlp = nn.Sequential(
            nn.Linear(in_channels, hidden_unit),
            nn.LayerNorm(hidden_unit),
            nn.ReLU(),
            nn.Linear(hidden_unit, in_channels)
        )

    def forward(self, x, edge_index):
        if self.verbose:
            print(f'x before mlp: {x}')
        x = self.mlp(x)
        if self.verbose:
            print(f"x after mlp: {x}")
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)

    def message(self, x_j):
        return x_j

    def update(self, aggr_out, x):
        if self.verbose:
            print(f"x after mlp: {x}")
            print(f"aggr_out: {aggr_out}")
        return torch.cat([x, aggr_out], dim=1)

GraphLayerProp中主要有三步:

1.1.1 encoder

1.1.2 aggregate

1.1.3 contact

subgraph

結(jié)合圖片來(lái)看:

1.1.1 encoder:

forward中x = self.mlp(x) 先對(duì)feature做一次mlp 砰苍,即x :(8310,8) -> (8310,64) -> x (8310,8)

x = self.mlp(x)

1.1.2 aggregate:

做一次max的gnn 的aggregate

super(GraphLayerProp, self).__init__(
            aggr='max')  # MaxPooling aggragation

1.1.3 contact:

將max出來(lái)的feature 和 max前的feature 做一次concat 潦匈,所以feature維度在這翻倍

torch.cat([x, aggr_out], dim=1) 

上述1.1.1-1.1.3是一層GraphLayerProp,subgraph的forward中過(guò)了三層赚导,即:
(8310,8)-> (8310,16)
(8310,16)-> (8310,32)
(8310,32)-> (8310,64)

現(xiàn)在過(guò)完三次GraphLayerProp茬缩,x : (8310,64)

1.2 out_data = max_pool(sub_data.cluster, sub_data) # 1162,64

回到1.2:對(duì)每個(gè)polyline subgraph做maxpooling

sub_data.cluster 里面類似[0,0,0,0,1,1,1,1,2,2,2,3,3....1161,1161]
這里面0000吼旧,1111凰锡,222分別是不同id的車道線、車輛等的子圖,即論文中的polyline subgraphs

例如:
0黍少,0寡夹,0,0表示id為0的子圖有四個(gè)時(shí)間刻

現(xiàn)在將每個(gè)物體抽象成了一個(gè)64維向量厂置,即,將所有時(shí)間刻的向量池化為一個(gè)時(shí)間刻的向量

做maxpooling 后x:(1162魂角,64)= (14*83 昵济,64)
即有14個(gè)場(chǎng)景中,每個(gè)場(chǎng)景83個(gè)車道和車輛單一時(shí)刻的vector

polyline subgraphs

1.3 out_data.x = out_data.x / out_data.x.norm(dim=0)

除以均值

2 x = sub_graph_out.x.view(-1, time_step_len, self.polyline_vec_shape)

接下來(lái)reshape一下

time_step_len = 83 (83包含了1個(gè)agent野揪,41個(gè)左車道線和41個(gè)右車道線)

x(1162,64) -> x(14,83,64)

這里14表示有14個(gè)預(yù)測(cè)場(chǎng)景访忿,每個(gè)場(chǎng)景有83個(gè)polyline,每個(gè)polyline的feature是64維的向量

3 out = self.self_atten_layer(x, valid_lens) #14,83,64

通過(guò)self attention計(jì)算每個(gè)polyline直接的注意力斯稳,再aggregate一下海铆。

self_atten_layer的初始化:

self.self_atten_layer = SelfAttentionLayer(
            self.polyline_vec_shape,
            global_graph_width, 
            need_scale=False) #64  64
self attention
def forward(self, x, valid_len):
        query = self.q_lin(x) # 14,83,64 
        key = self.k_lin(x)
        value = self.v_lin(x)
        scores = torch.bmm(query, key.transpose(1, 2)) # 14,83,83
        attention_weights = masked_softmax(scores, valid_len)
        return torch.bmm(attention_weights, value)

4 pred = self.traj_pred_mlp(out[:, [0]].squeeze(1)) #14,60

traj_pred_mlp的初始化

self.traj_pred_mlp = TrajPredMLP(
            global_graph_width, out_channels, traj_pred_mlp_width) # 64 60 64

最后一步直接把(14,83,64) -> (14,60)
60的向量由30個(gè)x坐標(biāo)值和30個(gè)y坐標(biāo)值組成,即預(yù)測(cè)的后30個(gè)時(shí)間片的軌跡坐標(biāo)

class TrajPredMLP(nn.Module):
    """Predict one feature trajectory, in offset format"""

    def __init__(self, in_channels, out_channels, hidden_unit):
        super(TrajPredMLP, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_channels, hidden_unit),
            nn.LayerNorm(hidden_unit),
            nn.ReLU(),
            nn.Linear(hidden_unit, out_channels)
        )

    def forward(self, x):
        return self.mlp(x)
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末挣惰,一起剝皮案震驚了整個(gè)濱河市卧斟,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌憎茂,老刑警劉巖珍语,帶你破解...
    沈念sama閱讀 217,277評(píng)論 6 503
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異竖幔,居然都是意外死亡板乙,警方通過(guò)查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,689評(píng)論 3 393
  • 文/潘曉璐 我一進(jìn)店門拳氢,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)募逞,“玉大人,你說(shuō)我怎么就攤上這事馋评》沤樱” “怎么了?”我有些...
    開(kāi)封第一講書人閱讀 163,624評(píng)論 0 353
  • 文/不壞的土叔 我叫張陵栗恩,是天一觀的道長(zhǎng)透乾。 經(jīng)常有香客問(wèn)我洪燥,道長(zhǎng),這世上最難降的妖魔是什么乳乌? 我笑而不...
    開(kāi)封第一講書人閱讀 58,356評(píng)論 1 293
  • 正文 為了忘掉前任捧韵,我火速辦了婚禮,結(jié)果婚禮上汉操,老公的妹妹穿的比我還像新娘再来。我一直安慰自己,他們只是感情好磷瘤,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,402評(píng)論 6 392
  • 文/花漫 我一把揭開(kāi)白布芒篷。 她就那樣靜靜地躺著,像睡著了一般采缚。 火紅的嫁衣襯著肌膚如雪针炉。 梳的紋絲不亂的頭發(fā)上,一...
    開(kāi)封第一講書人閱讀 51,292評(píng)論 1 301
  • 那天扳抽,我揣著相機(jī)與錄音篡帕,去河邊找鬼。 笑死贸呢,一個(gè)胖子當(dāng)著我的面吹牛镰烧,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播楞陷,決...
    沈念sama閱讀 40,135評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼怔鳖,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來(lái)了固蛾?” 一聲冷哼從身側(cè)響起结执,我...
    開(kāi)封第一講書人閱讀 38,992評(píng)論 0 275
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎魏铅,沒(méi)想到半個(gè)月后昌犹,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,429評(píng)論 1 314
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡览芳,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,636評(píng)論 3 334
  • 正文 我和宋清朗相戀三年斜姥,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片沧竟。...
    茶點(diǎn)故事閱讀 39,785評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡铸敏,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出悟泵,到底是詐尸還是另有隱情杈笔,我是刑警寧澤,帶...
    沈念sama閱讀 35,492評(píng)論 5 345
  • 正文 年R本政府宣布糕非,位于F島的核電站蒙具,受9級(jí)特大地震影響球榆,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜禁筏,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,092評(píng)論 3 328
  • 文/蒙蒙 一持钉、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧篱昔,春花似錦每强、人聲如沸。這莊子的主人今日做“春日...
    開(kāi)封第一講書人閱讀 31,723評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)。三九已至穗椅,卻和暖如春辨绊,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背房待。 一陣腳步聲響...
    開(kāi)封第一講書人閱讀 32,858評(píng)論 1 269
  • 我被黑心中介騙來(lái)泰國(guó)打工邢羔, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人桑孩。 一個(gè)月前我還...
    沈念sama閱讀 47,891評(píng)論 2 370
  • 正文 我出身青樓,卻偏偏與公主長(zhǎng)得像框冀,于是被迫代替她去往敵國(guó)和親流椒。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,713評(píng)論 2 354

推薦閱讀更多精彩內(nèi)容