Sujatha commited on
Commit
0901b66
·
verified ·
1 Parent(s): b47b2b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -12
app.py CHANGED
@@ -13,21 +13,33 @@ data = {
13
  }
14
  df = pd.DataFrame(data)
15
 
16
- # Configure pytorch_tabular
17
  data_config = DataConfig(
18
  target=["target"],
19
- continuous_cols=["feature1", "feature2", "feature3"]
 
20
  )
21
- model_config = CategoryEmbeddingModelConfig(task="classification") # No `num_classes`
22
- trainer_config = TrainerConfig(max_epochs=10)
23
-
24
- # Initialize and train model
25
- tabular_model = TabularModel(
26
- data_config=data_config,
27
- model_config=model_config,
28
- trainer_config=trainer_config
 
29
  )
30
- tabular_model.fit(df)
 
 
 
 
 
 
 
 
 
 
31
 
32
  # Define Inference Function
33
  def classify(feature1, feature2, feature3):
@@ -55,4 +67,3 @@ iface = gr.Interface(
55
  # Launch with additional server settings for Hugging Face Spaces
56
  print("Launching Gradio Interface...")
57
  iface.launch(server_name="0.0.0.0", server_port=7860, share=True)
58
-
 
13
  }
14
  df = pd.DataFrame(data)
15
 
16
+ # Ensure all configurations are set correctly
17
  data_config = DataConfig(
18
  target=["target"],
19
+ continuous_cols=["feature1", "feature2", "feature3"],
20
+ task="classification"
21
  )
22
+
23
+ model_config = CategoryEmbeddingModelConfig(
24
+ task="classification",
25
+ layers="64-64", # Example hidden layer sizes
26
+ learning_rate=1e-3
27
+ )
28
+
29
+ trainer_config = TrainerConfig(
30
+ max_epochs=10
31
  )
32
+
33
+ # Initialize and train the model
34
+ try:
35
+ tabular_model = TabularModel(
36
+ data_config=data_config,
37
+ model_config=model_config,
38
+ trainer_config=trainer_config
39
+ )
40
+ tabular_model.fit(df)
41
+ except ValueError as e:
42
+ print(f"Error initializing TabularModel: {e}")
43
 
44
  # Define Inference Function
45
  def classify(feature1, feature2, feature3):
 
67
  # Launch with additional server settings for Hugging Face Spaces
68
  print("Launching Gradio Interface...")
69
  iface.launch(server_name="0.0.0.0", server_port=7860, share=True)