Sujatha commited on
Commit
4b771a1
·
verified ·
1 Parent(s): eaea443

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ from transformers import TabularTransformerForSequenceClassification, TabularTransformerConfig
4
+ from transformers import Trainer, TrainingArguments
5
+ from datasets import Dataset
6
+ import torch
7
+
8
+ # Sample Data
9
+ data = {
10
+ 'feature1': [0.5, 0.3, 0.7, 0.2],
11
+ 'feature2': [1, 0, 1, 1],
12
+ 'feature3': [0.6, 0.1, 0.8, 0.4],
13
+ 'label': [0, 1, 0, 1] # Binary classification
14
+ }
15
+ df = pd.DataFrame(data)
16
+ dataset = Dataset.from_pandas(df)
17
+
18
+ # Configure the Model
19
+ config = TabularTransformerConfig(
20
+ num_labels=2, # Binary classification
21
+ numerical_features=['feature1', 'feature2', 'feature3']
22
+ )
23
+ model = TabularTransformerForSequenceClassification(config)
24
+
25
+ # Define Training Arguments
26
+ training_args = TrainingArguments(
27
+ output_dir="./results",
28
+ evaluation_strategy="epoch",
29
+ learning_rate=2e-5,
30
+ per_device_train_batch_size=4,
31
+ num_train_epochs=3
32
+ )
33
+
34
+ # Define Trainer
35
+ trainer = Trainer(
36
+ model=model,
37
+ args=training_args,
38
+ train_dataset=dataset,
39
+ eval_dataset=dataset
40
+ )
41
+
42
+ # Train the model
43
+ trainer.train()
44
+
45
+ # Define Inference Function
46
+ def classify(feature1, feature2, feature3):
47
+ input_data = {'feature1': feature1, 'feature2': feature2, 'feature3': feature3}
48
+ input_df = pd.DataFrame([input_data])
49
+ test_dataset = Dataset.from_pandas(input_df)
50
+ with torch.no_grad():
51
+ logits = model(**test_dataset[:][0]).logits
52
+ prediction = torch.argmax(logits, dim=1).item()
53
+ return "Class 1" if prediction == 1 else "Class 0"
54
+
55
+ # Gradio Interface
56
+ iface = gr.Interface(
57
+ fn=classify,
58
+ inputs=[
59
+ gr.inputs.Slider(0, 1, step=0.1, label="Feature 1"),
60
+ gr.inputs.Slider(0, 1, step=0.1, label="Feature 2"),
61
+ gr.inputs.Slider(0, 1, step=0.1, label="Feature 3")
62
+ ],
63
+ outputs="text",
64
+ title="Tabular Classification with Hugging Face",
65
+ description="Classify entries based on tabular data"
66
+ )
67
+
68
+ iface.launch()