Spaces:
Running
Running
import os | |
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
def load_model(model_size: str = "32B"): | |
""" | |
Load model and tokenizer based on size selection | |
Note: You'll need to replace these with actual HuggingFace model IDs | |
""" | |
model_map = { | |
"0.5B": "Qwen/Qwen2.5-Coder-0.5B", | |
"1.5B": "Qwen/Qwen2.5-Coder-1.5B", | |
"7B": "Qwen/Qwen2.5-Coder-7B", | |
# ... add other model sizes as needed | |
} | |
model_id = model_map.get(model_size, "Qwen/Qwen2.5-Coder-7B") | |
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
return model, tokenizer | |
def process_query(query: str, model_size: str = "7B") -> str: | |
""" | |
Process a single query and return the response | |
""" | |
if not query: | |
return "" | |
try: | |
model, tokenizer = load_model(model_size) | |
# Prepare the input | |
inputs = tokenizer(query, return_tensors="pt").to(model.device) | |
# Generate response | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.9, | |
pad_token_id=tokenizer.pad_token_id | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response.replace(query, "").strip() | |
except Exception as e: | |
return f"Error: {str(e)}" | |
def main(): | |
st.title("Qwen2.5-Coder Interface") | |
# Model size selection | |
model_size = st.radio( | |
"Select Model Size:", | |
options=["0.5B", "1.5B", "3B", "7B", "14B", "32B"], | |
index=5 # Default to 32B (last option) | |
) | |
# Input text area | |
query = st.text_area( | |
"Input", | |
placeholder="Enter your query here...", | |
height=150 | |
) | |
# Generate button | |
if st.button("Generate"): | |
if query: | |
with st.spinner("Generating response..."): | |
response = process_query(query, model_size) | |
st.text_area("Output", value=response, height=300) | |
else: | |
st.warning("Please enter a query first.") | |
if __name__ == "__main__": | |
main() |