File size: 8,998 Bytes
7f47926
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
# main_menu.py

import argparse
import sys
import os
from train_agent import train_agent
from test_agent import TestAgent, run_test_session
from twisted.internet import reactor, task
from lightbulb_custom import main as lightbulb_custom_main
from distillation_pipeline import distill_model  # Import the distillation function
from transformers import logging

# Suppress transformers warnings for cleaner output
logging.set_verbosity_error()

def parse_main_args():
    parser = argparse.ArgumentParser(description="Main Menu for Selecting Tasks")
    
    # Task selection
    parser.add_argument('--task', type=str, choices=[
                        'train_llm_world', 
                        'train_agent', 
                        'test_agent', 
                        'inference_llm', 
                        'inference_world_model', 
                        'advanced_inference',
                        'distill_full_model',       # New option for full model distillation
                        'distill_domain_specific'   # New option for selective distillation
                    ], 
                        required=True, 
                        help='Choose task to execute: train_llm_world, train_agent, test_agent, inference_llm, inference_world_model, advanced_inference, distill_full_model, distill_domain_specific')
    
    # Common arguments
    parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name for LLM')
    parser.add_argument('--student_model_name', type=str, default='distilgpt2', help='Name of the student model for distillation')
    parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name for training')
    parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
    parser.add_argument('--batch_size', type=int, default=4, help='Batch size for training')
    parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs for training')
    parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length for training')
    parser.add_argument('--temperature', type=float, default=2.0, help='Distillation temperature')
    parser.add_argument('--learning_rate', type=float, default=5e-5, help='Learning rate')
    
    # Distillation-specific arguments
    parser.add_argument('--save_path', type=str, default="./distilled_model", help="Path to save the distilled model")
    parser.add_argument('--log_dir', type=str, default="./logs", help="Directory for TensorBoard logs")
    parser.add_argument('--checkpoint_dir', type=str, default="./checkpoints", help="Directory to save checkpoints")
    parser.add_argument('--early_stopping_patience', type=int, default=3, help="Early stopping patience")
    
    # Inference-specific arguments
    parser.add_argument('--query', type=str, default='', help='Query for the test_agent or inference tasks')
    parser.add_argument('--inference_mode', type=str, choices=['without_world_model', 'world_model', 'world_model_tree_of_thought'], help='Inference mode')
    parser.add_argument('--beam_size', type=int, default=5, help='Beam size for beam search during inference')
    parser.add_argument('--n_tokens_predict', type=int, default=3, help='Number of tokens to predict at each step during inference')
    parser.add_argument('--mcts_iterations', type=int, default=10, help='Number of MCTS iterations during inference')
    parser.add_argument('--mcts_exploration_constant', type=float, default=1.414, help='Exploration constant for MCTS during inference')
    
    # Distillation-specific arguments
    parser.add_argument('--distill_full_model', action="store_true", help="Whether to distill the full model or not")
    parser.add_argument('--query_terms', type=str, nargs="+", help="Query terms for domain-specific distillation")
    
    # Load model for inference
    parser.add_argument('--load_model', type=str, help='Path to load the distilled model for inference')
    
    return parser.parse_args()

def main():
    # Parse arguments for the main function
    args = parse_main_args()
    
    # Execute tasks based on user input
    if args.task == 'train_llm_world':
        print("Starting LLM and World Model Training...")
        # Directly call the world model main function with appropriate arguments
        sys.argv = [
            'lightbulb_custom.py', 
            '--mode', 'train', 
            '--model_name', args.model_name, 
            '--dataset_name', args.dataset_name, 
            '--dataset_config', args.dataset_config,
            '--batch_size', str(args.batch_size), 
            '--num_epochs', str(args.num_epochs),
            '--max_length', str(args.max_length)
        ]
        lightbulb_custom_main()
    
    elif args.task == 'train_agent':
        print("Starting Agent Training...")
        # Call the train_agent function from train_agent.py using Twisted reactor
        d = task.deferLater(reactor, 0, train_agent)
        d.addErrback(lambda failure: print(f"An error occurred: {failure}", exc_info=True))
        d.addBoth(lambda _: reactor.stop())
        reactor.run()
    
    elif args.task == 'test_agent':
        print("Starting Test Agent...")
        test_agent = TestAgent()
        if args.query:
            # Directly process a single query
            result = test_agent.process_query(args.query)
            print("\nAgent's response:")
            print(result)
        else:
            # Run the interactive session
            reactor.callWhenRunning(run_test_session)
            reactor.run()
    
    elif args.task in ['inference_llm', 'inference_world_model', 'advanced_inference']:
        print("Starting Inference Task...")
        # Prepare the arguments for lightbulb_custom.py based on the selected inference task
    
        # Map the main_menu task to lightbulb_custom.py's inference_mode
        inference_mode_map = {
            'inference_llm': 'without_world_model',
            'inference_world_model': 'world_model',
            'advanced_inference': 'world_model_tree_of_thought'
        }
    
        selected_inference_mode = inference_mode_map.get(args.task, 'world_model_tree_of_thought')
    
        # Construct sys.argv for lightbulb_custom.py
        lightbulb_inf_args = [
            'lightbulb_custom.py',
            '--mode', 'inference',
            '--model_name', args.model_name,
            '--query', args.query,
            '--max_length', str(args.max_length),
            '--inference_mode', selected_inference_mode,
            '--beam_size', str(args.beam_size),
            '--n_tokens_predict', str(args.n_tokens_predict),
            '--mcts_iterations', str(args.mcts_iterations),
            '--mcts_exploration_constant', str(args.mcts_exploration_constant)
        ]
    
        # Include additional arguments if they exist
        if args.load_model:
            lightbulb_inf_args += ['--load_model', args.load_model]
    
        # Update sys.argv and call the inference main function
        sys.argv = lightbulb_inf_args
        lightbulb_custom_main()
    
    elif args.task == 'distill_full_model':
        print("Starting Full Model Distillation...")
        distill_model(
            teacher_model_name=args.model_name,
            student_model_name=args.student_model_name,
            dataset_name=args.dataset_name,
            config=args.dataset_config,
            distill_full_model=True,
            query_terms=None,
            num_epochs=args.num_epochs,
            batch_size=args.batch_size,
            max_length=args.max_length,
            learning_rate=args.learning_rate,
            temperature=args.temperature,
            save_path=args.save_path,
            log_dir=args.log_dir,
            checkpoint_dir=args.checkpoint_dir,
            early_stopping_patience=args.early_stopping_patience
        )
    
    elif args.task == 'distill_domain_specific':
        print("Starting Domain-Specific Distillation...")
        if not args.query_terms:
            print("Error: --query_terms must be provided for domain-specific distillation.")
            sys.exit(1)
        distill_model(
            teacher_model_name=args.model_name,
            student_model_name=args.student_model_name,
            dataset_name=args.dataset_name,
            config=args.dataset_config,
            distill_full_model=False,
            query_terms=args.query_terms,
            num_epochs=args.num_epochs,
            batch_size=args.batch_size,
            max_length=args.max_length,
            learning_rate=args.learning_rate,
            temperature=args.temperature,
            save_path=args.save_path,
            log_dir=args.log_dir,
            checkpoint_dir=args.checkpoint_dir,
            early_stopping_patience=args.early_stopping_patience
        )
    
    else:
        print(f"Unknown task: {args.task}")
        sys.exit(1)

if __name__ == "__main__":
    main()