File size: 4,418 Bytes
ca4fc4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

"""Run pytest using MCP."""

import argparse
import time

from mcli.sdk import (RunConfig, RunStatus, create_run, follow_run_logs,
                      stop_run, wait_for_run_status)

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--name',
                        type=str,
                        default='mcp-pytest',
                        help='Base name of run')
    parser.add_argument('--cluster',
                        type=str,
                        default='r1z4',
                        help='Cluster to use')
    parser.add_argument('--gpu_type',
                        type=str,
                        default='a100_40gb',
                        help='Type of GPU to use')
    parser.add_argument('--gpu_num',
                        type=int,
                        default=2,
                        help='Number of the GPU to use')
    parser.add_argument('--image',
                        type=str,
                        default='mosaicml/pytorch:latest',
                        help='Docker image to use')
    parser.add_argument('--git_branch',
                        type=str,
                        help='Git branch to check out')
    parser.add_argument(
        '--git_commit',
        type=str,
        help='Git commit to check out. Overrides git_branch if specified')
    parser.add_argument(
        '--pr_number',
        type=int,
        help=
        'PR number to check out. Overrides git_branch/git_commit if specified')
    parser.add_argument('--pytest_markers',
                        type=str,
                        help='Markers to pass to pytest')
    parser.add_argument('--pytest_command',
                        type=str,
                        help='Command to run pytest')
    parser.add_argument('--timeout',
                        type=int,
                        default=1800,
                        help='Timeout for run (in seconds)')
    args = parser.parse_args()

    name = args.name
    git_integration = {
        'integration_type': 'git_repo',
        'git_repo': 'mosaicml/llm-foundry',
        'ssh_clone': 'False',
    }
    if args.git_branch is not None and args.git_commit is None:
        name += f'-branch-{args.git_branch}'
        git_integration['git_branch'] = args.git_branch
    if args.git_commit is not None:
        name += f'-commit-{args.git_commit}'
        git_integration['git_commit'] = args.git_commit

    command = 'cd llm-foundry'

    # Checkout a specific PR if specified
    if args.pr_number is not None:
        name += f'-pr-{args.pr_number}'
        command += f'''

        git fetch origin pull/{args.pr_number}/head:pr_branch

        git checkout pr_branch

        '''

    # Shorten name if too long
    if len(name) > 56:
        name = name[:56]

    command += f'''

    pip install --upgrade --user .[all]

    export COMMON_ARGS="-v --durations=20 -m '{args.pytest_markers}'"

    make test PYTEST='{args.pytest_command}' EXTRA_ARGS="$COMMON_ARGS --codeblocks"

    make test-dist PYTEST='{args.pytest_command}' EXTRA_ARGS="$COMMON_ARGS" WORLD_SIZE=2

    python -m coverage combine

    python -m coverage report
    '''

    config = RunConfig(
        name=name,
        cluster=args.cluster,
        gpu_type=args.gpu_type,
        gpu_num=args.gpu_num,
        image=args.image,
        integrations=[git_integration],
        command=command,
    )

    # Create run
    run = create_run(config)
    print(f'[GHA] Run created: {run.name}')

    # Wait until run starts before fetching logs
    run = wait_for_run_status(run, status='running')
    start_time = time.time()
    print('[GHA] Run started. Following logs...')

    # Print logs
    for line in follow_run_logs(run):
        print(line, end='')
        # Check if args.timeout seconds have elapsed
        if time.time() - start_time > args.timeout:
            print(
                f'[GHA] Run timed out and did not complete in {args.timeout/60} minutes.'
            )
            run = stop_run(run)
            print('[GHA] Run stopped.')
            break

    print('[GHA] Run completed. Waiting for run to finish...')
    run = wait_for_run_status(run, status='completed')

    # Fail if command exited with non-zero exit code or timed out
    assert run.status == RunStatus.COMPLETED