23) set_seed(
Simple Correction Data
= Path(os.getcwd()) / "data" / "dataset_training_sample"
data_dir = generate_data(data_dir)
data, md = ["en/eng_sample/2.txt"]
val_files
= get_tokens_with_OCR_mistakes(data, data, val_files)
token_data = generate_vocabs(token_data.query('dataset == "train"')) vocab_transform
2it [00:00, 798.15it/s]
SimpleCorrectionDataset
SimpleCorrectionDataset (data, max_len=10)
*An abstract class representing a :class:Dataset
.
All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:__getitem__
, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:__len__
, which is expected to return the size of the dataset by many :class:~torch.utils.data.Sampler
implementations and the default options of :class:~torch.utils.data.DataLoader
.
.. note:: :class:~torch.utils.data.DataLoader
by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.*
To create a SimpleCorrectionDataset
with a maximum token length of 10, do:
= SimpleCorrectionDataset(token_data.query('dataset == "train"'), max_len=10) dataset
The first sample look like this:
0] dataset[
(['t', 'e', 's', 't', '-', ' ', 'A', 'A', 'A'],
['t', 'e', 's', 't', '-', '.', 'A', 'A', 'A'])
collate_fn
collate_fn (text_transform)
Function to collate data samples into batch tensors
collate_fn_with_text_transform
collate_fn_with_text_transform (text_transform, batch)
Function to collate data samples into batch tensors, to be used as partial with instatiated text_transform
= get_text_transform(vocab_transform) text_transform
= SimpleCorrectionDataset(token_data.query('dataset == "train"'), max_len=10)
train = DataLoader(
train_dataloader =5, collate_fn=collate_fn(text_transform)
train, batch_size )
Training
validate_model
validate_model (model, dataloader, device)
= torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
= 2
batch_size = 5
hidden_size = 0.1
dropout = 10
max_token_len
= SimpleCorrectionSeq2seq(
model len(vocab_transform["ocr"]),
hidden_size,len(vocab_transform["gs"]),
dropout,
max_token_len,=0.5,
teacher_forcing_ratio=device,
device
)
= model.encoder.initHidden(batch_size=batch_size, device=device) encoder_hidden
= SimpleCorrectionDataset(token_data.query('dataset == "val"'), max_len=10)
val = DataLoader(val, batch_size=5, collate_fn=collate_fn(text_transform))
val_dataloader
= validate_model(model, val_dataloader, device)
loss loss
25.545663621690537
train_model
train_model (train_dl, val_dl, model=None, optimizer=None, num_epochs=5, valid_niter=5000, model_save_path='model.rar', max_num_patience=5, max_num_trial=5, lr_decay=0.5, device='cpu')
= SimpleCorrectionDataset(token_data.query('dataset == "train"'), max_len=10)
train = DataLoader(
train_dataloader =2, collate_fn=collate_fn(text_transform), shuffle=True
train, batch_size
)
= SimpleCorrectionDataset(token_data.query('dataset == "val"'), max_len=10)
val = DataLoader(val, batch_size=3, collate_fn=collate_fn(text_transform))
val_dataloader
= 5
hidden_size = SimpleCorrectionSeq2seq(
model len(vocab_transform["ocr"]),
hidden_size,len(vocab_transform["gs"]),
0.1,
10,
=0.0,
teacher_forcing_ratio
)
model.to(device)= torch.optim.Adam(model.parameters())
optimizer
= Path(os.getcwd()) / "data" / "model.rar"
msp
train_model(=train_dataloader,
train_dl=val_dataloader,
val_dl=model,
model=optimizer,
optimizer=msp,
model_save_path=2,
num_epochs=5,
valid_niter=5,
max_num_patience=5,
max_num_trial=0.5,
lr_decay )
Epoch 1, iter 5, avg. train loss 25.21373109817505, avg. val loss 25.264954460991753
Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model.rar
Epoch 1, iter 10, avg. train loss 27.308312225341798, avg. val loss 25.19587156507704
Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model.rar
Epoch 2, iter 15, avg. train loss 25.64889602661133, avg. val loss 25.134972466362846
Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model.rar
Epoch 2, iter 20, avg. train loss 26.240159034729004, avg. val loss 25.078634050157333
Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model.rar
Epoch 2, iter 25, avg. train loss 22.31423110961914, avg. val loss 25.014130486382378
Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model.rar
Inference / prediction
https://pytorch.org/tutorials/beginner/chatbot_tutorial.html?highlight=greedy%20decoding
GreedySearchDecoder
GreedySearchDecoder (model)
*Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to
, etc.
.. note:: As per the example above, an __init__()
call to the parent class must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool*
= GreedySearchDecoder(model)
decoder
= 10
max_len
= SimpleCorrectionDataset(token_data.query('dataset == "test"'), max_len=max_len)
test = DataLoader(test, batch_size=5, collate_fn=collate_fn(text_transform))
test_dataloader
with torch.no_grad():
for i, (src, tgt) in enumerate(test_dataloader):
= decoder(src, tgt)
predicted_indices if i == 0:
print(predicted_indices)
else:
print(predicted_indices.size())
torch.Size([1, 5, 5])
tensor([[27, 27, 27, 17, 7, 17, 7, 7, 17, 27, 0],
[18, 27, 27, 27, 27, 27, 17, 17, 27, 17, 0],
[18, 3, 18, 27, 17, 26, 27, 27, 27, 27, 0],
[18, 26, 27, 18, 27, 27, 27, 27, 27, 27, 0],
[ 6, 27, 27, 27, 27, 17, 17, 7, 17, 7, 0]])
torch.Size([1, 5, 5])
torch.Size([5, 11])
torch.Size([1, 5, 5])
torch.Size([5, 11])
torch.Size([1, 5, 5])
torch.Size([5, 11])
torch.Size([1, 5, 5])
torch.Size([5, 11])
torch.Size([1, 5, 5])
torch.Size([5, 11])
torch.Size([1, 5, 5])
torch.Size([5, 11])
predict_and_convert_to_str
predict_and_convert_to_str (model, dataloader, tgt_vocab, device)
= predict_and_convert_to_str(
output_strings "gs"], device
model, test_dataloader, vocab_transform[
)0:3] output_strings[
100%|██████████| 7/7 [00:00<00:00, 352.93it/s]
torch.Size([1, 5, 5])
torch.Size([1, 5, 5])
torch.Size([1, 5, 5])
torch.Size([1, 5, 5])
torch.Size([1, 5, 5])
torch.Size([1, 5, 5])
torch.Size([1, 5, 5])
['mmmmmmmmmm', 'Fmmmmmmmmm', 'Fmmmmmmmmm']
= 10
max_len = (
test_data 'dataset == "test"')
token_data.query(f"len_ocr <= {max_len}")
.query(f"len_gs <= {max_len}")
.query(
.copy()
)
"pred"] = output_strings test_data[