LLM 大语言模型实战 (六)-模型训练和训练优化
详细介绍 HF Trainer(Hugging Face Trainer),我们可以从训练超参数、优化器等方面入手。
HF Trainer 概述
HF Trainer 是 Hugging Face 提供的用于简化模型训练的工具,支持多种模型和任务。主要组成部分包括:
-
训练超参数:
learning_rate
: 学习率,控制模型更新速度。per_device_train_batch_size
: 每个设备的训练批次大小。num_train_epochs
: 训练的轮数。weight_decay
: 权重衰减,防止过拟合。evaluation_strategy
: 评估策略(如epoch
或steps
)。
-
优化器:
- 默认使用 AdamW 优化器,可以自定义。
示例代码
下面是一个可在 Jupyter Notebook 中运行的示例,使用 Hugging Face Transformers 库进行文本分类任务。
# 导入必要的库
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
# 加载数据集(示例使用 IMDB 数据集)
dataset = load_dataset("imdb")
train_texts, val_texts, train_labels, val_labels = train_test_split(
dataset['train']['text'], dataset['train']['label'], test_size=0.1
)
# 初始化 Tokenizer 和模型
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
# 数据预处理
train_encodings = tokenizer(train_texts, truncation=True, padding=True)
val_encodings = tokenizer(val_texts, truncation=True, padding=True)
# 创建数据集类
import torch
class IMDbDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
train_dataset = IMDbDataset(train_encodings, train_labels)
val_dataset = IMDbDataset(val_encodings, val_labels)
# 设置训练参数
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=3,
per_device_train_batch_size=16,
per_device_eval_batch_size=64,
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
evaluation_strategy="epoch"
)
# 初始化 Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset
)
# 开始训练
trainer.train()
说明
- 数据集:使用 IMDB 数据集作为示例,进行情感分类。
- 模型:使用 DistilBERT 进行文本分类。
- 训练参数:可以根据需要调整超参数。
运行上述代码前,请确保安装了必要的库:
pip install transformers datasets torch
这个示例将帮助你理解 HF Trainer 的基本使用方法!如果你需要更深入的细节,请告诉我。
相关文章:
Hugging Face | 🤗 Transformers简介
Huggingface 超详细介绍
为者常成,行者常至
自由转载-非商用-非衍生-保持署名(创意共享3.0许可证)