Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fine tune TrOCR on IAM Handwriting Database using Seq2SeqTrainer #412

Open
johnlockejrr opened this issue Apr 18, 2024 · 13 comments
Open

Fine tune TrOCR on IAM Handwriting Database using Seq2SeqTrainer #412

johnlockejrr opened this issue Apr 18, 2024 · 13 comments

Comments

@johnlockejrr
Copy link

Seems the IAM dataset is not public anymore, any other location?

Trying to download, output:

<Error>
<script id="argent-x-extension" data-extension-id="dlcobpjiigpikoobohmabehhmhfoodbb"/>
<Code>PublicAccessNotPermitted</Code>
<Message>Public access is not permitted on this storage account. RequestId:954279e9-d01e-0066-427a-91a772000000 Time:2024-04-18T10:26:49.5054310Z</Message>
</Error>
@johnlockejrr
Copy link
Author

johnlockejrr commented Apr 18, 2024

Managed to get the data from the origin, now another problem, I follow your example and I get no model saved... do I do anything wrong?

import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset
from PIL import Image
from transformers import TrOCRProcessor
from transformers import VisionEncoderDecoderModel
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from datasets import load_metric
from transformers import default_data_collator
from huggingface_hub import login

df = pd.read_fwf('./IAM/gt_test.txt', header=None)
df.rename(columns={0: "file_name", 1: "text"}, inplace=True)
print(df.head())

train_df, test_df = train_test_split(df, test_size=0.2)
# we reset the indices to start from zero
train_df.reset_index(drop=True, inplace=True)
test_df.reset_index(drop=True, inplace=True)

class IAMDataset(Dataset):
    def __init__(self, root_dir, df, processor, max_target_length=128):
        self.root_dir = root_dir
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # get file name + text
        file_name = self.df['file_name'][idx]
        text = self.df['text'][idx]
        # prepare image (i.e. resize + normalize)
        image = Image.open(self.root_dir + file_name).convert("RGB")
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        # add labels (input_ids) by encoding the text
        labels = self.processor.tokenizer(text,
                                          padding="max_length",
                                          max_length=self.max_target_length).input_ids
        # important: make sure that PAD tokens are ignored by the loss function
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding

processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
train_dataset = IAMDataset(root_dir='./IAM/',
                           df=train_df,
                           processor=processor)
eval_dataset = IAMDataset(root_dir='./IAM/',
                           df=test_df,
                           processor=processor)

print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(eval_dataset))

encoding = train_dataset[0]
for k,v in encoding.items():
  print(k, v.shape)

#labels = encoding['labels']
#labels[labels == -100] = processor.tokenizer.pad_token_id
#label_str = processor.decode(labels, skip_special_tokens=True)
#print(label_str)

model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-stage1")

# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size

# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    fp16=True,
    output_dir="./iam-train",
    overwrite_output_dir=True,
    learning_rate=2e-5,
    weight_decay=0.01,
    num_train_epochs=5,
    push_to_hub=True,
    hub_token="hf_XXXXXXXXXXXXXXXXXXXXXXXXX",
    logging_steps=2,
    save_steps=1000,
    eval_steps=200,
    load_best_model_at_end=True,
    metric_for_best_model="cer",
)

cer_metric = load_metric("cer")

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return {"cer": cer}

# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor.feature_extractor,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=default_data_collator,
)

trainer.train()
(TrOCR-py3.10) incognito@DESKTOP-NHKR7QL:~/TrOCR-py3.10$ ls -al iam-train/
total 8
drwxr-xr-x 2 incognito incognito 4096 Apr 18 13:34 .
drwxr-xr-x 9 incognito incognito 4096 Apr 18 14:58 ..

@NielsRogge
Copy link
Owner

If you provide the save_steps argument, then the model should be saved automatically to output_dir every save_steps (since save_strategy="steps" by default).

@johnlockejrr
Copy link
Author

johnlockejrr commented Apr 18, 2024

I did a:

processor.save_pretrained('./iam-train')
model.save_pretrained('./iam-train')

And it saved... the old model?

Anyway, I would like to save the best model. I think the save doesn’t care about the best model, so will just save every save_steps regardless of which step had the better loss. Am I wrong? Should I evaluate by epoch?

@NielsRogge
Copy link
Owner

@johnlockejrr
Copy link
Author

Thank you so much!

Side question: do you have any scripts/docs how to train a foreign language TrOCR model? I mean here by foreign language especially Hebrew.

@NielsRogge
Copy link
Owner

Refer to this thread: huggingface/transformers#18163

@johnlockejrr
Copy link
Author

johnlockejrr commented Apr 19, 2024

Sorry for disturbing, I'm a novice in BERT... untill now I worked only with Kraken OCR, still neural networks but a little different.
Should I give a go to this code? I want to train Hebrew/Samaritan manuscripts recognition.

import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset
from PIL import Image
from transformers import TrOCRProcessor
from transformers import VisionEncoderDecoderModel
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from datasets import load_metric
from transformers import default_data_collator

df = pd.read_fwf('./SAM/gt_test.txt', header=None)
df.rename(columns={0: "file_name", 1: "text"}, inplace=True)
print(df.head())

train_df, test_df = train_test_split(df, test_size=0.1)
# we reset the indices to start from zero
train_df.reset_index(drop=True, inplace=True)
test_df.reset_index(drop=True, inplace=True)

