cassanof's picture
Update README.md
bcbf898 verified
|
raw
history blame
2.03 kB
metadata
datasets:
  - nuprl/EditPackFT-Multi
tags:
  - code

What is this

This is a deepseek coder 7b model trained to predict commit messages for a diff.

Languages trained on:

LANGS = [
    "Python",
    "Rust",
    "JavaScript",
    "Java",
    "Go",
    "C++",
    "C#",
    "Ruby",
    "PHP",
    "TypeScript",
    "C",
    "Scala",
    "Swift",
    "Kotlin",
    "Objective-C",
    "Perl",
    "Haskell",
    "Bash",
    "Sh",
    "Lua",
    "R",
    "Julia",
]

How to prompt:

import difflib
class NDiff:
    def __init__(self, s1, s2):
        self.s1 = s1
        self.s2 = s2
        self.diff = difflib.ndiff(s1.split("\n"), s2.split("\n"))

    def __str__(self):
        return "\n".join([l for l in self.diff if l[0] != "?"])

    def str_colored(self):
        import colored

        buf = ""
        for l in self.diff:
            if l[0] == "?":
                continue
            if l[0] == "-":
                buf += colored.stylize(l, colored.fg("red"))
            elif l[0] == "+":
                buf += colored.stylize(l, colored.fg("green"))
            else:
                buf += l
            buf += "\n"
        return buf

    def num_removed(self):
        return len([l for l in self.diff if l[0] == "-"])

    def num_added(self):
        return len([l for l in self.diff if l[0] == "+"])

    def __repr__(self):
        return self.__str__()

def format_prompt(old, new):
    diff_header = "<diff>"
    instr_header = "<commit_message>"
    diff = str(NDiff(old, new))
    return f"{diff_header}\n{diff}\n{instr_header}\n"

def gen(old, new, max_new_tokens=200, temperature=0.45, top_p=0.90):
    prompt = format_prompt(old, new)
    toks = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
    outs = model.generate(toks, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p)
    return [tokenizer.decode(out[len(toks[0]):], skip_special_tokens=True) for out in outs]

use the "gen" function with the old and new code