LTT commited on
Commit
e654a4e
·
verified ·
1 Parent(s): 1817aae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -0
app.py CHANGED
@@ -47,8 +47,35 @@ def install_cuda_toolkit():
47
  os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
48
  print("==> finfish install")
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  install_cuda_toolkit()
51
 
 
 
 
 
 
 
52
 
53
  print(f"GPU: {torch.cuda.is_available()}")
54
  a = torch.tensor([0]).cuda()
 
47
  os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
48
  print("==> finfish install")
49
 
50
+ import shutil
51
+
52
+ def find_cuda():
53
+ # Check if CUDA_HOME or CUDA_PATH environment variables are set
54
+ cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
55
+
56
+ if cuda_home and os.path.exists(cuda_home):
57
+ return cuda_home
58
+
59
+ # Search for the nvcc executable in the system's PATH
60
+ nvcc_path = shutil.which('nvcc')
61
+
62
+ if nvcc_path:
63
+ # Remove the 'bin/nvcc' part to get the CUDA installation path
64
+ cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
65
+ return cuda_path
66
+
67
+ return None
68
+
69
+
70
+
71
  install_cuda_toolkit()
72
 
73
+ cuda_path = find_cuda()
74
+
75
+ if cuda_path:
76
+ print(f"CUDA installation found at: {cuda_path}")
77
+ else:
78
+ print("CUDA installation not found")
79
 
80
  print(f"GPU: {torch.cuda.is_available()}")
81
  a = torch.tensor([0]).cuda()