dad1909 commited on
Commit
55fe000
1 Parent(s): 493a1a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -79
app.py CHANGED
@@ -1,90 +1,158 @@
1
  import os
2
- import gradio as gr
3
  import torch
4
- from transformers import TextStreamer, AutoModelForCausalLM, AutoTokenizer
5
- import spaces
6
-
7
- # Define the model configurations
8
- model_configs = {
9
- "CyberSentinel": {
10
- "model_name": "dad1909/cybersentinal-2.0",
11
- "max_seq_length": 1028,
12
- "dtype": torch.float16,
13
- "load_in_4bit": True
14
- }
15
- }
16
 
17
- # Hugging Face token
 
 
18
  hf_token = os.getenv("HF_TOKEN")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- # Load the model when the application starts
21
- loaded_models = {}
22
-
23
- def load_model(selected_model):
24
- if selected_model not in loaded_models:
25
- config = model_configs[selected_model]
26
- model = AutoModelForCausalLM.from_pretrained(
27
- config["model_name"],
28
- torch_dtype=config["dtype"],
29
- device_map="auto",
30
- use_auth_token=hf_token
31
- )
32
- tokenizer = AutoTokenizer.from_pretrained(
33
- config["model_name"],
34
- use_auth_token=hf_token
35
- )
36
- loaded_models[selected_model] = (model, tokenizer)
37
- return loaded_models[selected_model]
38
-
39
- alpaca_prompts = {
40
- "information": "Give me information about the following topic: {}",
41
- "vulnerable": """Identify the line of code that is vulnerable and describe the type of software vulnerability.
42
  ### Code Snippet:
43
  {}
44
- ### Vulnerability Description:""",
45
- "Chat": "{}"
46
  }
47
 
48
- @spaces.GPU(duration=100)
49
- def predict(selected_model, prompt, prompt_type, max_length=128):
50
- model, tokenizer = load_model(selected_model)
51
- selected_prompt = alpaca_prompts[prompt_type]
52
- formatted_prompt = selected_prompt.format(prompt)
53
- inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda")
54
- text_streamer = TextStreamer(tokenizer)
55
- output = model.generate(**inputs, streamer=text_streamer, max_new_tokens=max_length)
56
- return tokenizer.decode(output[0], skip_special_tokens=True)
57
-
58
- theme = gr.themes.Default(
59
- primary_hue=gr.themes.colors.rose,
60
- secondary_hue=gr.themes.colors.blue,
61
- font=gr.themes.GoogleFont("Source Sans Pro")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  )
 
63
 
64
- load_model("CyberSentinel")
65
-
66
- with gr.Blocks(theme=theme) as demo:
67
- selected_model = gr.Dropdown(choices=list(model_configs.keys()), value="CyberSentinel", label="Model")
68
- prompt = gr.Textbox(lines=5, placeholder="Enter your code snippet or topic here...", label="Prompt")
69
- prompt_type = gr.Dropdown(choices=list(alpaca_prompts.keys()), value="Chat", label="Prompt Type")
70
- max_length = gr.Slider(minimum=128, maximum=512, step=128, value=128, label="Max Length")
71
- generated_text = gr.Textbox(label="Generated Text")
72
-
73
- generate_button = gr.Button("Generate")
74
-
75
- generate_button.click(predict, inputs=[selected_model, prompt, prompt_type, max_length], outputs=generated_text)
76
-
77
- gr.Examples(
78
- examples=[
79
- ["CyberSentinel", "What is SQL injection?", "information", 128],
80
- ["CyberSentinel", "$buff = 'A' x 10000;\nopen(myfile, '>>PASS.PK2');\nprint myfile $buff;\nclose(myfile);", "vulnerable", 128],
81
- ["CyberSentinel", "Can you tell me a joke?", "Chat", 128]
82
- ],
83
- inputs=[selected_model, prompt, prompt_type, max_length]
84
- )
85
-
86
- demo.queue(default_concurrency_limit=20).launch(
87
- server_name="0.0.0.0",
88
- allowed_paths=["/"],
89
- share=True
90
- )
 
1
  import os
 
2
  import torch
3
+ from unsloth import FastLanguageModel, is_bfloat16_supported
4
+ from trl import SFTTrainer
5
+ from transformers import TrainingArguments
6
+ from datasets import load_dataset
7
+ import gradio as gr
8
+ import json
9
+ from huggingface_hub import HfApi
 
 
 
 
 
10
 
11
+ max_seq_length = 4096
12
+ dtype = None
13
+ load_in_4bit = True
14
  hf_token = os.getenv("HF_TOKEN")
