From 7d5790fffba4e7e6b2396f1e536e1405a3b441f5 Mon Sep 17 00:00:00 2001 From: yyt Date: Tue, 4 Nov 2025 09:35:22 +0000 Subject: [PATCH 1/5] implement vae dp for AutoencoderKL and AutoencoderKLWan --- .../models/autoencoders/autoencoder_kl.py | 300 ++++++++++++++++- .../models/autoencoders/autoencoder_kl_wan.py | 316 ++++++++++++++++++ 2 files changed, 615 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index 1a72aa3cfeb3..e6655a908860 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -15,11 +15,12 @@ import torch import torch.nn as nn +import torch.distributed as dist from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...loaders.single_file_model import FromOriginalModelMixin -from ...utils import deprecate +from ...utils import deprecate, logging from ...utils.accelerate_utils import apply_forward_hook from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -35,6 +36,9 @@ from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. @@ -127,6 +131,7 @@ def __init__( self.use_slicing = False self.use_tiling = False + self.use_dp = False # only relevant if vae tiling is enabled self.tile_sample_min_size = self.config.sample_size @@ -214,9 +219,58 @@ def set_default_attn_processor(self): self.set_attn_processor(processor) + def enable_dp( + self, + world_size: Optional[int] = None, + hw_splits: Optional[Tuple[int, int]] = None, + overlap_ratio: Optional[float] = None, + overlap_pixels: Optional[int] = None + ) -> None: + r""" + """ + if world_size is None: + world_size = dist.get_world_size() + + if world_size <= 1 or world_size > dist.get_world_size(): + logger.warning( + f"Supported world_size for vae dp is between 2 - {dist.get_world_size}, but got {world_size}. " \ + f"Fall back to normal vae") + return + + if hw_splits is None: + hw_splits = (1, int(world_size)) + + assert len(hw_splits) == 2, f"'hw_splits' should be a tuple of 2 int, but got length {len(hw_splits)}" + + h_split, w_split = map(int, hw_splits) + + self.use_dp = True + self.h_split, self.w_split = h_split, w_split + self.world_size = world_size + self.overlap_ratio = overlap_ratio + self.overlap_pixels = overlap_pixels + self.spatial_compression_ratio = 2 ** (len(self.config.block_out_channels) - 1) + + dp_ranks = list(range(0, world_size)) + self.vae_dp_group = dist.new_group(ranks=dp_ranks) + self.rank = dist.get_rank() + # patch_ranks_flatten = [tile_idx % world_size for tile_idx in range(num_tiles)] + # self.patch_ranks = torch.Tensor(patch_ranks_flatten).reshape(h_split, w_split) + self.tile_idxs_per_rank = [[] for _ in range(self.world_size)] + self.num_tiles_per_rank = [0] * self.world_size + rank_idx = 0 + for h_idx in range(self.h_split): + for w_idx in range(self.w_split): + rank_idx %= self.world_size + self.tile_idxs_per_rank[rank_idx].append((h_idx, w_idx)) + self.num_tiles_per_rank[rank_idx] += 1 + rank_idx += 1 + def _encode(self, x: torch.Tensor) -> torch.Tensor: batch_size, num_channels, height, width = x.shape + if self.use_dp: + return self._tiled_encode(x) if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size): return self._tiled_encode(x) @@ -256,6 +310,8 @@ def encode( return AutoencoderKLOutput(latent_dist=posterior) def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + if self.use_dp: + return self.tiled_decode_with_dp(z, return_dict=return_dict) if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): return self.tiled_decode(z, return_dict=return_dict) @@ -497,6 +553,248 @@ def forward( return DecoderOutput(sample=dec) + def _tiled_encode_with_dp(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of images. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + + _, _, height, width = x.shape + device = x.device + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + # Calculate stride based on h_split and w_split + tile_latent_stride_height = int((latent_height + self.h_split - 1) / self.h_split) + tile_latent_stride_width = int((latent_width + self.w_split - 1) / self.w_split) + + # Calculate overlap in latent space + overlap_latent_height = 3 + overlap_latent_width = 3 + if self.overlap_pixels is not None: + overlap_latent = (self.overlap_pixels + self.spatial_compression_ratio - 1) // self.spatial_compression_ratio + overlap_latent_height = overlap_latent + overlap_latent_width = overlap_latent + elif self.overlap_ratio is not None: + overlap_latent_height = int(self.overlap_ratio * latent_height) + overlap_latent_width = int(self.overlap_ratio * latent_width) + + # Calculate minimum tile size in latent space + tile_latent_min_height = tile_latent_stride_height + overlap_latent_height + tile_latent_min_width = tile_latent_stride_width + overlap_latent_width + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio + tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio + tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio + tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio + + # Determine tile grid dimensions - patch_ranks shape is [h_split, w_split] + num_tile_rows = self.h_split + num_tile_cols = self.w_split + + local_tiles = [] + local_hw_shapes = [] + + for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]: + patch_height_start = h_idx * tile_sample_stride_height + patch_height_end = patch_height_start + tile_sample_min_height + patch_width_start = w_idx * tile_sample_stride_width + patch_width_end = patch_width_start + tile_sample_min_width + + tile = x[:, :, patch_height_start : patch_height_end, patch_width_start : patch_width_end] + tile = self.encoder(tile) + if self.config.use_quant_conv: + tile = self.quant_conv(tile) + + local_tiles.append(tile.flatten(-2, -1)) + local_hw_shapes.append(torch.Tensor([*tile.shape[-2:]]).to(device).int()) + + # concat all tiles on local rank + local_tiles = torch.cat(local_tiles, dim=-1) + local_hw_shapes = torch.stack(local_hw_shapes) + + # get all hw shapes for each rank (perhaps has different shapes for last tile) + gathered_shape_list = [torch.empty((num_tiles, 2), dtype=local_hw_shapes.dtype, device=device) + for num_tiles in self.num_tiles_per_rank] + dist.all_gather(gathered_shape_list, local_hw_shapes, group=self.vae_dp_group) + + # gather tiles on all ranks + bc_ = local_tiles.shape[:-1] + gathered_tiles = [ + torch.empty( + (*bc_, tiles_shape.prod(dim=1).sum().item()), + dtype=local_tiles.dtype, device=device) for tiles_shape in gathered_shape_list + ] + dist.all_gather(gathered_tiles, local_tiles, group=self.vae_dp_group) + + # put tiles in rows based on tile_idxs_per_rank + rows = [[None] * num_tile_cols for _ in range(num_tile_rows)] + for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank): + if not tile_idxs: + continue + rank_tile_hw_shapes = gathered_shape_list[rank_idx] + hw_start_idx = 0 + # perhaps has more than one tile in each rank, get each by hw_shapes + for tile_idx, (h_idx, w_idx) in enumerate(tile_idxs): + rank_tile_hw_shape = rank_tile_hw_shapes[tile_idx] + hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item() # flattend hw + rows[h_idx][w_idx] = gathered_tiles[rank_idx][..., hw_start_idx:hw_end_idx].unflatten( + -1, rank_tile_hw_shape.tolist()) # unflatten hw dim + hw_start_idx = hw_end_idx + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=3)) + + enc = torch.cat(result_rows, dim=2)[:, :, :latent_height, :latent_width] + return enc + + def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + _, _, height, width = z.shape + device = z.device + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + # Calculate stride based on h_split and w_split + tile_latent_stride_height = int((height + self.h_split - 1) / self.h_split) + tile_latent_stride_width = int((width + self.w_split - 1) / self.w_split) + + # Calculate overlap in latent space + overlap_latent_height = 3 + overlap_latent_width = 3 + if self.overlap_pixels is not None: + overlap_latent = (self.overlap_pixels + self.spatial_compression_ratio - 1) // self.spatial_compression_ratio + overlap_latent_height = overlap_latent + overlap_latent_width = overlap_latent + elif self.overlap_ratio is not None: + overlap_latent_height = int(self.overlap_ratio * height) + overlap_latent_width = int(self.overlap_ratio * width) + + # Calculate minimum tile size in latent space + tile_latent_min_height = tile_latent_stride_height + overlap_latent_height + tile_latent_min_width = tile_latent_stride_width + overlap_latent_width + + # Convert min/stride to sample space + tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio + tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio + tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio + tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio + + blend_height = tile_sample_min_height - tile_sample_stride_height + blend_width = tile_sample_min_width - tile_sample_stride_width + + # Determine tile grid dimensions - patch_ranks shape is [h_split, w_split] + num_tile_rows = self.h_split + num_tile_cols = self.w_split + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + # Each rank computes only tiles assigned to it based on tile_idxs_per_rank + local_tiles = [] # List to store tiles computed by this rank + local_hw_shapes = [] # List to store shapes of tiles by this rank + + for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]: + patch_height_start = h_idx * tile_latent_stride_height + patch_height_end = patch_height_start + tile_latent_min_height + patch_width_start = w_idx * tile_latent_stride_width + patch_width_end = patch_width_start + tile_latent_min_width + + tile = z[:, :, patch_height_start : patch_height_end, patch_width_start : patch_width_end] + if self.config.use_post_quant_conv: + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + + local_tiles.append(decoded.flatten(-2, -1)) # flatten h,w dim for concate all tiles in one rank + local_hw_shapes.append(torch.Tensor([*decoded.shape[-2:]]).to(device).int()) # record hw for futher unflatten + + # concat all tiles on local rank + local_tiles = torch.cat(local_tiles, dim=-1) + local_hw_shapes = torch.stack(local_hw_shapes) + + # get all hw shapes for each rank (perhaps has different shapes for last tile) + gathered_shape_list = [torch.empty((num_tiles, 2), dtype=local_hw_shapes.dtype, device=device) + for num_tiles in self.num_tiles_per_rank] + dist.all_gather(gathered_shape_list, local_hw_shapes, group=self.vae_dp_group) + + # gather tiles on all ranks + bcn_ = local_tiles.shape[:-1] + gathered_tiles = [ + torch.empty( + (*bcn_, tiles_shape.prod(dim=1).sum().item()), + dtype=local_tiles.dtype, device=device) for tiles_shape in gathered_shape_list + ] + dist.all_gather(gathered_tiles, local_tiles, group=self.vae_dp_group) + + # put tiles in rows based on tile_idxs_per_rank + rows = [[None] * num_tile_cols for _ in range(num_tile_rows)] + for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank): + if not tile_idxs: + continue + rank_tile_hw_shapes = gathered_shape_list[rank_idx] + hw_start_idx = 0 + # perhaps has more than one tile in each rank, get each by hw_shapes + for tile_idx, (h_idx, w_idx) in enumerate(tile_idxs): + rank_tile_hw_shape = rank_tile_hw_shapes[tile_idx] + hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item() # flattend hw + rows[h_idx][w_idx] = gathered_tiles[rank_idx][..., hw_start_idx:hw_end_idx].unflatten( + -1, rank_tile_hw_shape.tolist()) # unflatten hw dim + hw_start_idx = hw_end_idx + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :tile_sample_stride_height, :tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=3)) + + dec = torch.cat(result_rows, dim=2)[:, :, :sample_height, :sample_width] + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections def fuse_qkv_projections(self): """ diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index f8bdfeb75524..3fd96a755a7f 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -1113,6 +1113,52 @@ def enable_tiling( self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + def enable_dp( + self, + world_size: Optional[int] = None, + hw_splits: Optional[Tuple[int, int]] = None, + overlap_ratio: Optional[float] = None, + overlap_pixels: Optional[int] = None + ) -> None: + r""" + """ + if world_size is None: + world_size = dist.get_world_size() + + if world_size <= 1 or world_size > dist.get_world_size(): + logger.warning( + f"Supported world_size for vae dp is between 2 - {dist.get_world_size}, but got {world_size}. " \ + f"Fall back to normal vae") + return + + if hw_splits is None: + hw_splits = (1, int(world_size)) + + assert len(hw_splits) == 2, f"'hw_splits' should be a tuple of 2 int, but got length {len(hw_splits)}" + + h_split, w_split = map(int, hw_splits) + + self.use_dp = True + self.h_split, self.w_split = h_split, w_split + self.world_size = world_size + self.overlap_ratio = overlap_ratio + self.overlap_pixels = overlap_pixels + + dp_ranks = list(range(0, world_size)) + self.vae_dp_group = dist.new_group(ranks=dp_ranks) + self.rank = dist.get_rank() + # patch_ranks_flatten = [tile_idx % world_size for tile_idx in range(num_tiles)] + # self.patch_ranks = torch.Tensor(patch_ranks_flatten).reshape(h_split, w_split) + self.tile_idxs_per_rank = [[] for _ in range(self.world_size)] + self.num_tiles_per_rank = [0] * self.world_size + rank_idx = 0 + for h_idx in range(self.h_split): + for w_idx in range(self.w_split): + rank_idx %= self.world_size + self.tile_idxs_per_rank[rank_idx].append((h_idx, w_idx)) + self.num_tiles_per_rank[rank_idx] += 1 + rank_idx += 1 + def clear_cache(self): # Use cached conv counts for decoder and encoder to avoid re-iterating modules each call self._conv_num = self._cached_conv_counts["decoder"] @@ -1393,6 +1439,276 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod return (dec,) return DecoderOutput(sample=dec) + def tiled_encode_with_dp(self, x: torch.Tensor) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + _, _, num_frames, height, width = x.shape + device = x.device + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + # Calculate stride based on h_split and w_split + tile_latent_stride_height = int((latent_height + self.h_split - 1) / self.h_split) + tile_latent_stride_width = int((latent_width + self.w_split - 1) / self.w_split) + + # Calculate overlap in latent space + overlap_latent_height = 3 + overlap_latent_width = 3 + if self.overlap_pixels is not None: + overlap_latent = (self.overlap_pixels + self.spatial_compression_ratio - 1) // self.spatial_compression_ratio + overlap_latent_height = overlap_latent + overlap_latent_width = overlap_latent + elif self.overlap_ratio is not None: + overlap_latent_height = int(self.overlap_ratio * latent_height) + overlap_latent_width = int(self.overlap_ratio * latent_width) + + # Calculate minimum tile size in latent space + tile_latent_min_height = tile_latent_stride_height + overlap_latent_height + tile_latent_min_width = tile_latent_stride_width + overlap_latent_width + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio + tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio + tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio + tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio + + # Determine tile grid dimensions - patch_ranks shape is [h_split, w_split] + num_tile_rows = self.h_split + num_tile_cols = self.w_split + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + local_tiles = [] + local_hw_shapes = [] + + for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]: + self.clear_cache() + patch_height_start = h_idx * tile_sample_stride_height + patch_height_end = patch_height_start + tile_sample_min_height + patch_width_start = w_idx * tile_sample_stride_width + patch_width_end = patch_width_start + tile_sample_min_width + time = [] + frame_range = 1 + (num_frames - 1) // 4 + for k in range(frame_range): + self._enc_conv_idx = [0] + if k == 0: + tile = x[:, :, :1, patch_height_start : patch_height_end, patch_width_start : patch_width_end] + else: + tile = x[ + :, + :, + 1 + 4 * (k - 1) : 1 + 4 * k, + patch_height_start : patch_height_end, + patch_width_start : patch_width_end, + ] + tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + tile = self.quant_conv(tile) + time.append(tile) + time = torch.cat(time, dim=2) + local_tiles.append(time.flatten(-2, -1)) + local_hw_shapes.append(torch.Tensor([*time.shape[-2:]]).to(device).int()) + self.clear_cache() + + # concat all tiles on local rank + local_tiles = torch.cat(local_tiles, dim=-1) + local_hw_shapes = torch.stack(local_hw_shapes) + + # get all hw shapes for each rank (perhaps has different shapes for last tile) + gathered_shape_list = [torch.empty((num_tiles, 2), dtype=local_hw_shapes.dtype, device=device) + for num_tiles in self.num_tiles_per_rank] + dist.all_gather(gathered_shape_list, local_hw_shapes, group=self.vae_dp_group) + + # gather tiles on all ranks + bcn_ = local_tiles.shape[:-1] + gathered_tiles = [ + torch.empty( + (*bcn_, tiles_shape.prod(dim=1).sum().item()), + dtype=local_tiles.dtype, device=device) for tiles_shape in gathered_shape_list + ] + dist.all_gather(gathered_tiles, local_tiles, group=self.vae_dp_group) + + # put tiles in rows based on tile_idxs_per_rank + rows = [[None] * num_tile_cols for _ in range(num_tile_rows)] + for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank): + if not tile_idxs: + continue + rank_tile_hw_shapes = gathered_shape_list[rank_idx] + hw_start_idx = 0 + # perhaps has more than one tile in each rank, get each by hw_shapes + for tile_idx, (h_idx, w_idx) in enumerate(tile_idxs): + rank_tile_hw_shape = rank_tile_hw_shapes[tile_idx] + hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item() # flattend hw + rows[h_idx][w_idx] = gathered_tiles[rank_idx][:, :, :, hw_start_idx:hw_end_idx].unflatten( + -1, rank_tile_hw_shape.tolist()) # unflatten hw dim + hw_start_idx = hw_end_idx + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + _, _, num_frames, height, width = z.shape + device = z.device + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + # Calculate stride based on h_split and w_split + tile_latent_stride_height = int((height + self.h_split - 1) / self.h_split) + tile_latent_stride_width = int((width + self.w_split - 1) / self.w_split) + + # Calculate overlap in latent space + overlap_latent_height = 3 + overlap_latent_width = 3 + if self.overlap_pixels is not None: + overlap_latent = (self.overlap_pixels + self.spatial_compression_ratio - 1) // self.spatial_compression_ratio + overlap_latent_height = overlap_latent + overlap_latent_width = overlap_latent + elif self.overlap_ratio is not None: + overlap_latent_height = int(self.overlap_ratio * height) + overlap_latent_width = int(self.overlap_ratio * width) + + # Calculate minimum tile size in latent space + tile_latent_min_height = tile_latent_stride_height + overlap_latent_height + tile_latent_min_width = tile_latent_stride_width + overlap_latent_width + + # Convert min/stride to sample space + tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio + tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio + tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio + tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio + + if self.config.patch_size is not None: + sample_height = sample_height // self.config.patch_size + sample_width = sample_width // self.config.patch_size + tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size + tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size + blend_height = tile_sample_min_height // self.config.patch_size - tile_sample_stride_height + blend_width = tile_sample_min_width // self.config.patch_size - tile_sample_stride_width + else: + blend_height = tile_sample_min_height - tile_sample_stride_height + blend_width = tile_sample_min_width - tile_sample_stride_width + + # Determine tile grid dimensions - patch_ranks shape is [h_split, w_split] + num_tile_rows = self.h_split + num_tile_cols = self.w_split + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + # Each rank computes only tiles assigned to it based on tile_idxs_per_rank + local_tiles = [] # List to store tiles computed by this rank + local_hw_shapes = [] # List to store shapes of tiles by this rank + + for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]: + self.clear_cache() + patch_height_start = h_idx * tile_latent_stride_height + patch_height_end = patch_height_start + tile_latent_min_height + patch_width_start = w_idx * tile_latent_stride_width + patch_width_end = patch_width_start + tile_latent_min_width + time = [] + for k in range(num_frames): + self._conv_idx = [0] + tile = z[:, :, k : k + 1, patch_height_start : patch_height_end, patch_width_start : patch_width_end] + tile = self.post_quant_conv(tile) + decoded = self.decoder( + tile, feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=(k == 0) + ) + time.append(decoded) + time = torch.cat(time, dim=2) + local_tiles.append(time.flatten(-2, -1)) # flatten h,w dim for concate all tiles in one rank + local_hw_shapes.append(torch.Tensor([*time.shape[-2:]]).to(device).int()) # record hw for futher unflatten + self.clear_cache() + + # concat all tiles on local rank + local_tiles = torch.cat(local_tiles, dim=-1) + local_hw_shapes = torch.stack(local_hw_shapes) + + # get all hw shapes for each rank (perhaps has different shapes for last tile) + gathered_shape_list = [torch.empty((num_tiles, 2), dtype=local_hw_shapes.dtype, device=device) + for num_tiles in self.num_tiles_per_rank] + dist.all_gather(gathered_shape_list, local_hw_shapes, group=self.vae_dp_group) + + # gather tiles on all ranks + bcn_ = local_tiles.shape[:-1] + gathered_tiles = [ + torch.empty( + (*bcn_, tiles_shape.prod(dim=1).sum().item()), + dtype=local_tiles.dtype, device=device) for tiles_shape in gathered_shape_list + ] + dist.all_gather(gathered_tiles, local_tiles, group=self.vae_dp_group) + + # put tiles in rows based on tile_idxs_per_rank + rows = [[None] * num_tile_cols for _ in range(num_tile_rows)] + for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank): + if not tile_idxs: + continue + rank_tile_hw_shapes = gathered_shape_list[rank_idx] + hw_start_idx = 0 + # perhaps has more than one tile in each rank, get each by hw_shapes + for tile_idx, (h_idx, w_idx) in enumerate(tile_idxs): + rank_tile_hw_shape = rank_tile_hw_shapes[tile_idx] + hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item() # flattend hw + rows[h_idx][w_idx] = gathered_tiles[rank_idx][..., hw_start_idx:hw_end_idx].unflatten( + 3, rank_tile_hw_shape.tolist()) # unflatten hw dim + hw_start_idx = hw_end_idx + + # combine all tiles, same as tiled decode + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if self.config.patch_size is not None: + dec = unpatchify(dec, patch_size=self.config.patch_size) + + dec = torch.clamp(dec, min=-1.0, max=1.0) + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + def forward( self, sample: torch.Tensor, From 6c61cd0e355c830f352a04100581e9afcc259856 Mon Sep 17 00:00:00 2001 From: yyt Date: Wed, 5 Nov 2025 03:25:45 +0000 Subject: [PATCH 2/5] extract same code in vae dp func --- .../models/autoencoders/autoencoder_kl.py | 275 ++++++------------ .../models/autoencoders/autoencoder_kl_wan.py | 223 +++++--------- src/diffusers/models/autoencoders/vae.py | 78 ++++- 3 files changed, 238 insertions(+), 338 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index e6655a908860..33841b2dae06 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -523,33 +523,41 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod if not return_dict: return (dec,) - return DecoderOutput(sample=dec) + def calculate_tiled_parallel_size(self, latent_height, latent_width): + # Calculate stride based on h_split and w_split + tile_latent_stride_height = int((latent_height + self.h_split - 1) / self.h_split) + tile_latent_stride_width = int((latent_width + self.w_split - 1) / self.w_split) - def forward( - self, - sample: torch.Tensor, - sample_posterior: bool = False, - return_dict: bool = True, - generator: Optional[torch.Generator] = None, - ) -> Union[DecoderOutput, torch.Tensor]: - r""" - Args: - sample (`torch.Tensor`): Input sample. - sample_posterior (`bool`, *optional*, defaults to `False`): - Whether to sample from the posterior. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`DecoderOutput`] instead of a plain tuple. - """ - x = sample - posterior = self.encode(x).latent_dist - if sample_posterior: - z = posterior.sample(generator=generator) - else: - z = posterior.mode() - dec = self.decode(z).sample + # Calculate overlap in latent space + overlap_latent_height = 3 + overlap_latent_width = 3 + if self.overlap_pixels is not None: + overlap_latent = (self.overlap_pixels + self.spatial_compression_ratio - 1) // self.spatial_compression_ratio + overlap_latent_height = overlap_latent + overlap_latent_width = overlap_latent + elif self.overlap_ratio is not None: + overlap_latent_height = int(self.overlap_ratio * latent_height) + overlap_latent_width = int(self.overlap_ratio * latent_width) - if not return_dict: - return (dec,) + # Calculate minimum tile size in latent space + tile_latent_min_height = tile_latent_stride_height + overlap_latent_height + tile_latent_min_width = tile_latent_stride_width + overlap_latent_width + + tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio + tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio + tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio + tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio + + blend_latent_height = tile_latent_min_height - tile_latent_stride_height + blend_latent_width = tile_latent_min_width - tile_latent_stride_width + + blend_sample_height = tile_sample_min_height - tile_sample_stride_height + blend_sample_width = tile_sample_min_width - tile_sample_stride_width + + return \ + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, \ + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, \ + blend_latent_height, blend_latent_width, blend_sample_height, blend_sample_width return DecoderOutput(sample=dec) @@ -575,86 +583,24 @@ def _tiled_encode_with_dp(self, x: torch.Tensor) -> torch.Tensor: latent_height = height // self.spatial_compression_ratio latent_width = width // self.spatial_compression_ratio - # Calculate stride based on h_split and w_split - tile_latent_stride_height = int((latent_height + self.h_split - 1) / self.h_split) - tile_latent_stride_width = int((latent_width + self.w_split - 1) / self.w_split) - - # Calculate overlap in latent space - overlap_latent_height = 3 - overlap_latent_width = 3 - if self.overlap_pixels is not None: - overlap_latent = (self.overlap_pixels + self.spatial_compression_ratio - 1) // self.spatial_compression_ratio - overlap_latent_height = overlap_latent - overlap_latent_width = overlap_latent - elif self.overlap_ratio is not None: - overlap_latent_height = int(self.overlap_ratio * latent_height) - overlap_latent_width = int(self.overlap_ratio * latent_width) - - # Calculate minimum tile size in latent space - tile_latent_min_height = tile_latent_stride_height + overlap_latent_height - tile_latent_min_width = tile_latent_stride_width + overlap_latent_width - - blend_height = tile_latent_min_height - tile_latent_stride_height - blend_width = tile_latent_min_width - tile_latent_stride_width - - tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio - tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio - tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio - tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio - - # Determine tile grid dimensions - patch_ranks shape is [h_split, w_split] - num_tile_rows = self.h_split - num_tile_cols = self.w_split - - local_tiles = [] - local_hw_shapes = [] - - for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]: - patch_height_start = h_idx * tile_sample_stride_height - patch_height_end = patch_height_start + tile_sample_min_height - patch_width_start = w_idx * tile_sample_stride_width - patch_width_end = patch_width_start + tile_sample_min_width + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, \ + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, \ + blend_latent_height, blend_latent_width, blend_sample_height, blend_sample_width = \ + self.calculate_tiled_parallel_size(latent_height, latent_width) + def vae_encode_op( + x, patch_height_start, patch_height_end, patch_width_start, patch_width_end + ) -> torch.Tensor: tile = x[:, :, patch_height_start : patch_height_end, patch_width_start : patch_width_end] tile = self.encoder(tile) if self.config.use_quant_conv: tile = self.quant_conv(tile) + return tile - local_tiles.append(tile.flatten(-2, -1)) - local_hw_shapes.append(torch.Tensor([*tile.shape[-2:]]).to(device).int()) - - # concat all tiles on local rank - local_tiles = torch.cat(local_tiles, dim=-1) - local_hw_shapes = torch.stack(local_hw_shapes) - - # get all hw shapes for each rank (perhaps has different shapes for last tile) - gathered_shape_list = [torch.empty((num_tiles, 2), dtype=local_hw_shapes.dtype, device=device) - for num_tiles in self.num_tiles_per_rank] - dist.all_gather(gathered_shape_list, local_hw_shapes, group=self.vae_dp_group) - - # gather tiles on all ranks - bc_ = local_tiles.shape[:-1] - gathered_tiles = [ - torch.empty( - (*bc_, tiles_shape.prod(dim=1).sum().item()), - dtype=local_tiles.dtype, device=device) for tiles_shape in gathered_shape_list - ] - dist.all_gather(gathered_tiles, local_tiles, group=self.vae_dp_group) - - # put tiles in rows based on tile_idxs_per_rank - rows = [[None] * num_tile_cols for _ in range(num_tile_rows)] - for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank): - if not tile_idxs: - continue - rank_tile_hw_shapes = gathered_shape_list[rank_idx] - hw_start_idx = 0 - # perhaps has more than one tile in each rank, get each by hw_shapes - for tile_idx, (h_idx, w_idx) in enumerate(tile_idxs): - rank_tile_hw_shape = rank_tile_hw_shapes[tile_idx] - hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item() # flattend hw - rows[h_idx][w_idx] = gathered_tiles[rank_idx][..., hw_start_idx:hw_end_idx].unflatten( - -1, rank_tile_hw_shape.tolist()) # unflatten hw dim - hw_start_idx = hw_end_idx + rows = self.run_vae_tile_parallel( + x, vae_encode_op, + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, device + ) result_rows = [] for i, row in enumerate(rows): @@ -663,9 +609,9 @@ def _tiled_encode_with_dp(self, x: torch.Tensor) -> torch.Tensor: # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_height) + tile = self.blend_v(rows[i - 1][j], tile, blend_latent_height) if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_width) + tile = self.blend_h(row[j - 1], tile, blend_latent_width) result_row.append(tile[:, :, :tile_latent_stride_height, :tile_latent_stride_width]) result_rows.append(torch.cat(result_row, dim=3)) @@ -686,95 +632,30 @@ def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Uni If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is returned. """ - _, _, height, width = z.shape + _, _, latent_height, latent_width = z.shape device = z.device - sample_height = height * self.spatial_compression_ratio - sample_width = width * self.spatial_compression_ratio - - # Calculate stride based on h_split and w_split - tile_latent_stride_height = int((height + self.h_split - 1) / self.h_split) - tile_latent_stride_width = int((width + self.w_split - 1) / self.w_split) - - # Calculate overlap in latent space - overlap_latent_height = 3 - overlap_latent_width = 3 - if self.overlap_pixels is not None: - overlap_latent = (self.overlap_pixels + self.spatial_compression_ratio - 1) // self.spatial_compression_ratio - overlap_latent_height = overlap_latent - overlap_latent_width = overlap_latent - elif self.overlap_ratio is not None: - overlap_latent_height = int(self.overlap_ratio * height) - overlap_latent_width = int(self.overlap_ratio * width) - - # Calculate minimum tile size in latent space - tile_latent_min_height = tile_latent_stride_height + overlap_latent_height - tile_latent_min_width = tile_latent_stride_width + overlap_latent_width - - # Convert min/stride to sample space - tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio - tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio - tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio - tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio - - blend_height = tile_sample_min_height - tile_sample_stride_height - blend_width = tile_sample_min_width - tile_sample_stride_width + sample_height = latent_height * self.spatial_compression_ratio + sample_width = latent_width * self.spatial_compression_ratio - # Determine tile grid dimensions - patch_ranks shape is [h_split, w_split] - num_tile_rows = self.h_split - num_tile_cols = self.w_split + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, \ + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, \ + blend_latent_height, blend_latent_width, blend_sample_height, blend_sample_width = \ + self.calculate_tiled_parallel_size(latent_height, latent_width) - # Split z into overlapping tiles and decode them separately. - # The tiles have an overlap to avoid seams between tiles. - # Each rank computes only tiles assigned to it based on tile_idxs_per_rank - local_tiles = [] # List to store tiles computed by this rank - local_hw_shapes = [] # List to store shapes of tiles by this rank - - for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]: - patch_height_start = h_idx * tile_latent_stride_height - patch_height_end = patch_height_start + tile_latent_min_height - patch_width_start = w_idx * tile_latent_stride_width - patch_width_end = patch_width_start + tile_latent_min_width + def vae_decode_op( + z, patch_height_start, patch_height_end, patch_width_start, patch_width_end + ) -> torch.Tensor: tile = z[:, :, patch_height_start : patch_height_end, patch_width_start : patch_width_end] if self.config.use_post_quant_conv: tile = self.post_quant_conv(tile) decoded = self.decoder(tile) + return decoded - local_tiles.append(decoded.flatten(-2, -1)) # flatten h,w dim for concate all tiles in one rank - local_hw_shapes.append(torch.Tensor([*decoded.shape[-2:]]).to(device).int()) # record hw for futher unflatten - - # concat all tiles on local rank - local_tiles = torch.cat(local_tiles, dim=-1) - local_hw_shapes = torch.stack(local_hw_shapes) - - # get all hw shapes for each rank (perhaps has different shapes for last tile) - gathered_shape_list = [torch.empty((num_tiles, 2), dtype=local_hw_shapes.dtype, device=device) - for num_tiles in self.num_tiles_per_rank] - dist.all_gather(gathered_shape_list, local_hw_shapes, group=self.vae_dp_group) - - # gather tiles on all ranks - bcn_ = local_tiles.shape[:-1] - gathered_tiles = [ - torch.empty( - (*bcn_, tiles_shape.prod(dim=1).sum().item()), - dtype=local_tiles.dtype, device=device) for tiles_shape in gathered_shape_list - ] - dist.all_gather(gathered_tiles, local_tiles, group=self.vae_dp_group) - - # put tiles in rows based on tile_idxs_per_rank - rows = [[None] * num_tile_cols for _ in range(num_tile_rows)] - for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank): - if not tile_idxs: - continue - rank_tile_hw_shapes = gathered_shape_list[rank_idx] - hw_start_idx = 0 - # perhaps has more than one tile in each rank, get each by hw_shapes - for tile_idx, (h_idx, w_idx) in enumerate(tile_idxs): - rank_tile_hw_shape = rank_tile_hw_shapes[tile_idx] - hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item() # flattend hw - rows[h_idx][w_idx] = gathered_tiles[rank_idx][..., hw_start_idx:hw_end_idx].unflatten( - -1, rank_tile_hw_shape.tolist()) # unflatten hw dim - hw_start_idx = hw_end_idx + rows = self.run_vae_tile_parallel( + z, vae_decode_op, + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, device + ) result_rows = [] for i, row in enumerate(rows): @@ -783,9 +664,9 @@ def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Uni # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_height) + tile = self.blend_v(rows[i - 1][j], tile, blend_sample_height) if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_width) + tile = self.blend_h(row[j - 1], tile, blend_sample_width) result_row.append(tile[:, :, :tile_sample_stride_height, :tile_sample_stride_width]) result_rows.append(torch.cat(result_row, dim=3)) @@ -795,6 +676,34 @@ def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Uni return DecoderOutput(sample=dec) + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections def fuse_qkv_projections(self): """ diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index 3fd96a755a7f..960c3a9d87f2 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -1439,21 +1439,7 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod return (dec,) return DecoderOutput(sample=dec) - def tiled_encode_with_dp(self, x: torch.Tensor) -> AutoencoderKLOutput: - r"""Encode a batch of images using a tiled encoder. - - Args: - x (`torch.Tensor`): Input batch of videos. - - Returns: - `torch.Tensor`: - The latent representation of the encoded videos. - """ - _, _, num_frames, height, width = x.shape - device = x.device - latent_height = height // self.spatial_compression_ratio - latent_width = width // self.spatial_compression_ratio - + def calculate_tiled_parallel_size(self, latent_height, latent_width): # Calculate stride based on h_split and w_split tile_latent_stride_height = int((latent_height + self.h_split - 1) / self.h_split) tile_latent_stride_width = int((latent_width + self.w_split - 1) / self.w_split) @@ -1473,29 +1459,55 @@ def tiled_encode_with_dp(self, x: torch.Tensor) -> AutoencoderKLOutput: tile_latent_min_height = tile_latent_stride_height + overlap_latent_height tile_latent_min_width = tile_latent_stride_width + overlap_latent_width - blend_height = tile_latent_min_height - tile_latent_stride_height - blend_width = tile_latent_min_width - tile_latent_stride_width - tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio - # Determine tile grid dimensions - patch_ranks shape is [h_split, w_split] - num_tile_rows = self.h_split - num_tile_cols = self.w_split + blend_latent_height = tile_latent_min_height - tile_latent_stride_height + blend_latent_width = tile_latent_min_width - tile_latent_stride_width - # Split x into overlapping tiles and encode them separately. - # The tiles have an overlap to avoid seams between tiles. - local_tiles = [] - local_hw_shapes = [] + if self.config.patch_size is not None: + sample_height = sample_height // self.config.patch_size + sample_width = sample_width // self.config.patch_size + tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size + tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size + blend_sample_height = tile_sample_min_height // self.config.patch_size - tile_sample_stride_height + blend_sample_width = tile_sample_min_width // self.config.patch_size - tile_sample_stride_width + else: + blend_sample_height = tile_sample_min_height - tile_sample_stride_height + blend_sample_width = tile_sample_min_width - tile_sample_stride_width + + return \ + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, \ + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, \ + blend_latent_height, blend_latent_width, blend_sample_height, blend_sample_width + + def tiled_encode_with_dp(self, x: torch.Tensor) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + _, _, num_frames, sample_height, sample_width = x.shape + device = x.device + latent_height = sample_height // self.spatial_compression_ratio + latent_width = sample_width // self.spatial_compression_ratio + + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, \ + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, \ + blend_latent_height, blend_latent_width, blend_sample_height, blend_sample_width = \ + self.calculate_tiled_parallel_size(latent_height, latent_width) + + def vae_encode_op( + x, patch_height_start, patch_height_end, patch_width_start, patch_width_end, num_frames + ) -> torch.Tensor: - for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]: self.clear_cache() - patch_height_start = h_idx * tile_sample_stride_height - patch_height_end = patch_height_start + tile_sample_min_height - patch_width_start = w_idx * tile_sample_stride_width - patch_width_end = patch_width_start + tile_sample_min_width time = [] frame_range = 1 + (num_frames - 1) // 4 for k in range(frame_range): @@ -1514,42 +1526,14 @@ def tiled_encode_with_dp(self, x: torch.Tensor) -> AutoencoderKLOutput: tile = self.quant_conv(tile) time.append(tile) time = torch.cat(time, dim=2) - local_tiles.append(time.flatten(-2, -1)) - local_hw_shapes.append(torch.Tensor([*time.shape[-2:]]).to(device).int()) self.clear_cache() + return time - # concat all tiles on local rank - local_tiles = torch.cat(local_tiles, dim=-1) - local_hw_shapes = torch.stack(local_hw_shapes) - - # get all hw shapes for each rank (perhaps has different shapes for last tile) - gathered_shape_list = [torch.empty((num_tiles, 2), dtype=local_hw_shapes.dtype, device=device) - for num_tiles in self.num_tiles_per_rank] - dist.all_gather(gathered_shape_list, local_hw_shapes, group=self.vae_dp_group) - - # gather tiles on all ranks - bcn_ = local_tiles.shape[:-1] - gathered_tiles = [ - torch.empty( - (*bcn_, tiles_shape.prod(dim=1).sum().item()), - dtype=local_tiles.dtype, device=device) for tiles_shape in gathered_shape_list - ] - dist.all_gather(gathered_tiles, local_tiles, group=self.vae_dp_group) - - # put tiles in rows based on tile_idxs_per_rank - rows = [[None] * num_tile_cols for _ in range(num_tile_rows)] - for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank): - if not tile_idxs: - continue - rank_tile_hw_shapes = gathered_shape_list[rank_idx] - hw_start_idx = 0 - # perhaps has more than one tile in each rank, get each by hw_shapes - for tile_idx, (h_idx, w_idx) in enumerate(tile_idxs): - rank_tile_hw_shape = rank_tile_hw_shapes[tile_idx] - hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item() # flattend hw - rows[h_idx][w_idx] = gathered_tiles[rank_idx][:, :, :, hw_start_idx:hw_end_idx].unflatten( - -1, rank_tile_hw_shape.tolist()) # unflatten hw dim - hw_start_idx = hw_end_idx + rows = self.run_vae_tile_parallel( + x, vae_encode_op, + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, device, + num_frames=num_frames + ) result_rows = [] for i, row in enumerate(rows): @@ -1558,9 +1542,9 @@ def tiled_encode_with_dp(self, x: torch.Tensor) -> AutoencoderKLOutput: # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_height) + tile = self.blend_v(rows[i - 1][j], tile, blend_latent_height) if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_width) + tile = self.blend_h(row[j - 1], tile, blend_latent_width) result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) result_rows.append(torch.cat(result_row, dim=-1)) @@ -1581,63 +1565,22 @@ def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Uni If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is returned. """ - _, _, num_frames, height, width = z.shape + _, _, num_frames, latent_height, latent_width = z.shape device = z.device - sample_height = height * self.spatial_compression_ratio - sample_width = width * self.spatial_compression_ratio + sample_height = latent_height * self.spatial_compression_ratio + sample_width = latent_width * self.spatial_compression_ratio - # Calculate stride based on h_split and w_split - tile_latent_stride_height = int((height + self.h_split - 1) / self.h_split) - tile_latent_stride_width = int((width + self.w_split - 1) / self.w_split) - - # Calculate overlap in latent space - overlap_latent_height = 3 - overlap_latent_width = 3 - if self.overlap_pixels is not None: - overlap_latent = (self.overlap_pixels + self.spatial_compression_ratio - 1) // self.spatial_compression_ratio - overlap_latent_height = overlap_latent - overlap_latent_width = overlap_latent - elif self.overlap_ratio is not None: - overlap_latent_height = int(self.overlap_ratio * height) - overlap_latent_width = int(self.overlap_ratio * width) + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, \ + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, \ + blend_latent_height, blend_latent_width, blend_sample_height, blend_sample_width = \ + self.calculate_tiled_parallel_size(latent_height, latent_width) - # Calculate minimum tile size in latent space - tile_latent_min_height = tile_latent_stride_height + overlap_latent_height - tile_latent_min_width = tile_latent_stride_width + overlap_latent_width + def vae_decode_op( + z, patch_height_start, patch_height_end, patch_width_start, patch_width_end, num_frames + ) -> torch.Tensor: - # Convert min/stride to sample space - tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio - tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio - tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio - tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio - - if self.config.patch_size is not None: - sample_height = sample_height // self.config.patch_size - sample_width = sample_width // self.config.patch_size - tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size - tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size - blend_height = tile_sample_min_height // self.config.patch_size - tile_sample_stride_height - blend_width = tile_sample_min_width // self.config.patch_size - tile_sample_stride_width - else: - blend_height = tile_sample_min_height - tile_sample_stride_height - blend_width = tile_sample_min_width - tile_sample_stride_width - - # Determine tile grid dimensions - patch_ranks shape is [h_split, w_split] - num_tile_rows = self.h_split - num_tile_cols = self.w_split - - # Split z into overlapping tiles and decode them separately. - # The tiles have an overlap to avoid seams between tiles. - # Each rank computes only tiles assigned to it based on tile_idxs_per_rank - local_tiles = [] # List to store tiles computed by this rank - local_hw_shapes = [] # List to store shapes of tiles by this rank - - for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]: self.clear_cache() - patch_height_start = h_idx * tile_latent_stride_height - patch_height_end = patch_height_start + tile_latent_min_height - patch_width_start = w_idx * tile_latent_stride_width - patch_width_end = patch_width_start + tile_latent_min_width + time = [] for k in range(num_frames): self._conv_idx = [0] @@ -1648,42 +1591,14 @@ def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Uni ) time.append(decoded) time = torch.cat(time, dim=2) - local_tiles.append(time.flatten(-2, -1)) # flatten h,w dim for concate all tiles in one rank - local_hw_shapes.append(torch.Tensor([*time.shape[-2:]]).to(device).int()) # record hw for futher unflatten self.clear_cache() + return time - # concat all tiles on local rank - local_tiles = torch.cat(local_tiles, dim=-1) - local_hw_shapes = torch.stack(local_hw_shapes) - - # get all hw shapes for each rank (perhaps has different shapes for last tile) - gathered_shape_list = [torch.empty((num_tiles, 2), dtype=local_hw_shapes.dtype, device=device) - for num_tiles in self.num_tiles_per_rank] - dist.all_gather(gathered_shape_list, local_hw_shapes, group=self.vae_dp_group) - - # gather tiles on all ranks - bcn_ = local_tiles.shape[:-1] - gathered_tiles = [ - torch.empty( - (*bcn_, tiles_shape.prod(dim=1).sum().item()), - dtype=local_tiles.dtype, device=device) for tiles_shape in gathered_shape_list - ] - dist.all_gather(gathered_tiles, local_tiles, group=self.vae_dp_group) - - # put tiles in rows based on tile_idxs_per_rank - rows = [[None] * num_tile_cols for _ in range(num_tile_rows)] - for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank): - if not tile_idxs: - continue - rank_tile_hw_shapes = gathered_shape_list[rank_idx] - hw_start_idx = 0 - # perhaps has more than one tile in each rank, get each by hw_shapes - for tile_idx, (h_idx, w_idx) in enumerate(tile_idxs): - rank_tile_hw_shape = rank_tile_hw_shapes[tile_idx] - hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item() # flattend hw - rows[h_idx][w_idx] = gathered_tiles[rank_idx][..., hw_start_idx:hw_end_idx].unflatten( - 3, rank_tile_hw_shape.tolist()) # unflatten hw dim - hw_start_idx = hw_end_idx + rows = self.run_vae_tile_parallel( + z, vae_decode_op, + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, device, + num_frames=num_frames + ) # combine all tiles, same as tiled decode result_rows = [] @@ -1693,9 +1608,9 @@ def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Uni # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_height) + tile = self.blend_v(rows[i - 1][j], tile, blend_sample_height) if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_width) + tile = self.blend_h(row[j - 1], tile, blend_sample_width) result_row.append(tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width]) result_rows.append(torch.cat(result_row, dim=-1)) dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] diff --git a/src/diffusers/models/autoencoders/vae.py b/src/diffusers/models/autoencoders/vae.py index 9c6031a988f9..d798711ec240 100644 --- a/src/diffusers/models/autoencoders/vae.py +++ b/src/diffusers/models/autoencoders/vae.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional, Tuple, List import numpy as np import torch import torch.nn as nn +import torch.distributed as dist from ...utils import BaseOutput from ...utils.torch_utils import randn_tensor @@ -926,3 +927,78 @@ def disable_slicing(self): decoding in one step. """ self.use_slicing = False + + def enable_dp(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + if not hasattr(self, "use_tiling"): + raise NotImplementedError(f"Tiling Parallel doesn't seem to be implemented for {self.__class__.__name__}.") + self.use_dp = True + + def disable_dp(self): + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_dp = False + + def run_vae_tile_parallel( + self, + input: torch.Tensor, + vae_op, + min_height, + min_width, + stride_height, + stride_width, + device, + **kwargs) -> List[List[torch.Tensor]]: + + local_tiles = [] + local_hw_shapes = [] + + for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]: + patch_height_start = h_idx * stride_height + patch_height_end = patch_height_start + min_height + patch_width_start = w_idx * stride_width + patch_width_end = patch_width_start + min_width + tile = vae_op(input, patch_height_start, patch_height_end, patch_width_start, patch_width_end, **kwargs) + local_tiles.append(tile.flatten(-2, -1)) + local_hw_shapes.append(torch.Tensor([*tile.shape[-2:]]).to(device).int()) + + # concat all tiles on local rank + local_tiles = torch.cat(local_tiles, dim=-1) + local_hw_shapes = torch.stack(local_hw_shapes) + + # get all hw shapes for each rank (perhaps has different shapes for last tile) + gathered_shape_list = [torch.empty((num_tiles, 2), dtype=local_hw_shapes.dtype, device=device) + for num_tiles in self.num_tiles_per_rank] + dist.all_gather(gathered_shape_list, local_hw_shapes, group=self.vae_dp_group) + + # gather tiles on all ranks + tile_shape_first = local_tiles.shape[:-1] + gathered_tiles = [ + torch.empty( + (*tile_shape_first, tiles_shape.prod(dim=1).sum().item()), + dtype=local_tiles.dtype, device=device) for tiles_shape in gathered_shape_list + ] + dist.all_gather(gathered_tiles, local_tiles, group=self.vae_dp_group) + + # put tiles in rows based on tile_idxs_per_rank + rows = [[None] * self.w_split for _ in range(self.h_split)] + for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank): + if not tile_idxs: + continue + rank_tile_hw_shapes = gathered_shape_list[rank_idx] + hw_start_idx = 0 + # perhaps has more than one tile in each rank, get each by hw_shapes + for tile_idx, (h_idx, w_idx) in enumerate(tile_idxs): + rank_tile_hw_shape = rank_tile_hw_shapes[tile_idx] + hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item() # flattend hw + rows[h_idx][w_idx] = gathered_tiles[rank_idx][..., hw_start_idx:hw_end_idx].unflatten( + -1, rank_tile_hw_shape.tolist()) # unflatten hw dim + hw_start_idx = hw_end_idx + + return rows \ No newline at end of file From 4aeeeb98698e7f51ddb96718cd5dad5d768a0e6b Mon Sep 17 00:00:00 2001 From: yyt Date: Wed, 5 Nov 2025 09:04:13 +0000 Subject: [PATCH 3/5] optimize blend method in tiled vae --- .../models/autoencoders/autoencoder_kl.py | 26 ++++++++++++++----- .../models/autoencoders/autoencoder_kl_wan.py | 22 +++++++++++++--- 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index 33841b2dae06..3517dde44f97 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -366,6 +366,20 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch. b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) return b + def blend_v_(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + y = torch.arange(0, blend_extent, device=a.device) + blend_ratio = (y / blend_extent)[None, None, :, None].to(a.dtype) + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - blend_ratio) + b[:, :, y, :] * blend_ratio + return b + + def blend_h_(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + x = torch.arange(0, blend_extent, device=a.device) + blend_ratio = (x / blend_extent)[None, None, None, :].to(a.dtype) + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - blend_ratio) + b[:, :, :, x] * blend_ratio + return b + def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor: r"""Encode a batch of images using a tiled encoder. @@ -523,6 +537,8 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod if not return_dict: return (dec,) + return DecoderOutput(sample=dec) + def calculate_tiled_parallel_size(self, latent_height, latent_width): # Calculate stride based on h_split and w_split tile_latent_stride_height = int((latent_height + self.h_split - 1) / self.h_split) @@ -559,8 +575,6 @@ def calculate_tiled_parallel_size(self, latent_height, latent_width): tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, \ blend_latent_height, blend_latent_width, blend_sample_height, blend_sample_width - return DecoderOutput(sample=dec) - def _tiled_encode_with_dp(self, x: torch.Tensor) -> torch.Tensor: r"""Encode a batch of images using a tiled encoder. @@ -609,9 +623,9 @@ def vae_encode_op( # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_latent_height) + tile = self.blend_v_(rows[i - 1][j], tile, blend_latent_height) if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_latent_width) + tile = self.blend_h_(row[j - 1], tile, blend_latent_width) result_row.append(tile[:, :, :tile_latent_stride_height, :tile_latent_stride_width]) result_rows.append(torch.cat(result_row, dim=3)) @@ -664,9 +678,9 @@ def vae_decode_op( # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_sample_height) + tile = self.blend_v_(rows[i - 1][j], tile, blend_sample_height) if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_sample_width) + tile = self.blend_h_(row[j - 1], tile, blend_sample_width) result_row.append(tile[:, :, :tile_sample_stride_height, :tile_sample_stride_width]) result_rows.append(torch.cat(result_row, dim=3)) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index 960c3a9d87f2..8bd8a12403eb 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -1295,6 +1295,20 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch. ) return b + def blend_v_(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + y = torch.arange(0, blend_extent, device=a.device) + blend_ratio = (y / blend_extent)[None, None, None, :, None].to(a.dtype) + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - blend_ratio) + b[:, :, :, y, :] * blend_ratio + return b + + def blend_h_(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + x = torch.arange(0, blend_extent, device=a.device) + blend_ratio = (x / blend_extent)[None, None, None, None, :].to(a.dtype) + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - blend_ratio) + b[:, :, :, :, x] * blend_ratio + return b + def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: r"""Encode a batch of images using a tiled encoder. @@ -1542,9 +1556,9 @@ def vae_encode_op( # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_latent_height) + tile = self.blend_v_(rows[i - 1][j], tile, blend_latent_height) if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_latent_width) + tile = self.blend_h_(row[j - 1], tile, blend_latent_width) result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) result_rows.append(torch.cat(result_row, dim=-1)) @@ -1608,9 +1622,9 @@ def vae_decode_op( # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_sample_height) + tile = self.blend_v_(rows[i - 1][j], tile, blend_sample_height) if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_sample_width) + tile = self.blend_h_(row[j - 1], tile, blend_sample_width) result_row.append(tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width]) result_rows.append(torch.cat(result_row, dim=-1)) dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] From 8cfad756b8f286b1faff50113d7ba4114526e913 Mon Sep 17 00:00:00 2001 From: yyt Date: Wed, 5 Nov 2025 12:49:56 +0000 Subject: [PATCH 4/5] fix world_size 1 bug when init parallel tiling --- src/diffusers/models/autoencoders/autoencoder_kl.py | 2 +- src/diffusers/models/autoencoders/autoencoder_kl_wan.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index 3517dde44f97..f11c7db25386 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -229,7 +229,7 @@ def enable_dp( r""" """ if world_size is None: - world_size = dist.get_world_size() + world_size = dist.get_world_size() if dist.is_initialized() else 1 if world_size <= 1 or world_size > dist.get_world_size(): logger.warning( diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index 8bd8a12403eb..0252f1e22580 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -17,6 +17,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import torch.distributed as dist from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin @@ -1123,7 +1124,7 @@ def enable_dp( r""" """ if world_size is None: - world_size = dist.get_world_size() + world_size = dist.get_world_size() if dist.is_initialized() else 1 if world_size <= 1 or world_size > dist.get_world_size(): logger.warning( From 25115150a16ee85b31df44566ac31fbfab67e576 Mon Sep 17 00:00:00 2001 From: yyt Date: Wed, 5 Nov 2025 12:56:40 +0000 Subject: [PATCH 5/5] fix bug in vae_kl_wan --- src/diffusers/models/autoencoders/autoencoder_kl_wan.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index 0252f1e22580..f034cebde12b 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -1074,6 +1074,8 @@ def __init__( self.tile_sample_stride_height = 192 self.tile_sample_stride_width = 192 + self.use_dp = False + # Precompute and cache conv counts for encoder and decoder for clear_cache speedup self._cached_conv_counts = { "decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules()) @@ -1177,6 +1179,9 @@ def _encode(self, x: torch.Tensor): if self.config.patch_size is not None: x = patchify(x, patch_size=self.config.patch_size) + if self.use_dp: + return self.tiled_encode_with_dp(x) + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): return self.tiled_encode(x) @@ -1229,6 +1234,9 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True): tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + if self.use_dp: + return self.tiled_decode_with_dp(z, return_dict=return_dict) + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): return self.tiled_decode(z, return_dict=return_dict)