Skip to content

UW-CTRL/stljax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

98 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

stljax

A toolbox to compute the robustness of STL formulas using computation graphs. This is the jax version of the STLCG toolbox originally implemented in PyTorch.

Installation

Requires Python 3.10+

Install the repo:

pip install git+https://github.com/UW-CTRL/stljax.git

Alternatively, if you like to install the package in editable mode,

git clone https://github.com/UW-CTRL/stljax.git
cd stljax
pip install -e .

(Best to use a virtual environment.)

Usage

demo.ipynb is an IPython jupyter notebook that showcases the basic functionality of the toolbox:

  • Setting up signals for the formulas, including the use of Expressions and Predicates
  • Defining STL formulas and visualizing them
  • Evaluating STL robustness, and robustness trace

(New) Features

stljax leverages to benefits of jax and automatic differentiation!

Aside from using jax as the backend, stljax is more recent and tidier implementation of stlcg which was originally implemented in PyTorch back ~2019.

  • Removed the distributed_mean hack from original stlcg implementation. jax keeps track of multiple max/min values and will distribute the gradients across all max/min values!

Tags

Tags 🏷️ Description
v.1.1.0 General code improvements. Included recurrent implementation and example notebooks.
v.1.0.0 Removed awkward expected signal dimension & leverage vmap for batched inputs. Masking for temporal operations & remove need to reverse signals.
v0.0.0 A transfer from the 2019 PyTorch implementation to Jax + some tidying + adding Predicates + reversing signal automatically.

Quick intro

Defining STL formulas and computing robustness values

There are two ways to define an STL formula. Using either the Expression and Predicate classes.

Using Expression

With Expression, you are essentially defining a signal whose values are the output of a predicate function computed external to the STL robustness computation formula. Essentially, you process your desired signal first (e.g., from a state trajectory, you compute velocity), and then you pass it directly into the STL formula.

A step-by-step break down:

  1. Suppose you have a trajectory that is an array of size [time_steps, state_dim]

  2. Suppose we have a get_velocity() function and a get_acceleration() function:
    velocity_value = get_velocity(trajectory) # [time_steps]
    acceleration_value = get_acceleration(trajectory) # [time_steps]

  3. Then, we can define the following two Expression objects:
    velocity_exp = Expression("velocity," value=velocity_value)
    acceleration_exp = Expression("acceleration", value=acceleration_value)

  4. With these two expressions, we can define an STL formula ϕ = □ (velocity_exp > 5.0) ∨ ◊ (acceleration_exp > 5.0) which is equivalent to ϕ = Always(velocity_exp > 5.0) & Eventually(acceleration_exp > 5.0).

  5. To compute the robustness trace of ϕ, we run ϕ((velocity_exp, acceleration_exp)) where the input is a tuple since the first part of the formula depends on velocity, and the second part depends on acceleration.

This means that the user needs to compute velocity and acceleration values before calling ϕ to compute the robustness trace (or ϕ.robustness((velocity_exp, acceleration_exp)) for the robustness value)

NOTE: Expressions are used to define an STL formula. While you can, you don't necessarily need to use Expressions as inputs for computing robustness values. So ϕ((velocity_value, acceleration_value)) should also work.

Using Predicate

With Predicate, this is more true to the STL definition. You pass a predicate function when defining an STL formula rather than passing the signal that would be the output of a predicate function. Essentially, you pass your N-D input (e.g., state trajectory) directly into the formula when computing robustness values.

A step-by-step break down:

  1. Suppose you have a trajectory that is an array of size [time_steps, state_dim]

  2. Suppose we have a get_velocity() function and a get_acceleration() function:
    velocity_value = get_velocity(trajectory) # [time_steps]
    acceleration_value = get_acceleration(trajectory) # [time_steps]

  3. Then, we can define the following two Predicate objects:
    velocity_pred = Predicate("velocity", predicate_function=get_velocity)
    acceleration_pred = Predicate("acceleration", predicate_function=get_acceleration)

  4. With these two Predicate objects, we can define an STL formula ϕ = □ (velocity_pred > 5.0) ∨ ◊ (acceleration_pred > 5.0) which is equivalent to ϕ = Always(velocity_pred > 5.0) & Eventually(acceleration_pred > 5.0).

  5. To compute the robustness trace of ϕ, we run ϕ(trajectory) where the input is what all the predicate functions expect the input to be.

In summary:
When using Predicates to define STL formulas, it will extract the velocity and acceleration values inside the robustness computation. Whereas when using Expressions, you need to extract the velocity and acceleration outside of the robustness computation.

Handling multiple signals

We can use jax.vmap to handle multiple signals at once.

jax.vmap(formula)(signals) # signals is shape [bs, time_dim,...]

NOTE: Need to take care for formulas defined with Expressions and need multiple inputs. Need a wrapper since jax.vmap doesn't like tuples in a single argument.

TODOs

  • manage reversing of signals internally for recurrent cases.

Publications

Here is a list of publications that use stlcg/stljax. Please file an issue, or pull request to add your publication to the list.

P. Kapoor, K. Mizuta, E. Kang, and K. Leung, "STLCG++: A Masking Approach for Differentiable Signal Temporal Logic Specification," ArXiv Preprint, 2025.

K. Leung, and M. Pavone, "Semi-Supervised Trajectory-Feedback Controller Synthesis for Signal Temporal Logic Specifications," in American Control Conference, 2022.

K. Leung, N. Aréchiga, and M. Pavone, "Backpropagation through STL specifications: Infusing logical structure into gradient-based methods," International Journal of Robotics Research, 2022.

J. DeCastro, K. Leung, N. Aréchiga, and M. Pavone, "Interpretable Policies from Formally-Specified Temporal Properties," in Proc. IEEE Int. Conf. on Intelligent Transportation Systems, Rhodes, Greece, 2020.

K. Leung, N. Arechiga, and M. Pavone, "Backpropagation for Parametric STL," in IEEE Intelligent Vehicles Symposium: Workshop on Unsupervised Learning for Automated Driving, Paris, France, 2019.

Citing stljax

When citing stljax, or stlcg++ (masking approach), please use the following citation

@article{KapoorMizutaEtAl2025,
  author = {Kapoor, P. and Mizuta, K. and Kang, E. and Leung, K.},
  journal = {{{Available at }\url{https://arxiv.org/abs/2501.04194}}},
  title = {{STLCG++}: A Masking Approach for Differentiable Signal Temporal Logic Specification},
  year = {2022}
}

When citing stlcg (recurrent approach), please use the following citations:

# journal paper
@Article{LeungArechigaEtAl2020,
  author       = {Leung, K. and Ar\'{e}chiga, N. and Pavone, M.},
  title        = {Backpropagation through signal temporal logic specifications: Infusing logical structure into gradient-based methods},
  booktitle    = {{Int. Journal of Robotics Research}},
  year         = {2022},
}

# conference paper
@Inproceedings{LeungArechigaEtAl2020,
  author       = {Leung, K. and Ar\'{e}chiga, N. and Pavone, M.},
  title        = {Backpropagation through signal temporal logic specifications: Infusing logical structure into gradient-based methods},
  booktitle    = {{Workshop on Algorithmic Foundations of Robotics}},
  year         = {2020},
}

Feedback

If there are any issues with the code, please make file an issue, or make a pull request.