dibend commited on
Commit
62b932f
1 Parent(s): 97fe953

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -21
app.py CHANGED
@@ -7,7 +7,7 @@ import gradio as gr
7
  # Set FRED API Key from environment variable
8
  FRED_API_KEY = os.getenv('FRED_API_KEY')
9
 
10
- # Available FRED data series and their descriptive labels
11
  series_options = {
12
  "UNRATE": "Unemployment Rate",
13
  "GDP": "Gross Domestic Product",
@@ -17,13 +17,12 @@ series_options = {
17
  "M1SL": "M1 Money Supply",
18
  "M2SL": "M2 Money Supply",
19
  "M3SL": "M3 Money Supply",
20
- # Add more series as needed
21
  "HOUST": "Housing Starts",
22
  "PCE": "Personal Consumption Expenditures",
23
- "BAA10YM": "Moody's Baa Corporate Bond Yield Spread",
24
  }
25
 
26
- # Function to fetch data from FRED API
27
  def fetch_fred_data(series_ids):
28
  """
29
  Fetches data for a list of FRED series IDs.
@@ -33,12 +32,16 @@ def fetch_fred_data(series_ids):
33
  for series_id in series_ids:
34
  response = requests.get(
35
  f'https://api.stlouisfed.org/fred/series/observations',
36
- params={'series_id': series_id, 'api_key': FRED_API_KEY, 'file_type': 'json'}
 
 
 
 
37
  )
38
  if response.status_code == 200:
39
  observations = response.json().get('observations', [])
40
  dates = [obs['date'] for obs in observations]
41
- # Safely convert values to float, handling invalid entries
42
  values = [
43
  float(obs['value']) if obs['value'].replace('.', '', 1).isdigit() else float('nan')
44
  for obs in observations
@@ -55,22 +58,31 @@ def standardize_data(df):
55
  """
56
  return (df - df.mean()) / df.std()
57
 
58
- # Function to create 3D correlation matrix
59
  def create_3d_correlation_matrix(df):
 
 
 
 
60
  correlation_matrix = df.corr()
61
- fig = go.Figure(data=[go.Surface(z=correlation_matrix.values,
62
- x=correlation_matrix.columns,
63
- y=correlation_matrix.index)])
64
- fig.update_layout(title='3D Correlation Matrix (Standardized)',
65
- autosize=False,
66
- width=800, height=800,
67
- scene=dict(
68
- xaxis=dict(title='Variables'),
69
- yaxis=dict(title='Variables'),
70
- zaxis=dict(title='Correlation')))
 
 
 
 
 
71
  return fig
72
 
73
- # Gradio function
74
  def visualize_correlation(selected_series):
75
  # Map descriptive labels back to FRED series IDs
76
  series_ids = [series for series in series_options if series_options[series] in selected_series]
@@ -96,7 +108,7 @@ with gr.Blocks() as demo:
96
  series_selector = gr.CheckboxGroup(
97
  choices=list(series_options.values()),
98
  label="Select Economic Indicators",
99
- info="Choose one or more indicators to include in the correlation matrix.",
100
  )
101
  submit_button = gr.Button("Generate Matrix")
102
 
@@ -104,12 +116,12 @@ with gr.Blocks() as demo:
104
  plot_output = gr.Plot(label="3D Correlation Matrix")
105
  error_message = gr.Markdown("", visible=False)
106
 
107
- # Event handler
108
  submit_button.click(
109
  fn=visualize_correlation,
110
  inputs=[series_selector],
111
  outputs=[plot_output, error_message],
112
  )
113
 
114
- # Launch the app
115
  demo.launch(debug=True)
 
7
  # Set FRED API Key from environment variable
8
  FRED_API_KEY = os.getenv('FRED_API_KEY')
9
 
10
+ # List of FRED data series and their descriptive labels
11
  series_options = {
12
  "UNRATE": "Unemployment Rate",
13
  "GDP": "Gross Domestic Product",
 
17
  "M1SL": "M1 Money Supply",
18
  "M2SL": "M2 Money Supply",
19
  "M3SL": "M3 Money Supply",
 
20
  "HOUST": "Housing Starts",
21
  "PCE": "Personal Consumption Expenditures",
22
+ "BAA10YM": "Moody's Baa Corporate Bond Yield Spread"
23
  }
24
 
25
+ # Function to fetch data from the FRED API
26
  def fetch_fred_data(series_ids):
27
  """
28
  Fetches data for a list of FRED series IDs.
 
32
  for series_id in series_ids:
33
  response = requests.get(
34
  f'https://api.stlouisfed.org/fred/series/observations',
35
+ params={
36
+ 'series_id': series_id,
37
+ 'api_key': FRED_API_KEY,
38
+ 'file_type': 'json'
39
+ }
40
  )
41
  if response.status_code == 200:
42
  observations = response.json().get('observations', [])
43
  dates = [obs['date'] for obs in observations]
44
+ # Convert values to float, handling invalid entries
45
  values = [
46
  float(obs['value']) if obs['value'].replace('.', '', 1).isdigit() else float('nan')
47
  for obs in observations
 
58
  """
59
  return (df - df.mean()) / df.std()
60
 
61
+ # Function to create a responsive 3D correlation matrix
62
  def create_3d_correlation_matrix(df):
63
+ """
64
+ Creates a 3D correlation matrix graph using Plotly.
65
+ The graph will automatically adjust its size.
66
+ """
67
  correlation_matrix = df.corr()
68
+ fig = go.Figure(data=[go.Surface(
69
+ z=correlation_matrix.values,
70
+ x=correlation_matrix.columns,
71
+ y=correlation_matrix.index
72
+ )])
73
+ fig.update_layout(
74
+ title='3D Correlation Matrix (Standardized)',
75
+ autosize=True, # Enables auto-resizing
76
+ scene=dict(
77
+ xaxis=dict(title='Variables'),
78
+ yaxis=dict(title='Variables'),
79
+ zaxis=dict(title='Correlation')
80
+ ),
81
+ margin=dict(l=0, r=0, t=50, b=50) # Adjust margins for better fit
82
+ )
83
  return fig
84
 
85
+ # Gradio function to handle user interaction
86
  def visualize_correlation(selected_series):
87
  # Map descriptive labels back to FRED series IDs
88
  series_ids = [series for series in series_options if series_options[series] in selected_series]
 
108
  series_selector = gr.CheckboxGroup(
109
  choices=list(series_options.values()),
110
  label="Select Economic Indicators",
111
+ info="Choose one or more indicators to include in the correlation matrix."
112
  )
113
  submit_button = gr.Button("Generate Matrix")
114
 
 
116
  plot_output = gr.Plot(label="3D Correlation Matrix")
117
  error_message = gr.Markdown("", visible=False)
118
 
119
+ # Event handler for the submit button
120
  submit_button.click(
121
  fn=visualize_correlation,
122
  inputs=[series_selector],
123
  outputs=[plot_output, error_message],
124
  )
125
 
126
+ # Launch the Gradio app
127
  demo.launch(debug=True)