EasyDeL/EasyDeL-Llama-3.1-8B-Instruct
A model implemented using the EasyDeL framework, designed to deliver optimal performance for large-scale natural language processing tasks.
Overview
EasyDeL provides an efficient, highly-optimized, and customizable machine learning model compatible with both GPU and TPU environments. Built with JAX, this model supports advanced features such as sharded model parallelism, making it suitable for distributed training and inference and customized kernels.
Features
- Efficient Implementation: Built with JAX/Flax for high-performance computation.
- Multi-Device Support: Optimized to run on TPU, GPU, and CPU environments for sharding model over 2^(1-1000+) of devices.
- Sharded Model Parallelism: Supports model parallelism across multiple devices for scalability.
- Customizable Precision: Allows specification of floating-point precision for performance optimization.
Installation
To install EasyDeL, simply run:
pip install easydel
Usage
Loading the Pre-trained Model
To load a pre-trained version of the model with EasyDeL:
from easydel import AutoEasyDeLModelForCausalLM
from jax import numpy as jnp, lax
max_length = None # can be set to use lower memory for caching
# Load model and parameters
model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
"EasyDeL/EasyDeL-Llama-3.1-8B-Instruct",
config_kwargs=dict(
use_scan_mlp=False,
attn_dtype=jnp.float16,
freq_max_position_embeddings=max_length,
mask_max_position_embeddings=max_length,
attn_mechanism=ed.AttentionMechanisms.FLASH_ATTN2
),
dtype=jnp.float16,
param_dtype=jnp.float16,
precision=lax.Precision("fastest"),
auto_shard_params=True,
)
Supported Tasks
[Need more information]
Limitations
- Hardware Dependency: Performance can vary significantly based on the hardware used.
- JAX/Flax Setup Required: The environment must support JAX/Flax for optimal use.
- Experimental Features: Some features (like custom kernel usage or ed-ops) may require additional configuration and tuning.
- Downloads last month
- 1