jiku log

JTCのデータサイエンス中間管理職の学び

「大規模言語モデル入門」を読む〜第5章〜

第5章 大規模言語モデルのファインチューニング

本章では,Transformerをもちいて自然言語の様々なタスクを解くための方法を,具体的なコードとともに紹介している。

5.1 日本語ベンチマーク : JGLUE

機械学習のモデル構築には,データセットが必要になる。アルゴリズムの精度向上を目指すためには,モデル性能が評価できるようにベンチマークを準備することが重要である。

大規模言語モデルベンチマークの先駆けは,GLUE(General Language Understanding Evaluation)である。
JGLUEは,言語モデルの日本語理解能力を評価するためのベンチマークである。JGLUEに収録されているタスクおよびデータセットは,下表の通りである。

タスク データセット
文書分類 MARC-ja
JCoLA
文ペア関係予測 JSTS
JNLI
質問応答 JSQuAD
JCommonsenseQA
文書分類

MARC-Jaは,Amazonの商品レビューについて,5段階評価が1・2の文章を否定的("negative"),4・5の文章を肯定的("positive"),3の文章を除外して作成したデータセットである。タスクとしては2値分類になる。

JCoLAは文章が文法的に正しいか・間違っているというアノテーションがなされており,タスクとしては2値分類になる。

文ペア関係予測

JSTSは,意味的類似度計算のデータセットであり,文ペアの意味的類似度が0(意味が完全に異なる)から5(意味が完全に同じ)というスコアが付与されている。

JNLIは,自然言語推論用のデータセットであり,含意(entailment),矛盾(contradiction),中立(neutral)の3種類のラベルが付与されている。

質問応答

JSQuADは,質問文とその答えを含む可能性があるパッセージが与えられている。タスクとしては,パッセージの中から質問文の回答箇所を抜き出すというものになっている。

JCommonsenseQAは,常識推論能力を評価するためのデータセットで,クラウドワーカーが作成しているものである。

5.2 感情分析モデルの実装

この節では,MARC-jaデータセットとTransformerを用いて感情分析モデルの実装を行なう方法を説明している(評価方法については,次節で説明している)。

データの準備や前処理の方法について,コードを交えて説明しているが,特に興味深かったのは以下の2点である。

  • ミニバッチ学習 : Transformerを用いてテキストのような可変長の系列データを学習する際に,長さを揃えるためにダミートークンを加える(パディング)。
  • 学習率スケジューラ : 学習のはじめのころは学習率を徐々に大きくし(ウォームアップ),その後は学習率を小さくする。

※註:はじめのうちは学習率を大きくすることで,局所解から抜けやすくしているものと想定している。

5.3 感情分析モデルのエラー分析

5.2節で開発したモデルについて混同行列を用いて評価を行ない,精度を高めるための工夫として個別の事例を確認することについて説明している。

精度を高めるうえでの注意点として,モデルのショートカットについて説明している。これは特定のフレーズ,たとえば「☆4つ」といったフレーズに影響を受けてラベルが決まってしまう,という現象である。対策として,このようなフレーズを除外することが挙げられていた。

5.4 自然言語推論・意味的類似度計算・多肢選択式応答モデルの実装

本節では,これらのタスクに関するコード例を紹介している。

5.5 メモリ効率の良いファインチューニング

大規模言語モデルを用いた学習を行なう際には,大きなメモリ容量のGPUが必要になるが,このようなハードウェアは必ずしも簡単に用意できるわけではない。
本節では,メモリを節約して学習行なうためのテクニックについて説明している。

自動混合精度計算は,単精度浮動小数点数(FP32)半精度浮動小数点数(FP16)を組合わせる方法である。
前向き計算や誤差逆伝搬などはFP16を用いて,数値の桁数を必要とするパラメータ更新の際にはFP32を用いるというものである。

メモリを節約するための方法として,他には損失スケーリング勾配チェックポインティングLoRAチューニングなどの方法が説明されていた。

まとめと感想

感想 : 精度向上のための手順

コードを実装して動かしてみて,いざ精度検証をしてみると,精度が高くないことがある。
精度を下げる原因としてモデルのショートカットが説明されており,原因を探るための観点として参考になった。
ただ,実際にモデルのショートカットを引き起こす原因(トークン)を探るためには,SHAPなどを使うのかもしれない。


本記事を最後まで読んでくださり,どうもありがとうございました。