TORCH02-02:Torch常見機(jī)器學(xué)習(xí)函數(shù)

Torch的核心是求導(dǎo),求導(dǎo)依賴運(yùn)算函數(shù)的定義锰扶,函數(shù)運(yùn)算在Torch中也是重點(diǎn)之一献酗,有的函數(shù)使用Function接口規(guī)范寝受,會提升求導(dǎo)的效率;本主題主要梳理Torch與機(jī)器學(xué)習(xí)有關(guān)的函數(shù)運(yùn)算罕偎;
??1. Tensor中實(shí)現(xiàn)的函數(shù)運(yùn)算
??2. torch.nn.functional中封裝的各類函數(shù)


說明:其中的激活函數(shù)對機(jī)器學(xué)習(xí)的訓(xùn)練非常重要


張量運(yùn)算函數(shù)

  • 下面三個(gè)函數(shù)有兩種使用方式:
    • torch模塊的全局函數(shù)
    • Tensor類中的成員函數(shù)

mm與addmm運(yùn)算

mm函數(shù)說明

  • mm是矩陣與矩陣的乘法(內(nèi)積)封裝很澄,這個(gè)函數(shù)只多矩陣運(yùn)算,所以其源代碼會檢測輸入向量的維度颜及,維度必須是2維的甩苛。
    torch.mm(input, mat2, out=None) → Tensor
  • 參數(shù)說明:
    • input:運(yùn)算矩陣
    • mat2:第二個(gè)矩陣

addmm函數(shù)說明

  • 計(jì)算公式是
    • \text{out} = \beta\ \text{input} + \alpha\ (\text{mat1} \mathbin{@} \text{mat2})
    addmm(beta=1, input, alpha=1, mat1, mat2, out=None) -> Tensor
  • 參數(shù)說明:
    • beta:對應(yīng)上面公式中的\beta
    • input:對應(yīng)上面公式中的\text{input}
    • alpha:對應(yīng)上面公式中的\alpha
    • mat1:對應(yīng)上面公式中的\text{mat1}
    • mat2:對應(yīng)上面公式中的\text{mat2}

使用例子

  1. mm函數(shù)的使用例子
import torch

m_input = torch.Tensor(
    [
        [1, 2],
        [3, 4],
        [5, 6]
    ]
)

mat2 = torch.Tensor(
    [
        [1, 2, 3],
        [4, 5, 6]
    ]
)
# vec2 = torch.Tensor(
#     [1,2]
# )

out = torch.mm(m_input, mat2)
print(out)
# out = torch.mm(m_input, vec2)
# print(out)
tensor([[ 9., 12., 15.],
        [19., 26., 33.],
        [29., 40., 51.]])
  1. addmm函數(shù)的使用例子
import torch

m_input = torch.Tensor(
    [
        [1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]
    ]
)

mat1 = torch.Tensor(
    [
        [1, 2],
        [3, 4],
        [5, 6]
    ]
)

mat2 = torch.Tensor(
    [
        [1, 2, 3],
        [4, 5, 6]
    ]
)

out = torch.addmm(input=m_input, mat1=mat1, mat2=mat2, beta=1, alpha=1)
print(out)
tensor([[10., 14., 18.],
        [23., 31., 39.],
        [36., 48., 60.]])

mv與addmv運(yùn)算

  • mv/addmv與mm/addmm函數(shù)類似,不過是矩陣與向量間的運(yùn)算公式

    • mv是矩陣與向量的內(nèi)積
    • addmv的計(jì)算公式是:\text{out} = \beta\ \text{input} + \alpha\ (\text{mat} \mathbin{@} \text{vec})
  • 注意:

    • 矩陣與向量的運(yùn)算俏站,最后輸出的是向量

mv函數(shù)說明

    torch.mv(input, vec, out=None) -> Tensor

addmv函數(shù)說明

    torch.addmv(beta=1, input, alpha=1, mat, vec, out=None) -> Tensor

使用例子

  1. mv函數(shù)使用例子
import torch

m_input = torch.Tensor(
    [
        [1, 2, 3],
        [4, 5, 6]
    ]
)

vec = torch.Tensor(
    [1, 2, 3]
)

out = torch.mv(input=m_input, vec=vec)
print(out)   # 輸出的是向量
tensor([14., 32.])
  1. addmv使用例子
import torch

m_input = torch.Tensor(
    [1, 2, 3]
)

mat = torch.Tensor(
    [
        [1, 2],
        [3, 4],
        [5, 6]
    ]
)

vec = torch.Tensor(
    [1, 2]
)

out = torch.addmv(input=m_input, mat=mat, vec=vec, beta=1, alpha=1)
print(out)    # 輸出的是向量
tensor([ 6., 13., 20.])

addbmm函數(shù)

  • addbmm函數(shù)批量操作讯蒲,其封裝的計(jì)算公式如下:

    • out = \beta\ \text{input} + \alpha\ (\sum_{i=0}^{b-1} \text{batch1}_i \mathbin{@} \text{batch2}_i)

    • 其中的b表示batch維度,也就是batch1與batch2是3維矩陣肄扎,第一維表示批量大小墨林。

  • 這個(gè)函數(shù)用來圖像運(yùn)算比較方便,因?yàn)閳D像包含顏色深度犯祠。

addbmm函數(shù)說明

    torch.addbmm(beta=1, input, alpha=1, batch1, batch2, out=None) -> Tensor

使用例子

import torch

M = torch.randn(3, 5)
batch1 = torch.randn(10, 3, 4)    # 10表示批次
batch2 = torch.randn(10, 4, 5)   # 10表示批次

out = torch.addbmm(M, batch1, batch2)
print(out)
tensor([[  3.3873,  -1.3027,   2.8728,  -0.7336,   7.9643],
        [ -3.1864,  -0.2114,  -2.7413,  -5.1329,   0.3058],
        [ -9.4950,  -2.7834, -13.0682,  -6.0314,   3.7112]])
import torch

m_input = torch.Tensor(
    [
        [1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]
    ]
)

mat1 = torch.Tensor(
    [
        [
            [1, 2],
            [3, 4],
            [5, 6]
        ],   # batch =1
    ]
)

mat2 = torch.Tensor(
    [
        [
            [1, 2, 3],
            [4, 5, 6]
        ], # batch =1
    ]
)

out = torch.addbmm(input=m_input, batch1=mat1, batch2=mat2, beta=1, alpha=1)    # 要求batch必須是三維
print(out)
tensor([[10., 14., 18.],
        [23., 31., 39.],
        [36., 48., 60.]])

模塊torch.nn函數(shù)

卷積運(yùn)算函數(shù)

  • 主要使用2D卷積做圖像處理旭等,所以下面只講解下2D卷積。

函數(shù)說明

    torch.nn.functional.conv2d(
        input,             # 被處理2D矩陣
        weight,           # 2D卷積核(共享權(quán)重)
        bias=None,     # 2D偏置項(xiàng)
        stride=1,         # 卷積運(yùn)算的步長衡载,可以是一個(gè)整數(shù)(表示高寬使用相同的步長)或者元組(表示高寬使用不同的步長)搔耕。
        padding=0,     # 補(bǔ)邊長度(這個(gè)需要手工計(jì)算傳入) 
        dilation=1,       # dilation卷積核的間隔,可以一個(gè)整數(shù)(表示兩個(gè)方向一樣)痰娱,或者元組(表示兩個(gè)方向)
        groups=1) → Tensor
  • 核心參數(shù)格式說明:
    • input:四維張量Tensor弃榨,維度說明(批量數(shù),深度梨睁,高度惭墓,寬度)
    • weight:四維張量Tensor,維度說明(輸出數(shù)而姐, 深度腊凶,高度,寬度)
      • 如果指定groups,則深度分成組計(jì)算钧萍,則是個(gè)維度為(輸出數(shù)褐缠, 深度/groups,高度风瘦,寬度)
    • bias:卷積運(yùn)算的偏置項(xiàng)队魏,一維張量,維度為(輸出數(shù))

使用例子

  • 例子中万搔,使用圖像來說明
%matplotlib inline
import torch
import matplotlib.pyplot as plt

org_imgs = plt.imread("datasets/hi.jpeg")
num_imgs = org_imgs.transpose(2, 0, 1)    # 變換下維度(根據(jù)卷積函數(shù)的要求胡桨,把深度變成第一維)

# 批量(圖像數(shù)量)1,深度3瞬雹,高300昧谊,寬300
input = torch.from_numpy(num_imgs).view(
    1,     # 批量
    num_imgs.shape[0],    # 深度
    num_imgs.shape[1],    # 高度
    num_imgs.shape[2]).double()    # 寬度

# 定義卷積核
kernel = torch.DoubleTensor(
     [
         [
             [-1, -1,  0],
             [-1,  0,  1],
             [ 0,   1,  1]
         ],
        [
             [-1, -1,  0],
             [-1,  0,  1],
             [ 0,   1,  1]
         ],
        [
             [-1, -1,  0],
             [-1,  0,  1],
             [ 0,   1,  1]
         ]
     ]
 ).view(1, 3, 3 ,3)   # 輸出一副圖像,深度是3酗捌,高度是3呢诬,寬度是3
# print(input.shape)
# print(kernel.shape)
# 注意:卷積核是3大小,padding就是1胖缤,卷積核大小是5尚镰,則padding就是2,如果padding=0,則輸出圖像高寬-2
out_image = torch.conv2d(input, kernel, padding=1, bias=torch.DoubleTensor([0]))
# print(out_image)

ax1 = plt.subplot(121)
# ax1.imshow(out_image.numpy()[0][0])    # 圖像深度是1哪廓,表示灰度圖
ax1.imshow(out_image.numpy()[0][0], cmap='gray')
ax2 = plt.subplot(122)
ax2.imshow(org_imgs, cmap='gray')    # 圖像深度是3狗唉,表示彩色圖像
# print( out_image.numpy()[0][0].shape)

<matplotlib.image.AxesImage at 0x120ce5438>
卷積運(yùn)算

池化運(yùn)算函數(shù)

  • 池化運(yùn)算主要是降維,這是用于卷積神經(jīng)網(wǎng)絡(luò)的運(yùn)算涡真。
  • 池化運(yùn)算一般有兩種方式:
    • 平均池化
    • 最大池化
  • 池化的核的大小決定了降維的多少:
    • 2 * 2 的池化核:輸出是輸入的一半分俯。

