Spaces:
Runtime error
Runtime error
EgorShibaev
commited on
Commit
•
d8b31a6
1
Parent(s):
e71def3
scripts
Browse files- prep_scripts/chunking.py +58 -0
- prep_scripts/lancedb_setup.py +96 -0
- prep_scripts/markdown_to_text.py +62 -0
- prep_scripts/requirements.txt +9 -0
prep_scripts/chunking.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.text_splitter import CharacterTextSplitter, NLTKTextSplitter
|
2 |
+
import argparse
|
3 |
+
from pathlib import Path
|
4 |
+
import os
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
def fixed_size_chunking(text, chunk_size=256) -> list[str]:
|
8 |
+
splitter = CharacterTextSplitter(
|
9 |
+
separator=" ",
|
10 |
+
chunk_size=chunk_size,
|
11 |
+
chunk_overlap=20
|
12 |
+
)
|
13 |
+
return splitter.split_text(text)
|
14 |
+
|
15 |
+
def content_aware_chunking(text, chunk_size=256) -> list[str]:
|
16 |
+
splitter = NLTKTextSplitter(
|
17 |
+
separator=".",
|
18 |
+
chunk_size = chunk_size,
|
19 |
+
chunk_overlap = 20
|
20 |
+
)
|
21 |
+
return splitter.split_text(text)
|
22 |
+
|
23 |
+
def main():
|
24 |
+
parser = argparse.ArgumentParser()
|
25 |
+
parser.add_argument("--input-dir", help="input directory with text files", type=str,
|
26 |
+
default="docs")
|
27 |
+
parser.add_argument("--output-dir", help="output directory to store chunked texts", type=str,
|
28 |
+
default="chunked_docs")
|
29 |
+
parser.add_argument("--chunk-size", help="chunk size", type=int, default=256)
|
30 |
+
parser.add_argument("--chunking-type", help="fixed_size or content_aware", type=str, default="fixed_size")
|
31 |
+
|
32 |
+
args = parser.parse_args()
|
33 |
+
|
34 |
+
input_dir = Path(args.input_dir)
|
35 |
+
output_dir = Path(args.output_dir)
|
36 |
+
|
37 |
+
assert os.path.isdir(input_dir), "Input directory doesn't exist"
|
38 |
+
|
39 |
+
os.makedirs(output_dir, exist_ok=True)
|
40 |
+
|
41 |
+
for file in tqdm(input_dir.rglob("*")):
|
42 |
+
if file.is_file():
|
43 |
+
with open(file, 'r', encoding='utf8') as f:
|
44 |
+
text = f.read()
|
45 |
+
|
46 |
+
if args.chunking_type == "fixed_size":
|
47 |
+
chunked_text = fixed_size_chunking(text, args.chunk_size)
|
48 |
+
elif args.chunking_type == "content_aware":
|
49 |
+
chunked_text = content_aware_chunking(text, args.chunk_size)
|
50 |
+
else:
|
51 |
+
raise ValueError("Invalid chunking type. Choose from 'fixed_size' or 'content_aware'")
|
52 |
+
|
53 |
+
for i, chunk in enumerate(chunked_text):
|
54 |
+
with open(output_dir / f"{file.stem}_chunk_{i}.txt", "w", encoding='utf8') as f:
|
55 |
+
f.write(chunk)
|
56 |
+
|
57 |
+
if __name__ == "__main__":
|
58 |
+
main()
|
prep_scripts/lancedb_setup.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import lancedb
|
3 |
+
import torch
|
4 |
+
import pyarrow as pa
|
5 |
+
import pandas as pd
|
6 |
+
from pathlib import Path
|
7 |
+
import tqdm
|
8 |
+
import numpy as np
|
9 |
+
import logging
|
10 |
+
|
11 |
+
from transformers import AutoConfig
|
12 |
+
from sentence_transformers import SentenceTransformer
|
13 |
+
|
14 |
+
logging.basicConfig(level=logging.INFO)
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
def main():
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument("--emb-model", help="embedding model name on HF hub", type=str)
|
21 |
+
parser.add_argument("--table", help="table name in DB", type=str)
|
22 |
+
parser.add_argument("--input-dir", help="input directory with documents to ingest", type=str)
|
23 |
+
parser.add_argument("--vec-column", help="vector column name in the table", type=str, default="vector")
|
24 |
+
parser.add_argument("--text-column", help="text column name in the table", type=str, default="text")
|
25 |
+
parser.add_argument("--db-loc", help="database location", type=str,
|
26 |
+
default=str(Path().resolve() / "gradio_app" / ".lancedb"))
|
27 |
+
parser.add_argument("--batch-size", help="batch size for embedding model", type=int, default=32)
|
28 |
+
parser.add_argument("--num-partitions", help="number of partitions for index", type=int, default=256)
|
29 |
+
parser.add_argument("--num-sub-vectors", help="number of sub-vectors for index", type=int, default=96)
|
30 |
+
|
31 |
+
args = parser.parse_args()
|
32 |
+
|
33 |
+
emb_config = AutoConfig.from_pretrained(args.emb_model)
|
34 |
+
emb_dimension = emb_config.hidden_size
|
35 |
+
|
36 |
+
assert emb_dimension % args.num_sub_vectors == 0, \
|
37 |
+
"Embedding size must be divisible by the num of sub vectors"
|
38 |
+
|
39 |
+
model = SentenceTransformer(args.emb_model)
|
40 |
+
model.eval()
|
41 |
+
|
42 |
+
if torch.backends.mps.is_available():
|
43 |
+
device = "mps"
|
44 |
+
elif torch.cuda.is_available():
|
45 |
+
device = "cuda"
|
46 |
+
else:
|
47 |
+
device = "cpu"
|
48 |
+
logger.info(f"using {str(device)} device")
|
49 |
+
|
50 |
+
db = lancedb.connect(args.db_loc)
|
51 |
+
|
52 |
+
schema = pa.schema(
|
53 |
+
[
|
54 |
+
pa.field(args.vec_column, pa.list_(pa.float32(), emb_dimension)),
|
55 |
+
pa.field(args.text_column, pa.string())
|
56 |
+
]
|
57 |
+
)
|
58 |
+
tbl = db.create_table(args.table, schema=schema, mode="overwrite")
|
59 |
+
|
60 |
+
input_dir = Path(args.input_dir)
|
61 |
+
files = list(input_dir.rglob("*"))
|
62 |
+
|
63 |
+
sentences = []
|
64 |
+
for file in files:
|
65 |
+
with open(file, 'r', encoding='utf8') as f:
|
66 |
+
sentences.append(f.read())
|
67 |
+
|
68 |
+
for i in tqdm.tqdm(range(0, int(np.ceil(len(sentences) / args.batch_size)))):
|
69 |
+
try:
|
70 |
+
batch = [sent for sent in sentences[i * args.batch_size:(i + 1) * args.batch_size] if len(sent) > 0]
|
71 |
+
encoded = model.encode(batch, normalize_embeddings=True, device=device)
|
72 |
+
encoded = [list(vec) for vec in encoded]
|
73 |
+
|
74 |
+
df = pd.DataFrame({
|
75 |
+
args.vec_column: encoded,
|
76 |
+
args.text_column: batch
|
77 |
+
})
|
78 |
+
|
79 |
+
tbl.add(df)
|
80 |
+
except:
|
81 |
+
logger.info(f"batch {i} was skipped")
|
82 |
+
|
83 |
+
'''
|
84 |
+
create ivf-pd index https://lancedb.github.io/lancedb/ann_indexes/
|
85 |
+
with the size of the transformer docs, index is not really needed
|
86 |
+
but we'll do it for demonstrational purposes
|
87 |
+
'''
|
88 |
+
tbl.create_index(
|
89 |
+
num_partitions=args.num_partitions,
|
90 |
+
num_sub_vectors=args.num_sub_vectors,
|
91 |
+
vector_column_name=args.vec_column
|
92 |
+
)
|
93 |
+
|
94 |
+
|
95 |
+
if __name__ == "__main__":
|
96 |
+
main()
|
prep_scripts/markdown_to_text.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
|
5 |
+
from tqdm import tqdm
|
6 |
+
from bs4 import BeautifulSoup
|
7 |
+
from markdown import markdown
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
|
11 |
+
def markdown_to_text(markdown_string):
|
12 |
+
""" Converts a markdown string to plaintext """
|
13 |
+
|
14 |
+
# md -> html -> text since BeautifulSoup can extract text cleanly
|
15 |
+
html = markdown(markdown_string)
|
16 |
+
|
17 |
+
html = re.sub(r'<!--((.|\n)*)-->', '', html)
|
18 |
+
html = re.sub('<code>bash', '<code>', html)
|
19 |
+
|
20 |
+
# extract text
|
21 |
+
soup = BeautifulSoup(html, "html.parser")
|
22 |
+
text = ''.join(soup.findAll(text=True))
|
23 |
+
|
24 |
+
text = re.sub('```(py|diff|python)', '', text)
|
25 |
+
text = re.sub('```\n', '\n', text)
|
26 |
+
text = re.sub('- .*', '', text)
|
27 |
+
text = text.replace('...', '')
|
28 |
+
text = re.sub('\n(\n)+', '\n\n', text)
|
29 |
+
|
30 |
+
return text
|
31 |
+
|
32 |
+
|
33 |
+
def main():
|
34 |
+
parser = argparse.ArgumentParser()
|
35 |
+
parser.add_argument("--input-dir", help="input directory with markdown", type=str,
|
36 |
+
default="transformers/docs/source/en/")
|
37 |
+
parser.add_argument("--output-dir", help="output directory to store raw texts", type=str,
|
38 |
+
default="docs")
|
39 |
+
|
40 |
+
args = parser.parse_args()
|
41 |
+
input_dir = Path(args.input_dir)
|
42 |
+
output_dir = Path(args.output_dir)
|
43 |
+
|
44 |
+
assert os.path.isdir(input_dir), "Input directory doesn't exist"
|
45 |
+
|
46 |
+
files = input_dir.rglob("*")
|
47 |
+
os.makedirs(output_dir, exist_ok=True)
|
48 |
+
|
49 |
+
for file in tqdm(files):
|
50 |
+
parent = file.parent.stem if file.parent.stem != input_dir.stem else ""
|
51 |
+
if file.is_file():
|
52 |
+
with open(file, 'r', encoding='utf8') as f:
|
53 |
+
md = f.read()
|
54 |
+
|
55 |
+
text = markdown_to_text(md)
|
56 |
+
|
57 |
+
with open(output_dir / f"{parent}_{file.stem}.txt", "w", encoding='utf8') as f:
|
58 |
+
f.write(text)
|
59 |
+
|
60 |
+
|
61 |
+
if __name__ == "__main__":
|
62 |
+
main()
|
prep_scripts/requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
bs4==0.0.1
|
2 |
+
lancedb==0.5.3
|
3 |
+
markdown==3.5.1
|
4 |
+
numpy==1.26.2
|
5 |
+
pandas==2.1.3
|
6 |
+
pyarrow==14.0.1
|
7 |
+
sentence-transformers==2.3.1
|
8 |
+
tqdm==4.66.1
|
9 |
+
torch==2.1.1
|