數(shù)據(jù):
先來(lái)看下丟到模型里面的x棍丐,下面是直接將x當(dāng)作散點(diǎn)圖可視化误辑,每個(gè)polyline用不同的顏色表示,紅線是需要預(yù)測(cè)的agent的歷史軌跡
下面是官方的api可視化
模型結(jié)構(gòu):
class HGNN(nn.Module):
def forward(self, data):
time_step_len = int(data[0].time_step_len[0]) #83
valid_lens = data[0].valid_len # 78
sub_graph_out = self.subgraph(data)
x = sub_graph_out.x.view(-1, time_step_len, self.polyline_vec_shape)
out = self.self_atten_layer(x, valid_lens)
pred = self.traj_pred_mlp(out[:, [0]].squeeze(1))
return pred
核心代碼就四行:
1. sub_graph_out = self.subgraph(data)
2. x = sub_graph_out.x.view(-1, time_step_len, self.polyline_vec_shape)
3. out = self.self_atten_layer(x, valid_lens)
4. pred = self.traj_pred_mlp(out[:, [0]].squeeze(1))
首先看1
subGraph的forward如下
class SubGraph(nn.Module):
"""
Subgraph that computes all vectors in a polyline, and get a polyline-level feature
"""
def __init__(self, in_channels, num_subgraph_layres=3, hidden_unit=64):
super(SubGraph, self).__init__()
self.num_subgraph_layres = num_subgraph_layres
self.layer_seq = nn.Sequential()
for i in range(num_subgraph_layres):
self.layer_seq.add_module(
f'glp_{i}', GraphLayerProp(in_channels, hidden_unit))
in_channels *= 2
def forward(self, sub_data):
x, edge_index = sub_data.x, sub_data.edge_index # x 8310,8 edge_index 2,66852
for name, layer in self.layer_seq.named_modules():
if isinstance(layer, GraphLayerProp):
x = layer(x, edge_index)
sub_data.x = x # 8310歌逢,64
out_data = max_pool(sub_data.cluster, sub_data) # 1162巾钉,64
assert out_data.x.shape[0] % int(sub_data.time_step_len[0]) == 0
out_data.x = out_data.x / out_data.x.norm(dim=0)
return out_data
subgraph的核心代碼有三步
1.1
for name, layer in self.layer_seq.named_modules():
if isinstance(layer, GraphLayerProp):
x = layer(x, edge_index)
1.2 out_data = max_pool(sub_data.cluster, sub_data)
1.3 out_data.x = out_data.x / out_data.x.norm(dim=0)
先來(lái)看1.1
subgraph的forward中首先過(guò)了三層GraphLayerProp
for name, layer in self.layer_seq.named_modules():
if isinstance(layer, GraphLayerProp):
x = layer(x, edge_index)
self.layer_seq.named_modules()如下:
(glp_0): GraphLayerProp(
(mlp): Sequential(
(0): Linear(in_features=8, out_features=64, bias=True)
(1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
(2): ReLU()
(3): Linear(in_features=64, out_features=8, bias=True)
)
)
(glp_1): GraphLayerProp(
(mlp): Sequential(
(0): Linear(in_features=16, out_features=64, bias=True)
(1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
(2): ReLU()
(3): Linear(in_features=64, out_features=16, bias=True)
)
)
(glp_2): GraphLayerProp(
(mlp): Sequential(
(0): Linear(in_features=32, out_features=64, bias=True)
(1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
(2): ReLU()
(3): Linear(in_features=64, out_features=32, bias=True)
)
)
但是我們發(fā)現(xiàn)(3)linear的out_features 不等于下一層的in_features
因?yàn)?3)linear后面還有個(gè)contact的操作(具體看GraphLayerProp里面的update),讓out_features翻倍了秘案,實(shí)際上應(yīng)該是:
(8310,8)-> (8310,16)
(8310,16)-> (8310,32)
(8310,32)-> (8310,64)
現(xiàn)在咱們來(lái)具體看下GraphLayerProp
class GraphLayerProp(MessagePassing):
"""
Message Passing mechanism for infomation aggregation
"""
def __init__(self, in_channels, hidden_unit=64, verbose=False):
super(GraphLayerProp, self).__init__(
aggr='max') # MaxPooling aggragation
self.verbose = verbose
self.mlp = nn.Sequential(
nn.Linear(in_channels, hidden_unit),
nn.LayerNorm(hidden_unit),
nn.ReLU(),
nn.Linear(hidden_unit, in_channels)
)
def forward(self, x, edge_index):
if self.verbose:
print(f'x before mlp: {x}')
x = self.mlp(x)
if self.verbose:
print(f"x after mlp: {x}")
return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)
def message(self, x_j):
return x_j
def update(self, aggr_out, x):
if self.verbose:
print(f"x after mlp: {x}")
print(f"aggr_out: {aggr_out}")
return torch.cat([x, aggr_out], dim=1)
GraphLayerProp中主要有三步:
1.1.1 encoder
1.1.2 aggregate
1.1.3 contact
結(jié)合圖片來(lái)看:
1.1.1 encoder:
forward中x = self.mlp(x) 先對(duì)feature做一次mlp 砰苍,即x :(8310,8) -> (8310,64) -> x (8310,8)
x = self.mlp(x)
1.1.2 aggregate:
做一次max的gnn 的aggregate
super(GraphLayerProp, self).__init__(
aggr='max') # MaxPooling aggragation
1.1.3 contact:
將max出來(lái)的feature 和 max前的feature 做一次concat 潦匈,所以feature維度在這翻倍
torch.cat([x, aggr_out], dim=1)
上述1.1.1-1.1.3是一層GraphLayerProp,subgraph的forward中過(guò)了三層赚导,即:
(8310,8)-> (8310,16)
(8310,16)-> (8310,32)
(8310,32)-> (8310,64)
現(xiàn)在過(guò)完三次GraphLayerProp茬缩,x : (8310,64)
1.2 out_data = max_pool(sub_data.cluster, sub_data) # 1162,64
回到1.2:對(duì)每個(gè)polyline subgraph做maxpooling
sub_data.cluster 里面類似[0,0,0,0,1,1,1,1,2,2,2,3,3....1161,1161]
這里面0000吼旧,1111凰锡,222分別是不同id的車道線、車輛等的子圖,即論文中的polyline subgraphs
例如:
0黍少,0寡夹,0,0表示id為0的子圖有四個(gè)時(shí)間刻
現(xiàn)在將每個(gè)物體抽象成了一個(gè)64維向量厂置,即,將所有時(shí)間刻的向量池化為一個(gè)時(shí)間刻的向量
做maxpooling 后x:(1162魂角,64)= (14*83 昵济,64)
即有14個(gè)場(chǎng)景中,每個(gè)場(chǎng)景83個(gè)車道和車輛單一時(shí)刻的vector
1.3 out_data.x = out_data.x / out_data.x.norm(dim=0)
除以均值
2 x = sub_graph_out.x.view(-1, time_step_len, self.polyline_vec_shape)
接下來(lái)reshape一下
time_step_len = 83 (83包含了1個(gè)agent野揪,41個(gè)左車道線和41個(gè)右車道線)
x(1162,64) -> x(14,83,64)
這里14表示有14個(gè)預(yù)測(cè)場(chǎng)景访忿,每個(gè)場(chǎng)景有83個(gè)polyline,每個(gè)polyline的feature是64維的向量
3 out = self.self_atten_layer(x, valid_lens) #14,83,64
通過(guò)self attention計(jì)算每個(gè)polyline直接的注意力斯稳,再aggregate一下海铆。
self_atten_layer的初始化:
self.self_atten_layer = SelfAttentionLayer(
self.polyline_vec_shape,
global_graph_width,
need_scale=False) #64 64
def forward(self, x, valid_len):
query = self.q_lin(x) # 14,83,64
key = self.k_lin(x)
value = self.v_lin(x)
scores = torch.bmm(query, key.transpose(1, 2)) # 14,83,83
attention_weights = masked_softmax(scores, valid_len)
return torch.bmm(attention_weights, value)
4 pred = self.traj_pred_mlp(out[:, [0]].squeeze(1)) #14,60
traj_pred_mlp的初始化
self.traj_pred_mlp = TrajPredMLP(
global_graph_width, out_channels, traj_pred_mlp_width) # 64 60 64
最后一步直接把(14,83,64) -> (14,60)
60的向量由30個(gè)x坐標(biāo)值和30個(gè)y坐標(biāo)值組成,即預(yù)測(cè)的后30個(gè)時(shí)間片的軌跡坐標(biāo)
class TrajPredMLP(nn.Module):
"""Predict one feature trajectory, in offset format"""
def __init__(self, in_channels, out_channels, hidden_unit):
super(TrajPredMLP, self).__init__()
self.mlp = nn.Sequential(
nn.Linear(in_channels, hidden_unit),
nn.LayerNorm(hidden_unit),
nn.ReLU(),
nn.Linear(hidden_unit, out_channels)
)
def forward(self, x):
return self.mlp(x)