函數(shù)說明

  1. 平均池化函數(shù)
    torch.nn.functional.avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None) → Tensor
  • 重要參數(shù)說明:
    • input:(minibatch,in_channels,iH,iW)與卷積的輸入是一致的;
    • kernel_size:池化核大凶郯颉(不需要數(shù)據(jù)):使用整數(shù)或者元組澳迫;
    • padding:補(bǔ)邊大小(補(bǔ)邊用于池化的cell夠用)剧劝;
    • count_include_pad : 計(jì)算平均的時(shí)候橄登,是否考慮padding的0;
    • ceil_mode:當(dāng)計(jì)算輸出形狀的時(shí)候讥此,采用取最大(ceil)還是最新G隆(floor);
  1. 最大池化函數(shù)
    torch.nn.functional.max_pool2d(*args, **kwargs)
  • 說明:
    • 與平均池化一樣萄喳,差異就是運(yùn)算方式不同:
      • 平均池化:取平均數(shù)作為輸出
      • 最大池化:取最大值作為輸出

使用例子

%matplotlib inline
import torch
import matplotlib.pyplot as plt

org_imgs = plt.imread("datasets/hi.jpeg")
num_imgs = org_imgs.transpose(2, 0, 1)    # 變換下維度(根據(jù)卷積函數(shù)的要求卒稳,把深度變成第一維)

# 批量(圖像數(shù)量)1,深度3他巨,高300充坑,寬300
input = torch.from_numpy(num_imgs).view(
    1,     # 批量
    num_imgs.shape[0],    # 深度
    num_imgs.shape[1],    # 高度
    num_imgs.shape[2]).double()    # 寬度

out = torch.nn.functional.avg_pool2d(input=input, kernel_size=5)
# print(out.shape)

out_image = out[0].byte().numpy()
out_image = out_image.transpose([1, 2, 0])
# out_image[out_image>=255] =255
# out_image[out_image<=0] =0
ax1 = plt.subplot(121, title="平均池化后圖像")
ax1.imshow(out_image)

ax2 = plt.subplot(122, title="原始圖像")
ax2.imshow(org_imgs)    # 圖像深度是3减江,表示彩色圖像
<matplotlib.image.AxesImage at 0x11febf3c8>
平均池化處理
%matplotlib inline
import torch
import matplotlib.pyplot as plt

org_imgs = plt.imread("datasets/hi.jpeg")
num_imgs = org_imgs.transpose(2, 0, 1)    # 變換下維度(根據(jù)卷積函數(shù)的要求,把深度變成第一維)

# 批量(圖像數(shù)量)1捻爷,深度3辈灼,高300,寬300
input = torch.from_numpy(num_imgs).view(
    1,     # 批量
    num_imgs.shape[0],    # 深度
    num_imgs.shape[1],    # 高度
    num_imgs.shape[2]).double()    # 寬度

out = torch.nn.functional.max_pool2d(input=input, kernel_size=5)   # 調(diào)整池化核
# print(out.shape)

out_image = out[0].byte().numpy()
out_image = out_image.transpose([1, 2, 0])
# out_image[out_image>=255] =255
# out_image[out_image<=0] =0
ax1 = plt.subplot(121, title="最大池化后圖像")
ax1.imshow(out_image)

ax2 = plt.subplot(122, title="原始圖像")
ax2.imshow(org_imgs)    # 圖像深度是3也榄,表示彩色圖像
<matplotlib.image.AxesImage at 0x11fbb1c88>
最大池化運(yùn)算

dropout函數(shù)

  • dropout函數(shù)有多個(gè):
    • dropout:隨機(jī)屏蔽訓(xùn)練元素巡莹;
    • alpha_dropout:采用伯努利分布隨機(jī)屏蔽訓(xùn)練元素,同時(shí)還做標(biāo)準(zhǔn)化處理甜紫;
    • dropout2d:對2D的數(shù)據(jù)進(jìn)行隨機(jī)屏蔽降宅;
    • dropout3d:對3D的數(shù)據(jù)進(jìn)行隨機(jī)屏蔽;

函數(shù)說明

  1. dropput函數(shù)
    torch.nn.functional.dropout(input, p=0.5, training=True, inplace=False)
  • 參數(shù)說明:
    • p表示被屏蔽的數(shù)據(jù)(屏蔽數(shù)據(jù)就是數(shù)據(jù)置零)
  1. alpha_dropout函數(shù)
    torch.nn.functional.alpha_dropout(input, p=0.5, training=False, inplace=False)
  1. dropout2d函數(shù)
    torch.nn.functional.dropout2d(input, p=0.5, training=True, inplace=False)
  • 參數(shù)說明:
    • input:格式按照上面conv2d與avg_pool2d一樣使用囚霸;
  1. dropout3d函數(shù)
    torch.nn.functional.dropout3d(input, p=0.5, training=True, inplace=False)

使用例子

  1. dropout的例子
    • 從下面例子從可以知道腰根,dropout后,還是保持均值不變(這個(gè)特性在選在不同dropout函數(shù)的時(shí)候邮辽,需要考慮)
%matplotlib inline
import torch
import matplotlib.pyplot as plt

org_imgs = plt.imread("datasets/hi.jpeg")
num_imgs = org_imgs.transpose(2, 0, 1)    # 變換下維度(根據(jù)卷積函數(shù)的要求唠雕,把深度變成第一維)

# 批量(圖像數(shù)量)1贸营,深度3吨述,高300钾菊,寬300
input = torch.from_numpy(num_imgs).view(
    1,     # 批量
    num_imgs.shape[0],    # 深度
    num_imgs.shape[1],    # 高度
    num_imgs.shape[2]).double()    # 寬度
print(input.mean(), input.var())
out = torch.nn.functional.dropout(input=input, p=0.1, training=True)   # 調(diào)整池化核
print(out.mean(), out.var())

out_image = out[0].byte().numpy()
out_image = out_image.transpose([1, 2, 0])
# print(out_image)
ax1 = plt.subplot(121, title="dropout后的圖像")
ax1.imshow(out_image, cmap="gray")

ax2 = plt.subplot(122, title="原始圖像")
ax2.imshow(org_imgs, cmap="gray")    # 圖像深度是3殴泰,表示彩色圖像
tensor(192.1158, dtype=torch.float64) tensor(8295.0039, dtype=torch.float64)
tensor(192.1322, dtype=torch.float64) tensor(13311.1657, dtype=torch.float64)





<matplotlib.image.AxesImage at 0x125382320>
dropout運(yùn)算
  1. alpha_dropout例子

    • alpha_dropout函數(shù)保持方差不變
%matplotlib inline
import torch
import matplotlib.pyplot as plt

org_imgs = plt.imread("datasets/hi.jpeg")
num_imgs = org_imgs.transpose(2, 0, 1)    # 變換下維度(根據(jù)卷積函數(shù)的要求茄袖,把深度變成第一維)

# 批量(圖像數(shù)量)1并徘,深度3宿稀,高300枫攀,寬300
input = torch.from_numpy(num_imgs).view(
    1,     # 批量
    num_imgs.shape[0],    # 深度
    num_imgs.shape[1],    # 高度
    num_imgs.shape[2]).double()    # 寬度
print(input.mean(), input.var())
out = torch.nn.functional.alpha_dropout(input=input, p=0.1, training=True)   # 調(diào)整池化核
print(out.mean(), out.var())

out_image = out[0].byte().numpy()
out_image = out_image.transpose([1, 2, 0])
# print(out_image)
ax1 = plt.subplot(121, title="dropout后的圖像")
ax1.imshow(out_image, cmap="gray")

ax2 = plt.subplot(122, title="原始圖像")
ax2.imshow(org_imgs, cmap="gray")    # 圖像深度是3掰茶,表示彩色圖像
tensor(192.1158, dtype=torch.float64) tensor(8295.0039, dtype=torch.float64)
tensor(159.2412, dtype=torch.float64) tensor(9212.0347, dtype=torch.float64)





<matplotlib.image.AxesImage at 0x125788be0>
alhpa-dropout運(yùn)算
  1. dropout2d例子
    • dropout2d是作用在整個(gè)通道上顷链。
%matplotlib inline
import torch
import matplotlib.pyplot as plt

org_imgs = plt.imread("datasets/hi.jpeg")
num_imgs = org_imgs.transpose(2, 0, 1)    # 變換下維度(根據(jù)卷積函數(shù)的要求阎毅,把深度變成第一維)

# 批量(圖像數(shù)量)1焚刚,深度3,高300扇调,寬300
input = torch.from_numpy(num_imgs).view(
    1,     # 批量
    num_imgs.shape[0],    # 深度
    num_imgs.shape[1],    # 高度
    num_imgs.shape[2]).double()    # 寬度
print(input.mean(), input.var())
out = torch.nn.functional.dropout2d(input=input, p=0.5, training=True)   # 調(diào)整池化核
print(out.mean(), out.var())

out_image = out[0].byte().numpy()
out_image = out_image.transpose([1, 2, 0])
# print(out_image)
ax1 = plt.subplot(121, title="dropout后的圖像")
ax1.imshow(out_image)

ax2 = plt.subplot(122, title="原始圖像")
ax2.imshow(org_imgs,)    # 圖像深度是3矿咕,表示彩色圖像
tensor(192.1158, dtype=torch.float64) tensor(8295.0039, dtype=torch.float64)
tensor(384.2315, dtype=torch.float64) tensor(33180.0157, dtype=torch.float64)





<matplotlib.image.AxesImage at 0x1288b2da0>
dropout2d運(yùn)算

dropout與dropout2d的區(qū)別

  • alpha_dropout與dropout/dropout2d區(qū)別在于標(biāo)準(zhǔn)化,但是dropout與dropout2d的在于dropout的方式狼钮。

    • dropout是對所有元素隨機(jī)碳柱;
    • dropout2d是像素(3個(gè)通道要么全部置零,要么保持原狀)隨機(jī)熬芜;
  • 下面是一個(gè)例子說明:

  1. dropout函數(shù)
%matplotlib inline
import torch
import matplotlib.pyplot as plt

