提示:这里专门写得是针对适用于YOLO的注意力模块写法,对于所有人工智能框架在原理层面是没有任何的,但是代码方面仅使用于YOLO

一、【注意力】简单概念

        由于我自己也是小白,非人工智能科班生临时学的,所以我也不深究其中的深层原理,大致了解一下是个什么东西就行

        首先【注意力】顾名思义,就是让计算机专注于图像中我们关心的目标物体,而忽略周边的无关像素,而做法就是要调各个像素的参数权重,让专注的像素的参数权重高、其他的参数权重低

这就涉及了两大类注意力:【Transformer 类注意力机制】VS【卷积类注意力模块】

1、【Transformer 类注意力机制】

        它的原理就是用下图这个公式得到结果:它的核心思想是通过对输入的每一个元素(Q、K、V)进行加权和计算,进而获得最相关的部分进行重点关注。

  • 这个公式里涉及了3个参数:Q、K、V
  • 大致结构如图,看一眼就够了,我只是想显得专业一点而已.....

        至于我为什么要介绍这个【Transformer 类注意力机制】,是因为它跟我们YOLO目标检测要用到的注意力模块暂时无需关心,他最常用于NLP大语言模型,虽然现在计算机视觉领域也开始称王,但是作为初学者,像我这种SB很有可能会花时间去了解它,而到最后啥也没学会

        这里我只是为了告诉各位【Transformer 类注意力机制】和【卷积类注意力模块】是两个不一样的东西,前者我们先不用,后者才是我们要用的

2、【卷积类注意力模块】

        卷积类注意力通常是指在卷积神经网络中通过局部感受野进行关注。与Transformer不同,卷积操作是基于局部区域的特征进行提取,因此它对局部空间的注意力更加专注,而不是全局的信息交互。卷积类注意力的计算方式更偏向于局部感知,适用于空间结构数据(如图像),而Transformer类注意力则更擅长处理长距离依赖的全局信息。当然这也决定了卷积类注意力只适用于计算机视觉,而非NLP大语言模型。

3、二者对比

  • Transformer类注意力
    • 全局关注:可以通过【全局】信息建模,适合处理NLP语言的【长距离依赖】。
    • 计算量大:由于需要计算所有输入元素之间的关系,计算开销相对较高。
    • 灵活性强:能够处理任何类型的输入序列数据(如文本、语音等)。
  • 卷积类注意力
    • 局部关注:主要聚焦于【局部】区域特征,只适用于【图像等】数据。
    • 计算效率高:由于卷积操作具有局部性,计算效率比Transformer高。
    • 不善处理长距离依赖:卷积结构不适合直接捕捉远距离信息,可能需要多层卷积来近似全局关系。

二、YOLO适用的卷积类注意力

1、第一代:SEnet模块

  • 结构:【平均池化压缩】+【2全连接层 + Sigmoid()函数】+【结果乘上输入权重】
  • 特点:
    • ✔ 只做通道注意力
    • ✔ 结构极简
    • ❌ 不知道空间位置
  • 效果:我自己对一个数据集进行了一次【不加SE】和【加了SE】训练
    • 【不加SE】得效果
      • 整体效果虽然不错,但是mAP50在69%、mAP50~95在49%左右,都还不够高
    • 【加了SE】得效果
      • 很顶级:明显mAP50提高了很多!!说明IoU高了,也就是预测框的定位更准
      • 另外注意看 验证集的损失值 也降低了很多!!

2、第二代:CBAM模块(重点,实用)

        那么YOLO最常见的注意力模块就是【CBAM注意力模块】,是因为对于微小目标检测,它更轻量、即插即用,专门强化小目标特征!!!        

        CBAM 的全称是 Convolutional Block Attention Module(卷积块注意力模块),是专门为卷积网络设计的轻量注意力模块(算力消耗低,适合大部分电脑)。
        它的核心逻辑是:从两个维度给特征 “加权”—— 先选对【通道】再选对【位置】,两步操作都针对小目标优化。

1)通道注意力(Channel Attention)—— 选 “有用的特征通道”

2)空间注意力(Spatial Attention)—— 选 “有用的像素位置”

        还是以“玉米雄穗”为例子,经过通道注意力后,模型已经聚焦到 “绿色纹理” 这类特征,但还不知道这些纹理在图片的哪个位置(是雄穗的位置,还是叶片的位置?)

        空间注意力的作用:给雄穗所在的像素位置加权重,给背景位置降权重 —— 让模型精准定位小目标的位置。

总结:

以玉米雄穗数据集为例子,大致流程就是:

  • 输入特征图(抓极小雄穗) ↓
  • 通道注意力:选“绿色纹理”通道 → 加权 ↓
  • 空间注意力:选“雄穗位置”像素 → 加权 ↓
  • 输出加权后的特征图(雄穗特征被放大,背景被抑制)
3)比【SE】优秀的原因,为什么我们要选他

CBAM融合了 “通道注意力 + 空间注意力”,并构成了一段包含池化、全连接、卷积的代码。对于他提出的论文里也和SE做了详细对比:

  • 在加入【平均池化】+【最大池化】后,对比效果最好
  • 加入【通道注意力】+【空间注意力】
    • 而且先【通道】再【空间】的顺序,对比效果最好
  • 以及热力图效果对比

  • 另外,YOLO官方还提供了拆开的【通道注意力ChannelAttention】、【空间注意力SpatialAttention】
  • 效果:但是我经过亲自实验,【单独用这两个】注意力模块的效果比【整个CBAM的:先“通道” 再“空间”】的效果差!!!
    • 单独用【空间注意力SpatialAttention】的效果
      • 甚至还不如【SE模块】的效果。。。。拉完了
    • 整个用【CBAM】的效果:
      • 夯爆了:明显这次mAP比【单独用空间注意力】、或用【SE模块】得高!!!另外看损失指标也降低了很多(看图表纵坐标数据,别只看曲线)

3、第三代:ECA模块

        SENet中降维会给通道注意力机制带来副作用,并且捕获所有通道之间的关系效率不高,且非必要。ECA注意力机制模块直接在全局平均池化层后使用1*1卷积,去除全连接层。避免了维度锐减,有效的捕获了跨通道交互。

  •  实现步骤如下:
    • (1)将输入特征图经过全局平均池化,特征图从[h,w,c]的矩阵变成[1,1,c]的向量。
    • (2)根据特征图的通道数计算得到自适应的一维卷积核大小kernelsize。
    • (3)将kernelsize用于一维卷积中,得到对于特征图的每个通道的权重。
    • 将归一化权重和原输入特征图逐通道相乘,生成加权后的特征图。
      • 说白了就两个人话:
        • 1、不降维
        • 2、专注【局部】通道特征交互融合
          • 特点:
            • ✔ 更轻量
            • ✔ 参数更少
            • ✔ 效果接近甚至超过CBAM
  • 实际效果:
    • 那么经过我换成【ECA模块】后得到的实验结果,竟然和【CBAM】一模一样,基本数据完全一样,那就说明【ECA】≈【CBAM】基本没区别

4、第四代:Coordinate Attention模块

        现有的注意力机制(如CBAM、SE)在求取通道注意力的时候,通道的处理一般是采用全局最大池化/平均池化,这样会损失掉物体的空间信息。Coordinate Attention模块作者期望在引入通道注意力机制的同时引入空间注意力机制,作者提出直接将注意力机制将位置信息嵌入到了通道注意力中

  • 步骤如下:

    • 将输入特征图分别在为宽度和高度两个方向分别进行全局平均池化,分别获得在宽度和高度两个方向的特征图。

      • 假设输入进来的特征层的形状为[C,H,W],在经过宽方向的平均池化后,获得的特征层shape为[C,H,1],此时我们将特征映射到了高维度上;

      • 在经过高方向的平均池化后,获得的特征层shape为[C,1,W],此时我们将特征映射到了宽维度上。

    • 然后将两个并行阶段合并,将宽和高转置到同一个维度,然后进行堆叠,将宽高特征合并在一起,此时我们获得的特征层为:[C,1,H+W],利用卷积+标准化+激活函数获得特征。

    • 之后再次分开为两个并行阶段,再将宽高分开成为:[C,1,H]和[C,1,W],之后进行转置。获得两个特征层[C,H,]和[C,1,w]。

    • 然后利用1x1卷积调整通道数后取sigmoid获得宽高维度上的注意力情况。乘上原有的特征就是CA注意力机制。​​​​​​​

      • 特点:
        • ✔ 结合位置信息
        • ✔ 更适合检测任务
  • 实际效果:
    • 我个人的实验证明,效果跟【CBAM】和【ECA】没有太大区别
    • 而且略微在小数级别比他两低一点,给到一个人上人吧

5、总结汇总

        这里还有一个C2PSA,yolo加入的一个官方注意力模块,我好奇试了一下,效果就是一坨屎,后来我才知道这玩意只适用在backbone处加,而我们一般加注意力、调整结构就调head就好了,所以不推荐C2PSA,也不过多介绍它

三、代码实操

首先讲一下写一个模块的套路(不管是注意力模块、C3k2、上采样下采样...等),无非就是:

规定模板套路:

  • def __init__(self, ......):初始化这个模块
    • 先继承父类:super().__init__()
    • 然后分别针对不同模块,创建内部每一个细节结构
  • def forward(self, ......):规定前向传播的流程
    • 根据前面__init__定义的结构,返回权重结果w
      • 一般全连接层的流程是【嵌套函数】,比如A—>B—>C,代码就是C ( B( A( w ) ) )

然后yolo的【nn】提供的一些便捷函数:

  • nn.AdaptiveAvgPool2d(要压缩成的维度):平均池化
  • nn.Conv2d():创建卷积
    • 一般就6个参数(输入通道, 输出通道, kernel_size, stride, padding, bias)
      • - 输入通道数(输入特征图有几层“通道”)。
      • - 输出通道数(卷积后得到几层通道)。
      • - 卷积核大小(如 1 表示 1×1,只做通道混合、不改变 H,W;3 或 7 表示 3×3、7×7)。
      • - stride:步长,默认 1(不缩小尺寸)。
      • - padding:四周补零圈数,通常取 (kernel_size-1) // 2 使 H、W 不变。
      • - bias:是否加偏置,True/False
  • nn.Linear():全连接层(线性变换)
    • 对每个样本做「输入维 → 输出维」的矩阵乘法
    • 一般三个函数(输入通道, 输出通道, bias)
      nn.Linear() nn.Conv2d()
      输入形状 (B, in_channels),没有 H、W (B, 2, H, W),有高、宽
      输出形状 (B, out_channels) (B, 1, H, W),还是 H×W
      在做什么 对整根向量做线性组合(全局、无空间) 对每个位置用一个小窗口做线性组合(局部、有空间)
      “2 / 1” 的含义 不适用 2 个通道进 → 1 个通道
  • nn.损失函数名字:调用现成写好的各种损失函数
    • 例如:nn.ReLU()、nn.SiLU().......
  • nn.Sequential:就是一个按顺序执行的模块容器:
    • 输入依次通过第 1 个、第 2 个、第 3 个.....子模块,得到输出。不用自己写 forward。
    • 例如:
      • out = 输入

        out = 第1层(out)

        out = 第2层(out)

        out = 第3层(out)

        return out

      • nn.Sequential = “把好几层按顺序串起来,当一个小网络用”;

1、SEnet模块

1)导入【SE】模块

        由于yolo11官方代码没有写SE模块,不知道是觉得没用废除了还是为啥,这里我们只能自己手写一个,记住SE的结构:

  • 1、全局平均池化:把每个通道C对应的2维数据用平均数压缩成一个1维数字
  • 2、2个卷积夹一个损失函数:fc1 —> 损失函数 —> fc2
    • 注意损失函数可以是ReLU也可以是SiLU,SiLU计算精度更高
    • 另外注意:1×1 卷积等价于“按通道做线性变换”,大家也习惯叫 “全连接”
  • 3、最后接一个Sigmoid()函数

(具体代码我会放在在最后,可以直接复制拿去直接用)

【注意:一定要做!!!!】

  • 自定义模块还需要从【train.py】前面注入:
    • 对应task.py那,你注册的当前train.py路径为全局环境变量“XXX”,然后task.py才能获取XXX的时候读到train.py路径
      import sys
      import os
      from pathlib import Path
      # 必须最先做:把项目根加入 path,后面 from utils.custom_modules 才能找到
      # 前面的 _ 不是语法要求,只是命名习惯,用来表达“内部用 / 不想被当公开接口”的意思
      _你设置的全局环境变量 = Path(__file__).resolve().parent
      if str(_你设置的全局环境变量) not in sys.path:
          sys.path.insert(0, str(_你设置的全局环境变量))
      # 让 DDP 子进程继承到
      os.environ["你设置的全局环境变量"] = str(_你设置的全局环境变量)
  • task.py能读到train.py后,再把我们自定义的模块注入​​​​​​​
    # 让 yaml 里能使用自定义的 SE、ECA、CoordAtt(必须在 import YOLO 之前注入到 tasks)
    import ultralytics.nn.tasks as _ultra_tasks
    from utils.custom_modules import SE, ECA, CoordAtt
    _ultra_tasks.SE = SE
    _ultra_tasks.ECA = ECA
    _ultra_tasks.CoordAtt = CoordAtt

