set_seed(23)Simple Correction Data
data_dir = Path(os.getcwd()) / "data" / "dataset_training_sample"
data, md = generate_data(data_dir)
val_files = ["en/eng_sample/2.txt"]
token_data = get_tokens_with_OCR_mistakes(data, data, val_files)
vocab_transform = generate_vocabs(token_data.query('dataset == "train"'))2it [00:00, 798.15it/s]
SimpleCorrectionDataset
SimpleCorrectionDataset (data, max_len=10)
*An abstract class representing a :class:Dataset.
All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:__getitem__, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:__len__, which is expected to return the size of the dataset by many :class:~torch.utils.data.Sampler implementations and the default options of :class:~torch.utils.data.DataLoader. Subclasses could also optionally implement :meth:__getitems__, for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.
.. note:: :class:~torch.utils.data.DataLoader by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.*
To create a SimpleCorrectionDataset with a maximum token length of 10, do:
dataset = SimpleCorrectionDataset(token_data.query('dataset == "train"'), max_len=10)The first sample look like this:
dataset[0](['t', 'e', 's', 't', '-', ' ', 'A', 'A', 'A'],
['t', 'e', 's', 't', '-', '.', 'A', 'A', 'A'])
collate_fn
collate_fn (text_transform)
Function to collate data samples into batch tensors
collate_fn_with_text_transform
collate_fn_with_text_transform (text_transform, batch)
Function to collate data samples into batch tensors, to be used as partial with instatiated text_transform
text_transform = get_text_transform(vocab_transform)train = SimpleCorrectionDataset(token_data.query('dataset == "train"'), max_len=10)
train_dataloader = DataLoader(
train, batch_size=5, collate_fn=collate_fn(text_transform)
)Training
validate_model
validate_model (model, dataloader, device)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 2
hidden_size = 5
dropout = 0.1
max_token_len = 10
model = SimpleCorrectionSeq2seq(
len(vocab_transform["ocr"]),
hidden_size,
len(vocab_transform["gs"]),
dropout,
max_token_len,
teacher_forcing_ratio=0.5,
device=device,
)
encoder_hidden = model.encoder.initHidden(batch_size=batch_size, device=device)val = SimpleCorrectionDataset(token_data.query('dataset == "val"'), max_len=10)
val_dataloader = DataLoader(val, batch_size=5, collate_fn=collate_fn(text_transform))
loss = validate_model(model, val_dataloader, device)
loss25.545663621690537
train_model
train_model (train_dl, val_dl, model=None, optimizer=None, num_epochs=5, valid_niter=5000, model_save_path='model.rar', max_num_patience=5, max_num_trial=5, lr_decay=0.5, device='cpu')
train = SimpleCorrectionDataset(token_data.query('dataset == "train"'), max_len=10)
train_dataloader = DataLoader(
train, batch_size=2, collate_fn=collate_fn(text_transform), shuffle=True
)
val = SimpleCorrectionDataset(token_data.query('dataset == "val"'), max_len=10)
val_dataloader = DataLoader(val, batch_size=3, collate_fn=collate_fn(text_transform))
hidden_size = 5
model = SimpleCorrectionSeq2seq(
len(vocab_transform["ocr"]),
hidden_size,
len(vocab_transform["gs"]),
0.1,
10,
teacher_forcing_ratio=0.0,
)
model.to(device)
optimizer = torch.optim.Adam(model.parameters())
msp = Path(os.getcwd()) / "data" / "model.rar"
train_model(
train_dl=train_dataloader,
val_dl=val_dataloader,
model=model,
optimizer=optimizer,
model_save_path=msp,
num_epochs=2,
valid_niter=5,
max_num_patience=5,
max_num_trial=5,
lr_decay=0.5,
)Epoch 1, iter 5, avg. train loss 25.21373109817505, avg. val loss 25.264954460991753
Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model.rar
Epoch 1, iter 10, avg. train loss 27.308312225341798, avg. val loss 25.19587156507704
Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model.rar
Epoch 2, iter 15, avg. train loss 25.64889602661133, avg. val loss 25.134972466362846
Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model.rar
Epoch 2, iter 20, avg. train loss 26.240159034729004, avg. val loss 25.078634050157333
Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model.rar
Epoch 2, iter 25, avg. train loss 22.31423110961914, avg. val loss 25.014130486382378
Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model.rar
Inference / prediction
https://pytorch.org/tutorials/beginner/chatbot_tutorial.html?highlight=greedy%20decoding
GreedySearchDecoder
GreedySearchDecoder (model)
*Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.
.. note:: As per the example above, an __init__() call to the parent class must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool*
decoder = GreedySearchDecoder(model)
max_len = 10
test = SimpleCorrectionDataset(token_data.query('dataset == "test"'), max_len=max_len)
test_dataloader = DataLoader(test, batch_size=5, collate_fn=collate_fn(text_transform))
with torch.no_grad():
for i, (src, tgt) in enumerate(test_dataloader):
predicted_indices = decoder(src, tgt)
if i == 0:
print(predicted_indices)
else:
print(predicted_indices.size())torch.Size([1, 5, 5])
tensor([[27, 27, 27, 17, 7, 17, 7, 7, 17, 27, 0],
[18, 27, 27, 27, 27, 27, 17, 17, 27, 17, 0],
[18, 3, 18, 27, 17, 26, 27, 27, 27, 27, 0],
[18, 26, 27, 18, 27, 27, 27, 27, 27, 27, 0],
[ 6, 27, 27, 27, 27, 17, 17, 7, 17, 7, 0]])
torch.Size([1, 5, 5])
torch.Size([5, 11])
torch.Size([1, 5, 5])
torch.Size([5, 11])
torch.Size([1, 5, 5])
torch.Size([5, 11])
torch.Size([1, 5, 5])
torch.Size([5, 11])
torch.Size([1, 5, 5])
torch.Size([5, 11])
torch.Size([1, 5, 5])
torch.Size([5, 11])
predict_and_convert_to_str
predict_and_convert_to_str (model, dataloader, tgt_vocab, device)
output_strings = predict_and_convert_to_str(
model, test_dataloader, vocab_transform["gs"], device
)
output_strings[0:3]100%|██████████| 7/7 [00:00<00:00, 352.93it/s]
torch.Size([1, 5, 5])
torch.Size([1, 5, 5])
torch.Size([1, 5, 5])
torch.Size([1, 5, 5])
torch.Size([1, 5, 5])
torch.Size([1, 5, 5])
torch.Size([1, 5, 5])
['mmmmmmmmmm', 'Fmmmmmmmmm', 'Fmmmmmmmmm']
max_len = 10
test_data = (
token_data.query('dataset == "test"')
.query(f"len_ocr <= {max_len}")
.query(f"len_gs <= {max_len}")
.copy()
)
test_data["pred"] = output_strings