基于Qwen2.5-VL 文本⇄视觉注意力热力图的逐帧、逐层可视化:

目标一句话:打开 output_attentions,提取 LLM 各层注意力,对齐图像 token,生成“文本⇄视觉”的逐帧热力图,帮助我们判断模型看对了哪里、理解了哪些词

在这里插入图片描述


0. 实验目标与场景

  • 实验标签"show_llm_hotmap": true

  • 核心目标:实现 compute_text_vision_heatmaps 及其工具函数,读取模型 attentions,分析文本 token图像 token 的相关性,并把每帧、每层的注意力可视化

  • 基线模型:Qwen2.5-VL(可迁移到其它 VLM;迁移时注意 tokenizer 格式、视觉序列切分与注意力实现)。

  • 示例输入(多帧+指令):

    '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n'
    '<|im_start|>user\n<image-1><image-2><image-3>pick the bottle<|im_end|>\n'
    '<|im_start|>assistant\n'
    

1. 方法总览(三步走)

  1. 在模型前向中接出 attentions

    • 打开 output_attentions=True
    • 通过开关 show_llm_hotmap 触发可视化。
  2. 把注意力折算为“文本查询→视觉块”的分数向量

    • 跨 head 平均;
    • 对文本区域做平均池化作为 query(或回退到末 token);
    • 按帧段切图像 token,逐段归一化为 [0,1]
    • 还原成近方形网格,得到热力图。
  3. 逐帧合并全部层的热力图并保存

    • 每帧导出一张“大图”(N×M 小图网格);
    • 每张小图含“原图|叠加热力图”,底部标注 Layer k

1.1 在模型里接线(开关 + 前向)

from utils.heatmaps import compute_text_vision_heatmaps

if self.expt_backbone_config.get("show_llm_hotmap", True):
    outputs = self.language_model(
        inputs_embeds=input_embeds,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_values=past_key_values,
        use_cache=use_cache,               # False
        output_attentions=True,            # ★ 必开
        output_hidden_states=output_hidden_states,
        return_dict=True,
    )
    attns = outputs.attentions

    tv_heatmaps, tv_global_score, tv_matrix = compute_text_vision_heatmaps(
        input_ids=input_ids,
        pixel_values=pixel_values,
        image_sizes=[[224,224],[224,224],[224,224]],  # 便于推断帧数T
        image_token_index=selected,                   # 图像 token 的 bool 掩码
        input_embeds=input_embeds,
        attns=attns,
        # out_dir="/data/out_dir",
    )

产物含义

  • tv_heatmaps: [T_total, Hg, Wg];顺序为「frame0的全部层,frame1的全部层,…」
  • tv_global_score: [T_total];每张小图的全局得分(0–1)
  • tv_matrix: [T_total, P_max];展开后的注意力向量(右侧补零到统一长度)

1.2 热力图主函数


