import os
from pathlib import Path
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer
from ocrpostcorrection.error_correction import get_tokens_with_OCR_mistakes, get_context_for_dataset
from ocrpostcorrection.icdar_data import generate_data
Error correction using T5
Functionality for error correction with T5.
filter_len_ocr_mistake_in_context
filter_len_ocr_mistake_in_context (data:pandas.core.frame.DataFrame, context_offset:int)
=5
context_offset
= filter_len_ocr_mistake_in_context(tdata, context_offset=context_offset)
data_context
assert tdata.len_mistake_in_context.max() > 10*context_offset
assert data_context.len_mistake_in_context.max() <= 10*context_offset
2024-08-10 22:09:57.075 | INFO | __main__:filter_len_ocr_mistake_in_context:7 - Max length of input samples: 50
= DatasetDict(
dataset
{"train": Dataset.from_pandas(tdata.query('dataset == "train"')),
"val": Dataset.from_pandas(tdata.query('dataset == "val"')),
"test": Dataset.from_pandas(tdata.query('dataset == "test"')),
}
)'train'][1] dataset[
{'ocr': 'troe',
'gs': 'tree',
'ocr_aligned': 'troe',
'gs_aligned': 'tree',
'start': 13,
'len_ocr': 4,
'key': 'en/eng_sample/1.txt',
'language': 'en',
'subset': 'eng_sample',
'dataset': 'train',
'len_gs': 4,
'diff': 0,
'context_before': 'In botany, a ',
'context_after': ' is a peremial plant',
'len_mistake_in_context': 37,
'__index_level_0__': 14}
filter_max_len
filter_max_len (example:Dict, max_len:int)
= 5
max_len = dataset.filter(
dataset_max_len ={"max_len": max_len}, batched=False
filter_max_len, fn_kwargs
)
for subset, expected in {'train': 9, 'val': 2, 'test': 11}.items():
assert len(dataset_max_len[subset]) == expected, f"Expected len of {expected} for '{subset}', got {len(dataset_max_len[subset])}"
= "google/byt5-small"
model_name
= AutoTokenizer.from_pretrained(model_name) tokenizer
preprocess_function
preprocess_function (examples, tokenizer, add_task_prefix:bool=False, context_marker:str='')
= dataset.map(
tokenized_dataset ={"tokenizer": tokenizer}, batched=True
preprocess_function, fn_kwargs )
'train'][1]['input_ids']) tokenizer.decode(tokenized_dataset[
'troe</s>'
= dataset.map(
tokenized_dataset ={"tokenizer": tokenizer, "add_task_prefix": True}, batched=True
preprocess_function, fn_kwargs )
'train'][1]['input_ids']) tokenizer.decode(tokenized_dataset[
'en: troe</s>'
= dataset.map(
tokenized_dataset ={"tokenizer": tokenizer, "add_task_prefix": True, "context_marker": "mistake"}, batched=True
preprocess_function, fn_kwargs )
'train'][1]['input_ids']) tokenizer.decode(tokenized_dataset[
'en: In botany, a <mistake>troe</mistake> is a peremial plant</s>'