2)yaml网络结构插入

提醒:【注意力模块】不改变通道数!!!!!!!

那么我们应该插到哪?怎么插呢?

  • 1、插入到【捕捉最大尺寸】检测头的【上采样 + concat + C3k2】的【前面】
    • 在这插注意力,就是在 “分流” 给检测头之前,先给 “水流” 做一次 “过滤”和“加压”。如果最大尺寸层的特征被精炼了,那么下游的检测头拿到的东西质量都会变高。
    • 说人话就是:我们要在最大尺寸的图片中用【注意力】捕捉细节!!!
    • 比如:你当前是3个检测层,那就插入到P3的【上采样 + concat + C3k2】的【前面】(就是第14层)
    • 比如:你当前是4个检测层,那就插入到P2的【上采样 + concat + C3k2】的【前面】(就是第17层)
  • 2、SE后面只用填一个参数:[16]
    • 代表的是【压缩比:reduction=16】,不用管为什么,这就是默认的

2、CBAM模块

        因为YOLO官方有相关代码,只不过没有引入,所以我们可以用官方的代码引入、也可以自己写,两种方式使用!

1)导入【CBAM】模块

【直接用官方给的CBAM模块(推荐)】

去到你的ultralytics安装路径下:

  • windows本地路径:
    • 去你pycharm下面点开 “外部库 / site-packages / ultralytics / nn / modules”,
    • 或者在我的电脑的 “你的conda安装目录 / envs / 你当前虚拟环境目录 / Lib / site-packages / ultralytics / nn / modules
  • 如果用autodl服务器的linux系统下路径:“ / root / miniconda3 / envs / 你的虚拟环境目录 / lib / pythonx.x / site-packages / ultralytics / nn / modules

  • 查找【conv.py】有无该模块
    • 找到ultralytics的 nn / models目录后,可以看到一个【conv.py】文件,这个文件就是定义了各个模块,我们可以【Ctrl + F】查找一下有没有CBAM这个模块
  • 如果有CBAM模块,那么我们继续在ultralytics的 nn目录下找到【task.py】文件
    • 然后在代码前面 from ultralytics.nn.modules import (......)里面加上【CBAM】,这样task.py才可以用到ultralytics的 nn / models / conv.py的CBAM模块
    • 然后还是在task.py,【Ctrl + F】查找一下 “elif m is”就能找到下图这么一块地方
      • 然后插入一段代码,一定要补上不然会报错!!!:
        elif m is CBLinear:
            c2 = args[0]
            c1 = ch[f]
            args = [c1, c2, *args[1:]]
【自定义手写CBAM模块(想懂原理可以试试)】 

有的ultralytics包可能特殊点,conv.py里没有CBAM这玩意,那就只能我们自己手写了

        回到上面我给的CBAM结构图,只要记住了CBAM的结构,就可以创建一个CBAM模块了(就是背 “前人” 定义好的结构就行):

  • 1、通道注意力:
    • 平均池化 + 最大池化 + 2个全连接层夹1个损失函数
  • 2、空间注意力:
    • 1个卷积 + sigmoid激活函数
  • 最后,在【train.py】代码注入CBAM
    • 因为我们写的自定义CBAM模块不在ultralytics源码里,我们要加载它,就需要先在运行train.py的时候,把它注入task.py里

(具体代码我会放在在最后,可以直接复制拿去直接用)

然后自定义代码模块注入记得做,前面SE模块写过了。。。。。

2)yaml网络结构插入

 提醒:【注意力模块】不改变通道数!!!!!!!

那么我们应该插到哪?怎么插呢?————根【SE】一样!

  • 1、插入到【捕捉最大尺寸】检测头的【上采样 + concat + C3k2】的【前面】
    • 比如:你当前是3个检测层,那就插入到P3的【上采样 + concat + C3k2】的【前面】(就是第14层)
    • SpatialAttention和CBAM同理
  • 2、后面的参数怎么填
    • CBAM后面参数:[? , ?]
      • 第一个代表【通道数】,那么上一行是多少他就照抄就行,因为【注意力模块】不改变通道数!!
      • 第二个代表【卷积核数量】,一般默认7
    • SpatialAttention后面参数[?]
      • 代表【卷积核数量】,一般默认7

3)拓展:单独拆出【通道注意力】和【空间注意力】

        这里我当时咨询chatGPT的时候,我说的是想要单独加强【空间定位】的注意力,他一开始是把【空间注意力:SpatialAttention】拆出来了来训练的,所以我单独研究了一下,CBAM的【通道注意力】和【空间注意力】是可以单独分为两个注意力模块使用的

【使用方法】:

  • 1、yolo官方在【conv.py】写了这两个注意力模块的​​​​​​​
  • 2、那么直接像CBAM一样插入【task.py】就行了,只需要在开头导入这一步就够了
  • 然后如果要手动自己写代码,结构上会有少许区别:
    • 【通道注意力ChannelAttention】:
      • CBAM的通道注意力部分采用标准的论文结构:双路池化 + 两层全连接层 + 相加 + Sigmoid
      • 而YOLO单独拆出来的ChannelAttention是 “轻量版”:单路池化 + 一层卷积 + Sigmoid
    • 【空间注意力SpatialAttention】:
      • 结构上倒是没有太大区别,只不过写法有点区别
  • 但是这里我没有仔细介绍是因为,前面说过,不管是单独使用【通道注意力ChannelAttention】、还是单独使用【空间注意力SpatialAttention】,效果都并没有明显变好!!!!只有像CBAM那样【先“ChannelAttention” 再“SpatialAttention”】的注意力效果才有用!!!!

3、ECA模块

1)导入【ECA】模块

YOLO依旧没有这块代码,需要我们自己手写:

  • 1、先在自定义模块脚本写他的代码:
    • 解析模块有点烦了,各位自己根据图片和注释自己看吧

 2、【注意:一定要做!!!!】

  • 和SE一样,自定义模块还需要从【train.py】前面注入:
    • 对应task.py那,你注册的当前train.py路径为全局环境变量“XXX”,然后task.py才能获取XXX的时候读到train.py路径
      import sys
      import os
      from pathlib import Path
      # 必须最先做:把项目根加入 path,后面 from utils.custom_modules 才能找到
      # 前面的 _ 不是语法要求,只是命名习惯,用来表达“内部用 / 不想被当公开接口”的意思
      _你设置的全局环境变量 = Path(__file__).resolve().parent
      if str(_你设置的全局环境变量) not in sys.path:
          sys.path.insert(0, str(_你设置的全局环境变量))
      # 让 DDP 子进程继承到
      os.environ["你设置的全局环境变量"] = str(_你设置的全局环境变量)
  • task.py能读到train.py后,再把我们自定义的模块注入
    # 让 yaml 里能使用自定义的 SE、ECA、CoordAtt(必须在 import YOLO 之前注入到 tasks)
    import ultralytics.nn.tasks as _ultra_tasks
    from utils.custom_modules import SE, ECA, CoordAtt
    _ultra_tasks.SE = SE
    _ultra_tasks.ECA = ECA
    _ultra_tasks.CoordAtt = CoordAtt

2)yaml网络结构插入

提醒:【注意力模块】不改变通道数!!!!!!!

那么我们应该插到哪?怎么插呢?

  • 还是跟上面模块原理一样
  • 后面的参数怎么填
    • ECA后面参数:[3]
      • 代表【1D局部卷积核数量】,一般默认3

4、Coordinate Attention模块

1)导入【ECA】模块

YOLO依旧没有这块代码,需要我们自己手写:

  • 1、先在自定义模块脚本写他的代码:
    • 解析模块有点烦了,各位自己根据图片和注释自己看吧

 2、【注意:一定要做!!!!】

  • 和SE一样,自定义模块还需要从【train.py】前面注入:
    • 对应task.py那,你注册的当前train.py路径为全局环境变量“XXX”,然后task.py才能获取XXX的时候读到train.py路径
      import sys
      import os
      from pathlib import Path
      # 必须最先做:把项目根加入 path,后面 from utils.custom_modules 才能找到
      # 前面的 _ 不是语法要求,只是命名习惯,用来表达“内部用 / 不想被当公开接口”的意思
      _你设置的全局环境变量 = Path(__file__).resolve().parent
      if str(_你设置的全局环境变量) not in sys.path:
          sys.path.insert(0, str(_你设置的全局环境变量))
      # 让 DDP 子进程继承到
      os.environ["你设置的全局环境变量"] = str(_你设置的全局环境变量)
  • task.py能读到train.py后,再把我们自定义的模块注入
    # 让 yaml 里能使用自定义的 SE、ECA、CoordAtt(必须在 import YOLO 之前注入到 tasks)
    import ultralytics.nn.tasks as _ultra_tasks
    from utils.custom_modules import SE, ECA, CoordAtt
    _ultra_tasks.SE = SE
    _ultra_tasks.ECA = ECA
    _ultra_tasks.CoordAtt = CoordAtt

2)yaml网络结构插入

 提醒:【注意力模块】不改变通道数!!!!!!!

那么我们应该插到哪?怎么插呢?

  • 还是跟上面模块原理一样
  • 后面的参数怎么填
    • ECA后面参数:[16]
      •  代表的是【压缩比:reduction=16】,不用管为什么,这就是默认的

总结

关于注入模块

  • 目前yolo11只有ChannelAttention、SptailAttention、CBAM这三个注意力模块可以直接加入到task.py使用
  • 其他的注意力模块需要我们手动编写,并在train.py开头注入

关于【yaml】部分插入

  • 1、位置你就记着要么插14行、要么17行,就这么简单就完事了
  • 2、参数部分:
    • 要填【通道数】得模块直接抄上一行的、
    • 压缩比例reduction固定16、
    • 1d卷积固定3、
    • 2d卷积固定7就行了

四、完整代码:

1、我的train.py代码:

import sys
import os
from pathlib import Path
# 必须最先做:把项目根加入 path,后面 from utils.custom_modules 才能找到
# 前面的 _ 不是语法要求,只是命名习惯,用来表达“内部用 / 不想被当公开接口”的意思
_MyProject_ROOT = Path(__file__).resolve().parent
if str(_MyProject_ROOT) not in sys.path:
    sys.path.insert(0, str(_MyProject_ROOT))
# 让 DDP 子进程继承到
os.environ["MyProject_ROOT"] = str(_MyProject_ROOT)


import torch
from ultralytics import YOLO

# 让 yaml 里能使用自定义的 SE、ECA、CoordAtt(必须在 import YOLO 之前注入到 tasks)
import ultralytics.nn.tasks as _ultra_tasks
from utils.custom_modules import SE, ECA, CoordAtt
_ultra_tasks.SE = SE
_ultra_tasks.ECA = ECA
_ultra_tasks.CoordAtt = CoordAtt


# ⚠️ 注意:AutoDL是Linux环境,路径要改!把数据上传到 /root/autodl-tmp/ 下(强烈建议放这里,读写快)
# 我们的数据集yaml
DATA_YAML_PATH = "F:\我自己的毕设\YOLO_study\CZM_NewRS\CZM_NewRS.yaml"
# 我们自定义的网络模型yaml
MODEL_YAML = "my_yaml/11/my_yolo11-obb.yaml"
# 我们的load加载的模型权重
MODEL_PATH = "yolo11n-obb.pt"


if __name__ == '__main__':
    # Linux下通常不需要 freeze_support,但留着也不报错
    torch.multiprocessing.freeze_support()

    # 加载模型,建议load换成 s (Small) 或 m (Medium) 模型,AutoDL显卡完全跑得动,精度比n高很多
    model = YOLO(MODEL_YAML).load(MODEL_PATH)

    results = model.train(
        data=DATA_YAML_PATH,
        epochs=100,  # 多跑一点,100轮对于从头练可能刚收敛
        patience=60,  # 60轮不提升就停,省点钱
        imgsz=1240,  # 1280也可以,但1024是标准倍数,训练更稳

        # 【速度拉满配置】
        device=0,  # ❗关键:开启双卡并行训练!
        batch=-1,  # 如果设batch=16,则双卡合计32;设-1自动适配
        workers=6,  # 有40核CPU,大胆给!16-24都可以,数据加载飞快
        # cache="ram",  # ❗有180G内存,直接把数据全读进RAM,速度起飞!
        amp=True,  # 混合精度训练,速度快显存占用少

        # 【增强参数微调:适当即可,拒绝卡通画】
        augment=True,
        hsv_h=0.015,  # 色相微调,保持不变
        hsv_s=0.3,  # ❗降下来!原0.7太高,导致色彩过饱和像动画片
        hsv_v=0.3,  # ❗降下来!原0.4太高,导致对比度太强
        # 【几何增强】
        degrees=5.0,  # 稍微给点旋转,OBB任务很需要
        translate=0.05,  # 稍微平移一下
        scale=0.5,  # 尺度缩放
        perspective=0.0,  # ✅ 必须0!透视太大,图片都变形了,0.0001都算大的
        # 【Mosaic 策略】
        mosaic=0.2,  # 开启马赛克增强
        mixup=0,  # 给一点点混合,不要太多
        copy_paste=0.0,   # ✅ 先关掉(后面再开)
        close_mosaic=10,  # ❗最后10轮关闭Mosaic即可,50轮太早了,浪费了增强效果
        
        # 📉【损失函数权重】提高定位相关权重,不降分类(老师:定位还不够好)
        # 默认:box=7.5, cls=0.5, dfl=1.5。OBB 的 angle 已含在 box loss 里,无单独权重
        box=10.0,   # 提高边界框/定位损失权重(默认 7.5)
        dfl=2.0,     # 提高 DFL 损失权重,利于框边更精细(默认 1.5)
        cls=0.5,     # 分类权重保持不变,不降低
        
        # 📉【优化器】
        optimizer='SGD',  # 这种大数据集,SGD通常比AdamW后期泛化更好
        lr0=0.01,
        lrf=0.01,
        cos_lr=True,  # 余弦退火学习率,训练更丝滑
    )