# @torch.no_grad()
def compute_text_vision_heatmaps(
    
    input_ids: torch.LongTensor,             # [B, N]
    pixel_values: torch.FloatTensor,         # 仅用于推断帧数/可视化,不参与注意力计算
    image_sizes: Optional[torch.Tensor],     # [[H,W], ...] 或 [B,T,2];可为 None
    image_token_index,                       # ✅ bool mask [S] / [B,S];也兼容 int 起点(等分)
    input_embeds: Optional[torch.Tensor],    # [B, S, D],直接送入 language_model
    attns: torch.FloatTensor,
    out_dir: Optional[Union[str, Path]] = None,
    ):
    """
    返回:
        heatmaps:      torch.FloatTensor [T_total, Hg_max, Wg_max]
                    # 注意:这里的 T_total = 帧数(T) * 有效层数(L_eff)
                    # 顺序为:frame0 的所有层,frame1 的所有层,...
        global_scores: torch.FloatTensor [T_total]          # 每张小图(某帧某层)的 vec01 均值(0~1)
        attn_matrix:   torch.FloatTensor [T_total, P_max]   # 每张小图的展平注意力(0~1),右侧补零到 P_max
    """
   

    device = input_embeds.device
    B, S, _ = input_embeds.shape

    

    if not isinstance(attns, (list, tuple)) or len(attns) == 0:
        if base_config.get("debug", True):
            print("attentions Tensor is none")
        return torch.empty(0, 1, 1), torch.empty(0), torch.empty(0, 1)

    # 如果图片张数不确定,需要推断输入了几张图,但eagle 的 T_expected = 3
    T_expected = heatmap_tools.infer_T(pixel_values, image_sizes)

    
    per_b_segments = []
    text_starts = []
    TEXT_START = 0

    if isinstance(image_token_index, torch.Tensor) and image_token_index.dtype == torch.bool:
        # 如果 image_token_index 是一个布尔类型的 Tensor
        img_mask = image_token_index
        if img_mask.dim() == 1:
            img_mask = img_mask.unsqueeze(0).expand(B, -1)  # [B,S]
        else:
            assert img_mask.shape == (B, S), f"image_token_index 形状 {tuple(img_mask.shape)} 与 [B,S]=[{B},{S}] 不匹配"
        
        for b in range(B):

            runs = heatmap_tools.get_segments_from_mask(img_mask[b])
            
            text_starts.append(int(runs[-1][0]) if runs else S)
            if T_expected > 0 and len(runs) > T_expected: 
                runs = sorted(sorted(runs, key=lambda x: (x[1]-x[0]+1), reverse=True)[:T_expected], key=lambda x: x[0])
            per_b_segments.append(runs)
        TEXT_START = min(text_starts)  if text_starts else S
    else:
        raise ValueError("image_token_index 必须是 bool mask([S]/[B,S]) ")


    TEXT_START = 786  # ==== 文本查询掩码: 默认取 [786, S) 当作文本,是根据eagle text input 格式确定的 ====
    
    
    # 去除来自 tokenizer.config 的额外特殊 token
    text_mask = torch.zeros((B, S), dtype=torch.bool, device=device)
    if TEXT_START < S:
        text_mask[:, TEXT_START:S] = True
    special_mask = heatmap_tools.build_special_mask(input_ids)
    text_mask = text_mask & (~special_mask)  # [B,S]
    has_text = text_mask.any(dim=-1)        # [B]
    
    

    # ====  遍历“全部层 × 全部帧段”生成热力图 ====
    heatmaps_list = []
    scores_list   = []
    vec_list      = []
    Hg_max = Wg_max = P_max = 1

    for layer_idx, att in enumerate(attns):
        # 跨 head 平均 -> [B,S,S]
        attn_mean = att.mean(dim=1)

        # 预先计算每个 batch 的 head_avg(文本 query 聚合)
        head_avg_list = []
        for b in range(B):
            if has_text[b]:
                
                count = text_mask[b].sum().clamp(min=1)
                head_avg_b = (attn_mean[b] * text_mask[b].to(attn_mean.dtype).unsqueeze(-1)).sum(dim=0) / count
            else:
                q_idx = S - 1        # eles final token
                head_avg_b = attn_mean[b, q_idx, :]
            head_avg_list.append(head_avg_b)
        head_avg = torch.stack(head_avg_list, dim=0)  # [B,S]

        # 针对该层,逐帧段生成热图
        for b in range(B):
            for (s0, e0) in per_b_segments[b]:
                s = max(0, min(s0, S - 1))
                e = max(0, min(e0, S - 1))
                if e < s:
                    continue

                vec = head_avg[b, s:e+1]  # [Pi]
                Pi = int(vec.numel())
                if Pi == 0:
                    continue

                # 归一化到 [0,1]
                vmin, vmax = vec.min(), vec.max()
                denom = (vmax - vmin).clamp(min=1e-6)
                vec01 = (vec - vmin) / denom

                # 近方形网格
                Hg = int(math.sqrt(Pi)) or 1
                Wg = int(math.ceil(Pi / max(Hg, 1)))
                P  = Hg * Wg

                if P > Pi:
                    pad = torch.zeros(P - Pi, dtype=vec01.dtype, device=vec01.device)
                    vec01_pad = torch.cat([vec01, pad], dim=0)
                else:
                    vec01_pad = vec01

                grid = vec01_pad.view(Hg, Wg).detach().cpu().float()
                heatmaps_list.append(grid)
                scores_list.append(vec01.mean().item())
                vec_list.append(vec01.detach().cpu().float())

                Hg_max = max(Hg_max, Hg)
                Wg_max = max(Wg_max, Wg)
                P_max  = max(P_max, Pi)

    if len(heatmaps_list) == 0:
        return torch.empty(0, 1, 1), torch.empty(0), torch.empty(0, 1)

    # 对齐尺寸并堆叠
    padded_grids = []
    for g in heatmaps_list:
        pad_h = Hg_max - g.shape[0]
        pad_w = Wg_max - g.shape[1]
        g_pad = F.pad(g, (0, pad_w, 0, pad_h), value=0.0)
        padded_grids.append(g_pad)
    heatmaps = torch.stack(padded_grids, dim=0)  # [T_total, Hg_max, Wg_max]

    padded_vecs = []
    for v in vec_list:
        if v.numel() < P_max:
            pad = torch.zeros(P_max - v.numel(), dtype=v.dtype)
            v = torch.cat([v, pad], dim=0)
        padded_vecs.append(v)
    attn_matrix   = torch.stack(padded_vecs, dim=0)                # [T_total, P_max]
    global_scores = torch.tensor(scores_list, dtype=torch.float32) # [T_total]
    
    
    
    
    heatmap_tools.save_text_vision_heatmap_images(
            pixel_values=pixel_values,
            heatmaps = heatmaps,
            out_dir=out_dir,
            file_prefix="eagle_llm_attn",
            alpha=0.4,                                 # 叠加透明度
        )





    return heatmaps, global_scores, attn_matrix




