Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support DuckDuckGoSearchAPI and TavilySearchAPI as Alternatives to You.com #20

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,12 @@ Below, we provide a quick start guide to run STORM locally to reproduce our expe
OPENAI_API_TYPE="azure"
AZURE_API_BASE=<your_azure_api_base_url>
AZURE_API_VERSION=<your_azure_api_version>
# Set up You.com search API key.
# Setup WEB_SEARCH_API as one of ['DuckDuckGoSearchAPI', 'TavilySearchAPI', 'YouSearchAPI'], using YouSearchAPI as default.
WEB_SEARCH_API="YouSearchAPI"
# Setup You.com search API key.
YDC_API_KEY=<your_youcom_api_key>
# Setup api.tavily.com search API key.
TAVILY_API_KEY=<your_api_tavily_com_key>
```

## Paper Experiments
Expand Down
5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,7 @@ fastchat
wikipedia==1.4.0
Wikipedia-API==0.6.0
rouge-score
toml
toml
newspaper4k
langchain==0.1.16

122 changes: 16 additions & 106 deletions src/modules/topic_expert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,116 +5,13 @@
from urllib.parse import urlparse

import requests
import dspy
from modules.utils import DialogueTurn, limit_word_count_preserve_newline, remove_uncompleted_sentences_with_citations
from modules.web_search_provider import DuckDuckGoSearchAPI, TavilySearchAPI, YouSearchAPI

import dspy

script_dir = os.path.dirname(os.path.abspath(__file__))

class MyYouRM(dspy.Retrieve):
def __init__(self, ydc_api_key=None, k=3):
super().__init__(k=k)
if not ydc_api_key and not os.environ.get("YDC_API_KEY"):
raise RuntimeError("You must supply ydc_api_key or set environment variable YDC_API_KEY")
elif ydc_api_key:
self.ydc_api_key = ydc_api_key
else:
self.ydc_api_key = os.environ["YDC_API_KEY"]

# The Wikipedia standard for sources.
self.generally_unreliable = None
self.deprecated = None
self.blacklisted = None

self._generate_domain_restriction()

def _generate_domain_restriction(self):
"""Generate domain restriction from Wikipedia standard."""

# Load the content of the file
file_path = os.path.join(script_dir, 'Wikipedia_Reliable sources_Perennial sources - Wikipedia.html')

with open(file_path, 'r', encoding='utf-8') as file:
content = file.read()

# Define the regular expression pattern to find the specified HTML tags
generally_unreliable = r'<tr class="s-gu" id="[^"]+">|<id="[^"]+" tr class="s-gu" >'
deprecate = r'<tr class="s-d" id="[^"]+">|<id="[^"]+" tr class="s-d" >'
blacklist = r'<tr class="s-b" id="[^"]+">|<id="[^"]+" tr class="s-b" >'

# find instance
gu = re.findall(generally_unreliable, content)
d = re.findall(deprecate, content)
b = re.findall(blacklist, content)

# extract id
s_gu = [re.search(r'id="([^"]+)"', match).group(1) for match in gu]
s_d = [re.search(r'id="([^"]+)"', match).group(1) for match in d]
s_b = [re.search(r'id="([^"]+)"', match).group(1) for match in b]

# complete list
generally_unreliable = [id_str.replace('&#39;', "'") for id_str in s_gu]
deprecated = [id_str.replace('&#39;', "'") for id_str in s_d]
blacklisted = [id_str.replace('&#39;', "'") for id_str in s_b]

# for now, when encountering Fox_News_(politics_and_science), we exclude the entire domain Fox_News and we can later increase the complexity of the rule to distinguish between different cases
generally_unreliable_f = set(id_str.split('_(')[0] for id_str in generally_unreliable)
deprecated_f = set(id_str.split('_(')[0] for id_str in deprecated)
blacklisted_f = set(id_str.split('_(')[0] for id_str in blacklisted)

self.generally_unreliable = generally_unreliable_f
self.deprecated = deprecated_f
self.blacklisted = blacklisted_f

def is_valid_wikipedia_source(self, url):
parsed_url = urlparse(url)
# Check if the URL is from a reliable domain
combined_set = self.generally_unreliable | self.deprecated | self.blacklisted
for domain in combined_set:
if domain in parsed_url.netloc:
return False

return True

def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str]):
"""Search with You.com for self.k top passages for query or queries

Args:
query_or_queries (Union[str, List[str]]): The query or queries to search for.
exclude_urls (List[str]): A list of urls to exclude from the search results.

Returns:
a list of Dicts, each dict has keys of 'description', 'snippets' (list of strings), 'title', 'url'
"""
queries = (
[query_or_queries]
if isinstance(query_or_queries, str)
else query_or_queries
)
collected_results = []
for query in queries:
try:
headers = {"X-API-Key": self.ydc_api_key}
results = requests.get(
f"https://api.ydc-index.io/search?query={query}",
headers=headers,
).json()

authoritative_results = []
for r in results['hits']:
if self.is_valid_wikipedia_source(r['url']):
authoritative_results.append(r)
if 'hits' in results:
collected_results.extend(authoritative_results[:self.k])
except Exception as e:
logging.error(f'Error occurs when searching query {query}: {e}')

if exclude_urls:
collected_results = [r for r in collected_results if r['url'] not in exclude_urls]

return collected_results


class QuestionToQuery(dspy.Signature):
"""You want to answer the question using Google search. What do you type in the search box?
Write the queries you will use in the following format:
Expand Down Expand Up @@ -152,9 +49,22 @@ def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel],
search_top_k):
super().__init__()
self.generate_queries = dspy.Predict(QuestionToQuery)
self.retrieve = MyYouRM(k=search_top_k)
self.answer_question = dspy.Predict(AnswerQuestion)
self.engine = engine

web_search_api = os.environ.get("WEB_SEARCH_API")
if not web_search_api:
raise RuntimeError("You must set environment variable WEB_SEARCH_API")

if web_search_api == "DuckDuckGoSearchAPI":
self.retrieve = DuckDuckGoSearchAPI(max_results=search_top_k, use_snippet=False, timeout=120)
elif web_search_api == "TavilySearchAPI":
self.retrieve = TavilySearchAPI(max_results=search_top_k, use_snippet=False, timeout=120)
elif web_search_api == "YouSearchAPI":
self.retrieve = YouSearchAPI(max_results=search_top_k)
else:
raise NotImplementedError(f"Expected WEB_SEARCH_API must be one of ['DuckDuckGoSearchAPI', 'TavilySearchAPI', 'YouSearchAPI'], but got {web_search_api} instead.")


def forward(self, topic: str, question: str, ground_truth_url: str):
with dspy.settings.context(lm=self.engine):
Expand Down