2、我的yolo11.yaml代码

3层检测目标层情况(仅head部分)

# YOLO11n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]                                                # 11
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4                                               # 12
  - [-1, 2, C3k2, [512, False]]                                                               # 13
  - [-1, 1, ECA, [3]]                                                                         # 14
#  - [-1, 1, CBAM, [512, 7]]                                                                   # 14
#  - [-1, 1, SpatialAttention, [7]] # 17 空间注意力,kernel_size=7                                # 14
#  - [-1, 1, SE, [16]]                                                                         # 14
#  - [-1, 1, CoordAtt, [16]]                                                                   # 14
  - [-1, 1, C2PSA, [512]]                                                                     # 14

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]                                               # 15
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3                                              # 16
  - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)                                            # 17

  - [-1, 1, Conv, [256, 3, 2]]                                                               # 18
  - [[-1, 13], 1, Concat, [1]] # cat head P4                                                 # 19
  - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)                                          # 20

  - [-1, 1, Conv, [512, 3, 2]]                                                               # 21
  - [[-1, 10], 1, Concat, [1]] # cat head P5                                                 # 22
  - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)                                           # 23

  - [[16, 19, 22], 1, OBB, [nc, 1]] # Detect(P3, P4, P5)                                     # 24

4层检测目标层情况(仅head部分)

# YOLO11n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]                                                # 11
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4                                               # 12
  - [-1, 2, C3k2, [512, False]]                                                               # 13

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]                                                # 14
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3                                               # 15
  - [-1, 3, C3k2, [256, False]] # 16 (P3/8-small)                                             # 16
  - [-1, 1, SE, [16]]                                                                         # 17
#  - [-1, 1, SpatialAttention, [7]] # 17 空间注意力,kernel_size=7                                # 17
#  - [-1, 1, CBAM, [256, 7]]                                                                   # 17
#  - [-1, 1, C2PSA, [256]]                                                                     # 17
#  - [-1, 1, ECA, [3]]                                                                         # 17
  - [-1, 1, CoordAtt, [16]]                                                                   # 17

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]                                                # 18
  - [[-1, 2], 1, Concat, [1]] # cat backbone P2 (layer 2 = C3k2 输出)                          # 19
  - [-1, 3, C3k2, [256, False]] # 20 (P2/4-small)                                             # 20

  - [-1, 1, Conv, [256, 3, 2]]                                                                # 21
  - [[-1, 16], 1, Concat, [1]] # cat head P3                                                  # 22
  - [-1, 2, C3k2, [512, False]] # 23 (P3/8-medium)                                            # 23

  - [-1, 1, Conv, [256, 3, 2]]                                                                # 24
  - [[-1, 13], 1, Concat, [1]] # cat head P4                                                  # 25
  - [-1, 2, C3k2, [512, False]] # 26 (P4/16-medium)                                           # 26

  - [-1, 1, Conv, [512, 3, 2]]                                                                # 27
  - [[-1, 10], 1, Concat, [1]] # cat head P5                                                  # 28
  - [-1, 2, C3k2, [1024, True]] # 29 (P5/32-large)                                            # 29

  - [[16, 20, 23, 26, 29], 1, OBB, [nc, 1]] # Detect(P3,P2,P3-med,P4,P5),插入 SA 后索引+1      # 30

3、我的自定义注意力模块代码(包含所有注意力模块)

"""
自定义注意力模块(SE / ECA / CoordAtt / Channel / Spatial / CBAM)
用于 YOLO 等检测网络的 backbone 或 head,增强通道或空间上的重要特征。

【常用名词速查】
- in_channels:输入通道数(本文件已统一用此名)。c_ / mid_channels:中间压缩后的通道数(如 in_channels//r),不是输入
- // :整除(向下取整),例如 256//16=16,用来做“压缩比”
- self.fc / self.gate:可学习的子网络;gate 一般指“门控”,输出 0~1 的权重
- squeeze(dim):去掉大小为 1 的维度;transpose:交换维度顺序;view:拉成指定形状

【若自己按逻辑写】
- SEnet:先全局平均池化得到每通道一个数 → 两层“全连接”(降维再升维)→ Sigmoid 得到权重 → 乘回原图。
- ChannelAttention:池化后一个 1×1 卷积 + Sigmoid,直接得到每通道权重(无瓶颈)。
- SpatialAttention:沿通道做 mean 和 max 得到两幅 (1,H,W),拼成 (2,H,W) → 卷积得到 (1,H,W) 权重图 → 乘回原图。

【nn.Conv2d 参数速查】
  nn.Conv2d(输入通道, 输出通道, kernel_size, stride=1, padding=0, bias=...)
- 第1个参数:输入通道数(输入特征图有几层“通道”)。
- 第2个参数:输出通道数(卷积后得到几层通道)。
- 第3个参数:卷积核大小(如 1 表示 1×1,只做通道混合、不改变 H,W;3 或 7 表示 3×3、7×7)。
- stride:步长,默认 1(不缩小尺寸)。
- padding:四周补零圈数,通常取 (kernel_size-1)//2 使 H、W 不变。
- bias:是否加偏置,True/False。
"""

# ---------- 下面这些你可能会碰到的写法简要说明 ----------
# a // b     :整除,结果向下取整。如 7//2=3,256//16=16。用来算“压缩后通道数”。
# self.fc    :通常指全连接/线性层(或 1×1 卷积),这里用来学“通道权重”或做降维。
# self.gate  :门控,一般接 Sigmoid,输出 (0,1),表示“保留多少”。
# .squeeze(-1):去掉最后一维且大小为 1 的维度;(B,C,1,1) → (B,C,1)。
# .transpose(1,2):交换第 1、2 维;(B,C,1) → (B,1,C),方便 Conv1d 在 C 维上卷积。
# .view(b,c) :把张量拉成 (b,c) 形状,不改变元素总数;常用于池化后喂给 Linear。
# .permute(0,1,3,2):按给定顺序重排维度,这里相当于把 dim=2 和 dim=3 互换。

import torch
import torch.nn as nn


# ============================ 1. SE(Squeeze-and-Excitation)============================
# 思路:对每个通道做“全局池化 → 两层全连接 → 得到该通道的权重”再乘回原特征
class SE(nn.Module):
    """Squeeze-and-Excitation:通道注意力,为每个通道学一个 0~1 的权重。"""
    # r 是“压缩比”(reduction ratio),用来算中间层的通道数
    def __init__(self, in_channels, r=16):
        super().__init__()
        # in_channels // r:中间层通道数(压缩比)。
        # max(1, ...)  第一个参数1,是为了防止r过大时mid_channels变成 0,至少也得是1,参数里取最大
        mid_channels = max(1, in_channels // r)

        self.pool = nn.AdaptiveAvgPool2d(1)   # 全局平均池化:(B,C,H,W)→(B,C,1,1),每个通道一个数
        self.fc1 = nn.Conv2d(in_channels, mid_channels, 1, bias=True)   # 降维 in_channels → mid_channels
        self.act = nn.SiLU() # SiLU或ReLU都可以
        self.fc2 = nn.Conv2d(mid_channels, in_channels, 1, bias=True)   # 升回 mid_channels → in_channels
        self.gate = nn.Sigmoid()   # 激活函数:把输出压到 (0,1)

    def forward(self, x):
        w = self.pool(x)           # (B,C,H,W)→(B,C,1,1)
        w = self.fc2(self.act(self.fc1(w)))   # 瓶颈结构:C → mid_channels → C
        return x * self.gate(w)    # 逐通道缩放,重要通道权重大


# ======================= 2.1 ChannelAttention(通道注意力,轻量版)=======================
# 思路:全局池化后用一个 1×1 卷积 + Sigmoid,直接得到每通道一个权重(无瓶颈)
class ChannelAttention(nn.Module):
    """轻量通道注意力:池化后 1×1 卷积 + Sigmoid,为每通道学一个标量权重。"""

    def __init__(self, in_channels: int) -> None:
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)   # (B,C,H,W)→(B,C,1,1)
        # 1×1 卷积:输入输出都是 in_channels,学通道间关系
        self.fc = nn.Conv2d(in_channels, in_channels, 1, 1, 0, bias=True)
        self.act = nn.Sigmoid()   # 输出 0~1 作为门控

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x * self.act(self.fc(self.pool(x)))


# ======================= 2.2 SpatialAttention(空间注意力)=======================
# 思路:先沿通道做“平均”和“最大”得到两幅图 (B,2,H,W),再卷积得到 (B,1,H,W) 的空间权重图
class SpatialAttention(nn.Module):
    """空间注意力:为每个像素位置学一个 0~1 的权重(哪里重要哪里权重大)。"""

    def __init__(self, kernel_size=7):
        super().__init__()
        assert kernel_size in {3, 7}, "kernel size must be 3 or 7"
        padding = 3 if kernel_size == 7 else 1   # 保持 H,W 不变
        # 输入 2 通道(avg+max),输出 1 通道(该位置的权重)
        self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.act = nn.Sigmoid()

    def forward(self, x):
        # torch.mean(x, 1, keepdim=True):沿通道维求平均 → (B,1,H,W)
        # torch.max(x, 1, keepdim=True)[0]:沿通道维求最大 → (B,1,H,W)
        # cat(..., dim=1):拼成 (B,2,H,W),再卷积得 (B,1,H,W),乘回原图
        return x * self.act(
            self.cv1(
                torch.cat(
                    [torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]],
                    dim=1,
                )
            )
        )


