Pointnet++官方代碼解讀(tensorflow)

由于pointnet是對整個點云進(jìn)行處理后maxpooling得到全局特征称近,沒有考慮到局部特征声邦。pointnet++主要是針對這個問題進(jìn)行改進(jìn)虫几,首先將點云劃分成overlapping的不同子集豹爹,然后調(diào)用pointnet對子集進(jìn)行特征提取滑蚯,再聚合俺附,直到得到整個點云集的特征為止肥卡。實際上pointnet++就是在pointnet的基礎(chǔ)上增加了一個層次化處理的結(jié)構(gòu)。這些embedded feature可以代表完整點云的語義信息事镣,進(jìn)一步用于整個點云的cls(分類)和point level的seg(語義分割)步鉴。

整個pointnet++要解決兩個問題:

  1. 如何劃分完整點云集
  2. 如何抽象點集,提取local feature

https://github.com/charlesq34/pointnet2

代碼解讀

核心文件在models文件夾下
pointnet_cls_basic.py是基礎(chǔ)pointnet的框架
pointnet2_cls_ssg.pypointnet2_cls_msg.py分別是single-scale-groupmulti-scale-group的代碼璃哟。

核心公共模塊

先來看cls和seg公用的核心模塊pointnet_sa_module氛琢,該函數(shù)定義位于./utils/pointnet_util.py

pointnet_sa_module (PointNet Set Abstraction Layer)
def pointnet_sa_module(xyz, points, npoint, radius, nsample, mlp, mlp2, group_all, is_training, bn_decay, scope, bn=True, pooling='max', knn=False, use_xyz=True, use_nchw=False):
    ''' PointNet Set Abstraction (SA) Module
        Input:
            xyz: (batch_size, ndataset, 3) TF tensor
            points: (batch_size, ndataset, channel) TF tensor
            npoint: int32 -- #points sampled in farthest point sampling
            radius: float32 -- search radius in local region
            nsample: int32 -- how many points in each local region
            mlp: list of int32 -- output size for MLP on each point
            mlp2: list of int32 -- output size for MLP on each region
            group_all: bool -- group all points into one PC if set true, OVERRIDE
                npoint, radius and nsample settings
            use_xyz: bool, if True concat XYZ with local point features, otherwise just use point features
            use_nchw: bool, if True, use NCHW data format for conv2d, which is usually faster than NHWC format
        Return:
            new_xyz: (batch_size, npoint, 3) TF tensor
            new_points: (batch_size, npoint, mlp[-1] or mlp2[-1]) TF tensor
            idx: (batch_size, npoint, nsample) int32 -- indices for local regions
    '''

先來解析一下各個輸入輸出的含義。

    data_format = 'NCHW' if use_nchw else 'NHWC'
    with tf.variable_scope(scope) as sc:
        # Sample and Grouping
        if group_all:
            nsample = xyz.get_shape()[1].value
            new_xyz, new_points, idx, grouped_xyz = sample_and_group_all(xyz, points, use_xyz)
        else:
            new_xyz, new_points, idx, grouped_xyz = sample_and_group(npoint, radius, nsample, xyz, points, knn, use_xyz)

這一段是根據(jù)輸入的group_all布爾參數(shù)來決定執(zhí)行sample_and_group_all或者sample_and_group随闪。

