RobbiePasquale commited on
Commit
85033dd
·
verified ·
1 Parent(s): 5977163

Update main_menu.py

Browse files
Files changed (1) hide show
  1. main_menu.py +118 -61
main_menu.py CHANGED
@@ -1,61 +1,118 @@
1
- # main_menu.py
2
-
3
- import argparse
4
- import sys
5
- from train_agent import train_agent
6
- from test_agent import TestAgent, run_test_session
7
- from lightbulb import main as world_model_main
8
-
9
- def parse_main_args():
10
- parser = argparse.ArgumentParser(description="Main Menu for Selecting Tasks")
11
- parser.add_argument('--task', type=str, choices=['train_llm_world', 'train_agent', 'test_agent'],
12
- required=True, help='Choose task to execute: train_llm_world, train_agent, test_agent')
13
- # Optional arguments for more granular control
14
- parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name for LLM')
15
- parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name for training')
16
- parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
17
- parser.add_argument('--batch_size', type=int, default=4, help='Batch size for training')
18
- parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs for training')
19
- parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length for training')
20
- parser.add_argument('--mode', type=str, choices=['train', 'inference'], default='train', help='Train or inference mode for LLM')
21
- parser.add_argument('--query', type=str, default='', help='Query for the test_agent')
22
- return parser.parse_args()
23
-
24
- def main():
25
- # Parse arguments for the main function
26
- args = parse_main_args()
27
-
28
- # Execute tasks based on user input
29
- if args.task == 'train_llm_world':
30
- print("Starting LLM and World Model Training...")
31
- # Directly call the world model main function
32
- sys.argv = ['lightbulb.py', '--mode', args.mode, '--model_name', args.model_name,
33
- '--dataset_name', args.dataset_name, '--dataset_config', args.dataset_config,
34
- '--batch_size', str(args.batch_size), '--num_epochs', str(args.num_epochs),
35
- '--max_length', str(args.max_length)]
36
- world_model_main()
37
-
38
- elif args.task == 'train_agent':
39
- print("Starting Agent Training...")
40
- # Call the train_agent function from train_agent.py
41
- from twisted.internet import reactor, task
42
- d = task.deferLater(reactor, 0, train_agent)
43
- d.addErrback(lambda failure: print(f"An error occurred: {failure}", exc_info=True))
44
- d.addBoth(lambda _: reactor.stop())
45
- reactor.run()
46
-
47
- elif args.task == 'test_agent':
48
- print("Starting Test Agent...")
49
- test_agent = TestAgent()
50
- if args.query:
51
- # Directly process a single query
52
- result = test_agent.process_query(args.query)
53
- print("\nAgent's response:")
54
- print(result)
55
- else:
56
- # Run the interactive session
57
- reactor.callWhenRunning(run_test_session)
58
- reactor.run()
59
-
60
- if __name__ == "__main__":
61
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main_menu.py
2
+
3
+ import argparse
4
+ import sys
5
+ import os
6
+ from train_agent import train_agent
7
+ from test_agent import TestAgent, run_test_session
8
+ from lightbulb import main as world_model_main
9
+ from lightbulb_inf import main as inference_main # Import the inference main function
10
+ from twisted.internet import reactor, task
11
+
12
+ def parse_main_args():
13
+ parser = argparse.ArgumentParser(description="Main Menu for Selecting Tasks")
14
+ parser.add_argument('--task', type=str, choices=[
15
+ 'train_llm_world',
16
+ 'train_agent',
17
+ 'test_agent',
18
+ 'inference_llm',
19
+ 'inference_world_model',
20
+ 'advanced_inference'
21
+ ],
22
+ required=True,
23
+ help='Choose task to execute: train_llm_world, train_agent, test_agent, inference_llm, inference_world_model, advanced_inference')
24
+ # Optional arguments for more granular control
25
+ parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name for LLM')
26
+ parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name for training')
27
+ parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
28
+ parser.add_argument('--batch_size', type=int, default=4, help='Batch size for training')
29
+ parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs for training')
30
+ parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length for training')
31
+ parser.add_argument('--mode', type=str, choices=['train', 'inference'], default='train', help='Train or inference mode for LLM')
32
+ parser.add_argument('--query', type=str, default='', help='Query for the test_agent or inference tasks')
33
+ # Additional arguments specific to inference can be added here if needed
34
+ return parser.parse_args()
35
+
36
+ def main():
37
+ # Parse arguments for the main function
38
+ args = parse_main_args()
39
+
40
+ # Execute tasks based on user input
41
+ if args.task == 'train_llm_world':
42
+ print("Starting LLM and World Model Training...")
43
+ # Directly call the world model main function with appropriate arguments
44
+ sys.argv = [
45
+ 'lightbulb.py',
46
+ '--mode', args.mode,
47
+ '--model_name', args.model_name,
48
+ '--dataset_name', args.dataset_name,
49
+ '--dataset_config', args.dataset_config,
50
+ '--batch_size', str(args.batch_size),
51
+ '--num_epochs', str(args.num_epochs),
52
+ '--max_length', str(args.max_length)
53
+ ]
54
+ world_model_main()
55
+
56
+ elif args.task == 'train_agent':
57
+ print("Starting Agent Training...")
58
+ # Call the train_agent function from train_agent.py using Twisted reactor
59
+ d = task.deferLater(reactor, 0, train_agent)
60
+ d.addErrback(lambda failure: print(f"An error occurred: {failure}", exc_info=True))
61
+ d.addBoth(lambda _: reactor.stop())
62
+ reactor.run()
63
+
64
+ elif args.task == 'test_agent':
65
+ print("Starting Test Agent...")
66
+ test_agent = TestAgent()
67
+ if args.query:
68
+ # Directly process a single query
69
+ result = test_agent.process_query(args.query)
70
+ print("\nAgent's response:")
71
+ print(result)
72
+ else:
73
+ # Run the interactive session
74
+ reactor.callWhenRunning(run_test_session)
75
+ reactor.run()
76
+
77
+ elif args.task in ['inference_llm', 'inference_world_model', 'advanced_inference']:
78
+ print("Starting Inference Task...")
79
+ # Prepare the arguments for lightbulb_inf.py based on the selected inference task
80
+
81
+ # Map the main_menu task to lightbulb_inf.py's inference_mode
82
+ inference_mode_map = {
83
+ 'inference_llm': 'without_world_model',
84
+ 'inference_world_model': 'world_model',
85
+ 'advanced_inference': 'world_model_tree_of_thought'
86
+ }
87
+
88
+ selected_inference_mode = inference_mode_map.get(args.task, 'world_model_tree_of_thought')
89
+
90
+ # Construct sys.argv for lightbulb_inf.py
91
+ lightbulb_inf_args = [
92
+ 'lightbulb_inf.py',
93
+ '--mode', 'inference',
94
+ '--model_name', args.model_name,
95
+ '--query', args.query,
96
+ '--max_length', str(args.max_length),
97
+ '--inference_mode', selected_inference_mode,
98
+ '--beam_size', str(getattr(args, 'beam_size', 5)),
99
+ '--n_tokens_predict', str(getattr(args, 'n_tokens_predict', 3)),
100
+ '--mcts_iterations', str(getattr(args, 'mcts_iterations', 10)),
101
+ '--mcts_exploration_constant', str(getattr(args, 'mcts_exploration_constant', 1.414))
102
+ ]
103
+
104
+ # Include additional arguments if they exist
105
+ if hasattr(args, 'load_model') and args.load_model:
106
+ lightbulb_inf_args += ['--load_model', args.load_model]
107
+
108
+ # Update sys.argv and call the inference main function
109
+ sys.argv = lightbulb_inf_args
110
+ inference_main()
111
+
112
+ else:
113
+ print(f"Unknown task: {args.task}")
114
+ sys.exit(1)
115
+
116
+ if __name__ == "__main__":
117
+ main()
118
+