File size: 3,087 Bytes
cb34746
d7ab57a
cb34746
 
 
de7de27
cb34746
 
 
 
 
 
 
 
 
 
 
 
 
de7de27
cb34746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be85521
cb34746
 
 
d7ab57a
 
 
 
 
cb34746
 
de7de27
8787dd3
d7ab57a
 
 
 
cb34746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import pyclamd
import urllib.response, requests

from picklescan.scanner import (
    scan_url,
    scan_file_path,
    ScanResult, SafetyLevel
)


# def scan_file(file_path: str):
#     ret = scan_pickle_bytes(io.BytesIO(pickle.dumps(file_path)), "file.pkl")
#     print(ret)


def scan_file(file_path: str):
    if file_path.startswith("http"):
        scan_result: ScanResult = scan_url(file_path)
    else:
        scan_result: ScanResult = scan_file_path(file_path)

    globalImports = list(map(lambda x: fmt_import(x.module, x.name), scan_result.globals))
    dangerousImports = list(map(lambda x: fmt_import(x.module, x.name),
                                filter(lambda x: x.safety == SafetyLevel.Dangerous, scan_result.globals)))
    if len(dangerousImports) > 0:
        picklescanExitCode = 1
    else:
        picklescanExitCode = 0
    return {
        'url': file_path,
        'fileExists': True,
        'picklescanExitCode': picklescanExitCode,
        'picklescanGlobalImports': globalImports,
        'picklescanDangerousImports': dangerousImports,
        # 'clamscanExitCode': ScanExitCode,
        # 'clamscanOutput': string,
        # hashes: Record < ModelHashType, string >;
        # conversions: Record < 'safetensors' | 'ckpt', ConversionResult >;
    }


def init_clamd():
    clamd = pyclamd.ClamdUnixSocket()
    return clamd


headers = {
    "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/78.0.3904.108 Safari/537.36"
}


def clamd_file(file_path: str, clamd):
    if file_path.startswith("http"):
        tmp_path = f'/tmp/clamd_{file_path.split("/")[-1].split("?")[0]}'
        print("tmp_path ", tmp_path)
        resp = requests.get(file_path, headers=headers).content
        with open(tmp_path, "wb") as f:
            f.write(resp)
        # urllib.request.urlretrieve(file_path, tmp_path)
        ret = clamd.scan_file(tmp_path)
        if ret is None:
            return {
                'clamscanExitCode': 0,
                'clamscanOutput': "No virus found",
            }
        elif file_path in ret and len(file_path) > 0:
            return {
                'clamscanExitCode': 1,
                'clamscanOutput': ' '.join(ret[file_path]),
            }


def fmt_import(module: str, name: str):
    return f"from ${module} import ${name}",


if __name__ == "__main__":
    detail = scan_file("https://huggingface.co/yesyeahvh/bad-hands-5/resolve/main/bad-hands-5.pt")
    clamd_detail = clamd_file("https://huggingface.co/yesyeahvh/bad-hands-5/resolve/main/bad-hands-5.pt")
    print(detail)
    print(clamd_detail)
    # ScanResult(
    #     globals=[Global(module='torch', name='FloatStorage', safety= < SafetyLevel.Innocuous: 'innocuous' >),
    #              Global(module='collections', name='OrderedDict', safety= < SafetyLevel.Innocuous: 'innocuous' >),
    #              Global(module='torch._utils', name='_rebuild_tensor_v2',safety= < SafetyLevel.Innocuous: 'innocuous' >)],
    #     scanned_files = 1, issues_count = 0, infected_files = 0, scan_err = False)