You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Layoutlmv3 issue with inferencing bounding box is not plotting correctly
from transformers import AutoModelForTokenClassification
from datasets import load_dataset
import torch
from transformers import AutoProcessor
import matplotlib.pyplot as plt
from updatetrain import id2label
import matplotlib.patches as patches
model = AutoModelForTokenClassification.from_pretrained("/new_dataset/new_layoutlmv3/checkpoint-3000")
print("printing token boxes",token_boxes[:5]) # Print the first 5 bounding boxes
true_predictions = [model.config.id2label[pred] for pred, label in zip(predictions, labels) if label != - 100]
true_labels = [model.config.id2label[label] for prediction, label in zip(predictions, labels) if label != -100]
true_boxes = [unnormalize_box(box, width, height) for box, label in zip(token_boxes, labels) if label != -100]
print("printing true Predictions:", true_predictions[:5]) # Print the first 5 predictions
print("printing true Labels:", true_labels[:5]) # Print the first 5 labels
from PIL import ImageDraw, ImageFont
draw = ImageDraw.Draw(image)
font = ImageFont.load_default()
def iob_to_label(label):
label = label[2:]
if not label:
return 'other'
return label
for i, (prediction, label) in enumerate(zip(predictions, labels)):
if label == -100:
continue # Skip the padding tokens or any token that should be ignored
predicted_label = model.config.id2label.get(prediction, "Label not found")
actual_label = id2label.get(label, "Label not found")
print(f"Token {i}: Predicted - {predicted_label}, Actual - {actual_label}")
for prediction, box in zip(true_predictions, true_boxes):
predicted_label = iob_to_label(prediction).lower()
if predicted_label in label2color:
draw.rectangle(box, outline=label2color[predicted_label])
draw.text((box[0] + 10, box[1] - 10), text=predicted_label, fill=label2color[predicted_label], font=font)
else:
print(f"Label {predicted_label} not in label2color dictionary.")
plt.imshow(image)
plt.show()
Dataset preparation code
import json
import os
import numpy as np
from PIL import Image
import datasets
class CustomDatasetConfig(datasets.BuilderConfig):
"""BuilderConfig for CustomDataset"""
def init(self, **kwargs):
"""BuilderConfig for CustomDataset.
Args:
**kwargs: keyword arguments forwarded to super.
"""
super(CustomDatasetConfig, self).init(**kwargs)
class CustomDataset(datasets.GeneratorBasedBuilder):
"""Custom dataset for document understanding."""
BUILDER_CONFIGS = [
CustomDatasetConfig(name="custom_dataset", version=datasets.Version("1.0.0"), description="Custom dataset"),
]
def _info(self):
return datasets.DatasetInfo(
features=datasets.Features(
{
"id": datasets.Value("string"),
"words": datasets.Sequence(datasets.Value("string")),
"bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
"ner_tags": datasets.Sequence(
datasets.features.ClassLabel(
names=[
'irrelevant',
'base_tax_header',
'base_tax_due_header',
'year_header',
'base_tax_total',
'base_tax_due_total',
'base_tax',
'base_tax_due',
'year',
# Add more labels as per your requirement
]
)
),
"image": datasets.features.Image(),
"image_path": datasets.Value("string"),
}
),
supervised_keys=None,
)
def _split_generators(self, dl_manager):
"""Returns SplitGenerators."""
# Assuming the data is already downloaded/extracted and available in a specific directory
data_dir = '/new_layoutlmv3_dataset/new_dataset/data/'
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN, gen_kwargs={"filepath": os.path.join(data_dir, "train.json")},
),
datasets.SplitGenerator(
name=datasets.Split.TEST, gen_kwargs={"filepath": os.path.join(data_dir, "test.json")},
),
]
def _generate_examples(self, filepath):
logger.info("⏳ Generating examples from = %s", filepath)
with open(filepath, "r", encoding="utf8") as f:
data = json.load(f)
# Define the base directory for your images
base_dir = r"D:\new_layoutlmv3_dataset\new_dataset" # Update this path to your base directory
for guid, item in enumerate(data):
# Check if 'file_name' exists and correct the path
if 'file_name' not in item:
logger.warning(f"Skipping entry {guid} due to missing 'file_name'")
continue
# Correct the slash direction and prepend the base directory to the file name
image_relative_path = item['file_name'].replace('../', '').replace('/', '\\')
image_path = os.path.join(base_dir, image_relative_path)
image, size = load_image(image_path)
words, bboxes, ner_tags = [], [], []
for annotation in item["annotations"]:
words.append(annotation["text"])
normalized_bbox = normalize_bbox(annotation["box"], size)
bboxes.append(normalized_bbox)
ner_tags.append(annotation["label"])
yield guid, {
"id": str(guid),
"words": words,
"bboxes": bboxes,
"ner_tags": ner_tags,
"image": image, # Adjusting according to the expected format
"image_path": image_path # Keep this as 'image_path' for consistency in your dataset features
}
The text was updated successfully, but these errors were encountered:
Layoutlmv3 issue with inferencing bounding box is not plotting correctly
from transformers import AutoModelForTokenClassification
from datasets import load_dataset
import torch
from transformers import AutoProcessor
import matplotlib.pyplot as plt
from updatetrain import id2label
import matplotlib.patches as patches
model = AutoModelForTokenClassification.from_pretrained("/new_dataset/new_layoutlmv3/checkpoint-3000")
dataset = load_dataset(r"/new_layoutlmv3_dataset/new_dataset/updateddataset.py")
processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
example = dataset["test"][1]
print(example["image"])
print(example.keys())
image = example["image"]
words = example["words"]
boxes = example["bboxes"]
word_labels = example["ner_tags"]
encoding = processor(image, words, boxes=boxes, word_labels=word_labels, truncation=True , stride =128, return_tensors="pt")
for k,v in encoding.items():
print(k,v.shape)
with torch.no_grad():
outputs = model(**encoding)
logits = outputs.logits
print(logits.shape)
predictions = logits.argmax(-1).squeeze().tolist()
print(predictions)
labels = encoding.labels.squeeze().tolist()
print(labels)
print("printing five labels",labels[:5]) # Print the first 5 labels
def unnormalize_box(bbox, width, height):
return [
width * (bbox[0]/1000),
height *(bbox[1]/1000),
width * (bbox[2]/1000),
height *(bbox[3]/1000),
]
token_boxes = encoding.bbox.squeeze().tolist()
width, height = image.size
print("Image size:", width, "x", height)
print("printing boxes",boxes[:5] )
print("printing token boxes",token_boxes[:5]) # Print the first 5 bounding boxes
true_predictions = [model.config.id2label[pred] for pred, label in zip(predictions, labels) if label != - 100]
true_labels = [model.config.id2label[label] for prediction, label in zip(predictions, labels) if label != -100]
true_boxes = [unnormalize_box(box, width, height) for box, label in zip(token_boxes, labels) if label != -100]
print("printing true Predictions:", true_predictions[:5]) # Print the first 5 predictions
print("printing true Labels:", true_labels[:5]) # Print the first 5 labels
from PIL import ImageDraw, ImageFont
draw = ImageDraw.Draw(image)
font = ImageFont.load_default()
def iob_to_label(label):
label = label[2:]
if not label:
return 'other'
return label
label2color = {
'relevant': 'red',
'se_tax_header': 'green',
'se_tax_due_header': 'green',
'ar_header': 'green',
'se_tax_total': 'blue',
'se_tax_due_total': 'blue',
'se_tax': 'yellow',
'se_tax_due': 'orange',
'ar': 'orange',
}
print(label2color)
print(model.config.id2label)
for i, (prediction, label) in enumerate(zip(predictions, labels)):
if label == -100:
continue # Skip the padding tokens or any token that should be ignored
predicted_label = model.config.id2label.get(prediction, "Label not found")
actual_label = id2label.get(label, "Label not found")
print(f"Token {i}: Predicted - {predicted_label}, Actual - {actual_label}")
for prediction, box in zip(true_predictions, true_boxes):
predicted_label = iob_to_label(prediction).lower()
if predicted_label in label2color:
draw.rectangle(box, outline=label2color[predicted_label])
draw.text((box[0] + 10, box[1] - 10), text=predicted_label, fill=label2color[predicted_label], font=font)
else:
print(f"Label {predicted_label} not in label2color dictionary.")
plt.imshow(image)
plt.show()
Dataset preparation code
import json
import os
import numpy as np
from PIL import Image
import datasets
import torch
logger = datasets.logging.get_logger(name)
def normalize_bbox(bbox, size):
return [
int(1000 * bbox[0] / size[0]),
int(1000 * bbox[1] / size[1]),
int(1000 * bbox[2] / size[0]),
int(1000 * bbox[3] / size[1]),
]
def load_image(image_path):
image = Image.open(image_path).convert("RGB")
w, h = image.size
return image, (w, h)
class CustomDatasetConfig(datasets.BuilderConfig):
"""BuilderConfig for CustomDataset"""
def init(self, **kwargs):
"""BuilderConfig for CustomDataset.
Args:
**kwargs: keyword arguments forwarded to super.
"""
super(CustomDatasetConfig, self).init(**kwargs)
class CustomDataset(datasets.GeneratorBasedBuilder):
"""Custom dataset for document understanding."""
The text was updated successfully, but these errors were encountered: