FLUX.1 是由Stable Diffusion模型核心团队重新组建的Black Forest Labs(黑森林实验室)发布的文生图大模型。

FLUX.1在16GB条件下运行

GPU内存16GB的条件下可以运行,使用8bit量化,T5大概需要8GB,FluxTransformer需要14GB,总的内存需要14GB。
只量化了T5和FluxTransformer这两个比较大的模型,CLIP和VAE比较小,可以不量化。
如果使用更少的内存,可以尝试4-bit 或者NF4量化。

在使用Flux的时候,一般设置negative_prompt为空,CFG的值为 1,对应的参数应该是"true_cfg_scale".
参数"guidance_scale"一般设置为大于1,值越大,生成的图像就和prompt越相关,但是以降低图像质量为代价。

from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
from diffusers import FluxTransformer2DModel,FluxPipeline
from transformers import T5EncoderModel
import torch

import gc

def flush():
    gc.collect()
    torch.cuda.empty_cache()

model_id = "/disk2/modelscope/hub/black-forest-labs/FLUX___1-dev"

quant_config = TransformersBitsAndBytesConfig(load_in_8bit=True,)

text_encoder_2_8bit = T5EncoderModel.from_pretrained(
    model_id,
    subfolder="text_encoder_2",
    quantization_config=quant_config,
    torch_dtype=torch.float16,
)

pipe = FluxPipeline.from_pretrained(
    model_id,
    transformer=None,
    vae = None,
    text_encoder_2=text_encoder_2_8bit,
    torch_dtype=torch.float16,
    device_map="balanced",
)

with torch.no_grad():
    prompt = "A cat holding a sign that says hello world"
    negative_prompt = ''
    prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt(prompt,prompt_2=None)
    negative_prompt_embeds,negative_pooled_prompt_embeds, _ = pipe.encode_prompt(negative_prompt,prompt_2=None)

del text_encoder_2_8bit
del pipe
flush()

quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True,)

transformer_8bit = FluxTransformer2DModel.from_pretrained(
    model_id,
    subfolder="transformer",
    quantization_config=quant_config,
    torch_dtype=torch.float16,
)

pipe = FluxPipeline.from_pretrained(
    model_id,
    transformer=transformer_8bit,
    text_encoder = None,
    text_encoder_2=None,
    torch_dtype=torch.float16,
    device_map="balanced",
)

pipe_kwargs = {
    "prompt_embeds":prompt_embeds,
    "negative_prompt_embeds":negative_prompt_embeds,
    "pooled_prompt_embeds":pooled_prompt_embeds,
    "negative_pooled_prompt_embeds":negative_pooled_prompt_embeds,
    "height": 1024,
    "width": 1024,
    "guidance_scale": 3.5,
    "num_inference_steps": 20,
    "max_sequence_length": 512,
}

image = pipe(**pipe_kwargs, generator=torch.manual_seed(0),).images[0]
image.save("flux_cat.jpg")

FLUX结构图
在这里插入图片描述FLUX.1中使用了MM-DiT结构,latent和prompt encoder在经过了modulation后,分别生成QKV,再把两组QKV合并成新的QKV计算attention,然后再拆分成两部分,分别经过MLP和modulation,进入下一个循环。
Single-DiT则是先把latent和prompt encoder合并,然后进入DiT。

encode_prompt

如果prompt_2为None,则等于prompt。
如果prompt_2和prompt发生了冲突的时候,结果会倾向prompt_2。prompt_2比prompt对结果的影响更大。
prompt 经过 CLIP生成的pooled_prompt_embeds 和 timestep,guidance参与了AdaLayerNormZero,modulation。
prompt_2 经过T5生成的prompt_embeds 参与了transformer的直接计算。
在这里插入图片描述

prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2

# We only use the pooled prompt output from the CLIPTextModel
pooled_prompt_embeds = self._get_clip_prompt_embeds(
    prompt=prompt,
    device=device,
    num_images_per_prompt=num_images_per_prompt,
)#(1,768)
prompt_embeds = self._get_t5_prompt_embeds(
    prompt=prompt_2,
    num_images_per_prompt=num_images_per_prompt,
    max_sequence_length=max_sequence_length,
    device=device,
)#(1,512,4096)
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)#(512,3)

prepare_latents

在Flux中,VAE仍然使用8倍压缩,但是channel数为16。
对于生成10241204的图片,latent的shape为(1,16,128,128),然后进行patchify,每个patch为22.变换shape为(1,4096,64)
latent_image_ids 则是3列,第0列为0,第1列为height的序号,第2列为width的序号,同样reshape为(4096, 3).

        # VAE applies 8x compression on images but we must also account for packing which requires
        # latent height and width to be divisible by 2.
        height = 2 * (int(height) // (self.vae_scale_factor * 2))
        width = 2 * (int(width) // (self.vae_scale_factor * 2))

        shape = (batch_size, num_channels_latents, height, width)  #(1,16,128,128)
        latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
        latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) #(1,4096,64)
        latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)

        return latents, latent_image_ids
        
    def _pack_latents(latents, batch_size, num_channels_latents, height, width):
        latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
        latents = latents.permute(0, 2, 4, 1, 3, 5)
        latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
        return latents
        
    def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
        latent_image_ids = torch.zeros(height, width, 3)
        latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
        latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]

        latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape

        latent_image_ids = latent_image_ids.reshape(
            latent_image_id_height * latent_image_id_width, latent_image_id_channels
        )
        return latent_image_ids.to(device=device, dtype=dtype)

Prepare timesteps

        sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
        image_seq_len = latents.shape[1]
        mu = calculate_shift(
            image_seq_len,
            self.scheduler.config.base_image_seq_len,
            self.scheduler.config.max_image_seq_len,
            self.scheduler.config.base_shift,
            self.scheduler.config.max_shift,
        )
        timesteps, num_inference_steps = retrieve_timesteps(
            self.scheduler,
            num_inference_steps,
            device,
            sigmas=sigmas,
            mu=mu,
        )
        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
        self._num_timesteps = len(timesteps)

Denoising

具体的流程还是看图更直观。

                noise_pred = self.transformer(
                    hidden_states=latents,
                    timestep=timestep / 1000,
                    guidance=guidance,
                    pooled_projections=pooled_prompt_embeds,#CLIP
                    encoder_hidden_states=prompt_embeds, #T5
                    txt_ids=text_ids,
                    img_ids=latent_image_ids,
                    joint_attention_kwargs=self.joint_attention_kwargs,
                    return_dict=False,
                )[0]

下面是FluxTransformer2DModel的简化版forward函数,删除了if条件

hidden_states = self.x_embedder(hidden_states) #(1,4096,64) => (1,4096,3072)
#将timestep,guidance,pooled_projections映射到相同维度,然后直接相加。CombinedTimestepGuidanceTextProjEmbeddings
temb = self.time_text_embed(timestep, guidance, pooled_projections)
encoder_hidden_states = self.context_embedder(encoder_hidden_states) #T5
ids = torch.cat((txt_ids, img_ids), dim=0)  #(512+4096,3)
image_rotary_emb = self.pos_embed(ids)  #RoPE

# FluxTransformerBlock x19
	 encoder_hidden_states, hidden_states = block(
	     hidden_states=hidden_states, #(1,4096,3072)
	     encoder_hidden_states=encoder_hidden_states, #(1,512,3072)
	     temb=temb, #(1,3072)
	     image_rotary_emb=image_rotary_emb, #((4608,128),(4608,128))
	     joint_attention_kwargs=joint_attention_kwargs,
	 )
	 
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

# FluxSingleTransformerBlock x38
	hidden_states = block(
	    hidden_states=hidden_states, #(1,4608,3072)
	    temb=temb, #(1,3072)
	    image_rotary_emb=image_rotary_emb,
	    joint_attention_kwargs=joint_attention_kwargs,
	)
	
 hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]

 hidden_states = self.norm_out(hidden_states, temb)
 output = self.proj_out(hidden_states)
   

self.time_text_embed

class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
    def __init__(self, embedding_dim, pooled_projection_dim):
        super().__init__()

        self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
        self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
        self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
        self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")

    def forward(self, timestep, guidance, pooled_projection):
        timesteps_proj = self.time_proj(timestep)
        timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))  # (N, D)

        guidance_proj = self.time_proj(guidance)
        guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))  # (N, D)

        time_guidance_emb = timesteps_emb + guidance_emb

        pooled_projections = self.text_embedder(pooled_projection)
        conditioning = time_guidance_emb + pooled_projections

        return conditioning

FluxTransformerBlock

        norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)

        norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
            encoder_hidden_states, emb=temb
        )
        joint_attention_kwargs = joint_attention_kwargs or {}
        # Attention.
        attention_outputs = self.attn(
            hidden_states=norm_hidden_states,
            encoder_hidden_states=norm_encoder_hidden_states,
            image_rotary_emb=image_rotary_emb,
            **joint_attention_kwargs,
        )

        if len(attention_outputs) == 2:
            attn_output, context_attn_output = attention_outputs
        elif len(attention_outputs) == 3:
            attn_output, context_attn_output, ip_attn_output = attention_outputs

        # Process attention outputs for the `hidden_states`.
        attn_output = gate_msa.unsqueeze(1) * attn_output
        hidden_states = hidden_states + attn_output

        norm_hidden_states = self.norm2(hidden_states)
        norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]

        ff_output = self.ff(norm_hidden_states)
        ff_output = gate_mlp.unsqueeze(1) * ff_output

        hidden_states = hidden_states + ff_output
        if len(attention_outputs) == 3:
            hidden_states = hidden_states + ip_attn_output

        # Process attention outputs for the `encoder_hidden_states`.

        context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
        encoder_hidden_states = encoder_hidden_states + context_attn_output

        norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
        norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]

        context_ff_output = self.ff_context(norm_encoder_hidden_states)
        encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
        if encoder_hidden_states.dtype == torch.float16:
            encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)

        return encoder_hidden_states, hidden_states

FluxAttnProcessor2_0

class FluxAttnProcessor2_0:
    """Attention processor used typically in processing the SD3-like self-attention projections."""

    def __init__(self):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: torch.FloatTensor = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        image_rotary_emb: Optional[torch.Tensor] = None,
    ) -> torch.FloatTensor:
        batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape

        # `sample` projections.
        query = attn.to_q(hidden_states)
        key = attn.to_k(hidden_states)
        value = attn.to_v(hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        if attn.norm_q is not None:
            query = attn.norm_q(query)
        if attn.norm_k is not None:
            key = attn.norm_k(key)

        # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
        if encoder_hidden_states is not None:
            # `context` projections.
            encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
            encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
            encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)

            encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
                batch_size, -1, attn.heads, head_dim
            ).transpose(1, 2)
            encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
                batch_size, -1, attn.heads, head_dim
            ).transpose(1, 2)
            encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
                batch_size, -1, attn.heads, head_dim
            ).transpose(1, 2)

            if attn.norm_added_q is not None:
                encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
            if attn.norm_added_k is not None:
                encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)

            # attention
            query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
            key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
            value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)

        if image_rotary_emb is not None:
            from .embeddings import apply_rotary_emb

            query = apply_rotary_emb(query, image_rotary_emb)
            key = apply_rotary_emb(key, image_rotary_emb)

        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        if encoder_hidden_states is not None:
            encoder_hidden_states, hidden_states = (
                hidden_states[:, : encoder_hidden_states.shape[1]],
                hidden_states[:, encoder_hidden_states.shape[1] :],
            )

            # linear proj
            hidden_states = attn.to_out[0](hidden_states)
            # dropout
            hidden_states = attn.to_out[1](hidden_states)

            encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

            return hidden_states, encoder_hidden_states
        else:
            return hidden_states

FluxSingleTransformerBlock

        residual = hidden_states
        norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
        mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
        joint_attention_kwargs = joint_attention_kwargs or {}
        attn_output = self.attn(
            hidden_states=norm_hidden_states,
            image_rotary_emb=image_rotary_emb,
            **joint_attention_kwargs,
        )

        hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
        gate = gate.unsqueeze(1)
        hidden_states = gate * self.proj_out(hidden_states)
        hidden_states = residual + hidden_states
        if hidden_states.dtype == torch.float16:
            hidden_states = hidden_states.clip(-65504, 65504)

        return hidden_states

参考:
深入浅出完整解析Stable Diffusion 3(SD 3)和FLUX.1系列核心基础知识

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