|
--- |
|
tags: |
|
- EasyDeL |
|
- FlaxLlamaForCausalLM |
|
- safetensors |
|
- TPU |
|
- GPU |
|
- XLA |
|
- Flax |
|
--- |
|
# EasyDeL/EasyDeL-Llama-3.2-3B-Instruct |
|
|
|
[![EasyDeL](https://img.shields.io/badge/🤗_EasyDeL-0.0.80-blue.svg)](https://github.com/erfanzar/EasyDeL) |
|
[![Model Type](https://img.shields.io/badge/Model_Type-FlaxLlamaForCausalLM-green.svg)](https://github.com/erfanzar/EasyDeL) |
|
|
|
A model implemented using the EasyDeL framework, designed to deliver optimal performance for large-scale natural language processing tasks. |
|
|
|
## Overview |
|
|
|
EasyDeL provides an efficient, highly-optimized, and customizable machine learning model compatible with both GPU and TPU environments. Built with JAX, this model supports advanced features such as sharded model parallelism, making it suitable for distributed training and inference and customized kernels. |
|
|
|
## Features |
|
|
|
|
|
- **Efficient Implementation**: Built with JAX/Flax for high-performance computation. |
|
- **Multi-Device Support**: Optimized to run on TPU, GPU, and CPU environments for sharding model over 2^(1-1000+) of devices. |
|
- **Sharded Model Parallelism**: Supports model parallelism across multiple devices for scalability. |
|
- **Customizable Precision**: Allows specification of floating-point precision for performance optimization. |
|
|
|
|
|
## Installation |
|
|
|
To install EasyDeL, simply run: |
|
|
|
```bash |
|
pip install easydel |
|
``` |
|
|
|
## Usage |
|
|
|
### Loading the Pre-trained Model |
|
|
|
To load a pre-trained version of the model with EasyDeL: |
|
|
|
```python |
|
from easydel import AutoEasyDeLModelForCausalLM |
|
from jax import numpy as jnp, lax |
|
|
|
max_length = None # can be set to use lower memory for caching |
|
|
|
# Load model and parameters |
|
model, params = AutoEasyDeLModelForCausalLM.from_pretrained( |
|
"EasyDeL/EasyDeL-Llama-3.2-3B-Instruct", |
|
config_kwargs=dict( |
|
use_scan_mlp=False, |
|
attn_dtype=jnp.float16, |
|
freq_max_position_embeddings=max_length, |
|
mask_max_position_embeddings=max_length, |
|
attn_mechanism=ed.AttentionMechanisms.FLASH_ATTN2 |
|
), |
|
dtype=jnp.float16, |
|
param_dtype=jnp.float16, |
|
precision=lax.Precision("fastest"), |
|
auto_shard_params=True, |
|
) |
|
``` |
|
|
|
## Supported Tasks |
|
|
|
|
|
[Need more information] |
|
|
|
|
|
## Limitations |
|
|
|
|
|
- **Hardware Dependency**: Performance can vary significantly based on the hardware used. |
|
- **JAX/Flax Setup Required**: The environment must support JAX/Flax for optimal use. |
|
- **Experimental Features**: Some features (like custom kernel usage or ed-ops) may require additional configuration and tuning. |
|
|