philschmid's picture
philschmid HF staff
wrong column
6bfe93a
import streamlit as st
from utils.load_dataset import load_datasets
from utils.load_tasks import load_tasks
from utils.load_models import load_models
from trainer import train_estimtator
from datetime import datetime
import logging
logger = logging.getLogger(__name__)
def main():
parameter = st.experimental_get_query_params()
parameter["model_name_or_path"] = parameter.get("model_name_or_path", ["none"])
parameter["dataset"] = parameter.get("dataset", ["none"])
parameter["task"] = parameter.get("task", ["none"])
### hyperparameter
parameter["epochs"] = parameter.get("epochs", [3])
parameter["learning_rate"] = parameter.get("learning_rate", [5e-5])
parameter["per_device_train_batch_size"] = parameter.get("per_device_train_batch_size", [8])
parameter["per_device_eval_batch_size"] = parameter.get("per_device_eval_batch_size", [8])
st.experimental_set_query_params(**parameter)
dataset_list = load_datasets()
task_list = load_tasks()
model_list = load_models()
st.header("Hugging Face model & dataset")
col1, col2 = st.beta_columns(2)
parameter["model_name_or_path"] = col1.selectbox("Model ID:", parameter["model_name_or_path"] + model_list)
st.experimental_set_query_params(**parameter)
parameter["dataset"] = col2.selectbox("Dataset:", parameter["dataset"] + dataset_list)
st.experimental_set_query_params(**parameter)
parameter["task"] = col1.selectbox("Task:", parameter["task"] + task_list)
st.experimental_set_query_params(**parameter)
use_auth_token = col2.text_input("HF auth token to upload your model:", help="api_xxxxx")
my_expander = st.beta_expander("Hyperparameters")
col1, col2 = my_expander.beta_columns(2)
parameter["epochs"] = col1.number_input("Epoch", 3)
st.experimental_set_query_params(**parameter)
parameter["learning_rate"] = col2.text_input("Learning Rate", 5e-5)
st.experimental_set_query_params(**parameter)
parameter["per_device_train_batch_size"] = col1.number_input("Training Batch Size", 8)
st.experimental_set_query_params(**parameter)
parameter["per_device_eval_batch_size"] = col2.number_input("Eval Batch Size", 8)
st.experimental_set_query_params(**parameter)
st.markdown("---")
st.header("Amazon Sagemaker configuration")
config = {}
config["job_name"] = st.text_input(
"model name",
f"{parameter['model_name_or_path'][0] if isinstance(parameter['model_name_or_path'],list)else parameter['model_name_or_path']}-job-{str(datetime.today()).split()[0]}",
)
col1, col2 = st.beta_columns(2)
config["aws_sagemaker_role"] = col1.text_input("AWS IAM role for sagemaker job")
config["instance_type"] = col2.selectbox(
"Instance type",
[
"single-gpu | ml.p3.2xlarge",
"multi-gpu | ml.p3.16xlarge",
],
)
config["region"] = col1.selectbox(
"AWS Region",
["eu-central-1", "eu-west-1", "us-east-1", "us-east-1", "us-west-1", "us-west-2"],
)
config["instance_count"] = col2.number_input("Instance count", 1)
config["use_spot"] = col1.selectbox("use spot instances", [False, True])
config["distributed"] = col2.selectbox("distributed training", [False, True])
st.markdown("---")
st.header("Credentials")
# sagemaker config
col1, col2 = st.beta_columns(2)
config["aws_access_key_id"] = col1.text_input("Aws Secret Key ID")
config["aws_secret_accesskey"] = col2.text_input("Aws Secret Access Key")
if use_auth_token:
parameter["use_auth_token"] = use_auth_token
if st.button("Start training on SageMaker"):
train_estimtator(parameter, config)
if __name__ == "__main__":
main()