|
import pytest |
|
|
|
from mergekit.common import ModelPath, ModelReference |
|
|
|
|
|
class TestModelReference: |
|
def test_parse_simple(self): |
|
text = "hf_user/model" |
|
mr = ModelReference.parse(text) |
|
assert mr.model == ModelPath(path="hf_user/model", revision=None) |
|
assert mr.lora is None |
|
assert str(mr) == text |
|
|
|
def test_parse_lora(self): |
|
text = "hf_user/model+hf_user/lora" |
|
mr = ModelReference.parse(text) |
|
assert mr.model == ModelPath(path="hf_user/model", revision=None) |
|
assert mr.lora == ModelPath(path="hf_user/lora", revision=None) |
|
assert str(mr) == text |
|
|
|
def test_parse_revision(self): |
|
text = "hf_user/model@v0.0.1" |
|
mr = ModelReference.parse(text) |
|
assert mr.model == ModelPath(path="hf_user/model", revision="v0.0.1") |
|
assert mr.lora is None |
|
assert str(mr) == text |
|
|
|
def test_parse_lora_plus_revision(self): |
|
text = "hf_user/model@v0.0.1+hf_user/lora@main" |
|
mr = ModelReference.parse(text) |
|
assert mr.model == ModelPath(path="hf_user/model", revision="v0.0.1") |
|
assert mr.lora == ModelPath(path="hf_user/lora", revision="main") |
|
assert str(mr) == text |
|
|
|
def test_parse_bad(self): |
|
with pytest.raises(RuntimeError): |
|
ModelReference.parse("@@@@@") |
|
|
|
with pytest.raises(RuntimeError): |
|
ModelReference.parse("a+b+c") |
|
|
|
with pytest.raises(RuntimeError): |
|
ModelReference.parse("a+b+c@d+e@f@g") |
|
|