# ============================ 2.3 CBAM(通道+空间串联)============================
# 思路:先做通道注意力再做空间注意力,两段都用“avg+max → 小网络 → sigmoid”的形式
class CBAM(nn.Module):
    """CBAM = Channel + Spatial:先通道再空间,两段注意力串联。"""

    def __init__(self, in_channels, reduction=16, kernel_size=7):
        super().__init__()
        mid_channels = max(1, in_channels // reduction)

        # 通道注意力(轻量版)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)    # 平均池化,尺寸:[b,c,1,1]
        self.max_pool = nn.AdaptiveMaxPool2d(1)    # 最大池化,尺寸:[b,c,1,1]
        self.fc = nn.Sequential(
            nn.Linear(in_channels, mid_channels, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(mid_channels, in_channels, bias=False),
        )

        # 空间注意力(轻量版)
        self.spatial = nn.Sequential(              # 2层卷积、后面跟一个sigmoid激活函数
            nn.Conv2d(2, 1, kernel_size=3, padding=1, bias=False),
            nn.Sigmoid()
        )
        self.sigmoid = nn.Sigmoid()

    # 开始【前向传播】学习
    def forward(self, x):
        b, c, _, _ = x.shape
        # 通道注意力(获取平均池化值、最大池化值、求sigmoid函数值)
        # view(b, c):把 (B,C,1,1) 拉成 (B,C) 才能喂给 Linear
        avg_out = self.fc(self.avg_pool(x).view(b, c)).view(b, c, 1, 1)
        max_out = self.fc(self.max_pool(x).view(b, c)).view(b, c, 1, 1)
        x = x * self.sigmoid(avg_out + max_out) # 【求和】并用【sigmoid函数激活】

        # 空间注意力(获取平均池化值、最大池化值、求sigmoid函数值)
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = x * self.sigmoid(self.spatial(torch.cat([avg_out, max_out], dim=1))) # 最后一次sigmoid激活
        return x


# ============================ 3. ECA(Efficient Channel Attention)============================
# 思路:不用全连接,用 1D 卷积在“通道维”上做局部交互,参数量小;需把 (B,C,1,1) 变成 (B,1,C) 才能用 Conv1d
class ECA(nn.Module):
    """ECA:用一维卷积在通道维做局部建模,替代 SE 的全连接,更省参数。"""

    def __init__(self, in_channels, k=3):
        super().__init__()
        k = int(k)
        if k % 2 == 0:
            k += 1   # Conv1d 用奇数 kernel 方便两边 padding 相同
        self.pool = nn.AdaptiveAvgPool2d(1)   # (B,C,H,W)→(B,C,1,1)
        self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
        self.gate = nn.Sigmoid()

    def forward(self, x):
        y = self.pool(x)   # (B, C, 1, 1)
        # squeeze(-1):去掉最后一维 (B,C,1,1)→(B,C,1);transpose(1,2):1和2维互换→(B,1,C)
        # 这样 C 这一维变成 Conv1d 的“长度”,1 是通道维,才能用 nn.Conv1d(1,1,k)
        y = y.squeeze(-1).transpose(1, 2)   # (B, 1, C)
        y = self.conv(y)   # (B, 1, C)
        # 再变回 (B,C,1,1):transpose(1,2)→(B,C,1),unsqueeze(-1)→(B,C,1,1)
        y = self.gate(y).transpose(1, 2).unsqueeze(-1)
        return x * y


# ============================ 4. CoordAtt(Coordinate Attention)============================
# 思路:沿 H、W 两个方向分别池化,得到“垂直方向”和“水平方向”的编码,再生成两路注意力图 a_h、a_w,乘回原图
class CoordAtt(nn.Module):
    """坐标注意力:沿高、宽方向分别池化并生成方向感知的注意力,适合长条/细长目标。"""

    def __init__(self, in_channels, r=16):
        super().__init__()
        mid_channels = max(8, in_channels // r)
        # (None, 1):高方向保留、宽压成 1 → (B,C,H,1);(1, None):宽保留、高压成 1 → (B,C,1,W)
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))

        self.conv1 = nn.Conv2d(in_channels, mid_channels, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(mid_channels)
        self.act = nn.SiLU()

        self.conv_h = nn.Conv2d(mid_channels, in_channels, 1, bias=False)   # 生成“高方向”权重
        self.conv_w = nn.Conv2d(mid_channels, in_channels, 1, bias=False)   # 生成“宽方向”权重
        self.gate = nn.Sigmoid()

    def forward(self, x):
        b, c, h, w = x.shape
        x_h = self.pool_h(x)   # (B, C, H, 1)
        # permute(0,1,3,2):把 (B,C,1,W) 变成 (B,C,W,1),方便后面和 x_h 在 dim=2 上 cat
        x_w = self.pool_w(x).permute(0, 1, 3, 2)   # (B, C, W, 1)

        y = torch.cat([x_h, x_w], dim=2)   # (B, C, H+W, 1)
        y = self.act(self.bn1(self.conv1(y)))

        y_h, y_w = torch.split(y, [h, w], dim=2)   # 再拆回高、宽两段
        y_w = y_w.permute(0, 1, 3, 2)   # (B, mid_channels, 1, W)

        a_h = self.gate(self.conv_h(y_h))   # (B, C, H, 1)
        a_w = self.gate(self.conv_w(y_w))   # (B, C, 1, W)
        return x * a_h * a_w   # 广播相乘,得到带高、宽注意力的特征

4、我的task.py代码

# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license

import contextlib
import pickle
import re
import types
from copy import deepcopy
from pathlib import Path

import torch
import torch.nn as nn

# 自定义注意力 SE/ECA/CoordAtt:需让 Python 找到你项目下的 utils(项目根必须在 sys.path 里)
# 方式1:train.py 最开头 sys.path.insert(0, 项目根) 再 import ultralytics
# 方式2:设置环境变量 CZM_NEWRS_ROOT=项目根(如 F:\我自己的毕设\YOLO_study\CZM_NewRS),便于 DDP 子进程也能找到
import sys as _sys
import os as _os
_MyProject_ROOT = _os.environ.get("MyProject_ROOT")
if _MyProject_ROOT and _MyProject_ROOT not in _sys.path:
    _sys.path.insert(0, _MyProject_ROOT)
try:
    from utils.custom_modules import SE, ECA, CoordAtt
    _CUSTOM_ATTN = {SE, ECA, CoordAtt}
except ImportError:
    _CUSTOM_ATTN = set()

from ultralytics.nn.autobackend import check_class_names
from ultralytics.nn.modules import (
    AIFI,
    C1,
    C2,
    C2PSA,
    C3,
    C3TR,
    ELAN1,
    OBB,
    PSA,
    SPP,
    SPPELAN,
    SpatialAttention,
    SPPF,
    A2C2f,
    AConv,
    ADown,
    Bottleneck,
    BottleneckCSP,
    ChannelAttention,
    C2f,
    C2fAttn,
    C2fCIB,
    C2fPSA,
    C3Ghost,
    C3k2,
    C3x,
    CBAM,
    CBFuse,
    CBLinear,
    Classify,
    Concat,
    Conv,
    Conv2,
    ConvTranspose,
    Detect,
    DWConv,
    DWConvTranspose2d,
    Focus,
    GhostBottleneck,
    GhostConv,
    HGBlock,
    HGStem,
    ImagePoolingAttn,
    Index,
    LRPCHead,
    Pose,
    RepC3,
    RepConv,
    RepNCSPELAN4,
    RepVGGDW,
    ResNetLayer,
    RTDETRDecoder,
    SCDown,
    Segment,
    TorchVision,
    WorldDetect,
    YOLOEDetect,
    YOLOESegment,
    v10Detect,
)
from ultralytics.utils import DEFAULT_CFG_DICT, LOGGER, YAML, colorstr, emojis
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
from ultralytics.utils.loss import (
    E2EDetectLoss,
    v8ClassificationLoss,
    v8DetectionLoss,
    v8OBBLoss,
    v8PoseLoss,
    v8SegmentationLoss,
)
from ultralytics.utils.ops import make_divisible
from ultralytics.utils.patches import torch_load
from ultralytics.utils.plotting import feature_visualization
from ultralytics.utils.torch_utils import (
    fuse_conv_and_bn,
    fuse_deconv_and_bn,
    initialize_weights,
    intersect_dicts,
    model_info,
    scale_img,
    smart_inference_mode,
    time_sync,
)


class BaseModel(torch.nn.Module):
    """Base class for all YOLO models in the Ultralytics family.

    This class provides common functionality for YOLO models including forward pass handling, model fusion, information
    display, and weight loading capabilities.

    Attributes:
        model (torch.nn.Module): The neural network model.
        save (list): List of layer indices to save outputs from.
        stride (torch.Tensor): Model stride values.

    Methods:
        forward: Perform forward pass for training or inference.
        predict: Perform inference on input tensor.
        fuse: Fuse Conv2d and BatchNorm2d layers for optimization.
        info: Print model information.
        load: Load weights into the model.
        loss: Compute loss for training.

    Examples:
        Create a BaseModel instance
        >>> model = BaseModel()
        >>> model.info()  # Display model information
    """

    def forward(self, x, *args, **kwargs):
        """Perform forward pass of the model for either training or inference.

        If x is a dict, calculates and returns the loss for training. Otherwise, returns predictions for inference.

        Args:
            x (torch.Tensor | dict): Input tensor for inference, or dict with image tensor and labels for training.
            *args (Any): Variable length argument list.
            **kwargs (Any): Arbitrary keyword arguments.

        Returns:
            (torch.Tensor): Loss if x is a dict (training), or network predictions (inference).
        """
        if isinstance(x, dict):  # for cases of training and validating while training.
            return self.loss(x, *args, **kwargs)
        return self.predict(x, *args, **kwargs)

    def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
        """Perform a forward pass through the network.

        Args:
            x (torch.Tensor): The input tensor to the model.
            profile (bool): Print the computation time of each layer if True.
            visualize (bool): Save the feature maps of the model if True.
            augment (bool): Augment image during prediction.
            embed (list, optional): A list of feature vectors/embeddings to return.

        Returns:
            (torch.Tensor): The last output of the model.
        """
        if augment:
            return self._predict_augment(x)
        return self._predict_once(x, profile, visualize, embed)

    def _predict_once(self, x, profile=False, visualize=False, embed=None):
        """Perform a forward pass through the network.

        Args:
            x (torch.Tensor): The input tensor to the model.
            profile (bool): Print the computation time of each layer if True.
            visualize (bool): Save the feature maps of the model if True.
            embed (list, optional): A list of feature vectors/embeddings to return.

        Returns:
            (torch.Tensor): The last output of the model.
        """
        y, dt, embeddings = [], [], []  # outputs
        embed = frozenset(embed) if embed is not None else {-1}
        max_idx = max(embed)
        for m in self.model:
            if m.f != -1:  # if not from previous layer
                x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
            if profile:
                self._profile_one_layer(m, x, dt)
            x = m(x)  # run
            y.append(x if m.i in self.save else None)  # save output
            if visualize:
                feature_visualization(x, m.type, m.i, save_dir=visualize)
            if m.i in embed:
                embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1))  # flatten
                if m.i == max_idx:
                    return torch.unbind(torch.cat(embeddings, 1), dim=0)
        return x

    def _predict_augment(self, x):
        """Perform augmentations on input image x and return augmented inference."""
        LOGGER.warning(
            f"{self.__class__.__name__} does not support 'augment=True' prediction. "
            f"Reverting to single-scale prediction."
        )
        return self._predict_once(x)

    def _profile_one_layer(self, m, x, dt):
        """Profile the computation time and FLOPs of a single layer of the model on a given input.

        Args:
            m (torch.nn.Module): The layer to be profiled.
            x (torch.Tensor): The input data to the layer.
            dt (list): A list to store the computation time of the layer.
        """
        try:
            import thop
        except ImportError:
            thop = None  # conda support without 'ultralytics-thop' installed

        c = m == self.model[-1] and isinstance(x, list)  # is final layer list, copy input as inplace fix
        flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0  # GFLOPs
        t = time_sync()
        for _ in range(10):
            m(x.copy() if c else x)
        dt.append((time_sync() - t) * 100)
        if m == self.model[0]:
            LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s}  module")
        LOGGER.info(f"{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f}  {m.type}")
        if c:
            LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s}  Total")

    def fuse(self, verbose=True):
        """Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer for improved computation
        efficiency.

        Returns:
            (torch.nn.Module): The fused model is returned.
        """
        if not self.is_fused():
            for m in self.model.modules():
                if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, "bn"):
                    if isinstance(m, Conv2):
                        m.fuse_convs()
                    m.conv = fuse_conv_and_bn(m.conv, m.bn)  # update conv
                    delattr(m, "bn")  # remove batchnorm
                    m.forward = m.forward_fuse  # update forward
                if isinstance(m, ConvTranspose) and hasattr(m, "bn"):
                    m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn)
                    delattr(m, "bn")  # remove batchnorm
                    m.forward = m.forward_fuse  # update forward
                if isinstance(m, RepConv):
                    m.fuse_convs()
                    m.forward = m.forward_fuse  # update forward
                if isinstance(m, RepVGGDW):
                    m.fuse()
                    m.forward = m.forward_fuse
                if isinstance(m, v10Detect):
                    m.fuse()  # remove one2many head
            self.info(verbose=verbose)

        return self

    def is_fused(self, thresh=10):
        """Check if the model has less than a certain threshold of BatchNorm layers.

        Args:
            thresh (int, optional): The threshold number of BatchNorm layers.

        Returns:
            (bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
        """
        bn = tuple(v for k, v in torch.nn.__dict__.items() if "Norm" in k)  # normalization layers, i.e. BatchNorm2d()
        return sum(isinstance(v, bn) for v in self.modules()) < thresh  # True if < 'thresh' BatchNorm layers in model

    def info(self, detailed=False, verbose=True, imgsz=640):
        """Print model information.

        Args:
            detailed (bool): If True, prints out detailed information about the model.
            verbose (bool): If True, prints out the model information.
            imgsz (int): The size of the image that the model will be trained on.
        """
        return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz)

    def _apply(self, fn):
        """Apply a function to all tensors in the model that are not parameters or registered buffers.

        Args:
            fn (function): The function to apply to the model.

        Returns:
            (BaseModel): An updated BaseModel object.
        """
        self = super()._apply(fn)
        m = self.model[-1]  # Detect()
        if isinstance(
            m, Detect
        ):  # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect, YOLOEDetect, YOLOESegment
            m.stride = fn(m.stride)
            m.anchors = fn(m.anchors)
            m.strides = fn(m.strides)
        return self

    def load(self, weights, verbose=True):
        """Load weights into the model.

        Args:
            weights (dict | torch.nn.Module): The pre-trained weights to be loaded.
            verbose (bool, optional): Whether to log the transfer progress.
        """
        model = weights["model"] if isinstance(weights, dict) else weights  # torchvision models are not dicts
        csd = model.float().state_dict()  # checkpoint state_dict as FP32
        updated_csd = intersect_dicts(csd, self.state_dict())  # intersect
        self.load_state_dict(updated_csd, strict=False)  # load
        len_updated_csd = len(updated_csd)
        first_conv = "model.0.conv.weight"  # hard-coded to yolo models for now
        # mostly used to boost multi-channel training
        state_dict = self.state_dict()
        if first_conv not in updated_csd and first_conv in state_dict:
            c1, c2, h, w = state_dict[first_conv].shape
            cc1, cc2, ch, cw = csd[first_conv].shape
            if ch == h and cw == w:
                c1, c2 = min(c1, cc1), min(c2, cc2)
                state_dict[first_conv][:c1, :c2] = csd[first_conv][:c1, :c2]
                len_updated_csd += 1
        if verbose:
            LOGGER.info(f"Transferred {len_updated_csd}/{len(self.model.state_dict())} items from pretrained weights")

    def loss(self, batch, preds=None):
        """Compute loss.

        Args:
            batch (dict): Batch to compute loss on.
            preds (torch.Tensor | list[torch.Tensor], optional): Predictions.
        """
        if getattr(self, "criterion", None) is None:
            self.criterion = self.init_criterion()

        if preds is None:
            preds = self.forward(batch["img"])
        return self.criterion(preds, batch)

    def init_criterion(self):
        """Initialize the loss criterion for the BaseModel."""
        raise NotImplementedError("compute_loss() needs to be implemented by task heads")


