Spaces:
Configuration error
Configuration error
import streamlit as st | |
from utils.load_dataset import load_datasets | |
from utils.load_tasks import load_tasks | |
from utils.load_models import load_models | |
from trainer import train_estimtator | |
from datetime import datetime | |
import logging | |
logger = logging.getLogger(__name__) | |
def main(): | |
parameter = st.experimental_get_query_params() | |
parameter["model_name_or_path"] = parameter.get("model_name_or_path", ["none"]) | |
parameter["dataset"] = parameter.get("dataset", ["none"]) | |
parameter["task"] = parameter.get("task", ["none"]) | |
### hyperparameter | |
parameter["epochs"] = parameter.get("epochs", [3]) | |
parameter["learning_rate"] = parameter.get("learning_rate", [5e-5]) | |
parameter["per_device_train_batch_size"] = parameter.get("per_device_train_batch_size", [8]) | |
parameter["per_device_eval_batch_size"] = parameter.get("per_device_eval_batch_size", [8]) | |
st.experimental_set_query_params(**parameter) | |
dataset_list = load_datasets() | |
task_list = load_tasks() | |
model_list = load_models() | |
st.header("Hugging Face model & dataset") | |
col1, col2 = st.beta_columns(2) | |
parameter["model_name_or_path"] = col1.selectbox("Model ID:", parameter["model_name_or_path"] + model_list) | |
st.experimental_set_query_params(**parameter) | |
parameter["dataset"] = col2.selectbox("Dataset:", parameter["dataset"] + dataset_list) | |
st.experimental_set_query_params(**parameter) | |
parameter["task"] = col1.selectbox("Task:", parameter["task"] + task_list) | |
st.experimental_set_query_params(**parameter) | |
use_auth_token = col2.text_input("HF auth token to upload your model:", help="api_xxxxx") | |
my_expander = st.beta_expander("Hyperparameters") | |
col1, col2 = my_expander.beta_columns(2) | |
parameter["epochs"] = col1.number_input("Epoch", 3) | |
st.experimental_set_query_params(**parameter) | |
parameter["learning_rate"] = col2.text_input("Learning Rate", 5e-5) | |
st.experimental_set_query_params(**parameter) | |
parameter["per_device_train_batch_size"] = col1.number_input("Training Batch Size", 8) | |
st.experimental_set_query_params(**parameter) | |
parameter["per_device_eval_batch_size"] = col2.number_input("Eval Batch Size", 8) | |
st.experimental_set_query_params(**parameter) | |
st.markdown("---") | |
st.header("Amazon Sagemaker configuration") | |
config = {} | |
config["job_name"] = st.text_input( | |
"model name", | |
f"{parameter['model_name_or_path'][0] if isinstance(parameter['model_name_or_path'],list)else parameter['model_name_or_path']}-job-{str(datetime.today()).split()[0]}", | |
) | |
col1, col2 = st.beta_columns(2) | |
config["aws_sagemaker_role"] = col1.text_input("AWS IAM role for sagemaker job") | |
config["instance_type"] = col2.selectbox( | |
"Instance type", | |
[ | |
"single-gpu | ml.p3.2xlarge", | |
"multi-gpu | ml.p3.16xlarge", | |
], | |
) | |
config["region"] = col1.selectbox( | |
"AWS Region", | |
["eu-central-1", "eu-west-1", "us-east-1", "us-east-1", "us-west-1", "us-west-2"], | |
) | |
config["instance_count"] = col2.number_input("Instance count", 1) | |
config["use_spot"] = col1.selectbox("use spot instances", [False, True]) | |
config["distributed"] = col2.selectbox("distributed training", [False, True]) | |
st.markdown("---") | |
st.header("Credentials") | |
# sagemaker config | |
col1, col2 = st.beta_columns(2) | |
config["aws_access_key_id"] = col1.text_input("Aws Secret Key ID") | |
config["aws_secret_accesskey"] = col2.text_input("Aws Secret Access Key") | |
if use_auth_token: | |
parameter["use_auth_token"] = use_auth_token | |
if st.button("Start training on SageMaker"): | |
train_estimtator(parameter, config) | |
if __name__ == "__main__": | |
main() | |