Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Layoutlmv3 issue with inferencing bounding box is not plotting correctly #392

Open
rajasekarkrish opened this issue Feb 15, 2024 · 0 comments

Comments

@rajasekarkrish
Copy link

Layoutlmv3 issue with inferencing bounding box is not plotting correctly
from transformers import AutoModelForTokenClassification
from datasets import load_dataset
import torch
from transformers import AutoProcessor
import matplotlib.pyplot as plt
from updatetrain import id2label
import matplotlib.patches as patches

model = AutoModelForTokenClassification.from_pretrained("/new_dataset/new_layoutlmv3/checkpoint-3000")

dataset = load_dataset(r"/new_layoutlmv3_dataset/new_dataset/updateddataset.py")

processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)

example = dataset["test"][1]
print(example["image"])
print(example.keys())

image = example["image"]
words = example["words"]
boxes = example["bboxes"]
word_labels = example["ner_tags"]

encoding = processor(image, words, boxes=boxes, word_labels=word_labels, truncation=True , stride =128, return_tensors="pt")

for k,v in encoding.items():
print(k,v.shape)

with torch.no_grad():
outputs = model(**encoding)

logits = outputs.logits
print(logits.shape)

predictions = logits.argmax(-1).squeeze().tolist()
print(predictions)

labels = encoding.labels.squeeze().tolist()
print(labels)

print("printing five labels",labels[:5]) # Print the first 5 labels
def unnormalize_box(bbox, width, height):
return [
width * (bbox[0]/1000),
height *(bbox[1]/1000),
width * (bbox[2]/1000),
height *(bbox[3]/1000),
]

token_boxes = encoding.bbox.squeeze().tolist()
width, height = image.size
print("Image size:", width, "x", height)

print("printing boxes",boxes[:5] )

print("printing token boxes",token_boxes[:5]) # Print the first 5 bounding boxes

true_predictions = [model.config.id2label[pred] for pred, label in zip(predictions, labels) if label != - 100]
true_labels = [model.config.id2label[label] for prediction, label in zip(predictions, labels) if label != -100]
true_boxes = [unnormalize_box(box, width, height) for box, label in zip(token_boxes, labels) if label != -100]

print("printing true Predictions:", true_predictions[:5]) # Print the first 5 predictions
print("printing true Labels:", true_labels[:5]) # Print the first 5 labels

from PIL import ImageDraw, ImageFont

draw = ImageDraw.Draw(image)

font = ImageFont.load_default()

def iob_to_label(label):
label = label[2:]
if not label:
return 'other'
return label

label2color = {
'relevant': 'red',
'se_tax_header': 'green',
'se_tax_due_header': 'green',
'ar_header': 'green',
'se_tax_total': 'blue',
'se_tax_due_total': 'blue',
'se_tax': 'yellow',
'se_tax_due': 'orange',
'ar': 'orange',
}

print(label2color)
print(model.config.id2label)

for i, (prediction, label) in enumerate(zip(predictions, labels)):
if label == -100:
continue # Skip the padding tokens or any token that should be ignored
predicted_label = model.config.id2label.get(prediction, "Label not found")
actual_label = id2label.get(label, "Label not found")
print(f"Token {i}: Predicted - {predicted_label}, Actual - {actual_label}")

for prediction, box in zip(true_predictions, true_boxes):
predicted_label = iob_to_label(prediction).lower()
if predicted_label in label2color:
draw.rectangle(box, outline=label2color[predicted_label])
draw.text((box[0] + 10, box[1] - 10), text=predicted_label, fill=label2color[predicted_label], font=font)
else:
print(f"Label {predicted_label} not in label2color dictionary.")

plt.imshow(image)
plt.show()

bounding_box_not_ploted_accordingly_layoutlmv3

Dataset preparation code
import json
import os
import numpy as np
from PIL import Image
import datasets

import torch

logger = datasets.logging.get_logger(name)

def normalize_bbox(bbox, size):
return [
int(1000 * bbox[0] / size[0]),
int(1000 * bbox[1] / size[1]),
int(1000 * bbox[2] / size[0]),
int(1000 * bbox[3] / size[1]),
]

def load_image(image_path):
image = Image.open(image_path).convert("RGB")
w, h = image.size
return image, (w, h)

class CustomDatasetConfig(datasets.BuilderConfig):
"""BuilderConfig for CustomDataset"""
def init(self, **kwargs):
"""BuilderConfig for CustomDataset.
Args:
**kwargs: keyword arguments forwarded to super.
"""
super(CustomDatasetConfig, self).init(**kwargs)

class CustomDataset(datasets.GeneratorBasedBuilder):
"""Custom dataset for document understanding."""

BUILDER_CONFIGS = [
    CustomDatasetConfig(name="custom_dataset", version=datasets.Version("1.0.0"), description="Custom dataset"),
]

def _info(self):
    return datasets.DatasetInfo(
        features=datasets.Features(
            {
                "id": datasets.Value("string"),
                "words": datasets.Sequence(datasets.Value("string")),
                "bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
                "ner_tags": datasets.Sequence(
                    datasets.features.ClassLabel(
                        names=[
                            'irrelevant',
                            'base_tax_header',
                            'base_tax_due_header',
                            'year_header',
                            'base_tax_total',
                            'base_tax_due_total',
                            'base_tax',
                            'base_tax_due',
                            'year',
                            # Add more labels as per your requirement
                        ]
                    )
                ),
                "image": datasets.features.Image(),
                "image_path": datasets.Value("string"),
            }
        ),
        supervised_keys=None,
    )

def _split_generators(self, dl_manager):
    """Returns SplitGenerators."""
    # Assuming the data is already downloaded/extracted and available in a specific directory
    data_dir = '/new_layoutlmv3_dataset/new_dataset/data/'
    return [
        datasets.SplitGenerator(
            name=datasets.Split.TRAIN, gen_kwargs={"filepath": os.path.join(data_dir, "train.json")},
        ),
        datasets.SplitGenerator(
            name=datasets.Split.TEST, gen_kwargs={"filepath": os.path.join(data_dir, "test.json")},
        ),
    ]

def _generate_examples(self, filepath):
    logger.info("⏳ Generating examples from = %s", filepath)
    with open(filepath, "r", encoding="utf8") as f:
        data = json.load(f)

    # Define the base directory for your images
    base_dir = r"D:\new_layoutlmv3_dataset\new_dataset"  # Update this path to your base directory

    for guid, item in enumerate(data):
        # Check if 'file_name' exists and correct the path
        if 'file_name' not in item:
            logger.warning(f"Skipping entry {guid} due to missing 'file_name'")
            continue
        # Correct the slash direction and prepend the base directory to the file name
        image_relative_path = item['file_name'].replace('../', '').replace('/', '\\')
        image_path = os.path.join(base_dir, image_relative_path)

        image, size = load_image(image_path)
        words, bboxes, ner_tags = [], [], []

        for annotation in item["annotations"]:
            words.append(annotation["text"])
            normalized_bbox = normalize_bbox(annotation["box"], size)
            bboxes.append(normalized_bbox)
            ner_tags.append(annotation["label"])

        yield guid, {
            "id": str(guid),
            "words": words,
            "bboxes": bboxes,
            "ner_tags": ner_tags,
            "image": image,  # Adjusting according to the expected format
            "image_path": image_path  # Keep this as 'image_path' for consistency in your dataset features
        }
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant