Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve load_requirements in setup.py #2861

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## 0.13.3-dev4
## 0.13.3-dev5

### Enhancements

Expand Down
3 changes: 3 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ USER ${NB_USER}
COPY example-docs example-docs
COPY unstructured unstructured

# Allow unit tests run via docker to detect this file at the root
COPY setup_utils.py setup_utils.py

RUN python3.10 -c "from unstructured.partition.model_init import initialize; initialize()"

CMD ["/bin/bash"]
127 changes: 3 additions & 124 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,65 +18,11 @@
limitations under the License.
"""

from typing import List, Optional, Union

from setuptools import find_packages, setup

from setup_utils import get_base_reqs, get_extras
from unstructured.__version__ import __version__


def load_requirements(file_list: Optional[Union[str, List[str]]] = None) -> List[str]:
if file_list is None:
file_list = ["requirements/base.in"]
if isinstance(file_list, str):
file_list = [file_list]
requirements: List[str] = []
for file in file_list:
with open(file, encoding="utf-8") as f:
requirements.extend(f.readlines())
requirements = [
req for req in requirements if not req.startswith("#") and not req.startswith("-")
]
return requirements


csv_reqs = load_requirements("requirements/extra-csv.in")
doc_reqs = load_requirements("requirements/extra-docx.in")
docx_reqs = load_requirements("requirements/extra-docx.in")
epub_reqs = load_requirements("requirements/extra-epub.in")
image_reqs = load_requirements("requirements/extra-pdf-image.in")
markdown_reqs = load_requirements("requirements/extra-markdown.in")
msg_reqs = load_requirements("requirements/extra-msg.in")
odt_reqs = load_requirements("requirements/extra-odt.in")
org_reqs = load_requirements("requirements/extra-pandoc.in")
pdf_reqs = load_requirements("requirements/extra-pdf-image.in")
ppt_reqs = load_requirements("requirements/extra-pptx.in")
pptx_reqs = load_requirements("requirements/extra-pptx.in")
rtf_reqs = load_requirements("requirements/extra-pandoc.in")
rst_reqs = load_requirements("requirements/extra-pandoc.in")
tsv_reqs = load_requirements("requirements/extra-csv.in")
xlsx_reqs = load_requirements("requirements/extra-xlsx.in")

all_doc_reqs = list(
set(
csv_reqs
+ docx_reqs
+ epub_reqs
+ image_reqs
+ markdown_reqs
+ msg_reqs
+ odt_reqs
+ org_reqs
+ pdf_reqs
+ pptx_reqs
+ rtf_reqs
+ rst_reqs
+ tsv_reqs
+ xlsx_reqs,
),
)


setup(
name="unstructured",
description="A library that prepares raw documents for downstream ML tasks.",
Expand Down Expand Up @@ -106,75 +52,8 @@ def load_requirements(file_list: Optional[Union[str, List[str]]] = None) -> List
entry_points={
"console_scripts": ["unstructured-ingest=unstructured.ingest.main:main"],
},
install_requires=load_requirements(),
extras_require={
# Document specific extra requirements
"all-docs": all_doc_reqs,
"csv": csv_reqs,
"doc": doc_reqs,
"docx": docx_reqs,
"epub": epub_reqs,
"image": image_reqs,
"md": markdown_reqs,
"msg": msg_reqs,
"odt": odt_reqs,
"org": org_reqs,
"pdf": pdf_reqs,
"ppt": ppt_reqs,
"pptx": pptx_reqs,
"rtf": rtf_reqs,
"rst": rst_reqs,
"tsv": tsv_reqs,
"xlsx": xlsx_reqs,
# Extra requirements for data connectors
"airtable": load_requirements("requirements/ingest/airtable.in"),
"astra": load_requirements("requirements/ingest/astra.in"),
"azure": load_requirements("requirements/ingest/azure.in"),
"azure-cognitive-search": load_requirements(
"requirements/ingest/azure-cognitive-search.in",
),
"biomed": load_requirements("requirements/ingest/biomed.in"),
"box": load_requirements("requirements/ingest/box.in"),
"chroma": load_requirements("requirements/ingest/chroma.in"),
"clarifai": load_requirements("requirements/ingest/clarifai.in"),
"confluence": load_requirements("requirements/ingest/confluence.in"),
"delta-table": load_requirements("requirements/ingest/delta-table.in"),
"discord": load_requirements("requirements/ingest/discord.in"),
"dropbox": load_requirements("requirements/ingest/dropbox.in"),
"elasticsearch": load_requirements("requirements/ingest/elasticsearch.in"),
"gcs": load_requirements("requirements/ingest/gcs.in"),
"github": load_requirements("requirements/ingest/github.in"),
"gitlab": load_requirements("requirements/ingest/gitlab.in"),
"google-drive": load_requirements("requirements/ingest/google-drive.in"),
"hubspot": load_requirements("requirements/ingest/hubspot.in"),
"jira": load_requirements("requirements/ingest/jira.in"),
"mongodb": load_requirements("requirements/ingest/mongodb.in"),
"notion": load_requirements("requirements/ingest/notion.in"),
"onedrive": load_requirements("requirements/ingest/onedrive.in"),
"opensearch": load_requirements("requirements/ingest/opensearch.in"),
"outlook": load_requirements("requirements/ingest/outlook.in"),
"pinecone": load_requirements("requirements/ingest/pinecone.in"),
"postgres": load_requirements("requirements/ingest/postgres.in"),
"qdrant": load_requirements("requirements/ingest/qdrant.in"),
"reddit": load_requirements("requirements/ingest/reddit.in"),
"s3": load_requirements("requirements/ingest/s3.in"),
"sharepoint": load_requirements("requirements/ingest/sharepoint.in"),
"salesforce": load_requirements("requirements/ingest/salesforce.in"),
"sftp": load_requirements("requirements/ingest/sftp.in"),
"slack": load_requirements("requirements/ingest/slack.in"),
"wikipedia": load_requirements("requirements/ingest/wikipedia.in"),
"weaviate": load_requirements("requirements/ingest/weaviate.in"),
# Legacy extra requirements
"huggingface": load_requirements("requirements/huggingface.in"),
"local-inference": all_doc_reqs,
"paddleocr": load_requirements("requirements/extra-paddleocr.in"),
"embed-huggingface": load_requirements("requirements/ingest/embed-huggingface.in"),
"embed-octoai": load_requirements("requirements/ingest/embed-octoai.in"),
"embed-vertexai": load_requirements("requirements/ingest/embed-vertexai.in"),
"openai": load_requirements("requirements/ingest/embed-openai.in"),
"bedrock": load_requirements("requirements/ingest/embed-aws-bedrock.in"),
"databricks-volumes": load_requirements("requirements/ingest/databricks-volumes.in"),
},
install_requires=get_base_reqs(),
extras_require=get_extras(),
package_dir={"unstructured": "unstructured"},
package_data={"unstructured": ["nlp/*.txt"]},
)
130 changes: 130 additions & 0 deletions setup_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import collections
from pathlib import Path
from typing import List, Union

current_dir = Path(__file__).parent.absolute()
requirements_dir = current_dir / "requirements"
ingest_requirements_dir = requirements_dir / "ingest"


def load_requirements(file: Union[str, Path]) -> List[str]:
path = file if isinstance(file, Path) else Path(file)
requirements: List[str] = []
if not path.is_file():
raise FileNotFoundError(f"path does not point to a valid file: {path}")
if not path.suffix == ".in":
raise ValueError(f"file should have .in extension: {path}")
file_dir = path.parent.resolve()
with open(file, encoding="utf-8") as f:
raw = f.read().splitlines()
requirements.extend([r for r in raw if not r.startswith("#") and not r.startswith("-")])
recursive_reqs = [r for r in raw if r.startswith("-r")]
for recursive_req in recursive_reqs:
file_spec = recursive_req.split()[-1]
file_path = Path(file_dir) / file_spec
requirements.extend(load_requirements(file=file_path.resolve()))
# Remove duplicates and any blank entries
return list({r for r in requirements if r})


def get_base_reqs() -> List[str]:
file = requirements_dir / "base.in"
return load_requirements(file)


def get_doc_reqs() -> dict[str, List[str]]:
return {
"csv": load_requirements(requirements_dir / "extra-csv.in"),
"doc": load_requirements(requirements_dir / "extra-docx.in"),
"docx": load_requirements(requirements_dir / "extra-docx.in"),
"epub": load_requirements(requirements_dir / "extra-epub.in"),
"image": load_requirements(requirements_dir / "extra-pdf-image.in"),
"markdown": load_requirements(requirements_dir / "extra-markdown.in"),
"msg": load_requirements(requirements_dir / "extra-msg.in"),
"odt": load_requirements(requirements_dir / "extra-odt.in"),
"org": load_requirements(requirements_dir / "extra-pandoc.in"),
"pdf": load_requirements(requirements_dir / "extra-pdf-image.in"),
"ppt": load_requirements(requirements_dir / "extra-pptx.in"),
"pptx": load_requirements(requirements_dir / "extra-pptx.in"),
"rtf": load_requirements(requirements_dir / "extra-pandoc.in"),
"rst": load_requirements(requirements_dir / "extra-pandoc.in"),
"tsv": load_requirements(requirements_dir / "extra-csv.in"),
"xlsx": load_requirements(requirements_dir / "extra-xlsx.in"),
}


def get_all_doc_reqs() -> List[str]:
reqs = []
for req in get_doc_reqs().values():
reqs.extend(req)
return list(set(reqs))


def get_connector_reqs() -> dict[str, List[str]]:
return {
"airtable": load_requirements(ingest_requirements_dir / "airtable.in"),
"astra": load_requirements(ingest_requirements_dir / "astra.in"),
"azure": load_requirements(ingest_requirements_dir / "azure.in"),
"azure-cognitive-search": load_requirements(
ingest_requirements_dir / "azure-cognitive-search.in",
),
"biomed": load_requirements(ingest_requirements_dir / "biomed.in"),
"box": load_requirements(ingest_requirements_dir / "box.in"),
"chroma": load_requirements(ingest_requirements_dir / "chroma.in"),
"clarifai": load_requirements(ingest_requirements_dir / "clarifai.in"),
"confluence": load_requirements(ingest_requirements_dir / "confluence.in"),
"delta-table": load_requirements(ingest_requirements_dir / "delta-table.in"),
"discord": load_requirements(ingest_requirements_dir / "discord.in"),
"dropbox": load_requirements(ingest_requirements_dir / "dropbox.in"),
"elasticsearch": load_requirements(ingest_requirements_dir / "elasticsearch.in"),
"gcs": load_requirements(ingest_requirements_dir / "gcs.in"),
"github": load_requirements(ingest_requirements_dir / "github.in"),
"gitlab": load_requirements(ingest_requirements_dir / "gitlab.in"),
"google-drive": load_requirements(ingest_requirements_dir / "google-drive.in"),
"hubspot": load_requirements(ingest_requirements_dir / "hubspot.in"),
"jira": load_requirements(ingest_requirements_dir / "jira.in"),
"mongodb": load_requirements(ingest_requirements_dir / "mongodb.in"),
"notion": load_requirements(ingest_requirements_dir / "notion.in"),
"onedrive": load_requirements(ingest_requirements_dir / "onedrive.in"),
"opensearch": load_requirements(ingest_requirements_dir / "opensearch.in"),
"outlook": load_requirements(ingest_requirements_dir / "outlook.in"),
"pinecone": load_requirements(ingest_requirements_dir / "pinecone.in"),
"postgres": load_requirements(ingest_requirements_dir / "postgres.in"),
"qdrant": load_requirements(ingest_requirements_dir / "qdrant.in"),
"reddit": load_requirements(ingest_requirements_dir / "reddit.in"),
"s3": load_requirements(ingest_requirements_dir / "s3.in"),
"sharepoint": load_requirements(ingest_requirements_dir / "sharepoint.in"),
"salesforce": load_requirements(ingest_requirements_dir / "salesforce.in"),
"sftp": load_requirements(ingest_requirements_dir / "sftp.in"),
"slack": load_requirements(ingest_requirements_dir / "slack.in"),
"wikipedia": load_requirements(ingest_requirements_dir / "wikipedia.in"),
"weaviate": load_requirements(ingest_requirements_dir / "weaviate.in"),
"embed-huggingface": load_requirements(ingest_requirements_dir / "embed-huggingface.in"),
"embed-octoai": load_requirements(ingest_requirements_dir / "embed-octoai.in"),
"embed-vertexai": load_requirements(ingest_requirements_dir / "embed-vertexai.in"),
"openai": load_requirements(ingest_requirements_dir / "embed-openai.in"),
"bedrock": load_requirements(ingest_requirements_dir / "embed-aws-bedrock.in"),
"databricks-volumes": load_requirements(ingest_requirements_dir / "databricks-volumes.in"),
}


def get_extras() -> dict[str, List[str]]:
reqs = {
"all-docs": get_all_doc_reqs(),
# Legacy extra requirements
"huggingface": load_requirements(requirements_dir / "huggingface.in"),
"local-inference": get_all_doc_reqs(),
"paddleocr": load_requirements(requirements_dir / "extra-paddleocr.in"),
}
# Check there aren't any duplicate keys
doc_reqs = get_doc_reqs()
connector_reqs = get_connector_reqs()
all_keys = list(reqs.keys()) + list(doc_reqs.keys()) + list(connector_reqs.keys())
duplicates = [key for key, count in collections.Counter(all_keys).items() if count > 1]
if duplicates:
raise ValueError(
"duplicate keys found amongst dictionaries: {}".format(", ".join(duplicates))
)
reqs.update(doc_reqs)
reqs.update(connector_reqs)
return reqs
3 changes: 3 additions & 0 deletions test_unstructured/files/child_reqs/child_reqs.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# This is a child dependency file in a sub directory
pandas
torch
7 changes: 7 additions & 0 deletions test_unstructured/files/example.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# This is a sample file for managing dependencies
-c constraints.txt
-r ./child_reqs/child_reqs.in
-r ./other_reqs.in
requests
httpx
pandas
2 changes: 2 additions & 0 deletions test_unstructured/files/other_reqs.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# This is a child dependency file in the same directory
sphinx<4.3.2
Empty file.
47 changes: 47 additions & 0 deletions test_unstructured/test_setup_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from pathlib import Path

import pytest

import setup_utils

current_dir = Path(__file__).parent.absolute()


def test_load_requirements():
file = current_dir / "files" / "example.in"
reqs = setup_utils.load_requirements(file=file)
desired_deps = ["torch", "httpx", "requests", "sphinx<4.3.2", "pandas"]
assert len(reqs) == len(desired_deps)
assert sorted(reqs) == sorted(desired_deps)


def test_load_requirements_not_file():
file = current_dir / "files" / "nothing.in"
with pytest.raises(FileNotFoundError):
setup_utils.load_requirements(file=file)


def test_load_requirements_wrong_suffix():
file = current_dir / "files" / "wrong_ext.txt"
with pytest.raises(ValueError):
setup_utils.load_requirements(file=file)


def test_load_base():
reqs = setup_utils.get_base_reqs()
assert reqs


def test_load_doc_reqs():
reqs = setup_utils.get_doc_reqs()
assert reqs


def test_load_all_doc_reqs():
reqs = setup_utils.get_all_doc_reqs()
assert reqs


def test_load_extra_reqs():
reqs = setup_utils.get_extras()
assert reqs
2 changes: 1 addition & 1 deletion unstructured/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.13.3-dev4" # pragma: no cover
__version__ = "0.13.3-dev5" # pragma: no cover