Skip to content

Commit

Permalink
Add caption support to tables
Browse files Browse the repository at this point in the history
  • Loading branch information
Asad Hasan committed Mar 15, 2024
1 parent 3783b44 commit 8170228
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 24 deletions.
5 changes: 5 additions & 0 deletions unstructured/documents/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,11 @@ class Table(Text):

category = "Table"

class Caption(Text):
"""An element for capturing captions."""

category = "Caption"


class TableChunk(Table):
"""An element for capturing chunks of tables."""
Expand Down
45 changes: 32 additions & 13 deletions unstructured/documents/html.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Table,
Text,
Title,
Caption
)
from unstructured.documents.xml import VALID_PARSERS, XMLDocument
from unstructured.logger import logger
Expand All @@ -37,11 +38,14 @@
)
from unstructured.utils import htmlify_matrix_of_cell_texts

TEXT_TAGS: Final[List[str]] = ["p", "a", "td", "span", "font"]
CAPTION_TAG: str = "caption"
CAPTION_TAG_START: str = f"<{CAPTION_TAG}>"
CAPTION_TAG_END: str = f"</{CAPTION_TAG}>"
TEXT_TAGS: Final[List[str]] = ["p", "a", "td", "span", "font", CAPTION_TAG]
LIST_ITEM_TAGS: Final[List[str]] = ["li", "dd"]
LIST_TAGS: Final[List[str]] = ["ul", "ol", "dl"]
HEADING_TAGS: Final[List[str]] = ["h1", "h2", "h3", "h4", "h5", "h6"]
TABLE_TAGS: Final[List[str]] = ["table", "tbody", "td", "tr"]
TABLE_TAGS: Final[List[str]] = ["table", "tbody", "td", "tr", CAPTION_TAG]
TEXTBREAK_TAGS: Final[List[str]] = ["br"]
PAGEBREAK_TAGS: Final[List[str]] = ["hr"]
EMPTY_TAGS: Final[List[str]] = PAGEBREAK_TAGS + TEXTBREAK_TAGS
Expand Down Expand Up @@ -103,6 +107,8 @@ class HTMLListItem(TagsMixin, ListItem):
class HTMLTable(TagsMixin, Table):
"""NarrativeText with tag information"""

class HTMLCaption(TagsMixin, Caption):
"""NarrativeText with tag information"""

def has_table_ancestor(element: TagsMixin) -> bool:
"""Checks to see if an element has ancestors that are table elements. If so, we consider
Expand Down Expand Up @@ -159,7 +165,7 @@ def _parse_pages_from_element_tree(self) -> List[Page]:
for article in articles:
descendanttag_elems: Tuple[etree._Element, ...] = ()
for tag_elem in article.iter():
if tag_elem in descendanttag_elems:
if tag_elem in descendanttag_elems and tag_elem.tag != CAPTION_TAG:
# Prevent repeating something that's been flagged as text as we chase it
# down a chain
continue
Expand Down Expand Up @@ -337,7 +343,7 @@ def _parse_HTMLTable_from_table_elem(table_elem: etree._Element) -> Optional[Ele
# -- cell within the table within the cell too.)

trs = cast(
List[etree._Element], table_elem.xpath("./tr | ./thead/tr | ./tbody/tr | ./tfoot/tr")
List[etree._Element], table_elem.xpath("./tr | ./thead/tr | ./tbody/tr | ./tfoot/tr | ./caption")
)

if not trs:
Expand All @@ -346,17 +352,21 @@ def _parse_HTMLTable_from_table_elem(table_elem: etree._Element) -> Optional[Ele
def iter_cell_texts(tr: etree._Element) -> Iterator[str]:
"""Generate the text of each cell in `tr`."""
# -- a cell can be either a "data" cell (td) or a "heading" cell (th) --
tds = cast(List[etree._Element], tr.xpath("./td | ./th"))
for td in tds:
# -- a cell can contain other elements like spans etc. so we can't count on the text
# -- being directly below the `<td>` element. `.itertext()` gets all of it recursively.
# -- Filter out whitespace text nodes that result from HTML formatting.
stripped_text_nodes = (t.strip() for t in cast(Iterator[str], td.itertext()))
yield " ".join(t for t in stripped_text_nodes if t)
if tr.tag == "caption":
stripped_text_nodes = (t.strip() for t in cast(Iterator[str], tr.itertext()))
yield " ".join(CAPTION_TAG_START+t+CAPTION_TAG_END for t in stripped_text_nodes if t)
else:
tds = cast(List[etree._Element], tr.xpath("./td | ./th"))
for td in tds:
# -- a cell can contain other elements like spans etc. so we can't count on the text
# -- being directly below the `<td>` element. `.itertext()` gets all of it recursively.
# -- Filter out whitespace text nodes that result from HTML formatting.
stripped_text_nodes = (t.strip() for t in cast(Iterator[str], td.itertext()))
yield " ".join(t for t in stripped_text_nodes if t)

table_data = [list(iter_cell_texts(tr)) for tr in trs]
html_table = htmlify_matrix_of_cell_texts(table_data)
table_text = " ".join(" ".join(t for t in row if t) for row in table_data).strip()
html_table = htmlify_matrix_of_cell_texts(table_data, CAPTION_TAG_START)
table_text = " ".join(" ".join(t.replace(CAPTION_TAG_START, "").replace(CAPTION_TAG_END, "") for t in row if t) for row in table_data).strip()

if table_text == "":
return None
Expand Down Expand Up @@ -465,6 +475,15 @@ def _text_to_element(
links=links,
emphasized_texts=emphasized_texts,
)

if tag == CAPTION_TAG:
return HTMLCaption(
text,
tag=tag,
ancestortags=ancestortags,
links=links,
emphasized_texts=emphasized_texts,
)

if len(text) < 2:
return None
Expand Down
28 changes: 17 additions & 11 deletions unstructured/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
_P = ParamSpec("_P")


def htmlify_matrix_of_cell_texts(matrix: Sequence[Sequence[str]]) -> str:
def htmlify_matrix_of_cell_texts(matrix: Sequence[Sequence[str]], CAPTION_TAG_START: str) -> str:
"""Form an HTML table from "rows" and "columns" of `matrix`.
Character overhead is minimized:
Expand All @@ -55,16 +55,22 @@ def iter_trs(rows_of_cell_strs: Sequence[Sequence[str]]) -> Iterator[str]:
# -- suppress emission of rows with no cells --
if not row_cell_strs:
continue
yield f"<tr>{''.join(iter_tds(row_cell_strs))}</tr>"

def iter_tds(row_cell_strs: Sequence[str]) -> Iterator[str]:
for s in row_cell_strs:
# -- take care of things like '<' and '>' in the text --
s = html.escape(s)
# -- substitute <br/> elements for line-feeds in the text --
s = "<br/>".join(s.split("\n"))
# -- strip leading and trailing whitespace, wrap it up and go --
yield f"<td>{s.strip()}</td>"
tds = ""
for s in row_cell_strs:
if s.find(CAPTION_TAG_START) == 0:
yield s
else:
tds += iter_tds(s)
if tds != "":
yield f"<tr>{tds}</tr>"

def iter_tds(s: str) -> str:
# -- take care of things like '<' and '>' in the text --
s = html.escape(s)
# -- substitute <br/> elements for line-feeds in the text --
s = "<br/>".join(s.split("\n"))
# -- strip leading and trailing whitespace, wrap it up and go --
return f"<td>{s.strip()}</td>"

return f"<table>{''.join(iter_trs(matrix))}</table>" if matrix else ""

Expand Down

0 comments on commit 8170228

Please sign in to comment.