Error correction using T5

Functionality for error correction with T5.
import os
from pathlib import Path

from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer

from ocrpostcorrection.error_correction import get_tokens_with_OCR_mistakes, get_context_for_dataset
from ocrpostcorrection.icdar_data import generate_data

filter_len_ocr_mistake_in_context

 filter_len_ocr_mistake_in_context (data:pandas.core.frame.DataFrame,
                                    context_offset:int)
context_offset=5

data_context = filter_len_ocr_mistake_in_context(tdata, context_offset=context_offset)

assert tdata.len_mistake_in_context.max() > 10*context_offset
assert data_context.len_mistake_in_context.max() <= 10*context_offset
2024-08-10 22:09:57.075 | INFO     | __main__:filter_len_ocr_mistake_in_context:7 - Max length of input samples: 50
dataset = DatasetDict(
        {
            "train": Dataset.from_pandas(tdata.query('dataset == "train"')),
            "val": Dataset.from_pandas(tdata.query('dataset == "val"')),
            "test": Dataset.from_pandas(tdata.query('dataset == "test"')),
        }
    )
dataset['train'][1]
{'ocr': 'troe',
 'gs': 'tree',
 'ocr_aligned': 'troe',
 'gs_aligned': 'tree',
 'start': 13,
 'len_ocr': 4,
 'key': 'en/eng_sample/1.txt',
 'language': 'en',
 'subset': 'eng_sample',
 'dataset': 'train',
 'len_gs': 4,
 'diff': 0,
 'context_before': 'In botany, a ',
 'context_after': ' is a peremial plant',
 'len_mistake_in_context': 37,
 '__index_level_0__': 14}

filter_max_len

 filter_max_len (example:Dict, max_len:int)
max_len = 5
dataset_max_len = dataset.filter(
        filter_max_len, fn_kwargs={"max_len": max_len}, batched=False
    )

for subset, expected in {'train': 9, 'val': 2, 'test': 11}.items():
    assert len(dataset_max_len[subset]) == expected, f"Expected len of {expected} for '{subset}', got {len(dataset_max_len[subset])}"
model_name = "google/byt5-small"

tokenizer = AutoTokenizer.from_pretrained(model_name)

preprocess_function

 preprocess_function (examples, tokenizer, add_task_prefix:bool=False,
                      context_marker:str='')
tokenized_dataset = dataset.map(
    preprocess_function, fn_kwargs={"tokenizer": tokenizer}, batched=True
)
tokenizer.decode(tokenized_dataset['train'][1]['input_ids'])
'troe</s>'
tokenized_dataset = dataset.map(
    preprocess_function, fn_kwargs={"tokenizer": tokenizer, "add_task_prefix": True}, batched=True
)
tokenizer.decode(tokenized_dataset['train'][1]['input_ids'])
'en: troe</s>'
tokenized_dataset = dataset.map(
    preprocess_function, fn_kwargs={"tokenizer": tokenizer, "add_task_prefix": True, "context_marker": "mistake"}, batched=True
)
tokenizer.decode(tokenized_dataset['train'][1]['input_ids'])
'en: In botany, a <mistake>troe</mistake> is a peremial plant</s>'