all-MiniLM-L6-v2编程实践:构建批量文本处理脚本

1. 引言:为什么你需要一个批量文本处理脚本?

想象一下这个场景:你手头有几百份产品描述、客户反馈或者新闻稿,你想快速找出哪些内容在讨论相似的主题,或者想为这些文本建立一个智能的搜索系统。如果手动处理,这可能需要几天甚至几周的时间。而今天,我们就要用 all-MiniLM-L6-v2 这个轻量级但强大的模型,来构建一个自动化脚本,帮你把这项繁琐的工作变得轻松简单。

all-MiniLM-L6-v2 是一个专门为句子生成“语义指纹”(专业术语叫“嵌入向量”)而设计的模型。简单来说,它能把一段文字转换成一串数字,这串数字能代表这段文字的意思。意思相近的文字,转换出来的数字串也会很相似。我们的脚本,就是利用这个特性,来批量处理文本,完成相似度计算、文本聚类等任务。

本文将带你从零开始,完成以下几步:

  1. 快速部署 all-MiniLM-L6-v2 模型服务。
  2. 编写一个Python脚本,实现批量文本的嵌入向量生成。
  3. 基于生成的向量,实现文本相似度计算和简单聚类。
  4. 提供完整的代码和实用技巧,让你能直接上手应用到自己的项目中。

无论你是数据分析师、产品经理,还是开发者,只要你有批量处理文本的需求,这篇教程都能给你一个清晰、可落地的解决方案。

2. 环境准备与模型服务部署

在开始编写脚本之前,我们需要先让模型“跑起来”。这里我们选择使用 Ollama 来部署,因为它非常简单,几乎是一键式的。

2.1 安装Ollama

Ollama 是一个本地运行大模型的工具。首先,访问 Ollama官网 并根据你的操作系统(Windows, macOS, Linux)下载对应的安装包。安装过程就像安装普通软件一样简单。

安装完成后,打开你的终端(或命令提示符/PowerShell),运行以下命令来拉取 all-MiniLM-L6-v2 模型:

ollama pull all-minilm

这个命令会从模型库中下载模型文件。由于 all-MiniLM-L6-v2 非常轻量(约22.7MB),下载会很快。

2.2 启动模型服务

下载完成后,我们需要以API服务的形式运行它,这样我们的Python脚本才能调用。在终端中运行:

ollama run all-minilm

默认情况下,Ollama会在 http://localhost:11434 提供一个API接口。你可以打开浏览器,访问 http://localhost:11434,如果看到Ollama的相关信息,说明服务已经成功启动。

为了验证嵌入功能是否正常,我们可以用 curl 命令快速测试一下(如果你没有 curl,也可以跳过,我们后面用Python测试):

curl http://localhost:11434/api/embeddings -d '{
  "model": "all-minilm",
  "prompt": "Hello, world!"
}'

如果返回一串很长的数字(向量),就说明一切就绪了!

3. 核心脚本编写:批量生成文本嵌入

模型服务跑起来后,我们就可以开始编写核心的Python脚本了。这个脚本主要做三件事:读取文本、调用模型API获取嵌入向量、保存结果。

首先,确保你安装了必要的Python库:

pip install requests pandas numpy

接下来,我们创建一个名为 batch_embedding.py 的Python文件。

3.1 脚本基础框架

我们先搭建脚本的骨架,定义好配置和主函数。

import requests
import json
import pandas as pd
import numpy as np
from typing import List, Dict, Any
import time
import logging

# 配置
OLLAMA_API_URL = "http://localhost:11434/api/embeddings"
MODEL_NAME = "all-minilm"
BATCH_SIZE = 10  # 每次请求处理的文本数量,避免单次请求过大
SLEEP_TIME = 0.1  # 请求间隔,避免对服务端造成压力

# 设置日志,方便查看运行过程
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def main():
    """主函数"""
    # 1. 加载待处理的文本数据
    texts = load_texts("input_texts.txt")  # 假设文本保存在这个文件,每行一段
    
    # 2. 批量生成嵌入向量
    embeddings = batch_generate_embeddings(texts)
    
    # 3. 保存结果
    save_results(texts, embeddings, "output_embeddings.csv")
    
    logger.info("批量文本嵌入处理完成!")

if __name__ == "__main__":
    main()

3.2 实现文本加载函数

我们需要一个函数来从文件(或数据库)中读取文本。这里以简单的文本文件为例,每行一段文本。