compute_text_vision_heatmaps(...) 的核心步骤:

  1. 推断帧数 T:优先从 image_sizes,其次从 pixel_values 形状(支持 [B,T,C,H,W])。

  2. 切分图像 token 段:从 image_token_index(bool 掩码)提取连续的 True 区间 (s,e),即每帧对应的 token 范围。

  3. 确定文本区域

    • 代码里默认 TEXT_START=786(基于 eagle 输入格式);
    • 结合 tokenizer_config.json 过滤特殊符号,得到 text_mask
  4. 跨 head 平均:把 attns[layer][B, H, S, S] 聚合为 [B, S, S]

  5. 构造“文本查询→全序列”的注意力分布

    • 若有文本:对 text_mask 区域做均值池化得到 query;
    • 若无文本:回退到 S-1(最后一个 token)。
  6. 逐帧段取出视觉向量、归一化、铺成网格

    • 每段得到 vec01 ∈ [0,1]
    • 近方形铺砖(Hg×Wg),不足补零;
    • 记录全局得分(vec01.mean())。
  7. 尺寸对齐 & 叠堆:把所有小图堆成 heatmaps,向量堆成 attn_matrix

  8. 保存可视化:调用 save_text_vision_heatmap_images(...),按每帧导出大图

小提示:目前 TEXT_START 被硬编码为 786。若迁移到其它 VLM,建议把它做成入参或从 tokenizer 自动推断,避免错位。


1.3 注意力实现的关键配置(拿不到 attentions 的元凶)

Transformers 会根据配置自动选择注意力内核。若:

  • attn_implementation: null
  • use_flash_attention: true(或自动走 FA2/SDPA)

可能不会返回每头注意力(只返回上下文),导致 outputs.attentions 为空。

解决办法(固定为 eager):

config = Eagle2_5_VLConfig.from_pretrained(DEFAULT_EAGLE_PATH)

if self.expt_backbone_config.get("show_llm_hotmap", True):
    setattr(config, "_attn_implementation", "eager")       # ★ 强制常规实现
    setattr(config, "_attn_implementation_autoset", False)  # ★ 禁止自动切回 FA/SDPA

self.eagle_model = Eagle2_5_VLForConditionalGeneration(config)

一图流理解

  • FlashAttention/SDPA:快、省显存;默认不保留每头 attn_probs
  • eager(matmul+softmax):标准实现,保留并返回注意力权重。

延伸阅读:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (2022)


2. 结果解读(看什么、怎么用)

  • 大图 = 单帧 × 全层:每张大图对应一帧;网格中每个 tile 是 原图 | 叠加热力图
  • 颜色强度:越红说明文本查询(如 “open”“door”“microwave” 等)对该视觉区域越关注。
  • 层级差异:浅层关注局部细节,高层更偏向语义区域;可以快速定位“模型究竟把哪块区域当成‘门把手/门框/开门动作’的证据”。
  • tv_global_score:可用作帧-层级的全局相关性评分(0–1),辅助筛选“最有代表性”的头图或层图。
  • tv_matrix:保留了各帧-各层的展平注意力向量,便于后续做统计、聚类或检索(例如,找到“看门把手”的层/帧集合)。

3. 工具函数概览(模块化职责)

为简洁起见,只列职责与接口,实现与主函数放在同一 heatmaps.py

  • heatmap_tools.infer_T(pixel_values, image_sizes)
    推断帧数 T(优先 image_sizes 其后看张量维度)。

  • heatmap_tools.get_segments_from_mask(mask_1d)
    [S] 的 bool 掩码转成若干连续 True 段 [(s,e), ...]

  • heatmap_tools.build_special_mask(input_ids, cfg_path=None)
    读取 tokenizer_config.jsonadded_tokens_decoder,过滤特殊 token。

  • heatmap_tools._pixel_values_to_rgb_images(pixel_values)
    pixel_values 规范化为 List[H×W×3, uint8, RGB],支持 [B,T,C,H,W] / [C,H,W]等多形态。

  • heatmap_tools.save_text_vision_heatmap_images(pixel_values, heatmaps, out_dir, file_prefix, alpha)
    每帧导出一张大图(内含该帧全部层的小图),小图底部标注 Layer k
    默认输出目录:DEFAULT_OUT_DIR = <PROJECT_ROOT>/vis_attns


4. 常见坑位与规避

  • 拿不到 attentions
    99% 是因为自动走了 FA/SDPA。按 1.3 的配置固定为 eager,并设置 _attn_implementation_autoset=False

  • image_token_index 形状不对
    需要 bool 掩码,支持 [S][B,S]。若是 [S] 会自动扩展到批维。

  • 文本起始位 TEXT_START 不匹配
    代码里默认 786(eagle 格式)。迁移到其它 VLM,请从 tokenizer 的实际拼接规则自动推断或显式传入,避免文本/视觉错位。

  • 特殊符号未过滤导致文本池化偏移
    确认 TOKENIZER_CONFIG_PATH 指向正确的 tokenizer_config.json,并让 build_special_mask 生效。

  • 输出过大
    导图时已对 tile 做缩放;如仍太大,可适当调低 thumb_w 或减少每帧展示的层数。


5. 迁移到其它 VLM 的要点清单

  • ✅ 能打开 output_attentions=True
  • ✅ 明确图像 token 的划分规则(多帧如何串接、每帧 token 长度);
  • ✅ 明确文本段起始TEXT_START 或自动推断);
  • 关闭/绕过 FlashAttention/SDPA,改用 eager
  • ✅ 更新 tokenizer_config.json 路径与格式以支持特殊符号过滤

6, 工具函数代码补充(可直接套用)

这部分代码和可视化函数共同放在一个heatmap.py 里,通过在forward 函数里调用 即可实现可视化。

```python
# --------------------------------------------------------
# NVIDIA
# Copyright (c) 2025 NVIDIA
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------

from __future__ import annotations

import math
import json
import os
from math import ceil, sqrt
from datetime import datetime, timezone, timedelta
from pathlib import Path
from typing import Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
import cv2


from experiment.expt_config import ExptConfig

base_config = ExptConfig().base_config()

PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.realpath(projetname.__file__))) 
TOKENIZER_CONFIG_PATH = os.path.join(PROJECT_ROOT ,"model", "backbone", "eagle2_hg_model","tokenizer_config.json")
DEFAULT_OUT_DIR  =  os.path.join(PROJECT_ROOT, "vis_attns")






class heatmap_tools():
    # ==== 2) 推断帧数(仅用于切分 image token 段) ====
    @staticmethod
    def infer_T(_pixel_values, _image_sizes):
        if _image_sizes is not None:
            if isinstance(_image_sizes, torch.Tensor):
                if _image_sizes.dim() == 3:   # [B,T,2]
                    return int(_image_sizes.shape[1])
                if _image_sizes.dim() == 2:   # [T,2]
                    return int(_image_sizes.shape[0])
            elif isinstance(_image_sizes, (list, tuple)) and len(_image_sizes) > 0:
                if isinstance(_image_sizes[0], (list, tuple)):
                    return len(_image_sizes)
        if isinstance(_pixel_values, torch.Tensor):
            if _pixel_values.dim() == 5:      # [B,T,C,H,W]
                return int(_pixel_values.shape[1])
            if _pixel_values.dim() in (3, 4): # [C,H,W] / [B,C,H,W] / [T,C,H,W]
                return 1
        return 1


    # ====  图像 token 掩码 -> 连续 True 段 ====
    @staticmethod
    def get_segments_from_mask(mask_1d: torch.Tensor):
        """给定 [S] 的 bool 掩码,返回连续 True 段 [(s,e), ...](闭区间)。"""
        m = mask_1d.to(torch.bool)
        if m.numel() == 0:
            return []
        mb = m.to(torch.uint8)
        pad = F.pad(mb, (1, 1), value=0)
        diff = pad[1:] - pad[:-1]
        starts = torch.nonzero(diff == 1, as_tuple=False).squeeze(1)
        ends   = torch.nonzero(diff == -1, as_tuple=False).squeeze(1) - 1
        runs = [(int(s.item()), int(e.item())) for s, e in zip(starts, ends)]
        return runs

    @staticmethod
    def build_special_mask(input_ids, cfg_path=None):
        
        p = cfg_path or os.getenv("TOKENIZER_CONFIG_PATH") 
        try:
            cfg = json.load(open(p, "r", encoding="utf-8"))
            ids = [int(k) for k, v in cfg.get("added_tokens_decoder", {}).items() if v.get("special", False)]
        except Exception:
            ids = []
        return torch.isin(input_ids, torch.tensor(sorted(set(ids)), device=input_ids.device)) if ids else torch.zeros_like(input_ids, dtype=torch.bool)
    

    # ==== 将 pixel_values 转成 List[np.ndarray(H, W, 3)] (RGB, uint8) ====
    @staticmethod
    def _pixel_values_to_rgb_images(pixel_values):
        """
        将 pixel_values 转成 List[np.ndarray(H, W, 3)] (RGB, uint8)
        支持输入:
        - torch.Tensor 或 np.ndarray
        - 形状 [B,C,H,W], [B,T,C,H,W], [T,C,H,W], [C,H,W]
        - 值域 任意(逐图 min-max 到 [0,255],避免全黑)
        """
        
        if isinstance(pixel_values, np.ndarray):
            t = torch.from_numpy(pixel_values)
        else:
            t = pixel_values
        t = t.detach().to(dtype=torch.float32, device="cpu")

        # 标准化形状为 [N,C,H,W]
        if t.ndim == 5:        # [B,T,C,H,W] -> [B*T,C,H,W]
            B, T, C, H, W = t.shape
            t = t.reshape(B * T, C, H, W)
        elif t.ndim == 4:      # [B,C,H,W] 或 [T,C,H,W]
            pass
        elif t.ndim == 3:      # [C,H,W] -> [1,C,H,W]
            C, H, W = t.shape
            t = t.reshape(1, C, H, W)
        else:
            raise ValueError(f"Unsupported pixel_values shape: {tuple(t.shape)}")

        # [N,H,W,C]
        arr = t.permute(0, 2, 3, 1).contiguous()  # float32

        # 逐图 min-max 到 [0,255]
        vmin = arr.amin(dim=(1, 2, 3), keepdim=True)
        vmax = arr.amax(dim=(1, 2, 3), keepdim=True)
        denom = (vmax - vmin).clamp_min(1e-6)
        arr = (arr - vmin) / denom * 255.0
        arr = arr.clamp(0, 255).to(torch.uint8).numpy()  # uint8

        imgs = [arr[i] for i in range(arr.shape[0])]  # RGB
        return imgs


    # ========== 保存热图(同一帧的不同层自动合并成 1 张大图,并在每张小图下方标注层号)==========
    @staticmethod
    def save_text_vision_heatmap_images(
        
        pixel_values,
        heatmaps,
        out_dir: Optional[Union[str, Path]] = None,
        file_prefix: str = "eagle_llm_attn",
        alpha: float = 0.5,
    ):
        """
        将全部层的热力图叠加到原图上,按“同一帧的所有层”为单位合成网格大图保存。
        约定:compute_text_vision_heatmaps 返回的 heatmaps 顺序是
            [frame0_layer0, frame0_layer1, ..., frame0_layerL-1,
            frame1_layer0, ..., frame1_layerL-1, ...]
        参数:
            pixel_values: 同模型输入,可为 torch.Tensor 或 np.ndarray。
            heatmaps    : torch.Tensor/np.ndarray,形状 [T_total,Hg,Wg] 或 [Hg,Wg]。
            out_dir     : 输出目录。
            file_prefix : 文件名前缀。
            alpha       : 叠加透明度(0~1)。
        """
        out_dir_effective = Path(out_dir) if out_dir else Path(DEFAULT_OUT_DIR)
        os.makedirs(out_dir_effective, exist_ok=True)

        # 统一热图形状为 [T_total,Hg,Wg]
        if isinstance(heatmaps, np.ndarray):
            h = torch.from_numpy(heatmaps)
        else:
            h = heatmaps
        h = h.detach().cpu().float()
        if h.dim() == 2:
            h = h.unsqueeze(0)
        elif h.dim() != 3:
            raise ValueError(f"[tvattn] heatmaps 形状不支持: {tuple(h.shape)}(期望 [T,H,W] 或 [H,W])")

        # 原图转 RGB uint8 列表
        imgs = heatmap_tools._pixel_values_to_rgb_images(pixel_values)# List[np.uint8(H,W,3)]
    
        num_frames = len(imgs)
        T_total = h.shape[0]
        if num_frames == 0 or T_total == 0:
            return

        # 推断每帧层数(整数除法;若除不尽则最后一帧按可用数量)
        layers_per_frame = max(1, T_total // num_frames)

        ts = datetime.now(timezone(timedelta(hours=8))).strftime("%Y%m%d_%H%M%S")

        for i in range(num_frames):
            img_rgb = imgs[i]                        # (H,W,3), RGB, uint8
            H, W = img_rgb.shape[:2]
            img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)

            # 取出该帧的全部层热图
            start = i * layers_per_frame
            end   = min((i + 1) * layers_per_frame, T_total)
            this_layers = h[start:end]               # [L_i, Hg, Wg]
            L_i = this_layers.shape[0]

            tiles = []
            for l in range(L_i):
                heat = this_layers[l].numpy()        # (Hg,Wg), float
                # 归一化到 [0,1]
                hmin, hmax = float(heat.min()), float(heat.max())
                denom = (hmax - hmin) if (hmax - hmin) > 1e-6 else 1e-6
                heat01 = (heat - hmin) / denom

                # resize 到原图大小
                heat_resized = cv2.resize(heat01, (W, H), interpolation=cv2.INTER_CUBIC)
                heat_u8 = (heat_resized * 255.0).clip(0, 255).astype(np.uint8)

                # 伪彩色并叠加
                heat_color_bgr = cv2.applyColorMap(heat_u8, cv2.COLORMAP_JET)
                overlay_bgr = cv2.addWeighted(img_bgr, 1.0 - float(alpha), heat_color_bgr, float(alpha), 0.0)

                # 生成“原图|叠加”的小图(更直观)
                side_bgr = np.concatenate([img_bgr, overlay_bgr], axis=1)   # [H, 2W, 3]

                # 在底部加一条白色字幕条并写入层号
                caption_h = max(28, H // 20)
                caption = np.full((caption_h, side_bgr.shape[1], 3), 255, dtype=np.uint8)
                label = f"Layer {l}"
                cv2.putText(caption, label, (12, caption_h - 8), cv2.FONT_HERSHEY_SIMPLEX,
                            0.7, (0, 0, 0), 2, lineType=cv2.LINE_AA)

                tile = np.concatenate([side_bgr, caption], axis=0)          # [H+cap, 2W, 3]
                tiles.append(tile)

            # 网格排版
            cols = int(ceil(sqrt(L_i)))
            rows = int(ceil(L_i / cols))

            # 缩放每个 tile,避免超大图(按宽度到 640 像素)
            thumb_w = 640
            scaled_tiles = []
            for t in tiles:
                h_t, w_t = t.shape[:2]
                scale = thumb_w / float(w_t)
                t_resz = cv2.resize(t, (thumb_w, max(1, int(h_t * scale))), interpolation=cv2.INTER_AREA)
                scaled_tiles.append(t_resz)

            # 补空白填满网格
            tile_h = max(t.shape[0] for t in scaled_tiles)
            tile_w = max(t.shape[1] for t in scaled_tiles)
            blank  = np.full((tile_h, tile_w, 3), 255, dtype=np.uint8)

            grid = []
            idx = 0
            for r in range(rows):
                row_imgs = []
                for c in range(cols):
                    if idx < len(scaled_tiles):
                        t = scaled_tiles[idx]
                        # 居中贴到标准 tile 尺寸上
                        pad = blank.copy()
                        y0 = (tile_h - t.shape[0]) // 2
                        x0 = (tile_w - t.shape[1]) // 2
                        pad[y0:y0+t.shape[0], x0:x0+t.shape[1]] = t
                        row_imgs.append(pad)
                        idx += 1
                    else:
                        row_imgs.append(blank.copy())
                grid.append(np.concatenate(row_imgs, axis=1))
            big_img = np.concatenate(grid, axis=0)

            out_path = os.path.join(out_dir_effective, f"{file_prefix}_{ts}_frame{i+1}_L{L_i}.png")
            cv2.imwrite(out_path, big_img)
            if base_config.get("debug", True):
                print(f"[vis_attn] saved: {out_path}")


Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