1. ホーム
  2. ピトーチ

pytorchのConv1dの詳細説明

2022-02-28 08:07:58

ブロガーは転載を歓迎しますが、必ず出典を明記した上でオリジナルのリンクを貼ってください ありがとうございます。

pytorchのnn.Conv1dについて説明します。

(最初にまとめたときはこんなに多くの人に読んでもらえるとは思っていなかったし、質問も結構ありました。いつもCSDNにログインしていないため、コメントへの返信がタイムリーにできないことがあります。Zhihuをお使いの方は、プライベートメッセージで下記までご連絡ください。 Sunny.Xiaの個人ページ メッセージを見たら必ず返信します。メッセージを見たら、必ず間に合うように返信します。何ヶ月も返信なしでコメントしてる人には申し訳ないけど、ナイフで切らないでね。(本当に頻繁にログインしていません。記事を書いている時だけです)

テキスト分類のpytorchを勉強しているときに、1次元畳み込みを使い、その仕組みを理解するのに時間がかかり、Web上で詳しく説明しているブログがないので、記録しておくことにします。

Conv1d

class torch.nn.Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True).

  • in_channels( int ) - 入力信号のチャンネル.テキスト分類では、これは単語ベクトルの次元です。
  • out_channels( int ) - 畳み込みによって生成されるチャンネル.1次元の畳み込みと同じ数の out_channels が存在します.
  • カーネルサイズ( int  または  tuple ) - 畳み込みカーネルのサイズ.畳み込みカーネルのサイズは (k,) で,2番目の次元は in_channels によって決定されるので,畳み込みサイズは実際には kernel_size*in_channels となります.
  • ストライド( int  または  tuple optional ) - コンボリューションステップ
  • パディング ( int  または  tuple optional ) - 0で補完された入力の各辺の層数
  • ダイレーション( int  または  tuple オプション``) - 畳み込みカーネルの要素間の間隔.
  • グループ( int optional ) - 入力チャンネルから出力チャンネルへのブロッキング接続の数.
  • バイアス( bool optional ) - もし bias=True を追加し、バイアスをかける。

一例です。

conv1 = nn.Conv1d(in_channels=256, out_channels=100,kernel_size=2)
input = torch.randn(32,35,256)
# batch_size x text_len x embedding_size -> batch_size x embedding_size x text_len
input = input.permute(0,2,1)
out = conv1(input)
print(out.size())


ここで32はbatch_size、35は最大文長、256は単語ベクトルである

1次元の畳み込みを再度入力する場合、32*35*256を32*256*35に変換する必要があるが、これは1次元の畳み込みが最後の次元で掃き出されるためで、最終的に出てくる大きさは 32*100*(35-2+1)=32*100*34

1次元畳み込みがどのように使用されるかを視覚化した図を添付します:。

画像の出典はこちら Kerasのテキスト分類の実装

図の入力単語ベクトルは5次元で、入力サイズは7*5、1次元の畳み込みカーネルはサイズが2、3、4でそれぞれ2つ、合計6つの特徴量を持つ。

k=4の場合、図中の赤の大きな行列を参照、コンボリューションカーネルサイズは4*5、ステップサイズは1。ここでは入力に対して上から下へスイープし、出力ベクトルのサイズは((7-4)/1+1)*1=4*1、最終的にコンボリューションカーネルサイズ4から1値でmax_poolingを経ることになります。最終的に6つの値が得られ、スプライシングされ、完全連結層を経て、2つのカテゴリの確率が出力される。

詳しく説明するためのコードを添付します。

ここで、embedding_size=256, feature_size=100, window_sizes=[3,4,5,6], max_text_len=35

class TextCNN(nn.Module):
    def __init__(self, config):
        super(TextCNN, self). __init__()
        self.is_training = True
        self.dropout_rate = config.dropout_rate
        self.num_class = config.num_class
        self.use_element = config.use_element
        self.config = config

        self.embedding = nn.Embedding(num_embeddings=config.vocab_size, 
                                embedding_dim=config.embedding_size)
        self.convs = nn.ModuleList([
                nn.Sequential(nn.Conv1d(in_channels=config.embedding_size, 
                                        out_channels=config.feature_size, 
                                        kernel_size=h),
# nn.BatchNorm1d(num_features=config.feature_size), 
                              nn.ReLU(),
                              nn.MaxPool1d(kernel_size=config.max_text_len-h+1))
                     for h in config.window_sizes
                    ])
        self.fc = nn.Linear(in_features=config.feature_size*len(config.window_sizes),
                            out_features=config.num_class)
        if os.path.exists(config.embedding_path) and config.is_training and config.is_pretrain:
            print("Loading pretrain embedding... ")
            self.embedding.weight.data.copy_(torch.from_numpy(np.load(config.embedding_path)))    
    
    def forward(self, x):
        embed_x = self.embedding(x)
        
        # print('embed size 1',embed_x.size()) # 32*35*256
# batch_size x text_len x embedding_size -> batch_size x embedding_size x text_len
        embed_x = embed_x.permute(0, 2, 1)
        # print('embed size 2',embed_x.size()) # 32*256*35
        out = [conv(embed_x) for conv in self.convs] #out[i]:batch_size x feature_size*1
        #for o in out:
        # print('o',o.size()) # 32*100*1
        out = torch.cat(out, dim=1) # corresponding to the second dimension (row) spliced together, e.g. 5*2*1,5*3*1 spliced into 5*5*1
        # print(out.size(1)) # 32*400*1
        out = out.view(-1, out.size(1)) 
        # print(out.size()) # 32*400 
        if not self.use_element:
            out = F.dropout(input=out, p=self.dropout_rate)
            out = self.fc(out)
        return out


embed_xは32*35*256で始まり、batch_sizeは32。permute後は32*256*35になり、カスタムネットワークに入力後、各要素は32*100*1、合計4要素になる。dim=1次元でstitch後、32*400*1、view後、32*400、最後に400*num_classサイズの完全連結行列を経て、32*2になる。