class DetectionModel(BaseModel):
    """YOLO detection model.

    This class implements the YOLO detection architecture, handling model initialization, forward pass, augmented
    inference, and loss computation for object detection tasks.

    Attributes:
        yaml (dict): Model configuration dictionary.
        model (torch.nn.Sequential): The neural network model.
        save (list): List of layer indices to save outputs from.
        names (dict): Class names dictionary.
        inplace (bool): Whether to use inplace operations.
        end2end (bool): Whether the model uses end-to-end detection.
        stride (torch.Tensor): Model stride values.

    Methods:
        __init__: Initialize the YOLO detection model.
        _predict_augment: Perform augmented inference.
        _descale_pred: De-scale predictions following augmented inference.
        _clip_augmented: Clip YOLO augmented inference tails.
        init_criterion: Initialize the loss criterion.

    Examples:
        Initialize a detection model
        >>> model = DetectionModel("yolo11n.yaml", ch=3, nc=80)
        >>> results = model.predict(image_tensor)
    """

    def __init__(self, cfg="yolo11n.yaml", ch=3, nc=None, verbose=True):
        """Initialize the YOLO detection model with the given config and parameters.

        Args:
            cfg (str | dict): Model configuration file path or dictionary.
            ch (int): Number of input channels.
            nc (int, optional): Number of classes.
            verbose (bool): Whether to display model information.
        """
        super().__init__()
        self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg)  # cfg dict
        if self.yaml["backbone"][0][2] == "Silence":
            LOGGER.warning(
                "YOLOv9 `Silence` module is deprecated in favor of torch.nn.Identity. "
                "Please delete local *.pt file and re-download the latest model checkpoint."
            )
            self.yaml["backbone"][0][2] = "nn.Identity"

        # Define model
        self.yaml["channels"] = ch  # save channels
        if nc and nc != self.yaml["nc"]:
            LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
            self.yaml["nc"] = nc  # override YAML value
        self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose)  # model, savelist
        self.names = {i: f"{i}" for i in range(self.yaml["nc"])}  # default names dict
        self.inplace = self.yaml.get("inplace", True)
        self.end2end = getattr(self.model[-1], "end2end", False)

        # Build strides
        m = self.model[-1]  # Detect()
        if isinstance(m, Detect):  # includes all Detect subclasses like Segment, Pose, OBB, YOLOEDetect, YOLOESegment
            s = 256  # 2x min stride
            m.inplace = self.inplace

            def _forward(x):
                """Perform a forward pass through the model, handling different Detect subclass types accordingly."""
                if self.end2end:
                    return self.forward(x)["one2many"]
                return self.forward(x)[0] if isinstance(m, (Segment, YOLOESegment, Pose, OBB)) else self.forward(x)

            self.model.eval()  # Avoid changing batch statistics until training begins
            m.training = True  # Setting it to True to properly return strides
            m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))])  # forward
            self.stride = m.stride
            self.model.train()  # Set model back to training(default) mode
            m.bias_init()  # only run once
        else:
            self.stride = torch.Tensor([32])  # default stride, e.g., RTDETR

        # Init weights, biases
        initialize_weights(self)
        if verbose:
            self.info()
            LOGGER.info("")

    def _predict_augment(self, x):
        """Perform augmentations on input image x and return augmented inference and train outputs.

        Args:
            x (torch.Tensor): Input image tensor.

        Returns:
            (torch.Tensor): Augmented inference output.
        """
        if getattr(self, "end2end", False) or self.__class__.__name__ != "DetectionModel":
            LOGGER.warning("Model does not support 'augment=True', reverting to single-scale prediction.")
            return self._predict_once(x)
        img_size = x.shape[-2:]  # height, width
        s = [1, 0.83, 0.67]  # scales
        f = [None, 3, None]  # flips (2-ud, 3-lr)
        y = []  # outputs
        for si, fi in zip(s, f):
            xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
            yi = super().predict(xi)[0]  # forward
            yi = self._descale_pred(yi, fi, si, img_size)
            y.append(yi)
        y = self._clip_augmented(y)  # clip augmented tails
        return torch.cat(y, -1), None  # augmented inference, train

    @staticmethod
    def _descale_pred(p, flips, scale, img_size, dim=1):
        """De-scale predictions following augmented inference (inverse operation).

        Args:
            p (torch.Tensor): Predictions tensor.
            flips (int): Flip type (0=none, 2=ud, 3=lr).
            scale (float): Scale factor.
            img_size (tuple): Original image size (height, width).
            dim (int): Dimension to split at.

        Returns:
            (torch.Tensor): De-scaled predictions.
        """
        p[:, :4] /= scale  # de-scale
        x, y, wh, cls = p.split((1, 1, 2, p.shape[dim] - 4), dim)
        if flips == 2:
            y = img_size[0] - y  # de-flip ud
        elif flips == 3:
            x = img_size[1] - x  # de-flip lr
        return torch.cat((x, y, wh, cls), dim)

    def _clip_augmented(self, y):
        """Clip YOLO augmented inference tails.

        Args:
            y (list[torch.Tensor]): List of detection tensors.

        Returns:
            (list[torch.Tensor]): Clipped detection tensors.
        """
        nl = self.model[-1].nl  # number of detection layers (P3-P5)
        g = sum(4**x for x in range(nl))  # grid points
        e = 1  # exclude layer count
        i = (y[0].shape[-1] // g) * sum(4**x for x in range(e))  # indices
        y[0] = y[0][..., :-i]  # large
        i = (y[-1].shape[-1] // g) * sum(4 ** (nl - 1 - x) for x in range(e))  # indices
        y[-1] = y[-1][..., i:]  # small
        return y

    def init_criterion(self):
        """Initialize the loss criterion for the DetectionModel."""
        return E2EDetectLoss(self) if getattr(self, "end2end", False) else v8DetectionLoss(self)


class OBBModel(DetectionModel):
    """YOLO Oriented Bounding Box (OBB) model.

    This class extends DetectionModel to handle oriented bounding box detection tasks, providing specialized loss
    computation for rotated object detection.

    Methods:
        __init__: Initialize YOLO OBB model.
        init_criterion: Initialize the loss criterion for OBB detection.

    Examples:
        Initialize an OBB model
        >>> model = OBBModel("yolo11n-obb.yaml", ch=3, nc=80)
        >>> results = model.predict(image_tensor)
    """

    def __init__(self, cfg="yolo11n-obb.yaml", ch=3, nc=None, verbose=True):
        """Initialize YOLO OBB model with given config and parameters.

        Args:
            cfg (str | dict): Model configuration file path or dictionary.
            ch (int): Number of input channels.
            nc (int, optional): Number of classes.
            verbose (bool): Whether to display model information.
        """
        super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)

    def init_criterion(self):
        """Initialize the loss criterion for the model."""
        return v8OBBLoss(self)


class SegmentationModel(DetectionModel):
    """YOLO segmentation model.

    This class extends DetectionModel to handle instance segmentation tasks, providing specialized loss computation for
    pixel-level object detection and segmentation.

    Methods:
        __init__: Initialize YOLO segmentation model.
        init_criterion: Initialize the loss criterion for segmentation.

    Examples:
        Initialize a segmentation model
        >>> model = SegmentationModel("yolo11n-seg.yaml", ch=3, nc=80)
        >>> results = model.predict(image_tensor)
    """

    def __init__(self, cfg="yolo11n-seg.yaml", ch=3, nc=None, verbose=True):
        """Initialize Ultralytics YOLO segmentation model with given config and parameters.

        Args:
            cfg (str | dict): Model configuration file path or dictionary.
            ch (int): Number of input channels.
            nc (int, optional): Number of classes.
            verbose (bool): Whether to display model information.
        """
        super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)

    def init_criterion(self):
        """Initialize the loss criterion for the SegmentationModel."""
        return v8SegmentationLoss(self)


class PoseModel(DetectionModel):
    """YOLO pose model.

    This class extends DetectionModel to handle human pose estimation tasks, providing specialized loss computation for
    keypoint detection and pose estimation.

    Attributes:
        kpt_shape (tuple): Shape of keypoints data (num_keypoints, num_dimensions).

    Methods:
        __init__: Initialize YOLO pose model.
        init_criterion: Initialize the loss criterion for pose estimation.

    Examples:
        Initialize a pose model
        >>> model = PoseModel("yolo11n-pose.yaml", ch=3, nc=1, data_kpt_shape=(17, 3))
        >>> results = model.predict(image_tensor)
    """

    def __init__(self, cfg="yolo11n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
        """Initialize Ultralytics YOLO Pose model.

        Args:
            cfg (str | dict): Model configuration file path or dictionary.
            ch (int): Number of input channels.
            nc (int, optional): Number of classes.
            data_kpt_shape (tuple): Shape of keypoints data.
            verbose (bool): Whether to display model information.
        """
        if not isinstance(cfg, dict):
            cfg = yaml_model_load(cfg)  # load model YAML
        if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg["kpt_shape"]):
            LOGGER.info(f"Overriding model.yaml kpt_shape={cfg['kpt_shape']} with kpt_shape={data_kpt_shape}")
            cfg["kpt_shape"] = data_kpt_shape
        super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)

    def init_criterion(self):
        """Initialize the loss criterion for the PoseModel."""
        return v8PoseLoss(self)


class ClassificationModel(BaseModel):
    """YOLO classification model.

    This class implements the YOLO classification architecture for image classification tasks, providing model
    initialization, configuration, and output reshaping capabilities.

    Attributes:
        yaml (dict): Model configuration dictionary.
        model (torch.nn.Sequential): The neural network model.
        stride (torch.Tensor): Model stride values.
        names (dict): Class names dictionary.

    Methods:
        __init__: Initialize ClassificationModel.
        _from_yaml: Set model configurations and define architecture.
        reshape_outputs: Update model to specified class count.
        init_criterion: Initialize the loss criterion.

    Examples:
        Initialize a classification model
        >>> model = ClassificationModel("yolo11n-cls.yaml", ch=3, nc=1000)
        >>> results = model.predict(image_tensor)
    """

    def __init__(self, cfg="yolo11n-cls.yaml", ch=3, nc=None, verbose=True):
        """Initialize ClassificationModel with YAML, channels, number of classes, verbose flag.

        Args:
            cfg (str | dict): Model configuration file path or dictionary.
            ch (int): Number of input channels.
            nc (int, optional): Number of classes.
            verbose (bool): Whether to display model information.
        """
        super().__init__()
        self._from_yaml(cfg, ch, nc, verbose)

    def _from_yaml(self, cfg, ch, nc, verbose):
        """Set Ultralytics YOLO model configurations and define the model architecture.

        Args:
            cfg (str | dict): Model configuration file path or dictionary.
            ch (int): Number of input channels.
            nc (int, optional): Number of classes.
            verbose (bool): Whether to display model information.
        """
        self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg)  # cfg dict

        # Define model
        ch = self.yaml["channels"] = self.yaml.get("channels", ch)  # input channels
        if nc and nc != self.yaml["nc"]:
            LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
            self.yaml["nc"] = nc  # override YAML value
        elif not nc and not self.yaml.get("nc", None):
            raise ValueError("nc not specified. Must specify nc in model.yaml or function arguments.")
        self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose)  # model, savelist
        self.stride = torch.Tensor([1])  # no stride constraints
        self.names = {i: f"{i}" for i in range(self.yaml["nc"])}  # default names dict
        self.info()

    @staticmethod
    def reshape_outputs(model, nc):
        """Update a TorchVision classification model to class count 'n' if required.

        Args:
            model (torch.nn.Module): Model to update.
            nc (int): New number of classes.
        """
        name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1]  # last module
        if isinstance(m, Classify):  # YOLO Classify() head
            if m.linear.out_features != nc:
                m.linear = torch.nn.Linear(m.linear.in_features, nc)
        elif isinstance(m, torch.nn.Linear):  # ResNet, EfficientNet
            if m.out_features != nc:
                setattr(model, name, torch.nn.Linear(m.in_features, nc))
        elif isinstance(m, torch.nn.Sequential):
            types = [type(x) for x in m]
            if torch.nn.Linear in types:
                i = len(types) - 1 - types[::-1].index(torch.nn.Linear)  # last torch.nn.Linear index
                if m[i].out_features != nc:
                    m[i] = torch.nn.Linear(m[i].in_features, nc)
            elif torch.nn.Conv2d in types:
                i = len(types) - 1 - types[::-1].index(torch.nn.Conv2d)  # last torch.nn.Conv2d index
                if m[i].out_channels != nc:
                    m[i] = torch.nn.Conv2d(
                        m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None
                    )

    def init_criterion(self):
        """Initialize the loss criterion for the ClassificationModel."""
        return v8ClassificationLoss()


