본문 바로가기

Engineering

GPT로 날짜 만들어보기

사용한 패키지는 torch와 transformers만 사용했다.

python3 -m venv venv
source venv/bin/activate
pip install torch transformers

GPT2로 날짜를 생성해보는 학습을 한번 진행해봤다.

1. 1990년도부터 2030년도에 해당하는 날짜를 모두 만들어 학습 데이터를 생성한다.

import calendar
from datetime import datetime

texts = []
for year in range(1990, 2030):
    for month in range(1, 13):
        for day in range(1, calendar.monthrange(year, month)[1]+1):
            texts.append(datetime(year, month, day).strftime("%Y년 %m월 %d일"))
            texts.append(datetime(year, month, day).strftime("%Y-%m-%d"))

2. 이 학습 데이터를 기반으로 학습 데이터로더를 구성한다.

학습 배치를 구성할 때, 텍스트를 토크나이징하고 패딩을 구성한다.

패딩에 해당하는 부분에서 손실값을 계산하지 않게 하기 위해서 패딩 부분에 ignore_index로 -100을 부여한다.

from transformers import GPT2Tokenizer
from torch.utils.data import DataLoader

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id

def collate_fn(examples):
    encodings = tokenizer(
        [tokenizer.bos_token+example+tokenizer.eos_token for example in examples],
        padding=True,
        truncation=True,
        return_tensors="pt",
    )
    encodings["labels"] = torch.where(
    	encodings["attention_mask"].bool(),
        encodings["input_ids"],
        -100,
    )
    return encodings
    
dataloader = DataLoader(texts, shuffle=True, batch_size=128, collate_fn=collate_fn)

3. 생성된 텍스트가 날짜 형식을 갖추는지 평가하기 위한 함수를 만든다.

def check(text):
    flag = False
    try:
        datetime.fromisoformat(text)
        flag = True
    except ValueError:
        pass
    try:
        datetime.strptime(text, "%Y년 %m월 %d일")
        flag = True
    except ValueError:
        pass
    return flag

4. 모델을 만들고 학습을 진행한다.

from transformers import GPT2LMHeadModel
import torch
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-5, weight_decay=0.0)
scheduler = optim.lr_scheduler.LinearLR(
    optimizer, start_factor=1.0, end_factor=0.0, total_iters=5*len(dataloader)
)
step = 0
for epoch in range(5):
    for batch in dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        optimizer.zero_grad(set_to_none=True)
        output = model(**batch)
        output.loss.backward()
        optimizer.step()
        scheduler.step()
        step += 1
        if step % 10 == 0:
            print(
                f"step=[{step}/{5*len(dataloader)}], "
                f"loss={output.loss.item():.4f}, "
                f"lr={optimizer.param_groups[0]['lr']:.6f}"
            )
            model.eval()
            generated_ids = model.generate(
                do_sample=True,
                max_new_tokens=32,
                num_return_sequences=8,
                pad_token_id=tokenizer.eos_token_id,
            )
            generated_texts = tokenizer.batch_decode(
                generated_ids,
                skip_special_tokens=True,
            )
            print(f"acc={sum(check(x) for x in generated_texts)/8:.2f}")
            model.train()

8개를 생성했을 때, 1개를 제외하고 다 맞는 결과를 얻을 수 있었다..!

마지막에 생성된 문장은 '2002년 06월 20일' 이었다.