jax[cuda12] flax==0.9.0 flash_attn_jax torch torchvision opencv-python-headless einops huggingface_hub transformers tokenizers sentencepiece fire invisible-watermark ml-dtypes