class RTDETRDetectionModel(DetectionModel):
    """RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class.

    This class is responsible for constructing the RTDETR architecture, defining loss functions, and facilitating both
    the training and inference processes. RTDETR is an object detection and tracking model that extends from the
    DetectionModel base class.

    Attributes:
        nc (int): Number of classes for detection.
        criterion (RTDETRDetectionLoss): Loss function for training.

    Methods:
        __init__: Initialize the RTDETRDetectionModel.
        init_criterion: Initialize the loss criterion.
        loss: Compute loss for training.
        predict: Perform forward pass through the model.

    Examples:
        Initialize an RTDETR model
        >>> model = RTDETRDetectionModel("rtdetr-l.yaml", ch=3, nc=80)
        >>> results = model.predict(image_tensor)
    """

    def __init__(self, cfg="rtdetr-l.yaml", ch=3, nc=None, verbose=True):
        """Initialize the RTDETRDetectionModel.

        Args:
            cfg (str | dict): Configuration file name or path.
            ch (int): Number of input channels.
            nc (int, optional): Number of classes.
            verbose (bool): Print additional information during initialization.
        """
        super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)

    def _apply(self, fn):
        """Apply a function to all tensors in the model that are not parameters or registered buffers.

        Args:
            fn (function): The function to apply to the model.

        Returns:
            (RTDETRDetectionModel): An updated BaseModel object.
        """
        self = super()._apply(fn)
        m = self.model[-1]
        m.anchors = fn(m.anchors)
        m.valid_mask = fn(m.valid_mask)
        return self

    def init_criterion(self):
        """Initialize the loss criterion for the RTDETRDetectionModel."""
        from ultralytics.models.utils.loss import RTDETRDetectionLoss

        return RTDETRDetectionLoss(nc=self.nc, use_vfl=True)

    def loss(self, batch, preds=None):
        """Compute the loss for the given batch of data.

        Args:
            batch (dict): Dictionary containing image and label data.
            preds (torch.Tensor, optional): Precomputed model predictions.

        Returns:
            loss_sum (torch.Tensor): Total loss value.
            loss_items (torch.Tensor): Main three losses in a tensor.
        """
        if not hasattr(self, "criterion"):
            self.criterion = self.init_criterion()

        img = batch["img"]
        # NOTE: preprocess gt_bbox and gt_labels to list.
        bs = img.shape[0]
        batch_idx = batch["batch_idx"]
        gt_groups = [(batch_idx == i).sum().item() for i in range(bs)]
        targets = {
            "cls": batch["cls"].to(img.device, dtype=torch.long).view(-1),
            "bboxes": batch["bboxes"].to(device=img.device),
            "batch_idx": batch_idx.to(img.device, dtype=torch.long).view(-1),
            "gt_groups": gt_groups,
        }

        if preds is None:
            preds = self.predict(img, batch=targets)
        dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1]
        if dn_meta is None:
            dn_bboxes, dn_scores = None, None
        else:
            dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta["dn_num_split"], dim=2)
            dn_scores, dec_scores = torch.split(dec_scores, dn_meta["dn_num_split"], dim=2)

        dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes])  # (7, bs, 300, 4)
        dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores])

        loss = self.criterion(
            (dec_bboxes, dec_scores), targets, dn_bboxes=dn_bboxes, dn_scores=dn_scores, dn_meta=dn_meta
        )
        # NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses.
        return sum(loss.values()), torch.as_tensor(
            [loss[k].detach() for k in ["loss_giou", "loss_class", "loss_bbox"]], device=img.device
        )

    def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):
        """Perform a forward pass through the model.

        Args:
            x (torch.Tensor): The input tensor.
            profile (bool): If True, profile the computation time for each layer.
            visualize (bool): If True, save feature maps for visualization.
            batch (dict, optional): Ground truth data for evaluation.
            augment (bool): If True, perform data augmentation during inference.
            embed (list, optional): A list of feature vectors/embeddings to return.

        Returns:
            (torch.Tensor): Model's output tensor.
        """
        y, dt, embeddings = [], [], []  # outputs
        embed = frozenset(embed) if embed is not None else {-1}
        max_idx = max(embed)
        for m in self.model[:-1]:  # except the head part
            if m.f != -1:  # if not from previous layer
                x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
            if profile:
                self._profile_one_layer(m, x, dt)
            x = m(x)  # run
            y.append(x if m.i in self.save else None)  # save output
            if visualize:
                feature_visualization(x, m.type, m.i, save_dir=visualize)
            if m.i in embed:
                embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1))  # flatten
                if m.i == max_idx:
                    return torch.unbind(torch.cat(embeddings, 1), dim=0)
        head = self.model[-1]
        x = head([y[j] for j in head.f], batch)  # head inference
        return x


class WorldModel(DetectionModel):
    """YOLOv8 World Model.

    This class implements the YOLOv8 World model for open-vocabulary object detection, supporting text-based class
    specification and CLIP model integration for zero-shot detection capabilities.

    Attributes:
        txt_feats (torch.Tensor): Text feature embeddings for classes.
        clip_model (torch.nn.Module): CLIP model for text encoding.

    Methods:
        __init__: Initialize YOLOv8 world model.
        set_classes: Set classes for offline inference.
        get_text_pe: Get text positional embeddings.
        predict: Perform forward pass with text features.
        loss: Compute loss with text features.

    Examples:
        Initialize a world model
        >>> model = WorldModel("yolov8s-world.yaml", ch=3, nc=80)
        >>> model.set_classes(["person", "car", "bicycle"])
        >>> results = model.predict(image_tensor)
    """

    def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True):
        """Initialize YOLOv8 world model with given config and parameters.

        Args:
            cfg (str | dict): Model configuration file path or dictionary.
            ch (int): Number of input channels.
            nc (int, optional): Number of classes.
            verbose (bool): Whether to display model information.
        """
        self.txt_feats = torch.randn(1, nc or 80, 512)  # features placeholder
        self.clip_model = None  # CLIP model placeholder
        super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)

    def set_classes(self, text, batch=80, cache_clip_model=True):
        """Set classes in advance so that model could do offline-inference without clip model.

        Args:
            text (list[str]): List of class names.
            batch (int): Batch size for processing text tokens.
            cache_clip_model (bool): Whether to cache the CLIP model.
        """
        self.txt_feats = self.get_text_pe(text, batch=batch, cache_clip_model=cache_clip_model)
        self.model[-1].nc = len(text)

    def get_text_pe(self, text, batch=80, cache_clip_model=True):
        """Get text positional embeddings for offline inference without CLIP model.

        Args:
            text (list[str]): List of class names.
            batch (int): Batch size for processing text tokens.
            cache_clip_model (bool): Whether to cache the CLIP model.

        Returns:
            (torch.Tensor): Text positional embeddings.
        """
        from ultralytics.nn.text_model import build_text_model

        device = next(self.model.parameters()).device
        if not getattr(self, "clip_model", None) and cache_clip_model:
            # For backwards compatibility of models lacking clip_model attribute
            self.clip_model = build_text_model("clip:ViT-B/32", device=device)
        model = self.clip_model if cache_clip_model else build_text_model("clip:ViT-B/32", device=device)
        text_token = model.tokenize(text)
        txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)]
        txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0)
        return txt_feats.reshape(-1, len(text), txt_feats.shape[-1])

    def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
        """Perform a forward pass through the model.

        Args:
            x (torch.Tensor): The input tensor.
            profile (bool): If True, profile the computation time for each layer.
            visualize (bool): If True, save feature maps for visualization.
            txt_feats (torch.Tensor, optional): The text features, use it if it's given.
            augment (bool): If True, perform data augmentation during inference.
            embed (list, optional): A list of feature vectors/embeddings to return.

        Returns:
            (torch.Tensor): Model's output tensor.
        """
        txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype)
        if txt_feats.shape[0] != x.shape[0] or self.model[-1].export:
            txt_feats = txt_feats.expand(x.shape[0], -1, -1)
        ori_txt_feats = txt_feats.clone()
        y, dt, embeddings = [], [], []  # outputs
        embed = frozenset(embed) if embed is not None else {-1}
        max_idx = max(embed)
        for m in self.model:  # except the head part
            if m.f != -1:  # if not from previous layer
                x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
            if profile:
                self._profile_one_layer(m, x, dt)
            if isinstance(m, C2fAttn):
                x = m(x, txt_feats)
            elif isinstance(m, WorldDetect):
                x = m(x, ori_txt_feats)
            elif isinstance(m, ImagePoolingAttn):
                txt_feats = m(x, txt_feats)
            else:
                x = m(x)  # run

            y.append(x if m.i in self.save else None)  # save output
            if visualize:
                feature_visualization(x, m.type, m.i, save_dir=visualize)
            if m.i in embed:
                embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1))  # flatten
                if m.i == max_idx:
                    return torch.unbind(torch.cat(embeddings, 1), dim=0)
        return x

    def loss(self, batch, preds=None):
        """Compute loss.

        Args:
            batch (dict): Batch to compute loss on.
            preds (torch.Tensor | list[torch.Tensor], optional): Predictions.
        """
        if not hasattr(self, "criterion"):
            self.criterion = self.init_criterion()

        if preds is None:
            preds = self.forward(batch["img"], txt_feats=batch["txt_feats"])
        return self.criterion(preds, batch)


class YOLOEModel(DetectionModel):
    """YOLOE detection model.

    This class implements the YOLOE architecture for efficient object detection with text and visual prompts, supporting
    both prompt-based and prompt-free inference modes.

    Attributes:
        pe (torch.Tensor): Prompt embeddings for classes.
        clip_model (torch.nn.Module): CLIP model for text encoding.

    Methods:
        __init__: Initialize YOLOE model.
        get_text_pe: Get text positional embeddings.
        get_visual_pe: Get visual embeddings.
        set_vocab: Set vocabulary for prompt-free model.
        get_vocab: Get fused vocabulary layer.
        set_classes: Set classes for offline inference.
        get_cls_pe: Get class positional embeddings.
        predict: Perform forward pass with prompts.
        loss: Compute loss with prompts.

    Examples:
        Initialize a YOLOE model
        >>> model = YOLOEModel("yoloe-v8s.yaml", ch=3, nc=80)
        >>> results = model.predict(image_tensor, tpe=text_embeddings)
    """

    def __init__(self, cfg="yoloe-v8s.yaml", ch=3, nc=None, verbose=True):
        """Initialize YOLOE model with given config and parameters.

        Args:
            cfg (str | dict): Model configuration file path or dictionary.
            ch (int): Number of input channels.
            nc (int, optional): Number of classes.
            verbose (bool): Whether to display model information.
        """
        super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)

    @smart_inference_mode()
    def get_text_pe(self, text, batch=80, cache_clip_model=False, without_reprta=False):
        """Get text positional embeddings for offline inference without CLIP model.

        Args:
            text (list[str]): List of class names.
            batch (int): Batch size for processing text tokens.
            cache_clip_model (bool): Whether to cache the CLIP model.
            without_reprta (bool): Whether to return text embeddings without reprta module processing.

        Returns:
            (torch.Tensor): Text positional embeddings.
        """
        from ultralytics.nn.text_model import build_text_model

        device = next(self.model.parameters()).device
        if not getattr(self, "clip_model", None) and cache_clip_model:
            # For backwards compatibility of models lacking clip_model attribute
            self.clip_model = build_text_model("mobileclip:blt", device=device)

        model = self.clip_model if cache_clip_model else build_text_model("mobileclip:blt", device=device)
        text_token = model.tokenize(text)
        txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)]
        txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0)
        txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
        if without_reprta:
            return txt_feats

        head = self.model[-1]
        assert isinstance(head, YOLOEDetect)
        return head.get_tpe(txt_feats)  # run auxiliary text head

    @smart_inference_mode()
    def get_visual_pe(self, img, visual):
        """Get visual embeddings.

        Args:
            img (torch.Tensor): Input image tensor.
            visual (torch.Tensor): Visual features.

        Returns:
            (torch.Tensor): Visual positional embeddings.
        """
        return self(img, vpe=visual, return_vpe=True)

    def set_vocab(self, vocab, names):
        """Set vocabulary for the prompt-free model.

        Args:
            vocab (nn.ModuleList): List of vocabulary items.
            names (list[str]): List of class names.
        """
        assert not self.training
        head = self.model[-1]
        assert isinstance(head, YOLOEDetect)

        # Cache anchors for head
        device = next(self.parameters()).device
        self(torch.empty(1, 3, self.args["imgsz"], self.args["imgsz"]).to(device))  # warmup

        # re-parameterization for prompt-free model
        self.model[-1].lrpc = nn.ModuleList(
            LRPCHead(cls, pf[-1], loc[-1], enabled=i != 2)
            for i, (cls, pf, loc) in enumerate(zip(vocab, head.cv3, head.cv2))
        )
        for loc_head, cls_head in zip(head.cv2, head.cv3):
            assert isinstance(loc_head, nn.Sequential)
            assert isinstance(cls_head, nn.Sequential)
            del loc_head[-1]
            del cls_head[-1]
        self.model[-1].nc = len(names)
        self.names = check_class_names(names)

    def get_vocab(self, names):
        """Get fused vocabulary layer from the model.

        Args:
            names (list): List of class names.

        Returns:
            (nn.ModuleList): List of vocabulary modules.
        """
        assert not self.training
        head = self.model[-1]
        assert isinstance(head, YOLOEDetect)
        assert not head.is_fused

        tpe = self.get_text_pe(names)
        self.set_classes(names, tpe)
        device = next(self.model.parameters()).device
        head.fuse(self.pe.to(device))  # fuse prompt embeddings to classify head

        vocab = nn.ModuleList()
        for cls_head in head.cv3:
            assert isinstance(cls_head, nn.Sequential)
            vocab.append(cls_head[-1])
        return vocab

    def set_classes(self, names, embeddings):
        """Set classes in advance so that model could do offline-inference without clip model.

        Args:
            names (list[str]): List of class names.
            embeddings (torch.Tensor): Embeddings tensor.
        """
        assert not hasattr(self.model[-1], "lrpc"), (
            "Prompt-free model does not support setting classes. Please try with Text/Visual prompt models."
        )
        assert embeddings.ndim == 3
        self.pe = embeddings
        self.model[-1].nc = len(names)
        self.names = check_class_names(names)

    def get_cls_pe(self, tpe, vpe):
        """Get class positional embeddings.

        Args:
            tpe (torch.Tensor, optional): Text positional embeddings.
            vpe (torch.Tensor, optional): Visual positional embeddings.

        Returns:
            (torch.Tensor): Class positional embeddings.
        """
        all_pe = []
        if tpe is not None:
            assert tpe.ndim == 3
            all_pe.append(tpe)
        if vpe is not None:
            assert vpe.ndim == 3
            all_pe.append(vpe)
        if not all_pe:
            all_pe.append(getattr(self, "pe", torch.zeros(1, 80, 512)))
        return torch.cat(all_pe, dim=1)

    def predict(
        self, x, profile=False, visualize=False, tpe=None, augment=False, embed=None, vpe=None, return_vpe=False
    ):
        """Perform a forward pass through the model.

        Args:
            x (torch.Tensor): The input tensor.
            profile (bool): If True, profile the computation time for each layer.
            visualize (bool): If True, save feature maps for visualization.
            tpe (torch.Tensor, optional): Text positional embeddings.
            augment (bool): If True, perform data augmentation during inference.
            embed (list, optional): A list of feature vectors/embeddings to return.
            vpe (torch.Tensor, optional): Visual positional embeddings.
            return_vpe (bool): If True, return visual positional embeddings.

        Returns:
            (torch.Tensor): Model's output tensor.
        """
        y, dt, embeddings = [], [], []  # outputs
        b = x.shape[0]
        embed = frozenset(embed) if embed is not None else {-1}
        max_idx = max(embed)
        for m in self.model:  # except the head part
            if m.f != -1:  # if not from previous layer
                x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
            if profile:
                self._profile_one_layer(m, x, dt)
            if isinstance(m, YOLOEDetect):
                vpe = m.get_vpe(x, vpe) if vpe is not None else None
                if return_vpe:
                    assert vpe is not None
                    assert not self.training
                    return vpe
                cls_pe = self.get_cls_pe(m.get_tpe(tpe), vpe).to(device=x[0].device, dtype=x[0].dtype)
                if cls_pe.shape[0] != b or m.export:
                    cls_pe = cls_pe.expand(b, -1, -1)
                x = m(x, cls_pe)
            else:
                x = m(x)  # run

            y.append(x if m.i in self.save else None)  # save output
            if visualize:
                feature_visualization(x, m.type, m.i, save_dir=visualize)
            if m.i in embed:
                embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1))  # flatten
                if m.i == max_idx:
                    return torch.unbind(torch.cat(embeddings, 1), dim=0)
        return x

    def loss(self, batch, preds=None):
        """Compute loss.

        Args:
            batch (dict): Batch to compute loss on.
            preds (torch.Tensor | list[torch.Tensor], optional): Predictions.
        """
        if not hasattr(self, "criterion"):
            from ultralytics.utils.loss import TVPDetectLoss

            visual_prompt = batch.get("visuals", None) is not None  # TODO
            self.criterion = TVPDetectLoss(self) if visual_prompt else self.init_criterion()

        if preds is None:
            preds = self.forward(batch["img"], tpe=batch.get("txt_feats", None), vpe=batch.get("visuals", None))
        return self.criterion(preds, batch)


