GPT2 파인튜닝으로 친근한 대화형 챗봇 만들기

목록으로 돌아가기

GPT2 모델

이미지 설명

GPT-2는 “Generative Pre-trained Transformer 2”의 약자로, 트랜스포머 아키텍처를 기반인 생성형 AI 모델이다. 2019년도에 출시 됐기 때문에 현 시점에서 GPT-2보다 좋은 모델은 당연히 존재한다. 하지만 조금… 문제는… 다들 너무 큰 파라미터를 가지기 때문에 개인, 학생들이랑 사용하기란 무리 요소가 매우 많다. 본 글은 위피독스를 참고하여 작성된 글이다.

데이터

이미지 설명

Aihub에서 제공하는 주제별 텍스트 일상 대화 데이터를 전처리하여 학습에 사용하였다.

Aihub 주제별 텍스트 일상 대화 데이터 다운로드


데이터 전처리

각 폴더를 순회하면 txt 파일를 활용하여 csv 데이터 파일을 만들었다.
각 컬럼은 id, Q, A 로 구성된다
여기서는 단순하게 첫번째 열이 질문 다음열이 답변으로 취급한 뒤 저장하였다
본 코드로 생성시 결측치가 생기면 이는 5개 정도이기 때문에 해당 행 삭제후 진행하였다

import os
from datasets import Dataset
from tqdm import tqdm
import time
from datasets import load_dataset
import pandas as pd

data = {"Q": [], "A": [], "id": []}


folder = [
    "./TS_01. KAKAO(1)",
    "./TS_01. KAKAO(2)",
    "./TS_01. KAKAO(3)",
    "./TS_01. KAKAO(4)",
    "./TS_02. FACEBOOK",
    "./TS_03. INSTAGRAM",
    "./TS_04. BAND",
    "./TS_05. NATEON",
]

id_counter = 1
for folder_path in folder:
    for file in tqdm(os.listdir(folder_path)):
        if file.endswith(".txt"): 
            file_path = os.path.join(folder_path, file)
            with open(file_path, "r", encoding="utf-8") as f:
                lines = f.readlines()
                prev_line = ""
                prev_label = None
                for line in lines:
                    line = line.strip()
                    if line:
                        try:
                            label = int(line.split(":")[0])
                            text = line[len(str(label)) + 3 :].strip()
                            if prev_line == "":
                                prev_line = text
                            else:
                                data["id"].append(id_counter)
                                data["Q"].append(prev_line)
                                data["A"].append(text)
                                id_counter += 1 
                                prev_line = ""
                        except ValueError:
                            continue


# 데이터셋 생성
dataset = Dataset.from_dict(data)

# 데이터셋을 CSV 파일로 저장
df = pd.DataFrame(data)
csv_filename = "makedata.csv"
df.to_csv(csv_filename, index=False, encoding="utf-8")

print(dataset[0])



Train 코드

학습은 먼저 skt에서 공개한 koGPT 모델을 활용하여 파인튜닝 할 것이다.
본 모델은 허깅페이스에서 쉽게 불러올 수 있다.

학습을 진행하기전 나의 환경 리스트이다

Package                      Version
---------------------------- ------------
absl-py                      1.4.0
accelerate                   0.22.0
aiofiles                     23.2.1
aiohttp                      3.8.5
aiosignal                    1.3.1
albumentations               1.3.1
alembic                      1.12.0
altair                       5.1.1
anyio                        3.7.1
argon2-cffi                  21.3.0
argon2-cffi-bindings         21.2.0
arrow                        1.2.3
asttokens                    2.2.1
astunparse                   1.6.3
async-timeout                4.0.3
attrs                        23.1.0
autotrain-advanced           0.6.31
backcall                     0.2.0
beautifulsoup4               4.12.2
bitsandbytes                 0.41.1
bleach                       6.0.0
cachetools                   5.3.1
certifi                      2023.7.22
cffi                         1.15.1
charset-normalizer           3.2.0
click                        8.1.7
cmaes                        0.10.0
cmake                        3.25.0
codecarbon                   2.2.3
colorlog                     6.7.0
comm                         0.1.3
contourpy                    1.1.0
cycler                       0.11.0
datasets                     2.14.4
debugpy                      1.6.7
decorator                    5.1.1
defusedxml                   0.7.1
diffusers                    0.20.2
dill                         0.3.7
einops                       0.6.1
evaluate                     0.3.0
exceptiongroup               1.1.2
executing                    1.2.0
fastapi                      0.103.1
fastjsonschema               2.17.1
ffmpy                        0.3.1
filelock                     3.9.0
fire                         0.5.0
flatbuffers                  23.5.26
fonttools                    4.42.1
frozenlist                   1.4.0
fsspec                       2023.9.0
future                       0.18.3
fuzzywuzzy                   0.18.0
gast                         0.4.0
google-auth                  2.22.0
google-auth-oauthlib         1.0.0
google-pasta                 0.2.0
gradio                       3.39.0
gradio_client                0.5.0
graphsurgeon                 0.4.6
greenlet                     2.0.2
grpcio                       1.56.0
h11                          0.14.0
h5py                         3.9.0
httpcore                     0.17.3
httpx                        0.24.1
huggingface-hub              0.16.4
idna                         3.4
imageio                      2.31.3
importlib-metadata           6.8.0
importlib-resources          6.0.0
invisible-watermark          0.2.0
ipadic                       1.0.0
ipykernel                    6.24.0
ipython                      8.12.2
ipython-genutils             0.2.0
ipywidgets                   8.0.7
jax                          0.4.13
jedi                         0.18.2
Jinja2                       3.1.2
jiwer                        3.0.2
joblib                       1.3.1
jsonschema                   4.18.3
jsonschema-specifications    2023.6.1
jupyter                      1.0.0
jupyter_client               8.3.0
jupyter-console              6.6.3
jupyter_core                 5.3.1
jupyter-events               0.6.3
jupyter_server               2.7.0
jupyter_server_terminals     0.4.4
jupyterlab-pygments          0.2.2
jupyterlab-widgets           3.0.8
keras                        2.12.0
kiwisolver                   1.4.5
lazy_loader                  0.3
libclang                     16.0.6
lightning-lite               1.8.0
lightning-utilities          0.9.0
linkify-it-py                2.0.2
lit                          15.0.7
loguru                       0.7.0
Mako                         1.2.4
Markdown                     3.4.3
markdown-it-py               2.2.0
MarkupSafe                   2.1.3
matplotlib                   3.7.2
matplotlib-inline            0.1.6
mdit-py-plugins              0.3.3
mdurl                        0.1.2
mistune                      3.0.1
ml-dtypes                    0.2.0
mpmath                       1.2.1
multidict                    6.0.4
multiprocess                 0.70.15
nbclassic                    1.0.0
nbclient                     0.8.0
nbconvert                    7.7.0
nbformat                     5.9.1
nest-asyncio                 1.5.6
networkx                     3.0
notebook                     6.5.4
notebook_shim                0.2.3
numpy                        1.23.5
oauthlib                     3.2.2
onnx                         1.14.0
onnx-graphsurgeon            0.3.12
opencv-python                4.8.0.76
opencv-python-headless       4.8.0.76
opt-einsum                   3.3.0
optuna                       3.3.0
orjson                       3.9.5
overrides                    7.3.1
packaging                    23.1
pandas                       2.0.3
pandocfilters                1.5.0
parso                        0.8.3
peft                         0.5.0
pickleshare                  0.7.5
Pillow                       10.0.0
pip                          23.2.1
pkgutil_resolve_name         1.3.10
platformdirs                 3.9.1
prometheus-client            0.17.1
prompt-toolkit               3.0.39
protobuf                     4.23.4
psutil                       5.9.5
ptyprocess                   0.7.0
pure-eval                    0.2.2
py-cpuinfo                   9.0.0
pyarrow                      13.0.0
pyasn1                       0.5.0
pyasn1-modules               0.3.0
pycparser                    2.21
pydantic                     1.10.11
pyDeprecate                  0.3.1
pydub                        0.25.1
Pygments                     2.15.1
pynvml                       11.5.0
pyparsing                    3.0.9
python-dateutil              2.8.2
python-json-logger           2.0.7
python-multipart             0.0.6
pytorch-lightning            1.9.0
pytz                         2023.3.post1
PyWavelets                   1.4.1
PyYAML                       6.0.1
pyzmq                        25.1.0
qtconsole                    5.4.3
QtPy                         2.3.1
qudida                       0.0.4
rapidfuzz                    2.13.7
referencing                  0.29.1
regex                        2023.8.8
requests                     2.31.0
requests-oauthlib            1.3.1
responses                    0.18.0
rfc3339-validator            0.1.4
rfc3986-validator            0.1.1
rpds-py                      0.8.11
rsa                          4.9
sacremoses                   0.0.53
safetensors                  0.3.3
scikit-image                 0.21.0
scikit-learn                 1.3.0
scipy                        1.10.1
semantic-version             2.10.0
Send2Trash                   1.8.2
sentencepiece                0.1.99
setuptools                   68.0.0
six                          1.16.0
sniffio                      1.3.0
soupsieve                    2.4.1
SQLAlchemy                   2.0.20
stack-data                   0.6.2
starlette                    0.27.0
sympy                        1.11.1
tensorboard                  2.12.3
tensorboard-data-server      0.7.1
tensorflow                   2.12.0
tensorflow-estimator         2.12.0
tensorflow-io-gcs-filesystem 0.32.0
tensorrt                     8.6.1
tensorrt-dispatch            8.6.1
tensorrt-lean                8.6.1
termcolor                    2.3.0
terminado                    0.17.1
threadpoolctl                3.2.0
tifffile                     2023.7.10
tiktoken                     0.4.0
tinycss2                     1.2.1
tokenizers                   0.13.3
toolz                        0.12.0
torch                        2.0.1+cu118
torchaudio                   2.0.2+cu118
torchmetrics                 1.1.1
torchvision                  0.15.2+cu118
tornado                      6.3.2
tqdm                         4.65.0
traitlets                    5.9.0
transformers                 4.33.0
triton                       2.0.0
trl                          0.7.1
typing_extensions            4.7.1
tzdata                       2023.3
uc-micro-py                  1.0.2
uff                          0.6.9
urllib3                      1.26.16
uvicorn                      0.23.2
wcwidth                      0.2.6
webencodings                 0.5.1
websocket-client             1.6.1
websockets                   11.0.3
Werkzeug                     2.3.6
wheel                        0.38.4
widgetsnbextension           4.0.8
wrapt                        1.14.1
xgboost                      1.7.6
xxhash                       3.3.0
yarl                         1.9.2
zipp                         3.16.2

아래는 훈련코드이다. 최종적으로 학습에 활용되는 인풋은 아래와 같다.

(Q토큰) + 질문데이터 + (S토큰) + (A토큰) + 답변데이터 + (E토큰)

import numpy as np
import pandas as pd
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.lightning import LightningModule
from torch.utils.data import DataLoader, Dataset
from transformers.optimization import AdamW, get_cosine_schedule_with_warmup
from transformers import PreTrainedTokenizerFast, GPT2LMHeadModel
import re, os
from tqdm import tqdm


Q_TKN = "<usr>"
A_TKN = "<sys>"
BOS = "</s>"
EOS = "</s>"
MASK = "<unused0>"
SENT = "<unused1>"
PAD = "<pad>"

save_dir = "saved_models"
os.makedirs(save_dir, exist_ok=True)


print("start1")


class ChatbotDataset(Dataset):
    def __init__(self, chats, max_len=40):  # 데이터셋의 전처리를 해주는 부분
        self._data = chats
        self.max_len = max_len
        self.q_token = Q_TKN
        self.a_token = A_TKN
        self.sent_token = SENT
        self.eos = EOS
        self.mask = MASK
        self.tokenizer = koGPT2_TOKENIZER

    def __len__(self):  # chatbotdata 의 길이를 리턴한다.
        return len(self._data)

    def __getitem__(self, idx):  # 로드한 챗봇 데이터를 차례차례 DataLoader로 넘겨주는 메서드
        turn = self._data.iloc[idx]
        q = turn["Q"]  # 질문을 가져온다.
        q = re.sub(r"([?.!,])", r" ", q)  # 구둣점들을 제거한다.

        a = turn["A"]  # 답변을 가져온다.
        a = re.sub(r"([?.!,])", r" ", a)  # 구둣점들을 제거한다.

        q_toked = self.tokenizer.tokenize(self.q_token + q + self.sent_token)
        q_len = len(q_toked)

        a_toked = self.tokenizer.tokenize(self.a_token + a + self.eos)
        a_len = len(a_toked)

        # 질문의 길이가 최대길이보다 크면
        if q_len > self.max_len:
            a_len = self.max_len - q_len  # 답변의 길이를 최대길이 - 질문길이
            if a_len <= 0:  # 질문의 길이가 너무 길어 질문만으로 최대 길이를 초과 한다면
                q_toked = q_toked[-(int(self.max_len / 2)) :]  # 질문길이를 최대길이의 반으로
                q_len = len(q_toked)
                a_len = self.max_len - q_len  # 답변의 길이를 최대길이 - 질문길이
            a_toked = a_toked[:a_len]
            a_len = len(a_toked)

        # 질문의 길이 + 답변의 길이가 최대길이보다 크면
        if q_len + a_len > self.max_len:
            a_len = self.max_len - q_len  # 답변의 길이를 최대길이 - 질문길이
            if a_len <= 0:  # 질문의 길이가 너무 길어 질문만으로 최대 길이를 초과 한다면
                q_toked = q_toked[-(int(self.max_len / 2)) :]  # 질문길이를 최대길이의 반으로
                q_len = len(q_toked)
                a_len = self.max_len - q_len  # 답변의 길이를 최대길이 - 질문길이
            a_toked = a_toked[:a_len]
            a_len = len(a_toked)

        # 답변 labels = [mask, mask, ...., mask, ..., <bos>,..답변.. <eos>, <pad>....]
        labels = [
            self.mask,
        ] * q_len + a_toked[1:]

        # mask = 질문길이 0 + 답변길이 1 + 나머지 0
        mask = [0] * q_len + [1] * a_len + [0] * (self.max_len - q_len - a_len)
        # 답변 labels을 index 로 만든다.
        labels_ids = self.tokenizer.convert_tokens_to_ids(labels)
        # 최대길이만큼 PADDING
        while len(labels_ids) < self.max_len:
            labels_ids += [self.tokenizer.pad_token_id]

        # 질문 + 답변을 index 로 만든다.
        token_ids = self.tokenizer.convert_tokens_to_ids(q_toked + a_toked)
        # 최대길이만큼 PADDING
        while len(token_ids) < self.max_len:
            token_ids += [self.tokenizer.pad_token_id]

        # 질문+답변, 마스크, 답변
        return (token_ids, np.array(mask), labels_ids)


def collate_batch(batch):
    data = [item[0] for item in batch]
    mask = [item[1] for item in batch]
    label = [item[2] for item in batch]
    return torch.LongTensor(data), torch.LongTensor(mask), torch.LongTensor(label)


koGPT2_TOKENIZER = PreTrainedTokenizerFast.from_pretrained(
    "skt/kogpt2-base-v2",
    bos_token=BOS,
    eos_token=EOS,
    unk_token="<unk>",
    pad_token=PAD,
    mask_token=MASK,
)
model = GPT2LMHeadModel.from_pretrained("skt/kogpt2-base-v2")

dataname = "totaldata.csv"
Chatbot_Data = pd.read_csv("./"+ dataname)
Chatbot_Data.dropna(subset=["A"], inplace=True)
Chatbot_Data.dropna(subset=["Q"], inplace=True)
Chatbot_Data.dropna(subset=["id"], inplace=True)


print("start3")


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_set = ChatbotDataset(Chatbot_Data, max_len=40)
train_dataloader = DataLoader(
    train_set,
    batch_size=32,
    num_workers=0,
    shuffle=True,
    collate_fn=collate_batch,
)

model.to(device)

learning_rate = 3e-5
criterion = torch.nn.CrossEntropyLoss(reduction="none")
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

epoch = 10
Sneg = -1e18


for epoch in range(epoch):
    dataloader = tqdm(train_dataloader, desc=f"Epoch {epoch}")
    for batch_idx, samples in enumerate(dataloader):
        optimizer.zero_grad()
        token_ids, mask, label = samples
        token_ids, mask, label = token_ids.to(device), mask.to(device), label.to(device)
        out = model(token_ids)
        out = out.logits
        mask_3d = mask.unsqueeze(dim=2).repeat_interleave(repeats=out.shape[2], dim=2)
        mask_out = torch.where(mask_3d == 1, out, Sneg * torch.ones_like(out))
        loss = criterion(mask_out.transpose(2, 1), label)
        avg_loss = loss.sum() / mask.sum()
        avg_loss.backward()
        optimizer.step()


model_save_path = os.path.join(save_dir, "chatbot_model.pth")
torch.save(
    {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "epoch": epoch,
    },
    model_save_path,
)

print("Model saved at:", model_save_path)



Predict 코드

예측할 때는 저장된 모델을 먼저 불러온다. 이전에 불러온 사전 훈련된 모델 model에 대해, checkpoint에서 불러온 모델 가중치를 복원한 뒤 예측을 진행하였다. 예측할 때는 아래와 같은 인풋 구조를 가진다.

(Q토큰) + 질문데이터 + (S토큰) + (A토큰)

이렇게 해서 질문 기반 다음 답변을 생각하다가 (E 토큰)을 만나면 생성을 종료하는 구조이다.

import torch
from transformers import PreTrainedTokenizerFast, GPT2LMHeadModel

# Define your special tokens and tokenizer
Q_TKN = "<usr>"
A_TKN = "<sys>"
SENT = "<unused1>"
EOS = "</s>"
BOS = "</s>"

# Initialize your tokenizer
koGPT2_TOKENIZER = PreTrainedTokenizerFast.from_pretrained(
    "skt/kogpt2-base-v2",
    bos_token=BOS,
    eos_token=EOS,
    unk_token="<unk>",
    pad_token="<pad>",
    mask_token="<unused0>",
)

# Load your trained model
model_path = "saved_models/chatbot_model.pth"  # Adjust the path to your saved model
model = GPT2LMHeadModel.from_pretrained("skt/kogpt2-base-v2")
checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

# Interaction loop with the chatbot
with torch.no_grad():
    while 1:
        q = input("user > ").strip()
        if q == "quit":
            break
        a = ""
        while 1:
            input_ids = torch.LongTensor(
                koGPT2_TOKENIZER.encode(Q_TKN + q + SENT + A_TKN + a)
            ).unsqueeze(dim=0)
            pred = model(input_ids)
            pred = pred.logits
            gen = koGPT2_TOKENIZER.convert_ids_to_tokens(
                torch.argmax(pred, dim=-1).squeeze().numpy().tolist()
            )[-1]
            if gen == EOS:
                break
            a += gen.replace("▁", " ")
        print("Chatbot > {}".format(a.strip()))



성능

이미지 설명

음…. 애매하다… 먼가 잘 되는 것 같긴 하지만 정해진 답변을 하는 느낌이 강하다. 데이터가 70만건이라 대화를 능숙하게 하기에는 매우 적은 데이터이긴 하다. 그래도 매뉴얼적인 답변을 하는 테스크를 원한다면 괜찮은 성능일거 같다. 역시 openAI에 돈주고 서비스 이용하는 것이 고성능적인 면에서는 괜찮을거 같다. 끝!

author-profile
Written by 유찬영

댓글