Upload 79 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- Andromeda/.DS_Store +0 -0
- Andromeda/.env +6 -0
- Andromeda/.github/ISSUE_TEMPLATE/---bug-report.md +36 -0
- Andromeda/.github/ISSUE_TEMPLATE/---feature-request.md +25 -0
- Andromeda/.github/ISSUE_TEMPLATE/---model-questions.md +17 -0
- Andromeda/.github/mcp/mcp_pytest.py +139 -0
- Andromeda/.github/workflows/FUNDING.md +13 -0
- Andromeda/.github/workflows/code-quality.yaml +44 -0
- Andromeda/.github/workflows/codeql-analysis.yml +70 -0
- Andromeda/.github/workflows/coverage.yaml +32 -0
- Andromeda/.github/workflows/docker.yaml +62 -0
- Andromeda/.github/workflows/pr-cpu.yaml +43 -0
- Andromeda/.github/workflows/pr-gpu.yaml +40 -0
- Andromeda/.github/workflows/pytest-cpu.yaml +48 -0
- Andromeda/.github/workflows/pytest-gpu.yaml +80 -0
- Andromeda/.github/workflows/python-publish.yml +39 -0
- Andromeda/.github/workflows/release.yaml +60 -0
- Andromeda/.gitignore +2 -0
- Andromeda/Andromeda/README.md +121 -0
- Andromeda/Andromeda/__init__.py +3 -0
- Andromeda/Andromeda/configs.py +128 -0
- Andromeda/Andromeda/core/__init__.py +8 -0
- Andromeda/Andromeda/core/attend.py +252 -0
- Andromeda/Andromeda/core/autoregressive_wrapper.py +150 -0
- Andromeda/Andromeda/core/flash.py +289 -0
- Andromeda/Andromeda/core/transformer.py +1376 -0
- Andromeda/Andromeda/dataset_prep/__init__.py +0 -0
- Andromeda/Andromeda/dataset_prep/books.py +12 -0
- Andromeda/Andromeda/inference.py +198 -0
- Andromeda/Andromeda/model.py +118 -0
- Andromeda/Andromeda/old/__init__.py +0 -0
- Andromeda/Andromeda/old/sophia.py +200 -0
- Andromeda/Andromeda/old/training.py +294 -0
- Andromeda/Andromeda/old/training_1.py +350 -0
- Andromeda/Andromeda/old/training_sophia.py +369 -0
- Andromeda/Andromeda/train.py +700 -0
- Andromeda/Andromeda/utils/__init__.py +0 -0
- Andromeda/Andromeda/utils/decoupled_optimizer.py +147 -0
- Andromeda/Andromeda/utils/helpers.py +17 -0
- Andromeda/Andromeda/utils/rf_utils.py +186 -0
- Andromeda/Andromeda/utils/stable_adamw.py +96 -0
- Andromeda/DOCs/Corporation/MONETIZATION.md +51 -0
- Andromeda/DOCs/Design/Dyson.md +26 -0
- Andromeda/DOCs/Design/MODEL_ARCHITECTURE.md +57 -0
- Andromeda/DOCs/Design/SPEED.md +11 -0
- Andromeda/DOCs/Design/Specs.md +196 -0
- Andromeda/DOCs/Docs/DOCUMENTATION.md +145 -0
- Andromeda/DOCs/Docs/TRAINING.md +82 -0
- Andromeda/DOCs/Docs/Training/DATASET_STRATEGY.md +100 -0
.gitattributes
CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
Andromeda/images/andromeda-banner.png filter=lfs diff=lfs merge=lfs -text
|
Andromeda/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
Andromeda/.env
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MASTER_ADDR=""
|
2 |
+
MASTER_PORT=""
|
3 |
+
RANK=""
|
4 |
+
WORLD_SIZE=""
|
5 |
+
# export TORCH_CPP_LOG_LEVEL=INFO NCCL_DEBUG=INFO
|
6 |
+
|
Andromeda/.github/ISSUE_TEMPLATE/---bug-report.md
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: "\U0001F41B Bug report"
|
3 |
+
about: Submit a bug report to improve our library!
|
4 |
+
title: ''
|
5 |
+
labels: bug
|
6 |
+
assignees: ''
|
7 |
+
|
8 |
+
---
|
9 |
+
|
10 |
+
<!-- Please check for related issues (both open and closed) before filing this issue. -->
|
11 |
+
|
12 |
+
## Environment
|
13 |
+
<!-- Please copy paste the output of running `composer_collect_env` below-->
|
14 |
+
<!--
|
15 |
+
If you can't install composer for some reason, you can also use the PyTorch collect env script
|
16 |
+
|
17 |
+
wget https://raw.githubusercontent.com/pytorch/pytorch/main/torch/utils/collect_env.py
|
18 |
+
# For security purposes, please check the contents of collect_env.py before running it.
|
19 |
+
python collect_env.py
|
20 |
+
-->
|
21 |
+
|
22 |
+
## To reproduce
|
23 |
+
|
24 |
+
Steps to reproduce the behavior:
|
25 |
+
|
26 |
+
1.
|
27 |
+
2.
|
28 |
+
3.
|
29 |
+
|
30 |
+
## Expected behavior
|
31 |
+
|
32 |
+
<!-- A clear and concise description of what you would expect to happen. -->
|
33 |
+
|
34 |
+
## Additional context
|
35 |
+
|
36 |
+
<!-- Please provide any additional context. -->
|
Andromeda/.github/ISSUE_TEMPLATE/---feature-request.md
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: "\U0001F680 Feature request"
|
3 |
+
about: Suggest an idea for this project
|
4 |
+
title: ''
|
5 |
+
labels: enhancement
|
6 |
+
assignees: ''
|
7 |
+
|
8 |
+
---
|
9 |
+
|
10 |
+
<!-- Please check for related feature requests (both open and closed) before filing this request. -->
|
11 |
+
|
12 |
+
## 🚀 Feature Request
|
13 |
+
<!-- A clear and concise description of the feature proposal -->
|
14 |
+
|
15 |
+
## Motivation
|
16 |
+
|
17 |
+
<!-- Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too -->
|
18 |
+
|
19 |
+
## [Optional] Implementation
|
20 |
+
|
21 |
+
<!-- Optionally, sketch out an implementation or interface needed. -->
|
22 |
+
|
23 |
+
## Additional context
|
24 |
+
|
25 |
+
<!-- Add any other context or screenshots about the feature request here. -->
|
Andromeda/.github/ISSUE_TEMPLATE/---model-questions.md
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: "\U00002753 Model-related question"
|
3 |
+
about: Ask a question about using our released models
|
4 |
+
title: ''
|
5 |
+
labels: question
|
6 |
+
assignees: ''
|
7 |
+
|
8 |
+
---
|
9 |
+
|
10 |
+
<!-- Please check for related question (both open and closed) before filing this question. -->
|
11 |
+
|
12 |
+
## ❓ Question
|
13 |
+
<!-- A clear and concise description of the question -->
|
14 |
+
|
15 |
+
## Additional context
|
16 |
+
|
17 |
+
<!-- Add any other context or screenshots about the feature request here. -->
|
Andromeda/.github/mcp/mcp_pytest.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
"""Run pytest using MCP."""
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import time
|
8 |
+
|
9 |
+
from mcli.sdk import (RunConfig, RunStatus, create_run, follow_run_logs,
|
10 |
+
stop_run, wait_for_run_status)
|
11 |
+
|
12 |
+
if __name__ == '__main__':
|
13 |
+
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument('--name',
|
16 |
+
type=str,
|
17 |
+
default='mcp-pytest',
|
18 |
+
help='Base name of run')
|
19 |
+
parser.add_argument('--cluster',
|
20 |
+
type=str,
|
21 |
+
default='r1z4',
|
22 |
+
help='Cluster to use')
|
23 |
+
parser.add_argument('--gpu_type',
|
24 |
+
type=str,
|
25 |
+
default='a100_40gb',
|
26 |
+
help='Type of GPU to use')
|
27 |
+
parser.add_argument('--gpu_num',
|
28 |
+
type=int,
|
29 |
+
default=2,
|
30 |
+
help='Number of the GPU to use')
|
31 |
+
parser.add_argument('--image',
|
32 |
+
type=str,
|
33 |
+
default='mosaicml/pytorch:latest',
|
34 |
+
help='Docker image to use')
|
35 |
+
parser.add_argument('--git_branch',
|
36 |
+
type=str,
|
37 |
+
help='Git branch to check out')
|
38 |
+
parser.add_argument(
|
39 |
+
'--git_commit',
|
40 |
+
type=str,
|
41 |
+
help='Git commit to check out. Overrides git_branch if specified')
|
42 |
+
parser.add_argument(
|
43 |
+
'--pr_number',
|
44 |
+
type=int,
|
45 |
+
help=
|
46 |
+
'PR number to check out. Overrides git_branch/git_commit if specified')
|
47 |
+
parser.add_argument('--pytest_markers',
|
48 |
+
type=str,
|
49 |
+
help='Markers to pass to pytest')
|
50 |
+
parser.add_argument('--pytest_command',
|
51 |
+
type=str,
|
52 |
+
help='Command to run pytest')
|
53 |
+
parser.add_argument('--timeout',
|
54 |
+
type=int,
|
55 |
+
default=1800,
|
56 |
+
help='Timeout for run (in seconds)')
|
57 |
+
args = parser.parse_args()
|
58 |
+
|
59 |
+
name = args.name
|
60 |
+
git_integration = {
|
61 |
+
'integration_type': 'git_repo',
|
62 |
+
'git_repo': 'mosaicml/llm-foundry',
|
63 |
+
'ssh_clone': 'False',
|
64 |
+
}
|
65 |
+
if args.git_branch is not None and args.git_commit is None:
|
66 |
+
name += f'-branch-{args.git_branch}'
|
67 |
+
git_integration['git_branch'] = args.git_branch
|
68 |
+
if args.git_commit is not None:
|
69 |
+
name += f'-commit-{args.git_commit}'
|
70 |
+
git_integration['git_commit'] = args.git_commit
|
71 |
+
|
72 |
+
command = 'cd llm-foundry'
|
73 |
+
|
74 |
+
# Checkout a specific PR if specified
|
75 |
+
if args.pr_number is not None:
|
76 |
+
name += f'-pr-{args.pr_number}'
|
77 |
+
command += f'''
|
78 |
+
|
79 |
+
git fetch origin pull/{args.pr_number}/head:pr_branch
|
80 |
+
|
81 |
+
git checkout pr_branch
|
82 |
+
|
83 |
+
'''
|
84 |
+
|
85 |
+
# Shorten name if too long
|
86 |
+
if len(name) > 56:
|
87 |
+
name = name[:56]
|
88 |
+
|
89 |
+
command += f'''
|
90 |
+
|
91 |
+
pip install --upgrade --user .[all]
|
92 |
+
|
93 |
+
export COMMON_ARGS="-v --durations=20 -m '{args.pytest_markers}'"
|
94 |
+
|
95 |
+
make test PYTEST='{args.pytest_command}' EXTRA_ARGS="$COMMON_ARGS --codeblocks"
|
96 |
+
|
97 |
+
make test-dist PYTEST='{args.pytest_command}' EXTRA_ARGS="$COMMON_ARGS" WORLD_SIZE=2
|
98 |
+
|
99 |
+
python -m coverage combine
|
100 |
+
|
101 |
+
python -m coverage report
|
102 |
+
'''
|
103 |
+
|
104 |
+
config = RunConfig(
|
105 |
+
name=name,
|
106 |
+
cluster=args.cluster,
|
107 |
+
gpu_type=args.gpu_type,
|
108 |
+
gpu_num=args.gpu_num,
|
109 |
+
image=args.image,
|
110 |
+
integrations=[git_integration],
|
111 |
+
command=command,
|
112 |
+
)
|
113 |
+
|
114 |
+
# Create run
|
115 |
+
run = create_run(config)
|
116 |
+
print(f'[GHA] Run created: {run.name}')
|
117 |
+
|
118 |
+
# Wait until run starts before fetching logs
|
119 |
+
run = wait_for_run_status(run, status='running')
|
120 |
+
start_time = time.time()
|
121 |
+
print('[GHA] Run started. Following logs...')
|
122 |
+
|
123 |
+
# Print logs
|
124 |
+
for line in follow_run_logs(run):
|
125 |
+
print(line, end='')
|
126 |
+
# Check if args.timeout seconds have elapsed
|
127 |
+
if time.time() - start_time > args.timeout:
|
128 |
+
print(
|
129 |
+
f'[GHA] Run timed out and did not complete in {args.timeout/60} minutes.'
|
130 |
+
)
|
131 |
+
run = stop_run(run)
|
132 |
+
print('[GHA] Run stopped.')
|
133 |
+
break
|
134 |
+
|
135 |
+
print('[GHA] Run completed. Waiting for run to finish...')
|
136 |
+
run = wait_for_run_status(run, status='completed')
|
137 |
+
|
138 |
+
# Fail if command exited with non-zero exit code or timed out
|
139 |
+
assert run.status == RunStatus.COMPLETED
|
Andromeda/.github/workflows/FUNDING.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# These are supported funding model platforms
|
2 |
+
|
3 |
+
github: [kyegomez]
|
4 |
+
patreon: # Replace with a single Patreon username
|
5 |
+
open_collective: # Replace with a single Open Collective username
|
6 |
+
ko_fi: # Replace with a single Ko-fi username
|
7 |
+
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
|
8 |
+
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
|
9 |
+
liberapay: # Replace with a single Liberapay username
|
10 |
+
issuehunt: # Replace with a single IssueHunt username
|
11 |
+
otechie: # Replace with a single Otechie username
|
12 |
+
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
|
13 |
+
custom: #Nothing
|
Andromeda/.github/workflows/code-quality.yaml
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Code Quality Checks
|
2 |
+
on:
|
3 |
+
push:
|
4 |
+
branches:
|
5 |
+
- main
|
6 |
+
- release/**
|
7 |
+
pull_request:
|
8 |
+
branches:
|
9 |
+
- main
|
10 |
+
- release/**
|
11 |
+
workflow_call:
|
12 |
+
workflow_dispatch:
|
13 |
+
# Cancel old runs when a new commit is pushed to the same branch if not on main or dev
|
14 |
+
concurrency:
|
15 |
+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
16 |
+
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
17 |
+
defaults:
|
18 |
+
run:
|
19 |
+
working-directory: .
|
20 |
+
jobs:
|
21 |
+
code-quality:
|
22 |
+
runs-on: ubuntu-20.04
|
23 |
+
timeout-minutes: 10
|
24 |
+
strategy:
|
25 |
+
matrix:
|
26 |
+
python_version:
|
27 |
+
- '3.8'
|
28 |
+
- '3.9'
|
29 |
+
- '3.10'
|
30 |
+
pip_deps:
|
31 |
+
- '[dev]'
|
32 |
+
steps:
|
33 |
+
- uses: actions/checkout@v3
|
34 |
+
- uses: actions/setup-python@v4
|
35 |
+
with:
|
36 |
+
python-version: ${{ matrix.python_version }}
|
37 |
+
- name: Setup
|
38 |
+
run: |
|
39 |
+
set -ex
|
40 |
+
python -m pip install --upgrade 'pip<23' wheel
|
41 |
+
python -m pip install --upgrade .${{ matrix.pip_deps }}
|
42 |
+
- name: Run checks
|
43 |
+
run: |
|
44 |
+
pre-commit run --all-files
|
Andromeda/.github/workflows/codeql-analysis.yml
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# For most projects, this workflow file will not need changing; you simply need
|
2 |
+
# to commit it to your repository.
|
3 |
+
#
|
4 |
+
# You may wish to alter this file to override the set of languages analyzed,
|
5 |
+
# or to provide custom queries or build logic.
|
6 |
+
#
|
7 |
+
# ******** NOTE ********
|
8 |
+
# We have attempted to detect the languages in your repository. Please check
|
9 |
+
# the `language` matrix defined below to confirm you have the correct set of
|
10 |
+
# supported CodeQL languages.
|
11 |
+
#
|
12 |
+
name: 'CodeQL'
|
13 |
+
|
14 |
+
on:
|
15 |
+
push:
|
16 |
+
branches: [main]
|
17 |
+
pull_request:
|
18 |
+
# The branches below must be a subset of the branches above
|
19 |
+
branches: [main]
|
20 |
+
schedule:
|
21 |
+
- cron: '0 9 * * 1' # Every Monday at 09:00 (9:00 AM)
|
22 |
+
|
23 |
+
jobs:
|
24 |
+
analyze:
|
25 |
+
name: Analyze
|
26 |
+
runs-on: ubuntu-latest
|
27 |
+
permissions:
|
28 |
+
actions: read
|
29 |
+
contents: read
|
30 |
+
security-events: write
|
31 |
+
|
32 |
+
strategy:
|
33 |
+
fail-fast: false
|
34 |
+
matrix:
|
35 |
+
language: ['python']
|
36 |
+
# CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ]
|
37 |
+
# Learn more about CodeQL language support at https://git.io/codeql-language-support
|
38 |
+
|
39 |
+
steps:
|
40 |
+
- name: Checkout repository
|
41 |
+
uses: actions/checkout@v2
|
42 |
+
|
43 |
+
# Initializes the CodeQL tools for scanning.
|
44 |
+
- name: Initialize CodeQL
|
45 |
+
uses: github/codeql-action/init@v2
|
46 |
+
with:
|
47 |
+
languages: ${{ matrix.language }}
|
48 |
+
# If you wish to specify custom queries, you can do so here or in a config file.
|
49 |
+
# By default, queries listed here will override any specified in a config file.
|
50 |
+
# Prefix the list here with "+" to use these queries and those in the config file.
|
51 |
+
# queries: ./path/to/local/query, your-org/your-repo/queries@main
|
52 |
+
|
53 |
+
# Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
|
54 |
+
# If this step fails, then you should remove it and run the build manually (see below)
|
55 |
+
- name: Autobuild
|
56 |
+
uses: github/codeql-action/autobuild@v2
|
57 |
+
|
58 |
+
# ℹ️ Command-line programs to run using the OS shell.
|
59 |
+
# 📚 https://git.io/JvXDl
|
60 |
+
|
61 |
+
# ✏️ If the Autobuild fails above, remove it and uncomment the following three lines
|
62 |
+
# and modify them (or add more) to build your code if your project
|
63 |
+
# uses a compiled language
|
64 |
+
|
65 |
+
# - run: |
|
66 |
+
# make bootstrap
|
67 |
+
# make release
|
68 |
+
|
69 |
+
- name: Perform CodeQL Analysis
|
70 |
+
uses: github/codeql-action/analyze@v2
|
Andromeda/.github/workflows/coverage.yaml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: PyTest Coverage
|
2 |
+
on:
|
3 |
+
workflow_call:
|
4 |
+
inputs:
|
5 |
+
download-path:
|
6 |
+
required: true
|
7 |
+
type: string
|
8 |
+
jobs:
|
9 |
+
coverage:
|
10 |
+
timeout-minutes: 5
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
steps:
|
13 |
+
- name: Checkout Repo
|
14 |
+
uses: actions/checkout@v3
|
15 |
+
- name: Setup
|
16 |
+
run: |
|
17 |
+
set -ex
|
18 |
+
python -m pip install --upgrade 'pip<23' wheel
|
19 |
+
pip install coverage[toml]==6.5.0
|
20 |
+
- name: Download artifacts
|
21 |
+
uses: actions/download-artifact@v3
|
22 |
+
with:
|
23 |
+
path: ${{ inputs.download-path }}
|
24 |
+
- name: Generate coverage report
|
25 |
+
run: |
|
26 |
+
set -ex
|
27 |
+
|
28 |
+
# Flatten the coverage files
|
29 |
+
ls ${{ inputs.download-path }} | while read x; do mv ${{ inputs.download-path }}/$x/.coverage .coverage.$x; done
|
30 |
+
|
31 |
+
python -m coverage combine
|
32 |
+
python -m coverage report
|
Andromeda/.github/workflows/docker.yaml
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Docker
|
2 |
+
on:
|
3 |
+
push:
|
4 |
+
branches:
|
5 |
+
- main
|
6 |
+
workflow_dispatch: {}
|
7 |
+
jobs:
|
8 |
+
docker-build:
|
9 |
+
runs-on: ubuntu-latest
|
10 |
+
if: github.repository_owner == 'mosaicml'
|
11 |
+
strategy:
|
12 |
+
matrix:
|
13 |
+
include:
|
14 |
+
- name: '1.13.1_cu117'
|
15 |
+
base_image: mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04
|
16 |
+
- name: '2.0.1_cu118'
|
17 |
+
base_image: mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04
|
18 |
+
|
19 |
+
steps:
|
20 |
+
- name: Maximize Build Space on Worker
|
21 |
+
uses: easimon/maximize-build-space@v4
|
22 |
+
with:
|
23 |
+
overprovision-lvm: true
|
24 |
+
remove-dotnet: true
|
25 |
+
remove-android: true
|
26 |
+
remove-haskell: true
|
27 |
+
|
28 |
+
- name: Checkout
|
29 |
+
uses: actions/checkout@v3
|
30 |
+
|
31 |
+
- name: Setup QEMU
|
32 |
+
uses: docker/setup-qemu-action@v2
|
33 |
+
|
34 |
+
- name: Setup Docker Buildx
|
35 |
+
uses: docker/setup-buildx-action@v2
|
36 |
+
|
37 |
+
- name: Login to DockerHub
|
38 |
+
uses: docker/login-action@v2
|
39 |
+
with:
|
40 |
+
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
41 |
+
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
42 |
+
|
43 |
+
- name: Calculate Docker Image Variables
|
44 |
+
run: |
|
45 |
+
set -euxo pipefail
|
46 |
+
|
47 |
+
###################
|
48 |
+
# Calculate the tag
|
49 |
+
###################
|
50 |
+
GIT_SHA=$(echo ${{ github.sha }} | cut -c1-7)
|
51 |
+
echo "IMAGE_TAG=${GIT_SHA}" >> ${GITHUB_ENV}
|
52 |
+
|
53 |
+
- name: Build and Push the Docker Image
|
54 |
+
uses: docker/build-push-action@v3
|
55 |
+
with:
|
56 |
+
context: .
|
57 |
+
tags: mosaicml/llm-foundry:${{ matrix.name }}-latest,
|
58 |
+
mosaicml/llm-foundry:${{ matrix.name }}-${{ env.IMAGE_TAG }}
|
59 |
+
push: true
|
60 |
+
cache-from: type=registry,ref=mosaicml/llm-foundry:${{ matrix.name }}-buildcache
|
61 |
+
cache-to: type=registry,ref=mosaicml/llm-foundry:${{ matrix.name }}-buildcache,mode=max
|
62 |
+
build-args: BASE_IMAGE=${{ matrix.base_image }}
|
Andromeda/.github/workflows/pr-cpu.yaml
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: PR CPU tests
|
2 |
+
on:
|
3 |
+
push:
|
4 |
+
branches:
|
5 |
+
- main
|
6 |
+
- release/*
|
7 |
+
pull_request:
|
8 |
+
branches:
|
9 |
+
- main
|
10 |
+
- release/*
|
11 |
+
workflow_dispatch:
|
12 |
+
# Cancel old runs when a new commit is pushed to the same branch if not on main or dev
|
13 |
+
concurrency:
|
14 |
+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
15 |
+
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
16 |
+
jobs:
|
17 |
+
pytest-cpu:
|
18 |
+
uses: ./.github/workflows/pytest-cpu.yaml
|
19 |
+
strategy:
|
20 |
+
matrix:
|
21 |
+
include:
|
22 |
+
- name: 'cpu-latest'
|
23 |
+
container: mosaicml/pytorch:latest_cpu # mosaicml/pytorch:1.13.1_cpu-python3.10-ubuntu20.04
|
24 |
+
markers: 'not gpu'
|
25 |
+
pytest_command: 'coverage run -m pytest'
|
26 |
+
- name: 'cpu-2.0.1'
|
27 |
+
container: mosaicml/pytorch:2.0.1_cpu-python3.10-ubuntu20.04
|
28 |
+
markers: 'not gpu'
|
29 |
+
pytest_command: 'coverage run -m pytest'
|
30 |
+
name: ${{ matrix.name }}
|
31 |
+
if: github.repository_owner == 'mosaicml'
|
32 |
+
with:
|
33 |
+
container: ${{ matrix.container }}
|
34 |
+
name: ${{ matrix.name }}
|
35 |
+
pytest-command: ${{ matrix.pytest_command }}
|
36 |
+
pytest-markers: ${{ matrix.markers }}
|
37 |
+
coverage:
|
38 |
+
uses: ./.github/workflows/coverage.yaml
|
39 |
+
name: Coverage Results
|
40 |
+
if: github.repository_owner == 'mosaicml'
|
41 |
+
needs: [pytest-cpu]
|
42 |
+
with:
|
43 |
+
download-path: artifacts
|
Andromeda/.github/workflows/pr-gpu.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: PR GPU tests
|
2 |
+
on:
|
3 |
+
push:
|
4 |
+
branches:
|
5 |
+
- main
|
6 |
+
- release/*
|
7 |
+
pull_request_target:
|
8 |
+
branches:
|
9 |
+
- main
|
10 |
+
- release/**
|
11 |
+
workflow_dispatch:
|
12 |
+
# Cancel old runs when a new commit is pushed to the same branch if not on main or dev
|
13 |
+
concurrency:
|
14 |
+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
15 |
+
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
16 |
+
jobs:
|
17 |
+
pytest-gpu:
|
18 |
+
uses: ./.github/workflows/pytest-gpu.yaml
|
19 |
+
strategy:
|
20 |
+
matrix:
|
21 |
+
include:
|
22 |
+
- name: 'gpu-latest'
|
23 |
+
container: mosaicml/pytorch:latest # mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04
|
24 |
+
markers: 'gpu'
|
25 |
+
pytest_command: 'coverage run -m pytest'
|
26 |
+
- name: 'gpu-2.0.1'
|
27 |
+
container: mosaicml/pytorch:2.0.1_cu117-python3.10-ubuntu20.04
|
28 |
+
markers: 'gpu'
|
29 |
+
pytest_command: 'coverage run -m pytest'
|
30 |
+
name: ${{ matrix.name }}
|
31 |
+
if: github.repository_owner == 'mosaicml'
|
32 |
+
with:
|
33 |
+
container: ${{ matrix.container }}
|
34 |
+
mcloud-timeout: 1200
|
35 |
+
name: ${{ matrix.name }}
|
36 |
+
pytest-command: ${{ matrix.pytest_command }}
|
37 |
+
pytest-markers: ${{ matrix.markers }}
|
38 |
+
python-version: 3.9
|
39 |
+
secrets:
|
40 |
+
mcloud-api-key: ${{ secrets.MCLOUD_API_KEY }}
|
Andromeda/.github/workflows/pytest-cpu.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Pytest CPU
|
2 |
+
on:
|
3 |
+
workflow_call:
|
4 |
+
inputs:
|
5 |
+
container:
|
6 |
+
required: true
|
7 |
+
type: string
|
8 |
+
name:
|
9 |
+
required: true
|
10 |
+
type: string
|
11 |
+
pytest-command:
|
12 |
+
required: true
|
13 |
+
type: string
|
14 |
+
pytest-markers:
|
15 |
+
required: true
|
16 |
+
type: string
|
17 |
+
jobs:
|
18 |
+
pytest-cpu:
|
19 |
+
timeout-minutes: 30
|
20 |
+
runs-on: ubuntu-latest
|
21 |
+
container: ${{ inputs.container }}
|
22 |
+
steps:
|
23 |
+
- name: Checkout Repo
|
24 |
+
uses: actions/checkout@v3
|
25 |
+
- name: Setup
|
26 |
+
run: |
|
27 |
+
set -ex
|
28 |
+
export PATH=/composer-python:$PATH
|
29 |
+
python -m pip install --upgrade 'pip<23' wheel
|
30 |
+
python -m pip install --upgrade .[dev]
|
31 |
+
- name: Run Tests
|
32 |
+
id: tests
|
33 |
+
run: |
|
34 |
+
set -ex
|
35 |
+
export PATH=/composer-python:$PATH
|
36 |
+
export COMMON_ARGS="-v --durations=20 -m '${{ inputs.pytest-markers }}'"
|
37 |
+
|
38 |
+
# Necessary to run git diff for doctests
|
39 |
+
git config --global --add safe.directory /__w/llm-foundry/llm-foundry
|
40 |
+
|
41 |
+
make test PYTEST='${{ inputs.pytest-command }}' EXTRA_ARGS="$COMMON_ARGS --codeblocks"
|
42 |
+
# make test-dist PYTEST='${{ inputs.pytest-command }}' EXTRA_ARGS="$COMMON_ARGS" WORLD_SIZE=2
|
43 |
+
|
44 |
+
python -m coverage combine
|
45 |
+
- uses: actions/upload-artifact@v3
|
46 |
+
with:
|
47 |
+
name: coverage-${{ github.sha }}-${{ inputs.name }}
|
48 |
+
path: .coverage
|
Andromeda/.github/workflows/pytest-gpu.yaml
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Pytest GPU
|
2 |
+
on:
|
3 |
+
workflow_call:
|
4 |
+
inputs:
|
5 |
+
container:
|
6 |
+
required: true
|
7 |
+
type: string
|
8 |
+
mcloud-timeout:
|
9 |
+
required: false
|
10 |
+
type: number
|
11 |
+
default: 1800
|
12 |
+
name:
|
13 |
+
required: true
|
14 |
+
type: string
|
15 |
+
pytest-command:
|
16 |
+
required: true
|
17 |
+
type: string
|
18 |
+
pytest-markers:
|
19 |
+
required: true
|
20 |
+
type: string
|
21 |
+
python-version:
|
22 |
+
required: false
|
23 |
+
type: string
|
24 |
+
default: 3.9
|
25 |
+
secrets:
|
26 |
+
mcloud-api-key:
|
27 |
+
required: true
|
28 |
+
jobs:
|
29 |
+
pytest-gpu:
|
30 |
+
timeout-minutes: 60 # ${{ inputs.gha-timeout }} for some reason not able to turn this into an input
|
31 |
+
runs-on: ubuntu-latest
|
32 |
+
env:
|
33 |
+
MOSAICML_API_KEY: ${{ secrets.mcloud-api-key }}
|
34 |
+
steps:
|
35 |
+
- name: Checkout Repo
|
36 |
+
uses: actions/checkout@v3
|
37 |
+
- name: Setup Python
|
38 |
+
uses: actions/setup-python@v4
|
39 |
+
with:
|
40 |
+
python-version: ${{ inputs.python-version }}
|
41 |
+
- name: Cache pip
|
42 |
+
uses: actions/cache@v3
|
43 |
+
with:
|
44 |
+
# This path is specific to Ubuntu
|
45 |
+
path: ~/.cache/pip
|
46 |
+
# Look to see if there is a cache hit for the corresponding requirements file
|
47 |
+
key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}
|
48 |
+
restore-keys: |
|
49 |
+
${{ runner.os }}-pip-
|
50 |
+
${{ runner.os }}-
|
51 |
+
- name: Setup MCLI
|
52 |
+
run: |
|
53 |
+
set -ex
|
54 |
+
python -m pip install mosaicml-cli
|
55 |
+
mcli init --mcloud
|
56 |
+
mcli version
|
57 |
+
- name: Submit Run
|
58 |
+
id: tests
|
59 |
+
run: |
|
60 |
+
set -ex
|
61 |
+
|
62 |
+
PR_NUMBER="$(jq --raw-output .pull_request.number "$GITHUB_EVENT_PATH")"
|
63 |
+
REF_ARGS=""
|
64 |
+
|
65 |
+
# Use the PR number if it exists, commit SHA for protected branches and the branch name otherwise
|
66 |
+
if [ -z "$PR_NUMBER" ] || [ "$PR_NUMBER" = "null" ]; then
|
67 |
+
if [[ "$GITHUB_REF" =~ "refs/heads/main" || "$GITHUB_REF" =~ "refs/heads/release" ]]; then
|
68 |
+
REF_ARGS="--git_commit $GITHUB_SHA"
|
69 |
+
else
|
70 |
+
REF_ARGS="--git_branch $GITHUB_REF_NAME"
|
71 |
+
fi
|
72 |
+
else
|
73 |
+
REF_ARGS="--pr_number $PR_NUMBER"
|
74 |
+
fi
|
75 |
+
|
76 |
+
python .github/mcp/mcp_pytest.py \
|
77 |
+
--image '${{ inputs.container }}' \
|
78 |
+
--pytest_markers '${{ inputs.pytest-markers }}' \
|
79 |
+
--pytest_command '${{ inputs.pytest-command }}' \
|
80 |
+
--timeout ${{ inputs.mcloud-timeout }} ${REF_ARGS}
|
Andromeda/.github/workflows/python-publish.yml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This workflow will upload a Python Package using Twine when a release is created
|
2 |
+
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
|
3 |
+
|
4 |
+
# This workflow uses actions that are not certified by GitHub.
|
5 |
+
# They are provided by a third-party and are governed by
|
6 |
+
# separate terms of service, privacy policy, and support
|
7 |
+
# documentation.
|
8 |
+
|
9 |
+
name: Upload Python Package
|
10 |
+
|
11 |
+
on:
|
12 |
+
release:
|
13 |
+
types: [published]
|
14 |
+
|
15 |
+
permissions:
|
16 |
+
contents: read
|
17 |
+
|
18 |
+
jobs:
|
19 |
+
deploy:
|
20 |
+
|
21 |
+
runs-on: ubuntu-latest
|
22 |
+
|
23 |
+
steps:
|
24 |
+
- uses: actions/checkout@v3
|
25 |
+
- name: Set up Python
|
26 |
+
uses: actions/setup-python@v3
|
27 |
+
with:
|
28 |
+
python-version: '3.x'
|
29 |
+
- name: Install dependencies
|
30 |
+
run: |
|
31 |
+
python -m pip install --upgrade pip
|
32 |
+
pip install build
|
33 |
+
- name: Build package
|
34 |
+
run: python -m build
|
35 |
+
- name: Publish package
|
36 |
+
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
37 |
+
with:
|
38 |
+
user: __token__
|
39 |
+
password: ${{ secrets.PYPI_API_TOKEN }}
|
Andromeda/.github/workflows/release.yaml
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Release
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
tags:
|
6 |
+
- 'v*'
|
7 |
+
workflow_dispatch:
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
code-quality:
|
11 |
+
uses: ./.github/workflows/code-quality.yaml
|
12 |
+
|
13 |
+
pypi-packaging:
|
14 |
+
name: Build and Publish llm-foundry PyPI Package
|
15 |
+
needs:
|
16 |
+
- code-quality
|
17 |
+
runs-on: ubuntu-latest
|
18 |
+
steps:
|
19 |
+
- name: Checkout source
|
20 |
+
uses: actions/checkout@v3
|
21 |
+
|
22 |
+
- name: Set up Python
|
23 |
+
uses: actions/setup-python@v3
|
24 |
+
with:
|
25 |
+
python-version: '3.9'
|
26 |
+
|
27 |
+
- name: Build source and wheel distributions
|
28 |
+
run: |
|
29 |
+
if [[ "${{ github.ref }}" =~ refs\/tags\/v ]]; then
|
30 |
+
PYPI_PACKAGE_NAME="llm-foundry"
|
31 |
+
else
|
32 |
+
PYPI_PACKAGE_NAME="llm-foundry-test-$(date +%Y%m%d%H%M%S)"
|
33 |
+
fi
|
34 |
+
|
35 |
+
# Remove the peft, xentropy-cuda-lib and triton-pre-mlir dependencies as PyPI does not
|
36 |
+
# support direct installs. The error message for importing PEFT, FusedCrossEntropy,
|
37 |
+
# and flash_attn_triton gives instructions on how to install if a user tries to use it
|
38 |
+
# without this dependency.
|
39 |
+
sed '/xentropy-cuda-lib@git+https:\/\/github.com\/HazyResearch\/flash-attention.git@.*/d' -i setup.py
|
40 |
+
sed '/triton-pre-mlir@git+https:\/\/github.com\/vchiley\/triton.git@.*/d' -i setup.py
|
41 |
+
sed '/peft@git+https:\/\/github.com\/huggingface\/peft.git.*/d' -i setup.py
|
42 |
+
|
43 |
+
python -m pip install --upgrade build twine
|
44 |
+
python -m build
|
45 |
+
twine check --strict dist/*
|
46 |
+
|
47 |
+
- name: Publish 📦 to PyPI
|
48 |
+
uses: pypa/gh-action-pypi-publish@release/v1
|
49 |
+
if: contains(github.ref, 'refs/tags/v')
|
50 |
+
with:
|
51 |
+
user: __token__
|
52 |
+
password: ${{ secrets.PROD_PYPI_API_TOKEN }}
|
53 |
+
|
54 |
+
- name: Publish distribution 📦 to Test PyPI
|
55 |
+
uses: pypa/gh-action-pypi-publish@release/v1
|
56 |
+
if: contains(github.ref, 'refs/heads/') || contains(github.ref, 'refs/pull/')
|
57 |
+
with:
|
58 |
+
user: __token__
|
59 |
+
password: ${{ secrets.TEST_PYPI_API_TOKEN }}
|
60 |
+
repository_url: https://test.pypi.org/legacy/
|
Andromeda/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
.DS_Store
|
2 |
+
dist
|
Andromeda/Andromeda/README.md
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Transformer Model Technical Research Analysis
|
2 |
+
|
3 |
+
This document provides an analysis of the hyperparameters and configurations of the given Transformer model, focusing on dimensions, depth, and heads, as well as an architectural overview of their meanings and use cases.
|
4 |
+
|
5 |
+
## Model Configuration
|
6 |
+
|
7 |
+
```python
|
8 |
+
model = Transformer(
|
9 |
+
num_tokens=20000,
|
10 |
+
max_seq_len=8192,
|
11 |
+
use_abs_pos_emb = False,
|
12 |
+
attn_layers = Decoder(
|
13 |
+
dim=512,
|
14 |
+
depth=6,
|
15 |
+
heads=8,
|
16 |
+
alibi_pos_bias=True,
|
17 |
+
alibi_num_heads=4,
|
18 |
+
rotary_xpos=True,
|
19 |
+
attn_flash = True,
|
20 |
+
deepnorm=True,
|
21 |
+
shift_tokens=1,
|
22 |
+
attn_one_kv_head = True,
|
23 |
+
)
|
24 |
+
)
|
25 |
+
```
|
26 |
+
|
27 |
+
### Hyperparameters
|
28 |
+
|
29 |
+
1. **num_tokens**: The number of unique tokens in the input vocabulary. In this case, the model is configured to handle 20,000 unique tokens.
|
30 |
+
|
31 |
+
2. **max_seq_len**: The maximum sequence length that the model can handle. The current configuration supports sequences of up to 8,192 tokens.
|
32 |
+
|
33 |
+
3. **use_abs_pos_emb**: A boolean flag indicating whether to use absolute positional embeddings. The model is configured not to use absolute positional embeddings (`False`).
|
34 |
+
|
35 |
+
4. **dim**: The dimensionality of the input embeddings and the internal representations within the Transformer layers. The model uses a dimensionality of 512.
|
36 |
+
|
37 |
+
5. **depth**: The number of Transformer layers (or blocks) in the model. This model has a depth of 6, meaning it has 6 layers.
|
38 |
+
|
39 |
+
6. **heads**: The number of attention heads in the multi-head self-attention mechanism. This model uses 8 attention heads.
|
40 |
+
|
41 |
+
### Additional Configurations
|
42 |
+
|
43 |
+
- **alibi_pos_bias**: A boolean flag indicating whether to use the Alibi position bias mechanism. The model is configured to use Alibi position bias (`True`).
|
44 |
+
|
45 |
+
- **alibi_num_heads**: The number of Alibi attention heads to use. The model is configured to use 4 Alibi attention heads.
|
46 |
+
|
47 |
+
- **rotary_xpos**: A boolean flag indicating whether to use the rotary positional encoding mechanism. The model is configured to use rotary positional encoding (`True`).
|
48 |
+
|
49 |
+
- **attn_flash**: A boolean flag indicating whether to use the Flash attention mechanism. The model is configured to use Flash attention (`True`).
|
50 |
+
|
51 |
+
- **deepnorm**: A boolean flag indicating whether to use deep normalization. The model is configured to use deep normalization (`True`).
|
52 |
+
|
53 |
+
- **shift_tokens**: The number of tokens to shift during training to form the target sequence. The model is configured to shift by 1 token (`1`).
|
54 |
+
|
55 |
+
- **attn_one_kv_head**: A boolean flag indicating whether to use one key-value head for attention instead of multiple heads. The model is configured to use one key-value head (`True`).
|
56 |
+
|
57 |
+
## Architectural Overview
|
58 |
+
|
59 |
+
### Dimensions
|
60 |
+
|
61 |
+
- **Input Embedding Dimension (dim)**: This hyperparameter defines the size of the input embeddings and the internal representations within the Transformer layers. A larger dimensionality can capture more complex relationships between tokens but may require more computational resources.
|
62 |
+
|
63 |
+
### Depth
|
64 |
+
|
65 |
+
- **Number of Transformer Layers (depth)**: This hyperparameter defines the number of Transformer layers (or blocks) in the model. Each layer consists of a multi-head self-attention mechanism followed by a position-wise feed-forward network. Increasing the depth allows the model to capture more complex and hierarchical relationships between tokens but may also increase the risk of overfitting and require more computational resources.
|
66 |
+
|
67 |
+
### Heads
|
68 |
+
|
69 |
+
- **Number of Attention Heads (heads)**: This hyperparameter defines the number of attention heads in the multi-head self-attention mechanism. Each head processes the input sequence independently and captures different aspects of the relationships between tokens. The outputs of all heads are then concatenated and transformed to produce the final output. Increasing the number of attention heads can help the model capture more diverse and fine-grained relationships between tokens but may also increase computational complexity and memory requirements.
|
70 |
+
|
71 |
+
## Benefits and Consequences of Increasing Hyperparameters
|
72 |
+
|
73 |
+
### Dimensions
|
74 |
+
|
75 |
+
**Benefits:**
|
76 |
+
|
77 |
+
- Better representation: Increasing the dimensionality of the input embeddings and internal representations allows the model to capture more complex relationships between tokens.
|
78 |
+
|
79 |
+
- Improved model expressiveness: A higher dimensionality may enable the model to learn more expressive features, leading to better performance on complex tasks.
|
80 |
+
|
81 |
+
**Consequences:**
|
82 |
+
|
83 |
+
- Computational complexity: Increasing the dimensionality will increase the computational complexity of the model, which may lead to longer training and inference times.
|
84 |
+
|
85 |
+
- Memory requirements: A higher dimensionality will increase the memory requirements of the model, potentially limiting its applicability on resource-constrained hardware.
|
86 |
+
|
87 |
+
- Risk of overfitting: Models with a higher dimensionality may be more prone to overfitting, especially if the size of the training dataset is small.
|
88 |
+
|
89 |
+
### Depth
|
90 |
+
|
91 |
+
**Benefits:**
|
92 |
+
|
93 |
+
- Hierarchical representation: Increasing the depth of the model allows it to capture more complex and hierarchical relationships between tokens, which can lead to improved performance on tasks that require understanding long-range dependencies.
|
94 |
+
|
95 |
+
- Enhanced feature extraction: Deeper models can extract features at different levels of abstraction, potentially improving their ability to generalize to new data.
|
96 |
+
|
97 |
+
**Consequences:**
|
98 |
+
|
99 |
+
- Computational complexity: Increasing the depth will increase the computational complexity of the model, leading to longer training and inference times.
|
100 |
+
|
101 |
+
- Memory requirements: A deeper model will require more memory, potentially limiting its applicability on resource-constrained hardware.
|
102 |
+
|
103 |
+
- Risk of overfitting: Deeper models may be more prone to overfitting, especially if the size of the training dataset is small.
|
104 |
+
|
105 |
+
- Vanishing/exploding gradients: Deeper models may suffer from vanishing or exploding gradients during training, making it harder to optimize the model. Techniques such as layer normalization or skip connections can help mitigate this issue.
|
106 |
+
|
107 |
+
### Heads
|
108 |
+
|
109 |
+
**Benefits:**
|
110 |
+
|
111 |
+
- Diverse attention: Increasing the number of attention heads allows the model to capture more diverse and fine-grained relationships between tokens, which can improve its ability to understand the input data.
|
112 |
+
|
113 |
+
- Robustness: Multi-head attention can make the model more robust, as each head can focus on different aspects of the input data.
|
114 |
+
|
115 |
+
**Consequences:**
|
116 |
+
|
117 |
+
- Computational complexity: Increasing the number of attention heads will increase the computational complexity of the model, leading to longer training and inference times.
|
118 |
+
|
119 |
+
- Memory requirements: A model with more attention heads will require more memory, potentially limiting its applicability on resource-constrained hardware.
|
120 |
+
|
121 |
+
- Diminishing returns: There may be diminishing returns when increasing the number of attention heads beyond a certain point, as the model may already be capturing most of the relevant information with fewer heads.
|
Andromeda/Andromeda/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# from Andromeda.train import Train
|
2 |
+
from Andromeda.model import AndromedaTokenizer, Andromeda
|
3 |
+
from Andromeda.train import Train, train
|
Andromeda/Andromeda/configs.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from Andromeda.model import AndromedaEmbedding, Andromeda
|
2 |
+
|
3 |
+
|
4 |
+
Andromeda1Billion = Andromeda(
|
5 |
+
num_tokens=25000,
|
6 |
+
max_seq_len=4192,
|
7 |
+
dim=2048,
|
8 |
+
depth=16,
|
9 |
+
dim_head=128,
|
10 |
+
heads=8,
|
11 |
+
use_abs_pos_emb=False,
|
12 |
+
alibi_pos_bias=True,
|
13 |
+
alibi_num_heads=4,
|
14 |
+
rotary_xpos=True,
|
15 |
+
attn_flash=True,
|
16 |
+
# shift_tokens=1,
|
17 |
+
attn_one_kv_head=True,
|
18 |
+
qk_norm=True,
|
19 |
+
attn_qk_norm=True,
|
20 |
+
attn_qk_norm_dim_scale=True,
|
21 |
+
embedding_provider=AndromedaEmbedding()
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
Andromeda3Billion = Andromeda(
|
27 |
+
num_tokens=50432,
|
28 |
+
max_seq_len=8192,
|
29 |
+
dim=3072,
|
30 |
+
depth=24,
|
31 |
+
dim_head=128,
|
32 |
+
heads=12,
|
33 |
+
use_abs_pos_emb=False,
|
34 |
+
alibi_pos_bias=True,
|
35 |
+
alibi_num_heads=6,
|
36 |
+
rotary_xpos=True,
|
37 |
+
attn_flash=True,
|
38 |
+
shift_tokens=1,
|
39 |
+
attn_one_kv_head=True,
|
40 |
+
qk_norm=True,
|
41 |
+
attn_qk_norm=True,
|
42 |
+
attn_qk_norm_dim_scale=True,
|
43 |
+
embedding_provider=AndromedaEmbedding()
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
Andromeda7Billion = Andromeda(
|
49 |
+
num_tokens=50432,
|
50 |
+
max_seq_len=8192,
|
51 |
+
dim=4096,
|
52 |
+
depth=32,
|
53 |
+
dim_head=128,
|
54 |
+
heads=16,
|
55 |
+
use_abs_pos_emb=False,
|
56 |
+
alibi_pos_bias=True,
|
57 |
+
alibi_num_heads=8,
|
58 |
+
rotary_xpos=True,
|
59 |
+
attn_flash=True,
|
60 |
+
shift_tokens=1,
|
61 |
+
attn_one_kv_head=True,
|
62 |
+
qk_norm=True,
|
63 |
+
attn_qk_norm=True,
|
64 |
+
attn_qk_norm_dim_scale=True,
|
65 |
+
embedding_provider=AndromedaEmbedding()
|
66 |
+
)
|
67 |
+
|
68 |
+
Andromeda10Billion = Andromeda(
|
69 |
+
num_tokens=50432,
|
70 |
+
max_seq_len=8192,
|
71 |
+
dim=5120,
|
72 |
+
depth=32,
|
73 |
+
dim_head=128,
|
74 |
+
heads=20,
|
75 |
+
use_abs_pos_emb=False,
|
76 |
+
alibi_pos_bias=True,
|
77 |
+
alibi_num_heads=4,
|
78 |
+
rotary_xpos=True,
|
79 |
+
attn_flash=True,
|
80 |
+
shift_tokens=1,
|
81 |
+
attn_one_kv_head=True,
|
82 |
+
qk_norm=True,
|
83 |
+
attn_qk_norm=True,
|
84 |
+
attn_qk_norm_dim_scale=True,
|
85 |
+
embedding_provider=AndromedaEmbedding()
|
86 |
+
)
|
87 |
+
|
88 |
+
Andromeda15Billion = Andromeda(
|
89 |
+
num_tokens=50432,
|
90 |
+
max_seq_len=8192,
|
91 |
+
dim=6144,
|
92 |
+
depth=40,
|
93 |
+
dim_head=128,
|
94 |
+
heads=24,
|
95 |
+
use_abs_pos_emb=False,
|
96 |
+
alibi_pos_bias=True,
|
97 |
+
alibi_num_heads=4,
|
98 |
+
rotary_xpos=True,
|
99 |
+
attn_flash=True,
|
100 |
+
shift_tokens=1,
|
101 |
+
attn_one_kv_head=True,
|
102 |
+
qk_norm=True,
|
103 |
+
attn_qk_norm=True,
|
104 |
+
attn_qk_norm_dim_scale=True,
|
105 |
+
embedding_provider=AndromedaEmbedding()
|
106 |
+
)
|
107 |
+
|
108 |
+
Andromeda20Billion = Andromeda(
|
109 |
+
num_tokens=50432,
|
110 |
+
max_seq_len=8192,
|
111 |
+
dim=7168,
|
112 |
+
depth=48,
|
113 |
+
dim_head=128,
|
114 |
+
heads=28,
|
115 |
+
use_abs_pos_emb=False,
|
116 |
+
alibi_pos_bias=True,
|
117 |
+
alibi_num_heads=4,
|
118 |
+
rotary_xpos=True,
|
119 |
+
attn_flash=True,
|
120 |
+
shift_tokens=1,
|
121 |
+
attn_one_kv_head=True,
|
122 |
+
qk_norm=True,
|
123 |
+
attn_qk_norm=True,
|
124 |
+
attn_qk_norm_dim_scale=True,
|
125 |
+
embedding_provider=AndromedaEmbedding()
|
126 |
+
)
|
127 |
+
|
128 |
+
#to GPT like 176Billion Parameters 122888 dimension, 96 depth, 96 heads, attn dim head 128
|
Andromeda/Andromeda/core/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from packaging import version
|
3 |
+
|
4 |
+
if version.parse(torch.__version__) >= version.parse('2.0.0'):
|
5 |
+
from einops._torch_specific import allow_ops_in_compiled_graph
|
6 |
+
allow_ops_in_compiled_graph()
|
7 |
+
|
8 |
+
|
Andromeda/Andromeda/core/attend.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn, einsum, Tensor
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from collections import namedtuple
|
8 |
+
from functools import wraps
|
9 |
+
from packaging import version
|
10 |
+
from dataclasses import dataclass
|
11 |
+
from einops import rearrange
|
12 |
+
|
13 |
+
from Andromeda.core.flash import attention
|
14 |
+
|
15 |
+
# from flash import FlashAttention
|
16 |
+
|
17 |
+
# constants
|
18 |
+
|
19 |
+
EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class Intermediates:
|
23 |
+
qk_similarities: Tensor = None
|
24 |
+
pre_softmax_attn: Tensor = None
|
25 |
+
post_softmax_attn: Tensor = None
|
26 |
+
|
27 |
+
# helpers
|
28 |
+
|
29 |
+
def exists(val):
|
30 |
+
return val is not None
|
31 |
+
|
32 |
+
def default(val, d):
|
33 |
+
return val if exists(val) else d
|
34 |
+
|
35 |
+
def once(fn):
|
36 |
+
called = False
|
37 |
+
@wraps(fn)
|
38 |
+
def inner(x):
|
39 |
+
nonlocal called
|
40 |
+
if called:
|
41 |
+
return
|
42 |
+
called = True
|
43 |
+
return fn(x)
|
44 |
+
return inner
|
45 |
+
|
46 |
+
print_once = once(print)
|
47 |
+
|
48 |
+
# main class
|
49 |
+
|
50 |
+
class Attend(nn.Module):
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
*,
|
54 |
+
dropout = 0.,
|
55 |
+
causal = False,
|
56 |
+
heads = None,
|
57 |
+
talking_heads = False,
|
58 |
+
scale = None,
|
59 |
+
qk_norm = False,
|
60 |
+
flash = False,
|
61 |
+
triton = False,
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
self.scale = scale
|
65 |
+
self.qk_norm = qk_norm
|
66 |
+
self.causal = causal
|
67 |
+
self.attn_fn = partial(F.softmax, dtype = torch.float32) if not qk_norm else F.softmax
|
68 |
+
|
69 |
+
self.dropout = dropout
|
70 |
+
self.attn_dropout = nn.Dropout(dropout)
|
71 |
+
|
72 |
+
# talking heads
|
73 |
+
|
74 |
+
assert not (flash and talking_heads), 'talking heads not compatible with flash attention'
|
75 |
+
|
76 |
+
self.talking_heads = talking_heads
|
77 |
+
if talking_heads:
|
78 |
+
self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
|
79 |
+
self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
|
80 |
+
|
81 |
+
# flash attention
|
82 |
+
self.flash = flash
|
83 |
+
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
|
84 |
+
|
85 |
+
# determine efficient attention configs for cuda and cpu
|
86 |
+
self.cpu_config = EfficientAttentionConfig(True, True, True)
|
87 |
+
self.cuda_config = None
|
88 |
+
|
89 |
+
if not torch.cuda.is_available() or not flash:
|
90 |
+
return
|
91 |
+
|
92 |
+
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
|
93 |
+
|
94 |
+
if device_properties.major == 8 and device_properties.minor == 0:
|
95 |
+
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
|
96 |
+
self.cuda_config = EfficientAttentionConfig(True, False, False)
|
97 |
+
else:
|
98 |
+
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
|
99 |
+
self.cuda_config = EfficientAttentionConfig(False, True, True)
|
100 |
+
|
101 |
+
def flash_attn(
|
102 |
+
self,
|
103 |
+
q, k, v,
|
104 |
+
mask = None,
|
105 |
+
attn_bias = None
|
106 |
+
):
|
107 |
+
batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
|
108 |
+
|
109 |
+
# Recommended for multi-query single-key-value attention by Tri Dao
|
110 |
+
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
|
111 |
+
|
112 |
+
if k.ndim == 3:
|
113 |
+
k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
|
114 |
+
|
115 |
+
if v.ndim == 3:
|
116 |
+
v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
|
117 |
+
|
118 |
+
# handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention
|
119 |
+
|
120 |
+
if self.qk_norm:
|
121 |
+
default_scale = q.shape[-1] ** -0.5
|
122 |
+
q = q * (default_scale / self.scale)
|
123 |
+
|
124 |
+
# Check if mask exists and expand to compatible shape
|
125 |
+
# The mask is B L, so it would have to be expanded to B H N L
|
126 |
+
|
127 |
+
causal = self.causal
|
128 |
+
|
129 |
+
if exists(mask):
|
130 |
+
assert mask.ndim == 4
|
131 |
+
mask = mask.expand(batch, heads, q_len, k_len)
|
132 |
+
|
133 |
+
# manually handle causal mask, if another mask was given
|
134 |
+
|
135 |
+
if causal:
|
136 |
+
causal_mask = torch.ones((q_len, k_len), dtype = torch.bool, device = device).triu(k_len - q_len + 1)
|
137 |
+
mask = mask | causal_mask
|
138 |
+
causal = False
|
139 |
+
|
140 |
+
# handle alibi positional bias
|
141 |
+
# convert from bool to float
|
142 |
+
|
143 |
+
if exists(attn_bias):
|
144 |
+
attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, -1, -1, -1)
|
145 |
+
|
146 |
+
# if mask given, the mask would already contain the causal mask from above logic
|
147 |
+
# otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
|
148 |
+
|
149 |
+
mask_value = -torch.finfo(q.dtype).max
|
150 |
+
|
151 |
+
if exists(mask):
|
152 |
+
attn_bias = attn_bias.masked_fill(mask, mask_value // 2)
|
153 |
+
elif causal:
|
154 |
+
causal_mask = torch.ones((q_len, k_len), dtype = torch.bool, device = device).triu(k_len - q_len + 1)
|
155 |
+
attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2)
|
156 |
+
causal = False
|
157 |
+
|
158 |
+
# scaled_dot_product_attention handles attn_mask either as bool or additive bias
|
159 |
+
# make it an additive bias here
|
160 |
+
|
161 |
+
mask = attn_bias
|
162 |
+
|
163 |
+
# Check if there is a compatible device for flash attention
|
164 |
+
|
165 |
+
config = self.cuda_config if is_cuda else self.cpu_config
|
166 |
+
|
167 |
+
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
|
168 |
+
|
169 |
+
with torch.backends.cuda.sdp_kernel(**config._asdict()):
|
170 |
+
out = F.scaled_dot_product_attention(
|
171 |
+
q, k, v,
|
172 |
+
attn_mask = mask,
|
173 |
+
dropout_p = self.dropout if self.training else 0.,
|
174 |
+
is_causal = causal
|
175 |
+
)
|
176 |
+
|
177 |
+
return out, Intermediates()
|
178 |
+
|
179 |
+
def forward(
|
180 |
+
self,
|
181 |
+
q, k, v,
|
182 |
+
mask = None,
|
183 |
+
attn_bias = None,
|
184 |
+
prev_attn = None
|
185 |
+
):
|
186 |
+
"""
|
187 |
+
einstein notation
|
188 |
+
b - batch
|
189 |
+
h - heads
|
190 |
+
n, i, j - sequence length (base sequence length, source, target)
|
191 |
+
d - feature dimension
|
192 |
+
"""
|
193 |
+
|
194 |
+
n, device = q.shape[-2], q.device
|
195 |
+
|
196 |
+
scale = default(self.scale, q.shape[-1] ** -0.5)
|
197 |
+
|
198 |
+
if self.flash:
|
199 |
+
assert not exists(prev_attn), 'residual attention not compatible with flash attention'
|
200 |
+
return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)
|
201 |
+
# return FlashAttention(q, k, v, mask=mask, attn_bias=attn_bias )
|
202 |
+
|
203 |
+
if self.triton:
|
204 |
+
return attention(q, k, v, self.casual, scale)
|
205 |
+
|
206 |
+
kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
|
207 |
+
|
208 |
+
dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
|
209 |
+
|
210 |
+
if exists(prev_attn):
|
211 |
+
dots = dots + prev_attn
|
212 |
+
|
213 |
+
qk_similarities = dots.clone()
|
214 |
+
|
215 |
+
if self.talking_heads:
|
216 |
+
dots = self.pre_softmax_talking_heads(dots)
|
217 |
+
|
218 |
+
if exists(attn_bias):
|
219 |
+
dots = dots + attn_bias
|
220 |
+
|
221 |
+
dtype = dots.dtype
|
222 |
+
pre_softmax_attn = dots.clone()
|
223 |
+
|
224 |
+
mask_value = -torch.finfo(dots.dtype).max
|
225 |
+
|
226 |
+
if exists(mask):
|
227 |
+
dots = dots.masked_fill(mask, mask_value)
|
228 |
+
|
229 |
+
if self.causal:
|
230 |
+
i, j = dots.shape[-2:]
|
231 |
+
causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
|
232 |
+
dots = dots.masked_fill(causal_mask, mask_value)
|
233 |
+
|
234 |
+
attn = self.attn_fn(dots, dim = -1)
|
235 |
+
attn = attn.type(dtype)
|
236 |
+
|
237 |
+
post_softmax_attn = attn.clone()
|
238 |
+
|
239 |
+
attn = self.attn_dropout(attn)
|
240 |
+
|
241 |
+
if self.talking_heads:
|
242 |
+
attn = self.post_softmax_talking_heads(attn)
|
243 |
+
|
244 |
+
out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
|
245 |
+
|
246 |
+
intermediates = Intermediates(
|
247 |
+
qk_similarities = qk_similarities,
|
248 |
+
pre_softmax_attn = pre_softmax_attn,
|
249 |
+
post_softmax_attn = post_softmax_attn
|
250 |
+
)
|
251 |
+
|
252 |
+
return out, intermediates
|
Andromeda/Andromeda/core/autoregressive_wrapper.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import ceil
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from einops import rearrange, pack, unpack
|
7 |
+
|
8 |
+
def exists(val):
|
9 |
+
return val is not None
|
10 |
+
|
11 |
+
def eval_decorator(fn):
|
12 |
+
def inner(self, *args, **kwargs):
|
13 |
+
was_training = self.training
|
14 |
+
self.eval()
|
15 |
+
out = fn(self, *args, **kwargs)
|
16 |
+
self.train(was_training)
|
17 |
+
return out
|
18 |
+
return inner
|
19 |
+
|
20 |
+
# nucleus
|
21 |
+
|
22 |
+
def top_p(logits, thres = 0.9):
|
23 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
24 |
+
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
25 |
+
|
26 |
+
sorted_indices_to_remove = cum_probs > (1 - thres)
|
27 |
+
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
28 |
+
sorted_indices_to_remove[:, 0] = 0
|
29 |
+
|
30 |
+
sorted_logits[sorted_indices_to_remove] = float('-inf')
|
31 |
+
return sorted_logits.scatter(1, sorted_indices, sorted_logits)
|
32 |
+
|
33 |
+
# topk
|
34 |
+
|
35 |
+
def top_k(logits, thres = 0.9):
|
36 |
+
k = ceil((1 - thres) * logits.shape[-1])
|
37 |
+
val, ind = torch.topk(logits, k)
|
38 |
+
probs = torch.full_like(logits, float('-inf'))
|
39 |
+
probs.scatter_(1, ind, val)
|
40 |
+
return probs
|
41 |
+
|
42 |
+
# top_a
|
43 |
+
|
44 |
+
def top_a(logits, min_p_pow=2.0, min_p_ratio=0.02):
|
45 |
+
probs = F.softmax(logits, dim=-1)
|
46 |
+
limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio
|
47 |
+
logits[probs < limit] = float('-inf')
|
48 |
+
logits[probs >= limit] = 1
|
49 |
+
return logits
|
50 |
+
|
51 |
+
# autoregressive wrapper class
|
52 |
+
|
53 |
+
class AutoregressiveWrapper(nn.Module):
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
net,
|
57 |
+
ignore_index = -100,
|
58 |
+
pad_value = 0,
|
59 |
+
mask_prob = 0.
|
60 |
+
):
|
61 |
+
super().__init__()
|
62 |
+
self.pad_value = pad_value
|
63 |
+
self.ignore_index = ignore_index
|
64 |
+
|
65 |
+
self.net = net
|
66 |
+
self.max_seq_len = net.max_seq_len
|
67 |
+
|
68 |
+
# paper shows masking (MLM) in conjunction with autoregressive decoder-only training leads to big improvements https://arxiv.org/abs/2210.13432
|
69 |
+
assert mask_prob < 1.
|
70 |
+
self.mask_prob = mask_prob
|
71 |
+
|
72 |
+
@torch.no_grad()
|
73 |
+
@eval_decorator
|
74 |
+
def generate(
|
75 |
+
self,
|
76 |
+
start_tokens,
|
77 |
+
seq_len,
|
78 |
+
eos_token = None,
|
79 |
+
temperature = 1.,
|
80 |
+
filter_logits_fn = top_k,
|
81 |
+
filter_thres = 0.9,
|
82 |
+
min_p_pow = 2.0,
|
83 |
+
min_p_ratio = 0.02,
|
84 |
+
**kwargs
|
85 |
+
):
|
86 |
+
|
87 |
+
start_tokens, ps = pack([start_tokens], '* n')
|
88 |
+
|
89 |
+
b, t = start_tokens.shape
|
90 |
+
|
91 |
+
out = start_tokens
|
92 |
+
|
93 |
+
for _ in range(seq_len):
|
94 |
+
x = out[:, -self.max_seq_len:]
|
95 |
+
|
96 |
+
logits = self.net(x, **kwargs)[:, -1]
|
97 |
+
|
98 |
+
if filter_logits_fn in {top_k, top_p}:
|
99 |
+
filtered_logits = filter_logits_fn(logits, thres = filter_thres)
|
100 |
+
probs = F.softmax(filtered_logits / temperature, dim=-1)
|
101 |
+
|
102 |
+
elif filter_logits_fn is top_a:
|
103 |
+
filtered_logits = filter_logits_fn(logits, min_p_pow = min_p_pow, min_p_ratio= min_p_ratio)
|
104 |
+
probs = F.softmax(filtered_logits / temperature, dim=-1)
|
105 |
+
|
106 |
+
sample = torch.multinomial(probs, 1)
|
107 |
+
|
108 |
+
out = torch.cat((out, sample), dim=-1)
|
109 |
+
|
110 |
+
if exists(eos_token):
|
111 |
+
is_eos_tokens = (out == eos_token)
|
112 |
+
|
113 |
+
if is_eos_tokens.any(dim = -1).all():
|
114 |
+
# mask out everything after the eos tokens
|
115 |
+
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
|
116 |
+
mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
|
117 |
+
out = out.masked_fill(mask, self.pad_value)
|
118 |
+
break
|
119 |
+
|
120 |
+
out = out[:, t:]
|
121 |
+
|
122 |
+
out, = unpack(out, ps, '* n')
|
123 |
+
|
124 |
+
return out
|
125 |
+
|
126 |
+
def forward(self, x, return_loss=True, **kwargs):
|
127 |
+
seq, ignore_index = x.shape[1], self.ignore_index
|
128 |
+
|
129 |
+
inp, target = x[:, :-1], x[:, 1:]
|
130 |
+
|
131 |
+
if self.mask_prob > 0.:
|
132 |
+
rand = torch.randn(inp.shape, device = x.device)
|
133 |
+
rand[:, 0] = -torch.finfo(rand.dtype).max # first token should not be masked out
|
134 |
+
num_mask = min(int(seq * self.mask_prob), seq - 1)
|
135 |
+
indices = rand.topk(num_mask, dim = -1).indices
|
136 |
+
mask = ~torch.zeros_like(inp).scatter(1, indices, 1.).bool()
|
137 |
+
kwargs.update(self_attn_context_mask = mask)
|
138 |
+
|
139 |
+
logits = self.net(inp, **kwargs)
|
140 |
+
|
141 |
+
loss = F.cross_entropy(
|
142 |
+
rearrange(logits, 'b n c -> b c n'),
|
143 |
+
target,
|
144 |
+
ignore_index = ignore_index
|
145 |
+
)
|
146 |
+
|
147 |
+
if return_loss:
|
148 |
+
return logits, loss
|
149 |
+
|
150 |
+
return logits
|
Andromeda/Andromeda/core/flash.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
import triton
|
4 |
+
import triton.language as tl
|
5 |
+
|
6 |
+
|
7 |
+
@triton.jit
|
8 |
+
def max_fn(x, y):
|
9 |
+
return tl.math.max(x, y)
|
10 |
+
|
11 |
+
|
12 |
+
@triton.jit
|
13 |
+
def _fwd_kernel(
|
14 |
+
Q, K, V, sm_scale,
|
15 |
+
L,
|
16 |
+
Out,
|
17 |
+
stride_qz, stride_qh, stride_qm, stride_qk,
|
18 |
+
stride_kz, stride_kh, stride_kn, stride_kk,
|
19 |
+
stride_vz, stride_vh, stride_vk, stride_vn,
|
20 |
+
stride_oz, stride_oh, stride_om, stride_on,
|
21 |
+
Z, H, N_CTX,
|
22 |
+
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
23 |
+
BLOCK_N: tl.constexpr,
|
24 |
+
IS_CAUSAL: tl.constexpr,
|
25 |
+
):
|
26 |
+
start_m = tl.program_id(0)
|
27 |
+
off_hz = tl.program_id(1)
|
28 |
+
qvk_offset = off_hz * stride_qh
|
29 |
+
Q_block_ptr = tl.make_block_ptr(
|
30 |
+
base=Q + qvk_offset,
|
31 |
+
shape=(N_CTX, BLOCK_DMODEL),
|
32 |
+
strides=(stride_qm, stride_qk),
|
33 |
+
offsets=(start_m * BLOCK_M, 0),
|
34 |
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
35 |
+
order=(1, 0)
|
36 |
+
)
|
37 |
+
K_block_ptr = tl.make_block_ptr(
|
38 |
+
base=K + qvk_offset,
|
39 |
+
shape=(BLOCK_DMODEL, N_CTX),
|
40 |
+
strides=(stride_kk, stride_kn),
|
41 |
+
offsets=(0, 0),
|
42 |
+
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
43 |
+
order=(0, 1)
|
44 |
+
)
|
45 |
+
V_block_ptr = tl.make_block_ptr(
|
46 |
+
base=V + qvk_offset,
|
47 |
+
shape=(N_CTX, BLOCK_DMODEL),
|
48 |
+
strides=(stride_vk, stride_vn),
|
49 |
+
offsets=(0, 0),
|
50 |
+
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
51 |
+
order=(1, 0)
|
52 |
+
)
|
53 |
+
# initialize offsets
|
54 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
55 |
+
offs_n = tl.arange(0, BLOCK_N)
|
56 |
+
# initialize pointer to m and l
|
57 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
58 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
59 |
+
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
60 |
+
# scale sm_scale by log_2(e) and use
|
61 |
+
# 2^x instead of exp in the loop because CSE and LICM
|
62 |
+
# don't work as expected with `exp` in the loop
|
63 |
+
qk_scale = sm_scale * 1.44269504
|
64 |
+
# load q: it will stay in SRAM throughout
|
65 |
+
q = tl.load(Q_block_ptr)
|
66 |
+
q = (q * qk_scale).to(tl.float16)
|
67 |
+
# loop over k, v and update accumulator
|
68 |
+
lo = 0
|
69 |
+
hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX
|
70 |
+
for start_n in range(lo, hi, BLOCK_N):
|
71 |
+
# -- load k, v --
|
72 |
+
k = tl.load(K_block_ptr)
|
73 |
+
v = tl.load(V_block_ptr)
|
74 |
+
# -- compute qk ---
|
75 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
76 |
+
if IS_CAUSAL:
|
77 |
+
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
78 |
+
qk += tl.dot(q, k)
|
79 |
+
# -- compute scaling constant ---
|
80 |
+
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
|
81 |
+
alpha = tl.math.exp2(m_i - m_i_new)
|
82 |
+
p = tl.math.exp2(qk - m_i_new[:, None])
|
83 |
+
# -- scale and update acc --
|
84 |
+
acc_scale = l_i * 0 + alpha # workaround some compiler bug
|
85 |
+
acc *= acc_scale[:, None]
|
86 |
+
acc += tl.dot(p.to(tl.float16), v)
|
87 |
+
# -- update m_i and l_i --
|
88 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
89 |
+
m_i = m_i_new
|
90 |
+
# update pointers
|
91 |
+
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
92 |
+
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
93 |
+
# write back l and m
|
94 |
+
acc = acc / l_i[:, None]
|
95 |
+
l_ptrs = L + off_hz * N_CTX + offs_m
|
96 |
+
tl.store(l_ptrs, m_i + tl.math.log2(l_i))
|
97 |
+
# write back O
|
98 |
+
O_block_ptr = tl.make_block_ptr(
|
99 |
+
base=Out + qvk_offset,
|
100 |
+
shape=(N_CTX, BLOCK_DMODEL),
|
101 |
+
strides=(stride_om, stride_on),
|
102 |
+
offsets=(start_m * BLOCK_M, 0),
|
103 |
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
104 |
+
order=(1, 0)
|
105 |
+
)
|
106 |
+
tl.store(O_block_ptr, acc.to(tl.float16))
|
107 |
+
|
108 |
+
|
109 |
+
@triton.jit
|
110 |
+
def _bwd_preprocess(
|
111 |
+
Out, DO,
|
112 |
+
Delta,
|
113 |
+
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
|
114 |
+
):
|
115 |
+
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
116 |
+
off_n = tl.arange(0, D_HEAD)
|
117 |
+
# load
|
118 |
+
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
119 |
+
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
120 |
+
# compute
|
121 |
+
delta = tl.sum(o * do, axis=1)
|
122 |
+
# write-back
|
123 |
+
tl.store(Delta + off_m, delta)
|
124 |
+
|
125 |
+
|
126 |
+
@triton.jit
|
127 |
+
def _bwd_kernel(
|
128 |
+
Q, K, V, sm_scale, Out, DO,
|
129 |
+
DQ, DK, DV,
|
130 |
+
L,
|
131 |
+
D,
|
132 |
+
stride_qz, stride_qh, stride_qm, stride_qk,
|
133 |
+
stride_kz, stride_kh, stride_kn, stride_kk,
|
134 |
+
stride_vz, stride_vh, stride_vk, stride_vn,
|
135 |
+
Z, H, N_CTX,
|
136 |
+
num_block,
|
137 |
+
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
138 |
+
BLOCK_N: tl.constexpr,
|
139 |
+
CAUSAL: tl.constexpr,
|
140 |
+
):
|
141 |
+
off_hz = tl.program_id(0)
|
142 |
+
off_z = off_hz // H
|
143 |
+
off_h = off_hz % H
|
144 |
+
qk_scale = sm_scale * 1.44269504
|
145 |
+
# offset pointers for batch/head
|
146 |
+
Q += off_z * stride_qz + off_h * stride_qh
|
147 |
+
K += off_z * stride_qz + off_h * stride_qh
|
148 |
+
V += off_z * stride_qz + off_h * stride_qh
|
149 |
+
DO += off_z * stride_qz + off_h * stride_qh
|
150 |
+
DQ += off_z * stride_qz + off_h * stride_qh
|
151 |
+
DK += off_z * stride_qz + off_h * stride_qh
|
152 |
+
DV += off_z * stride_qz + off_h * stride_qh
|
153 |
+
for start_n in range(0, num_block):
|
154 |
+
if CAUSAL:
|
155 |
+
lo = start_n * BLOCK_M
|
156 |
+
else:
|
157 |
+
lo = 0
|
158 |
+
# initialize row/col offsets
|
159 |
+
offs_qm = lo + tl.arange(0, BLOCK_M)
|
160 |
+
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
|
161 |
+
offs_m = tl.arange(0, BLOCK_N)
|
162 |
+
offs_k = tl.arange(0, BLOCK_DMODEL)
|
163 |
+
# initialize pointers to value-like data
|
164 |
+
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
165 |
+
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
166 |
+
v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
167 |
+
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
168 |
+
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
169 |
+
# pointer to row-wise quantities in value-like data
|
170 |
+
D_ptrs = D + off_hz * N_CTX
|
171 |
+
l_ptrs = L + off_hz * N_CTX
|
172 |
+
# initialize dv amd dk
|
173 |
+
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
174 |
+
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
175 |
+
# k and v stay in SRAM throughout
|
176 |
+
k = tl.load(k_ptrs)
|
177 |
+
v = tl.load(v_ptrs)
|
178 |
+
# loop over rows
|
179 |
+
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
|
180 |
+
offs_m_curr = start_m + offs_m
|
181 |
+
# load q, k, v, do on-chip
|
182 |
+
q = tl.load(q_ptrs)
|
183 |
+
# recompute p = softmax(qk, dim=-1).T
|
184 |
+
if CAUSAL:
|
185 |
+
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf"))
|
186 |
+
else:
|
187 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
188 |
+
qk += tl.dot(q, tl.trans(k))
|
189 |
+
qk *= qk_scale
|
190 |
+
l_i = tl.load(l_ptrs + offs_m_curr)
|
191 |
+
p = tl.math.exp2(qk - l_i[:, None])
|
192 |
+
# compute dv
|
193 |
+
do = tl.load(do_ptrs)
|
194 |
+
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
|
195 |
+
# compute dp = dot(v, do)
|
196 |
+
Di = tl.load(D_ptrs + offs_m_curr)
|
197 |
+
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
198 |
+
dp += tl.dot(do, tl.trans(v))
|
199 |
+
# compute ds = p * (dp - delta[:, None])
|
200 |
+
ds = p * dp * sm_scale
|
201 |
+
# compute dk = dot(ds.T, q)
|
202 |
+
dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
|
203 |
+
# compute dq
|
204 |
+
dq = tl.load(dq_ptrs)
|
205 |
+
dq += tl.dot(ds.to(Q.dtype.element_ty), k)
|
206 |
+
tl.store(dq_ptrs, dq)
|
207 |
+
# increment pointers
|
208 |
+
dq_ptrs += BLOCK_M * stride_qm
|
209 |
+
q_ptrs += BLOCK_M * stride_qm
|
210 |
+
do_ptrs += BLOCK_M * stride_qm
|
211 |
+
# write-back
|
212 |
+
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
213 |
+
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
214 |
+
tl.store(dv_ptrs, dv)
|
215 |
+
tl.store(dk_ptrs, dk)
|
216 |
+
|
217 |
+
|
218 |
+
empty = torch.empty(128, device="cuda")
|
219 |
+
|
220 |
+
|
221 |
+
class _attention(torch.autograd.Function):
|
222 |
+
|
223 |
+
@staticmethod
|
224 |
+
def forward(ctx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal, sm_scale):
|
225 |
+
# shape constraints
|
226 |
+
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
227 |
+
assert Lq == Lk and Lk == Lv
|
228 |
+
assert Lk in {16, 32, 64, 128}
|
229 |
+
o = torch.empty_like(q)
|
230 |
+
BLOCK_M = 128
|
231 |
+
BLOCK_N = 64
|
232 |
+
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
|
233 |
+
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
234 |
+
|
235 |
+
num_warps = 4 if Lk <= 64 else 8
|
236 |
+
_fwd_kernel[grid](
|
237 |
+
q, k, v, sm_scale,
|
238 |
+
L,
|
239 |
+
o,
|
240 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
241 |
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
242 |
+
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
243 |
+
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
244 |
+
q.shape[0], q.shape[1], q.shape[2],
|
245 |
+
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
|
246 |
+
IS_CAUSAL=causal,
|
247 |
+
num_warps=num_warps,
|
248 |
+
num_stages=4)
|
249 |
+
|
250 |
+
ctx.save_for_backward(q, k, v, o, L)
|
251 |
+
ctx.grid = grid
|
252 |
+
ctx.sm_scale = sm_scale
|
253 |
+
ctx.BLOCK_DMODEL = Lk
|
254 |
+
ctx.causal = causal
|
255 |
+
return o
|
256 |
+
|
257 |
+
@staticmethod
|
258 |
+
def backward(ctx, do):
|
259 |
+
BLOCK = 128
|
260 |
+
q, k, v, o, L = ctx.saved_tensors
|
261 |
+
do = do.contiguous()
|
262 |
+
dq = torch.zeros_like(q, dtype=torch.float32)
|
263 |
+
dk = torch.empty_like(k)
|
264 |
+
dv = torch.empty_like(v)
|
265 |
+
delta = torch.empty_like(L)
|
266 |
+
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
|
267 |
+
o, do,
|
268 |
+
delta,
|
269 |
+
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
270 |
+
)
|
271 |
+
_bwd_kernel[(ctx.grid[1],)](
|
272 |
+
q, k, v, ctx.sm_scale,
|
273 |
+
o, do,
|
274 |
+
dq, dk, dv,
|
275 |
+
L, delta,
|
276 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
277 |
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
278 |
+
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
279 |
+
q.shape[0], q.shape[1], q.shape[2],
|
280 |
+
ctx.grid[0],
|
281 |
+
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
282 |
+
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
|
283 |
+
CAUSAL=ctx.causal,
|
284 |
+
num_stages=1,
|
285 |
+
)
|
286 |
+
return dq, dk, dv, None, None
|
287 |
+
|
288 |
+
|
289 |
+
attention = _attention.apply
|
Andromeda/Andromeda/core/transformer.py
ADDED
@@ -0,0 +1,1376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from random import random
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn, einsum, Tensor
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from functools import partial, wraps
|
9 |
+
from inspect import isfunction
|
10 |
+
from dataclasses import dataclass
|
11 |
+
from typing import List
|
12 |
+
|
13 |
+
from einops import rearrange, repeat
|
14 |
+
|
15 |
+
from Andromeda.core.attend import Attend, Intermediates
|
16 |
+
from Andromeda.core.autoregressive_wrapper import AutoregressiveWrapper
|
17 |
+
|
18 |
+
from abc import ABC, abstractmethod
|
19 |
+
# import bitsandbytes as bnb
|
20 |
+
|
21 |
+
# constants
|
22 |
+
|
23 |
+
DEFAULT_DIM_HEAD = 64
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
class LayerIntermediates:
|
27 |
+
hiddens: List[Tensor] = None
|
28 |
+
attn_intermediates: List[Intermediates] = None
|
29 |
+
|
30 |
+
# helpers
|
31 |
+
|
32 |
+
def exists(val):
|
33 |
+
return val is not None
|
34 |
+
|
35 |
+
def default(val, d):
|
36 |
+
if exists(val):
|
37 |
+
return val
|
38 |
+
return d() if isfunction(d) else d
|
39 |
+
|
40 |
+
def cast_tuple(val, depth):
|
41 |
+
return val if isinstance(val, tuple) else (val,) * depth
|
42 |
+
|
43 |
+
def maybe(fn):
|
44 |
+
@wraps(fn)
|
45 |
+
def inner(x, *args, **kwargs):
|
46 |
+
if not exists(x):
|
47 |
+
return x
|
48 |
+
return fn(x, *args, **kwargs)
|
49 |
+
return inner
|
50 |
+
|
51 |
+
class always():
|
52 |
+
def __init__(self, val):
|
53 |
+
self.val = val
|
54 |
+
def __call__(self, *args, **kwargs):
|
55 |
+
return self.val
|
56 |
+
|
57 |
+
class not_equals():
|
58 |
+
def __init__(self, val):
|
59 |
+
self.val = val
|
60 |
+
def __call__(self, x, *args, **kwargs):
|
61 |
+
return x != self.val
|
62 |
+
|
63 |
+
class equals():
|
64 |
+
def __init__(self, val):
|
65 |
+
self.val = val
|
66 |
+
def __call__(self, x, *args, **kwargs):
|
67 |
+
return x == self.val
|
68 |
+
|
69 |
+
# tensor helpers
|
70 |
+
|
71 |
+
def max_neg_value(tensor):
|
72 |
+
return -torch.finfo(tensor.dtype).max
|
73 |
+
|
74 |
+
def l2norm(t, groups = 1):
|
75 |
+
t = rearrange(t, '... (g d) -> ... g d', g = groups)
|
76 |
+
t = F.normalize(t, p = 2, dim = -1)
|
77 |
+
return rearrange(t, '... g d -> ... (g d)')
|
78 |
+
|
79 |
+
def pad_at_dim(t, pad, dim = -1, value = 0.):
|
80 |
+
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
81 |
+
zeros = ((0, 0) * dims_from_right)
|
82 |
+
return F.pad(t, (*zeros, *pad), value = value)
|
83 |
+
|
84 |
+
def or_reduce(masks):
|
85 |
+
head, *body = masks
|
86 |
+
for rest in body:
|
87 |
+
head = head | rest
|
88 |
+
return head
|
89 |
+
|
90 |
+
# init helpers
|
91 |
+
|
92 |
+
def init_zero_(layer):
|
93 |
+
nn.init.constant_(layer.weight, 0.)
|
94 |
+
if exists(layer.bias):
|
95 |
+
nn.init.constant_(layer.bias, 0.)
|
96 |
+
|
97 |
+
# keyword argument helpers
|
98 |
+
|
99 |
+
def pick_and_pop(keys, d):
|
100 |
+
values = list(map(lambda key: d.pop(key), keys))
|
101 |
+
return dict(zip(keys, values))
|
102 |
+
|
103 |
+
def group_dict_by_key(cond, d):
|
104 |
+
return_val = [dict(),dict()]
|
105 |
+
for key in d.keys():
|
106 |
+
match = bool(cond(key))
|
107 |
+
ind = int(not match)
|
108 |
+
return_val[ind][key] = d[key]
|
109 |
+
return (*return_val,)
|
110 |
+
|
111 |
+
def string_begins_with(prefix, str):
|
112 |
+
return str.startswith(prefix)
|
113 |
+
|
114 |
+
def group_by_key_prefix(prefix, d):
|
115 |
+
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
116 |
+
|
117 |
+
def groupby_prefix_and_trim(prefix, d):
|
118 |
+
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
119 |
+
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
120 |
+
return kwargs_without_prefix, kwargs
|
121 |
+
|
122 |
+
# initializations
|
123 |
+
|
124 |
+
def deepnorm_init(
|
125 |
+
transformer,
|
126 |
+
beta,
|
127 |
+
module_name_match_list = ['.ff.', '.to_v', '.to_out']
|
128 |
+
):
|
129 |
+
for name, module in transformer.named_modules():
|
130 |
+
if type(module) != nn.Linear:
|
131 |
+
continue
|
132 |
+
|
133 |
+
needs_beta_gain = any(map(lambda substr: substr in name, module_name_match_list))
|
134 |
+
gain = beta if needs_beta_gain else 1
|
135 |
+
nn.init.xavier_normal_(module.weight.data, gain = gain)
|
136 |
+
|
137 |
+
if exists(module.bias):
|
138 |
+
nn.init.constant_(module.bias.data, 0)
|
139 |
+
|
140 |
+
# structured dropout, more effective than traditional attention dropouts
|
141 |
+
|
142 |
+
def dropout_seq(seq, mask, dropout):
|
143 |
+
b, n, *_, device = *seq.shape, seq.device
|
144 |
+
logits = torch.randn(b, n, device = device)
|
145 |
+
|
146 |
+
if exists(mask):
|
147 |
+
mask_value = max_neg_value(logits)
|
148 |
+
logits = logits.masked_fill(~mask, mask_value)
|
149 |
+
|
150 |
+
keep_prob = 1. - dropout
|
151 |
+
num_keep = max(1, int(keep_prob * n))
|
152 |
+
keep_indices = logits.topk(num_keep, dim = 1).indices
|
153 |
+
|
154 |
+
batch_indices = torch.arange(b, device = device)
|
155 |
+
batch_indices = rearrange(batch_indices, 'b -> b 1')
|
156 |
+
|
157 |
+
seq = seq[batch_indices, keep_indices]
|
158 |
+
|
159 |
+
if exists(mask):
|
160 |
+
seq_counts = mask.sum(dim = -1)
|
161 |
+
seq_keep_counts = torch.ceil(seq_counts * keep_prob).int()
|
162 |
+
keep_mask = torch.arange(num_keep, device = device) < rearrange(seq_keep_counts, 'b -> b 1')
|
163 |
+
|
164 |
+
mask = mask[batch_indices, keep_indices] & keep_mask
|
165 |
+
|
166 |
+
return seq, mask
|
167 |
+
|
168 |
+
# activations
|
169 |
+
|
170 |
+
class ReluSquared(nn.Module):
|
171 |
+
def forward(self, x):
|
172 |
+
return F.relu(x) ** 2
|
173 |
+
|
174 |
+
|
175 |
+
#tokenization
|
176 |
+
class BaseTokenizer(ABC):
|
177 |
+
@abstractmethod
|
178 |
+
def tokenize(self, text: str) -> List[int]:
|
179 |
+
pass
|
180 |
+
|
181 |
+
class CustomTokenizer(BaseTokenizer):
|
182 |
+
def tokenize(self, text: str) -> List[int]:
|
183 |
+
# Your custom tokenization algorithm
|
184 |
+
tokens = ...
|
185 |
+
return tokens
|
186 |
+
|
187 |
+
# embedding
|
188 |
+
|
189 |
+
class BaseEmbedding(ABC):
|
190 |
+
@abstractmethod
|
191 |
+
def get_embedding(self, num_tokens: int, dim: int) -> nn.Module:
|
192 |
+
# Custom embedding function or model
|
193 |
+
embedding = ...
|
194 |
+
|
195 |
+
return embedding
|
196 |
+
|
197 |
+
class AndromedaEmbedding(BaseEmbedding):
|
198 |
+
def get_embedding(self, num_tokens: int, dim: int) -> nn.Module:
|
199 |
+
embedding = nn.Embedding(num_tokens, dim)
|
200 |
+
|
201 |
+
return embedding
|
202 |
+
|
203 |
+
# class AndromedaBnBEmbedding(BaseEmbedding):
|
204 |
+
# def get_embedding(self, num_tokens: int, dim: int, padding_idx: int = 0) -> bnb.nn.modules:
|
205 |
+
# embedding = bnb.nn.modules.Embedding(num_tokens, dim, padding_idx)
|
206 |
+
|
207 |
+
# return embedding
|
208 |
+
|
209 |
+
class TokenEmbedding(nn.Module):
|
210 |
+
def __init__(self, dim, num_tokens, embedding_provider: BaseEmbedding, l2norm_embed = False):
|
211 |
+
super().__init__()
|
212 |
+
self.l2norm_embed = l2norm_embed
|
213 |
+
self.emb = embedding_provider.get_embedding(num_tokens, dim)
|
214 |
+
# nn.Embedding(num_tokens, dim)
|
215 |
+
|
216 |
+
def forward(self, x):
|
217 |
+
token_emb = self.emb(x)
|
218 |
+
return l2norm(token_emb) if self.l2norm_embed else token_emb
|
219 |
+
|
220 |
+
# positional embeddings
|
221 |
+
|
222 |
+
class AbsolutePositionalEmbedding(nn.Module):
|
223 |
+
def __init__(self, dim, max_seq_len, l2norm_embed = False):
|
224 |
+
super().__init__()
|
225 |
+
self.scale = dim ** -0.5 if not l2norm_embed else 1.
|
226 |
+
self.max_seq_len = max_seq_len
|
227 |
+
self.l2norm_embed = l2norm_embed
|
228 |
+
self.emb = nn.Embedding(max_seq_len, dim)
|
229 |
+
|
230 |
+
def forward(self, x, pos = None):
|
231 |
+
seq_len, device = x.shape[1], x.device
|
232 |
+
assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
|
233 |
+
|
234 |
+
if not exists(pos):
|
235 |
+
pos = torch.arange(seq_len, device = device)
|
236 |
+
|
237 |
+
pos_emb = self.emb(pos)
|
238 |
+
pos_emb = pos_emb * self.scale
|
239 |
+
return l2norm(pos_emb) if self.l2norm_embed else pos_emb
|
240 |
+
|
241 |
+
class ScaledSinusoidalEmbedding(nn.Module):
|
242 |
+
def __init__(self, dim, theta = 10000):
|
243 |
+
super().__init__()
|
244 |
+
assert (dim % 2) == 0
|
245 |
+
self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
|
246 |
+
|
247 |
+
half_dim = dim // 2
|
248 |
+
freq_seq = torch.arange(half_dim).float() / half_dim
|
249 |
+
inv_freq = theta ** -freq_seq
|
250 |
+
self.register_buffer('inv_freq', inv_freq, persistent = False)
|
251 |
+
|
252 |
+
def forward(self, x, pos = None):
|
253 |
+
seq_len, device = x.shape[1], x.device
|
254 |
+
|
255 |
+
if not exists(pos):
|
256 |
+
pos = torch.arange(seq_len, device = device)
|
257 |
+
|
258 |
+
emb = einsum('i, j -> i j', pos, self.inv_freq)
|
259 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
|
260 |
+
return emb * self.scale
|
261 |
+
|
262 |
+
class RelativePositionBias(nn.Module):
|
263 |
+
def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 8):
|
264 |
+
super().__init__()
|
265 |
+
self.scale = scale
|
266 |
+
self.causal = causal
|
267 |
+
self.num_buckets = num_buckets
|
268 |
+
self.max_distance = max_distance
|
269 |
+
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
|
270 |
+
|
271 |
+
@staticmethod
|
272 |
+
def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128):
|
273 |
+
ret = 0
|
274 |
+
n = -relative_position
|
275 |
+
if not causal:
|
276 |
+
num_buckets //= 2
|
277 |
+
ret += (n < 0).long() * num_buckets
|
278 |
+
n = torch.abs(n)
|
279 |
+
else:
|
280 |
+
n = torch.max(n, torch.zeros_like(n))
|
281 |
+
|
282 |
+
max_exact = num_buckets // 2
|
283 |
+
is_small = n < max_exact
|
284 |
+
|
285 |
+
val_if_large = max_exact + (
|
286 |
+
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
|
287 |
+
).long()
|
288 |
+
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
|
289 |
+
|
290 |
+
ret += torch.where(is_small, n, val_if_large)
|
291 |
+
return ret
|
292 |
+
|
293 |
+
@property
|
294 |
+
def device(self):
|
295 |
+
return next(self.parameters()).device
|
296 |
+
|
297 |
+
def forward(self, i, j):
|
298 |
+
device = self.device
|
299 |
+
q_pos = torch.arange(j - i, j, dtype = torch.long, device = device)
|
300 |
+
k_pos = torch.arange(j, dtype = torch.long, device = device)
|
301 |
+
rel_pos = k_pos[None, :] - q_pos[:, None]
|
302 |
+
rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
|
303 |
+
values = self.relative_attention_bias(rp_bucket)
|
304 |
+
bias = rearrange(values, 'i j h -> h i j')
|
305 |
+
return bias * self.scale
|
306 |
+
|
307 |
+
class DynamicPositionBias(nn.Module):
|
308 |
+
def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
|
309 |
+
super().__init__()
|
310 |
+
assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
|
311 |
+
self.log_distance = log_distance
|
312 |
+
|
313 |
+
self.mlp = nn.ModuleList([])
|
314 |
+
|
315 |
+
self.mlp.append(nn.Sequential(
|
316 |
+
nn.Linear(1, dim),
|
317 |
+
nn.LayerNorm(dim) if norm else nn.Identity(),
|
318 |
+
nn.SiLU()
|
319 |
+
))
|
320 |
+
|
321 |
+
for _ in range(depth - 1):
|
322 |
+
self.mlp.append(nn.Sequential(
|
323 |
+
nn.Linear(dim, dim),
|
324 |
+
nn.LayerNorm(dim) if norm else nn.Identity(),
|
325 |
+
nn.SiLU()
|
326 |
+
))
|
327 |
+
|
328 |
+
self.mlp.append(nn.Linear(dim, heads))
|
329 |
+
|
330 |
+
@property
|
331 |
+
def device(self):
|
332 |
+
return next(self.parameters()).device
|
333 |
+
|
334 |
+
def forward(self, i, j):
|
335 |
+
assert i == j
|
336 |
+
n, device = j, self.device
|
337 |
+
|
338 |
+
# get the (n x n) matrix of distances
|
339 |
+
seq_arange = torch.arange(n, device = device)
|
340 |
+
context_arange = torch.arange(n, device = device)
|
341 |
+
indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j')
|
342 |
+
indices += (n - 1)
|
343 |
+
|
344 |
+
# input to continuous positions MLP
|
345 |
+
pos = torch.arange(-n + 1, n, device = device).float()
|
346 |
+
pos = rearrange(pos, '... -> ... 1')
|
347 |
+
|
348 |
+
if self.log_distance:
|
349 |
+
pos = torch.sign(pos) * torch.log(pos.abs() + 1) # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1)
|
350 |
+
|
351 |
+
for layer in self.mlp:
|
352 |
+
pos = layer(pos)
|
353 |
+
|
354 |
+
# get position biases
|
355 |
+
bias = pos[indices]
|
356 |
+
bias = rearrange(bias, 'i j h -> h i j')
|
357 |
+
return bias
|
358 |
+
|
359 |
+
class AlibiPositionalBias(nn.Module):
|
360 |
+
def __init__(self, heads, total_heads, **kwargs):
|
361 |
+
super().__init__()
|
362 |
+
self.heads = heads
|
363 |
+
self.total_heads = total_heads
|
364 |
+
|
365 |
+
slopes = Tensor(self._get_slopes(heads))
|
366 |
+
slopes = rearrange(slopes, 'h -> h 1 1')
|
367 |
+
self.register_buffer('slopes', slopes, persistent = False)
|
368 |
+
self.register_buffer('bias', None, persistent = False)
|
369 |
+
|
370 |
+
def get_bias(self, i, j, device):
|
371 |
+
i_arange = torch.arange(j - i, j, device = device)
|
372 |
+
j_arange = torch.arange(j, device = device)
|
373 |
+
bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1'))
|
374 |
+
return bias
|
375 |
+
|
376 |
+
@staticmethod
|
377 |
+
def _get_slopes(heads):
|
378 |
+
def get_slopes_power_of_2(n):
|
379 |
+
start = (2**(-2**-(math.log2(n)-3)))
|
380 |
+
ratio = start
|
381 |
+
return [start*ratio**i for i in range(n)]
|
382 |
+
|
383 |
+
if math.log2(heads).is_integer():
|
384 |
+
return get_slopes_power_of_2(heads)
|
385 |
+
|
386 |
+
closest_power_of_2 = 2 ** math.floor(math.log2(heads))
|
387 |
+
return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]
|
388 |
+
|
389 |
+
@property
|
390 |
+
def device(self):
|
391 |
+
return next(self.buffers()).device
|
392 |
+
|
393 |
+
def forward(self, i, j):
|
394 |
+
h, device = self.total_heads, self.device
|
395 |
+
|
396 |
+
if exists(self.bias) and self.bias.shape[-1] >= j:
|
397 |
+
return self.bias[..., :i, :j]
|
398 |
+
|
399 |
+
bias = self.get_bias(i, j, device)
|
400 |
+
bias = bias * self.slopes
|
401 |
+
|
402 |
+
num_heads_unalibied = h - bias.shape[0]
|
403 |
+
bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = 0)
|
404 |
+
self.register_buffer('bias', bias, persistent = False)
|
405 |
+
|
406 |
+
return self.bias
|
407 |
+
|
408 |
+
class LearnedAlibiPositionalBias(AlibiPositionalBias):
|
409 |
+
def __init__(self, heads, total_heads):
|
410 |
+
super().__init__(heads, total_heads)
|
411 |
+
log_slopes = torch.log(self.slopes)
|
412 |
+
self.learned_logslopes = nn.Parameter(log_slopes)
|
413 |
+
|
414 |
+
def forward(self, i, j):
|
415 |
+
h, i, j, device = self.heads, self.device
|
416 |
+
|
417 |
+
def get_slopes(param):
|
418 |
+
return pad_at_dim(param.exp(), (0, h - param.shape[0]), dim = -2)
|
419 |
+
|
420 |
+
if exists(self.bias) and self.bias.shape[-1] >= j:
|
421 |
+
bias = self.bias[..., :i, :j]
|
422 |
+
else:
|
423 |
+
bias = self.get_bias(i, j, device)
|
424 |
+
self.register_buffer('bias', bias, persistent = False)
|
425 |
+
|
426 |
+
slopes = get_slopes(self.learned_logslopes)
|
427 |
+
bias = bias * slopes
|
428 |
+
|
429 |
+
return bias
|
430 |
+
|
431 |
+
class RotaryEmbedding(nn.Module):
|
432 |
+
def __init__(
|
433 |
+
self,
|
434 |
+
dim,
|
435 |
+
use_xpos = False,
|
436 |
+
scale_base = 512,
|
437 |
+
interpolation_factor=1.,
|
438 |
+
base=10000,
|
439 |
+
base_rescale_factor=1.
|
440 |
+
):
|
441 |
+
super().__init__()
|
442 |
+
base *= base_rescale_factor ** (dim / (dim - 2))
|
443 |
+
|
444 |
+
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
|
445 |
+
|
446 |
+
self.register_buffer('inv_freq', inv_freq)
|
447 |
+
|
448 |
+
if not use_xpos:
|
449 |
+
self.register_buffer('scale', None)
|
450 |
+
return
|
451 |
+
|
452 |
+
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
453 |
+
|
454 |
+
self.scale_base = scale_base
|
455 |
+
self.register_buffer('scale', scale)
|
456 |
+
|
457 |
+
def forward(self, seq_len, device):
|
458 |
+
t = torch.arange(seq_len, device = device).type_as(self.inv_freq)
|
459 |
+
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
|
460 |
+
freqs = torch.cat((freqs, freqs), dim = -1)
|
461 |
+
|
462 |
+
if not exists(self.scale):
|
463 |
+
return freqs, 1.
|
464 |
+
|
465 |
+
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
|
466 |
+
scale = self.scale ** rearrange(power, 'n -> n 1')
|
467 |
+
scale = torch.cat((scale, scale), dim = -1)
|
468 |
+
|
469 |
+
return freqs, scale
|
470 |
+
|
471 |
+
|
472 |
+
def rotate_half(x):
|
473 |
+
x = rearrange(x, '... (j d) -> ... j d', j = 2)
|
474 |
+
x1, x2 = x.unbind(dim = -2)
|
475 |
+
return torch.cat((-x2, x1), dim = -1)
|
476 |
+
|
477 |
+
def apply_rotary_pos_emb(t, freqs, scale = 1):
|
478 |
+
seq_len = t.shape[-2]
|
479 |
+
freqs = freqs[-seq_len:, :]
|
480 |
+
return (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
481 |
+
|
482 |
+
# norms
|
483 |
+
|
484 |
+
class Scale(nn.Module):
|
485 |
+
def __init__(self, value, fn):
|
486 |
+
super().__init__()
|
487 |
+
self.value = value
|
488 |
+
self.fn = fn
|
489 |
+
|
490 |
+
def forward(self, x, **kwargs):
|
491 |
+
out = self.fn(x, **kwargs)
|
492 |
+
def scale_fn(t):
|
493 |
+
return t * self.value
|
494 |
+
|
495 |
+
if not isinstance(out, tuple):
|
496 |
+
return scale_fn(out)
|
497 |
+
|
498 |
+
return (scale_fn(out[0]), *out[1:])
|
499 |
+
|
500 |
+
class ScaleNorm(nn.Module):
|
501 |
+
def __init__(self, dim, eps = 1e-5):
|
502 |
+
super().__init__()
|
503 |
+
self.eps = eps
|
504 |
+
self.g = nn.Parameter(torch.ones(1) * (dim ** -0.5))
|
505 |
+
|
506 |
+
def forward(self, x):
|
507 |
+
norm = torch.norm(x, dim = -1, keepdim = True)
|
508 |
+
return x / norm.clamp(min = self.eps) * self.g
|
509 |
+
|
510 |
+
class RMSNorm(nn.Module):
|
511 |
+
def __init__(self, dim, eps = 1e-8):
|
512 |
+
super().__init__()
|
513 |
+
self.scale = dim ** -0.5
|
514 |
+
self.eps = eps
|
515 |
+
self.g = nn.Parameter(torch.ones(dim))
|
516 |
+
|
517 |
+
def forward(self, x):
|
518 |
+
norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
|
519 |
+
return x / norm.clamp(min = self.eps) * self.g
|
520 |
+
|
521 |
+
# residual and residual gates
|
522 |
+
|
523 |
+
class Residual(nn.Module):
|
524 |
+
def __init__(self, dim, scale_residual = False, scale_residual_constant = 1.):
|
525 |
+
super().__init__()
|
526 |
+
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
|
527 |
+
self.scale_residual_constant = scale_residual_constant
|
528 |
+
|
529 |
+
def forward(self, x, residual):
|
530 |
+
if exists(self.residual_scale):
|
531 |
+
residual = residual * self.residual_scale
|
532 |
+
|
533 |
+
if self.scale_residual_constant != 1:
|
534 |
+
residual = residual * self.scale_residual_constant
|
535 |
+
|
536 |
+
return x + residual
|
537 |
+
|
538 |
+
class GRUGating(nn.Module):
|
539 |
+
def __init__(self, dim, scale_residual = False, **kwargs):
|
540 |
+
super().__init__()
|
541 |
+
self.gru = nn.GRUCell(dim, dim)
|
542 |
+
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
|
543 |
+
|
544 |
+
def forward(self, x, residual):
|
545 |
+
if exists(self.residual_scale):
|
546 |
+
residual = residual * self.residual_scale
|
547 |
+
|
548 |
+
gated_output = self.gru(
|
549 |
+
rearrange(x, 'b n d -> (b n) d'),
|
550 |
+
rearrange(residual, 'b n d -> (b n) d')
|
551 |
+
)
|
552 |
+
|
553 |
+
return gated_output.reshape_as(x)
|
554 |
+
|
555 |
+
# token shifting
|
556 |
+
|
557 |
+
def shift(t, amount, mask = None):
|
558 |
+
if amount == 0:
|
559 |
+
return t
|
560 |
+
else:
|
561 |
+
amount = min(amount, t.shape[1])
|
562 |
+
|
563 |
+
if exists(mask):
|
564 |
+
t = t.masked_fill(~mask[..., None], 0.)
|
565 |
+
|
566 |
+
return pad_at_dim(t, (amount, -amount), dim = - 2, value = 0.)
|
567 |
+
|
568 |
+
class ShiftTokens(nn.Module):
|
569 |
+
def __init__(self, shifts, fn):
|
570 |
+
super().__init__()
|
571 |
+
self.fn = fn
|
572 |
+
self.shifts = tuple(shifts)
|
573 |
+
|
574 |
+
def forward(self, x, **kwargs):
|
575 |
+
mask = kwargs.get('mask', None)
|
576 |
+
shifts = self.shifts
|
577 |
+
segments = len(shifts)
|
578 |
+
feats_per_shift = x.shape[-1] // segments
|
579 |
+
splitted = x.split(feats_per_shift, dim = -1)
|
580 |
+
segments_to_shift, rest = splitted[:segments], splitted[segments:]
|
581 |
+
segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
|
582 |
+
x = torch.cat((*segments_to_shift, *rest), dim = -1)
|
583 |
+
return self.fn(x, **kwargs)
|
584 |
+
|
585 |
+
# feedforward
|
586 |
+
|
587 |
+
class GLU(nn.Module):
|
588 |
+
def __init__(self, dim_in, dim_out, activation):
|
589 |
+
super().__init__()
|
590 |
+
self.act = activation
|
591 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
592 |
+
|
593 |
+
def forward(self, x):
|
594 |
+
x, gate = self.proj(x).chunk(2, dim = -1)
|
595 |
+
return x * self.act(gate)
|
596 |
+
|
597 |
+
class FeedForward(nn.Module):
|
598 |
+
def __init__(
|
599 |
+
self,
|
600 |
+
dim,
|
601 |
+
dim_out = None,
|
602 |
+
mult = 4,
|
603 |
+
glu = False,
|
604 |
+
swish = False,
|
605 |
+
relu_squared = False,
|
606 |
+
post_act_ln = False,
|
607 |
+
dropout = 0.,
|
608 |
+
no_bias = False,
|
609 |
+
zero_init_output = False
|
610 |
+
):
|
611 |
+
super().__init__()
|
612 |
+
inner_dim = int(dim * mult)
|
613 |
+
dim_out = default(dim_out, dim)
|
614 |
+
|
615 |
+
if relu_squared:
|
616 |
+
activation = ReluSquared()
|
617 |
+
elif swish:
|
618 |
+
activation = nn.SiLU()
|
619 |
+
else:
|
620 |
+
activation = nn.GELU()
|
621 |
+
|
622 |
+
project_in = nn.Sequential(
|
623 |
+
nn.Linear(dim, inner_dim, bias = not no_bias),
|
624 |
+
activation
|
625 |
+
) if not glu else GLU(dim, inner_dim, activation)
|
626 |
+
|
627 |
+
self.ff = nn.Sequential(
|
628 |
+
project_in,
|
629 |
+
nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
|
630 |
+
nn.Dropout(dropout),
|
631 |
+
nn.Linear(inner_dim, dim_out, bias = not no_bias)
|
632 |
+
)
|
633 |
+
|
634 |
+
# init last linear layer to 0
|
635 |
+
if zero_init_output:
|
636 |
+
init_zero_(self.ff[-1])
|
637 |
+
|
638 |
+
def forward(self, x):
|
639 |
+
return self.ff(x)
|
640 |
+
|
641 |
+
# attention. it is all we need
|
642 |
+
|
643 |
+
class Attention(nn.Module):
|
644 |
+
def __init__(
|
645 |
+
self,
|
646 |
+
dim,
|
647 |
+
dim_head = DEFAULT_DIM_HEAD,
|
648 |
+
heads = 8,
|
649 |
+
causal = False,
|
650 |
+
flash = False,
|
651 |
+
talking_heads = False,
|
652 |
+
head_scale = False,
|
653 |
+
sparse_topk = None,
|
654 |
+
num_mem_kv = 0,
|
655 |
+
dropout = 0.,
|
656 |
+
on_attn = False,
|
657 |
+
gate_values = False,
|
658 |
+
zero_init_output = False,
|
659 |
+
max_attend_past = None,
|
660 |
+
qk_norm = False,
|
661 |
+
qk_norm_groups = 1,
|
662 |
+
qk_norm_scale = 10,
|
663 |
+
qk_norm_dim_scale = False,
|
664 |
+
one_kv_head = False,
|
665 |
+
shared_kv = False,
|
666 |
+
value_dim_head = None,
|
667 |
+
tensor_product = False # https://arxiv.org/abs/2208.06061
|
668 |
+
):
|
669 |
+
super().__init__()
|
670 |
+
self.scale = dim_head ** -0.5
|
671 |
+
|
672 |
+
self.heads = heads
|
673 |
+
self.causal = causal
|
674 |
+
self.max_attend_past = max_attend_past
|
675 |
+
|
676 |
+
value_dim_head = default(value_dim_head, dim_head)
|
677 |
+
q_dim = k_dim = dim_head * heads
|
678 |
+
v_dim = out_dim = value_dim_head * heads
|
679 |
+
|
680 |
+
self.one_kv_head = one_kv_head
|
681 |
+
if one_kv_head:
|
682 |
+
k_dim = dim_head
|
683 |
+
v_dim = value_dim_head
|
684 |
+
out_dim = v_dim * heads
|
685 |
+
|
686 |
+
self.to_q = nn.Linear(dim, q_dim, bias = False)
|
687 |
+
self.to_k = nn.Linear(dim, k_dim, bias = False)
|
688 |
+
|
689 |
+
# shared key / values, for further memory savings during inference
|
690 |
+
assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
|
691 |
+
self.to_v = nn.Linear(dim, v_dim, bias = False) if not shared_kv else None
|
692 |
+
|
693 |
+
# relations projection from tp-attention
|
694 |
+
self.to_r = nn.Linear(dim, v_dim, bias = False) if tensor_product else None
|
695 |
+
|
696 |
+
# add GLU gating for aggregated values, from alphafold2
|
697 |
+
self.to_v_gate = None
|
698 |
+
if gate_values:
|
699 |
+
self.to_v_gate = nn.Linear(dim, out_dim)
|
700 |
+
nn.init.constant_(self.to_v_gate.weight, 0)
|
701 |
+
nn.init.constant_(self.to_v_gate.bias, 1)
|
702 |
+
|
703 |
+
# cosine sim attention
|
704 |
+
self.qk_norm = qk_norm
|
705 |
+
self.qk_norm_groups = qk_norm_groups
|
706 |
+
self.qk_norm_scale = qk_norm_scale
|
707 |
+
|
708 |
+
# whether to use the rmsnorm (equivalent to cosine sim attention when scale is equal to 1) - https://arxiv.org/abs/2302.05442
|
709 |
+
self.qk_norm_dim_scale = qk_norm_dim_scale
|
710 |
+
|
711 |
+
self.qk_norm_q_scale = self.qk_norm_k_scale = 1
|
712 |
+
if qk_norm and qk_norm_dim_scale:
|
713 |
+
self.qk_norm_q_scale = nn.Parameter(torch.ones(dim_head))
|
714 |
+
self.qk_norm_k_scale = nn.Parameter(torch.ones(dim_head))
|
715 |
+
|
716 |
+
assert (not qk_norm) or (dim_head % qk_norm_groups) == 0, 'dimension per attention head must be divisible by the qk norm groups'
|
717 |
+
assert not (qk_norm and (dim_head // qk_norm_groups) <= 2), 'the group dimension may be too small (2 was too small in my tests, but 4 still works, surprisingly)'
|
718 |
+
|
719 |
+
# attend class - includes core attention algorithm + talking heads
|
720 |
+
|
721 |
+
self.attend = Attend(
|
722 |
+
heads = heads,
|
723 |
+
causal = causal,
|
724 |
+
talking_heads = talking_heads,
|
725 |
+
dropout = dropout,
|
726 |
+
qk_norm = qk_norm,
|
727 |
+
scale = qk_norm_scale if qk_norm else self.scale,
|
728 |
+
flash = flash
|
729 |
+
)
|
730 |
+
|
731 |
+
# head scaling
|
732 |
+
self.head_scale = head_scale
|
733 |
+
if head_scale:
|
734 |
+
self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1))
|
735 |
+
|
736 |
+
# explicit topk sparse attention
|
737 |
+
self.sparse_topk = sparse_topk
|
738 |
+
|
739 |
+
# add memory key / values
|
740 |
+
self.num_mem_kv = num_mem_kv
|
741 |
+
if num_mem_kv > 0:
|
742 |
+
self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
|
743 |
+
self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
|
744 |
+
|
745 |
+
# attention on attention
|
746 |
+
self.attn_on_attn = on_attn
|
747 |
+
self.to_out = nn.Sequential(nn.Linear(out_dim, dim * 2, bias = False), nn.GLU()) if on_attn else nn.Linear(out_dim, dim, bias = False)
|
748 |
+
|
749 |
+
# init output projection 0
|
750 |
+
if zero_init_output:
|
751 |
+
init_zero_(self.to_out)
|
752 |
+
|
753 |
+
def forward(
|
754 |
+
self,
|
755 |
+
x,
|
756 |
+
context = None,
|
757 |
+
mask = None,
|
758 |
+
context_mask = None,
|
759 |
+
attn_mask = None,
|
760 |
+
rel_pos = None,
|
761 |
+
rotary_pos_emb = None,
|
762 |
+
prev_attn = None,
|
763 |
+
mem = None
|
764 |
+
):
|
765 |
+
b, n, _, h, head_scale, device, has_context = *x.shape, self.heads, self.head_scale, x.device, exists(context)
|
766 |
+
kv_input = default(context, x)
|
767 |
+
|
768 |
+
q_input = x
|
769 |
+
k_input = kv_input
|
770 |
+
v_input = kv_input
|
771 |
+
r_input = x
|
772 |
+
|
773 |
+
if exists(mem):
|
774 |
+
k_input = torch.cat((mem, k_input), dim = -2)
|
775 |
+
v_input = torch.cat((mem, v_input), dim = -2)
|
776 |
+
|
777 |
+
q = self.to_q(q_input)
|
778 |
+
k = self.to_k(k_input)
|
779 |
+
v = self.to_v(v_input) if exists(self.to_v) else k
|
780 |
+
r = self.to_r(r_input) if exists(self.to_r) else None
|
781 |
+
|
782 |
+
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
|
783 |
+
|
784 |
+
if not self.one_kv_head:
|
785 |
+
k, v, r = map(lambda t: maybe(rearrange)(t, 'b n (h d) -> b h n d', h = h), (k, v, r))
|
786 |
+
|
787 |
+
if self.qk_norm:
|
788 |
+
qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
|
789 |
+
q, k = map(qk_l2norm, (q, k))
|
790 |
+
|
791 |
+
q = q * self.qk_norm_q_scale
|
792 |
+
k = k * self.qk_norm_k_scale
|
793 |
+
|
794 |
+
if exists(rotary_pos_emb) and not has_context:
|
795 |
+
freqs, xpos_scale = rotary_pos_emb
|
796 |
+
l = freqs.shape[-1]
|
797 |
+
|
798 |
+
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
|
799 |
+
(ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
|
800 |
+
|
801 |
+
ql, kl, vl = map(lambda arg: apply_rotary_pos_emb(arg[0], freqs, arg[1]), ((ql, q_xpos_scale), (kl, k_xpos_scale), (vl, k_xpos_scale)))
|
802 |
+
q, k, v = map(lambda t: torch.cat(t, dim = -1), ((ql, qr), (kl, kr), (vl, vr)))
|
803 |
+
|
804 |
+
input_mask = default(context_mask, mask)
|
805 |
+
|
806 |
+
if self.num_mem_kv > 0:
|
807 |
+
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v))
|
808 |
+
|
809 |
+
if self.qk_norm:
|
810 |
+
mem_k = l2norm(mem_k)
|
811 |
+
mem_k = mem_k * self.qk_norm_k_scale
|
812 |
+
|
813 |
+
k = torch.cat((mem_k, k), dim = -2)
|
814 |
+
v = torch.cat((mem_v, v), dim = -2)
|
815 |
+
|
816 |
+
if exists(input_mask):
|
817 |
+
input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True)
|
818 |
+
|
819 |
+
|
820 |
+
i, j = map(lambda t: t.shape[-2], (q, k))
|
821 |
+
|
822 |
+
# determine masking
|
823 |
+
|
824 |
+
max_neg_value(q)
|
825 |
+
masks = []
|
826 |
+
final_attn_mask = None
|
827 |
+
|
828 |
+
if exists(input_mask):
|
829 |
+
input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
|
830 |
+
masks.append(~input_mask)
|
831 |
+
|
832 |
+
if exists(attn_mask):
|
833 |
+
assert 2 <= attn_mask.ndim <= 4, 'attention mask must have greater than 2 dimensions but less than or equal to 4'
|
834 |
+
if attn_mask.ndim == 2:
|
835 |
+
attn_mask = rearrange(attn_mask, 'i j -> 1 1 i j')
|
836 |
+
elif attn_mask.ndim == 3:
|
837 |
+
attn_mask = rearrange(attn_mask, 'h i j -> 1 h i j')
|
838 |
+
masks.append(~attn_mask)
|
839 |
+
|
840 |
+
if exists(self.max_attend_past):
|
841 |
+
range_q = torch.arange(j - i, j, device = device)
|
842 |
+
range_k = torch.arange(j, device = device)
|
843 |
+
dist = rearrange(range_q, 'i -> 1 1 i 1') - rearrange(range_k, 'j -> 1 1 1 j')
|
844 |
+
max_attend_past_mask = dist > self.max_attend_past
|
845 |
+
masks.append(max_attend_past_mask)
|
846 |
+
|
847 |
+
if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
|
848 |
+
top, _ = dots.topk(self.sparse_topk, dim = -1)
|
849 |
+
vk = rearrange(top[..., -1], '... -> ... 1')
|
850 |
+
sparse_topk_mask = dots < vk
|
851 |
+
masks.append(sparse_topk_mask)
|
852 |
+
|
853 |
+
if len(masks) > 0:
|
854 |
+
final_attn_mask = or_reduce(masks)
|
855 |
+
|
856 |
+
# prepare relative positional bias, if needed
|
857 |
+
|
858 |
+
attn_bias = None
|
859 |
+
if exists(rel_pos):
|
860 |
+
attn_bias = rel_pos(i, j)
|
861 |
+
|
862 |
+
# attention is all we need
|
863 |
+
|
864 |
+
out, intermediates = self.attend(
|
865 |
+
q, k, v,
|
866 |
+
mask = final_attn_mask,
|
867 |
+
attn_bias = attn_bias,
|
868 |
+
prev_attn = prev_attn
|
869 |
+
)
|
870 |
+
|
871 |
+
# https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients
|
872 |
+
|
873 |
+
if exists(r):
|
874 |
+
out = out * r + out
|
875 |
+
|
876 |
+
# normformer scaling of heads
|
877 |
+
|
878 |
+
if head_scale:
|
879 |
+
out = out * self.head_scale_params
|
880 |
+
|
881 |
+
# merge heads
|
882 |
+
|
883 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
884 |
+
|
885 |
+
# alphafold2 styled gating of the values
|
886 |
+
|
887 |
+
if exists(self.to_v_gate):
|
888 |
+
gates = self.to_v_gate(x)
|
889 |
+
out = out * gates.sigmoid()
|
890 |
+
|
891 |
+
# combine the heads
|
892 |
+
|
893 |
+
out = self.to_out(out)
|
894 |
+
|
895 |
+
if exists(mask):
|
896 |
+
mask = rearrange(mask, 'b n -> b n 1')
|
897 |
+
out = out.masked_fill(~mask, 0.)
|
898 |
+
|
899 |
+
return out, intermediates
|
900 |
+
|
901 |
+
class AttentionLayers(nn.Module):
|
902 |
+
def __init__(
|
903 |
+
self,
|
904 |
+
dim,
|
905 |
+
depth,
|
906 |
+
heads = None,
|
907 |
+
causal = False,
|
908 |
+
cross_attend = False,
|
909 |
+
only_cross = False,
|
910 |
+
use_scalenorm = False,
|
911 |
+
use_rmsnorm = False,
|
912 |
+
alibi_pos_bias = False,
|
913 |
+
alibi_num_heads = None,
|
914 |
+
alibi_learned = False,
|
915 |
+
rel_pos_bias = False,
|
916 |
+
rel_pos_num_buckets = 32,
|
917 |
+
rel_pos_max_distance = 128,
|
918 |
+
dynamic_pos_bias = False,
|
919 |
+
dynamic_pos_bias_log_distance = False,
|
920 |
+
dynamic_pos_bias_mlp_depth = 2,
|
921 |
+
dynamic_pos_bias_norm = False,
|
922 |
+
rotary_pos_emb = False,
|
923 |
+
rotary_emb_dim = None,
|
924 |
+
rotary_xpos = False,
|
925 |
+
rotary_interpolation_factor=1.,
|
926 |
+
rotary_xpos_scale_base = 512,
|
927 |
+
rotary_base_rescale_factor=1.,
|
928 |
+
custom_layers = None,
|
929 |
+
sandwich_coef = None,
|
930 |
+
par_ratio = None,
|
931 |
+
residual_attn = False,
|
932 |
+
cross_residual_attn = False,
|
933 |
+
macaron = False,
|
934 |
+
pre_norm = True,
|
935 |
+
gate_residual = False,
|
936 |
+
scale_residual = False,
|
937 |
+
scale_residual_constant = 1.,
|
938 |
+
deepnorm = False,
|
939 |
+
shift_tokens = 0,
|
940 |
+
sandwich_norm = False,
|
941 |
+
resi_dual = False,
|
942 |
+
zero_init_branch_output = False,
|
943 |
+
layer_dropout = 0.,
|
944 |
+
cross_attn_tokens_dropout = 0.,
|
945 |
+
**kwargs
|
946 |
+
):
|
947 |
+
super().__init__()
|
948 |
+
rotary_pos_emb = rotary_pos_emb or rotary_xpos
|
949 |
+
|
950 |
+
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
|
951 |
+
attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs)
|
952 |
+
|
953 |
+
dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
|
954 |
+
|
955 |
+
self.dim = dim
|
956 |
+
self.depth = depth
|
957 |
+
self.layers = nn.ModuleList([])
|
958 |
+
|
959 |
+
self.has_pos_emb = rel_pos_bias or rotary_pos_emb
|
960 |
+
|
961 |
+
rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
|
962 |
+
|
963 |
+
assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
|
964 |
+
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor=rotary_interpolation_factor, base_rescale_factor=rotary_base_rescale_factor) if rotary_pos_emb else None
|
965 |
+
|
966 |
+
assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
|
967 |
+
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
|
968 |
+
|
969 |
+
# relative positional bias
|
970 |
+
|
971 |
+
flash_attn = attn_kwargs.get('flash', False)
|
972 |
+
assert (int(rel_pos_bias) + int(dynamic_pos_bias) + int(alibi_pos_bias)) <= 1, 'you can only choose up to one of t5, alibi, or dynamic positional bias'
|
973 |
+
|
974 |
+
self.rel_pos = None
|
975 |
+
if rel_pos_bias:
|
976 |
+
assert not flash_attn, 'flash attention not compatible with t5 relative positional bias'
|
977 |
+
self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
|
978 |
+
elif dynamic_pos_bias:
|
979 |
+
assert not flash_attn, 'flash attention not compatible with dynamic positional bias'
|
980 |
+
self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm)
|
981 |
+
elif alibi_pos_bias:
|
982 |
+
alibi_num_heads = default(alibi_num_heads, heads)
|
983 |
+
assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
|
984 |
+
alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned else AlibiPositionalBias
|
985 |
+
self.rel_pos = alibi_pos_klass(heads = alibi_num_heads, total_heads = heads)
|
986 |
+
|
987 |
+
# determine deepnorm and residual scale
|
988 |
+
|
989 |
+
if deepnorm:
|
990 |
+
assert scale_residual_constant == 1, 'scale residual constant is being overridden by deep norm settings'
|
991 |
+
pre_norm = sandwich_norm = resi_dual = False
|
992 |
+
scale_residual = True
|
993 |
+
scale_residual_constant = (2 * depth) ** 0.25
|
994 |
+
|
995 |
+
assert (int(sandwich_norm) + int(resi_dual)) <= 1, 'either sandwich norm or resiDual is selected, but not both'
|
996 |
+
assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
|
997 |
+
assert not (not pre_norm and resi_dual), 'resiDualcannot be used when not using prenorm'
|
998 |
+
|
999 |
+
self.pre_norm = pre_norm
|
1000 |
+
self.sandwich_norm = sandwich_norm
|
1001 |
+
self.resi_dual = resi_dual
|
1002 |
+
|
1003 |
+
self.residual_attn = residual_attn
|
1004 |
+
self.cross_residual_attn = cross_residual_attn
|
1005 |
+
self.cross_attend = cross_attend
|
1006 |
+
|
1007 |
+
norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
|
1008 |
+
norm_class = RMSNorm if use_rmsnorm else norm_class
|
1009 |
+
norm_fn = partial(norm_class, dim)
|
1010 |
+
|
1011 |
+
if cross_attend and not only_cross:
|
1012 |
+
default_block = ('a', 'c', 'f')
|
1013 |
+
elif cross_attend and only_cross:
|
1014 |
+
default_block = ('c', 'f')
|
1015 |
+
else:
|
1016 |
+
default_block = ('a', 'f')
|
1017 |
+
|
1018 |
+
if macaron:
|
1019 |
+
default_block = ('f',) + default_block
|
1020 |
+
|
1021 |
+
# zero init
|
1022 |
+
|
1023 |
+
if zero_init_branch_output:
|
1024 |
+
attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
|
1025 |
+
ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
|
1026 |
+
|
1027 |
+
# calculate layer block order
|
1028 |
+
|
1029 |
+
if exists(custom_layers):
|
1030 |
+
layer_types = custom_layers
|
1031 |
+
elif exists(par_ratio):
|
1032 |
+
par_depth = depth * len(default_block)
|
1033 |
+
assert 1 < par_ratio <= par_depth, 'par ratio out of range'
|
1034 |
+
default_block = tuple(filter(not_equals('f'), default_block))
|
1035 |
+
par_attn = par_depth // par_ratio
|
1036 |
+
depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
|
1037 |
+
par_width = (depth_cut + depth_cut // par_attn) // par_attn
|
1038 |
+
assert len(default_block) <= par_width, 'default block is too large for par_ratio'
|
1039 |
+
par_block = default_block + ('f',) * (par_width - len(default_block))
|
1040 |
+
par_head = par_block * par_attn
|
1041 |
+
layer_types = par_head + ('f',) * (par_depth - len(par_head))
|
1042 |
+
elif exists(sandwich_coef):
|
1043 |
+
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
|
1044 |
+
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
|
1045 |
+
else:
|
1046 |
+
layer_types = default_block * depth
|
1047 |
+
|
1048 |
+
self.layer_types = layer_types
|
1049 |
+
self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
|
1050 |
+
|
1051 |
+
# stochastic depth
|
1052 |
+
|
1053 |
+
self.layer_dropouts = cast_tuple(layer_dropout, len(layer_types))
|
1054 |
+
|
1055 |
+
# structured dropout for cross attending
|
1056 |
+
|
1057 |
+
self.cross_attn_tokens_dropout = cross_attn_tokens_dropout
|
1058 |
+
|
1059 |
+
# calculate token shifting
|
1060 |
+
|
1061 |
+
shift_tokens = cast_tuple(shift_tokens, len(layer_types))
|
1062 |
+
|
1063 |
+
# iterate and construct layers
|
1064 |
+
|
1065 |
+
for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
|
1066 |
+
is_last_layer = ind == (len(self.layer_types) - 1)
|
1067 |
+
|
1068 |
+
if layer_type == 'a':
|
1069 |
+
layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs)
|
1070 |
+
elif layer_type == 'c':
|
1071 |
+
layer = Attention(dim, heads = heads, **attn_kwargs)
|
1072 |
+
elif layer_type == 'f':
|
1073 |
+
layer = FeedForward(dim, **ff_kwargs)
|
1074 |
+
layer = layer if not macaron else Scale(0.5, layer)
|
1075 |
+
else:
|
1076 |
+
raise Exception(f'invalid layer type {layer_type}')
|
1077 |
+
|
1078 |
+
if layer_shift_tokens > 0:
|
1079 |
+
shift_range_upper = layer_shift_tokens + 1
|
1080 |
+
shift_range_lower = -layer_shift_tokens if not causal else 0
|
1081 |
+
layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
|
1082 |
+
|
1083 |
+
residual_fn = GRUGating if gate_residual else Residual
|
1084 |
+
residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
|
1085 |
+
|
1086 |
+
pre_branch_norm = norm_fn() if pre_norm else None
|
1087 |
+
post_branch_norm = norm_fn() if sandwich_norm else None
|
1088 |
+
post_main_norm = norm_fn() if (resi_dual or not pre_norm) and not is_last_layer else None
|
1089 |
+
|
1090 |
+
norms = nn.ModuleList([
|
1091 |
+
pre_branch_norm,
|
1092 |
+
post_branch_norm,
|
1093 |
+
post_main_norm
|
1094 |
+
])
|
1095 |
+
|
1096 |
+
self.layers.append(nn.ModuleList([
|
1097 |
+
norms,
|
1098 |
+
layer,
|
1099 |
+
residual
|
1100 |
+
]))
|
1101 |
+
|
1102 |
+
self.layers_length = len(self.layers) # It doesn't work if called after
|
1103 |
+
|
1104 |
+
if deepnorm:
|
1105 |
+
init_gain = (8 * depth) ** -0.25
|
1106 |
+
deepnorm_init(self, init_gain)
|
1107 |
+
|
1108 |
+
def forward(
|
1109 |
+
self,
|
1110 |
+
x,
|
1111 |
+
context = None,
|
1112 |
+
mask = None,
|
1113 |
+
context_mask = None,
|
1114 |
+
attn_mask = None,
|
1115 |
+
self_attn_context_mask = None,
|
1116 |
+
mems = None,
|
1117 |
+
return_hiddens = False
|
1118 |
+
):
|
1119 |
+
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
|
1120 |
+
|
1121 |
+
hiddens = []
|
1122 |
+
intermediates = []
|
1123 |
+
prev_attn = None
|
1124 |
+
prev_cross_attn = None
|
1125 |
+
|
1126 |
+
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
|
1127 |
+
|
1128 |
+
rotary_pos_emb = None
|
1129 |
+
if exists(self.rotary_pos_emb):
|
1130 |
+
max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
|
1131 |
+
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
|
1132 |
+
|
1133 |
+
outer_residual = x
|
1134 |
+
|
1135 |
+
for ind, (layer_type, (norm, block, residual_fn), layer_dropout) in enumerate(zip(self.layer_types, self.layers, self.layer_dropouts)):
|
1136 |
+
ind == (self.layers_length - 1)
|
1137 |
+
|
1138 |
+
if self.training and layer_dropout > 0. and random() < layer_dropout:
|
1139 |
+
continue
|
1140 |
+
|
1141 |
+
if layer_type == 'a':
|
1142 |
+
if return_hiddens:
|
1143 |
+
hiddens.append(x)
|
1144 |
+
layer_mem = mems.pop(0) if mems else None
|
1145 |
+
|
1146 |
+
if layer_type == 'c':
|
1147 |
+
if self.training and self.cross_attn_tokens_dropout > 0.:
|
1148 |
+
context, context_mask = dropout_seq(context, context_mask, self.cross_attn_tokens_dropout)
|
1149 |
+
|
1150 |
+
inner_residual = x
|
1151 |
+
|
1152 |
+
pre_norm, post_branch_norm, post_main_norm = norm
|
1153 |
+
|
1154 |
+
if exists(pre_norm) and not self.resi_dual:
|
1155 |
+
x = pre_norm(x)
|
1156 |
+
|
1157 |
+
if layer_type == 'a':
|
1158 |
+
out, inter = block(x, mask = mask, context_mask = self_attn_context_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, mem = layer_mem)
|
1159 |
+
elif layer_type == 'c':
|
1160 |
+
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn)
|
1161 |
+
elif layer_type == 'f':
|
1162 |
+
out = block(x)
|
1163 |
+
|
1164 |
+
if self.resi_dual:
|
1165 |
+
outer_residual = residual_fn(out, outer_residual)
|
1166 |
+
|
1167 |
+
if exists(post_branch_norm):
|
1168 |
+
out = post_branch_norm(out)
|
1169 |
+
|
1170 |
+
x = residual_fn(out, inner_residual)
|
1171 |
+
|
1172 |
+
if layer_type in ('a', 'c') and return_hiddens:
|
1173 |
+
intermediates.append(inter)
|
1174 |
+
|
1175 |
+
if layer_type == 'a' and self.residual_attn:
|
1176 |
+
prev_attn = inter.pre_softmax_attn
|
1177 |
+
elif layer_type == 'c' and self.cross_residual_attn:
|
1178 |
+
prev_cross_attn = inter.pre_softmax_attn
|
1179 |
+
|
1180 |
+
if exists(post_main_norm):
|
1181 |
+
x = post_main_norm(x)
|
1182 |
+
|
1183 |
+
if self.resi_dual:
|
1184 |
+
x = x + pre_norm(outer_residual)
|
1185 |
+
|
1186 |
+
if return_hiddens:
|
1187 |
+
intermediates = LayerIntermediates(
|
1188 |
+
hiddens = hiddens,
|
1189 |
+
attn_intermediates = intermediates
|
1190 |
+
)
|
1191 |
+
|
1192 |
+
return x, intermediates
|
1193 |
+
|
1194 |
+
return x
|
1195 |
+
|
1196 |
+
|
1197 |
+
class Decoder(AttentionLayers):
|
1198 |
+
def __init__(self, **kwargs):
|
1199 |
+
assert 'causal' not in kwargs, 'cannot set causality on decoder'
|
1200 |
+
super().__init__(causal = True, **kwargs)
|
1201 |
+
|
1202 |
+
|
1203 |
+
|
1204 |
+
class Transformer(nn.Module):
|
1205 |
+
def __init__(
|
1206 |
+
self,
|
1207 |
+
*,
|
1208 |
+
num_tokens,
|
1209 |
+
max_seq_len,
|
1210 |
+
attn_layers,
|
1211 |
+
# tokenizer: BaseTokenizer,
|
1212 |
+
embedding_provider: BaseEmbedding,
|
1213 |
+
emb_dim = None,
|
1214 |
+
max_mem_len = 0.,
|
1215 |
+
shift_mem_down = 0,
|
1216 |
+
emb_dropout = 0.,
|
1217 |
+
post_emb_norm = False,
|
1218 |
+
num_memory_tokens = None,
|
1219 |
+
tie_embedding = False,
|
1220 |
+
logits_dim = None,
|
1221 |
+
use_abs_pos_emb = True,
|
1222 |
+
scaled_sinu_pos_emb = False,
|
1223 |
+
l2norm_embed = False,
|
1224 |
+
emb_frac_gradient = 1. # GLM-130B and Cogview successfully used this, set at 0.1
|
1225 |
+
):
|
1226 |
+
super().__init__()
|
1227 |
+
|
1228 |
+
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
|
1229 |
+
|
1230 |
+
dim = attn_layers.dim
|
1231 |
+
emb_dim = default(emb_dim, dim)
|
1232 |
+
|
1233 |
+
self.emb_dim = emb_dim
|
1234 |
+
self.num_tokens = num_tokens
|
1235 |
+
self.max_seq_len = max_seq_len
|
1236 |
+
self.max_mem_len = max_mem_len
|
1237 |
+
self.shift_mem_down = shift_mem_down
|
1238 |
+
|
1239 |
+
self.l2norm_embed = l2norm_embed
|
1240 |
+
self.token_emb = TokenEmbedding(emb_dim, num_tokens, embedding_provider, l2norm_embed=l2norm_embed)
|
1241 |
+
|
1242 |
+
if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
|
1243 |
+
self.pos_emb = always(0)
|
1244 |
+
elif scaled_sinu_pos_emb:
|
1245 |
+
self.pos_emb = ScaledSinusoidalEmbedding(emb_dim)
|
1246 |
+
else:
|
1247 |
+
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len, l2norm_embed = l2norm_embed)
|
1248 |
+
|
1249 |
+
self.emb_frac_gradient = emb_frac_gradient # fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290
|
1250 |
+
|
1251 |
+
self.post_emb_norm = nn.LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
|
1252 |
+
self.emb_dropout = nn.Dropout(emb_dropout)
|
1253 |
+
|
1254 |
+
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
|
1255 |
+
self.attn_layers = attn_layers
|
1256 |
+
self.norm = nn.LayerNorm(dim)
|
1257 |
+
|
1258 |
+
self.init_()
|
1259 |
+
|
1260 |
+
logits_dim = default(logits_dim, num_tokens)
|
1261 |
+
self.to_logits = nn.Linear(dim, logits_dim) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
|
1262 |
+
|
1263 |
+
# memory tokens (like [cls]) from Memory Transformers paper
|
1264 |
+
num_memory_tokens = default(num_memory_tokens, 0)
|
1265 |
+
self.num_memory_tokens = num_memory_tokens
|
1266 |
+
if num_memory_tokens > 0:
|
1267 |
+
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
|
1268 |
+
|
1269 |
+
def init_(self):
|
1270 |
+
if self.l2norm_embed:
|
1271 |
+
nn.init.normal_(self.token_emb.emb.weight, std = 1e-5)
|
1272 |
+
|
1273 |
+
if not isinstance(self.pos_emb, always):
|
1274 |
+
nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5)
|
1275 |
+
|
1276 |
+
return
|
1277 |
+
|
1278 |
+
nn.init.kaiming_normal_(self.token_emb.emb.weight)
|
1279 |
+
|
1280 |
+
def forward(
|
1281 |
+
self,
|
1282 |
+
x,
|
1283 |
+
return_embeddings = False,
|
1284 |
+
return_logits_and_embeddings = False,
|
1285 |
+
return_intermediates = False,
|
1286 |
+
mask = None,
|
1287 |
+
return_mems = False,
|
1288 |
+
return_attn = False,
|
1289 |
+
mems = None,
|
1290 |
+
pos = None,
|
1291 |
+
prepend_embeds = None,
|
1292 |
+
sum_embeds = None,
|
1293 |
+
**kwargs
|
1294 |
+
):
|
1295 |
+
b, n, device, num_mem, emb_frac_gradient = *x.shape, x.device, self.num_memory_tokens, self.emb_frac_gradient
|
1296 |
+
return_hiddens = return_mems | return_attn
|
1297 |
+
|
1298 |
+
# absolute positional embedding
|
1299 |
+
|
1300 |
+
external_pos_emb = exists(pos) and pos.dtype != torch.long
|
1301 |
+
pos_emb = self.pos_emb(x, pos = pos) if not external_pos_emb else pos
|
1302 |
+
x = self.token_emb(x) + pos_emb
|
1303 |
+
|
1304 |
+
# for summing embeddings passed externally - needs this for self-conditioning in non-autoregressive training
|
1305 |
+
|
1306 |
+
if exists(sum_embeds):
|
1307 |
+
x = x + sum_embeds
|
1308 |
+
|
1309 |
+
# post embedding norm, purportedly leads to greater stabilization
|
1310 |
+
|
1311 |
+
x = self.post_emb_norm(x)
|
1312 |
+
|
1313 |
+
# whether to append embeds, as in PaLI, for image embeddings
|
1314 |
+
|
1315 |
+
if exists(prepend_embeds):
|
1316 |
+
prepend_seq, prepend_dim = prepend_embeds.shape[1:]
|
1317 |
+
|
1318 |
+
assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as text model dimensions'
|
1319 |
+
|
1320 |
+
x = torch.cat((prepend_embeds, x), dim = -2)
|
1321 |
+
|
1322 |
+
# whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model
|
1323 |
+
|
1324 |
+
if emb_frac_gradient < 1:
|
1325 |
+
assert emb_frac_gradient > 0
|
1326 |
+
|
1327 |
+
x = x * emb_frac_gradient + x.detach() * (1 - emb_frac_gradient)
|
1328 |
+
|
1329 |
+
# embedding dropout
|
1330 |
+
|
1331 |
+
x = self.emb_dropout(x)
|
1332 |
+
|
1333 |
+
x = self.project_emb(x)
|
1334 |
+
|
1335 |
+
if num_mem > 0:
|
1336 |
+
mem = repeat(self.memory_tokens, 'n d -> b n d', b = b)
|
1337 |
+
x = torch.cat((mem, x), dim = 1)
|
1338 |
+
|
1339 |
+
# auto-handle masking after appending memory tokens
|
1340 |
+
if exists(mask):
|
1341 |
+
mask = pad_at_dim(mask, (num_mem, 0), dim = -1, value = True)
|
1342 |
+
|
1343 |
+
if self.shift_mem_down and exists(mems):
|
1344 |
+
mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
|
1345 |
+
mems = [*mems_r, *mems_l]
|
1346 |
+
|
1347 |
+
if return_hiddens:
|
1348 |
+
x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
|
1349 |
+
else:
|
1350 |
+
x = self.attn_layers(x, mask = mask, mems = mems, **kwargs)
|
1351 |
+
|
1352 |
+
x = self.norm(x)
|
1353 |
+
|
1354 |
+
mem, x = x[:, :num_mem], x[:, num_mem:]
|
1355 |
+
|
1356 |
+
if return_logits_and_embeddings:
|
1357 |
+
out = (self.to_logits(x), x)
|
1358 |
+
elif return_embeddings:
|
1359 |
+
out = x
|
1360 |
+
else:
|
1361 |
+
out = self.to_logits(x)
|
1362 |
+
|
1363 |
+
if return_intermediates:
|
1364 |
+
return out, intermediates
|
1365 |
+
|
1366 |
+
if return_mems:
|
1367 |
+
hiddens = intermediates.hiddens
|
1368 |
+
new_mems = list(map(lambda pair: torch.cat(pair, dim = -2), zip(mems, hiddens))) if exists(mems) else hiddens
|
1369 |
+
new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
|
1370 |
+
return out, new_mems
|
1371 |
+
|
1372 |
+
if return_attn:
|
1373 |
+
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
1374 |
+
return out, attn_maps
|
1375 |
+
|
1376 |
+
return out
|
Andromeda/Andromeda/dataset_prep/__init__.py
ADDED
File without changes
|
Andromeda/Andromeda/dataset_prep/books.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from Andromeda.dataset_builder import DatasetBuilder
|
2 |
+
from build_dataset import DatasetBuilder
|
3 |
+
|
4 |
+
builder = DatasetBuilder(
|
5 |
+
dataset_name="the_pile_books3",
|
6 |
+
seq_len=8192,
|
7 |
+
num_cpu=4,
|
8 |
+
hf_account_repo="kye/the_pile_books3_GPTNeox-8192",
|
9 |
+
tokenizer="EleutherAI/gpt-neox-20b",
|
10 |
+
)
|
11 |
+
|
12 |
+
dataset = builder.build_dataset()
|
Andromeda/Andromeda/inference.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
from einops._torch_specific import allow_ops_in_compiled_graph
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
# class AndromedaEval:
|
8 |
+
# def __init__(self, path, seed=42, device=None):
|
9 |
+
# self.path = path
|
10 |
+
# self.seed = seed
|
11 |
+
|
12 |
+
# self.device = device
|
13 |
+
|
14 |
+
# if self.device is None:
|
15 |
+
# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
16 |
+
|
17 |
+
# set_seed(self.seed)
|
18 |
+
|
19 |
+
# #tokenizer
|
20 |
+
# self.tokenizer = AndromedaTokenizer
|
21 |
+
|
22 |
+
# #model
|
23 |
+
# self.model = Andromeda
|
24 |
+
|
25 |
+
# #checkpoint
|
26 |
+
# self.model.load_state_dict(torch.load(self.path))
|
27 |
+
# self.model.eval()
|
28 |
+
|
29 |
+
# #device
|
30 |
+
# self.model = self.model.to(self.device)
|
31 |
+
|
32 |
+
# #metrics
|
33 |
+
# self.metrics = {}
|
34 |
+
# self.reset_metrics()
|
35 |
+
|
36 |
+
# def reset_metrics(self):
|
37 |
+
# self.metrics = {
|
38 |
+
# "generation_steps": None,
|
39 |
+
# "time_forward": [],
|
40 |
+
# "time_forward_average": None,
|
41 |
+
|
42 |
+
# "memory_usages": [],
|
43 |
+
# "memory_usage_average": None,
|
44 |
+
# "time_end_to_end": None,
|
45 |
+
|
46 |
+
# "throughput": None
|
47 |
+
# }
|
48 |
+
|
49 |
+
# def get_num_params(self):
|
50 |
+
# num_params = sum(param.numel() for param in self.model.parameters() if param.requires_grad)
|
51 |
+
|
52 |
+
# return num_params
|
53 |
+
|
54 |
+
# def generate(self, prompt, generation_steps=32):
|
55 |
+
# #make sure all of the metrics reset at every generation
|
56 |
+
# self.reset_metrics()
|
57 |
+
|
58 |
+
# self.metrics["generation_steps"] = generation_steps
|
59 |
+
|
60 |
+
# tokens = self.tokenizer.encode(prompt)
|
61 |
+
# tokens_new = []
|
62 |
+
|
63 |
+
# time_end_to_end = time.time()
|
64 |
+
|
65 |
+
# #generation loop
|
66 |
+
# for _ in range(generation_steps):
|
67 |
+
# tokens_tensor = torch.tensor([tokens], device=self.device)
|
68 |
+
|
69 |
+
# #forward pass
|
70 |
+
# tracemalloc.start()
|
71 |
+
|
72 |
+
# time_forward_0 = time.time()
|
73 |
+
|
74 |
+
# logits = self.model(tokens_tensor, return_loss=False)[:, -1] # no loss takes the output of the last tokens
|
75 |
+
|
76 |
+
# time_forward_1 = time.time()
|
77 |
+
|
78 |
+
# _, memory_usage = tracemalloc.get_traced_memory()
|
79 |
+
# tracemalloc.stop()
|
80 |
+
|
81 |
+
# self.metrics["memory_usages"].append(memory_usage)
|
82 |
+
|
83 |
+
# time_forward = time_forward_1 - time_forward_0
|
84 |
+
# self.metrics["times_forward"].append(time_forward)
|
85 |
+
|
86 |
+
# next_token = torch.armax(logits).item()
|
87 |
+
|
88 |
+
# #save the newly generated token
|
89 |
+
# tokens.append(next_token)
|
90 |
+
# tokens_new.append(next_token)
|
91 |
+
|
92 |
+
# time_end_to_end_1 = time.time()
|
93 |
+
|
94 |
+
# time_end_to_end = time_end_to_end_1 - time_end_to_end_0
|
95 |
+
# self.metrics["time_end_to_end"] = time_end_to_end
|
96 |
+
|
97 |
+
# decoded = self.tokenizer.decode(tokens)
|
98 |
+
|
99 |
+
# self.metrics["time_forward_average"] = np.mean(self.metrics["times_forward"])
|
100 |
+
# self.metrics["memory_usage_average"] = np.mean(self.metrics["memory_usage"])
|
101 |
+
|
102 |
+
# self.metrics['throughput'] = generation_steps / np.sum(self.metrics["times_forward"])
|
103 |
+
|
104 |
+
# return tokens_new, decoded
|
105 |
+
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
# def main():
|
110 |
+
# prompt = 'My name is'
|
111 |
+
|
112 |
+
# andromeda = EvalAndromeda(path='checkpoints/step_44927_6656/pytorch_model.bin')
|
113 |
+
|
114 |
+
# num_params = Andromeda.get_num_params()
|
115 |
+
# print(f'The model has {num_params} parameters')
|
116 |
+
|
117 |
+
# _, output = Andromeda.generate(prompt)
|
118 |
+
|
119 |
+
# for metric, value in Andromeda.metrics.items():
|
120 |
+
# print(f'{metric}: {value}\n')
|
121 |
+
|
122 |
+
# print('\n')
|
123 |
+
|
124 |
+
# print(output)
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
def main():
|
132 |
+
allow_ops_in_compiled_graph()
|
133 |
+
|
134 |
+
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
|
135 |
+
|
136 |
+
parser = argparse.ArgumentParser(description="Generate text using Andromeda model")
|
137 |
+
parser.add_argument("prompt", type=str, help="Text prompt to generate text")
|
138 |
+
parser.add_argument(
|
139 |
+
"--seq_len", type=int, default=256, help="Sequence length for generated text"
|
140 |
+
)
|
141 |
+
parser.add_argument(
|
142 |
+
"--temperature", type=float, default=0.8, help="Sampling temperature"
|
143 |
+
)
|
144 |
+
parser.add_argument(
|
145 |
+
"--filter_thres", type=float, default=0.9, help="Filter threshold for sampling"
|
146 |
+
)
|
147 |
+
parser.add_argument(
|
148 |
+
"--model",
|
149 |
+
type=str,
|
150 |
+
default="andromeda-e-1",
|
151 |
+
help="Model to use for generation",
|
152 |
+
)
|
153 |
+
|
154 |
+
parser.add_argument(
|
155 |
+
"--dtype",
|
156 |
+
type=str,
|
157 |
+
default="fp32",
|
158 |
+
help="Data type for the model: 'bf16', or 'fp32'",
|
159 |
+
)
|
160 |
+
|
161 |
+
args = parser.parse_args()
|
162 |
+
|
163 |
+
|
164 |
+
dtype = torch.float32
|
165 |
+
if args.dtype == 'bf16':
|
166 |
+
dtype = torch.bfloat16
|
167 |
+
|
168 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
169 |
+
|
170 |
+
#need to submit to torch hub
|
171 |
+
model = torch.hub.load("apacai/andromeda", args.model).to(device).to(dtype)
|
172 |
+
|
173 |
+
opt_model = torch.compile(model, backend="hidet")
|
174 |
+
|
175 |
+
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
176 |
+
|
177 |
+
encoded_text = tokenizer(args.prompt, return_tensors="pt")
|
178 |
+
|
179 |
+
output_tensor = opt_model.generate(
|
180 |
+
seq_len=args.seq_len,
|
181 |
+
prompt=encoded_text["input_ids"].to(device),
|
182 |
+
temperature=args.temperature,
|
183 |
+
filter_thres=args.filter_thres,
|
184 |
+
pad_value=0.0,
|
185 |
+
eos_token=tokenizer.eos_token_id,
|
186 |
+
return_seq_without_prompt=False,
|
187 |
+
use_tqdm=True,
|
188 |
+
)
|
189 |
+
|
190 |
+
decoded_output = tokenizer.batch_decode(output_tensor, skip_special_tokens=True)
|
191 |
+
|
192 |
+
return decoded_output
|
193 |
+
|
194 |
+
|
195 |
+
if __name__ == "__main__":
|
196 |
+
generated_text = main()
|
197 |
+
for text in generated_text:
|
198 |
+
print(f"{text}")
|
Andromeda/Andromeda/model.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.nn import Module
|
2 |
+
from Andromeda.core.transformer import Transformer, AutoregressiveWrapper, AndromedaEmbedding, Decoder
|
3 |
+
from transformers import AutoTokenizer
|
4 |
+
|
5 |
+
class AndromedaTokenizer:
|
6 |
+
def __init__(self):
|
7 |
+
self.tokenizer= AutoTokenizer.from_pretrained(
|
8 |
+
"EleutherAI/gpt-neox-20b",
|
9 |
+
eos_token="<eos>",
|
10 |
+
pad_token="<pad>",
|
11 |
+
extra_ids=0,
|
12 |
+
model_max_length=8192
|
13 |
+
)
|
14 |
+
|
15 |
+
def tokenize_texts(self, texts):
|
16 |
+
return self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True).input_ids
|
17 |
+
|
18 |
+
def decode(self, texts):
|
19 |
+
return self.tokenizer.decode(texts)
|
20 |
+
|
21 |
+
def __len__(self):
|
22 |
+
num_tokens = len(self.tokenizer)
|
23 |
+
return num_tokens
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
class Andromeda(Module):
|
28 |
+
"""
|
29 |
+
Andromeda is a transformer-based model architecture. It initializes with
|
30 |
+
a Transformer and AutoregressiveWrapper with default or user-specified parameters.
|
31 |
+
"""
|
32 |
+
def __init__(self,
|
33 |
+
num_tokens=50432,
|
34 |
+
max_seq_len=8192,
|
35 |
+
dim=2560,
|
36 |
+
depth=32,
|
37 |
+
dim_head=128,
|
38 |
+
heads=24,
|
39 |
+
use_abs_pos_emb=False,
|
40 |
+
alibi_pos_bias=True,
|
41 |
+
alibi_num_heads=12,
|
42 |
+
rotary_xpos=True,
|
43 |
+
attn_flash=True,
|
44 |
+
# shift_tokens=1,
|
45 |
+
attn_one_kv_head=True, # multiquery attention
|
46 |
+
qk_norm=True,
|
47 |
+
attn_qk_norm=True,
|
48 |
+
attn_qk_norm_dim_scale=True,
|
49 |
+
embedding_provider=AndromedaEmbedding()):
|
50 |
+
"""
|
51 |
+
Initialize the model with specified or default parameters.
|
52 |
+
Args:
|
53 |
+
- num_tokens: Number of tokens in the vocabulary
|
54 |
+
- max_seq_len: Maximum sequence length
|
55 |
+
- dim: Dimension of the model
|
56 |
+
- depth: Depth of the model
|
57 |
+
- dim_head: Dimension of the model head
|
58 |
+
- heads: Number of heads
|
59 |
+
- use_abs_pos_emb: Whether to use absolute position embedding
|
60 |
+
- alibi_pos_bias: Alibi position bias
|
61 |
+
- alibi_num_heads: Number of alibi heads
|
62 |
+
- rotary_xpos: Rotary position
|
63 |
+
- attn_flash: Attention flash
|
64 |
+
- deepnorm: Deep normalization
|
65 |
+
- shift_tokens: Number of tokens to shift
|
66 |
+
- attn_one_kv_head: Attention one key/value head
|
67 |
+
- qk_norm: Query-key normalization
|
68 |
+
- attn_qk_norm: Attention query-key normalization
|
69 |
+
- attn_qk_norm_dim_scale: Attention query-key normalization dimension scale
|
70 |
+
- embedding_provider: Embedding provider module
|
71 |
+
"""
|
72 |
+
super().__init__()
|
73 |
+
|
74 |
+
try:
|
75 |
+
self.Andromeda = Transformer(
|
76 |
+
num_tokens=num_tokens,
|
77 |
+
max_seq_len=max_seq_len,
|
78 |
+
use_abs_pos_emb=use_abs_pos_emb,
|
79 |
+
embedding_provider=embedding_provider,
|
80 |
+
attn_layers=Decoder(
|
81 |
+
dim=dim,
|
82 |
+
depth=depth,
|
83 |
+
dim_head=dim_head,
|
84 |
+
heads=heads,
|
85 |
+
alibi_pos_bias=alibi_pos_bias,
|
86 |
+
alibi_num_heads=alibi_num_heads,
|
87 |
+
rotary_xpos=rotary_xpos,
|
88 |
+
attn_flash=attn_flash,
|
89 |
+
# deepnorm=deepnorm,
|
90 |
+
# shift_tokens=shift_tokens,
|
91 |
+
attn_one_kv_head=attn_one_kv_head,
|
92 |
+
qk_norm=qk_norm,
|
93 |
+
attn_qk_norm=attn_qk_norm,
|
94 |
+
attn_qk_norm_dim_scale=attn_qk_norm_dim_scale
|
95 |
+
)
|
96 |
+
)
|
97 |
+
|
98 |
+
self.decoder = AutoregressiveWrapper(self.Andromeda)
|
99 |
+
|
100 |
+
except Exception as e:
|
101 |
+
print("Failed to initialize Andromeda: ", e)
|
102 |
+
raise
|
103 |
+
|
104 |
+
def forward(self, text_tokens, **kwargs):
|
105 |
+
"""
|
106 |
+
Forward pass through the model. It expects the input text_tokens.
|
107 |
+
Args:
|
108 |
+
- text_tokens: Input tokens
|
109 |
+
- kwargs: Other arguments
|
110 |
+
Returns:
|
111 |
+
- output from the decoder
|
112 |
+
"""
|
113 |
+
try:
|
114 |
+
model_input = self.decoder.forward(text_tokens)[0]
|
115 |
+
return self.decoder(model_input, padded_x=model_input[0])
|
116 |
+
except Exception as e:
|
117 |
+
print("Failed in forward method: ", e)
|
118 |
+
raise
|
Andromeda/Andromeda/old/__init__.py
ADDED
File without changes
|
Andromeda/Andromeda/old/sophia.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import Tensor
|
3 |
+
from torch.optim.optimizer import Optimizer
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
|
7 |
+
class SophiaG(Optimizer):
|
8 |
+
def __init__(self, params, lr=1e-4, betas=(0.965, 0.99), rho = 0.04,
|
9 |
+
weight_decay=1e-1, *, maximize: bool = False,
|
10 |
+
capturable: bool = False):
|
11 |
+
if not 0.0 <= lr:
|
12 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
13 |
+
if not 0.0 <= betas[0] < 1.0:
|
14 |
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
15 |
+
if not 0.0 <= betas[1] < 1.0:
|
16 |
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
17 |
+
if not 0.0 <= rho:
|
18 |
+
raise ValueError("Invalid rho parameter at index 1: {}".format(rho))
|
19 |
+
if not 0.0 <= weight_decay:
|
20 |
+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
21 |
+
defaults = dict(lr=lr, betas=betas, rho=rho,
|
22 |
+
weight_decay=weight_decay,
|
23 |
+
maximize=maximize, capturable=capturable)
|
24 |
+
super(SophiaG, self).__init__(params, defaults)
|
25 |
+
|
26 |
+
def __setstate__(self, state):
|
27 |
+
super().__setstate__(state)
|
28 |
+
for group in self.param_groups:
|
29 |
+
group.setdefault('maximize', False)
|
30 |
+
group.setdefault('capturable', False)
|
31 |
+
state_values = list(self.state.values())
|
32 |
+
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step'])
|
33 |
+
if not step_is_tensor:
|
34 |
+
for s in state_values:
|
35 |
+
s['step'] = torch.tensor(float(s['step']))
|
36 |
+
|
37 |
+
@torch.no_grad()
|
38 |
+
def update_hessian(self):
|
39 |
+
for group in self.param_groups:
|
40 |
+
beta1, beta2 = group['betas']
|
41 |
+
for p in group['params']:
|
42 |
+
if p.grad is None:
|
43 |
+
continue
|
44 |
+
state = self.state[p]
|
45 |
+
|
46 |
+
if len(state) == 0:
|
47 |
+
state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \
|
48 |
+
if self.defaults['capturable'] else torch.tensor(0.)
|
49 |
+
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
50 |
+
state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
51 |
+
|
52 |
+
if 'hessian' not in state.keys():
|
53 |
+
state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
54 |
+
|
55 |
+
state['hessian'].mul_(beta2).addcmul_(p.grad, p.grad, value=1 - beta2)
|
56 |
+
|
57 |
+
|
58 |
+
@torch.no_grad()
|
59 |
+
def step(self, closure=None, bs=5120):
|
60 |
+
loss = None
|
61 |
+
if closure is not None:
|
62 |
+
with torch.enable_grad():
|
63 |
+
loss = closure()
|
64 |
+
|
65 |
+
for group in self.param_groups:
|
66 |
+
params_with_grad = []
|
67 |
+
grads = []
|
68 |
+
exp_avgs = []
|
69 |
+
state_steps = []
|
70 |
+
hessian = []
|
71 |
+
beta1, beta2 = group['betas']
|
72 |
+
|
73 |
+
for p in group['params']:
|
74 |
+
if p.grad is None:
|
75 |
+
continue
|
76 |
+
params_with_grad.append(p)
|
77 |
+
|
78 |
+
if p.grad.is_sparse:
|
79 |
+
raise RuntimeError('Hero does not support sparse gradients')
|
80 |
+
grads.append(p.grad)
|
81 |
+
state = self.state[p]
|
82 |
+
# State initialization
|
83 |
+
if len(state) == 0:
|
84 |
+
state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \
|
85 |
+
if self.defaults['capturable'] else torch.tensor(0.)
|
86 |
+
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
87 |
+
state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
88 |
+
|
89 |
+
if 'hessian' not in state.keys():
|
90 |
+
state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
91 |
+
|
92 |
+
exp_avgs.append(state['exp_avg'])
|
93 |
+
state_steps.append(state['step'])
|
94 |
+
hessian.append(state['hessian'])
|
95 |
+
|
96 |
+
if self.defaults['capturable']:
|
97 |
+
bs = torch.ones((1,), dtype=torch.float, device=p.device) * bs
|
98 |
+
|
99 |
+
sophiag(params_with_grad,
|
100 |
+
grads,
|
101 |
+
exp_avgs,
|
102 |
+
hessian,
|
103 |
+
state_steps,
|
104 |
+
bs=bs,
|
105 |
+
beta1=beta1,
|
106 |
+
beta2=beta2,
|
107 |
+
rho=group['rho'],
|
108 |
+
lr=group['lr'],
|
109 |
+
weight_decay=group['weight_decay'],
|
110 |
+
maximize=group['maximize'],
|
111 |
+
capturable=group['capturable'])
|
112 |
+
|
113 |
+
return loss
|
114 |
+
|
115 |
+
def sophiag(params: List[Tensor],
|
116 |
+
grads: List[Tensor],
|
117 |
+
exp_avgs: List[Tensor],
|
118 |
+
hessian: List[Tensor],
|
119 |
+
state_steps: List[Tensor],
|
120 |
+
capturable: bool = False,
|
121 |
+
*,
|
122 |
+
bs: int,
|
123 |
+
beta1: float,
|
124 |
+
beta2: float,
|
125 |
+
rho: float,
|
126 |
+
lr: float,
|
127 |
+
weight_decay: float,
|
128 |
+
maximize: bool):
|
129 |
+
|
130 |
+
if not all(isinstance(t, torch.Tensor) for t in state_steps):
|
131 |
+
raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")
|
132 |
+
|
133 |
+
|
134 |
+
func = _single_tensor_sophiag
|
135 |
+
|
136 |
+
func(params,
|
137 |
+
grads,
|
138 |
+
exp_avgs,
|
139 |
+
hessian,
|
140 |
+
state_steps,
|
141 |
+
bs=bs,
|
142 |
+
beta1=beta1,
|
143 |
+
beta2=beta2,
|
144 |
+
rho=rho,
|
145 |
+
lr=lr,
|
146 |
+
weight_decay=weight_decay,
|
147 |
+
maximize=maximize,
|
148 |
+
capturable=capturable)
|
149 |
+
|
150 |
+
def _single_tensor_sophiag(params: List[Tensor],
|
151 |
+
grads: List[Tensor],
|
152 |
+
exp_avgs: List[Tensor],
|
153 |
+
hessian: List[Tensor],
|
154 |
+
state_steps: List[Tensor],
|
155 |
+
*,
|
156 |
+
bs: int,
|
157 |
+
beta1: float,
|
158 |
+
beta2: float,
|
159 |
+
rho: float,
|
160 |
+
lr: float,
|
161 |
+
weight_decay: float,
|
162 |
+
maximize: bool,
|
163 |
+
capturable: bool):
|
164 |
+
|
165 |
+
for i, param in enumerate(params):
|
166 |
+
grad = grads[i] if not maximize else -grads[i]
|
167 |
+
exp_avg = exp_avgs[i]
|
168 |
+
hess = hessian[i]
|
169 |
+
step_t = state_steps[i]
|
170 |
+
|
171 |
+
if capturable:
|
172 |
+
assert param.is_cuda and step_t.is_cuda and bs.is_cuda
|
173 |
+
|
174 |
+
if torch.is_complex(param):
|
175 |
+
grad = torch.view_as_real(grad)
|
176 |
+
exp_avg = torch.view_as_real(exp_avg)
|
177 |
+
hess = torch.view_as_real(hess)
|
178 |
+
param = torch.view_as_real(param)
|
179 |
+
|
180 |
+
# update step
|
181 |
+
step_t += 1
|
182 |
+
|
183 |
+
# Perform stepweight decay
|
184 |
+
param.mul_(1 - lr * weight_decay)
|
185 |
+
|
186 |
+
# Decay the first and second moment running average coefficient
|
187 |
+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
188 |
+
|
189 |
+
if capturable:
|
190 |
+
step_size = lr
|
191 |
+
step_size_neg = step_size.neg()
|
192 |
+
|
193 |
+
ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None,1)
|
194 |
+
param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg)
|
195 |
+
else:
|
196 |
+
step_t.item()
|
197 |
+
step_size_neg = - lr
|
198 |
+
|
199 |
+
ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None,1)
|
200 |
+
param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg)
|
Andromeda/Andromeda/old/training.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#quantization + paralleism
|
2 |
+
import time
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from accelerate.utils import set_seed
|
6 |
+
from datasets import load_dataset
|
7 |
+
from torch.nn import CrossEntropyLoss
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
from transformers import default_data_collator, get_linear_schedule_with_warmup
|
10 |
+
from accelerate import Accelerator
|
11 |
+
|
12 |
+
from rich.progress import Progress
|
13 |
+
|
14 |
+
|
15 |
+
from lion_pytorch import Lion
|
16 |
+
# from x_transformers import Transformer, Decoder, AutoregressiveWrapper
|
17 |
+
from optimus_prim import Transformer, Decoder, AutoregressiveWrapper
|
18 |
+
|
19 |
+
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
20 |
+
import torch.distributed as dist
|
21 |
+
|
22 |
+
from torch.distributed.fsdp import (
|
23 |
+
FullyShardedDataParallel,
|
24 |
+
CPUOffload,
|
25 |
+
)
|
26 |
+
|
27 |
+
from torch.distributed.fsdp.wrap import (
|
28 |
+
default_auto_wrap_policy,
|
29 |
+
)
|
30 |
+
|
31 |
+
from transformers import AutoTokenizer
|
32 |
+
|
33 |
+
#logging
|
34 |
+
import boto3
|
35 |
+
|
36 |
+
|
37 |
+
#training
|
38 |
+
import wandb
|
39 |
+
|
40 |
+
from torch.utils.tensorboard import SummaryWriter
|
41 |
+
|
42 |
+
class CustomGPTNeoXTokenizer:
|
43 |
+
def __init__(self):
|
44 |
+
self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
45 |
+
|
46 |
+
def tokenize(self, text):
|
47 |
+
return self.tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
48 |
+
|
49 |
+
custom_tokenizer = CustomGPTNeoXTokenizer()
|
50 |
+
|
51 |
+
Andromeda = Transformer(
|
52 |
+
num_tokens=64007,
|
53 |
+
max_seq_len=8192,
|
54 |
+
use_abs_pos_emb = False,
|
55 |
+
tokenizer=custom_tokenizer,
|
56 |
+
attn_layers = Decoder(
|
57 |
+
dim=2048,
|
58 |
+
depth=6,
|
59 |
+
heads=16,
|
60 |
+
alibi_pos_bias=True,
|
61 |
+
alibi_num_heads=8,
|
62 |
+
rotary_xpos=True,
|
63 |
+
attn_flash = True,
|
64 |
+
deepnorm=True,
|
65 |
+
shift_tokens=1,
|
66 |
+
attn_one_kv_head = True,
|
67 |
+
qk_norm=True
|
68 |
+
)
|
69 |
+
)
|
70 |
+
|
71 |
+
Andromeda = AutoregressiveWrapper(Andromeda)
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
AWS_ACCESS_KEY_ID=""
|
76 |
+
AWS_SECRET_ACCESS_KEY="d"
|
77 |
+
|
78 |
+
|
79 |
+
def save_model_to_s3(model, bucket_name, key_prefix, step):
|
80 |
+
s3 = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY)
|
81 |
+
model_path = f"checkpoint_at_step_{step}.pt"
|
82 |
+
torch.save(model.state_dict(), model_path)
|
83 |
+
s3.upload_file(model_path, bucket_name, f"{key_prefix}/{model_path}")
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
def count_number_of_parameters(model, only_trainable: bool = True) -> int:
|
88 |
+
if only_trainable:
|
89 |
+
num_params: int = sum(p.numel()
|
90 |
+
for p in model.parameters() if p.requires_grad)
|
91 |
+
else:
|
92 |
+
num_params: int = sum(p.numel() for p in model.parameters() if p)
|
93 |
+
return int(num_params)
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
def prep_sample(sample):
|
98 |
+
title = sample["title"]
|
99 |
+
text = sample["text"]
|
100 |
+
return {
|
101 |
+
"title": title,
|
102 |
+
"text": text
|
103 |
+
}
|
104 |
+
|
105 |
+
|
106 |
+
def train(args):
|
107 |
+
|
108 |
+
if args.use_ddp:
|
109 |
+
dist.init_process_group(backend="nccl")
|
110 |
+
|
111 |
+
|
112 |
+
accelerator = Accelerator(
|
113 |
+
mixed_precision="fp16",
|
114 |
+
gradient_accumulation_steps=1,
|
115 |
+
)
|
116 |
+
|
117 |
+
# If passed along, set the training seed now.
|
118 |
+
if args.seed is not None:
|
119 |
+
set_seed(args.seed)
|
120 |
+
|
121 |
+
#v1
|
122 |
+
model = Andromeda()
|
123 |
+
if args.use_ddp:
|
124 |
+
model = DistributedDataParallel(model)
|
125 |
+
else:
|
126 |
+
model = DataParallel(model)
|
127 |
+
|
128 |
+
fsdp_model = FullyShardedDataParallel(
|
129 |
+
model(),
|
130 |
+
fsdp_auto_wrap_policy=default_auto_wrap_policy,
|
131 |
+
cpu_offload=CPUOffload(offload_params=True),
|
132 |
+
)
|
133 |
+
|
134 |
+
fsdp_model = fsdp_model.to(accelerator.device)
|
135 |
+
|
136 |
+
#device count
|
137 |
+
if torch.cuda.device_count() > 1:
|
138 |
+
print(f"Let's use ${torch.cuda.device_count()} GPUS")
|
139 |
+
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
optimizer = Lion(model.parameters(), lr=args.learning_rate / 3, weight_decay=args.weight_decay * 3)
|
144 |
+
|
145 |
+
lr_scheduler = get_linear_schedule_with_warmup(
|
146 |
+
optimizer=optimizer,
|
147 |
+
num_warmup_steps=args.warmup_steps,
|
148 |
+
num_training_steps=args.max_steps,
|
149 |
+
)
|
150 |
+
|
151 |
+
# tokenizer = KosmosTokenizer()
|
152 |
+
|
153 |
+
#====================> load data #====================> load data #====================> load data
|
154 |
+
|
155 |
+
|
156 |
+
dataset = load_dataset("the_pile_books3")
|
157 |
+
|
158 |
+
# dataset = dataset.map(prep_sample, num_proc=8)
|
159 |
+
dataset = dataset.map(prep_sample, num_proc=8)
|
160 |
+
|
161 |
+
|
162 |
+
#new removed columns
|
163 |
+
remove_columns = ['title']
|
164 |
+
|
165 |
+
|
166 |
+
dataset = dataset.map(Andromeda.decoder.tokenizer, batched=True,
|
167 |
+
batch_size=128, remove_columns=remove_columns)
|
168 |
+
|
169 |
+
train_dataloader = DataLoader(
|
170 |
+
dataset, collate_fn=default_data_collator, batch_size=args.batch_size, pin_memory=True
|
171 |
+
)
|
172 |
+
|
173 |
+
|
174 |
+
|
175 |
+
#====================> load data #====================> load data #====================> load data #====================> load data
|
176 |
+
|
177 |
+
fsdp_model, train_dataloader, optimizer, lr_scheduler = accelerator.prepare(fsdp_model, train_dataloader, optimizer,
|
178 |
+
lr_scheduler)
|
179 |
+
fsdp_model.train()
|
180 |
+
accelerator.register_for_checkpointing(lr_scheduler)
|
181 |
+
|
182 |
+
accelerator.print(
|
183 |
+
f"Number of parameters: {count_number_of_parameters(model):,}")
|
184 |
+
accelerator.print(
|
185 |
+
f"Number of trainable parameters: {count_number_of_parameters(model, only_trainable=True):,}")
|
186 |
+
|
187 |
+
# Log model and optimizer parameters to wandb
|
188 |
+
accelerator.init_trackers(project_name="Andromeda")
|
189 |
+
|
190 |
+
#wandb
|
191 |
+
wandb.init(project="Andromeda", config=args)
|
192 |
+
|
193 |
+
#init tensorboard writer
|
194 |
+
tb_writer = SummaryWriter()
|
195 |
+
|
196 |
+
|
197 |
+
train_loader = iter(train_dataloader)
|
198 |
+
epoch_loss = 0
|
199 |
+
total_loss = 0
|
200 |
+
start_time = time.time()
|
201 |
+
|
202 |
+
with Progress() as progress:
|
203 |
+
task = progress.add_task("[red]Training...", total=args.max_steps)
|
204 |
+
for step in range(0, args.max_steps):
|
205 |
+
batch_start = time.time()
|
206 |
+
batch = next(train_loader)
|
207 |
+
outputs = fsdp_model(**batch, self_attn_padding_mask=batch["attention_mask"])
|
208 |
+
# Shift so that tokens < n predict n
|
209 |
+
outputs = torch.cat([outputs[:, :1], outputs[:, 67:]], dim=1).contiguous()
|
210 |
+
# shift_logits = outputs[..., :-1, :].contiguous()
|
211 |
+
# shift_labels = batch["labels"][..., 1:].contiguous()
|
212 |
+
# Flatten the tokens
|
213 |
+
loss_fct = CrossEntropyLoss()
|
214 |
+
one_hot_labels = torch.nn.functional.one_hot(batch["labels"][:, 1:], num_classes=32002).float()
|
215 |
+
loss = loss_fct(outputs[:,:-1], one_hot_labels)
|
216 |
+
|
217 |
+
epoch_loss += loss.detach().float()
|
218 |
+
|
219 |
+
accelerator.backward(loss)
|
220 |
+
optimizer.step()
|
221 |
+
optimizer.zero_grad()
|
222 |
+
|
223 |
+
batch_end = time.time()
|
224 |
+
logs = {
|
225 |
+
"loss": loss.item(),
|
226 |
+
"perplexity": torch.exp(loss).item(),
|
227 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
228 |
+
"examples": args.batch_size * (step + 1),
|
229 |
+
"examples_per_second": args.batch_size / (batch_end - batch_start),
|
230 |
+
}
|
231 |
+
if step % args.log_every == args.log_every - 1:
|
232 |
+
#log metrics to wandb
|
233 |
+
wandb.log(logs, step=step)
|
234 |
+
|
235 |
+
#log metrics to tensorboard
|
236 |
+
# Log metrics to TensorBoard
|
237 |
+
tb_writer.add_scalar("loss", logs["loss"], step)
|
238 |
+
tb_writer.add_scalar("perplexity", logs["perplexity"], step)
|
239 |
+
tb_writer.add_scalar("lr", logs["lr"], step)
|
240 |
+
tb_writer.add_scalar("examples", logs["examples"], step)
|
241 |
+
tb_writer.add_scalar("examples_per_second", logs["examples_per_second"], step)
|
242 |
+
|
243 |
+
#accelerator
|
244 |
+
accelerator.log(logs, step=step)
|
245 |
+
progress.update(task, advance=1, description=f"Step Loss: {loss.item():.5f} "
|
246 |
+
f"| Mean Loss: {(total_loss + epoch_loss) / step:.5f} "
|
247 |
+
f"| Mean PPL: {torch.exp((total_loss + epoch_loss) / step):.2f} "
|
248 |
+
f"| Examples: {args.batch_size * (step + 1)} "
|
249 |
+
f"| Examples/s: {args.batch_size / (batch_end - batch_start):.2f} "
|
250 |
+
f"| Elapsed: {time.strftime('%H:%M:%S', time.gmtime(time.time() - start_time))}")
|
251 |
+
|
252 |
+
if step % args.save_every == args.save_every - 1:
|
253 |
+
train_epoch_loss = epoch_loss / args.save_every
|
254 |
+
total_loss += epoch_loss
|
255 |
+
epoch_loss = 0
|
256 |
+
|
257 |
+
accelerator.log({
|
258 |
+
"train_ppl": torch.exp(train_epoch_loss),
|
259 |
+
"train_epoch_loss": train_epoch_loss,
|
260 |
+
}, step=step)
|
261 |
+
|
262 |
+
progress.print(f"Saving checkpoint at step {step}...")
|
263 |
+
accelerator.save_state(
|
264 |
+
f"{args.checkpoint_dir}/checkpoint_at_step_{step}/")
|
265 |
+
|
266 |
+
#save the model weights to s3
|
267 |
+
save_model_to_s3(model, "kosmostraining", "kosmosv1/checkpoints", step)
|
268 |
+
print(f"Saved to s3: {save_model_to_s3} ")
|
269 |
+
|
270 |
+
#finish tensorboard writer
|
271 |
+
tb_writer.close()
|
272 |
+
|
273 |
+
#finish wnabd run
|
274 |
+
wandb.finish()
|
275 |
+
|
276 |
+
|
277 |
+
if __name__ == "__main__":
|
278 |
+
import argparse
|
279 |
+
|
280 |
+
parser = argparse.ArgumentParser()
|
281 |
+
parser.add_argument("--checkpoint_dir", type=str, default="checkpoints")
|
282 |
+
parser.add_argument("--learning_rate", type=float, default=1e-5)
|
283 |
+
parser.add_argument("--weight_decay", type=float, default=0.01)
|
284 |
+
parser.add_argument("--warmup_steps", type=int, default=0)
|
285 |
+
parser.add_argument("--max_steps", type=int, default=100000)
|
286 |
+
parser.add_argument("--batch_size", type=int, default=4)
|
287 |
+
parser.add_argument("--log_every", type=int, default=1)
|
288 |
+
parser.add_argument("--save_every", type=int, default=100)
|
289 |
+
parser.add_argument("--seed", type=int, default=None)
|
290 |
+
parser.add_argument("--use_ddp", action="store_true", help="Use DistributedDataParallel")
|
291 |
+
|
292 |
+
args = parser.parse_args()
|
293 |
+
|
294 |
+
train(args)
|
Andromeda/Andromeda/old/training_1.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import multiprocessing
|
3 |
+
import os
|
4 |
+
|
5 |
+
from datetime import timedelta
|
6 |
+
from functools import partial
|
7 |
+
from itertools import chain
|
8 |
+
|
9 |
+
|
10 |
+
from accelerate import Accelerator
|
11 |
+
from accelerate.utils import InitProcessGroupKwargs
|
12 |
+
|
13 |
+
from datasets import concatenate_datasets, load_dataset
|
14 |
+
|
15 |
+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
16 |
+
CheckpointImpl, apply_activation_checkpointing, checkpoint_wrapper)
|
17 |
+
|
18 |
+
from torch.utils.data import DataLoader
|
19 |
+
|
20 |
+
from tqdm import tqdm
|
21 |
+
|
22 |
+
from transformers import (AutoTokenizer, default_data_collator,
|
23 |
+
get_cosine_schedule_with_warmup,
|
24 |
+
get_linear_schedule_with_warmup, set_seed)
|
25 |
+
|
26 |
+
|
27 |
+
# from stable_adamw import StableAdamWUnfused
|
28 |
+
# sd
|
29 |
+
|
30 |
+
from optimus_prime import Transformer, Decoder, AutoregressiveWrapper
|
31 |
+
from optimus_prime import AndromedaEmbedding
|
32 |
+
|
33 |
+
from lion_pytorch import Lion
|
34 |
+
|
35 |
+
|
36 |
+
# constants
|
37 |
+
|
38 |
+
class CFG:
|
39 |
+
BATCH_SIZE: int = 3 # 3
|
40 |
+
GRADIENT_ACCUMULATE_EVERY: int = 1
|
41 |
+
SEED: int = 42
|
42 |
+
LEARNING_RATE: float = 1e-4
|
43 |
+
WEIGHT_DECAY: float = 1e-2
|
44 |
+
SEQ_LEN: int = 8192 # 8192
|
45 |
+
NUM_CPU: int = multiprocessing.cpu_count()
|
46 |
+
USE_PRETOKENIZED: bool = True
|
47 |
+
USE_ACTIVATION_CHECKPOINTING: bool = True
|
48 |
+
RESUME_FROM_CHECKPOINT: str = None
|
49 |
+
CHECKPOINTING_STEPS: int = 1000
|
50 |
+
OUTPUT_DIR: str = "output"
|
51 |
+
ENTITY_NAME: str = "wanb" # Put your wandb username here
|
52 |
+
|
53 |
+
# deepspeed_plugin = DeepSpeedPlugin(zero_stage=2, gradient_accumulation_steps=CFG.GRADIENT_ACCUMULATE_EVERY)
|
54 |
+
|
55 |
+
# helpers
|
56 |
+
|
57 |
+
def print_num_params(model, accelerator: Accelerator):
|
58 |
+
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
59 |
+
accelerator.print(f"Number of parameters in model: {n_params}")
|
60 |
+
|
61 |
+
def fsdp_activation_checkpointing(
|
62 |
+
model, accelerator: Accelerator, offload_to_cpu=False
|
63 |
+
):
|
64 |
+
|
65 |
+
accelerator.print("Using FSDP activation checkpointing")
|
66 |
+
|
67 |
+
# check_fn = lambda submodule: isinstance(submodule, ParallelTransformerBlock)
|
68 |
+
|
69 |
+
non_reentrant_wrapper = partial(
|
70 |
+
checkpoint_wrapper,
|
71 |
+
offload_to_cpu=offload_to_cpu,
|
72 |
+
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
|
73 |
+
)
|
74 |
+
|
75 |
+
apply_activation_checkpointing(
|
76 |
+
model, checkpoint_wrapper_fn=non_reentrant_wrapper)
|
77 |
+
|
78 |
+
|
79 |
+
def get_lr_scheduler_with_warmup(
|
80 |
+
optimizer, scheduler_type, num_warmup_steps, max_train_steps, grad_accumulate_every
|
81 |
+
):
|
82 |
+
NUM_WARMUP_STEPS = num_warmup_steps
|
83 |
+
GRADIENT_ACCUMULATE_EVERY = grad_accumulate_every
|
84 |
+
|
85 |
+
if scheduler_type == "linear":
|
86 |
+
return get_linear_schedule_with_warmup(
|
87 |
+
optimizer=optimizer,
|
88 |
+
num_warmup_steps=NUM_WARMUP_STEPS * GRADIENT_ACCUMULATE_EVERY,
|
89 |
+
num_training_steps=max_train_steps * GRADIENT_ACCUMULATE_EVERY
|
90 |
+
)
|
91 |
+
elif scheduler_type == "cosine":
|
92 |
+
return get_cosine_schedule_with_warmup(
|
93 |
+
optimizer=optimizer,
|
94 |
+
num_warmup_steps=NUM_WARMUP_STEPS * GRADIENT_ACCUMULATE_EVERY,
|
95 |
+
num_training_steps=max_train_steps * GRADIENT_ACCUMULATE_EVERY
|
96 |
+
)
|
97 |
+
else:
|
98 |
+
raise ValueError(
|
99 |
+
"Invalid scheduler_type. Expected 'linear' or 'cosine', got: {}".format(
|
100 |
+
scheduler_type
|
101 |
+
)
|
102 |
+
)
|
103 |
+
|
104 |
+
|
105 |
+
def build_dataloaders():
|
106 |
+
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
107 |
+
dataset = load_dataset("openwebtext", split="train")
|
108 |
+
|
109 |
+
tokenized_dataset = dataset.map(
|
110 |
+
lambda example: tokenizer([t + tokenizer.eos_token for t in example["text"]]),
|
111 |
+
batched=True,
|
112 |
+
num_proc=CFG.NUM_CPU,
|
113 |
+
remove_columns=["text"],
|
114 |
+
)
|
115 |
+
|
116 |
+
block_size = CFG.SEQ_LEN
|
117 |
+
|
118 |
+
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
|
119 |
+
def group_texts(examples):
|
120 |
+
# Concatenate all texts.
|
121 |
+
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
|
122 |
+
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
123 |
+
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
124 |
+
# customize this part to your needs.
|
125 |
+
if total_length >= block_size:
|
126 |
+
total_length = (total_length // block_size) * block_size
|
127 |
+
# Split by chunks of max_len.
|
128 |
+
result = {
|
129 |
+
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
130 |
+
for k, t in concatenated_examples.items()
|
131 |
+
}
|
132 |
+
return result
|
133 |
+
|
134 |
+
train_dataset = tokenized_dataset.map(
|
135 |
+
group_texts, batched=True, num_proc=CFG.NUM_CPU,
|
136 |
+
)
|
137 |
+
|
138 |
+
return train_dataset
|
139 |
+
|
140 |
+
# main
|
141 |
+
|
142 |
+
def TrainAndromeda():
|
143 |
+
# accelerator
|
144 |
+
|
145 |
+
timeout = InitProcessGroupKwargs(timeout=timedelta(seconds=1_000_000))
|
146 |
+
|
147 |
+
accelerator = Accelerator(
|
148 |
+
gradient_accumulation_steps=CFG.GRADIENT_ACCUMULATE_EVERY,
|
149 |
+
mixed_precision="fp16",
|
150 |
+
log_with="wandb",
|
151 |
+
kwargs_handlers=[timeout],
|
152 |
+
deepspeed_plugin=deepspeed_plugin
|
153 |
+
)
|
154 |
+
|
155 |
+
accelerator.init_trackers(
|
156 |
+
project_name="andromeda",
|
157 |
+
config={
|
158 |
+
"batch_size": CFG.BATCH_SIZE,
|
159 |
+
"gradient_accumulate_every": CFG.GRADIENT_ACCUMULATE_EVERY,
|
160 |
+
"learning_rate": CFG.LEARNING_RATE,
|
161 |
+
"seq_len": CFG.SEQ_LEN,
|
162 |
+
},
|
163 |
+
init_kwargs={"wandb": {"entity": CFG.ENTITY_NAME}}
|
164 |
+
)
|
165 |
+
|
166 |
+
accelerator.print(f"Total GPUS: {accelerator.num_processes}")
|
167 |
+
|
168 |
+
# set seed
|
169 |
+
|
170 |
+
set_seed(CFG.SEED)
|
171 |
+
|
172 |
+
# Create the tokenizer
|
173 |
+
|
174 |
+
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
175 |
+
|
176 |
+
# instantiate andromeda
|
177 |
+
|
178 |
+
model = Transformer(
|
179 |
+
num_tokens=64007,
|
180 |
+
max_seq_len=8192,
|
181 |
+
use_abs_pos_emb=False,
|
182 |
+
tokenizer=tokenizer, # !
|
183 |
+
embedding_provider=AndromedaEmbedding(),
|
184 |
+
attn_layers = Decoder(
|
185 |
+
dim=128, # 2048
|
186 |
+
depth=8, # 16
|
187 |
+
dim_head=128,
|
188 |
+
heads=8,
|
189 |
+
alibi_pos_bias=True,
|
190 |
+
alibi_num_heads=4,
|
191 |
+
rotary_xpos=True,
|
192 |
+
attn_flash = True,
|
193 |
+
deepnorm=True,
|
194 |
+
shift_tokens=1,
|
195 |
+
attn_one_kv_head = True,
|
196 |
+
qk_norm=True,
|
197 |
+
attn_qk_norm=True,
|
198 |
+
attn_qk_norm_dim_scale=True # set this to True, in addition to `attn_qk_norm = True`
|
199 |
+
)
|
200 |
+
).to(accelerator.device)
|
201 |
+
|
202 |
+
model = AutoregressiveWrapper(model).to(accelerator.device)
|
203 |
+
|
204 |
+
optim = Lion(model.parameters(), lr=1e-4, weight_decay=1e-2, use_triton=True)
|
205 |
+
|
206 |
+
print_num_params(model, accelerator)
|
207 |
+
|
208 |
+
if CFG.USE_ACTIVATION_CHECKPOINTING:
|
209 |
+
fsdp_activation_checkpointing(model, accelerator)
|
210 |
+
|
211 |
+
# dataloaders
|
212 |
+
|
213 |
+
if CFG.USE_PRETOKENIZED:
|
214 |
+
d0 = load_dataset("conceptofmind/c4_0-to-20_neox_with_eos_8k", split="train")
|
215 |
+
d1 = load_dataset("conceptofmind/c4_21-to-40_neox_with_eos_8k", split="train")
|
216 |
+
d2 = load_dataset("conceptofmind/c4_41-to-60_neox_with_eos_8k", split="train")
|
217 |
+
d3 = load_dataset("conceptofmind/c4_61-to-80_neox_with_eos_8k", split="train")
|
218 |
+
d4 = load_dataset("conceptofmind/c4_81-to-100_neox_with_eos_8k", split="train")
|
219 |
+
|
220 |
+
train_dataset = concatenate_datasets([d0, d1, d2, d3, d4])
|
221 |
+
else:
|
222 |
+
train_dataset = build_dataloaders()
|
223 |
+
|
224 |
+
train_loader = DataLoader(
|
225 |
+
train_dataset, batch_size=CFG.BATCH_SIZE, collate_fn=default_data_collator,
|
226 |
+
)
|
227 |
+
|
228 |
+
max_train_steps = math.ceil(len(train_loader) / CFG.GRADIENT_ACCUMULATE_EVERY)
|
229 |
+
accelerator.print(f"Max train steps: {max_train_steps}")
|
230 |
+
|
231 |
+
# lr scheduler
|
232 |
+
# We cant decide on an actual number
|
233 |
+
|
234 |
+
NUM_WARMUP_STEPS = int(max_train_steps * 0.01)
|
235 |
+
accelerator.print(f"Num warmup steps: {NUM_WARMUP_STEPS}")
|
236 |
+
|
237 |
+
lr_scheduler = get_lr_scheduler_with_warmup(
|
238 |
+
optimizer=optim,
|
239 |
+
scheduler_type="cosine",
|
240 |
+
num_warmup_steps=NUM_WARMUP_STEPS,
|
241 |
+
max_train_steps=max_train_steps,
|
242 |
+
grad_accumulate_every=CFG.GRADIENT_ACCUMULATE_EVERY
|
243 |
+
)
|
244 |
+
|
245 |
+
# prepare
|
246 |
+
|
247 |
+
model, optim, train_loader, lr_scheduler = accelerator.prepare(
|
248 |
+
model, optim, train_loader, lr_scheduler
|
249 |
+
)
|
250 |
+
|
251 |
+
# checkpoint scheduler
|
252 |
+
|
253 |
+
accelerator.register_for_checkpointing(lr_scheduler)
|
254 |
+
|
255 |
+
# I do not know why Huggingface recommends recalculation of max_train_steps
|
256 |
+
|
257 |
+
max_train_steps = math.ceil(len(train_loader) / CFG.GRADIENT_ACCUMULATE_EVERY)
|
258 |
+
accelerator.print(f"Max train steps recalculated: {max_train_steps}")
|
259 |
+
|
260 |
+
# Total batch size for logging
|
261 |
+
|
262 |
+
total_batch_size = (
|
263 |
+
CFG.BATCH_SIZE * accelerator.num_processes * CFG.GRADIENT_ACCUMULATE_EVERY
|
264 |
+
)
|
265 |
+
accelerator.print(f"Total batch size: {total_batch_size}")
|
266 |
+
|
267 |
+
# resume training
|
268 |
+
|
269 |
+
progress_bar = tqdm(
|
270 |
+
range(max_train_steps), disable=not accelerator.is_local_main_process
|
271 |
+
)
|
272 |
+
completed_steps = 0
|
273 |
+
|
274 |
+
if CFG.RESUME_FROM_CHECKPOINT:
|
275 |
+
if CFG.RESUME_FROM_CHECKPOINT is not None or CFG.RESUME_FROM_CHECKPOINT != "":
|
276 |
+
accelerator.print(f"Resuming from checkpoint {CFG.RESUME_FROM_CHECKPOINT}")
|
277 |
+
accelerator.load_state(CFG.RESUME_FROM_CHECKPOINT)
|
278 |
+
path = os.path.basename(CFG.RESUME_FROM_CHECKPOINT)
|
279 |
+
|
280 |
+
training_difference = os.path.splitext(path)[0]
|
281 |
+
|
282 |
+
# need to multiply `gradient_accumulation_steps` to reflect real steps
|
283 |
+
resume_step = (
|
284 |
+
int(training_difference.replace("step_", ""))
|
285 |
+
* CFG.GRADIENT_ACCUMULATE_EVERY
|
286 |
+
)
|
287 |
+
|
288 |
+
if CFG.RESUME_FROM_CHECKPOINT and resume_step is not None:
|
289 |
+
train_loader = accelerator.skip_first_batches(train_loader, resume_step)
|
290 |
+
completed_steps += resume_step
|
291 |
+
progress_bar.update(resume_step)
|
292 |
+
|
293 |
+
# training
|
294 |
+
|
295 |
+
model.train()
|
296 |
+
|
297 |
+
for step, batch in enumerate(train_loader):
|
298 |
+
with accelerator.accumulate(model):
|
299 |
+
inputs = batch["input_ids"].to(accelerator.device)
|
300 |
+
_, loss = model(inputs, return_loss=True)
|
301 |
+
accelerator.backward(loss)
|
302 |
+
|
303 |
+
# print(loss.item())
|
304 |
+
|
305 |
+
accelerator.log({"loss": loss.item()}, step=step)
|
306 |
+
|
307 |
+
if accelerator.sync_gradients:
|
308 |
+
accelerator.clip_grad_norm_(model.parameters(), 0.5)
|
309 |
+
|
310 |
+
optim.step()
|
311 |
+
lr_scheduler.step()
|
312 |
+
optim.zero_grad()
|
313 |
+
|
314 |
+
if accelerator.sync_gradients:
|
315 |
+
progress_bar.update(1)
|
316 |
+
completed_steps += 1
|
317 |
+
|
318 |
+
if isinstance(CFG.CHECKPOINTING_STEPS, int):
|
319 |
+
if completed_steps % CFG.CHECKPOINTING_STEPS == 0:
|
320 |
+
output_dir = f"step_{completed_steps }"
|
321 |
+
if CFG.OUTPUT_DIR is not None:
|
322 |
+
output_dir = os.path.join(CFG.OUTPUT_DIR, output_dir)
|
323 |
+
accelerator.save_state(output_dir)
|
324 |
+
|
325 |
+
if completed_steps >= max_train_steps:
|
326 |
+
break
|
327 |
+
|
328 |
+
# end training
|
329 |
+
|
330 |
+
accelerator.print("Training Finished")
|
331 |
+
accelerator.end_training()
|
332 |
+
|
333 |
+
# save final model
|
334 |
+
|
335 |
+
# accelerator.print(f"Saving model to {CFG.OUTPUT_DIR}")
|
336 |
+
if CFG.OUTPUT_DIR is not None:
|
337 |
+
base_path = f'{CFG.OUTPUT_DIR}/final'
|
338 |
+
|
339 |
+
if not os.path.exists(base_path):
|
340 |
+
os.makedirs(base_path)
|
341 |
+
|
342 |
+
accelerator.wait_for_everyone()
|
343 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
344 |
+
with accelerator.main_process_first():
|
345 |
+
accelerator.save(
|
346 |
+
unwrapped_model.state_dict(), os.path.join(base_path, 'final_model.pt')
|
347 |
+
)
|
348 |
+
|
349 |
+
if __name__ == "__main__":
|
350 |
+
TrainAndromeda()
|
Andromeda/Andromeda/old/training_sophia.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import multiprocessing
|
3 |
+
import os
|
4 |
+
|
5 |
+
from datetime import timedelta
|
6 |
+
from functools import partial
|
7 |
+
from itertools import chain
|
8 |
+
|
9 |
+
|
10 |
+
from accelerate import Accelerator
|
11 |
+
from accelerate.utils import InitProcessGroupKwargs
|
12 |
+
|
13 |
+
from datasets import concatenate_datasets, load_dataset
|
14 |
+
|
15 |
+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
16 |
+
CheckpointImpl, apply_activation_checkpointing, checkpoint_wrapper)
|
17 |
+
|
18 |
+
from torch.utils.data import DataLoader
|
19 |
+
|
20 |
+
from tqdm import tqdm
|
21 |
+
|
22 |
+
from transformers import (AutoTokenizer, default_data_collator,
|
23 |
+
get_cosine_schedule_with_warmup,
|
24 |
+
get_linear_schedule_with_warmup, set_seed)
|
25 |
+
|
26 |
+
|
27 |
+
# from stable_adamw import StableAdamWUnfused
|
28 |
+
# sd
|
29 |
+
|
30 |
+
from optimus_prime import Transformer, Decoder, AutoregressiveWrapper
|
31 |
+
from optimus_prime import AndromedaEmbedding
|
32 |
+
|
33 |
+
from sophia import SophiaG
|
34 |
+
|
35 |
+
# constants
|
36 |
+
|
37 |
+
class CFG:
|
38 |
+
BATCH_SIZE: int = 3 # 3
|
39 |
+
GRADIENT_ACCUMULATE_EVERY: int = 1
|
40 |
+
SEED: int = 42
|
41 |
+
LEARNING_RATE: float = 1e-4
|
42 |
+
WEIGHT_DECAY: float = 1e-2
|
43 |
+
SEQ_LEN: int = 8192 # 8192
|
44 |
+
NUM_CPU: int = multiprocessing.cpu_count()
|
45 |
+
USE_PRETOKENIZED: bool = True
|
46 |
+
USE_ACTIVATION_CHECKPOINTING: bool = True
|
47 |
+
RESUME_FROM_CHECKPOINT: str = None
|
48 |
+
CHECKPOINTING_STEPS: int = 1000
|
49 |
+
OUTPUT_DIR: str = "output"
|
50 |
+
ENTITY_NAME: str = "nicolo" # Put your wandb username here
|
51 |
+
|
52 |
+
# helpers
|
53 |
+
|
54 |
+
def print_num_params(model, accelerator: Accelerator):
|
55 |
+
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
56 |
+
accelerator.print(f"Number of parameters in model: {n_params}")
|
57 |
+
|
58 |
+
def fsdp_activation_checkpointing(
|
59 |
+
model, accelerator: Accelerator, offload_to_cpu=False
|
60 |
+
):
|
61 |
+
|
62 |
+
accelerator.print("Using FSDP activation checkpointing")
|
63 |
+
|
64 |
+
# check_fn = lambda submodule: isinstance(submodule, ParallelTransformerBlock)
|
65 |
+
|
66 |
+
non_reentrant_wrapper = partial(
|
67 |
+
checkpoint_wrapper,
|
68 |
+
offload_to_cpu=offload_to_cpu,
|
69 |
+
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
|
70 |
+
)
|
71 |
+
|
72 |
+
apply_activation_checkpointing(
|
73 |
+
model, checkpoint_wrapper_fn=non_reentrant_wrapper)
|
74 |
+
|
75 |
+
|
76 |
+
def get_lr_scheduler_with_warmup(
|
77 |
+
optimizer, scheduler_type, num_warmup_steps, max_train_steps, grad_accumulate_every
|
78 |
+
):
|
79 |
+
NUM_WARMUP_STEPS = num_warmup_steps
|
80 |
+
GRADIENT_ACCUMULATE_EVERY = grad_accumulate_every
|
81 |
+
|
82 |
+
if scheduler_type == "linear":
|
83 |
+
return get_linear_schedule_with_warmup(
|
84 |
+
optimizer=optimizer,
|
85 |
+
num_warmup_steps=NUM_WARMUP_STEPS * GRADIENT_ACCUMULATE_EVERY,
|
86 |
+
num_training_steps=max_train_steps * GRADIENT_ACCUMULATE_EVERY
|
87 |
+
)
|
88 |
+
elif scheduler_type == "cosine":
|
89 |
+
return get_cosine_schedule_with_warmup(
|
90 |
+
optimizer=optimizer,
|
91 |
+
num_warmup_steps=NUM_WARMUP_STEPS * GRADIENT_ACCUMULATE_EVERY,
|
92 |
+
num_training_steps=max_train_steps * GRADIENT_ACCUMULATE_EVERY
|
93 |
+
)
|
94 |
+
else:
|
95 |
+
raise ValueError(
|
96 |
+
"Invalid scheduler_type. Expected 'linear' or 'cosine', got: {}".format(
|
97 |
+
scheduler_type
|
98 |
+
)
|
99 |
+
)
|
100 |
+
|
101 |
+
|
102 |
+
def build_dataloaders():
|
103 |
+
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
104 |
+
|
105 |
+
content_column = 'text'
|
106 |
+
|
107 |
+
dataset = load_dataset("sentiment140", split="train")
|
108 |
+
dataset = dataset.remove_columns([col for col in dataset.column_names if col != content_column])
|
109 |
+
|
110 |
+
tokenized_dataset = dataset.map(
|
111 |
+
lambda example: tokenizer([t + tokenizer.eos_token for t in example[content_column]]),
|
112 |
+
batched=True,
|
113 |
+
num_proc=CFG.NUM_CPU,
|
114 |
+
remove_columns=[content_column]
|
115 |
+
)
|
116 |
+
|
117 |
+
block_size = CFG.SEQ_LEN
|
118 |
+
|
119 |
+
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
|
120 |
+
def group_texts(examples):
|
121 |
+
# Concatenate all texts.
|
122 |
+
concatenated_examples = {}
|
123 |
+
|
124 |
+
for k in examples.keys():
|
125 |
+
concatenated_examples[k] = list(chain(*examples[k]))
|
126 |
+
|
127 |
+
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
128 |
+
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
129 |
+
# customize this part to your needs.
|
130 |
+
if total_length >= block_size:
|
131 |
+
total_length = (total_length // block_size) * block_size
|
132 |
+
# Split by chunks of max_len.
|
133 |
+
result = {
|
134 |
+
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
135 |
+
for k, t in concatenated_examples.items()
|
136 |
+
}
|
137 |
+
|
138 |
+
return result
|
139 |
+
|
140 |
+
train_dataset = tokenized_dataset.map(
|
141 |
+
group_texts, batched=True, num_proc=CFG.NUM_CPU
|
142 |
+
)
|
143 |
+
|
144 |
+
return train_dataset
|
145 |
+
|
146 |
+
# main
|
147 |
+
|
148 |
+
def TrainAndromeda():
|
149 |
+
# accelerator
|
150 |
+
|
151 |
+
timeout = InitProcessGroupKwargs(timeout=timedelta(seconds=1_000_000))
|
152 |
+
|
153 |
+
accelerator = Accelerator(
|
154 |
+
gradient_accumulation_steps=CFG.GRADIENT_ACCUMULATE_EVERY,
|
155 |
+
mixed_precision="fp16", # Switch to bf16
|
156 |
+
log_with="wandb",
|
157 |
+
kwargs_handlers=[timeout]
|
158 |
+
)
|
159 |
+
|
160 |
+
accelerator.init_trackers(
|
161 |
+
project_name="andromeda",
|
162 |
+
config={
|
163 |
+
"batch_size": CFG.BATCH_SIZE,
|
164 |
+
"gradient_accumulate_every": CFG.GRADIENT_ACCUMULATE_EVERY,
|
165 |
+
"learning_rate": CFG.LEARNING_RATE,
|
166 |
+
"seq_len": CFG.SEQ_LEN,
|
167 |
+
},
|
168 |
+
init_kwargs={"wandb": {"entity": CFG.ENTITY_NAME}}
|
169 |
+
)
|
170 |
+
|
171 |
+
accelerator.print(f"Total GPUS: {accelerator.num_processes}")
|
172 |
+
|
173 |
+
# set seed
|
174 |
+
|
175 |
+
set_seed(CFG.SEED)
|
176 |
+
|
177 |
+
# Create the tokenizer
|
178 |
+
|
179 |
+
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
180 |
+
|
181 |
+
# instantiate andromeda
|
182 |
+
|
183 |
+
model = Transformer(
|
184 |
+
num_tokens=64007,
|
185 |
+
max_seq_len=8192,
|
186 |
+
use_abs_pos_emb=False,
|
187 |
+
tokenizer=tokenizer, # !
|
188 |
+
embedding_provider=AndromedaEmbedding(),
|
189 |
+
attn_layers = Decoder(
|
190 |
+
dim=128, # 2048
|
191 |
+
depth=8, # 16
|
192 |
+
dim_head=128,
|
193 |
+
heads=8,
|
194 |
+
alibi_pos_bias=True,
|
195 |
+
alibi_num_heads=4,
|
196 |
+
rotary_xpos=True,
|
197 |
+
attn_flash = True,
|
198 |
+
# deepnorm=True,
|
199 |
+
shift_tokens=1,
|
200 |
+
attn_one_kv_head = True,
|
201 |
+
qk_norm=True,
|
202 |
+
attn_qk_norm=True,
|
203 |
+
attn_qk_norm_dim_scale=True # set this to True, in addition to `attn_qk_norm = True`
|
204 |
+
)
|
205 |
+
).to(accelerator.device)
|
206 |
+
|
207 |
+
model = AutoregressiveWrapper(model).to(accelerator.device)
|
208 |
+
|
209 |
+
#optim = Lion(model.parameters(), lr=1e-4, weight_decay=1e-2)
|
210 |
+
optim = SophiaG(model.parameters(), lr=1e-5, weight_decay=1e-1)
|
211 |
+
|
212 |
+
print_num_params(model, accelerator)
|
213 |
+
|
214 |
+
if CFG.USE_ACTIVATION_CHECKPOINTING:
|
215 |
+
fsdp_activation_checkpointing(model, accelerator)
|
216 |
+
|
217 |
+
# dataloaders
|
218 |
+
|
219 |
+
if CFG.USE_PRETOKENIZED:
|
220 |
+
d0 = load_dataset("conceptofmind/c4_0-to-20_neox_with_eos_8k", split="train")
|
221 |
+
d1 = load_dataset("conceptofmind/c4_21-to-40_neox_with_eos_8k", split="train")
|
222 |
+
d2 = load_dataset("conceptofmind/c4_41-to-60_neox_with_eos_8k", split="train")
|
223 |
+
d3 = load_dataset("conceptofmind/c4_61-to-80_neox_with_eos_8k", split="train")
|
224 |
+
d4 = load_dataset("conceptofmind/c4_81-to-100_neox_with_eos_8k", split="train")
|
225 |
+
|
226 |
+
train_dataset = concatenate_datasets([d0, d1, d2, d3, d4])
|
227 |
+
else:
|
228 |
+
train_dataset = build_dataloaders()
|
229 |
+
|
230 |
+
train_loader = DataLoader(
|
231 |
+
train_dataset, batch_size=CFG.BATCH_SIZE, collate_fn=default_data_collator,
|
232 |
+
)
|
233 |
+
|
234 |
+
# optimizer
|
235 |
+
|
236 |
+
# optim = decoupled_optimizer(
|
237 |
+
# model,
|
238 |
+
# learning_rate=CFG.LEARNING_RATE,
|
239 |
+
# weight_decay=CFG.WEIGHT_DECAY,
|
240 |
+
# beta_1=0.9,
|
241 |
+
# beta_2=0.95,
|
242 |
+
# use_adamw=False,
|
243 |
+
# )
|
244 |
+
|
245 |
+
# Determine number of training steps
|
246 |
+
|
247 |
+
max_train_steps = math.ceil(len(train_loader) / CFG.GRADIENT_ACCUMULATE_EVERY)
|
248 |
+
accelerator.print(f"Max train steps: {max_train_steps}")
|
249 |
+
|
250 |
+
# lr scheduler
|
251 |
+
# We cant decide on an actual number
|
252 |
+
|
253 |
+
NUM_WARMUP_STEPS = int(max_train_steps * 0.01)
|
254 |
+
accelerator.print(f"Num warmup steps: {NUM_WARMUP_STEPS}")
|
255 |
+
|
256 |
+
lr_scheduler = get_lr_scheduler_with_warmup(
|
257 |
+
optimizer=optim,
|
258 |
+
scheduler_type="cosine",
|
259 |
+
num_warmup_steps=NUM_WARMUP_STEPS,
|
260 |
+
max_train_steps=max_train_steps,
|
261 |
+
grad_accumulate_every=CFG.GRADIENT_ACCUMULATE_EVERY
|
262 |
+
)
|
263 |
+
|
264 |
+
# prepare
|
265 |
+
|
266 |
+
model, optim, train_loader, lr_scheduler = accelerator.prepare(
|
267 |
+
model, optim, train_loader, lr_scheduler
|
268 |
+
)
|
269 |
+
|
270 |
+
# checkpoint scheduler
|
271 |
+
|
272 |
+
accelerator.register_for_checkpointing(lr_scheduler)
|
273 |
+
|
274 |
+
# I do not know why Huggingface recommends recalculation of max_train_steps
|
275 |
+
|
276 |
+
max_train_steps = math.ceil(len(train_loader) / CFG.GRADIENT_ACCUMULATE_EVERY)
|
277 |
+
accelerator.print(f"Max train steps recalculated: {max_train_steps}")
|
278 |
+
|
279 |
+
# Total batch size for logging
|
280 |
+
|
281 |
+
total_batch_size = (
|
282 |
+
CFG.BATCH_SIZE * accelerator.num_processes * CFG.GRADIENT_ACCUMULATE_EVERY
|
283 |
+
)
|
284 |
+
accelerator.print(f"Total batch size: {total_batch_size}")
|
285 |
+
|
286 |
+
# resume training
|
287 |
+
|
288 |
+
progress_bar = tqdm(
|
289 |
+
range(max_train_steps), disable=not accelerator.is_local_main_process
|
290 |
+
)
|
291 |
+
completed_steps = 0
|
292 |
+
|
293 |
+
if CFG.RESUME_FROM_CHECKPOINT:
|
294 |
+
if CFG.RESUME_FROM_CHECKPOINT is not None or CFG.RESUME_FROM_CHECKPOINT != "":
|
295 |
+
accelerator.print(f"Resuming from checkpoint {CFG.RESUME_FROM_CHECKPOINT}")
|
296 |
+
accelerator.load_state(CFG.RESUME_FROM_CHECKPOINT)
|
297 |
+
path = os.path.basename(CFG.RESUME_FROM_CHECKPOINT)
|
298 |
+
|
299 |
+
training_difference = os.path.splitext(path)[0]
|
300 |
+
|
301 |
+
# need to multiply `gradient_accumulation_steps` to reflect real steps
|
302 |
+
resume_step = (
|
303 |
+
int(training_difference.replace("step_", ""))
|
304 |
+
* CFG.GRADIENT_ACCUMULATE_EVERY
|
305 |
+
)
|
306 |
+
|
307 |
+
if CFG.RESUME_FROM_CHECKPOINT and resume_step is not None:
|
308 |
+
train_loader = accelerator.skip_first_batches(train_loader, resume_step)
|
309 |
+
completed_steps += resume_step
|
310 |
+
progress_bar.update(resume_step)
|
311 |
+
|
312 |
+
# training
|
313 |
+
|
314 |
+
model.train()
|
315 |
+
|
316 |
+
for step, batch in enumerate(train_loader):
|
317 |
+
with accelerator.accumulate(model):
|
318 |
+
inputs = batch["input_ids"].to(accelerator.device)
|
319 |
+
_, loss = model(inputs, return_loss=True)
|
320 |
+
accelerator.backward(loss)
|
321 |
+
|
322 |
+
# print(loss.item())
|
323 |
+
|
324 |
+
accelerator.log({"loss": loss.item()}, step=step)
|
325 |
+
|
326 |
+
if accelerator.sync_gradients:
|
327 |
+
accelerator.clip_grad_norm_(model.parameters(), 0.5)
|
328 |
+
|
329 |
+
optim.step()
|
330 |
+
lr_scheduler.step()
|
331 |
+
optim.zero_grad()
|
332 |
+
|
333 |
+
if accelerator.sync_gradients:
|
334 |
+
progress_bar.update(1)
|
335 |
+
completed_steps += 1
|
336 |
+
|
337 |
+
if isinstance(CFG.CHECKPOINTING_STEPS, int):
|
338 |
+
if completed_steps % CFG.CHECKPOINTING_STEPS == 0:
|
339 |
+
output_dir = f"step_{completed_steps }"
|
340 |
+
if CFG.OUTPUT_DIR is not None:
|
341 |
+
output_dir = os.path.join(CFG.OUTPUT_DIR, output_dir)
|
342 |
+
accelerator.save_state(output_dir)
|
343 |
+
|
344 |
+
if completed_steps >= max_train_steps:
|
345 |
+
break
|
346 |
+
|
347 |
+
# end training
|
348 |
+
|
349 |
+
accelerator.print("Training Finished")
|
350 |
+
accelerator.end_training()
|
351 |
+
|
352 |
+
# save final model
|
353 |
+
|
354 |
+
# accelerator.print(f"Saving model to {CFG.OUTPUT_DIR}")
|
355 |
+
if CFG.OUTPUT_DIR is not None:
|
356 |
+
base_path = f'{CFG.OUTPUT_DIR}/final'
|
357 |
+
|
358 |
+
if not os.path.exists(base_path):
|
359 |
+
os.makedirs(base_path)
|
360 |
+
|
361 |
+
accelerator.wait_for_everyone()
|
362 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
363 |
+
with accelerator.main_process_first():
|
364 |
+
accelerator.save(
|
365 |
+
unwrapped_model.state_dict(), os.path.join(base_path, 'final_model.pt')
|
366 |
+
)
|
367 |
+
|
368 |
+
if __name__ == "__main__":
|
369 |
+
TrainAndromeda()
|
Andromeda/Andromeda/train.py
ADDED
@@ -0,0 +1,700 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import multiprocessing
|
3 |
+
import os
|
4 |
+
from datetime import timedelta
|
5 |
+
from functools import partial
|
6 |
+
from itertools import chain
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
########### SETUP CONFIG
|
11 |
+
import torch.distributed as dist
|
12 |
+
from accelerate import Accelerator
|
13 |
+
from accelerate.logging import get_logger
|
14 |
+
from accelerate.state import AcceleratorState
|
15 |
+
from accelerate.utils import DummyOptim, InitProcessGroupKwargs
|
16 |
+
from datasets import load_dataset
|
17 |
+
from lion_pytorch import Lion
|
18 |
+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
19 |
+
CheckpointImpl,
|
20 |
+
apply_activation_checkpointing,
|
21 |
+
checkpoint_wrapper,
|
22 |
+
)
|
23 |
+
|
24 |
+
# import bitsandbytes as bnb
|
25 |
+
from torch.distributed.fsdp import (
|
26 |
+
BackwardPrefetch,
|
27 |
+
FullyShardedDataParallel,
|
28 |
+
MixedPrecision,
|
29 |
+
ShardingStrategy,
|
30 |
+
)
|
31 |
+
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
32 |
+
from torch.nn import LayerNorm
|
33 |
+
from torch.optim import AdamW
|
34 |
+
from torch.utils.data import DataLoader
|
35 |
+
from tqdm import tqdm
|
36 |
+
from transformers import (
|
37 |
+
AutoTokenizer,
|
38 |
+
default_data_collator,
|
39 |
+
get_cosine_schedule_with_warmup,
|
40 |
+
get_linear_schedule_with_warmup,
|
41 |
+
set_seed,
|
42 |
+
)
|
43 |
+
|
44 |
+
# from Andromeda.model import Andromeda
|
45 |
+
from Andromeda.configs import Andromeda1Billion
|
46 |
+
from Andromeda.core.transformer import Transformer
|
47 |
+
from Andromeda.utils.stable_adamw import StableAdamWUnfused
|
48 |
+
|
49 |
+
# state = AcceleratorState()
|
50 |
+
|
51 |
+
|
52 |
+
logger = get_logger(__name__, log_level="INFO")
|
53 |
+
|
54 |
+
class CFG:
|
55 |
+
BATCH_SIZE = 1
|
56 |
+
GRADIENT_ACCUMULATE_EVERY: int = 1
|
57 |
+
SEED: int = 42
|
58 |
+
LEARNING_RATE: float = 1e-4 #3e-4 # 1e-4 for lion
|
59 |
+
WEIGHT_DECAY: float = 0.1
|
60 |
+
SEQ_LEN: int = 8192
|
61 |
+
NUM_CPU: int = multiprocessing.cpu_count()
|
62 |
+
USE_DEEPSPEED: bool = True
|
63 |
+
USE_FSDP: bool = True
|
64 |
+
USE_PRETOKENIZED: bool = True
|
65 |
+
USE_ACTIVATION_CHECKPOINTING: bool = True
|
66 |
+
RESUME_FROM_CHECKPOINT: str = False
|
67 |
+
CHECKPOINTING_STEPS: int = 1000
|
68 |
+
OUTPUT_DIR: str = 'checkpoints/' # Folder
|
69 |
+
ENTITY_NAME: str = "Andromeda"
|
70 |
+
LOGGING_STEPS: int = 100
|
71 |
+
|
72 |
+
|
73 |
+
# helpers
|
74 |
+
|
75 |
+
|
76 |
+
def print_num_params(model, accelerator: Accelerator):
|
77 |
+
# n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
78 |
+
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
79 |
+
accelerator.print(f"Number of parameters in model: {n_params}")
|
80 |
+
|
81 |
+
|
82 |
+
# activation checkpointing
|
83 |
+
|
84 |
+
|
85 |
+
def activation_checkpointing(
|
86 |
+
model: torch.nn.Module,
|
87 |
+
offload_to_cpu: bool = False,
|
88 |
+
accelerator: Accelerator = None,
|
89 |
+
):
|
90 |
+
"""
|
91 |
+
Apply activation checkpointing to a model.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
model (Module): The model to which to apply activation checkpointing.
|
95 |
+
offload_to_cpu (bool, optional): Whether to offload the activations to CPU. Defaults to False.
|
96 |
+
accelerator (Accelerator, optional): The Accelerate library accelerator. Defaults to None.
|
97 |
+
"""
|
98 |
+
if accelerator is not None:
|
99 |
+
accelerator.print("Using activation checkpointing")
|
100 |
+
def check_fn(submodule):
|
101 |
+
return isinstance(submodule, Transformer)
|
102 |
+
non_reentrant_wrapper = partial(
|
103 |
+
checkpoint_wrapper,
|
104 |
+
offload_to_cpu=offload_to_cpu,
|
105 |
+
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
|
106 |
+
)
|
107 |
+
apply_activation_checkpointing(
|
108 |
+
model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
|
109 |
+
)
|
110 |
+
|
111 |
+
|
112 |
+
# FSDP
|
113 |
+
|
114 |
+
|
115 |
+
def fsdp(
|
116 |
+
model: torch.nn.Module,
|
117 |
+
auto_wrap: bool = False,
|
118 |
+
mp: str = "fp32",
|
119 |
+
shard_strat: str = "NO_SHARD",
|
120 |
+
):
|
121 |
+
"""
|
122 |
+
This function wraps a given PyTorch model with the FullyShardedDataParallel (FSDP) wrapper to enable efficient data parallelism and model sharding.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
model (torch.nn.Module): The original PyTorch model to be wrapped with FSDP.
|
126 |
+
auto_wrap (bool, optional): If True, it enables automatic wrapping of the model's layers according to the transformer_auto_wrap_policy. Default is False.
|
127 |
+
mp (str, optional): The mixed precision mode to be used. Can be 'bf16' for BFloat16, 'fp16' for Float16 or 'fp32' for Float32 precision. Default is 'fp32'.
|
128 |
+
shard_strat (str, optional): The sharding strategy to be used. Can be 'SHARD_GRAD' for sharding at gradient computation, 'FULL_SHARD' for full model sharding or 'NO_SHARD' for no sharding. Default is 'NO_SHARD'.
|
129 |
+
|
130 |
+
Raises:
|
131 |
+
ValueError: If the provided mp (mixed precision mode) is not 'bf16', 'fp16' or 'fp32'.
|
132 |
+
ValueError: If the provided shard_strat (sharding strategy) is not 'SHARD_GRAD', 'FULL_SHARD' or 'NO_SHARD'.
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
torch.nn.Module: The input model wrapped with FSDP.
|
136 |
+
"""
|
137 |
+
if auto_wrap:
|
138 |
+
Andromeda_auto_wrap_policy = partial(
|
139 |
+
transformer_auto_wrap_policy,
|
140 |
+
transformer_layer_cls={
|
141 |
+
Transformer,
|
142 |
+
},
|
143 |
+
)
|
144 |
+
else:
|
145 |
+
Andromeda_auto_wrap_policy = None
|
146 |
+
|
147 |
+
if mp == "bf16":
|
148 |
+
mp_fsdp = MixedPrecision(
|
149 |
+
param_dtype=torch.bfloat16,
|
150 |
+
# Gradient communication precision.
|
151 |
+
reduce_dtype=torch.bfloat16,
|
152 |
+
# Buffer precision.
|
153 |
+
buffer_dtype=torch.bfloat16,
|
154 |
+
)
|
155 |
+
elif mp == "fp16":
|
156 |
+
mp_fsdp = MixedPrecision(
|
157 |
+
param_dtype=torch.float16,
|
158 |
+
# Gradient communication precision.
|
159 |
+
reduce_dtype=torch.float16,
|
160 |
+
# Buffer precision.
|
161 |
+
buffer_dtype=torch.float16,
|
162 |
+
)
|
163 |
+
elif mp == "fp32":
|
164 |
+
mp_fsdp = MixedPrecision(
|
165 |
+
param_dtype=torch.float32,
|
166 |
+
# Gradient communication precision.
|
167 |
+
reduce_dtype=torch.float32,
|
168 |
+
# Buffer precision.
|
169 |
+
buffer_dtype=torch.float32,
|
170 |
+
)
|
171 |
+
else:
|
172 |
+
raise ValueError(
|
173 |
+
"Invalid scheduler_type. Expected 'bf16', 'fp16' or 'fp32', got: {}".format(
|
174 |
+
mp
|
175 |
+
)
|
176 |
+
)
|
177 |
+
|
178 |
+
if shard_strat == "SHARD_GRAD":
|
179 |
+
sharding_strat_fsdp = ShardingStrategy.SHARD_GRAD_OP
|
180 |
+
elif shard_strat == "FULL_SHARD":
|
181 |
+
sharding_strat_fsdp = ShardingStrategy.FULL_SHARD
|
182 |
+
elif shard_strat == "NO_SHARD":
|
183 |
+
sharding_strat_fsdp = ShardingStrategy.NO_SHARD
|
184 |
+
else:
|
185 |
+
raise ValueError(
|
186 |
+
"Invalid scheduler_type. Expected 'SHARD_GRAD', 'FULL_SHARD' or 'NO_SHARD', got: {}".format(
|
187 |
+
shard_strat
|
188 |
+
)
|
189 |
+
)
|
190 |
+
|
191 |
+
model = FullyShardedDataParallel(
|
192 |
+
model,
|
193 |
+
auto_wrap_policy=Andromeda_auto_wrap_policy,
|
194 |
+
mixed_precision=mp_fsdp,
|
195 |
+
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
|
196 |
+
sharding_strategy=sharding_strat_fsdp,
|
197 |
+
forward_prefetch=True,
|
198 |
+
use_orig_params=True,
|
199 |
+
)
|
200 |
+
|
201 |
+
return model
|
202 |
+
|
203 |
+
|
204 |
+
# learning rate scheduler
|
205 |
+
|
206 |
+
|
207 |
+
def get_lr_scheduler_with_warmup(
|
208 |
+
optimizer: torch.optim.Optimizer,
|
209 |
+
scheduler_type: str,
|
210 |
+
num_warmup_steps: int,
|
211 |
+
max_train_steps: int,
|
212 |
+
grad_accumulate_every: int = 1,
|
213 |
+
accelerator: Accelerator = None,
|
214 |
+
):
|
215 |
+
"""
|
216 |
+
Get a learning rate scheduler with warmup.
|
217 |
+
|
218 |
+
Args:
|
219 |
+
optimizer (Optimizer): The optimizer for which to create the learning rate scheduler.
|
220 |
+
scheduler_type (str): The type of learning rate scheduler to create, either "linear" or "cosine".
|
221 |
+
num_warmup_steps (int): The number of warmup steps for the learning rate scheduler.
|
222 |
+
max_train_steps (int): The maximum number of training steps.
|
223 |
+
grad_accumulate_every (int, optional): The gradient accumulation factor. Defaults to 1.
|
224 |
+
accelerator (Accelerator, optional): The Accelerate library accelerator. Defaults to None.
|
225 |
+
|
226 |
+
Returns:
|
227 |
+
The learning rate scheduler with warmup.
|
228 |
+
|
229 |
+
Raises:
|
230 |
+
ValueError: If scheduler_type is not "linear" or "cosine".
|
231 |
+
"""
|
232 |
+
NUM_WARMUP_STEPS = num_warmup_steps
|
233 |
+
GRADIENT_ACCUMULATE_EVERY = grad_accumulate_every
|
234 |
+
if accelerator is not None:
|
235 |
+
accelerator.print(f"Using {scheduler_type} lr scheduler")
|
236 |
+
if scheduler_type == "linear":
|
237 |
+
return get_linear_schedule_with_warmup(
|
238 |
+
optimizer=optimizer,
|
239 |
+
num_warmup_steps=NUM_WARMUP_STEPS * GRADIENT_ACCUMULATE_EVERY,
|
240 |
+
num_training_steps=max_train_steps * GRADIENT_ACCUMULATE_EVERY,
|
241 |
+
)
|
242 |
+
elif scheduler_type == "cosine":
|
243 |
+
return get_cosine_schedule_with_warmup(
|
244 |
+
optimizer=optimizer,
|
245 |
+
num_warmup_steps=NUM_WARMUP_STEPS * GRADIENT_ACCUMULATE_EVERY,
|
246 |
+
num_training_steps=max_train_steps * GRADIENT_ACCUMULATE_EVERY,
|
247 |
+
)
|
248 |
+
else:
|
249 |
+
raise ValueError(
|
250 |
+
"Invalid scheduler_type. Expected 'linear' or 'cosine', got: {}".format(
|
251 |
+
scheduler_type
|
252 |
+
)
|
253 |
+
)
|
254 |
+
|
255 |
+
|
256 |
+
# optimizers
|
257 |
+
|
258 |
+
|
259 |
+
def decoupled_optimizer(
|
260 |
+
model: torch.nn.Module,
|
261 |
+
learning_rate: float,
|
262 |
+
weight_decay: float,
|
263 |
+
beta_1: float,
|
264 |
+
beta_2: float,
|
265 |
+
optimizer_type: str,
|
266 |
+
use_fsdp: bool = True,
|
267 |
+
accelerator: Accelerator = None,
|
268 |
+
):
|
269 |
+
"""
|
270 |
+
Decouples the optimizer from the training process.
|
271 |
+
|
272 |
+
This function sets up the optimizer for the model by creating two groups of parameters:
|
273 |
+
one for weight decay and one without weight decay. Then, it initializes the optimizer
|
274 |
+
with these two groups of parameters.
|
275 |
+
|
276 |
+
Args:
|
277 |
+
model (Module): The model whose parameters are optimized.
|
278 |
+
learning_rate (float): The learning rate for the optimizer.
|
279 |
+
weight_decay (float): The weight decay for the optimizer.
|
280 |
+
beta_1 (float): The exponential decay rate for the 1st moment estimates.
|
281 |
+
beta_2 (float): The exponential decay rate for the 2nd moment estimates.
|
282 |
+
optimizer_type (str): The type of the optimizer. Can be 'lion', 'adamw', or 'stable_adamw'.
|
283 |
+
use_fsdp (bool, optional): If True, the optimizer will work with fully sharded data parallelism. Defaults to True.
|
284 |
+
accelerator (Accelerator, optional): The accelerator from HuggingFace's Accelerate library. Defaults to None.
|
285 |
+
|
286 |
+
Returns:
|
287 |
+
Optimizer: The initialized optimizer.
|
288 |
+
|
289 |
+
Raises:
|
290 |
+
ValueError: If the optimizer type is not 'lion', 'adamw' or 'stable_adamw'.
|
291 |
+
"""
|
292 |
+
accelerator.print(f"Using {optimizer_type} optimizer")
|
293 |
+
# Create an empty dictionary called param_dict to store the model's named parameters.
|
294 |
+
param_dict = {}
|
295 |
+
# Iterate over the model's named parameters and populate the param_dict with key-value pairs.
|
296 |
+
for param_name, param in model.named_parameters():
|
297 |
+
param_dict[param_name] = param
|
298 |
+
|
299 |
+
# Separate the model's named modules into two groups: decay and no_decay.
|
300 |
+
|
301 |
+
# Create an empty list to store the names of the LayerNorm and Embedding layer weights with no weight decay.
|
302 |
+
no_decay = []
|
303 |
+
|
304 |
+
if use_fsdp:
|
305 |
+
exclude_module = "_fsdp_wrapped_module.token_emb"
|
306 |
+
else:
|
307 |
+
exclude_module = "token_emb"
|
308 |
+
|
309 |
+
# Iterate through the named modules of the model.
|
310 |
+
for module_name, module in model.named_modules():
|
311 |
+
# Check if the current module is an instance of any of the desired types (LayerNorm or torch.nn.Embedding).
|
312 |
+
for ndim in [LayerNorm, torch.nn.Embedding]:
|
313 |
+
if isinstance(module, ndim):
|
314 |
+
# If torch.nn.Embedding, append its name with a ".weight" suffix to the no_decay list.
|
315 |
+
if module_name == exclude_module:
|
316 |
+
no_decay.append(f"{module_name}.weight")
|
317 |
+
else:
|
318 |
+
# If the module is an instance of LayerNorm
|
319 |
+
no_decay.append(f"{module_name}.gamma")
|
320 |
+
# Exit the inner loop since the desired module has been found.
|
321 |
+
break
|
322 |
+
|
323 |
+
# Create an empty list to store the names of the Linear layer weights with weight decay.
|
324 |
+
decay = []
|
325 |
+
|
326 |
+
# Iterate through the named modules of the model.
|
327 |
+
for module_name, module in model.named_modules():
|
328 |
+
# Check if the current module is an instance of the desired type (torch.nn.Linear).
|
329 |
+
for ndim in [torch.nn.Linear]:
|
330 |
+
if isinstance(module, ndim):
|
331 |
+
# If the module is an instance of torch.nn.Linear, append its name with a ".weight" suffix to the decay list.
|
332 |
+
decay.append(f"{module_name}.weight")
|
333 |
+
# Exit the inner loop since the desired module has been found.
|
334 |
+
break
|
335 |
+
|
336 |
+
# Create two separate lists of model parameters: decay_param and no_decay_param.
|
337 |
+
# The decay_param list contains the parameters that should have weight decay applied.
|
338 |
+
# The no_decay_param list contains the parameters that should not have weight decay applied, excluding the 'to_logits.weight' parameter.
|
339 |
+
|
340 |
+
# Create an empty list called decay_param to store the parameters with weight decay.
|
341 |
+
decay_param = []
|
342 |
+
|
343 |
+
if use_fsdp:
|
344 |
+
exclude_param = "_fsdp_wrapped_module.to_logits.weight"
|
345 |
+
else:
|
346 |
+
exclude_param = "to_logits.weight"
|
347 |
+
|
348 |
+
# Iterate over the decay list, which contains the names of the parameters with weight decay.
|
349 |
+
for param in decay:
|
350 |
+
# Check if the current parameter is not 'to_logits.weight'.
|
351 |
+
# Append the corresponding parameter from param_dict to the decay_param list.
|
352 |
+
|
353 |
+
if param != exclude_param:
|
354 |
+
decay_param.append(param_dict[param])
|
355 |
+
|
356 |
+
# Create an empty list called no_decay_param to store the parameters without weight decay.
|
357 |
+
no_decay_param = []
|
358 |
+
|
359 |
+
# Iterate over the no_decay list, which contains the names of the parameters without weight decay.
|
360 |
+
for param in no_decay:
|
361 |
+
try:
|
362 |
+
|
363 |
+
# Append the corresponding parameter from param_dict to the no_decay_param list.
|
364 |
+
no_decay_param.append(param_dict[param])
|
365 |
+
except KeyError:
|
366 |
+
# print(f"Parameter {param_name} does not exist in the model")
|
367 |
+
pass
|
368 |
+
|
369 |
+
# Create a list called grouped_params that contains two dictionaries.
|
370 |
+
# The first dictionary has the decay_param list and the corresponding weight_decay value.
|
371 |
+
# The second dictionary has the no_decay_param list and a weight_decay value of 0.0.
|
372 |
+
grouped_params = [
|
373 |
+
{"params": decay_param, "weight_decay": weight_decay},
|
374 |
+
{"params": no_decay_param, "weight_decay": 0.0},
|
375 |
+
]
|
376 |
+
|
377 |
+
# Create a variable called optimizer that stores an instance of the optimizer.
|
378 |
+
if optimizer_type == "lion":
|
379 |
+
optimizer = Lion(grouped_params, lr=learning_rate, betas=(beta_1, beta_2),)
|
380 |
+
elif optimizer_type == "adamw":
|
381 |
+
optimizer = AdamW(grouped_params, lr=learning_rate, betas=(beta_1, beta_2),)
|
382 |
+
elif optimizer_type == "deepspeed":
|
383 |
+
optimizer = DummyOptim(grouped_params, lr=learning_rate, betas=(beta_1, beta_2),)
|
384 |
+
elif optimizer_type == "stable_adamw":
|
385 |
+
optimizer = StableAdamWUnfused(
|
386 |
+
grouped_params, lr=learning_rate, betas=(beta_1, beta_2),
|
387 |
+
)
|
388 |
+
# elif optimizer_type=="Adam8bit":
|
389 |
+
# optimizer = bnb.optim.Adam8bit(grouped_params, lr=learning_rate, betas=(beta_1, beta_2))
|
390 |
+
# elif optimizer_type=="Lion8Bit":
|
391 |
+
# optimizer = bnb.optim.Lion8bit(grouped_params, lr=learning_rate, betas=(beta_1, beta_2))
|
392 |
+
else:
|
393 |
+
raise ValueError(
|
394 |
+
"Invalid optimizer_type. Expected 'lion', 'adamw', 'deepspeed' or 'stable_adamw', got: {}".format(
|
395 |
+
optimizer_type
|
396 |
+
)
|
397 |
+
)
|
398 |
+
|
399 |
+
# Return the optimizer.
|
400 |
+
return optimizer
|
401 |
+
|
402 |
+
|
403 |
+
# dataloaders
|
404 |
+
|
405 |
+
|
406 |
+
def build_dataloaders():
|
407 |
+
"""
|
408 |
+
Build data loaders for training.
|
409 |
+
|
410 |
+
This function performs the following steps:
|
411 |
+
1. Load the tokenizer from the pretrained "EleutherAI/gpt-neox-20b" model.
|
412 |
+
2. Load the "openwebtext" dataset.
|
413 |
+
3. Tokenize the dataset, adding the end-of-sentence token to each text.
|
414 |
+
4. Process the tokenized dataset into chunks of a specified block size.
|
415 |
+
|
416 |
+
Returns:
|
417 |
+
Dataset: The processed dataset ready for training.
|
418 |
+
"""
|
419 |
+
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
420 |
+
dataset = load_dataset("openwebtext", split="train")
|
421 |
+
|
422 |
+
tokenized_dataset = dataset.map(
|
423 |
+
lambda example: tokenizer([t + tokenizer.eos_token for t in example["text"]]),
|
424 |
+
batched=True,
|
425 |
+
num_proc=CFG.NUM_CPU,
|
426 |
+
remove_columns=["text"],
|
427 |
+
)
|
428 |
+
|
429 |
+
block_size = CFG.SEQ_LEN
|
430 |
+
|
431 |
+
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
|
432 |
+
def group_texts(examples):
|
433 |
+
# Concatenate all texts.
|
434 |
+
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
|
435 |
+
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
436 |
+
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
437 |
+
# customize this part to your needs.
|
438 |
+
if total_length >= block_size:
|
439 |
+
total_length = (total_length // block_size) * block_size
|
440 |
+
# Split by chunks of max_len.
|
441 |
+
result = {
|
442 |
+
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
443 |
+
for k, t in concatenated_examples.items()
|
444 |
+
}
|
445 |
+
return result
|
446 |
+
|
447 |
+
train_dataset = tokenized_dataset.map(
|
448 |
+
group_texts, batched=True, num_proc=CFG.NUM_CPU,
|
449 |
+
)
|
450 |
+
|
451 |
+
return train_dataset
|
452 |
+
|
453 |
+
#switch to falconwebdataset
|
454 |
+
def build_pre_tokenized():
|
455 |
+
d0 = load_dataset("conceptofmind/c4_0-to-20_neox_with_eos_8k", split="train[:10]")
|
456 |
+
# d1 = load_dataset("conceptofmind/c4_21-to-40_neox_with_eos_8k", split="train")
|
457 |
+
# d2 = load_dataset("conceptofmind/c4_41-to-60_neox_with_eos_8k", split="train")
|
458 |
+
# d3 = load_dataset("conceptofmind/c4_61-to-80_neox_with_eos_8k", split="train")
|
459 |
+
# d4 = load_dataset("conceptofmind/c4_81-to-100_neox_with_eos_8k", split="train")
|
460 |
+
# train_dataset = concatenate_datasets([d0, d1, d2, d3, d4])
|
461 |
+
return d0
|
462 |
+
|
463 |
+
|
464 |
+
|
465 |
+
def Train():
|
466 |
+
# accelerator
|
467 |
+
|
468 |
+
timeout = InitProcessGroupKwargs(timeout=timedelta(seconds=1_000_000))
|
469 |
+
|
470 |
+
accelerator = Accelerator(
|
471 |
+
gradient_accumulation_steps=CFG.GRADIENT_ACCUMULATE_EVERY,
|
472 |
+
mixed_precision="fp16",
|
473 |
+
log_with="wandb",
|
474 |
+
kwargs_handlers=[timeout],
|
475 |
+
)
|
476 |
+
|
477 |
+
state = AcceleratorState()
|
478 |
+
|
479 |
+
state.deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = CFG.BATCH_SIZE #??????
|
480 |
+
|
481 |
+
accelerator.init_trackers(
|
482 |
+
project_name="Andromeda",
|
483 |
+
config={
|
484 |
+
"batch_size": CFG.BATCH_SIZE,
|
485 |
+
"gradient_accumulate_every": CFG.GRADIENT_ACCUMULATE_EVERY,
|
486 |
+
"learning_rate": CFG.LEARNING_RATE,
|
487 |
+
"seq_len": CFG.SEQ_LEN,
|
488 |
+
},
|
489 |
+
# init_kwargs={"wandb": {"entity": CFG.ENTITY_NAME}},
|
490 |
+
)
|
491 |
+
|
492 |
+
accelerator.print(f"Total GPUS: {accelerator.num_processes}")
|
493 |
+
|
494 |
+
# set seed
|
495 |
+
|
496 |
+
set_seed(CFG.SEED)
|
497 |
+
|
498 |
+
# model = Andromeda(
|
499 |
+
# num_tokens=50432,
|
500 |
+
# max_seq_len=8192,
|
501 |
+
# dim=3072,
|
502 |
+
# depth=24,
|
503 |
+
# dim_head=128,
|
504 |
+
# heads=12,
|
505 |
+
# use_abs_pos_emb=False,
|
506 |
+
# alibi_pos_bias=True,
|
507 |
+
# alibi_num_heads=6,
|
508 |
+
# rotary_xpos=True,
|
509 |
+
# attn_flash=True,
|
510 |
+
# shift_tokens=1,
|
511 |
+
# attn_one_kv_head=True,
|
512 |
+
# qk_norm=True,
|
513 |
+
# attn_qk_norm=True,
|
514 |
+
# attn_qk_norm_dim_scale=True,
|
515 |
+
# embedding_provider=AndromedaEmbedding()
|
516 |
+
# )
|
517 |
+
model = Andromeda1Billion()
|
518 |
+
|
519 |
+
print_num_params(model, accelerator)
|
520 |
+
|
521 |
+
if CFG.USE_FSDP:
|
522 |
+
model = fsdp(
|
523 |
+
model,
|
524 |
+
mp="fp16",
|
525 |
+
shard_strat="SHARD_GRAD"
|
526 |
+
)
|
527 |
+
|
528 |
+
if CFG.USE_ACTIVATION_CHECKPOINTING:
|
529 |
+
activation_checkpointing(model, accelerator)
|
530 |
+
|
531 |
+
model = accelerator.prepare(model)
|
532 |
+
|
533 |
+
# dataloaders
|
534 |
+
|
535 |
+
if CFG.USE_PRETOKENIZED:
|
536 |
+
train_dataset = build_pre_tokenized()
|
537 |
+
else:
|
538 |
+
train_dataset = build_dataloaders()
|
539 |
+
|
540 |
+
train_loader = DataLoader(
|
541 |
+
train_dataset, batch_size=CFG.BATCH_SIZE, collate_fn=default_data_collator,
|
542 |
+
)
|
543 |
+
|
544 |
+
|
545 |
+
# optimizer
|
546 |
+
optim = decoupled_optimizer(
|
547 |
+
model=model,
|
548 |
+
learning_rate=CFG.LEARNING_RATE,
|
549 |
+
weight_decay=CFG.WEIGHT_DECAY,
|
550 |
+
beta_1=0.90,
|
551 |
+
beta_2=0.95,
|
552 |
+
optimizer_type='lion',
|
553 |
+
use_fsdp=True,
|
554 |
+
accelerator=accelerator
|
555 |
+
)
|
556 |
+
|
557 |
+
# Determine number of training steps
|
558 |
+
|
559 |
+
max_train_steps = math.ceil(len(train_loader) / CFG.GRADIENT_ACCUMULATE_EVERY)
|
560 |
+
accelerator.print(f"Max train steps: {max_train_steps}")
|
561 |
+
|
562 |
+
# lr scheduler
|
563 |
+
|
564 |
+
NUM_WARMUP_STEPS = int(max_train_steps * 0.01)
|
565 |
+
accelerator.print(f"Num warmup steps: {NUM_WARMUP_STEPS}")
|
566 |
+
|
567 |
+
# if False: # if CFG.USE_DEEPSPEED:
|
568 |
+
# lr_scheduler = DummyScheduler(
|
569 |
+
# optim,
|
570 |
+
# total_num_steps=max_train_steps * accelerator.num_processes,
|
571 |
+
# warmup_num_steps=NUM_WARMUP_STEPS
|
572 |
+
# )
|
573 |
+
# else:
|
574 |
+
lr_scheduler = get_lr_scheduler_with_warmup(
|
575 |
+
optimizer=optim,
|
576 |
+
scheduler_type="cosine",
|
577 |
+
num_warmup_steps=NUM_WARMUP_STEPS,
|
578 |
+
max_train_steps=max_train_steps,
|
579 |
+
grad_accumulate_every=CFG.GRADIENT_ACCUMULATE_EVERY,
|
580 |
+
)
|
581 |
+
|
582 |
+
# prepare
|
583 |
+
|
584 |
+
optim, train_loader, lr_scheduler = accelerator.prepare(
|
585 |
+
optim, train_loader, lr_scheduler
|
586 |
+
)
|
587 |
+
|
588 |
+
# checkpoint scheduler
|
589 |
+
|
590 |
+
accelerator.register_for_checkpointing(lr_scheduler)
|
591 |
+
|
592 |
+
# I do not know why Huggingface recommends recalculation of max_train_steps
|
593 |
+
|
594 |
+
max_train_steps = math.ceil(len(train_loader) / CFG.GRADIENT_ACCUMULATE_EVERY)
|
595 |
+
accelerator.print(f"Max train steps recalculated: {max_train_steps}")
|
596 |
+
|
597 |
+
# Total batch size for logging
|
598 |
+
|
599 |
+
total_batch_size = (
|
600 |
+
CFG.BATCH_SIZE * accelerator.num_processes * CFG.GRADIENT_ACCUMULATE_EVERY
|
601 |
+
)
|
602 |
+
accelerator.print(f"Total batch size: {total_batch_size}")
|
603 |
+
|
604 |
+
# resume training
|
605 |
+
|
606 |
+
progress_bar = tqdm(
|
607 |
+
range(max_train_steps), disable=not accelerator.is_local_main_process
|
608 |
+
)
|
609 |
+
completed_steps = 0
|
610 |
+
|
611 |
+
if CFG.RESUME_FROM_CHECKPOINT:
|
612 |
+
if CFG.RESUME_FROM_CHECKPOINT is not None or CFG.RESUME_FROM_CHECKPOINT != "":
|
613 |
+
accelerator.print(f"Resuming from checkpoint {CFG.RESUME_FROM_CHECKPOINT}")
|
614 |
+
accelerator.load_state(CFG.RESUME_FROM_CHECKPOINT)
|
615 |
+
path = os.path.basename(CFG.RESUME_FROM_CHECKPOINT)
|
616 |
+
training_difference = os.path.splitext(path)[0]
|
617 |
+
|
618 |
+
# need to multiply `gradient_accumulation_steps` to reflect real steps
|
619 |
+
resume_step = (
|
620 |
+
int(training_difference.replace("step_", ""))
|
621 |
+
* CFG.GRADIENT_ACCUMULATE_EVERY
|
622 |
+
)
|
623 |
+
|
624 |
+
if CFG.RESUME_FROM_CHECKPOINT and resume_step is not None:
|
625 |
+
train_loader = accelerator.skip_first_batches(train_loader, resume_step)
|
626 |
+
completed_steps += resume_step
|
627 |
+
progress_bar.update(resume_step)
|
628 |
+
|
629 |
+
# training
|
630 |
+
|
631 |
+
model.train()
|
632 |
+
for step, batch in enumerate(train_loader):
|
633 |
+
with accelerator.accumulate(model):
|
634 |
+
inputs = batch["input_ids"].to(accelerator.device)
|
635 |
+
loss = model(inputs, return_loss=True)
|
636 |
+
accelerator.backward(loss)
|
637 |
+
|
638 |
+
accelerator.log({"loss": loss.item()}, step=step)
|
639 |
+
|
640 |
+
if accelerator.sync_gradients:
|
641 |
+
accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
642 |
+
|
643 |
+
optim.step()
|
644 |
+
lr_scheduler.step()
|
645 |
+
optim.zero_grad()
|
646 |
+
|
647 |
+
if accelerator.sync_gradients:
|
648 |
+
progress_bar.update(1)
|
649 |
+
completed_steps += 1
|
650 |
+
|
651 |
+
if isinstance(CFG.CHECKPOINTING_STEPS, int):
|
652 |
+
if completed_steps % CFG.CHECKPOINTING_STEPS == 0:
|
653 |
+
output_dir = f"step_{completed_steps }"
|
654 |
+
if CFG.OUTPUT_DIR is not None:
|
655 |
+
output_dir = os.path.join(CFG.OUTPUT_DIR, output_dir)
|
656 |
+
accelerator.save_state(output_dir)
|
657 |
+
|
658 |
+
if completed_steps >= max_train_steps:
|
659 |
+
break
|
660 |
+
|
661 |
+
#logging every CFG.LOGGING STEPS
|
662 |
+
if CFG.LOGGING_STEPS > 0 and step % CFG.LOGGING_STEPS == 0:
|
663 |
+
logger.info(
|
664 |
+
f"Step: {completed_steps}/{max_train_steps}, Loss: {loss.item():.5f}"
|
665 |
+
)
|
666 |
+
|
667 |
+
# end training
|
668 |
+
|
669 |
+
# accelerator.print(f"Training Finished")
|
670 |
+
accelerator.end_training()
|
671 |
+
|
672 |
+
# save final model
|
673 |
+
|
674 |
+
# accelerator.print(f"Saving model to {CFG.OUTPUT_DIR}")
|
675 |
+
if CFG.OUTPUT_DIR is not None:
|
676 |
+
accelerator.wait_for_everyone()
|
677 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
678 |
+
with accelerator.main_process_first():
|
679 |
+
accelerator.save(
|
680 |
+
unwrapped_model.state_dict(), f"{CFG.OUTPUT_DIR}/final/final_model.pt"
|
681 |
+
)
|
682 |
+
|
683 |
+
|
684 |
+
def train():
|
685 |
+
os.environ['MASTER_ADDR'] #'localhost'
|
686 |
+
os.environ['MASTER_PORT'] #= '9994'
|
687 |
+
|
688 |
+
# # [CRITICAL] Pay attention to this when scaling to multiple GPUs and clusters
|
689 |
+
|
690 |
+
# # Pay attention to this, use "accelerate config"
|
691 |
+
|
692 |
+
os.environ['RANK'] #= str(0) # Number of nodes (servers)
|
693 |
+
os.environ['WORLD_SIZE'] # = str(torch.cuda.device_count())
|
694 |
+
|
695 |
+
dist.init_process_group(backend='nccl') #init_method="env://")
|
696 |
+
|
697 |
+
Train()
|
698 |
+
|
699 |
+
if __name__ == '__main__':
|
700 |
+
train()
|
Andromeda/Andromeda/utils/__init__.py
ADDED
File without changes
|
Andromeda/Andromeda/utils/decoupled_optimizer.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
# from palm_rlhf_pytorch.palm import LayerNorm
|
3 |
+
from torch.nn import LayerNorm
|
4 |
+
from torch.optim import AdamW
|
5 |
+
|
6 |
+
# from palm.utils import print_main
|
7 |
+
from Andromeda.utils.helpers import print_main
|
8 |
+
from Andromeda.utils.stable_adamw import StableAdamWUnfused
|
9 |
+
|
10 |
+
# optimizers
|
11 |
+
|
12 |
+
|
13 |
+
def decoupled_optimizer(
|
14 |
+
model: torch.nn.Module,
|
15 |
+
learning_rate: float,
|
16 |
+
weight_decay: float = 0.1,
|
17 |
+
beta_1: float = 0.90,
|
18 |
+
beta_2: float = 0.95,
|
19 |
+
optimizer_type: str = "adamw",
|
20 |
+
use_fsdp: bool = True,
|
21 |
+
):
|
22 |
+
"""
|
23 |
+
Decouples the optimizer from the training process.
|
24 |
+
|
25 |
+
This function sets up the optimizer for the model by creating two groups of parameters:
|
26 |
+
one for weight decay and one without weight decay. Then, it initializes the optimizer
|
27 |
+
with these two groups of parameters.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
model (Module): The model whose parameters are optimized.
|
31 |
+
learning_rate (float): The learning rate for the optimizer.
|
32 |
+
weight_decay (float): The weight decay for the optimizer.
|
33 |
+
beta_1 (float): The exponential decay rate for the 1st moment estimates.
|
34 |
+
beta_2 (float): The exponential decay rate for the 2nd moment estimates.
|
35 |
+
optimizer_type (str): The type of the optimizer. Can be 'lion', 'adamw', or 'stable_adamw'.
|
36 |
+
use_fsdp (bool, optional): If True, the optimizer will work with fully sharded data parallelism. Defaults to True.
|
37 |
+
accelerator (Accelerator, optional): The accelerator from HuggingFace's Accelerate library. Defaults to None.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
Optimizer: The initialized optimizer.
|
41 |
+
|
42 |
+
Raises:
|
43 |
+
ValueError: If the optimizer type is not 'lion', 'adamw' or 'stable_adamw'.
|
44 |
+
"""
|
45 |
+
print_main(f"Using {optimizer_type} optimizer")
|
46 |
+
# Create an empty dictionary called param_dict to store the model's named parameters.
|
47 |
+
param_dict = {}
|
48 |
+
# Iterate over the model's named parameters and populate the param_dict with key-value pairs.
|
49 |
+
for param_name, param in model.named_parameters():
|
50 |
+
print_main(param_name)
|
51 |
+
param_dict[param_name] = param
|
52 |
+
|
53 |
+
# Separate the model's named modules into two groups: decay and no_decay.
|
54 |
+
|
55 |
+
# Create an empty list to store the names of the LayerNorm and Embedding layer weights with no weight decay.
|
56 |
+
no_decay = []
|
57 |
+
|
58 |
+
if use_fsdp:
|
59 |
+
exclude_module = "_fsdp_wrapped_module.token_emb"
|
60 |
+
else:
|
61 |
+
exclude_module = "token_emb"
|
62 |
+
|
63 |
+
# Iterate through the named modules of the model.
|
64 |
+
for module_name, module in model.named_modules():
|
65 |
+
# Check if the current module is an instance of any of the desired types (LayerNorm or torch.nn.Embedding).
|
66 |
+
for ndim in [LayerNorm, torch.nn.Embedding]:
|
67 |
+
if isinstance(module, ndim):
|
68 |
+
# If torch.nn.Embedding, append its name with a ".weight" suffix to the no_decay list.
|
69 |
+
if module_name == exclude_module:
|
70 |
+
no_decay.append(f"{module_name}.weight")
|
71 |
+
else:
|
72 |
+
# If the module is an instance of LayerNorm
|
73 |
+
no_decay.append(f"{module_name}.gamma")
|
74 |
+
# Exit the inner loop since the desired module has been found.
|
75 |
+
break
|
76 |
+
|
77 |
+
# Create an empty list to store the names of the Linear layer weights with weight decay.
|
78 |
+
decay = []
|
79 |
+
|
80 |
+
# Iterate through the named modules of the model.
|
81 |
+
for module_name, module in model.named_modules():
|
82 |
+
# Check if the current module is an instance of the desired type (torch.nn.Linear).
|
83 |
+
for ndim in [torch.nn.Linear]:
|
84 |
+
if isinstance(module, ndim):
|
85 |
+
# If the module is an instance of torch.nn.Linear, append its name with a ".weight" suffix to the decay list.
|
86 |
+
decay.append(f"{module_name}.weight")
|
87 |
+
# Exit the inner loop since the desired module has been found.
|
88 |
+
break
|
89 |
+
|
90 |
+
# Create two separate lists of model parameters: decay_param and no_decay_param.
|
91 |
+
# The decay_param list contains the parameters that should have weight decay applied.
|
92 |
+
# The no_decay_param list contains the parameters that should not have weight decay applied, excluding the 'to_logits.weight' parameter.
|
93 |
+
|
94 |
+
# Create an empty list called decay_param to store the parameters with weight decay.
|
95 |
+
decay_param = []
|
96 |
+
|
97 |
+
if use_fsdp:
|
98 |
+
exclude_param = "_fsdp_wrapped_module.to_logits.weight"
|
99 |
+
else:
|
100 |
+
exclude_param = "to_logits.weight"
|
101 |
+
|
102 |
+
# Iterate over the decay list, which contains the names of the parameters with weight decay.
|
103 |
+
for param in decay:
|
104 |
+
# Check if the current parameter is not 'to_logits.weight'.
|
105 |
+
# Append the corresponding parameter from param_dict to the decay_param list.
|
106 |
+
|
107 |
+
if param != exclude_param:
|
108 |
+
decay_param.append(param_dict[param])
|
109 |
+
|
110 |
+
# Create an empty list called no_decay_param to store the parameters without weight decay.
|
111 |
+
no_decay_param = []
|
112 |
+
|
113 |
+
# Iterate over the no_decay list, which contains the names of the parameters without weight decay.
|
114 |
+
for param in no_decay:
|
115 |
+
# Append the corresponding parameter from param_dict to the no_decay_param list.
|
116 |
+
no_decay_param.append(param_dict[param])
|
117 |
+
|
118 |
+
# Create a list called grouped_params that contains two dictionaries.
|
119 |
+
# The first dictionary has the decay_param list and the corresponding weight_decay value.
|
120 |
+
# The second dictionary has the no_decay_param list and a weight_decay value of 0.0.
|
121 |
+
grouped_params = [
|
122 |
+
{"params": decay_param, "weight_decay": weight_decay},
|
123 |
+
{"params": no_decay_param, "weight_decay": 0.0},
|
124 |
+
]
|
125 |
+
|
126 |
+
# Create a variable called optimizer that stores an instance of the optimizer.
|
127 |
+
if optimizer_type == "adamw":
|
128 |
+
optimizer = AdamW(
|
129 |
+
grouped_params,
|
130 |
+
lr=learning_rate,
|
131 |
+
betas=(beta_1, beta_2),
|
132 |
+
)
|
133 |
+
elif optimizer_type == "stable_adamw":
|
134 |
+
optimizer = StableAdamWUnfused(
|
135 |
+
grouped_params,
|
136 |
+
lr=learning_rate,
|
137 |
+
betas=(beta_1, beta_2),
|
138 |
+
)
|
139 |
+
else:
|
140 |
+
raise ValueError(
|
141 |
+
"Invalid optimizer_type. Expected 'lion', 'adamw', 'deepspeed' or 'stable_adamw', got: {}".format(
|
142 |
+
optimizer_type
|
143 |
+
)
|
144 |
+
)
|
145 |
+
|
146 |
+
# Return the optimizer.
|
147 |
+
return optimizer
|
Andromeda/Andromeda/utils/helpers.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.distributed as dist # Add this line
|
2 |
+
|
3 |
+
def print_num_params(model):
|
4 |
+
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
5 |
+
|
6 |
+
if dist.is_available():
|
7 |
+
if dist.get_rank() == 0:
|
8 |
+
print(f"Number of parameters in model: {n_params}")
|
9 |
+
else:
|
10 |
+
print(f"Number of parameters in model: {n_params}")
|
11 |
+
|
12 |
+
def print_main(msg):
|
13 |
+
if dist.is_available():
|
14 |
+
if dist.get_rank() == 0:
|
15 |
+
print(msg)
|
16 |
+
else:
|
17 |
+
print(msg)
|
Andromeda/Andromeda/utils/rf_utils.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import einsum, _nnpack_available
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import nn
|
6 |
+
from einops import rearrange
|
7 |
+
import copy
|
8 |
+
from pathlib import PurePath
|
9 |
+
from tqdm import tqdm_gui
|
10 |
+
from beartype import beartype
|
11 |
+
from beartype.typing import Tuple, Optional
|
12 |
+
|
13 |
+
from einops import rearrange, repeat, reduce, unpack
|
14 |
+
from einops.layers.torch import Rearrange, Reduce
|
15 |
+
|
16 |
+
|
17 |
+
#helpers
|
18 |
+
def exists(val):
|
19 |
+
return val is not None
|
20 |
+
|
21 |
+
|
22 |
+
#decorators
|
23 |
+
def eval_decorator(fn):
|
24 |
+
def inner(self, *args, **kwargs):
|
25 |
+
was_training = self.training
|
26 |
+
self.eval()
|
27 |
+
out = fn(self, *args, **kwargs)
|
28 |
+
self.train(was_training)
|
29 |
+
return out
|
30 |
+
return inner
|
31 |
+
|
32 |
+
def defaults(val, d):
|
33 |
+
return val if exists(val) else d
|
34 |
+
|
35 |
+
#tensor helpers
|
36 |
+
|
37 |
+
def log(t, eps=1e-20):
|
38 |
+
return torch.log(t.clamp(min = eps))
|
39 |
+
|
40 |
+
def masked_mean(seq, mask=None, dim=1, keepdim=True):
|
41 |
+
if not exists(mask):
|
42 |
+
return seq.mean(dim=dim)
|
43 |
+
|
44 |
+
if seq.ndim == 3:
|
45 |
+
mask = rearrange(mask, 'b n -> b n 1')
|
46 |
+
|
47 |
+
masked_seq = seq.masked_fill(~mask, 0.)
|
48 |
+
numer = masked_seq.sum(dim=dim, keepdim=keepdim)
|
49 |
+
denom = mask.sum(dim=dim, keepdim=keepdim)
|
50 |
+
|
51 |
+
masked_mean = numer / denom.clamp(min = 1e-3)
|
52 |
+
masked_mean = masked_mean.masked_fill(denom == 0, 0.)
|
53 |
+
return masked_mean
|
54 |
+
|
55 |
+
|
56 |
+
#sampling helpers
|
57 |
+
|
58 |
+
def gumbel_noise(t):
|
59 |
+
noise = torch.zeros_like(t).uniform(0, 1)
|
60 |
+
return -log(-log(noise))
|
61 |
+
|
62 |
+
|
63 |
+
def gumbel_sample(t, temperature = 1., dim=-1):
|
64 |
+
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
|
65 |
+
|
66 |
+
def top_p(logits, thres=0.9):
|
67 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
68 |
+
cum_probs = torch.einsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
69 |
+
|
70 |
+
sorted_indices_to_remove = cum_probs > (1 - thres)
|
71 |
+
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
72 |
+
sorted_indices_to_remove[:, 0] = 0
|
73 |
+
|
74 |
+
sorted_logits[sorted_indices_to_remove] = float("-inf")
|
75 |
+
return sorted_logits.scatter(1, sorted_indices, sorted_logits)
|
76 |
+
|
77 |
+
def top_k(logits, thres=0.9):
|
78 |
+
k = math.ceil((1 - thres) * logits.shape[-1])
|
79 |
+
val, ind = torch.topk(logits, k)
|
80 |
+
probs = torch.full_like(logits, float('-inf'))
|
81 |
+
probs.scatter_(1, ind, val)
|
82 |
+
return probs
|
83 |
+
|
84 |
+
|
85 |
+
class LoRA(nn.Module):
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
dim,
|
89 |
+
dim_out,
|
90 |
+
r=8,
|
91 |
+
alpha=None
|
92 |
+
):
|
93 |
+
super().__init__()
|
94 |
+
alpha = defaults(alpha, r)
|
95 |
+
self.scale = alpha / r
|
96 |
+
|
97 |
+
self.A = nn.Parameter(torch.randn(dim, r))
|
98 |
+
self.B = nn.Parameter(torch.zeros(r, dim_out))
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
#reward model
|
103 |
+
@beartype
|
104 |
+
|
105 |
+
class RewardModel(nn.Module):
|
106 |
+
def __init__(
|
107 |
+
self,
|
108 |
+
model: Andromeda,
|
109 |
+
dropout=0.1,
|
110 |
+
num_binned_output = 0.,
|
111 |
+
use_lora = True,
|
112 |
+
lora_r = 8,
|
113 |
+
reward_lora_scope = 'reward',
|
114 |
+
):
|
115 |
+
super().__init__()
|
116 |
+
|
117 |
+
self.model = copy.deepcopy(Andromeda)
|
118 |
+
self.model.set_dropout(dropout)
|
119 |
+
|
120 |
+
self.reward_lora_scope = reward_lora_scope is use_lora else None
|
121 |
+
|
122 |
+
if exists(self.reward_lora_scope):
|
123 |
+
self.model.add_finetune_params(reward_lora_scope, lora_r = lora_r)
|
124 |
+
|
125 |
+
dim = model.dim
|
126 |
+
|
127 |
+
self.binned_output = num_binned_output > 1
|
128 |
+
|
129 |
+
self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim))
|
130 |
+
self.response_embed = nn.Parameter(torch.zeros(1, 1, dim))
|
131 |
+
|
132 |
+
|
133 |
+
if self.binned_output:
|
134 |
+
self.to_pred = nn.Linear(dim, num_binned_output)
|
135 |
+
else:
|
136 |
+
self.to_pred = nn.Sequential(
|
137 |
+
nn.Linear(dim, 1, bias=False),
|
138 |
+
Rearrange('... 1 -> ...')
|
139 |
+
)
|
140 |
+
|
141 |
+
def load(self, path):
|
142 |
+
path = Path(path)
|
143 |
+
assert path.exists()
|
144 |
+
self.load_state_dict(torch.load(str(path)))
|
145 |
+
|
146 |
+
def finetune_parameters(self):
|
147 |
+
return (
|
148 |
+
*self.to_pred.parameters(),
|
149 |
+
*(self.model.finetune_parameters(self.reward_lora_scope) if exists(self.reward_lora_scope) else model.parameters())
|
150 |
+
)
|
151 |
+
|
152 |
+
|
153 |
+
def forward(
|
154 |
+
self,
|
155 |
+
x,
|
156 |
+
mask=None,
|
157 |
+
prompt_mask=None,
|
158 |
+
prompt_lengths=None,
|
159 |
+
labels=None,
|
160 |
+
sample=False,
|
161 |
+
sample_temperature=1.,
|
162 |
+
disable_lora=False
|
163 |
+
):
|
164 |
+
assert not (exists(prompt_mask) and exists(prompt_lengths))
|
165 |
+
|
166 |
+
#derive prompt mask from prompt lengths
|
167 |
+
|
168 |
+
if exists(prompt_lengths):
|
169 |
+
batch, seq_len = x.shape
|
170 |
+
arange = torch.arange(seq_len, device = x.device)
|
171 |
+
prompt_mask = repeat(arange, 'n -> n n', b = batch) > rearrange(prompt_lengths, 'b -> b 1')
|
172 |
+
|
173 |
+
#rward model should have an understand of which section is prompt and which section is repsonse
|
174 |
+
|
175 |
+
extra_embed = None
|
176 |
+
|
177 |
+
if exists(prompt_mask):
|
178 |
+
extra_embed = torch.where(
|
179 |
+
rearrange(prompt_mask, 'b n -> b n 1'),
|
180 |
+
self.prompt_embed,
|
181 |
+
self.response_embed
|
182 |
+
)
|
183 |
+
|
184 |
+
embeds = self.model(
|
185 |
+
x,
|
186 |
+
)
|
Andromeda/Andromeda/utils/stable_adamw.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
|
4 |
+
# This is the unfused version of StableAdamW. It is slower than the fused version (coming).
|
5 |
+
|
6 |
+
|
7 |
+
class StableAdamWUnfused(torch.optim.Optimizer):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
params,
|
11 |
+
lr=0.002,
|
12 |
+
weight_decay=0.2,
|
13 |
+
betas=(0.9, 0.99),
|
14 |
+
eps=1e-8,
|
15 |
+
clip_thresh=1.0,
|
16 |
+
precision="amp_bfloat16",
|
17 |
+
custom_scalar=65536,
|
18 |
+
):
|
19 |
+
beta1, beta2 = betas[0], betas[1]
|
20 |
+
defaults = dict(lr=lr, weight_decay=weight_decay, beta1=beta1, beta2=beta2)
|
21 |
+
super(StableAdamWUnfused, self).__init__(params, defaults)
|
22 |
+
|
23 |
+
self.eps = eps
|
24 |
+
self.d = clip_thresh
|
25 |
+
|
26 |
+
# Set precision to "custom_fp16" if you want to use a fixed loss scalar, custom_scalar, which is divided out in the update step.
|
27 |
+
# If you do this, call (custom_scalar * loss).backward() instead of loss.backward().
|
28 |
+
self.precision = precision
|
29 |
+
self.custom_scaler = custom_scalar
|
30 |
+
|
31 |
+
for group in self.param_groups:
|
32 |
+
group["step"] = 1.0
|
33 |
+
|
34 |
+
print("Using StableAdamWUnfused-v1")
|
35 |
+
|
36 |
+
def __setstate__(self, state):
|
37 |
+
super(StableAdamWUnfused, self).__setstate__(state)
|
38 |
+
|
39 |
+
def step(self, closure=None):
|
40 |
+
if closure is not None:
|
41 |
+
closure()
|
42 |
+
|
43 |
+
for group in self.param_groups:
|
44 |
+
lr = group["lr"]
|
45 |
+
weight_decay = group["weight_decay"]
|
46 |
+
beta1 = group["beta1"]
|
47 |
+
beta2 = group["beta2"]
|
48 |
+
step = group["step"]
|
49 |
+
|
50 |
+
for p in group["params"]:
|
51 |
+
if p.grad is None:
|
52 |
+
continue
|
53 |
+
theta = p.data
|
54 |
+
param_state = self.state[p]
|
55 |
+
|
56 |
+
if self.precision == "custom_fp16":
|
57 |
+
g = p.grad.data / self.custom_scaler
|
58 |
+
if torch.any(torch.isnan(g) | torch.isinf(g)):
|
59 |
+
continue
|
60 |
+
else:
|
61 |
+
g = p.grad.data
|
62 |
+
|
63 |
+
if "exp_avg" not in param_state:
|
64 |
+
v = param_state["exp_avg"] = torch.zeros_like(theta)
|
65 |
+
u = param_state["exp_avg_sq"] = torch.zeros_like(theta)
|
66 |
+
else:
|
67 |
+
v = param_state["exp_avg"]
|
68 |
+
u = param_state["exp_avg_sq"]
|
69 |
+
|
70 |
+
beta1hat = beta1 * (1 - beta1 ** (step - 1)) / (1 - beta1**step)
|
71 |
+
beta2hat = beta2 * (1 - beta2 ** (step - 1)) / (1 - beta2**step)
|
72 |
+
|
73 |
+
v = v.mul_(beta1hat).add_(g, alpha=1.0 - beta1hat)
|
74 |
+
u = u.mul_(beta2hat).addcmul_(g, g, value=1.0 - beta2hat)
|
75 |
+
|
76 |
+
denominator = u.sqrt().add_(self.eps)
|
77 |
+
|
78 |
+
# StableAdamW = AdamW + update clipping (https://arxiv.org/abs/1804.04235) applied tensor-wise.
|
79 |
+
rms = (
|
80 |
+
torch.div(
|
81 |
+
g.pow(2), torch.maximum(u, (self.eps**2) * torch.ones_like(u))
|
82 |
+
)
|
83 |
+
.mean()
|
84 |
+
.sqrt()
|
85 |
+
.item()
|
86 |
+
)
|
87 |
+
|
88 |
+
theta = theta.mul_(1.0 - lr * weight_decay).addcdiv_(
|
89 |
+
v, denominator, value=-lr * (1.0 / max(1.0, rms / self.d))
|
90 |
+
)
|
91 |
+
|
92 |
+
# save current params
|
93 |
+
param_state["exp_avg"] = v
|
94 |
+
param_state["exp_avg_sq"] = u
|
95 |
+
|
96 |
+
group["step"] = step + 1
|
Andromeda/DOCs/Corporation/MONETIZATION.md
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Andromeda Product Brief and Monetization Strategy Document
|
2 |
+
|
3 |
+
## Product Summary:
|
4 |
+
|
5 |
+
Andromeda is an innovative language model designed for high performance and efficiency. It utilizes advanced techniques that allow it to process and learn from multiple sources and adapt in real-time.
|
6 |
+
|
7 |
+
## Monetization Strategies:
|
8 |
+
|
9 |
+
1. **Usage-based API:** Provide Andromeda as a paid API service where users pay based on the amount of computation they use.
|
10 |
+
2. **Consulting deals:** Offer expert consulting services to businesses looking to incorporate Andromeda's capabilities into their operations.
|
11 |
+
3. **Dedicated capacity:** Sell dedicated computational power to businesses for exclusive usage of Andromeda's capabilities.
|
12 |
+
4. **Licensing the technology:** Allow companies to license the Andromeda model for their proprietary use.
|
13 |
+
5. **Subscription models:** Provide access to Andromeda's capabilities on a subscription basis.
|
14 |
+
6. **Freemium model:** Offer basic usage of Andromeda for free, while charging for advanced features and capabilities.
|
15 |
+
7. **Partnerships:** Form strategic partnerships with tech companies that can leverage Andromeda's capabilities in their products and services.
|
16 |
+
8. **Sponsorships:** Sponsor research projects or tech events to get visibility and promote Andromeda's services.
|
17 |
+
9. **Training and certifications:** Offer training programs and certifications on Andromeda usage and applications.
|
18 |
+
10. **Custom development:** Offer custom development services for businesses that want specialized applications of Andromeda.
|
19 |
+
|
20 |
+
## Potential Customers:
|
21 |
+
|
22 |
+
1. **Tech companies:** Andromeda can be integrated into a wide array of tech products and services.
|
23 |
+
2. **Educational institutions:** Universities and research institutions can use Andromeda for research purposes.
|
24 |
+
3. **Government agencies:** Andromeda can assist in processing and analyzing large amounts of data.
|
25 |
+
4. **Healthcare providers:** Andromeda can be used in data analysis and decision making in healthcare.
|
26 |
+
5. **Media and entertainment industry:** Andromeda's language model can be used in content creation and curation.
|
27 |
+
|
28 |
+
## Potential Cashflow Gains:
|
29 |
+
|
30 |
+
1. **API usage revenues:** Charging per API call can generate substantial revenues with a high number of users.
|
31 |
+
2. **Subscription fees:** A tier-based subscription model can ensure a steady income stream.
|
32 |
+
3. **Licensing fees:** Companies willing to license the technology can provide a significant one-time or recurring revenue.
|
33 |
+
4. **Consulting fees:** Consulting services can yield high-value contracts.
|
34 |
+
5. **Sponsorship revenues:** Sponsoring events or projects can yield returns in the form of new business leads and customers.
|
35 |
+
|
36 |
+
## Expenses:
|
37 |
+
|
38 |
+
1. **Cloud infrastructure costs:** Major expense in maintaining and scaling the Andromeda model.
|
39 |
+
2. **Research and development:** Continual improvement of Andromeda requires ongoing investment.
|
40 |
+
3. **Marketing and sales:** Promoting Andromeda and closing sales deals will be a recurring expense.
|
41 |
+
4. **Operational costs:** Expenses related to managing the company, including salaries, office space, utilities, and more.
|
42 |
+
5. **Open-source contributors:** Andromeda is built on the contributions of numerous developers. Recognizing these contributors through a rewards program is an essential part of maintaining a healthy development ecosystem.
|
43 |
+
|
44 |
+
### Open Source Contributors:
|
45 |
+
|
46 |
+
The following is a representative list of contributors who have helped make Agora what it is today:
|
47 |
+
|
48 |
+
1. Kye
|
49 |
+
2. Nicolo
|
50 |
+
|
51 |
+
Each contributor brings unique expertise and value to the project, helping to shape Andromeda into a powerful, efficient, and intelligent language model that will revolutionize the NLP landscape.
|
Andromeda/DOCs/Design/Dyson.md
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Insights and Techniques:
|
2 |
+
|
3 |
+
1. Flops: The importance of considering the number of floating-point operations (FLOPs) when designing models.
|
4 |
+
2. Flash Attention 2.0: The use of techniques like Flash Attention 2.0 cuda to enable more FLOPs in the model.
|
5 |
+
3. Mixed Precision: Utilizing mixed precision training to improve training speed and memory efficiency.
|
6 |
+
4. Deepspeed 3 with NVMe: Using Deepspeed 3 with NVMe for optimizing training performance.
|
7 |
+
5. 8-bit Optimizer: Employing an 8-bit optimizer for further speed improvements.
|
8 |
+
6. Gradient Clipping: Adding gradient clipping to achieve massive speedup during training.
|
9 |
+
7. XPOS, ALIBI, QK Layernorm: Leveraging advanced techniques for extrapolation, interpolation, and training stabilization.
|
10 |
+
8. Multi Query Attention: Using multi-query attention to boost decoding speed.
|
11 |
+
9. Parallelized Transformer Blocks: Parallelizing transformer blocks to enhance overall model performance.
|
12 |
+
10. Positional Embeddings and Shifted Tokens: The decision to not use positional embeddings and utilization of shifted tokens for sequence length advancement.
|
13 |
+
11. Positional Interpolation: Incorporating positional interpolation for improved sequence handling.
|
14 |
+
12. Optimized CUDA Embedding Function: Utilizing an optimized CUDA embedding function for better performance.
|
15 |
+
13. Nebula Loss Function: Implementing the Nebula loss function, a polymorphic loss function for multi-task training.
|
16 |
+
|
17 |
+
Possible Improvements:
|
18 |
+
|
19 |
+
1. Clearer Metrics: To validate the model's claims, it would be beneficial to establish specific metrics for monitoring across training, especially regarding reasoning capabilities.
|
20 |
+
2. Validation and Testing Environment: Further development and description of the exhaustive testing environment to validate the model's performance and capabilities.
|
21 |
+
3. Comprehensive Documentation: Provide detailed documentation of the model's architecture, training methodology, and testing procedures to ensure transparency and replicability.
|
22 |
+
4. Benchmarking Against Competitors: Perform benchmarking against existing models to showcase the advantages and differentiation offered by the proposed architecture and training techniques.
|
23 |
+
5. Real-World Applications: Highlight potential real-world applications or use cases where the proposed model can provide superior performance compared to existing solutions.
|
24 |
+
6. Explainability and Interpretability: Consider incorporating methods for model explainability and interpretability, especially in applications where these aspects are crucial.
|
25 |
+
7. Addressing Specific Niche Needs: Identify specific niches or use cases where the model can excel and tailor marketing and development efforts accordingly.
|
26 |
+
8. Collaboration and Peer Review: Engage with the research community, participate in peer review, and seek collaboration opportunities to gain additional insights and validation.
|
Andromeda/DOCs/Design/MODEL_ARCHITECTURE.md
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
### Alibi Positional Bias
|
3 |
+
|
4 |
+
Alibi positional bias allows the model to learn relative positions between tokens, enabling it to better capture the relationships and dependencies between tokens in a sequence.
|
5 |
+
|
6 |
+
Usage example:
|
7 |
+
|
8 |
+
```python
|
9 |
+
attn_layers = Decoder(
|
10 |
+
...
|
11 |
+
alibi_pos_bias=True,
|
12 |
+
alibi_num_heads=4,
|
13 |
+
...
|
14 |
+
)
|
15 |
+
```
|
16 |
+
|
17 |
+
### Rotary Position Encodings (xpos)
|
18 |
+
|
19 |
+
Rotary position encodings introduce a more efficient way to encode positions in the input sequence. They avoid the need for absolute positional embeddings, reducing the model's memory footprint and improving training speed.
|
20 |
+
|
21 |
+
Usage example:
|
22 |
+
|
23 |
+
```python
|
24 |
+
attn_layers = Decoder(
|
25 |
+
...
|
26 |
+
rotary_xpos=True,
|
27 |
+
...
|
28 |
+
)
|
29 |
+
```
|
30 |
+
|
31 |
+
### Flash Attention
|
32 |
+
|
33 |
+
Flash attention speeds up the self-attention mechanism by reducing the number of attention computations. It accelerates training and inference while maintaining a high level of performance.
|
34 |
+
|
35 |
+
Usage example:
|
36 |
+
|
37 |
+
```python
|
38 |
+
attn_layers = Decoder(
|
39 |
+
...
|
40 |
+
attn_flash=True,
|
41 |
+
...
|
42 |
+
)
|
43 |
+
```
|
44 |
+
|
45 |
+
Usage example:
|
46 |
+
|
47 |
+
```python
|
48 |
+
attn_layers = Decoder(
|
49 |
+
...
|
50 |
+
deepnorm=True,
|
51 |
+
...
|
52 |
+
)
|
53 |
+
```
|
54 |
+
|
55 |
+
### Deep Normalization (deepnorm)
|
56 |
+
|
57 |
+
Deep normalization is a technique that normalizes the activations within a layer, helping with training stability and convergence. It allows the model to better learn complex patterns and generalize to unseen data.
|
Andromeda/DOCs/Design/SPEED.md
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Increasing Speed
|
2 |
+
|
3 |
+
* Integrate Flash Attention 2.0 cuda, significant speed up
|
4 |
+
|
5 |
+
* Utilize 8BIT Optimizer from BNB, big speed up weakness => bnb isn't compatible with all gpus
|
6 |
+
|
7 |
+
* Use a better tokenizer TokenMonster?
|
8 |
+
|
9 |
+
* Parallelize the transformer blocks similar to that of [PALMS](https://github.com/conceptofmind/PaLM)
|
10 |
+
|
11 |
+
* Look into MPTS config for LION for pretraining, did they use high batch size?
|
Andromeda/DOCs/Design/Specs.md
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## **Andromeda Specs**: Unveiling Mastery
|
2 |
+
|
3 |
+
**Overview**
|
4 |
+
Elegantly marrying craftsmanship and technology, Andromeda is not just another step in AI evolution. It's a giant leap. Driven by precision, powered by innovation, and defined by excellence, Andromeda is the epitome of intelligence realized. Here, we detail the marvel that is Andromeda, in numbers, facts, and logic.
|
5 |
+
|
6 |
+
---
|
7 |
+
|
8 |
+
### **Specifications**
|
9 |
+
|
10 |
+
| **Feature** | **Specification** |
|
11 |
+
|----------------------------------------------|-----------------------------------------------|
|
12 |
+
| **Sequence Handling** | Ultra Long (32,000 - 200,000+ context lengths)|
|
13 |
+
| **Processing Speed** | Ultra Fast (32,000+ tokens in < 100ms) |
|
14 |
+
| **Reasoning Abilities** | Creativity, Quantitative |
|
15 |
+
| **Attention Mechanism** | Flash Attention 2.0 Triton |
|
16 |
+
| **Memory Consumption** (compared to GPT-3) | 100x Less |
|
17 |
+
| **Memory Consumption** (compared to LLAMA) | 30x Less |
|
18 |
+
| **Max Sequence Processing Speed** | 100,000+ sequences in < 300ms |
|
19 |
+
| **Dataset Strategy** | Books, Falcon, Redpajama, Math, Code |
|
20 |
+
| **Functionality** | FSDP, HF Accelerate, Poetry Composition, API Calls, and more |
|
21 |
+
|
22 |
+
---
|
23 |
+
|
24 |
+
### **Benchmarks**
|
25 |
+
**Speed**: At the heart of Andromeda's unparalleled capabilities is its raw speed. Leveraging the prowess of Flash Attention 2.0 Triton, it doesn't merely process data; it blazes through it. This power allows it to consume 50x less memory than its predecessor, GPT-3, and 10x less than LLAMA.
|
26 |
+
|
27 |
+
---
|
28 |
+
|
29 |
+
### **Why Andromeda?**
|
30 |
+
- **Performance**: Andromeda isn't about doing things faster; it's about doing them the best. Reliable processing of sequences, even as extensive as 100,000+ lengths, is realized in the blink of an eye, under 300ms.
|
31 |
+
|
32 |
+
- **Precision and Creativity**: The dataset strategy is no mere algorithm. It's a symphony, meticulously crafted to offer both creativity and quantitative reasoning.
|
33 |
+
|
34 |
+
- **Versatility**: Andromeda doesn't just compute; it contemplates. Whether you need the flair of a poet or the precision of an API call, Andromeda delivers, seamlessly.
|
35 |
+
|
36 |
+
---
|
37 |
+
|
38 |
+
### **Andromeda Principles**
|
39 |
+
- **Efficiency**: It's not just about doing more; it's about doing better. Techniques like attention flashing, rotary position encodings, and deep normalization ensure every cycle, every operation, every byte is optimized for performance.
|
40 |
+
|
41 |
+
- **Flexibility**: In the ever-evolving world of technology, adaptability is king. Andromeda is designed to mold, adapt, and excel, irrespective of the task or domain.
|
42 |
+
|
43 |
+
- **Scalability**: Grow with you, for you. Andromeda isn't static. It's dynamic, designed to scale, accommodating growing resources and expanding data sizes.
|
44 |
+
|
45 |
+
- **Community-Driven**: Behind Andromeda's machine brain is the human heart of the community. It doesn't just utilize open source; it thrives on it, constantly evolving, learning, and improving with contributions from around the world.
|
46 |
+
|
47 |
+
|
48 |
+
For enthusiasts, developers, and thinkers looking to dive deeper, the Model Architecture documentation offers an exhaustive, detailed view into the intricate marvel that is Andromeda. Dive in, and witness engineering and artistry in harmony.
|
49 |
+
|
50 |
+
---
|
51 |
+
|
52 |
+
### **Andromeda: A Detailed Technical Overview**
|
53 |
+
|
54 |
+
At the intersection of technological ingenuity and groundbreaking design principles, Andromeda emerges. Representing the zenith of years of research and development, it promises a transformative leap in AI performance, efficiency, and versatility. In this technical specifications document, we deconstruct the intricacies of Andromeda, presenting a meticulous overview of its structure, performance metrics, and underlying methodologies.
|
55 |
+
|
56 |
+
## **Feature Insights**
|
57 |
+
|
58 |
+
### **Alibi Positional Bias**
|
59 |
+
Empowering Andromeda to discern relative positions between tokens, this feature accentuates its ability to grasp intricate relationships within a sequence.
|
60 |
+
|
61 |
+
### **Rotary Position Encodings (xpos)**
|
62 |
+
This is a revolutionary means of encoding positions, shrinking the model's memory demands and propelling training speeds.
|
63 |
+
|
64 |
+
### **Flash Attention**
|
65 |
+
This is the linchpin of Andromeda's speed prowess, minimizing attention computations, thus boosting training and inference phases.
|
66 |
+
|
67 |
+
### **Deep Normalization (deepnorm)**
|
68 |
+
By normalizing activations, deep normalization shores up training stability, allowing Andromeda to identify intricate patterns with finesse.
|
69 |
+
|
70 |
+
## **Feature Insights (Contd.)**
|
71 |
+
|
72 |
+
### **Attn One KV Head (Multiquery Attention)**
|
73 |
+
A breakthrough in attention mechanism design, this feature allows for simultaneous computation of multiple queries against the same set of key-values, fostering speed and efficiency.
|
74 |
+
|
75 |
+
### **QK Norm & Attention QK Norm**
|
76 |
+
These two features introduce a normalization step in the query and key matrices. This step facilitates stabilization in the attention mechanism, rendering it more robust and enabling it to scale with larger input sizes.
|
77 |
+
|
78 |
+
### **Attention QK Norm Dimension Scale**
|
79 |
+
A sophisticated adjustment to the attention mechanism, it modulates the normalization scale in accordance to the dimensions of the model. The result is a more adaptive and responsive attention framework.
|
80 |
+
|
81 |
+
### **Embedding Provider**
|
82 |
+
At the foundation of Andromeda, this module facilitates the embedding process, converting token sequences into dense vectors. Tailored for Andromeda, it ensures rapid and efficient embedding processes.
|
83 |
+
|
84 |
+
---
|
85 |
+
|
86 |
+
## **Deeper Dive: Model Parameters**
|
87 |
+
|
88 |
+
Unpacking Andromeda means diving deep into the parameters that shape its capabilities. Here's a granular view:
|
89 |
+
|
90 |
+
| **Parameter** | **Description** | **Default Value** |
|
91 |
+
|-----------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------|
|
92 |
+
| **num_tokens** | Total number of tokens in the vocabulary. | 50432 |
|
93 |
+
| **max_seq_len** | Maximum sequence length the model can process. | 8192 |
|
94 |
+
| **dim** | Dimension size of the model. It represents the size of embeddings and general depth in neural layers. | 2560 |
|
95 |
+
| **depth** | Represents the number of transformer layers in the architecture. | 32 |
|
96 |
+
| **dim_head** | Dimension size of each head in multi-head attention mechanism. | 128 |
|
97 |
+
| **heads** | Total number of heads in multi-head attention. | 24 |
|
98 |
+
| **use_abs_pos_emb** | Boolean flag to determine if absolute positional embeddings are used. | False |
|
99 |
+
| **alibi_pos_bias** | Enables the alibi positional bias in attention mechanisms. | True |
|
100 |
+
| **alibi_num_heads** | Specifies the number of heads for the alibi positional bias. | 12 |
|
101 |
+
| **rotary_xpos** | Determines if rotary positional encodings are utilized. | True |
|
102 |
+
| **attn_flash** | Flag to activate the Flash Attention mechanism, minimizing computations in the attention phase. | True |
|
103 |
+
| **shift_tokens** | The number of tokens by which input sequences are shifted. Essential for certain sequence-to-sequence tasks. | 1 |
|
104 |
+
| **attn_one_kv_head** | Activates multiquery attention by computing multiple queries against a singular key-value pair. | True |
|
105 |
+
| **qk_norm** | Enables the query-key normalization mechanism in the attention phase. | True |
|
106 |
+
| **attn_qk_norm** | A more advanced version of query-key normalization that scales according to the model's dimensions. | True |
|
107 |
+
| **attn_qk_norm_dim_scale** | Modulates the scale of the aforementioned attention normalization based on the model's dimensionality. | True |
|
108 |
+
| **embedding_provider** | The module responsible for providing embeddings. Custom providers can be passed for tailored embedding processes. | AndromedaEmbedding|
|
109 |
+
|
110 |
+
---
|
111 |
+
|
112 |
+
|
113 |
+
## **Insights and Techniques**
|
114 |
+
|
115 |
+
#### **1. Floating-Point Operations (FLOPs)**
|
116 |
+
Considering the number of FLOPs is paramount. It provides a metric to gauge the computational intensity and, by extension, the potential speed of the model.
|
117 |
+
|
118 |
+
#### **2. Flash Attention 2.0 Triton**
|
119 |
+
Enhanced with CUDA, this method offers a significant surge in the number of FLOPs the model can handle, amplifying its overall efficiency.
|
120 |
+
|
121 |
+
#### **3. Mixed Precision Training**
|
122 |
+
By embracing mixed precision, Andromeda realizes a noteworthy uptick in training speed while achieving commendable memory efficiency.
|
123 |
+
|
124 |
+
#### **4. Deepspeed 3 with NVMe Integration**
|
125 |
+
This powerful combination paves the way for superlative optimization during the training phase.
|
126 |
+
|
127 |
+
#### **5. 8-bit Optimizer**
|
128 |
+
Further pushing the boundaries of speed, the 8-bit optimizer boosts processing times without compromising the integrity of results.
|
129 |
+
|
130 |
+
#### **6. Gradient Clipping**
|
131 |
+
This technique has been integrated into the training regimen, achieving a massive speedup and preventing undesirable spikes during the process.
|
132 |
+
|
133 |
+
#### **7. Advanced Techniques: XPOS, ALIBI, QK Layernorm**
|
134 |
+
These sophisticated techniques are harnessed for superior extrapolation, interpolation, and stabilization during training.
|
135 |
+
|
136 |
+
#### **8. Multi Query Attention**
|
137 |
+
This approach has been adopted to supercharge decoding speeds.
|
138 |
+
|
139 |
+
#### **9. Parallelized Transformer Blocks**
|
140 |
+
Ensuring that the model's performance is consistently high, these blocks run in tandem to provide a smooth and efficient operational experience.
|
141 |
+
|
142 |
+
#### **10. Shifted Tokens**
|
143 |
+
In a strategic move, Andromeda sidesteps traditional positional embeddings, relying instead on shifted tokens for sequence length progression.
|
144 |
+
|
145 |
+
#### **11. Positional Interpolation**
|
146 |
+
This innovative technique augments the model's ability to manage sequences more effectively.
|
147 |
+
|
148 |
+
#### **12. Optimized CUDA Embedding Function**
|
149 |
+
This function is tailored for peak performance, ensuring rapid and accurate computations.
|
150 |
+
|
151 |
+
#### **13. Nebula Loss Function**
|
152 |
+
Integrated into Andromeda, this polymorphic loss function is adept at handling multi-task training scenarios.
|
153 |
+
|
154 |
+
## **A Word on Optimization and Future Iterations**
|
155 |
+
|
156 |
+
As with any state-of-the-art model, Andromeda's design is an ever-evolving tapestry. This means iterative refinement. As feedback streams in and technology progresses, expect advancements in:
|
157 |
+
|
158 |
+
- **Model Pruning**: Trimming redundancies, bolstering efficiency.
|
159 |
+
- **Knowledge Distillation**: Harnessing the wisdom of larger models in smaller, more agile architectures.
|
160 |
+
- **Zero-Shot and Few-Shot Learning**: Broadening adaptability horizons.
|
161 |
+
- **Enhanced Data Augmentation**: Fortifying the model's grasp on varied, nuanced contexts.
|
162 |
+
- **Decentralized Training**: Tapping into the global hive-mind, harnessing the collaborative power of the community.
|
163 |
+
|
164 |
+
|
165 |
+
## **Potential Other Future Trajectories**
|
166 |
+
|
167 |
+
#### **1. Clearer Metrics**
|
168 |
+
There's always room to elevate the benchmarking rigor, especially concerning reasoning abilities.
|
169 |
+
|
170 |
+
#### **2. Robust Validation and Testing Environment**
|
171 |
+
Further fine-tuning of the testing environment can offer even more reliable validations of Andromeda's capabilities.
|
172 |
+
|
173 |
+
#### **3. Comprehensive Documentation**
|
174 |
+
To bolster transparency and replicability, detailed documentation covering every facet of Andromeda is on the horizon.
|
175 |
+
|
176 |
+
#### **4. Benchmarking Against Peers**
|
177 |
+
By juxtaposing Andromeda against its counterparts, its distinctive advantages can be spotlighted more effectively.
|
178 |
+
|
179 |
+
#### **5. Spotlight on Real-World Applications**
|
180 |
+
By highlighting tangible use-cases, the versatility and prowess of Andromeda can be showcased in palpable contexts.
|
181 |
+
|
182 |
+
#### **6. Model Interpretability**
|
183 |
+
Future iterations might delve deeper into model interpretability, especially for critical applications.
|
184 |
+
|
185 |
+
#### **7. Niche Customizations**
|
186 |
+
By tailoring Andromeda to meet specific niche needs, its adaptability and value proposition can be further enhanced.
|
187 |
+
|
188 |
+
#### **8. Collaborative Endeavors**
|
189 |
+
Engaging more intimately with the global research community could spawn collaborative projects, bringing diverse insights to the fore.
|
190 |
+
|
191 |
+
|
192 |
+
As we voyage further into the AI frontier, Andromeda stands as a beacon, illuminating the path forward, promising marvels yet to come. It's not just about machine intelligence; it's about the dance between human curiosity and machine capability.
|
193 |
+
|
194 |
+
---
|
195 |
+
|
196 |
+
Join us on this journey. Dive deeper, ask questions, innovate, and let's redefine what's possible, together.
|
Andromeda/DOCs/Docs/DOCUMENTATION.md
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Documentation
|
2 |
+
|
3 |
+
## `DatasetBuilder`
|
4 |
+
|
5 |
+
### DatasetBuilder
|
6 |
+
|
7 |
+
DatasetBuilder provides a convenient way to build datasets for training the Andromeda model.
|
8 |
+
|
9 |
+
#### Constructor
|
10 |
+
|
11 |
+
```python
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
dataset_name,
|
15 |
+
seq_len=8192,
|
16 |
+
num_cpu=None,
|
17 |
+
hf_account_repo=None,
|
18 |
+
tokenizer="EleutherAI/gpt-neox-20b",
|
19 |
+
)
|
20 |
+
```
|
21 |
+
|
22 |
+
Initialize the DatasetBuilder.
|
23 |
+
|
24 |
+
**Args:**
|
25 |
+
|
26 |
+
- `dataset_name` (str): Name of the dataset to process.
|
27 |
+
- `seq_len` (int): Maximum sequence length.
|
28 |
+
- `num_cpu` (int, optional): Number of CPU cores to use for multiprocessing. Defaults to None.
|
29 |
+
- `hf_account_repo` (str, optional): Hugging Face account name and repository to push the processed dataset. Defaults to None.
|
30 |
+
- `tokenizer` (str, optional): Tokenizer model to use. Defaults to "EleutherAI/gpt-neox-20b".
|
31 |
+
|
32 |
+
#### Methods
|
33 |
+
|
34 |
+
##### build_dataset
|
35 |
+
|
36 |
+
```python
|
37 |
+
def build_dataset(self) -> torch.utils.data.Dataset
|
38 |
+
```
|
39 |
+
|
40 |
+
Build and process the dataset.
|
41 |
+
|
42 |
+
**Returns:**
|
43 |
+
|
44 |
+
- `torch.utils.data.Dataset`: The processed dataset ready for training.
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
## AndromedaTokenizer
|
49 |
+
|
50 |
+
### Purpose
|
51 |
+
|
52 |
+
The `AndromedaTokenizer` class provides tokenization functionality using the Hugging Face tokenizer. It allows you to tokenize texts using the specified tokenizer model.
|
53 |
+
|
54 |
+
### Systems Understanding
|
55 |
+
|
56 |
+
The `AndromedaTokenizer` class initializes a tokenizer model from the Hugging Face library. It uses the `AutoTokenizer.from_pretrained` method to load the tokenizer model with specific parameters such as the EOS token, pad token, extra IDs, and model maximum length. The `tokenize_texts` method tokenizes input texts using the tokenizer model and returns the tokenized input IDs.
|
57 |
+
|
58 |
+
### Usage Example
|
59 |
+
|
60 |
+
```python
|
61 |
+
from Andromeda import AndromedaTokenizer
|
62 |
+
|
63 |
+
# Initialize the tokenizer
|
64 |
+
tokenizer = AndromedaTokenizer()
|
65 |
+
|
66 |
+
# Tokenize texts
|
67 |
+
texts = ["This is an example sentence.", "Another example sentence."]
|
68 |
+
tokenized_ids = tokenizer.tokenize_texts(texts)
|
69 |
+
|
70 |
+
print(tokenized_ids)
|
71 |
+
```
|
72 |
+
|
73 |
+
## Andromeda
|
74 |
+
|
75 |
+
### Purpose
|
76 |
+
|
77 |
+
The `Andromeda` class is a transformer-based model architecture. It consists of a `Transformer` and `AutoregressiveWrapper` with default or user-specified parameters.
|
78 |
+
|
79 |
+
### Systems Understanding
|
80 |
+
|
81 |
+
The `Andromeda` class initializes with a `Transformer` and `AutoregressiveWrapper`. The `Transformer` encapsulates the main transformer model, and the `AutoregressiveWrapper` enables autoregressive generation using the transformer model.
|
82 |
+
|
83 |
+
The constructor of the `Andromeda` class takes various parameters that define the architecture of the model, such as the number of tokens, maximum sequence length, model dimension, depth, number of heads, etc. These parameters are used to initialize the `Transformer` and `AutoregressiveWrapper` with the specified configuration.
|
84 |
+
|
85 |
+
The `forward` method performs a forward pass through the model. It takes the input `text_tokens` as input and passes it through the `Decoder` module inside the `Andromeda` model. The output from the decoder is returned as the result.
|
86 |
+
|
87 |
+
### Usage Example
|
88 |
+
|
89 |
+
```python
|
90 |
+
from Andromeda import Andromeda
|
91 |
+
|
92 |
+
# Create an instance of the Andromeda model
|
93 |
+
model = Andromeda()
|
94 |
+
|
95 |
+
# Define the input text tokens
|
96 |
+
text_tokens = [1, 2, 3, 4, 5] # Example input tokens
|
97 |
+
|
98 |
+
# Perform a forward pass through the model
|
99 |
+
output = model.forward(text_tokens)
|
100 |
+
|
101 |
+
print(output)
|
102 |
+
```
|
103 |
+
|
104 |
+
### Constructor
|
105 |
+
|
106 |
+
```python
|
107 |
+
def __init__(self, num_tokens=50304, max_seq_len=8192, dim=2560, depth=32, dim_head=128, heads=24, use_abs_pos_emb=False, alibi_pos_bias=True, alibi_num_heads=12, rotary_xpos=True, attn_flash=True, deepnorm=True, shift_tokens=1, attn_one_kv_head=True, qk_norm=True, attn_qk_norm=True, attn_qk_norm_dim_scale=True, embedding_provider=AndromedaEmbedding())
|
108 |
+
```
|
109 |
+
|
110 |
+
- `num_tokens` (optional): Number of tokens in the vocabulary.
|
111 |
+
- `max_seq_len` (optional): Maximum sequence length.
|
112 |
+
- `dim` (optional): Dimension of the model.
|
113 |
+
- `depth` (optional): Depth of the model.
|
114 |
+
- `dim_head` (optional): Dimension of the model head.
|
115 |
+
- `heads` (optional): Number of heads.
|
116 |
+
- `use_abs_pos_emb` (optional): Whether to use absolute position embedding.
|
117 |
+
- `alibi_pos_bias` (optional): Alibi position bias.
|
118 |
+
- `alibi_num_heads` (optional): Number of alibi heads.
|
119 |
+
- `rotary_xpos` (optional): Rotary position.
|
120 |
+
- `attn_flash` (optional): Attention flash.
|
121 |
+
- `deepnorm` (optional): Deep normalization.
|
122 |
+
- `shift_tokens` (optional): Number of tokens to shift.
|
123 |
+
- `attn_one_kv_head` (optional): Attention one key/value head.
|
124 |
+
- `qk_norm` (optional): Query-key normalization.
|
125 |
+
- `attn_qk_norm` (optional): Attention query-key normalization.
|
126 |
+
- `attn_qk_norm_dim_scale` (optional): Attention query-key normalization dimension scale.
|
127 |
+
- `embedding_provider` (optional): Embedding provider module.
|
128 |
+
|
129 |
+
### Methods
|
130 |
+
|
131 |
+
- `forward(text_tokens, **kwargs)`: Performs a forward pass through the model.
|
132 |
+
- `text_tokens` (required): Input tokens.
|
133 |
+
- `kwargs` (optional): Other arguments.
|
134 |
+
|
135 |
+
### Args
|
136 |
+
|
137 |
+
- `text_tokens` (list): Input tokens.
|
138 |
+
|
139 |
+
### Returns
|
140 |
+
|
141 |
+
- Output from the decoder module.
|
142 |
+
|
143 |
+
## Conclusion
|
144 |
+
|
145 |
+
The Andromeda module provides a transformer-based model architecture for text generation. The `AndromedaTokenizer` class allows you to tokenize texts using the specified tokenizer model. The `Andromeda` class initializes with a transformer and autoregressive wrapper, providing the functionality for text generation. By using the provided classes and methods, you can generate text using the Andromeda model.
|
Andromeda/DOCs/Docs/TRAINING.md
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Andromeda Model Training Standard Operating Procedure
|
2 |
+
|
3 |
+
This document provides instructions on how to train the Andromeda model end-to-end using the provided code. The training procedure consists of three main scripts: `build_dataset.py`, `model.py`, and `train_distributed.py`. Follow the steps below to train the Andromeda model.
|
4 |
+
|
5 |
+
## Prerequisites
|
6 |
+
|
7 |
+
Before starting the training process, ensure that you have the following requirements:
|
8 |
+
|
9 |
+
- Python 3.7 or higher
|
10 |
+
- PyTorch 1.9 or higher
|
11 |
+
- Transformers library
|
12 |
+
- Datasets library
|
13 |
+
- Accelerate library
|
14 |
+
- Wandb library (optional, for logging)
|
15 |
+
|
16 |
+
## Step 1: Building the Dataset
|
17 |
+
|
18 |
+
The first step is to build the dataset required for training. The `build_dataset.py` script processes the training data and prepares it for training. Follow the instructions below to build the dataset:
|
19 |
+
|
20 |
+
1. Open the `build_dataset.py` script.
|
21 |
+
2. Set the configuration parameters in the `CFG` class according to your requirements:
|
22 |
+
- `HF_ACCOUNT_REPO`: Replace with your Hugging Face API key.
|
23 |
+
- `TOKENIZER`: Choose the tokenizer model to use (e.g., "EleutherAI/gpt-neox-20b").
|
24 |
+
- `DATASET_NAME`: Choose the dataset to process (e.g., "tiiuae/falcon-refinedweb").
|
25 |
+
- `SEQ_LEN`: Set the desired sequence length.
|
26 |
+
3. Save the changes to the script.
|
27 |
+
4. Open a terminal or command prompt and navigate to the directory containing the `build_dataset.py` script.
|
28 |
+
5. Run the following command to execute the script:
|
29 |
+
```
|
30 |
+
python build_dataset.py
|
31 |
+
```
|
32 |
+
6. The script will process the dataset and push it to your Hugging Face account repository specified by `HF_ACCOUNT_REPO`.
|
33 |
+
|
34 |
+
## Step 2: Defining the Andromeda Model
|
35 |
+
|
36 |
+
The second step is to define the Andromeda model architecture. The `model.py` script contains the model definition and configuration. Follow the instructions below to configure the Andromeda model:
|
37 |
+
|
38 |
+
1. Open the `model.py` script.
|
39 |
+
2. Set the configuration parameters in the `AndromedaTokenizer` and `Andromeda` classes according to your requirements:
|
40 |
+
- `tokenizer`: Configure the tokenizer with the desired parameters.
|
41 |
+
- `Andromeda`: Configure the Andromeda model with the desired architecture.
|
42 |
+
3. Save the changes to the script.
|
43 |
+
|
44 |
+
## Step 3: Training the Andromeda Model
|
45 |
+
|
46 |
+
The final step is to train the Andromeda model using the `train_distributed.py` script. Follow the instructions below to start the training process:
|
47 |
+
|
48 |
+
1. Open the `train_distributed.py` script.
|
49 |
+
2. Set the configuration parameters in the `TrainAndromeda.CFG` class according to your requirements:
|
50 |
+
- `BATCH_SIZE`: Set the batch size for training.
|
51 |
+
- `GRADIENT_ACCUMULATE_EVERY`: Set the number of gradient accumulation steps.
|
52 |
+
- `LEARNING_RATE`: Set the learning rate for the optimizer.
|
53 |
+
- `WEIGHT_DECAY`: Set the weight decay for the optimizer.
|
54 |
+
- `SEQ_LEN`: Set the desired sequence length.
|
55 |
+
- `USE_DEEPSPEED`: Set to `True` if using DeepSpeed for optimization.
|
56 |
+
- `USE_FSDP`: Set to `True` if using Fully Sharded Data Parallelism.
|
57 |
+
- `USE_PRETOKENIZED`: Set to `True` if using a pre-tokenized dataset.
|
58 |
+
- `USE_ACTIVATION_CHECKPOINTING`: Set to `True` if using activation checkpointing.
|
59 |
+
- `RESUME_FROM_CHECKPOINT`: Set to the path of a checkpoint to resume training from.
|
60 |
+
- `CHECKPOINTING_STEPS`: Set the number of steps between checkpoints.
|
61 |
+
- `OUTPUT_DIR`: Set the output directory for saving the model checkpoints and logs.
|
62 |
+
- `ENTITY_NAME`: Set the Wandb entity name for logging (optional).
|
63 |
+
3. Save the changes to the script.
|
64 |
+
4. Open a terminal or command prompt and navigate to the directory containing the `train_distributed.py` script.
|
65 |
+
5. Run the following command to start the training:
|
66 |
+
```
|
67 |
+
python train_distributed.py
|
68 |
+
```
|
69 |
+
6. The script will train the Andromeda model using the specified configuration and dataset.
|
70 |
+
7. During training, the progress will be displayed in the terminal, and logs will be saved to the specified output directory.
|
71 |
+
|
72 |
+
# Other Training methods
|
73 |
+
|
74 |
+
First:
|
75 |
+
|
76 |
+
`Accelerate Config`
|
77 |
+
|
78 |
+
Enable Deepspeed 3:
|
79 |
+
|
80 |
+
`Accelerate launch train_distributed_accelerate.py`
|
81 |
+
|
82 |
+
|
Andromeda/DOCs/Docs/Training/DATASET_STRATEGY.md
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Andromeda
|
2 |
+
|
3 |
+
We should train an 100m param, 500m, 1billion parameters verisions with similiar hyperparameters from these 2 similiar models
|
4 |
+
|
5 |
+
[concept of mind's PALM](https://github.com/conceptofmind/PaLM)
|
6 |
+
Model Size Num Tokens Dim Depth Dim Head Heads Flash Attention Learning Rate
|
7 |
+
150 M 50304 768 12 128 8 True 6e-4
|
8 |
+
410 M 50304 1024 24 128 8 True 3e-4
|
9 |
+
1 B 50304 2048 16 128 8 True 3e-4
|
10 |
+
|
11 |
+
|
12 |
+
[MPT HF](https://huggingface.co/mosaicml/mpt-7b)
|
13 |
+
|
14 |
+
Hyperparameter Value
|
15 |
+
n_parameters 6.7B
|
16 |
+
n_layers 32
|
17 |
+
n_heads 32
|
18 |
+
d_model 4096
|
19 |
+
vocab size 50432
|
20 |
+
sequence length 2048
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
## Data prioritization: Prioritize datasets based on their relevance to the desired AI capabilities and the quality of the data.
|
26 |
+
|
27 |
+
High priority: C4, openwebtext, super_glue, piqa, Falcon-40B (RefinedWeb-English, RefinedWeb-Europe, Books, Conversations, Code, Technical), glue, tiiuae/falcon-refinedweb, math_dataset
|
28 |
+
|
29 |
+
Medium priority: bigcode/ta-prompt, bigcode/the-stack-dedup, OpenAssistant/oasst1, ehartford/wizard_vicuna_70k_unfiltered, tiiuae/falcon-refinedweb
|
30 |
+
|
31 |
+
Low priority: timdettmers/openassistant-guanaco, JosephusCheung/GuanacoDataset, JosephusCheung/GuanacoDataset, anon8231489123/ShareGPT_Vicuna_unfiltered, togethercomputer/RedPajama-Data, togethercomputer/RedPajama-Data-1T, Anthropic/hh-rlhf, databricks/databricks-dolly-15k, QingyiSi/Alpaca-CoT, alpaca,
|
32 |
+
distillation, timdettmers/openassistant-guanaco, OpenAssistant/oasst1, dmayhem93/toolformer-v0-postprocessed, openai_humaneval, yahma/alpaca-cleaned,
|
33 |
+
|
34 |
+
## Data preprocessing: Clean, preprocess, and tokenize the datasets to ensure consistency and compatibility with the AI model.
|
35 |
+
|
36 |
+
Remove duplicates, irrelevant content, and low-quality data.
|
37 |
+
|
38 |
+
Tokenize the text using a suitable tokenizer, such as GPT Neox tokenizer or potentially falcon's tokenizer
|
39 |
+
|
40 |
+
Split the datasets into training, validation, and testing sets.
|
41 |
+
|
42 |
+
|
43 |
+
## Training strategy: Train the AI model using the prioritized datasets in a multi-stage process.
|
44 |
+
|
45 |
+
Stage 1: Pretrain the model on high-priority datasets (openwebtext, super_glue, piqa, Falcon-40B, glue) to build a strong language understanding foundation.
|
46 |
+
|
47 |
+
Stage 2: Fine-tune the model on medium-priority datasets (bigcode/ta-prompt, bigcode/the-stack-dedup, OpenAssistant/oasst1, ehartford/wizard_vicuna_70k_unfiltered, tiiuae/falcon-refinedweb) to enhance its performance in specific domains and tasks.
|
48 |
+
|
49 |
+
Stage 3: Further fine-tune the model on low-priority datasets (JosephusCheung/GuanacoDataset, anon8231489123/ShareGPT_Vicuna_unfiltered, togethercomputer/RedPajama-Data, togethercomputer/RedPajama-Data-1T, Anthropic/hh-rlhf, databricks/databricks-dolly-15k, QingyiSi/Alpaca-CoT) to capture any additional knowledge and nuances. PRM800K: A Process Supervision Dataset
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
Evaluation and iteration: Continuously evaluate the model's performance on the validation and testing sets, and iterate the training process to improve its performance.
|
54 |
+
|
55 |
+
Monitor the model's performance using relevant metrics, such as perplexity, F1 score, or BLEU score, depending on the task.
|
56 |
+
Adjust hyperparameters, learning rate, and training duration as needed to optimize the model's performance.
|
57 |
+
If necessary, revisit the data prioritization and preprocessing steps to refine the training data.
|
58 |
+
|
59 |
+
|
60 |
+
# Evaluations and Benchmarks:
|
61 |
+
|
62 |
+
[Chain of thought hub](https://github.com/FranxYao/chain-of-thought-hub)
|
63 |
+
SFT stands for Style Fine-tuning and RLHF stands for Reinforcement Learning and Human Feedback. These are techniques used in natural language processing to improve the quality and accuracy of generated text. The statement suggests that if these techniques are applied correctly to the 65B LLaMA dataset, it is possible to recreate ChatGPT.
|
64 |
+
|
65 |
+
|
66 |
+
# Analysis of Existing Models
|
67 |
+
|
68 |
+
### MPT-7B
|
69 |
+
|
70 |
+
```python
|
71 |
+
Data Source Number of Tokens in Source Proportion Effective Number of Tokens Epochs
|
72 |
+
mC4 3.1.0 - English 417.99 B 0.33 330 B 0.14
|
73 |
+
C4 - English - SemDedup 80% 100.42 B 0.299 299 B 2.98
|
74 |
+
RedPajama - CommonCrawl 878.45 B 0.1 100 B 0.11
|
75 |
+
The Stack - Selected Languages 463.78 B 0.1 100 B 0.22
|
76 |
+
RedPajama - Wikipedia - En 4.87 B 0.04 40 B 8.21
|
77 |
+
The Stack - Markdown 107.07 B 0.035 35 B 0.33
|
78 |
+
S2ORC 48.85 B 0.033 33 B 0.68
|
79 |
+
RedPajama - Books 26.02 B 0.03 30B 1.15
|
80 |
+
RedPajama - arXiv 28.10 B 0.019 19 B 0.68
|
81 |
+
RedPajama - StackExchange 20.54 B 0.014 14 B 0.68
|
82 |
+
```
|
83 |
+
|
84 |
+
# MPT-1B
|
85 |
+
|
86 |
+
```
|
87 |
+
Training Data
|
88 |
+
The model was trained for 200B tokens (batch size 2200, sequence length 2048). It was trained on the following data mix:
|
89 |
+
|
90 |
+
67% RedPajama Common Crawl
|
91 |
+
15% C4
|
92 |
+
4.5% RedPajama GitHub
|
93 |
+
4.5% RedPajama Wikipedia
|
94 |
+
4.5% RedPajama Books
|
95 |
+
2.5% RedPajama Arxiv
|
96 |
+
2% RedPajama StackExchange
|
97 |
+
|
98 |
+
Each sample was chosen from one of the datasets, with the dataset selected with the probability specified above. The examples were shuffled within each dataset. Each example was constructed from as many sequences from that dataset as were necessary to fill the 2048 sequence length.
|
99 |
+
|
100 |
+
```
|