Spaces:
Runtime error
Runtime error
import gradio as gr | |
import pandas as pd | |
from pytorch_tabular import TabularModel | |
from pytorch_tabular.config import DataConfig, TrainerConfig | |
from pytorch_tabular.models import CategoryEmbeddingModelConfig | |
# Sample Data | |
data = { | |
'feature1': [0.5, 0.3, 0.7, 0.2], | |
'feature2': [1, 0, 1, 1], | |
'feature3': [0.6, 0.1, 0.8, 0.4], | |
'target': [0, 1, 0, 1] # Binary classification target | |
} | |
df = pd.DataFrame(data) | |
# Ensure all configurations are set correctly | |
data_config = DataConfig( | |
target=["target"], | |
continuous_cols=["feature1", "feature2", "feature3"], | |
task="classification" | |
) | |
model_config = CategoryEmbeddingModelConfig( | |
task="classification", | |
layers="64-64", # Example hidden layer sizes | |
learning_rate=1e-3 | |
) | |
trainer_config = TrainerConfig( | |
max_epochs=10 | |
) | |
# Initialize and train the model | |
try: | |
tabular_model = TabularModel( | |
data_config=data_config, | |
model_config=model_config, | |
trainer_config=trainer_config | |
) | |
tabular_model.fit(df) | |
except ValueError as e: | |
print(f"Error initializing TabularModel: {e}") | |
# Define Inference Function | |
def classify(feature1, feature2, feature3): | |
input_data = pd.DataFrame({ | |
"feature1": [feature1], | |
"feature2": [feature2], | |
"feature3": [feature3] | |
}) | |
prediction = tabular_model.predict(input_data)["prediction"].iloc[0] | |
return "Class 1" if prediction == 1 else "Class 0" | |
# Gradio Interface | |
iface = gr.Interface( | |
fn=classify, | |
inputs=[ | |
gr.inputs.Slider(0, 1, step=0.1, label="Feature 1"), | |
gr.inputs.Slider(0, 1, step=0.1, label="Feature 2"), | |
gr.inputs.Slider(0, 1, step=0.1, label="Feature 3") | |
], | |
outputs="text", | |
title="Tabular Classification with PyTorch Tabular", | |
description="Classify entries based on tabular data" | |
) | |
# Launch with additional server settings for Hugging Face Spaces | |
print("Launching Gradio Interface...") | |
iface.launch(server_name="0.0.0.0", server_port=7860, share=True) | |