导图社区 VAE的详细结构 卷积 ResNet sample Attention mean std
VAE的详细结构 卷积 ResNet sample Attention mean std思维导图,假设输入1,3,1024,1024图像,标注了每一步的维度变化。
编辑于2025-03-26 14:14:31中心主题
AutoencoderKL(
(encoder): Encoder(
(conv_in):
Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(down_blocks): ModuleList(
(0): DownEncoderBlock2D(
(resnets): ModuleList(
(0)ResnetBlock2D( (norm1): GroupNorm(32, 128, eps=1e-06, affine=True) # 这意味着输入的 128 个通道会被分成 32 组,每组有 128 / 32 = 4 个通道。 (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (norm2): GroupNorm(32, 128, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU()
(1) ResnetBlock2D( (norm1): GroupNorm(32, 128, eps=1e-06, affine=True) (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (norm2): GroupNorm(32, 128, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU()
(downsamplers): ModuleList(
(0): Downsample2D( (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))
(1): DownEncoderBlock2D(
(resnets): ModuleList(
(0): ResnetBlock2D( (norm1): GroupNorm(32, 128, eps=1e-06, affine=True) (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (norm2): GroupNorm(32, 256, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))
(1): ResnetBlock2D( (norm1): GroupNorm(32, 256, eps=1e-06, affine=True) (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (norm2): GroupNorm(32, 256, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU()
(downsamplers): ModuleList(
(0): Downsample2D( (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))
(2): DownEncoderBlock2D(
(resnets): ModuleList(
(0): ResnetBlock2D( (norm1): GroupNorm(32, 256, eps=1e-06, affine=True) (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))
(1): ResnetBlock2D( (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU()
(downsamplers): ModuleList(
(0): Downsample2D( (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2))
(3): DownEncoderBlock2D(
(resnets): ModuleList(
(0): ResnetBlock2D( (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU()
(1): ResnetBlock2D( (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU()
(mid_block): UNetMidBlock2D(
(attentions): ModuleList(
(0): Attention( (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True) (to_q): Linear(in_features=512, out_features=512, bias=True) (to_k): Linear(in_features=512, out_features=512, bias=True) (to_v): Linear(in_features=512, out_features=512, bias=True) (to_out): ModuleList( (0): Linear(in_features=512, out_features=512, bias=True) (1): Dropout(p=0.0, inplace=False)
(resnets): ModuleList(
(0):ResnetBlock2D( (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU()
(1):ResnetBlock2D( (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU()
(conv_norm_out): GroupNorm(32, 512, eps=1e-06, affine=True)
(conv_act): SiLU()
(conv_out): Conv2d(512, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))# 32个卷积核的形状是 (3, 3, 512),即它有 3x3 的空间尺寸,并且 每个卷积核会应用到所有的 512 个输入通道。因此,每个卷积核包含 3 * 3 * 512 = 4608 个参数。
DiagonalGaussianDistribution
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.std = torch.exp(0.5 * self.logvar) torch.Size([1, 32, 128, 128]) self.var = torch.exp(self.logvar) torch.Size([1, 32, 128, 128])
self.mean torch.Size([1, 32, 128, 128])
(decoder): Decoder(
(conv_in): Conv2d(16, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(up_blocks): ModuleList(
(0): UpDecoderBlock2D(
(resnets): ModuleList(
(0): 3 x ResnetBlock2D( (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU()
(1): ResnetBlock2D( (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU()
(2): ResnetBlock2D( (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU()
(upsamplers): ModuleList(
(0): Upsample2D( (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): UpDecoderBlock2D(
(resnets): ModuleList(
(0): ResnetBlock2D( (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU()
(1): ResnetBlock2D( (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU()
(2): ResnetBlock2D( (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU()
(upsamplers): ModuleList(
(0): Upsample2D( (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): UpDecoderBlock2D(
(resnets): ModuleList(
(0): ResnetBlock2D( (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) (conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (norm2): GroupNorm(32, 256, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
(1): ResnetBlock2D( (norm1): GroupNorm(32, 256, eps=1e-06, affine=True) (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (norm2): GroupNorm(32, 256, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU()
(2): ResnetBlock2D( (norm1): GroupNorm(32, 256, eps=1e-06, affine=True) (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (norm2): GroupNorm(32, 256, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU()
(upsamplers): ModuleList(
(0): Upsample2D( (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): UpDecoderBlock2D(
(resnets): ModuleList(
(0): ResnetBlock2D( (norm1): GroupNorm(32, 256, eps=1e-06, affine=True) (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (norm2): GroupNorm(32, 128, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU() (conv_shortcut): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
(1): ResnetBlock2D( (norm1): GroupNorm(32, 128, eps=1e-06, affine=True) (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (norm2): GroupNorm(32, 128, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU()
(2): ResnetBlock2D( (norm1): GroupNorm(32, 128, eps=1e-06, affine=True) (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (norm2): GroupNorm(32, 128, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU()
(mid_block): UNetMidBlock2D(
(attentions): ModuleList(
(0): Attention( (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True) (to_q): Linear(in_features=512, out_features=512, bias=True) (to_k): Linear(in_features=512, out_features=512, bias=True) (to_v): Linear(in_features=512, out_features=512, bias=True) (to_out): ModuleList( (0): Linear(in_features=512, out_features=512, bias=True) (1): Dropout(p=0.0, inplace=False)
(resnets): ModuleList(
(0): ResnetBlock2D( (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU()
(1) ResnetBlock2D( (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) (dropout): Dropout(p=0.0, inplace=False) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (nonlinearity): SiLU()
(conv_norm_out): GroupNorm(32, 128, eps=1e-06, affine=True)
(conv_act): SiLU()
(conv_out): Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor: sample = randn_tensor( self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype, ) x = self.mean + self.std * sample return x
torch.Size([1, 32, 128, 128])
torch.Size([1, 512, 128, 128])
torch.Size([1,512, 128, 128])
torch.Size([1,512, 128, 128])
torch.Size([1, 256, 256, 256])
torch.Size([1, 128, 512, 512])
torch.Size([1, 128, 1024, 1024])
torch.Size([1, 3, 1024, 1024])