sample_and_group
def sample_and_group(npoint, radius, nsample, xyz, points, knn=False, use_xyz=True):
    '''
    Input:
        npoint: int32
        radius: float32
        nsample: int32
        xyz: (batch_size, ndataset, 3) TF tensor
        points: (batch_size, ndataset, channel) TF tensor, if None will just use xyz as points
        knn: bool, if True use kNN instead of radius search
        use_xyz: bool, if True concat XYZ with local point features, otherwise just use point features
    Output:
        new_xyz: (batch_size, npoint, 3) TF tensor
        new_points: (batch_size, npoint, nsample, 3+channel) TF tensor
        idx: (batch_size, npoint, nsample) TF tensor, indices of local points as in ndataset points
        grouped_xyz: (batch_size, npoint, nsample, 3) TF tensor, normalized point XYZs
            (subtracted by seed point XYZ) in local regions
    '''

    '''
    根據(jù)fps算法從輸入xyz中選取npoint個點阳似,返回他們的index
    然后在由gather_point采樣出對應(yīng)的點,返回子集點云new_xyz  (batch_size, npoint, 3)铐伴。
    '''
    new_xyz = gather_point(xyz, farthest_point_sample(npoint, xyz))
    '''
    然后為new_xyz中的每個點在xyz中找到他的local region neighbor撮奏。
    根據(jù)grouping算法不同,對于new_xyz中的每個點盛杰,knn會返回metric space上最近的nsample個點挽荡。
    query_ball_point則會根據(jù)nsamples和radius兩方面的限制來提取pts_cnt個點(不一定能達(dá)到nsample)。
    idx是一個(batch_size, npoint, nsample)的變量即供,nsample是int32的array, indices to input points in xyz定拟。
    '''
    if knn:
        _,idx = knn_point(nsample, xyz, new_xyz)
    else:
        idx, pts_cnt = query_ball_point(radius, nsample, xyz, new_xyz)
    '''
    group_point是根據(jù)idx返回的序號將點云local_region的點云組織成有效數(shù)據(jù)結(jié)構(gòu)(batch_size, npoint, nsample, 3)。
    '''
    grouped_xyz = group_point(xyz, idx) # (batch_size, npoint, nsample, 3)
    grouped_xyz -= tf.tile(tf.expand_dims(new_xyz, 2), [1,1,nsample,1]) # translation normalization
    '''
    points: (batch_size, ndataset, channel) 是考慮范圍內(nèi)所有點的feature channel是feature的維度
    grouped_points: (batch_size, npoint, nsample, channel) 是根據(jù)idx把特征提取出來
    use_xyz==True的時候,將坐標(biāo)和特征拼接輸出(batch_size, npoint, nsample, 3+channel)
    否則青自,只輸出特征株依,不輸出坐標(biāo)。
    '''
    if points is not None:
        grouped_points = group_point(points, idx) # (batch_size, npoint, nsample, channel)
        if use_xyz:
            new_points = tf.concat([grouped_xyz, grouped_points], axis=-1) # (batch_size, npoint, nample, 3+channel)
        else:
            new_points = grouped_points
    else:
        new_points = grouped_xyz
    '''
    new_xyz是新生成的子集點云的中心點集合
    new_points可能是所有參與運(yùn)算的點的三維坐標(biāo)也有可能是所有參與運(yùn)算點的特征延窜,也有可能是坐標(biāo)+特征
    idx是所有參與運(yùn)算的點在xyz中的序號
    grouped_xyz:(batch_size, npoint, nsample, 3)是所有參與運(yùn)算的點的三維坐標(biāo)
    '''
    return new_xyz, new_points, idx, grouped_xyz

sample and group all等效于Equivalent to sample_and_group with npoint=1, radius=inf, use (0,0,0) as the centroid`,相當(dāng)于Pointnet的全局處理方式恋腕。

        # Point Feature Embedding
        if use_nchw: new_points = tf.transpose(new_points, [0,3,1,2])
        for i, num_out_channel in enumerate(mlp):
            new_points = tf_util.conv2d(new_points, num_out_channel, [1,1],
                                        padding='VALID', stride=[1,1],
                                        bn=bn, is_training=is_training,
                                        scope='conv%d'%(i), bn_decay=bn_decay,
                                        data_format=data_format) 
        if use_nchw: new_points = tf.transpose(new_points, [0,2,3,1])
        # Pooling in Local Regions
        if pooling=='max':
            new_points = tf.reduce_max(new_points, axis=[2], keep_dims=True, name='maxpool')
        elif pooling=='avg':
            new_points = tf.reduce_mean(new_points, axis=[2], keep_dims=True, name='avgpool')
        elif pooling=='weighted_avg':
            with tf.variable_scope('weighted_avg'):
                dists = tf.norm(grouped_xyz,axis=-1,ord=2,keep_dims=True)
                exp_dists = tf.exp(-dists * 5)
                weights = exp_dists/tf.reduce_sum(exp_dists,axis=2,keep_dims=True) # (batch_size, npoint, nsample, 1)
                new_points *= weights # (batch_size, npoint, nsample, mlp[-1])
                new_points = tf.reduce_sum(new_points, axis=2, keep_dims=True)
        elif pooling=='max_and_avg':
            max_points = tf.reduce_max(new_points, axis=[2], keep_dims=True, name='maxpool')
            avg_points = tf.reduce_mean(new_points, axis=[2], keep_dims=True, name='avgpool')
            new_points = tf.concat([avg_points, max_points], axis=-1)
        # [Optional] Further Processing 
        if mlp2 is not None:
            if use_nchw: new_points = tf.transpose(new_points, [0,3,1,2])
            for i, num_out_channel in enumerate(mlp2):
                new_points = tf_util.conv2d(new_points, num_out_channel, [1,1],
                                            padding='VALID', stride=[1,1],
                                            bn=bn, is_training=is_training,
                                            scope='conv_post_%d'%(i), bn_decay=bn_decay,
                                            data_format=data_format) 
            if use_nchw: new_points = tf.transpose(new_points, [0,2,3,1])

        new_points = tf.squeeze(new_points, [2]) # (batch_size, npoints, mlp2[-1])
        return new_xyz, new_points, idx

剩下的這三塊相當(dāng)?shù)膕traight forward,應(yīng)該不必解說逆瑞,注意返回的是new_xyz新的子集中心點坐標(biāo)荠藤,new_points是對應(yīng)的特征,idx是包含指向的local region的所有參與運(yùn)算點的index获高,其shape為(batch_size, npoint, nsample)哈肖。

classification任務(wù)

ssg的核心模型

用于cls的ssg核心代碼如下:

def get_model(point_cloud, is_training, bn_decay=None):
    """ Classification PointNet, input is BxNx3, output Bx40 """
    batch_size = point_cloud.get_shape()[0].value
    num_point = point_cloud.get_shape()[1].value
    end_points = {}
    l0_xyz = point_cloud
    l0_points = None
    end_points['l0_xyz'] = l0_xyz

    # Set abstraction layers
    # Note: When using NCHW for layer 2, we see increased GPU memory usage (in TF1.4).
    # So we only use NCHW for layer 1 until this issue can be resolved.
    
    """
    從原始點云中選出512個點來,每個點在其周圍選擇至多32個點作為local region念秧。
    l1_xyz : (batch_size, 512, 3)
    l1_points: (batch_size, 512, 128)
    l1_indices:(batch_size, 512, 32)
    """
    l1_xyz, l1_points, l1_indices = pointnet_sa_module(l0_xyz, l0_points, npoint=512, radius=0.2, nsample=32, mlp=[64,64,128], mlp2=None, group_all=False, is_training=is_training, bn_decay=bn_decay, scope='layer1', use_nchw=True)
    """
    從512個點中選出128個點來淤井,每個點在其周圍選擇至多64個點作為local region。
    l2_xyz : (batch_size, 128, 3)
    l2_points: (batch_size, 128, 256)
    l2_indices:(batch_size, 128, 64)
    """
    l2_xyz, l2_points, l2_indices = pointnet_sa_module(l1_xyz, l1_points, npoint=128, radius=0.4, nsample=64, mlp=[128,128,256], mlp2=None, group_all=False, is_training=is_training, bn_decay=bn_decay, scope='layer2')
    """
    從128個點中g(shù)roup all摊趾。
    l3_xyz : (batch_size, 1, 3)
    l3_points: (batch_size, 1, 256)
    l3_indices:(batch_size, 1, 128)
    """
    l3_xyz, l3_points, l3_indices = pointnet_sa_module(l2_xyz, l2_points, npoint=None, radius=None, nsample=None, mlp=[256,512,1024], mlp2=None, group_all=True, is_training=is_training, bn_decay=bn_decay, scope='layer3')

    # Fully connected layers
    net = tf.reshape(l3_points, [batch_size, -1])
    net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training, scope='fc1', bn_decay=bn_decay)
    net = tf_util.dropout(net, keep_prob=0.5, is_training=is_training, scope='dp1')
    net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training, scope='fc2', bn_decay=bn_decay)
    net = tf_util.dropout(net, keep_prob=0.5, is_training=is_training, scope='dp2')
    net = tf_util.fully_connected(net, 40, activation_fn=None, scope='fc3')

    return net, end_points

可以看到數(shù)據(jù)流基本是進(jìn)行了三次pointnet_sa_module然后得到特征送入全連接層進(jìn)行分類币狠。

pointnet_fp_module
def pointnet_fp_module(xyz1, xyz2, points1, points2, mlp, is_training, bn_decay, scope, bn=True):
    ''' PointNet Feature Propogation (FP) Module
        Input:                                                                                                      
            xyz1: (batch_size, ndataset1, 3) TF tensor                                                              
            xyz2: (batch_size, ndataset2, 3) TF tensor, sparser than xyz1                                           
            points1: (batch_size, ndataset1, nchannel1) TF tensor                                                   
            points2: (batch_size, ndataset2, nchannel2) TF tensor
            mlp: list of int32 -- output size for MLP on each point                                                 
        Return:
            new_points: (batch_size, ndataset1, mlp[-1]) TF tensor
    '''
    with tf.variable_scope(scope) as sc:
        dist, idx = three_nn(xyz1, xyz2)
        dist = tf.maximum(dist, 1e-10)
        norm = tf.reduce_sum((1.0/dist),axis=2,keep_dims=True)
        norm = tf.tile(norm,[1,1,3])
        weight = (1.0/dist) / norm
        interpolated_points = three_interpolate(points2, idx, weight)

        if points1 is not None:
            new_points1 = tf.concat(axis=2, values=[interpolated_points, points1]) # B,ndataset1,nchannel1+nchannel2
        else:
            new_points1 = interpolated_points
        new_points1 = tf.expand_dims(new_points1, 2)
        for i, num_out_channel in enumerate(mlp):
            new_points1 = tf_util.conv2d(new_points1, num_out_channel, [1,1],
                                         padding='VALID', stride=[1,1],
                                         bn=bn, is_training=is_training,
                                         scope='conv_%d'%(i), bn_decay=bn_decay)
        new_points1 = tf.squeeze(new_points1, [2]) # B,ndataset1,mlp[-1]
        return new_points1

segmentation任務(wù)

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市砾层,隨后出現(xiàn)的幾起案子漩绵,更是在濱河造成了極大的恐慌,老刑警劉巖肛炮,帶你破解...
    沈念sama閱讀 211,743評論 6 492
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件渐行,死亡現(xiàn)場離奇詭異,居然都是意外死亡铸董,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 90,296評論 3 385
  • 文/潘曉璐 我一進(jìn)店門肴沫,熙熙樓的掌柜王于貴愁眉苦臉地迎上來粟害,“玉大人,你說我怎么就攤上這事颤芬”” “怎么了?”我有些...
    開封第一講書人閱讀 157,285評論 0 348
  • 文/不壞的土叔 我叫張陵站蝠,是天一觀的道長汰具。 經(jīng)常有香客問我,道長菱魔,這世上最難降的妖魔是什么留荔? 我笑而不...
    開封第一講書人閱讀 56,485評論 1 283
  • 正文 為了忘掉前任,我火速辦了婚禮澜倦,結(jié)果婚禮上聚蝶,老公的妹妹穿的比我還像新娘杰妓。我一直安慰自己,他們只是感情好碘勉,可當(dāng)我...
    茶點故事閱讀 65,581評論 6 386
  • 文/花漫 我一把揭開白布巷挥。 她就那樣靜靜地躺著,像睡著了一般验靡。 火紅的嫁衣襯著肌膚如雪倍宾。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 49,821評論 1 290
  • 那天胜嗓,我揣著相機(jī)與錄音高职,去河邊找鬼。 笑死兼蕊,一個胖子當(dāng)著我的面吹牛初厚,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播孙技,決...
    沈念sama閱讀 38,960評論 3 408
  • 文/蒼蘭香墨 我猛地睜開眼产禾,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了牵啦?” 一聲冷哼從身側(cè)響起亚情,我...
    開封第一講書人閱讀 37,719評論 0 266
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎哈雏,沒想到半個月后楞件,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 44,186評論 1 303
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡裳瘪,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 36,516評論 2 327
  • 正文 我和宋清朗相戀三年土浸,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片彭羹。...
    茶點故事閱讀 38,650評論 1 340
  • 序言:一個原本活蹦亂跳的男人離奇死亡黄伊,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出派殷,到底是詐尸還是另有隱情还最,我是刑警寧澤,帶...
    沈念sama閱讀 34,329評論 4 330
  • 正文 年R本政府宣布毡惜,位于F島的核電站拓轻,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏经伙。R本人自食惡果不足惜扶叉,卻給世界環(huán)境...
    茶點故事閱讀 39,936評論 3 313
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧辜梳,春花似錦粱甫、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,757評論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至宗挥,卻和暖如春乌庶,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背契耿。 一陣腳步聲響...
    開封第一講書人閱讀 31,991評論 1 266
  • 我被黑心中介騙來泰國打工瞒大, 沒想到剛下飛機(jī)就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人搪桂。 一個月前我還...
    沈念sama閱讀 46,370評論 2 360
  • 正文 我出身青樓透敌,卻偏偏與公主長得像,于是被迫代替她去往敵國和親踢械。 傳聞我的和親對象是個殘疾皇子酗电,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 43,527評論 2 349