class YOLOESegModel(YOLOEModel, SegmentationModel):
    """YOLOE segmentation model.

    This class extends YOLOEModel to handle instance segmentation tasks with text and visual prompts, providing
    specialized loss computation for pixel-level object detection and segmentation.

    Methods:
        __init__: Initialize YOLOE segmentation model.
        loss: Compute loss with prompts for segmentation.

    Examples:
        Initialize a YOLOE segmentation model
        >>> model = YOLOESegModel("yoloe-v8s-seg.yaml", ch=3, nc=80)
        >>> results = model.predict(image_tensor, tpe=text_embeddings)
    """

    def __init__(self, cfg="yoloe-v8s-seg.yaml", ch=3, nc=None, verbose=True):
        """Initialize YOLOE segmentation model with given config and parameters.

        Args:
            cfg (str | dict): Model configuration file path or dictionary.
            ch (int): Number of input channels.
            nc (int, optional): Number of classes.
            verbose (bool): Whether to display model information.
        """
        super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)

    def loss(self, batch, preds=None):
        """Compute loss.

        Args:
            batch (dict): Batch to compute loss on.
            preds (torch.Tensor | list[torch.Tensor], optional): Predictions.
        """
        if not hasattr(self, "criterion"):
            from ultralytics.utils.loss import TVPSegmentLoss

            visual_prompt = batch.get("visuals", None) is not None  # TODO
            self.criterion = TVPSegmentLoss(self) if visual_prompt else self.init_criterion()

        if preds is None:
            preds = self.forward(batch["img"], tpe=batch.get("txt_feats", None), vpe=batch.get("visuals", None))
        return self.criterion(preds, batch)


class Ensemble(torch.nn.ModuleList):
    """Ensemble of models.

    This class allows combining multiple YOLO models into an ensemble for improved performance through model averaging
    or other ensemble techniques.

    Methods:
        __init__: Initialize an ensemble of models.
        forward: Generate predictions from all models in the ensemble.

    Examples:
        Create an ensemble of models
        >>> ensemble = Ensemble()
        >>> ensemble.append(model1)
        >>> ensemble.append(model2)
        >>> results = ensemble(image_tensor)
    """

    def __init__(self):
        """Initialize an ensemble of models."""
        super().__init__()

    def forward(self, x, augment=False, profile=False, visualize=False):
        """Generate the YOLO network's final layer.

        Args:
            x (torch.Tensor): Input tensor.
            augment (bool): Whether to augment the input.
            profile (bool): Whether to profile the model.
            visualize (bool): Whether to visualize the features.

        Returns:
            y (torch.Tensor): Concatenated predictions from all models.
            train_out (None): Always None for ensemble inference.
        """
        y = [module(x, augment, profile, visualize)[0] for module in self]
        # y = torch.stack(y).max(0)[0]  # max ensemble
        # y = torch.stack(y).mean(0)  # mean ensemble
        y = torch.cat(y, 2)  # nms ensemble, y shape(B, HW, C)
        return y, None  # inference, train output


# Functions ------------------------------------------------------------------------------------------------------------


@contextlib.contextmanager
def temporary_modules(modules=None, attributes=None):
    """Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).

    This function can be used to change the module paths during runtime. It's useful when refactoring code, where you've
    moved a module from one location to another, but you still want to support the old import paths for backwards
    compatibility.

    Args:
        modules (dict, optional): A dictionary mapping old module paths to new module paths.
        attributes (dict, optional): A dictionary mapping old module attributes to new module attributes.

    Examples:
        >>> with temporary_modules({"old.module": "new.module"}, {"old.module.attribute": "new.module.attribute"}):
        >>> import old.module  # this will now import new.module
        >>> from old.module import attribute  # this will now import new.module.attribute

    Notes:
        The changes are only in effect inside the context manager and are undone once the context manager exits.
        Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger
        applications or libraries. Use this function with caution.
    """
    if modules is None:
        modules = {}
    if attributes is None:
        attributes = {}
    import sys
    from importlib import import_module

    try:
        # Set attributes in sys.modules under their old name
        for old, new in attributes.items():
            old_module, old_attr = old.rsplit(".", 1)
            new_module, new_attr = new.rsplit(".", 1)
            setattr(import_module(old_module), old_attr, getattr(import_module(new_module), new_attr))

        # Set modules in sys.modules under their old name
        for old, new in modules.items():
            sys.modules[old] = import_module(new)

        yield
    finally:
        # Remove the temporary module paths
        for old in modules:
            if old in sys.modules:
                del sys.modules[old]


class SafeClass:
    """A placeholder class to replace unknown classes during unpickling."""

    def __init__(self, *args, **kwargs):
        """Initialize SafeClass instance, ignoring all arguments."""
        pass

    def __call__(self, *args, **kwargs):
        """Run SafeClass instance, ignoring all arguments."""
        pass


class SafeUnpickler(pickle.Unpickler):
    """Custom Unpickler that replaces unknown classes with SafeClass."""

    def find_class(self, module, name):
        """Attempt to find a class, returning SafeClass if not among safe modules.

        Args:
            module (str): Module name.
            name (str): Class name.

        Returns:
            (type): Found class or SafeClass.
        """
        safe_modules = (
            "torch",
            "collections",
            "collections.abc",
            "builtins",
            "math",
            "numpy",
            # Add other modules considered safe
        )
        if module in safe_modules:
            return super().find_class(module, name)
        else:
            return SafeClass


def torch_safe_load(weight, safe_only=False):
    """Attempt to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches
    the error, logs a warning message, and attempts to install the missing module via the check_requirements()
    function. After installation, the function again attempts to load the model using torch.load().

    Args:
        weight (str): The file path of the PyTorch model.
        safe_only (bool): If True, replace unknown classes with SafeClass during loading.

    Returns:
        ckpt (dict): The loaded model checkpoint.
        file (str): The loaded filename.

    Examples:
        >>> from ultralytics.nn.tasks import torch_safe_load
        >>> ckpt, file = torch_safe_load("path/to/best.pt", safe_only=True)
    """
    from ultralytics.utils.downloads import attempt_download_asset

    check_suffix(file=weight, suffix=".pt")
    file = attempt_download_asset(weight)  # search online if missing locally
    try:
        with temporary_modules(
            modules={
                "ultralytics.yolo.utils": "ultralytics.utils",
                "ultralytics.yolo.v8": "ultralytics.models.yolo",
                "ultralytics.yolo.data": "ultralytics.data",
            },
            attributes={
                "ultralytics.nn.modules.block.Silence": "torch.nn.Identity",  # YOLOv9e
                "ultralytics.nn.tasks.YOLOv10DetectionModel": "ultralytics.nn.tasks.DetectionModel",  # YOLOv10
                "ultralytics.utils.loss.v10DetectLoss": "ultralytics.utils.loss.E2EDetectLoss",  # YOLOv10
            },
        ):
            if safe_only:
                # Load via custom pickle module
                safe_pickle = types.ModuleType("safe_pickle")
                safe_pickle.Unpickler = SafeUnpickler
                safe_pickle.load = lambda file_obj: SafeUnpickler(file_obj).load()
                with open(file, "rb") as f:
                    ckpt = torch_load(f, pickle_module=safe_pickle)
            else:
                ckpt = torch_load(file, map_location="cpu")

    except ModuleNotFoundError as e:  # e.name is missing module name
        if e.name == "models":
            raise TypeError(
                emojis(
                    f"ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained "
                    f"with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with "
                    f"YOLOv8 at https://github.com/ultralytics/ultralytics."
                    f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
                    f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolo11n.pt'"
                )
            ) from e
        elif e.name == "numpy._core":
            raise ModuleNotFoundError(
                emojis(
                    f"ERROR ❌️ {weight} requires numpy>=1.26.1, however numpy=={__import__('numpy').__version__} is installed."
                )
            ) from e
        LOGGER.warning(
            f"{weight} appears to require '{e.name}', which is not in Ultralytics requirements."
            f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
            f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
            f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolo11n.pt'"
        )
        check_requirements(e.name)  # install missing module
        ckpt = torch_load(file, map_location="cpu")

    if not isinstance(ckpt, dict):
        # File is likely a YOLO instance saved with i.e. torch.save(model, "saved_model.pt")
        LOGGER.warning(
            f"The file '{weight}' appears to be improperly saved or formatted. "
            f"For optimal results, use model.save('filename.pt') to correctly save YOLO models."
        )
        ckpt = {"model": ckpt.model}

    return ckpt, file


