llama2-13b-cpt / app.py
suolyer's picture
Update app.py
85f1687
raw
history blame
693 Bytes
import gradio as gr
import plotly.express as px
import numpy as np
import wandb
api = wandb.Api()
def get_plot(period=1):
run = api.run("fengshenbang/llama2_13b_cpt_v1/kakfv1ab")
metrics_dataframe = run.history()
y = metrics_dataframe['train/lm_loss'].tolist()
x = np.arange(0, len(y), 1)
y = np.array(y)
fig = px.line(x=x, y=y)
return fig
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
gr.Markdown("Ziya-LLaMA2-CPT/train/lm_loss")
plot = gr.Plot(label="Plot (updates every half second)")
dep = demo.load(get_plot, None, plot, every=1)
if __name__ == "__main__":
demo.queue().launch()