"""
pygeoinf/datasets.py
Provides access to built-in datasets for testing, benchmarking,
and visualization across the pygeoinf package.
"""
import os
import csv
import random
import urllib.request
import urllib.parse
from typing import List, Tuple, Union
# Import the centralized path
from .config import DATADIR
# Define the specific file path
_CSV_PATH = os.path.join(DATADIR, "gsn_stations.csv")
[docs]
def download_gsn_stations(force: bool = False) -> None:
"""
Fetches the Global Seismograph Network (GSN) stations from the IRIS
FDSN API and saves them to a local CSV file in the data/ directory.
"""
if os.path.exists(_CSV_PATH) and not force:
return
print("pygeoinf: Local dataset missing. Fetching station data from IRIS...")
# Ensure the central DATADIR exists before writing!
os.makedirs(DATADIR, exist_ok=True)
url = "http://service.iris.edu/fdsnws/station/1/query"
params = {"network": "IU,II", "level": "station", "format": "text"}
full_url = f"{url}?{urllib.parse.urlencode(params)}"
stations = []
try:
with urllib.request.urlopen(full_url, timeout=10) as response:
lines = response.read().decode("utf-8").strip().split("\n")
for line in lines[1:]:
parts = line.split("|")
if len(parts) >= 4:
stations.append([parts[1], float(parts[2]), float(parts[3])])
with open(_CSV_PATH, "w", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow(["Station", "Latitude", "Longitude"])
writer.writerows(stations)
print(f"pygeoinf: Successfully saved {len(stations)} stations to {_CSV_PATH}")
except Exception as e:
raise RuntimeError(f"Failed to download GSN stations from IRIS. Error: {e}")
[docs]
def load_gsn_stations(
n_stations: int = None, include_names: bool = False
) -> Union[List[Tuple[float, float]], List[Tuple[str, float, float]]]:
"""
Loads a representative global set of seismic stations from the GSN.
If the internal CSV file is missing, this function will attempt to
automatically download it from IRIS into the pygeoinf/data/ directory.
Args:
n_stations: If provided, returns a random subsample of this size.
If greater than the total available stations, returns all.
include_names: If True, returns (Name, Latitude, Longitude).
If False, returns pure (Latitude, Longitude) tuples.
Returns:
A list of station tuples in degrees.
"""
_CSV_PATH = os.path.join(DATADIR, "gsn_stations.csv")
if not os.path.exists(_CSV_PATH):
download_gsn_stations()
stations = []
with open(_CSV_PATH, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
lat = float(row["Latitude"])
lon = float(row["Longitude"])
if include_names:
stations.append((row["Station"], lat, lon))
else:
stations.append((lat, lon))
# Sub-sample if requested and mathematically valid
if n_stations is not None and n_stations < len(stations):
return random.sample(stations, n_stations)
return stations
[docs]
def download_usgs_earthquakes(
min_magnitude: float = 5.0,
start_time: str = None,
end_time: str = None,
min_depth: float = None,
max_depth: float = None,
bbox: Tuple[float, float, float, float] = None,
limit: int = 2000,
force: bool = False,
filename: str = "usgs_events_filtered.csv",
) -> None:
"""
Fetches a filtered catalog of earthquakes from the USGS API and saves it
to a CSV in the centralized DATADIR.
"""
csv_path = os.path.join(DATADIR, filename)
if os.path.exists(csv_path) and not force:
return
print(f"pygeoinf: Fetching up to {limit} earthquakes from USGS...")
os.makedirs(DATADIR, exist_ok=True)
params = {"format": "csv", "limit": limit, "orderby": "time"}
if min_magnitude is not None:
params["minmagnitude"] = min_magnitude
if start_time is not None:
params["starttime"] = start_time
if end_time is not None:
params["endtime"] = end_time
if min_depth is not None:
params["mindepth"] = min_depth
if max_depth is not None:
params["maxdepth"] = max_depth
if bbox is not None:
params["minlatitude"] = bbox[0]
params["maxlatitude"] = bbox[1]
params["minlongitude"] = bbox[2]
params["maxlongitude"] = bbox[3]
url = "https://earthquake.usgs.gov/fdsnws/event/1/query"
full_url = f"{url}?{urllib.parse.urlencode(params)}"
try:
with urllib.request.urlopen(full_url, timeout=20) as response:
data = response.read().decode("utf-8")
with open(csv_path, "w", encoding="utf-8") as f:
f.write(data)
num_events = len(data.strip().split("\n")) - 1
print(f"pygeoinf: Successfully saved {num_events} events to {csv_path}")
except Exception as e:
raise RuntimeError(f"Failed to download USGS events. Error: {e}")
[docs]
def sample_earthquakes(
n_events: int, min_magnitude: float = 5.0
) -> List[Tuple[float, float, float]]:
"""
Returns a random subsample of real earthquake locations.
If the local cache does not contain enough events to satisfy the request,
it automatically fetches a larger catalog from the USGS to rebuild the cache.
Args:
n_events: The exact number of earthquake locations to return.
min_magnitude: The minimum magnitude to use if a new download is required.
Returns:
A list of tuples: (Latitude, Longitude, Depth_in_km).
"""
cache_filename = "usgs_event_cache.csv"
cache_path = os.path.join(DATADIR, cache_filename)
events = []
# 1. Try loading from the existing cache
if os.path.exists(cache_path):
with open(cache_path, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
events.append(
(
float(row["latitude"]),
float(row["longitude"]),
float(row["depth"]),
)
)
# 2. Check if the cache is large enough
if len(events) >= n_events:
# Use random.sample to grab unique items without replacement
return random.sample(events, n_events)
# 3. Cache is too small (or doesn't exist). Download a new one!
# Smart fetching: Always download at least 2000, or the requested amount + a 20% buffer.
# This prevents hitting the FDSN API repeatedly if the user slowly increases n_events.
fetch_limit = max(2000, int(n_events * 1.2))
print(
f"pygeoinf: Local cache only has {len(events)} events. Fetching {fetch_limit} to build a robust cache..."
)
download_usgs_earthquakes(
min_magnitude=min_magnitude,
limit=fetch_limit,
filename=cache_filename,
force=True, # Overwrite the old, insufficient cache
)
# 4. Reload the newly downloaded cache
events = []
with open(cache_path, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
events.append(
(float(row["latitude"]), float(row["longitude"]), float(row["depth"]))
)
# 5. Return the exact sample size requested
# If the API returned fewer events than requested (e.g., asked for 100,000 Mag 9.0s),
# we just return whatever we actually managed to get.
return random.sample(events, min(n_events, len(events)))