Skip to content

Commit

Permalink
Add pruning configuration and update requirements; enhance logging an…
Browse files Browse the repository at this point in the history
…d example files
  • Loading branch information
mkurman committed Nov 14, 2024
1 parent e99b269 commit 5006781
Show file tree
Hide file tree
Showing 18 changed files with 699 additions and 131 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@ venv
/**/.vscode/
/**/.idea/
/**/.mypy_cache/
prune*
prune*/
logs/*
configs/*
69 changes: 65 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

![Llama-pruning-image](/assets/llama-pruning.jpg "Llama pruning")

This project provides tools to load and prune large language models using a structured pruning method. The pruning method is based on the work of [Pere Martra](https://github.com/peremartra) with modifications by [Mariusz Kurman](https://github.com/mkurman).
This project provides tools to load and prune large language models using a structured pruning method. The method is based on the work of [Pere Martra](https://github.com/peremartra) with multiple modifications by [Mariusz Kurman](https://github.com/mkurman), including improved adjusted importance calculation, weight normalization, and enhanced layer normalization techniques.

This method is applicable to all models with a Llama-like architecture that includes MLP gating, such as Llama, Phi, Mistral, Qwen, SmolLM, and others.

Expand All @@ -17,8 +17,10 @@ Pere Martra's book: [Large Language Models: Apply and Implement Strategies for L
- [Installation](#installation)
- [Usage](#usage)
- [Arguments](#arguments)
- [AutoML](#automl) *WIP*
- [Example](#example)
- [License](#license)
- [TODO](#todo)

## Installation

Expand All @@ -43,28 +45,74 @@ Pere Martra's book: [Large Language Models: Apply and Implement Strategies for L

To load and prune a model, run the `main.py` script with the appropriate arguments:

Provide `config.yaml` configuration:
```sh
python src/main.py --config config.yaml
```

or use CLI:
```sh
python src/main.py --model_name <model_name> --prune_percent <prune_percent> --dtype <dtype> --cache_dir <cache_dir> --device <device> --output <output> --prompt <prompt> --max_new_tokens <max_new_tokens> [--apply_chat_template]
```

You can find an example `config.yaml` in the `examples` directory.
It is recommended to store your configuration files in the `configs` directory for better organization.

## Arguments

- `--model_name`: Name of the model to load (default: `meditsolutions/Llama-3.2-SUN-2.5B-chat`).
- `--config`: Path to yaml configuration file. If provided, other arguments will be ignored. You can find an example configuration file at config/prune_config.yaml.
- `--model_name`: Name of the model to load (default: `meditsolutions/Llama-3.2-SUN-1B-chat`).
- `--prune_percent`: Percentage of MLP neurons to prune (default: `0.2`).
- `--dtype`: Data type to use (default: `float32`).
- `--cache_dir`: Directory to cache the model (default: `None`).
- `--device`: Device to use (default: `cuda`).
- `--output`: Directory to save the pruned model (default: `pruned_model`).
- `--output`: Directory to save the pruned model (default: `results/pruned_model`).
- `--apply_chat_template`: Apply chat template to the model (default: `None`).
- `--prompt`: Prompt to generate the output (default: `What is the capital of France?`).
- `--max_new_tokens`: Maximum number of tokens to generate (default: `50`).
- `--target_size`: Target size for the MLPs intermediate layer. (`prune_percent` will be ignored).
- `--prune_method`: Method to use for pruning. Currently, only "mk_prune" (alias: "mk") and "mk_prune_adjusted" (alias: "mka") are supported. (default: `mk_prune`)
- `--use_normalized_weights`: Use normalized weights to calculate the final weights. (default: `False`)
- `--use_layer_norm_tweaks`: Apply layer normalization changes to account for the impact of pruned neurons. (default: `False`)
- `--layer_norm_scale`: Layer normalization scale. Only used if use_layer_norm_tweaks is True. (default: 4.0)
- `--layer_norm_scale`: Layer normalization scale. Only used if use_layer_norm_tweaks is True. (default: `4.0`)
- `--gate_up_weight_weights`: Weights for the gate and up weights. (default: [1.0, 1.0])
- `--log_dir`: Directory to save the logs. (default: `logs`)
- `--stop_logging`: Stop logging to the file. (default: `False`)
- `--test_only`: Run the test only. Do not save the model (default: `False`).
- `--print_summary`: Print the pruned model summary. (default: `False`)
- `--quiet`: Do not print logs.
- `--eval_dataset`: Hugging Face dataset to evaluate the model. *WIP*
- `--eval_dataset_size`: Size of the evaluation dataset. (default: 20) *WIP*

## AutoML *WIP*
The AutoML feature in this project allows for automated grid search over multiple pruning parameters to find the best configuration for pruning the model. The parameters for the grid search can be specified in a YAML configuration file.

The AutoML feature is currently under development and will be available in version **1.1.0**.

### Example YAML Configuration
```yaml
model_name: "meditsolutions/Llama-3.2-SUN-2.5B-chat"
dtype: "torch.float32"
device: "cuda"
output_dir: "pruned_model"
prune_grid:
prune_percent_list: [0.1, 0.2, 0.3]
prune_method_list: ["mk_prune", "mk_prune_adjusted"]
use_normalized_weights_list: [true, false]
use_layer_norm_tweaks_list: [true, false]
layer_norm_scale_list: [2.0, 4.0]
target_size_list: [null, 512]
gate_up_weight_weights_list: [[1.0, 1.0], [0.5, 0.5]]
```

### Running the Grid Search
To run the grid search, use the following command:

```sh
python src/main.py --config examples/grid_search.yaml
```

The script will load the configuration from the YAML file and perform a grid search over the specified parameters to find the best pruned model based on KL divergence loss.

## Example

Expand All @@ -75,3 +123,16 @@ python src/main.py --model_name meditsolutions/Llama-3.2-SUN-2.5B-chat --prune_p
## License

This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.

## TODO

- [ ] AutoML functionality
- [ ] More pruning methods
- [ ] Model merging integration
- [ ] Eval harness integration
- [ ] Support for additional model architectures
- [ ] Improved logging and monitoring
- [ ] Documentation and examples for custom pruning strategies
- [ ] User-friendly WebUI
- [ ] Performance benchmarking and comparison with other pruning techniques
- [ ] Visualization tools for model pruning and evaluation results
20 changes: 20 additions & 0 deletions examples/basic.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
model_name: "meditsolutions/Llama-3.2-SUN-1B-chat"
prune_percent: 0.2
dtype: "float32"
cache_dir: null
device: "cuda"
output: "results/pruned_model"
apply_chat_template: true
prompt: "What is the capital of France?"
max_new_tokens: 50
target_size: null
log_dir: "logs"
test_only: false
prune_method: "mk_prune"
use_normalized_weights: false
use_layer_norm_tweaks: false
layer_norm_scale: 4.0
gate_up_weight_weights: [1.0, 1.0]
print_summary: false
quiet: false
stop_logging: false
30 changes: 30 additions & 0 deletions examples/gird_search.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
model_name: "meditsolutions/Llama-3.2-SUN-1B-chat"
prune_percent: 0.2
dtype: "float32"
cache_dir: null
device: "cuda"
output: "results/pruned_model"
apply_chat_template: true
prompt: "What is the capital of France?"
max_new_tokens: 50
target_size: null
log_dir: "logs"
test_only: false
prune_method: "mk"
use_normalized_weights: false
use_layer_norm_tweaks: false
layer_norm_scale: 4.0
gate_up_weight_weights: [1.0, 1.0]
print_summary: false
quiet: false
stop_logging: false

# Prune grid search parameters
grid_search:
prune_percent_list: [0.1, 0.2, 0.3]
prune_method_list: ["mk_prune", "mk_prune_adjusted"]
use_normalized_weights_list: [true, false]
use_layer_norm_tweaks_list: [true, false]
layer_norm_scale_list: [2.0, 4.0]
target_size_list: [null, 512]
gate_up_weight_weights_list: [[1.0, 1.0], [0.4, 0.6]]
20 changes: 20 additions & 0 deletions examples/mka.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
model_name: "meditsolutions/Llama-3.2-SUN-1B-chat"
prune_percent: 0.2
dtype: "float32"
cache_dir: null
device: "cuda"
output: "results/pruned_model"
apply_chat_template: true
prompt: "What is the capital of France?"
max_new_tokens: 50
target_size: null
log_dir: "logs"
test_only: false
prune_method: "mka"
use_normalized_weights: false
use_layer_norm_tweaks: true
layer_norm_scale: 2.0
gate_up_weight_weights: [0.3, 0.7]
print_summary: false
quiet: false
stop_logging: false
20 changes: 20 additions & 0 deletions examples/quiet_no_logs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
model_name: "meditsolutions/Llama-3.2-SUN-1B-chat"
prune_percent: 0.2
dtype: "float32"
cache_dir: null
device: "cuda"
output: "results/pruned_model"
apply_chat_template: true
prompt: "What is the capital of France?"
max_new_tokens: 50
target_size: null
log_dir: "logs"
test_only: true
prune_method: "mka"
use_normalized_weights: false
use_layer_norm_tweaks: true
layer_norm_scale: 5.0
gate_up_weight_weights: [0.3, 0.7]
print_summary: false
quiet: true
stop_logging: true
20 changes: 20 additions & 0 deletions examples/test_only.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
model_name: "meditsolutions/Llama-3.2-SUN-1B-chat"
prune_percent: 0.2
dtype: "float32"
cache_dir: null
device: "cuda"
output: "results/pruned_model"
apply_chat_template: true
prompt: "What is the capital of France?"
max_new_tokens: 50
target_size: null
log_dir: "logs"
test_only: true
prune_method: "mka"
use_normalized_weights: false
use_layer_norm_tweaks: true
layer_norm_scale: 5.0
gate_up_weight_weights: [0.3, 0.7]
print_summary: false
quiet: false
stop_logging: false
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ transformers==4.46.2
sentencepiece==0.2.0
accelerate==0.26.0
numpy==1.26.0
protobuf==5.28.3
protobuf==5.28.3
pyyaml==6.0.2
86 changes: 86 additions & 0 deletions src/config/prune_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from dataclasses import dataclass, field
from typing import Optional, List
from config.prune_grid_search_config import PruneGridSearchConfig

import logging
import os

logger = logging.getLogger(__name__)


# Dataclass to store the prune configuration.
@dataclass
class PruneConfig:
model_name: str = "meditsolutions/Llama-3.2-SUN-1B-chat"
dtype: str = "float32"
device: str = "cuda"
output: str = "results/pruned_model"
apply_chat_template: bool = False
prompt: str = "What is the capital of France?"
apply_chat_template: bool = False
max_new_tokens: int = 50
prune_percent: float = 0.2
prune_method: str = "mk_prune"
use_normalized_weights: bool = False
use_layer_norm_tweaks: bool = False
layer_norm_scale: float = 4.0
log_dir: str = "logs"
stop_logging: bool = False
test_only: bool = False
print_summary: bool = False
quiet: bool = False
cache_dir: Optional[str] = None
target_size: Optional[int] = None
gate_up_weight_weights: Optional[List[float]] = field(
default_factory=lambda: [1.0, 1.0]
)
eval_dataset: Optional[str] = None
eval_dataset_size: Optional[int] = None
grid_search: Optional[PruneGridSearchConfig] = None

def __post_init__(self):
# Check if the prune method is valid.
if self.prune_method not in ["mk_prune", "mk", "mk_prune_adjusted", "mka"]:
raise ValueError(f"Unknown prune method: {self.prune_method}")
elif self.prune_method in ["mk", "mk_prune"]:
self.prune_method = "mk_prune"
elif self.prune_method in ["mka", "mk_prune_adjusted"]:
self.prune_method = "mk_prune_adjusted"
else:
pass

# Validate parameters
if self.use_layer_norm_tweaks and self.layer_norm_scale <= 0:
raise ValueError("layer_norm_scale must be greater than 0.")

if self.gate_up_weight_weights is not None:
if len(self.gate_up_weight_weights) != 2:
raise ValueError("gate_up_weight_weights must have 2 values.")
if any(
[weight < 0 or weight > 1.0 for weight in self.gate_up_weight_weights]
):
raise ValueError("gate_up_weight_weights must be between 0 and 1")

if self.target_size is not None:
if self.target_size <= 0:
raise ValueError("target_size must be greater than 0.")

if self.prune_percent < 0 or self.prune_percent > 1.0:
raise ValueError("prune_percent must be between 0 and 1")

if self.max_new_tokens <= 10:
raise ValueError("max_new_tokens must be greater than 10")

if self.cache_dir is not None:
if not os.path.exists(self.cache_dir):
raise ValueError("cache_dir does not exist.")

if not os.path.exists(self.output):
os.makedirs(self.output)

if self.eval_dataset is not None and self.grid_search is None:
raise ValueError("Grid search (AutoML) requires an eval_dataset.")

if self.eval_dataset is not None and self.eval_dataset_size is None:
logger.warning("eval_dataset_size is not set. Defaulting to 20.")
self.eval_dataset_size = 20
57 changes: 57 additions & 0 deletions src/config/prune_grid_search_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from dataclasses import dataclass, field
from typing import Optional, List

import itertools


# Dataclass to store the prune grid search configuration.
@dataclass
class PruneGridSearchConfig:
prune_percent_list: List[float] = field(default_factory=lambda: [0.2])
prune_method_list: List[str] = field(default_factory=lambda: ["mk_prune"])
use_normalized_weights_list: List[bool] = field(default_factory=lambda: [False])
use_layer_norm_tweaks_list: List[bool] = field(default_factory=lambda: [False])
layer_norm_scale_list: List[float] = field(default_factory=lambda: [4.0])
target_size_list: List[Optional[int]] = field(default_factory=lambda: [None])
gate_up_weight_weights_list: List[List[float]] = field(
default_factory=lambda: [[1.0, 1.0]]
)

def __post_init__(self):
self.generate_param_combinations()

assert all(
0.0 <= p <= 1.0 for p in self.prune_percent_list
), "Prune percent must be between 0 and 1"
assert all(
m in ["mk_prune", "mk_prune_adjusted", "mk", "mka"]
for m in self.prune_method_list
), "Prune method must be 'mk_prune' or 'mk_prune_adjusted' or 'mk' or 'mka'"
assert all(
isinstance(u, bool) for u in self.use_normalized_weights_list
), "Use normalized weights must be a boolean"
assert all(
isinstance(u, bool) for u in self.use_layer_norm_tweaks_list
), "Use layer norm tweaks must be a boolean"
assert all(
0.0 <= s for s in self.layer_norm_scale_list
), "Layer norm scale must be non-negative"
assert all(
t is None or t > 0 for t in self.target_size_list
), "Target size must be None or a positive integer"
assert all(
all(0.0 <= w <= 1.0 for w in weights)
for weights in self.gate_up_weight_weights_list
), "Gate up weight weights must be between 0 and 1"

@classmethod
def generate_param_combinations(self):
return itertools.product(
self.prune_percent_list,
self.prune_method_list,
self.use_normalized_weights_list,
self.use_layer_norm_tweaks_list,
self.layer_norm_scale_list,
self.target_size_list,
self.gate_up_weight_weights_list,
)
Loading

0 comments on commit 5006781

Please sign in to comment.