def load_texts(file_path: str) -> List[str]:
    """
    从文本文件中加载文本,每行作为一条独立文本。
    
    参数:
        file_path: 文本文件路径
        
    返回:
        文本列表
    """
    texts = []
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()  # 去除首尾空白字符
                if line:  # 忽略空行
                    texts.append(line)
        logger.info(f"成功从 {file_path} 加载了 {len(texts)} 条文本。")
    except FileNotFoundError:
        logger.error(f"文件 {file_path} 未找到,请检查路径。")
        # 这里我们提供一个示例数据,方便直接运行脚本测试
        texts = [
            "机器学习是人工智能的一个分支。",
            "深度学习利用神经网络进行特征学习。",
            "Python是一种流行的编程语言。",
            "Java广泛应用于企业级开发。",
            "今天的天气非常晴朗。",
            "我打算周末去公园散步。"
        ]
        logger.info(f"已使用内置示例数据,共 {len(texts)} 条。")
    return texts

3.3 实现核心的嵌入生成函数

这是脚本的核心。我们将文本分批发送给Ollama API,并处理返回结果。

def get_single_embedding(text: str) -> List[float]:
    """
    获取单条文本的嵌入向量。
    
    参数:
        text: 输入文本
        
    返回:
        嵌入向量(浮点数列表)
    """
    payload = {
        "model": MODEL_NAME,
        "prompt": text
    }
    
    try:
        response = requests.post(OLLAMA_API_URL, json=payload, timeout=30)
        response.raise_for_status()  # 如果响应状态码不是200,抛出异常
        result = response.json()
        return result.get("embedding", [])
    except requests.exceptions.RequestException as e:
        logger.error(f"请求失败 for text: '{text[:50]}...'。错误: {e}")
        return []  # 返回空列表,外部函数需要处理这种情况
    except json.JSONDecodeError as e:
        logger.error(f"解析响应JSON失败 for text: '{text[:50]}...'。错误: {e}")
        return []

def batch_generate_embeddings(texts: List[str]) -> List[List[float]]:
    """
    批量生成文本嵌入向量,支持分批处理。
    
    参数:
        texts: 文本列表
        
    返回:
        嵌入向量列表,与输入文本顺序一一对应
    """
    all_embeddings = []
    total = len(texts)
    
    for i in range(0, total, BATCH_SIZE):
        batch = texts[i:i+BATCH_SIZE]
        logger.info(f"正在处理第 {i//BATCH_SIZE + 1}/{(total-1)//BATCH_SIZE + 1} 批,本批 {len(batch)} 条文本。")
        
        batch_embeddings = []
        for text in batch:
            embedding = get_single_embedding(text)
            if embedding:  # 只有成功获取到向量才加入
                batch_embeddings.append(embedding)
            else:
                # 如果某条失败,用一个零向量占位,后续可根据需要处理
                logger.warning(f"文本 '{text[:50]}...' 获取嵌入失败,使用零向量填充。")
                # all-MiniLM-L6-v2 的向量维度是384
                batch_embeddings.append([0.0] * 384)
            time.sleep(SLEEP_TIME)  # 短暂停顿,友好访问
        
        all_embeddings.extend(batch_embeddings)
    
    logger.info(f"所有 {total} 条文本的嵌入向量已生成。")
    return all_embeddings

3.4 实现结果保存函数

将文本和其对应的嵌入向量保存起来,方便后续使用。这里我们用CSV格式保存,同时用NumPy保存向量数组以便高效读取。

def save_results(texts: List[str], embeddings: List[List[float]], csv_path: str):
    """
    将文本和嵌入向量保存到CSV和NPZ文件。
    
    参数:
        texts: 原始文本列表
        embeddings: 嵌入向量列表
        csv_path: 输出的CSV文件路径
    """
    if len(texts) != len(embeddings):
        logger.error("文本数量和嵌入向量数量不匹配,无法保存。")
        return
    
    # 保存到CSV(文本和向量分开,向量可以只存路径或序列化字符串,这里存前5维示例)
    data = []
    for idx, (text, emb) in enumerate(zip(texts, embeddings)):
        # 为了在CSV中可读,我们只存储向量的前几个维度作为示例
        sample_emb = emb[:5] if len(emb) >= 5 else emb
        data.append({
            "id": idx,
            "text": text,
            "embedding_sample": str(sample_emb)  # 将列表转为字符串存储
        })
    
    df = pd.DataFrame(data)
    df.to_csv(csv_path, index=False, encoding='utf-8-sig')
    logger.info(f"文本和嵌入向量示例已保存到 {csv_path}")
    
    # 将完整的向量数组保存为NumPy二进制文件,便于后续机器学习任务快速加载
    np_embeddings = np.array(embeddings)
    npz_path = csv_path.replace('.csv', '_embeddings.npz')
    np.savez(npz_path, embeddings=np_embeddings, texts=texts)
    logger.info(f"完整的嵌入向量数组已保存到 {npz_path}")

4. 进阶应用:基于嵌入向量的文本处理

生成了嵌入向量,数据就变成了计算机能理解的“数字形式”。现在,我们可以玩点更高级的了。我们在同一个目录下创建另一个脚本 text_analysis.py

4.1 计算文本相似度

这是最直接的应用。我们可以计算任意两段文本的语义相似度。

# text_analysis.py
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def load_embeddings(npz_path: str):
    """加载之前保存的嵌入向量和文本"""
    data = np.load(npz_path, allow_pickle=True)
    return data['embeddings'], data['texts']

def find_similar_texts(query_text: str, all_texts: List[str], all_embeddings: np.ndarray, top_k: int = 3):
    """
    给定一段查询文本,找出数据集中最相似的top_k个文本。
    这里假设query_text的嵌入向量已经预先计算好,或者需要实时调用3.3节的函数获取。
    为了示例,我们假设query_embedding是传入的。
    """
    # 假设我们已经有了查询文本的嵌入向量 query_embedding
    # query_embedding = get_single_embedding(query_text) # 需要调用之前的函数
    # 为了演示,我们随机选一个向量作为查询(实际应用中需计算)
    if len(all_embeddings) == 0:
        return []
    
    # 这里演示:随机选择一个现有向量作为“查询”
    query_idx = 0
    query_embedding = all_embeddings[query_idx].reshape(1, -1)
    query_text = all_texts[query_idx]
    
    logger.info(f"查询文本: '{query_text}'")
    
    # 计算余弦相似度
    similarities = cosine_similarity(query_embedding, all_embeddings).flatten()
    
    # 获取相似度最高的top_k个索引(排除自己)
    top_indices = np.argsort(similarities)[-top_k-1:-1][::-1]  # 取前top_k个,排除自身(如果query在集合中)
    
    results = []
    for idx in top_indices:
        results.append({
            "text": all_texts[idx],
            "similarity": similarities[idx]
        })
        logger.info(f"相似度 {similarities[idx]:.4f}: {all_texts[idx][:60]}...")
    
    return results

# 示例用法
if __name__ == "__main__":
    emb, txts = load_embeddings("output_embeddings_embeddings.npz")
    print("相似度查找示例:")
    find_similar_texts("", txts, emb, top_k=3)

4.2 简单文本聚类

我们还可以用这些向量把文本分成不同的组。

# 继续在 text_analysis.py 中添加
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

def cluster_texts(embeddings: np.ndarray, texts: List[str], n_clusters: int = 3):
    """
    对文本嵌入向量进行聚类。
    
    参数:
        embeddings: 嵌入向量矩阵
        texts: 对应的文本列表
        n_clusters: 预设的聚类数量
    """
    if len(embeddings) < n_clusters:
        logger.warning(f"文本数量 {len(embeddings)} 少于聚类数 {n_clusters},无法聚类。")
        return None
    
    # 使用K-Means聚类
    kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
    cluster_labels = kmeans.fit_predict(embeddings)
    
    # 打印每个簇的文本示例
    for cluster_id in range(n_clusters):
        cluster_text_indices = np.where(cluster_labels == cluster_id)[0]
        logger.info(f"\n--- 簇 {cluster_id} (共 {len(cluster_text_indices)} 条) ---")
        # 打印该簇的前3条文本
        for idx in cluster_text_indices[:3]:
            logger.info(f"  - {texts[idx][:80]}...")
        if len(cluster_text_indices) > 3:
            logger.info(f"  ... 以及另外 {len(cluster_text_indices) - 3} 条")
    
    # 可视化(降维到2D以便绘图)
    pca = PCA(n_components=2)
    embeddings_2d = pca.fit_transform(embeddings)
    
    plt.figure(figsize=(10, 6))
    scatter = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=cluster_labels, cmap='viridis', alpha=0.7)
    plt.colorbar(scatter, label='簇标签')
    plt.title('文本嵌入向量聚类可视化 (PCA降维)')
    plt.xlabel('主成分 1')
    plt.ylabel('主成分 2')
    plt.tight_layout()
    plt.savefig('text_clusters_visualization.png')
    logger.info("聚类可视化图已保存为 'text_clusters_visualization.png'")
    # plt.show() # 如果在Jupyter或支持的环境中可以显示
    
    return cluster_labels

# 示例用法
if __name__ == "__main__":
    # ... 之前的加载代码 ...
    print("\n聚类分析示例:")
    labels = cluster_texts(emb, txts, n_clusters=2)

5. 脚本优化与实用技巧

基本的脚本跑通后,我们可以从工程化角度做一些优化,让它更健壮、更高效。

5.1 增加错误处理与重试机制

网络请求可能失败,我们需要让脚本更稳定。

# 修改 batch_embedding.py 中的 get_single_embedding 函数
def get_single_embedding_robust(text: str, max_retries: int = 3) -> List[float]:
    """
    带重试机制的单条文本嵌入获取。
    """
    for attempt in range(max_retries):
        try:
            embedding = get_single_embedding(text)  # 调用原来的函数
            if embedding:
                return embedding
            else:
                logger.warning(f"第{attempt+1}次尝试获取嵌入失败(返回空),文本: '{text[:50]}...'")
        except Exception as e:
            logger.warning(f"第{attempt+1}次尝试获取嵌入失败,文本: '{text[:50]}...',错误: {e}")
        
        if attempt < max_retries - 1:
            wait_time = 2 ** attempt  # 指数退避
            logger.info(f"等待 {wait_time} 秒后重试...")
            time.sleep(wait_time)
    
    logger.error(f"经过 {max_retries} 次重试仍失败,文本: '{text[:50]}...',返回零向量。")
    return [0.0] * 384

5.2 添加进度显示

处理大量文本时,一个进度条能让人安心。

# 修改 batch_generate_embeddings 函数,加入进度显示
from tqdm import tqdm  # 需要安装: pip install tqdm

def batch_generate_embeddings_with_progress(texts: List[str]) -> List[List[float]]:
    """带进度条的批量生成函数"""
    all_embeddings = []
    total = len(texts)
    
    # 使用tqdm创建进度条
    with tqdm(total=total, desc="生成嵌入向量") as pbar:
        for i in range(0, total, BATCH_SIZE):
            batch = texts[i:i+BATCH_SIZE]
            for text in batch:
                embedding = get_single_embedding_robust(text)
                all_embeddings.append(embedding)
                pbar.update(1)
                time.sleep(SLEEP_TIME)
    
    return all_embeddings

5.3 处理长文本

all-MiniLM-L6-v2 最大支持256个token。对于更长的文本,我们需要进行分割。

def split_long_text(text: str, max_tokens: int = 256, overlap_tokens: int = 50) -> List[str]:
    """
    简单按空格分割长文本,并允许重叠以避免在句子中间切断。
    这是一个简单示例,生产环境建议使用更专业的分词器。
    """
    words = text.split()
    chunks = []
    start = 0
    
    while start < len(words):
        end = min(start + max_tokens, len(words))
        chunk = ' '.join(words[start:end])
        chunks.append(chunk)
        start = start + max_tokens - overlap_tokens  # 重叠一部分
        
        if start >= len(words):
            break
    
    return chunks

def get_embedding_for_long_text(long_text: str) -> List[float]:
    """
    获取长文本的嵌入向量(通过对各分块向量取平均)。
    """
    chunks = split_long_text(long_text)
    if not chunks:
        return [0.0] * 384
    
    chunk_embeddings = []
    for chunk in chunks:
        emb = get_single_embedding_robust(chunk)
        if emb:
            chunk_embeddings.append(emb)
    
    if not chunk_embeddings:
        return [0.0] * 384
    
    # 计算平均向量
    avg_embedding = np.mean(chunk_embeddings, axis=0).tolist()
    return avg_embedding

6. 总结

通过这篇教程,我们完成了一个从模型部署到脚本开发,再到实际应用的完整闭环。让我们回顾一下关键步骤和收获:

  1. 快速部署:利用 Ollama,我们几乎零配置地在本机跑起了 all-MiniLM-L6-v2 嵌入模型服务。
  2. 核心脚本:我们编写了一个健壮的Python脚本,能够批量读取文本、调用API生成语义向量、并妥善保存结果。脚本包含了错误重试、进度提示等工程化细节。
  3. 进阶应用:我们展示了如何利用生成的嵌入向量进行文本相似度计算和简单聚类,让数据产生实际价值。
  4. 优化技巧:我们探讨了处理长文本、增加稳定性等实用技巧,让脚本更适合真实生产环境。

这个脚本可以成为你处理文本数据的“瑞士军刀”。你可以将它用于:

  • 智能文档管理:自动为文档库建立语义索引,实现“意思搜索”而非“关键词搜索”。
  • 客户反馈分析:将海量反馈自动分类(如产品功能、服务质量、价格问题),快速洞察核心议题。
  • 内容去重与推荐:发现相似的文章或帖子,用于内容聚合或个性化推荐。
  • 问答系统基础:作为检索式问答系统的第一步,快速找到与问题最相关的文档段落。

all-MiniLM-L6-v2 以其小巧的体积和不错的效果,在效率与性能之间取得了很好的平衡,非常适合作为批量文本处理任务的起点。希望这个脚本能为你打开语义文本处理的大门。


获取更多AI镜像

想探索更多AI镜像和应用场景?访问 CSDN星图镜像广场,提供丰富的预置镜像,覆盖大模型推理、图像生成、视频生成、模型微调等多个领域,支持一键部署。

更多推荐