set_seed(23)BERT Vectors Correction Data
BertVectorsCorrectionDataset
BertVectorsCorrectionDataset (data:pandas.core.frame.DataFrame, split_name:str, bert_vectors_file:Optional[ pathlib.Path]=None, max_len:int=11, hidden_size:int=768, look_up_bert_vectors:bool=True)
*An abstract class representing a :class:Dataset.
All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:__getitem__, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:__len__, which is expected to return the size of the dataset by many :class:~torch.utils.data.Sampler implementations and the default options of :class:~torch.utils.data.DataLoader.
.. note:: :class:~torch.utils.data.DataLoader by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.*
The sample bert vectors have been generated using python src/stages/create-bert-vectors.py --seed 1234 --dataset-in ../ocrpostcorrection/nbs/data/correction/dataset.csv --model-dir models/error-detection/ --model-name bert-base-multilingual-cased --batch-size 1 --out-file ../ocrpostcorrection/nbs/data/correction/bert-vectors.hdf5 (from ocrpostcorrection-notebooks, model from 9099e78)
data_csv = Path(os.getcwd()) / "data" / "correction" / "dataset.csv"
data = pd.read_csv(data_csv, index_col=0)
data.fillna("", inplace=True)
bert_vectors_file = Path(os.getcwd()) / "data" / "correction" / "bert-vectors.hdf5"
split_name = "test"
dataset = BertVectorsCorrectionDataset(
data=data.query(f"dataset == '{split_name}'"),
bert_vectors_file=bert_vectors_file,
split_name=split_name,
max_len=11,
hidden_size=768,
look_up_bert_vectors=True,
)split_name = "test"
data_csv = Path(os.getcwd()) / "data" / "correction" / "dataset.csv"
data = pd.read_csv(data_csv, index_col=0)
data.fillna("", inplace=True)
bert_vectors_file = Path(os.getcwd()) / "data" / "correction" / "bert-vectors.hdf5"
dataset_no_look_up = BertVectorsCorrectionDataset(
data=data.query(f"dataset == '{split_name}'"),
bert_vectors_file=None,
split_name=split_name,
max_len=11,
hidden_size=768,
look_up_bert_vectors=False
)collate_fn
collate_fn (text_transform)
Function to collate data samples into batch tensors
collate_fn_with_text_transform
collate_fn_with_text_transform (text_transform, batch)
Function to collate data samples into batch tensors, to be used as partial with instatiated text_transform
Training
validate_model
validate_model (model, dataloader, device)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 2
hidden_size = 768
dropout = 0.1
max_token_len = 10
model = SimpleCorrectionSeq2seq(
len(vocab_transform["ocr"]),
hidden_size,
len(vocab_transform["gs"]),
dropout,
max_token_len,
teacher_forcing_ratio=0.5,
device=device,
)
model.to(device)SimpleCorrectionSeq2seq(
(encoder): EncoderRNN(
(embedding): Embedding(46, 768)
(gru): GRU(768, 768, batch_first=True)
)
(decoder): AttnDecoderRNN(
(embedding): Embedding(44, 768)
(attn): Linear(in_features=1536, out_features=11, bias=True)
(attn_combine): Linear(in_features=1536, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(gru): GRU(768, 768)
(out): Linear(in_features=768, out_features=44, bias=True)
)
)
split_name = "val"
data_csv = Path(os.getcwd()) / "data" / "correction" / "dataset.csv"
data = pd.read_csv(data_csv, index_col=0)
data.fillna("", inplace=True)
bert_vectors_file = Path(os.getcwd()) / "data" / "correction" / "bert-vectors.hdf5"
val = BertVectorsCorrectionDataset(
data=data.query(f"dataset == '{split_name}'"),
bert_vectors_file=bert_vectors_file,
split_name=split_name,
max_len=11,
hidden_size=768,
look_up_bert_vectors=True
)
val_dataloader = DataLoader(val, batch_size=5, collate_fn=collate_fn(text_transform))
loss = validate_model(model, val_dataloader, device)
loss24.875640021430122
data| ocr | gs | ocr_aligned | gs_aligned | start | len_ocr | language | subset | dataset | len_gs | diff | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | test- AAA | test-.AAA | test- AAA | test-.AAA | 0 | 9 | fr | fr_sample | train | 9 | 0 |
| 1 | test-BBB | test- BBB | test@-BBB | test- BBB | 10 | 8 | fr | fr_sample | train | 9 | -1 |
| 2 | test-CCC | test- CCC | test-@CCC | test- CCC | 19 | 8 | fr | fr_sample | train | 9 | -1 |
| 3 | -DDD | DDD | -DDD | DDD | 33 | 4 | fr | fr_sample | train | 3 | 1 |
| 4 | test- EEE | test-EEE | test- EEE | test-@EEE | 38 | 9 | fr | fr_sample | train | 8 | 1 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 75 | species! | species. | species! | species. | 111 | 8 | en | eng_sample | test | 8 | 0 |
| 76 | Test -hyhen | Testhyhen | Test -hyhen | Test@@hyhen | 120 | 11 | en | eng_sample | test | 9 | 2 |
| 77 | error | errors | error@ | errors | 137 | 5 | en | eng_sample | test | 6 | -1 |
| 78 | C | CCC | C@@ | CCC | 151 | 1 | en | eng_sample | test | 3 | -2 |
| 79 | 34 | 3 4 | 3@4 | 3 4 | 153 | 2 | en | eng_sample | test | 3 | -1 |
80 rows × 11 columns
train_model
train_model (train_dl:torch.utils.data.dataloader.DataLoader[int], val_dl:torch.utils.data.dataloader.DataLoader[int], model:oc rpostcorrection.error_correction.SimpleCorrectionSeq2seq, optimizer:torch.optim.optimizer.Optimizer, num_epochs:int=5, valid_niter:int=5000, model_save_path:pathlib.Path=Path('model.rar'), max_num_patience:int=5, max_num_trial:int=5, lr_decay:float=0.5, device:torch.device=device(type='cpu'))
split_name = "train"
train = BertVectorsCorrectionDataset(
data=data.query(f"dataset == '{split_name}'"),
bert_vectors_file=bert_vectors_file,
split_name=split_name,
)
train_dataloader = DataLoader(
train, batch_size=2, collate_fn=collate_fn(text_transform), shuffle=True
)
split_name = "val"
val = BertVectorsCorrectionDataset(
data=data.query(f"dataset == '{split_name}'"),
bert_vectors_file=bert_vectors_file,
split_name=split_name,
)
val_dataloader = DataLoader(val, batch_size=3, collate_fn=collate_fn(text_transform))
hidden_size = 768
model = SimpleCorrectionSeq2seq(
len(vocab_transform["ocr"]),
hidden_size,
len(vocab_transform["gs"]),
0.1,
10,
teacher_forcing_ratio=0.0,
)
model.to(device)
optimizer = torch.optim.Adam(model.parameters())
msp = Path(os.getcwd()) / "data" / "model_bert_vectors.rar"
train_log = train_model(
train_dl=train_dataloader,
val_dl=val_dataloader,
model=model,
optimizer=optimizer,
model_save_path=msp,
num_epochs=2,
valid_niter=5,
max_num_patience=5,
max_num_trial=5,
lr_decay=0.5,
)
os.remove(msp)
train_log 31%|███ | 4/13 [00:00<00:01, 8.61it/s]2023-09-03 19:00:06.994 | INFO | __main__:train_model:58 - Epoch 1, iter 5, avg. train loss 27.373350143432617, avg. val loss 24.627723693847656
2023-09-03 19:00:06.995 | INFO | __main__:train_model:68 - Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model_bert_vectors.rar
69%|██████▉ | 9/13 [00:01<00:00, 8.23it/s]2023-09-03 19:00:07.644 | INFO | __main__:train_model:58 - Epoch 1, iter 10, avg. train loss 24.103273010253908, avg. val loss 24.043284098307293
2023-09-03 19:00:07.644 | INFO | __main__:train_model:68 - Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model_bert_vectors.rar
100%|██████████| 13/13 [00:01<00:00, 7.17it/s]
8%|▊ | 1/13 [00:00<00:01, 8.18it/s]2023-09-03 19:00:08.502 | INFO | __main__:train_model:58 - Epoch 2, iter 15, avg. train loss 19.344550323486327, avg. val loss 19.952612982855904
2023-09-03 19:00:08.502 | INFO | __main__:train_model:68 - Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model_bert_vectors.rar
38%|███▊ | 5/13 [00:00<00:01, 6.30it/s]2023-09-03 19:00:09.292 | INFO | __main__:train_model:58 - Epoch 2, iter 20, avg. train loss 22.214738655090333, avg. val loss 18.82603581746419
2023-09-03 19:00:09.293 | INFO | __main__:train_model:68 - Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model_bert_vectors.rar
77%|███████▋ | 10/13 [00:01<00:00, 6.35it/s]2023-09-03 19:00:10.207 | INFO | __main__:train_model:58 - Epoch 2, iter 25, avg. train loss 21.808086776733397, avg. val loss 18.29765616522895
2023-09-03 19:00:10.208 | INFO | __main__:train_model:68 - Saving model and optimizer to /Users/janneke/code/ocrpostcorrection/nbs/data/model_bert_vectors.rar
100%|██████████| 13/13 [00:02<00:00, 5.57it/s]
| train_loss | val_loss | |
|---|---|---|
| 0 | 27.373350 | 24.627724 |
| 1 | 24.103273 | 24.043284 |
| 2 | 19.344550 | 19.952613 |
| 3 | 22.214739 | 18.826036 |
| 4 | 21.808087 | 18.297656 |
Inference / prediction
https://pytorch.org/tutorials/beginner/chatbot_tutorial.html?highlight=greedy%20decoding
GreedySearchDecoder
GreedySearchDecoder (model)
*Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.
.. note:: As per the example above, an __init__() call to the parent class must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool*
decoder = GreedySearchDecoder(model)
test_dataloader = DataLoader(dataset, batch_size=5, collate_fn=collate_fn(text_transform))
with torch.no_grad():
for i, (src, tgt, encoder_hidden) in enumerate(test_dataloader):
predicted_indices = decoder(src, encoder_hidden, tgt)
if i == 0:
print(predicted_indices)
else:
print(predicted_indices.size())tensor([[ 4, 5, 6, 4, 8, 3, 3, 3, 3, 3, 0],
[ 4, 5, 6, 4, 8, 4, 3, 3, 3, 3, 0],
[ 4, 4, 5, 6, 4, 3, 3, 3, 3, 3, 0],
[14, 13, 3, 3, 3, 3, 3, 3, 3, 3, 0],
[ 4, 5, 6, 4, 8, 3, 3, 3, 3, 3, 0]])
torch.Size([5, 11])
torch.Size([5, 11])
torch.Size([5, 11])
torch.Size([5, 11])
torch.Size([5, 11])
torch.Size([5, 11])
decoder = GreedySearchDecoder(model)
test_dataloader = DataLoader(dataset_no_look_up, batch_size=5, collate_fn=collate_fn(text_transform))
with torch.no_grad():
for i, (src, tgt, encoder_hidden) in enumerate(test_dataloader):
predicted_indices = decoder(src, encoder_hidden, tgt)
if i == 0:
print(predicted_indices)
else:
print(predicted_indices.size())tensor([[ 4, 5, 6, 4, 8, 6, 4, 3, 3, 3, 0],
[ 4, 5, 6, 4, 8, 4, 3, 3, 3, 3, 0],
[ 4, 4, 5, 6, 4, 3, 4, 3, 3, 3, 0],
[14, 13, 3, 3, 3, 3, 3, 3, 3, 3, 0],
[ 4, 5, 6, 4, 8, 3, 3, 3, 3, 3, 0]])
torch.Size([5, 11])
torch.Size([5, 11])
torch.Size([5, 11])
torch.Size([5, 11])
torch.Size([5, 11])
torch.Size([5, 11])
predict_and_convert_to_str
predict_and_convert_to_str (model, dataloader, bert_model, dataloader_bert_vectors, tgt_vocab, device)
model_name = "bert-base-multilingual-cased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
bert_model = BertModel.from_pretrained(model_name)
bert_model.eval()
dataset_bert_vectors = HFDataset.from_pandas(test_dataloader.dataset.ds.ocr.to_frame())
tokenized_dataset = dataset_bert_vectors.map(
lambda sample: tokenizer(sample["ocr"], truncation=True),
batched=True,
)
tokenized_dataset = tokenized_dataset.remove_columns(
["ocr"]
)
collator = DataCollatorWithPadding(tokenizer)
test_dataloader_bert_vectors = DataLoader(
tokenized_dataset, batch_size=5, collate_fn=collator
)
predictions = predict_and_convert_to_str(
model=model,
dataloader=test_dataloader,
bert_model=bert_model,
dataloader_bert_vectors=test_dataloader_bert_vectors,
tgt_vocab=vocab_transform["gs"],
device=device,
)
predictions[:3]Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
0it [00:00, ?it/s]You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
7it [00:00, 11.23it/s]
['test-', 'test-t', 'ttest']