def load_checkpoint(weight, device=None, inplace=True, fuse=False):
    """Load a single model weights.

    Args:
        weight (str | Path): Model weight path.
        device (torch.device, optional): Device to load model to.
        inplace (bool): Whether to do inplace operations.
        fuse (bool): Whether to fuse model.

    Returns:
        model (torch.nn.Module): Loaded model.
        ckpt (dict): Model checkpoint dictionary.
    """
    ckpt, weight = torch_safe_load(weight)  # load ckpt
    args = {**DEFAULT_CFG_DICT, **(ckpt.get("train_args", {}))}  # combine model and default args, preferring model args
    model = (ckpt.get("ema") or ckpt["model"]).float()  # FP32 model

    # Model compatibility updates
    model.args = args  # attach args to model
    model.pt_path = weight  # attach *.pt file path to model
    model.task = getattr(model, "task", guess_model_task(model))
    if not hasattr(model, "stride"):
        model.stride = torch.tensor([32.0])

    model = (model.fuse() if fuse and hasattr(model, "fuse") else model).eval().to(device)  # model in eval mode

    # Module updates
    for m in model.modules():
        if hasattr(m, "inplace"):
            m.inplace = inplace
        elif isinstance(m, torch.nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
            m.recompute_scale_factor = None  # torch 1.11.0 compatibility

    # Return model and ckpt
    return model, ckpt


def parse_model(d, ch, verbose=True):
    """Parse a YOLO model.yaml dictionary into a PyTorch model.

    Args:
        d (dict): Model dictionary.
        ch (int): Input channels.
        verbose (bool): Whether to print model details.

    Returns:
        model (torch.nn.Sequential): PyTorch model.
        save (list): Sorted list of output layers.
    """
    import ast

    # Args
    legacy = True  # backward compatibility for v3/v5/v8/v9 models
    max_channels = float("inf")
    nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
    depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
    scale = d.get("scale")
    if scales:
        if not scale:
            scale = next(iter(scales.keys()))
            LOGGER.warning(f"no model scale passed. Assuming scale='{scale}'.")
        depth, width, max_channels = scales[scale]

    if act:
        Conv.default_act = eval(act)  # redefine default activation, i.e. Conv.default_act = torch.nn.SiLU()
        if verbose:
            LOGGER.info(f"{colorstr('activation:')} {act}")  # print

    if verbose:
        LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10}  {'module':<45}{'arguments':<30}")
    ch = [ch]
    layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch out
    base_modules = frozenset(
        {
            Classify,
            Conv,
            ConvTranspose,
            GhostConv,
            Bottleneck,
            GhostBottleneck,
            SPP,
            SPPF,
            C2fPSA,
            C2PSA,
            DWConv,
            Focus,
            BottleneckCSP,
            C1,
            C2,
            C2f,
            C3k2,
            RepNCSPELAN4,
            ELAN1,
            ADown,
            AConv,
            SPPELAN,
            C2fAttn,
            C3,
            C3TR,
            C3Ghost,
            torch.nn.ConvTranspose2d,
            DWConvTranspose2d,
            C3x,
            RepC3,
            PSA,
            SCDown,
            C2fCIB,
            A2C2f,
        }
    )
    repeat_modules = frozenset(  # modules with 'repeat' arguments
        {
            BottleneckCSP,
            C1,
            C2,
            C2f,
            C3k2,
            C2fAttn,
            C3,
            C3TR,
            C3Ghost,
            C3x,
            RepC3,
            C2fPSA,
            C2fCIB,
            C2PSA,
            A2C2f,
        }
    )
    for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]):  # from, number, module, args
        m = (
            getattr(torch.nn, m[3:])
            if "nn." in m
            else getattr(__import__("torchvision").ops, m[16:])
            if "torchvision.ops." in m
            else globals()[m]
        )  # get module
        for j, a in enumerate(args):
            if isinstance(a, str):
                with contextlib.suppress(ValueError):
                    args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
        n = n_ = max(round(n * depth), 1) if n > 1 else n  # depth gain
        if m in base_modules:
            c1, c2 = ch[f], args[0]
            if c2 != nc:  # if c2 != nc (e.g., Classify() output)
                c2 = make_divisible(min(c2, max_channels) * width, 8)
            if m is C2fAttn:  # set 1) embed channels and 2) num heads
                args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8)
                args[2] = int(max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2])

            args = [c1, c2, *args[1:]]
            if m in repeat_modules:
                args.insert(2, n)  # number of repeats
                n = 1
            if m is C3k2:  # for M/L/X sizes
                legacy = False
                if scale in "mlx":
                    args[3] = True
            if m is A2C2f:
                legacy = False
                if scale in "lx":  # for L/X sizes
                    args.extend((True, 1.2))
            if m is C2fCIB:
                legacy = False
        elif m is AIFI:
            args = [ch[f], *args]
        elif m in frozenset({HGStem, HGBlock}):
            c1, cm, c2 = ch[f], args[0], args[1]
            args = [c1, cm, c2, *args[2:]]
            if m is HGBlock:
                args.insert(4, n)  # number of repeats
                n = 1
        elif m is ResNetLayer:
            c2 = args[1] if args[3] else args[1] * 4
        elif m is torch.nn.BatchNorm2d:
            args = [ch[f]]
        elif m in _CUSTOM_ATTN:
            c1 = ch[f]
            c2 = ch[f]  # 注意力不改变通道数
            args = [c1, *args]
        elif m is Concat:
            c2 = sum(ch[x] for x in f)
        elif m in frozenset(
            {Detect, WorldDetect, YOLOEDetect, Segment, YOLOESegment, Pose, OBB, ImagePoolingAttn, v10Detect}
        ):
            args.append([ch[x] for x in f])
            if m is Segment or m is YOLOESegment:
                args[2] = make_divisible(min(args[2], max_channels) * width, 8)
            if m in {Detect, YOLOEDetect, Segment, YOLOESegment, Pose, OBB}:
                m.legacy = legacy
        elif m is RTDETRDecoder:  # special case, channels arg must be passed in index 1
            args.insert(1, [ch[x] for x in f])
        elif m is CBAM:
            c1, c2 = ch[f], args[0]
            if c2 != nc:
                c2 = make_divisible(min(c2, max_channels) * width, 8)
            args = [c1, *args[1:]]
        elif m is CBLinear:
            c2 = args[0]
            c1 = ch[f]
            args = [c1, c2, *args[1:]]
        elif m is CBFuse:
            c2 = ch[f[-1]]
        elif m in frozenset({TorchVision, Index}):
            c2 = args[0]
            c1 = ch[f]
            args = [*args[1:]]
        else:
            c2 = ch[f]

        m_ = torch.nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)  # module
        t = str(m)[8:-2].replace("__main__.", "")  # module type
        m_.np = sum(x.numel() for x in m_.parameters())  # number params
        m_.i, m_.f, m_.type = i, f, t  # attach index, 'from' index, type
        if verbose:
            LOGGER.info(f"{i:>3}{f!s:>20}{n_:>3}{m_.np:10.0f}  {t:<45}{args!s:<30}")  # print
        save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelist
        layers.append(m_)
        if i == 0:
            ch = []
        ch.append(c2)
    return torch.nn.Sequential(*layers), sorted(save)


def yaml_model_load(path):
    """Load a YOLOv8 model from a YAML file.

    Args:
        path (str | Path): Path to the YAML file.

    Returns:
        (dict): Model dictionary.
    """
    path = Path(path)
    if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)):
        new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem)
        LOGGER.warning(f"Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.")
        path = path.with_name(new_stem + path.suffix)

    unified_path = re.sub(r"(\d+)([nslmx])(.+)?$", r"\1\3", str(path))  # i.e. yolov8x.yaml -> yolov8.yaml
    yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path)
    d = YAML.load(yaml_file)  # model dict
    d["scale"] = guess_model_scale(path)
    d["yaml_file"] = str(path)
    return d


def guess_model_scale(model_path):
    """Extract the size character n, s, m, l, or x of the model's scale from the model path.

    Args:
        model_path (str | Path): The path to the YOLO model's YAML file.

    Returns:
        (str): The size character of the model's scale (n, s, m, l, or x).
    """
    try:
        return re.search(r"yolo(e-)?[v]?\d+([nslmx])", Path(model_path).stem).group(2)
    except AttributeError:
        return ""


def guess_model_task(model):
    """Guess the task of a PyTorch model from its architecture or configuration.

    Args:
        model (torch.nn.Module | dict): PyTorch model or model configuration in YAML format.

    Returns:
        (str): Task of the model ('detect', 'segment', 'classify', 'pose', 'obb').
    """

    def cfg2task(cfg):
        """Guess from YAML dictionary."""
        m = cfg["head"][-1][-2].lower()  # output module name
        if m in {"classify", "classifier", "cls", "fc"}:
            return "classify"
        if "detect" in m:
            return "detect"
        if "segment" in m:
            return "segment"
        if m == "pose":
            return "pose"
        if m == "obb":
            return "obb"

    # Guess from model cfg
    if isinstance(model, dict):
        with contextlib.suppress(Exception):
            return cfg2task(model)
    # Guess from PyTorch model
    if isinstance(model, torch.nn.Module):  # PyTorch model
        for x in "model.args", "model.model.args", "model.model.model.args":
            with contextlib.suppress(Exception):
                return eval(x)["task"]  # nosec B307: safe eval of known attribute paths
        for x in "model.yaml", "model.model.yaml", "model.model.model.yaml":
            with contextlib.suppress(Exception):
                return cfg2task(eval(x))  # nosec B307: safe eval of known attribute paths
        for m in model.modules():
            if isinstance(m, (Segment, YOLOESegment)):
                return "segment"
            elif isinstance(m, Classify):
                return "classify"
            elif isinstance(m, Pose):
                return "pose"
            elif isinstance(m, OBB):
                return "obb"
            elif isinstance(m, (Detect, WorldDetect, YOLOEDetect, v10Detect)):
                return "detect"

    # Guess from model filename
    if isinstance(model, (str, Path)):
        model = Path(model)
        if "-seg" in model.stem or "segment" in model.parts:
            return "segment"
        elif "-cls" in model.stem or "classify" in model.parts:
            return "classify"
        elif "-pose" in model.stem or "pose" in model.parts:
            return "pose"
        elif "-obb" in model.stem or "obb" in model.parts:
            return "obb"
        elif "detect" in model.parts:
            return "detect"

    # Unable to determine task from model
    LOGGER.warning(
        "Unable to automatically guess model task, assuming 'task=detect'. "
        "Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'."
    )
    return "detect"  # assume detect

五、拓展、理解小目标层P3、中目标层P4、大目标层P5....等

【如果你使用CPU跑模型】

import torch
import os
from ultralytics import YOLO

# 解决OMP冲突(Windows必加)
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

# 1. 加载你的YOLO11s模型(带CBAM的也可以)
model = YOLO("yolo11s.pt")
# 把模型设为eval模式,避免BatchNorm干扰
model.eval()

# 2. 创建一个和你训练时一样尺寸的测试张量(1张图,3通道,1280×1280)
dummy_input = torch.randn(1, 3, 1280, 1280).to("cuda" if torch.cuda.is_available() else "cpu")

# 3. 遍历模型每一层,打印输出尺寸(关键!)
print("=== 各层输出尺寸(尺寸越大=小目标层)===")
x = dummy_input
for i, layer in enumerate(model.model.model):
    x = layer(x)
    # 只打印4维特征图(排除分类/回归头的一维输出)
    if len(x.shape) == 4:
        # 修改点:使用三引号 """ 将字符串包裹起来,允许内部换行
        print(f"""层{i}:输出尺寸 = {x.shape} → 对应目标类型:{
            '极小目标(P2)' if x.shape[2] == 160 else
            '中小目标(P3)' if x.shape[2] == 80 else
            '大目标(P4)' if x.shape[2] == 40 else '其他层'
        }""")

【如果你使用GPU跑模型】

import torch
import os
from ultralytics import YOLO

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

# 1. 加载模型
model = YOLO("yolo11s.pt")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.model.to(device)  # 移动模型到 GPU

# 2. 定义字典存储输出
outputs = {}


# 3. 定义钩子函数
def get_output(idx):
    def hook(module, input, output):
        # 只保存4维特征图
        if isinstance(output, torch.Tensor) and len(output.shape) == 4:
            outputs[idx] = output

    return hook


# 4. 【关键修改】直接遍历 model.model.model 列表
# model.model 是 DetectionModel
# model.model.model 才是真正的 nn.Sequential 列表
try:
    # 尝试获取模型中的Sequential列表
    model_list = model.model.model
    if not isinstance(model_list, torch.nn.Sequential):
        # 如果获取的不是Sequential,尝试直接遍历model.model(虽然通常不行,但做个兼容)
        print("警告:未能直接找到 model.model.model 列表,尝试直接遍历 model.model...")
        model_list = model.model

    for i, layer in enumerate(model_list):
        # 注册钩子,使用索引 i 作为 key
        layer.register_forward_hook(get_output(i))

except Exception as e:
    print(f"遍历模型层出错: {e}")
    print("尝试打印模型结构以调试:")
    print(model.model)

# 5. 运行推理
dummy_input = torch.randn(1, 3, 1280, 1280).to(device)
with torch.no_grad():
    _ = model.model(dummy_input)

# 6. 格式化打印
print("-" * 80)
print(f"{'Layer Index':<12} | {'Output Shape (C, H, W)':<25} | {'Target Type'}")
print("-" * 80)

if not outputs:
    print("未捕获到任何输出,请检查模型结构或钩子注册逻辑。")
else:
    # 按索引排序
    for idx in sorted(outputs.keys()):
        out = outputs[idx]
        c, h, w = out.shape[1], out.shape[2], out.shape[3]

        target_type = (
            '极小目标层' if h == 160 else
            '中小目标层' if h == 80 else
            '大目标层' if h == 40 else '其他层'
        )

        print(f"{idx:<12} | {str((c, h, w)):<25} | {target_type}")

print("-" * 80)

更多推荐