|
import os |
|
import re |
|
import sys |
|
import time |
|
import subprocess |
|
|
|
from distutils import sysconfig |
|
from setuptools import setup, Extension, find_packages |
|
from Cython.Build import build_ext, cythonize |
|
|
|
|
|
class CMakeExtension(Extension): |
|
def __init__(self, name, sourcedir=''): |
|
Extension.__init__(self, name, sources=[]) |
|
self.sourcedir = os.path.abspath(sourcedir) |
|
|
|
|
|
class CMakeBuild(build_ext): |
|
def build_extension(self, ext): |
|
if isinstance(ext, CMakeExtension): |
|
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) |
|
if not extdir.endswith(os.path.sep): |
|
extdir += os.path.sep |
|
|
|
if not os.path.exists(self.build_temp): |
|
os.makedirs(self.build_temp) |
|
|
|
subprocess.check_call(['cmake', ext.sourcedir, '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir], |
|
cwd=self.build_temp) |
|
subprocess.check_call(['cmake', '--build', '.', '--', 'VERBOSE=1', '-j8'], cwd=self.build_temp) |
|
else: |
|
super().build_extension(ext) |
|
|
|
|
|
ext_modules = [CMakeExtension('greedrl_c')] |
|
|
|
setup( |
|
name='greedrl', |
|
version='1.0.0', |
|
packages=find_packages(), |
|
ext_modules=ext_modules, |
|
cmdclass={'build_ext': CMakeBuild}, |
|
install_requires=["torch==1.12.1+cu113"], |
|
) |
|
|