|
@@ -8,6 +8,7 @@ from fastapi import (
|
|
Form,
|
|
Form,
|
|
)
|
|
)
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
+import requests
|
|
import os, shutil, logging, re
|
|
import os, shutil, logging, re
|
|
from datetime import datetime
|
|
from datetime import datetime
|
|
|
|
|
|
@@ -716,36 +717,19 @@ def validate_url(url: Union[str, Sequence[str]]):
|
|
if isinstance(validators.url(url), validators.ValidationError):
|
|
if isinstance(validators.url(url), validators.ValidationError):
|
|
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
|
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
|
if not ENABLE_RAG_LOCAL_WEB_FETCH:
|
|
if not ENABLE_RAG_LOCAL_WEB_FETCH:
|
|
- # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
|
|
|
|
- parsed_url = urllib.parse.urlparse(url)
|
|
|
|
- # Get IPv4 and IPv6 addresses
|
|
|
|
- ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
|
|
|
|
- # Check if any of the resolved addresses are private
|
|
|
|
- # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
|
|
|
|
- for ip in ipv4_addresses:
|
|
|
|
- if validators.ipv4(ip, private=True):
|
|
|
|
- raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
|
|
|
- for ip in ipv6_addresses:
|
|
|
|
- if validators.ipv6(ip, private=True):
|
|
|
|
|
|
+ # Check if the URL exists by making a HEAD request
|
|
|
|
+ try:
|
|
|
|
+ response = requests.head(url, allow_redirects=True)
|
|
|
|
+ if response.status_code != 200:
|
|
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
|
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
|
|
|
+ except requests.exceptions.RequestException:
|
|
|
|
+ raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
|
return True
|
|
return True
|
|
elif isinstance(url, Sequence):
|
|
elif isinstance(url, Sequence):
|
|
return all(validate_url(u) for u in url)
|
|
return all(validate_url(u) for u in url)
|
|
else:
|
|
else:
|
|
return False
|
|
return False
|
|
|
|
|
|
-
|
|
|
|
-def resolve_hostname(hostname):
|
|
|
|
- # Get address information
|
|
|
|
- addr_info = socket.getaddrinfo(hostname, None)
|
|
|
|
-
|
|
|
|
- # Extract IP addresses from address information
|
|
|
|
- ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
|
|
|
|
- ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
|
|
|
|
-
|
|
|
|
- return ipv4_addresses, ipv6_addresses
|
|
|
|
-
|
|
|
|
-
|
|
|
|
def search_web(engine: str, query: str) -> list[SearchResult]:
|
|
def search_web(engine: str, query: str) -> list[SearchResult]:
|
|
"""Search the web using a search engine and return the results as a list of SearchResult objects.
|
|
"""Search the web using a search engine and return the results as a list of SearchResult objects.
|
|
Will look for a search engine API key in environment variables in the following order:
|
|
Will look for a search engine API key in environment variables in the following order:
|