Skip to content
Snippets Groups Projects
Select Git revision
  • 92f13a83a9247162369c1f41ac8104a4c5b9435a
  • main default protected
2 results

ARIMA_forecast_Code

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    ARIMA_forecast_Code 2.69 KiB
    import numpy as np
    from netCDF4 import Dataset
    import matplotlib.pyplot as plt
    from statsmodels.tsa.arima.model import ARIMA
    
    # Function to calculate the count for each year
    def calculate_threshold_count_for_year(nc_file):
        with Dataset(nc_file, 'r') as nc:
            temperature_data = nc.variables['sst'][:]
        
        # Calculate the spatial average for each day
        daily_spatial_averages = np.mean(temperature_data, axis=(1, 2))
        
        # Calculate the 95th percentile of temperatures for the year
        percentile_95 = np.percentile(daily_spatial_averages, 95)
        
        # Identify periods where the 95th percentile threshold is crossed for 3 or more days consecutively
        consecutive_days_count = 0
        threshold_crossed = False
        
        for temp in daily_spatial_averages:
            if temp > percentile_95:
                consecutive_days_count += 1
                if consecutive_days_count >= 3:
                    threshold_crossed = True
            else:
                consecutive_days_count = 0
        
        # Return the result for the current year
        return threshold_crossed, consecutive_days_count
    
    # Lists to store results for plotting
    years = []
    heatwave_counts = []
    
    # Loop through years 1981 to 2022
    for year in range(1981, 2023):
        nc_file = f'output_file_TWCPO_{year}_dailymean.nc'
        threshold_crossed, count = calculate_threshold_count_for_year(nc_file)
        
        # Append results to lists
        years.append(year)
        heatwave_counts.append(count)
    
        # Print the result for the current year
        if threshold_crossed:
            print(f"Number of 3-day or longer periods where the spatial average is above 95th percentile for {year}: {count}")
        else:
            print(f"No periods with temperatures above 95th percentile for 3 or more consecutive days for {year}.")
    
    # Fit ARIMA models
    data = np.array(heatwave_counts)
    
    # ARIMA(5,1,0)
    model_1 = ARIMA(data, order=(5, 1, 0))
    fit_model_1 = model_1.fit()
    forecast_1 = fit_model_1.get_forecast(steps=forecast_steps)
    
    # ARIMA(10,1,0)
    model_2 = ARIMA(data, order=(10, 1, 0))
    fit_model_2 = model_2.fit()
    forecast_2 = fit_model_2.get_forecast(steps=forecast_steps)
    
    
    #Plottingseries with forecasts for three scenarios
    plt.plot(years, heatwave_counts, marker='o', label='Observed')
    plt.plot(range(2023, 2033), forecast_1.predicted_mean, color='blue', linestyle='dashed', marker='o', label='ARIMA(5,1,0) Forecast')
    plt.plot(range(2023, 2033), forecast_2.predicted_mean, color='green', linestyle='dashed', marker='o', label='ARIMA(10,1,0) Forecast')
    plt.xlabel('Year')
    plt.ylabel('Number of Heatwaves')
    plt.title('Number of Marine Heatwaves (1981-2022) with Forecasts for the next decade using ARIMA')
    plt.yticks(np.arange(min(heatwave_counts), max(heatwave_counts)+1, 3))
    plt.grid(False)
    plt.legend()
    plt.show()