hp-dial/src/train.py

185 lines
4.2 KiB
Python

from torchtext.data import Field, BucketIterator, TabularDataset
import torch
from torchtext import data
from model import Seq2Seq, Encoder, Decoder, Attention
import math
import time
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from spacy.tokenizer import Tokenizer
from spacy.lang.eu import Basque
nlp = Basque()
def tokenizer(s):
return list(map(lambda x: x.text, nlp(s)))
text_field = Field(
init_token="<sos>",
eos_token="<eos>",
lower=True,
tokenize=tokenizer,
tokenizer_language="eu",
)
fields = [("query", text_field), ("answer", text_field)]
train_data = TabularDataset(
path="../data/eu_train.tsv", format="tsv", fields=fields
)
text_field.build_vocab(train_data, min_freq=5)
print("Vocabulary has been built")
print("Vocab len is {}".format(len(text_field.vocab)))
# Save the text field for testing
torch.save(text_field, "../model/text_field.Field")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32
train_iterator = BucketIterator(
dataset=train_data,
batch_size=BATCH_SIZE,
sort_key=lambda x: data.interleave_keys(len(x.query), len(x.answer)),
device=device,
)
# Tamainak egokitu zuen beharretara
INPUT_DIM = len(text_field.vocab)
OUTPUT_DIM = len(text_field.vocab)
ENC_EMB_DIM = 128
DEC_EMB_DIM = 128
ENC_HID_DIM = 256
DEC_HID_DIM = 256
ATTN_DIM = 32
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
attn = Attention(ENC_HID_DIM, DEC_HID_DIM, ATTN_DIM)
dec = Decoder(
OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn
)
model = Seq2Seq(enc, dec, device).to(device)
def init_weights(m: nn.Module):
for name, param in m.named_parameters():
if "weight" in name:
nn.init.normal_(param.data, mean=0, std=0.01)
else:
nn.init.constant_(param.data, 0)
model.apply(init_weights)
optimizer = optim.Adam(model.parameters())
def count_parameters(model: nn.Module):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"The model has {count_parameters(model):,} trainable parameters")
PAD_IDX = text_field.vocab.stoi["<pad>"]
def bp(c, r):
if c > r:
res = 1
else:
res = math.exp(1 - r / c)
return res
def normalDistribution(mu, sigma):
return lambda x: math.exp(-0.5 * math.pow((x - mu) / sigma, 2)) / (
sigma * math.sqrt(2 * math.pi)
)
def rayleighDistribution(sigma2):
return lambda x: x / sigma2 * math.exp(-math.pow(x, 2) / (2 * sigma2))
def criterion(x, y):
f = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
ray = lambda x: rayleighDistribution(20)(x*5.8+0.2)
return (1 - ray(bp(x.shape[0], y.shape[0]))/0.1356243) * f(x, y)
def train(
model: nn.Module,
iterator: BucketIterator,
optimizer: optim.Optimizer,
criterion: nn.Module,
clip: float,
):
model.train()
epoch_loss = 0
for _, batch in tqdm(enumerate(iterator), total=len(iterator)):
src = batch.query
trg = batch.answer
optimizer.zero_grad()
output = model(src, trg)
output = output[1:].view(-1, output.shape[-1])
trg = trg[1:].view(-1)
loss = criterion(output, trg)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
optimizer.step()
epoch_loss += loss.item()
return epoch_loss / len(iterator)
def epoch_time(start_time: int, end_time: int):
elapsed_time = end_time - start_time
elapsed_mins = int(elapsed_time / 60)
elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
return elapsed_mins, elapsed_secs
N_EPOCHS = 10
CLIP = 1
best_valid_loss = float("inf")
for epoch in tqdm(range(N_EPOCHS)):
start_time = time.time()
train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
end_time = time.time()
epoch_mins, epoch_secs = epoch_time(start_time, end_time)
print(f"Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s")
print(
f"\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}"
)
# Save checkpoint
torch.save(model, "../model/model.pt")