Simple Correction Data

Functionality for error correction with the simple dataset
set_seed(23)
data_dir = Path(os.getcwd()) / "data" / "dataset_training_sample"
data, md = generate_data(data_dir)
val_files = ["en/eng_sample/2.txt"]

token_data = get_tokens_with_OCR_mistakes(data, data, val_files)
vocab_transform = generate_vocabs(token_data.query('dataset == "train"'))
2it [00:00, 798.15it/s]

SimpleCorrectionDataset

 SimpleCorrectionDataset (data, max_len=10)

*An abstract class representing a :class:Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:__getitem__, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:__len__, which is expected to return the size of the dataset by many :class:~torch.utils.data.Sampler implementations and the default options of :class:~torch.utils.data.DataLoader.

.. note:: :class:~torch.utils.data.DataLoader by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.*

To create a SimpleCorrectionDataset with a maximum token length of 10, do:

dataset = SimpleCorrectionDataset(token_data.query('dataset == "train"'), max_len=10)

The first sample look like this:

dataset[0]
(['t', 'e', 's', 't', '-', ' ', 'A', 'A', 'A'],
 ['t', 'e', 's', 't', '-', '.', 'A', 'A', 'A'])

collate_fn

 collate_fn (text_transform)

Function to collate data samples into batch tensors


collate_fn_with_text_transform

 collate_fn_with_text_transform (text_transform, batch)

Function to collate data samples into batch tensors, to be used as partial with instatiated text_transform

text_transform = get_text_transform(vocab_transform)
train = SimpleCorrectionDataset(token_data.query('dataset == "train"'), max_len=10)
train_dataloader = DataLoader(
    train, batch_size=5, collate_fn=collate_fn(text_transform)
)

Training


validate_model

 validate_model (model, dataloader, device)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 2
hidden_size = 5
dropout = 0.1
max_token_len = 10

model = SimpleCorrectionSeq2seq(
    len(vocab_transform["ocr"]),
    hidden_size,
    len(vocab_transform["gs"]),
    dropout,
    max_token_len,
    teacher_forcing_ratio=0.5,
    device=device,
)

encoder_hidden = model.encoder.initHidden(batch_size=batch_size, device=device)
val = SimpleCorrectionDataset(token_data.query('dataset == "val"'), max_len=10)
val_dataloader = DataLoader(val, batch_size=5, collate_fn=collate_fn(text_transform))

loss = validate_model(model, val_dataloader, device)
loss
25.545663621690537

train_model

 train_model (train_dl, val_dl, model=None, optimizer=None, num_epochs=5,
              valid_niter=5000, model_save_path='model.rar',
              max_num_patience=5, max_num_trial=5, lr_decay=0.5,
              device='cpu')
train = SimpleCorrectionDataset(token_data.query('dataset == "train"'), max_len=10)
train_dataloader = DataLoader(
    train, batch_size=2, collate_fn=collate_fn(text_transform), shuffle=True
)

val = SimpleCorrectionDataset(token_data.query('dataset == "val"'), max_len=10)
val_dataloader = DataLoader(val, batch_size=3, collate_fn=collate_fn(text_transform))

hidden_size = 5
model = SimpleCorrectionSeq2seq(
    len(vocab_transform["ocr"]),
    hidden_size,
    len(vocab_transform["gs"]),
    0.1,
    10,
    teacher_forcing_ratio=0.0,
)
model.to(device)
optimizer = torch.optim.Adam(model.parameters())

msp = Path(os.getcwd()) / "data" / "model.rar"

train_model(
    train_dl=train_dataloader,
    val_dl=val_dataloader,
    model=model,
    optimizer=optimizer,
    model_save_path=msp,
    num_epochs=2,
    valid_niter=5,
    max_num_patience=5,
    max_num_trial=5,
    lr_decay=0.5,
)
Epoch 1, iter 5, avg. train loss 25.21373109817505, avg. val loss 25.264954460991753
Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model.rar
Epoch 1, iter 10, avg. train loss 27.308312225341798, avg. val loss 25.19587156507704
Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model.rar
Epoch 2, iter 15, avg. train loss 25.64889602661133, avg. val loss 25.134972466362846
Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model.rar
Epoch 2, iter 20, avg. train loss 26.240159034729004, avg. val loss 25.078634050157333
Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model.rar
Epoch 2, iter 25, avg. train loss 22.31423110961914, avg. val loss 25.014130486382378
Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model.rar

Inference / prediction

https://pytorch.org/tutorials/beginner/chatbot_tutorial.html?highlight=greedy%20decoding


GreedySearchDecoder

 GreedySearchDecoder (model)

*Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

.. note:: As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool*

decoder = GreedySearchDecoder(model)

max_len = 10

test = SimpleCorrectionDataset(token_data.query('dataset == "test"'), max_len=max_len)
test_dataloader = DataLoader(test, batch_size=5, collate_fn=collate_fn(text_transform))

with torch.no_grad():
    for i, (src, tgt) in enumerate(test_dataloader):
        predicted_indices = decoder(src, tgt)
        if i == 0:
            print(predicted_indices)
        else:
            print(predicted_indices.size())
torch.Size([1, 5, 5])
tensor([[27, 27, 27, 17,  7, 17,  7,  7, 17, 27,  0],
        [18, 27, 27, 27, 27, 27, 17, 17, 27, 17,  0],
        [18,  3, 18, 27, 17, 26, 27, 27, 27, 27,  0],
        [18, 26, 27, 18, 27, 27, 27, 27, 27, 27,  0],
        [ 6, 27, 27, 27, 27, 17, 17,  7, 17,  7,  0]])
torch.Size([1, 5, 5])
torch.Size([5, 11])
torch.Size([1, 5, 5])
torch.Size([5, 11])
torch.Size([1, 5, 5])
torch.Size([5, 11])
torch.Size([1, 5, 5])
torch.Size([5, 11])
torch.Size([1, 5, 5])
torch.Size([5, 11])
torch.Size([1, 5, 5])
torch.Size([5, 11])

predict_and_convert_to_str

 predict_and_convert_to_str (model, dataloader, tgt_vocab, device)
output_strings = predict_and_convert_to_str(
    model, test_dataloader, vocab_transform["gs"], device
)
output_strings[0:3]
100%|██████████| 7/7 [00:00<00:00, 352.93it/s]
torch.Size([1, 5, 5])
torch.Size([1, 5, 5])
torch.Size([1, 5, 5])
torch.Size([1, 5, 5])
torch.Size([1, 5, 5])
torch.Size([1, 5, 5])
torch.Size([1, 5, 5])
['mmmmmmmmmm', 'Fmmmmmmmmm', 'Fmmmmmmmmm']
max_len = 10
test_data = (
    token_data.query('dataset == "test"')
    .query(f"len_ocr <= {max_len}")
    .query(f"len_gs <= {max_len}")
    .copy()
)

test_data["pred"] = output_strings