diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f40fbd8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +_site +.sass-cache +.jekyll-cache +.jekyll-metadata +vendor diff --git a/404.html b/404.html new file mode 100644 index 0000000..086a5c9 --- /dev/null +++ b/404.html @@ -0,0 +1,25 @@ +--- +permalink: /404.html +layout: default +--- + + + +
+

404

+ +

Page not found :(

+

The requested page could not be found.

+
diff --git a/Gemfile b/Gemfile new file mode 100644 index 0000000..07ac8bd --- /dev/null +++ b/Gemfile @@ -0,0 +1,35 @@ +source "https://rubygems.org" +# Hello! This is where you manage which Jekyll version is used to run. +# When you want to use a different version, change it below, save the +# file and run `bundle install`. Run Jekyll with `bundle exec`, like so: +# +# bundle exec jekyll serve +# +# This will help ensure the proper Jekyll version is running. +# Happy Jekylling! +#gem "jekyll", "~> 4.2.2" +# This is the default theme for new Jekyll sites. You may change this to anything you like. +gem "minima", "~> 2.5" +# If you want to use GitHub Pages, remove the "gem "jekyll"" above and +# uncomment the line below. To upgrade, run `bundle update github-pages`. +gem "github-pages", "~> 228", group: :jekyll_plugins +# If you have any plugins, put them here! +group :jekyll_plugins do + gem "jekyll-feed", "~> 0.12" +end + +# Windows and JRuby does not include zoneinfo files, so bundle the tzinfo-data gem +# and associated library. +platforms :mingw, :x64_mingw, :mswin, :jruby do + gem "tzinfo", "~> 1.2" + gem "tzinfo-data" +end + +# Performance-booster for watching directories on Windows +gem "wdm", "~> 0.1.1", :platforms => [:mingw, :x64_mingw, :mswin] + +# Lock `http_parser.rb` gem to `v0.6.x` on JRuby builds since newer versions of the gem +# do not have a Java counterpart. +gem "http_parser.rb", "~> 0.6.0", :platforms => [:jruby] + +gem "webrick", "~> 1.8" diff --git a/Gemfile.lock b/Gemfile.lock new file mode 100644 index 0000000..a6dd7a8 --- /dev/null +++ b/Gemfile.lock @@ -0,0 +1,278 @@ +GEM + remote: https://rubygems.org/ + specs: + activesupport (7.1.1) + base64 + bigdecimal + concurrent-ruby (~> 1.0, >= 1.0.2) + connection_pool (>= 2.2.5) + drb + i18n (>= 1.6, < 2) + minitest (>= 5.1) + mutex_m + tzinfo (~> 2.0) + addressable (2.8.5) + public_suffix (>= 2.0.2, < 6.0) + base64 (0.1.1) + bigdecimal (3.1.4) + coffee-script (2.4.1) + coffee-script-source + execjs + coffee-script-source (1.11.1) + colorator (1.1.0) + commonmarker (0.23.10) + concurrent-ruby (1.2.2) + connection_pool (2.4.1) + dnsruby (1.70.0) + simpleidn (~> 0.2.1) + drb (2.1.1) + ruby2_keywords + em-websocket (0.5.3) + eventmachine (>= 0.12.9) + http_parser.rb (~> 0) + ethon (0.16.0) + ffi (>= 1.15.0) + eventmachine (1.2.7) + execjs (2.9.1) + faraday (2.7.11) + base64 + faraday-net_http (>= 2.0, < 3.1) + ruby2_keywords (>= 0.0.4) + faraday-net_http (3.0.2) + ffi (1.16.3) + forwardable-extended (2.6.0) + gemoji (3.0.1) + github-pages (228) + github-pages-health-check (= 1.17.9) + jekyll (= 3.9.3) + jekyll-avatar (= 0.7.0) + jekyll-coffeescript (= 1.1.1) + jekyll-commonmark-ghpages (= 0.4.0) + jekyll-default-layout (= 0.1.4) + jekyll-feed (= 0.15.1) + jekyll-gist (= 1.5.0) + jekyll-github-metadata (= 2.13.0) + jekyll-include-cache (= 0.2.1) + jekyll-mentions (= 1.6.0) + jekyll-optional-front-matter (= 0.3.2) + jekyll-paginate (= 1.1.0) + jekyll-readme-index (= 0.3.0) + jekyll-redirect-from (= 0.16.0) + jekyll-relative-links (= 0.6.1) + jekyll-remote-theme (= 0.4.3) + jekyll-sass-converter (= 1.5.2) + jekyll-seo-tag (= 2.8.0) + jekyll-sitemap (= 1.4.0) + jekyll-swiss (= 1.0.0) + jekyll-theme-architect (= 0.2.0) + jekyll-theme-cayman (= 0.2.0) + jekyll-theme-dinky (= 0.2.0) + jekyll-theme-hacker (= 0.2.0) + jekyll-theme-leap-day (= 0.2.0) + jekyll-theme-merlot (= 0.2.0) + jekyll-theme-midnight (= 0.2.0) + jekyll-theme-minimal (= 0.2.0) + jekyll-theme-modernist (= 0.2.0) + jekyll-theme-primer (= 0.6.0) + jekyll-theme-slate (= 0.2.0) + jekyll-theme-tactile (= 0.2.0) + jekyll-theme-time-machine (= 0.2.0) + jekyll-titles-from-headings (= 0.5.3) + jemoji (= 0.12.0) + kramdown (= 2.3.2) + kramdown-parser-gfm (= 1.1.0) + liquid (= 4.0.4) + mercenary (~> 0.3) + minima (= 2.5.1) + nokogiri (>= 1.13.6, < 2.0) + rouge (= 3.26.0) + terminal-table (~> 1.4) + github-pages-health-check (1.17.9) + addressable (~> 2.3) + dnsruby (~> 1.60) + octokit (~> 4.0) + public_suffix (>= 3.0, < 5.0) + typhoeus (~> 1.3) + html-pipeline (2.14.3) + activesupport (>= 2) + nokogiri (>= 1.4) + http_parser.rb (0.8.0) + i18n (1.14.1) + concurrent-ruby (~> 1.0) + jekyll (3.9.3) + addressable (~> 2.4) + colorator (~> 1.0) + em-websocket (~> 0.5) + i18n (>= 0.7, < 2) + jekyll-sass-converter (~> 1.0) + jekyll-watch (~> 2.0) + kramdown (>= 1.17, < 3) + liquid (~> 4.0) + mercenary (~> 0.3.3) + pathutil (~> 0.9) + rouge (>= 1.7, < 4) + safe_yaml (~> 1.0) + jekyll-avatar (0.7.0) + jekyll (>= 3.0, < 5.0) + jekyll-coffeescript (1.1.1) + coffee-script (~> 2.2) + coffee-script-source (~> 1.11.1) + jekyll-commonmark (1.4.0) + commonmarker (~> 0.22) + jekyll-commonmark-ghpages (0.4.0) + commonmarker (~> 0.23.7) + jekyll (~> 3.9.0) + jekyll-commonmark (~> 1.4.0) + rouge (>= 2.0, < 5.0) + jekyll-default-layout (0.1.4) + jekyll (~> 3.0) + jekyll-feed (0.15.1) + jekyll (>= 3.7, < 5.0) + jekyll-gist (1.5.0) + octokit (~> 4.2) + jekyll-github-metadata (2.13.0) + jekyll (>= 3.4, < 5.0) + octokit (~> 4.0, != 4.4.0) + jekyll-include-cache (0.2.1) + jekyll (>= 3.7, < 5.0) + jekyll-mentions (1.6.0) + html-pipeline (~> 2.3) + jekyll (>= 3.7, < 5.0) + jekyll-optional-front-matter (0.3.2) + jekyll (>= 3.0, < 5.0) + jekyll-paginate (1.1.0) + jekyll-readme-index (0.3.0) + jekyll (>= 3.0, < 5.0) + jekyll-redirect-from (0.16.0) + jekyll (>= 3.3, < 5.0) + jekyll-relative-links (0.6.1) + jekyll (>= 3.3, < 5.0) + jekyll-remote-theme (0.4.3) + addressable (~> 2.0) + jekyll (>= 3.5, < 5.0) + jekyll-sass-converter (>= 1.0, <= 3.0.0, != 2.0.0) + rubyzip (>= 1.3.0, < 3.0) + jekyll-sass-converter (1.5.2) + sass (~> 3.4) + jekyll-seo-tag (2.8.0) + jekyll (>= 3.8, < 5.0) + jekyll-sitemap (1.4.0) + jekyll (>= 3.7, < 5.0) + jekyll-swiss (1.0.0) + jekyll-theme-architect (0.2.0) + jekyll (> 3.5, < 5.0) + jekyll-seo-tag (~> 2.0) + jekyll-theme-cayman (0.2.0) + jekyll (> 3.5, < 5.0) + jekyll-seo-tag (~> 2.0) + jekyll-theme-dinky (0.2.0) + jekyll (> 3.5, < 5.0) + jekyll-seo-tag (~> 2.0) + jekyll-theme-hacker (0.2.0) + jekyll (> 3.5, < 5.0) + jekyll-seo-tag (~> 2.0) + jekyll-theme-leap-day (0.2.0) + jekyll (> 3.5, < 5.0) + jekyll-seo-tag (~> 2.0) + jekyll-theme-merlot (0.2.0) + jekyll (> 3.5, < 5.0) + jekyll-seo-tag (~> 2.0) + jekyll-theme-midnight (0.2.0) + jekyll (> 3.5, < 5.0) + jekyll-seo-tag (~> 2.0) + jekyll-theme-minimal (0.2.0) + jekyll (> 3.5, < 5.0) + jekyll-seo-tag (~> 2.0) + jekyll-theme-modernist (0.2.0) + jekyll (> 3.5, < 5.0) + jekyll-seo-tag (~> 2.0) + jekyll-theme-primer (0.6.0) + jekyll (> 3.5, < 5.0) + jekyll-github-metadata (~> 2.9) + jekyll-seo-tag (~> 2.0) + jekyll-theme-slate (0.2.0) + jekyll (> 3.5, < 5.0) + jekyll-seo-tag (~> 2.0) + jekyll-theme-tactile (0.2.0) + jekyll (> 3.5, < 5.0) + jekyll-seo-tag (~> 2.0) + jekyll-theme-time-machine (0.2.0) + jekyll (> 3.5, < 5.0) + jekyll-seo-tag (~> 2.0) + jekyll-titles-from-headings (0.5.3) + jekyll (>= 3.3, < 5.0) + jekyll-watch (2.2.1) + listen (~> 3.0) + jemoji (0.12.0) + gemoji (~> 3.0) + html-pipeline (~> 2.2) + jekyll (>= 3.0, < 5.0) + kramdown (2.3.2) + rexml + kramdown-parser-gfm (1.1.0) + kramdown (~> 2.0) + liquid (4.0.4) + listen (3.8.0) + rb-fsevent (~> 0.10, >= 0.10.3) + rb-inotify (~> 0.9, >= 0.9.10) + mercenary (0.3.6) + minima (2.5.1) + jekyll (>= 3.5, < 5.0) + jekyll-feed (~> 0.9) + jekyll-seo-tag (~> 2.1) + minitest (5.20.0) + mutex_m (0.1.2) + nokogiri (1.15.4-arm64-darwin) + racc (~> 1.4) + octokit (4.25.1) + faraday (>= 1, < 3) + sawyer (~> 0.9) + pathutil (0.16.2) + forwardable-extended (~> 2.6) + public_suffix (4.0.7) + racc (1.7.3) + rb-fsevent (0.11.2) + rb-inotify (0.10.1) + ffi (~> 1.0) + rexml (3.2.6) + rouge (3.26.0) + ruby2_keywords (0.0.5) + rubyzip (2.3.2) + safe_yaml (1.0.5) + sass (3.7.4) + sass-listen (~> 4.0.0) + sass-listen (4.0.0) + rb-fsevent (~> 0.9, >= 0.9.4) + rb-inotify (~> 0.9, >= 0.9.7) + sawyer (0.9.2) + addressable (>= 2.3.5) + faraday (>= 0.17.3, < 3) + simpleidn (0.2.1) + unf (~> 0.1.4) + terminal-table (1.8.0) + unicode-display_width (~> 1.1, >= 1.1.1) + typhoeus (1.4.0) + ethon (>= 0.9.0) + tzinfo (2.0.6) + concurrent-ruby (~> 1.0) + unf (0.1.4) + unf_ext + unf_ext (0.0.8.2) + unicode-display_width (1.8.0) + webrick (1.8.1) + +PLATFORMS + arm64-darwin-21 + +DEPENDENCIES + github-pages (~> 228) + http_parser.rb (~> 0.6.0) + jekyll-feed (~> 0.12) + minima (~> 2.5) + tzinfo (~> 1.2) + tzinfo-data + wdm (~> 0.1.1) + webrick (~> 1.8) + +BUNDLED WITH + 2.3.11 diff --git a/_config.yml b/_config.yml new file mode 100644 index 0000000..45a35cf --- /dev/null +++ b/_config.yml @@ -0,0 +1,52 @@ +# Welcome to Jekyll! +# +# This config file is meant for settings that affect your whole blog, values +# which you are expected to set up once and rarely edit after that. If you find +# yourself editing this file very often, consider using Jekyll's data files +# feature for the data you need to update frequently. +# +# For technical reasons, this file is *NOT* reloaded automatically when you use +# 'bundle exec jekyll serve'. If you change this file, please restart the server process. +# +# If you need help with YAML syntax, here are some quick references for you: +# https://learn-the-web.algonquindesign.ca/topics/markdown-yaml-cheat-sheet/#yaml +# https://learnxinyminutes.com/docs/yaml/ +# +# Site settings +# These are used to personalize your new site. If you look in the HTML files, +# you will see them accessed via {{ site.title }}, {{ site.email }}, and so on. +# You can create any custom variable you would like, and they will be accessible +# in the templates via {{ site.myvariable }}. + +title: SFR +email: aidan.scannell@aalto.fi +description: >- # this means to ignore newlines until "baseurl:" + SFR: Sparse Function-space Representation of Neural Networks +# baseurl: "/" # the subpath of your site, e.g. /blog +url: "https://aaltoml.github.io/sfr" # the base hostname & protocol for your site, e.g. http://example.com +github_username: AaltoML + +# Build settings +theme: minima +plugins: + - jekyll-feed + +# Exclude from processing. +# The following items will not be processed, by default. +# Any item listed under the `exclude:` key here will be automatically added to +# the internal "default list". +# +# Excluded items can be processed by explicitly listing the directories or +# their entries' file path in the `include:` list. +# +# exclude: +# - .sass-cache/ +# - .jekyll-cache/ +# - gemfiles/ +# - Gemfile +# - Gemfile.lock +# - node_modules/ +# - vendor/bundle/ +# - vendor/cache/ +# - vendor/gems/ +# - vendor/ruby/ diff --git a/index.markdown b/index.markdown new file mode 100644 index 0000000..e67c63f --- /dev/null +++ b/index.markdown @@ -0,0 +1,129 @@ +--- +# Feel free to add content and custom Front Matter to this file. +# To modify the layout, see https://jekyllrb.com/docs/themes/#overriding-theme-defaults + +# layout: home +layout: post +title: "SFR: Sparse Function-space Representation of Neural Networks" +date: 2023-11-04 14:36:41 +0200 +categories: jekyll update +author: "Aidan Scannell*, Riccardo Mereu*, Paul Chang, Ella Tamir, Joni Pajarinen, Arno Solin" +--- +Workshop Paper +Code + + + +![SFR](regression.png "SFR") + + +## Abstract +> Deep neural networks (NNs) are known to lack uncertainty estimates and struggle to incorporate new data. We present a method that mitigates these issues by converting NNs from weight space to function space, via a dual parameterization. Importantly, the dual parameterization enables us to formulate a sparse representation that captures information from the entire data set. This offers a compact and principled way of capturing uncertainty and enables us to incorporate new data without retraining whilst retaining predictive performance. We provide proof-of-concept demonstrations with the proposed approach for quantifying uncertainty in supervised learning on UCI benchmark tasks. + + +## TL;DR +- `SFR` is a "posthoc" Bayesian deep learning method + - equip any trained NN with uncertainty estimates + + + + +- `SFR` can be viewed as a function-space Laplace approximation for NNs +- `SFR` has several benefits over [weight-space Laplace approximation for NNs](https://arxiv.org/abs/2106.14806): + - Its function-space representation is effective for regularization in continual learning (CL) + - It can incorporate new data without retraining the NN + - It has good uncertainty estimates + - We use them to guide exploration in model-based reinforcement learning (RL) + + + +| | **SFR** | **GP** | **Laplace BNN** | +|-------------------------------|:-------:|:------:|:--------------------------:| +| **Function-space** | ✅ | ✅ | ❌ (*weight space*) | +| **Image inputs** | ✅ | ❌ | ✅ | +| **Large data** | ✅ | ❌ | ✅ | +| **Incorporate new data fast** | ✅/❌ | ✅ | ❌ (*requires retraining*) | + +## Useage +See the [notebooks](https://github.com/AaltoML/sfr/tree/main/notebooks) for how to use our code for both regression and classification. + +### Minimal example +Here's a short example: +{% highlight python %} +import src +import torch + +torch.set_default_dtype(torch.float64) + +def func(x, noise=True): + return torch.sin(x * 5) / x + torch.cos(x * 10) + +# Toy data set +X_train = torch.rand((100, 1)) * 2 +Y_train = func(X_train, noise=True) +data = (X_train, Y_train) + +# Training config +width = 64 +num_epochs = 1000 +batch_size = 16 +learning_rate = 1e-3 +delta = 0.00005 # prior precision +data_loader = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset(*data), batch_size=batch_size +) + +# Create a neural network +network = torch.nn.Sequential( + torch.nn.Linear(1, width), + torch.nn.Tanh(), + torch.nn.Linear(width, width), + torch.nn.Tanh(), + torch.nn.Linear(width, 1), +) + +# Instantiate SFR (handles NN training/prediction as they're coupled via the prior/likelihood) +sfr = src.SFR( + network=network, + prior=src.priors.Gaussian(params=network.parameters, delta=delta), + likelihood=src.likelihoods.Gaussian(sigma_noise=2), + output_dim=1, + num_inducing=32, + dual_batch_size=None, # this reduces the memory required for computing dual parameters + jitter=1e-4, +) + +sfr.train() +optimizer = torch.optim.Adam([{"params": sfr.parameters()}], lr=learning_rate) +for epoch_idx in range(num_epochs): + for batch_idx, batch in enumerate(data_loader): + x, y = batch + loss = sfr.loss(x, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + +sfr.set_data(data) # This builds the dual parameters + +# Make predictions in function space +X_test = torch.linspace(-0.7, 3.5, 300, dtype=torch.float64).reshape(-1, 1) +f_mean, f_var = sfr.predict_f(X_test) + +# Make predictions in output space +y_mean, y_var = sfr.predict(X_test) +{% endhighlight %} + + +## Citation +Please consider citing our workshop paper. + +{% highlight bibtex %} +@inproceedings{scannellSparse2023, + title = {Sparse Function-space Representation of Neural Networks}, + maintitle = {ICML 2023 Workshop on Duality Principles for Modern Machine Learning}, + author = {Aidan Scannell and Riccardo Mereu and Paul Chang and Ella Tami and Joni Pajarinen and Arno Solin}, + year = {2023}, + month = {7}, +} +{% endhighlight %} + diff --git a/regression.png b/regression.png new file mode 100644 index 0000000..b45895c Binary files /dev/null and b/regression.png differ