Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 222 additions & 1 deletion src/diffusers/models/autoencoders/autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@

import torch
import torch.nn as nn
import torch.distributed as dist

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import deprecate
from ...utils import deprecate, logging
from ...utils.accelerate_utils import apply_forward_hook
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
Expand All @@ -35,6 +36,9 @@
from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder


logger = logging.get_logger(__name__) # pylint: disable=invalid-name


class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
Expand Down Expand Up @@ -127,6 +131,7 @@ def __init__(

self.use_slicing = False
self.use_tiling = False
self.use_dp = False

# only relevant if vae tiling is enabled
self.tile_sample_min_size = self.config.sample_size
Expand Down Expand Up @@ -214,9 +219,58 @@ def set_default_attn_processor(self):

self.set_attn_processor(processor)

def enable_dp(
self,
world_size: Optional[int] = None,
hw_splits: Optional[Tuple[int, int]] = None,
overlap_ratio: Optional[float] = None,
overlap_pixels: Optional[int] = None
) -> None:
r"""
"""
if world_size is None:
world_size = dist.get_world_size() if dist.is_initialized() else 1

if world_size <= 1 or world_size > dist.get_world_size():
logger.warning(
f"Supported world_size for vae dp is between 2 - {dist.get_world_size}, but got {world_size}. " \
f"Fall back to normal vae")
return

if hw_splits is None:
hw_splits = (1, int(world_size))

assert len(hw_splits) == 2, f"'hw_splits' should be a tuple of 2 int, but got length {len(hw_splits)}"

h_split, w_split = map(int, hw_splits)

self.use_dp = True
self.h_split, self.w_split = h_split, w_split
self.world_size = world_size
self.overlap_ratio = overlap_ratio
self.overlap_pixels = overlap_pixels
self.spatial_compression_ratio = 2 ** (len(self.config.block_out_channels) - 1)

dp_ranks = list(range(0, world_size))
self.vae_dp_group = dist.new_group(ranks=dp_ranks)
self.rank = dist.get_rank()
# patch_ranks_flatten = [tile_idx % world_size for tile_idx in range(num_tiles)]
# self.patch_ranks = torch.Tensor(patch_ranks_flatten).reshape(h_split, w_split)
self.tile_idxs_per_rank = [[] for _ in range(self.world_size)]
self.num_tiles_per_rank = [0] * self.world_size
rank_idx = 0
for h_idx in range(self.h_split):
for w_idx in range(self.w_split):
rank_idx %= self.world_size
self.tile_idxs_per_rank[rank_idx].append((h_idx, w_idx))
self.num_tiles_per_rank[rank_idx] += 1
rank_idx += 1

def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = x.shape

if self.use_dp:
return self._tiled_encode(x)
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
return self._tiled_encode(x)

Expand Down Expand Up @@ -256,6 +310,8 @@ def encode(
return AutoencoderKLOutput(latent_dist=posterior)

def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
if self.use_dp:
return self.tiled_decode_with_dp(z, return_dict=return_dict)
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
return self.tiled_decode(z, return_dict=return_dict)

Expand Down Expand Up @@ -310,6 +366,20 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b

def blend_v_(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
y = torch.arange(0, blend_extent, device=a.device)
blend_ratio = (y / blend_extent)[None, None, :, None].to(a.dtype)
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - blend_ratio) + b[:, :, y, :] * blend_ratio
return b

def blend_h_(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
x = torch.arange(0, blend_extent, device=a.device)
blend_ratio = (x / blend_extent)[None, None, None, :].to(a.dtype)
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - blend_ratio) + b[:, :, :, x] * blend_ratio
return b

def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.

Expand Down Expand Up @@ -469,6 +539,157 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod

return DecoderOutput(sample=dec)

def calculate_tiled_parallel_size(self, latent_height, latent_width):
# Calculate stride based on h_split and w_split
tile_latent_stride_height = int((latent_height + self.h_split - 1) / self.h_split)
tile_latent_stride_width = int((latent_width + self.w_split - 1) / self.w_split)

# Calculate overlap in latent space
overlap_latent_height = 3
overlap_latent_width = 3
if self.overlap_pixels is not None:
overlap_latent = (self.overlap_pixels + self.spatial_compression_ratio - 1) // self.spatial_compression_ratio
overlap_latent_height = overlap_latent
overlap_latent_width = overlap_latent
elif self.overlap_ratio is not None:
overlap_latent_height = int(self.overlap_ratio * latent_height)
overlap_latent_width = int(self.overlap_ratio * latent_width)

# Calculate minimum tile size in latent space
tile_latent_min_height = tile_latent_stride_height + overlap_latent_height
tile_latent_min_width = tile_latent_stride_width + overlap_latent_width

tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio
tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio
tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio
tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio

blend_latent_height = tile_latent_min_height - tile_latent_stride_height
blend_latent_width = tile_latent_min_width - tile_latent_stride_width

blend_sample_height = tile_sample_min_height - tile_sample_stride_height
blend_sample_width = tile_sample_min_width - tile_sample_stride_width

return \
tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, \
tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, \
blend_latent_height, blend_latent_width, blend_sample_height, blend_sample_width

def _tiled_encode_with_dp(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.

When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
output, but they should be much less noticeable.

Args:
x (`torch.Tensor`): Input batch of images.

Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""

_, _, height, width = x.shape
device = x.device
latent_height = height // self.spatial_compression_ratio
latent_width = width // self.spatial_compression_ratio

tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, \
tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, \
blend_latent_height, blend_latent_width, blend_sample_height, blend_sample_width = \
self.calculate_tiled_parallel_size(latent_height, latent_width)

def vae_encode_op(
x, patch_height_start, patch_height_end, patch_width_start, patch_width_end
) -> torch.Tensor:
tile = x[:, :, patch_height_start : patch_height_end, patch_width_start : patch_width_end]
tile = self.encoder(tile)
if self.config.use_quant_conv:
tile = self.quant_conv(tile)
return tile

rows = self.run_vae_tile_parallel(
x, vae_encode_op,
tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, device
)

result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v_(rows[i - 1][j], tile, blend_latent_height)
if j > 0:
tile = self.blend_h_(row[j - 1], tile, blend_latent_width)
result_row.append(tile[:, :, :tile_latent_stride_height, :tile_latent_stride_width])
result_rows.append(torch.cat(result_row, dim=3))

enc = torch.cat(result_rows, dim=2)[:, :, :latent_height, :latent_width]
return enc

def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.

Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.

Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
_, _, latent_height, latent_width = z.shape
device = z.device
sample_height = latent_height * self.spatial_compression_ratio
sample_width = latent_width * self.spatial_compression_ratio

tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, \
tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, \
blend_latent_height, blend_latent_width, blend_sample_height, blend_sample_width = \
self.calculate_tiled_parallel_size(latent_height, latent_width)

def vae_decode_op(
z, patch_height_start, patch_height_end, patch_width_start, patch_width_end
) -> torch.Tensor:

tile = z[:, :, patch_height_start : patch_height_end, patch_width_start : patch_width_end]
if self.config.use_post_quant_conv:
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile)
return decoded

rows = self.run_vae_tile_parallel(
z, vae_decode_op,
tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, device
)

result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v_(rows[i - 1][j], tile, blend_sample_height)
if j > 0:
tile = self.blend_h_(row[j - 1], tile, blend_sample_width)
result_row.append(tile[:, :, :tile_sample_stride_height, :tile_sample_stride_width])
result_rows.append(torch.cat(result_row, dim=3))

dec = torch.cat(result_rows, dim=2)[:, :, :sample_height, :sample_width]
if not return_dict:
return (dec,)

return DecoderOutput(sample=dec)

def forward(
self,
sample: torch.Tensor,
Expand Down
Loading