org_imgs = plt.imread("datasets/hi.jpeg")
input = torch.from_numpy(org_imgs).double()

out = torch.nn.functional.dropout(input=input, p=0.5, training=True)   # 只對浮點(diǎn)數(shù)運(yùn)算
# -----------------
ax1 = plt.subplot(121, title="原始圖像")
ax1.imshow(org_imgs)    # 圖像深度是3莲镣,表示彩色圖像
ax2 = plt.subplot(122, title="dropout后的圖像")
ax2.imshow(out.byte().numpy())


<matplotlib.image.AxesImage at 0x11442f400>
dropout與dropout2d的區(qū)別在于:dropout是通道
  1. dropout2d函數(shù)
    • 使用第三維作為通道。效果是要么黑色涎拉,要么紅色瑞侮,因?yàn)檫@是多整個(gè)像素操作的圆,把像素置零。
%matplotlib inline
import torch
import matplotlib.pyplot as plt

org_imgs = plt.imread("datasets/hi.jpeg")
input = torch.from_numpy(org_imgs).double()

out = torch.nn.functional.dropout2d(input=input, p=0.5, training=True)   # 只對浮點(diǎn)數(shù)運(yùn)算
# -----------------
ax1 = plt.subplot(121, title="原始圖像")
ax1.imshow(org_imgs)    # 圖像深度是3半火,表示彩色圖像
ax2 = plt.subplot(122, title="dropout后的圖像")
ax2.imshow(out.byte().numpy())
<matplotlib.image.AxesImage at 0x116c8aa20>
dropout2d是像素:對所有顏色通道

線性運(yùn)算函數(shù)

  • 線性函數(shù)僅僅是針對所有維數(shù)的略板,但我們通常使用的還是二維矩陣。

linear函數(shù)

  • y = xA^T + b

bilinear函數(shù)

  • y = x_1 A x_2 + b

linear使用例子

  1. 首先關(guān)注形狀
import torch

x = torch.randn(2, 4)    # 
A = torch.randn(3, 4)
b = torch.randn(2, 3)

out = torch.nn.functional.linear(input=x, weight=A, bias=b)
print(out)
tensor([[ 2.6883,  1.2959,  0.0287],
        [-0.4045,  0.4580, -2.5265]])
  1. 然后關(guān)注計(jì)算規(guī)則(內(nèi)積)慈缔。
import torch

x = torch.LongTensor(  # 2*4
    [
        [1, 2, 3, 4],
        [5, 6, 7, 8]
    ]
)
A = torch.LongTensor(   # 3*4
    [
        [1, 2, 3, 4],
        [5, 6, 7, 8],
        [1, 1, 1, 1]
    ]
)
b = torch.LongTensor(   # 2 * 3
    [
        [1, 2, 3],
        [5, 6, 7]
    ]
)

out = torch.nn.functional.linear(input=x, weight=A, bias=b)
print(out)
tensor([[ 31,  72,  13],
        [ 75, 180,  33]])
  1. 多維的情況
    • 關(guān)于維數(shù)的說明:
      • Input: (N, *, in\_features) 其中 * 表示任意多個(gè)維度
      • Weight: (out\_features, in\_features)
      • Bias: (out\_features)
      • Output: (N, *, out\_features)
import torch

x = torch.randn(2, 3, 5, 4)    # N=2, * (3, 5),????_????????????????=4
A = torch.randn(3, 4)          # ??????_????????????????= 3, ????_????????????????=4
b = torch.randn(3)              #  ??????_????????????????= 3
b = torch.Tensor([88, 888, 8888])

out = torch.nn.functional.linear(input=x, weight=A, bias=b)
print(out.shape)    # N=2,   * (3, 5), ??????_????????????????= 3
print(out)
torch.Size([2, 3, 5, 3])
tensor([[[[  87.6084,  888.8325, 8889.0488],
          [  89.4753,  889.6735, 8886.3359],
          [  88.0262,  890.7296, 8890.6123],
          [  89.7283,  890.1259, 8886.5527],
          [  86.9266,  891.3016, 8888.4502]],

         [[  87.8339,  890.9385, 8889.3877],
          [  88.5911,  889.2175, 8889.3818],
          [  85.6238,  885.1033, 8887.4404],
          [  88.8761,  887.0152, 8885.6064],
          [  86.9773,  885.6030, 8887.9336]],

         [[  88.2102,  889.2548, 8889.2451],
          [  89.2577,  885.4579, 8883.8271],
          [  87.3587,  888.9655, 8887.9131],
          [  88.3702,  888.2765, 8886.4570],
          [  87.8151,  886.6364, 8887.6582]]],


        [[[  86.8274,  888.8477, 8889.2188],
          [  87.7344,  888.5159, 8889.4189],
          [  88.3876,  887.8715, 8888.4531],
          [  90.0545,  885.2833, 8887.1396],
          [  90.2269,  889.3709, 8885.2266]],

         [[  87.8483,  889.3448, 8888.0000],
          [  89.7446,  889.5375, 8888.2314],
          [  84.5236,  885.5975, 8887.9775],
          [  89.0541,  887.0644, 8885.5928],
          [  91.5662,  889.8510, 8885.7021]],

         [[  87.5693,  891.6489, 8890.9561],
          [  88.2609,  887.5566, 8888.3799],
          [  87.9360,  890.4555, 8888.3887],
          [  87.0874,  885.4625, 8887.5732],
          [  87.1704,  884.4638, 8889.4834]]]])
import torch

x = torch.randn(2, 4)    # N=2, * (3, 5),????_????????????????=4
A = torch.randn(3, 4)          # ??????_????????????????= 3, ????_????????????????=4
b = torch.randn(3)              #  ??????_????????????????= 3
b = torch.Tensor([88, 888, 8888])

out = torch.nn.functional.linear(input=x, weight=A, bias=b)
print(out.shape)    # N=2,   * (3, 5), ??????_????????????????= 3
print(out)
torch.Size([2, 3])
tensor([[  88.3087,  889.9603, 8887.1494],
        [  87.2756,  886.3957, 8887.5176]])

bilinear使用例子

  • 主要注意下形狀即可叮称。

    • Input1: (N, *, \text{in1_features})
    • Input2: (N, *, \text{in2_features})
    • Output: (N, *, \text{out_features})
    • weight:(\text{out_features}, \text{in1_features}, \text{in2_features})
    • bias:(\text{out_features})
  • 注意

    • 兩個(gè)輸入數(shù)據(jù)的中間 *表示一樣的批量數(shù)據(jù)。
    • bilinear運(yùn)算主要用于圖像的雙線性插值運(yùn)算藐鹤,比如圖像放大:如果僅僅是對一個(gè)像素賦值變成4個(gè)瓤檐,圖像會模糊,采用線性插值采樣娱节,圖像效果好得多挠蛉。
import torch

input1 = torch.randn(3,2,4,3)      # 中間的2與4必須與下面的中間2與4一致
input2 = torch.randn(3,2,4,4)     # 中間的2與4必須與上面的中間2與4一致
weight = torch.randn(5,3, 4)      # 必須三維
bias = torch.randn(5)

out = torch.nn.functional.bilinear(input1=input1, input2=input2, weight=weight, bias=bias)
print(out.shape)
torch.Size([3, 2, 4, 5])

激活函數(shù)

  • 下面的激活函數(shù)一般都提供兩個(gè)版本:
    • 返回版本
    • 修改版本(后綴_的函數(shù))

threshold函版本

  • 線性函數(shù)的閾值,本質(zhì)是relu的功能增強(qiáng)版本肄满。
    torch.nn.functional.threshold(input, threshold, value, inplace=False)
    
  • 參數(shù)說明:

    • input:需要處理的數(shù)據(jù)谴古,可以是任意shape
    • threshold:閾值
    • value:替代閾值以下值得值。
  • 例子代碼

%matplotlib inline
import torch
import matplotlib.pyplot as plt

x = torch.linspace(-10, 10, 1000)
y = torch.nn.functional.threshold(input=x, threshold=2, value=4)
plt.plot(x, y, color=(1, 0, 0, 1))

[<matplotlib.lines.Line2D at 0x11759f5c0>]
threshold函數(shù)

relu函數(shù)

  • 線性整流函數(shù)(Rectified Linear Unit, ReLU)稠歉,又稱修正線性單元掰担;
    torch.nn.functional.relu(input, inplace=False) → Tensor
  • 例子代碼
%matplotlib inline
import torch
import matplotlib.pyplot as plt

x = torch.linspace(-10, 10, 1000)
y = torch.nn.functional.relu(input=x)
plt.plot(x, y, color=(1, 0, 0, 1))
[<matplotlib.lines.Line2D at 0x11799bac8>]
relu函數(shù)

hardtanh函數(shù)

  • hard正切雙曲線,是tanh的增強(qiáng)版本
    • 使用最大與最小限制輸出值范圍怒炸。默認(rèn)限制在[-1, 1]之間带饱。
    torch.nn.functional.hardtanh(input, min_val=-1., max_val=1., inplace=False) → Tensor
  • 參數(shù)說明:

    • min_val=-1.
    • max_val=1.,
  • 例子代碼

%matplotlib inline
import torch
import matplotlib.pyplot as plt

x = torch.linspace(-10, 10, 1000)
y = torch.nn.functional.hardtanh(input=x)
plt.plot(x, y, color=(1, 0, 0, 1))
[<matplotlib.lines.Line2D at 0x117bba160>]
hardtanh函數(shù)

relu6

  • relu函數(shù)的特殊版本:
    • 輸出范圍現(xiàn)在0-6之間
    torch.nn.functional.relu6(input, inplace=False) → Tensor
  • 例子代碼
%matplotlib inline
import torch
import matplotlib.pyplot as plt

x = torch.linspace(-10, 10, 1000)
y = torch.nn.functional.relu6(input=x)
plt.plot(x, y, color=(1, 0, 0, 1))
[<matplotlib.lines.Line2D at 0x117d634a8>]
relu6函數(shù)

elu函數(shù)

  • 指數(shù)運(yùn)算修正的線性單元;
    • f(x) = \begin{cases} x&x \ge 0\\ \alpha ( e^x - 1 )&x < 0\\ \end{cases}
    torch.nn.functional.elu(input, alpha=1.0, inplace=False)
  • 參數(shù)說明:

    • input:被處理的數(shù)據(jù)
    • alpha=1.0:公式中的\alpha
  • 例子代碼

%matplotlib inline
import torch
import matplotlib.pyplot as plt

x = torch.linspace(-10, 10, 1000)
y = torch.nn.functional.elu(input=x, alpha=1.0)
plt.plot(x, y, color=(1, 0, 0, 1))
[<matplotlib.lines.Line2D at 0x117f55d68>]
elu函數(shù)

selu函數(shù)

  • selu公式如下阅羹;
    • f(x) = \begin{cases} \lambda x&x \ge 0\\ \lambda \alpha ( e^x - 1 )&x < 0\\ \end{cases}

    • \lambda=1.0507009873554804934193349852946

    • \alpha=1.6732632423543772848170429916717

    torch.nn.functional.selu(input, inplace=False) → Tensor
  • 例子代碼
%matplotlib inline
import torch
import matplotlib.pyplot as plt

x = torch.linspace(-10, 10, 1000)
y = torch.nn.functional.selu(input=x)
plt.plot(x, y, color=(1, 0, 0, 1))
[<matplotlib.lines.Line2D at 0x1180959e8>]
selu函數(shù)

celu函數(shù)

  • 函數(shù)公式如下:
    • f(x) = \begin{cases} x&x \ge 0\\ \alpha ( e^{\frac{x}{\alpha}} - 1 )&x < 0\\ \end{cases}
    torch.nn.functional.celu(input, alpha=1., inplace=False) → Tensor
  • 例子代碼
%matplotlib inline
import torch
import matplotlib.pyplot as plt

x = torch.linspace(-10, 10, 1000)
y = torch.nn.functional.celu(input=x)
plt.plot(x, y, color=(1, 0, 0, 1))
[<matplotlib.lines.Line2D at 0x1180fa6a0>]
celu函數(shù)

leaky_relu函數(shù)

  • 函數(shù)公式如下:
    • f(x) = \begin{cases} x&x \ge 0\\ \dfrac{x}{\alpha}&x < 0\\ \end{cases}
    torch.nn.functional.leaky_relu(input, negative_slope=0.01, inplace=False) → Tensor
  • 參數(shù)說明:

    • negative_slope:就是公式中的\dfrac{1}{\alpha}
  • 例子代碼

%matplotlib inline
import torch
import matplotlib.pyplot as plt

x = torch.linspace(-10, 10, 1000)
y = torch.nn.functional.leaky_relu(input=x, negative_slope=0.01)
plt.plot(x, y, color=(1, 0, 0, 1))
[<matplotlib.lines.Line2D at 0x1181cb1d0>]
leaky_relu函數(shù)

prelu函數(shù)

  • 是leaky_relu函數(shù)參數(shù)化版本(weight需要是一個(gè)張量)勺疼,下面公式中的乘法是叉乘。
  • 函數(shù)公式如下:
    • f(x) = \begin{cases} x&x \ge 0\\ w * x&x < 0\\ \end{cases}
    torch.nn.functional.prelu(input, weight) → Tensor
  • 例子代碼
%matplotlib inline
import torch
import matplotlib.pyplot as plt

x = torch.linspace(-10, 10, 1000)
y = torch.nn.functional.prelu(input=x, weight=torch.Tensor([0.01]))
plt.plot(x, y, color=(1, 0, 0, 1))
[<matplotlib.lines.Line2D at 0x118318a20>]
prelu函數(shù)

rrelu函數(shù)

  • 隨機(jī)修正線性單元:
  • 函數(shù)公式如下:
    • f(x) = \begin{cases} x&x \ge 0\\ \alpha x&x < 0\\ \end{cases}

    • 其中\alpha是服從[lower, upper)之間的均勻分布隨機(jī)概率捏鱼。

    torch.nn.functional.rrelu(input, lower=1./8, upper=1./3, training=False, inplace=False) → Tensor
  • 參數(shù):

    • 指定均勻分布的區(qū)間范圍的兩個(gè)參數(shù):
      • lower=1./8
      • upper=1./3
  • 例子代碼

%matplotlib inline
import torch
import matplotlib.pyplot as plt

x = torch.linspace(-10, 10, 1000)
y = torch.nn.functional.rrelu(input=x, lower=0.4,  upper=0.6)
plt.plot(x, y, color=(1, 0, 0, 1))
[<matplotlib.lines.Line2D at 0x1184ab898>]
rrelu函數(shù)

glu函數(shù)

  • 網(wǎng)管線性單元
  • 計(jì)算公式:GLU(a,b)=a \otimes \sigma(b)
    • 其中\sigma函數(shù)就是sigmoid函數(shù)(下面會專門介紹)
    • \otimes是矩陣叉乘(對應(yīng)元素相乘)
    • 其中的a, b是輸入矩陣按照指定的維度一分為二形成的矩陣执庐。
  • 注意:
    • 要求input至少是2D矩陣。而且被拆分的維度必須是偶數(shù)
    torch.nn.functional.glu(input, dim=-1) → Tensor
  • 例子代碼
import torch
x = torch.Tensor(
    [
        [1, 2, 3, 4],
        [4, 5, 6, 7],
        [7, 8, 9, 10]
    ]
) 

y =  torch.nn.functional.glu(x, dim=-1)
print(y)
tensor([[0.9526, 1.9640],
        [3.9901, 4.9954],
        [6.9991, 7.9996]])

gelu函數(shù)

  • 函數(shù)公式是:
    • GeLU(x) = x * \Phi (x)
      • 其中\Phi(x)是高斯分布函數(shù)(標(biāo)準(zhǔn)正態(tài)分布)导梆。
    torch.nn.functional.gelu(input) → Tensor
  • 說明

    • 任何作用在標(biāo)量上的函數(shù)對矩陣也是有效的轨淌,對矩陣的運(yùn)算就是按照元素操作。
  • 例子代碼

%matplotlib inline
import torch
import matplotlib.pyplot as plt

x = torch.linspace(-10, 10, 1000)
y = torch.nn.functional.gelu(input=x)
plt.plot(x, y, color=(1, 0, 0, 1))
[<matplotlib.lines.Line2D at 0x11874e438>]
gelu函數(shù)

logsigmoid函數(shù)

  • 是sigmoid函數(shù)再做log對數(shù)運(yùn)算问潭,函數(shù)公式如下:
    • f(x) = log(\dfrac{1}{1+e^{-x}})
    torch.nn.functional.logsigmoid(input) → Tensor
  • 例子代碼
%matplotlib inline
import torch
import matplotlib.pyplot as plt

x = torch.linspace(-10, 10, 1000)
y = torch.nn.functional.logsigmoid(input=x)
plt.plot(x, y, color=(1, 0, 0, 1))
[<matplotlib.lines.Line2D at 0x1188fa940>]
logsigmoid函數(shù)

hardshrink函數(shù)

  • 分段限制收縮函數(shù)猿诸,公式如下:
    • \begin{split}f(x) = \begin{cases} x&\text{如: } x > \lambda \\ x&\text{如: } x < -\lambda \\ 0&\text{其他 } \end{cases}\end{split}
    torch.nn.functional.hardshrink(input, lambd=0.5) → Tensor
  • 例子代碼
%matplotlib inline
import torch
import matplotlib.pyplot as plt

x = torch.linspace(-10, 10, 1000)
y = torch.nn.functional.hardshrink(input=x, lambd=3.0)
plt.plot(x, y, color=(1, 0, 0, 1))
[<matplotlib.lines.Line2D at 0x118415e48>]
hardshrink函數(shù)

tanhshrink

  • 正切雙曲限制收縮函數(shù),函數(shù)公式為:
    • Tanhshrink(x)= x ? Tanh(x)
    torch.nn.functional.tanhshrink(input) → Tensor
  • 例子代碼
%matplotlib inline
import torch
import matplotlib.pyplot as plt

x = torch.linspace(-10, 10, 1000)
y = torch.nn.functional.tanhshrink(input=x)
plt.plot(x, y, color=(1, 0, 0, 1))
[<matplotlib.lines.Line2D at 0x1189c9c88>]
tanhshrink函數(shù)

softsign函數(shù)

  • 符號函數(shù)的增強(qiáng)版本狡忙,表示公式如下:
    • f(x) = \dfrac{x}{1 +| x |}
    torch.nn.functional.softsign(input) → Tensor
  • 例子代碼
%matplotlib inline
import torch
import matplotlib.pyplot as plt

x = torch.linspace(-10, 10, 1000)
y = torch.nn.functional.softsign(input=x)
plt.plot(x, y, color=(1, 0, 0, 1))
[<matplotlib.lines.Line2D at 0x1187a0c88>]
softsign函數(shù)

softplus函數(shù)

  • softplus可以看作是ReLu的平滑梳虽,其函數(shù)公式如下:
    • f(x) = \ln(1 + e^{x})
    torch.nn.functional.softplus(input, beta=1, threshold=20) → Tensor
  • 因?yàn)閟oftplus的導(dǎo)數(shù)是邏輯分布函數(shù),所以上面beta參數(shù)與threshold參數(shù)是邏輯分布中對應(yīng)的參數(shù)灾茁!

  • 例子代碼

%matplotlib inline
import torch
import matplotlib.pyplot as plt

x = torch.linspace(-10, 10, 1000)
y = torch.nn.functional.softplus(input=x, beta=1, threshold=20)
plt.plot(x, y, color=(1, 0, 0, 1))
[<matplotlib.lines.Line2D at 0x1204d4390>]
softplus函數(shù)
import torch
%matplotlib inline
import torch
import matplotlib.pyplot as plt

x = torch.linspace(-10, 10, 1000)

y = 1 + x.exp()
y = y.log()
plt.plot(x, y, color=(1, 0, 0, 1))
[<matplotlib.lines.Line2D at 0x12060bef0>]
softplus函數(shù)實(shí)現(xiàn)過程

softmin函數(shù)

  • 是softmax函數(shù)的對偶函數(shù)
    • softmin= softmax(-x)
    • softmax是對數(shù)標(biāo)準(zhǔn)化
    torch.nn.functional.softmin(input, dim=None, _stacklevel=3, dtype=None)
  • 例子代碼
%matplotlib inline
import torch
import matplotlib.pyplot as plt

x = torch.linspace(-10, 10, 1000)
y = torch.nn.functional.softmin(input=x, dim=0)
plt.plot(x, y, color=(1, 0, 0, 1))
[<matplotlib.lines.Line2D at 0x12074c2e8>]
softmin函數(shù)

softmax函數(shù)

  • 函數(shù)公式是:f(x) = \dfrac{e^{x_i}}{\sum \limits _j e^{x_j}}

  • 支持多維窜觉,所以可以指定按照某個(gè)維度求和谷炸。

    torch.nn.functional.softmax(input, dim=None, _stacklevel=3, dtype=None)
  • 例子代碼
%matplotlib inline
import torch
import matplotlib.pyplot as plt

x = torch.linspace(-10, 10, 1000)
y = torch.nn.functional.softmax(input=x, dim=0)
plt.plot(x, y, color=(1, 0, 0, 1))
[<matplotlib.lines.Line2D at 0x120892da0>]
softmin函數(shù)
  • 手工實(shí)現(xiàn)softmax
%matplotlib inline
import torch
import matplotlib.pyplot as plt

x = torch.linspace(-10, 10, 1000)
s = 0
for xi in x:
    s += xi.exp()
y = x.exp()/s
plt.plot(x, y, color=(1, 0, 0, 1))
[<matplotlib.lines.Line2D at 0x120900a20>]
softmin實(shí)現(xiàn)過程

softshrink函數(shù)

  • 函數(shù)公式為:
    • \begin{split}out = \begin{cases} x - \lambda, \text{if } x > \lambda \\ x + \lambda, \text{if } x < -\lambda \\ 0, \text{otherwise} \end{cases}\end{split}
    torch.nn.functional.softshrink(input, lambd=0.5) → Tensor
  • 例子代碼
%matplotlib inline
import torch
import matplotlib.pyplot as plt

x = torch.linspace(-10, 10, 1000)
y = torch.nn.functional.softshrink(input=x, lambd=3)
plt.plot(x, y, color=(1, 0, 0, 1))
[<matplotlib.lines.Line2D at 0x120b29320>]
softshrink函數(shù)

gumbel_softmax函數(shù)

  • Gumbel-Softmax distribution (Link 1 Link 2) :耿貝爾分布;
    • p(x)=\frac{1}{\beta}e^{-z-e^{-z}}
      • z=\frac{x-\mu}{\beta}
    torch.nn.functional.gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1)
  • 參數(shù)說明:

    • hard=False:離散化為one-hot向量
  • 例子代碼

%matplotlib inline
import torch
import matplotlib.pyplot as plt

x = torch.linspace(0, 1, 100)
y = torch.nn.functional.gumbel_softmax(logits=x, tau=1, hard=False)
plt.plot(x, y, color=(1, 0, 0, 1))
[<matplotlib.lines.Line2D at 0x120eef198>]
gumbel_softmax函數(shù)
%matplotlib inline
import torch
import matplotlib.pyplot as plt

x = torch.linspace(0, 1, 100)
y = torch.nn.functional.gumbel_softmax(logits=x, tau=1, hard=True)
plt.plot(x, y, color=(1, 0, 0, 1))
[<matplotlib.lines.Line2D at 0x121020e10>]
gumbel_softmax禀挫,hard=True

log_softmax函數(shù)

  • 對softmax做一個(gè)log對數(shù)運(yùn)算
    torch.nn.functional.log_softmax(input, dim=None, _stacklevel=3, dtype=None)
  • 例子代碼
%matplotlib inline
import torch
import matplotlib.pyplot as plt

x = torch.linspace(-10, 10, 100)
y = torch.nn.functional.log_softmax(input=x, dim=-1)
plt.plot(x, y, color=(1, 0, 0, 1))
[<matplotlib.lines.Line2D at 0x1212e0390>]
log_softmax函數(shù)

tanh函數(shù)

  • 這是正切雙曲線函數(shù):tanh(x)=\dfrac{sinh(x)}{cosh(x)}=\dfrac{e^x-e^{-x}}{e^x+e^{-x}}
  • 一個(gè)可以使用自身表示其導(dǎo)數(shù)的函數(shù):
    • \begin{align*} tanh^{'}(x)&=((e^x-e^{-x})(e^x+e^{-x}))^{'} \\ &=(e^x+e^{-x})(e^x+e^{-x})^{-1}-(e^x-e^{-x})(e^x+e^{-x})^{-2}(e^x-e^{-x})\\ &=1-\frac{(e^x-e^{-x})^{2}}{(e^x+e^{-x})^{2}}\\ &=1-tanh^2(x) \end{align*}
    torch.nn.functional.tanh(input) → Tensor
  • 例子代碼
%matplotlib inline
import torch
import matplotlib.pyplot as plt

x = torch.linspace(-10, 10, 100)
# y = torch.nn.functional.tanh(input=x)
y = torch.tanh(input=x)
plt.plot(x, y, color=(1, 0, 0, 1))
[<matplotlib.lines.Line2D at 0x121504518>]
tanh函數(shù)

sigmoid函數(shù)

  • 函數(shù)公式:f(x) = \dfrac{1}{1 + e ^ {-x}}

  • 例子代碼

%matplotlib inline
import torch
import matplotlib.pyplot as plt

x = torch.linspace(-10, 10, 100)
y = torch.sigmoid(input=x)
plt.plot(x, y, color=(1, 0, 0, 1))
[<matplotlib.lines.Line2D at 0x121568dd8>]
sigmoid函數(shù)

標(biāo)準(zhǔn)化函數(shù)

  • 標(biāo)準(zhǔn)函數(shù)提供了5個(gè):
    • normalize:標(biāo)準(zhǔn)化
    • batch_norm:批量標(biāo)準(zhǔn)化
    • instance_norm:批處理實(shí)例標(biāo)準(zhǔn)化
    • layer_norm:層標(biāo)準(zhǔn)化
    • local_response_norm:局部響應(yīng)標(biāo)準(zhǔn)化

normalize函數(shù)

    torch.nn.functional.normalize(
        input,    # 處理的數(shù)據(jù)
        p=2,     # 標(biāo)準(zhǔn)化使用的范數(shù)
        dim=1,   # 標(biāo)準(zhǔn)化維度
        eps=1e-12,    # 修正量(用來防止標(biāo)準(zhǔn)化中分母為0的情況)
        out=None)
  • 標(biāo)準(zhǔn)化計(jì)算公式:
    • v = \dfrac{v}{max{(||v||_p, \epsilon)}}
import torch
x = torch.Tensor([1,2,3,4])

y = torch.nn.functional.normalize(x, p=2, dim=0)
print(y)
print(y.mean(), y.var())
print(x.norm())
print(x/x.norm())
tensor([0.1826, 0.3651, 0.5477, 0.7303])
tensor(0.4564) tensor(0.0556)
tensor(5.4772)
tensor([0.1826, 0.3651, 0.5477, 0.7303])
# 范數(shù)的計(jì)算:到原點(diǎn)的距離(范數(shù)定義的距離)為1.
import torch
x = torch.Tensor([1,2,3,4])
sum = 0
for x_ in x:
    sum += x_ * x_

print(sum)
print(sum.sqrt())
tensor(30.)
tensor(5.4772)

batch_norm函數(shù)

  • 深度學(xué)習(xí)都需要對數(shù)據(jù)做歸一化

    1. 深度神經(jīng)網(wǎng)絡(luò)主要就是為了學(xué)習(xí)訓(xùn)練數(shù)據(jù)的分布旬陡,并在測試集上達(dá)到很好的泛化效果,但是语婴,如果我們每一個(gè)batch輸入的數(shù)據(jù)都具有不同的分布描孟,顯然會給網(wǎng)絡(luò)的訓(xùn)練帶來困難。
    2. 數(shù)據(jù)經(jīng)過一層層網(wǎng)絡(luò)計(jì)算后砰左,其數(shù)據(jù)分布也在發(fā)生著變化匿醒,此現(xiàn)象稱為Internal Covariate Shift,batchnorm就是用來解決這個(gè)分布變化的問題缠导。
  • 對神經(jīng)網(wǎng)絡(luò)的每一層做歸一化的問題廉羔;

    • 假設(shè)將每一層輸出后的數(shù)據(jù)都?xì)w一化到0均值,1方差僻造,滿足正太分布憋他,但是每一層的數(shù)據(jù)分布都是標(biāo)準(zhǔn)正太分布,導(dǎo)致其完全學(xué)習(xí)不到輸入數(shù)據(jù)的特征髓削,因?yàn)閷W(xué)習(xí)到的特征分布被歸一化了竹挡,因此,直接對每一層做歸一化顯然是不合理的蔬螟。
    • 但是如果稍作修改此迅,加入可訓(xùn)練的參數(shù)做歸一化汽畴,那就是BatchNorm實(shí)現(xiàn)的了旧巾。
  • Batchnorm本身上也是一種正則的方式,可以代替其他正則方式如dropout等忍些;

    torch.nn.functional.batch_norm(
        input,                      # 需要處理的數(shù)據(jù)
        running_mean,          # 均值(用來做數(shù)據(jù)預(yù)測用的:對預(yù)測數(shù)據(jù)使用這個(gè)均值與方差來標(biāo)準(zhǔn)化)
        running_var,              # 方差(用來做數(shù)據(jù)預(yù)測用的)
        weight=None,            # 縮放系數(shù)   
        bias=None,                # 位移系數(shù)
        training=False, 
        momentum=0.1,          # 動量參數(shù)(每次訓(xùn)練過程中罢坝,用來修正running_mean與running_var的)
        eps=1e-05)                # 防止為0的誤差量
import torch
x = torch.Tensor(
    [
        [1],
        [2],
        [3],
        [4]
    ]
)

m = torch.Tensor([0.0])
# m = x.mean(dim=0)
v = torch.Tensor([1.0])

w = torch.Tensor([1.0])     # w與b是需要學(xué)習(xí)的,因?yàn)闃?biāo)準(zhǔn)化容易丟棄已經(jīng)學(xué)習(xí)的規(guī)則隙券,使用參數(shù)是為了保留已學(xué)習(xí)的規(guī)則闹司。
w.requires_grad=True
b = torch.Tensor([0.0])
b.requires_grad=True

y = torch.nn.functional.batch_norm(
    input=x, 
    running_mean=m,      # 指定標(biāo)準(zhǔn)化需要的均值
    running_var=v,           # 指定標(biāo)準(zhǔn)需要的方差
    weight=w,                 # 指定需要訓(xùn)練的標(biāo)準(zhǔn)化特征
    bias=b,                     # 指定需要訓(xùn)練的標(biāo)準(zhǔn)化特征
    momentum=0.1)
print(y)   
print(m,v)

# momentum參數(shù)在函數(shù)中沒有使用
tensor([[1.0000],
        [2.0000],
        [3.0000],
        [4.0000]], grad_fn=<NativeBatchNormBackward>)
tensor([0.]) tensor([1.])
  • 上面還不是最適合的使用方式娱仔,下面是對實(shí)際場景的最直觀使用耐朴。
    • 下面x的深度是3筛峭,所以m,v必須是3維數(shù)俯艰。
import torch
x= torch.randn(4, 3, 5, 5)   # 4張3通道5*5的圖像

m = x.mean(dim=(0,2,3))   
v = x.var(dim=(0,2,3))
print(m,v)
w = torch.Tensor([1.0, 1.0, 1.0])     # w與b是需要學(xué)習(xí)的,因?yàn)闃?biāo)準(zhǔn)化容易丟棄已經(jīng)學(xué)習(xí)的規(guī)則啦辐,使用參數(shù)是為了保留已學(xué)習(xí)的規(guī)則。
w.requires_grad=True
b = torch.Tensor([0.0, 0.0, 0.0])
b.requires_grad=True

y = torch.nn.functional.batch_norm(
    input=x, 
    running_mean=m,      # 指定標(biāo)準(zhǔn)化需要的均值
    running_var=v,           # 指定標(biāo)準(zhǔn)需要的方差
    weight=w,                 # 指定需要訓(xùn)練的標(biāo)準(zhǔn)化特征
    bias=b,                     # 指定需要訓(xùn)練的標(biāo)準(zhǔn)化特征
    momentum=0.1)
print(y)   
print(m,v)

# momentum參數(shù)在函數(shù)中沒有使用
tensor([-0.1099, -0.1201,  0.0899]) tensor([1.1323, 1.0076, 1.1978])
tensor([[[[ 0.6039, -0.5582,  0.4959, -0.6427, -0.5870],
          [-0.3643,  0.5706, -0.5763,  1.4945, -1.4511],
          [ 0.1614,  0.5990,  1.9546, -0.3642,  1.4048],
          [ 0.7057, -0.9043,  0.1523,  0.9368, -0.2871],
          [ 0.8770, -0.9029,  0.2469, -0.6572, -0.2440]],

         [[ 0.7871, -2.0743,  0.7447,  1.3778,  0.0870],
          [ 0.3714,  0.2540, -0.3870, -0.9554, -0.9491],
          [ 2.3781, -0.0893, -0.0280,  1.3762,  0.4648],
          [-1.1302, -1.2012,  1.0125,  1.7782,  0.1390],
          [-0.0499, -0.1794,  0.4153, -0.4797, -1.5406]],

         [[-0.3106,  0.3042, -2.6128, -0.6955,  0.1389],
          [ 0.6331,  0.5983, -0.3249, -1.0146, -0.3954],
          [ 0.7133, -1.6992, -0.0674, -0.0610, -0.1267],
          [ 2.2547,  1.1713, -1.6066,  1.9786,  1.5711],
          [ 0.9024, -0.9147,  0.6853,  0.2226, -0.0866]]],


        [[[-0.1178,  0.2178, -0.4152, -0.2686,  1.8236],
          [-1.4309, -1.3974, -0.2394,  0.4349, -1.5596],
          [ 0.6755,  0.3908, -0.9385,  0.2810,  0.5441],
          [ 0.7135,  0.4403, -0.0172, -0.0238,  1.2088],
          [ 0.7274,  1.1456, -0.8432, -2.3800, -2.0214]],

         [[-1.1411, -2.1205, -0.2992,  0.0374,  0.3259],
          [-0.1580, -0.8872, -0.0905,  0.2920,  0.3672],
          [ 0.6656,  0.5359, -0.9617, -1.1993,  1.9091],
          [-0.9763, -1.2720,  1.6294, -0.8532,  0.8935],
          [ 1.4322,  0.5476,  2.3484,  0.3687,  0.4523]],

         [[-0.3341,  1.3806,  1.0831,  0.4029, -0.8603],
          [-0.0656,  0.0882,  1.1263, -0.3221, -1.2238],
          [ 0.9934, -0.7986, -0.4356,  0.1994,  0.5767],
          [ 0.7159,  0.7112, -2.4808, -0.3498,  0.1422],
          [ 0.1981, -0.5962, -0.4935, -0.9984,  2.4134]]],


        [[[ 0.1386, -0.4204,  0.0246, -1.2291,  0.9505],
          [ 0.8621, -0.9966,  0.7689,  1.3110, -0.4741],
          [ 0.6832,  0.3085,  1.0335, -0.7198,  0.9119],
          [-0.8902, -0.5828, -0.5656,  0.1060, -0.7015],
          [ 0.7235, -2.8109, -1.7042, -0.1866, -1.0658]],

         [[-0.9406, -0.8973,  0.8829, -1.2684, -0.2581],
          [ 2.1365,  0.0305,  0.4322, -0.1121,  0.8947],
          [ 0.0875,  0.0843, -1.1963,  0.5078, -0.8428],
          [-0.3173, -1.0045, -1.4108,  1.7931,  0.2055],
          [ 0.1410, -0.9936,  0.4887, -0.7456, -1.8184]],

         [[ 0.5326, -0.0776, -0.2944,  1.2970,  0.3175],
          [ 1.0699,  0.5990,  0.6703, -0.7066, -0.8497],
          [-0.4794,  0.3368, -0.3240, -0.0839, -0.6898],
          [-0.4844,  0.6520, -0.9095,  0.6653, -0.4540],
          [-0.5475,  1.8319, -1.6593, -1.9774, -0.3549]]],


        [[[-0.0032, -0.0484, -0.2990,  0.7075, -0.6719],
          [ 1.5305,  0.8488,  0.4503,  2.0581,  0.4679],
          [ 0.0354, -1.0883,  1.6254, -2.1237,  0.4148],
          [-2.1199, -0.0477, -0.9596,  0.8312, -0.3866],
          [ 1.6314,  1.9948, -0.2880,  0.4639, -0.1125]],

         [[ 0.2379,  0.2749, -1.1225, -0.4368,  1.3602],
          [ 1.4522,  0.2326, -0.6224, -1.1355,  0.3085],
          [ 0.7286,  0.2441, -1.4076,  0.5067, -0.7492],
          [ 0.7412, -0.1160,  1.2024,  0.9920, -1.0792],
          [-1.1923,  0.6291, -1.0011, -0.9160,  1.0205]],

         [[ 0.2387, -0.9847,  0.7370,  1.5279, -0.9204],
          [ 0.8992,  1.4813,  0.6839, -0.2267, -1.3877],
          [ 0.2700,  0.4006,  0.8214, -0.4204, -0.7890],
          [-0.7622,  1.1744,  0.9095, -0.7127, -1.2161],
          [-0.9994, -2.5458,  1.1390, -0.1247,  0.3969]]]],
       grad_fn=<NativeBatchNormBackward>)
tensor([-0.1099, -0.1201,  0.0899]) tensor([1.1323, 1.0076, 1.1978])
  • 關(guān)于batch_norm的一段代碼轴总,可以很好的詮釋batch_norm的含義

    def Batchnorm_simple_for_train(x, gamma, beta, bn_param):
        """
        param:x    : 輸入數(shù)據(jù)怀樟,設(shè)shape(B,L)
        param:gama : 縮放因子  γ
        param:beta : 平移因子  β
        param:bn_param   : batchnorm所需要的一些參數(shù)
            eps      : 接近0的數(shù),防止分母出現(xiàn)0
            momentum : 動量參數(shù),一般為0.9穆咐, 0.99着绊, 0.999
            running_mean :滑動平均的方式計(jì)算新的均值归露,訓(xùn)練時(shí)計(jì)算,為測試數(shù)據(jù)做準(zhǔn)備
            running_var  : 滑動平均的方式計(jì)算新的方差疆液,訓(xùn)練時(shí)計(jì)算,為測試數(shù)據(jù)做準(zhǔn)備
        """
        running_mean = bn_param['running_mean']  #shape = [B]
        running_var = bn_param['running_var']    #shape = [B]
        results = 0. # 建立一個(gè)新的變量

        x_mean=x.mean(axis=0)  # 計(jì)算x的均值
        x_var=x.var(axis=0)    # 計(jì)算方差
        x_normalized=(x-x_mean)/np.sqrt(x_var+eps)       # 歸一化
        results = gamma * x_normalized + beta            # 縮放平移

        running_mean = momentum * running_mean + (1 - momentum) * x_mean
        running_var = momentum * running_var + (1 - momentum) * x_var

        #記錄新的值
        bn_param['running_mean'] = running_mean
        bn_param['running_var'] = running_var 

        return results , bn_param


instance_norm函數(shù)

  • 與batch_norm,layer_norm類似的函數(shù)眶明,區(qū)別在于:

    • 計(jì)算平均值的方式不同,最終把每個(gè)通道的和起來
    • [圖片上傳失敗...(image-11e7d1-1569983607544)]
  • 將輸入的圖像shape記為[N, C, H, W],這幾個(gè)方法主要的區(qū)別就是在幌缝,

    • Batch Norm是在batch上浴栽,對NHW做歸一化典鸡,就是對每個(gè)單一通道輸入進(jìn)行歸一化嫁乘,這樣做對小batchsize效果不好仓蛆;
    • Layer Norm在通道方向上,對CHW歸一化能庆,就是對每個(gè)深度上的輸入進(jìn)行歸一化,主要對RNN作用明顯丰涉;
    • Instance Norm在圖像像素上,對HW做歸一化,對一個(gè)圖像的長寬即對一個(gè)像素進(jìn)行歸一化伪煤,用在風(fēng)格化遷移;
    • Group Norm將channel分組防泵,有點(diǎn)類似于LN,只是GN把channel也進(jìn)行了劃分,細(xì)化咏瑟,然后再做歸一化兄旬;
    • Switchable Norm是將BN、LN罐孝、IN結(jié)合,賦予權(quán)重,讓網(wǎng)絡(luò)自己去學(xué)習(xí)歸一化層應(yīng)該使用什么方法谒兄。
  • instance_norm函數(shù)說明:
    torch.nn.functional.instance_norm(
        input, 
        running_mean=None,     # 可以為None
        running_var=None, 
        weight=None, 
        bias=None, 
        use_input_stats=True, momentum=0.1, eps=1e-05)
