CorvaeOboro commited on
Commit
8245392
Β·
1 Parent(s): 0c92011

Upload custom_ops.py

Browse files
Files changed (1) hide show
  1. dnnlib/tflib/custom_ops.py +181 -0
dnnlib/tflib/custom_ops.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """TensorFlow custom ops builder.
10
+ """
11
+
12
+ import glob
13
+ import os
14
+ import re
15
+ import uuid
16
+ import hashlib
17
+ import tempfile
18
+ import shutil
19
+ import tensorflow as tf
20
+ from tensorflow.python.client import device_lib # pylint: disable=no-name-in-module
21
+
22
+ from .. import util
23
+
24
+ #----------------------------------------------------------------------------
25
+ # Global options.
26
+
27
+ cuda_cache_path = None
28
+ cuda_cache_version_tag = 'v1'
29
+ do_not_hash_included_headers = True # Speed up compilation by assuming that headers included by the CUDA code never change.
30
+ verbose = True # Print status messages to stdout.
31
+
32
+ #----------------------------------------------------------------------------
33
+ # Internal helper funcs.
34
+
35
+ def _find_compiler_bindir():
36
+ hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True)
37
+ if hostx64_paths != []:
38
+ return hostx64_paths[0]
39
+ hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True)
40
+ if hostx64_paths != []:
41
+ return hostx64_paths[0]
42
+ hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True)
43
+ if hostx64_paths != []:
44
+ return hostx64_paths[0]
45
+ vc_bin_dir = 'C:/Program Files (x86)/Microsoft Visual Studio 14.0/vc/bin'
46
+ if os.path.isdir(vc_bin_dir):
47
+ return vc_bin_dir
48
+ return None
49
+
50
+ def _get_compute_cap(device):
51
+ caps_str = device.physical_device_desc
52
+ m = re.search('compute capability: (\\d+).(\\d+)', caps_str)
53
+ major = m.group(1)
54
+ minor = m.group(2)
55
+ return (major, minor)
56
+
57
+ def _get_cuda_gpu_arch_string():
58
+ gpus = [x for x in device_lib.list_local_devices() if x.device_type == 'GPU']
59
+ if len(gpus) == 0:
60
+ raise RuntimeError('No GPU devices found')
61
+ (major, minor) = _get_compute_cap(gpus[0])
62
+ return 'sm_%s%s' % (major, minor)
63
+
64
+ def _run_cmd(cmd):
65
+ with os.popen(cmd) as pipe:
66
+ output = pipe.read()
67
+ status = pipe.close()
68
+ if status is not None:
69
+ raise RuntimeError('NVCC returned an error. See below for full command line and output log:\n\n%s\n\n%s' % (cmd, output))
70
+
71
+ def _prepare_nvcc_cli(opts):
72
+ cmd = 'nvcc ' + opts.strip()
73
+ cmd += ' --disable-warnings'
74
+ cmd += ' --include-path "%s"' % tf.sysconfig.get_include()
75
+ cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'protobuf_archive', 'src')
76
+ cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'com_google_absl')
77
+ cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'eigen_archive')
78
+
79
+ compiler_bindir = _find_compiler_bindir()
80
+ if compiler_bindir is None:
81
+ # Require that _find_compiler_bindir succeeds on Windows. Allow
82
+ # nvcc to use whatever is the default on Linux.
83
+ if os.name == 'nt':
84
+ raise RuntimeError('Could not find MSVC/GCC/CLANG installation on this computer. Check compiler_bindir_search_path list in "%s".' % __file__)
85
+ else:
86
+ cmd += ' --compiler-bindir "%s"' % compiler_bindir
87
+ cmd += ' 2>&1'
88
+ return cmd
89
+
90
+ #----------------------------------------------------------------------------
91
+ # Main entry point.
92
+
93
+ _plugin_cache = dict()
94
+
95
+ def get_plugin(cuda_file, extra_nvcc_options=[]):
96
+ cuda_file_base = os.path.basename(cuda_file)
97
+ cuda_file_name, cuda_file_ext = os.path.splitext(cuda_file_base)
98
+
99
+ # Already in cache?
100
+ if cuda_file in _plugin_cache:
101
+ return _plugin_cache[cuda_file]
102
+
103
+ # Setup plugin.
104
+ if verbose:
105
+ print('Setting up TensorFlow plugin "%s": ' % cuda_file_base, end='', flush=True)
106
+ try:
107
+ # Hash CUDA source.
108
+ md5 = hashlib.md5()
109
+ with open(cuda_file, 'rb') as f:
110
+ md5.update(f.read())
111
+ md5.update(b'\n')
112
+
113
+ # Hash headers included by the CUDA code by running it through the preprocessor.
114
+ if not do_not_hash_included_headers:
115
+ if verbose:
116
+ print('Preprocessing... ', end='', flush=True)
117
+ with tempfile.TemporaryDirectory() as tmp_dir:
118
+ tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + cuda_file_ext)
119
+ _run_cmd(_prepare_nvcc_cli('"%s" --preprocess -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir)))
120
+ with open(tmp_file, 'rb') as f:
121
+ bad_file_str = ('"' + cuda_file.replace('\\', '/') + '"').encode('utf-8') # __FILE__ in error check macros
122
+ good_file_str = ('"' + cuda_file_base + '"').encode('utf-8')
123
+ for ln in f:
124
+ if not ln.startswith(b'# ') and not ln.startswith(b'#line '): # ignore line number pragmas
125
+ ln = ln.replace(bad_file_str, good_file_str)
126
+ md5.update(ln)
127
+ md5.update(b'\n')
128
+
129
+ # Select compiler options.
130
+ compile_opts = ''
131
+ if os.name == 'nt':
132
+ compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.lib')
133
+ elif os.name == 'posix':
134
+ compile_opts += f' --compiler-options \'-fPIC\''
135
+ compile_opts += f' --compiler-options \'{" ".join(tf.sysconfig.get_compile_flags())}\''
136
+ compile_opts += f' --linker-options \'{" ".join(tf.sysconfig.get_link_flags())}\''
137
+ else:
138
+ assert False # not Windows or Linux, w00t?
139
+ compile_opts += f' --gpu-architecture={_get_cuda_gpu_arch_string()}'
140
+ compile_opts += ' --use_fast_math'
141
+ for opt in extra_nvcc_options:
142
+ compile_opts += ' ' + opt
143
+ nvcc_cmd = _prepare_nvcc_cli(compile_opts)
144
+
145
+ # Hash build configuration.
146
+ md5.update(('nvcc_cmd: ' + nvcc_cmd).encode('utf-8') + b'\n')
147
+ md5.update(('tf.VERSION: ' + tf.VERSION).encode('utf-8') + b'\n')
148
+ md5.update(('cuda_cache_version_tag: ' + cuda_cache_version_tag).encode('utf-8') + b'\n')
149
+
150
+ # Compile if not already compiled.
151
+ cache_dir = util.make_cache_dir_path('tflib-cudacache') if cuda_cache_path is None else cuda_cache_path
152
+ bin_file_ext = '.dll' if os.name == 'nt' else '.so'
153
+ bin_file = os.path.join(cache_dir, cuda_file_name + '_' + md5.hexdigest() + bin_file_ext)
154
+ if not os.path.isfile(bin_file):
155
+ if verbose:
156
+ print('Compiling... ', end='', flush=True)
157
+ with tempfile.TemporaryDirectory() as tmp_dir:
158
+ tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + bin_file_ext)
159
+ _run_cmd(nvcc_cmd + ' "%s" --shared -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir))
160
+ os.makedirs(cache_dir, exist_ok=True)
161
+ intermediate_file = os.path.join(cache_dir, cuda_file_name + '_' + uuid.uuid4().hex + '_tmp' + bin_file_ext)
162
+ shutil.copyfile(tmp_file, intermediate_file)
163
+ os.rename(intermediate_file, bin_file) # atomic
164
+
165
+ # Load.
166
+ if verbose:
167
+ print('Loading... ', end='', flush=True)
168
+ plugin = tf.load_op_library(bin_file)
169
+
170
+ # Add to cache.
171
+ _plugin_cache[cuda_file] = plugin
172
+ if verbose:
173
+ print('Done.', flush=True)
174
+ return plugin
175
+
176
+ except:
177
+ if verbose:
178
+ print('Failed!', flush=True)
179
+ raise
180
+
181
+ #----------------------------------------------------------------------------