Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: update notion extractor #3898

Merged
merged 1 commit into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 16 additions & 22 deletions api/core/rag/extractor/notion_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@

RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']

# if user want split by headings, use the corresponding splitter
HEADING_SPLITTER = {
'heading_1': '# ',
'heading_2': '## ',
'heading_3': '### ',
}

class NotionExtractor(BaseExtractor):

Expand Down Expand Up @@ -73,8 +77,7 @@ def _load_data_as_documents(
docs.extend(page_text_documents)
elif notion_page_type == 'page':
page_text_list = self._get_notion_block_data(notion_obj_id)
for page_text in page_text_list:
docs.append(Document(page_content=page_text))
docs.append(Document(page_content='\n'.join(page_text_list)))
else:
raise ValueError("notion page type not supported")

Expand All @@ -96,7 +99,7 @@ def _get_notion_database_data(

data = res.json()

database_content_list = []
database_content = []
if 'results' not in data or data["results"] is None:
return []
for result in data["results"]:
Expand Down Expand Up @@ -131,10 +134,9 @@ def _get_notion_database_data(
row_content = row_content + f'{key}:{value_content}\n'
else:
row_content = row_content + f'{key}:{value}\n'
document = Document(page_content=row_content)
database_content_list.append(document)
database_content.append(row_content)

return database_content_list
return [Document(page_content='\n'.join(database_content))]

def _get_notion_block_data(self, page_id: str) -> list[str]:
result_lines_arr = []
Expand All @@ -154,8 +156,6 @@ def _get_notion_block_data(self, page_id: str) -> list[str]:
json=query_dict
)
data = res.json()
# current block's heading
heading = ''
for result in data["results"]:
result_type = result["type"]
result_obj = result[result_type]
Expand All @@ -172,8 +172,6 @@ def _get_notion_block_data(self, page_id: str) -> list[str]:
if "text" in rich_text:
text = rich_text["text"]["content"]
cur_result_text_arr.append(text)
if result_type in HEADING_TYPE:
heading = text

result_block_id = result["id"]
has_children = result["has_children"]
Expand All @@ -185,11 +183,10 @@ def _get_notion_block_data(self, page_id: str) -> list[str]:
cur_result_text_arr.append(children_text)

cur_result_text = "\n".join(cur_result_text_arr)
cur_result_text += "\n\n"
if result_type in HEADING_TYPE:
result_lines_arr.append(cur_result_text)
if result_type in HEADING_SPLITTER:
result_lines_arr.append(f"{HEADING_SPLITTER[result_type]}{cur_result_text}")
else:
result_lines_arr.append(f'{heading}\n{cur_result_text}')
result_lines_arr.append(cur_result_text + '\n\n')

if data["next_cursor"] is None:
break
Expand Down Expand Up @@ -218,7 +215,6 @@ def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
data = res.json()
if 'results' not in data or data["results"] is None:
break
heading = ''
for result in data["results"]:
result_type = result["type"]
result_obj = result[result_type]
Expand All @@ -235,8 +231,6 @@ def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
text = rich_text["text"]["content"]
prefix = "\t" * num_tabs
cur_result_text_arr.append(prefix + text)
if result_type in HEADING_TYPE:
heading = text
result_block_id = result["id"]
has_children = result["has_children"]
block_type = result["type"]
Expand All @@ -247,10 +241,10 @@ def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
cur_result_text_arr.append(children_text)

cur_result_text = "\n".join(cur_result_text_arr)
if result_type in HEADING_TYPE:
result_lines_arr.append(cur_result_text)
if result_type in HEADING_SPLITTER:
result_lines_arr.append(f'{HEADING_SPLITTER[result_type]}{cur_result_text}')
else:
result_lines_arr.append(f'{heading}\n{cur_result_text}')
result_lines_arr.append(cur_result_text + '\n\n')

if data["next_cursor"] is None:
break
Expand Down
Empty file.
102 changes: 102 additions & 0 deletions api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from unittest import mock

from core.rag.extractor import notion_extractor

user_id = "user1"
database_id = "database1"
page_id = "page1"


extractor = notion_extractor.NotionExtractor(
notion_workspace_id='x',
notion_obj_id='x',
notion_page_type='page',
tenant_id='x',
notion_access_token='x')


def _generate_page(page_title: str):
return {
"object": "page",
"id": page_id,
"properties": {
"Page": {
"type": "title",
"title": [
{
"type": "text",
"text": {"content": page_title},
"plain_text": page_title
}
]
}
}
}


def _generate_block(block_id: str, block_type: str, block_text: str):
return {
"object": "block",
"id": block_id,
"parent": {
"type": "page_id",
"page_id": page_id
},
"type": block_type,
"has_children": False,
block_type: {
"rich_text": [
{
"type": "text",
"text": {"content": block_text},
"plain_text": block_text,
}]
}
}


def _mock_response(data):
response = mock.Mock()
response.status_code = 200
response.json.return_value = data
return response


def _remove_multiple_new_lines(text):
while '\n\n' in text:
text = text.replace("\n\n", "\n")
return text.strip()


def test_notion_page(mocker):
texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"]
mocked_notion_page = {
"object": "list",
"results": [
_generate_block("b1", "heading_1", texts[0]),
_generate_block("b2", "heading_2", texts[1]),
_generate_block("b3", "paragraph", texts[2]),
_generate_block("b4", "heading_3", texts[3])
],
"next_cursor": None
}
mocker.patch("requests.request", return_value=_mock_response(mocked_notion_page))

page_docs = extractor._load_data_as_documents(page_id, "page")
assert len(page_docs) == 1
content = _remove_multiple_new_lines(page_docs[0].page_content)
assert content == '# Head 1\n## 1.1\nparagraph 1\n### 1.1.1'


def test_notion_database(mocker):
page_title_list = ["page1", "page2", "page3"]
mocked_notion_database = {
"object": "list",
"results": [_generate_page(i) for i in page_title_list],
"next_cursor": None
}
mocker.patch("requests.post", return_value=_mock_response(mocked_notion_database))
database_docs = extractor._load_data_as_documents(database_id, "database")
assert len(database_docs) == 1
content = _remove_multiple_new_lines(database_docs[0].page_content)
assert content == '\n'.join([f'Page:{i}' for i in page_title_list])