import torch
x= torch.randn(4, 3, 5, 5)   # 4張3通道5*5的圖像

m = x.mean(dim=(0,2,3), keepdim=False)   # N(0),C(1),H(2)啊研,W(3)
v = x.var(dim=(0,2,3), keepdim=False)
print(m.shape)
w = torch.ones(3)     # w與b是需要學(xué)習(xí)的,因?yàn)闃?biāo)準(zhǔn)化容易丟棄已經(jīng)學(xué)習(xí)的規(guī)則,使用參數(shù)是為了保留已學(xué)習(xí)的規(guī)則。
w.requires_grad=True
b = torch.zeros(3)
b.requires_grad=True

y = torch.nn.functional.instance_norm(
    input=x, 
    running_mean=m,    
    running_var=v,
    weight=w, bias=b
)
print(y)   
# 比較batch_norm與instance_norm的結(jié)果差別
y = torch.nn.functional.batch_norm(
    input=x, 
    running_mean=m,    
    running_var=v,
    weight=w, bias=b
)
print(y)   
torch.Size([3])
tensor([[[[-1.3300e+00, -1.0276e+00,  3.0897e-01,  9.5767e-02,  3.2338e-01],
          [-4.1176e-01,  7.4911e-01, -6.9988e-01,  2.9321e-01,  7.2987e-01],
          [ 1.3905e+00,  5.4397e-01,  1.2827e-02,  3.2809e-01,  7.4175e-01],
          [ 3.9900e-01, -2.5250e+00,  6.4288e-01,  1.8031e+00, -6.0249e-02],
          [ 1.0510e+00, -5.5906e-01,  2.5902e-01, -2.1903e+00, -8.6867e-01]],

         [[ 4.4713e-01,  8.9750e-01, -9.6316e-01,  3.1501e-01,  1.5213e+00],
          [-3.3827e-01, -7.3737e-01, -9.3462e-01,  1.7493e+00,  1.0172e+00],
          [ 4.7460e-01, -2.6951e-01, -1.1510e+00, -7.1324e-01, -1.1702e+00],
          [ 1.0441e+00, -1.7336e+00,  1.2532e-01, -7.4906e-01, -8.4268e-01],
          [ 1.6433e-01,  1.9831e+00, -6.9415e-01, -6.0049e-01,  1.1584e+00]],

         [[-2.3036e+00,  7.9183e-01, -2.5915e-01,  1.3214e+00, -1.5231e+00],
          [ 4.1672e-02,  5.2620e-01,  3.9970e-01,  4.4207e-01, -2.3129e-01],
          [-4.2022e-01, -8.8784e-01,  8.0396e-01,  1.7451e+00, -6.7207e-02],
          [ 9.6801e-01,  8.7536e-01,  1.4504e+00, -4.8293e-01, -1.4823e+00],
          [ 3.8419e-01,  3.7225e-01, -3.8346e-01, -4.2881e-01, -1.6523e+00]]],


        [[[-5.7823e-01, -3.3143e-01,  7.1775e-01,  2.1239e-01,  3.6560e-01],
          [-1.7156e-01, -9.9428e-01,  1.9243e-01,  8.9780e-01, -1.5584e+00],
          [-1.8803e+00, -1.2125e-02, -7.6846e-01, -3.1888e-01, -4.6432e-01],
          [ 7.6454e-01,  2.7620e-02, -1.4467e+00,  6.5572e-01, -9.6305e-01],
          [-2.7112e-01,  1.7859e+00,  1.9314e+00,  2.0155e+00,  1.9219e-01]],

         [[-1.7664e+00,  1.5730e+00, -8.9284e-01,  1.1217e+00, -1.0072e+00],
          [ 8.7060e-01, -4.5303e-01,  1.1819e+00, -3.9239e-01, -7.1941e-02],
          [-1.8391e+00, -1.2180e+00, -9.6703e-01,  1.1305e+00,  3.5450e-01],
          [ 6.3686e-01,  1.0018e+00, -5.7442e-01, -8.5710e-01, -9.9438e-01],
          [-1.9725e-01,  9.4404e-01,  9.4887e-01,  1.0814e+00,  3.8601e-01]],

         [[ 4.1552e-01, -6.7624e-01,  1.1387e+00,  2.0360e+00,  2.3363e-01],
          [ 1.0976e+00, -1.5366e-01,  4.5983e-02, -4.4905e-01,  3.7788e-01],
          [-3.0292e-01, -5.5900e-01, -2.4216e+00,  3.2742e-01, -1.7846e+00],
          [ 1.0120e+00, -7.8423e-01, -6.4370e-01, -1.2993e+00,  3.6975e-01],
          [ 9.7618e-01, -6.3158e-01, -3.9939e-01,  1.3506e+00,  7.2390e-01]]],


        [[[ 1.0768e+00,  8.2844e-02, -2.2382e-01,  2.1644e+00, -1.5242e+00],
          [-1.7984e-01, -5.6175e-01,  2.4898e-01, -5.0420e-01, -8.5796e-01],
          [ 1.5355e+00,  8.2799e-01,  6.5538e-01, -1.5339e+00, -9.7581e-01],
          [ 6.1572e-02, -6.4004e-01,  4.7675e-01, -1.4959e+00, -1.0269e+00],
          [ 1.0670e+00,  1.4524e+00, -6.9493e-01, -3.3628e-01,  9.0577e-01]],

         [[-9.0222e-01,  9.3811e-01, -1.0737e+00, -5.2899e-01, -1.5509e+00],
          [ 5.9876e-01, -5.6182e-01, -1.3897e+00, -1.9862e+00, -5.7414e-01],
          [ 1.9665e+00,  1.4921e+00, -6.4958e-01,  1.1793e-01,  3.5676e-01],
          [ 2.9828e-01,  4.0369e-01,  1.0687e+00,  1.9175e-01,  7.0528e-02],
          [ 9.6439e-01, -7.0567e-01,  1.6306e+00, -5.3084e-01,  3.5556e-01]],

         [[-7.4705e-01,  2.0131e-01,  9.2345e-01, -2.1607e+00, -9.0636e-01],
          [ 1.9722e-01,  1.4994e+00, -2.2590e-01,  4.1395e-01, -7.3280e-02],
          [-5.7261e-01,  9.1539e-01, -8.1471e-01,  3.0324e-01,  1.4799e+00],
          [-9.1599e-01, -1.5451e+00,  9.5537e-01,  1.8283e+00, -7.3733e-01],
          [-1.1077e+00, -8.3597e-01,  7.4975e-01,  2.5025e-01,  9.2512e-01]]],


        [[[ 4.3015e-02, -1.1650e+00,  1.7200e-02, -1.4198e+00,  1.6581e+00],
          [ 1.8173e-01, -8.0867e-02,  3.9920e-01, -4.2110e-01,  1.6406e+00],
          [ 7.3747e-01,  9.6860e-01, -1.3032e+00, -1.1416e+00,  7.8118e-01],
          [ 2.5386e+00,  1.9666e-01,  4.3760e-01,  5.3153e-02, -4.3334e-01],
          [-1.3547e+00, -1.1957e+00, -6.1150e-02, -5.3436e-01, -5.4225e-01]],

         [[ 1.6825e+00, -5.6726e-01,  7.8480e-01,  8.8852e-02, -9.9347e-01],
          [-1.0618e+00, -1.7799e+00,  1.7843e-01,  5.7744e-01,  7.4290e-01],
          [ 6.0726e-01, -1.3402e+00,  2.0730e-01, -2.3312e+00, -2.5911e-02],
          [-3.1424e-01, -1.2857e-03,  2.2741e-01,  1.0183e+00,  3.0547e-01],
          [-8.4714e-01,  1.1635e+00, -5.9346e-01,  1.8763e+00,  3.9539e-01]],

         [[ 9.8080e-01,  1.2992e-01, -9.5844e-01, -1.7077e+00, -5.6559e-01],
          [ 7.4964e-01,  4.8532e-01, -1.3676e-01, -3.8980e-01, -1.4458e-01],
          [-5.6245e-02,  1.3231e+00, -1.0872e+00,  2.4868e+00,  1.4366e+00],
          [ 2.5883e-01, -7.0243e-01, -8.8365e-02,  2.2908e-01, -2.3850e+00],
          [-9.2475e-02, -7.5162e-01,  6.7078e-01,  5.5210e-01, -2.3674e-01]]]],
       grad_fn=<ViewBackward>)
tensor([[[[-1.5079, -1.2248,  0.0263, -0.1733,  0.0398],
          [-0.6484,  0.4383, -0.9181,  0.0115,  0.4203],
          [ 1.0386,  0.2462, -0.2509,  0.0442,  0.4314],
          [ 0.1105, -2.6264,  0.3388,  1.4249, -0.3193],
          [ 0.7209, -0.7862, -0.0205, -2.3132, -1.0761]],

         [[ 0.6112,  1.0795, -0.8553,  0.4738,  1.7282],
          [-0.2055, -0.6205, -0.8256,  1.9652,  1.2040],
          [ 0.6398, -0.1340, -1.0506, -0.5954, -1.0706],
          [ 1.2319, -1.6564,  0.2766, -0.6327, -0.7300],
          [ 0.3171,  2.2084, -0.5756, -0.4782,  1.3508]],

         [[-2.4660,  1.0380, -0.1517,  1.6375, -1.5826],
          [ 0.1888,  0.7373,  0.5941,  0.6421, -0.1202],
          [-0.3341, -0.8634,  1.0517,  2.1171,  0.0656],
          [ 1.2374,  1.1325,  1.7835, -0.4051, -1.5364],
          [ 0.5765,  0.5630, -0.2924, -0.3438, -1.7287]]],


        [[[-0.5347, -0.3214,  0.5852,  0.1485,  0.2809],
          [-0.1833, -0.8942,  0.1313,  0.7408, -1.3817],
          [-1.6598, -0.0455, -0.6990, -0.3106, -0.4362],
          [ 0.6257, -0.0111, -1.2851,  0.5316, -0.8672],
          [-0.2693,  1.5082,  1.6340,  1.7067,  0.1311]],

         [[-2.0110,  1.5980, -1.0669,  1.1103, -1.1904],
          [ 0.8389, -0.5915,  1.1754, -0.5260, -0.1797],
          [-2.0895, -1.4182, -1.1470,  1.1198,  0.2812],
          [ 0.5863,  0.9807, -0.7227, -1.0282, -1.1766],
          [-0.3151,  0.9183,  0.9235,  1.0667,  0.3152]],

         [[ 0.3790, -0.6094,  1.0336,  1.8460,  0.2143],
          [ 0.9964, -0.1363,  0.0444, -0.4037,  0.3449],
          [-0.2714, -0.5032, -2.1894,  0.2992, -1.6127],
          [ 0.9190, -0.7071, -0.5799, -1.1733,  0.3375],
          [ 0.8865, -0.5689, -0.3587,  1.2254,  0.6581]]],


        [[[ 1.2406,  0.1963, -0.1259,  2.3834, -1.4923],
          [-0.0797, -0.4810,  0.3708, -0.4205, -0.7922],
          [ 1.7225,  0.9792,  0.7978, -1.5024, -0.9161],
          [ 0.1739, -0.5633,  0.6102, -1.4625, -0.9697],
          [ 1.2304,  1.6353, -0.6209, -0.2441,  1.0609]],

         [[-0.8363,  0.5258, -0.9632, -0.5601, -1.3164],
          [ 0.2746, -0.5844, -1.1971, -1.6386, -0.5935],
          [ 1.2869,  0.9359, -0.6493, -0.0813,  0.0955],
          [ 0.0522,  0.1302,  0.6225, -0.0266, -0.1163],
          [ 0.5452, -0.6908,  1.0383, -0.5614,  0.0946]],

         [[-0.6772,  0.1907,  0.8516, -1.9708, -0.8229],
          [ 0.1870,  1.3787, -0.2002,  0.3853, -0.0606],
          [-0.5175,  0.8442, -0.7391,  0.2840,  1.3608],
          [-0.8318, -1.4075,  0.8808,  1.6797, -0.6683],
          [-1.0072, -0.7585,  0.6926,  0.2355,  0.8531]]],


        [[[ 0.2341, -1.0415,  0.2069, -1.3105,  1.9396],
          [ 0.3806,  0.1033,  0.6102, -0.2559,  1.9211],
          [ 0.9674,  1.2115, -1.1874, -1.0168,  1.0136],
          [ 2.8693,  0.3964,  0.6508,  0.2448, -0.2689],
          [-1.2418, -1.0739,  0.1241, -0.3755, -0.3839]],

         [[ 1.8749, -0.4660,  0.9409,  0.2167, -0.9095],
          [-0.9806, -1.7278,  0.3099,  0.7251,  0.8973],
          [ 0.7561, -1.2703,  0.3399, -2.3015,  0.0973],
          [-0.2027,  0.1229,  0.3609,  1.1838,  0.4421],
          [-0.7572,  1.3349, -0.4933,  2.0766,  0.5357]],

         [[ 0.8154, -0.0229, -1.0953, -1.8335, -0.7082],
          [ 0.5876,  0.3272, -0.2857, -0.5350, -0.2934],
          [-0.2064,  1.1527, -1.2222,  2.2992,  1.2645],
          [ 0.1041, -0.8430, -0.2380,  0.0748, -2.5008],
          [-0.2421, -0.8915,  0.5100,  0.3930, -0.3842]]]],
       grad_fn=<NativeBatchNormBackward>)

layer_norm函數(shù)

  • 對指定的輸入中最后幾個(gè)維度標(biāo)準(zhǔn)化
    torch.nn.functional.layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05)
import torch
x= torch.randn(4, 3, 5, 5)   # 4張3通道5*5的圖像

w = torch.ones(3, 5, 5)     # 對最后兩個(gè)維度做標(biāo)準(zhǔn)化
w.requires_grad=True
b = torch.zeros(3, 5,5)
b.requires_grad=True

y = torch.nn.functional.layer_norm(
    input=x, 
    normalized_shape=(3, 5, 5),
    weight=w, bias=b
)
print(y[0,0,:,:]) 
tensor([[ 0.0917,  0.2445,  0.1294,  0.0352,  0.5847],
        [ 1.7303, -1.4383, -1.1723,  0.3757,  0.1029],
        [ 0.1677, -0.1781,  1.3524, -0.0983, -0.3715],
        [ 1.7907, -0.6600,  2.0479, -0.0905, -0.1329],
        [ 0.6516,  0.1149,  0.6535, -0.3212,  0.4425]],
       grad_fn=<SliceBackward>)

距離函數(shù)

  • Torch提供三種距離計(jì)算方式:

    • 計(jì)算兩個(gè)向量的歐氏距離(2-范數(shù)):pairwise_distance
    • 計(jì)算相似距離(兩個(gè)向量夾角的余弦):cosine_similarity
      • 這個(gè)實(shí)際上是兩個(gè)矩陣的相關(guān)性度量(向量正交就是無關(guān)柿赊,夾角90度,余弦為0胰挑;向量線性相關(guān),夾角0度,余弦為1)
    • 計(jì)算矩陣每行之間的歐氏距離:pdist
  • 注意:

    • 實(shí)際在數(shù)學(xué)上根據(jù)不同的應(yīng)用盖矫,有很多種距離的定義方式。
  • 下面使用例子說明:

pairwise_distance函數(shù)

    torch.nn.functional.pairwise_distance(x1, x2, p=2.0, eps=1e-06, keepdim=False)
import torch
# 兩個(gè)點(diǎn)之間的距離
t1 = torch.Tensor([[0, 1]])
t2 = torch.Tensor([[1, 0]])

print(torch.nn.functional.pairwise_distance(t1, t2))
tensor([1.4142])
import torch
# 對應(yīng)兩個(gè)向量之間的距離,行數(shù)相同(1->1),否則其中一個(gè)只能是0行(1->多)檩帐。
t1 = torch.Tensor(
    [
        [0, 1],
        [0, 0]
    ]
)
t2 = torch.Tensor(
    [
        [1, 0],
        [0,0]
    ]
)

print(torch.nn.functional.pairwise_distance(t1, t2))
tensor([1.4142e+00, 1.4142e-06])

cosine_similarity函數(shù)

    torch.nn.functional.cosine_similarity(x1, x2, dim=1, eps=1e-8) → Tensor
t1 = torch.Tensor(
    [
        [0, 1],
        [0, 1]
    ]
)
t2 = torch.Tensor(
    [
        [1, 0],
        [0,1]
    ]
)

print(torch.nn.functional.cosine_similarity(t1, t2, dim=1))   # 按照行計(jì)算相似度
tensor([0., 1.])

pdist函數(shù)

    torch.nn.functional.pdist(input, p=2) → Tensor
t1 = torch.Tensor(
    [
        [0, 1],
        [1, 1],
        [0.5, 0.5],
        [0,0]
    ]
)

print(torch.nn.functional.pdist(t1))   # 按照全排列方式計(jì)算
tensor([1.0000, 0.7071, 1.0000, 0.7071, 1.4142, 0.7071])

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子没龙,更是在濱河造成了極大的恐慌,老刑警劉巖,帶你破解...
    沈念sama閱讀 207,248評論 6 481
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件肛鹏,死亡現(xiàn)場離奇詭異,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī)皱卓,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 88,681評論 2 381
  • 文/潘曉璐 我一進(jìn)店門兄朋,熙熙樓的掌柜王于貴愁眉苦臉地迎上來傅事,“玉大人,你說我怎么就攤上這事响鹃〖拿酰” “怎么了脆栋?”我有些...
    開封第一講書人閱讀 153,443評論 0 344
  • 文/不壞的土叔 我叫張陵,是天一觀的道長褐捻。 經(jīng)常有香客問我,道長,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 55,475評論 1 279
  • 正文 為了忘掉前任,我火速辦了婚禮硫椰,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘。我一直安慰自己帮坚,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 64,458評論 5 374
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著节视,像睡著了一般匾荆。 火紅的嫁衣襯著肌膚如雪简卧。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 49,185評論 1 284
  • 那天,我揣著相機(jī)與錄音,去河邊找鬼链蕊。 笑死掌实,一個(gè)胖子當(dāng)著我的面吹牛宴卖,可吹牛的內(nèi)容都是我干的随闽。 我是一名探鬼主播,決...
    沈念sama閱讀 38,451評論 3 401
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼蘑斧!你這毒婦竟也來了沟突?” 一聲冷哼從身側(cè)響起庸论,我...
    開封第一講書人閱讀 37,112評論 0 261
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎聂示,沒想到半個(gè)月后域携,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 43,609評論 1 300
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡鱼喉,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,083評論 2 325
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了扛禽。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片锋边。...
    茶點(diǎn)故事閱讀 38,163評論 1 334
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖编曼,靈堂內(nèi)的尸體忽然破棺而出豆巨,到底是詐尸還是另有隱情,我是刑警寧澤灵巧,帶...
    沈念sama閱讀 33,803評論 4 323
  • 正文 年R本政府宣布搀矫,位于F島的核電站抹沪,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏瓤球。R本人自食惡果不足惜融欧,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,357評論 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望卦羡。 院中可真熱鬧噪馏,春花似錦、人聲如沸绿饵。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,357評論 0 19
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽拟赊。三九已至刺桃,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間吸祟,已是汗流浹背瑟慈。 一陣腳步聲響...
    開封第一講書人閱讀 31,590評論 1 261
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留屋匕,地道東北人葛碧。 一個(gè)月前我還...
    沈念sama閱讀 45,636評論 2 355
  • 正文 我出身青樓,卻偏偏與公主長得像过吻,于是被迫代替她去往敵國和親进泼。 傳聞我的和親對象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 42,925評論 2 344

推薦閱讀更多精彩內(nèi)容