class SAMDataset(Dataset):
    def __init__(self, root_dir, df, processor, max_target_length=128):
        self.root_dir = root_dir
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # get file name + text
        file_name = self.df['file_name'][idx]
        text = self.df['text'][idx]
        # prepare image (i.e. resize + normalize)
        image = Image.open(self.root_dir + file_name).convert("RGB")
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        # add labels (input_ids) by encoding the text
        labels = self.processor.tokenizer(text,
                                          padding="max_length",
                                          max_length=self.max_target_length).input_ids
        # important: make sure that PAD tokens are ignored by the loss function
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding

processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")

train_dataset = SAMDataset(root_dir='./SAM/',
                           df=train_df,
                           processor=processor)
eval_dataset = SAMDataset(root_dir='./SAM/',
                           df=test_df,
                           processor=processor)

print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(eval_dataset))

encoding = train_dataset[0]
for k,v in encoding.items():
  print(k, v.shape)

#labels = encoding['labels']
#labels[labels == -100] = processor.tokenizer.pad_token_id
#label_str = processor.decode(labels, skip_special_tokens=True)
#print(label_str)

#model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-stage1")

encoder_checkpoint = "google/vit-base-patch16-224-in21k"
decoder_checkpoint = "imvladikon/alephbertgimmel-base-512"
#
# load a fine-tuned image captioning model and corresponding tokenizer and image processor
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
   encoder_checkpoint, decoder_checkpoint
).to("cuda")

# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size

# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    fp16=True,
    output_dir=f"{encoder_checkpoint}-ft-sam-v1",
    overwrite_output_dir=True,
    learning_rate=2e-5,
    weight_decay=0.01,
    num_train_epochs=5,
    push_to_hub=True,
    hub_token="hf_XXXXXXXXXXXXXXXXXXXXXXXX",
    logging_steps=2,
    save_steps=1000,
    eval_steps=200,
    save_strategy="steps",
    load_best_model_at_end=True,
    metric_for_best_model="cer",
)

cer_metric = load_metric("cer")

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return {"cer": cer}

# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor.feature_extractor,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=default_data_collator,
)

trainer.train()
trainer.save_model()
trainer.push_to_hub()

@johnlockejrr
Copy link
Author

johnlockejrr commented Apr 19, 2024

ValueError: Input image size (384*384) doesn't match model (224*224)
Any idea where should I pass interpolate_pos_encoding=True?
Or is there a non-Vit model that could work in my case?
Thanks!

@johnlockejrr
Copy link
Author

Finally I trained it on google/vit-base-patch16-384, after finishing it recognizes gibberish not even already trained images... in Hebrew, though, as I trained it, but gibberish...

@NielsRogge
Copy link
Owner

I'd recommend starting with 5 training examples and see if the model is able to overfit them

@johnlockejrr
Copy link
Author

Ok, I'll do that! Should I keep google/vit-base-patch16-384 or use "google/vit-base-patch16-224-in21k" with interpolate_pos_encoding=True? In general, the script seems ok? Thank you so much!

@johnlockejrr
Copy link
Author

Trained with 16 samples:

We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
{'loss': 16.0417, 'grad_norm': nan, 'learning_rate': 2e-05, 'epoch': 1.0}
{'loss': 15.6389, 'grad_norm': nan, 'learning_rate': 2e-05, 'epoch': 2.0}
{'loss': 15.0141, 'grad_norm': 173.44268798828125, 'learning_rate': 1.8e-05, 'epoch': 3.0}
{'loss': 11.8002, 'grad_norm': 53.65483856201172, 'learning_rate': 1.6000000000000003e-05, 'epoch': 4.0}
{'loss': 10.464, 'grad_norm': 53.46924591064453, 'learning_rate': 1.4e-05, 'epoch': 5.0}
{'loss': 9.4223, 'grad_norm': 38.62109375, 'learning_rate': 1.2e-05, 'epoch': 6.0}
{'loss': 8.9751, 'grad_norm': 27.501571655273438, 'learning_rate': 1e-05, 'epoch': 7.0}
{'loss': 8.7579, 'grad_norm': 20.1580867767334, 'learning_rate': 8.000000000000001e-06, 'epoch': 8.0}
{'loss': 8.4512, 'grad_norm': 18.931493759155273, 'learning_rate': 6e-06, 'epoch': 9.0}
{'loss': 8.241, 'grad_norm': 20.91577911376953, 'learning_rate': 4.000000000000001e-06, 'epoch': 10.0}
{'train_runtime': 12.2611, 'train_samples_per_second': 11.418, 'train_steps_per_second': 1.631, 'train_loss': 11.280629634857178, 'epoch': 10.0}

Test output:

(huggingface-source-py3.10) incognito@DESKTOP-NHKR7QL:~/TrOCR-py3.10$ python test_ocr.py LINES/sam_gt/2.4.jpg
/home/incognito/huggingface-source-py3.10/lib/python3.10/site-packages/transformers/generation/utils.py:1252: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use and modify the model generation configuration (see https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )
  warnings.warn(
 � � �� ��� � �ת� �ת ��ת���ת � �ב ��ב� �ב��ב � � �� �תת�תת �תב �ת� ���� ����

Even worse.

@NielsRogge
Copy link
Owner

That means there's a bug in data prepatation/hyperparameter settings/model configuration.

I recommend this guide for debugging: https://karpathy.github.io/2019/04/25/recipe/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants