1転移学習とファインチューニング¶
転移学習は、あるタスクの学習で得られた知識を、他の関連するタスクの学習に適用する手法を指します。一般的には、以下のステップで行われることが多いです:
事前学習: 事前学習モデル(pre-trained models)とは、大規模なデータセットを用いて訓練した学習済みモデルのことです。一般的に、大量のデータ(例えば、インターネット上のテキストデータ)を使用して、モデルを事前に学習します。この時点でのモデルは、言語の汎用的な特徴や構造を捉えることができます。
ファインチューニング(fine-tuning): 事前学習モデルを、特定のタスクのデータ(例えば、感情分析や質問応答)でファインチューニングします。事前学習モデルでは汎用的な特徴をあらかじめ学習しておきますので、手元にある学習データが小規模でも高精度な認識性能を達成することが知られています。

2センチメント分析の実装¶
!nvidia-smizsh:1: command not found: nvidia-smi
2.1データセット¶
2.1.1Hugging Faceからサンプルデータの取得¶
Hugging Faceのには色々なデータセットが用意されております。ここでは、多言語のセンチメントデータセットを例として使用することにします。その中に、英語と日本語のサプセットが含まれます。
!pip install datasetsCollecting datasets
Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Requirement already satisfied: filelock in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from datasets) (3.13.1)
Requirement already satisfied: numpy>=1.17 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from datasets) (1.26.3)
Collecting pyarrow>=15.0.0 (from datasets)
Downloading pyarrow-18.1.0-cp312-cp312-macosx_12_0_arm64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Requirement already satisfied: pandas in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from datasets) (2.1.4)
Collecting requests>=2.32.2 (from datasets)
Using cached requests-2.32.3-py3-none-any.whl.metadata (4.6 kB)
Collecting tqdm>=4.66.3 (from datasets)
Using cached tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
Collecting xxhash (from datasets)
Downloading xxhash-3.5.0-cp312-cp312-macosx_11_0_arm64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
Downloading multiprocess-0.70.16-py312-none-any.whl.metadata (7.2 kB)
Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets) (2023.12.2)
Requirement already satisfied: aiohttp in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from datasets) (3.9.3)
Collecting huggingface-hub>=0.23.0 (from datasets)
Downloading huggingface_hub-0.27.0-py3-none-any.whl.metadata (13 kB)
Requirement already satisfied: packaging in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from datasets) (23.2)
Requirement already satisfied: pyyaml>=5.1 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from datasets) (6.0.1)
Requirement already satisfied: aiosignal>=1.1.2 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from aiohttp->datasets) (1.3.1)
Requirement already satisfied: attrs>=17.3.0 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from aiohttp->datasets) (23.2.0)
Requirement already satisfied: frozenlist>=1.1.1 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from aiohttp->datasets) (1.4.1)
Requirement already satisfied: multidict<7.0,>=4.5 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from aiohttp->datasets) (6.0.4)
Requirement already satisfied: yarl<2.0,>=1.0 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from aiohttp->datasets) (1.9.4)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from huggingface-hub>=0.23.0->datasets) (4.9.0)
Requirement already satisfied: charset-normalizer<4,>=2 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from requests>=2.32.2->datasets) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from requests>=2.32.2->datasets) (3.6)
Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from requests>=2.32.2->datasets) (2.1.0)
Requirement already satisfied: certifi>=2017.4.17 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from requests>=2.32.2->datasets) (2023.11.17)
Requirement already satisfied: python-dateutil>=2.8.2 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from pandas->datasets) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from pandas->datasets) (2023.3.post1)
Requirement already satisfied: tzdata>=2022.1 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from pandas->datasets) (2023.4)
Requirement already satisfied: six>=1.5 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 480.6/480.6 kB 17.1 MB/s eta 0:00:00
Downloading dill-0.3.8-py3-none-any.whl (116 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 116.3/116.3 kB 16.9 MB/s eta 0:00:00
Downloading huggingface_hub-0.27.0-py3-none-any.whl (450 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 450.5/450.5 kB 48.4 MB/s eta 0:00:00
Downloading multiprocess-0.70.16-py312-none-any.whl (146 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 146.7/146.7 kB 25.1 MB/s eta 0:00:00
Downloading pyarrow-18.1.0-cp312-cp312-macosx_12_0_arm64.whl (29.5 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 29.5/29.5 MB 32.7 MB/s eta 0:00:0000:01:00:01
Using cached requests-2.32.3-py3-none-any.whl (64 kB)
Using cached tqdm-4.67.1-py3-none-any.whl (78 kB)
Downloading xxhash-3.5.0-cp312-cp312-macosx_11_0_arm64.whl (30 kB)
Installing collected packages: xxhash, tqdm, requests, pyarrow, dill, multiprocess, huggingface-hub, datasets
Attempting uninstall: tqdm
Found existing installation: tqdm 4.66.1
Uninstalling tqdm-4.66.1:
Successfully uninstalled tqdm-4.66.1
Attempting uninstall: requests
Found existing installation: requests 2.31.0
Uninstalling requests-2.31.0:
Successfully uninstalled requests-2.31.0
Attempting uninstall: huggingface-hub
Found existing installation: huggingface-hub 0.20.2
Uninstalling huggingface-hub-0.20.2:
Successfully uninstalled huggingface-hub-0.20.2
Successfully installed datasets-3.2.0 dill-0.3.8 huggingface-hub-0.27.0 multiprocess-0.70.16 pyarrow-18.1.0 requests-2.32.3 tqdm-4.67.1 xxhash-3.5.0
[notice] A new release of pip is available: 24.0 -> 24.3.1
[notice] To update, run: pip install --upgrade pip
!pip install wandb
import wandb
wandb.login()Requirement already satisfied: wandb in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (0.19.1)
Requirement already satisfied: click!=8.0.0,>=7.1 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from wandb) (8.1.7)
Requirement already satisfied: docker-pycreds>=0.4.0 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from wandb) (0.4.0)
Requirement already satisfied: gitpython!=3.1.29,>=1.0.0 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from wandb) (3.1.43)
Requirement already satisfied: platformdirs in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from wandb) (4.1.0)
Requirement already satisfied: protobuf!=4.21.0,!=5.28.0,<6,>=3.19.0 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from wandb) (4.25.5)
Requirement already satisfied: psutil>=5.0.0 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from wandb) (5.9.7)
Requirement already satisfied: pydantic<3,>=2.6 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from wandb) (2.6.0)
Requirement already satisfied: pyyaml in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from wandb) (6.0.1)
Requirement already satisfied: requests<3,>=2.0.0 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from wandb) (2.32.3)
Requirement already satisfied: sentry-sdk>=2.0.0 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from wandb) (2.19.2)
Requirement already satisfied: setproctitle in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from wandb) (1.3.4)
Requirement already satisfied: setuptools in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from wandb) (69.0.3)
Requirement already satisfied: six>=1.4.0 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)
Requirement already satisfied: gitdb<5,>=4.0.1 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from gitpython!=3.1.29,>=1.0.0->wandb) (4.0.11)
Requirement already satisfied: annotated-types>=0.4.0 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from pydantic<3,>=2.6->wandb) (0.6.0)
Requirement already satisfied: pydantic-core==2.16.1 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from pydantic<3,>=2.6->wandb) (2.16.1)
Requirement already satisfied: typing-extensions>=4.6.1 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from pydantic<3,>=2.6->wandb) (4.9.0)
Requirement already satisfied: charset-normalizer<4,>=2 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from requests<3,>=2.0.0->wandb) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from requests<3,>=2.0.0->wandb) (3.6)
Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from requests<3,>=2.0.0->wandb) (2.1.0)
Requirement already satisfied: certifi>=2017.4.17 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from requests<3,>=2.0.0->wandb) (2023.11.17)
Requirement already satisfied: smmap<6,>=3.0.1 in /Users/ryozawau/anaconda3/envs/jupyterbook/lib/python3.12/site-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb) (5.0.1)
[notice] A new release of pip is available: 24.0 -> 24.3.1
[notice] To update, run: pip install --upgrade pip
wandb: Currently logged in as: lvzeyu1995 (lvzeyu1995-tohoku-university). Use `wandb login --relogin` to force relogin
Trueimport os
os.environ["WANDB_PROJECT"]="sentiment_analysis"from datasets import load_dataset
#dataset = load_dataset("tyqiangz/multilingual-sentiments", "japanese")
dataset = load_dataset("tyqiangz/multilingual-sentiments", "english")2.1.2サンプルデータの確認¶
取得したデータセットの中身を確認します。
データセットはこのようにtrain, validation, testに分かれています。 [‘text’, ‘source’, ‘label’]といった情報を持っています。
datasetDatasetDict({
train: Dataset({
features: ['text', 'source', 'label'],
num_rows: 1839
})
validation: Dataset({
features: ['text', 'source', 'label'],
num_rows: 324
})
test: Dataset({
features: ['text', 'source', 'label'],
num_rows: 870
})
})dataset.set_format(type="pandas")
train_df = dataset["train"][:]
train_df.head(5)dataset["train"].features{'text': Value(dtype='string', id=None),
'source': Value(dtype='string', id=None),
'label': ClassLabel(names=['positive', 'neutral', 'negative'], id=None)}import matplotlib.pyplot as plt
train_df["label"].value_counts(ascending=True).plot(kind="barh", title="Train Dataset")<Axes: title={'center': 'Train Dataset'}, ylabel='label'>
2.1.3テキストの確認¶
Transformerモデルは、最大コンテキストサイズ(maximum context size)と呼ばれる最大入力系列長があります。
モデルのコンテキストサイズより長いテキストは切り捨てる必要があり、切り捨てたテキストに重要な情報が含まれている場合、性能の低下につながることがあります。
train_df["text_length"]=train_df["text"].str.len()train_df.boxplot(column="text_length", by="label", figsize=(12, 6))
<Axes: title={'center': 'text_length'}, xlabel='label'>
2.2トークン化¶
コンピュータは、入力として生の文字列を受け取ることができません。その代わりに、テキストがトークン化され、数値ベクトルとしてエンコードされていることが想定しています。
トークン化は、文字列をモデルで使用される最小単位に分解するステップです。
Transformerライブラリー は便利なAutoTokenizerクラスを提供しており、事前学習済みモデルに関連つけられたトークナイザーを素早く使用することができます。
2.2.1トークナイザの動作確認¶
tokenizerテキストを数値形式(トークン)に変換します。
入力テキストをトークンに分割します
特殊トークンが自動的に付加されます
トークンをトークンIDに変換します
from transformers import AutoTokenizer
model_ckpt = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)tokenizer_config.json: 100%|██████████| 28.0/28.0 [00:00<00:00, 72.0kB/s]
config.json: 100%|██████████| 483/483 [00:00<00:00, 1.26MB/s]
vocab.txt: 100%|██████████| 232k/232k [00:00<00:00, 1.36MB/s]
tokenizer.json: 100%|██████████| 466k/466k [00:00<00:00, 938kB/s]
train_df["text"][0]'okay i\\u2019m sorry but TAYLOR SWIFT LOOKS NOTHING LIKE JACKIE O SO STOP COMPARING THE TWO. c\\u2019mon America aren\\u2019t you sick of her yet? (sorry) 'sample_text_encoded = tokenizer(train_df["text"][0])
sample_text_encoded{'input_ids': [101, 3100, 1045, 1032, 23343, 24096, 2683, 2213, 3374, 2021, 4202, 9170, 3504, 2498, 2066, 9901, 1051, 2061, 2644, 13599, 1996, 2048, 1012, 1039, 1032, 23343, 24096, 2683, 8202, 2637, 4995, 1032, 23343, 24096, 2683, 2102, 2017, 5305, 1997, 2014, 2664, 1029, 1006, 3374, 1007, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}結果にinput_idsとattention_maskが含まれます。
input_ids: 数字にエンコードされたトークン
attention_mask: モデルで有効なトークンかどうかを判別するためのマスクです。無効なトークン(例えば、PADなど)に対しては、attention_maskを として処理します。
各batchにおいて、入力系列はbatch内最大系列長までpaddingされます。

トークナイザの結果は数字にエンコードされているため、トークン文字列を得るには、convert_ids_to_tokensを用います。
文の開始が[CLS]、文の終了が[SEP]という特殊なトークンとなっています。
tokens = tokenizer.convert_ids_to_tokens(sample_text_encoded.input_ids)
print(tokens)['[CLS]', 'okay', 'i', '\\', 'u2', '##01', '##9', '##m', 'sorry', 'but', 'taylor', 'swift', 'looks', 'nothing', 'like', 'jackie', 'o', 'so', 'stop', 'comparing', 'the', 'two', '.', 'c', '\\', 'u2', '##01', '##9', '##mon', 'america', 'aren', '\\', 'u2', '##01', '##9', '##t', 'you', 'sick', 'of', 'her', 'yet', '?', '(', 'sorry', ')', '[SEP]']
2.2.2データセット全体のトークン化¶
def tokenize(batch):
return tokenizer(batch["text"], padding=True, truncation=True)dataset.reset_format()dataset_encoded = dataset.map(tokenize, batched=True, batch_size=None)Map: 100%|██████████| 1839/1839 [00:00<00:00, 32383.76 examples/s]
Map: 100%|██████████| 324/324 [00:00<00:00, 34130.86 examples/s]
Map: 100%|██████████| 870/870 [00:00<00:00, 36298.43 examples/s]
import pandas as pd
sample_encoded = dataset_encoded["train"][0]
pd.DataFrame(
[sample_encoded["input_ids"]
, sample_encoded["attention_mask"]
, tokenizer.convert_ids_to_tokens(sample_encoded["input_ids"])],
['input_ids', 'attention_mask', "tokens"]
).T2.3分類器の実装¶
2.3.1事前学習モデルの導入¶
Transformerライブラリは事前学習モデルの使用ためAutoModelクラスを提供します。
AutoModelクラスはトークンエンコーディングを埋め込みに変換し、エンコーダスタックを経由して最後の隠れ状態を返します。
import torch
from transformers import AutoModel
# GPUある場合はGPUを使う
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModel.from_pretrained(model_ckpt).to(device)model.safetensors: 100%|██████████| 268M/268M [00:05<00:00, 50.2MB/s]
最初に、文字列をエンコーダしてトークンをPyTorchのテンソルに変換する必要があります。
結果として得られるテンソルは[batch_size,n_tokens]という形状です。
text = "this is a test"
inputs = tokenizer(text, return_tensors="pt")
print(f"Input tensor shape: {inputs['input_ids'].size()}")Input tensor shape: torch.Size([1, 6])
得られるテンソルをモデルの入力として渡します。
モデルと同じデバイス(GPU or CPU)に設置します。
計算のメモリを減らせるため、
torch.no_grad()で、勾配の自動計算を無効します。出力には隠れ状態、損失、アテンションのオブジェクトが含まれます。
inputs = {k:v.to(device) for k,v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
print(outputs)BaseModelOutput(last_hidden_state=tensor([[[-0.1565, -0.1862, 0.0528, ..., -0.1188, 0.0662, 0.5470],
[-0.3575, -0.6484, -0.0618, ..., -0.3040, 0.3508, 0.5221],
[-0.2772, -0.4459, 0.1818, ..., -0.0948, -0.0076, 0.9958],
[-0.2841, -0.3917, 0.3753, ..., -0.2151, -0.1173, 1.0526],
[ 0.2661, -0.5094, -0.3180, ..., -0.4203, 0.0144, -0.2149],
[ 0.9441, 0.0112, -0.4714, ..., 0.1439, -0.7288, -0.1619]]],
device='cuda:0'), hidden_states=None, attentions=None)
隠れた状態テンソルを見ると、その形状は [batch_size, n_tokens, hidden_dim] であることがわかります。つまり、6つの入力トークンのそれぞれに対して、768次元のベクトルが返されます。
outputs.last_hidden_state.size()torch.Size([1, 6, 768])分類タスクでは、[CLS] トークンに関連する隠れた状態を入力特徴として使用するのが一般的な方法です。このトークンは各シーケンスの始まりに現れるため、次のように outputs.last_hidden_state に単純にインデックスを付けることで抽出できます。
outputs.last_hidden_state[:,0].size()
torch.Size([1, 768])最後の隠れ状態を取得する方法がわかりましたので、データ全体に対して処理を行うため、これまでのステップを関数でまとめます。
そして、データ全体に適用し、すべてのテキストの隠れ状態を抽出します。
def extract_hidden_states(batch):
# Place model inputs on the GPU
inputs = {k:v.to(device) for k,v in batch.items()
if k in tokenizer.model_input_names}
# Extract last hidden states
with torch.no_grad():
last_hidden_state = model(**inputs).last_hidden_state
# Return vector for [CLS] token
return {"hidden_state": last_hidden_state[:,0].cpu().numpy()}dataset_encoded.set_format(type="torch", columns=["input_ids", "attention_mask","label"])
dataset_hidden=dataset_encoded.map(extract_hidden_states, batched=True, batch_size=16)Map: 100%|██████████| 1839/1839 [00:01<00:00, 1785.59 examples/s]
Map: 100%|██████████| 324/324 [00:00<00:00, 2093.79 examples/s]
Map: 100%|██████████| 870/870 [00:00<00:00, 2275.50 examples/s]
2.3.2分類器の学習¶
前処理されたデータセットには、分類器を学習させるために必要な情報がすべて含まれています。
具体的には、隠れ状態を入力特徴量として、ラベルをターゲットとして使用すると、様々な分類アルゴリズムに適用できるだろう。
ここで、ロジスティック回帰モデルを学習します。
import numpy as np
X_train = np.array(dataset_hidden["train"]["hidden_state"])
X_valid = np.array(dataset_hidden["validation"]["hidden_state"])
y_train = np.array(dataset_hidden["train"]["label"])
y_valid = np.array(dataset_hidden["validation"]["label"])
X_train.shape, X_valid.shape((1839, 768), (324, 768))from sklearn.linear_model import LogisticRegression
lr_clf = LogisticRegression(max_iter=3000)
lr_clf.fit(X_train, y_train)lr_clf.score(X_valid, y_valid)0.5987654320987654from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
def plot_confusion_matrix(y_preds, y_true, labels):
cm = confusion_matrix(y_true, y_preds, normalize="true")
fig, ax = plt.subplots(figsize=(6, 6))
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
disp.plot(cmap="Blues", values_format=".2f", ax=ax, colorbar=False)
plt.title("Normalized confusion matrix")
plt.show()
y_preds = lr_clf.predict(X_valid)
plot_confusion_matrix(y_preds, y_valid, ["positive","neutral","negative"])
2.3.3AutoModelForSequenceClassificationのファインチューニング¶
transformerライブラリは、ファインチューニングのタスクに応じてAPIを提供しています。
分類タスクの場合、AutoModelの代わりにAutoModelForSequenceClassificationを使用します。
AutoModelForSequenceClassificationが事前学習済みモデルの出力の上に分類器ヘッドを持っており、モデルの設定がより簡単になります。
from transformers import AutoModelForSequenceClassification
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_labels = 3
model = (AutoModelForSequenceClassification
.from_pretrained(model_ckpt, num_labels=num_labels)
.to(device))Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
modelDistilBertForSequenceClassification(
(distilbert): DistilBertModel(
(embeddings): Embeddings(
(word_embeddings): Embedding(30522, 768, padding_idx=0)
(position_embeddings): Embedding(512, 768)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(transformer): Transformer(
(layer): ModuleList(
(0-5): 6 x TransformerBlock(
(attention): MultiHeadSelfAttention(
(dropout): Dropout(p=0.1, inplace=False)
(q_lin): Linear(in_features=768, out_features=768, bias=True)
(k_lin): Linear(in_features=768, out_features=768, bias=True)
(v_lin): Linear(in_features=768, out_features=768, bias=True)
(out_lin): Linear(in_features=768, out_features=768, bias=True)
)
(sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(ffn): FFN(
(dropout): Dropout(p=0.1, inplace=False)
(lin1): Linear(in_features=768, out_features=3072, bias=True)
(lin2): Linear(in_features=3072, out_features=768, bias=True)
(activation): GELUActivation()
)
(output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)
)
)
)
(pre_classifier): Linear(in_features=768, out_features=768, bias=True)
(classifier): Linear(in_features=768, out_features=3, bias=True)
(dropout): Dropout(p=0.2, inplace=False)
)inputs = tokenizer("I purchased these boots to use both for everyday wear and when riding my motorcycle.", return_tensors="pt") # pytorch tensorに変換するためにreturn_tensors="pt"を指定
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
print(outputs)---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[1], line 1
----> 1 inputs = tokenizer("I purchased these boots to use both for everyday wear and when riding my motorcycle.", return_tensors="pt") # pytorch tensorに変換するためにreturn_tensors="pt"を指定
2 inputs = {k: v.to(device) for k, v in inputs.items()}
3 with torch.no_grad():
NameError: name 'tokenizer' is not defined2.3.4学習の準備¶
学習時に性能指標を与える必要があるため、それを関数化して定義しておきます。
from sklearn.metrics import accuracy_score, f1_score
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
f1 = f1_score(labels, preds, average="weighted")
acc = accuracy_score(labels, preds)
return {"accuracy": acc, "f1": f1}学習を効率化するために、transformerライブラリのTrainer APIを使用します。
Trainerクラスを初期化する際には、TrainingArgumentsという訓練に関する様々な設定値の集合を引数に与えることで、訓練の設定に関する細かい調整が可能です。
from transformers import TrainingArguments
batch_size = 16
logging_steps = len(dataset_encoded["train"]) // batch_size
model_name = "sample-text-classification-bert"
training_args = TrainingArguments(
output_dir=model_name,
num_train_epochs=2,
learning_rate=2e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
weight_decay=0.01,
evaluation_strategy="epoch",
disable_tqdm=False,
logging_steps=logging_steps,
push_to_hub=False,
log_level="error"
)Trainerクラスで実行します。
結果を確認すると、特徴ベースのアプローチよりも精度が改善されることがわかります。
from transformers import Trainer
trainer = Trainer(
model=model,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=dataset_encoded["train"],
eval_dataset=dataset_encoded["validation"],
tokenizer=tokenizer
)
trainer.train()TrainOutput(global_step=230, training_loss=0.8717699584753617, metrics={'train_runtime': 7.9795, 'train_samples_per_second': 460.93, 'train_steps_per_second': 28.824, 'total_flos': 74225497893768.0, 'train_loss': 0.8717699584753617, 'epoch': 2.0})preds_output = trainer.predict(dataset_encoded["test"])import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
y_preds = np.argmax(preds_output.predictions, axis=1)
y_valid = np.array(dataset_encoded["test"]["label"])
labels = dataset_encoded["train"].features["label"].names
def plot_confusion_matrix(y_preds, y_true, labels):
cm = confusion_matrix(y_true, y_preds, normalize="true")
fig, ax = plt.subplots(figsize=(6, 6))
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
disp.plot(cmap="Blues", values_format=".2f", ax=ax, colorbar=False)
plt.title("Normalized confusion matrix")
plt.show()
plot_confusion_matrix(y_preds, y_valid, labels)
2.4.2モデル保存¶
id2label = {}
for i in range(dataset["train"].features["label"].num_classes):
id2label[i] = dataset["train"].features["label"].int2str(i)
label2id = {}
for i in range(dataset["train"].features["label"].num_classes):
label2id[dataset["train"].features["label"].int2str(i)] = i
trainer.model.config.id2label = id2label
trainer.model.config.label2id = label2idtrainer.save_model(f"./Data/sample-text-classification-bert")2.4.3学習済みモデルの読み込み¶
new_tokenizer = AutoTokenizer\
.from_pretrained(f"./Data/sample-text-classification-bert")
new_model = (AutoModelForSequenceClassification
.from_pretrained(f"./Data/sample-text-classification-bert")
.to(device))サンプルテキストで推論の結果を確認します。
def id2label(x):
label_dict={0:"positive",1:"neutral",2:"negative"}
return label_dict[x]text1="this week is not going as i had hoped"
text2="awe i love you too!!!! 1 am here i miss you"
inputs = new_tokenizer(text1, return_tensors="pt")
new_model.eval()
with torch.no_grad():
outputs = new_model(
inputs["input_ids"].to(device),
inputs["attention_mask"].to(device),
)
outputs.logits
y_preds = np.argmax(outputs.logits.to('cpu').detach().numpy().copy(), axis=1)
y_preds = [id2label(x) for x in y_preds]
y_preds['negative']inputs = new_tokenizer(text2, return_tensors="pt")
new_model.eval()
with torch.no_grad():
outputs = new_model(
inputs["input_ids"].to(device),
inputs["attention_mask"].to(device),
)
outputs.logits
y_preds = np.argmax(outputs.logits.to('cpu').detach().numpy().copy(), axis=1)
y_preds = [id2label(x) for x in y_preds]
y_preds['positive']