Error Correction

Functionality for error correction (task 2)

Dataset Creation

A dataset for token correction consists of the OCR text and gold standard of AlignedTokens. These can be extracted from the Text objects using the get_tokens_with_OCR_mistakes function. This function also adds data properties that can be used for calculating statistics about the data.


get_tokens_with_OCR_mistakes

 get_tokens_with_OCR_mistakes
                               (data:Dict[str,ocrpostcorrection.icdar_data
                               .Text], data_test:Dict[str,ocrpostcorrectio
                               n.icdar_data.Text], val_files:List[str])

Return pandas dataframe with all OCR mistakes from train, val, and test

The following code example shows how use this function. For simplicity, in the example below, the data dictionary (which contain <file name>: Text pairs) is used both as train/val and test set.

data_dir = Path(os.getcwd()) / "data" / "dataset_training_sample"
data, md = generate_data(data_dir)
val_files = ["en/eng_sample/2.txt"]

token_data = get_tokens_with_OCR_mistakes(data, data, val_files)
print(token_data.shape)
token_data.head()
2it [00:00, 1508.20it/s]
(80, 12)
ocr gs ocr_aligned gs_aligned start len_ocr key language subset dataset len_gs diff
0 In In ## 0 2 en/eng_sample/1.txt en eng_sample test 0 2
1 troe tree troe tree 13 4 en/eng_sample/1.txt en eng_sample test 4 0
2 peremial perennial perem@ial perennial 23 8 en/eng_sample/1.txt en eng_sample test 9 -1
3 eLngated elongated eL@ngated elongated 46 8 en/eng_sample/1.txt en eng_sample test 9 -1
4 stein, stem, stein, stem@, 55 6 en/eng_sample/1.txt en eng_sample test 5 1

Get the context of an ocr mistake.


get_OCR_mistakes_in_context

 get_OCR_mistakes_in_context
                              (data:Dict[str,ocrpostcorrection.icdar_data.
                              Text], data_test:Dict[str,ocrpostcorrection.
                              icdar_data.Text],
                              ocr_mistakes:pandas.core.frame.DataFrame,
                              offset:int)

get_context_for_dataset

 get_context_for_dataset
                          (data:Dict[str,ocrpostcorrection.icdar_data.Text
                          ], ocr_mistakes:pandas.core.frame.DataFrame,
                          offset:int)

get_closest_value

 get_closest_value (lst:List, value:int)
token_data2 = get_OCR_mistakes_in_context(data, data, token_data, offset=20)
print(token_data2.shape)
token_data2.head()
100%|██████████| 4/4 [00:00<00:00, 873.54it/s]
100%|██████████| 4/4 [00:00<00:00, 913.24it/s]
(80, 15)
ocr gs ocr_aligned gs_aligned start len_ocr key language subset dataset len_gs diff context_before context_after len_mistake_in_context
0 In In ## 0 2 en/eng_sample/1.txt en eng_sample train 0 2 botany, a troe is a 22
1 troe tree troe tree 13 4 en/eng_sample/1.txt en eng_sample train 4 0 In botany, a is a peremial plant 37
2 peremial perennial perem@ial perennial 23 8 en/eng_sample/1.txt en eng_sample train 9 -1 botany, a troe is a plant with an eLngated 51
3 eLngated elongated eL@ngated elongated 46 8 en/eng_sample/1.txt en eng_sample train 9 -1 peremial plant with an stein, or trunk, 48
4 stein, stem, stein, stem@, 55 6 en/eng_sample/1.txt en eng_sample train 5 1 plant with an eLngated or trunk, suppor ing 50
token_data2.tail()
ocr gs ocr_aligned gs_aligned start len_ocr key language subset dataset len_gs diff context_before context_after len_mistake_in_context
35 test-FFF test- FFF test-@FFF test- FFF 48 8 fr/fr_sample/2.txt fr fr_sample test 9 -1 test -DDD test- EEE test-GGG test - HHH 48
36 test-GGG test -GGG test@-GGG test -GGG 57 8 fr/fr_sample/2.txt fr fr_sample test 9 -1 test- EEE test-FFF test - HHH test-III 47
37 test - HHH test-HHH test - HHH test@-@HHH 66 10 fr/fr_sample/2.txt fr fr_sample test 8 2 EEE test-FFF test-GGG test-III test - JJJ 52
38 test-III test - III test@-@III test - III 77 8 fr/fr_sample/2.txt fr fr_sample test 10 -2 test-GGG test - HHH test - JJJ blablabla 49
39 ? ! ? ! 107 1 fr/fr_sample/2.txt fr fr_sample test 1 0 test - JJJ blablabla 22

Create vocabularies

https://pytorch.org/tutorials/beginner/translation_transformer.html

Define special symbols and indices and make sure the tokens are in order of their indices to properly insert them in vocababulary.


yield_tokens

 yield_tokens (data, col)

Helper function to create vocabulary containing characters


generate_vocabs

 generate_vocabs (train)

Generate ocr and gs vocabularies from the train set

Use the trainset to create the ocr and gs vocabularies:

vocab_transform = generate_vocabs(token_data.query('dataset == "train"'))
len(vocab_transform["ocr"]), len(vocab_transform["gs"])
(46, 44)

Collation

The character sequences need to be transformed into vectors.

Source: https://pytorch.org/tutorials/beginner/translation_transformer.html


get_text_transform

 get_text_transform (vocab_transform)

Returns text transforms to convert raw strings into tensors indices


tensor_transform

 tensor_transform (token_ids:List[int])

Function to add BOS/EOS and create tensor for input sequence indices


sequential_transforms

 sequential_transforms (*transforms)

Helper function to club together sequential operations

