Skip to content

Commit

Permalink
FIFO-Diffusion with VC2 release
Browse files Browse the repository at this point in the history
  • Loading branch information
jjihwan committed May 25, 2024
1 parent 98661a2 commit d72a993
Show file tree
Hide file tree
Showing 26 changed files with 6,551 additions and 24 deletions.
15 changes: 15 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# results
results/

# checkpoints
videocrafter_models/
zeroscope_models/

# venvs
.fifo
.sora

# others
taming
.DS_Store
__pycache__
77 changes: 53 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<div align="center">

<p>
💾 <b> VRAM < 10GB</b> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
💾 <b> VRAM < 10GB </b> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
🚀 <b> Infinitely Long Videos</b> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
⭐️ <b> Tuning-free</b>
</p>
Expand All @@ -12,14 +12,32 @@

</div>

## 📽️ See video samples in our <a href="https://jjihwan.github.io/projects/FIFO-Diffusion"> project page</a>!
## 📽️ See more video samples in our <a href="https://jjihwan.github.io/projects/FIFO-Diffusion"> project page</a>!
<div align="center">

<img src="https://github.com/jjihwan/FIFO-Diffusion_public/assets/63445348/aafafa52-5ddf-4093-9d29-681fe469e447">

"An astronaut floating in space, high quality, 4K resolution."

100 frames, 320 X 512 resolution

<img src="https://github.com/jjihwan/FIFO-Diffusion_public/assets/63445348/b198c5bb-5104-4a57-a433-ddadfa7ec713">

"A colony of penguins waddling on an Antarctic ice sheet, 4K, ultra HD."

100 frames, 320 X 512 resolution
</div>


## News 📰
**[2024.05.25]** 🥳🥳🥳 We are thrilled to present our official PyTorch implementation for FIFO-Diffusion. We are releasing the code based on VideoCrafter2.

**[2024.05.19]** Our paper, *FIFO-Diffusion: Generating Infinite Videos from Text without Training*, has been archived.

## Clone our repository
```
git clone [email protected]:jjihwan/FIFO-Diffusion.git
cd FIFO-Diffusion
git clone [email protected]:jjihwan/FIFO-Diffusion_public.git
cd FIFO-Diffusion_public
```

## ☀️ Start with <a href="https://github.com/AILab-CVC/VideoCrafter">VideoCrafter</a>
Expand All @@ -34,43 +52,39 @@ pip install -r requirements.txt

### 2.1 Download the models from Hugging Face🤗
|Model|Resolution|Checkpoint
|:---------|:---------|:--------
|:----|:---------|:---------
|VideoCrafter2 (Text2Video)|320x512|[Hugging Face](https://huggingface.co/VideoCrafter/VideoCrafter2/blob/main/model.ckpt)
|VideoCrafter1 (Text2Video)|320x512|[Hugging Face](https://huggingface.co/VideoCrafter/Text2Video-512/blob/main/model.ckpt)

### 2.2 Set file structure
Store them as following structure:
```
cd FIFO-Diffusion
cd FIFO-Diffusion_public
.
└── videocrafter_models
├── base_512_v2
│ └── model.ckpt # VideoCrafter2 checkpoint
└── base_512_v1
└── model.ckpt # VideoCrafter1 checkpoint
└── base_512_v2
└── model.ckpt # VideoCrafter2 checkpoint
```

### 3.1. Run with VideoCrafter2
### 3.1. Run with VideoCrafter2 (Single GPU)
Requires less than **9GB VRAM** with Titan XP.
```
python3 videocrafter_main.py
python3 videocrafter_main.py --save_frames
```

### 3.2. Distributed Parallel inference with VideoCrafter2 (Multiple GPUs required)
### 3.2. Distributed Parallel inference with VideoCrafter2 (Multiple GPUs)
May consume slightly more memory than the single GPU inference (**11GB** with Titan XP).
Please note that our implementation for parallel inference might not be optimal.
Pull requests are welcome! 🤓

```
python3 videocrafter_main_mp.py --num_gpus 8
python3 videocrafter_main_mp.py --num_gpus 8 --save_frames
```

### 3.3. Run with VideoCrafter1
```
python3 videocrafter_main.py -ver=1
```

## ☀️ Start with <a href="https://github.com/PKU-YuanGroup/Open-Sora-Plan">Open-Sora Plan</a>
## ☀️ Start with <a href="https://github.com/PKU-YuanGroup/Open-Sora-Plan">Open-Sora Plan</a> (Comming Soon)

### 1. Environment Setup ⚙️ (python==3.10.14 recommended)
```
cd FIFO-Diffusion
cd FIFO-Diffusion_public
git clone [email protected]:PKU-YuanGroup/Open-Sora-Plan.git
python -m venv .sora
Expand All @@ -85,7 +99,7 @@ pip install -e .
sh scripts/opensora_fifo_ddpm.sh
```

## ☀️ Start with <a href="https://huggingface.co/cerspense/zeroscope_v2_576w">zeroscope</a>
## ☀️ Start with <a href="https://huggingface.co/cerspense/zeroscope_v2_576w">zeroscope</a> (Comming Soon)

### 1. Environment Setup ⚙️ (python==3.10.14 recommended)
```
Expand All @@ -95,8 +109,23 @@ source .fifo/bin/activate
pip install -r requirements.txt
```

### 2. Run with zeroscope(Recommended)
### 2. Run with zeroscope
```
mkdir zeroscope_models # directory where the model will be stored
python3 zeroscope_main.py
```

## 😆 Citation
```
@article{kim2024fifo,
title = {FIFO-Diffusion: Generating Infinite Videos from Text without Training},
author = {Jihwan Kim and Junoh Kang and Jinyoung Choi and Bohyung Han},
journal = {arXiv preprint arXiv:2405.11473},
year = {2024},
}
```


## 🤓 Acknowledgements
Our codebase builds on [VideoCrafter](https://github.com/AILab-CVC/VideoCrafter), [Open-Sora Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [zeroscope](https://huggingface.co/cerspense/zeroscope_v2_576w).
Thanks the authors for sharing their awesome codebases!
77 changes: 77 additions & 0 deletions configs/inference_t2v_512_v2.0.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
model:
target: lvdm.models.ddpm3d.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.012
num_timesteps_cond: 1
timesteps: 1000
first_stage_key: video
cond_stage_key: caption
cond_stage_trainable: false
conditioning_key: crossattn
image_size:
- 40
- 64
channels: 4
scale_by_std: false
scale_factor: 0.18215
use_ema: false
uncond_type: empty_seq
use_scale: true
scale_b: 0.7
unet_config:
target: lvdm.modules.networks.openaimodel3d.UNetModel
params:
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions:
- 4
- 2
- 1
num_res_blocks: 2
channel_mult:
- 1
- 2
- 4
- 4
num_head_channels: 64
transformer_depth: 1
context_dim: 1024
use_linear: true
use_checkpoint: true
temporal_conv: true
temporal_attention: true
temporal_selfatt_only: true
use_relative_position: false
use_causal_attention: false
temporal_length: 16
addition_attention: true
fps_cond: true
first_stage_config:
target: lvdm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 512
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder
params:
freeze: true
layer: penultimate
100 changes: 100 additions & 0 deletions lvdm/basics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# adopted from
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
# and
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
# and
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
#
# thanks!

import torch.nn as nn
from utils.utils import instantiate_from_config


def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self

def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module

def scale_module(module, scale):
"""
Scale the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().mul_(scale)
return module


def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")


def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)


def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")


def nonlinearity(type='silu'):
if type == 'silu':
return nn.SiLU()
elif type == 'leaky_relu':
return nn.LeakyReLU()


class GroupNormSpecific(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)


def normalization(channels, num_groups=32):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNormSpecific(num_groups, channels)


class HybridConditioner(nn.Module):

def __init__(self, c_concat_config, c_crossattn_config):
super().__init__()
self.concat_conditioner = instantiate_from_config(c_concat_config)
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)

def forward(self, c_concat, c_crossattn):
c_concat = self.concat_conditioner(c_concat)
c_crossattn = self.crossattn_conditioner(c_crossattn)
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
Loading

0 comments on commit d72a993

Please sign in to comment.