网易首页 > 网易号 > 正文 申请入驻

利用PyTorch的三元组损失Hard Triplet Loss进行嵌入模型微调

0
分享至

本文介绍如何使用 PyTorch 和三元组边缘损失 (Triplet Margin Loss) 微调嵌入模型,并重点阐述实现细节和代码示例。三元组损失是一种对比损失函数,通过缩小锚点与正例间的距离,同时扩大锚点与负例间的距离来优化模型。

数据集准备与处理

一般的嵌入模型都会使用Sentence Transformer ,其中的 encode() 方法可以直接处理文本输入。但是为了进行微调,我们需要采用 Transformer 库,所以就要将文本转换为模型可接受的 token IDs 和 attention masks。Token IDs 代表模型词汇表中的词或字符,attention masks 用于防止模型关注填充 tokens。

本文使用 thenlper/gte-base 模型,需要对应的 tokenizer 对文本进行预处理。该模型基于 BertModel 架构:

BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(30522, 768, padding_idx=0)
(position_embeddings): Embedding(512, 768)
(token_type_embeddings): Embedding(2, 768)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0-11): 12 x BertLayer(
(attention): BertAttention(
(self): BertSdpaSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
(intermediate_act_fn): GELUActivation()
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)

利用 Transformers 库的 AutoTokenizer 和 AutoModel 可以简化模型加载过程,无需手动处理底层架构和配置细节。

from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
tokenizer = AutoTokenizer.from_pretrained("thenlper/gte-base")
# 获取文本并进行标记
train_texts = [df_train.loc[i]['content'] for i in range(df_train.shape[0])]
dev_texts = [df_dev.loc[i]['content'] for i in range(df_dev.shape[0])]
test_texts = [df_test.loc[i]['content'] for i in range(df_test.shape[0])]
train_tokens = []
train_attention_masks = []
dev_tokens = []
dev_attention_masks = []
test_tokens = []
test_attention_masks = []
for sent in tqdm(train_texts):
encoding = tokenizer(sent, truncation=True, padding='max_length', return_tensors='pt')
train_tokens.append(encoding['input_ids'].squeeze(0))
train_attention_masks.append(encoding['attention_mask'].squeeze(0))
for sent in tqdm(dev_texts):
encoding = tokenizer(sent, truncation=True, padding='max_length', return_tensors='pt')
dev_tokens.append(encoding['input_ids'].squeeze(0))
dev_attention_masks.append(encoding['attention_mask'].squeeze(0))
for sent in tqdm(test_texts):
encoding = tokenizer(sent, truncation=True, padding='max_length', return_tensors='pt')
test_tokens.append(encoding['input_ids'].squeeze(0))
test_attention_masks.append(encoding['attention_mask'].squeeze(0))

获取 token IDs 和 attention masks 后,需要将其存储并创建一个自定义的 PyTorch 数据集。

import random
from collections import defaultdict
import torch
from torch.utils.data import Dataset, DataLoader, Sampler, SequentialSampler
class CustomTripletDataset(Dataset):
def __init__(self, tokens, attention_masks, labels):
self.tokens = tokens
self.attention_masks = attention_masks
self.labels = torch.Tensor(labels)
self.label_dict = defaultdict(list)
for i in range(len(tokens)):
self.label_dict[int(self.labels[i])].append(i)
self.unique_classes = list(self.label_dict.keys())
def __len__(self):
return len(self.tokens)
def __getitem__(self, index):
ids = self.tokens[index].to(device)
ams = self.attention_masks[index].to(device)
y = self.labels[index].to(device)
return ids, ams, y

由于采用三元组损失,需要从数据集中采样正例和负例。label_dict 字典用于存储每个类别及其对应的数据索引,方便随机采样。DataLoader 用于加载数据集:

train_loader = DataLoader(train_dataset, batch_sampler=train_batch_sampler)

其中 train_batch_sampler 是自定义的批次采样器:

class CustomBatchSampler(SequentialSampler):
def __init__(self, dataset, batch_size):
self.dataset = dataset
self.batch_size = batch_size
self.unique_classes = sorted(dataset.unique_classes)
self.label_dict = dataset.label_dict
self.num_batches = len(self.dataset) // self.batch_size
self.class_size = self.batch_size // 4
def __iter__(self):
total_samples_used = 0
weights = np.repeat(1, len(self.unique_classes))
while total_samples_used < len(self.dataset):
batch = []
classes = []
for _ in range(4):
next_selected_class = self._select_class(weights)
while next_selected_class in classes:
next_selected_class = self._select_class(weights)
weights[next_selected_class] += 1
classes.append(next_selected_class)
new_choices = self.label_dict[next_selected_class]
remaining_samples = list(np.random.choice(new_choices, min(self.class_size, len(new_choices)), replace=False))
batch.extend(remaining_samples)
total_samples_used += len(batch)
yield batch
def _select_class(self, weights):
dist = 1/weights
dist = dist/np.sum(dist)
selected = int(np.random.choice(self.unique_classes, p=dist))
return selected
def __len__(self):
return self.num_batches

自定义批次采样器控制训练批次的构成,本文的实现确保每个批次包含 4 个类别,每个类别包含 8 个数据点。验证采样器则确保验证集批次在不同 epoch 间保持一致。

模型构建

嵌入模型通常基于 Transformer 架构,输出每个 token 的嵌入。为了获得句子嵌入,需要对 token 嵌入进行汇总。常用的方法包括 CLS 池化和平均池化。本文使用的 gte-base 模型采用平均池化,需要从模型输出中提取 token 嵌入并计算平均值。

import torch.nn.functional as F
import torch.nn as nn
class EmbeddingModel(nn.Module):
def __init__(self, base_model):
super().__init__()
self.base_model = base_model
def average_pool(self, last_hidden_states, attention_mask):
# 平均 token 嵌入
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
def forward(self, input_ids, attention_mask):
outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
last_hidden_state = outputs.last_hidden_state
pooled_output = self.average_pool(last_hidden_state, attention_mask)
normalized_output = F.normalize(pooled_output, p=2, dim=1)
return normalized_output
base_model = AutoModel.from_pretrained("thenlper/gte-base")
model = EmbeddingModel(base_model)

EmbeddingModel 类封装了 Hugging Face 模型,并实现了平均池化和嵌入归一化。

模型训练

训练循环中需要动态计算每个锚点的最难正例和最难负例。

import numpy as np
def train(model, train_loader, criterion, optimizer, scheduler):
model.train()
epoch_train_losses = []
for idx, (ids, attention_masks, labels) in enumerate(train_loader):
optimizer.zero_grad()
embeddings = model(ids, attention_masks)
distance_matrix = torch.cdist(embeddings, embeddings, p=2) # 创建方形距离矩阵
anchors = []
positives = []
negatives = []
for i in range(len(labels)):
anchor_label = labels[i].item()
anchor_distance = distance_matrix[i] # 锚点与所有其他点之间的距离
# 最难的正例(同一类别中最远的)
hardest_positive_idx = (labels == anchor_label).nonzero(as_tuple=True)[0] # 所有同类索引
hardest_positive_idx = hardest_positive_idx[hardest_positive_idx != i] # 排除自己的标签
hardest_positive = hardest_positive_idx[anchor_distance[hardest_positive_idx].argmax()] # 最远同类的标签
# 最难的负例(不同类别中最近的)
hardest_negative_idx = (labels != anchor_label).nonzero(as_tuple=True)[0] # 所有不同类索引
hardest_negative = hardest_negative_idx[anchor_distance[hardest_negative_idx].argmin()] # 最近不同类的标签
# 加载选择的
anchors.append(embeddings[i])
positives.append(embeddings[hardest_positive])
negatives.append(embeddings[hardest_negative])
# 将列表转换为张量
anchors = torch.stack(anchors)
positives = torch.stack(positives)
negatives = torch.stack(negatives)
# 计算损失
loss = criterion(anchors, positives, negatives)
epoch_train_losses.append(loss.item())
# 反向传播和优化
loss.backward()
optimizer.step()
# 更新学习率
scheduler.step()
return np.mean(epoch_train_losses)

训练过程中使用 torch.cdist() 计算嵌入间的距离矩阵,并根据距离选择最难正例和最难负例。PyTorch 的 TripletMarginLoss 用于计算损失。

结论与讨论

实践表明,Batch Hard Triplet Loss 在某些情况下并非最优选择。例如,当正例样本内部差异较大时,强制其嵌入相似可能适得其反。

本文的重点在于 PyTorch 中自定义批次采样和动态距离计算的实现。

对于某些任务,直接在分类任务上微调嵌入模型可能比使用三元组损失更有效。

特别声明:以上内容(如有图片或视频亦包括在内)为自媒体平台“网易号”用户上传并发布,本平台仅提供信息存储服务。

Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.

相关推荐
热点推荐
罗马诺:国米正在谈先租后买3500万签迪亚比,球员已同意加盟

罗马诺:国米正在谈先租后买3500万签迪亚比,球员已同意加盟

懂球帝
2026-01-29 00:08:23
成龙女儿吴卓林结婚现场曝光,紧握爱人的手,洋溢着甜蜜笑容!

成龙女儿吴卓林结婚现场曝光,紧握爱人的手,洋溢着甜蜜笑容!

娱乐团长
2026-01-13 15:39:28
文联春晚太真实!宋轶朝天鼻辣眼,陈妍希肥头大耳,章子怡脸馒化

文联春晚太真实!宋轶朝天鼻辣眼,陈妍希肥头大耳,章子怡脸馒化

无人倾听无人倾听
2026-01-28 08:41:48
2026异地就医大调整,不用备案直接报,这三类人享专属福利

2026异地就医大调整,不用备案直接报,这三类人享专属福利

复转这些年
2026-01-29 03:00:03
9500万人口的东北,去年生了38万,死亡91万!情况比想象中更严重

9500万人口的东北,去年生了38万,死亡91万!情况比想象中更严重

狐狸先森讲升学规划
2025-08-01 18:30:03
调查:市价不到两百的护眼灯校园采购价却高达七八百,背后藏何猫腻?

调查:市价不到两百的护眼灯校园采购价却高达七八百,背后藏何猫腻?

澎湃新闻
2026-01-28 08:04:03
突然发现孩子真的很平庸,难以接受怎么办?网友分享引起万千共鸣

突然发现孩子真的很平庸,难以接受怎么办?网友分享引起万千共鸣

另子维爱读史
2026-01-27 20:48:17
刚刚,夜晚15家公司出现重大利空消息,有没有与你相关的个股?

刚刚,夜晚15家公司出现重大利空消息,有没有与你相关的个股?

股市皆大事
2026-01-28 20:23:17
双buff叠满!陪读妈妈+女留学生四轮围猎,牢A要凉?

双buff叠满!陪读妈妈+女留学生四轮围猎,牢A要凉?

步论天下事
2026-01-26 16:22:39
18岁伊斯兰少女直播拒戴头巾,被冷血父亲荣誉处决。

18岁伊斯兰少女直播拒戴头巾,被冷血父亲荣誉处决。

环球趣闻分享
2026-01-07 13:30:09
硬核警告!敢给台海送军火?直接“没收”。天降帮手,操作太顶了

硬核警告!敢给台海送军火?直接“没收”。天降帮手,操作太顶了

阿凫爱吐槽
2026-01-20 06:36:27
双胞胎人妻下海拍片!姐姐怂恿婚姻不顺的妹妹:不如一起出道!

双胞胎人妻下海拍片!姐姐怂恿婚姻不顺的妹妹:不如一起出道!

小飞爱生活1987
2026-01-26 12:48:35
从30万跌到16万,这四款豪华B级车腰斩甩卖,谁买谁是大赢家

从30万跌到16万,这四款豪华B级车腰斩甩卖,谁买谁是大赢家

西莫的艺术宫殿
2026-01-28 17:38:38
乌克兰名将多次阴阳中国!连赢俄罗斯选手画两个叉 为乌克兰而战

乌克兰名将多次阴阳中国!连赢俄罗斯选手画两个叉 为乌克兰而战

念洲
2026-01-28 08:02:24
允许了!天才9号秀!马刺可能离队第一人

允许了!天才9号秀!马刺可能离队第一人

篮球实战宝典
2026-01-28 22:36:43
马上,70万亿美元!

马上,70万亿美元!

路财主
2026-01-10 19:59:20
张雨绮代孕风波升级!双胞胎学籍分割成铁证,杨天真口碑崩塌

张雨绮代孕风波升级!双胞胎学籍分割成铁证,杨天真口碑崩塌

让生活充满温暖
2026-01-27 10:37:07
今天,A股涨到4170,做好准备,明天,1月29日,大概率会这样走

今天,A股涨到4170,做好准备,明天,1月29日,大概率会这样走

有范又有料
2026-01-28 15:03:36
山西一副省长被免!王震、雷健坤、韩珍堂、张晓峰履新...

山西一副省长被免!王震、雷健坤、韩珍堂、张晓峰履新...

无比
2026-01-28 17:21:23
6个习惯降低全身炎症让你养出健康长寿体质

6个习惯降低全身炎症让你养出健康长寿体质

吃练双修指南
2026-01-26 14:00:09
2026-01-29 04:55:00
deephub incentive-icons
deephub
CV NLP和数据挖掘知识
1904文章数 1445关注度
往期回顾 全部

科技要闻

它是神也是毒!Clawdbot改名卷入千万诈骗

头条要闻

俄总统助理:泽连斯基若愿与普京会晤 可来莫斯科

头条要闻

俄总统助理:泽连斯基若愿与普京会晤 可来莫斯科

体育要闻

没天赋的CBA第一小前锋,秘诀只有一个字

娱乐要闻

金子涵拉黑蔡徐坤,蔡徐坤工作室回应

财经要闻

从万科退休20天后,郁亮疑似失联

汽车要闻

新手必看!冰雪路面不敢开?记住这4点 关键时刻真能保命

态度原创

房产
艺术
本地
数码
公开课

房产要闻

实景兑现在即!绿城,在海棠湾重新定义终极旅居想象!

艺术要闻

沙特醒悟,“全球最大单体建筑”停止施工!

本地新闻

云游中国|拨开云雾,巫山每帧都是航拍大片

数码要闻

荣耀平板新春版本今启推送,升级计划公布

公开课

李玫瑾:为什么性格比能力更重要?

无障碍浏览 进入关怀版