text_transform = get_text_transform(vocab_transform)

text_transform["ocr"](["t", "e", "s", "t", "-", " ", "A", "A", "A"])
tensor([ 4,  5,  6,  4,  7, 10, 13, 13, 13,  3])
text_transform = get_text_transform(vocab_transform)

print(text_transform["ocr"](["e", "x", "a", "m", "p", "l", "e"]))
print(text_transform["gs"](["e", "x", "a", "m", "p", "l", "e"]))
tensor([ 5,  0, 21, 34, 22, 33,  5,  3])
tensor([ 5,  0, 21, 27, 23, 26,  5,  3])

Neural network

Network: https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html


EncoderRNN

 EncoderRNN (input_size, hidden_size)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

.. note:: As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool


AttnDecoderRNN

 AttnDecoderRNN (hidden_size, output_size, dropout_p=0.1, max_length=11)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

.. note:: As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool


SimpleCorrectionSeq2seq

 SimpleCorrectionSeq2seq (input_size, hidden_size, output_size, dropout,
                          max_length, teacher_forcing_ratio, device='cpu')

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

.. note:: As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 2
hidden_size = 5
dropout = 0.1
max_token_len = 10

model = SimpleCorrectionSeq2seq(
    len(vocab_transform["ocr"]),
    hidden_size,
    len(vocab_transform["gs"]),
    dropout,
    max_token_len,
    teacher_forcing_ratio=0.5,
    device=device,
)

input = torch.tensor([[6, 4], [22, 30], [0, 6], [18, 4], [11, 3], [3, 1]])
encoder_hidden = model.encoder.initHidden(batch_size=batch_size, device=device)

target = torch.tensor([[6, 4], [23, 5], [16, 6], [16, 4], [11, 4], [3, 1]])

losses, _ = model(input, encoder_hidden, target)
losses
tensor([-23.0017, -19.0353], grad_fn=<SumBackward1>)

Evaluation

model_save_path = Path(os.getcwd()) / "data" / "model.rar"

checkpoint = torch.load(model_save_path)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer = torch.optim.Adam(model.parameters())
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
model.eval()
SimpleCorrectionSeq2seq(
  (encoder): EncoderRNN(
    (embedding): Embedding(46, 5)
    (gru): GRU(5, 5, batch_first=True)
  )
  (decoder): AttnDecoderRNN(
    (embedding): Embedding(44, 5)
    (attn): Linear(in_features=10, out_features=11, bias=True)
    (attn_combine): Linear(in_features=10, out_features=5, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (gru): GRU(5, 5)
    (out): Linear(in_features=5, out_features=44, bias=True)
  )
)

indices2string

 indices2string (indices, itos)
indices = torch.tensor(
    [
        [20, 34, 22, 6, 1, 1, 1, 1, 1, 1],
        [22, 6, 1, 1, 1, 1, 1, 1, 1, 1],
        [21, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [4, 5, 6, 4, 1, 1, 1, 1, 1, 1],
        [29, 1, 1, 1, 1, 1, 1, 1, 1, 1],
    ]
)
indices2string(indices, vocab_transform["gs"].get_itos())
['This', 'is', 'a', 'test', '!']
decoder = GreedySearchDecoder(model)

max_len = 10

test = SimpleCorrectionDataset(token_data.query('dataset == "test"'), max_len=max_len)
test_dataloader = DataLoader(test, batch_size=5, collate_fn=collate_fn(text_transform))


output_strings = predict_and_convert_to_str(
    model, test_dataloader, vocab_transform["gs"], device
)
100%|██████████| 7/7 [00:00<00:00, 163.36it/s]
torch.Size([1, 5, 5])
torch.Size([1, 5, 5])
torch.Size([1, 5, 5])
torch.Size([1, 5, 5])
torch.Size([1, 5, 5])
torch.Size([1, 5, 5])
torch.Size([1, 5, 5])
max_len = 10
test_data = (
    token_data.query('dataset == "test"')
    .query(f"len_ocr <= {max_len}")
    .query(f"len_gs <= {max_len}")
    .copy()
)

test_data["pred"] = output_strings

Performance measure: mean normalized edit distance

  • Mean (normalized) edit distance.
    • Option: ignore -
    • option: ignore case
test_data["ed"] = test_data.apply(
    lambda row: edlib.align(row.ocr, row.gs)["editDistance"], axis=1
)
test_data.ed.describe()
count    35.000000
mean      1.971429
std       1.773758
min       1.000000
25%       1.000000
50%       1.000000
75%       2.000000
max       8.000000
Name: ed, dtype: float64
test_data["ed_norm"] = test_data.apply(
    lambda row: normalized_ed(row.ed, row.ocr, row.gs), axis=1
)
test_data.ed_norm.describe()
count    35.000000
mean      0.390952
std       0.326909
min       0.100000
25%       0.125000
50%       0.250000
75%       0.583333
max       1.000000
Name: ed_norm, dtype: float64
test_data["ed_pred"] = test_data.apply(
    lambda row: edlib.align(row.pred, row.gs)["editDistance"], axis=1
)
test_data.ed_pred.describe()
count    35.000000
mean      8.057143
std       3.253440
min       1.000000
25%       7.000000
50%      10.000000
75%      10.000000
max      11.000000
Name: ed_pred, dtype: float64
test_data["ed_norm_pred"] = test_data.apply(
    lambda row: normalized_ed(row.ed_pred, row.pred, row.gs), axis=1
)
test_data.ed_norm_pred.describe()
count    35.000000
mean      0.989351
std       0.030110
min       0.900000
25%       1.000000
50%       1.000000
75%       1.000000
max       1.000000
Name: ed_norm_pred, dtype: float64