BERT Vectors Correction Data

Functionality for error correction with the BERT vectors correction dataset
set_seed(23)

BertVectorsCorrectionDataset

 BertVectorsCorrectionDataset (data:pandas.core.frame.DataFrame,
                               split_name:str, bert_vectors_file:Optional[
                               pathlib.Path]=None, max_len:int=11,
                               hidden_size:int=768,
                               look_up_bert_vectors:bool=True)

*An abstract class representing a :class:Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:__getitem__, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:__len__, which is expected to return the size of the dataset by many :class:~torch.utils.data.Sampler implementations and the default options of :class:~torch.utils.data.DataLoader.

.. note:: :class:~torch.utils.data.DataLoader by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.*

The sample bert vectors have been generated using python src/stages/create-bert-vectors.py --seed 1234 --dataset-in ../ocrpostcorrection/nbs/data/correction/dataset.csv --model-dir models/error-detection/ --model-name bert-base-multilingual-cased --batch-size 1 --out-file ../ocrpostcorrection/nbs/data/correction/bert-vectors.hdf5 (from ocrpostcorrection-notebooks, model from 9099e78)

data_csv = Path(os.getcwd()) / "data" / "correction" / "dataset.csv"
data = pd.read_csv(data_csv, index_col=0)
data.fillna("", inplace=True)
bert_vectors_file = Path(os.getcwd()) / "data" / "correction" / "bert-vectors.hdf5"
split_name = "test"

dataset = BertVectorsCorrectionDataset(
    data=data.query(f"dataset == '{split_name}'"),
    bert_vectors_file=bert_vectors_file,
    split_name=split_name,
    max_len=11,
    hidden_size=768,
    look_up_bert_vectors=True,
)
split_name = "test"
data_csv = Path(os.getcwd()) / "data" / "correction" / "dataset.csv"
data = pd.read_csv(data_csv, index_col=0)
data.fillna("", inplace=True)
bert_vectors_file = Path(os.getcwd()) / "data" / "correction" / "bert-vectors.hdf5"

dataset_no_look_up = BertVectorsCorrectionDataset(
    data=data.query(f"dataset == '{split_name}'"),
    bert_vectors_file=None,
    split_name=split_name,
    max_len=11,
    hidden_size=768,
    look_up_bert_vectors=False
)

collate_fn

 collate_fn (text_transform)

Function to collate data samples into batch tensors


collate_fn_with_text_transform

 collate_fn_with_text_transform (text_transform, batch)

Function to collate data samples into batch tensors, to be used as partial with instatiated text_transform

Training


validate_model

 validate_model (model, dataloader, device)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 2
hidden_size = 768
dropout = 0.1
max_token_len = 10

model = SimpleCorrectionSeq2seq(
    len(vocab_transform["ocr"]),
    hidden_size,
    len(vocab_transform["gs"]),
    dropout,
    max_token_len,
    teacher_forcing_ratio=0.5,
    device=device,
)
model.to(device)
SimpleCorrectionSeq2seq(
  (encoder): EncoderRNN(
    (embedding): Embedding(46, 768)
    (gru): GRU(768, 768, batch_first=True)
  )
  (decoder): AttnDecoderRNN(
    (embedding): Embedding(44, 768)
    (attn): Linear(in_features=1536, out_features=11, bias=True)
    (attn_combine): Linear(in_features=1536, out_features=768, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (gru): GRU(768, 768)
    (out): Linear(in_features=768, out_features=44, bias=True)
  )
)
split_name = "val"
data_csv = Path(os.getcwd()) / "data" / "correction" / "dataset.csv"
data = pd.read_csv(data_csv, index_col=0)
data.fillna("", inplace=True)
bert_vectors_file = Path(os.getcwd()) / "data" / "correction" / "bert-vectors.hdf5"

val = BertVectorsCorrectionDataset(
    data=data.query(f"dataset == '{split_name}'"),
    bert_vectors_file=bert_vectors_file,
    split_name=split_name,
    max_len=11,
    hidden_size=768,
    look_up_bert_vectors=True
)
val_dataloader = DataLoader(val, batch_size=5, collate_fn=collate_fn(text_transform))

loss = validate_model(model, val_dataloader, device)
loss
24.875640021430122
data
ocr gs ocr_aligned gs_aligned start len_ocr language subset dataset len_gs diff
0 test- AAA test-.AAA test- AAA test-.AAA 0 9 fr fr_sample train 9 0
1 test-BBB test- BBB test@-BBB test- BBB 10 8 fr fr_sample train 9 -1
2 test-CCC test- CCC test-@CCC test- CCC 19 8 fr fr_sample train 9 -1
3 -DDD DDD -DDD DDD 33 4 fr fr_sample train 3 1
4 test- EEE test-EEE test- EEE test-@EEE 38 9 fr fr_sample train 8 1
... ... ... ... ... ... ... ... ... ... ... ...
75 species! species. species! species. 111 8 en eng_sample test 8 0
76 Test -hyhen Testhyhen Test -hyhen Test@@hyhen 120 11 en eng_sample test 9 2
77 error errors error@ errors 137 5 en eng_sample test 6 -1
78 C CCC C@@ CCC 151 1 en eng_sample test 3 -2
79 34 3 4 3@4 3 4 153 2 en eng_sample test 3 -1

80 rows × 11 columns


train_model

 train_model (train_dl:torch.utils.data.dataloader.DataLoader[int],
              val_dl:torch.utils.data.dataloader.DataLoader[int], model:oc
              rpostcorrection.error_correction.SimpleCorrectionSeq2seq,
              optimizer:torch.optim.optimizer.Optimizer, num_epochs:int=5,
              valid_niter:int=5000,
              model_save_path:pathlib.Path=Path('model.rar'),
              max_num_patience:int=5, max_num_trial:int=5,
              lr_decay:float=0.5, device:torch.device=device(type='cpu'))
split_name = "train"
train = BertVectorsCorrectionDataset(
    data=data.query(f"dataset == '{split_name}'"),
    bert_vectors_file=bert_vectors_file,
    split_name=split_name,
)
train_dataloader = DataLoader(
    train, batch_size=2, collate_fn=collate_fn(text_transform), shuffle=True
)

split_name = "val"
val = BertVectorsCorrectionDataset(
    data=data.query(f"dataset == '{split_name}'"),
    bert_vectors_file=bert_vectors_file,
    split_name=split_name,
)
val_dataloader = DataLoader(val, batch_size=3, collate_fn=collate_fn(text_transform))

hidden_size = 768
model = SimpleCorrectionSeq2seq(
    len(vocab_transform["ocr"]),
    hidden_size,
    len(vocab_transform["gs"]),
    0.1,
    10,
    teacher_forcing_ratio=0.0,
)
model.to(device)
optimizer = torch.optim.Adam(model.parameters())

msp = Path(os.getcwd()) / "data" / "model_bert_vectors.rar"

train_log = train_model(
    train_dl=train_dataloader,
    val_dl=val_dataloader,
    model=model,
    optimizer=optimizer,
    model_save_path=msp,
    num_epochs=2,
    valid_niter=5,
    max_num_patience=5,
    max_num_trial=5,
    lr_decay=0.5,
)
os.remove(msp)
train_log
 31%|███       | 4/13 [00:00<00:01,  8.61it/s]2023-09-03 19:00:06.994 | INFO     | __main__:train_model:58 - Epoch 1, iter 5, avg. train loss 27.373350143432617, avg. val loss 24.627723693847656
2023-09-03 19:00:06.995 | INFO     | __main__:train_model:68 - Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model_bert_vectors.rar
 69%|██████▉   | 9/13 [00:01<00:00,  8.23it/s]2023-09-03 19:00:07.644 | INFO     | __main__:train_model:58 - Epoch 1, iter 10, avg. train loss 24.103273010253908, avg. val loss 24.043284098307293
2023-09-03 19:00:07.644 | INFO     | __main__:train_model:68 - Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model_bert_vectors.rar
100%|██████████| 13/13 [00:01<00:00,  7.17it/s]
  8%|▊         | 1/13 [00:00<00:01,  8.18it/s]2023-09-03 19:00:08.502 | INFO     | __main__:train_model:58 - Epoch 2, iter 15, avg. train loss 19.344550323486327, avg. val loss 19.952612982855904
2023-09-03 19:00:08.502 | INFO     | __main__:train_model:68 - Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model_bert_vectors.rar
 38%|███▊      | 5/13 [00:00<00:01,  6.30it/s]2023-09-03 19:00:09.292 | INFO     | __main__:train_model:58 - Epoch 2, iter 20, avg. train loss 22.214738655090333, avg. val loss 18.82603581746419
2023-09-03 19:00:09.293 | INFO     | __main__:train_model:68 - Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model_bert_vectors.rar
 77%|███████▋  | 10/13 [00:01<00:00,  6.35it/s]2023-09-03 19:00:10.207 | INFO     | __main__:train_model:58 - Epoch 2, iter 25, avg. train loss 21.808086776733397, avg. val loss 18.29765616522895
2023-09-03 19:00:10.208 | INFO     | __main__:train_model:68 - Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model_bert_vectors.rar
100%|██████████| 13/13 [00:02<00:00,  5.57it/s]
train_loss val_loss
0 27.373350 24.627724
1 24.103273 24.043284
2 19.344550 19.952613
3 22.214739 18.826036
4 21.808087 18.297656

Inference / prediction

https://pytorch.org/tutorials/beginner/chatbot_tutorial.html?highlight=greedy%20decoding


GreedySearchDecoder

 GreedySearchDecoder (model)

*Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

.. note:: As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool*

decoder = GreedySearchDecoder(model)

test_dataloader = DataLoader(dataset, batch_size=5, collate_fn=collate_fn(text_transform))

with torch.no_grad():
    for i, (src, tgt, encoder_hidden) in enumerate(test_dataloader):
        predicted_indices = decoder(src, encoder_hidden, tgt)
        if i == 0:
            print(predicted_indices)
        else:
            print(predicted_indices.size())
tensor([[ 4,  5,  6,  4,  8,  3,  3,  3,  3,  3,  0],
        [ 4,  5,  6,  4,  8,  4,  3,  3,  3,  3,  0],
        [ 4,  4,  5,  6,  4,  3,  3,  3,  3,  3,  0],
        [14, 13,  3,  3,  3,  3,  3,  3,  3,  3,  0],
        [ 4,  5,  6,  4,  8,  3,  3,  3,  3,  3,  0]])
torch.Size([5, 11])
torch.Size([5, 11])
torch.Size([5, 11])
torch.Size([5, 11])
torch.Size([5, 11])
torch.Size([5, 11])
decoder = GreedySearchDecoder(model)

test_dataloader = DataLoader(dataset_no_look_up, batch_size=5, collate_fn=collate_fn(text_transform))

with torch.no_grad():
    for i, (src, tgt, encoder_hidden) in enumerate(test_dataloader):
        predicted_indices = decoder(src, encoder_hidden, tgt)
        if i == 0:
            print(predicted_indices)
        else:
            print(predicted_indices.size())
tensor([[ 4,  5,  6,  4,  8,  6,  4,  3,  3,  3,  0],
        [ 4,  5,  6,  4,  8,  4,  3,  3,  3,  3,  0],
        [ 4,  4,  5,  6,  4,  3,  4,  3,  3,  3,  0],
        [14, 13,  3,  3,  3,  3,  3,  3,  3,  3,  0],
        [ 4,  5,  6,  4,  8,  3,  3,  3,  3,  3,  0]])
torch.Size([5, 11])
torch.Size([5, 11])
torch.Size([5, 11])
torch.Size([5, 11])
torch.Size([5, 11])
torch.Size([5, 11])

predict_and_convert_to_str

 predict_and_convert_to_str (model, dataloader, bert_model,
                             dataloader_bert_vectors, tgt_vocab, device)
model_name = "bert-base-multilingual-cased"

tokenizer = AutoTokenizer.from_pretrained(model_name)
bert_model = BertModel.from_pretrained(model_name)
bert_model.eval()

dataset_bert_vectors = HFDataset.from_pandas(test_dataloader.dataset.ds.ocr.to_frame())
tokenized_dataset = dataset_bert_vectors.map(
    lambda sample: tokenizer(sample["ocr"], truncation=True),
    batched=True,
)
tokenized_dataset = tokenized_dataset.remove_columns(
    ["ocr"]
)

collator = DataCollatorWithPadding(tokenizer)
test_dataloader_bert_vectors = DataLoader(
    tokenized_dataset, batch_size=5, collate_fn=collator
)

predictions = predict_and_convert_to_str(
    model=model,
    dataloader=test_dataloader,
    bert_model=bert_model,
    dataloader_bert_vectors=test_dataloader_bert_vectors,
    tgt_vocab=vocab_transform["gs"],
    device=device,
)
predictions[:3]
Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
0it [00:00, ?it/s]You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
7it [00:00, 11.23it/s]
['test-', 'test-t', 'ttest']