This project provides an implementation of GFlowNet for saving trajectories generated by the model during training. The primary focus is on extracting and saving trajectories using the save_gflownet_trajectories
function. This allows for detailed analysis and further use of the trajectories for research or model improvement.
The save_gflownet_trajectories
functionality allows users to:
- Generate trajectories using a GFlowNet model.
- Save these trajectories in a JSON format for post-processing.
- Customize the number of trajectories, environment settings, and file paths via command-line arguments.
- Python 3.8+
- PyTorch
- numpy
- gymnasium
- wandb (optional for logging)
Clone the repository and install the required dependencies:
$ git clone <repository_url>
$ cd <repository_folder>
$ pip install -r requirements.txt
Argument | Description | Default Value |
---|---|---|
--save_trajectories |
File path to save the trajectories (JSON format) | None |
--num_trajectories |
Number of trajectories to generate | 100 |
--env_mode |
Environment mode (entire or partial ) |
entire |
--prob_index |
Problem index for ARC tasks | 178 |
--num_actions |
Number of available actions in the environment | 5 |
To save trajectories using the save_gflownet_trajectories
function, execute the following command:
python main.py --save_trajectories "trajectories.json" --num_trajectories 100 \
--env_mode "entire" --prob_index 178 --num_actions 5
--save_trajectories
: Specifies the output JSON file for the trajectories.--num_trajectories
: Sets the number of trajectories to generate.--env_mode
: Configures the environment mode (entire
includes all tasks).--prob_index
: Selects the specific ARC problem index.--num_actions
: Defines the number of actions available in the environment.
The command generates 100 trajectories for the specified environment and saves them to trajectories.json
.
project/
├── main.py # Main execution script
├── train.py # Training and trajectory saving logic
├── replay_buffer.py # Replay Buffer for off-policy learning
├── config.py # Configuration file
├── utils.py # Utility functions
├── gflow/
│ ├── gflownet_target.py
│ ├── utils.py
├── ARCenv/
│ ├── wrapper.py
├── arcle/
│ ├── loaders.py
└── policy_target.py # Policy model definitions
The config.py
file provides a central location for modifying default settings, such as:
Key | Description | Default Value |
---|---|---|
CUDANUM |
CUDA device index | 0 |
DEVICE |
Device configuration (CPU/GPU) | cuda:0 |
TASKNUM |
ARC Task problem index | 178 |
ACTIONNUM |
Number of actions in the environment | 5 |
WANDB_USE |
Enable WandB logging | True |
REPLAY_BUFFER_CAPACITY |
Maximum size of the replay buffer | 10000 |
python main.py --save_trajectories "output_trajectories.json" --num_trajectories 50
python main.py --save_trajectories "trajectories_partial.json" --env_mode "partial"
For further information or questions, please contact:
- Name: [Your Name]
- Email: [Your Email]
- GitHub: [Your GitHub Profile]