diff --git a/pyproject.toml b/pyproject.toml index 0bccb61..d5e2a1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,17 @@ pythonpath = "." addopts = '-p no:warnings' # disable pytest warnings log_format = '%(name)s %(levelname)s: %(message)s' +# coverage settings +[tool.coverage.run] +source = ["src"] + +[tool.coverage.report] +omit = ["src/dev_streamlit.py"] +exclude_lines = [ + "pragma: no cover", + "if __name__ == .__main__.:", +] + # ruff global settings [tool.ruff] line-length = 79 diff --git a/src/api.py b/src/api.py index 955f80d..2ad4506 100644 --- a/src/api.py +++ b/src/api.py @@ -2,6 +2,7 @@ Functions that make API calls stored here """ +import logging from datetime import datetime, timedelta from http import HTTPStatus @@ -15,9 +16,18 @@ from src import helper +logger = logging.getLogger(__name__) + testing = 1 +def _create_openmeteo_client(): + """Creates a cached, retry-enabled Open-Meteo API client.""" + cache_session = requests_cache.CachedSession(".cache", expire_after=3600) + retry_session = retry(cache_session, retries=5, backoff_factor=0.2) + return openmeteo_requests.Client(session=retry_session) + + def get_coordinates(args): """ Takes a location (city or address) and returns the coordinates: [lat, long] @@ -37,9 +47,9 @@ def get_coordinates(args): location.raw["name"], ] else: - print( - f"Invalid location '{address}' provided. " - "Using default location." + logger.warning( + "Invalid location '%s' provided. Using default location.", + address, ) return default_location() return default_location() @@ -67,10 +77,7 @@ def get_uv(lat, long, decimal, unit="imperial"): Get UV at coordinates (lat, long) Calling the API here: https://open-meteo.com/en/docs """ - # Setup the Open-Meteo API client with cache and retry on error - cache_session = requests_cache.CachedSession(".cache", expire_after=3600) - retry_session = retry(cache_session, retries=5, backoff_factor=0.2) - openmeteo = openmeteo_requests.Client(session=retry_session) + openmeteo = _create_openmeteo_client() url = "https://air-quality-api.open-meteo.com/v1/air-quality" params = { @@ -115,10 +122,7 @@ def get_uv_history(lat, long, decimal, unit="imperial"): API Documentation: https://open-meteo.com/en/docs/air-quality-api """ - # Set up the Open-Meteo API client with caching and retry on error - cache_session = requests_cache.CachedSession(".cache", expire_after=3600) - retry_session = retry(cache_session, retries=5, backoff_factor=0.2) - openmeteo = openmeteo_requests.Client(session=retry_session) + openmeteo = _create_openmeteo_client() # Calculate the date one year ago and the current hour one_year_ago = datetime.now() - timedelta(days=365) @@ -168,10 +172,7 @@ def ocean_information(lat, long, decimal, unit="imperial"): Get Ocean Data at coordinates API: https://open-meteo.com/en/docs/marine-weather-api """ - # Setup the Open-Meteo API client with cache and retry on error - cache_session = requests_cache.CachedSession(".cache", expire_after=3600) - retry_session = retry(cache_session, retries=5, backoff_factor=0.2) - openmeteo = openmeteo_requests.Client(session=retry_session) + openmeteo = _create_openmeteo_client() url = "https://marine-api.open-meteo.com/v1/marine" params = { @@ -197,8 +198,6 @@ def ocean_information(lat, long, decimal, unit="imperial"): current_wave_direction = round(current.Variables(1).Value(), decimal) current_wave_period = round(current.Variables(2).Value(), decimal) - # print(f"Current time {current.Time(``)}") - return [current_wave_height, current_wave_direction, current_wave_period] @@ -227,15 +226,12 @@ def ocean_information_history(lat, long, decimal, unit="imperial"): API Documentation: https://open-meteo.com/en/docs/marine-weather-api """ - # Set up the Open-Meteo API client with caching and retry on error - cache_session = requests_cache.CachedSession(".cache", expire_after=3600) - retry_session = retry(cache_session, retries=5, backoff_factor=0.2) - openmeteo = openmeteo_requests.Client(session=retry_session) + openmeteo = _create_openmeteo_client() # Calculate the date and current hour one year ago one_year_ago = datetime.now() - timedelta(days=365) formatted_date_one_year_ago = one_year_ago.strftime("%Y-%m-%d") - current_hour = one_year_ago.hour # Combined calculation here + current_hour = one_year_ago.hour # Define the API request parameters url = "https://marine-api.open-meteo.com/v1/marine" @@ -284,10 +280,7 @@ def current_wind_temp(lat, long, decimal, temp_unit="fahrenheit"): """ Gathers the wind and temperature data """ - # Setup the Open-Meteo API client with cache and retry on error - cache_session = requests_cache.CachedSession(".cache", expire_after=3600) - retry_session = retry(cache_session, retries=5, backoff_factor=0.2) - openmeteo = openmeteo_requests.Client(session=retry_session) + openmeteo = _create_openmeteo_client() url = "https://api.open-meteo.com/v1/forecast" params = { @@ -314,15 +307,12 @@ def current_wind_temp(lat, long, decimal, temp_unit="fahrenheit"): ] -def get_rain(lat, long, decimal): +def get_rain(lat, long): """ Get rain data at coordinates (lat, long) Calling the API here: https://open-meteo.com/en/docs """ - # Setup the Open-Meteo API client with cache and retry on error - cache_session = requests_cache.CachedSession(".cache", expire_after=3600) - retry_session = retry(cache_session, retries=5, backoff_factor=0.2) - openmeteo = openmeteo_requests.Client(session=retry_session) + openmeteo = _create_openmeteo_client() url = "https://api.open-meteo.com/v1/forecast" params = { @@ -336,15 +326,12 @@ def get_rain(lat, long, decimal): # Process daily data. The order of variables needs to be the # same as requested. daily = response.Daily() - daily_rain_sum = daily.Variables(0).ValuesAsNumpy(), decimal - daily_precipitation_probability_max = ( - daily.Variables(1).ValuesAsNumpy(), - decimal, - ) + daily_rain_sum = daily.Variables(0).ValuesAsNumpy() + daily_precipitation_probability_max = daily.Variables(1).ValuesAsNumpy() return ( - float(daily_rain_sum[0][0]), - float(daily_precipitation_probability_max[0][0]), + float(daily_rain_sum[0]), + float(daily_precipitation_probability_max[0]), ) @@ -353,10 +340,7 @@ def forecast(lat, long, decimal, days=0): Number of forecast days. Max is 7 API: https://open-meteo.com/en/docs/marine-weather-api """ - # Setup the Open-Meteo API client with cache and retry on error - cache_session = requests_cache.CachedSession(".cache", expire_after=3600) - retry_session = retry(cache_session, retries=5, backoff_factor=0.2) - openmeteo = openmeteo_requests.Client(session=retry_session) + openmeteo = _create_openmeteo_client() # First URL is the marine API. Second is for general weather/UV index urls = ( @@ -411,7 +395,6 @@ def forecast(lat, long, decimal, days=0): # Extract general weather data using a loop to reduce number of local # variables - general_data = [ helper.round_decimal( response_general.Daily().Variables(i).ValuesAsNumpy(), decimal @@ -454,12 +437,8 @@ def get_hourly_forecast(lat, long, days=1, unit="fahrenheit"): """ Gets hourly weather data """ - # Setup the Open-Meteo API client with cache and retry on error - cache_session = requests_cache.CachedSession(".cache", expire_after=3600) - retry_session = retry(cache_session, retries=5, backoff_factor=0.2) - openmeteo = openmeteo_requests.Client(session=retry_session) + openmeteo = _create_openmeteo_client() - # Make sure all required weather variables are listed here # The order of variables in hourly or daily is important # to assign them correctly below url = "https://api.open-meteo.com/v1/forecast" @@ -491,13 +470,9 @@ def get_hourly_forecast(lat, long, days=1, unit="fahrenheit"): "visibility": hourly_visibility, }) - # Sets variable to get current time current_time = pd.Timestamp.now(tz="UTC") - - # Sets variable to find index of current hour curr_hour = np.argmin(np.abs(hourly_data["date"] - current_time)) - # Creates dictionary for the current hour's weather data curr_hour_data = {} for i in ["cloud_cover", "visibility"]: curr_hour_data[i] = round(float(hourly_data[i].iloc[curr_hour]), 1) @@ -518,13 +493,13 @@ def gather_data(lat, long, arguments): uv_index = get_uv(lat, long, arguments["decimal"], arguments["unit"]) - wind_temp = current_wind_temp(lat, long, arguments["decimal"]) + hourly_dict = get_hourly_forecast(lat, long) - hourly_dict = get_hourly_forecast(lat, long, arguments["decimal"]) + air_temp, wind_speed, wind_dir = current_wind_temp( + lat, long, arguments["decimal"] + ) + rain_sum, precipitation_probability_max = get_rain(lat, long) - rain_data = get_rain(lat, long, arguments["decimal"]) - air_temp, wind_speed, wind_dir = wind_temp[0], wind_temp[1], wind_temp[2] - rain_sum, precipitation_probability_max = rain_data[0], rain_data[1] arguments["ocean_data"] = ocean_data arguments["uv_index"] = uv_index spot_forecast = forecast(lat, long, arguments["decimal"], 7) @@ -532,28 +507,19 @@ def gather_data(lat, long, arguments): spot_forecast, arguments["decimal"] ) + ocean_history = ocean_information_history( + lat, long, arguments["decimal"], arguments["unit"] + ) ocean_data_dict = { "Lat": lat, "Long": long, "Location": arguments["city"], "Height": ocean_data[0], - "Height one year ago": ( - ocean_information_history( - lat, long, arguments["decimal"], arguments["unit"] - )[0] - ), + "Height one year ago": ocean_history[0], "Swell Direction": ocean_data[1], - "Swell Direction one year ago": ( - ocean_information_history( - lat, long, arguments["decimal"], arguments["unit"] - )[1] - ), + "Swell Direction one year ago": ocean_history[1], "Period": ocean_data[2], - "Period one year ago": ( - ocean_information_history( - lat, long, arguments["decimal"], arguments["unit"] - )[2] - ), + "Period one year ago": ocean_history[2], "UV Index": uv_index, "UV Index one year ago": ( get_uv_history(lat, long, arguments["decimal"], arguments["unit"]) @@ -571,7 +537,7 @@ def gather_data(lat, long, arguments): return ocean_data_dict -def seperate_args_and_get_location(args): +def separate_args_and_get_location(args): """ Gets user's coordinates from either the argument(location=) or, if none, @@ -585,3 +551,7 @@ def seperate_args_and_get_location(args): "city": coordinates[2], } return location_data + + +# Backward-compatible alias for the misspelled name +seperate_args_and_get_location = separate_args_and_get_location diff --git a/src/art.py b/src/art.py index c61fa33..b077289 100644 --- a/src/art.py +++ b/src/art.py @@ -2,6 +2,10 @@ All ASCII art in this file """ +import logging + +logger = logging.getLogger(__name__) + # ASCII text colors colors = { "end": "\033[0m", @@ -12,7 +16,7 @@ "blue": "\033[0;34m", "purple": "\033[0;35m", "teal": "\033[0;36m", - "light_blue": "\033[0;34m", + "light_blue": "\033[0;94m", "white": "\033[0;37m", "bold_red": "\033[1;31m", "bold_green": "\033[1;32m", @@ -29,10 +33,10 @@ def print_wave(show_wave, show_large_wave, color): Prints Wave """ if color is not None and color.lower() not in colors: - print("Not a valid color") + logger.warning("Invalid color '%s'. Using default 'blue'.", color) color = "blue" - if int(show_large_wave) == 1: + if show_large_wave: print( colors[color] + """ @@ -48,7 +52,7 @@ def print_wave(show_wave, show_large_wave, color): """ + colors["end"] ) - elif int(show_wave) == 1: + elif show_wave: print( colors[color] + """ diff --git a/src/cli.py b/src/cli.py index cec244c..6eb71a7 100644 --- a/src/cli.py +++ b/src/cli.py @@ -2,62 +2,89 @@ Main module """ +import logging import sys from src import api, helper, settings from src.db import operations -# Load environment variables from .env file -env = settings.GPTSettings() -gpt_prompt = env.GPT_PROMPT -api_key = env.API_KEY -model = env.GPT_MODEL +logger = logging.getLogger(__name__) -# Check for DB -env_db = settings.DatabaseSettings() -db_uri = env_db.DB_URI -if db_uri: - db_handler = operations.SurfReportDatabaseOps() -gpt_info = [api_key, model] - - -def run(lat=0, long=0, args=None): +class SurfReport: """ - Main function + Orchestrates fetching, persisting, and displaying a surf report. + + Initializes settings and optional database connection on construction, + then exposes a run() method to execute a full report cycle. """ - if args is None: - args = helper.separate_args(sys.argv) - else: - args = helper.separate_args(args) - location = api.seperate_args_and_get_location(args) + def __init__(self): + gpt_env = settings.GPTSettings() + self.gpt_prompt = gpt_env.GPT_PROMPT + self.gpt_info = (gpt_env.API_KEY, gpt_env.GPT_MODEL) + self.db_handler = self._init_db() + + @staticmethod + def _init_db(): + """Initializes the database handler, or returns None if unavailable.""" + db_env = settings.DatabaseSettings() + if not db_env.DB_URI: + return None + try: + return operations.SurfReportDatabaseOps() + except Exception: + logger.warning( + "Could not connect to database. Reports will not be saved." + ) + return None + + def run(self, lat=None, long=None, args=None): + """ + Fetches surf data for the given coordinates or parsed location, + optionally persists the report, and renders output. + + Returns the ocean data dict, plus the GPT response when in text mode. + """ + args = helper.separate_args(args if args is not None else sys.argv) + + location = api.separate_args_and_get_location(args) + city, loc_lat, loc_long = helper.set_location(location) + + if lat is None or long is None: + lat, long = loc_lat, loc_long + + arguments = helper.arguments_dictionary(lat, long, city, args) + ocean_data_dict = api.gather_data(lat, long, arguments) - city, loc_lat, loc_long = helper.set_location(location) - if lat == 0 and long == 0: - lat, long = loc_lat, loc_long + self._save_report(ocean_data_dict) + return self._render_output(ocean_data_dict, arguments) - # Sets arguments = dictionary with all the CLI args (show_wave, city, etc.) - arguments = helper.arguments_dictionary(lat, long, city, args) + def _save_report(self, ocean_data_dict): + """Persists the report to the database if a handler is available.""" + if self.db_handler: + self.db_handler.insert_report(ocean_data_dict) - # Makes API calls (ocean, UV) and returns the values in a dictionary - ocean_data_dict = api.gather_data(lat, long, arguments) + def _render_output(self, ocean_data_dict, arguments): + """Renders JSON or human-readable output based on arguments.""" + if not arguments["json_output"]: + response = helper.print_outputs( + ocean_data_dict, arguments, self.gpt_prompt, self.gpt_info + ) + return ocean_data_dict, response + helper.json_output(ocean_data_dict) + return ocean_data_dict - # Build JSON output once — used by both branches and optional DB insert - json_out = helper.json_output(ocean_data_dict, print_output=False) - if arguments["json_output"] == 0: - response = helper.print_outputs( - ocean_data_dict, arguments, gpt_prompt, gpt_info - ) - if db_uri: - db_handler.insert_report(json_out) - return ocean_data_dict, response - else: - if db_uri: - db_handler.insert_report(json_out) - return json_out +def run(lat=None, long=None, args=None): + """Module-level entry point; delegates to SurfReport for convenience.""" + return SurfReport().run(lat=lat, long=long, args=args) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) run() diff --git a/src/db/connection.py b/src/db/connection.py index 38e7fba..357e611 100755 --- a/src/db/connection.py +++ b/src/db/connection.py @@ -4,9 +4,12 @@ from src.settings import DatabaseSettings +logger = logging.getLogger(__name__) + class Database: - # Manages the MongoDB connection + """Manages the MongoDB connection.""" + def __init__(self): settings = DatabaseSettings() self.db_uri = settings.DB_URI @@ -18,19 +21,18 @@ def connect(self, db_name="surf"): try: self.client = MongoClient(self.db_uri) self.db = self.client[db_name] - logging.info("Database connected successfully") + logger.info("Database connected successfully") except Exception as e: - logging.warning(f"Could not connect to MongoDB: {e}") + logger.warning("Could not connect to MongoDB: %s", e) raise return self.db def disconnect(self): - # Close the connection if self.client: self.client.close() self.client = None self.db = None - logging.info("Database connection closed") + logger.info("Database connection closed") db_manager = Database() diff --git a/src/db/operations.py b/src/db/operations.py index 5907db7..2464543 100755 --- a/src/db/operations.py +++ b/src/db/operations.py @@ -2,9 +2,12 @@ from src.db.connection import db_manager +logger = logging.getLogger(__name__) + class SurfReportDatabaseOps: - # Handles operations to the db + """Handles surf report operations against the database.""" + def __init__(self): self.db = db_manager.connect() self.collection = self.db["surfReports"] @@ -12,8 +15,8 @@ def __init__(self): def insert_report(self, report_document): try: rec = self.collection.insert_one(report_document) - logging.info(f"Document inserted with ID: {rec.inserted_id}") + logger.info("Document inserted with ID: %s", rec.inserted_id) return rec.inserted_id except Exception as e: - logging.error(f"Error inserting to the db: {e}") + logger.error("Error inserting to the db: %s", e) raise diff --git a/src/gpt.py b/src/gpt.py index 86f8ae3..a9212ab 100644 --- a/src/gpt.py +++ b/src/gpt.py @@ -2,9 +2,13 @@ GPT Functions stored here """ +import logging + from g4f.client import Client from openai import OpenAI +logger = logging.getLogger(__name__) + def simple_gpt(surf_summary, gpt_prompt): """ @@ -13,12 +17,16 @@ def simple_gpt(surf_summary, gpt_prompt): report the user wants, loaded in from the environment vars Using: https://github.com/xtekky/gpt4free """ - client = Client() - response = client.chat.completions.create( - model="gpt-3.5-turbo", - messages=[{"role": "user", "content": surf_summary + gpt_prompt}], - ) - return response.choices[0].message.content + try: + client = Client() + response = client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": surf_summary + gpt_prompt}], + ) + return response.choices[0].message.content + except Exception as e: + logger.error("GPT (free) request failed: %s", e) + return "Unable to generate GPT response." def openai_gpt(surf_summary, gpt_prompt, api_key, model): @@ -29,17 +37,18 @@ def openai_gpt(surf_summary, gpt_prompt, api_key, model): Uses openai's GPT, needs an API key https://platform.openai.com/docs/api-reference/introduction """ - client = OpenAI( - # This is the default and can be omitted - api_key=api_key, - ) - chat_completion = client.chat.completions.create( - messages=[ - { - "role": "user", - "content": surf_summary + gpt_prompt, - } - ], - model=model, - ) - return chat_completion.choices[0].message.content + try: + client = OpenAI(api_key=api_key) + chat_completion = client.chat.completions.create( + messages=[ + { + "role": "user", + "content": surf_summary + gpt_prompt, + } + ], + model=model, + ) + return chat_completion.choices[0].message.content + except Exception as e: + logger.error("OpenAI request failed: %s", e) + return "Unable to generate GPT response." diff --git a/src/helper.py b/src/helper.py index 4a1d8ab..07ceb35 100644 --- a/src/helper.py +++ b/src/helper.py @@ -3,38 +3,37 @@ """ import json -import sys -from pathlib import Path - -sys.path.append(str(Path(__file__).parent.parent)) +import logging from src import api, art, gpt -# At the top of helper.py, add a constant dict for default args; +logger = logging.getLogger(__name__) + +MAX_FORECAST_DAYS = 7 DEFAULT_ARGUMENTS = { - "show_wave": 1, - "show_large_wave": 0, - "show_uv": 1, - "show_past_uv": 0, - "show_height": 1, - "show_direction": 1, - "show_period": 1, - "show_height_history": 0, - "show_direction_history": 0, - "show_period_history": 0, - "show_city": 1, - "show_date": 1, - "show_air_temp": 0, - "show_wind_speed": 0, - "show_wind_direction": 0, - "json_output": 0, - "show_rain_sum": 0, - "show_precipitation_prob": 0, + "show_wave": True, + "show_large_wave": False, + "show_uv": True, + "show_past_uv": False, + "show_height": True, + "show_direction": True, + "show_period": True, + "show_height_history": False, + "show_direction_history": False, + "show_period_history": False, + "show_city": True, + "show_date": True, + "show_air_temp": False, + "show_wind_speed": False, + "show_wind_direction": False, + "json_output": False, + "show_rain_sum": False, + "show_precipitation_prob": False, "unit": "imperial", - "gpt": 0, - "show_cloud_cover": 0, - "show_visibility": 0, + "gpt": False, + "show_cloud_cover": False, + "show_visibility": False, } @@ -58,70 +57,69 @@ def arguments_dictionary(lat, long, city, args): return arguments -def set_output_values(args, arguments_dictionary): # noqa +def set_output_values(args, args_dict): # noqa """ - Takes a list of command line arguments (args) - and sets the appropriate values in the - arguments_dictionary (show_wave = 1, etc). - Returns the arguments_dictionary dict with the updated CLI args. + Takes a list of command line arguments (args) and sets the appropriate + values in args_dict (show_wave = True, etc). + Returns args_dict with the updated CLI args. """ mappings = { - "hide_wave": ("show_wave", 0), - "hw": ("show_wave", 0), - "show_large_wave": ("show_large_wave", 1), - "slw": ("show_large_wave", 1), - "hide_uv": ("show_uv", 0), - "huv": ("show_uv", 0), - "show_past_uv": ("show_past_uv", 1), - "spuv": ("show_past_uv", 1), - "hide_past_uv": ("show_past_uv", 0), - "hide_height": ("show_height", 0), - "hh": ("show_height", 0), - "show_height_history": ("show_height_history", 1), - "shh": ("show_height_history", 1), - "hide_height_history": ("show_height_history", 0), - "hide_direction": ("show_direction", 0), - "hdir": ("show_direction", 0), - "show_direction_history": ("show_direction_history", 1), - "sdh": ("show_direction_history", 1), - "hide_direction_history": ("show_direction_history", 0), - "hide_period": ("show_period", 0), - "hp": ("show_period", 0), - "show_period_history": ("show_period_history", 1), - "sph": ("show_period_history", 1), - "hide_period_history": ("show_period_history", 0), - "hide_location": ("show_city", 0), - "hl": ("show_city", 0), - "hide_date": ("show_date", 0), - "hdate": ("show_date", 0), + "hide_wave": ("show_wave", False), + "hw": ("show_wave", False), + "show_large_wave": ("show_large_wave", True), + "slw": ("show_large_wave", True), + "hide_uv": ("show_uv", False), + "huv": ("show_uv", False), + "show_past_uv": ("show_past_uv", True), + "spuv": ("show_past_uv", True), + "hide_past_uv": ("show_past_uv", False), + "hide_height": ("show_height", False), + "hh": ("show_height", False), + "show_height_history": ("show_height_history", True), + "shh": ("show_height_history", True), + "hide_height_history": ("show_height_history", False), + "hide_direction": ("show_direction", False), + "hdir": ("show_direction", False), + "show_direction_history": ("show_direction_history", True), + "sdh": ("show_direction_history", True), + "hide_direction_history": ("show_direction_history", False), + "hide_period": ("show_period", False), + "hp": ("show_period", False), + "show_period_history": ("show_period_history", True), + "sph": ("show_period_history", True), + "hide_period_history": ("show_period_history", False), + "hide_location": ("show_city", False), + "hl": ("show_city", False), + "hide_date": ("show_date", False), + "hdate": ("show_date", False), "metric": ("unit", "metric"), "m": ("unit", "metric"), - "json": ("json_output", 1), - "j": ("json_output", 1), - "gpt": ("gpt", 1), - "g": ("gpt", 1), - "show_air_temp": ("show_air_temp", 1), - "sat": ("show_air_temp", 1), - "show_wind_speed": ("show_wind_speed", 1), - "sws": ("show_wind_speed", 1), - "show_wind_direction": ("show_wind_direction", 1), - "swd": ("show_wind_direction", 1), - "show_rain_sum": ("show_rain_sum", 1), - "srs": ("show_rain_sum", 1), - "show_precipitation_prob": ("show_precipitation_prob", 1), - "spp": ("show_precipitation_prob", 1), - "show_cloud_cover": ("show_cloud_cover", 1), - "scc": ("show_cloud_cover", 1), - "show_visibility": ("show_visibility", 1), - "sv": ("show_visibility", 1), + "json": ("json_output", True), + "j": ("json_output", True), + "gpt": ("gpt", True), + "g": ("gpt", True), + "show_air_temp": ("show_air_temp", True), + "sat": ("show_air_temp", True), + "show_wind_speed": ("show_wind_speed", True), + "sws": ("show_wind_speed", True), + "show_wind_direction": ("show_wind_direction", True), + "swd": ("show_wind_direction", True), + "show_rain_sum": ("show_rain_sum", True), + "srs": ("show_rain_sum", True), + "show_precipitation_prob": ("show_precipitation_prob", True), + "spp": ("show_precipitation_prob", True), + "show_cloud_cover": ("show_cloud_cover", True), + "scc": ("show_cloud_cover", True), + "show_visibility": ("show_visibility", True), + "sv": ("show_visibility", True), } for arg in args: if arg in mappings: key, value = mappings[arg] - arguments_dictionary[key] = value + args_dict[key] = value - return arguments_dictionary + return args_dict def separate_args(args): @@ -148,7 +146,7 @@ def _extract_arg(args, keys, default, cast=str): try: return cast(arg_str.split("=", 1)[1]) except (ValueError, IndexError): - print(f"Invalid value for {keys[0]}. Using default.") + logger.warning("Invalid value for %s. Using default.", keys[0]) return default @@ -161,7 +159,9 @@ def extract_decimal(args): try: return int(arg.split("=")[1]) except (ValueError, IndexError): - print("Invalid value for decimal. Please provide an integer.") + logger.warning( + "Invalid value for decimal. Please provide an integer." + ) return 1 @@ -169,10 +169,12 @@ def get_forecast_days(args): """ Extract forecast day count from CLI args. Defaults to 0. Max is 7. """ - MAX_VALUE = 7 value = _extract_arg(args, ["forecast", "fc"], default=0, cast=int) - if value < 0 or value > MAX_VALUE: - print("Must choose a non-negative number <= 7 in forecast!") + if value < 0 or value > MAX_FORECAST_DAYS: + logger.warning( + "Forecast days must be between 0 and %d. Using default.", + MAX_FORECAST_DAYS, + ) return 0 return value @@ -188,7 +190,7 @@ def print_location(city, show_city): """ Prints location. """ - if int(show_city) == 1: + if show_city: print("Location: ", city) print("\n") @@ -232,7 +234,7 @@ def print_ocean_data(arguments_dict, ocean_data_dict): ] for arg_key, data_key, label in mappings: - if int(arguments_dict[arg_key]) == 1: + if arguments_dict[arg_key]: print(f"{label}{ocean_data_dict[data_key]}") @@ -265,7 +267,7 @@ def print_forecast(ocean, forecast): for day in range(ocean["forecast_days"]): for arg_key, data_key, label in mappings: - if int(ocean[arg_key]) == 1: + if ocean[arg_key]: try: data = forecast[data_key][day] formatted = round(float(data), ocean["decimal"]) @@ -284,12 +286,11 @@ def round_decimal(round_list, decimal): def json_output(data_dict, print_output=True): """ - If JSON=TRUE in .args, we print and return the JSON data. - Data dict includes current & forecast data. + Serializes data_dict to JSON. Prints to stdout if print_output is True. + Returns the original dict for programmatic use. """ - json_out = json.dumps(data_dict, indent=4) if print_output: - print(json_out) + print(json.dumps(data_dict, indent=4)) return data_dict @@ -320,7 +321,7 @@ def print_outputs(ocean_data_dict, arguments, gpt_prompt, gpt_info): print_forecast(arguments, forecast) gpt_response = None - if arguments["gpt"] == 1: + if arguments["gpt"]: gpt_response = print_gpt(ocean_data_dict, gpt_prompt, gpt_info) print(gpt_response) return gpt_response diff --git a/src/send_email.py b/src/send_email.py index cc2bfee..23d7425 100644 --- a/src/send_email.py +++ b/src/send_email.py @@ -2,6 +2,7 @@ Module to send surf report emails """ +import logging import smtplib import subprocess from email.mime.multipart import MIMEMultipart @@ -9,40 +10,45 @@ from src.settings import EmailSettings -# Load environment variables from .env file -env = EmailSettings() - -# Create a multipart message and set headers -message = MIMEMultipart() -message["From"] = env.EMAIL -message["To"] = env.EMAIL_RECEIVER -message["Subject"] = env.SUBJECT +logger = logging.getLogger(__name__) def send_user_email(): """ - Sends user an email + Fetches the current surf report via curl and sends it as an email. """ - SURF = subprocess.run( - ["curl", env.COMMAND], - capture_output=True, - text=True, - check=True, - ) - if SURF.returncode == 0: # Check if command executed successfully - BODY = SURF.stdout - else: - BODY = "Failed to execute curl command." - message.attach(MIMEText(BODY, "plain")) + env = EmailSettings() + + message = MIMEMultipart() + message["From"] = env.EMAIL + message["To"] = env.EMAIL_RECEIVER + message["Subject"] = env.SUBJECT + + try: + result = subprocess.run( + ["curl", env.COMMAND], + capture_output=True, + text=True, + check=True, + ) + body = result.stdout + except subprocess.CalledProcessError as e: + logger.error("Failed to fetch surf report via curl: %s", e.stderr) + body = "Failed to fetch surf report." + + message.attach(MIMEText(body, "plain")) - # Connect to the SMTP server with smtplib.SMTP(env.SMTP_SERVER, env.SMTP_PORT) as server: - server.starttls() # Secure the connection + server.starttls() server.login(env.EMAIL, env.EMAIL_PW) - text = message.as_string() - server.sendmail(env.EMAIL, env.EMAIL_RECEIVER, text) - print("Email sent successfully.") + server.sendmail(env.EMAIL, env.EMAIL_RECEIVER, message.as_string()) + logger.info("Email sent successfully.") -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) send_user_email() diff --git a/src/server.py b/src/server.py index b3ff417..0f1bcf3 100644 --- a/src/server.py +++ b/src/server.py @@ -2,8 +2,7 @@ Flask Server! """ -import asyncio -import os +import logging import subprocess import sys import urllib.parse @@ -20,86 +19,66 @@ from src.settings import ServerSettings +logger = logging.getLogger(__name__) + def create_app(env): """ Application factory function """ - - # Load environment variables from .env file - app = Flask(__name__) CORS(app) @app.route("/help") def serve_help(): - """ - Servers the help.txt file - """ + """Serves the help.txt file.""" return send_from_directory( Path(__file__).resolve().parents[1], "help.txt" ) @app.route("/home") def serve_index(): - """ - Servers index.html - """ + """Serves index.html.""" return render_template("index.html", env_vars=env.model_dump()) @app.route("/script.js") def serve_script(): - """ - Servers javascript - """ + """Serves the frontend JavaScript.""" return send_file("static/script.js") @app.route("/") def default_route(): - """ - Default route, serves surf report - """ + """Serves the surf report.""" query_parameters = urllib.parse.parse_qsl( request.query_string.decode(), keep_blank_values=True ) - parsed_parameters = [] - - for key, value in query_parameters: - if value: - parsed_parameters.append(f"{key}={value}") - else: - parsed_parameters.append(key) - - # Join the parsed parameters list into a single string + parsed_parameters = [ + f"{key}={value}" if value else key + for key, value in query_parameters + ] args = ",".join(parsed_parameters) - async def run_subprocess(): - try: - script_path = os.path.join("src", "cli.py") - - result = subprocess.run( - [sys.executable, script_path, args], - capture_output=True, - text=True, - check=True, - ) - return result.stdout - except subprocess.CalledProcessError as e: - # Print the error message from the subprocess - print("Error message from subprocess:", e.stderr) - # Raise the error again to propagate it - raise e - - # Run subprocess asynchronously - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - result = loop.run_until_complete(run_subprocess()) - return result + try: + result = subprocess.run( + [sys.executable, Path("src") / "cli.py", args], + capture_output=True, + text=True, + check=True, + ) + return result.stdout + except subprocess.CalledProcessError as e: + logger.error("Subprocess error: %s", e.stderr) + raise return app -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) env = ServerSettings() app = create_app(env) app.run(host="0.0.0.0", port=env.PORT, debug=env.DEBUG) diff --git a/src/settings.py b/src/settings.py index e250f18..ac4a9d4 100644 --- a/src/settings.py +++ b/src/settings.py @@ -41,15 +41,13 @@ class EmailSettings(ServerSettings): EMAIL: EmailStr # required EMAIL_PW: str # required EMAIL_RECEIVER: EmailStr # required - COMMAND: str = Field( - default_factory=lambda cls: f"{cls.IP_ADDRESS}:{cls.PORT}" - ) + COMMAND: str = Field(default="localhost:8000") SUBJECT: str = Field(default="Surf Report") class GPTSettings(CommonSettings): """ - Class for defining server env settings. + Class for defining GPT env settings. """ GPT_PROMPT: str = Field( @@ -69,4 +67,4 @@ class DatabaseSettings(CommonSettings): Class for defining database env settings """ - DB_URI: str + DB_URI: str = Field(default="") diff --git a/src/streamlit_helper.py b/src/streamlit_helper.py index 051894f..4a3d8fb 100644 --- a/src/streamlit_helper.py +++ b/src/streamlit_helper.py @@ -2,43 +2,35 @@ Helper functions for the streamlit frontend """ -import sys -from pathlib import Path - import folium import pandas as pd -sys.path.append(str(Path(__file__).parent.parent)) - from src import cli def extra_args(gpt): """ By default, the location is the only argument when cli.run() - is ran. Extra args outputs and other arguments the user wants, - like using the GPT function + is run. Extra args outputs and other arguments the user wants, + like using the GPT function. """ - # Arguments - extra_args = "" + args = "" if gpt: - extra_args += ",gpt" + args += ",gpt" - return extra_args + return args def get_report(location, extra_args): """ - Executes cli.run(), retrns the report dictionary, - gpt response, lat and long + Executes cli.run(), returns the report dictionary, + gpt response, lat and long. """ - gpt_response = None args = "location=" + location if extra_args: args += extra_args - surf_report = cli.run(args=["placeholder", args]) - report_dict, gpt_response = surf_report[0], surf_report[1] + report_dict, gpt_response = cli.run(args=["placeholder", args]) lat, long = report_dict["Lat"], report_dict["Long"] return report_dict, gpt_response, lat, long diff --git a/tests/test_api.py b/tests/test_api.py index 1b51b35..f7845fd 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -4,6 +4,7 @@ Run pytest: pytest """ +import logging from http import HTTPStatus from unittest.mock import Mock, patch @@ -204,3 +205,62 @@ def test_ocean_information_history(): result = ocean_information_history(31.9505, 115.8605, 1) expected_result = ["0.6", "0.6", "0.6"] assert result == expected_result + + +# --------------------------------------------------------------------------- +# Error / fallback paths +# --------------------------------------------------------------------------- + + +def test_get_coordinates_no_args_falls_back_to_default(mocker): + """get_coordinates falls back to default when no location= arg is given.""" + mocker.patch( + "src.api.default_location", return_value=[0.0, 0.0, "Default City"] + ) + result = get_coordinates([]) + assert result == [0.0, 0.0, "Default City"] + + +def test_get_coordinates_invalid_location_falls_back_to_default( + mocker, caplog +): + """get_coordinates logs a warning and falls back when geocoding fails.""" + mock_geo = mocker.patch("src.api.Nominatim") + mock_geo.return_value.geocode.return_value = None + mocker.patch( + "src.api.default_location", return_value=[0.0, 0.0, "Default City"] + ) + + with caplog.at_level(logging.WARNING, logger="src.api"): + result = get_coordinates(["location=nowhere_xyz_invalid"]) + + assert result == [0.0, 0.0, "Default City"] + assert "Invalid location" in caplog.text + + +def test_get_uv_returns_no_data_on_value_error(mocker): + """get_uv returns 'No data' when Open-Meteo client raises ValueError.""" + mock_client = mocker.patch("src.api._create_openmeteo_client") + mock_client.return_value.weather_api.side_effect = ValueError("bad coords") + assert get_uv(1000, -2000, 2) == "No data" + + +def test_get_uv_history_returns_no_data_on_value_error(mocker): + """get_uv_history returns 'No data' when the API raises ValueError.""" + mock_client = mocker.patch("src.api._create_openmeteo_client") + mock_client.return_value.weather_api.side_effect = ValueError("bad coords") + assert get_uv_history(31.9505, 115.8605, 2) == "No data" + + +def test_ocean_information_returns_no_data_on_value_error(mocker): + """ocean_information returns 'No data' when the API raises ValueError.""" + mock_client = mocker.patch("src.api._create_openmeteo_client") + mock_client.return_value.weather_api.side_effect = ValueError("bad coords") + assert ocean_information(1000, -2000, 2) == "No data" + + +def test_ocean_information_history_returns_no_data_on_value_error(mocker): + """ocean_information_history returns 'No data' on ValueError.""" + mock_client = mocker.patch("src.api._create_openmeteo_client") + mock_client.return_value.weather_api.side_effect = ValueError("bad coords") + assert ocean_information_history(1000, -2000, 2) == "No data" diff --git a/tests/test_art.py b/tests/test_art.py index 5819053..6e8d569 100644 --- a/tests/test_art.py +++ b/tests/test_art.py @@ -32,3 +32,17 @@ def test_print_wave(): # Perform assertions based on expected output assert "[0;34m" in output, "Blue color code not found in output" assert output, "print_wave() did not print anything" + + +def test_print_large_wave(): + """print_wave prints the large wave art when show_large_wave is True.""" + captured_output = io.StringIO() + sys.stdout = captured_output + + art.print_wave(True, True, "blue") + + sys.stdout = sys.__stdout__ + output = captured_output.getvalue() + + assert "[0;34m" in output, "Blue color code not found in large wave output" + assert output, "print_wave() did not print anything for large wave" diff --git a/tests/test_cli.py b/tests/test_cli.py index 8104eed..582d535 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,21 +1,200 @@ """ -QA tests for cli.py -Make sure pytest is installed: pip install pytest -Run pytest: pytest +Tests for cli.py """ +import logging +from unittest.mock import Mock -# TODO: fix broken test - -# def test_cli_output(): -# """ -# Main() returns a dictionary of: location, height, period, etc. -# This functions checks if the dictionary is returned and is populated -# """ -# expected = 5 -# # Hardcode lat and long for location. -# # If not, when test are ran in Github Actions -# # We get an error(because server probably isn't near ocean) -# data_dict = cli.run(36.95, -121.97)[0] -# time.sleep(5) -# assert len(data_dict) >= expected +from src.cli import SurfReport, run +from src.helper import DEFAULT_ARGUMENTS + +_LAT = 10.0 +_LONG = 20.0 + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _mock_settings( + mocker, *, db_uri="", gpt_prompt="prompt", api_key="", model="gpt-3.5" +): + """Patch both settings classes used by SurfReport.__init__.""" + mocker.patch( + "src.cli.settings.GPTSettings", + return_value=Mock( + GPT_PROMPT=gpt_prompt, API_KEY=api_key, GPT_MODEL=model + ), + ) + mocker.patch( + "src.cli.settings.DatabaseSettings", + return_value=Mock(DB_URI=db_uri), + ) + + +def _make_arguments(**overrides): + """Return a full arguments dict suitable for run() calls.""" + args = { + **DEFAULT_ARGUMENTS, + "lat": 36.97, + "long": -122.03, + "city": "Santa Cruz", + "decimal": 1, + "forecast_days": 0, + "color": "blue", + } + args.update(overrides) + return args + + +def _mock_run_pipeline(mocker, arguments, ocean_data): + """Patch all I/O helpers called inside SurfReport.run().""" + mocker.patch("src.cli.helper.separate_args", return_value=[]) + mocker.patch( + "src.cli.api.separate_args_and_get_location", + return_value={"city": "Santa Cruz", "lat": 36.97, "long": -122.03}, + ) + mocker.patch( + "src.cli.helper.set_location", + return_value=("Santa Cruz", 36.97, -122.03), + ) + mocker.patch("src.cli.helper.arguments_dictionary", return_value=arguments) + mocker.patch("src.cli.api.gather_data", return_value=ocean_data) + + +# --------------------------------------------------------------------------- +# Initialisation +# --------------------------------------------------------------------------- + + +def test_init_stores_gpt_settings(mocker): + """SurfReport reads GPT prompt and model from settings on construction.""" + _mock_settings(mocker, gpt_prompt="surf report", api_key="", model="gpt-4") + report = SurfReport() + assert report.gpt_prompt == "surf report" + assert report.gpt_info == ("", "gpt-4") + + +def test_init_db_none_when_uri_empty(mocker): + """db_handler is None when DB_URI is not configured.""" + _mock_settings(mocker, db_uri="") + assert SurfReport().db_handler is None + + +def test_init_db_connected_when_uri_set(mocker): + """db_handler is the SurfReportDatabaseOps instance when DB_URI is set.""" + _mock_settings(mocker, db_uri="mongodb://localhost") + mock_handler = Mock() + mocker.patch( + "src.cli.operations.SurfReportDatabaseOps", return_value=mock_handler + ) + assert SurfReport().db_handler is mock_handler + + +def test_init_db_logs_warning_and_returns_none_on_failure(mocker, caplog): + """db_handler is None and a warning is logged when DB connection fails.""" + _mock_settings(mocker, db_uri="mongodb://localhost") + mocker.patch( + "src.cli.operations.SurfReportDatabaseOps", + side_effect=Exception("timeout"), + ) + with caplog.at_level(logging.WARNING, logger="src.cli"): + report = SurfReport() + assert report.db_handler is None + assert "Could not connect to database" in caplog.text + + +# --------------------------------------------------------------------------- +# run() – text mode +# --------------------------------------------------------------------------- + + +def test_run_text_mode_returns_dict_and_gpt_response(mocker): + """run() in text mode returns (ocean_data_dict, gpt_response).""" + ocean_data = {"Height": 3.0, "Lat": 36.97, "Long": -122.03} + arguments = _make_arguments() + _mock_settings(mocker) + _mock_run_pipeline(mocker, arguments, ocean_data) + mock_print = mocker.patch( + "src.cli.helper.print_outputs", return_value="surf is fun" + ) + + result = SurfReport().run() + + assert result == (ocean_data, "surf is fun") + mock_print.assert_called_once() + + +def test_run_json_mode_returns_dict_only(mocker): + """run() in JSON mode returns only the ocean data dict.""" + ocean_data = {"Height": 3.0, "Lat": 36.97, "Long": -122.03} + arguments = _make_arguments(json_output=True) + _mock_settings(mocker) + _mock_run_pipeline(mocker, arguments, ocean_data) + mock_json = mocker.patch("src.cli.helper.json_output") + + result = SurfReport().run() + + assert result is ocean_data + mock_json.assert_called_once_with(ocean_data) + + +def test_run_uses_explicit_lat_long_over_resolved(mocker): + """Caller-supplied lat/long overrides the location resolved from args.""" + ocean_data = {"Height": 3.0, "Lat": _LAT, "Long": _LONG} + arguments = _make_arguments(lat=_LAT, long=_LONG) + _mock_settings(mocker) + _mock_run_pipeline(mocker, arguments, ocean_data) + mock_gather = mocker.patch( + "src.cli.api.gather_data", return_value=ocean_data + ) + mocker.patch("src.cli.helper.print_outputs", return_value=None) + + SurfReport().run(lat=_LAT, long=_LONG) + + call_lat, call_long = mock_gather.call_args[0][:2] + assert call_lat == _LAT + assert call_long == _LONG + + +# --------------------------------------------------------------------------- +# _save_report +# --------------------------------------------------------------------------- + + +def test_save_report_calls_insert_when_handler_set(mocker): + """_save_report delegates to the db_handler when one is configured.""" + _mock_settings(mocker, db_uri="mongodb://localhost") + mock_handler = Mock() + mocker.patch( + "src.cli.operations.SurfReportDatabaseOps", return_value=mock_handler + ) + data = {"Height": 3} + SurfReport()._save_report(data) + mock_handler.insert_report.assert_called_once_with(data) + + +def test_save_report_is_noop_without_handler(mocker): + """_save_report does nothing when db_handler is None.""" + _mock_settings(mocker) + SurfReport()._save_report({"Height": 3}) # must not raise + + +# --------------------------------------------------------------------------- +# Module-level run() shim +# --------------------------------------------------------------------------- + + +def test_module_run_delegates_to_surf_report(mocker): + """The module-level run() creates a SurfReport and forwards all args.""" + mock_instance = Mock() + mock_instance.run.return_value = {"ocean": "data"} + mock_class = mocker.patch("src.cli.SurfReport", return_value=mock_instance) + + result = run(lat=1.0, long=2.0, args=["placeholder", "json"]) + + mock_class.assert_called_once() + mock_instance.run.assert_called_once_with( + lat=1.0, long=2.0, args=["placeholder", "json"] + ) + assert result == {"ocean": "data"} diff --git a/tests/test_db.py b/tests/test_db.py new file mode 100644 index 0000000..696c6b5 --- /dev/null +++ b/tests/test_db.py @@ -0,0 +1,128 @@ +""" +Tests for src/db/connection.py and src/db/operations.py +""" + +import logging +from unittest.mock import MagicMock + +import pytest + +from src.db.connection import Database +from src.db.operations import SurfReportDatabaseOps + +# --------------------------------------------------------------------------- +# Database (connection.py) +# --------------------------------------------------------------------------- + + +def _db_with_mocked_settings(mocker, db_uri="mongodb://localhost"): + """Return a Database instance with settings and MongoClient mocked.""" + mock_client = MagicMock() + mock_cls = mocker.patch( + "src.db.connection.MongoClient", return_value=mock_client + ) + mocker.patch( + "src.db.connection.DatabaseSettings", + return_value=MagicMock(DB_URI=db_uri), + ) + return Database(), mock_client, mock_cls + + +def test_connect_creates_mongo_client(mocker): + """connect() creates a MongoClient and returns the named database.""" + db, mock_client, _ = _db_with_mocked_settings(mocker) + result = db.connect("surf") + assert result is mock_client["surf"] + + +def test_connect_reuses_existing_connection(mocker): + """connect() does not create a second MongoClient if already connected.""" + db, _, mock_cls = _db_with_mocked_settings(mocker) + db.connect() + db.connect() # second call — should not instantiate another client + mock_cls.assert_called_once() + + +def test_connect_logs_warning_and_reraises_on_failure(mocker, caplog): + """connect() logs a warning and re-raises when MongoClient fails.""" + mocker.patch( + "src.db.connection.MongoClient", + side_effect=Exception("connection refused"), + ) + mocker.patch( + "src.db.connection.DatabaseSettings", + return_value=MagicMock(DB_URI="mongodb://localhost"), + ) + db = Database() + with caplog.at_level(logging.WARNING, logger="src.db.connection"): + with pytest.raises(Exception, match="connection refused"): + db.connect() + assert "Could not connect to MongoDB" in caplog.text + + +def test_disconnect_closes_client_and_clears_state(mocker): + """disconnect() closes the MongoClient and resets client/db to None.""" + db, mock_client, _ = _db_with_mocked_settings(mocker) + db.connect() + db.disconnect() + mock_client.close.assert_called_once() + assert db.client is None + assert db.db is None + + +def test_disconnect_is_noop_when_not_connected(mocker): + """disconnect() does nothing when no client exists.""" + mocker.patch( + "src.db.connection.DatabaseSettings", + return_value=MagicMock(DB_URI=""), + ) + Database().disconnect() # must not raise + + +# --------------------------------------------------------------------------- +# SurfReportDatabaseOps (operations.py) +# --------------------------------------------------------------------------- + + +def test_init_connects_and_sets_collection(mocker): + """SurfReportDatabaseOps.__init__ connects and sets self.collection.""" + mock_col = MagicMock() + mock_db = MagicMock() + mock_db.__getitem__.return_value = mock_col + mocker.patch("src.db.operations.db_manager").connect.return_value = mock_db + + ops = SurfReportDatabaseOps() + + assert ops.collection is mock_col + + +def _make_ops(mocker): + """Return SurfReportDatabaseOps with bypassed __init__ and mocked col.""" + mocker.patch.object(SurfReportDatabaseOps, "__init__", return_value=None) + ops = SurfReportDatabaseOps() + mock_col = MagicMock() + ops.collection = mock_col + return ops, mock_col + + +def test_insert_report_returns_inserted_id(mocker): + """insert_report returns the inserted document ID on success.""" + ops, mock_col = _make_ops(mocker) + mock_col.insert_one.return_value = MagicMock(inserted_id="abc123") + + result = ops.insert_report({"Height": 3}) + + assert result == "abc123" + mock_col.insert_one.assert_called_once_with({"Height": 3}) + + +def test_insert_report_logs_error_and_reraises_on_failure(mocker, caplog): + """insert_report logs an error and re-raises the exception on failure.""" + ops, mock_col = _make_ops(mocker) + mock_col.insert_one.side_effect = Exception("write failed") + + with caplog.at_level(logging.ERROR, logger="src.db.operations"): + with pytest.raises(Exception, match="write failed"): + ops.insert_report({"Height": 3}) + + assert "Error inserting to the db" in caplog.text diff --git a/tests/test_gpt.py b/tests/test_gpt.py index c6aa3d6..c15884a 100644 --- a/tests/test_gpt.py +++ b/tests/test_gpt.py @@ -1,44 +1,68 @@ """ -QA tests for gpt.py -Make sure pytest is installed: pip install pytest -Run pytest: pytest +Tests for gpt.py """ +from unittest.mock import Mock -# // TODO: mock this api call, bad practice to actually make a call -# commenting out because this is breaking the ci/cd pipeline - -# def test_simple_gpt(): -# """ -# Testing the simple_gpt function -# Calls the simple gpt and asks it to output -# the days of the week. If the output does not contain -# any day of the week, we assume the gpt is non-fucntional -# """ - -# surf_summary = "" -# gpt_prompt = """Please output the days of the week in English. What day -# is your favorite?""" - -# gpt_response = gpt.simple_gpt(surf_summary, gpt_prompt).lower() -# expected_response = set([ -# "monday", -# "tuesday", -# "wednesday", -# "thursday", -# "friday" "saturday", -# "sunday", -# "一", -# "二", -# "三", -# "四", -# "五", -# ]) - -# # Can case the "gpt_response" string into a list, and -# # check for set intersection with the expected response set -# gpt_response_set = set(gpt_response.split()) - -# assert gpt_response_set.intersection( -# expected_response -# ), f"Expected '{expected_response}', but got: {gpt_response}" +from src import gpt + + +def _make_chat_response(content): + """Build a minimal chat completion response mock.""" + message = Mock() + message.content = content + choice = Mock() + choice.message = message + response = Mock() + response.choices = [choice] + return response + + +def test_simple_gpt_returns_model_content(mocker): + """simple_gpt returns the text content from the g4f response.""" + mock_client = Mock() + mock_client.chat.completions.create.return_value = _make_chat_response( + "Great surf day!" + ) + mocker.patch("src.gpt.Client", return_value=mock_client) + + result = gpt.simple_gpt("surf is 4ft", "what board should I ride?") + + assert result == "Great surf day!" + mock_client.chat.completions.create.assert_called_once() + + +def test_simple_gpt_returns_fallback_on_exception(mocker): + """simple_gpt returns the error string when the g4f client raises.""" + mocker.patch("src.gpt.Client", side_effect=Exception("API down")) + + result = gpt.simple_gpt("surf is 4ft", "what board?") + + assert result == "Unable to generate GPT response." + + +def test_openai_gpt_returns_model_content(mocker): + """openai_gpt returns the text content from the OpenAI response.""" + mock_client = Mock() + mock_client.chat.completions.create.return_value = _make_chat_response( + "Bring your longboard." + ) + mocker.patch("src.gpt.OpenAI", return_value=mock_client) + + result = gpt.openai_gpt( + "surf is 2ft", "recommend a board", "sk-testkey", "gpt-4" + ) + + assert result == "Bring your longboard." + mock_client.chat.completions.create.assert_called_once() + + +def test_openai_gpt_returns_fallback_on_exception(mocker): + """openai_gpt returns the error string when the OpenAI client raises.""" + mocker.patch("src.gpt.OpenAI", side_effect=Exception("quota exceeded")) + + result = gpt.openai_gpt( + "surf is 2ft", "recommend a board", "sk-key", "gpt-4" + ) + + assert result == "Unable to generate GPT response." diff --git a/tests/test_helper.py b/tests/test_helper.py index 01fcbfa..43928cd 100644 --- a/tests/test_helper.py +++ b/tests/test_helper.py @@ -5,21 +5,25 @@ """ import io +import logging from unittest.mock import patch from src import helper from src.helper import set_output_values +_LAT = 36.97 +_LONG = -122.03 -def test_invalid_input(): + +def test_invalid_input(caplog): """ - Test if decimal input prints proper invalid input message + Test if decimal input logs proper invalid input message """ - with patch("sys.stdout", new=io.StringIO()) as fake_stdout: + with caplog.at_level(logging.WARNING, logger="src.helper"): helper.extract_decimal(["decimal=NotADecimal"]) - printed_output = fake_stdout.getvalue().strip() - expected = "Invalid value for decimal. Please provide an integer." - assert printed_output == expected + assert ( + "Invalid value for decimal. Please provide an integer." in caplog.text + ) def test_default_input(): @@ -185,3 +189,279 @@ def test_get_forecast_days_valid(): def test_get_forecast_days_default(): assert helper.get_forecast_days([]) == 0 + + +# --------------------------------------------------------------------------- +# separate_args +# --------------------------------------------------------------------------- + + +def test_separate_args_splits_on_comma(): + """separate_args returns the second element split by comma.""" + result = helper.separate_args(["cmd", "json,location=santa_cruz"]) + assert result == ["json", "location=santa_cruz"] + + +def test_separate_args_returns_empty_for_single_arg(): + """separate_args returns [] when no extra args are passed.""" + assert helper.separate_args(["cmd"]) == [] + + +# --------------------------------------------------------------------------- +# _extract_arg +# --------------------------------------------------------------------------- + + +def test_extract_arg_logs_warning_on_invalid_cast(caplog): + """_extract_arg logs a warning and returns the default when cast fails.""" + with caplog.at_level(logging.WARNING, logger="src.helper"): + result = helper._extract_arg( + ["forecast=abc"], ["forecast"], 0, cast=int + ) + assert result == 0 + assert "Invalid value for forecast" in caplog.text + + +# --------------------------------------------------------------------------- +# get_forecast_days +# --------------------------------------------------------------------------- + + +def test_get_forecast_days_out_of_range_logs_warning(caplog): + """get_forecast_days returns 0 and warns when value exceeds MAX.""" + with caplog.at_level(logging.WARNING, logger="src.helper"): + result = helper.get_forecast_days(["forecast=10"]) + assert result == 0 + assert "Forecast days must be between" in caplog.text + + +def test_get_forecast_days_negative_logs_warning(caplog): + """get_forecast_days returns 0 and warns when value is negative.""" + with caplog.at_level(logging.WARNING, logger="src.helper"): + result = helper.get_forecast_days(["forecast=-1"]) + assert result == 0 + + +# --------------------------------------------------------------------------- +# print_location +# --------------------------------------------------------------------------- + + +def test_print_location_show_city_true(capsys): + """print_location prints the city name when show_city is True.""" + helper.print_location("Santa Cruz", True) + assert "Santa Cruz" in capsys.readouterr().out + + +# --------------------------------------------------------------------------- +# print_ocean_data +# --------------------------------------------------------------------------- + + +def test_print_ocean_data_prints_enabled_fields(capsys): + """print_ocean_data prints only the fields whose flag is True.""" + arguments_dict = { + "show_uv": True, + "show_past_uv": False, + "show_height": True, + "show_height_history": False, + "show_direction": False, + "show_direction_history": False, + "show_period": False, + "show_period_history": False, + "show_air_temp": False, + "show_wind_speed": False, + "show_wind_direction": False, + "show_rain_sum": False, + "show_precipitation_prob": False, + "show_cloud_cover": False, + "show_visibility": False, + } + ocean_data_dict = {"UV Index": 5, "Height": 3.5} + helper.print_ocean_data(arguments_dict, ocean_data_dict) + out = capsys.readouterr().out + assert "UV index: 5" in out + assert "Wave Height: 3.5" in out + + +# --------------------------------------------------------------------------- +# print_forecast +# --------------------------------------------------------------------------- + + +def test_print_forecast_renders_float_values(capsys): + """print_forecast prints rounded float values for enabled fields.""" + ocean = { + "forecast_days": 1, + "decimal": 1, + "show_date": False, + "show_uv": False, + "show_height": True, + "show_direction": False, + "show_period": False, + "show_air_temp": False, + "show_rain_sum": False, + "show_precipitation_prob": False, + "show_wind_speed": False, + "show_wind_direction": False, + } + forecast = {"wave_height_max": [3.567]} + helper.print_forecast(ocean, forecast) + assert "Wave Height: 3.6" in capsys.readouterr().out + + +def test_print_forecast_falls_back_on_type_error(capsys): + """print_forecast prints the raw value when float() raises TypeError.""" + ocean = { + "forecast_days": 1, + "decimal": 1, + "show_date": False, + "show_uv": True, + "show_height": False, + "show_direction": False, + "show_period": False, + "show_air_temp": False, + "show_rain_sum": False, + "show_precipitation_prob": False, + "show_wind_speed": False, + "show_wind_direction": False, + } + forecast = {"uv_index_max": [None]} # float(None) raises TypeError + helper.print_forecast(ocean, forecast) + assert "UV Index: None" in capsys.readouterr().out + + +# --------------------------------------------------------------------------- +# json_output +# --------------------------------------------------------------------------- + + +def test_json_output_prints_when_print_output_true(capsys): + """json_output prints JSON to stdout and returns the original dict.""" + data = {"key": "value"} + result = helper.json_output(data) + assert result is data + assert '"key": "value"' in capsys.readouterr().out + + +def test_json_output_silent_when_print_output_false(capsys): + """json_output returns the dict silently when print_output=False.""" + data = {"key": "value"} + result = helper.json_output(data, print_output=False) + assert result is data + assert not capsys.readouterr().out + + +# --------------------------------------------------------------------------- +# print_outputs +# --------------------------------------------------------------------------- + + +def test_print_outputs_no_ocean_data(mocker, capsys): + """print_outputs shows 'No ocean data' when Height is 'No data'.""" + ocean_data = {"Height": "No data", "Lat": 36.97, "Long": -122.03} + arguments = { + **helper.DEFAULT_ARGUMENTS, + "lat": 36.97, + "long": -122.03, + "city": "Santa Cruz", + "decimal": 1, + "forecast_days": 0, + "color": "blue", + } + mocker.patch("src.api.forecast", return_value={}) + helper.print_outputs(ocean_data, arguments, "", (None, "")) + assert "No ocean data at this location." in capsys.readouterr().out + + +def test_print_outputs_valid_data(mocker, capsys): + """print_outputs renders location, wave art, and ocean fields.""" + ocean_data = { + "Height": 3.0, + "Lat": 36.97, + "Long": -122.03, + "UV Index": 5, + "Swell Direction": 270, + "Period": 12, + } + arguments = { + **helper.DEFAULT_ARGUMENTS, + "lat": 36.97, + "long": -122.03, + "city": "Santa Cruz", + "decimal": 1, + "forecast_days": 0, + "color": "blue", + } + mocker.patch("src.api.forecast", return_value={}) + helper.print_outputs(ocean_data, arguments, "", (None, "")) + out = capsys.readouterr().out + assert "Santa Cruz" in out + assert "Wave Height: 3.0" in out + + +def test_print_outputs_with_gpt(mocker, capsys): + """print_outputs calls print_gpt and prints its response when gpt=True.""" + ocean_data = { + "Height": 3.0, + "Lat": 36.97, + "Long": -122.03, + "UV Index": 5, + "Swell Direction": 270, + "Period": 12, + } + arguments = { + **helper.DEFAULT_ARGUMENTS, + "lat": 36.97, + "long": -122.03, + "city": "Santa Cruz", + "decimal": 1, + "forecast_days": 0, + "color": "blue", + "gpt": True, + } + mocker.patch("src.api.forecast", return_value={}) + mocker.patch("src.helper.print_gpt", return_value="GPT says: go surf!") + + result = helper.print_outputs(ocean_data, arguments, "prompt", (None, "")) + + assert result == "GPT says: go surf!" + assert "GPT says: go surf!" in capsys.readouterr().out + + +# --------------------------------------------------------------------------- +# set_location +# --------------------------------------------------------------------------- + + +def test_set_location_unpacks_dict(): + """set_location returns (city, lat, long) from the location dict.""" + location = {"city": "Santa Cruz", "lat": _LAT, "long": _LONG} + city, lat, long = helper.set_location(location) + assert city == "Santa Cruz" + assert lat == _LAT + assert long == _LONG + + +# --------------------------------------------------------------------------- +# print_gpt +# --------------------------------------------------------------------------- + + +def test_print_gpt_uses_openai_when_key_is_long_enough(mocker): + """print_gpt calls openai_gpt when the API key is at least 5 chars.""" + surf_data = { + "Location": "Santa Cruz", + "Height": "3", + "Swell Direction": "270", + "Period": "12", + "Unit": "ft", + } + mock_openai = mocker.patch( + "src.helper.gpt.openai_gpt", return_value="openai response" + ) + result = helper.print_gpt( + surf_data, "any prompt", ("sk-validkey", "gpt-4") + ) + assert result == "openai response" + mock_openai.assert_called_once() diff --git a/tests/test_send_email.py b/tests/test_send_email.py new file mode 100644 index 0000000..3863f4e --- /dev/null +++ b/tests/test_send_email.py @@ -0,0 +1,70 @@ +""" +Tests for send_email.py +""" + +import subprocess +from unittest.mock import MagicMock + +from src.send_email import send_user_email + + +def _env_defaults(): + return { + "EMAIL": "sender@example.com", + "EMAIL_RECEIVER": "receiver@example.com", + "SUBJECT": "Surf Report", + "COMMAND": "localhost:8000", + "SMTP_SERVER": "smtp.example.com", + "SMTP_PORT": 587, + "EMAIL_PW": "secret", + } + + +def _patch_env(mocker, **overrides): + env = _env_defaults() + env.update(overrides) + return mocker.patch( + "src.send_email.EmailSettings", return_value=MagicMock(**env) + ) + + +def test_send_email_success(mocker): + """send_user_email fetches the report via curl and sends the email.""" + _patch_env(mocker) + + mock_result = MagicMock() + mock_result.stdout = "Surf height: 3ft" + mocker.patch("subprocess.run", return_value=mock_result) + + mock_smtp = MagicMock() + mock_smtp_cls = mocker.patch("smtplib.SMTP", return_value=mock_smtp) + mock_smtp.__enter__ = lambda s: s + mock_smtp.__exit__ = MagicMock(return_value=False) + + send_user_email() + + mock_smtp_cls.assert_called_once() + mock_smtp.sendmail.assert_called_once() + + +def test_send_email_curl_failure_uses_fallback_body(mocker): + """send_user_email falls back to an error message when curl fails.""" + _patch_env(mocker) + + mocker.patch( + "subprocess.run", + side_effect=subprocess.CalledProcessError(1, "curl", stderr="error"), + ) + + mock_smtp = MagicMock() + mocker.patch("smtplib.SMTP", return_value=mock_smtp) + mock_smtp.__enter__ = lambda s: s + mock_smtp.__exit__ = MagicMock(return_value=False) + + # Should not raise; the fallback body is used instead + send_user_email() + + mock_smtp.sendmail.assert_called_once() + # The email body should contain the fallback message + call_args = mock_smtp.sendmail.call_args[0] + assert "Failed to fetch surf report." in call_args[2] diff --git a/tests/test_server.py b/tests/test_server.py index 5e9c6a4..406833f 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,30 +1,43 @@ """ -QA tests for server.py -Make sure pytest is installed: pip install pytest -Run pytest: pytest +Tests for server.py """ +import subprocess +from http import HTTPStatus +from unittest.mock import patch -# TODO: fix broken test +from src.server import create_app +from src.settings import ServerSettings -# def test_routes(): -# """ -# Test that the routes are able to be retrieved -# /home, /help, / -# When a page is requested (GET) -# THEN check if the response is valid (200) -# """ -# env = settings.ServerSettings() -# flask_app = create_app(env) -# OK = 200 -# # Create a test client using the Flask application configured for testing -# with flask_app.test_client() as test_client: -# response_help = test_client.get("/help") -# assert response_help.status_code == OK +def _make_app(): + return create_app(ServerSettings()) -# response_home = test_client.get("/home") -# assert response_home.status_code == OK -# response_root = test_client.get("/") -# assert response_root.status_code == OK +def test_serve_index_returns_200(monkeypatch): + """GET /home renders the index template and returns 200.""" + app = _make_app() + with patch("src.server.render_template", return_value="home"): + resp = app.test_client().get("/home") + assert resp.status_code == HTTPStatus.OK + assert b"home" in resp.data + + +def test_serve_script_returns_200(monkeypatch): + """GET /script.js serves the JavaScript file and returns 200.""" + app = _make_app() + with patch("src.server.send_file", return_value="console.log('ok')"): + resp = app.test_client().get("/script.js") + assert resp.status_code == HTTPStatus.OK + + +def test_root_subprocess_error_returns_500(monkeypatch): + """GET / returns 500 and logs the error when the subprocess fails.""" + app = _make_app() + + def fail_run(*args, **kwargs): + raise subprocess.CalledProcessError(1, "cmd", stderr="boom") + + monkeypatch.setattr(subprocess, "run", fail_run) + resp = app.test_client().get("/") + assert resp.status_code == HTTPStatus.INTERNAL_SERVER_ERROR diff --git a/tests/test_streamlit_helper.py b/tests/test_streamlit_helper.py new file mode 100644 index 0000000..2ab891b --- /dev/null +++ b/tests/test_streamlit_helper.py @@ -0,0 +1,124 @@ +""" +Tests for streamlit_helper.py +""" + +import folium +import pandas as pd + +from src import streamlit_helper + +_LAT = 36.97 +_LONG = -122.03 +_EXPECTED_ROWS = 2 + + +def test_extra_args_without_gpt(): + """extra_args returns an empty string when GPT is disabled.""" + assert not streamlit_helper.extra_args(gpt=False) + + +def test_extra_args_with_gpt(): + """extra_args appends ',gpt' when GPT is enabled.""" + assert streamlit_helper.extra_args(gpt=True) == ",gpt" + + +def test_get_report_returns_report_dict_and_coords(mocker): + """get_report calls cli.run and unpacks the result correctly.""" + ocean_data = {"Lat": _LAT, "Long": _LONG, "Height": 3.0} + mocker.patch("src.cli.SurfReport") + mocker.patch( + "src.streamlit_helper.cli.run", + return_value=(ocean_data, "gpt response"), + ) + + report_dict, gpt_response, lat, long = streamlit_helper.get_report( + "santa_cruz", "" + ) + + assert report_dict is ocean_data + assert gpt_response == "gpt response" + assert lat == _LAT + assert long == _LONG + + +def test_get_report_appends_extra_args(mocker): + """get_report forwards extra_args to cli.run correctly.""" + ocean_data = {"Lat": 1.0, "Long": 2.0} + mock_run = mocker.patch( + "src.streamlit_helper.cli.run", + return_value=(ocean_data, None), + ) + + streamlit_helper.get_report("santa_cruz", ",gpt") + + call_kwargs = mock_run.call_args[1] + assert "location=santa_cruz,gpt" in call_kwargs["args"][1] + + +def test_map_data_returns_folium_map(): + """map_data returns a folium Map centred on the given coordinates.""" + result = streamlit_helper.map_data(_LAT, _LONG) + assert isinstance(result, folium.Map) + + +def test_graph_data_height_period_returns_dataframe(): + """graph_data returns a DataFrame with date/heights/periods columns.""" + report_dict = { + "Forecast": [ + { + "date": "2024-01-01", + "surf height": 3.0, + "swell period": 12.0, + "swell direction": 270.0, + }, + { + "date": "2024-01-02", + "surf height": 4.0, + "swell period": 10.0, + "swell direction": 260.0, + }, + ] + } + + df = streamlit_helper.graph_data( + report_dict, graph_type="Height/Period :ocean:" + ) + assert isinstance(df, pd.DataFrame) + assert list(df.columns) == ["date", "heights", "periods"] + assert len(df) == _EXPECTED_ROWS + + +def test_graph_data_direction_returns_dataframe(): + """graph_data returns a DataFrame with date/directions columns.""" + report_dict = { + "Forecast": [ + { + "date": "2024-01-01", + "surf height": 3.0, + "swell period": 12.0, + "swell direction": 270.0, + }, + ] + } + + df = streamlit_helper.graph_data(report_dict, graph_type="Direction") + assert isinstance(df, pd.DataFrame) + assert list(df.columns) == ["date", "directions"] + + +def test_graph_data_none_graph_type_defaults_to_height_period(): + """graph_data with graph_type=None defaults to Height/Period layout.""" + report_dict = { + "Forecast": [ + { + "date": "2024-01-01", + "surf height": 3.0, + "swell period": 12.0, + "swell direction": 270.0, + }, + ] + } + + df = streamlit_helper.graph_data(report_dict, graph_type=None) + assert "heights" in df.columns + assert "periods" in df.columns