15
+ current_num = os.getenv("NUM")
16
+
17
+ print(f"stage ${current_num}")
18
+
19
+ api = HfApi(token=hf_token)
20
+ # models = f"dad1909/cybersentinal-2.0-{current_num}"
21
+ model_base = "unsloth/gemma-2-27b-bnb-4bit"
22
+
23
+ print("Starting model and tokenizer loading...")
24
+
25
+ # Load the model and tokenizer
26
+ model, tokenizer = FastLanguageModel.from_pretrained(
27
+ model_name=model_base,
28
+ max_seq_length=max_seq_length,
29
+ dtype=dtype,
30
+ load_in_4bit=load_in_4bit,
31
+ token=hf_token
32
+ )
33
+
34
+ print("Model and tokenizer loaded successfully.")
35
+
36
+ print("Configuring PEFT model...")
37
+ model = FastLanguageModel.get_peft_model(
38
+ model,
39
+ r=16,
40
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
41
+ lora_alpha=16,
42
+ lora_dropout=0,
43
+ bias="none",
44
+ use_gradient_checkpointing="unsloth",
45
+ random_state=3407,
46
+ use_rslora=False,
47
+ loftq_config=None,
48
+ )
49
+ print("PEFT model configured.")
50
 
51
+ # Updated alpaca_prompt for different types
52
+ alpaca_prompt = {
53
+ "learning_from": """Below is a CVE definition.
54
+ ### CVE definition:
55
+ {}
56
+ ### detail CVE:
57
+ {}""",
58
+ "definition": """Below is a definition about software vulnerability. Explain it.
59
+ ### Definition:
60
+ {}
61
+ ### Explanation:
62
+ {}""",
63
+ "code_vulnerability": """Below is a code snippet. Identify the line of code that is vulnerable and describe the type of software vulnerability.
 
 
 
 
 
 
 
 
 
64
  ### Code Snippet:
65
  {}
66
+ ### Vulnerability solution:
67
+ {}"""
68
  }
69
 
70
+ EOS_TOKEN = tokenizer.eos_token
71
+
72
+ def detect_prompt_type(instruction):
73
+ if instruction.startswith("what is code vulnerable of this code:"):
74
+ return "code_vulnerability"
75
+ elif instruction.startswith("Learning from"):
76
+ return "learning_from"
77
+ elif instruction.startswith("what is"):
78
+ return "definition"
79
+ else:
80
+ return "unknown"
81
+
82
+ def formatting_prompts_func(examples):
83
+ instructions = examples["instruction"]
84
+ outputs = examples["output"]
85
+ texts = []
86
+
87
+ for instruction, output in zip(instructions, outputs):
88
+ prompt_type = detect_prompt_type(instruction)
89
+ if prompt_type in alpaca_prompt:
90
+ prompt = alpaca_prompt[prompt_type].format(instruction, output)
91
+ else:
92
+ prompt = instruction + "\n\n" + output
93
+ text = prompt + EOS_TOKEN
94
+ texts.append(text)
95
+
96
+ return {"text": texts}
97
+
98
+ print("Loading dataset...")
99
+ dataset = load_dataset("dad1909/DCSV", split="train")
100
+ print("Dataset loaded successfully.")
101
+
102
+ print("Applying formatting function to the dataset...")
103
+ dataset = dataset.map(formatting_prompts_func, batched=True)
104
+ print("Formatting function applied.")
105
+
106
+ print("Initializing trainer...")
107
+ trainer = SFTTrainer(
108
+ model=model,
109
+ tokenizer=tokenizer,
110
+ train_dataset=dataset,
111
+ dataset_text_field="text",
112
+ max_seq_length=max_seq_length,
113
+ dataset_num_proc=2,
114
+ packing=False,
115
+ args=TrainingArguments(
116
+ per_device_train_batch_size=1,
117
+ gradient_accumulation_steps=1,
118
+ learning_rate=2e-4,
119
+ fp16=not is_bfloat16_supported(),
120
+ bf16=is_bfloat16_supported(),
121
+ warmup_steps=5,
122
+ logging_steps=10,
123
+ max_steps=100,
124
+ optim="adamw_8bit",
125
+ weight_decay=0.01,
126
+ lr_scheduler_type="linear",
127
+ seed=3407,
128
+ output_dir="outputs"
129
+ ),
130
+ )
131
+ print("Trainer initialized.")
132
+
133
+ print("Starting training...")
134
+ trainer_stats = trainer.train()
135
+ print("Training completed.")
136
+
137
+ num = int(current_num)
138
+ num += 1
139
+
140
+ uploads_models = f"cybersentinal-2.0-{str(num)}"
141
+
142
+ up = "sentinal-3.1-70B"
143
+
144
+ print("Saving the trained model...")
145
+ model.save_pretrained_merged("model", tokenizer, save_method="merged_16bit")
146
+ print("Model saved successfully.")
147
+
148
+ print("Pushing the model to the hub...")
149
+ model.push_to_hub_merged(
150
+ up,
151
+ tokenizer,
152
+ save_method="merged_16bit",
153
+ token=hf_token
154
  )
155
+ print("Model pushed to hub successfully.")
156
 
157
+ api.delete_space_variable(repo_id="dad1909/CyberCode", key="NUM")
158
+ api.add_space_variable(repo_id="dad1909/CyberCode", key="NUM", value=str(num))